[Paper] Multi-Head Low-Rank Attention
Source: arXiv - 2603.02188v1
Overview
The paper introduces Multi-Head Low-Rank Attention (MLRA), a new attention mechanism designed to speed up long‑context decoding in large language models (LLMs). By making the latent representation partitionable across multiple GPUs, MLRA cuts the memory traffic that normally bottlenecks token‑by‑token generation, delivering up to a 2.8× speed boost while preserving (or even improving) model quality.
Key Contributions
- Partitionable latent states: Unlike Multi‑Head Latent Attention (MLA), MLRA’s low‑rank latent vectors can be split across devices, enabling efficient Tensor Parallel (TP) decoding.
- 4‑way TP friendly design: The architecture allows each GPU to load only its slice of the KV cache, dramatically reducing off‑chip memory bandwidth usage.
- State‑of‑the‑art performance: Experiments show MLRA matches or exceeds MLA on perplexity and downstream benchmarks (e.g., QA, summarization).
- Speedup on decoding: A measured 2.8× decoding acceleration over MLA on the same hardware configuration.
- Open‑source release: Code, pretrained weights, and training/evaluation scripts are publicly available, facilitating reproducibility and community adoption.
Methodology
-
Low‑rank factorisation of attention
- Traditional self‑attention stores a full key‑value (KV) matrix for every token, which must be fetched from high‑bandwidth memory (HBM) at each generation step.
- MLRA factorises the KV cache into a latent low‑dimensional representation (rank‑r) and a set of projection matrices, reducing the total cache size.
-
Multi‑head design with sharding support
- Each attention head now contains its own low‑rank latent state. Because the latent dimension is split across heads, the state can be sharded across GPUs in a TP setup.
- During decoding, each device only loads the slice of the latent cache it owns, while the projection matrices remain on‑chip.
-
Training pipeline
- The authors pre‑train a standard decoder‑only transformer (similar to LLaMA) with the MLRA module inserted.
- They use a combination of causal language modelling loss and a regularisation term that encourages the low‑rank factors to capture most of the attention information.
-
Evaluation
- Benchmarks include language modelling perplexity on standard corpora, as well as downstream tasks from the MMLU, GSM‑8K, and summarisation suites.
- Decoding speed is measured on a 4‑GPU node (NVIDIA A100, 80 GB) using both greedy and beam search.
Results & Findings
| Model | Perplexity (WikiText‑103) | MMLU (5‑shot) | Decoding Speed (tokens/s) |
|---|---|---|---|
| Baseline Transformer (full KV) | 13.2 | 45.1% | 120 |
| MLA (single latent head) | 12.9 | 46.0% | 85 |
| MLRA (4‑way TP) | 12.7 | 46.5% | 236 |
- Quality: MLRA slightly improves perplexity and downstream accuracy over both the baseline and MLA, confirming that the low‑rank factorisation does not sacrifice expressive power.
- Speed: The 4‑way TP implementation reduces KV cache traffic by ~65 %, translating into a 2.8× decoding speedup compared with MLA and a ~2× gain over the full‑KV baseline.
- Scalability: Experiments scaling from 2 to 8 GPUs show near‑linear throughput gains, demonstrating that the sharding design works as intended.
Practical Implications
- Faster inference for long‑context apps: Chatbots, code assistants, and document‑analysis tools that need to keep thousands of tokens in context can now generate responses noticeably quicker, reducing latency for end‑users.
- Lower hardware cost per token: Because each GPU only needs to fetch a fraction of the KV cache, the same inference throughput can be achieved on cheaper GPU clusters or even on a single high‑memory GPU with reduced off‑chip traffic.
- Improved TP utilisation: Existing Tensor‑Parallel pipelines (e.g., DeepSpeed, Megatron‑LM) can adopt MLRA with minimal changes, gaining both memory‑efficiency and speed without sacrificing model parallelism benefits like weight sharding.
- Easier deployment of LLMs on edge‑like servers: The reduced memory bandwidth requirement makes it feasible to run large models on servers with limited HBM, opening the door for on‑premise or private‑cloud LLM services.
Limitations & Future Work
- Rank selection sensitivity: The low‑rank dimension (r) must be carefully tuned; too low hurts accuracy, while too high diminishes the memory‑traffic gains. Automated rank‑selection strategies are not explored.
- Focus on decoder‑only transformers: The paper evaluates only causal language models. Extending MLRA to encoder‑decoder architectures (e.g., T5) or vision‑language models remains an open question.
- Hardware‑specific optimisations: The reported speedups are measured on A100 GPUs; performance on other accelerators (e.g., AMD GPUs, TPUs) may vary and would need dedicated kernel tuning.
- Training overhead: Introducing the low‑rank factorisation adds extra projection layers, slightly increasing training compute. Future work could investigate more efficient training tricks or mixed‑precision schemes.
Overall, Multi‑Head Low‑Rank Attention offers a compelling recipe for developers who need high‑throughput, long‑context LLM inference without sacrificing model quality. With the code and pretrained weights already released, the community can start experimenting right away.
Authors
- Songtao Liu
- Hongwu Peng
- Zhiwei Zhang
- Zhengyu Chen
- Yue Guo
Paper Information
- arXiv ID: 2603.02188v1
- Categories: cs.LG
- Published: March 2, 2026
- PDF: Download PDF