Reading the following paper and blogs:

Blackwell GPUs introduces native hardware support for Microscaling (MX) formats. While theoretically offering 2x throughput over BF16, realizing these gains requires a fundamental departure from standard FP8 recipes. Successful implementation relies on a “bimodal” strategy: a numerical recipe that defies OCP v1.0 standards to ensure convergence, and a systems recipe that completely redesigns memory hierarchy usage to prevent quantization overhead from erasing compute gains.


The Numerical Recipe: Stability via Precision

Standard FP8 training (e.g., on H100s) typically relies on “per-tensor” scaling, necessitating the E5M2 format for gradients to handle high dynamic range. MXFP8 shifts this paradigm by using fine-grained block scaling (groups of 32 elements), which locally absorbs dynamic range requirements.

  • Universal E4M3 is Mandatory: Contrary to conventional wisdom, E4M3 should be used for all tensors: weights, activations, and notably, activation gradients.
    • Insight: Because the shared block scale ($E8M0$) handles the magnitude, the element type ($E4M3$) can dedicate its bits to precision (3 mantissa bits) rather than range. Ablation studies show that using E5M2 for gradients in MXFP8 leads to perplexity degradation, while E4M3 matches BF16 parity.
  • The “Ceiling” Rounding Protocol: The OCP v1.0 specification suggests a rounding method that effectively floors the scale factor. This is numerically dangerous for MXFP8.
    • The Flaw: If the scale is rounded down, $V/Scale$ may overflow the maximum representable value of the 8-bit format.
    • The Fix: Use a Ceiling/Round-Up algorithm for scale calculation. This ensures the scale is always large enough to map the input values within the valid quantization range, eliminating training instability.

The Systems Recipe: The “Tensor Memory” Constraint

Achieving the theoretical hardware speedups requires navigating the Blackwell architecture’s new memory hierarchy, specifically Tensor Memory (TMEM). A direct port of Hopper kernels will likely result in performance worse than BF16.

  • The Trap of Dequantization: Blackwell Tensor Cores accumulate results in TMEM, not registers. Attempting to dequantize results using CUDA cores requires a round-trip (TMEM $\to$ Reg $\to$ CUDA Core $\to$ TMEM) that creates massive pipeline bubbles, potentially taking 1.76x longer than the matrix math itself.
    • Solution: Use the tcgen05.mma instruction. This hardware intrinsic handles block scaling and accumulation entirely within the Tensor Core/TMEM pipeline, avoiding the CUDA core bottleneck.
  • Memory Bandwidth & Quantization Overhead: Quantization is memory-bound. Naive kernels (e.g., standard libraries) often cap at ~4.5 TB/s, which can consume up to 40% of the total step time.
    • Optimization: Custom kernels are required to push bandwidth to 6.2+ TB/s. Crucially, these kernels must output scale factors directly in the specific, swizzled layout expected by tcgen05, avoiding runtime reshapes that kill performance.
  • The “Persistent Grid” Data Flow: The optimal data path for scale factors avoids the register file entirely. The pipeline must move scales from HBM $\to$ Shared Memory (SMEM) $\to$ TMEM using asynchronous copy instructions (cp.async.bulk and tcgen05.cp). This preserves the “tensor vibe” and maximizes occupancy.

Domain Specifics: Mixture-of-Experts (MoE)

MoE models introduce unique challenges due to irregular memory access patterns.

  • Grouped GEMMs: Standard kernels fail here. Need Grouped Wgrad/Dgrad kernels.
  • Supergrouping: To prevent cache thrashing, “supergrouping” heuristics must be applied per expert. This organizes blocks to ensure the output matrix region computed by SMs is as square as possible, maximizing L2 cache reuse.

Technical Synthesis & “Gotchas”

Feature Standard FP8 (Hopper) MXFP8 Recipe (Blackwell) Reason
Gradient Format E5M2 E4M3 Block scaling handles range; gradients need precision.
Scale Rounding Floor (Standard) Ceil / Round Up Prevents overflow in quantized blocks.
Accumulation Registers TMEM Hardware architecture change; CUDA cores are too slow for dequant.
Kernel Mode Warp-synchronous 2-CTA / Async tcgen05 allows 2 SMs to share B-matrix, reducing HBM traffic.

MXFP8 GEMM and Block Scaling Factors

1. Core Concept: The Scaled Block

In MXFP8 (Narrow Precision) Matrix Multiplication, accuracy is maintained by dividing tensors into small blocks rather than scaling the entire tensor with a single scalar. For FP8 data types (CUDA_R_8F_E4M3 or CUDA_R_8F_E5M2), cuBLAS uses 32-element 1D block scaling.

