Reading the following paper:

As LLM move toward agentic workflows and reinforcement learning (RL) scaling, the $O(T^2)$ complexity and linear Key-Value (KV) cache growth of standard Softmax attention have become prohibitive bottlenecks. While linear attention offers $O(T)$ efficiency, it has historically failed to match the recall and “copying” capabilities of full attention.

Kimi Linear introduces Kimi Delta Attention (KDA), a hardware-optimized linear module that utilizes fine-grained gating. By interleaving KDA with standard Multi-Head Latent Attention (MLA) in a 3:1 ratio, the model achieves superior performance across short, long, and RL benchmarks while reducing KV cache usage by up to 75% and increasing decoding throughput by 6$\times$ at 1M context lengths.


Technical Deep Dive: Kimi Delta Attention (KDA)

The core contribution is KDA, which improves upon the “Gated DeltaNet” (GDN) and Mamba architectures.

A. The Math: Fine-Grained Decay

Standard linear attention (like standard DeltaNet or Mamba2) often uses coarse, head-wise gating. KDA introduces channel-wise fine-grained decay. The state update rule is: \(S_t = (I - \beta_t k_t k_t^\top) \text{Diag}(\alpha_t) S_{t-1} + \beta_t k_t v_t^\top\) Here, $\text{Diag}(\alpha_t)$ allows each feature dimension to maintain an independent forgetting rate, offering precise control over the finite-state memory. This is critical for the “associative recall” and “copying” tasks where linear attention usually struggles.

B. Hardware Efficiency vs. DPLR

KDA is a specialized variant of the Diagonal-Plus-Low-Rank (DPLR) recurrence ($S_t = (D - ab^\top)S_{t-1} + \dots$).

  • The Bottleneck: General DPLR is computationally expensive and difficult to parallelize efficiently due to numerical precision issues that require secondary chunking.
  • The Kimi Solution: KDA binds the DPLR low-rank vectors $a$ and $b$ to the key vector $k$ (i.e., $a_t = \beta_t k_t$ and $b_t = k_t \odot \alpha_t$). This constraint eliminates redundant matrix multiplications and secondary chunking steps, improving operator efficiency by roughly 100% compared to general DPLR formulations.

C. Positional Encoding as Data-Dependent Decay

KDA functions as a learnable, data-dependent Positional Encoding.

  • Standard RoPE rotates keys/queries based on fixed frequencies to encode relative distance.
  • KDA’s decay mechanism ($\alpha_t$) acts similarly to a multiplicative positional encoding but relaxes the orthogonality constraint of RoPE. This allows the model to dynamically decide the “distance” or relevance of past tokens based on data content.

Architecture Design: The Hybrid Strategy

The Kimi Linear model is not purely linear; it is a Hybrid MoE (Mixture-of-Experts) architecture.

  • Hybrid Ratio: The model interleaves 3 KDA layers for every 1 Full Attention (MLA) layer. Ablation studies showed this provided the optimal trade-off between throughput and validation loss (better than 1:1 or 7:1).
  • NoPE (No Positional Encoding) for MLA: The global attention layers (MLA) do not use RoPE. The model delegates all positional awareness and recency bias to the KDA layers. This simplifies long-context training and eliminates the need for frequency base tuning (e.g., YaRN) when extending context windows.
  • Specs: The primary evaluated model has 48B total parameters (3B activated), utilizing 8 out of 64 experts.

Evaluation

Comparisons against Full Attention (MLA) and Hybrid Gated DeltaNet (GDN-H) baselines using identical training recipes (1.4T to 5.7T tokens).

  • Synthetic Tasks: KDA solved the “Palindrome” and “Associative Recall” tasks perfectly up to 2048 tokens, whereas standard Mamba2 failed, proving the necessity of the Delta rule + fine-grained decay.
  • General Performance: Kimi Linear outperformed the full-attention MLA baseline on MMLU, GSM8K, and coding benchmarks.
  • Long Context (128k+): On the RULER benchmark, Kimi Linear achieved a score of 84.3 (vs. 81.3 for MLA), showing that the hybrid design handles long-range dependencies better than pure full attention.
  • RL Scaling: In Reinforcement Learning runs (math domain), Kimi Linear showed faster convergence and higher asymptotic accuracy than MLA, validating its suitability for “test-time scaling”.

