Ensemble of Averages: Improving Model Selection and Boosting Performance in Domain Generalization

21 Oct 2021  ·  Devansh Arpit, Huan Wang, Yingbo Zhou, Caiming Xiong ·

In Domain Generalization (DG) settings, models trained on a given set of training domains have notoriously chaotic performance on distribution shifted test domains, and stochasticity in optimization (e.g. seed) plays a big role. This makes deep learning models unreliable in real world settings. We first show that a simple protocol for averaging model parameters along the optimization path, starting early during training, both significantly boosts domain generalization and diminishes the impact of stochasticity by improving the rank correlation between the in-domain validation accuracy and out-domain test accuracy, which is crucial for reliable model selection. Next, we show that an ensemble of independently trained models also has a chaotic behavior in the DG setting. Taking advantage of our observation, we show that instead of ensembling unaveraged models, ensembling moving average models (EoA) from different runs does increase stability and further boosts performance. On the DomainBed benchmark, when using a ResNet-50 pre-trained on ImageNet, this ensemble of averages achieves $88.6\%$ on PACS, $79.1\%$ on VLCS, $72.5\%$ on OfficeHome, $52.3\%$ on TerraIncognita, and $47.4\%$ on DomainNet, an average of $68.0\%$, beating ERM (w/o model averaging) by $\sim 4\%$. We also evaluate a model that is pre-trained on a larger dataset, where we show EoA achieves an average accuracy of $72.7\%$, beating its corresponding ERM baseline by $5\%$.

PDF Abstract

Results from the Paper


Task Dataset Model Metric Name Metric Value Global Rank Result Benchmark
Domain Generalization DomainNet Ensemble of Averages (ResNeXt-50 32x4d) Average Accuracy 54.6 # 1
Domain Generalization DomainNet Ensemble of Averages (ResNet-50) Average Accuracy 47.4 # 2
Domain Generalization Office-Home Ensemble of Averages (ResNet-50) Average Accuracy 72.5 # 2
Domain Generalization Office-Home Ensemble of Averages (ResNeXt-50 32x4d) Average Accuracy 80.2 # 1
Domain Generalization PACS Ensemble of Averages (ResNet-50) Average Accuracy 88.6 # 3
Domain Generalization PACS Ensemble of Averages (ResNeXt-50 32x4d) Average Accuracy 93.2 # 1
Domain Generalization TerraIncognita Ensemble of Averages (ResNeXt-50 32x4d) Average Accuracy 55.2 # 1
Domain Generalization TerraIncognita Ensemble of Averages (ResNet-50) Average Accuracy 52.3 # 3
Domain Generalization VLCS Ensemble of Averages (ResNet-50) Average Accuracy 79.1 # 3
Domain Generalization VLCS Ensemble of Averages (ResNeXt-50 32x4d) Average Accuracy 80.4 # 1

Methods


No methods listed for this paper. Add relevant methods here