Non-convex Learning via Replica Exchange Stochastic Gradient MCMC

ICML 2020  ·  Wei Deng, Qi Feng, Liyao Gao, Faming Liang, Guang Lin ·

Replica exchange Monte Carlo (reMC), also known as parallel tempering, is an important technique for accelerating the convergence of the conventional Markov Chain Monte Carlo (MCMC) algorithms. However, such a method requires the evaluation of the energy function based on the full dataset and is not scalable to big data. The na\"ive implementation of reMC in mini-batch settings introduces large biases, which cannot be directly extended to the stochastic gradient MCMC (SGMCMC), the standard sampling method for simulating from deep neural networks (DNNs). In this paper, we propose an adaptive replica exchange SGMCMC (reSGMCMC) to automatically correct the bias and study the corresponding properties. The analysis implies an acceleration-accuracy trade-off in the numerical discretization of a Markov jump process in a stochastic environment. Empirically, we test the algorithm through extensive experiments on various setups and obtain the state-of-the-art results on CIFAR10, CIFAR100, and SVHN in both supervised learning and semi-supervised learning tasks.

PDF Abstract ICML 2020 PDF

Results from the Paper


Ranked #77 on Image Classification on CIFAR-100 (using extra training data)

     Get a GitHub badge
Task Dataset Model Metric Name Metric Value Global Rank Uses Extra
Training Data
Result Benchmark
Image Classification CIFAR-10 ResNet32 with reSGHMC Percentage correct 95.35 # 123
Image Classification CIFAR-10 WRN-28-10 with reSGHMC Percentage correct 97.42 # 78
PARAMS 36.5M # 221
Image Classification CIFAR-10 ResNet20 with reSGHMC Percentage correct 94.62 # 138
Image Classification CIFAR-10 ResNet56 with reSGHMC Percentage correct 96.12 # 111
Image Classification CIFAR-10 WRN-16-8 with reSGHMC Percentage correct 96.87 # 90
Image Classification CIFAR-100 WRN-28-10 with reSGHMC Percentage correct 84.38 # 77
Image Classification CIFAR-100 ResNet20 with reSGHMC Percentage correct 74.14 # 150
Image Classification CIFAR-100 ResNet32 with reSGHMC Percentage correct 76.55 # 142
Image Classification CIFAR-100 ResNet56 with reSGHMC Percentage correct 80.14 # 127
Image Classification CIFAR-100 WRN-16-8 with reSGHMC Percentage correct 82.95 # 94

Methods