An Interactive Reading of

Make Attention
Sub-Quadratic Again

Learned Search Projections for Attention Candidate Retrieval
The paper, in plain English

Every time a large language model generates a new word, it looks back at every previous word it has written and asks: "which of these should I pay attention to?" That look-back is called self-attention, and its cost grows with the square of the sequence length. At 4,000 tokens the scoring matrix has 16 million entries. At 1 million tokens it has a trillion. The model is doing valuable work for most of those entries, but a surprisingly large fraction of the attention probability mass lands on just a handful of keys.

The paper's insight is simple but subtle: you cannot just throw a nearest-neighbor search index at the model's internal vectors and hope it works. Those vectors were never trained to be good neighbors to each other. Instead, the paper attaches a tiny trainable search layer to selected attention layers. This layer learns to project hidden states into a shared low-dimensional space where "nearest neighbor" actually means "the keys the teacher model would have attended to most strongly." Two training losses — a contrastive objective that rewards pulling teacher-preferred keys closer, and a KL divergence that matches the full attention distribution — teach these projections to retrieve the right candidates.

The first headline result, from a clean 6-layer pilot in Qwen3-4B, was that learned search projections preserve full-attention perplexity within +0.01% at K=256 while scoring only 256 keys instead of all prior tokens. Learned retrieval captures more teacher-attention mass than Quest (a strong page-based baseline) at equal token budgets, and the learned vectors are compatible with off-the-shelf FAISS/HNSW approximate nearest-neighbor search.

The new broad-layer experiments push the question further: how much of the model can actually run this way? Substituting all 36 layers is feasible but costs quality (+3.23% relative perplexity gap). Per-layer diagnostics identify layers 0–2 as the weakest contributors, so reserving those — along with the final layer — yields a 32-of-36-layer reserved-edge configuration with only +1.746% PPL gap, Recall@K 0.825, and 20.97M trainable parameters. A post-hoc K-sweep on this checkpoint actually matches full attention at K=256 (−0.062% gap on a 2-batch slice). The picture that emerges is that coverage is a Pareto knob, not a binary choice: trading a small amount of model quality for a much larger reduction in candidate-scoring cost is something you can dial in.

The honest framing stays the same: this shows ANN-compatible retrieval can work — now at near-full-model scale, not just a pilot — but it does not yet beat existing methods in wall-clock speed on today's hardware. Decode-mode KV-cache integration, multi-seed confidence intervals, and long-context task validation are explicitly listed as next experiments.

I
Learned Search Projections
Small trainable projections map hidden states into a shared dsearch=128-dimensional retrieval space where nearest-neighbor search actually recovers attention-relevant keys.
II
Dual Distillation
A contrastive InfoNCE loss pulls teacher-topK keys closer, while a KL distillation loss aligns the full search distribution with the teacher's attention pattern.
III
ANN Compatibility
The learned vectors plug directly into off-the-shelf FAISS/HNSW, tracking exact retrieval quality within +0.03% PPL on the clean evaluation slice.
Start with the quadratic wall
Chapter 1

The Quadratic Wall

Self-attention is the engine that gives transformer language models their power. For every new token a model generates, it computes a score against every previous token in the sequence, passes those scores through a softmax, and uses the resulting weights to blend value vectors into an output. The scoring step is where the cost hides.

Scaled Dot-Product Attention
$$A = \text{softmax}\!\left(\frac{QK^\top}{\sqrt{d_h}} + M\right), \quad O = AV$$

Here $Q, K, V$ are the query, key, and value tensors, $d_h$ is the head dimension, and $M$ is a causal or block-causal mask. The matrix $QK^\top$ contains $\mathcal{O}(N^2)$ pairwise scores — one for every (query, key) pair in the eligible set.

N = 32
Chart updates as you drag. The heatmap shows relative attention weight intensity.
Scoring entries
1,024
Top-5% mass
87%
FLOPs per query
4,096

The attention matrix is almost always sparse in probability mass — a few keys receive most of the weight. The entire paper rests on one question: can we find those few keys without scoring all of them?

