Theory Attention Mechanisms Empirical Bayes Interactive

Attention as In-Context Empirical Bayes

Transformer self-attention isn't just a computational primitive — it's performing principled statistical inference. A minimal attention-only transformer implements a two-stage empirical Bayes denoiser: depth refines a particle approximation to the unknown data prior via reverse diffusion, while a long-range skip connection queries it for Bayes-optimal posterior averaging.

June 1, 2026 ~14 min read Paper: arXiv:2605.29351
01 — The Problem

When Every Token Is Noisy

The standard story of attention-as-memory goes like this: the context tokens act as clean "memories," and attention retrieves the one that best matches a query. Smart et al. (2025) formalized this — a single attention layer provably achieves Bayes-optimal denoising when the context tokens are clean.

But what happens when every token in the context is corrupted by noise? Now the "memories" are blurry. The attention layer can't retrieve the right answer because it doesn't know the true prior distribution — and a single step can't fix that.

This paper studies the all-token corruption setting: N tokens drawn i.i.d. from an unknown prior ρ₀, all independently corrupted by isotropic Gaussian noise N(0, σ²I). The goal is to recover the clean tokens from the noisy batch alone — without ever seeing a clean example.

before-vs-after Stage 1 — posterior mean for query x̃* single-step fails; depth fixes it
Left: posterior mean without Stage 1 (stage1=0) → converges to wrong cluster center  |  Right: posterior mean after Stage 1 → sharp, correct

The left panel shows 24 noisy tokens (gray dots) drawn from a two-mode prior with clusters at x = −2 and x = +2, corrupted by σ² = 0.5. Without prior refinement, the Stage 2 posterior mean for the amber query token (near x = +1.4) smears across both clusters and lands near the center — the wrong answer.

The right panel shows the same tokens after 40 self-attention steps. The particles have migrated toward the true cluster locations. Now the posterior mean correctly identifies the right cluster and denoises the query to x̂* ≈ +2. Depth earns its keep.

Why depth is necessary

When Stage 1 is skipped (L = 0), the posterior-mean computation has to work with the noisy empirical distribution. For a symmetric mixture, the posterior is nearly 50-50 — any estimate will be pulled toward the center. Only after the particles have been iteratively refined toward the true prior can Stage 2 reliably distinguish between modes.

02 — Stage 1: Particle Dynamics

Attention as Reverse Diffusion

The key insight is that a single self-attention step has exactly the form of one step of reverse diffusion. The Nadaraya–Watson update — a Gaussian-kernel weighted average of neighbor positions — is precisely what you'd compute to estimate the score function ∇ log ρ and take one step up the density gradient.

Stage 1 attention update (Eq. 1)
zi(ℓ+1) = (1 − η) zi(ℓ) + η ∑j aij zj(ℓ)
aij = softmaxj( −β/2 · ‖zi(ℓ)zj(ℓ)‖² )
// β = kernel bandwidth, η = step size, L ≈ βσ²/2η layers total

In the limit of N → ∞ particles and continuous depth t = ℓη, this converges to the mean-field flow of reverse diffusion: each particle drifts toward higher-density regions, simultaneously anti-diffusing the collective empirical distribution from ρ_noisy = N(0, σ²) toward the clean prior ρ₀.

The bandwidth β controls how "local" the averaging is. Too small and particles average globally, washing out structure. Too large and particles cluster prematurely at finite N, collapsing before reaching the true prior. There's a sweet spot.

Stage 1 particle dynamics — 24 noisy tokens converging to clean prior play to animate
step 0 / 60 β =

Watch the gray dots — initially scattered near x = ±2 with σ = 0.7 noise — flow toward the two true cluster centers (faint circles). With β = 5 the convergence is smooth and steady. With β = 1 the flow is sluggish, never fully denoising within the budget. With β = 20 particles collapse inward too fast, falling into a premature cluster at finite N.

Three equivalent views of Stage 1

