Gemma 4 Multi-Token Prediction Speculative Decoding Interactive

How Gemma 4's Multi-Token
Prediction Works

Gemma 4 ships a compact assistant model that drafts multiple tokens per step — then the large backbone verifies them in a single parallel pass. This post tears apart the implementation: KV sharing, centroid-based vocabulary prediction, and bidirectional attention, with animated, code-annotated diagrams.

May 24, 2026 ~17 min read Source: transformers/models/gemma4_assistant
01 — The Problem

Why standard generation is slow

Large language models generate text one token at a time. Every new token requires a full forward pass through all N layers. On a 2B-parameter model that is fast; on a 27B model it becomes the throughput bottleneck. Because each token depends on the previous one, generation is inherently serial.

Multi-Token Prediction (MTP) — also called speculative decoding — breaks this constraint by running a lightweight draft model in parallel with the main model. The draft model proposes several tokens at once; the backbone verifies them all in a single pass. Accepted tokens are kept; rejected ones are discarded and the backbone corrects from that point. Net throughput gain: typically 2–4×.

Autoregressive (standard)

The capital of France is Paris .
baseline speed
7forward passes

Multi-Token Prediction

The capital of France is Paris .
~3×typical speedup
2–3backbone passes
💡 The draft model must be extremely fast — otherwise the overhead of drafting cancels the parallel-verification win. Gemma 4's assistant achieves this by reusing the backbone's KV cache and operating at a much smaller hidden dimension.
02 — Architecture

The two-model system

The system consists of Gemma4ForCausalLM (the backbone) and Gemma4AssistantForCausalLM (the draft model). The backbone runs a standard forward pass and writes the K/V states from its last non-sharing layer of each attention type into shared_kv_states. The assistant reads those two tensors to draft tokens without recomputing attention from scratch.

Click any component below to see its source code in the sidebar.

architecture-overview.svg Click components to view code
MAIN BACKBONE Gemma4ForCausalLM d_model = 1536 ScaledWordEmbedding + Per-Layer Embeddings (PLE) Transformer Layers full_attn + sliding_window LM Head Linear(1536→vocab_size) shared_kv_states dict["full_attention" | "sliding_attention"] → (key, value) tensors ← exposed to assistant shared_kv_states passed to assistant DRAFT MODEL Gemma4AssistantForCausalLM d_model = 768 pre_projection Linear(2×1536 → 768) concat backbone hid Transformer Layers uses shared_kv_states bidirectional attention post_projection Linear(768 → 1536) back to backbone dim MaskedEmbedder centroid-based vocab pred if use_ordered_embeddings draft logits → speculative decoding verification by backbone
The assistant's transformer is much smaller than the backbone (fewer layers, smaller hidden dim). It is fast enough that drafting 3–4 tokens costs less than one backbone forward pass.
03 — KV Sharing

Reusing the backbone's attention states

The key insight: the assistant does not need its own KV cache. Instead, it reads a shared dict of K/V tensors — one pair per attention type — that the backbone's last non-sharing layers wrote. The backbone's Gemma4TextModel.forward() populates shared_kv_states and returns it when return_shared_kv_states=True. The assistant receives this dict and every one of its layers looks up K/V by type: shared_kv_states[self.layer_type].

Since the assistant config forces num_kv_shared_layers = num_hidden_layers, every assistant layer is a KV-sharing layer — none of them allocate K or V projections at all. The assistant's layer count (e.g., 4) is completely independent of the backbone's (e.g., 26); what matters is that both models share the same layer_types sequence.

kv-sharing-flow.svg Animated data flow
BACKBONE layer 0 [full_attn] layer 1 [sliding_attn] ⋯ non-sharing layers ⋯ last non-sharing [sliding_attn] store_full_length_kv = True → writes SWA K/V last non-sharing [full_attn] store_full_length_kv = True → writes full K/V ↓ is_kv_shared_layer = True (tail layers) backbone tail (reuse same dict) shared_kv_states "sliding_attention": (K, V) "full_attention": (K, V) ASSISTANT (e.g. 4 layers) layer 0 [sliding_attn] shared_kv_states["sliding_attention"] layer 1 [full_attn] shared_kv_states["full_attention"] layer 2 [sliding_attn] shared_kv_states["sliding_attention"] ← same tensor layer 3 [full_attn] shared_kv_states["full_attention"] ← same tensor Key: exactly 2 K/V tensors in the dict · both written by the last non-sharing layer of each type assistant layers 0 and 2 (both sliding_attn) use the same tensor · layers 1 and 3 (both full_attn) use the same tensor lookup is by self.layer_type — no layer-index correspondence between backbone and assistant

