Scalable Sinkhorn Backpropagation

29 Sep 2021  ·  Marvin Eisenberger, Aysim Toker, Laura Leal-Taixé, Florian Bernard, Daniel Cremers ·

Optimal transport has recently gained increasing attention in the context of deep learning. A major contributing factor is the line of work on smooth relaxations that make the classical optimal transport problem differentiable. The most prominent example is entropy regularized optimal transport which can be optimized efficiently via an alternating scheme of Sinkhorn projections. We thus experienced a surge of deep learning techniques that use the Sinkhorn operator to learn matchings, permutations, sorting and ranking, or to construct a geometrically motivated loss function for generative models. The prevalent approach to training such a neural network is first-order optimization by algorithmic unrolling of the forward pass. Hence, the runtime and memory complexity of the backward pass increase linearly with the number of Sinkhorn iterations. This often makes it impractical when computational resources like GPU memory are scarce. A more efficient alternative is computing the derivative of a Sinkhorn layer via implicit differentiation. Our main contribution is deriving a simple and efficient algorithm that performs this backward pass in closed form. It is based on the Sinkhorn operator in its most general form -- with learnable cost matrices and target capacities. We further provide a theoretical analysis with error bounds for approximate inputs. Finally, we demonstrate that, for a number of applications, replacing automatic differentiation with our module often improves the stability and accuracy of the obtained gradients while drastically reducing the computation cost.

PDF Abstract
No code implementations yet. Submit your code now

Datasets


  Add Datasets introduced or used in this paper

Results from the Paper


  Submit results from this paper to get state-of-the-art GitHub badges and help the community compare results to other papers.

Methods


No methods listed for this paper. Add relevant methods here