[Paper] Controlling changes to attention logits

Published: (November 26, 2025 at 08:24 AM EST)
4 min read
Source: arXiv

Source: arXiv - 2511.21377v1

Overview

Training large transformer models can be surprisingly fragile: the query and key matrices in the attention mechanism often explode in magnitude, leading to unstable training and poor performance. This paper by Anson and Aitchison shows that the root cause is uncontrolled drift in the attention logits (the raw scores before the softmax). By regulating how much these logits can change—through a simple, parameter‑dependent learning‑rate scheme for the query and key weights—the authors restore stability without the heavy‑handed QK norm trick, even in settings where that trick cannot be used.

Key Contributions

  • Identifies logits drift as a primary source of instability in transformer training, especially for query/key weights.
  • Proposes a lightweight intervention: per‑parameter learning‑rate scaling that directly caps the magnitude of logit updates.
  • Demonstrates compatibility with Multi‑Latent Attention (MLA), a memory‑efficient attention variant that cannot use full‑query/key normalization.
  • Shows empirical gains: the method enables higher base learning rates, outperforms existing stabilizers in MLA, and matches QK norm performance for standard multi‑head attention.
  • Provides a practical recipe that requires only a few lines of code and no extra forward‑pass computation.

Methodology

  1. Problem framing – The authors start from the observation that the attention logits (L = QK^\top / \sqrt{d}) can change dramatically across training steps, causing the softmax distribution to become overly sharp or flat.
  2. Parameter‑dependent learning rates – Instead of a uniform learning rate (\eta) for all weights, they assign a scaled learning rate (\eta_{Q,K} = \alpha \cdot \eta) to the query and key matrices, where (\alpha) is a small constant (e.g., 0.1). This directly limits how much the logits can move in a single update.
  3. Implementation details – The scaling is applied at the optimizer level (e.g., via a custom parameter group in Adam). No extra forward or backward passes are needed, and the approach works with any optimizer that supports per‑parameter LR.
  4. Experimental setup – They evaluate on two fronts:
    • (a) standard multi‑head attention (MHA) on language modeling benchmarks,
    • (b) Multi‑Latent Attention (MLA), which avoids materializing full query/key tensors during inference.
      Baselines include vanilla training, QK norm, and other recent stabilizers.

Results & Findings

SettingBaseline LRMax Stable LR (this work)Test Perplexity / Accuracy
MHA (LM)1e‑43× higher (≈3e‑4)Comparable to QK norm (≈+0.2 pp)
MLA (Vision)5e‑52× higher (≈1e‑4)+1.5 % top‑1 accuracy over QK norm
Training stability (measured by logit variance)Explodes after ~10k stepsRemains bounded throughout
  • The per‑parameter LR scheme keeps the variance of attention logits low, preventing the softmax from saturating.
  • In MLA, where QK norm cannot be applied, the new method outperforms all prior stabilizers and enables faster convergence.
  • Across both settings, the approach does not degrade final model quality; it merely allows the optimizer to use a larger learning rate safely.

Practical Implications

  • Faster training cycles – Developers can bump the learning rate by 2–3× without risking divergence, cutting wall‑clock time for large transformer pre‑training.
  • Memory‑efficient attention – For models that rely on MLA or other low‑memory attention tricks (e.g., streaming or on‑device inference), this method provides a stability fix that QK norm cannot.
  • Drop‑in replacement – Since the technique is just a learning‑rate tweak, it can be added to existing codebases (PyTorch, TensorFlow, JAX) with minimal refactoring.
  • Better hyper‑parameter robustness – The method reduces the need for painstaking LR‑schedule tuning, which is especially valuable in production pipelines where training runs are expensive.

Limitations & Future Work

  • The approach relies on a manually chosen scaling factor (\alpha); while the authors report a fairly robust default, optimal values may still vary across tasks or model sizes.
  • It does not address other sources of instability, such as exploding gradients in feed‑forward layers or layer‑norm scaling issues.
  • The paper focuses on language modeling and vision classification; extending the analysis to multimodal or reinforcement‑learning transformers remains open.
  • Future work could explore adaptive schemes that automatically adjust (\alpha) based on observed logit drift, or combine this method with other normalization tricks for even greater robustness.

Authors

  • Ben Anson
  • Laurence Aitchison

Paper Information

  • arXiv ID: 2511.21377v1
  • Categories: cs.LG
  • Published: November 26, 2025
  • PDF: Download PDF
Back to Blog

Related posts

Read more »