클라우드 TPU에서 JAX 디버깅을 위한 개발자 가이드: 핵심 도구 및 기법
Source: Google Developers Blog
날짜: 2026년 1월 5일
작성자: Brian Kang – Senior Staff – Field Solutions Architect, AI Infrastructure
작성자 페이지
올바른 도구 선택: 핵심 구성 요소와 종속성
시스템의 핵심에는 거의 모든 디버깅 도구가 의존하는 두 가지 주요 구성 요소가 있습니다:
| 구성 요소 | 설명 | 디버거가 필요로 하는 이유 |
|---|---|---|
| libtpu | 모든 Cloud TPU VM에 존재하는 공유 라이브러리(libtpu.so). XLA 컴파일러, TPU 드라이버, 하드웨어 통신 로직을 포함합니다. | 대부분의 도구가 구성(예: 로깅, HLO 덤프)하거나 조회(예: 실시간 메트릭)하는 저수준 런타임을 제공합니다. |
| JAX + jaxlib | JAX는 모델 코드를 작성하는 Python 프론트엔드이며, jaxlib는 Python을 libtpu.so와 연결하는 C++ 백엔드입니다. | 도구가 계측하는 고수준 API(예: 트레이싱, 프로파일링)와 TPU에서 실행되는 컴파일된 커널을 제공합니다. |
시각적 개요
-
구성 요소 관계 다이어그램

-
도구‑종속성 표

요약
libtpu는 중심 기둥입니다: 대부분의 디버깅 유틸리티가 이를 구성(예: 로깅 활성화 또는 HLO 덤프)하거나 실시간 데이터(예: 모니터링, 프로파일링)를 조회합니다.- Python 수준에서 동작하는 도구—예를 들어 XProf—는 JAX 프로그램의 상태를 직접 검사하지만, 여전히
jaxlib에 의존해 이러한 검사를 TPU 호환 작업으로 변환합니다.
이러한 관계를 이해하면 컴파일 문제, 런타임 성능 병목, 하드웨어 수준 오류 등 해결하고자 하는 문제에 가장 적합한 도구를 선택할 수 있습니다.
Source:
모든 워크로드를 위한 필수 로깅 및 진단 플래그
자세한 로깅
자세한 로깅을 활성화하는 것은 디버깅에 가장 중요한 단계이며, 이를 하지 않으면 눈이 먼 채로 작업하게 됩니다. TPU 슬라이스의 각 워커에 아래 플래그를 적용하여 TPU 런타임 설정부터 프로그램 실행 단계까지 모든 과정을 타임스탬프와 함께 기록하세요.

