클라우드 TPU에서 JAX 디버깅을 위한 개발자 가이드: 필수 도구와 기법

발행: (2026년 2월 12일 오전 09:43 GMT+9)
8 분 소요

Source: Google Developers Blog

저자

Brian Kang – 시니어 스태프, 필드 솔루션 아키텍트, AI 인프라스트럭처

JAX on Cloud TPU는 머신러닝 워크플로에 강력한 가속을 제공합니다. 분산 클라우드 환경에서 작업할 때는 로그 접근, 하드웨어 메트릭 등 워크플로를 디버깅하기 위한 특수 도구가 필요합니다. 이 블로그 포스트는 다양한 디버깅 및 프로파일링 기법에 대한 실용적인 가이드를 제공합니다.

올바른 도구 선택하기: 핵심 구성 요소 및 종속성

시스템의 핵심에는 거의 모든 디버깅 도구가 의존하는 두 가지 주요 구성 요소가 있습니다:

  • libtpu (libtpu.so, TPU 런타임 포함) – XLA 컴파일러, TPU 드라이버, 하드웨어와 통신하는 로직을 포함하는 모든 Cloud TPU VM에 있는 공유 라이브러리입니다. 거의 모든 디버깅 도구가 libtpu와 상호 작용하거나 이를 통해 구성됩니다.
  • JAXjaxlib – 프레임워크. JAX는 모델 코드를 작성하는 Python 라이브러리이며, jaxlib는 C++ 백엔드로 libtpu.so와 연결되는 역할을 합니다.

이 구성 요소들과 디버깅 도구 간의 관계는 아래 그림에 나와 있습니다.

relationship diagram

다음은 각 도구, 그 종속성 및 서로 간의 관계를 정리한 표입니다.

tool table

요약하면, libtpu는 대부분의 디버깅 도구가 구성(로그, HLO 덤프)이나 실시간 데이터 조회(모니터링, 프로파일링)를 위해 의존하는 중심 축입니다. XProf와 같은 다른 도구는 Python 수준에서 직접 JAX 프로그램의 상태를 검사합니다. 이러한 관계를 이해하면 현재 직면한 특정 문제에 가장 적합한 도구를 보다 효과적으로 선택할 수 있습니다.

모든 워크로드를 위한 필수 로깅 및 진단 플래그

상세 로깅

디버깅에 가장 중요한 단계는 상세 로깅을 활성화하는 것입니다. 이를 하지 않으면 눈을 가리고 작업하는 것과 같습니다. 이 플래그들은 TPU 슬라이스의 모든 워커에 설정되어야 하며, TPU 런타임 설정부터 프로그램 실행 단계까지 모든 정보를 타임스탬프와 함께 기록합니다.

log flags illustration

모든 TPU 워커 노드에 기본 플래그 활성화

gcloud alpha compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
  --project ${PROJECT_ID} --zone ${ZONE} --worker=all --node=all \
  --command='TPU_VMODULE=slice_configuration=1,real_program_continuator=1 \
            TPU_MIN_LOG_LEVEL=0 TF_CPP_MIN_LOG_LEVEL=0 TPU_STDERR_LOG_LEVEL=0 \
            python3 -c "import jax; print(f\"Host {jax.process_index()}: \
            Global devices: {jax.device_count()}, Local devices: \
            {jax.local_device_count()}\")"'

libtpu 로그는 각 TPU VM의 /tmp/tpu_logs/tpu_driver.INFO에 자동으로 생성됩니다. 이 파일이 TPU 런타임이 수행하고 있는 작업에 대한 최신 사실입니다. 모든 TPU VM에서 로그를 수집하려면 아래 스크립트를 실행하세요:

#!/bin/bash

TPU_NAME="your-tpu-name"
PROJECT="your-project-id"
ZONE="your-tpu-zone"
BASE_LOG_DIR="path/to/local/log/dir"

NUM_WORKERS=$(gcloud compute tpus tpu-vm describe "$TPU_NAME" \
               --zone="$ZONE" --project="$PROJECT" \
               | grep tpuVmSelflink | awk -F'[:/]' '{print $13}' | uniq | wc -l)

echo "Number of workers = $NUM_WORKERS"

for ((i=0; i<NUM_WORKERS; i++)); do
  mkdir -p "${BASE_LOG_DIR}/$i"
  echo "Downloading logs from worker=$i ..."
  gcloud compute tpus tpu-vm scp "${TPU_NAME}:/tmp/tpu_logs/*" \
       "${BASE_LOG_DIR}/$i/" --zone="${ZONE}" --project="${PROJECT}" --worker=$i
