Google Cloud TPUs와 JAX를 사용한 프로덕션 AI 구축

발행: (2025년 12월 1일 오전 08:25 GMT+9)
10 min read

Source: Google Developers Blog

Nov 19, 2025Rakesh Iyer, Senior Software Engineering Manager, Google ML Frameworks

JAX logo

JAX는 AI 전반에 걸쳐 최첨단 파운데이션 모델을 개발하기 위한 핵심 프레임워크가 되었으며, 구글에만 국한되지 않습니다. Anthropic, xAI, Apple 등 주요 LLM 제공업체들이 오픈소스 JAX 프레임워크를 파운데이션 모델 구축 도구 중 하나로 활용하고 있습니다.

오늘은 JAX AI Stack에 대한 개요를 공유하게 되어 기쁩니다 — JAX를 핵심 수치 라이브러리로 삼아, 모든 규모의 머신러닝을 위한 산업 수준 솔루션으로 확장한 견고하고 엔드‑투‑엔드 플랫폼입니다.

이 생태계의 힘과 설계를 보여주기 위해, 각 구성 요소를 상세히 설명한 기술 보고서를 공개했습니다. 개발자, 연구자, 인프라 엔지니어 여러분께 전체 보고서를 읽고, 필요에 맞게 도구를 활용하는 방법을 이해하시길 권장합니다.

아래에서는 현대 AI를 위한 견고하고 유연한 플랫폼을 구성하는 아키텍처 철학과 핵심 구성 요소를 정리했습니다.

The Architectural Imperative: Modularity and Performance

JAX AI Stack은 모듈식이며 느슨하게 결합된 구성 요소라는 철학 위에 구축되었습니다. 각 라이브러리는 하나의 작업에 최적화되도록 설계되었습니다. 이 접근 방식은 사용자가 최적화, 데이터 로딩, 체크포인팅 등 요구 사항에 정확히 맞는 최고의 라이브러리를 선택·조합해 맞춤형 ML 스택을 만들 수 있게 합니다. 특히 AI 분야는 급변하기 때문에 이러한 모듈성은 필수적입니다. 새로운 라이브러리와 기술을 기존 대규모 단일 프레임워크를 수정하는 위험과 비용 없이 빠르게 개발·통합할 수 있습니다.

현대 ML 스택은 추상화의 연속성을 제공해야 합니다. 개발 속도를 높이는 자동화된 고수준 최적화와, 마이크로초 단위까지 제어가 필요한 경우를 위한 세밀하고 수동적인 제어가 공존해야 합니다. JAX AI Stack은 이러한 연속성을 제공하도록 설계되었습니다.

The Core “JAX AI Stack”

JAX 생태계의 핵심은 네 개의 주요 라이브러리로 구성된 “JAX AI Stack”이며, 모두 JAX와 XLA의 컴파일러‑우선 설계를 기반으로 합니다.

  • JAX – 가속기‑지향 배열 연산의 기반. 순수 함수형 프로그래밍 모델 덕분에 변환을 조합할 수 있어, 워크로드를 다양한 하드웨어와 클러스터 규모에 효과적으로 확장할 수 있습니다.
  • Flax – 모델 작성 및 “수술”을 위한 유연하고 직관적인 API를 제공, JAX의 함수형 코어와 객체‑지향을 선호하는 개발자 사이의 격차를 메워줍니다.
  • Optax – 조합 가능한 그래디언트 처리 및 최적화 변환 라이브러리. 연구자가 표준 옵티마이저(예: Adam)를 그래디언트 클리핑이나 누적과 같은 기법과 몇 줄의 코드만으로 선언적으로 연결할 수 있습니다.
  • Orbax – “any‑scale” 체크포인팅 라이브러리로, 비동기 분산 체크포인팅을 지원해 비용이 많이 드는 학습이 하드웨어 장애에도 큰 진행 상황을 잃지 않도록 합니다.

jax-ai-stack 메타패키지는 다음과 같이 설치할 수 있습니다:

pip install jax-ai-stack

JAX ecosystem
JAX AI Stack 및 생태계 구성 요소

The Extended JAX AI Stack

이 안정적인 코어 위에, 전체 ML 라이프사이클을 지원하는 풍부한 특화 라이브러리 생태계가 구축됩니다.

