The teacher loses its mind after 100 student tokens
On-policy distillation (OPD) sounds great in theory: instead of imitating fixed teacher outputs, let the student generate its own rollout, then use the teacher to judge every token. The student learns from its own mistakes.
But there's a hidden problem. At the first token, the teacher sees only the prompt — it's fully in its natural reasoning state. At token 300, it has to condition on 299 tokens generated by the student, which looks nothing like its own training distribution. The teacher starts trying to complete the student's (often wrong) trajectory rather than correcting toward the right answer.
The authors measure this directly: they take a Qwen3-1.7B teacher and ask it to continue from a student-generated prefix of length N, then measure avg@4 accuracy on MATH-500. The teacher starts at 65.30% accuracy on its own, drops to 62.70% after seeing 100 student tokens, and collapses to 51.75% after 300 tokens — nearly the student's own 50.95% baseline.
This has a simple name: Off-Policy Teacher Decay. The teacher is only reliable on its own distribution. After ~100 student tokens, it has been dragged off-distribution and its token-level scores are no longer corrective — they're just completions of whatever wrong path the student started.
The fix they find is elegantly minimal: stop the student's rollout at N=100 tokens. Now the teacher only ever sees clean, in-distribution context.
How OPD works — and why the loss function is uniform
Standard knowledge distillation trains a student to match a teacher's output distribution on fixed data — the student never picks what to generate. On-policy distillation (OPD), introduced by MiniLLM and GKD, flips this: the student generates its own rollout, and the teacher scores each token of that rollout.
The loss is reverse KL divergence, averaged uniformly over all T tokens:
The reverse KL has a mode-seeking property: the student is penalized for putting mass on tokens the teacher doesn't support, but not penalized for concentrating mass on a single supported token. This matters a lot for sub-mode commitment — covered in §05.
Why full rollout is expensive
A single training step with OPD requires generating a full ~1000-token rollout. That's 180 seconds per step on a single A6000 GPU for Qwen3-1.7B. The paper shows ESR at N=100 tokens takes only 5 seconds — a 36× speedup on generation alone, totaling 24× wall-clock reduction.
A one-line fix: stop the rollout at N tokens
ESR's change is deliberately minimal. Instead of generating the full response y₁…y_T (typically ~1000 tokens), truncate the rollout at N tokens and compute the loss only over that window:
Everything else — temperature, LoRA configuration, optimizer, scorer — stays identical. The change is literally clipping the generated length.
Why early tokens encode the strategy
The paper includes a beautiful case study on a MATH-500 problem. In an isosceles right triangle, the altitude to the hypotenuse has length 4√2. What is the area?
Last 100 tokens (execution): "...altitude has length h = a/√2. Setting this equal to 4√2: a = 8. Area = ½ × 8 × 8 = 32."
The first window sets up the geometry, names the unknown, and identifies the key relationship (altitude bisects the leg) — these are the choices that determine success. The last tokens just execute algebra that any solver can finish once the strategy is fixed.
N sweep: robust region 50–200 tokens
A natural concern: is the method sensitive to N? The paper sweeps N on MATH-500 (Qwen2.5-Math-1.5B → Qwen3-1.7B). Performance saturates for N∈[50,200] and all choices beat OPD. The exception is cross-family pairs (Gemma→Qwen), which are sensitive and prefer N=50 — the bigger the student-teacher gap, the earlier decay sets in.
Training the first 100 tokens aligns the other 900 — for free
Here's a puzzle: ESR only supervises the first N=100 tokens. But a full rollout is ~1000 tokens. How can the student possibly learn to write good late-position tokens if it never receives training signal on them?
The authors measure per-position KL divergence (between trained student and teacher) before and after ESR training. The result is striking: positions 100–900+ that receive zero direct training signal drop in KL by 30–40%, matching the improvement in the trained window.
Why does this work? The authors propose two reasons:
- Strategic framing propagates. Early tokens encode problem setup and solution strategy. Once the student learns to frame problems correctly (like the altitude example), the execution naturally follows.
- Subliminal learning. Recent work (Cloud et al., 2025) shows LLMs can transmit behavioral traits via hidden signals. Early tokens may inject a "global mindset" rather than just fixing the prefix.
Remarkably, position turns out to be an independent axis from the token saliency signals you'd naively use. The paper tests selecting the same 100 tokens by highest KL divergence, highest entropy, or combined signals — they all underperform ESR, most of them underperforming even full-sequence OPD. The top-100-KL tokens contain 93% of the total trajectory loss but are not the effective tokens.
Why ESR-trained students sometimes beat the teacher
Across Table 1, ESR-trained students exceed the teacher's reference accuracy in multiple settings (marked ⋆). This seems paradoxical — shouldn't the teacher be an upper bound?
The answer lies in reverse KL's mode-seeking behavior. The teacher's distribution at any planning step is often multi-modal: it supports both a verbose, exploratory reasoning chain and a concise, direct approach. Full-rollout OPD averages the teacher's signal over many positions, which pulls the student toward the teacher's dominant (most probable) mode — which may be the verbose, over-thinking one.
ESR's early-window loss sees only the planning tokens, where the student's distribution is high-entropy and diverges most from the teacher. Reverse KL at these positions penalizes the student for putting mass on tokens the teacher doesn't support, but doesn't penalize concentration within the support. The student therefore commits to the secondary mode — the concise, correct-answer style — rather than averaging toward the verbose dominant mode.
The paper verifies this quantitatively. Among the top-10% highest-KL tokens after training, ESR students choose teacher's top-2–5 token (secondary modes) 47.4% of the time vs 44.6% for full OPD — and are more confident (top-1 probability 0.79 vs 0.77). The ESR model is more decisive and commits to a better sub-mode.
| Metric | Student | OPD | ESR |
|---|---|---|---|
| Top-1 token probability | 0.71 | 0.77 | 0.79 |
| Matches teacher's top-1 | 28.6% | 45.7% | 41.9% |
| In teacher's top 2–5 (sub-mode) | 59.8% | 44.6% | 47.4% |
| Outside teacher's top-5 | 11.5% | 9.7% | 10.7% |
Interestingly, ESR students also generate 2–3× shorter responses than their teacher: median ~380 tokens vs ~1,150 for the teacher. By removing late-position supervision, the student preserves its own concise style while inheriting the teacher's strategic reasoning.
ESR dominates across families, scales, and tasks
The main experiment sweeps three distillation regimes — same family / same generation (e.g. Qwen3-1.7B → Qwen3-4B), same family / cross generation (Qwen2.5 → Qwen3), and cross family (Gemma-2 → Qwen3) — at student scales 1.5B–32B and teacher scales 1.7B–72B.
MATH-500 avg@4 — key pairs
Training efficiency
| Metric | ESR | Full OPD | Speedup |
|---|---|---|---|
| Student generation time/step | 5 s | 180 s | 36× |
| Total wall-clock/step | 8 s | 194 s | 24× |
| Peak GPU memory (A6000) | 24.1 GB | 63.3 GB | 2.6× |
| Fits on single A6000 (48 GB) | ✓ yes | ✗ no | — |
Beyond raw numbers, ESR enables the student and teacher to fit on a single A6000 GPU (48 GB). Full OPD requires model loading/unloading between steps — the 24× speedup likely underestimates real-world gains.