大步迈向 Transformer(第2部分):构建 Transformer
Source: Dev.to
Naive 方法
让我们具体一点:在每个时间步,我们希望看到我们之前的每个字符来做决定。
一种简单的做法是通过加权求和把之前字符的数据带在一起;在最初的情况我们只取均值。
B, T, C = 4, 8, 2
x = torch.randn(B, T, C)
xbow = torch.zeros((B, T, C))
for b in range(B):
for t in range(T):
xprev = x[b, :t+1] # (t, C)
xbow[b, t] = torch.mean(xprev, 0)
所有这些操作都可以用矩阵乘法来表达。下面是使用下三角掩码来强制因果性(没有 token 能关注未来的 token)并对每一行进行归一化以计算均值的紧凑写法。
torch.manual_seed(42)
a = torch.tril(torch.ones(3, 3))
a = a / torch.sum(a, 1, keepdim=True) # 行均值
b = torch.randint(0, 10, (3, 2)).float()
c = a @ b
print('a=')
print(a)
print('--')
print('b=')
print(b)
print('--')
print('c=')
print(c)
下三角矩阵(torch.tril)确保每个 token 只看过去的内容。把每行除以其和就得到聚合值的均值。
将同样的思路应用到真实输入 x:
wei = torch.tril(torch.ones(T, T))
wei = wei / wei.sum(1, keepdim=True) # 归一化行
xbow2 = wei @ x # (B, T, T) @ (B, T, C) -> (B, T, C)
torch.allclose(xbow, xbow2) # 应该为 True
在实践中我们用 softmax 替代显式的均值,对掩码后的分数进行归一化:
wei = torch.zeros((T, T))
wei = wei.masked_fill(tril == 0, float('-inf')) # 掩码未来位置
wei = F.softmax(wei, dim=-1) # 将分数转为概率
xbow3 = wei @ x
torch.allclose(xbow, xbow3) # 验证等价性
我们对被掩码的条目使用 -inf,这样 softmax 会把它们的概率设为零(因为 exp(-inf) = 0)。
位置嵌入
仅靠自注意力是置换不变的,因此需要注入 token 位置信息。常见做法是把 位置嵌入 加到 token 嵌入上:
# 示例(伪代码)
pos_emb = posemb_matrix(torch.arange(T))
x = token_emb + pos_emb
自注意力的关键
Naive 方法把所有之前的 token 当作等权贡献者。实际上,有些 token 更相关,于是我们用 加权求和 取代均匀平均。
每个 token 被投射成三个向量:
- Query (Q) – token 想要寻找的内容。
- Key (K) – 用来与查询匹配的方式。
- Value (V) – 将被聚合的信息。
查询和键之间的注意力权重通过点积(相似度)计算,然后通过 softmax 得到概率分布。对值向量的加权求和产生该 token 的输出。
一个直观的可视化解释可以在 3Blue1Brown 的视频中找到:Attention in transformers.
实现
head_size = 16
# 线性投射(无偏置)
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
# 投射输入 x(形状: B, T, C)
k = key(x) # (B, T, head_size)
q = query(x) # (B, T, head_size)
# 计算原始注意力分数
wei = q @ k.transpose(-2, -1) # (B, T, T)
# 因果掩码:防止关注未来 token
tril = torch.tril(torch.ones(T, T))
wei = wei.masked_fill(tril == 0, float('-inf'))
# 在最后一个维度上做 softmax,得到注意力概率
wei = F.softmax(wei, dim=-1)
# 投射到值并聚合
v = value(x) # (B, T, head_size)
out = wei @ v # (B, T, head_size)
关键点
k.transpose(-2, -1)使键向量与查询的点积对齐。- 因果掩码 (
tril) 确保每个位置只关注更早的位置。 - softmax 将原始分数转为合适的加权分布。
- 最终输出
out是值向量的加权求和。
备注
- Naive 的等权平均是注意力的一种特例,此时所有注意力分数相同。
- 实际上会使用多个注意力头,每个头都有自己的 Q、K、V 投射,随后将它们的输出拼接。
- 在注意力块周围会加入层归一化、残差连接和前馈网络,形成完整的 Transformer 层(参见 Andrej Karpathy 的 “nanoGPT” 实现,获取最小可运行示例)。