SRU++ is a self-attentive recurrent unit that combines fast recurrence and attention for sequence modeling, extending the SRU unit. The key modification of SRU++ is to incorporate more expressive non-linear operations into the recurrent network. Specifically, given the input sequence represented as a matrix $\mathbf{X} \in \mathbb{R}^{L \times d}$, the attention component computes the query, key and value representations using the following multiplications,
$$ \mathbf{Q} =\mathbf{W}^{q} \mathbf{X}^{\top} $$
$$ \mathbf{K} =\mathbf{W}^{k} \mathbf{Q} \ $$
$$ \mathbf{V} =\mathbf{W}^{v} \mathbf{Q} $$
where $\mathbf{W}^{q} \in \mathbb{R}^{d^{\prime} \times d}, \mathbf{W}^{k}, \mathbf{W}^{v} \in \mathbb{R}^{d^{\prime} \times d^{\prime}}$ are model parameters. $d^{\prime}$ is the attention dimension that is typically much smaller than $d$. Note that the keys $\mathbf{K}$ and values $\mathbf{V}$ are computed using $\mathbf{Q}$ instead of $\mathbf{X}$ such that the weight matrices $\mathbf{W}^{k}$ and $\mathbf{W}^{v}$ are significantly smaller.
Next, we compute a weighted average output $\mathbf{A} \in \mathbb{R}^{d^{\prime} \times L}$ using scaled dot-product attention:
$$ \mathbf{A}^{\top}=\operatorname{softmax}\left(\frac{\mathbf{Q}^{\top} \mathbf{K}}{\sqrt{d^{\prime}}}\right) \mathbf{V}^{\top} $$
The final output $U$ required by the elementwise recurrence is obtained by another linear projection,
$$ \mathbf{U}^{\top}=\mathbf{W}^{o}(\mathbf{Q}+\alpha \cdot \mathbf{A}) $$
where $\alpha \in \mathbb{R}$ is a learned scalar and $\mathbf{W}_{o} \in \mathbb{R}^{3 d \times d^{\prime}}$ is a parameter matrix. $\mathbf{Q}+\alpha \cdot \mathbf{A}$ is a residual connection which improves gradient propagation and stabilizes training. We initialize $\alpha$ to zero and as a result,
$$ \mathbf{U}^{\top}=\mathbf{W}^{o} \mathbf{Q}=\left(\mathbf{W}^{o} \mathbf{W}^{q}\right) \mathbf{X}^{\top} $$
initially falls back to a linear transformation of the input $X$ skipping the attention transformation. Intuitively, skipping attention encourages leveraging recurrence to capture sequential patterns during early stage of training. As $|\alpha|$ grows, the attention mechanism can learn long-range dependencies for the model. In addition, $\mathbf{W}^{o} \mathbf{W}^{q}$ can be interpreted as applying a matrix factorization trick with a small inner dimension $d^{\prime}<d$, reducing the total number of parameters. The Figure compares the differences of SRU, SRU with this factorization trick (but without attention), and SRU++.
The last modification is adding layer normalization to each SRU++ layer. We apply normalization after the attention operation and before the matrix multiplication with $\mathbf{W}^{o}$
$$ \mathbf{U}^{\top}=\mathbf{W}^{o} \operatorname{layernorm}(\mathbf{Q}+\alpha \cdot \mathbf{A}) $$
This implementation is post-layer normalization in which the normalization is added after the residual connection.
Source: When Attention Meets Fast Recurrence: Training Language Models with Reduced ComputePaper | Code | Results | Date | Stars |
---|
Task | Papers | Share |
---|---|---|
Language Modelling | 2 | 28.57% |
Machine Translation | 2 | 28.57% |
Automatic Speech Recognition (ASR) | 1 | 14.29% |
Speech Recognition | 1 | 14.29% |
Translation | 1 | 14.29% |
Component | Type |
|
---|---|---|
Layer Normalization
|
Normalization | |
Residual Connection
|
Skip Connections | |
Scaled Dot-Product Attention
|
Attention Mechanisms |