Improving Out-of-Distribution Robustness of Classifiers Through Interpolated Generative Models

29 Sep 2021  ·  Haoyue Bai, Ceyuan Yang, Yinghao Xu, S.-H. Gary Chan, Bolei Zhou ·

Out-of-distribution (OoD) generalization is one of the major challenges for deploying machine learning systems in the real world. Learning representations that disentangle the underlying structure of data is of key importance for improving OoD generalization. Recent works suggest the proprieties of disentangled representation in the latent space of GAN models. In this work, we investigate when and how GAN models can be used to improve OoD robustness in classifiers. Generative models are expected to be able to generate realistic images and increase the diversity of the training set to improve the model's ability to generalize. However, training the conventional GAN models for data augmentation preserves the correlations in the training data. This hampers training a robust classifier against distribution shifts since spurious correlations from the biased training data are unrelated to the causal features of interest. Besides, Training GAN models directly on multiple source domains are fallible and suffer from mode collapse. In this paper, we employ interpolated generative models to generate OoD samples at training time via data augmentation. Specifically, we use the StyleGAN2 model as the source of generative augmentation, which is pre-trained on one source training domain. We then fine-tune it on other source domains with frozen lower layers of the discriminator. Then, we apply linear interpolation in the parameter space of the multiple correlated networks on multiple source domains and control the augmentation in the training time via the interpolation coefficients. A style-mixing mechanism is further introduced to improve the diversity of the generated OoD samples. Our experiments show that our proposed framework explicitly increases the diversity of training domains and achieves consistent improvements over baselines on both synthesized MNIST and many real-world OoD datasets.

PDF Abstract
No code implementations yet. Submit your code now

Results from the Paper


  Submit results from this paper to get state-of-the-art GitHub badges and help the community compare results to other papers.

Methods