[Paper] Scaling State-Space Models on Multiple GPUs with Tensor Parallelism
Source: arXiv - 2602.21144v1
Overview
Selective state‑space models (SSMs) have emerged as a powerful alternative to traditional Transformers for handling very long contexts in large language models (LLMs). The paper “Scaling State‑Space Models on Multiple GPUs with Tensor Parallelism” tackles a practical bottleneck: running SSM‑based LLMs efficiently when a single GPU runs out of memory or bandwidth. By adapting tensor parallelism (TP)—a technique already popular for scaling Transformers—the authors demonstrate how to spread the heavy SSM computations across multiple GPUs while keeping the critical recurrent path fast and communication‑light.
Key Contributions
- TP‑aware SSM design that partitions the large projection matrices but keeps the per‑token recurrent state updates local to each GPU.
- State‑cache mechanism enabling “prefill‑to‑decode” (TTFT) speed‑ups by reusing the SSM hidden state across the prefill and subsequent decoding phases.
- Quantized All‑Reduce for the TP aggregation step, slashing synchronization bandwidth by up to ~18 % without sacrificing numerical stability.
- Comprehensive evaluation on three real‑world SSM‑based LLM families (Mamba, Falcon‑Mamba, Zamba) across NVIDIA A6000 and A100 clusters, showing 1.6–4.0× throughput gains when scaling from 1 to 4 GPUs.
- Open‑source implementation (or detailed pseudo‑code) that can be plugged into existing inference stacks with minimal changes.
Methodology
- Tensor Partitioning – The authors treat the SSM mixer’s packed weight tensor (which contains both the long‑range projection and the local mixing kernels) as a single “big matrix”. They split this matrix along the feature dimension across GPUs, similar to classic TP for Transformers.
- Local Recurrence – Unlike a Transformer’s self‑attention, an SSM updates a hidden state sequentially for each token. The design ensures that each GPU maintains its own slice of the hidden state, so the recurrent update does not require cross‑GPU communication on the critical path.
- State Cache Across Prefill & Decode – During the initial “prefill” (processing a long prompt) the hidden state is cached. When the model switches to token‑by‑token decoding, the cached state is reused, avoiding recomputation of the expensive projection for already‑processed tokens.
- Quantized All‑Reduce – After each token, the partial results from all GPUs must be summed (All‑Reduce) to form the final hidden representation. The authors quantize the tensors to 8‑bit before the reduction and de‑quantize afterward, dramatically reducing the amount of data moved over the interconnect.
- Benchmark Suite – They run inference on three model families, varying context lengths (2 K–64 K tokens) and batch sizes, measuring both raw throughput (tokens/s) and end‑to‑end request latency.
Results & Findings
| # GPUs | Model | Context (tokens) | Throughput ↑ vs. 1‑GPU | Quantized All‑Reduce gain |
|---|---|---|---|---|
| 2 | Mamba | 8 K | 1.6–2.1× | +10 % |
| 4 | Mamba | 32 K | 2.6–4.0× | +18 % |
| 2‑4 | Falcon‑Mamba / Zamba | 4 K–64 K | Similar scaling trends | Consistent gains |
- Long‑context advantage: The speed‑up grows with context length because the SSM’s per‑token cost dominates, and the TP split reduces per‑GPU memory pressure.
- TTFT benefit: Caching the state across prefill and decode cuts the prefill latency by ~30 % on average.
- Communication efficiency: Quantized All‑Reduce lowers PCIe/NVLink traffic, making the approach viable even on clusters with modest interconnect bandwidth.
Practical Implications
- Deployers can now run SSM‑based LLMs with 32 K+ context on 2‑4 GPU nodes without hitting memory limits, opening up use‑cases like document‑level QA, code‑base search, or long‑form generation.
- Cost‑effective scaling: Instead of buying a single massive GPU (e.g., H100 80 GB), teams can stitch together more affordable A6000/A100 cards and still achieve near‑linear speed‑up.
- Framework integration: The design maps cleanly onto existing TP libraries (e.g., Megatron‑LM, DeepSpeed), meaning developers can add SSM support with a few configuration changes rather than a full rewrite.
- Lower latency for real‑time apps: The TTFT cache reduces the “warm‑up” penalty when switching from prompt ingestion to token‑by‑token generation, which is crucial for chat‑style assistants.
- Quantized communication provides a template for other memory‑heavy models (e.g., retrieval‑augmented Transformers) where All‑Reduce becomes a bottleneck.
Limitations & Future Work
- Hardware dependence: The biggest gains are observed on GPUs with high‑speed NVLink; on slower interconnects the quantized All‑Reduce may still be a limiting factor.
- Model‑specific tuning: The partitioning strategy assumes a certain shape of the SSM mixer; models with irregular or sparsified mixers might need custom slicing logic.
- Precision trade‑offs: While 8‑bit quantization works for the reduction step, the authors note a small (<0.2 BLEU) degradation on some downstream tasks—future work could explore mixed‑precision or adaptive quantization.
- Extending beyond inference: The paper focuses on inference; applying the same TP scheme to training (especially with gradient accumulation) remains an open challenge.
Bottom line: By marrying tensor parallelism with clever state caching and quantized communication, this work makes large‑scale, long‑context SSM inference practical on today’s multi‑GPU clusters—an advance that could accelerate the next wave of LLM‑powered applications.
Authors
- Anurag Dutt
- Nimit Shah
- Hazem Masarani
- Anshul Gandhi
Paper Information
- arXiv ID: 2602.21144v1
- Categories: cs.DC, cs.LG
- Published: February 24, 2026
- PDF: Download PDF