Learning to Reweight for Generalizable Graph Neural Network

AAAI 2024(2024)

引用 0|浏览10
Graph Neural Networks (GNNs) show promising results for graph tasks. However, existing GNNs' generalization ability will degrade when there exist distribution shifts between testing and training graph data. The fundamental reason for the severe degeneration is that most GNNs are designed based on the I.I.D hypothesis. In such a setting, GNNs tend to exploit subtle statistical correlations existing in the training set for predictions, even though it is a spurious correlation. In this paper, we study the problem of the generalization ability of GNNs on Out-Of-Distribution (OOD) settings. To solve this problem, we propose the Learning to Reweight for Generalizable Graph Neural Network (L2R-GNN) to enhance the generalization ability for achieving satisfactory performance on unseen testing graphs that have different distributions with training graphs. We propose a novel nonlinear graph decorrelation method, which can substantially improve the out-of-distribution generalization ability and compares favorably to previous methods in restraining the over-reduced sample size. The variables of graph representation are clustered based on the stability of their correlations, and graph decorrelation method learns weights to remove correlations between the variables of different clusters rather than any two variables. Besides, we introduce an effective stochastic algorithm based on bi-level optimization for the L2R-GNN framework, which enables simultaneously learning the optimal weights and GNN parameters, and avoids the over-fitting issue. Experiments show that L2R-GNN greatly outperforms baselines on various graph prediction benchmarks under distribution shifts.
DMKM: Graph Mining, Social Network Analysis & Community
AI 理解论文
Chat Paper