Domain Generalization by Mutual-Information Regularization with Pre-trained Models

21 Mar 2022  ·  Junbum Cha, Kyungjae Lee, Sungrae Park, Sanghyuk Chun ·

Domain generalization (DG) aims to learn a generalized model to an unseen target domain using only limited source domains. Previous attempts to DG fail to learn domain-invariant representations only from the source domains due to the significant domain shifts between training and test domains. Instead, we re-formulate the DG objective using mutual information with the oracle model, a model generalized to any possible domain. We derive a tractable variational lower bound via approximating the oracle model by a pre-trained model, called Mutual Information Regularization with Oracle (MIRO). Our extensive experiments show that MIRO significantly improves the out-of-distribution performance. Furthermore, our scaling experiments show that the larger the scale of the pre-trained model, the greater the performance improvement of MIRO. Source code is available at https://github.com/kakaobrain/miro.

PDF Abstract
Task Dataset Model Metric Name Metric Value Global Rank Uses Extra
Training Data
Result Benchmark
Domain Generalization DomainNet MIRO (RegNetY-16GF, SWAD) Average Accuracy 60.7 # 6
Domain Generalization DomainNet MIRO (ResNet-50, SWAD) Average Accuracy 47.0 # 17
Domain Generalization Office-Home MIRO (ResNet-50, SWAD) Average Accuracy 72.4 # 19
Domain Generalization Office-Home MIRO (RegNetY-16GF, SWAD) Average Accuracy 83.3 # 9
Domain Generalization PACS MIRO (RegNetY-16GF, SWAD) Average Accuracy 96.8 # 5
Domain Generalization PACS MIRO (ResNet-50, SWAD) Average Accuracy 88.4 # 22
Domain Generalization TerraIncognita MIRO (RegNetY-16GF, SWAD) Average Accuracy 64.3 # 2
Domain Generalization TerraIncognita MIRO (ResNet-50, SWAD) Average Accuracy 52.9 # 12
Domain Generalization VLCS MIRO (ResNet-50, SWAD) Average Accuracy 79.6 # 18
Domain Generalization VLCS MIRO (RegNetY-16GF, SWAD) Average Accuracy 81.7 # 11

Methods


No methods listed for this paper. Add relevant methods here