Attention LLM Inference KV Cache Interactive

MLA: How DeepSeek Shrinks the KV Cache by 93% Without Losing Quality

Standard multi-head attention caches 32,768 floats per token per layer. Multi-head Latent Attention caches just 576 — a 57× reduction — by compressing keys and values into a tiny latent vector, then exploiting a matrix-absorption trick to reconstruct them at compute time. The result: 5.76× faster generation and better-than-MHA quality.

June 1, 2026 ~14 min read Paper: arXiv:2405.04434
01 — The KV Cache Problem

Why KV Cache Is the Bottleneck in Long-Context Inference

Every token you generate reads from the KV cache — the stored keys and values for every past token at every layer. In standard multi-head attention (MHA) with 128 heads of dimension 128, that's 32,768 floats per token per layer. Run a 100-layer model on a 128K-token context and you need 100 × 128K × 32,768 × 2 bytes ≈ 838 GB of KV cache alone.

This is why long-context inference is so expensive: the model parameters fit in GPU memory, but the KV cache grows without bound as sequences get longer. Batching multiple users together makes it even worse — every concurrent request needs its own KV cache.

The core tension: MHA needs full heads for quality (multi-head = multiple independent subspaces), but caching full heads is prohibitively expensive at scale. GQA and MQA reduce cache by sharing heads, but they sacrifice quality. MLA finds a third way: compress the KV pairs into a tiny latent vector and reconstruct them on the fly.

Sequence length × cache size grows fast

The interactive chart below shows how KV cache memory scales with sequence length for each attention variant in a 100-layer model. Drag the slider to feel how quickly MHA's cache explodes while MLA stays nearly flat.

kv_cache_growth.py — KV cache memory vs. sequence length (100 layers, bfloat16) drag slider
32K tokens

At 128K tokens — DeepSeek-V2's context length — MHA needs ~838 GB of KV cache per batch entry at full precision. MLA needs only ~55 GB. That difference is the entire A100 you'd need to dedicate just to KV storage.

02 — Low-Rank KV Compression

Compress Down, Expand Up — And Cache Only the Small Vector

The key insight of MLA is embarrassingly simple: keys and values live in a high-dimensional space, but a lower-dimensional latent vector is enough to reconstruct them. Instead of caching the full keys and values, cache the compressed latent.

Standard MHA (cache both K and V)
k_t = WK · h_t     ← cached: n_h × d_h floats = 16,384
v_t = WV · h_t     ← cached: n_h × d_h floats = 16,384

MLA (cache only the latent c)
c_tKV = WDKV · h_t     ← cached: d_c = 512 floats ← 57× smaller
k_t = WUK · c_tKV     ← recomputed at decode time
v_t = WUV · c_tKV     ← recomputed at decode time

The down-projection WDKV maps from d = 5120 to d_c = 512, squeezing the token representation to a latent that's 10× smaller than a single attention head. The up-projections WUK and WUV expand back to the full 128 heads during the decode step.

mla_arch.svg — MLA data flow: compress → cache → expand architecture diagram
h_t 5120-dim input W^DKV 5120→512 c_t^KV 512-dim ← CACHED 57× smaller W^UK 512→16384 W^UV 512→16384 k_t^C 128 heads v_t^C 128 heads Attention softmax(QK/√d) o ONLY THIS IS CACHED ↓ recomputed at decode time →

