Jet-RL and the Precision Mismatch in Reasoning Models
Reading the following paper:
Unified Precision Flow for On-Policy Stability in Reasoning Models
1. The Core Bottleneck: The “Reasoning” Tax
As Large Language Models (LLMs) shift toward “System 2” reasoning (Chain-of-Thought), the computational profile of Reinforcement Learning (RL) has shifted.
- The Rollout Dominance: Effective reasoning requires generating long token sequences ($>6,000$ tokens). Because LLM generation is autoregressive, the rollout phase (data collection) now consumes over 70% of the total training time when context exceeds 8k tokens.
- The Need for Speed: To make RL feasible for reasoning models, accelerating the rollout phase is mandatory. The industry standard solution has been FP8 quantization.
2. The Problem: The “Off-Policy” Trap of Mixed Precision
The standard acceleration strategy, dubbed “BF16-Train-FP8-Rollout,” keeps the training master weights in BF16 (for stability) but casts the actor to FP8 for the rollout phase (for speed). While this works for short-context chat, it causes catastrophic collapse in reasoning tasks.
Why Naive FP8 Fails
Two failure modes for the mixed-precision strategy:
- Long-Context Divergence: In autoregressive generation, small numerical errors from FP8 quantization accumulate at every step. By 16k tokens, the trajectory generated by the FP8 actor diverges significantly from what the BF16 training actor would have produced.
- Hard Tasks: For challenging benchmarks (e.g., DeepMATH) or base models lacking instruction tuning, the model is less confident. The quantization noise disrupts the optimization landscape, causing divergence within 20 steps.
The Theoretical Insight: This divergence violates the on-policy assumption of algorithms like PPO. The optimizer updates the policy based on data collected by a numerically different policy (the FP8 actor). This introduces a distribution shift that acts like aggressive “off-policy” noise, destabilizing the training.
3. The Solution: Jet-RL (Unified Precision Flow)
Jet-RL proposes a strict Unified Precision Flow. The core principle is that the inference quantization graph ($G_{infer}$) must be a strict subgraph of the training forward graph ($G_{train}^{fwd}$).
Mechanism:
- The training forward pass must use the exact same FP8 quantization and granularity as the rollout engine.
- This forces the training optimizer to “see” the quantization noise as part of the environment dynamics, effectively realigning the update steps to be truly on-policy.
- No Calibration: Because RL weights change every step, slow calibration (like SmoothQuant) is impossible. Jet-RL casts weights directly, avoiding inter-step synchronization overhead.
4. Technical Implementation Details
Jet-RL employs a sophisticated quantization strategy designed for the E4M3 FP8 format (max value 448), utilizing NVIDIA H100 Tensor Cores.
A. Granularity Strategy
To prevent underflow/overflow without calibration, Jet-RL uses fine-grained quantization:
- Weights: Quantized using $128 \times 128$ per-block granularity.
- Activations & Gradients: Quantized using $1 \times 128$ per-group granularity.
B. The GEMM Pipeline (Forward & Backward)
Specific flow for Linear Layers involving three GEMM operations:
- Forward Pass (FProp):
- Input ($X$) and Weight ($W$) are cast to FP8.
- Computation is performed in FP8.
- Backward Pass - Gradient of Weights (WGrad):
- Calculates $\nabla W = \nabla Y^T \times X$.
- Configuration: Specifically uses a $1 \times 128$ input multiplied by a $128 \times 1$ input layout to stabilize training.
- Backward Pass - Gradient of Inputs (DGrad):
- Calculates $\nabla X = \nabla Y \times W$.
- Crucial Stability Hack: While the GEMM is FP8, the gradients flowing between operators are kept in BF16. Quantizing the communication gradients often causes underflow that kills convergence.
C. Memory Optimization
Activations saved for the backward pass are stored in FP8, reducing the memory footprint compared to BF16 execution.
5. Evaluation & Impact
Stability & Accuracy:
- Convergence: On the DeepMATH dataset (16k context), the naive BF16-Train-FP8-Rollout method failed to converge entirely. Jet-RL converged stably.
- Performance Gap:
- Naive method: >10% accuracy drop on challenging tasks (Qwen3-8B-Base on DeepMATH).
- Jet-RL: 0.9% gap compared to the full BF16 baseline.
Efficiency:
- Rollout Speed: 1.07x – 1.33x faster. Larger models (32B) see higher gains because they are compute-bound, allowing FP8 Tensor Cores to shine. Smaller models (8B) are memory-bound, limiting gains.
- Training Step: 1.41x faster due to FP8 GEMMs in the forward/backward pass.
- End-to-End: 1.16x total training speedup (on 8B models), with higher gains projected for larger scales.
6. Takeaways
- Quantization is a Convergence Parameter: In Supervised Fine-Tuning (SFT), inference precision is an efficiency detail. In RL, precision mismatches create distribution shifts. If you rollout in FP8, you must train with an FP8 forward pass to satisfy the on-policy requirement.
- The “Overthinking” Sensitivity: As models generate longer chains of thought, they become hypersensitive to quantization. The accumulated error of FP8 over 16,000 steps creates a trajectory that a BF16-trained model considers “out of distribution.”
- Future-Proofing for 70B+: The efficiency analysis suggests Jet-RL becomes more valuable as models scale. The shift from memory-bound to compute-bound regimes in larger models means FP8 will be the primary lever for reducing RL training time in the future.