Why you can't just use native vectors
Chapter 2

Why Native Vectors Fail

The obvious idea: build an approximate-nearest-neighbor (ANN) index over the transformer's own key vectors. For each query, retrieve the nearest keys in Euclidean space, and attend only to those. The problem is that native query and key vectors are not trained to be mutual nearest neighbors.

Per-query scoring cost
$$s_{tj} = \frac{q_t^\top k_j}{\sqrt{d_h}}, \qquad C_{\text{full}}(N) = N \cdot d_h$$

RetrievalAttention handles the mismatch by making the index attention-aware. This paper takes the opposite route: instead of fixing the index, it fixes the vectors.

60%
30°
Left: native Q/K distributions — offset means ANN retrieval misses. Right: learned space — aligned distributions enable ANN.
Native ANN recall
0.41
Learned ANN recall
0.89

RetrievalAttention adapts the index to native Q/K vectors. This paper adapts the vectors to the index. The trade-off: RetrievalAttention is training-free; this method requires a small training phase. The benefit: standard ANN machinery works out of the box.

How the search space is learned
Chapter 3

Learning a Shared Search Space

The core idea: attach lightweight trainable projections to selected attention layers. Each projection maps the hidden state into a low-dimensional "search space" where queries and keys are trained to be mutually retrievable.

Search Projections
$$Q_i^s = h_i \, W_i^{Q_s}, \qquad K_i^s = h_i \, W_i^{K_s}$$

where $W_i^{Q_s}, W_i^{K_s} \in \mathbb{R}^{d_{\text{model}} \times d_{\text{search}}}$ are per-layer trainable projections, and $h_i$ is the hidden state entering layer $i$'s self-attention module.

128
6
Drag to see how d_search and the number of substituted layers affect parameter count and retrieval capacity.
Trainable params
3.93M
% of base model
0.098%

The base model weights are never modified. The approximation is in candidate selection only — once K candidates are found, the model's own native Q, K, and V are used for the actual attention computation.

How the projections are trained
Chapter 4

Dual Distillation

The search projections are trained by two complementary losses. The first is a contrastive InfoNCE loss that teaches the projections to rank teacher-preferred keys above distractors. The second is a KL-divergence loss that aligns the full search distribution with the teacher's attention pattern.

Contrastive Teacher-TopK Objective
$$\mathcal{L}_{\text{NCE}}(t) = -\log \frac{\sum_{j \in P_t} \exp(z_{tj})}{\sum_{j \in \mathcal{C}_t} \exp(z_{tj})}$$

where $P_t$ is the set of teacher top-$K_{\text{pos}}$ keys, $\mathcal{C}_t$ is the valid causal key set, and $z_{tj} = \frac{(\tilde{q}_t^s)^\top \tilde{k}_j^s}{\tau}$ is the search similarity between L2-normalized vectors scaled by temperature $\tau$.

Distribution-Level KL Distillation
$$\mathcal{L}_{\text{KL}}(t) = D_{\text{KL}}\!\left(A_i^T[t, \cdot] \;\|\; A_i^S[t, \cdot]\right) = \sum_{j \in \mathcal{C}_t} A_i^T[t,j] \log \frac{A_i^T[t,j]}{A_i^S[t,j]}$$

The total layer-averaged objective combines both with $\alpha = \beta = 1$:

$$\mathcal{L} = \frac{\alpha}{|\mathcal{I}|} \sum_{i \in \mathcal{I}} \mathcal{L}^i_{\text{NCE}} + \frac{\beta}{|\mathcal{I}|} \sum_{i \in \mathcal{I}} \mathcal{L}^i_{\text{KL}}$$
0.50
16
1.0
Adjust temperature and positive set size to see how the contrastive loss landscape changes.

The teacher distribution is reconstructed outside the model forward pass — by capturing native post-RoPE Q and K tensors and recomputing softmax. This avoids forcing the model onto slower eager attention paths during training.

How inference works
Chapter 5

Retrieve Then Attend

