Argmax Flows: Learning Categorical Distributions with Normalizing Flows

This paper introduces a new method to define and train continuous distributions such as normalizing flows directly on categorical data, for example text and image segmentation. The generative model is defined by a composition of a normalizing flow and an argmax function. To optimize this model, we dequantize the argmax using a distribution that is a probabilistic right-inverse to the argmax. This distribution lifts the categorical data to a continuous space on which the flow can be trained. We demonstrate that applying existing dequantization techniques naïvely to categorical data leads to suboptimal solutions. In addition, the model is fast both in generative (for sampling) and inference direction (for training), as opposed to autoregressive models.

PDF Abstract

Datasets


  Add Datasets introduced or used in this paper

Results from the Paper


  Submit results from this paper to get state-of-the-art GitHub badges and help the community compare results to other papers.

Methods


No methods listed for this paper. Add relevant methods here