Rollout Routing Replay: Stabilizing MoE Reinforcement Learning
Reading the following paper:
1. The Core Problem: Training-Inference Inconsistency in MoEs
Reinforcement Learning (RL) for Large Language Models (LLMs) typically relies on separated engines: an inference engine (e.g., SGLang) for generating rollouts and a training engine (e.g., Megatron) for policy updates. While this separation causes minor numeric discrepancies in dense models, it causes catastrophic instability in MoE models.
The root cause is the discontinuity of the routing mechanism. In MoE models, slight perturbations in input or non-determinism in compute kernels can cause the router to select entirely different experts during training than were used during inference.
- Quantified Discrepancy: ~10% of routers select different experts for the same token during training vs. inference.
- Sequence Impact: 94% of tokens in a sequence select a different expert in at least one layer during the forward pass.
- Consequence: This creates “extreme tokens” where probability ratios between training and inference differ by orders of magnitude, corrupting the importance sampling ratios used in PPO-style algorithms (GRPO, etc.) and leading to model collapse.
2. The Solution: Rollout Routing Replay (R3)
R3, a method to enforce consistency by recording the routing decisions made during generation and “replaying” them during the training backward pass.
Technical Implementation: The method distinguishes between the selection of experts and the weighting of experts to ensure gradients still flow correctly.
- Inference (Rollout): The router calculates logits $s_{infer}$ and generates a binary mask $I_{infer}$ using a Top-K operation. This mask records exactly which experts were active.
- Training (Replay): Instead of recalculating the Top-K mask based on training logits $s_{train}$ (which might differ slightly due to engine noise), the model forces the use of $I_{infer}$.
- Gradient Preservation: Crucially, while the mask is fixed to the inference path, the gating weights are recomputed using the training logits to preserve the computation graph for backpropagation: \(g_{replay,i} = \frac{I_{infer,i} \exp(s_{train,i})}{\sum_{j=1}^M I_{infer,j} \exp(s_{train,j})}\) This allows the router’s linear weights $W_r$ to be optimized while guaranteeing that the experts processed are identical to those that generated the data.
System Optimization: To handle the computational overhead, R3 integrates with KVCache prefix caching. Since routing decisions for prefix tokens are deterministic given the same weights, the routing masks are cached alongside KV pairs. This makes the method efficient for multi-turn RL tasks (e.g., agents using tools), keeping latency overhead below 3%.
3. Key Findings and Experimental Results
R3 on Qwen3-30B-A3B (an MoE model) across math (AIME, MATH500) and code (SWE-bench) benchmarks.
- Reduction in Divergence: R3 reduces the KL divergence between training and inference distributions from $1.5 \times 10^{-3}$ to $7.5 \times 10^{-4}$, effectively bringing MoE consistency to the level of dense models.
- Prevention of Collapse: In “single mini-step” RL settings (which are more aggressive), standard GRPO and TIS (Truncated Importance Sampling) baselines collapsed (diverged) quickly. R3 maintained stability throughout training.
- Performance Gains:
- Math: GRPO+R3 outperformed the GSPO baseline by 1.29 points.
- Code Agents: On SWE-bench Verified, GRPO+R3 achieved a Pass@1 of 38.6, compared to 31.8 for vanilla GRPO, which collapsed early.
- Optimization Dynamics: Models trained with R3 showed lower gradient norms, earlier entropy growth (indicating faster exploration), and smoother sequence length increases compared to baselines.
4. Insights
- Addressing the Root vs. Symptoms: Previous methods like TIS (Truncated Importance Sampling) try to manage the result of the discrepancy by clipping high probability ratios. R3 addresses the root cause (expert mismatch). This explains why adding TIS to R3 yielded negligible benefits—R3 had already removed the variance TIS was designed to mitigate.
- Divergence from “Recompute” Logic: Standard Recompute Routing Replay (used in other contexts) caches routing from the “Recompute” stage to the “Update” stage. R3 caches from “Rollout” to “Update.” This is a critical distinction because the gap between the Inference Engine (Rollout) and Training Engine (Update) is the primary source of instability in modern LLM stacks, not just the policy shift.
- Implications for Scaling: As models scale, MoE becomes the standard for efficient inference. The non-differentiable nature of discrete routing makes MoE uniquely fragile to system-level noise (FP8 vs BF16, kernel differences). R3 essentially acts as a synchronization lock between the disparate software stacks used in LLM post-training.