Reading the following paper:

Overview

FlashAttention-4 is the latest iteration of the wildly successful hardware-aware attention algorithm, designed specifically to tackle the architectural shifts in NVIDIA’s Blackwell (B200/GB200) GPUs. While FlashAttention-3 was heavily optimized for the Hopper architecture (H100), the transition to Blackwell introduces a phenomenon “asymmetric hardware scaling”. This paper provides a masterclass in algorithmic and kernel co-design, demonstrating how to shift compute paradigms when the hardware bottlenecks unexpectedly change.

The Core Problem: Asymmetric Hardware Scaling

The fundamental challenge addressed in FlashAttention-4 is that not all parts of the GPU got faster at the same rate. On the Blackwell B200, the Matrix Multiply-Accumulate (MMA) tensor core throughput doubled compared to Hopper, reaching a massive 8192 ops/clock/SM (2.25 PFLOPS for FP16/BF16). However, the shared memory (SMEM) bandwidth remained flat at 128 bytes/clock/SM, and the multi-function unit (MUFU), which handles exponential operations for softmax, remained at 16 ops/clock/SM.

Roofline analysis in the paper reveals a surprising reality for Blackwell: matrix multiplication is no longer the primary bottleneck for attention. Instead, shared memory traffic and exponential operations now dominate execution time, exceeding MMA compute by 25-60%.

Key Technical Innovations

1. Software-Emulated Exponential & Conditional Rescaling

Because the hardware MUFU cannot keep up with the doubled tensor core speed, the exponential calculation in the softmax becomes a major chokepoint.

  • Polynomial Approximation: To bypass the MUFU bottleneck, it leveraged a software emulation of $2^x$ using floating-point FMA units via polynomial approximation. By distributing the exponential computations across both the MUFU and FMA units (using emulation for 10-25% of entries), they effectively increased the exponential throughput. The genius here is recognizing that while a degree-3 polynomial has a higher FP32-level error than the hardware MUFU, once the output is rounded to BF16 (the standard precision for attention), the quantization error dominates, making the software emulation virtually indistinguishable from hardware.

  • Conditional Rescaling: FlashAttention relies on online softmax, which requires rescaling previous results when a new maximum value is encountered. FlashAttention-4 introduces a threshold-based skip mechanism: if the new maximum is not significantly larger than the old one, it skips the vector multiplication for rescaling and resolves the normalization at the very end.

2. Taming Shared Memory with 2-CTA MMA Mode

In the backward pass, SMEM bandwidth becomes an acute bottleneck because the algorithm requires five different MMA operations, forcing operands to be repeatedly read from shared memory.

To solve this, FlashAttention-4 leverages a new Blackwell feature: 2-CTA tensor core MMA mode. In this mode, two Cooperative Thread Arrays (CTAs) within the same cluster cooperatively execute a single MMA.

  • Halving SMEM Traffic: Each CTA stages only half of operand B in its own shared memory, while the hardware consumes the combined B tile during the multiply.

  • Halving Atomic Reductions: In the gradient accumulation step (dQ), the researchers use distributed shared memory (DSMEM) to exchange half of the gradient of the softmax (dS) between the two CTAs. This repacking allows each CTA to write only half of the dQ tile, cutting the number of expensive global atomic reductions in half.

3. Exploiting TMEM and Full Asynchrony

Blackwell introduces Tensor Memory (TMEM), a 256 KB on-chip memory per SM specifically for storing intermediate tensor core results. Unlike Hopper, where MMAs wrote to registers and caused massive register pressure, Blackwell’s MMAs write asynchronously directly to TMEM. FlashAttention-4 redesigns the software pipeline to aggressively overlap the fully asynchronous tensor core operations with softmax and memory operations, utilizing the larger 128x128 MMA tiles.

4. A Pythonic Shift: CuTe-DSL

From a developer ecosystem standpoint, one of the most exciting updates is the move away from notoriously slow C++ template metaprogramming. FlashAttention-4 is written entirely in CuTe-DSL embedded in Python. This framework lowers to PTX and compiles just-in-time (JIT), reducing compile times by 20-30x (down to ~1.4–2.5 seconds from 45–55 seconds) while retaining full low-level expressivity.

Performance Outcomes

By directly addressing the shifting hardware bottlenecks, FlashAttention-4 achieves impressive performance on B200 GPUs:

  • Up to 1.3x speedup over cuDNN 9.13 and 2.7x over Triton for BF16.
  • Achieves up to 1613 TFLOPs/s, utilizing 71% of the theoretical maximum compute.

Insights

  1. The End of “Compute-Bound” Attention (For Now): We are entering an era where keeping the tensor cores fed is harder than the actual matrix multiplication. Kernel developers can no longer treat GPUs as uniform compute scaling machines; they must acutely monitor the ratio of MMA throughput to memory bandwidth and non-linear function units.

  2. Software Emulation is Viable at Lower Precisions: The decision to emulate the exponential function using FMA units is a brilliant application of precision-aware optimization. It highlights a broader insight for AI system design: if your target data type (like BF16) has a high quantization error, you have “budget” to use cheaper, faster mathematical approximations without harming the end result.

  3. Python for Low-Level GPU Kernels is Maturing: The use of CuTe-DSL proves that writing bare-metal, highly optimized GPU kernels no longer strictly requires the painful compilation cycles of complex C++ libraries like CUTLASS. This lowering of the barrier to entry will likely accelerate community experimentation with new attention variants.