At inference time, selected attention layers are replaced with a retrieve-then-attend pipeline. The model's native Q, K, V are computed as usual. Then the search projections find the top-K candidates. Attention is computed over only those candidates.

Substituted Sparse Attention
$$\hat{o}_t = \sum_{j \in S_t} \frac{\exp(q_t^\top k_j / \sqrt{d_h})}{\sum_{\ell \in S_t} \exp(q_t^\top k_\ell / \sqrt{d_h})} \, v_j$$

where $S_t$ is the retrieved candidate set for query $t$. The native $q_t$, $k_j$, $v_j$ are the frozen model's own vectors — the approximation is in the candidate set, not in the attention values.

Block-Causal Mask
$$M_{tj} = 0 \;\;\text{iff}\;\; \text{segment}(t) = \text{segment}(j) \;\text{and}\; j \le t$$

and $M_{tj} = -\infty$ otherwise. This prevents cross-document attention leakage in packed sequences.

Click on any pipeline stage to see details. The blue path shows the native attention flow; the orange path shows the search projection layer.

The sparse attention is not an approximation in value space. The model's own Q, K, V, RoPE, and output projection are all used exactly as in full attention. The only approximation is in which keys participate.

When does the math pay off?
Chapter 6

The Complexity Payoff

The method replaces a linear scan over all N keys with a sub-linear HNSW retrieval. The scoring-cost proxy tells a clear story: full attention grows as $\mathcal{O}(N)$, Quest-style page selection also grows as $\mathcal{O}(N)$ (but with a smaller constant), and learned HNSW grows as $\mathcal{O}(\log N)$.

Candidate-Scoring Proxies
$$C_{\text{full}}(N) = N \cdot d_h = 128N$$ $$C_{\text{Quest}}(N) = \frac{N}{P} \cdot 2d_h = 16N$$ $$C_{\text{HNSW}}(N) = M \cdot \text{ef}_{\text{search}} \cdot \log_2(N) \cdot d_{\text{search}} = 262{,}144 \cdot \log_2(N)$$
222 = 4.19M
64
The crossover point shifts with ef_search and context length. Zoom into the long-context regime to see where HNSW wins.
Quest / HNSW crossover
~300K
HNSW advantage at 1M
3.0×

This is a candidate-scoring proxy, not measured GPU runtime. The paper is honest: it does not yet prove wall-clock speedup. But the asymptotic shape — $\mathcal{O}(\log N)$ vs $\mathcal{O}(N)$ — means the advantage must materialize at sufficiently long contexts.

The experimental results
Chapter 7

Near-Parity Perplexity

The clean block-causal experiment substitutes 6 layers of Qwen3-4B-Instruct-2507 with learned sparse attention. On WikiText-103 with 4096-token sequences, the method preserves full-attention perplexity within a razor-thin margin.

128
Drag K to see the trade-off between retrieval budget and quality metrics.
Recall@K
0.744
Mass@K
0.787
PPL gap
+0.07%

The learned projection matches or slightly exceeds raw-QK oracle retrieval at every tested layer. This means the search space is not just adequate — it is genuinely capturing attention-relevant geometry that raw native vectors miss.

Quest vs. Learned head-to-head
Chapter 8

Quest vs. Learned

Quest is a strong baseline: it selects KV pages using query-aware min/max metadata, is training-free, and directly targets the KV-cache memory bottleneck. How does learned search compare?

128

FAISS/HNSW Compatibility

A CPU FAISS/HNSW prototype tracks exact learned retrieval on the clean evaluation slice:

Method K PPL Rel. PPL Gap Filler Rate
Learned exact12830.47+0.07%n/a
Learned FAISS/HNSW12830.47+0.09%0.447
Learned exact25630.45+0.01%n/a
Learned FAISS/HNSW25630.46+0.04%0.683

The filler rate is expected for short same-segment prefixes where fewer than K valid causal keys exist. Filler slots are masked out of the sparse-attention softmax.

Learned search captures more teacher mass at equal K, but perplexity does not currently show a clean advantage over Quest. The contribution is retrieval fidelity and ANN compatibility, not a PPL win. The paper earns points for honesty.

