Pyramid Mini-Batching for Optimal Transport

29 Sep 2021  ·  Devin Guillory, Kuniaki Saito, Eric Tzeng, Yannik Pitcan, Kate Saenko, Trevor Darrell ·

Optimal transport theory provides a useful tool to measure the differences between two distributions. Aligning distributions by minimizing optimal transport distances has been shown to be effective in a variety of machine learning settings, including generative modeling and domain adaptation. However, computing optimal transport distances over large numbers of data points is very time-consuming and intractable for measuring the distances between discrete distributions with large numbers of data points. In this work we propose a geometric sampling scheme which partitions the datasets into pyramid-based encodings. Our approach, Pyramid Mini-Batching, significantly improves the quality of optimal transport approximations and downstream alignments with minimal computational overhead. We perform experiments over the Discrete Optimal Transport benchmark to demonstrate the effectiveness of this strategy over multiple established optimal transport settings and see that our approach improves estimates of OT distances by nearly $30\%$ for single pass estimation. Furthermore, we see that when attempting to minimize optimal transport distance our approach is ten times more effective than with random mini-batch sampling. To highlight the practical benefits of this approach, we use optimal transport distance in domain adaptation settings and show our approach produces state of the results on large-scale domain adaptation problems VisDA17 and DomainNet. Ablation studies indicate that our sampling approach could be combined with conventional distribution alignment approaches and over substantial improvements to their results.

PDF Abstract
No code implementations yet. Submit your code now

Datasets


Results from the Paper


  Submit results from this paper to get state-of-the-art GitHub badges and help the community compare results to other papers.

Methods


No methods listed for this paper. Add relevant methods here