Efficiency Gains

The shift to Kimi Linear yields massive computational dividends:

  • Prefilling: Matches or slightly exceeds MLA speed at short contexts, but becomes 2.9$\times$ faster at 1M tokens.
  • Decoding: Achieves 6$\times$ higher throughput (Time Per Output Token) at 1M context compared to MLA.
  • Memory: Reduces KV cache memory footprint by 75%, allowing for significantly larger batch sizes during inference.

Insight

“Positional Encoding” and “Memory Management” are two sides of the same coin. By treating the decay matrix in linear attention as a dynamic position encoding, it removed the need for explicit RoPE in the global layers. This is a significant simplification that likely contributes to the model’s stability in extrapolation.

Think of standard Full Attention (MLA) as a lossless .bmp image of the conversation history—it keeps every pixel (token pair) perfect but becomes massive and slow to process as the image gets larger. Think of standard Linear Attention as a .jpeg compression—it is fast and small, but it gets “blurry” (lossy) and loses specific details (like exact phone numbers or code snippets). Kimi Linear acts like a smart .zip archive:

  1. KDA uses “fine-grained decay” to selectively compress the background noise while keeping the important bits in high fidelity (the linear layers).
  2. Periodically, it opens a full “window” (the 1 MLA layer) to check the raw data, ensuring nothing was lost.
  3. Because it compresses efficiently, it can read a “1-million-page book” (1M context) 6 times faster than the model trying to hold every single page in memory at once.

Derivation of KDA

Step-by-step derivation of the Kimi Delta Attention (KDA) mechanism, moving from its recurrent definition to its hardware-efficient chunk-wise formulation.

I have denoted tensor shapes using the following notation:

  • $B$: Batch size
  • $H$: Number of heads
  • $T$: Sequence length
  • $C$: Chunk size
  • $d_k$: Key dimension (head dimension)
  • $d_v$: Value dimension

1. The Core Recurrent Rule

Standard linear attention keeps a running memory state $S_t$. KDA improves this by combining the “Delta Rule” (which updates memory by subtracting the projection of the current key) with Fine-Grained Decay (channel-wise gating).

The fundamental recurrence for a single step $t$ is:

\[S_t = (I - \beta_t k_t k_t^\top) \text{Diag}(\alpha_t) S_{t-1} + \beta_t k_t v_t^\top\]

Tensor Shapes:

  • $S_t, S_{t-1} \in \mathbb{R}^{B \times H \times d_k \times d_v}$ : The recurrent memory state.
  • $k_t \in \mathbb{R}^{B \times H \times d_k \times 1}$ : The key vector.
  • $v_t \in \mathbb{R}^{B \times H \times d_v \times 1}$ : The value vector.
  • $\beta_t \in \mathbb{R}^{B \times H \times 1}$ : Scalar update strength (broadcasted).
  • $\text{Diag}(\alpha_t) \in \mathbb{R}^{B \times H \times d_k \times d_k}$ : Diagonal matrix of the fine-grained decay vector $\alpha_t \in \mathbb{R}^{d_k}$.

Insight: Unlike Mamba2 or standard DeltaNet which use scalar decay, the term $\text{Diag}(\alpha_t)$ allows each feature dimension to decay at a different rate, providing finer control over memory retention.


2. Chunk-wise Decomposition

To parallelize this on GPUs, the sequence is split into chunks of length $C$. The state at the $r$-th position within chunk $t$ ($S_r[t]$) is derived from the initial state of the chunk ($S_0[t]$).

The expansion is divided into two terms: a Decay Term ($P$) and a History Term ($H$):

\[S_r[t] = \underbrace{\left( \prod_{i=1}^r (I - \beta_i[t] k_i[t] k_i[t]^\top) \text{Diag}(\alpha_i[t]) \right)}_{P_r[t]} S_0[t] + \underbrace{\sum_{i=1}^r (\dots) \beta_i[t] k_i[t] v_i[t]^\top}_{H_r[t]}\]

Tensor Shapes:

  • $S_r[t] \in \mathbb{R}^{B \times H \times d_k \times d_v}$
  • $P_r[t] \in \mathbb{R}^{B \times H \times d_k \times d_k}$ : The cumulative decay matrix.
  • $H_r[t] \in \mathbb{R}^{B \times H \times d_k \times d_v}$ : The history accumulation matrix.