Industrial‑Scale Infrastructure

사용자‑대면 라이브러리 아래에는 JAX가 단일 TPU/GPU부터 수천 개의 GPU/TPU까지 원활히 확장될 수 있게 하는 인프라가 자리합니다.

  • XLA (Accelerated Linear Algebra) – 도메인‑특화, 하드웨어‑독립 컴파일러로, 전체 프로그램 분석을 통해 연산을 융합하고 메모리 레이아웃을 최적화해 뛰어난 즉시 성능을 제공합니다.
  • Pathways – 수만 개의 칩에 걸친 대규모 분산 연산을 오케스트레이션하면서, 연구자가 마치 단일 강력한 머신을 사용하는 것처럼 코딩할 수 있게 하는 통합 런타임입니다.

Advanced Development for Peak Efficiency

최고 수준의 하드웨어 활용도를 달성하기 위해, 생태계는 보다 깊은 제어와 높은 효율성을 제공하는 특화 도구들을 제공합니다.

  • Pallas & Tokamax – 메모리 계층 구조와 병렬성에 대한 정밀 제어를 통해 TPU와 GPU용 커스텀 커널을 작성할 수 있는 확장 기능; Tokamax는 최신 커널(예: FlashAttention) 라이브러리를 큐레이션합니다.
  • Qwix – 비침해적 양자화 라이브러리로, JAX 함수를 가로채 QLoRA나 PTQ와 같은 기법을 적용합니다. 원본 모델 코드를 거의 혹은 전혀 수정하지 않아도 됩니다.
  • Grain – 고성능, 결정론적 데이터 로딩 라이브러리로, Orbax와 통합해 데이터 파이프라인의 정확한 상태를 모델과 함께 체크포인팅하여 재시작 후 비트‑대‑비트 재현성을 보장합니다.

The Full Path to Production

다른 모듈들은 연구와 배포를 연결하는 성숙한 엔드‑투‑엔드 애플리케이션 레이어를 추가합니다.

  • MaxText & MaxDiffusion – LLM 및 디퓨전 모델 학습을 위한 대표적인 확장 가능한 프레임워크로, 기본 제공되는 좋은 처리량과 Model FLOPs Utilization (MFU) 최적화를 특징으로 합니다.
  • Tunix – JAX‑네이티브 사후 훈련 정렬 라이브러리로, LoRA/QLoRA 기반 SFT, GRPO, GSPO, DPO, PPO 등 최신 알고리즘을 제공합니다. MaxText와 Tunix의 통합은 Google Cloud 고객에게 가장 성능이 뛰어나고 확장 가능한 사후 훈련 환경을 제공합니다.
  • Inference Solutions – 최대 호환성을 위해, 모든 모델에 사용할 수 있는 인기 vLLM serving 프레임워크를 제공합니다.

Read the Report, Explore the Stack

JAX AI Stack은 단순히 라이브러리 모음이 아니라, Cloud TPUs와 공동 설계된 모듈식, 프로덕션‑레디 플랫폼입니다. 소프트웨어와 하드웨어의 깊은 통합은 성능과 총소유비용(TCO) 모두에서 강력한 이점을 제공하며, 다양한 적용 사례에서 입증되었습니다:

  • Kakao는 스택을 활용해 인프라 한계를 극복하고 LLM의 처리량을 2.7배 향상시키면서 비용‑성능을 최적화했습니다.
  • Lightricks는 130억 파라미터 규모의 비디오 디퓨전 모델을 구축하면서 확장성 장벽을 깨고 선형 확장성을 달성해 연구 속도를 가속화했습니다.
  • Escalante는 12개의 모델을 하나의 최적화 파이프라인으로 결합해 AI‑기반 단백질 설계에서 3.65배 더 나은 달러당 성능을 얻었습니다.

생태계를 직접 탐험하고, 기술 보고서를 읽어보시고, 새로운 중앙 허브 https://jaxstack.ai 에서 시작해 보세요.

Getting Started

Back to Blog

관련 글

더 보기 »

Data Commons Gemini CLI 확장 발표

우리가 10월 초에 Gemini CLI extensions framework를 출시한 이후, Google이 소유한 확장과 제3자 기여 확장이 폭발적으로 증가하는 것을 보았습니다.