A single Scaled Block consists of:

  1. Vector: 32 adjacent data elements (FP8).
  2. Scale: A single shared 8-bit scaling factor (Format CUDA_R_8F_UE8M0).

The Mathematical Operation The dot product of two blocks (one from Matrix A, one from Matrix B) is computed by multiplying their shared scales and then summing the products of their elements: \(Dot(x, y) = S^x S^y \cdot \sum_{i=1}^{32} x^i y^i\) The full GEMM operation sums these “Block Dot Products” along the reduction dimension ($K$).

2. The 1D Block Scaling Factors Layout

The scaling factors are not stored in a linear row-by-row format. They use a specialized 128x4 Tiled Layout to optimize memory access.

Tile Structure
  • Dimensions: A single tile covers 128 elements in the Outer dimension (rows of A or columns of B) and 4 elements in the Inner dimension (blocks along K).
  • Capacity: One tile of scaling factors covers a $128 \times 128$ area of the source data tensor ($128 \text{ rows} \times 4 \text{ blocks} \times 32 \text{ elements/block}$).
Address Mapping (The Swizzle)

To locate a specific scaling factor within a tile, cuBLAS uses a specific “swizzled” addressing formula. This interleaves the data in memory.

The Formula:

offset = (outer % 32) * 16 + (outer / 32) * 4 + inner
  • outer: The Row index (for Matrix A) or Column index (for Matrix B).
  • inner: The Block index along the K dimension ($0 \dots 3$ within a tile).
  • offset: The linear byte index within the tile memory.

3. Concrete Example: Tracing a GEMM Element

Let us derive the calculation for a single result element $D_{5,0}$ (Row 5 of Matrix A dot Column 0 of Matrix B).

  • Dimensions: We assume $K=64$ (requiring 2 blocks of 32 elements).
  • Goal: Find the scaling factors $S^A$ (for Row 5) and $S^B$ (for Col 0) to compute the dot product.
Step A: Locating Scales for Matrix A (Row 5)

We need scales for Block 0 ($k=0..31$) and Block 1 ($k=32..63$).

  • Inputs: outer = 5 (Row), inner = 0 and 1.

Calculation for Block 0: \(\text{offset} = (5 \pmod{32}) \times 16 + (5 / 32) \times 4 + 0\) \(\text{offset} = 5 \times 16 + 0 + 0 = \mathbf{80}\)

Calculation for Block 1: \(\text{offset} = 5 \times 16 + 0 + 1 = \mathbf{81}\)

Result: The scales for Row 5 are stored at offsets 80 and 81.

Step B: Locating Scales for Matrix B (Column 0)

We need scales for Block 0 and Block 1.

  • Inputs: outer = 0 (Col), inner = 0 and 1.

Calculation for Block 0: \(\text{offset} = (0 \pmod{32}) \times 16 + (0 / 32) \times 4 + 0 = \mathbf{0}\)

Result: The scales for Column 0 are stored at offsets 0 and 1.

Step C: Understanding the Interleaving

Why is Row 5 at offset 80, but Row 32 at offset 4?

  • Row 32 calculation: (32 % 32) * 16 + (32 / 32) * 4 + 0 = 0 + 4 + 0 = 4. This layout groups rows 0, 32, 64, and 96 closely in memory (Offsets 0, 4, 8, 12) before moving to Row 1 (Offset 16). This optimizes the read patterns for GPU threads consuming 32 rows simultaneously.
Step D: Execution
  1. Fetch: Load data vectors and scales from offsets 80, 0 (for Block 0) and 81, 1 (for Block 1).
  2. Compute: \(D_{5,0} = (S^A_{80} S^B_{0} \sum_{k=0}^{31} A_{5,k} B_{k,0}) + (S^A_{81} S^B_{1} \sum_{k=32}^{63} A_{5,k} B_{k,0})\)

4. Implementation Constraints

  • Global Stride: When moving to the next set of blocks (incrementing inner by 4 to the next tile), the memory address jumps by 128 bytes (the size of the outer dimension of the tile).
    • Formula: offset = (sf_inner + sf_outer * sf_inner_dim) * 128.
  • Alignment: The starting address of the scaling factors must be 16-byte aligned.
  • Transposition: The scaling factor layout is immutable. Even if Matrix A is transposed in the GEMM descriptor (CUBLAS_OP_T), its scaling factors must remain in the standard layout (M-major for A, N-major for B).
  • Padding: If the matrix dimensions are not multiples of the tile size (128), full tiles must still be allocated, and out-of-bounds values should be filled with zeros.