3. The WY Representation (Compact Form)

Computing the product of Householder-like matrices in $P_r[t]$ is expensive ($O(C \cdot d_k^2)$). KDA uses the WY representation to compress these updates into low-rank corrections using auxiliary vectors $w$ and $u$.

Derivation of $P_r[t]$ (Decay State):

\(P_r[t] = \text{Diag}(\gamma_r[t]) - \sum_{i=1}^r \text{Diag}(\gamma_{i \to r}[t]) k_i[t] w_i[t]^\top\)

Derivation of $H_r[t]$ (History State):

\(H_r[t] = \sum_{i=1}^r \text{Diag}(\gamma_{i \to r}[t]) k_i[t] u_i[t]^\top\)

Here, $\gamma$ represents the cumulative product of the decays $\alpha$. The auxiliary vectors $w$ and $u$ are computed recursively:

  1. Vector $w$ (Key Correction): \(w_r[t] = \beta_r[t] \left( \text{Diag}(\gamma_r[t]) k_r[t] - \sum_{i=1}^{r-1} w_i[t] (k_i[t]^\top \text{Diag}(\gamma_{i \to r}[t]) k_r[t]) \right)\)
  2. Vector $u$ (Value Correction): \(u_r[t] = \beta_r[t] \left( v_r[t] - \sum_{i=1}^{r-1} u_i[t] (k_i[t]^\top \text{Diag}(\gamma_{i \to r}[t]) k_r[t]) \right)\)

Tensor Shapes:

  • $w_r[t] \in \mathbb{R}^{B \times H \times d_k \times 1}$ : Auxiliary vector for keys.
  • $u_r[t] \in \mathbb{R}^{B \times H \times d_v \times 1}$ : Auxiliary vector for values.
  • $\gamma \in \mathbb{R}^{B \times H \times d_k}$ : Cumulative decay vector.

4. The UT Transform (Hardware Efficient Matrix Form)

To execute the above efficiently on Tensor Cores, KDA reformulates the recursive auxiliary vectors into block-matrix operations using the UT Transform. This avoids non-matmul FLOPs.

First, an inverse matrix $M[t]$ is computed to handle the interactions between keys within the chunk:

\[M[t] = \left( I + \text{StrictTril}\left( \text{Diag}(\beta[t]) (\Gamma_{1 \to C}[t] \odot K[t]) (K[t] \Gamma_{1 \to C}[t])^\top \right) \right)^{-1} \text{Diag}(\beta[t])\]

Then, the compact block representations $W[t]$ and $U[t]$ are derived:

\(W[t] = M[t] (\Gamma_{1 \to C}[t] \odot K[t])\) \(U[t] = M[t] V[t]\)

Tensor Shapes:

  • $K[t] \in \mathbb{R}^{B \times H \times C \times d_k}$ : Stacked keys for the chunk.
  • $M[t] \in \mathbb{R}^{B \times H \times C \times C}$ : The interaction matrix (lower triangular).
  • $W[t] \in \mathbb{R}^{B \times H \times C \times d_k}$ : Stacked auxiliary $w$ vectors.
  • $U[t] \in \mathbb{R}^{B \times H \times C \times d_v}$ : Stacked auxiliary $u$ vectors.

Constraint Insight: KDA binds the variables of the general DPLR form (Diagonal Plus Low Rank) such that $a_t = \beta_t k_t$ and $b_t = k_t \odot \alpha_t$. This constraint reduces the number of secondary chunking steps required for numerical stability from 4 to 2, improving operator efficiency by $\approx 100\%$ compared to standard DPLR.


5. Final State Update & Output

Finally, the state is updated for the next chunk, and the output $O[t]$ is computed via Inter-Chunk (using state) and Intra-Chunk (using current block) parts.

State Update: \(S[t+1] = \text{Diag}(\gamma_C[t]) S[t] + (\Gamma_{i \to C}[t] \odot K[t])^\top (U[t] - W[t]S[t])\)

