 Feedforward Networks

# DExTra

Introduced by Mehta et al. in DeLighT: Deep and Light-weight Transformer

DExTra, or Deep and Light-weight Expand-reduce Transformation, is a light-weight expand-reduce transformation that enables learning wider representations efficiently.

DExTra maps a $d_{m}$ dimensional input vector into a high dimensional space (expansion) and then reduces it down to a $d_{o}$ dimensional output vector (reduction) using $N$ layers of group transformations. During these expansion and reduction phases, DExTra uses group linear transformations because they learn local representations by deriving the output from a specific part of the input and are more efficient than linear transformations. To learn global representations, DExTra shares information between different groups in the group linear transformation using feature shuffling

Formally, the DExTra transformation is controlled by five configuration parameters: (1) depth $N$, (2) width multiplier $m_{w}$, (3) input dimension $d_{m}$, (4) output dimension $d_{o}$, and (5) maximum groups $g_{max}$ in a group linear transformation. In the expansion phase, DExTra projects the $d_{m}$-dimensional input to a high-dimensional space, $d_{max} = m_{w}d_{m}$, linearly using $\text{ceil}\left(\frac{N}{2}\right)$ layers. In the reduction phase, DExTra projects the $d_{max}$-dimensional vector to a $d_{o}$-dimensional space using the remaining $N -\text{ceil}\left(\frac{N}{2}\right)$ layers. Mathematically, we define the output $Y$ at each layer $l$ as:

$$\mathbf{Y}_{l} = \mathcal{F}\left(\mathbf{X}, \mathbf{W}^{l}, \mathbf{b}^{l}, g^{l}\right) \text{ if } l=1$$ $$\mathbf{Y}_{l} = \mathcal{F}\left(\mathcal{H}\left(\mathbf{X}, \mathbf{Y}^{l-1}\right), \mathbf{W}^{l}, \mathbf{b}^{l}, g^{l}\right) \text{ Otherwise }$$

where the number of groups at each layer $l$ are computed as:

$$g^{l} = \text{min}\left(2^{l-1}, g_{max}\right), 1 \leq l \leq \text{ceil}\left(N/2\right)$$ $$g^{N-l}, \text{Otherwise}$$

In the above equations, $\mathcal{F}$ is a group linear transformation function. The function $\mathcal{F}$ takes the input $\left(\mathbf{X} \text{ or } \mathcal{H}\left(\mathbf{X}, \mathbf{Y}^{l-1}\right) \right)$, splits it into $g^{l}$ groups, and then applies a linear transformation with learnable parameters $\mathbf{W}^{l}$ and bias $\mathbf{b}^{l}$ to each group independently. The outputs of each group are then concatenated to produce the final output $\mathbf{Y}^{l}$. The function $\mathcal{H}$ first shuffles the output of each group in $\mathbf{Y}^{l−1}$ and then combines it with the input $\mathbf{X}$ using an input mixer connection.

In the authors' experiments, they use $g_{max} = \text{ceil}\left(\frac{d_{m}}{32}\right)$ so that each group has at least 32 input elements. Note that (i) group linear transformations reduce to linear transformations when $g^{l} = 1$, and (ii) DExTra is equivalent to a multi-layer perceptron when $g_{max} = 1$.

#### Papers

Paper Code Results Date Stars