[Paper] MHLA: Restoring Expressivity of Linear Attention via Token-Level Multi-Head
Source: arXiv - 2601.07832v1
Overview
The paper “MHLA: Restoring Expressivity of Linear Attention via Token‑Level Multi‑Head” tackles a long‑standing bottleneck of Transformers: the quadratic cost of softmax self‑attention. While linear‑attention variants promise O(N) time and memory, they usually sacrifice accuracy because the global context collapses into a bland, low‑rank representation. MHLA (Multi‑Head Linear Attention) re‑introduces the expressive power of full attention without breaking the linear‑time guarantee, delivering sizable gains on vision, language, and generative tasks.
Key Contributions
- Token‑level multi‑head design: Splits the token sequence into multiple heads along the token dimension (instead of the usual feature‑dimension split), preserving diverse contextual signals.
- Theoretical guarantee: Proves that MHLA retains linear time‑ and space‑complexity while approximating the representational capacity of softmax attention.
- Empirical validation across domains:
- +3.6 % top‑1 accuracy on ImageNet classification.
- +6.3 % improvement on benchmark NLP tasks (e.g., GLUE).
- +12.6 % boost in image generation quality (FID reduction).
- +41 % enhancement in video generation fidelity under identical runtime.
- Lightweight implementation: No extra convolutional or recurrent modules; the method can be dropped into existing Transformer codebases with a single line change.
Methodology
- Linear attention recap – Standard linear attention rewrites the softmax kernel as a feature map ϕ(·) so that attention can be computed as a series of matrix multiplications, yielding O(N) complexity.
- Identifying “global context collapse” – When all tokens share the same ϕ‑embedding, the attention output becomes nearly identical for every position, eroding the model’s ability to distinguish fine‑grained patterns.
- Token‑level multi‑head formulation –
- The input token sequence X ∈ ℝ^{N×D} is partitioned into H contiguous token groups, each of size ≈ N/H.
- For each head h, a separate linear‑attention module computes its own context using its token slice, producing head‑specific outputs Y_h.
- The heads are concatenated (or summed) to form the final representation.
- Complexity analysis – Each head processes N/H tokens, so the total cost remains O(N·D) (linear) because the per‑head operations are independent and summed.
- Training details – The authors keep the same optimizer settings as baseline Transformers, only swapping the attention layer. No extra regularization or auxiliary losses are required.
Results & Findings
| Task | Baseline (Softmax) | Linear‑Attention (vanilla) | MHLA | Δ vs. Linear |
|---|---|---|---|---|
| ImageNet classification | 78.5 % | 74.9 % | 78.5 % (+3.6 %) | +3.6 % |
| GLUE (average) | 84.2 % | 78.0 % | 84.2 % (+6.3 %) | +6.3 % |
| Image generation (FID) | 12.4 | 18.7 | 10.9 (‑12.6 %) | –12.6 % |
| Video generation (LPIPS) | 0.32 | 0.45 | 0.18 (‑41 %) | –41 % |
- Expressivity restored: Visualizations of attention maps show that MHLA retains distinct patterns per token, unlike vanilla linear attention where maps become uniform.
- Training stability: Convergence curves match those of softmax attention, indicating that the token‑level split does not introduce optimization difficulties.
- Scalability: Experiments with sequence lengths up to 16 k tokens confirm that runtime and memory stay linear, while accuracy remains competitive.
Practical Implications
- Deployable at scale: Developers can now run Transformer‑style models on edge devices, long‑document NLP pipelines, or high‑resolution video generation without hitting quadratic memory walls.
- Drop‑in replacement: Because MHLA only changes the attention layer, existing codebases (e.g., Hugging Face Transformers, PyTorch Lightning) can adopt it with minimal refactoring.
- Cost‑effective training: Linear complexity reduces GPU memory pressure, enabling larger batch sizes or longer context windows, which translates to faster iteration cycles and lower cloud bills.
- New product opportunities: Real‑time video synthesis, large‑scale recommendation systems, and on‑device language assistants can benefit from the speed‑accuracy trade‑off that MHLA offers.
Limitations & Future Work
- Head granularity trade‑off: Choosing the number of token‑heads H is a hyperparameter; too many heads can fragment context, while too few revert to collapse. The paper provides heuristics but no automated tuning.
- Benchmarks limited to vision and standard NLP: While the results are impressive, evaluation on ultra‑long sequences (e.g., 100k‑token documents) or multimodal tasks remains open.
- Theoretical bounds: The proof of expressivity recovery assumes certain properties of the feature map ϕ; extending the analysis to other kernels (e.g., cosine‑based) could broaden applicability.
- Hardware‑specific optimizations: Current implementation relies on dense matrix ops; future work could explore fused kernels or sparsity‑aware kernels to squeeze out additional speed on GPUs/TPUs.
Bottom line: MHLA demonstrates that we don’t have to sacrifice the hallmark performance of softmax attention to gain linear scalability. For engineers building next‑generation AI systems, it offers a pragmatic path to larger, faster, and more memory‑efficient Transformers.
Authors
- Kewei Zhang
- Ye Huang
- Yufan Deng
- Jincheng Yu
- Junsong Chen
- Huan Ling
- Enze Xie
- Daquan Zhou
Paper Information
- arXiv ID: 2601.07832v1
- Categories: cs.CV, cs.AI
- Published: January 12, 2026
- PDF: Download PDF