Adversarial Invariant Learning

Though machine learning algorithms are able to achieve pattern recognition from the correlation between data and labels, the presence of spurious features in the data decreases the robustness of these learned relationships with respect to varied testing environments. This is known as out-of-distribution (OoD) generalization problem. Recently, invariant risk minimization (IRM) attempts to tackle this issue by penalizing predictions based on the unstable spurious features in the data collected from different environments. However, similar to domain adaptation or domain generalization, a prevalent non-trivial limitation in these works is that the environment information is assigned by human specialists i.e. a priori or determined heuristically. However, an inappropriate group partitioning can dramatically deteriorate the OoD generalization and the process is expensive and time-consuming. To deal with this issue, we propose a novel theoretically principled min-max framework to iteratively construct a worst-case splitting, i.e. creating the most challenging environment splittings for the backbone learning paradigm (e.g. IRM) to learn the robust feature representation. We also design a differentiable training strategy to facilitate the feasible gradient-based computation. Numerical experiments show that our algorithmic framework has achieved superior and stable performance in various datasets, such as Colored MNIST and Punctuated Stanford Sentiment Treebank (SST). Furthermore, we also find our algorithm to be robust even to a strong data poisoning attack. To the best of our knowledge, this is one of the first to adopt differentiable environment splitting method to enable stable predictions across environments without environment index information, which achieves the state-of-the-art performance on datasets with strong spurious correlation, such as Colored MNIST.

PDF Abstract

Datasets


  Add Datasets introduced or used in this paper

Results from the Paper


  Submit results from this paper to get state-of-the-art GitHub badges and help the community compare results to other papers.

Methods


No methods listed for this paper. Add relevant methods here