Paper: Residual Matrix Transformers: Scaling the Size of the Residual Stream.

Core Motivation

The current paradigm of LLM scaling (Kaplan et al., 2020) relies heavily on expanding model size, data, and compute. However, the AI field is rapidly approaching the physical limits of available data and energy. While sparse modifications like Mixture of Experts (MoE) scale parameters without scaling per-example compute, the Residual Matrix Transformer (RMT) proposes an entirely new axis for scaling: the residual stream size.

In a standard transformer, the residual stream is a vector of dimension $D$ that acts as a memory bus across layers. Scaling $D$ linearly scales the size of all parameter matrices, inflating parameter and FLOP counts quadratically. The RMT solves this by replacing the residual stream vector with an outer product memory matrix, decoupling the residual stream’s bandwidth from the model’s compute and parameter footprint.

Mathematical Framework: Outer Product Memory

The architecture builds on outer product memory stores (Kohonen, 1972; Anderson, 1972).

Given a set of $N$ key vectors $q^{(p)} \in \mathbb{R}^{D_k}$ and data vectors $x^{(p)} \in \mathbb{R}^{D_v}$ for $p = 1, \ldots, N$, an outer product store $M \in \mathbb{R}^{D_k \times D_v}$ is constructed by summing their outer products:

\[M = \text{Norm}\left(\sum_{p=1}^N q^{(p)} \otimes x^{(p)}\right)\]

where $u \otimes v = uv^T$, and Norm is LayerNorm.

To retrieve a specific data vector $x^{(r)}$ from $M$, we perform a tensor contraction over the first dimension using its associated key vector $q^{(r)}$:

\[x^{(r)} \approx q^{(r)} \cdot_1 M\]

RMT Architecture

In the RMT, the batched residual stream for $N$ tokens is represented as a tensor $X \in \mathbb{R}^{D_k \times D_v \times N}$, rather than the standard matrix $X \in \mathbb{R}^{D \times N}$.

Attention Layer

In a standard transformer, features are retrieved using linear transformations (e.g., $Q^{(h)} = W_Q^{(h)} X$). In the RMT, the expensive $W_Q, W_K, W_V \in \mathbb{R}^{D_h \times D}$ weight matrices are completely removed and replaced by learned key vectors $r_Q^{(h)}, r_K^{(h)}, r_V^{(h)} \in \mathbb{R}^{D_k}$.

The attention inputs are retrieved via tensor contraction:

\[Q^{(h)} = r_Q^{(h)} \cdot_1 X\] \[K^{(h)} = r_K^{(h)} \cdot_1 X\] \[V^{(h)} = r_V^{(h)} \cdot_1 X\]

These resulting matrices belong to $\mathbb{R}^{D_v \times N}$ (where $D_v$ acts as the attention head dimension $D_h$). Standard Head Attention (SHA) is applied normally, and the output is written back into the residual matrix using an output key vector $w_O^{(h)} \in \mathbb{R}^{D_k}$:

\[MHA(X) = \sum_{h=1}^R w_O^{(h)} \otimes SHA(Q^{(h)}, K^{(h)}, V^{(h)})\]

Feed-Forward Layer (FFN)

Unlike attention, the FFN retains its standard linear transformations $W_1$ and $W_2$, as evidence suggests these matrices store critical factual information rather than just routing it. The RMT uses key vector “adapters”: it retrieves $R$ data vectors from the matrix, concatenates them for the standard FFN operation, and then splits (un-vecs) the output to store it back into the residual matrix via outer products with $w_{FF}^{(h)}$.

Variance Propagation

For deep networks to initialize and train stably, the mean and variance of activations and gradients must propagate effectively (Glorot & Bengio, 2010). The paper provides a closed-form derivation proving that outer product storage and retrieval maintain healthy variance.

Let the forward storage operation for a single token be $X_{out} = \sum_{h=1}^R w^{(h)} \otimes x_{in}^{(h)}$, where weights are initialized independently with mean 0.

\[E[X_{out, ij}] = \sum_{h=1}^R E[w_i^{(h)}] E[x_{in, j}^{(h)}] = 0\]

The variance propagates as:

\[Var(X_{out, ij}) = \sum_{h=1}^R Var\left(w_i^{(h)} x_{in, j}^{(h)}\right)\]

Assuming independence and $\mu_w = 0$:

\[Var(X_{out, ij}) = R \sigma_w^2 (\sigma_{x_{in}}^2 + \mu_{x_{in}}^2)\]

By choosing standard initialization dimensions, the ratio $\frac{\sigma_{x_{out}}^2}{\sigma_{x_{in}}^2}$ can be kept close to 1, avoiding vanishing or exploding gradients.

Scaling Performance

The replacement of standard weight matrices with vectors transforms the economics of the network:

  • Cost of Scaling: Increasing the residual stream size by 100% in a standard transformer yields ~94% more FLOPs and 100% more parameters. In the RMT, increasing residual matrix capacity ($D_k$) by 100% increases both parameters and FLOPs by < 1%.
  • Efficiency: To reach the identical target loss, the RMT uses 58% fewer FLOPs, 25% fewer parameters, and 41% fewer training tokens compared to a Chinchilla-optimal baseline transformer.
  • Zero-Shot Dominance: Evaluated on LAMBADA, PIQA, and ARC, an RMT trained on 28% fewer FLOPs outperformed a standard transformer that was 33% larger.
  • “Free” Scaling Axis: When holding parameter count, dataset size, and compute budget constant, expanding the residual stream size $D_k$ monotonically decreased validation loss.

Caveats

  • Memory: Roughly equivalent during training (larger residual activations for gradient checkpointing offset by fewer model parameters), but strictly more efficient during inference.
  • Wall-Clock Time: In current PyTorch/JAX ecosystems, highly optimized GEMM kernels run significantly faster than unoptimized tensor contractions. Despite needing dramatically fewer FLOPs, the RMT takes ~43% longer per training step. A custom CUDA kernel for contracting over small key vectors could bridge this gap.

Summary

The RMT unlocks a new “free” scaling dimension by replacing the residual stream vector with an outer product memory matrix. Weight matrices in attention are replaced by learned key vectors, making residual stream scaling nearly cost-free in parameters and FLOPs. This promises significant reductions in the energy and data required for frontier LLM training, pending kernel-level optimization to close the wall-clock gap.