Global Self-Attention as a Replacement for Graph Convolution

7 Aug 2021  ·  Md Shamim Hussain, Mohammed J. Zaki, Dharmashankar Subramanian ·

We propose an extension to the transformer neural network architecture for general-purpose graph learning by adding a dedicated pathway for pairwise structural information, called edge channels. The resultant framework - which we call Edge-augmented Graph Transformer (EGT) - can directly accept, process and output structural information of arbitrary form, which is important for effective learning on graph-structured data. Our model exclusively uses global self-attention as an aggregation mechanism rather than static localized convolutional aggregation. This allows for unconstrained long-range dynamic interactions between nodes. Moreover, the edge channels allow the structural information to evolve from layer to layer, and prediction tasks on edges/links can be performed directly from the output embeddings of these channels. We verify the performance of EGT in a wide range of graph-learning experiments on benchmark datasets, in which it outperforms Convolutional/Message-Passing Graph Neural Networks. EGT sets a new state-of-the-art for the quantum-chemical regression task on the OGB-LSC PCQM4Mv2 dataset containing 3.8 million molecular graphs. Our findings indicate that global self-attention based aggregation can serve as a flexible, adaptive and effective replacement of graph convolution for general-purpose graph learning. Therefore, convolutional local neighborhood aggregation is not an essential inductive bias.

PDF Abstract

Results from the Paper

Task Dataset Model Metric Name Metric Value Global Rank Result Benchmark
Graph Classification CIFAR10 100k EGT Accuracy (%) 68.702 # 7
Node Classification CLUSTER EGT Accuracy 79.232 # 2
Graph Classification MNIST EGT Accuracy 98.173 # 2
Graph Property Prediction ogbg-molhiv EGT Test ROC-AUC 0.806 ± 0.0065 # 11
Graph Property Prediction ogbg-molpcba EGT Test AP 0.2961 ± 0.0024 # 11
Node Classification PATTERN EGT Accuracy 86.821 # 2
Node Classification PATTERN 100k EGT Accuracy (%) 86.816 # 1
Graph Regression PCQM4M-LSC EGT Validation MAE 0.1224 # 3
Graph Regression PCQM4Mv2-LSC EGT Validation MAE 0.0857 # 6
Test MAE 0.0862 # 5
Link Prediction TSP/HCP Benchmark set EGT F1 0.853 # 2
Graph Regression ZINC 100k EGT MAE 0.143 # 4
Graph Regression ZINC-500k EGT MAE 0.108 # 17