A Developer's Guide to Debugging JAX on Cloud TPUs: Essential Tools and Techniques
Source: Google Developers Blog
Jan 5, 2026
Brian Kang
Senior Staff – Field Solutions Architect, AI Infrastructure
Author page
JAX on Cloud TPUs provides powerful acceleration for machine‑learning workflows. When working in distributed cloud environments you need specialized tools to debug your workflows—accessing logs, hardware metrics, and more. This blog post serves as a practical guide to various debugging and profiling techniques.
Choosing the Right Tool: Core Components and Dependencies
At the heart of the system are two main components that nearly all debugging tools depend on:
| Component | Description | Why Debuggers Need It |
|---|---|---|
| libtpu | Shared library (libtpu.so) present on every Cloud TPU VM. It bundles the XLA compiler, the TPU driver, and the communication logic for the hardware. | Provides the low‑level runtime that most tools configure (e.g., logging, HLO dumps) or query (e.g., real‑time metrics). |
| JAX + jaxlib | JAX is the Python front‑end where you write model code; jaxlib is the C++ backend that bridges Python to libtpu.so. | Supplies the high‑level API that tools instrument (e.g., tracing, profiling) and the compiled kernels that run on the TPU. |
Visual Overview
-
Component relationship diagram

-
Tool‑dependency table

Summary
libtpuis the central pillar: most debugging utilities either configure it (e.g., enable logging or HLO dumps) or query it for live data (e.g., monitoring, profiling).- Tools that operate at the Python level—such as XProf—inspect the state of your JAX program directly, but they still rely on
jaxlibto translate those inspections into TPU‑compatible actions.
Understanding these relationships lets you pick the most appropriate tool for the problem you’re tackling, whether it’s a compilation issue, runtime performance bottleneck, or hardware‑level fault.
Essential Logging and Diagnostic Flags for Every Workload
Verbose Logging
Enabling verbose logging is the most critical step for debugging; without it you’re flying blind. Apply these flags on every worker of your TPU slice to log everything from TPU runtime setup to program‑execution steps with timestamps.

Enable default flags on every TPU worker node
gcloud alpha compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
--project ${PROJECT_ID} --zone ${ZONE} --worker=all --node=all \
--command='TPU_VMODULE=slice_configuration=1,real_program_continuator=1 \
TPU_MIN_LOG_LEVEL=0 TF_CPP_MIN_LOG_LEVEL=0 TPU_STDERR_LOG_LEVEL=0 \
python3 -c "import jax; print(f\"Host {jax.process_index()}: \
Global devices: {jax.device_count()}, Local devices: {jax.local_device_count()}\")"'
libtpu logs are automatically generated in /tmp/tpu_logs/tpu_driver.INFO on each TPU VM. This file is the ground‑truth for what the TPU runtime is doing. To collect logs from all TPU VMs, run the script below:
#!/usr/bin/env bash
TPU_NAME="your-tpu-name"
PROJECT="your-project-id"
ZONE="your-zone"
BASE_LOG_DIR="path/to/where/you/want/logs"
# Number of workers in the slice
NUM_WORKERS=$(gcloud compute tpus tpu-vm describe "$TPU_NAME" \
--zone="$ZONE" --project="$PROJECT" \
| grep tpuVmSelflink | awk -F'[:/]' '{print $13}' | uniq | wc -l)
echo "Number of workers = $NUM_WORKERS"
for ((i=0; i<NUM_WORKERS; i++)); do
mkdir -p "${BASE_LOG_DIR}/${i}"
echo "Downloading logs from worker=$i"
gcloud compute tpus tpu-vm scp "${TPU_NAME}:/tmp/tpu_logs/*" \
"${BASE_LOG_DIR}/${i}/" --zone="${ZONE}" --project="${PROJECT}" --worker=$i
done
On Google Colab you can set the same environment variables via os.environ and access the logs in the Files pane on the left sidebar.
Example log snippets
...
I1031 19:02:51.863599 669 b295d63588a.cc:843] Process id 669
I1031 19:02:51.863609 669 b295d63588a.cc:848] Current working directory /content
...
I1031 19:02:51.863621 669 b295d63588a.cc:866] Build tool: Bazel, release r4rca-2025.05.26-2 (mainline @763214608)
I1031 19:02:51.863621 669 b295d63588a.cc:867] Build target:
I1031 19:02:51.863624 669 b295d63588a.cc:874] Command line arguments:
I1031 19:02:51.863624 669 b295d63588a.cc:876] argv[0]: './tpu_driver'
...
I1031 19:02:51.863784 669 init.cc:78] Remote crash gathering hook installed.
I1031 19:02:51.863807 669 tpu_runtime_type_flags.cc:79] --tpu_use_tfrt not specified. Using default value: true
I1031 19:02:51.873759 669 tpu_hal.cc:448] Registered plugin from module: breakpoint_debugger_server
...
I1031 19:02:51.879890 669 pending_event_logger.cc:896] Enabling PjRt/TPU event dependency logging
I1031 19:02:51.880524 843 device_util.cc:124] Found 1 TPU v5 lite chips.
...
I1031 19:02:53.471830 851 2a886c8_compiler_base.cc:3677] CODE_GENERATION stage duration: 3.610218ms
I1031 19:02:53.471885
...
851 isa_program_util_common.cc:486] (HLO module jit_add): Executable fingerprint:0cae8d08bd660ddbee7ef03654ae249ae4122b40da162a3b0ca2cd4bb4b3a19c
TPU Monitoring Library
The TPU Monitoring Library provides programmatic insight into workflow performance on TPU hardware (utilization, capacity, latency, and more). It is part of the libtpu package, which is installed automatically as a dependency of jax[tpu].
# Explicit installation (optional)
pip install "jax[tpu]" libtpu
List all supported metrics:
from libtpu.sdk import tpumonitoring
print(tpumonitoring.list_supported_metrics())
Get a specific metric (e.g., duty_cycle_pct) and print its data and description:
from libtpu.sdk import tpumonitoring
duty_cycle_metric = tpumonitoring.get_metric("duty_cycle_pct")
print("TPU Duty Cycle Data:")
print(f" Description: {duty_cycle_metric.description}")
print(f" Data: {duty_cycle_metric.data}")
You would typically integrate tpumonitoring directly in your JAX programs—during model training, before inference, etc. See the Cloud TPU documentation for more details.
tpu‑info
The tpu‑info CLI provides a real‑time view of TPU memory and other utilization metrics, similar to nvidia‑smi for GPUs.
Install on all workers and nodes
gcloud alpha compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
--project ${PROJECT_ID} --zone ${ZONE} --worker=all --node=all \
--command='pip install tpu-info'
Check chip utilization on a single worker/node
gcloud alpha compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
--project ${PROJECT_ID} --zone ${ZONE} --worker=0 --node=0
tpu-info
When chips are in use, process IDs, memory usage, and duty‑cycle % will be displayed.

When no chips are in use, the TPU VM will show no activity.

Learn more about additional metrics and streaming mode in the tpu‑info documentation.
In this post we discussed several TPU logging and monitoring options. Next in the series we’ll explore how to debug your JAX programs—starting with generating HLO dumps and profiling your code with XProf.