在 Google Cloud TPUs 上使用 JAX 构建生产级 AI
Source: Google Developers Blog
架构必然性:模块化与性能
JAX AI Stack 基于模块化、松耦合组件的理念构建,每个库都专注于单一任务。这种方式使用户能够构建定制化的机器学习栈,挑选并组合最适合优化、数据加载或检查点管理的库,以精准满足其需求。关键是,这种模块化在快速演进的 AI 领域尤为重要。它允许快速创新,因为新库和新技术可以在不修改大型单体框架的风险和开销的前提下开发并集成。
现代机器学习栈必须提供抽象的连续体:面向开发速度的自动化高级优化,以及在每个微秒都至关重要时的细粒度手动控制。JAX AI Stack 旨在提供这一连续体。
核心 “JAX AI Stack”
在 JAX 生态系统的核心是由四个关键库组成的 “JAX AI Stack”,它们为模型开发提供基础,全部基于 JAX 与 XLA 的编译器优先设计。
- JAX – 面向加速器的数组计算基础。其纯函数式编程模型使得转换可组合,能够在不同硬件类型和集群规模上有效扩展工作负载。
- Flax – 提供灵活直观的模型编写和 “手术” API,在不牺牲性能的前提下,为 JAX 增加面向对象的层。
- Optax – 可组合的梯度处理与优化转换库。只需几行代码即可将标准优化器(如 Adam)与梯度裁剪或累积等技术链式组合。
- Orbax – 支持异步分布式检查点的 “任意规模” 检查点库,确保训练过程在硬件故障时仍能保持显著进度。
jax-ai-stack 元包的安装方式如下:
pip install jax-ai-stack

JAX AI Stack 与生态系统组件
扩展的 JAX AI Stack
在这一稳固核心之上,丰富的专用库生态提供了完整机器学习生命周期所需的端到端能力。
工业规模基础设施
在面向用户的库之下,是支撑 JAX 从单个 TPU/GPU 无缝扩展到数千个 GPU/TPU 的基础设施。
- XLA(Accelerated Linear Algebra)—— 面向领域的、硬件无关的编译器,通过全程序分析、算子融合和内存布局优化,提供开箱即用的强大性能。
- Pathways—— 用于大规模分布式计算的统一运行时,开发者可以像使用单台强大机器一样编写代码,而 Pathways 则在数万颗芯片上调度执行并内置容错机制。
为极致效率而生的高级开发
为实现最高硬件利用率,生态系统提供了专用工具,赋予更深层次的控制和更高的效率。
- Pallas 与 Tokamax—— 用于在 TPU 与 GPU 上编写自定义内核的扩展,能够精确控制内存层次结构和并行度。Tokamax 提供了精选的前沿内核库(如 FlashAttention),实现即插即用的性能提升。
- Qwix—— 完整且非侵入式的量化库,通过拦截 JAX 函数实现 QLoRA、PTQ 等技术,几乎不需要或根本不需要修改原始模型代码。
- Grain—— 高性能、确定性的数据加载库,可与 Orbax 集成,对数据管道的完整状态进行检查点保存,保证重启后位级可复现。
完整的生产路径
其他模块为 JAX AI Stack 增添了成熟的端到端应用层,弥合了从研究到部署的鸿沟。
- MaxText 与 MaxDiffusion—— 用于大语言模型和扩散模型训练的可扩展框架,提供经过优化的开箱即用基准,提升吞吐量和模型 FLOPs 利用率(MFU)。
- Tunix—— 原生 JAX 的后训练对齐库,提供最先进的算法,如监督微调(SFT)。