[Paper] AdaSplash-2: Faster Differentiable Sparse Attention

Published: (April 16, 2026 at 12:03 PM EDT)
5 min read
Source: arXiv

Source: arXiv - 2604.15180v1

Overview

The paper AdaSplash‑2 tackles one of the biggest pain points in modern transformer models: the quadratic memory and compute cost of the softmax‑based attention mechanism when dealing with long sequences. By dramatically speeding up the differentiable sparse α‑entmax attention, the authors make it practical to train and deploy models that keep only the most relevant attention weights, unlocking higher efficiency without sacrificing accuracy.

Key Contributions

  • Histogram‑based τ initialization: A lightweight on‑chip histogram of attention scores yields an accurate starting point for the α‑entmax normalizer, cutting the number of root‑finding iterations to 1–2 on average.
  • AdaSplash‑2 algorithm: Integrates the histogram init with a sparsity‑aware GPU kernel that skips zero‑valued blocks, keeping overhead minimal.
  • Performance parity with FlashAttention‑2: When sparsity exceeds ~60 % (common in long‑context regimes), AdaSplash‑2 matches or even outperforms the state‑of‑the‑art dense attention implementation.
  • Empirical validation on downstream tasks: Models trained with AdaSplash‑2 achieve comparable results to softmax baselines on short contexts and show significant gains (up to ~15 % lower perplexity) on long‑context benchmarks.
  • Open‑source implementation: The authors release a CUDA‑based library that can be dropped into existing PyTorch/Transformers pipelines.

Methodology

  1. α‑entmax attention replaces the softmax normalizer with a sparsity‑inducing function parameterized by α > 1. The output is a probability distribution where many entries become exactly zero, but computing the normalizer τ requires solving a root‑finding problem.
  2. AdaSplash‑2’s histogram init:
    • While scanning the raw attention scores (the QKᵀ matrix), the kernel builds a coarse histogram (e.g., 256 bins) stored in fast SRAM.
    • The histogram approximates the cumulative distribution of scores, allowing the algorithm to estimate τ with a closed‑form expression rather than starting from a naïve guess.
  3. Iterative refinement: With the histogram‑based guess, the root‑finding loop converges in 1–2 Newton iterations instead of the 5–10 typical for naïve methods.
  4. Sparse‑aware GPU kernel: After τ is known, the kernel masks out entries below the entmax threshold, packs the remaining values into dense blocks, and processes only those blocks. Zero blocks are skipped entirely, saving memory bandwidth and compute.
  5. Training pipeline: The authors plug AdaSplash‑2 into standard transformer code (e.g., HuggingFace’s BertModel) and train on language modeling and summarization datasets with sequence lengths up to 16 k tokens.

Results & Findings

SettingBaseline (FlashAttention‑2)AdaSplash‑2Speed‑up (relative)Sparsity level
4 k tokens, 70 % sparsity1.00× (baseline)0.94×6 % faster70 %
8 k tokens, 80 % sparsity1.00×0.88×12 % faster80 %
16 k tokens, 85 % sparsity1.00×0.81×19 % faster85 %
Language modeling (perplexity) – short context (512)12.312.4
Language modeling – long context (8 k)15.813.6
  • Training time: For moderate‑to‑high sparsity, per‑step wall‑clock time is on par with or better than dense FlashAttention‑2.
  • Model quality: No degradation on short sequences; noticeable improvements on long‑range tasks, confirming that the sparsity pattern preserves the most informative dependencies.
  • Memory footprint: Peak activation memory drops by ~40 % at 85 % sparsity, enabling larger batch sizes or longer sequences on the same GPU.

Practical Implications

  • Long‑context applications: Retrieval‑augmented generation, document‑level summarization, and code‑completion tools can now train transformer models with tens of thousands of tokens without prohibitive hardware upgrades.
  • Cost savings: Reduced memory bandwidth and compute translate directly into lower cloud GPU bills, especially for workloads that already exhibit high attention sparsity (e.g., hierarchical or sliding‑window models).
  • Drop‑in replacement: Because AdaSplash‑2 follows the same API as standard nn.MultiheadAttention, developers can experiment with sparse attention by swapping a single module import.
  • Compatibility with existing optimizations: The method works alongside mixed‑precision training, gradient checkpointing, and other speed‑up tricks, making it a versatile addition to any performance‑focused stack.
  • Potential for edge deployment: The SRAM‑resident histogram and block‑skipping logic are well‑suited for custom ASICs or mobile GPUs where memory is at a premium.

Limitations & Future Work

  • Sparsity dependence: The speed advantage diminishes when the attention pattern is dense (< 50 % sparsity). In such regimes, traditional dense kernels remain preferable.
  • Histogram granularity trade‑off: A coarser histogram reduces SRAM usage but can lead to slightly more Newton iterations; tuning this hyper‑parameter may be required for different hardware.
  • Extension to multi‑query/multi‑key setups: The current implementation assumes a single QKᵀ matrix per head; adapting to more exotic attention variants (e.g., multi‑query attention) is left for future research.
  • Theoretical analysis of convergence: While empirical iteration counts are low, a formal bound on the number of Newton steps given histogram error would strengthen the method’s guarantees.

Overall, AdaSplash‑2 demonstrates that differentiable sparse attention can be both fast and accurate, opening the door for scalable, long‑context transformer models in production environments.

Authors

  • Nuno Gonçalves
  • Hugo Pitorro
  • Vlad Niculae
  • Edoardo Ponti
  • Lei Li
  • Andre Martins
  • Marcos Treviso

Paper Information

  • arXiv ID: 2604.15180v1
  • Categories: cs.LG, cs.CL
  • Published: April 16, 2026
  • PDF: Download PDF
0 views
Back to Blog

Related posts

Read more »