Why this works

Normal attention needs K and V computed from the current sequence context. By borrowing the backbone's K/V, the assistant effectively sees the full context the backbone has built up — without paying to recompute it. The assistant only runs its own Q projection to attend over the borrowed K/V.

Which layers, exactly?

There is no 1:1 layer mapping. The assistant can have 4 layers while the backbone has 26 — that is by design, because shared_kv_states contains exactly two K/V pairs, keyed by attention type, not by layer index.

Here is the actual mechanism. The backbone designates its last num_kv_shared_layers layers as "KV-sharing layers" — they don't compute K or V at all. Among the preceding non-sharing layers, the code (modular_gemma4.py lines 978–983) identifies the last non-sharing layer of each attention type and marks it store_full_length_kv = True. When those layers run, they write into the dict:

  • Last non-sharing full_attention layer → shared_kv_states["full_attention"] = (K, V)
  • Last non-sharing sliding_attention layer → shared_kv_states["sliding_attention"] = (K, V)

That's it — two tensors total, regardless of how many layers the backbone has. Every KV-sharing layer, both the backbone's own tail and every assistant layer, then does a single type-keyed lookup (line 1031):

key_states, value_states = shared_kv_states[self.layer_type]

Two assistant layers of the same type (e.g., two sliding_attn layers) consume the exact same K/V tensor. The backbone's last num_kv_shared_layers layers are identical in structure to the assistant's layers — both are just KV-sharing layers running over the same frozen context.

kv-source-diagram.svg Two tensors, type-keyed lookup
BACKBONE (e.g. 26 layers) ASSISTANT (e.g. 4 layers) layers 0 … k-3 — normal self-attn, compute own K/V layer k-2 · sliding_attn store_full_length_kv = True layer k-1 · full_attn store_full_length_kv = True first_kv_shared_layer_idx = N - num_kv_shared_layers layers k … N-1 — is_kv_shared_layer also use shared_kv_states[self.layer_type] shared_kv_states "sliding_attention": (K, V) [B, W, n_heads, d_head] W = window "full_attention": (K, V) [B, S, n_heads, d_head] S = full prefix exactly 2 entries total asst layer 0 · sliding_attn Q[B,L,n_h,d_h] K/V[B,W,n_h,d_h] L≪W asst layer 1 · full_attn Q[B,L,n_h,d_h] K/V[B,S,n_h,d_h] L≪S asst layer 2 · sliding_attn same K/V tensor as layer 0 asst layer 3 · full_attn same K/V tensor as layer 1 Source (modular_gemma4.py): line 981: store_full_length_kv = not is_kv_shared_layer and layer_idx == last-index-of-this-type-in-non-sharing-region line 1031: if is_kv_shared_layer: key_states, value_states = shared_kv_states[self.layer_type] Assistant config: num_kv_shared_layers = num_hidden_layers → ALL assistant layers are KV-sharing (no own K/V at any layer) assistant num_hidden_layers is independent of backbone · only layer_type pattern must be compatible

Q-only cross-attention

Structurally this is identical to encoder–decoder cross-attention: Q comes from one stream (the assistant's hidden state), K and V come from another (the backbone's frozen context). The backbone is the "encoder" that has already processed the full prompt; the assistant is a lightweight "decoder" that attends into that encoded memory to draft new tokens.

The FLOP saving is direct: K and V projections together account for ⅔ of the QKV projection cost in standard self-attention. Removing them across all N layers — plus eliminating the assistant's own KV memory allocation — makes the drafter dramatically cheaper per forward pass than even a naive small model with its own attention would be.

Contrast: EAGLE-3 and DeepSeek MTP

