Enhancing Out-of-distribution Generalization on Graphs via Causal Attention Learning

ACM Transactions on Knowledge Discovery from Data(2023)

引用 0|浏览2
暂无评分
摘要
In graph classification, attention- and pooling-based graph neural networks (GNNs) predominate to extract salient features from the input graph and support the prediction. They mostly follow the paradigm of “learning to attend”, which maximizes the mutual information between the attended graph and the ground-truth label. However, this paradigm causes GNN classifiers to indiscriminately absorb all statistical correlations between input features and labels in the training data, without distinguishing the causal and noncausal effects of features. Rather than emphasizing causal features, the attended graphs tend to rely on noncausal features as shortcuts to predictions. These shortcut features may easily change outside the training distribution, thereby leading to poor generalization for GNN classifiers. In this paper, we take a causal view on GNN modeling. Under our causal assumption, the shortcut feature serves as a confounder between the causal feature and prediction. It misleads the classifier into learning spurious correlations that facilitate prediction in in-distribution (ID) test evaluation, while causing significant performance drop in out-of-distribution (OOD) test data. To address this issue, we employ the backdoor adjustment from causal theory — combining each causal feature with various shortcut features, to identify causal patterns and mitigate the confounding effect. Specifically, we employ attention modules to estimate the causal and shortcut features of the input graph. Then, a memory bank collects the estimated shortcut features, enhancing the diversity of shortcut features for combination. Simultaneously, we apply the prototype strategy to improve the consistency of intra-class causal features. We term our method as CAL+, which can promote stable relationships between causal estimation and prediction, regardless of distribution changes. Extensive experiments on synthetic and real-world OOD benchmarks demonstrate our method’s effectiveness in improving OOD generalization. Our codes are released at https://github.com/shuyao-wang/CAL-plus.
更多
查看译文
关键词
Graph Learning,Attention Mechanism,Out-of-distribution Generalization
AI 理解论文
溯源树
样例
生成溯源树,研究论文发展脉络
Chat Paper
正在生成论文摘要