SMART: Robust and Efficient Fine-Tuning for Pre-trained Natural Language Models through Principled Regularized Optimization

Transfer learning has fundamentally changed the landscape of natural language processing (NLP) research. Many existing state-of-the-art models are first pre-trained on a large text corpus and then fine-tuned on downstream tasks. However, due to limited data resources from downstream tasks and the extremely large capacity of pre-trained models, aggressive fine-tuning often causes the adapted model to overfit the data of downstream tasks and forget the knowledge of the pre-trained model. To address the above issue in a more principled manner, we propose a new computational framework for robust and efficient fine-tuning for pre-trained language models. Specifically, our proposed framework contains two important ingredients: 1. Smoothness-inducing regularization, which effectively manages the capacity of the model; 2. Bregman proximal point optimization, which is a class of trust-region methods and can prevent knowledge forgetting. Our experiments demonstrate that our proposed method achieves the state-of-the-art performance on multiple NLP benchmarks.

PDF Abstract ACL 2020 PDF ACL 2020 Abstract

Results from the Paper


Task Dataset Model Metric Name Metric Value Global Rank Result Benchmark
Semantic Textual Similarity MRPC SMART-RoBERTa Large Accuracy 93.7% # 1
Natural Language Inference MultiNLI SMART-RoBERTa Large Matched 91.0 # 6
Mismatched 90.8 # 4
Natural Language Inference QNLI SMART-RoBERTa Large Accuracy 95.4% # 6
Natural Language Inference SciTail SMART-MT-DNN Accuracy 96.1 # 2
Natural Language Inference SNLI SMART-MT-DNN % Test Accuracy 91.6 # 5
Sentiment Analysis SST-2 Binary classification SMART-RoBERTa Large Accuracy 97.5 # 1
Semantic Textual Similarity STS Benchmark SMART-RoBERTa Large Pearson Correlation 0.929 # 1
Spearman Correlation 0.925 # 2
Natural Language Inference WNLI SMART-RoBERTa Large Accuracy 91.89% # 4

Methods


No methods listed for this paper. Add relevant methods here