TPU vs GPU:真实世界性能测试——Google Cloud 上的 LLM 训练

发布: (2025年12月31日 GMT+8 10:16)
11 min read
原文: Dev.to

Source: Dev.to

请提供您希望翻译的具体文本内容,我将按照要求将其译成简体中文并保留原有的格式。

介绍

随着大规模语言模型(LLMs)规模的不断扩大,用于训练的底层硬件已成为项目成功的最关键因素。业界目前正陷入一场引人入胜的架构之争:NVIDIA GPU 的通用强大性能与 Google TPU(张量处理单元)的专用高效性能之间的对决。

对于在 Google Cloud Platform(GCP)上构建系统的工程师和架构师而言,选择 A100/H100 GPU 集群 还是 TPU v4/v5p Pod 并不仅仅是成本问题——它直接影响软件架构、数据流水线以及收敛速度。本文将通过真实的 LLM 训练性能案例,对这两种架构进行深入的技术分析。

硅层差异

根本的区别在于芯片处理 matrix multiplication 的方式,它是 Transformer 架构的核心操作。

方面NVIDIA GPUsGoogle TPUs
设计理念多核通用处理器,具有层次化的 Streaming Multiprocessors (SMs) 和专用 Tensor Cores。围绕 systolic‑array 设计构建的 Domain‑Specific Architecture (DSA)。
内存层次结构通过 CUDA kernels 协调的复杂层次结构(L1/L2 缓存、shared memory)。通过处理单元网格的简化流动,最小化 register‑file 和 external‑memory 访问。
核心操作对图形、仿真和神经网络具有灵活性。针对大规模、确定性的 matrix multiplications 进行优化。

集群级通信

训练像 Llama‑3 或 GPT‑4 这样的 LLM 从来不是在单个芯片上完成的;它是在集群上进行的。芯片间通信的速度往往比原始 TFLOPS 更为关键。

  • NVIDIA

    • NVLink/NVSwitch:节点内部通信。
    • InfiniBand:节点间通信。
    • H100 支持 NVLink 4,提供 ≈ 900 GB/s 的带宽。
  • Google TPUs

    • 光路交换机 (Optical Circuit Switch, OCS) 与专有的 Inter‑Core Interconnect (ICI)
    • TPU v4 和 v5p 利用 OCS 动态重新配置 pod 拓扑,形成巨大的 3‑D 环形结构,实现跨数千颗芯片的低延迟、高带宽通信,且无需传统网络层的额外开销。

功能对比

特性NVIDIA H100 (SXM5)Google TPU v5p
架构Hopper (General Purpose)Systolic Array (DSA)
内存80 GB HBM395 GB HBM3
内存带宽3.35 TB/s4.8 TB/s
互连NVLink 4.0 / InfiniBandICI / Optical Circuit Switch
主要软件CUDA, PyTorchXLA, JAX, PyTorch

Source:

实际测试设置

我们在 Google Cloud 上对 7 B 参数的 Transformer 模型(Llama‑2 架构)进行了训练实验。

测试配置详情
GPU 集群8 × NVIDIA H100(80 GB)节点,使用 GPUDirect‑TCPX 互联
TPU PodTPU v5p‑8(8 核)和 TPU v5p‑32(32 核)切片
软件栈两个平台均受益于 XLA(Accelerated Linear Algebra)。XLA 原生支持 TPU,OpenXLA 使得 PyTorch 和 JAX 代码能够高效编译到 GPU 与 TPU。TPU 必须 使用 XLA;GPU 也可以在 “eager mode” 下运行。
TPU 上首选框架JAX,因为其函数式编程方式天然适配 systolic array。

示例 JAX 分片代码(可在 TPU Pod 与多 GPU 环境中运行)

import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec, NamedSharding
from jax.experimental import mesh_utils

# Detect devices (TPU or GPU)
devices = jax.devices()
print(f"Devices found: {devices}")

# Define a 2‑D mesh for model and data parallelism
# Works identically on TPU pods and multi‑GPU setups
device_mesh = mesh_utils.create_device_mesh((4, 2))
mesh = Mesh(devices=device_mesh, axis_names=('data', 'model'))

# Create a sharded array
# 'data' axis shards the batch, 'model' axis shards the weights
sharding = NamedSharding(mesh, PartitionSpec('data', 'model'))

def train_step(state, batch):
    # XLA handles the communication primitives (all‑reduce)
    # during the gradient computation automatically
    def loss_fn(params):
        logits = model.apply(params, batch['input'])
        return jnp.mean(cross_entropy(logits, batch['target']))

    grads = jax.grad(loss_fn)(state.params)
    return state.apply_gradients(grads=grads)

# JIT‑compile the step for XLA optimization
parallel_train_step = jax.jit(train_step)

性能结果

指标NVIDIA H100Google TPU v5p
Throughput (tokens / sec / chip)~3,800~3,450
Model FLOPs Utilization (MFU)~52 %~58 %
观察由于更高的时钟频率和多功能缓存,小批量时每芯片的原始吞吐量更高。随着批量规模扩大(≥ 1 M 令牌),更高的 MFU 和内存带宽变得显而易见。

TPU 的确定性执行和 ICI 互连最小化空闲时间,尽管每芯片的原始吞吐量略低,但仍实现更高的整体利用率。

分布式训练策略

