In one line: Standard MHA inference saturates memory bandwidth via the KV cache. MQA / GQA let multiple query heads share KV; FlashAttention moves the computation into SRAM — the trifecta speeds up both training and inference several-fold.
The MHA bottleneck#
For every layer and every token, the entire KV cache must be read from HBM into the SM. At long contexts:
- Memory bandwidth ≫ compute → the GPU's compute units sit idle waiting for data;
- KV cache size =
2 * L * H * d * dtype. Doubling sequence length doubles the cache.
The three big optimisations#
Analogy#
MHA = every writer has a personal librarian (K/V pair);
MQA = the whole town shares one librarian — fast but coarse;
GQA = a few writers share one librarian — quality and speed both win;
FlashAttention = the librarian brings the books to the desk once and reads everything in place — no more shelf trips for every lookup.
How GQA works#
KV-head count = 4 instead of 16 → KV cache 4× smaller.
Practical notes#
- Check the model card for
num_key_value_heads. Less thannum_attention_heads= GQA. LLaMA-3 8B: 32 vs 8. - FlashAttention v2/v3 integrates almost for free — PyTorch 2 SDPA, xformers, TransformerEngine all ship with it.
- Long context = SWA + positional-encoding extrapolation (RoPE / NTK-aware / YaRN) + long-context SFT data.
- Estimate KV-cache footprint:
bytes ≈ 2 * num_layers * KV_heads * head_dim * seq_len * 2 (bf16). 70B at 4K context is usually a few GB. - Quantise the KV cache: int8 / fp8 shrinks it further; quality loss is mostly at the long-context tail.
- Self-hosting fine-tuning: FlashAttention + LoRA + 8/4-bit quant lets a 24 GB GPU run 7B SFT.
Easy confusions#
Requires retraining or fresh init.
Drop-in replacement.