STraTA: Self-Training with Task Augmentation for Better Few-shot Learning

Despite their recent successes in tackling many NLP tasks, large-scale pre-trained language models do not perform as well in few-shot settings where only a handful of training examples are available. To address this shortcoming, we propose STraTA, which stands for Self-Training with Task Augmentation, an approach that builds on two key ideas for effective leverage of unlabeled data. First, STraTA uses task augmentation, a novel technique that synthesizes a large amount of data for auxiliary-task fine-tuning from target-task unlabeled texts. Second, STraTA performs self-training by further fine-tuning the strong base model created by task augmentation on a broad distribution of pseudo-labeled data. Our experiments demonstrate that STraTA can substantially improve sample efficiency across 12 few-shot benchmarks. Remarkably, on the SST-2 sentiment dataset, STraTA, with only 8 training examples per class, achieves comparable results to standard fine-tuning with 67K training examples. Our analyses reveal that task augmentation and self-training are both complementary and independently effective.

PDF Abstract EMNLP 2021 PDF EMNLP 2021 Abstract
Task Dataset Model Metric Name Metric Value Global Rank Result Benchmark
Few-Shot NLI QNLI (8 training examples per class) BERT-Large Accuracy 64.4±6.1 # 10
Few-Shot NLI QNLI (8 training examples per class) BERT-Base + MNLI Accuracy 62.8±5.1 # 11
Few-Shot NLI QNLI (8 training examples per class) BERT-Base + LMFT Accuracy 57.6±9.1 # 13
Few-Shot NLI QNLI (8 training examples per class) BERT-Base Accuracy 59.0±10.9 # 12
Few-Shot NLI QNLI (8 training examples per class) BERT-Large + LMFT Accuracy 52.2±1.6 # 14
Few-Shot NLI QNLI (8 training examples per class) BERT-Large + MNLI Accuracy 64.5±4.4 # 9
Few-Shot NLI QNLI (8 training examples per class) BERT-Large + ST Accuracy 85.4±1.7 # 3
Few-Shot NLI QNLI (8 training examples per class) BERT-Large + TA Accuracy 71.5±4.0 # 7
Few-Shot NLI QNLI (8 training examples per class) BERT-Large + MNLI + ST Accuracy 86.1±1.1 # 2
Few-Shot NLI QNLI (8 training examples per class) BERT-Large + STraTA Accuracy 86.4±0.8 # 1
Few-Shot NLI QNLI (8 training examples per class) BERT-Base + STraTA Accuracy 82.1±0.5 # 4
Few-Shot NLI QNLI (8 training examples per class) BERT-Base + MNLI + ST Accuracy 81.5±1.2 # 5
Few-Shot NLI QNLI (8 training examples per class) BERT-Base + ST Accuracy 71.6±11.3 # 6
Few-Shot NLI QNLI (8 training examples per class) BERT-Base + TA Accuracy 70.1±3.4 # 8
Few-Shot NLI SNLI (8 training examples per class) BERT-Base + ST Accuracy 65±5.8 # 6
Few-Shot NLI SNLI (8 training examples per class) BERT-Large + STraTA Accuracy 87.3±0.3 # 1
Few-Shot NLI SNLI (8 training examples per class) BERT-Base + MNLI + ST Accuracy 83.2±0.3 # 4
Few-Shot NLI SNLI (8 training examples per class) BERT-Base + TA Accuracy 83.3±0.8 # 3
Few-Shot NLI SNLI (8 training examples per class) BERT-Base + MNLI Accuracy 75.2±5.7 # 5
Few-Shot NLI SNLI (8 training examples per class) BERT-Base + LMFT Accuracy 45.2±3.9 # 7
Few-Shot NLI SNLI (8 training examples per class) BERT-Base Accuracy 43.7±2.2 # 8
Few-Shot NLI SNLI (8 training examples per class) BERT-Base + STraTA Accuracy 85.7±0.2 # 2

Methods