The Split Matters: Flat Minima Methods for Improving the Performance of GNNs

15 Jun 2023  ·  Nicolas Lell, Ansgar Scherp ·

When training a Neural Network, it is optimized using the available training data with the hope that it generalizes well to new or unseen testing data. At the same absolute value, a flat minimum in the loss landscape is presumed to generalize better than a sharp minimum. Methods for determining flat minima have been mostly researched for independent and identically distributed (i. i. d.) data such as images. Graphs are inherently non-i. i. d. since the vertices are edge-connected. We investigate flat minima methods and combinations of those methods for training graph neural networks (GNNs). We use GCN and GAT as well as extend Graph-MLP to work with more layers and larger graphs. We conduct experiments on small and large citation, co-purchase, and protein datasets with different train-test splits in both the transductive and inductive training procedure. Results show that flat minima methods can improve the performance of GNN models by over 2 points, if the train-test split is randomized. Following Shchur et al., randomized splits are essential for a fair evaluation of GNNs, as other (fixed) splits like 'Planetoid' are biased. Overall, we provide important insights for improving and fairly evaluating flat minima methods on GNNs. We recommend practitioners to always use weight averaging techniques, in particular EWA when using early stopping. While weight averaging techniques are only sometimes the best performing method, they are less sensitive to hyperparameters, need no additional training, and keep the original model unchanged. All source code is available in https://github.com/Foisunt/FMMs-in-GNNs.

PDF Abstract
Task Dataset Model Metric Name Metric Value Global Rank Benchmark
Node Classification Citeseer Graph-MLP + SWA Accuracy 77.99 ± 1.57% # 11
Node Classification CiteSeer with Public Split: fixed 20 nodes per class Graph-MLP + PGN Accuracy 74.73 ± 0.6% # 4
Node Classification Cora GAT + SWA Accuracy 88.66 ± 1.38% # 8
Node Classification Cora with Public Split: fixed 20 nodes per class GAT+PGN Accuracy 83.26 ± 0.69% # 16
Node Classification PPI GCN + SAF F1 99.38 ± 0.01% # 8
Node Classification PPI GAT + PGN F1 99.34 ± 0.02% # 10
Node Classification Pubmed Graph-MLP + SAF Accuracy 90.64 ± 0.46% # 5
Node Classification PubMed (60%/20%/20% random splits) Graph-MLP + SAF 1:1 Accuracy 90.64 ± 0.46% # 7
Node Classification PubMed with Public Split: fixed 20 nodes per class Graph-MLP + ASAM Accuracy 82.60 ± 0.80% # 4

Methods