LLM 系统的缓存策略(第3部分):Multi-Query Attention 与 Memory-Efficient Decoding

发布: (2026年2月8日 GMT+8 23:51)
6 分钟阅读
原文: Dev.to

Source: Dev.to

请提供您希望翻译的文章正文内容,我将为您翻译成简体中文。

2️⃣ Recap: KV‑Caching in Autoregressive Decoding

KV‑caching 通过存储每个已生成 token 对应的键 (K) 和值 (V),消除了自回归解码中的 二次 注意力成本。
对于一个具有

  • (L) 层数
  • (H) 每层的注意力头数
  • (T) 序列长度(上下文窗口)
  • (d_h) 头维度

的 Transformer,KV‑cache 每层的大小为

[ \mathcal{O}(L; H; T; d_h) ]

KV‑caching 消除了 冗余计算,但 并未 减少随头数 (H) 线性增长的内存占用。
对于现代大语言模型(32–128 个头,长上下文),KV‑cache 很快就会主导推理成本(包括内存和带宽)。

❓ 注意力头真的需要独立的键和值吗?

多查询注意力(MQA)——概念

每个头保留自己的查询投影,但 在所有头之间共享一套键和值

形式上

[ \begin{aligned} Q_i &= X,W_{Q_i} \quad &&\text{(每个头独立)}\[4pt] K &= X,W_{K} \quad &&\text{(共享)}\[4pt] V &= X,W_{V} \quad &&\text{(共享)} \end{aligned} ]

每个头的注意力随后为

[ \text{Attention}_i ;=; \operatorname{softmax}!\Bigl(\frac{Q_i K^{!\top}}{\sqrt{d_h}}\Bigr),V ]

重要说明

  • 键和值是共享的,但它们 不是同一个 矩阵:(W_K \neq W_V)。
  • 这一单一设计决定将 KV‑缓存大小缩减 (H) 倍。

📐 参数形状

MatrixShapeDescription
(W_Q)(\mathbb{R}^{d \times (H d_h)})为每个头单独的查询投影
(W_K)(\mathbb{R}^{d \times d_h})共享 键投影
(W_V)(\mathbb{R}^{d \times d_h})共享 值投影

(d) 是模型维度。

📊 KV‑缓存内存对比

注意力类型每层 KV 缓存
多头注意力 (MHA)(H \times T \times d_h)
多查询注意力 (MQA)(1 \times T \times d_h)

示例:64 头模型,FP16(每元素 2 字节)

参数数值
层数 (L)80
头数 (H)64
头维度 (d_h)128
上下文长度 (T)2048
精度FP16 (2 B)
注意力类型KV‑缓存公式每序列近似大小
MHA(2 \times L \times H \times T \times d_h \times 2\text{ B})1.2 GB
MQA(2 \times L \times 1 \times T \times d_h \times 2\text{ B})19 MB
缩减≈ 64× 更小

系数 2 表示同时存储键 值。

🧭 什么被失去(以及为什么它通常并不重要)

在标准 MHA 中,每个头都有独立的投影

[ \begin{aligned} Q_i &= X W_{Q_i}\ K_i &= X W_{K_i}\ V_i &= X W_{V_i} \end{aligned} ]

  • 这为每个头提供了各自的 相似度度量((K_i))、检索语义((V_i))和 对齐目标((Q_i))。
  • 从几何角度看,MHA 跨越 多个低秩注意力算子,使得各头能够专门化(语法、长程依赖、位置偏置、共指等)。

MQA 强制

[ K_1 = K_2 = \dots = K_H = K,\qquad V_1 = V_2 = \dots = V_H = V ]

  • 所有头在 相同的键空间 中计算相关性,并从 相同的值流形 中检索;多样性仅来源于查询。

后果

方面MHAMQA
注意力子空间数量(H)(多)1(共享)
每头相似度度量
每头语义抽象
独立关系子空间
“视角容量”降低

秩的降低限制了模型同时表示多个不兼容解释的能力。

为什么这种退化通常可以忽略不计

  1. MHA 头部的冗余 – 许多头学习到高度相关的模式。
  2. 深度与宽度的补偿 – 前馈层吸收了失去的表达能力。
  3. 训练适应 – 从头使用 MQA 训练的模型会学习到稳健的共享 KV 空间。
  4. 推理瓶颈 – 在部署时,内存带宽而非表示能力主导了延迟。

🚀 推理工作流(解码)

步骤MHAMQA
1️⃣ 为新 token 重新计算 queries✔️✔️
2️⃣ 从缓存加载 keys & values(H) KV 张量 每层1 KV 张量 每层
3️⃣ 计算 attention✔️✔️

减少每层加载的 KV 张量数量可以显著降低内存流量、缓存压力和 token 延迟。

📌 要点

  • 多查询注意力 用适度的每头表示多样性损失,换取 KV 缓存内存和带宽的巨大降低(对 64 头模型可降低至 1/64)。
  • 对于大规模推理密集型部署,这种权衡通常是值得的,这也是许多生产级 LLM(例如 PaLM)采用 MQA 的原因。

KV 多样性

每头共享
表达能力更高更低
KV 缓存大小(\mathcal{O}(H,T,d_h))(\mathcal{O}(T,d_h))
推理效率更低高得多

注意: MQA(多查询注意力)不是一种免费优化。它是一种有意的架构权衡,侧重于推理可扩展性,而非最大化表达能力。

0 浏览
Back to Blog

相关文章

阅读更多 »

并非所有 RecSys 问题都相同

并非所有 RecSys 工作都是相同的。行业中的异常案例扭曲了我们对推荐系统的定义。TikTok、Spotify 和 Netflix 采用 hybrid……