Improving Unsupervised Hierarchical Representation with Reinforcement Learning
Learning representations to capture the very fundamental understanding of the world is a key challenge in machine learning. The hierarchical structure of explanatory factors hidden in data is such a general representation and could be potentially achieved with a hierarchical VAE. However training a hierarchical VAE always suffers from the "posterior collapse" where the data information is hard to propagate to the higher-level latent variables hence resulting in a bad hierarchical representation. To address this issue we first analyze the shortcomings of existing methods for mitigating the "posterior collapse" from an information theory perspective then highlight the necessity of regularization for explicitly propagating data information to higher-level latent variables while maintaining the dependency between different levels. This naturally leads to formulating the inference of the hierarchical latent representation as a sequential decision process which could benefit from applying reinforcement learning (RL). Aligning RL's objective with the regularization we first introduce a "skip-generative path" to acquire a reward for evaluating the information content of an inferred latent representation and then the developed Q-value function based on it could have a consistent optimization direction of the regularization. Finally policy gradient one of the typical RL methods is employed to train a hierarchical VAE without introducing a gradient estimator. Experimental results firmly support our analysis and demonstrate that our proposed method effectively mitigates the "posterior collapse" issue learns an informative hierarchy acquires explainable latent representations and significantly outperforms other hierarchical VAE-based methods in downstream tasks.
PDF Abstract