The same principle applies to queries — a compressed query latent cQ (1536-dim) is computed first, then up-projected to the 128-head query space. This saves activation memory during training but is less critical for inference speed (queries aren't cached).

Why does compression work? Keys and values are linear projections of the same hidden state h_t. If h_t's information can be faithfully preserved in 512 dimensions, then any linear function of h_t — including full KV — can be reconstructed from that 512-dim vector. The low-rank assumption is that typical token representations don't need the full 16,384-dimensional space.
03 — The Absorption Trick

Merge WUK Into WQ — Eliminate Key Materialization Entirely

Here's where MLA gets clever. During inference, you're computing attention scores as:

score = q_tT · k_j = q_tT · WUK · c_jKV

The query q_t itself comes from multiplying WUQ by the compressed query latent. So the full chain is:

score = (WUQ · c_tQ)T · WUK · c_jKV = (c_tQ)T · (WUQ)T · WUK · c_jKV

The matrices (WUQ)T · WUK don't depend on the current position — they can be fused into a single weight matrix WQ' = (WUQ)T · WUK offline, at model-load time. At decode time you never materialize the full keys at all:

Naïve (two steps)
k_j = W^UK · c_j^KV
score = q_tT · k_j
Materializes a 16,384-dim k_j for every cached token
Absorbed (one step)
W^Q' = (W^UQ)^T · W^UK  precomputed
score = (c_t^Q)T · W^Q' · c_j^KV
Scores computed directly from 512-dim c_j^KV — no key expansion

The same absorption applies to values: WUV can be merged into the output projection WO. At inference time, the model never materializes full keys or values — it works directly with the compact 512-dim latent vectors, expanding only once when writing the output.

absorption_trick.py — Step-by-step inference walkthrough step through
1
Prefill: compress all past tokens
For every token j in the context, compute and store the 512-dim latent: c_j^KV = W^DKV · h_j. Also store the 64-dim decoupled RoPE key k_j^R (explained in §04).
c_cache[j] = W_DKV @ h[j] # 512 floats
2
Decode: compress the new query
For the new token t, compute the 1536-dim query latent c_t^Q = W^DQ · h_t. This is also smaller than the full query space.
c_q = W_DQ @ h_t # 1536 floats
3
Compute attention scores (absorbed)
Use the pre-fused weight W^Q' = (W^UQ)^T · W^UK to score directly against cached c_j^KV — no key materialization needed.
scores = c_q @ W_Qprime @ c_cache.T # 512→512
4
Weighted sum of values (absorbed)
Compute weighted sum of c_j^KV vectors, then apply the fused W^O' = W^UV · W^O to get the final output — values never materialized as full tensors.
out = (attn_weights @ c_cache) @ W_Oprime
5
Append new latent to cache
Store c_t^KV and k_t^R for this token. Total storage added: 512 + 64 = 576 floats — about 1.1 KB at bfloat16 per layer.
c_cache.append(W_DKV @ h_t) # 576 floats total
Step 1 of 5
04 — Decoupled RoPE

Why Positional Encoding Breaks the Absorption — and the Fix

There's one critical obstacle to the absorption trick: Rotary Position Embeddings (RoPE). RoPE works by rotating the key and query vectors by an angle that depends on position. For standard attention:

RoPE applied to keys
k_j = RoPE(WK · h_j, position=j)

If you try to apply RoPE to the compressed key k_j^C = WUK · c_j^KV, the rotation matrix R(j) sits between WQ and WUK in the score computation:

score = qT · R(j) · WUK · cKV ← R(j) depends on position j, so W^UK can't be absorbed into W^Q
The problem: The whole point of absorption is pre-multiplying (WUQ)T · WUK once at load time. But if there's a position-dependent rotation between them, the fused weight changes for every token position — you'd need a separate WQ' for every sequence position. That's not precomputable.

The fix: separate position from content

MLA decouples RoPE from the compressed keys entirely. Each attention head uses a two-part query and key:

  • Content part — derived from c^KV via WUK (no RoPE, absorb-able)
  • Position part — a separate small key k^R computed directly from h_t via WKR and rotated with RoPE
Decoupled RoPE
q_{t,i} = [q_{t,i}^C ; q_{t,i}^R]     ← content head ++ position head
k_{t,i} = [k_{t,i}^C ; k_t^R]          ← content key ++ position key (shared across heads)

q_{t,i}^R = RoPE(WQR · c_t^Q)      ← 64-dim per head, RoPE-encoded
k_t^R = RoPE(WKR · h_t)           ← 64-dim, shared across all heads, CACHED

The content key k^C is still absorbed into WQ (no RoPE → no position coupling). The position key k^R is cached separately — it's only 64 floats. The final dot product spans both parts:

score(t,j,i) = q_{t,i}^T · k_{j,i} = (q_{t,i}^C)T k_{j,i}^C + (q_{t,i}^R)T k_j^R
decoupled_rope.svg — Content vs. position key paths architecture
h_t input 5120-d W^DKV 5120→512 W^KR 5120→64 c^KV 512-d · CACHED RoPE pos-encoded k^R · CACHED W^Q' (absorbed) (W^UQ)^T · W^UK RoPE dot-prod q^R · k^R + score(t,j) softmax input CONTENT PATH (no RoPE, absorb-able) POSITION PATH (RoPE, tiny 64-dim cache)

The position key k_t^R is shared across all 128 attention heads — it doesn't need a per-head copy. This is what keeps its cache cost low. Total cache per token: 512 (content latent) + 64 (position key) = 576 floats.

05 — Cache Math

MHA vs GQA vs MQA vs MLA: The Numbers

Let's make the savings concrete using DeepSeek-V2's actual hyperparameters: n_h = 128 heads, d_h = 128 dims/head, 60 layers. All values in floats per token; multiply by 2 for bfloat16 bytes.

MHA
32,768
2 × n_h × d_h = 2 × 128 × 128
baseline — too large
GQA (g=8)
2,048
2 × 8 × d_h = 2 × 8 × 128
16× smaller, quality↓
MQA
256
2 × 1 × d_h = 2 × 1 × 128
128× smaller, quality↓↓
MLA
576
d_c + d_h^R = 512 + 64
57× smaller, quality↑ vs MHA

The key surprise: MLA's 576-float cache beats MHA quality while being smaller than even MQA in total effective capacity. Why? Because the latent vector is a richer representation than a single shared key — it retains the full information of h_t, allowing any number of heads to reconstruct distinct keys from it.

At full 128K context

For a single batch entry across 60 layers at 128K tokens (bfloat16, 2 bytes/float):

Method Cache / layer / token Total at 128K ctx vs MHA
MHA 32,768 floats ~503 GB
GQA (g=8) 2,048 floats ~31 GB 16×
MQA 256 floats ~3.9 GB 128×
MLA 576 floats ~8.8 GB 57×

DeepSeek-V2 reports a 93.3% KV cache reduction vs. DeepSeek 67B (which used MHA). The difference between the raw 57× factor and the 93.3% reduction comes from the model having fewer total layers (60 vs. 95) and using d_c = 4·d_h rather than a larger head count.

Throughput impact

throughput_comparison.py — Decoding speed on 8× H800 GPUs animated on scroll
DeepSeek 67B (MHA)
8,696 tok/s
DeepSeek-V2 (MLA)
50,083 tok/s

5.76× faster generation despite 3.5× more total parameters. The KV cache saving frees bandwidth for actual compute.

06 — Results

Better Than MHA at a Fraction of the Cost

The payoff: DeepSeek-V2 (236B params, 21B activated) matches or beats dense models with 67–78B parameters on every major benchmark, while costing 42.5% less to train per trillion tokens. The efficiency gains compound — smaller KV cache means larger effective batch sizes, which means better GPU utilization throughout training.

benchmark_results.py — Base model comparison (0-shot to 8-shot, standard benchmarks) animated on scroll
DeepSeek 67B
71.3
LLaMA-3 70B
78.9
Mixtral 8×22B
77.6
Qwen1.5 72B
77.2
DeepSeek-V2 (MLA)
78.5
DeepSeek 67B
18.7
LLaMA-3 70B
42.2
Mixtral 8×22B
42.5
Qwen1.5 72B
41.4
DeepSeek-V2 (MLA)
43.6
The key insight from results: MLA doesn't trade quality for efficiency — DeepSeek-V2 scores higher than MHA-based models at similar scale. The compression is lossless in the sense that a 512-dim projection of a 5120-dim hidden state preserves enough information for full MHA-quality attention, while eliminating 98.2% of the KV memory footprint.

Full benchmark table

Benchmark DeepSeek 67B LLaMA-3 70B Mixtral 8×22B Qwen1.5 72B DeepSeek-V2
MMLU 5-shot 71.3 78.9 77.6 77.2 78.5
BBH 3-shot 68.7 81.0 78.9 59.9 78.9
DROP F1 3-shot 69.7 82.5 80.4 71.5 80.1
GSM8K 8-shot 63.4 83.0 80.3 77.9 79.2
MATH 4-shot 18.7 42.2 42.5 41.4 43.6
HumanEval 0-shot 45.1 48.2 53.1 43.9 48.8
CMMLU 5-shot 70.8 69.3 60.0 84.3 84.0

DeepSeek-V2 uses only 21B activated parameters per token (MoE) — its full 236B parameter count is not activated simultaneously. Direct parameter comparison with dense models understates the efficiency advantage.