[Paper] AdaSplash-2:更快的可微稀疏注意力

发布: (2026年4月17日 GMT+8 00:03)
8 分钟阅读
原文: arXiv

Source: arXiv - 2604.15180v1

概述

论文 AdaSplash‑2 解决了现代 Transformer 模型中最大的痛点之一:在处理长序列时,基于 softmax 的注意力机制的二次方内存和计算成本。通过显著加速可微稀疏的 α‑entmax 注意力,作者们使得仅保留最相关注意力权重的模型的训练和部署变得可行,从而在不牺牲准确性的前提下实现更高的效率。

关键贡献

  • 基于直方图的 τ 初始化:轻量级的片上注意力分数直方图为 α‑entmax 正规化器提供了准确的起始点,平均将根寻找迭代次数减少到 1–2 次。
  • AdaSplash‑2 算法:将直方图初始化与感知稀疏性的 GPU 核心相结合,跳过值为零的块,保持开销极小。
  • 与 FlashAttention‑2 性能持平:当稀疏度超过约 60 %(在长上下文场景中常见)时,AdaSplash‑2 能够匹配甚至超越最先进的密集注意力实现。
  • 在下游任务上的实证验证:使用 AdaSplash‑2 训练的模型在短上下文上取得与 softmax 基线相当的结果,并在长上下文基准上表现出 显著提升(困惑度降低约 15 %)。
  • 开源实现:作者发布了基于 CUDA 的库,可直接嵌入现有的 PyTorch/Transformers 流程中。

方法论

  1. α‑entmax attention 将 softmax 正规化器替换为由 α > 1 参数化的稀疏化函数。输出是一个概率分布,许多条目会恰好为零,但计算正规化因子 τ 需要求解根寻找问题。
  2. AdaSplash‑2 的直方图初始化
    • 在扫描原始注意力分数(QKᵀ 矩阵)时,内核在高速 SRAM 中构建一个粗略直方图(例如 256 桶)。
    • 该直方图近似分数的累计分布,使算法能够使用闭式表达式估计 τ,而不是从一个朴素的猜测开始。
  3. 迭代细化:有了基于直方图的猜测后,根寻找循环在 1–2 次 Newton 迭代内收敛,而不是朴素方法常见的 5–10 次。
  4. 稀疏感知 GPU 内核:在 τ 已知后,内核屏蔽掉低于 entmax 阈值的条目,将剩余值打包成密集块,仅处理这些块。零块会被完全跳过,从而节省内存带宽和计算资源。
  5. 训练流水线:作者将 AdaSplash‑2 嵌入标准 Transformer 代码(例如 HuggingFace 的 BertModel),并在语言建模和摘要数据集上进行训练,序列长度最高可达 16 k token。

结果与发现

设置基准 (FlashAttention‑2)AdaSplash‑2加速(相对)稀疏度水平
4 k 令牌,70 % 稀疏度1.00×(基准)0.94×加快 6 %70 %
8 k 令牌,80 % 稀疏度1.00×0.88×加快 12 %80 %
16 k 令牌,85 % 稀疏度1.00×0.81×加快 19 %85 %
语言建模(困惑度)– 短上下文(512)12.312.4
语言建模 – 长上下文(8 k)15.813.6
  • 训练时间:对于中等到高稀疏度,每步实际时长与密集的 FlashAttention‑2 持平或更好。
  • 模型质量:在短序列上没有下降;在长程任务上有显著提升,证明稀疏模式保留了最有信息的依赖关系。
  • 内存占用:在 85 % 稀疏度下,峰值激活内存降低约 40 %,使得在同一 GPU 上可以使用更大的批量或更长的序列。

Practical Implications

  • 长上下文应用:检索增强生成、文档级摘要和代码补全工具现在可以在不进行昂贵硬件升级的情况下,使用数万 token 训练 Transformer 模型。
  • 成本节约:内存带宽和计算量的降低直接转化为更低的云 GPU 费用,尤其是对已经表现出高度注意力稀疏性的工作负载(例如层次结构或滑动窗口模型)。
  • 即插即用:由于 AdaSplash‑2 遵循与标准 nn.MultiheadAttention 相同的 API,开发者只需替换一次模块导入即可实验稀疏注意力。
  • 兼容现有优化:该方法可与混合精度训练、梯度检查点以及其他加速技巧共同使用,使其成为任何面向性能的技术栈的多功能补充。
  • 边缘部署的潜力:SRAM‑resident histogram 与块跳过逻辑非常适合内存受限的定制 ASIC 或移动 GPU。

限制与未来工作

  • 稀疏性依赖:当注意力模式较密集(< 50 % 稀疏度)时,速度优势会减弱。在这种情况下,传统的密集 kernel 仍然更可取。
  • 直方图粒度权衡:更粗的直方图可以降低 SRAM 使用,但可能导致略多的 Newton 迭代;不同硬件可能需要调节此超参数。
  • 向多查询/多键设置的扩展:当前实现假设每个 head 只有一个 QKᵀ 矩阵;将其适配到更奇特的注意力变体(例如 multi‑query attention)留待未来研究。
  • 收敛性的理论分析:虽然经验上迭代次数很少,但若能给出基于直方图误差的 Newton 步数的形式化上界,将加强该方法的保证。

总体而言,AdaSplash‑2 展示了可微稀疏注意力既 快速准确,为生产环境中可扩展的长上下文 transformer 模型打开了大门。

作者

  • Nuno Gonçalves
  • Hugo Pitorro
  • Vlad Niculae
  • Edoardo Ponti
  • Lei Li
  • Andre Martins
  • Marcos Treviso

论文信息

  • arXiv ID: 2604.15180v1
  • 分类: cs.LG, cs.CL
  • 出版日期: 2026年4月16日
  • PDF: 下载 PDF
0 浏览
Back to Blog

相关文章

阅读更多 »