[Paper] Compiler-First State Space Duality and Portable $O(1)$ Autoregressive Caching for Inference

Published: (March 10, 2026 at 08:03 AM EDT)
5 min read
Source: arXiv

Source: arXiv - 2603.09555v1

Overview

Cosmo Santoni’s paper shows that the inference engine for Mamba‑2’s state‑space models (SSMs) can be built entirely with standard XLA primitives—no hand‑crafted CUDA or Triton kernels are required. By exploiting the model’s “state‑space duality” (diagonal state, chunkable recurrence, and einsum‑heavy compute), the author achieves the theoretical (O(1)) cache‑friendly autoregressive decoding on CPUs, NVIDIA GPUs, and Google Cloud TPUs from a single JAX codebase.

Key Contributions

  • Compiler‑first design: Demonstrates that XLA’s fusion and tiling passes naturally optimise the Mamba‑2 recurrence, eliminating the need for custom kernels.
  • Portable (O(1)) autoregressive cache: Implements an on‑device cache that updates state without any host‑CPU synchronization, preserving constant‑time per‑token cost.
  • Unified JAX implementation: A single source file runs unmodified on three very different hardware back‑ends (CPU, CUDA GPU, TPU v6e).
  • Performance parity with hand‑tuned CUDA: Achieves ~140 TFLOPS (≈15 % of peak) on pre‑fill and up to 64 % memory‑bandwidth utilisation on decode, matching PyTorch/CUDA token‑for‑token results within float‑32 rounding.
  • Open‑source release: Code is publicly available and merged into the Bonsai JAX model library, enabling immediate experimentation.

Methodology

  1. Identify structural invariants – The author isolates three properties of the Mamba‑2 recurrence that make it amenable to XLA:

    • Diagonal state matrix → simplifies per‑channel operations.
    • Chunkable recurrence → the recurrence can be broken into independent blocks that XLA can tile.
    • Einsum‑dominated compute → the bulk of work is expressed as tensor contractions, which XLA fuses aggressively.
  2. Express the whole inference pipeline as XLA primitives – Prefill (processing the prompt) and autoregressive decode (generating one token at a time) are written using jax.lax.scan, jax.lax.dot_general, and jax.lax.reduce_window. No custom_call or external kernel is introduced.

  3. Compile‑time cache construction – The on‑device cache is allocated as a static buffer that XLA treats as a mutable variable. During each decode step the cache is read, updated, and written back entirely on‑device, guaranteeing true (O(1)) per‑token cost.

  4. Cross‑platform validation – The same JAX program is compiled for three back‑ends:

    • CPU (LLVM XLA) – baseline correctness and memory‑footprint checks.
    • NVIDIA GPU (CUDA XLA) – comparison against the reference PyTorch implementation.
    • TPU v6e (TPU XLA) – the primary performance target.
  5. Benchmarking & correctness – The author measures FLOPs, memory‑bandwidth utilisation, and token‑level numerical agreement (float‑32 tolerance) across model sizes from 130 M to 2.7 B parameters.

Results & Findings

PlatformModel SizePrefill ThroughputDecode Bandwidth UtilisationNumerical Agreement
TPU v6e130 M – 2.7 B~140 TFLOPS (≈15 % MFU)up to 64 % of peakToken‑for‑token match, hidden‑state diff ≤ 1 ULP
NVIDIA GPU (CUDA)2.7 BSame as reference PyTorch (within 2 % latency)Identical outputs
CPU130 MBaseline (no speed claim)Correctness verified

Key take‑aways:

  • XLA can fuse the entire SSM recurrence without any bespoke kernel, delivering performance close to hand‑tuned CUDA.
  • The on‑device cache truly removes host‑CPU stalls, confirming the theoretical (O(1)) cost per generated token.
  • Portability is achieved with zero code changes, proving that the duality properties are hardware‑agnostic as long as an XLA backend exists.

Practical Implications

  • Simplified deployment pipelines – Teams can ship Mamba‑2 inference containers that run on any cloud provider (AWS, GCP, Azure) without worrying about GPU‑specific kernels.
  • Lower engineering overhead – No need for separate CUDA/Triton code paths; a single JAX source suffices for research prototypes and production services.
  • Cost‑effective scaling – Since TPUs can now run SSMs at high utilisation, inference‑heavy workloads (e.g., real‑time transcription, code completion) can be moved to cheaper TPU‑based instances.
  • Future‑proofing – As XLA expands to new accelerators (e.g., AMD GPUs, custom ASICs), the same implementation will automatically benefit, protecting investment in model code.
  • Easier integration with existing JAX ecosystems – The Bonsai JAX library already supports many transformer‑style models; adding Mamba‑2 becomes a matter of a single import.

Limitations & Future Work

  • Peak utilisation still modest – 15 % of theoretical FLOP capacity on prefill indicates room for further fusion or layout optimisations.
  • TPU‑centric evaluation – While GPU correctness is shown, the performance gains are demonstrated only on TPU v6e; broader GPU benchmarks would strengthen the claim.
  • Static control flow assumption – The approach relies on the recurrence being fully static; models with dynamic branching or variable‑length state may need additional handling.
  • Memory footprint on very large models – The on‑device cache grows linearly with state size; future work could explore hierarchical or compressed caching strategies.

Overall, the paper paves the way for truly portable, high‑performance SSM inference, turning what was once a niche CUDA‑only capability into a compiler‑driven, cross‑hardware primitive.

Authors

  • Cosmo Santoni

Paper Information

  • arXiv ID: 2603.09555v1
  • Categories: cs.LG, cs.AI, cs.DC, cs.PF
  • Published: March 10, 2026
  • PDF: Download PDF
0 views
Back to Blog

Related posts

Read more »