These three designs represent the three main bets on how a draft model should access backbone context. The differences run deeper than engineering choices — they reflect fundamentally different views on where the useful signal lives.

drafter-context-access.svg How each drafter accesses backbone context
vs GEMMA 4 MTP
EAGLE-3 Backbone low layers ← l mid layers ← m top layers ← h concat → 3k FC [l ⊕ m ⊕ h] → g g fused (not K/V) Draft Transformer own W_Q W_K W_V builds its own KV cache feature sharing 3-layer fusion via FC own KV · causal hidden features yes GEMMA 4 MTP Backbone normal layers last SWA ← K/V last full ← K/V tail (reuse) backbone K/V cache backbone's own — reused no re-projection needed swa:K,V · full:K,V reused as-is ✓ embed (dₙ₋₁) Draft Transformer Q = RoPE(embed(dₙ₋₁)·W_Q) no W_K · no W_V · q_len=1 → dₙ sampled embed(dₙ) → next step backbone K/V reused as-is zero extra projection · Q only lookup by layer_type backbone's own K/V reused — no re-projection Signal type: Own K/V:
Aspect EAGLE-3 DeepSeek MTP DFlash P-EAGLE DART Gemma 4 MTP
Backbone → drafter Hidden states from selected backbone layers injected at drafter input Shared token embedding + LM head only; last hidden state passed between sequential modules Hidden states from 5 uniformly sampled backbone layers; projected through new W_K, W_V weights into fresh K/V vectors — these K/V tensors did not exist in the backbone Same 3-layer FC-fusion as EAGLE-3 (layers at indices 2, L/2, L−1); concatenated and projected to a single vector g — identical access pattern to EAGLE-3, just used for parallel prediction Same 3-layer FC-fusion as EAGLE-3 (low/mid/top → g) plus shifted token embedding e of the next predicted token; both concatenated into z = [g; e] before the draft layer Backbone's own K/V cache (computed during backbone's forward pass) borrowed directly — no re-projection, zero extra compute
K/V source Drafter computes its own K/V from its hidden state (full self-attention) Each MTP module computes its own K/V (full self-attention within the module) New K/V created by projection — hidden states passed through fresh W_K, W_V weights at every draft layer; K/V are a transformed derivative of backbone features, not the backbone's own cache Drafter computes its own K/V (full self-attention, same as EAGLE-3) — no sharing with backbone K/V Drafter computes its own K/V (single decoder layer, full self-attention) — same as EAGLE-3; no backbone K/V sharing Backbone's pre-computed K/V reused as-is — no W_K or W_V in the drafter; the backbone already ran K = h·W_K and V = h·W_V for its own attention, and those exact tensors are shared directly
Attention mask Causal — drafter generates one token at a time Causal within each sequential module Bidirectional within each draft block; all masked positions in a block attend to each other and to c Causal with depth-ordering constraints — all K positions processed together but each only attends to earlier depth positions; no future-draft leakage Strictly causal — over both prefix and within the masked block; MASK tokens cannot attend forward; no bidirectional refinement Bidirectional — all draft tokens see each other; backbone K/V provides causal history
Draft structure Parallel single-pass over all draft positions Sequential depth-wise: module 1→t+1, module 2→t+2, module 3→t+3 Block diffusion — all tokens in a block decoded simultaneously in one forward pass; no intra-block sequential dependency Parallel single-pass — all K draft tokens generated in one forward pass (vs K sequential passes in EAGLE-3); 4-layer drafter and a learnable shared hidden state replace per-step autoregressive context Single forward pass with MASK tokens — prefix [z₁:ₙ] + (d−1) MASK tokens fed to 1-layer decoder; yields d logits simultaneously; no iterative denoising; N-gram tree pruning builds the final draft tree Parallel single-pass over all draft positions
Vocab prediction Full LM head (tied with backbone) Full LM head (shared output projection) Full LM head Full LM head (tied with backbone) Full LM head; annealed KL divergence training (position-decaying weights γ=0.6 suppress noisy distant targets) Centroid-based: 32 of 2048 clusters → 4096 tokens scored (98.4% savings)
Core bet Backbone features are a rich enough signal at the input level Depth-wise decomposition: each extra token is one extra transformer layer of compute Injecting backbone context into every draft layer's K/V keeps the signal strong at depth; block diffusion decouples draft quality from latency RoPE encodes absolute position via attention alone — a learnable shared hidden state suffices for all K positions; 4-layer depth substitutes for lost autoregressive signal; one forward pass replaces K sequential ones 3-layer FC-fused features + shifted token embedding give the draft model the same signal as EAGLE-3; MASK tokens let a single causal layer predict d future positions in one shot; N-gram continuity prevents semantically incoherent draft trees The backbone's K/V cache is the context — use it directly, skip all re-encoding
🔑 The sharpest contrast: EAGLE-3 shares hidden state features (intermediate representations) but still computes its own K/V. DeepSeek shares only embedding weights. Gemma 4 shares the actual K/V number arrays the backbone computed — so the drafter's attention layers never touch the prompt at all.
04 — Drafter Internals

