Recurrent Neural Networks

SRU++ is a self-attentive recurrent unit that combines fast recurrence and attention for sequence modeling, extending the SRU unit. The key modification of SRU++ is to incorporate more expressive non-linear operations into the recurrent network. Specifically, given the input sequence represented as a matrix $\mathbf{X} \in \mathbb{R}^{L \times d}$, the attention component computes the query, key and value representations using the following multiplications,

$$ \mathbf{Q} =\mathbf{W}^{q} \mathbf{X}^{\top} $$

$$ \mathbf{K} =\mathbf{W}^{k} \mathbf{Q} \ $$

$$ \mathbf{V} =\mathbf{W}^{v} \mathbf{Q} $$

where $\mathbf{W}^{q} \in \mathbb{R}^{d^{\prime} \times d}, \mathbf{W}^{k}, \mathbf{W}^{v} \in \mathbb{R}^{d^{\prime} \times d^{\prime}}$ are model parameters. $d^{\prime}$ is the attention dimension that is typically much smaller than $d$. Note that the keys $\mathbf{K}$ and values $\mathbf{V}$ are computed using $\mathbf{Q}$ instead of $\mathbf{X}$ such that the weight matrices $\mathbf{W}^{k}$ and $\mathbf{W}^{v}$ are significantly smaller.

Next, we compute a weighted average output $\mathbf{A} \in \mathbb{R}^{d^{\prime} \times L}$ using scaled dot-product attention:

$$ \mathbf{A}^{\top}=\operatorname{softmax}\left(\frac{\mathbf{Q}^{\top} \mathbf{K}}{\sqrt{d^{\prime}}}\right) \mathbf{V}^{\top} $$

The final output $U$ required by the elementwise recurrence is obtained by another linear projection,

$$ \mathbf{U}^{\top}=\mathbf{W}^{o}(\mathbf{Q}+\alpha \cdot \mathbf{A}) $$

where $\alpha \in \mathbb{R}$ is a learned scalar and $\mathbf{W}_{o} \in \mathbb{R}^{3 d \times d^{\prime}}$ is a parameter matrix. $\mathbf{Q}+\alpha \cdot \mathbf{A}$ is a residual connection which improves gradient propagation and stabilizes training. We initialize $\alpha$ to zero and as a result,

$$ \mathbf{U}^{\top}=\mathbf{W}^{o} \mathbf{Q}=\left(\mathbf{W}^{o} \mathbf{W}^{q}\right) \mathbf{X}^{\top} $$

initially falls back to a linear transformation of the input $X$ skipping the attention transformation. Intuitively, skipping attention encourages leveraging recurrence to capture sequential patterns during early stage of training. As $|\alpha|$ grows, the attention mechanism can learn long-range dependencies for the model. In addition, $\mathbf{W}^{o} \mathbf{W}^{q}$ can be interpreted as applying a matrix factorization trick with a small inner dimension $d^{\prime}<d$, reducing the total number of parameters. The Figure compares the differences of SRU, SRU with this factorization trick (but without attention), and SRU++.

The last modification is adding layer normalization to each SRU++ layer. We apply normalization after the attention operation and before the matrix multiplication with $\mathbf{W}^{o}$

$$ \mathbf{U}^{\top}=\mathbf{W}^{o} \operatorname{layernorm}(\mathbf{Q}+\alpha \cdot \mathbf{A}) $$

This implementation is post-layer normalization in which the normalization is added after the residual connection.

Source: When Attention Meets Fast Recurrence: Training Language Models with Reduced Compute

Papers


Paper Code Results Date Stars

Tasks


Task Papers Share
Language Modelling 2 28.57%
Machine Translation 2 28.57%
Automatic Speech Recognition (ASR) 1 14.29%
Speech Recognition 1 14.29%
Translation 1 14.29%

Categories