策略GPU 实现TPU 实现
数据并行torch.distributedNCCLGSPMD 编译器(XLA)自动处理
模型并行(Tensor、Pipeline、Sequence)通过 PyTorch API 手动分片GSPMD(General Shard‑Man Parallel Multi‑Device)让开发者编写单设备代码;编译器在整个 pod 中插入所需的分片逻辑。

成本考虑

  • Google Cloud TPU 定价 通常低于 H100 在相同计算时间下的定价。
  • Spot TPUs 的费用可比按需实例低 70 %
  • GPU 也提供 Spot 实例,但价格差异和可用性因地区和需求而异。

要点

  1. 原始吞吐量 vs. 利用率 – H100 在小批量、单芯片速度上领先;TPU 在大规模持续利用率方面表现出色。
  2. 互连重要 – TPU 的光学电路交换提供了一种在数千芯片间更平滑扩展的拓扑结构。
  3. 软件生态系统 – 两个平台现在都支持 XLA;JAX 是 TPU 的自然匹配,而基于 NCCL 的 PyTorch 仍是 GPU 的标准。
  4. 成本效率 – 在 GCP 上,Spot TPU 通常为大规模 LLM 训练提供最佳的性价比。

选择合适的硬件最终取决于工作负载的批量大小、期望的训练速度以及预算限制。通过了解上述架构细节,您可以做出符合性能目标和成本目标的明智决策。

可用性与成本比较

大块连续的 H100 GPU 的可用性通常低于 TPU 切片的可用性。

示例成本比较(估算的 8 芯节点每小时费用)

配置现货 / 预留费用(≈)
8× H100 节点$12.00 – $15.00
TPU v5p‑8 切片$8.00 – $11.00

在计算 每美元代币数 时,TPU v5p 在我们的训练运行中始终比 H100 高出 15–25 %,尽管 H100 的原始吞吐量略高。这使得 TPU 成为预算为主要约束的长期预训练阶段的首选。

当 GPU 仍然闪耀

  • 生态系统与灵活性 – 大多数开源机器学习研究首先为 CUDA 编写。小众库或全新注意力机制(例如 FlashAttention‑3)通常先针对 NVIDIA 进行优化。
  • Torch‑XLA 允许 PyTorch 在 TPU 上运行,但通常需要进行少量代码修改,以避免 CPU 与 TPU 之间的“上下文切换”,这会严重影响性能。
  • 调试 – XLA 代码是编译后的,不能直接在训练循环中放置 print 语句。使用 jax.debug.print 或 Cloud TPU 分析器来识别诸如 HBM 停顿Infeed 队列 等瓶颈。

常见 TPU 瓶颈:Infeed

在使用 TPU 时,一个常见的限制是 Infeed,即 CPU 无法足够快地提供数据,使 TPU 保持忙碌。

# Using the TPU Profiler in a training loop
import jax

with jax.profiler.trace("/tmp/tpu_profile", create_perfetto_link=True):
    for i in range(100):
        state = parallel_train_step(state, next(data_iter))
        # Ensure the TPU doesn't wait for the host
        if i % 10 == 0:
            print(f"Step {i} completed")

在 Google Cloud 上进行 LLM 训练的决策树

场景推荐加速器原因
规模极大 – 从头开始预训练,跨数百或数千块芯片TPU v5p卓越的芯片间带宽(OCS、ICI)以及线性扩展能力
JAX/XLA 兼容性 – 代码基于 JAX 或熟悉 torch_xlaTPU v5p原生 XLA 编译
成本敏感 – 需要最佳的“每美元代币数”,且可以使用 Spot 实例TPU v5p更低的云端定价,更高的利用率
标准架构 – 传统的 Transformer 块(Attention、MLP、LayerNorm)TPU v5p在 XLA 编译器中高度优化
前沿研究 – 自定义 CUDA 核心或缺乏 XLA 支持的非标准层GPU H100CUDA 为主的生态系统
快速原型 – 使用 eager 模式的 PyTorch 进行快速调试GPU H100更容易、更交互式的开发
小规模微调 – 单节点(8 GPU)工作负载GPU H100更快的设置,更大的灵活性
多云策略 – 在 AWS、Azure、GCP 之间的可移植性GPU H100(或使用抽象层的 TPU)更少的后端特定代码更改

“TPU 与 GPU” 的争论已不再是关于原始速度——而是关于针对特定工作负载的系统级效率。

优势概览

指标获胜者原因
原始吞吐量(单节点)GPU H100更高的时钟频率和专用的 Transformer 引擎
可扩展性(多节点)TPU v5p光学电路交换(OCS)和芯片间互连(ICI)提供更优的带宽
每 Token 成本TPU v5p云端定价更低,硬件利用率更高
开发者效率GPU H100社区支持庞大,调试更便捷
框架支持并列两者均支持 PyTorch/JAX(GPU 原生,TPU 通过 XLA)
面向未来GPU H100CUDA 支持确保兼容新兴研究

通过仔细评估模型架构和预算,您可以选择合适的加速器,使您的大语言模型训练项目保持进度并控制成本。

进一步阅读与资源

  • 技术指南 – Google Cloud AI Architecture & Implementation
  • 关注我们:
    • Twitter / X
    • LinkedIn
    • GitHub
Back to Blog

相关文章

阅读更多 »