Prune Once for All: Sparse Pre-Trained Language Models

10 Nov 2021  ·  Ofir Zafrir, Ariel Larey, Guy Boudoukh, Haihao Shen, Moshe Wasserblat ·

Transformer-based language models are applied to a wide range of applications in natural language processing. However, they are inefficient and difficult to deploy. In recent years, many compression algorithms have been proposed to increase the implementation efficiency of large Transformer-based models on target hardware. In this work we present a new method for training sparse pre-trained Transformer language models by integrating weight pruning and model distillation. These sparse pre-trained models can be used to transfer learning for a wide range of tasks while maintaining their sparsity pattern. We demonstrate our method with three known architectures to create sparse pre-trained BERT-Base, BERT-Large and DistilBERT. We show how the compressed sparse pre-trained models we trained transfer their knowledge to five different downstream natural language tasks with minimal accuracy loss. Moreover, we show how to further compress the sparse models' weights to 8bit precision using quantization-aware training. For example, with our sparse pre-trained BERT-Large fine-tuned on SQuADv1.1 and quantized to 8bit we achieve a compression ratio of $40$X for the encoder with less than $1\%$ accuracy loss. To the best of our knowledge, our results show the best compression-to-accuracy ratio for BERT-Base, BERT-Large, and DistilBERT.

PDF Abstract
Task Dataset Model Metric Name Metric Value Global Rank Result Benchmark
Natural Language Inference MultiNLI Dev BERT-Large-uncased-PruneOFA (90% unstruct sparse) Matched 83.74 # 2
Mismatched 84.2 # 2
Natural Language Inference MultiNLI Dev BERT-Base-uncased-PruneOFA (85% unstruct sparse, QAT Int8) Matched 81.4 # 6
Mismatched 82.51 # 5
Natural Language Inference MultiNLI Dev BERT-Base-uncased-PruneOFA (85% unstruct sparse) Matched 82.71 # 4
Mismatched 83.67 # 4
Natural Language Inference MultiNLI Dev DistilBERT-uncased-PruneOFA (90% unstruct sparse, QAT Int8) Matched 78.8 # 10
Mismatched 80.4 # 10
Natural Language Inference MultiNLI Dev DistilBERT-uncased-PruneOFA (90% unstruct sparse) Matched 80.68 # 8
Mismatched 81.47 # 8
Natural Language Inference MultiNLI Dev DistilBERT-uncased-PruneOFA (85% unstruct sparse, QAT Int8) Matched 80.66 # 9
Mismatched 81.14 # 9
Natural Language Inference MultiNLI Dev DistilBERT-uncased-PruneOFA (85% unstruct sparse) Matched 81.35 # 7
Mismatched 82.03 # 7
Natural Language Inference MultiNLI Dev BERT-Large-uncased-PruneOFA (90% unstruct sparse, QAT Int8) Matched 83.47 # 3
Mismatched 84.08 # 3
Natural Language Inference MultiNLI Dev BERT-Base-uncased-PruneOFA (90% unstruct sparse) Matched 81.45 # 5
Mismatched 82.43 # 6
Question Answering SQuAD1.1 dev BERT-Base-uncased-PruneOFA (85% unstruct sparse) EM 81.1 # 12
F1 88.42 # 14
Question Answering SQuAD1.1 dev BERT-Large-uncased-PruneOFA (90% unstruct sparse) EM 83.35 # 10
F1 90.2 # 12
Question Answering SQuAD1.1 dev DistilBERT-uncased-PruneOFA (90% unstruct sparse, QAT Int8) EM 75.62 # 25
F1 83.87 # 28
Question Answering SQuAD1.1 dev DistilBERT-uncased-PruneOFA (90% unstruct sparse) EM 76.91 # 22
F1 84.82 # 26
Question Answering SQuAD1.1 dev DistilBERT-uncased-PruneOFA (85% unstruct sparse, QAT Int8) EM 77.03 # 21
F1 85.13 # 24
Question Answering SQuAD1.1 dev DistilBERT-uncased-PruneOFA (85% unstruct sparse) EM 78.1 # 19
F1 85.82 # 21
Question Answering SQuAD1.1 dev BERT-Large-uncased-PruneOFA (90% unstruct sparse, QAT Int8) EM 83.22 # 11
F1 90.02 # 13
Question Answering SQuAD1.1 dev BERT-Base-uncased-PruneOFA (90% unstruct sparse) EM 79.83 # 14
F1 87.25 # 17
Question Answering SQuAD1.1 dev BERT-Base-uncased-PruneOFA (85% unstruct sparse, QAT Int8) EM 80.84 # 13
F1 88.24 # 15

Methods