Why standard generation is slow
Large language models generate text one token at a time. Every new token requires a full forward pass through all N layers. On a 2B-parameter model that is fast; on a 27B model it becomes the throughput bottleneck. Because each token depends on the previous one, generation is inherently serial.
Multi-Token Prediction (MTP) — also called speculative decoding — breaks this constraint by running a lightweight draft model in parallel with the main model. The draft model proposes several tokens at once; the backbone verifies them all in a single pass. Accepted tokens are kept; rejected ones are discarded and the backbone corrects from that point. Net throughput gain: typically 2–4×.
Autoregressive (standard)
Multi-Token Prediction
The two-model system
The system consists of Gemma4ForCausalLM (the backbone) and Gemma4AssistantForCausalLM (the draft model). The backbone runs a standard forward pass and writes the K/V states from its last non-sharing layer of each attention type into shared_kv_states. The assistant reads those two tensors to draft tokens without recomputing attention from scratch.
Click any component below to see its source code in the sidebar.
Reusing the backbone's attention states
The key insight: the assistant does not need its own KV cache. Instead, it reads a shared dict of K/V tensors — one pair per attention type — that the backbone's last non-sharing layers wrote. The backbone's Gemma4TextModel.forward() populates shared_kv_states and returns it when return_shared_kv_states=True. The assistant receives this dict and every one of its layers looks up K/V by type: shared_kv_states[self.layer_type].
Since the assistant config forces num_kv_shared_layers = num_hidden_layers, every assistant layer is a KV-sharing layer — none of them allocate K or V projections at all. The assistant's layer count (e.g., 4) is completely independent of the backbone's (e.g., 26); what matters is that both models share the same layer_types sequence.
Why this works
Normal attention needs K and V computed from the current sequence context. By borrowing the backbone's K/V, the assistant effectively sees the full context the backbone has built up — without paying to recompute it. The assistant only runs its own Q projection to attend over the borrowed K/V.
Which layers, exactly?
There is no 1:1 layer mapping. The assistant can have 4 layers while the backbone has 26 — that is by design, because shared_kv_states contains exactly two K/V pairs, keyed by attention type, not by layer index.
Here is the actual mechanism. The backbone designates its last num_kv_shared_layers layers as "KV-sharing layers" — they don't compute K or V at all. Among the preceding non-sharing layers, the code (modular_gemma4.py lines 978–983) identifies the last non-sharing layer of each attention type and marks it store_full_length_kv = True. When those layers run, they write into the dict:
- Last non-sharing full_attention layer →
shared_kv_states["full_attention"] = (K, V) - Last non-sharing sliding_attention layer →
shared_kv_states["sliding_attention"] = (K, V)
That's it — two tensors total, regardless of how many layers the backbone has. Every KV-sharing layer, both the backbone's own tail and every assistant layer, then does a single type-keyed lookup (line 1031):
key_states, value_states = shared_kv_states[self.layer_type]
Two assistant layers of the same type (e.g., two sliding_attn layers) consume the exact same K/V tensor. The backbone's last num_kv_shared_layers layers are identical in structure to the assistant's layers — both are just KV-sharing layers running over the same frozen context.
Q-only cross-attention
Structurally this is identical to encoder–decoder cross-attention: Q comes from one stream (the assistant's hidden state), K and V come from another (the backbone's frozen context). The backbone is the "encoder" that has already processed the full prompt; the assistant is a lightweight "decoder" that attends into that encoded memory to draft new tokens.
The FLOP saving is direct: K and V projections together account for ⅔ of the QKV projection cost in standard self-attention. Removing them across all N layers — plus eliminating the assistant's own KV memory allocation — makes the drafter dramatically cheaper per forward pass than even a naive small model with its own attention would be.
Contrast: EAGLE-3 and DeepSeek MTP
These three designs represent the three main bets on how a draft model should access backbone context. The differences run deeper than engineering choices — they reflect fundamentally different views on where the useful signal lives.
| Aspect | EAGLE-3 | DeepSeek MTP | DFlash | P-EAGLE | DART | Gemma 4 MTP |
|---|---|---|---|---|---|---|
| Backbone → drafter | Hidden states from selected backbone layers injected at drafter input | Shared token embedding + LM head only; last hidden state passed between sequential modules | Hidden states from 5 uniformly sampled backbone layers; projected through new W_K, W_V weights into fresh K/V vectors — these K/V tensors did not exist in the backbone | Same 3-layer FC-fusion as EAGLE-3 (layers at indices 2, L/2, L−1); concatenated and projected to a single vector g — identical access pattern to EAGLE-3, just used for parallel prediction | Same 3-layer FC-fusion as EAGLE-3 (low/mid/top → g) plus shifted token embedding e of the next predicted token; both concatenated into z = [g; e] before the draft layer | Backbone's own K/V cache (computed during backbone's forward pass) borrowed directly — no re-projection, zero extra compute |
| K/V source | Drafter computes its own K/V from its hidden state (full self-attention) | Each MTP module computes its own K/V (full self-attention within the module) | New K/V created by projection — hidden states passed through fresh W_K, W_V weights at every draft layer; K/V are a transformed derivative of backbone features, not the backbone's own cache | Drafter computes its own K/V (full self-attention, same as EAGLE-3) — no sharing with backbone K/V | Drafter computes its own K/V (single decoder layer, full self-attention) — same as EAGLE-3; no backbone K/V sharing | Backbone's pre-computed K/V reused as-is — no W_K or W_V in the drafter; the backbone already ran K = h·W_K and V = h·W_V for its own attention, and those exact tensors are shared directly |
| Attention mask | Causal — drafter generates one token at a time | Causal within each sequential module | Bidirectional within each draft block; all masked positions in a block attend to each other and to c | Causal with depth-ordering constraints — all K positions processed together but each only attends to earlier depth positions; no future-draft leakage | Strictly causal — over both prefix and within the masked block; MASK tokens cannot attend forward; no bidirectional refinement | Bidirectional — all draft tokens see each other; backbone K/V provides causal history |
| Draft structure | Parallel single-pass over all draft positions | Sequential depth-wise: module 1→t+1, module 2→t+2, module 3→t+3 | Block diffusion — all tokens in a block decoded simultaneously in one forward pass; no intra-block sequential dependency | Parallel single-pass — all K draft tokens generated in one forward pass (vs K sequential passes in EAGLE-3); 4-layer drafter and a learnable shared hidden state replace per-step autoregressive context | Single forward pass with MASK tokens — prefix [z₁:ₙ] + (d−1) MASK tokens fed to 1-layer decoder; yields d logits simultaneously; no iterative denoising; N-gram tree pruning builds the final draft tree | Parallel single-pass over all draft positions |
| Vocab prediction | Full LM head (tied with backbone) | Full LM head (shared output projection) | Full LM head | Full LM head (tied with backbone) | Full LM head; annealed KL divergence training (position-decaying weights γ=0.6 suppress noisy distant targets) | Centroid-based: 32 of 2048 clusters → 4096 tokens scored (98.4% savings) |
| Core bet | Backbone features are a rich enough signal at the input level | Depth-wise decomposition: each extra token is one extra transformer layer of compute | Injecting backbone context into every draft layer's K/V keeps the signal strong at depth; block diffusion decouples draft quality from latency | RoPE encodes absolute position via attention alone — a learnable shared hidden state suffices for all K positions; 4-layer depth substitutes for lost autoregressive signal; one forward pass replaces K sequential ones | 3-layer FC-fused features + shifted token embedding give the draft model the same signal as EAGLE-3; MASK tokens let a single causal layer predict d future positions in one shot; N-gram continuity prevents semantically incoherent draft trees | The backbone's K/V cache is the context — use it directly, skip all re-encoding |
Inside one assistant transformer layer
The assistant's transformer is structurally almost identical to the backbone's — same RMSNorm, same GatedMLP (SwiGLU), same residual connections — with one surgical difference: the attention sub-layer has no K or V projection matrices. When each layer runs its attention, it looks up (K, V) by its own layer type: shared_kv_states[self.layer_type].
Below is the full data flow: the pipeline overview at the top shows where the assistant sits end-to-end, and the expanded view shows exactly what happens inside one transformer layer.
The draft sequence length L (3–7 tokens) is far shorter than the backbone's full context S. Because K and V span the full S-length backbone context, each draft token's Q vector attends over the entire prefix in one shot — without the assistant ever having read that prefix.
Bridging two hidden dimensions
The backbone runs at d_model = 1536; the assistant runs at a smaller hidden_size (e.g. 768). Two projection layers bridge the gap.
The pre_projection receives a concatenation of the target model's token embedding (from embed_tokens) and the target model's last-layer hidden state, both taken at the same last-validated position, and compresses the resulting 3072-dim vector to the assistant's smaller working dimension. The post_projection expands back to 1536 for one specific reason: the backbone's lm_head is Linear(1536 → vocab_size) — it is the verification head during speculative decoding, and it only accepts inputs in the backbone's dimension. The assistant's last_hidden_state (after post_projection) is fed directly into that same lm_head to compute the reference logits for acceptance/rejection checking. Without this projection, the shapes would be incompatible and the backbone could not re-use its own output head to verify drafts.
Centroid-based token prediction
Computing logits over a 256K-token vocabulary requires a Linear(768, 256000) matrix multiply — expensive for a small draft model. Gemma4AssistantMaskedEmbedder avoids this by clustering vocabulary tokens into centroids and only scoring the top-K clusters.
num_centroids=2048 and top_k=32, the assistant evaluates 32 × 128 = 4,096 tokens instead of the full vocabulary of ~256K — a 98.4% reduction in logit computation. Why the assistant uses bidirectional masks
The assistant does use bidirectional attention masks — this is confirmed by the source imports and calls (create_bidirectional_mask, create_bidirectional_sliding_window_mask). But this raises an obvious question: if there is no causal mask, how can the assistant generate autoregressively?
The answer is in the docstring: "no difference for the edge case of q_len == 1 as it acts as full attention no matter what." During each draft step, the assistant processes a single token (q_len = 1). With one query position, bidirectional vs causal is mathematically identical — there are no other draft positions to mask. Drafting remains autoregressive in the usual sense: each step samples one token, the backbone updates its K/V, and the next step reads that updated context.
For full attention this non-difference is easy to see in the matrix: the new draft token sits at the very last row, and a causal lower-triangle already has all 1s there — identical to an all-1s bidirectional row. No special logic needed.
Sliding-window attention is different. SWA uses a fixed-width window (W tokens) rather than the full context, so direction matters. The backbone ran prefill with a past-looking window — each token attended to the previous W tokens. Its stored K/V reflects that. The bidirectional SWA utility, however, anchors the window at the start of the query sequence (future-looking by default): for q_len=1, the single Q gets a window opening at K/V slot 0, not at the recent end of the context. Those are opposite ends of the sequence — attending there would be meaningless.
The fix is a one-line tensor flip: swa_mask.flip(dims=(-1,)). Reversing the K/V axis turns [1,1,1,0,…,0] into [0,…,0,1,1,1] — shifting the window from the front of the K/V buffer to the back, exactly where the backbone's most recent SWA K/V lives.
Importantly, the flip changes nothing about the attention mechanism itself — softmax(QKᵀ/√d + mask)·V is identical before and after. It is purely a mask correction: the utility produced the wrong mask values (window at the wrong end), and the flip corrects them. You could equivalently have written a custom function that produced [0,…,0,1,1,1] directly, making the flip unnecessary. The bidirectional mask and the flip together are just the codebase's way of arriving at the mask tensor the drafter actually needs.
Full attention — 3 draft tokens × 5 backbone tokens (all backbone tokens are in the past)
SWA flip — all 3 draft steps · q_len=1 each · kv_len=5 · W=3
Target model SWA (self-attn) vs drafter SWA (cross-attn) — W=3, kv_len=5
How draft tokens stay distinct despite sharing K/V and weights
A natural question: if d₀, d₁, and d₂ all attend to the same backbone K/V with the same model weights, shouldn't they produce the same token? They don't, for two reasons that compound each other.
Different input embeddings. Each draft step is conditioned on the token sampled in the previous step. When generating d₁, the draft model's input is the embedding of d₀ — the token just sampled. Generating d₂ feeds d₁'s embedding. A different input produces a different Q, a different attention output, and a different prediction.
Different positional encoding (RoPE). Even if the token embedding were somehow identical across steps, the Q vector at position 5 is rotated differently from position 6, which is rotated differently from position 7. RoPE bakes absolute position directly into Q, making the query to the backbone K/V position-sensitive. The backbone K/V at position 2 looks geometrically "close" to Q at position 3 and much further to Q at position 6 — the attention pattern shifts at every step even with the same weights.
Concretely, each step computes:
d₀: Q = RoPE(embed(t₄) · W_Q, pos=5) → attend backbone K/V → sample d₀
d₁: Q = RoPE(embed(d₀) · W_Q, pos=6) → attend backbone K/V → sample d₁
d₂: Q = RoPE(embed(d₁) · W_Q, pos=7) → attend backbone K/V → sample d₂ Both the input token and the rotated position change at every step. The backbone K/V and the weight matrix W_Q are fixed; the query is not.
Why it works at all. The backbone K/V encodes the full prompt. The draft model's task is narrow: given this context, predict the most likely token at position X. The positional encoding in Q carries the "what position am I at?" signal, and the input embedding carries the "what was just generated?" signal. Together they are usually enough for short continuations — the model is trained end-to-end on exactly this objective. The limitation shows up in the acceptance rate: without explicit draft-to-draft attention (unlike EAGLE), later tokens in the draft sequence are conditioned only on the backbone context plus their immediate predecessor, so quality degrades faster as draft length grows. This is why draft lengths are kept short (3–7 tokens).
The underlying design bet. Viewed this way, each draft token is nearly an independent prediction from backbone context, with only a shallow one-step dependency on its immediate predecessor. The backbone K/V spans hundreds or thousands of prompt tokens; the single preceding draft token adds only a marginal local signal on top of that dominant context. This is the architectural wager: for short continuations, the prompt alone is usually sufficient to predict the next few tokens correctly — the model rarely needs to know what earlier draft tokens were in order to get the later ones right. Most of the information that determines the next token was already encoded in the backbone's representation of the prompt. The draft tokens are closer to parallel lookups into that representation than to a serial reasoning chain. Keeping drafts short (3–7 tokens) is precisely the regime where this bet pays off: the inter-draft dependencies that are being ignored are weak enough that the acceptance rate stays high.
Putting it all together
At generation time, HuggingFace's model.generate() orchestrates the speculative decoding loop automatically when assistant_model is provided. The backbone and assistant alternate: assistant drafts a batch, backbone verifies, accepted tokens are committed, and the loop continues.
After the first rejected token, all subsequent draft tokens are discarded and the backbone generates the correct continuation. In this example, 3 tokens are accepted and 1 is rejected — still a net 3× gain over pure autoregressive generation.
model.generate(inputs, assistant_model=assistant, max_new_tokens=200). No other changes required — HuggingFace handles the entire speculative decoding loop. Bonus Could the MTP Drafter Be Trained Post-Pretraining? ›
Eagle 3 and D-Flash train their draft models post-pretraining on modest datasets (~100K–200K examples). Gemma 4's MTP drafter is trained jointly during pretraining — a much heavier commitment. A natural question: could the same post-training recipe be applied to Gemma 4's Q-only K/V-sharing architecture, and how close would the acceptance rate get?
In joint pretraining, the main model's WK and WV projections receive gradients from both the autoregressive loss and the drafter's prediction loss. They co-adapt: the main model learns to surface features the drafter can act on, and the drafter's centroid vocabulary clustering co-evolves with that K/V geometry. Post-training, those projections are frozen — the drafter inherits a representation space that was never optimized with it in mind.
Gemma 4's drafter uses ~2,048 learned centroids to cluster the 262K vocabulary into coarse groups before fine prediction. Trained jointly, these centroids carve the representation space in a way both models agree on. Trained post-hoc with frozen main-model weights, the centroids must rediscover a compatible clustering from scratch — harder, but learnable. A warm-start approach — initializing centroid embeddings via k-means over the main model's frozen output embedding matrix — would give the post-hoc drafter a geometrically aligned starting point and likely close most of the acceptance-rate gap.
Eagle 3's post-training success rests on one key finding: the target model's hidden states are information-dense enough that a drafter can learn to predict well from them without co-training. Gemma 4's drafter accesses a structurally similar signal via shared K/V caches. The key difference is directness — Eagle 3 feeds hidden states as explicit per-layer features; Gemma 4's drafter reads K/V caches implicitly through attention. That indirection doesn't make post-training impossible, it just makes it harder for the drafter to extract the full signal from frozen representations.
These are hypothetical estimates extrapolated from Eagle 3 and D-Flash benchmarks — no Gemma 4 post-training experiment has been published. The warm-start centroid strategy is the highest-leverage variable: initializing centroids via k-means over the frozen main model's output embedding matrix aligns the vocabulary clustering with the existing representation geometry, likely recovering most of the joint-training advantage at a fraction of the training cost. Domain-specific post-training scores higher within its target domain but degrades on out-of-distribution inputs, making it a poor general-purpose choice.