When Every Token Is Noisy
The standard story of attention-as-memory goes like this: the context tokens act as clean "memories," and attention retrieves the one that best matches a query. Smart et al. (2025) formalized this — a single attention layer provably achieves Bayes-optimal denoising when the context tokens are clean.
But what happens when every token in the context is corrupted by noise? Now the "memories" are blurry. The attention layer can't retrieve the right answer because it doesn't know the true prior distribution — and a single step can't fix that.
This paper studies the all-token corruption setting: N tokens drawn i.i.d. from an unknown prior ρ₀, all independently corrupted by isotropic Gaussian noise N(0, σ²I). The goal is to recover the clean tokens from the noisy batch alone — without ever seeing a clean example.
The left panel shows 24 noisy tokens (gray dots) drawn from a two-mode prior with clusters at x = −2 and x = +2, corrupted by σ² = 0.5. Without prior refinement, the Stage 2 posterior mean for the amber query token (near x = +1.4) smears across both clusters and lands near the center — the wrong answer.
The right panel shows the same tokens after 40 self-attention steps. The particles have migrated toward the true cluster locations. Now the posterior mean correctly identifies the right cluster and denoises the query to x̂* ≈ +2. Depth earns its keep.
Why depth is necessary
When Stage 1 is skipped (L = 0), the posterior-mean computation has to work with the noisy empirical distribution. For a symmetric mixture, the posterior is nearly 50-50 — any estimate will be pulled toward the center. Only after the particles have been iteratively refined toward the true prior can Stage 2 reliably distinguish between modes.
Attention as Reverse Diffusion
The key insight is that a single self-attention step has exactly the form of one step of reverse diffusion. The Nadaraya–Watson update — a Gaussian-kernel weighted average of neighbor positions — is precisely what you'd compute to estimate the score function ∇ log ρ and take one step up the density gradient.
aij = softmaxj( −β/2 · ‖zi(ℓ) − zj(ℓ)‖² )
// β = kernel bandwidth, η = step size, L ≈ βσ²/2η layers total
In the limit of N → ∞ particles and continuous depth t = ℓη, this converges to the mean-field flow of reverse diffusion: each particle drifts toward higher-density regions, simultaneously anti-diffusing the collective empirical distribution from ρ_noisy = N(0, σ²) toward the clean prior ρ₀.
The bandwidth β controls how "local" the averaging is. Too small and particles average globally, washing out structure. Too large and particles cluster prematurely at finite N, collapsing before reaching the true prior. There's a sweet spot.
Watch the gray dots — initially scattered near x = ±2 with σ = 0.7 noise — flow toward the two true cluster centers (faint circles). With β = 5 the convergence is smooth and steady. With β = 1 the flow is sluggish, never fully denoising within the budget. With β = 20 particles collapse inward too fast, falling into a premature cluster at finite N.
Three equivalent views of Stage 1
The update rule in Eq. 1 admits three equivalent interpretations that each illuminate a different aspect of what's happening:
- Reverse diffusion: each step approximates one Langevin step up the log-density gradient of the empirical kernel density
- Nadaraya–Watson kernel regression: each token is replaced by the kernel-weighted average of its neighbors — a nonparametric density mode seeker
- Gaussian attention: exactly the self-attention mechanism with scaled dot-product keys/queries and identity value weights
Depth Refines; the Skip Connection Queries
The full architecture (Algorithm 1) decomposes into two stages with architecturally distinct roles. Stage 1 is the self-attention loop across depth — it iteratively refines the particle prior. Stage 2 is a final cross-attention step that uses the noisy input as a query against the refined particle distribution to compute the posterior mean.
The key architectural element enabling Stage 2 is a long-range skip connection — an attention residual that carries the original noisy tokens as queries, bypassing Stage 1 entirely. Without this skip, Stage 1 would overwrite the token identity information needed for posterior inference.
The roles of depth and residuals, derived
This two-stage decomposition gives architecture elements statistically derivable roles rather than merely empirical ones:
- Depth (L): controls how well Stage 1 has refined the particle prior — more layers means a sharper, more accurate approximation to ρ₀
- Bandwidth β: must be matched to noise scale σ² via L ≈ βσ²/(2η) for correct denoising time
- Stage 2 bandwidth βc = 1/σ²: the Bayesian posterior bandwidth — derived from the Gaussian noise model, not learned
- Long-range skip (AttnRes): carries the noisy original token x̃ᵢ to the Stage 2 query position, preserving identity for posterior averaging
Stage 1 Sculpts the Associative Memory
Stage 2 can be rewritten as gradient descent on an energy landscape induced by the current particle set Z(ℓ):
E(Z; q) = −1/βc · log ∑j exp(−βc/2 · ‖q − zj‖²)
// Dense associative memory (Ramsauer 2021) — but memories emerge in-context
This is a dense associative memory (Hopfield network), except the memories aren't stored parameters — they're the refined particles Z(ℓ) from Stage 1. As depth increases, Stage 1 sharpens these particles from a noisy cloud into distinct clusters, dynamically sculpting the energy wells.
At depth 0, the particles are scattered randomly — the energy wells are shallow and wide, unable to reliably attract queries to the correct basin. At depth 40, the particles have clustered, forming two sharp wells at x ≈ ±2. The query x̃* (amber star) falls into the deeper well and is denoised to the correct cluster. Depth sculpts the inference landscape without any learned parameters.
Connection to in-context learning
From this perspective, multilayer attention implements something remarkable: a transformer that has never seen the test distribution can construct an effective associative memory during the forward pass by using the context tokens as both the "training data" for the energy landscape and the queries to be denoised.
A Fixed Bandwidth Replaces the Noise Schedule
Standard diffusion models require a carefully engineered noise schedule — a time-dependent σ(t) that controls how much noise is added/removed at each step. This paper shows that for attention-based denoising, a noise schedule is unnecessary: a single fixed bandwidth β and the right integration time suffice.
In the large-N, continuous-depth limit, the Stage 1 variance evolution follows:
dv/dt = −2 → v(t) = σ² − 2t → T* = σ²/2
// Finite β correction (Theorem F.1 from Appendix F):
dv/dt = −2v / (v + β⁻¹) // slower, curved for small β
The optimal integration horizon is T* = σ²/2: run Stage 1 for exactly this long and the empirical distribution will have converged to the clean prior. In terms of layers: L* ≈ βσ²/(2η). No schedule needed — just the right depth.
The chart plots variance of the particle distribution v(t) as a function of integration time t. β = 5 tracks the theoretical linear decay (dashed gray) nearly exactly, reaching v ≈ 0 at T* = 0.125 (σ² = 0.25). β = 1 is governed by the corrected ODE and converges much slower. β = 20 at finite N shows premature collapse — fast initial decay, but particles form spurious clusters before reaching T*.
Depth as a noise-informed hyperparameter
The depth–noise relationship L* ≈ βσ²/(2η) is a principled guide for architecture design. Unlike heuristic depth selection, this formula is derived from the denoising dynamics. It implies that noisier data requires more layers — not because of gradient flow considerations, but because there's more diffusion to reverse.
Convergence to Bayes-Optimal
The main theoretical result (Theorem 5.1) proves that the two-stage estimator converges to the Bayes-optimal denoiser — the posterior mean under the true prior — through a sequence of limits: first N → ∞ at fixed (β, R), then the truncation radius R → ∞, then β → ∞.
limβ→∞ limR→∞ limN→∞ E[ sup|y|≤M |mμ(Tβ)(y) − mP₀(y)| ] = 0
// m_P(y) = E[X | Y=y] = Bayes-optimal posterior mean under prior P
// μ(Tβ) = empirical distribution of particles at integration time Tβ = βτ/2
Concretely for Gaussian priors P₀ = N(a, Σ₀), the mean-field flow stays Gaussian and the covariance evolves analytically as Γ̇ = −2Γ(I + βΓ)⁻¹, converging to Σ₀ at T = Tβ. The truncation error vanishes with high probability once R > √(log N) · const.
Normalized MSE / σ² on 2-mode GMM (µ = ±1, σ² = 0.5). Lower is better.
Values from Fig. 4(a): β=20, σ²=0.5, L=200. x-axis scaled to max bar = 1.76.
More context length (N) consistently closes the gap to Bayes MMSE. At N = 8000 the two-stage estimator reaches within 70% of Bayes-optimal without any learned parameters and without knowing the prior ρ₀ — purely from the context. Stage 1 alone (predicting the noisy input directly, without Stage 2) performs poorly at 1.76, near random.
What's not yet proven
The paper is honest about its limitations. The recovery theorem uses sequential limits — not joint scaling — and doesn't give finite-sample convergence rates. The admissible class Aτ is verified for Gaussians but not for multimodal or heavy-tailed priors. Practical transformers also have multiple heads, MLP layers, learned weights, and positional encoding — none of which are captured here. What the paper establishes is a principled statistical account of what the minimal single-head attention mechanism is computing, as a foundation for richer settings.