使用 JAX 在 Google Cloud TPU 上构建生产 AI

发布: (2025年12月1日 GMT+8 07:25)
8 min read

Source: Google Developers Blog

2025年11月19日Rakesh Iyer,Google机器学习框架高级软件工程经理

JAX 标志

JAX 已成为在整个 AI 领域开发最先进基础模型的关键框架,而且不仅限于 Google。Anthropic、xAI 和 Apple 等领先的大语言模型提供商正将开源的 JAX 框架作为构建其基础模型的工具之一。

今天,我们很高兴分享 JAX AI Stack 的概览——一个基于 JAX 的核心数值库,面向任何规模机器学习的工业级完整平台。

为了展示该生态系统的强大与设计理念,我们已发布了详细的 技术报告,阐述每个组件。我们鼓励开发者、研究者和基础设施工程师阅读完整报告,了解如何将这些工具用于您的特定需求。

下面,我们概述了构成现代 AI 稳健且灵活平台的架构哲学和关键组件。

架构要义:模块化与性能

JAX AI Stack 基于模块化、松耦合组件的哲学构建,每个库都专注于单一任务。这种方式使用户能够构建定制化的机器学习栈,挑选并组合最适合优化、数据加载或检查点的库,以精准匹配需求。模块化在快速演进的 AI 领域尤为关键,它允许新库和新技术在不修改庞大单体框架的风险和开销下快速研发并集成。

现代机器学习栈必须提供抽象的连续体:面向开发速度的自动化高级优化,以及在每个微秒都至关重要时的细粒度手动控制。JAX AI Stack 正是为此而设计。

核心 “JAX AI Stack”

JAX 生态的核心是由四个关键库组成的 “JAX AI Stack”,它们都基于 JAX 与 XLA 的编译器优先设计,为模型开发提供基础。

  • JAX – 面向加速器的数组计算基础。其纯函数式编程模型使得转换可组合,能够在不同硬件类型和集群规模上高效扩展。
  • Flax – 提供灵活直观的模型编写和 “手术” API,弥合 JAX 函数式核心与许多开发者偏好的面向对象之间的差距。
  • Optax – 可组合的梯度处理与优化转换库。研究者只需几行代码即可声明式地将标准优化器(如 Adam)与梯度裁剪、累积等技术链式组合。
  • Orbax – 支持异步分布式检查点的 “任意规模” 检查点库,确保昂贵的训练过程在硬件故障时仍能保留重要进度。

jax-ai-stack 元包的安装方式如下:

pip install jax-ai-stack

JAX 生态系统
JAX AI Stack 与生态系统组件

扩展的 JAX AI Stack

在这个稳固的核心之上,丰富的专用库生态为整个机器学习生命周期提供端到端能力。

工业规模基础设施

在面向用户的库之下,是支撑 JAX 从单个 TPU/GPU 无缝扩展到成千上万 GPU/TPU 的基础设施。

  • XLA(Accelerated Linear Algebra)——面向特定领域、硬件无关的编译器,通过全程序分析融合算子并优化内存布局,提供开箱即用的强大性能。
  • Pathways——统一的分布式计算运行时,使研究者能够像使用单台强大机器一样编写代码,而 Pathways 则在数万颗芯片上调度计算。

为极致效率而生的高级开发

为实现最高硬件利用率,生态系统提供了专用工具,提供更深层的控制和更高的效率。

  • PallasTokamax——用于在 TPU 和 GPU 上编写自定义内核的扩展,能够精确控制内存层次结构和并行度;Tokamax 提供了包括 FlashAttention 在内的前沿内核库。
  • Qwix——全面且非侵入式的量化库,通过拦截 JAX 函数实现 QLoRA、PTQ 等技术,几乎不需要或根本不需要修改原始模型代码。
  • Grain——高性能、确定性的数据加载库,可与 Orbax 集成,在模型检查点时同步保存数据管道的完整状态,保证重启后位级可复现。

完整的生产路径

其他模块为 JAX AI Stack 增添了成熟的端到端应用层,连接研究与部署。

  • MaxTextMaxDiffusion——面向 LLM 与扩散模型训练的旗舰可扩展框架,开箱即提供良好的吞吐量和模型 FLOPs 利用率(MFU)。
  • Tunix——原生 JAX 的后训练对齐库,提供 SFT + LoRA/QLoRA、GRPO、GSPO、DPO、PPO 等最前沿算法。MaxText 与 Tunix 的结合为 Google Cloud 客户提供了最高效、最可扩展的后训练方案。
  • 推理解决方案——为实现最大兼容性,我们提供流行的 vLLM serving 框架,支持任意模型。

阅读报告,探索栈

JAX AI Stack 不仅是库的集合,更是一个模块化、可投产的平台,已与 Cloud TPUs 共同设计,以应对下一代 AI 挑战。这种软硬件深度融合在性能和总体拥有成本上均提供了显著优势,已在多种应用中得到验证:

  • Kakao 利用该栈突破基础设施瓶颈,使其 LLM 的吞吐量提升 2.7 倍,并优化了性价比。
  • Lightricks 在 130 亿参数的视频扩散模型上突破扩展壁垒,实现线性可扩展并加速了研究进程。
  • Escalante 将十余个模型合并为单一优化流程用于 AI 蛋白质设计,达成 3.65 倍 的每美元性能提升。

我们诚邀您探索生态系统,阅读 技术报告,并在全新中心枢纽 https://jaxstack.ai 开始使用。

入门指南

Back to Blog

相关文章

阅读更多 »

宣布 Data Commons Gemini CLI 扩展

自从我们在十月初推出 Gemini CLI 扩展框架以来,我们已经看到 Google 自有和第三方贡献的扩展在 op... 中呈爆炸式增长。