Broad-layer substitution
Chapter 10

Broad-Layer Substitution

The 6-layer pilot shows near-parity. But what happens when we substitute nearly every layer? The paper now reports two broader experiments: an all-36-layer substitution and a 32-layer "reserved-edge" configuration that holds back the weakest layers.

All-32 Reserved-Edge Configuration
$$\mathcal{I}_{\text{sub}} = \{3, 4, \ldots, 34\}, \quad \text{reserved} = \{0, 1, 2, 35\}$$ $$\text{Params} = 32 \times 2 \times d_{\text{model}} \times d_{\text{search}} = 32 \times 2 \times 2560 \times 128 = 20.97\text{M}$$
All-32 best PPL gap
+1.746%
All-36 PPL gap
+3.227%
Trainable params (All-32)
20.97M

Left: All-32 training trajectory over 1000 steps (Table 3). Recall@K plateaus at ~0.825; PPL gap stabilizes near +1.75%. Right: Coverage vs quality showing the three tested configurations (Table 5).

Per-Layer Diagnostics

Layer-wise retrieval analysis on the all-36 experiment reveals that layers 0, 1, and 2 have substantially lower Mass@K than the interior layers. The reserved-edge strategy directly addresses this: hold back the weakest layers, substitute the strong ones. Layer 35 (the final layer) is also reserved as a conservative choice.

Coverage is not a binary switch — it is a Pareto knob. Six layers gives near-parity (+0.07%); 32 layers gives 89% coverage at +1.75%; all 36 pushes to +3.23%. The practitioner chooses where on this frontier to operate.

Post-hoc K-sweep on the All32 checkpoint
Chapter 11

K-Sweep Diagnostics

After training the All-32 reserved-edge model for 1000 steps, the paper performs a post-hoc exact K-sweep on a 2-batch clean block-causal evaluation slice. The sweep varies K from 16 to 256, revealing the retrieval budget vs quality trade-off for the broad-layer configuration.

All-32 K-Sweep Summary (Table 4)
$$\text{K=256: Mass@K} = 0.902, \;\text{PPL gap} = -0.062\% \quad\text{(near-parity across 32 layers)}$$
128
Drag K to see how retrieval quality and perplexity change across the All-32 configuration.
Mass@K
0.807
Recall@K
0.746
PPL
20.66
Rel. PPL gap
+0.590%

At K=256, the All-32 configuration achieves exact parity or better on the 2-batch evaluation slice (−0.062%). This is measured on a small slice and should be interpreted cautiously, but it demonstrates that broad-layer substitution is feasible when the retrieval budget is sufficient.

What comes next
Chapter 9

The Road Ahead

The pilot result is encouraging but narrow. The paper is explicit about what it does not prove — and equally explicit about what must happen next for the claim to grow from "promising prototype" to "practical system."

What the result proves

What the result does not prove

Required next steps

  1. Multi-seed confidence intervals for all reported results.
  2. Full 36-layer substitution with layer-wise training strategies to improve weak edge layers.
  3. Coverage Pareto sweep: 12-layer, 18-layer, 20-layer configurations to map the full frontier.
  4. Long-context evaluation on LongBench, RULER, passkey retrieval, and needle-in-haystack.
  5. Decode-mode KV-cache integration with incremental index updates.
  6. GPU-resident retrieval and fused sparse gather-attention kernels.
  7. Measured wall-clock latency and memory footprint.
  8. Cross-model validation beyond Qwen3-4B.

A radar chart summarizing the current state of evidence across six dimensions. The broad-layer experiments have pushed full-layer coverage to 89% (32/36 layers), but wall-clock latency, long-context quality, and cross-model generality remain unproven.

The correct framing is aspirational but increasingly supported: learned search projections make attention-relevant key selection compatible with standard ANN retrieval, preserving model quality from 6-layer near-parity to 32-layer broad substitution. Coverage is a quality knob, not a binary. The path to a stronger claim is clear — and the paper names every missing piece.