Caching Strategies for LLM Systems (Part 3): Multi-Query Attention and Memory-Efficient Decoding
Source: Dev.to
2️⃣ Recap: KV‑Caching in Autoregressive Decoding
KV‑caching eliminates the quadratic attention cost of autoregressive decoding by storing the keys (K) and values (V) produced for every previously generated token.
For a transformer with
- (L) layers
- (H) attention heads per layer
- (T) sequence length (context window)
- (d_h) head dimension
the KV‑cache size per layer is
[ \mathcal{O}(L; H; T; d_h) ]
KV‑caching removes redundant computation, but it does nothing to curb the memory growth that is linear in the number of heads (H).
For modern LLMs (32–128 heads, long contexts) the KV‑cache quickly dominates inference cost (both memory and bandwidth).
❓ Do attention heads really need independent keys and values?
Multi‑Query Attention (MQA) – the idea
Each head keeps its own query projection, but shares a single set of keys and values across all heads.
Formally
[ \begin{aligned} Q_i &= X,W_{Q_i} \quad &&\text{(independent per head)}\[4pt] K &= X,W_{K} \quad &&\text{(shared)}\[4pt] V &= X,W_{V} \quad &&\text{(shared)} \end{aligned} ]
Each head’s attention is then
[ \text{Attention}_i ;=; \operatorname{softmax}!\Bigl(\frac{Q_i K^{!\top}}{\sqrt{d_h}}\Bigr),V ]
Important clarifications
- Keys and values are shared, but they are not the same matrix: (W_K \neq W_V).
- This single design decision collapses the KV‑cache size by a factor of (H).
📐 Parameter shapes
| Matrix | Shape | Description |
|---|---|---|
| (W_Q) | (\mathbb{R}^{d \times (H d_h)}) | Separate query projections for each head |
| (W_K) | (\mathbb{R}^{d \times d_h}) | Shared key projection |
| (W_V) | (\mathbb{R}^{d \times d_h}) | Shared value projection |
(d) is the model dimension.
📊 KV‑Cache Memory Comparison
| Attention type | KV cache per layer |
|---|---|
| Multi‑Head Attention (MHA) | (H \times T \times d_h) |
| Multi‑Query Attention (MQA) | (1 \times T \times d_h) |
Example: 64‑head model, FP16 (2 bytes/element)
| Parameter | Value |
|---|---|
| Layers (L) | 80 |
| Heads (H) | 64 |
| Head dimension (d_h) | 128 |
| Context length (T) | 2048 |
| Precision | FP16 (2 B) |
| Attention type | KV‑cache formula | Approx. size per sequence |
|---|---|---|
| MHA | (2 \times L \times H \times T \times d_h \times 2\text{ B}) | ≈ 1.2 GB |
| MQA | (2 \times L \times 1 \times T \times d_h \times 2\text{ B}) | ≈ 19 MB |
| Reduction | — | ≈ 64× smaller |
The factor 2 accounts for storing both keys and values.
🧭 What is lost (and why it often doesn’t matter)
In standard MHA each head has independent projections
[ \begin{aligned} Q_i &= X W_{Q_i}\ K_i &= X W_{K_i}\ V_i &= X W_{V_i} \end{aligned} ]
- This gives each head its own similarity metric ((K_i)), retrieval semantics ((V_i)), and alignment objective ((Q_i)).
- Geometrically, MHA spans multiple low‑rank attention operators, allowing heads to specialize (syntax, long‑range dependencies, positional bias, coreference, …).
MQA enforces
[ K_1 = K_2 = \dots = K_H = K,\qquad V_1 = V_2 = \dots = V_H = V ]
All heads score relevance in the same key space and retrieve from the same value manifold; diversity comes only from the queries.
Consequences
| Aspect | MHA | MQA |
|---|---|---|
| Number of attention subspaces | (H) (many) | 1 (shared) |
| Per‑head similarity metric | Yes | No |
| Per‑head semantic abstraction | Yes | No |
| Independent relational subspaces | Yes | No |
| “Point‑of‑view capacity” | High | Reduced |
The reduction in rank limits the model’s ability to represent multiple incompatible interpretations simultaneously.
Why the degradation is often negligible
- Redundancy in MHA heads – many heads learn highly correlated patterns.
- Depth & width compensation – feed‑forward layers absorb the lost expressivity.
- Training adaptation – models trained from scratch with MQA learn robust shared KV spaces.
- Inference bottleneck – in deployment, memory bandwidth, not representational power, dominates latency.
🚀 Inference workflow (decoding)
| Step | MHA | MQA |
|---|---|---|
| 1️⃣ Re‑compute queries for the new token | ✔️ | ✔️ |
| 2️⃣ Load keys & values from cache | (H) KV tensors per layer | 1 KV tensor per layer |
| 3️⃣ Compute attention | ✔️ | ✔️ |
Reducing the number of KV tensors loaded per layer dramatically cuts memory traffic, cache pressure, and token latency.
📌 Take‑away
- Multi‑Query Attention trades a modest loss of per‑head representational diversity for a massive reduction in KV‑cache memory and bandwidth (up to 64× for a 64‑head model).
- For large‑scale inference‑heavy deployments, this trade‑off is often worthwhile, which is why many production‑grade LLMs (e.g., PaLM) adopt MQA.
KV Diversity
| Per‑head | Shared | |
|---|---|---|
| Expressiveness | Higher | Lower |
| KV cache size | (\mathcal{O}(H,T,d_h)) | (\mathcal{O}(T,d_h)) |
| Inference efficiency | Lower | Much higher |
Note: MQA (Multi‑Query Attention) is not a free optimization. It is a deliberate architectural trade‑off that favors inference scalability over maximal expressiveness.