Building production AI on Google Cloud TPUs with JAX

Published: (December 9, 2025 at 02:18 AM EST)
3 min read

Source: Google Developers Blog

The Architectural Imperative: Modularity and Performance

The JAX AI Stack is built on a philosophy of modular, loosely coupled components, where each library is designed to excel at a single task. This approach empowers users to build a bespoke ML stack, selecting and combining the best libraries for optimization, data loading, or checkpointing to precisely fit their requirements. Crucially, this modularity is vital in the rapidly evolving field of AI. It allows for rapid innovation, as new libraries and techniques can be developed and integrated without the risk and overhead of modifying a large, monolithic framework.

A modern ML stack must provide a continuum of abstraction: automated high‑level optimizations for speed of development, and fine‑grained, manual control for when every microsecond counts. The JAX AI Stack is designed to offer this continuum.

The Core “JAX AI Stack”

At the heart of the JAX ecosystem is the “JAX AI Stack” consisting of four key libraries that provide the foundation for model development, all built on the compiler‑first design of JAX and XLA.

  • JAX – The foundation for accelerator‑oriented array computation. Its pure functional programming model makes transformations composable, allowing workloads to scale effectively across hardware types and cluster sizes.
  • Flax – Provides a flexible, intuitive API for model authoring and “surgery,” offering an object‑oriented layer on top of JAX without sacrificing performance.
  • Optax – A library of composable gradient‑processing and optimization transformations. Chain standard optimizers (e.g., Adam) with techniques like gradient clipping or accumulation in just a few lines of code.
  • Orbax – An “any‑scale” checkpointing library that supports asynchronous distributed checkpointing, ensuring training runs can withstand hardware failures without losing significant progress.

The jax-ai-stack metapackage can be installed with:

pip install jax-ai-stack

JAX ecosystem

The JAX AI Stack and Ecosystem Components

The Extended JAX AI Stack

Building on this stable core, a rich ecosystem of specialized libraries provides the end‑to‑end capabilities needed for the entire ML lifecycle.

Industrial‑Scale Infrastructure

Beneath the user‑facing libraries lies the infrastructure that enables JAX to scale from a single TPU/GPU to thousands of GPUs/TPUs seamlessly.

  • XLA (Accelerated Linear Algebra) – A domain‑specific, hardware‑agnostic compiler that delivers strong out‑of‑the‑box performance through whole‑program analysis, operator fusion, and memory‑layout optimization.
  • Pathways – A unified runtime for massive‑scale distributed computation, allowing developers to code as if using a single powerful machine while Pathways orchestrates execution across tens of thousands of chips with built‑in fault tolerance.

Advanced Development for Peak Efficiency

To achieve the highest levels of hardware utilization, the ecosystem provides specialized tools that offer deeper control and higher efficiency.

  • Pallas & Tokamax – Extensions for writing custom kernels for TPUs and GPUs with precise control over memory hierarchy and parallelism. Tokamax supplies a curated library of state‑of‑the‑art kernels (e.g., FlashAttention) for plug‑and‑play performance.
  • Qwix – A comprehensive, non‑intrusive quantization library that enables techniques like QLoRA or PTQ by intercepting JAX functions, requiring minimal or no changes to the original model code.
  • Grain – A performant, deterministic data‑loading library that integrates with Orbax to checkpoint the exact state of the data pipeline, guaranteeing bit‑for‑bit reproducibility after restarts.

The Full Path to Production

Other modules augment the JAX AI Stack with a mature, end‑to‑end application layer that bridges the gap from research to deployment.

  • MaxText & MaxDiffusion – Scalable frameworks for LLM and diffusion model training, offering reliable starting points optimized for goodput and Model FLOPs Utilization (MFU) out of the box.
  • Tunix – A JAX‑native library for post‑training alignment, providing state‑of‑the‑art algorithms such as supervised fine‑tuning (SFT).
Back to Blog

Related posts

Read more »

Building with Gemini 3 in Jules

NOV. 19, 2025 On Tuesday we introduced Gemini 3, Google’s most intelligent model that can help bring any idea to life. Today, we’re excited to share that Gemini...