Why √dmodel Is Your North Star
Before training starts, every LayerNorm (or RMSNorm) weight vector γ is initialized to all ones — a deliberate identity: after normalizing the activations to zero-mean unit-variance, the layer passes them through unchanged at step 0. This is the "do nothing" starting point that lets the model learn its own rescaling from scratch.
That all-ones vector has a precise L₂ norm. If your model's hidden dimension is dmodel, then every element γi = 1, and:
This is the exact norm you will see in your training logs at step 0, every time, for any model: it is a mathematical identity, not an empirical observation. It's the number against which all subsequent drift should be measured.
Try it: pick your dmodel
You can verify this in one line of PyTorch at step 0 before any optimizer update:
model.final_layernorm.weight.norm().item() should equal √d_model to floating-point precision. If it doesn't, your weight initialization is wrong.
Healthy Shrink vs. Dangerous Collapse
During normal pretraining, the γ vector will drift away from all-ones — that is expected and healthy. The model is learning to up-weight informative feature dimensions and down-weight noisy ones. But the direction and magnitude of drift carry distinct signals.
The visualization below shows what happens to the per-dimension γ values as training progresses. Toggle between "healthy" and "collapsed" states to see the difference.
All γi = 1.0. The LayerNorm is an identity rescaler — each feature dimension carries equal weight. ‖γ‖₂ = √64 = 8.0 for this 64-dim example.
The three regimes
- Healthy shrink (norm 40–55 for d=3072). Some dimensions are dampened, others amplified. The net norm is lower than init because the residual stream has grown across many layers — the LayerNorm weights compensate by contracting certain channels. This is normal and desirable.
- Dangerous collapse (norm < 20). The γ vector has uniformly shrunk toward zero. The final LayerNorm is actively erasing signal — all feature dimensions are being heavily suppressed before they reach the LM head. For MTP heads that branch here, they receive impoverished representations.
- Explosion (norm > 2× √d). Some γi values have grown large, over-amplifying specific dimensions. The final hidden state is dominated by a few noisy channels. The LM head sees logit explosions, and the MTP auxiliary loss gradients carry enormous scale mismatches backward.
‖final_layernorm.weight‖₂ has drifted more than 2× from √d_model in either direction, treat it as a training health warning. For d=3072 that means values outside [27, 110] need investigation.
The Exact Point Where Norms Become Critical
To understand why the final LayerNorm norm matters so much, you need to know precisely where the MTP module branches off from the main model — and what state the representations are in at that branching point.
In DeepSeek-V3's implementation (and similar architectures like the EAGLE-style MTP used by Gemma 4), the MTP module receives the hidden states h0i from the last layer of the main transformer — the same representations that will pass through final_layernorm before reaching the LM head. The MTP module never sees the normalized output; it sees the raw residual-stream vectors.
The key architectural insight: final_layernorm sits on both paths. The main model passes through it to produce logits for t+1. The MTP module receives the pre-norm hidden state h⁰i — but the scale of that vector is fundamentally tied to what the LayerNorm has learned to do. If γ has collapsed, the representations that enter the MTP projection are already impoverished.
In DeepSeek-V3, each MTP module applies its own RMSNorm to h⁰i before combining it with the next token's embedding. So the MTP path has its own normalization — but the information content of the representations still depends on how the main model's final layer has learned to organize its hidden state.
LayerNorm as a Gradient Scale Gate
The most dangerous effect of norm drift in an MTP setup isn't representation quality — it's the gradient flow. LayerNorm's backward pass scales gradients by its weight vector γ. When the MTP auxiliary loss LMTP backpropagates through the final layers of the main transformer, those gradients are modulated by γ's current magnitude.
This creates a subtle but serious problem: the MTP auxiliary loss has a scaling factor λ (DeepSeek-V3 uses λ=0.3 for the first 10T tokens, then λ=0.1) intended to balance the MTP contribution against the main loss. But if γ has dramatically collapsed, the actual gradient entering the main transformer from the MTP path is far smaller than intended — effectively disabling the MTP training signal even though the loss looks fine.
At ‖γ‖₂ = 55.4 (healthy init), MTP gradients enter the main transformer at their intended scale. The λ=0.1 loss weight works as designed.
The converse failure — γ explosion — is equally dangerous. Oversized γ values amplify the MTP gradient by 2–3×, drowning out the main LM loss and destabilizing the primary pretraining objective. You can observe this as the MTP loss going down while the main language model loss suddenly starts rising or oscillating.
The stability trinity
- MTP losses must both trend downward. Monitor
lm_loss,mtp_1_lossseparately in TensorBoard/WandB. The gap between them should stay roughly stable. - The gap between main and MTP loss should be stable. If MTP loss is rising while main loss falls, the auxiliary task is being crowded out — lower
mtp_loss_scaling_factor. - ‖final_layernorm.weight‖₂ should stay in [0.4×√d, 1.8×√d]. Outside this band, the gradient scaling is distorted enough to require intervention.
Why Norm Matters at Inference Time Too
During training, the MTP module is an auxiliary task. But in DeepSeek-V3 and similar models, the MTP heads are repurposed at inference time for speculative decoding: the MTP head drafts candidate tokens, and the main model verifies them. The throughput gain (often 1.8–2.5×) depends entirely on the acceptance rate — the fraction of draft tokens the verifier accepts.
Acceptance rate is determined by how well the draft distribution matches the target. A collapsed or exploded γ directly distorts this calibration:
The mechanism: final_layernorm.weight controls the scale of activations entering the output head. When γ collapses, the model's logits become artificially peaky (over-confident on a few tokens). When γ explodes, logits flatten (model becomes over-uncertain). Both distort the draft distribution away from what the verifier expects.
| ‖γ‖₂ range (d=3072) | Training signal | Representation quality for MTP | Speculative decoding calibration |
|---|---|---|---|
| 50–62 (≈√d) | ✓ balanced | ✓ rich, general-purpose features | ✓ high acceptance rate |
| 27–50 (healthy shrink) | ✓ normal | ✓ somewhat specialized | ✓ minor degradation |
| <20 (collapse) | ✗ MTP gradient starved | ✗ impoverished, dimensionally flat | ✗ over-confident drafts, low acceptance |
| >110 (explosion) | ✗ MTP gradient floods main loss | ✗ noisy, spike-dominated | ✗ flat drafts, low acceptance |
What to Log, When to Act
Here is a concrete step-by-step playbook for monitoring your MTP run. The critical principle: log these metrics from step 0, not after you notice a problem.
assert abs(norm0 - math.sqrt(d_model)) < 0.01
if not (LOW <= norm <= HIGH): alert("layernorm norm out of band")
lambda_mtp = 0.3 if tokens_seen < 10e12 else 0.1
h_combined = M_k @ concat([h_mtp, Emb(t_{i+k})])
Common model configurations at a glance
| Model | d_model | Init norm (√d) | Healthy range | MTP λ |
|---|---|---|---|---|
| LLaMA-3.2-3B | 3072 | 55.4 | 22 – 100 | 0.1–0.3 |
| LLaMA-3.1-8B | 4096 | 64.0 | 26 – 115 | 0.1–0.3 |
| LLaMA-3.1-70B | 8192 | 90.5 | 36 – 163 | 0.05–0.2 |
| DeepSeek-V3 | 7168 | 84.7 | 34 – 152 | 0.3→0.1 |
final_layernorm.weight.norm() from step 0. Expect it to start at √d_model and gently shrink as training converges. If it collapses below 0.4×√d or explodes above 1.8×√d, your MTP gradient balance is broken — adjust λ before the damage propagates into the main model.