Output Computation ($O[t]$): \(O[t] = \underbrace{(\Gamma_{1 \to C}[t] \odot Q[t]) S[t]}_{\text{Inter-Chunk}} + \underbrace{\text{Tril}(\dots)(U[t] - W[t]S[t])}_{\text{Intra-Chunk}}\)

Tensor Shapes:

  • $S[t+1] \in \mathbb{R}^{B \times H \times d_k \times d_v}$ : Updated recurrent state.
  • $O[t] \in \mathbb{R}^{B \times H \times C \times d_v}$ : Output for the current chunk.

Tensor Shapes

Tensor shapes used in the 48B parameter Kimi Linear model:

1. Fixed Internal Dimensions

These dimensions are constant across the architecture and define the “resolution” of the attention mechanism:

  • Head Dimension ($d_k, d_v$): 128 The report explicitly states that the key ($d_k$) and value ($d_v$) head dimensions are set to 128 for all experiments.
  • Chunk Size ($C$): 64 For the hardware-efficient chunkwise algorithm (used during prefilling and training), the sequence is split into blocks of 64 tokens. This size is chosen to align with GPU tensor core operations.
  • Memory State Size ($S_t$): $128 \times 128$ Because KDA is a linear attention mechanism, it maintains a fixed-size recurrent state regardless of sequence length. With $d_k=d_v=128$, the state matrix per head is exactly 128 $\times$ 128.

2. Concrete Tensor Shapes (Per Chunk)

When processing a single chunk (block) of text during training or prefilling, the tensors involved in the KDA kernel have the following shapes.

Assume a Batch size ($B$) and Number of Heads ($H$).

Tensor Name Symbol Concrete Shape ($B \times H \times \dots$) Description
Input Key/Value $K[t], V[t]$ $B \times H \times \mathbf{64} \times \mathbf{128}$ The keys/values for one chunk.
Decay Gate $\text{Diag}(\alpha)$ $B \times H \times \mathbf{64} \times \mathbf{128}$ The fine-grained decay (channel-wise).
Recurrent State $S[t]$ $B \times H \times \mathbf{128} \times \mathbf{128}$ The “memory” passed between chunks.
Inter-Chunk State $W[t], U[t]$ $B \times H \times \mathbf{64} \times \mathbf{128}$ Auxiliary vectors for the compact form.
Block Matrix $M[t]$ $B \times H \times \mathbf{64} \times \mathbf{64}$ The internal interaction matrix (Inverse).

3. Variable Dimensions

While the internal kernel shapes are fixed, the external dimensions vary based on the specific experiment:

  • Sequence Length ($T$):
    • Training: 4,096 tokens.
    • Inference: Up to 1,000,000 (1M) tokens.
  • Number of Heads ($H$):
    • In the scaling law experiments (smaller models), $H$ ranged from 16 to 24 heads.
    • In the main 48B Mixture-of-Experts (MoE) model, the head configuration matches the “Moonlight” architecture, though the specific head count is implied to be significantly larger to accommodate the parameter count.

Insight: Why these numbers?

The choice of 128 for the head dimension and 64 for the chunk size is not arbitrary.

  1. Hardware Alignment: A dimension of 128 is a multiple of 32 and 64, which aligns perfectly with NVIDIA GPU warp sizes and Tensor Core matrix multiplication instruction shapes (e.g., $16\times8\times16$), maximizing floating-point operations per second (FLOPS).
  2. State Capacity: A state size of $128 \times 128$ provides a “memory capacity” of 16,384 parameters per head. This is large enough to capture complex dependencies (unlike smaller dimensions in early RNNs) but small enough to fit in the fast SRAM (L1/Shared Memory) of a GPU during kernel execution.

KV Cache Implementations

Implementing the “KV cache” for Kimi Delta Attention (KDA) during inference differs fundamentally from standard Transformers. Instead of a linearly growing cache of history tensors, KDA maintains a fixed-size recurrent state.

Step-by-step implementation guide for KDA inference (decoding):

1. The Data Structure: Fixed-Size State

Unlike standard attention which stores a cache of shape [Batch, Heads, Seq_Len, Head_Dim], KDA compresses the history into a fixed matrix $S$.

  • Tensor Shape: [Batch, Heads, d_k, d_v]
  • Dimensions: With the specific Kimi Linear settings ($d_k = d_v = 128$), the state is a 128 $\times$ 128 matrix per head.
  • Memory Footprint: This size remains constant regardless of whether the context length is 1 token or 1 million tokens, enabling the reported 75% reduction in memory usage.

