TPU vs GPU:真实世界性能测试——Google Cloud 上的 LLM 训练
Source: Dev.to
请提供您希望翻译的具体文本内容,我将按照要求将其译成简体中文并保留原有的格式。
介绍
随着大规模语言模型(LLMs)规模的不断扩大,用于训练的底层硬件已成为项目成功的最关键因素。业界目前正陷入一场引人入胜的架构之争:NVIDIA GPU 的通用强大性能与 Google TPU(张量处理单元)的专用高效性能之间的对决。
对于在 Google Cloud Platform(GCP)上构建系统的工程师和架构师而言,选择 A100/H100 GPU 集群 还是 TPU v4/v5p Pod 并不仅仅是成本问题——它直接影响软件架构、数据流水线以及收敛速度。本文将通过真实的 LLM 训练性能案例,对这两种架构进行深入的技术分析。
硅层差异
根本的区别在于芯片处理 matrix multiplication 的方式,它是 Transformer 架构的核心操作。
| 方面 | NVIDIA GPUs | Google TPUs |
|---|---|---|
| 设计理念 | 多核通用处理器,具有层次化的 Streaming Multiprocessors (SMs) 和专用 Tensor Cores。 | 围绕 systolic‑array 设计构建的 Domain‑Specific Architecture (DSA)。 |
| 内存层次结构 | 通过 CUDA kernels 协调的复杂层次结构(L1/L2 缓存、shared memory)。 | 通过处理单元网格的简化流动,最小化 register‑file 和 external‑memory 访问。 |
| 核心操作 | 对图形、仿真和神经网络具有灵活性。 | 针对大规模、确定性的 matrix multiplications 进行优化。 |
集群级通信
训练像 Llama‑3 或 GPT‑4 这样的 LLM 从来不是在单个芯片上完成的;它是在集群上进行的。芯片间通信的速度往往比原始 TFLOPS 更为关键。
-
NVIDIA
- NVLink/NVSwitch:节点内部通信。
- InfiniBand:节点间通信。
- H100 支持 NVLink 4,提供 ≈ 900 GB/s 的带宽。
-
Google TPUs
- 光路交换机 (Optical Circuit Switch, OCS) 与专有的 Inter‑Core Interconnect (ICI)。
- TPU v4 和 v5p 利用 OCS 动态重新配置 pod 拓扑,形成巨大的 3‑D 环形结构,实现跨数千颗芯片的低延迟、高带宽通信,且无需传统网络层的额外开销。
功能对比
| 特性 | NVIDIA H100 (SXM5) | Google TPU v5p |
|---|---|---|
| 架构 | Hopper (General Purpose) | Systolic Array (DSA) |
| 内存 | 80 GB HBM3 | 95 GB HBM3 |
| 内存带宽 | 3.35 TB/s | 4.8 TB/s |
| 互连 | NVLink 4.0 / InfiniBand | ICI / Optical Circuit Switch |
| 主要软件 | CUDA, PyTorch | XLA, JAX, PyTorch |
Source: …
实际测试设置
我们在 Google Cloud 上对 7 B 参数的 Transformer 模型(Llama‑2 架构)进行了训练实验。
| 测试配置 | 详情 |
|---|---|
| GPU 集群 | 8 × NVIDIA H100(80 GB)节点,使用 GPUDirect‑TCPX 互联 |
| TPU Pod | TPU v5p‑8(8 核)和 TPU v5p‑32(32 核)切片 |
| 软件栈 | 两个平台均受益于 XLA(Accelerated Linear Algebra)。XLA 原生支持 TPU,OpenXLA 使得 PyTorch 和 JAX 代码能够高效编译到 GPU 与 TPU。TPU 必须 使用 XLA;GPU 也可以在 “eager mode” 下运行。 |
| TPU 上首选框架 | JAX,因为其函数式编程方式天然适配 systolic array。 |
示例 JAX 分片代码(可在 TPU Pod 与多 GPU 环境中运行)
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec, NamedSharding
from jax.experimental import mesh_utils
# Detect devices (TPU or GPU)
devices = jax.devices()
print(f"Devices found: {devices}")
# Define a 2‑D mesh for model and data parallelism
# Works identically on TPU pods and multi‑GPU setups
device_mesh = mesh_utils.create_device_mesh((4, 2))
mesh = Mesh(devices=device_mesh, axis_names=('data', 'model'))
# Create a sharded array
# 'data' axis shards the batch, 'model' axis shards the weights
sharding = NamedSharding(mesh, PartitionSpec('data', 'model'))
def train_step(state, batch):
# XLA handles the communication primitives (all‑reduce)
# during the gradient computation automatically
def loss_fn(params):
logits = model.apply(params, batch['input'])
return jnp.mean(cross_entropy(logits, batch['target']))
grads = jax.grad(loss_fn)(state.params)
return state.apply_gradients(grads=grads)
# JIT‑compile the step for XLA optimization
parallel_train_step = jax.jit(train_step)
性能结果
| 指标 | NVIDIA H100 | Google TPU v5p |
|---|---|---|
| Throughput (tokens / sec / chip) | ~3,800 | ~3,450 |
| Model FLOPs Utilization (MFU) | ~52 % | ~58 % |
| 观察 | 由于更高的时钟频率和多功能缓存,小批量时每芯片的原始吞吐量更高。 | 随着批量规模扩大(≥ 1 M 令牌),更高的 MFU 和内存带宽变得显而易见。 |
TPU 的确定性执行和 ICI 互连最小化空闲时间,尽管每芯片的原始吞吐量略低,但仍实现更高的整体利用率。
分布式训练策略
| 策略 | GPU 实现 | TPU 实现 |
|---|---|---|
| 数据并行 | torch.distributed 与 NCCL | 由 GSPMD 编译器(XLA)自动处理 |
| 模型并行(Tensor、Pipeline、Sequence) | 通过 PyTorch API 手动分片 | GSPMD(General Shard‑Man Parallel Multi‑Device)让开发者编写单设备代码;编译器在整个 pod 中插入所需的分片逻辑。 |
成本考虑
- Google Cloud TPU 定价 通常低于 H100 在相同计算时间下的定价。
- Spot TPUs 的费用可比按需实例低 70 %。
- GPU 也提供 Spot 实例,但价格差异和可用性因地区和需求而异。
要点
- 原始吞吐量 vs. 利用率 – H100 在小批量、单芯片速度上领先;TPU 在大规模持续利用率方面表现出色。
- 互连重要 – TPU 的光学电路交换提供了一种在数千芯片间更平滑扩展的拓扑结构。
- 软件生态系统 – 两个平台现在都支持 XLA;JAX 是 TPU 的自然匹配,而基于 NCCL 的 PyTorch 仍是 GPU 的标准。
- 成本效率 – 在 GCP 上,Spot TPU 通常为大规模 LLM 训练提供最佳的性价比。
选择合适的硬件最终取决于工作负载的批量大小、期望的训练速度以及预算限制。通过了解上述架构细节,您可以做出符合性能目标和成本目标的明智决策。
可用性与成本比较
大块连续的 H100 GPU 的可用性通常低于 TPU 切片的可用性。
示例成本比较(估算的 8 芯节点每小时费用)
| 配置 | 现货 / 预留费用(≈) |
|---|---|
| 8× H100 节点 | $12.00 – $15.00 |
| TPU v5p‑8 切片 | $8.00 – $11.00 |
在计算 每美元代币数 时,TPU v5p 在我们的训练运行中始终比 H100 高出 15–25 %,尽管 H100 的原始吞吐量略高。这使得 TPU 成为预算为主要约束的长期预训练阶段的首选。
当 GPU 仍然闪耀
- 生态系统与灵活性 – 大多数开源机器学习研究首先为 CUDA 编写。小众库或全新注意力机制(例如 FlashAttention‑3)通常先针对 NVIDIA 进行优化。
- Torch‑XLA 允许 PyTorch 在 TPU 上运行,但通常需要进行少量代码修改,以避免 CPU 与 TPU 之间的“上下文切换”,这会严重影响性能。
- 调试 – XLA 代码是编译后的,不能直接在训练循环中放置
print语句。使用jax.debug.print或 Cloud TPU 分析器来识别诸如 HBM 停顿 或 Infeed 队列 等瓶颈。
常见 TPU 瓶颈:Infeed
在使用 TPU 时,一个常见的限制是 Infeed,即 CPU 无法足够快地提供数据,使 TPU 保持忙碌。
# Using the TPU Profiler in a training loop
import jax
with jax.profiler.trace("/tmp/tpu_profile", create_perfetto_link=True):
for i in range(100):
state = parallel_train_step(state, next(data_iter))
# Ensure the TPU doesn't wait for the host
if i % 10 == 0:
print(f"Step {i} completed")
在 Google Cloud 上进行 LLM 训练的决策树
| 场景 | 推荐加速器 | 原因 |
|---|---|---|
| 规模极大 – 从头开始预训练,跨数百或数千块芯片 | TPU v5p | 卓越的芯片间带宽(OCS、ICI)以及线性扩展能力 |
JAX/XLA 兼容性 – 代码基于 JAX 或熟悉 torch_xla | TPU v5p | 原生 XLA 编译 |
| 成本敏感 – 需要最佳的“每美元代币数”,且可以使用 Spot 实例 | TPU v5p | 更低的云端定价,更高的利用率 |
| 标准架构 – 传统的 Transformer 块(Attention、MLP、LayerNorm) | TPU v5p | 在 XLA 编译器中高度优化 |
| 前沿研究 – 自定义 CUDA 核心或缺乏 XLA 支持的非标准层 | GPU H100 | CUDA 为主的生态系统 |
| 快速原型 – 使用 eager 模式的 PyTorch 进行快速调试 | GPU H100 | 更容易、更交互式的开发 |
| 小规模微调 – 单节点(8 GPU)工作负载 | GPU H100 | 更快的设置,更大的灵活性 |
| 多云策略 – 在 AWS、Azure、GCP 之间的可移植性 | GPU H100(或使用抽象层的 TPU) | 更少的后端特定代码更改 |
“TPU 与 GPU” 的争论已不再是关于原始速度——而是关于针对特定工作负载的系统级效率。
优势概览
| 指标 | 获胜者 | 原因 |
|---|---|---|
| 原始吞吐量(单节点) | GPU H100 | 更高的时钟频率和专用的 Transformer 引擎 |
| 可扩展性(多节点) | TPU v5p | 光学电路交换(OCS)和芯片间互连(ICI)提供更优的带宽 |
| 每 Token 成本 | TPU v5p | 云端定价更低,硬件利用率更高 |
| 开发者效率 | GPU H100 | 社区支持庞大,调试更便捷 |
| 框架支持 | 并列 | 两者均支持 PyTorch/JAX(GPU 原生,TPU 通过 XLA) |
| 面向未来 | GPU H100 | CUDA 支持确保兼容新兴研究 |
通过仔细评估模型架构和预算,您可以选择合适的加速器,使您的大语言模型训练项目保持进度并控制成本。
进一步阅读与资源
- 技术指南 – Google Cloud AI Architecture & Implementation
- 关注我们:
- Twitter / X
- GitHub