Inside one assistant transformer layer

The assistant's transformer is structurally almost identical to the backbone's — same RMSNorm, same GatedMLP (SwiGLU), same residual connections — with one surgical difference: the attention sub-layer has no K or V projection matrices. When each layer runs its attention, it looks up (K, V) by its own layer type: shared_kv_states[self.layer_type].

Below is the full data flow: the pipeline overview at the top shows where the assistant sits end-to-end, and the expanded view shows exactly what happens inside one transformer layer.

drafter-layer-internals.svg Pipeline + single-layer zoom
Full pipeline: backbone h [B,L,3072] pre_proj 3072→768 Transformer × N Q-only attn · bidir backbone K/V (2 tensors, by layer type) post_proj 768→1536 MaskedEmbedder draft logits Zoom — one assistant transformer layer (layer i): h_in [B, L, 768] ← from previous layer h_in [B,L,768] RMSNorm Q = h · W_Q no W_K · no W_V (these weight matrices simply don't exist) Q shared_kv_states[self.layer_type] "full_attention" or "sliding_attention" K [B,S,n_heads,d_head] V [B,S,n_heads,d_head] Attention(Q, K, V) softmax(QKᵀ/√d) · V · bidir mask + residual RMSNorm GatedMLP (SwiGLU) intermediate_size expansion + residual h_out [B, L, 768] → layer i+1 S = full prompt length (backbone context) L = draft length (typically 3–7 tokens) Q attends over the entire backbone prefix despite being only L tokens itself

The draft sequence length L (3–7 tokens) is far shorter than the backbone's full context S. Because K and V span the full S-length backbone context, each draft token's Q vector attends over the entire prefix in one shot — without the assistant ever having read that prefix.

05 — Projections

Bridging two hidden dimensions

The backbone runs at d_model = 1536; the assistant runs at a smaller hidden_size (e.g. 768). Two projection layers bridge the gap.

projection-pipeline.svg
backbone hidden [B, L, 1536] last validated pos concat token embedding [B, L, 1536] embed_tokens [B,L,3072] pre_projection Linear 3072 → 768 [B,L,768] Transformer N layers bidirectional shared KV [B,L,768] post_projection Linear 768 → 1536 [B, L, 1536] back to backbone dim returned as hidden_states * The concat (→3072) is the target model's token embedding + last-layer hidden state, both for the same last-validated position.

The pre_projection receives a concatenation of the target model's token embedding (from embed_tokens) and the target model's last-layer hidden state, both taken at the same last-validated position, and compresses the resulting 3072-dim vector to the assistant's smaller working dimension. The post_projection expands back to 1536 for one specific reason: the backbone's lm_head is Linear(1536 → vocab_size) — it is the verification head during speculative decoding, and it only accepts inputs in the backbone's dimension. The assistant's last_hidden_state (after post_projection) is fed directly into that same lm_head to compute the reference logits for acceptance/rejection checking. Without this projection, the shapes would be incompatible and the backbone could not re-use its own output head to verify drafts.

06 — Efficient Vocabulary

Centroid-based token prediction

Computing logits over a 256K-token vocabulary requires a Linear(768, 256000) matrix multiply — expensive for a small draft model. Gemma4AssistantMaskedEmbedder avoids this by clustering vocabulary tokens into centroids and only scoring the top-K clusters.

centroid-vocab.svg Replay animation
① Score 2048 centroids ② Top-K=32 ③ Gather 32×128=4096 tokens ④ Scatter logits hidden [B,L,768] centroids() Linear(768→2048) All 2048 centroids: topk(k=32) of 2048 32 centroids selected (1.6%) gather lm_head rows 4096 rows of 768 32 × 128 = 4096 tokens evaluated vs 256,000 for full LM head → 98.4% computation saved dot product [B,L,4096] scatter_ [B,L,V] Vocabulary space [vocab_size = 256K]: ← 4096 tokens scored token_ordering buffer: token_ordering = tensor of shape [vocab_size] — maps centroid_id * 128 + local_idx → global_vocab_idx This buffer is learned during training to group semantically similar tokens together under the same centroid. At inference time it is fixed — just an index lookup to find the 128 token positions belonging to each selected centroid.
🔬 With num_centroids=2048 and top_k=32, the assistant evaluates 32 × 128 = 4,096 tokens instead of the full vocabulary of ~256K — a 98.4% reduction in logit computation.
07 — Bidirectional Attention

Why the assistant uses bidirectional masks

The assistant does use bidirectional attention masks — this is confirmed by the source imports and calls (create_bidirectional_mask, create_bidirectional_sliding_window_mask). But this raises an obvious question: if there is no causal mask, how can the assistant generate autoregressively?

The answer is in the docstring: "no difference for the edge case of q_len == 1 as it acts as full attention no matter what." During each draft step, the assistant processes a single token (q_len = 1). With one query position, bidirectional vs causal is mathematically identical — there are no other draft positions to mask. Drafting remains autoregressive in the usual sense: each step samples one token, the backbone updates its K/V, and the next step reads that updated context.

For full attention this non-difference is easy to see in the matrix: the new draft token sits at the very last row, and a causal lower-triangle already has all 1s there — identical to an all-1s bidirectional row. No special logic needed.

Sliding-window attention is different. SWA uses a fixed-width window (W tokens) rather than the full context, so direction matters. The backbone ran prefill with a past-looking window — each token attended to the previous W tokens. Its stored K/V reflects that. The bidirectional SWA utility, however, anchors the window at the start of the query sequence (future-looking by default): for q_len=1, the single Q gets a window opening at K/V slot 0, not at the recent end of the context. Those are opposite ends of the sequence — attending there would be meaningless.

The fix is a one-line tensor flip: swa_mask.flip(dims=(-1,)). Reversing the K/V axis turns [1,1,1,0,…,0] into [0,…,0,1,1,1] — shifting the window from the front of the K/V buffer to the back, exactly where the backbone's most recent SWA K/V lives.

Importantly, the flip changes nothing about the attention mechanism itself — softmax(QKᵀ/√d + mask)·V is identical before and after. It is purely a mask correction: the utility produced the wrong mask values (window at the wrong end), and the flip corrects them. You could equivalently have written a custom function that produced [0,…,0,1,1,1] directly, making the flip unnecessary. The bidirectional mask and the flip together are just the codebase's way of arriving at the mask tensor the drafter actually needs.

attention-masks.svg

Full attention — 3 draft tokens × 5 backbone tokens (all backbone tokens are in the past)

causal full attention q_len=3 · kv_len=5 · no window t₀ t₁ t₂ t₃ t₄ d₀ d₁ d₂ t₀..t₄ at positions 0-4; d₀..d₂ at 5-7 causal: j ≤ p always true → all cells on bidirectional full attention q_len=3 · kv_len=5 · no window t₀ t₁ t₂ t₃ t₄ d₀ d₁ d₂ bidir: unrestricted → also all cells on identical to causal — mask type irrelevant ✓

SWA flip — all 3 draft steps · q_len=1 each · kv_len=5 · W=3

before flip(KV) q_len=1 × 3 steps · kv_len=5 · W=3 t₀ t₁ t₂ t₃ t₄ d₀ d₁ d₂ bidir utility gives same mask every step kv_len fixed → identical q_len=1 call each time after flip(dims=(-1,)) q_len=1 × 3 steps · kv_len=5 · W=3 t₀ t₁ t₂ t₃ t₄ d₀ d₁ d₂ same [t₂,t₃,t₄] window for d₀, d₁ and d₂ ideal SWA: d₀→(t₃,t₄) d₁→(t₄) d₂→∅ flip ignores absolute position — an approximation

Target model SWA (self-attn) vs drafter SWA (cross-attn) — W=3, kv_len=5

t₀ t₁ t₂ t₃ t₄ ← backbone K/V positions → target self-attn ↓ t₀ t₁ t₂ t₃ t₄ drafter cross-attn d₀ = same window Target model (self-attention) Q and K/V both from the backbone sequence. Each token attends to its past W neighbours. K/V for all S positions stored after prefill. Drafter (cross-attention) Q from draft token, K/V from backbone only. No draft K/V exists — kv_len stays fixed at S. The flip anchors the window at the K/V end. Why they match t₄ in target SWA: attended to t₂,t₃,t₄ d₀ after flip: attends to t₂,t₃,t₄ d₀ is at position S = 5, the next slot after t₄. Its SWA window is the same last-W K/V tokens. The drafter's mask after flip reproduces exactly the window the backbone's own last token used.

How draft tokens stay distinct despite sharing K/V and weights

A natural question: if d₀, d₁, and d₂ all attend to the same backbone K/V with the same model weights, shouldn't they produce the same token? They don't, for two reasons that compound each other.

Different input embeddings. Each draft step is conditioned on the token sampled in the previous step. When generating d₁, the draft model's input is the embedding of d₀ — the token just sampled. Generating d₂ feeds d₁'s embedding. A different input produces a different Q, a different attention output, and a different prediction.

Different positional encoding (RoPE). Even if the token embedding were somehow identical across steps, the Q vector at position 5 is rotated differently from position 6, which is rotated differently from position 7. RoPE bakes absolute position directly into Q, making the query to the backbone K/V position-sensitive. The backbone K/V at position 2 looks geometrically "close" to Q at position 3 and much further to Q at position 6 — the attention pattern shifts at every step even with the same weights.

Concretely, each step computes:

d₀:  Q = RoPE(embed(t₄) · W_Q,  pos=5)  →  attend backbone K/V  →  sample d₀
d₁:  Q = RoPE(embed(d₀) · W_Q,  pos=6)  →  attend backbone K/V  →  sample d₁
d₂:  Q = RoPE(embed(d₁) · W_Q,  pos=7)  →  attend backbone K/V  →  sample d₂

Both the input token and the rotated position change at every step. The backbone K/V and the weight matrix W_Q are fixed; the query is not.

Why it works at all. The backbone K/V encodes the full prompt. The draft model's task is narrow: given this context, predict the most likely token at position X. The positional encoding in Q carries the "what position am I at?" signal, and the input embedding carries the "what was just generated?" signal. Together they are usually enough for short continuations — the model is trained end-to-end on exactly this objective. The limitation shows up in the acceptance rate: without explicit draft-to-draft attention (unlike EAGLE), later tokens in the draft sequence are conditioned only on the backbone context plus their immediate predecessor, so quality degrades faster as draft length grows. This is why draft lengths are kept short (3–7 tokens).

The underlying design bet. Viewed this way, each draft token is nearly an independent prediction from backbone context, with only a shallow one-step dependency on its immediate predecessor. The backbone K/V spans hundreds or thousands of prompt tokens; the single preceding draft token adds only a marginal local signal on top of that dominant context. This is the architectural wager: for short continuations, the prompt alone is usually sufficient to predict the next few tokens correctly — the model rarely needs to know what earlier draft tokens were in order to get the later ones right. Most of the information that determines the next token was already encoded in the backbone's representation of the prompt. The draft tokens are closer to parallel lookups into that representation than to a serial reasoning chain. Keeping drafts short (3–7 tokens) is precisely the regime where this bet pays off: the inter-draft dependencies that are being ignored are weak enough that the acceptance rate stays high.

design-bet.svg Auto-playing
backbone K/V · same context queried independently at every draft step The quick brown fox jumped over ↑ dominant signal · same backbone K/V at every step d₀ pos=6 the d₁ pos=7 lazy d₂ pos=8 dog embed(d₀) embed(d₁) → weak signal · one embedding passed between steps
08 — The Full Loop

Putting it all together

At generation time, HuggingFace's model.generate() orchestrates the speculative decoding loop automatically when assistant_model is provided. The backbone and assistant alternate: assistant drafts a batch, backbone verifies, accepted tokens are committed, and the loop continues.

generation-loop.svg — Accept / Reject Auto-playing
Draft is Paris , the most
Verify is Paris , the most
Accept is Paris ,
Reject the (backbone corrects from here)

After the first rejected token, all subsequent draft tokens are discarded and the backbone generates the correct continuation. In this example, 3 tokens are accepted and 1 is rejected — still a net 3× gain over pure autoregressive generation.

🚀 Use it with: model.generate(inputs, assistant_model=assistant, max_new_tokens=200). No other changes required — HuggingFace handles the entire speculative decoding loop.
Bonus Could the MTP Drafter Be Trained Post-Pretraining?

Eagle 3 and D-Flash train their draft models post-pretraining on modest datasets (~100K–200K examples). Gemma 4's MTP drafter is trained jointly during pretraining — a much heavier commitment. A natural question: could the same post-training recipe be applied to Gemma 4's Q-only K/V-sharing architecture, and how close would the acceptance rate get?

JOINT PRETRAINING (current) Main Model Gemma 4 grad on MTP Drafter Assistant grad on K, V grad flow K/V projections co-evolve centroids shape repr space jointly full pretraining corpus POST-TRAINING (hypothetical) Main Model Gemma 4 frozen MTP Drafter Assistant grad on K, V no grad back K/V projections fixed at pretraining drafter must adapt to frozen repr small post-training dataset
Why K/V Co-Adaptation is Load-Bearing

In joint pretraining, the main model's WK and WV projections receive gradients from both the autoregressive loss and the drafter's prediction loss. They co-adapt: the main model learns to surface features the drafter can act on, and the drafter's centroid vocabulary clustering co-evolves with that K/V geometry. Post-training, those projections are frozen — the drafter inherits a representation space that was never optimized with it in mind.

The Centroid Alignment Challenge

Gemma 4's drafter uses ~2,048 learned centroids to cluster the 262K vocabulary into coarse groups before fine prediction. Trained jointly, these centroids carve the representation space in a way both models agree on. Trained post-hoc with frozen main-model weights, the centroids must rediscover a compatible clustering from scratch — harder, but learnable. A warm-start approach — initializing centroid embeddings via k-means over the main model's frozen output embedding matrix — would give the post-hoc drafter a geometrically aligned starting point and likely close most of the acceptance-rate gap.

What Eagle 3 Tells Us

Eagle 3's post-training success rests on one key finding: the target model's hidden states are information-dense enough that a drafter can learn to predict well from them without co-training. Gemma 4's drafter accesses a structurally similar signal via shared K/V caches. The key difference is directness — Eagle 3 feeds hidden states as explicit per-layer features; Gemma 4's drafter reads K/V caches implicitly through attention. That indirection doesn't make post-training impossible, it just makes it harder for the drafter to extract the full signal from frozen representations.

Hypothesis — Projected Acceptance Rate
100% Joint Pretraining (current) 100% reference Post-Train, warm-start centroids ~88–94% recommended Post-Train, general dataset ~82–88% Post-Train, domain-specific only ~78–85% in-domain only

These are hypothetical estimates extrapolated from Eagle 3 and D-Flash benchmarks — no Gemma 4 post-training experiment has been published. The warm-start centroid strategy is the highest-leverage variable: initializing centroids via k-means over the frozen main model's output embedding matrix aligns the vocabulary clustering with the existing representation geometry, likely recovering most of the joint-training advantage at a fraction of the training cost. Domain-specific post-training scores higher within its target domain but degrades on out-of-distribution inputs, making it a poor general-purpose choice.