Improved transferability of self-supervised learning models through batch normalization finetuning

Abundance of unlabelled data and advances in Self-Supervised Learning (SSL) have made it the preferred choice in many transfer learning scenarios. Due to the rapid and ongoing development of SSL approaches, practitioners are now faced with an overwhelming amount of models trained for a specific task/domain, calling for a method to estimate transfer performance on novel tasks/domains. Typically, the role of such estimator is played by linear probing which trains a linear classifier on top of the frozen feature extractor. In this work we address a shortcoming of linear probing — it is not very strongly correlated with the performance of the models finetuned end-to-end— the latter often being the final objective in transfer learning— and, in some cases, catastrophically misestimates a model’s potential. We propose a way to obtain a significantly better proxy task by unfreezing and jointly finetuning batch normalization layers together with the classification head. At a cost of extra training of only 0.16% model parameters, in case of ResNet-50, we acquire a proxy task that (i) has a stronger correlation with end-to-end finetuned performance, (ii) improves the linear probing performance in the many- and few-shot learning regimes and (iii) in some cases, outperforms both linear probing and end-to-end finetuning, reaching the state-of-the-art performance on a pathology dataset. Finally, we analyze and discuss the changes batch normalization training introduces in the feature distributions that may be the reason for the improved performance. The code is available at https://github.com/vpulab/bn_finetuning.

PDF

Datasets


Results from the Paper


 Ranked #1 on Classification on MHIST (using extra training data)

     Get a GitHub badge
Task Dataset Model Metric Name Metric Value Global Rank Uses Extra
Training Data
Benchmark
Classification MHIST MoCo-v2 (ResNet-50) Accuracy 88.03 # 1
Classification MHIST Barlow Twins (ResNet-50) Accuracy 84.03 # 3
Classification MHIST SwAV (ResNet-50) Accuracy 83.21 # 4

Methods