Rethinking Sharpness-Aware Minimization as Variational Inference

19 Oct 2022  ·  Szilvia Ujváry, Zsigmond Telek, Anna Kerekes, Anna Mészáros, Ferenc Huszár ·

Sharpness-aware minimization (SAM) aims to improve the generalisation of gradient-based learning by seeking out flat minima. In this work, we establish connections between SAM and Mean-Field Variational Inference (MFVI) of neural network parameters. We show that both these methods have interpretations as optimizing notions of flatness, and when using the reparametrisation trick, they both boil down to calculating the gradient at a perturbed version of the current mean parameter. This thinking motivates our study of algorithms that combine or interpolate between SAM and MFVI. We evaluate the proposed variational algorithms on several benchmark datasets, and compare their performance to variants of SAM. Taking a broader perspective, our work suggests that SAM-like updates can be used as a drop-in replacement for the reparametrisation trick.

PDF Abstract
No code implementations yet. Submit your code now

Results from the Paper

  Submit results from this paper to get state-of-the-art GitHub badges and help the community compare results to other papers.