A Developer's Guide to Debugging JAX on Cloud TPUs: Essential Tools and Techniques

Published: (January 5, 2026 at 05:19 PM EST)
5 min read

Source: Google Developers Blog

Jan 5, 2026

Brian Kang
Senior Staff – Field Solutions Architect, AI Infrastructure
Author page

JAX on Cloud TPUs provides powerful acceleration for machine‑learning workflows. When working in distributed cloud environments you need specialized tools to debug your workflows—accessing logs, hardware metrics, and more. This blog post serves as a practical guide to various debugging and profiling techniques.

Choosing the Right Tool: Core Components and Dependencies

At the heart of the system are two main components that nearly all debugging tools depend on:

ComponentDescriptionWhy Debuggers Need It
libtpuShared library (libtpu.so) present on every Cloud TPU VM. It bundles the XLA compiler, the TPU driver, and the communication logic for the hardware.Provides the low‑level runtime that most tools configure (e.g., logging, HLO dumps) or query (e.g., real‑time metrics).
JAX + jaxlibJAX is the Python front‑end where you write model code; jaxlib is the C++ backend that bridges Python to libtpu.so.Supplies the high‑level API that tools instrument (e.g., tracing, profiling) and the compiled kernels that run on the TPU.

Visual Overview

  • Component relationship diagram
    Diagram showing how libtpu, JAX, and various debugging tools interact

  • Tool‑dependency table
    Table summarising each debugging tool, its primary purpose, and its dependencies on libtpu/JAX/jaxlib

Summary

  • libtpu is the central pillar: most debugging utilities either configure it (e.g., enable logging or HLO dumps) or query it for live data (e.g., monitoring, profiling).
  • Tools that operate at the Python level—such as XProf—inspect the state of your JAX program directly, but they still rely on jaxlib to translate those inspections into TPU‑compatible actions.

Understanding these relationships lets you pick the most appropriate tool for the problem you’re tackling, whether it’s a compilation issue, runtime performance bottleneck, or hardware‑level fault.

Essential Logging and Diagnostic Flags for Every Workload

Verbose Logging

Enabling verbose logging is the most critical step for debugging; without it you’re flying blind. Apply these flags on every worker of your TPU slice to log everything from TPU runtime setup to program‑execution steps with timestamps.

Log flags diagram

Enable default flags on every TPU worker node

gcloud alpha compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
  --project ${PROJECT_ID} --zone ${ZONE} --worker=all --node=all \
  --command='TPU_VMODULE=slice_configuration=1,real_program_continuator=1 \
TPU_MIN_LOG_LEVEL=0 TF_CPP_MIN_LOG_LEVEL=0 TPU_STDERR_LOG_LEVEL=0 \
python3 -c "import jax; print(f\"Host {jax.process_index()}: \
Global devices: {jax.device_count()}, Local devices: {jax.local_device_count()}\")"'

libtpu logs are automatically generated in /tmp/tpu_logs/tpu_driver.INFO on each TPU VM. This file is the ground‑truth for what the TPU runtime is doing. To collect logs from all TPU VMs, run the script below:

#!/usr/bin/env bash

TPU_NAME="your-tpu-name"
PROJECT="your-project-id"
ZONE="your-zone"
BASE_LOG_DIR="path/to/where/you/want/logs"

# Number of workers in the slice
NUM_WORKERS=$(gcloud compute tpus tpu-vm describe "$TPU_NAME" \
  --zone="$ZONE" --project="$PROJECT" \
  | grep tpuVmSelflink | awk -F'[:/]' '{print $13}' | uniq | wc -l)

echo "Number of workers = $NUM_WORKERS"

for ((i=0; i<NUM_WORKERS; i++)); do
  mkdir -p "${BASE_LOG_DIR}/${i}"
  echo "Downloading logs from worker=$i"
  gcloud compute tpus tpu-vm scp "${TPU_NAME}:/tmp/tpu_logs/*" \
    "${BASE_LOG_DIR}/${i}/" --zone="${ZONE}" --project="${PROJECT}" --worker=$i
done

On Google Colab you can set the same environment variables via os.environ and access the logs in the Files pane on the left sidebar.

Example log snippets

...
I1031 19:02:51.863599     669 b295d63588a.cc:843] Process id 669
I1031 19:02:51.863609     669 b295d63588a.cc:848] Current working directory /content
...
I1031 19:02:51.863621     669 b295d63588a.cc:866] Build tool: Bazel, release r4rca-2025.05.26-2 (mainline @763214608)
I1031 19:02:51.863621     669 b295d63588a.cc:867] Build target:
I1031 19:02:51.863624     669 b295d63588a.cc:874] Command line arguments:
I1031 19:02:51.863624     669 b295d63588a.cc:876] argv[0]: './tpu_driver'
...
I1031 19:02:51.863784     669 init.cc:78] Remote crash gathering hook installed.
I1031 19:02:51.863807     669 tpu_runtime_type_flags.cc:79] --tpu_use_tfrt not specified. Using default value: true
I1031 19:02:51.873759     669 tpu_hal.cc:448] Registered plugin from module: breakpoint_debugger_server
...
I1031 19:02:51.879890     669 pending_event_logger.cc:896] Enabling PjRt/TPU event dependency logging
I1031 19:02:51.880524     843 device_util.cc:124] Found 1 TPU v5 lite chips.
...
I1031 19:02:53.471830     851 2a886c8_compiler_base.cc:3677] CODE_GENERATION stage duration: 3.610218ms
I1031 19:02:53.471885
...
851 isa_program_util_common.cc:486] (HLO module jit_add): Executable fingerprint:0cae8d08bd660ddbee7ef03654ae249ae4122b40da162a3b0ca2cd4bb4b3a19c

TPU Monitoring Library

The TPU Monitoring Library provides programmatic insight into workflow performance on TPU hardware (utilization, capacity, latency, and more). It is part of the libtpu package, which is installed automatically as a dependency of jax[tpu].

# Explicit installation (optional)
pip install "jax[tpu]" libtpu

List all supported metrics:

from libtpu.sdk import tpumonitoring

print(tpumonitoring.list_supported_metrics())

Get a specific metric (e.g., duty_cycle_pct) and print its data and description:

from libtpu.sdk import tpumonitoring

duty_cycle_metric = tpumonitoring.get_metric("duty_cycle_pct")
print("TPU Duty Cycle Data:")
print(f"  Description: {duty_cycle_metric.description}")
print(f"  Data: {duty_cycle_metric.data}")

You would typically integrate tpumonitoring directly in your JAX programs—during model training, before inference, etc. See the Cloud TPU documentation for more details.

tpu‑info

The tpu‑info CLI provides a real‑time view of TPU memory and other utilization metrics, similar to nvidia‑smi for GPUs.

Install on all workers and nodes

gcloud alpha compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
  --project ${PROJECT_ID} --zone ${ZONE} --worker=all --node=all \
  --command='pip install tpu-info'

Check chip utilization on a single worker/node

gcloud alpha compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
  --project ${PROJECT_ID} --zone ${ZONE} --worker=0 --node=0

tpu-info

When chips are in use, process IDs, memory usage, and duty‑cycle % will be displayed.

TPU in use

When no chips are in use, the TPU VM will show no activity.

TPU idle

Learn more about additional metrics and streaming mode in the tpu‑info documentation.


In this post we discussed several TPU logging and monitoring options. Next in the series we’ll explore how to debug your JAX programs—starting with generating HLO dumps and profiling your code with XProf.

Back to Blog

Related posts

Read more »