Mean-field Variational Inference via Wasserstein Gradient Flow

17 Jul 2022  ·  Rentian Yao, Yun Yang ·

Variational inference, such as the mean-field (MF) approximation, requires certain conjugacy structures for efficient computation. These can impose unnecessary restrictions on the viable prior distribution family and further constraints on the variational approximation family. In this work, we introduce a general computational framework to implement MF variational inference for Bayesian models, with or without latent variables, using the Wasserstein gradient flow (WGF), a modern mathematical technique for realizing a gradient flow over the space of probability measures. Theoretically, we analyze the algorithmic convergence of the proposed approaches, providing an explicit expression for the contraction factor. We also strengthen existing results on MF variational posterior concentration from a polynomial to an exponential contraction, by utilizing the fixed point equation of the time-discretized WGF. Computationally, we propose a new constraint-free function approximation method using neural networks to numerically realize our algorithm. This method is shown to be more precise and efficient than traditional particle approximation methods based on Langevin dynamics.

PDF Abstract

Datasets


  Add Datasets introduced or used in this paper

Results from the Paper


  Submit results from this paper to get state-of-the-art GitHub badges and help the community compare results to other papers.

Methods


No methods listed for this paper. Add relevant methods here