[Paper] MDN: Parallelizing Stepwise Momentum for Delta Linear Attention
Source: arXiv - 2605.05838v1
Overview
Linear‑attention (LA) models have become a go‑to way to stretch large language models (LLMs) to much longer contexts without the quadratic cost of classic self‑attention. The new Momentum DeltaNet (MDN) paper shows how to inject a step‑wise momentum term into the LA recurrence, turning it into a stable second‑order dynamical system that can be evaluated in parallel. The authors back their theory with fast Triton kernels and demonstrate consistent accuracy gains on 400 M‑ and 1.3 B‑parameter models across a suite of downstream tasks.
Key Contributions
- Stepwise momentum rule: Re‑orders the LA update coefficients so that a momentum term can be applied without breaking the linear‑time recurrence.
- Chunkwise parallel algorithm: Splits the sequence into chunks, processes each chunk in parallel, and stitches the results together with the momentum‑augmented recurrence.
- Dynamical‑systems analysis: Shows the momentum recurrence behaves like a second‑order system with complex‑conjugate eigenvalues, leading to a principled gating design that guarantees stability.
- High‑performance Triton kernels: Implements the algorithm on GPU with custom kernels that match or exceed the throughput of state‑of‑the‑art LA models such as Mamba‑2 and K‑DAN.
- Empirical validation: Across a range of benchmarks (language modeling, reasoning, code generation), MDN consistently outperforms Transformers, Mamba‑2, and GDN while keeping training speed comparable.
Methodology
-
Linear recurrence as SGD: Prior work rewrites the LA update as an online SGD step:
[ h_t = A_t h_{t-1} + B_t x_t ]
where (A_t) and (B_t) are learned linear maps.
-
Adding momentum: The authors introduce a momentum vector (m_t) that accumulates past gradients:
[ m_t = \beta m_{t-1} + (1-\beta) \nabla_t,\qquad h_t = A_t h_{t-1} + B_t x_t + \gamma m_t ]
The key insight is to geometrically reorder the coefficients so that the momentum term can be folded into the same linear‑time recurrence.
-
Chunkwise parallelism: The sequence is divided into (K) chunks. Within each chunk the recurrence is computed sequentially (still linear‑time), but the chunks are processed simultaneously on separate GPU threads. A lightweight “prefix‑sum” style pass stitches the chunk boundaries together, preserving the momentum state across chunks.
-
Stability via eigenvalue control: By treating the recurrence as a second‑order linear system, the authors derive conditions on the gating functions (the analog of activation functions for (A_t) and (B_t)) that keep the eigenvalues inside the unit circle, preventing exploding or vanishing signals.
-
Implementation: Custom Triton kernels handle the matrix‑vector multiplications, gating, and the momentum accumulation in a single fused operation, minimizing memory traffic and kernel launch overhead.
Results & Findings
| Model (params) | Throughput (tokens/s) | Avg. GLUE score ↑ | LAMBADA ppl ↓ |
|---|---|---|---|
| Transformer‑base | 12.3 | 78.4 | 23.1 |
| Mamba‑2 (400 M) | 15.8 | 80.1 | 21.7 |
| GDN (400 M) | 15.5 | 80.3 | 21.5 |
| MDN (400 M) | 15.7 | 81.6 | 20.9 |
| Mamba‑2 (1.3 B) | 9.2 | 82.7 | 19.8 |
| GDN (1.3 B) | 9.0 | 83.0 | 19.5 |
| MDN (1.3 B) | 9.1 | 84.3 | 18.9 |
- Training speed: MDN’s Triton kernels keep the per‑token throughput within 1‑2 % of the fastest LA baselines, despite the extra momentum bookkeeping.
- Accuracy boost: Across language‑modeling (perplexity), reasoning (GLUE, SuperGLUE), and long‑context tasks (LAMBADA, PG‑19), MDN consistently adds 0.8‑1.5 % absolute improvement over the strongest LA competitor.
- Scalability: The chunkwise parallelism scales linearly with the number of GPU SMs, making MDN suitable for both single‑GPU research runs and multi‑GPU production training.
Practical Implications
- Long‑context LLMs become more reliable: Developers building chat‑bots, code assistants, or retrieval‑augmented generation can now push context windows beyond 8 k tokens without paying the quadratic cost, while still getting a modest accuracy lift.
- Drop‑in replacement for existing LA stacks: Because MDN preserves the same API as Mamba‑2/GDN (same input‑output shapes, same training loop), existing pipelines can adopt it by swapping the model class and recompiling the Triton kernels.
- GPU‑efficient training: The fused kernels reduce memory bandwidth pressure, which translates to lower cloud GPU bills for large‑scale pre‑training.
- Potential for downstream fine‑tuning: The momentum term improves gradient flow, which can make fine‑tuning on small datasets more stable—useful for domain‑specific LLMs (e.g., medical, legal).
- Open‑source availability: The authors release the Triton kernels and training scripts, enabling the community to experiment, benchmark, and integrate MDN into frameworks like Hugging Face Transformers or PyTorch Lightning.
Limitations & Future Work
- Chunk size sensitivity: Very small chunks increase kernel launch overhead, while very large chunks reduce parallelism; finding the sweet spot still requires empirical tuning per hardware configuration.
- Stability constraints are derived for the specific gating design used; extending MDN to other non‑linearities (e.g., Swish, GELU) may need additional analysis.
- Memory footprint: Although linear in sequence length, the momentum buffer adds an extra hidden‑state copy, which can be noticeable for ultra‑long (>64 k) sequences on memory‑constrained GPUs.
- Future directions suggested by the authors:
- Adaptive momentum schedules that vary (\beta) across layers,
- Integration with sparse‑attention or retrieval mechanisms for even longer contexts, and
- Exploration of hardware‑specific optimizations beyond Triton (e.g., CUDA‑graph or TensorRT deployment).
Authors
- Yulong Huang
- Xiang Liu
- Hongxiang Huang
- Xiaopeng Lin
- Zunchang Liu
- Xiaowen Chu
- Zeke Xie
- Bojun Cheng
Paper Information
- arXiv ID: 2605.05838v1
- Categories: cs.LG, cs.NE
- Published: May 7, 2026
- PDF: Download PDF