done

Google Colab에서는 os.environ을 통해 동일한 환경 변수를 설정하고, 왼쪽 Files 패널에서 로그에 접근할 수 있습니다.

로그 스니펫 예시

...
I1031 19:02:51.863599     669 b295d63588a.cc:843] Process id 669
I1031 19:02:51.863609     669 b295d63588a.cc:848] Current working directory /content
...
I1031 19:02:51.863621     669 b295d63588a.cc:866] Build tool: Bazel, release r4rca-2025.05.26-2 (mainline @763214608)
I1031 19:02:51.863621     669 b295d63588a.cc:867] Build target:
I1031 19:02:51.863624     669 b295d63588a.cc:874] Command line arguments:
I1031 19:02:51.863624     669 b295d63588a.cc:876] argv[0]: './tpu_driver'
...
I1031 19:02:51.863784     669 init.cc:78] Remote crash gathering hook installed.
I1031 19:02:51.863807     669 tpu_runtime_type_flags.cc:79] --tpu_use_tfrt not specified. Using default value: true
I1031 19:02:51.873759     669 tpu_hal.cc:448] Registered plugin from module: breakpoint_debugger_server
...
I1031 19:02:51.879890     669 pending_event_logger.cc:896] Enabling PjRt/TPU event dependency logging
I1031 19:02:51.880524     843 device_util.cc:124] Found 1 TPU v5 lite chips.
...
I1031 19:02:53.471830     851 2a886c8_compiler_base.cc:3677] CODE_GENERATION stage duration: 3.610218ms
I1031 19:02:53.471885
851 isa_program_util_common.cc:486] (HLO module jit_add): Executable fingerprint:0cae8d08bd660ddbee7ef03654ae249ae4122b40da162a3b0ca2cd4bb4b3a19c

TPU 모니터링 라이브러리

TPU 모니터링 라이브러리는 TPU 하드웨어에서 워크플로 성능(활용도, 용량, 지연 시간 등)을 프로그래밍 방식으로 인사이트를 제공하는 도구입니다. 이는 libtpu 패키지의 일부이며, jax[tpu]를 설치하면 자동으로 종속성으로 포함됩니다.

# 명시적 설치
pip install "jax[tpu]" libtpu

tpumonitoring.list_supported_metrics() 로 지원되는 모든 메트릭을 확인할 수 있고, tpumonitoring.get_metric 으로 특정 메트릭을 가져올 수 있습니다. 아래 예시는 duty_cycle_pct 데이터를 출력하고 설명을 표시합니다:

from libtpu.sdk import tpumonitoring

duty_cycle_metric = tpumonitoring.get_metric("duty_cycle_pct")
duty_cycle_data = duty_cycle_metric.data

print("TPU D

> **Source:** ...

```python
uty Cycle Data:")
print(f"  Description: {duty_cycle_metric.description}")
print(f"  Data: {duty_cycle_data}")

Monitoring Library에 대해 자세히 알아보려면 Cloud TPU 문서를 참조하세요.

tpu‑info

tpu‑info CLI는 GPU용 nvidia‑smi와 유사하게 TPU 메모리 및 활용 메트릭을 실시간으로 보여줍니다.

모든 워커와 노드에 설치

gcloud alpha compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
  --project ${PROJECT_ID} \
  --zone ${ZONE} \
  --worker=all \
  --node=all \
  --command='pip install tpu-info'

단일 워커와 노드에서 칩 활용도 확인

gcloud alpha compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
  --project ${PROJECT_ID} \
  --zone ${ZONE} \
  --worker=0 \
  --node=0

tpu-info

칩이 사용 중일 때는 프로세스 ID, 메모리 사용량, 듀티 사이클 %가 표시됩니다:

TPU utilization

칩이 사용되지 않을 경우, TPU VM에 활동이 표시되지 않습니다:

No activity

추가 메트릭 및 스트리밍 모드에 대해서는 문서를 참고하세요.

다음 단계

이번 포스트에서는 TPU 로깅 및 모니터링 옵션을 다루었습니다. 다음 편에서는 JAX 프로그램을 디버깅하는 방법을 살펴볼 예정이며, HLO 덤프 생성 및 XProf를 이용한 코드 프로파일링부터 시작합니다.

네비게이션

0 조회
Back to Blog

관련 글

더 보기 »

Developer Knowledge API 및 MCP Server 소개

AI 기반 개발자 도구의 생태계가—Antigravity와 같은 에이전시 플랫폼을 포함하여—https://developers.google.com/blog/build-with-google-antigravity-our-new-a... 에서 확대되고 있습니다.