Building production AI on Google Cloud TPUs with JAX

Published: (November 30, 2025 at 06:25 PM EST)
5 min read

Source: Google Developers Blog

Nov 19, 2025Rakesh Iyer, Senior Software Engineering Manager, Google ML Frameworks

JAX logo

JAX has become a key framework for developing state‑of‑the‑art foundation models across the AI landscape, and not just at Google. Leading LLM providers such as Anthropic, xAI, and Apple are utilizing the open‑source JAX framework as one of the tools to build their foundation models.

Today, we are excited to share an overview of the JAX AI Stack — a robust, end‑to‑end platform based on JAX, the core numerical library, into an industrial‑grade solution for machine learning at any scale.

To showcase the power and design of this ecosystem, we have published a detailed technical report explaining every component. We urge developers, researchers, and infrastructure engineers to read the full report to understand how these tools can be leveraged for your specific needs.

Below, we outline the architectural philosophy and key components that form a robust and flexible platform for modern AI.

The Architectural Imperative: Modularity and Performance

The JAX AI Stack is built on a philosophy of modular, loosely coupled components, where each library is designed to excel at a single task. This approach empowers users to build a bespoke ML stack, selecting and combining the best libraries for optimization, data loading, or checkpointing to precisely fit their requirements. Crucially, this modularity is vital in the rapidly evolving field of AI. It allows for rapid innovation, as new libraries and techniques can be developed and integrated without the risk and overhead of modifying a large, monolithic framework.

A modern ML stack must provide a continuum of abstraction: automated high‑level optimizations for speed of development, and fine‑grained, manual control for when every microsecond counts. The JAX AI Stack is designed to offer this continuum.

The Core “JAX AI Stack”

At the heart of the JAX ecosystem is the “JAX AI Stack” consisting of four key libraries that provide the foundation for model development, all built on the compiler‑first design of JAX and XLA.

  • JAX – The foundation for accelerator‑oriented array computation. Its pure functional programming model makes transformations composable, allowing workloads to scale effectively across hardware types and cluster sizes.
  • Flax – Provides a flexible, intuitive API for model authoring and “surgery,” bridging the gap between JAX’s functional core and the object‑oriented preferences of many developers.
  • Optax – A library of composable gradient‑processing and optimization transformations. It lets researchers declaratively chain standard optimizers (e.g., Adam) with techniques like gradient clipping or accumulation in just a few lines of code.
  • Orbax – An “any‑scale” checkpointing library that supports asynchronous distributed checkpointing, ensuring that expensive training runs can withstand hardware failures without losing significant progress.

The jax-ai-stack metapackage can be installed with:

pip install jax-ai-stack

JAX ecosystem
The JAX AI Stack and Ecosystem Components

The Extended JAX AI Stack

Building on this stable core, a rich ecosystem of specialized libraries provides the end‑to‑end capabilities needed for the entire ML lifecycle.

Industrial‑Scale Infrastructure

Beneath the user‑facing libraries lies the infrastructure that enables JAX to scale from a single TPU/GPU to thousands of GPUs/TPUs seamlessly.

  • XLA (Accelerated Linear Algebra) – A domain‑specific, hardware‑agnostic compiler that delivers strong out‑of‑the‑box performance by using whole‑program analysis to fuse operators and optimize memory layouts.
  • Pathways – The unified runtime for massive‑scale distributed computation, allowing researchers to code as if using a single powerful machine while Pathways orchestrates computation across tens of thousands of chips.

Advanced Development for Peak Efficiency

To achieve the highest levels of hardware utilization, the ecosystem provides specialized tools that offer deeper control and higher efficiency.

  • Pallas & Tokamax – Extensions for writing custom kernels for TPUs and GPUs with precise control over memory hierarchy and parallelism; Tokamax supplies a curated library of state‑of‑the‑art kernels (e.g., FlashAttention).
  • Qwix – A comprehensive, non‑intrusive quantization library that enables techniques like QLoRA or PTQ by intercepting JAX functions, requiring minimal or no changes to the original model code.
  • Grain – A performant, deterministic data‑loading library that integrates with Orbax to checkpoint the exact state of the data pipeline alongside the model, guaranteeing bit‑for‑bit reproducibility after restarts.

The Full Path to Production

Other modules augment the JAX AI Stack with a mature, end‑to‑end application layer that bridges research and deployment.

  • MaxText & MaxDiffusion – Flagship, scalable frameworks for LLM and diffusion model training, optimized for goodput and Model FLOPs Utilization (MFU) out of the box.
  • Tunix – A JAX‑native library for post‑training alignment, offering state‑of‑the‑art algorithms such as SFT with LoRA/QLoRA, GRPO, GSPO, DPO, and PPO. MaxText integration with Tunix provides the most performant and scalable post‑training for Google Cloud customers.
  • Inference Solutions – For maximum compatibility, we provide the popular vLLM serving framework for any model.

Read the Report, Explore the Stack

The JAX AI Stack is more than just a collection of libraries; it is a modular, production‑ready platform, co‑designed with Cloud TPUs to tackle the next generation of AI challenges. This deep integration of software and hardware delivers a compelling advantage in both performance and total cost of ownership, as seen across a diverse range of applications:

  • Kakao leveraged the stack to overcome infrastructure limits, achieving a 2.7× throughput increase for their LLMs while optimizing cost‑performance.
  • Lightricks broke through a scaling wall with a 13‑billion‑parameter video diffusion model, unlocking linear scalability and accelerating research.
  • Escalante combined a dozen models into a single optimization for AI‑driven protein design, achieving 3.65× better performance per dollar.

We invite you to explore the ecosystem, read the technical report, and get started at the new central hub: https://jaxstack.ai.

Getting Started


Back to Blog

Related posts

Read more »