The update rule in Eq. 1 admits three equivalent interpretations that each illuminate a different aspect of what's happening:

  • Reverse diffusion: each step approximates one Langevin step up the log-density gradient of the empirical kernel density
  • Nadaraya–Watson kernel regression: each token is replaced by the kernel-weighted average of its neighbors — a nonparametric density mode seeker
  • Gaussian attention: exactly the self-attention mechanism with scaled dot-product keys/queries and identity value weights
03 — Two-Stage Architecture

Depth Refines; the Skip Connection Queries

The full architecture (Algorithm 1) decomposes into two stages with architecturally distinct roles. Stage 1 is the self-attention loop across depth — it iteratively refines the particle prior. Stage 2 is a final cross-attention step that uses the noisy input as a query against the refined particle distribution to compute the posterior mean.

The key architectural element enabling Stage 2 is a long-range skip connection — an attention residual that carries the original noisy tokens as queries, bypassing Stage 1 entirely. Without this skip, Stage 1 would overwrite the token identity information needed for posterior inference.

two-stage architecture — attention residual enables posterior averaging depth = prior refinement; skip = posterior query
noisy tokens x̃₁, …, x̃ₙ Stage 1 self-attention × L z(ℓ+1) = (1-η)z(ℓ) + η · Attn(z(ℓ), z(ℓ)) ℓ=0…L-1 particle prior Z(L) Z(L) skip: noisy query x̃ᵢ (AttnRes) Stage 2 cross-attention (1 step) βc = 1/σ² x̂ᵢ = Σⱼ bᵢⱼ z(L) x̂₁…x̂ₙ denoised Stage 1 (depth) Stage 2 (posterior mean) AttnRes (long-range skip)

The roles of depth and residuals, derived

This two-stage decomposition gives architecture elements statistically derivable roles rather than merely empirical ones:

  • Depth (L): controls how well Stage 1 has refined the particle prior — more layers means a sharper, more accurate approximation to ρ₀
  • Bandwidth β: must be matched to noise scale σ² via L ≈ βσ²/(2η) for correct denoising time
  • Stage 2 bandwidth βc = 1/σ²: the Bayesian posterior bandwidth — derived from the Gaussian noise model, not learned
  • Long-range skip (AttnRes): carries the noisy original token x̃ᵢ to the Stage 2 query position, preserving identity for posterior averaging
04 — Dynamic Energy Landscape

Stage 1 Sculpts the Associative Memory

Stage 2 can be rewritten as gradient descent on an energy landscape induced by the current particle set Z(ℓ):

Stage 2 as energy minimization (Eq. 2)
x̂ᵢ = x̃ᵢ − ∇q E(Z(ℓ); q) |q = x̃ᵢ
E(Z; q) = −1/βc · log ∑j exp(−βc/2 · ‖q − zj‖²)
// Dense associative memory (Ramsauer 2021) — but memories emerge in-context

This is a dense associative memory (Hopfield network), except the memories aren't stored parameters — they're the refined particles Z(ℓ) from Stage 1. As depth increases, Stage 1 sharpens these particles from a noisy cloud into distinct clusters, dynamically sculpting the energy wells.

energy landscape E(Z⁽ˡ⁾; q) — depth sculpts the posterior inference surface darker = lower energy (stronger attraction)
depth ℓ =

At depth 0, the particles are scattered randomly — the energy wells are shallow and wide, unable to reliably attract queries to the correct basin. At depth 40, the particles have clustered, forming two sharp wells at x ≈ ±2. The query x̃* (amber star) falls into the deeper well and is denoised to the correct cluster. Depth sculpts the inference landscape without any learned parameters.

Connection to in-context learning

From this perspective, multilayer attention implements something remarkable: a transformer that has never seen the test distribution can construct an effective associative memory during the forward pass by using the context tokens as both the "training data" for the energy landscape and the queries to be denoised.

05 — Optimal Depth T* = σ²/2

A Fixed Bandwidth Replaces the Noise Schedule

Standard diffusion models require a carefully engineered noise schedule — a time-dependent σ(t) that controls how much noise is added/removed at each step. This paper shows that for attention-based denoising, a noise schedule is unnecessary: a single fixed bandwidth β and the right integration time suffice.

