开发者指南:在云 TPU 上调试 JAX 的必备工具与技术

发布: (2026年1月6日 GMT+8 08:35)
6 min read

Source: Google Developers Blog

2026年1月5日
作者: Brian Kang – 高级员工,现场解决方案架构师,AI 基础设施

选择合适的工具:核心组件与依赖

在系统的核心有两个主要组件,几乎所有调试工具都依赖于它们

组件描述
libtpu(包含 libtpu.so,TPU 运行时)每个 Cloud TPU VM 上的共享库。它捆绑了 XLA 编译器、TPU 驱动以及与硬件通信的逻辑。几乎所有调试工具都会通过 libtpu 进行交互或配置。
JAXjaxlib(框架)jax 是你编写模型代码的 Python 库。jaxlib 是其 C++ 后端,充当 libtpu.so 的桥梁。

下面的图示展示了这些组件与调试工具之间的关系。

Relationship diagram

以下是对具体工具、它们的依赖以及相互关系的详细拆解:

Tool table

总结: libtpu 是大多数调试工具依赖的核心支柱——无论是用于配置(日志、HLO 转储)还是用于查询实时数据(监控、分析)。诸如 XProf 等工具在 Python 层面直接检查你的 JAX 程序状态。了解这些关系有助于你针对具体问题选择合适的工具。

每个工作负载的关键日志与诊断标志

详细日志

启用详细日志是最关键的调试步骤。没有它,你就像盲目飞行。请在 每个 TPU 切片的工作节点 上使用这些标志,以记录从 TPU 运行时设置到程序执行步骤的所有信息(带时间戳)。

Logging flags diagram

在所有 TPU 工作节点上启用默认标志

gcloud alpha compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
  --project ${PROJECT_ID} --zone ${ZONE} --worker=all --node=all \
  --command='TPU_VMODULE=slice_configuration=1,real_program_continuator=1 \
            TPU_MIN_LOG_LEVEL=0 TF_CPP_MIN_LOG_LEVEL=0 TPU_STDERR_LOG_LEVEL=0 \
            python3 -c "import jax; print(f\"Host {jax.process_index()}: Global devices: {jax.device_count()}, Local devices: {jax.local_device_count()}\")"'

libtpu 日志会自动写入每个 TPU VM 上的 /tmp/tpu_logs/tpu_driver.INFO。该文件是了解 TPU 运行时实际行为的最直接依据。

从所有 TPU VM 收集日志

#!/bin/bash

TPU_NAME="YOUR_TPU_NAME"
PROJECT="YOUR_PROJECT"
ZONE="YOUR_ZONE"
BASE_LOG_DIR="PATH/WHERE/LOGS/WILL/BE/DOWNLOADED"

NUM_WORKERS=$(gcloud compute tpus tpu-vm describe $TPU_NAME \
               --zone=$ZONE --project=$PROJECT \
               | grep tpuVmSelflink \
               | awk -F'[:/]' '{print $13}' \
               | uniq | wc -l)

echo "Number of workers = $NUM_WORKERS"

for ((i=0; i<$NUM_WORKERS; i++)); do
  mkdir -p "${BASE_LOG_DIR}/$i"
  echo "Downloading logs from worker=$i"
  gcloud compute tpus tpu-vm scp ${TPU_NAME}:/tmp/tpu_logs/* \
        "${BASE_LOG_DIR}/$i/" --zone=${ZONE} --project=${PROJECT} --worker=$i
done

在 Google Colab 中,你可以通过 os.environ 设置相同的环境变量,并在左侧的 Files 面板中查看日志。

日志示例片段

I1031 19:02:51.863599     669 b295d63588a.cc:843] Process id 669
I1031 19:02:51.863609     669 b295d63588a.cc:848] Current working directory /content
I1031 19:02:51.863621     669 b295d63588a.cc:866] Build tool: Bazel, release r4rca-2025.05.26-2 (mainline @763214608)
I1031 19:02:51.863621     669 b295d63588a.cc:867] Build target:
I1031 19:02:51.863624     669 b295d63588a.cc:874] Command line arguments:
I1031 19:02:51.863624     669 b295d63588a.cc:876] argv[0]: './tpu_driver'
...
I1031 19:02:51.863784     669 init.cc:78] Remote crash gathering hook installed.
I1031 19:02:51.863807     669 tpu_runtime_type_flags.cc:79] --tpu_use_tfrt not specified. Using default value: true
I1031 19:02:51.873759     669 tpu_hal.cc:448] Registered plugin from module: breakpoint_debugger_server
...
I1031 19:02:51.879890     669 pending_event_logger.cc:896] Enabling PjRt/TPU event dependency logging
I1031 19:02:51.880524     843 device_util.cc:124] Found 1 TPU v5 lite chips.
...
I1031 19:02:53.471830     851 2a886c8_compiler_base.cc:3677] CODE_GENERATION stage duration: 3.610218ms

这些摘录展示了驱动日志中会出现的内容——进程 ID、构建元数据、插件注册、设备发现以及各阶段的耗时细节。

TPU 监控库

TPU 监控库让您能够以编程方式获取 TPU 硬件上工作流性能的洞察(利用率、容量、延迟等)。它是 libtpu 包的一部分,该包会作为 jax[tpu] 的依赖自动安装,因此您可以立即开始使用监控 API。

安装

# 显式安装
pip install "jax[tpu]" libtpu

使用

from libtpu.sdk import tpumonitoring

duty_cycle_metric = tpumonitoring.get_metric("duty_cycle_pct")
duty_cycle_data = duty_cycle_metric.data

print("TPU Duty Cycle Data:")
print(f"  Description: {duty_cycle_metric.description}")
print(f"  Data: {duty_cycle_data}")

直接在您的 JAX 程序中集成 tpumonitoring——在模型训练期间、推理之前等场景使用。
了解更多,请参阅 Cloud TPU 文档

tpu‑info

tpu‑info 命令行工具提供 TPU 内存和其他利用率指标的实时视图,类似于 GPU 的 nvidia‑smi

在所有 Worker 和 Node 上安装

gcloud alpha compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
  --project ${PROJECT_ID} \
  --zone ${ZONE} \
  --worker=all \
  --node=all \
  --command='pip install tpu-info'

检查芯片利用率指标

gcloud alpha compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
  --project ${PROJECT_ID} \
  --zone ${ZONE} \
  --worker=0 \
  --node=0

tpu-info

当芯片正在使用时,输出会显示进程 ID、内存使用情况和占空比 %:

TPU 利用率(活跃)

当没有芯片在使用时,TPU VM 将显示无活动:

TPU 利用率(空闲)

了解更多指标和流式模式,请参阅 tpu‑info 文档


在本文中我们讨论了一些 TPU 日志记录和监控选项。系列的下一篇将探讨如何调试你的 JAX 程序——从生成 HLO 转储和使用 XProf 对代码进行分析开始。

Back to Blog

相关文章

阅读更多 »