从理论到实践:揭开现代 LLM 中 Key-Value Cache 的神秘面纱
Source: Dev.to
引言 — 什么是键值缓存(Key‑Value Cache),以及我们为何需要它?

我在大语言模型领域的探索之路
虽然我并非传统的数据科学或深度学习背景,但过去几年与 AI 和生成模型的工作让我能够以务实的方式学习 KV 缓存的概念:通过阅读博客和技术书籍,并动手实验示例代码。这种实践方法帮助把数学思想转化为具体、可运行的组件。

键值缓存(Key‑Value Cache)是一个专用的内存空间,用来存储自注意力机制为已处理的 token 产生的中间 键 (K) 和 值 (V) 向量。通过在后续推理步骤中复用这些向量,缓存能够显著加速生成过程。
为什么 KV 缓存是必需的
Transformer 模型(例如 GPT 系列)以自回归方式生成文本:每个新 token 的预测都基于所有先前生成的 token。若没有 KV 缓存,模型必须在每一步重新计算整个序列的 K 和 V,导致计算成本随序列长度 n 成二次增长 O(n²)。这使得长文本生成变得极其缓慢且昂贵。
KV 缓存的工作原理
-
预填充阶段(首个 Token / Prompt)
模型处理输入提示,计算每个 token 的 Q、K、V。K 与 V 向量被存入 KV 缓存。 -
解码阶段(后续 Token)
- 仅为新生成的 token 计算 查询 (Q)。
- 之前计算好的 K 与 V 向量直接从缓存中读取。
- 计算新 token 自身的 K 与 V 向量并追加到缓存中。
-
结果
每个新 token 的注意力计算从二次 O(n²) 降至线性 O(n),从而实现更快的推理。
权衡
主要的权衡在于 GPU 内存使用的增加:对于非常长的序列或大的批量大小,缓存的 K 与 V 张量可能会占据大量显存。

Source: (Sebastian Raschka, PhD)
通过示例概念代码进行说明
下面是一个最小化的 Python 示例,演示在使用 KV 缓存的情况下如何模拟注意力步骤。
# kv_cache_demo.py
KV_CACHE = {
"keys": [], # stored K vectors
"values": [] # stored V vectors
}
def generate_next_token(new_token, sequence_so_far):
"""
Simulates the attention step for a new token, using/updating the KV cache.
"""
print(f"\n--- Processing Token: '{new_token}' ---")
# 1️⃣ Compute Query for the new token
Q_new = f"Q_vec({new_token})"
print(f"1. Computed Query (Q): {Q_new}")
# 2️⃣ Compute Key and Value for the new token only
K_new = f"K_vec({new_token})"
V_new = f"V_vec({new_token})"
print(f"2. Computed Key (K) and Value (V): {K_new}, {V_new}")
# 3️⃣ Build full attention matrices using cached + new vectors
K_full = KV_CACHE["keys"] + [K_new]
V_full = KV_CACHE["values"] + [V_new]
print(f"3. Full Attention Keys (cached + new): {K_full}")
# 4️⃣ Perform (conceptual) attention
attention_output = f"Attention({Q_new}, {K_full}, {V_full})"
print(f"4. Attention Calculation: {attention_output}")
# 5️⃣ Update the cache
KV_CACHE["keys"].append(K_new)
KV_CACHE["values"].append(V_new)
print(f"5. KV Cache updated – size now: {len(KV_CACHE['keys'])} tokens")
return "Predicted_Token"
# ---- Demo ----
print("=== Initial Prompt Phase: 'Hello, world' ===")
prompt_tokens = ["Hello,", "world"]
# Process prompt tokens
generate_next_token(prompt_tokens[0], [])
generate_next_token(prompt_tokens[1], prompt_tokens[:1])
print("\n=== Generation Phase: Predicting the 3rd token ===")
next_token = "(Model predicts 'how')"
generate_next_token(next_token, prompt_tokens)
代码中的关键要点
- 缓存利用:在处理新 token 时,模型复用
KV_CACHE['keys']与KV_CACHE['values'],其中已包含所有先前 token 的向量。 - 最小计算量:每一步仅计算最新 token 的查询、键和值。
- 效率提升:若没有缓存,模型必须在每一步重新计算整个历史的 K 与 V,导致大量冗余工作。
这里展示的概念可以扩展到完整的 PyTorch 多头注意力实现中,届时缓存会以张量形式(self.cache_k、self.cache_v)管理,而查询的尺寸保持不变。