从理论到实践:揭开现代 LLM 中 Key-Value Cache 的神秘面纱

发布: (2025年12月5日 GMT+8 20:51)
5 min read
原文: Dev.to

Source: Dev.to

引言 — 什么是键值缓存(Key‑Value Cache),以及我们为何需要它?

KV Cache illustration

我在大语言模型领域的探索之路

虽然我并非传统的数据科学或深度学习背景,但过去几年与 AI 和生成模型的工作让我能够以务实的方式学习 KV 缓存的概念:通过阅读博客和技术书籍,并动手实验示例代码。这种实践方法帮助把数学思想转化为具体、可运行的组件。

Source: NVIDIA blog

键值缓存(Key‑Value Cache)是一个专用的内存空间,用来存储自注意力机制为已处理的 token 产生的中间 键 (K)值 (V) 向量。通过在后续推理步骤中复用这些向量,缓存能够显著加速生成过程。

为什么 KV 缓存是必需的

Transformer 模型(例如 GPT 系列)以自回归方式生成文本:每个新 token 的预测都基于所有先前生成的 token。若没有 KV 缓存,模型必须在每一步重新计算整个序列的 K 和 V,导致计算成本随序列长度 n 成二次增长 O(n²)。这使得长文本生成变得极其缓慢且昂贵。

KV 缓存的工作原理

  1. 预填充阶段(首个 Token / Prompt)
    模型处理输入提示,计算每个 token 的 Q、K、V。K 与 V 向量被存入 KV 缓存。

  2. 解码阶段(后续 Token)

    • 仅为新生成的 token 计算 查询 (Q)
    • 之前计算好的 K 与 V 向量直接从缓存中读取。
    • 计算新 token 自身的 K 与 V 向量并追加到缓存中。
  3. 结果
    每个新 token 的注意力计算从二次 O(n²) 降至线性 O(n),从而实现更快的推理。

权衡

主要的权衡在于 GPU 内存使用的增加:对于非常长的序列或大的批量大小,缓存的 K 与 V 张量可能会占据大量显存。

KV Cache memory trade‑off
Source: (Sebastian Raschka, PhD)

通过示例概念代码进行说明

下面是一个最小化的 Python 示例,演示在使用 KV 缓存的情况下如何模拟注意力步骤。

# kv_cache_demo.py
KV_CACHE = {
    "keys": [],   # stored K vectors
    "values": []  # stored V vectors
}

def generate_next_token(new_token, sequence_so_far):
    """
    Simulates the attention step for a new token, using/updating the KV cache.
    """
    print(f"\n--- Processing Token: '{new_token}' ---")

    # 1️⃣ Compute Query for the new token
    Q_new = f"Q_vec({new_token})"
    print(f"1. Computed Query (Q): {Q_new}")

    # 2️⃣ Compute Key and Value for the new token only
    K_new = f"K_vec({new_token})"
    V_new = f"V_vec({new_token})"
    print(f"2. Computed Key (K) and Value (V): {K_new}, {V_new}")

    # 3️⃣ Build full attention matrices using cached + new vectors
    K_full = KV_CACHE["keys"] + [K_new]
    V_full = KV_CACHE["values"] + [V_new]
    print(f"3. Full Attention Keys (cached + new): {K_full}")

    # 4️⃣ Perform (conceptual) attention
    attention_output = f"Attention({Q_new}, {K_full}, {V_full})"
    print(f"4. Attention Calculation: {attention_output}")

    # 5️⃣ Update the cache
    KV_CACHE["keys"].append(K_new)
    KV_CACHE["values"].append(V_new)
    print(f"5. KV Cache updated – size now: {len(KV_CACHE['keys'])} tokens")

    return "Predicted_Token"

# ---- Demo ----
print("=== Initial Prompt Phase: 'Hello, world' ===")
prompt_tokens = ["Hello,", "world"]

# Process prompt tokens
generate_next_token(prompt_tokens[0], [])
generate_next_token(prompt_tokens[1], prompt_tokens[:1])

print("\n=== Generation Phase: Predicting the 3rd token ===")
next_token = "(Model predicts 'how')"
generate_next_token(next_token, prompt_tokens)

代码中的关键要点

  • 缓存利用:在处理新 token 时,模型复用 KV_CACHE['keys']KV_CACHE['values'],其中已包含所有先前 token 的向量。
  • 最小计算量:每一步仅计算最新 token 的查询、键和值。
  • 效率提升:若没有缓存,模型必须在每一步重新计算整个历史的 K 与 V,导致大量冗余工作。

这里展示的概念可以扩展到完整的 PyTorch 多头注意力实现中,届时缓存会以张量形式(self.cache_kself.cache_v)管理,而查询的尺寸保持不变。

Back to Blog

相关文章

阅读更多 »

准备迎接AI乡村音乐爆发

当词曲作者Patrick Irwin去年搬到纳什维尔时,他正踏入一场彩票。每天都有数百场会话进行,词作者们创作song demo……