开发者指南:在云 TPU 上调试 JAX 的必备工具与技术
Source: Google Developers Blog
2026年1月5日
作者: Brian Kang – 高级员工,现场解决方案架构师,AI 基础设施
选择合适的工具:核心组件与依赖
在系统的核心有两个主要组件,几乎所有调试工具都依赖于它们:
| 组件 | 描述 |
|---|---|
libtpu(包含 libtpu.so,TPU 运行时) | 每个 Cloud TPU VM 上的共享库。它捆绑了 XLA 编译器、TPU 驱动以及与硬件通信的逻辑。几乎所有调试工具都会通过 libtpu 进行交互或配置。 |
| JAX 和 jaxlib(框架) | jax 是你编写模型代码的 Python 库。jaxlib 是其 C++ 后端,充当 libtpu.so 的桥梁。 |
下面的图示展示了这些组件与调试工具之间的关系。

以下是对具体工具、它们的依赖以及相互关系的详细拆解:

总结: libtpu 是大多数调试工具依赖的核心支柱——无论是用于配置(日志、HLO 转储)还是用于查询实时数据(监控、分析)。诸如 XProf 等工具在 Python 层面直接检查你的 JAX 程序状态。了解这些关系有助于你针对具体问题选择合适的工具。
每个工作负载的关键日志与诊断标志
详细日志
启用详细日志是最关键的调试步骤。没有它,你就像盲目飞行。请在 每个 TPU 切片的工作节点 上使用这些标志,以记录从 TPU 运行时设置到程序执行步骤的所有信息(带时间戳)。

在所有 TPU 工作节点上启用默认标志
gcloud alpha compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
--project ${PROJECT_ID} --zone ${ZONE} --worker=all --node=all \
--command='TPU_VMODULE=slice_configuration=1,real_program_continuator=1 \
TPU_MIN_LOG_LEVEL=0 TF_CPP_MIN_LOG_LEVEL=0 TPU_STDERR_LOG_LEVEL=0 \
python3 -c "import jax; print(f\"Host {jax.process_index()}: Global devices: {jax.device_count()}, Local devices: {jax.local_device_count()}\")"'
libtpu 日志会自动写入每个 TPU VM 上的 /tmp/tpu_logs/tpu_driver.INFO。该文件是了解 TPU 运行时实际行为的最直接依据。
从所有 TPU VM 收集日志
#!/bin/bash
TPU_NAME="YOUR_TPU_NAME"
PROJECT="YOUR_PROJECT"
ZONE="YOUR_ZONE"
BASE_LOG_DIR="PATH/WHERE/LOGS/WILL/BE/DOWNLOADED"
NUM_WORKERS=$(gcloud compute tpus tpu-vm describe $TPU_NAME \
--zone=$ZONE --project=$PROJECT \
| grep tpuVmSelflink \
| awk -F'[:/]' '{print $13}' \
| uniq | wc -l)
echo "Number of workers = $NUM_WORKERS"
for ((i=0; i<$NUM_WORKERS; i++)); do
mkdir -p "${BASE_LOG_DIR}/$i"
echo "Downloading logs from worker=$i"
gcloud compute tpus tpu-vm scp ${TPU_NAME}:/tmp/tpu_logs/* \
"${BASE_LOG_DIR}/$i/" --zone=${ZONE} --project=${PROJECT} --worker=$i
done
在 Google Colab 中,你可以通过 os.environ 设置相同的环境变量,并在左侧的 Files 面板中查看日志。
日志示例片段
I1031 19:02:51.863599 669 b295d63588a.cc:843] Process id 669
I1031 19:02:51.863609 669 b295d63588a.cc:848] Current working directory /content
I1031 19:02:51.863621 669 b295d63588a.cc:866] Build tool: Bazel, release r4rca-2025.05.26-2 (mainline @763214608)
I1031 19:02:51.863621 669 b295d63588a.cc:867] Build target:
I1031 19:02:51.863624 669 b295d63588a.cc:874] Command line arguments:
I1031 19:02:51.863624 669 b295d63588a.cc:876] argv[0]: './tpu_driver'
...
I1031 19:02:51.863784 669 init.cc:78] Remote crash gathering hook installed.
I1031 19:02:51.863807 669 tpu_runtime_type_flags.cc:79] --tpu_use_tfrt not specified. Using default value: true
I1031 19:02:51.873759 669 tpu_hal.cc:448] Registered plugin from module: breakpoint_debugger_server
...
I1031 19:02:51.879890 669 pending_event_logger.cc:896] Enabling PjRt/TPU event dependency logging
I1031 19:02:51.880524 843 device_util.cc:124] Found 1 TPU v5 lite chips.
...
I1031 19:02:53.471830 851 2a886c8_compiler_base.cc:3677] CODE_GENERATION stage duration: 3.610218ms
这些摘录展示了驱动日志中会出现的内容——进程 ID、构建元数据、插件注册、设备发现以及各阶段的耗时细节。
TPU 监控库
TPU 监控库让您能够以编程方式获取 TPU 硬件上工作流性能的洞察(利用率、容量、延迟等)。它是 libtpu 包的一部分,该包会作为 jax[tpu] 的依赖自动安装,因此您可以立即开始使用监控 API。
安装
# 显式安装
pip install "jax[tpu]" libtpu
使用
from libtpu.sdk import tpumonitoring
duty_cycle_metric = tpumonitoring.get_metric("duty_cycle_pct")
duty_cycle_data = duty_cycle_metric.data
print("TPU Duty Cycle Data:")
print(f" Description: {duty_cycle_metric.description}")
print(f" Data: {duty_cycle_data}")
直接在您的 JAX 程序中集成 tpumonitoring——在模型训练期间、推理之前等场景使用。
了解更多,请参阅 Cloud TPU 文档。
tpu‑info
tpu‑info 命令行工具提供 TPU 内存和其他利用率指标的实时视图,类似于 GPU 的 nvidia‑smi。
在所有 Worker 和 Node 上安装
gcloud alpha compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
--project ${PROJECT_ID} \
--zone ${ZONE} \
--worker=all \
--node=all \
--command='pip install tpu-info'
检查芯片利用率指标
gcloud alpha compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
--project ${PROJECT_ID} \
--zone ${ZONE} \
--worker=0 \
--node=0
tpu-info
当芯片正在使用时,输出会显示进程 ID、内存使用情况和占空比 %:

当没有芯片在使用时,TPU VM 将显示无活动:

了解更多指标和流式模式,请参阅 tpu‑info 文档。
在本文中我们讨论了一些 TPU 日志记录和监控选项。系列的下一篇将探讨如何调试你的 JAX 程序——从生成 HLO 转储和使用 XProf 对代码进行分析开始。