모든 TPU 워커 노드에 기본 플래그 활성화
gcloud alpha compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
--project ${PROJECT_ID} --zone ${ZONE} --worker=all --node=all \
--command='TPU_VMODULE=slice_configuration=1,real_program_continuator=1 \
TPU_MIN_LOG_LEVEL=0 TF_CPP_MIN_LOG_LEVEL=0 TPU_STDERR_LOG_LEVEL=0 \
python3 -c "import jax; print(f\"Host {jax.process_index()}: \
Global devices: {jax.device_count()}, Local devices: {jax.local_device_count()}\")"'
libtpu 로그는 각 TPU VM의 /tmp/tpu_logs/tpu_driver.INFO에 자동으로 생성됩니다. 이 파일이 TPU 런타임이 수행하는 작업에 대한 최종 증거입니다. 모든 TPU VM에서 로그를 수집하려면 아래 스크립트를 실행하세요:
#!/usr/bin/env bash
TPU_NAME="your-tpu-name"
PROJECT="your-project-id"
ZONE="your-zone"
BASE_LOG_DIR="path/to/where/you/want/logs"
# 슬라이스 내 워커 수
NUM_WORKERS=$(gcloud compute tpus tpu-vm describe "$TPU_NAME" \
--zone="$ZONE" --project="$PROJECT" \
| grep tpuVmSelflink | awk -F'[:/]' '{print $13}' | uniq | wc -l)
echo "Number of workers = $NUM_WORKERS"
for ((i=0; i<NUM_WORKERS; i++)); do
mkdir -p "${BASE_LOG_DIR}/${i}"
echo "Downloading logs from worker=$i"
gcloud compute tpus tpu-vm scp "${TPU_NAME}:/tmp/tpu_logs/*" \
"${BASE_LOG_DIR}/${i}/" --zone="${ZONE}" --project="${PROJECT}" --worker=$i
done
Google Colab에서는 os.environ을 통해 동일한 환경 변수를 설정하고, 왼쪽 사이드바의 Files 패널에서 로그에 접근할 수 있습니다.
로그 스니펫 예시
...
I1031 19:02:51.863599 669 b295d63588a.cc:843] Process id 669
I1031 19:02:51.863609 669 b295d63588a.cc:848] Current working directory /content
...
I1031 19:02:51.863621 669 b295d63588a.cc:866] Build tool: Bazel, release r4rca-2025.05.26-2 (mainline @763214608)
I1031 19:02:51.863621 669 b295d63588a.cc:867] Build target:
I1031 19:02:51.863624 669 b295d63588a.cc:874] Command line arguments:
I1031 19:02:51.863624 669 b295d63588a.cc:876] argv[0]: './tpu_driver'
...
I1031 19:02:51.863784 669 init.cc:78] Remote crash gathering hook installed.
I1031 19:02:51.863807 669 tpu_runtime_type_flags.cc:79] --tpu_use_tfrt not specified. Using default value: true
I1031 19:02:51.873759 669 tpu_hal.cc:448] Registered plugin from module: breakpoint_debugger_server
...
I1031 19:02:51.879890 669 pending_event_logger.cc:896] Enabling PjRt/TPU event dependency logging
I1031 19:02:51.880524 843 device_util.cc:124] Found 1 TPU v5 lite chips.
...
I1031 19:02:53.471830 851 2a886c8_compiler_base.cc:3677] CODE_GENERATION stage duration: 3.610218ms
I1031 19:02:53.471885
...
851 isa_program_util_common.cc:486] (HLO module jit_add): Executable fingerprint:0cae8d08bd660ddbee7ef03654ae249ae4122b40da162a3b0ca2cd4bb4b3a19c
TPU 모니터링 라이브러리
TPU 모니터링 라이브러리는 TPU 하드웨어에서 워크플로 성능(활용도, 용량, 지연 시간 등)에 대한 프로그래밍 방식의 인사이트를 제공합니다. 이는 jax[tpu] 의존성으로 자동 설치되는 libtpu 패키지의 일부입니다.
# 명시적 설치 (선택 사항)
pip install "jax[tpu]" libtpu
지원되는 모든 메트릭을 나열하려면:
from libtpu.sdk import tpumonitoring
print(tpumonitoring.list_supported_metrics())
특정 메트릭(예: duty_cycle_pct)을 가져와 데이터와 설명을 출력하려면:
from libtpu.sdk import tpumonitoring
duty_cycle_metric = tpumonitoring.get_metric("duty_cycle_pct")
print("TPU Duty Cycle Data:")
print(f" Description: {duty_cycle_metri")
```python
c.description}")
print(f" Data: {duty_cycle_metric.data}")
보통 tpumonitoring은 JAX 프로그램에 직접 통합하여 모델 학습 중이나 추론 전 등에서 사용합니다. 자세한 내용은 Cloud TPU 문서를 참고하세요.
tpu‑info
tpu‑info CLI는 GPU용 nvidia‑smi와 유사하게 TPU 메모리 및 기타 활용 지표를 실시간으로 보여줍니다.
모든 워커와 노드에 설치
gcloud alpha compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
--project ${PROJECT_ID} --zone ${ZONE} --worker=all --node=all \
--command='pip install tpu-info'
단일 워커/노드에서 칩 활용도 확인
gcloud alpha compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
--project ${PROJECT_ID} --zone ${ZONE} --worker=0 --node=0
tpu-info
칩이 사용 중이면 프로세스 ID, 메모리 사용량, duty‑cycle %가 표시됩니다.

칩이 사용되지 않으면 TPU VM에 활동이 표시되지 않습니다.

추가 메트릭 및 스트리밍 모드에 대해서는 tpu‑info 문서를 참고하세요.
이번 글에서는 여러 TPU 로깅 및 모니터링 옵션을 살펴보았습니다. 시리즈의 다음 편에서는 JAX 프로그램 디버깅 방법을 다룰 예정이며, HLO 덤프 생성 및 XProf를 이용한 프로파일링부터 시작합니다.