Improving Generalization in Federated Learning by Seeking Flat Minima

22 Mar 2022  ·  Debora Caldarola, Barbara Caputo, Marco Ciccone ·

Models trained in federated settings often suffer from degraded performances and fail at generalizing, especially when facing heterogeneous scenarios. In this work, we investigate such behavior through the lens of geometry of the loss and Hessian eigenspectrum, linking the model's lack of generalization capacity to the sharpness of the solution. Motivated by prior studies connecting the sharpness of the loss surface and the generalization gap, we show that i) training clients locally with Sharpness-Aware Minimization (SAM) or its adaptive version (ASAM) and ii) averaging stochastic weights (SWA) on the server-side can substantially improve generalization in Federated Learning and help bridging the gap with centralized models. By seeking parameters in neighborhoods having uniform low loss, the model converges towards flatter minima and its generalization significantly improves in both homogeneous and heterogeneous scenarios. Empirical results demonstrate the effectiveness of those optimizers across a variety of benchmark vision datasets (e.g. CIFAR10/100, Landmarks-User-160k, IDDA) and tasks (large scale classification, semantic segmentation, domain generalization).

PDF Abstract
Task Dataset Model Metric Name Metric Value Global Rank Result Benchmark
Federated Learning CIFAR-100 (alpha=0, 10 clients per round) FedASAM + SWA ACC@1-100Clients 42.64 # 1
Federated Learning CIFAR-100 (alpha=0, 10 clients per round) FedSAM ACC@1-100Clients 36.93 # 4
Federated Learning CIFAR-100 (alpha=0, 10 clients per round) FedAvg ACC@1-100Clients 36.74 # 5
Federated Learning CIFAR-100 (alpha=0, 10 clients per round) FedSAM + SWA ACC@1-100Clients 39.51 # 3
Federated Learning CIFAR-100 (alpha=0, 10 clients per round) FedASAM ACC@1-100Clients 39.76 # 2
Federated Learning CIFAR-100 (alpha=0, 20 clients per round) FedASAM ACC@1-100Clients 40.81 # 2
Federated Learning CIFAR-100 (alpha=0, 20 clients per round) FedASAM + SWA ACC@1-100Clients 41.62 # 1
Federated Learning CIFAR-100 (alpha=0, 20 clients per round) FedSAM ACC@1-100Clients 38.56 # 5
Image Classification CIFAR-100 (alpha=0, 20 clients per round) FedAvgM + ASAM + SWA ACC@1-100Clients 51.58 # 1
Federated Learning CIFAR-100 (alpha=0, 20 clients per round) FedSAM + SWA ACC@1-100Clients 39.24 # 3
Federated Learning CIFAR-100 (alpha=0, 20 clients per round) FedAvg ACC@1-100Clients 38.59 # 4
Federated Learning CIFAR-100 (alpha=0.5, 10 clients per round) FedAvg ACC@1-100Clients 41.27 # 5
Federated Learning CIFAR-100 (alpha=0.5, 10 clients per round) FedASAM ACC@1-100Clients 46.58 # 3
Federated Learning CIFAR-100 (alpha=0.5, 10 clients per round) FedSAM ACC@1-100Clients 44.84 # 4
Federated Learning CIFAR-100 (alpha=0.5, 10 clients per round) FedASAM + SWA ACC@1-100Clients 48.72 # 1
Federated Learning CIFAR-100 (alpha=0.5, 10 clients per round) FedSAM + SWA ACC@1-100Clients 46.76 # 2
Federated Learning CIFAR-100 (alpha=0.5, 20 clients per round) FedAvg ACC@1-100Clients 42.17 # 5
Federated Learning CIFAR-100 (alpha=0.5, 20 clients per round) FedSAM ACC@1-100Clients 46.05 # 4
Federated Learning CIFAR-100 (alpha=0.5, 20 clients per round) FedASAM ACC@1-100Clients 47.78 # 2
Federated Learning CIFAR-100 (alpha=0.5, 20 clients per round) FedSAM + SWA ACC@1-100Clients 46.47 # 3
Federated Learning CIFAR-100 (alpha=0.5, 20 clients per round) FedASAM + SWA ACC@1-100Clients 48.27 # 1
Federated Learning CIFAR-100 (alpha=0.5, 5 clients per round) FedSAM ACC@1-100Clients 44.73 # 4
Federated Learning CIFAR-100 (alpha=0.5, 5 clients per round) FedASAM ACC@1-100Clients 45.61 # 3
Federated Learning CIFAR-100 (alpha=0.5, 5 clients per round) FedAvg ACC@1-100Clients 40.43 # 5
Federated Learning CIFAR-100 (alpha=0.5, 5 clients per round) FedSAM + SWA ACC@1-100Clients 47.96 # 2
Federated Learning CIFAR-100 (alpha=0.5, 5 clients per round) FedASAM + SWA ACC@1-100Clients 49.17 # 1
Federated Learning CIFAR-100 (alpha=0, 5 clients per round) FedSAM + SWA ACC@1-100Clients 39.3 # 2
Federated Learning CIFAR-100 (alpha=0, 5 clients per round) FedAvg ACC@1-100Clients 30.25 # 5
Federated Learning CIFAR-100 (alpha=0, 5 clients per round) FedSAM ACC@1-100Clients 31.04 # 4
Federated Learning CIFAR-100 (alpha=0, 5 clients per round) FedASAM ACC@1-100Clients 36.04 # 3
Federated Learning CIFAR-100 (alpha=0, 5 clients per round) FedASAM + SWA ACC@1-100Clients 42.01 # 1
Federated Learning CIFAR-100 (alpha=1000, 10 clients per round) FedSAM ACC@1-100Clients 53.39 # 4
Federated Learning CIFAR-100 (alpha=1000, 10 clients per round) FedASAM ACC@1-100Clients 54.97 # 1
Federated Learning CIFAR-100 (alpha=1000, 10 clients per round) FedAvg ACC@1-100Clients 50.25 # 5
Federated Learning CIFAR-100 (alpha=1000, 10 clients per round) FedASAM + SWA ACC@1-100Clients 54.79 # 2
Federated Learning CIFAR-100 (alpha=1000, 10 clients per round) FedSAM + SWA ACC@1-100Clients 53.67 # 3
Federated Learning CIFAR-100 (alpha=1000, 20 clients per round) FedSAM + SWA ACC@1-100Clients 54.36 # 2
Federated Learning CIFAR-100 (alpha=1000, 20 clients per round) FedASAM + SWA ACC@1-100Clients 54.1 # 3
Federated Learning CIFAR-100 (alpha=1000, 20 clients per round) FedAvg ACC@1-100Clients 50.66 # 5
Federated Learning CIFAR-100 (alpha=1000, 20 clients per round) FedSAM ACC@1-100Clients 53.97 # 4
Federated Learning CIFAR-100 (alpha=1000, 20 clients per round) FedASAM ACC@1-100Clients 54.5 # 1
Federated Learning CIFAR-100 (alpha=1000, 5 clients per round) FedAvg ACC@1-100Clients 49.92 # 5
Federated Learning CIFAR-100 (alpha=1000, 5 clients per round) FedASAM ACC@1-100Clients 54.81 # 1
Federated Learning CIFAR-100 (alpha=1000, 5 clients per round) FedSAM ACC@1-100Clients 54.01 # 2
Federated Learning CIFAR-100 (alpha=1000, 5 clients per round) FedSAM + SWA ACC@1-100Clients 53.9 # 3
Federated Learning CIFAR-100 (alpha=1000, 5 clients per round) FedASAM + SWA ACC@1-100Clients 53.86 # 4
Federated Learning Cityscapes heterogeneous SiloBN + ASAM mIoU 49.75 # 1
Federated Learning Cityscapes heterogeneous FedSAM mIoU 41.22 # 8
Federated Learning Cityscapes heterogeneous FedAvg + SWA mIoU 42.48 # 6
Federated Learning Cityscapes heterogeneous FedSAM + SWA mIoU 43.42 # 4
Federated Learning Cityscapes heterogeneous FedASAM + SWA mIoU 43.02 # 5
Federated Learning Cityscapes heterogeneous SiloBN mIoU 45.96 # 3
Federated Learning Cityscapes heterogeneous SiloBN + SAM mIoU 49.1 # 2
Federated Learning Cityscapes heterogeneous FedASAM mIoU 42.27 # 7
Federated Learning Cityscapes heterogeneous FedAvg mIoU 38.65 # 9
Federated Learning Landmarks-User-160k FedSAM Acc@1-1262Clients 63.72 # 5
Federated Learning Landmarks-User-160k FedSAM + SWA Acc@1-1262Clients 68.12 # 2
Federated Learning Landmarks-User-160k FedASAM Acc@1-1262Clients 64.23 # 4
Federated Learning Landmarks-User-160k FedASAM + SWA Acc@1-1262Clients 68.32 # 1
Federated Learning Landmarks-User-160k FedAvg Acc@1-1262Clients 61.91 # 6
Federated Learning Landmarks-User-160k FedAvg + SWA Acc@1-1262Clients 67.52 # 3

Methods