Personalized Federated Learning with Gaussian Processes

Federated learning aims to learn a global model that performs well on client devices with limited cross-client communication. Personalized federated learning (PFL) further extends this setup to handle data heterogeneity between clients by learning personalized models. A key challenge in this setting is to learn effectively across clients even though each client has unique data that is often limited in size. Here we present pFedGP, a solution to PFL that is based on Gaussian processes (GPs) with deep kernel learning. GPs are highly expressive models that work well in the low data regime due to their Bayesian nature. However, applying GPs to PFL raises multiple challenges. Mainly, GPs performance depends heavily on access to a good kernel function, and learning a kernel requires a large training set. Therefore, we propose learning a shared kernel function across all clients, parameterized by a neural network, with a personal GP classifier for each client. We further extend pFedGP to include inducing points using two novel methods, the first helps to improve generalization in the low data regime and the second reduces the computational cost. We derive a PAC-Bayes generalization bound on novel clients and empirically show that it gives non-vacuous guarantees. Extensive experiments on standard PFL benchmarks with CIFAR-10, CIFAR-100, and CINIC-10, and on a new setup of learning under input noise show that pFedGP achieves well-calibrated predictions while significantly outperforming baseline methods, reaching up to 21% in accuracy gain.

PDF Abstract NeurIPS 2021 PDF NeurIPS 2021 Abstract

Results from the Paper


 Ranked #1 on Personalized Federated Learning on CIFAR-10 (ACC@1-100Clients metric)

     Get a GitHub badge
Task Dataset Model Metric Name Metric Value Global Rank Result Benchmark
Personalized Federated Learning CIFAR-10 pFedGP-IP-compute ACC@1-50Clients 89.9 # 4
ACC@1-100Clients 88.8 # 1
ACC@1-500 86.8 # 3
Personalized Federated Learning CIFAR-10 pFedGP ACC@1-50Clients 89.2 # 3
ACC@1-100Clients 88.8 # 1
ACC@1-500 87.6 # 1
Personalized Federated Learning CIFAR-10 pFedGP-IP-data ACC@1-50Clients 88.6 # 2
ACC@1-100Clients 87.4 # 5
ACC@1-500 86.9 # 2
Personalized Federated Learning CIFAR-100 pFedGP-IP-data ACC@1-50Clients 60.2 # 3
ACC@1-100Clients 58.5 # 3
ACC@1-500 55.7 # 1
Personalized Federated Learning CIFAR-100 pFedGP-IP-compute ACC@1-50Clients 61.2 # 2
ACC@1-100Clients 59.8 # 2
ACC@1-500 49.2 # 3
Personalized Federated Learning CIFAR-100 pFedGP ACC@1-50Clients 63.3 # 1
ACC@1-100Clients 61.3 # 1
ACC@1-500 50.6 # 2

Methods