Sparsifying Transformer Models with Trainable Representation Pooling

We propose a novel method to sparsify attention in the Transformer model by learning to select the most-informative token representations during the training process, thus focusing on the task-specific parts of an input. A reduction of quadratic time and memory complexity to sublinear was achieved due to a robust trainable top-$k$ operator. Our experiments on a challenging long document summarization task show that even our simple baseline performs comparably to the current SOTA, and with trainable pooling, we can retain its top quality, while being $1.8\times$ faster during training, $4.5\times$ faster during inference, and up to $13\times$ more computationally efficient in the decoder.

Task Dataset Model Metric Name Metric Value Global Rank Result Benchmark
Text Summarization arXiv Summarization Dataset DeepPyramidion ROUGE-1 47.15 # 2
ROUGE-2 19.99 # 2
Text Summarization arXiv Summarization Dataset Blockwise (baseline) ROUGE-1 46.85 # 3
ROUGE-2 19.39 # 3
Text Summarization Pubmed DeepPyramidion ROUGE-1 47.81 # 7
ROUGE-2 21.14 # 6