FlashAttention
The Attention IO Bottleneck
By 2022 the Transformer was the dominant deep-learning architecture, but its self-attention sub-layer carried a quadratic cost in sequence length. Most prior work attacked the arithmetic count: sparse attention, low-rank approximations, kernelized linear attention. The wall-clock numbers refused to match the predicted savings — approximate methods often ran no faster than the dense baseline they replaced.
FlashAttention starts from a different observation. On modern GPUs, attention is memory-bound, not compute-bound: most of the time is spent moving the \( N \times N \) intermediate matrices between high-bandwidth memory (HBM) and the on-chip SRAM where matmuls actually run. Approximate methods saved arithmetic the hardware had not been waiting on. The right axis to optimize is memory traffic.
Standard scaled dot-product attention
Symbols: \( \mathbf{Q}, \mathbf{K}, \mathbf{V} \in \mathbb{R}^{N \times d} \) are the query, key, and value matrices for a single attention head; \( N \) is the sequence length; \( d \) is the per-head dimension (typically 64 or 128); \( \mathbf{O} \in \mathbb{R}^{N \times d} \) is the output. The softmax is applied row-wise. A naive implementation materializes two intermediates explicitly: the score matrix \( \mathbf{S} = \mathbf{Q}\mathbf{K}^{\top} / \sqrt{d} \in \mathbb{R}^{N \times N} \) and the probability matrix \( \mathbf{P} = \text{softmax}(\mathbf{S}) \in \mathbb{R}^{N \times N} \).
Attention is not FLOP-bound
Common misconception: the quadratic FLOP cost of attention is what makes long-sequence transformers slow, so any approach that drops the FLOP count to \( O(N) \) or \( O(N \log N) \) should produce a proportional speedup. On an A100 GPU at \( N \in [512, 4096] \) the attention sub-layer's wall-clock time is dominated by reads and writes of the \( N \times N \) score and probability matrices through HBM, not by the matmuls themselves (Section 3.1, Fig 2). A method that halves FLOPs but still touches HBM the same way will not run twice as fast — and several of them ran slower than the dense baseline.
The paper's claim
Standard attention can be computed exactly — no approximation, identical numerical output up to non-associativity of floating point — using a tiled algorithm whose entire forward and backward pass touches the \( N \times N \) probability matrix only inside on-chip SRAM, never in HBM. The result is \( O(N) \) HBM memory instead of \( O(N^{2}) \), a 2–4\( \times \) wall-clock speedup on standard transformer training (Tables 1 and 2 of arXiv:2205.14135), and the first transformer ever trained to above-chance accuracy on Long Range Arena's 16K-token Path-X task (Table 5).