In the large-N, continuous-depth limit, the Stage 1 variance evolution follows:

Variance dynamics (Proposition 3.2)
// Exact reverse diffusion regime (β → ∞, N → ∞):
dv/dt = −2  →  v(t) = σ² − 2t  →  T* = σ²/2

// Finite β correction (Theorem F.1 from Appendix F):
dv/dt = −2v / (v + β⁻¹)   // slower, curved for small β

The optimal integration horizon is T* = σ²/2: run Stage 1 for exactly this long and the empirical distribution will have converged to the clean prior. In terms of layers: L* ≈ βσ²/(2η). No schedule needed — just the right depth.

variance decay vs. depth — β determines convergence regime vertical line = T* = σ²/2
highlight β =

The chart plots variance of the particle distribution v(t) as a function of integration time t. β = 5 tracks the theoretical linear decay (dashed gray) nearly exactly, reaching v ≈ 0 at T* = 0.125 (σ² = 0.25). β = 1 is governed by the corrected ODE and converges much slower. β = 20 at finite N shows premature collapse — fast initial decay, but particles form spurious clusters before reaching T*.

Depth as a noise-informed hyperparameter

The depth–noise relationship L* ≈ βσ²/(2η) is a principled guide for architecture design. Unlike heuristic depth selection, this formula is derived from the denoising dynamics. It implies that noisier data requires more layers — not because of gradient flow considerations, but because there's more diffusion to reverse.

06 — Recovery Guarantee

Convergence to Bayes-Optimal

The main theoretical result (Theorem 5.1) proves that the two-stage estimator converges to the Bayes-optimal denoiser — the posterior mean under the true prior — through a sequence of limits: first N → ∞ at fixed (β, R), then the truncation radius R → ∞, then β → ∞.

Theorem 5.1 — sequential posterior-mean recovery (informal)
// Let P₀ ∈ Aτ (admissible class; includes all Gaussians)
limβ→∞ limR→∞ limN→∞ E[ sup|y|≤M |mμ(Tβ)(y) − mP₀(y)| ] = 0

// m_P(y) = E[X | Y=y] = Bayes-optimal posterior mean under prior P
// μ(Tβ) = empirical distribution of particles at integration time Tβ = βτ/2

Concretely for Gaussian priors P₀ = N(a, Σ₀), the mean-field flow stays Gaussian and the covariance evolves analytically as Γ̇ = −2Γ(I + βΓ)⁻¹, converging to Σ₀ at T = Tβ. The truncation error vanishes with high probability once R > √(log N) · const.

MSE vs. depth and context length — approaching Bayes MMSE scroll into view to animate

Normalized MSE / σ² on 2-mode GMM (µ = ±1, σ² = 0.5). Lower is better.

Stage 1 only (no Stage 2)
Stage 1 only, L=200
1.76
Stage 2 with Stage 1
Stage 2, N=1000
1.04
Stage 2, N=3000
0.76
Stage 2, N=5000
0.56
Stage 2, N=8000
0.44
Lower bound
Bayes MMSE (oracle prior)
0.26

Values from Fig. 4(a): β=20, σ²=0.5, L=200. x-axis scaled to max bar = 1.76.

More context length (N) consistently closes the gap to Bayes MMSE. At N = 8000 the two-stage estimator reaches within 70% of Bayes-optimal without any learned parameters and without knowing the prior ρ₀ — purely from the context. Stage 1 alone (predicting the noisy input directly, without Stage 2) performs poorly at 1.76, near random.

What's not yet proven

The paper is honest about its limitations. The recovery theorem uses sequential limits — not joint scaling — and doesn't give finite-sample convergence rates. The admissible class Aτ is verified for Gaussians but not for multimodal or heavy-tailed priors. Practical transformers also have multiple heads, MLP layers, learned weights, and positional encoding — none of which are captured here. What the paper establishes is a principled statistical account of what the minimal single-head attention mechanism is computing, as a foundation for richer settings.