AutoDropout: Learning Dropout Patterns to Regularize Deep Networks

5 Jan 2021  ·  Hieu Pham, Quoc V. Le ·

Neural networks are often over-parameterized and hence benefit from aggressive regularization. Conventional regularization methods, such as Dropout or weight decay, do not leverage the structures of the network's inputs and hidden states. As a result, these conventional methods are less effective than methods that leverage the structures, such as SpatialDropout and DropBlock, which randomly drop the values at certain contiguous areas in the hidden states and setting them to zero. Although the locations of dropout areas random, the patterns of SpatialDropout and DropBlock are manually designed and fixed. Here we propose to learn the dropout patterns. In our method, a controller learns to generate a dropout pattern at every channel and layer of a target network, such as a ConvNet or a Transformer. The target network is then trained with the dropout pattern, and its resulting validation performance is used as a signal for the controller to learn from. We show that this method works well for both image recognition on CIFAR-10 and ImageNet, as well as language modeling on Penn Treebank and WikiText-2. The learned dropout patterns also transfers to different tasks and datasets, such as from language model on Penn Treebank to Engligh-French translation on WMT 2014. Our code will be available.

PDF Abstract

Results from the Paper


Task Dataset Model Metric Name Metric Value Global Rank Result Benchmark
Image Classification CIFAR-10 AutoDropout Percentage correct 96.8 # 78
Image Classification CIFAR-10 WRN-28-10+AutoDropout+RandAugment Percentage correct 97.9 # 51
PARAMS 36.5M # 178
Image Classification cifar-10,4000 WRN-28-2 + UDA+AutoDropout Percentage error 4.2 # 1
Image Classification ImageNet EfficientNet-B0 Top 1 Accuracy 77.5% # 402
Image Classification ImageNet ResNet-50+AutoDropout+RandAugment Top 1 Accuracy 80.3% # 301
Hardware Burden None # 1
Operations per network pass None # 1
Image Classification ImageNet ResNet-50 Top 1 Accuracy 78.7% # 368
Image Classification ImageNet-10 ResNet-50 + UDA+AutoDropout Top 1 Accuracy 72.9 # 1
Machine Translation IWSLT2014 German-English TransformerBase + AutoDropout BLEU score 35.8 # 13
Language Modelling Penn Treebank (Word Level) Transformer-XL + AutoDropout Validation perplexity 58.1 # 19
Test perplexity 54.9 # 22
Machine Translation WMT2014 English-French TransformerBase + AutoDropout BLEU score 40 # 32
Hardware Burden None # 1
Operations per network pass None # 1

Methods