大步迈向 Transformer(第2部分):构建 Transformer

发布: (2025年12月6日 GMT+8 18:02)
5 min read
原文: Dev.to

Source: Dev.to

Naive 方法

让我们具体一点:在每个时间步,我们希望看到我们之前的每个字符来做决定。
一种简单的做法是通过加权求和把之前字符的数据带在一起;在最初的情况我们只取均值。

B, T, C = 4, 8, 2
x = torch.randn(B, T, C)
xbow = torch.zeros((B, T, C))
for b in range(B):
    for t in range(T):
        xprev = x[b, :t+1]          # (t, C)
        xbow[b, t] = torch.mean(xprev, 0)

所有这些操作都可以用矩阵乘法来表达。下面是使用下三角掩码来强制因果性(没有 token 能关注未来的 token)并对每一行进行归一化以计算均值的紧凑写法。

torch.manual_seed(42)
a = torch.tril(torch.ones(3, 3))
a = a / torch.sum(a, 1, keepdim=True)   # 行均值
b = torch.randint(0, 10, (3, 2)).float()
c = a @ b
print('a=')
print(a)
print('--')
print('b=')
print(b)
print('--')
print('c=')
print(c)

下三角矩阵(torch.tril)确保每个 token 只看过去的内容。把每行除以其和就得到聚合值的均值。

将同样的思路应用到真实输入 x

wei = torch.tril(torch.ones(T, T))
wei = wei / wei.sum(1, keepdim=True)   # 归一化行
xbow2 = wei @ x                         # (B, T, T) @ (B, T, C) -> (B, T, C)
torch.allclose(xbow, xbow2)             # 应该为 True

在实践中我们用 softmax 替代显式的均值,对掩码后的分数进行归一化:

wei = torch.zeros((T, T))
wei = wei.masked_fill(tril == 0, float('-inf'))   # 掩码未来位置
wei = F.softmax(wei, dim=-1)                      # 将分数转为概率
xbow3 = wei @ x
torch.allclose(xbow, xbow3)                       # 验证等价性

我们对被掩码的条目使用 -inf,这样 softmax 会把它们的概率设为零(因为 exp(-inf) = 0)。

位置嵌入

仅靠自注意力是置换不变的,因此需要注入 token 位置信息。常见做法是把 位置嵌入 加到 token 嵌入上:

# 示例(伪代码)
pos_emb = posemb_matrix(torch.arange(T))
x = token_emb + pos_emb

自注意力的关键

Naive 方法把所有之前的 token 当作等权贡献者。实际上,有些 token 更相关,于是我们用 加权求和 取代均匀平均。

每个 token 被投射成三个向量:

  • Query (Q) – token 想要寻找的内容。
  • Key (K) – 用来与查询匹配的方式。
  • Value (V) – 将被聚合的信息。

查询和键之间的注意力权重通过点积(相似度)计算,然后通过 softmax 得到概率分布。对值向量的加权求和产生该 token 的输出。

一个直观的可视化解释可以在 3Blue1Brown 的视频中找到:Attention in transformers.

实现

head_size = 16

# 线性投射(无偏置)
key   = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

# 投射输入 x(形状: B, T, C)
k = key(x)    # (B, T, head_size)
q = query(x)  # (B, T, head_size)

# 计算原始注意力分数
wei = q @ k.transpose(-2, -1)               # (B, T, T)

# 因果掩码:防止关注未来 token
tril = torch.tril(torch.ones(T, T))
wei = wei.masked_fill(tril == 0, float('-inf'))

# 在最后一个维度上做 softmax,得到注意力概率
wei = F.softmax(wei, dim=-1)

# 投射到值并聚合
v = value(x)                               # (B, T, head_size)
out = wei @ v                               # (B, T, head_size)

关键点

  • k.transpose(-2, -1) 使键向量与查询的点积对齐。
  • 因果掩码 (tril) 确保每个位置只关注更早的位置。
  • softmax 将原始分数转为合适的加权分布。
  • 最终输出 out 是值向量的加权求和。

备注

  • Naive 的等权平均是注意力的一种特例,此时所有注意力分数相同。
  • 实际上会使用多个注意力头,每个头都有自己的 Q、K、V 投射,随后将它们的输出拼接。
  • 在注意力块周围会加入层归一化、残差连接和前馈网络,形成完整的 Transformer 层(参见 Andrej Karpathy 的 “nanoGPT” 实现,获取最小可运行示例)。
Back to Blog

相关文章

阅读更多 »