Wasserstein Embedding for Graph Learning

We present Wasserstein Embedding for Graph Learning (WEGL), a novel and fast framework for embedding entire graphs in a vector space, in which various machine learning models are applicable for graph-level prediction tasks. We leverage new insights on defining similarity between graphs as a function of the similarity between their node embedding distributions. Specifically, we use the Wasserstein distance to measure the dissimilarity between node embeddings of different graphs. Unlike prior work, we avoid pairwise calculation of distances between graphs and reduce the computational complexity from quadratic to linear in the number of graphs. WEGL calculates Monge maps from a reference distribution to each node embedding and, based on these maps, creates a fixed-sized vector representation of the graph. We evaluate our new graph embedding approach on various benchmark graph-property prediction tasks, showing state-of-the-art classification performance while having superior computational efficiency. The code is available at https://github.com/navid-naderi/WEGL.

PDF Abstract ICLR 2021 PDF ICLR 2021 Abstract

Results from the Paper


Task Dataset Model Metric Name Metric Value Global Rank Result Benchmark
Graph Classification COLLAB WEGL Accuracy 79.8% # 14
Graph Classification D&D WEGL Accuracy 78.6% # 23
Graph Classification ENZYMES WEGL Accuracy 60.5 # 20
Graph Classification IMDb-B WEGL Accuracy 75.4% # 13
Graph Classification IMDb-M WEGL Accuracy 52% # 12
Graph Classification MUTAG WEGL Accuracy 88.3% # 41
Graph Classification NCI1 WEGL Accuracy 76.8% # 34
Graph Property Prediction ogbg-molhiv WEGL Test ROC-AUC 0.7757 ± 0.0111 # 33
Validation ROC-AUC 0.8101 ± 0.0097 # 33
Number of params 361064 # 24
Ext. data No # 1
Graph Classification PROTEINS WEGL Accuracy 76.5% # 35
Graph Classification PTC WEGL Accuracy 67.5% # 17
Graph Classification REDDIT-B WEGL Accuracy 92 # 5
Graph Classification RE-M12K WEGL Accuracy 47.8% # 4
Graph Classification RE-M5K WEGL Accuracy 55.1% # 3

Methods