LLM 系统的缓存策略(第3部分):Multi-Query Attention 与 Memory-Efficient Decoding
Source: Dev.to
请提供您希望翻译的文章正文内容,我将为您翻译成简体中文。
2️⃣ Recap: KV‑Caching in Autoregressive Decoding
KV‑caching 通过存储每个已生成 token 对应的键 (K) 和值 (V),消除了自回归解码中的 二次 注意力成本。
对于一个具有
- (L) 层数
- (H) 每层的注意力头数
- (T) 序列长度(上下文窗口)
- (d_h) 头维度
的 Transformer,KV‑cache 每层的大小为
[ \mathcal{O}(L; H; T; d_h) ]
KV‑caching 消除了 冗余计算,但 并未 减少随头数 (H) 线性增长的内存占用。
对于现代大语言模型(32–128 个头,长上下文),KV‑cache 很快就会主导推理成本(包括内存和带宽)。
❓ 注意力头真的需要独立的键和值吗?
多查询注意力(MQA)——概念
每个头保留自己的查询投影,但 在所有头之间共享一套键和值。
形式上
[ \begin{aligned} Q_i &= X,W_{Q_i} \quad &&\text{(每个头独立)}\[4pt] K &= X,W_{K} \quad &&\text{(共享)}\[4pt] V &= X,W_{V} \quad &&\text{(共享)} \end{aligned} ]
每个头的注意力随后为
[ \text{Attention}_i ;=; \operatorname{softmax}!\Bigl(\frac{Q_i K^{!\top}}{\sqrt{d_h}}\Bigr),V ]
重要说明
- 键和值是共享的,但它们 不是同一个 矩阵:(W_K \neq W_V)。
- 这一单一设计决定将 KV‑缓存大小缩减 (H) 倍。
📐 参数形状
| Matrix | Shape | Description |
|---|---|---|
| (W_Q) | (\mathbb{R}^{d \times (H d_h)}) | 为每个头单独的查询投影 |
| (W_K) | (\mathbb{R}^{d \times d_h}) | 共享 键投影 |
| (W_V) | (\mathbb{R}^{d \times d_h}) | 共享 值投影 |
(d) 是模型维度。
📊 KV‑缓存内存对比
| 注意力类型 | 每层 KV 缓存 |
|---|---|
| 多头注意力 (MHA) | (H \times T \times d_h) |
| 多查询注意力 (MQA) | (1 \times T \times d_h) |
示例:64 头模型,FP16(每元素 2 字节)
| 参数 | 数值 |
|---|---|
| 层数 (L) | 80 |
| 头数 (H) | 64 |
| 头维度 (d_h) | 128 |
| 上下文长度 (T) | 2048 |
| 精度 | FP16 (2 B) |
| 注意力类型 | KV‑缓存公式 | 每序列近似大小 |
|---|---|---|
| MHA | (2 \times L \times H \times T \times d_h \times 2\text{ B}) | ≈ 1.2 GB |
| MQA | (2 \times L \times 1 \times T \times d_h \times 2\text{ B}) | ≈ 19 MB |
| 缩减 | — | ≈ 64× 更小 |
系数 2 表示同时存储键 和 值。
🧭 什么被失去(以及为什么它通常并不重要)
在标准 MHA 中,每个头都有独立的投影
[ \begin{aligned} Q_i &= X W_{Q_i}\ K_i &= X W_{K_i}\ V_i &= X W_{V_i} \end{aligned} ]
- 这为每个头提供了各自的 相似度度量((K_i))、检索语义((V_i))和 对齐目标((Q_i))。
- 从几何角度看,MHA 跨越 多个低秩注意力算子,使得各头能够专门化(语法、长程依赖、位置偏置、共指等)。
MQA 强制
[ K_1 = K_2 = \dots = K_H = K,\qquad V_1 = V_2 = \dots = V_H = V ]
- 所有头在 相同的键空间 中计算相关性,并从 相同的值流形 中检索;多样性仅来源于查询。
后果
| 方面 | MHA | MQA |
|---|---|---|
| 注意力子空间数量 | (H)(多) | 1(共享) |
| 每头相似度度量 | 是 | 否 |
| 每头语义抽象 | 是 | 否 |
| 独立关系子空间 | 是 | 否 |
| “视角容量” | 高 | 降低 |
秩的降低限制了模型同时表示多个不兼容解释的能力。
为什么这种退化通常可以忽略不计
- MHA 头部的冗余 – 许多头学习到高度相关的模式。
- 深度与宽度的补偿 – 前馈层吸收了失去的表达能力。
- 训练适应 – 从头使用 MQA 训练的模型会学习到稳健的共享 KV 空间。
- 推理瓶颈 – 在部署时,内存带宽而非表示能力主导了延迟。
🚀 推理工作流(解码)
| 步骤 | MHA | MQA |
|---|---|---|
| 1️⃣ 为新 token 重新计算 queries | ✔️ | ✔️ |
| 2️⃣ 从缓存加载 keys & values | (H) KV 张量 每层 | 1 KV 张量 每层 |
| 3️⃣ 计算 attention | ✔️ | ✔️ |
减少每层加载的 KV 张量数量可以显著降低内存流量、缓存压力和 token 延迟。
📌 要点
- 多查询注意力 用适度的每头表示多样性损失,换取 KV 缓存内存和带宽的巨大降低(对 64 头模型可降低至 1/64)。
- 对于大规模推理密集型部署,这种权衡通常是值得的,这也是许多生产级 LLM(例如 PaLM)采用 MQA 的原因。
KV 多样性
| 每头 | 共享 | |
|---|---|---|
| 表达能力 | 更高 | 更低 |
| KV 缓存大小 | (\mathcal{O}(H,T,d_h)) | (\mathcal{O}(T,d_h)) |
| 推理效率 | 更低 | 高得多 |
注意: MQA(多查询注意力)不是一种免费优化。它是一种有意的架构权衡,侧重于推理可扩展性,而非最大化表达能力。