2. The Algorithm: Recurrent Update Step

During decoding (autoregressive generation), you switch from the parallel “chunkwise” algorithm used in prefilling/training to the Recurrent Form (Eq. 1 in the report).

For a single step $t$, with inputs $q_t, k_t, v_t$ and gates $\alpha_t, \beta_t$:

Step A: Decay the State

First, apply the fine-grained channel-wise decay. This corresponds to the $\text{Diag}(\alpha_t)$ term. \(S' = \text{Diag}(\alpha_t) S_{t-1}\)

  • Implementation: Element-wise multiplication of the state columns (or rows, depending on layout) by the vector $\alpha_t$.
  • Note: $\alpha_t \in^{d_k}$ allows different feature channels to have different retention spans.
Step B: Compute the Delta (Update)

KDA uses a “Delta Rule” update, which can be factored efficiently to avoid full matrix multiplication. The update rule is: \(S_t = (I - \beta_t k_t k_t^\top) S' + \beta_t k_t v_t^\top\) To implement this efficiently (reducing FLOPs), factor out $\beta_t k_t$: \(S_t = S' + \beta_t k_t (v_t^\top - k_t^\top S')\)

Efficient Operations:

  1. Project Key: Compute vector $p = k_t^\top S’$ (Vector-Matrix multiplication, $1 \times d_v$).
  2. Compute Residual: Compute $r = v_t^\top - p$ (The difference between the actual value and the recalled value).
  3. Update State: $S_t = S’ + (\beta_t k_t) \otimes r$ (Outer product addition).
Step C: Compute Output

Once the state is updated (or before, depending on specific causal masking definitions, though Eq. 1 implies updated state is used), project the query against the memory: \(o_t = S_t^\top q_t\)

  • Gating: The final output is then passed through a Grouped-RMSNorm and a Sigmoid output gate as defined in Equation 10.

3. PyTorch-style Pseudo-Code

Here is how you would implement the inference step for a single head in a batch:

def kda_inference_step(state, q, k, v, alpha, beta):
    """
    Args:
        state: [Batch, D, D] (Recurrent Memory S)
        q, k, v: [Batch, D, 1]
        alpha: [Batch, D, 1] (Fine-grained decay)
        beta: [Batch, 1, 1] (Update strength)
    Returns:
        output: [Batch, D, 1]
        new_state: [Batch, D, D]
    """
    # 1. Decay the state (Channel-wise)
    # Equivalent to Diag(alpha) @ State
    # Broadcasting alpha across the second dim
    state_decayed = state * alpha 
    
    # 2. Compute the projection (k^T @ S_decayed)
    # Result shape: [Batch, 1, D]
    k_projection = torch.matmul(k.transpose(1, 2), state_decayed)
    
    # 3. Compute the value residual (v^T - k^T @ S)
    # Result shape: [Batch, 1, D]
    v_residual = v.transpose(1, 2) - k_projection
    
    # 4. Update the state
    # S_new = S_decayed + beta * (k @ v_residual)
    # Using outer product for k and residual
    update_term = beta * torch.matmul(k, v_residual)
    new_state = state_decayed + update_term
    
    # 5. Compute Output (S^T @ q) -> Transpose S is standard for retrieval
    output = torch.matmul(new_state.transpose(1, 2), q)
    
    return output, new_state

4. Hybrid Architecture Consideration

It is crucial to note that Kimi Linear is a Hybrid Architecture (3 KDA layers : 1 MLA layer).

  • KDA Layers: Use the fixed-state implementation described above.
  • MLA Layers: You must still maintain a standard KV cache (e.g., Ring buffer or PagedAttention) for these specific layers.
  • Result: The total cache size is reduced by roughly 75% because 3 out of every 4 layers do not require a growing cache.

Summary of Differences

Feature Standard KV Cache KDA Inference State
Shape [B, H, T, D] (Grows with T) [B, H, D, D] (Fixed)
Operation Concatenate new $k, v$ Mathematical Update (Decay + Add)
Memory Linear $O(T)$ Constant $O(1)$
Speed Slows down as $T$ increases Constant speed per token