CAGE: Probing Causal Relationships in Deep Generative Models

29 Sep 2021  ·  Joey Bose, Ricardo Pio Monti, Aditya Grover ·

Deep generative models excel at generating complex, high-dimensional data, often exhibiting impressive generalization beyond the training distribution. The learning principle for these models is however purely based on statistical objectives and it is unclear to what extent such models have internalized the causal relationships present in the training data, if at all. With increasing real-world deployments, such a causal understanding of generative models is essential for interpreting and controlling their use in high-stake applications that require synthetic data generation. We propose CAGE, a framework for inferring the cause-effect relationships governing deep generative models. CAGE employs careful geometrical manipulations within the latent space of a generative model for generating counterfactuals and estimating unit-level generative causal effects. CAGE does not require any modifications to the training procedure and can be used with any existing pretrained latent variable model. Moreover, the pretraining can be completely unsupervised and does not require any treatment or outcome labels. Empirically, we demonstrate the use of CAGE for: (a) inferring cause-effect relationships within a deep generative model trained on both synthetic and high resolution images, and (b) guiding data augmentations for robust classification where CAGE achieves improvements over current default approaches on image datasets.

PDF Abstract

Datasets


Results from the Paper


  Submit results from this paper to get state-of-the-art GitHub badges and help the community compare results to other papers.

Methods