[Paper] Flash-KMeans: Fast and Memory-Efficient Exact K-Means
Source: arXiv - 2603.09229v1
Overview
The paper revisits the classic k‑means clustering algorithm and shows how to turn it from an offline, batch‑only tool into a fast, memory‑efficient primitive that can run directly inside modern AI pipelines on GPUs. By redesigning the core kernels to avoid costly memory traffic and atomic‑write contention, the authors achieve order‑of‑magnitude speedups on NVIDIA H200 GPUs, making exact k‑means viable for online services and large‑scale data‑centric workloads.
Key Contributions
- FlashAssign kernel – computes distances and selects the nearest centroid on‑the‑fly, eliminating the need to materialize the full (N \times K) distance matrix in high‑bandwidth memory.
- Sort‑inverse update – builds an inverse mapping of point‑to‑centroid assignments and replaces high‑contention atomic scatters with contiguous, segment‑level reductions.
- System‑level co‑design – introduces chunked‑stream overlapping, cache‑aware compile heuristics, and other GPU‑friendly tricks to keep the kernels saturated.
- Comprehensive evaluation – demonstrates up to 17.9× end‑to‑end speedup over the best existing GPU implementations and 33× / 200× improvements versus cuML and FAISS, respectively.
- Open‑source implementation – the authors release the Flash‑KMeans codebase, enabling easy integration into existing PyTorch/TensorFlow pipelines.
Methodology
-
Problem Identification – The authors profile state‑of‑the‑art GPU k‑means (e.g., cuML, FAISS) and pinpoint two bottlenecks:
- IO bottleneck: Explicitly storing the (N \times K) distance matrix in HBM overwhelms memory bandwidth.
- Atomic contention: Updating centroids via scatter‑style atomic adds leads to severe serialization when many points map to the same centroid.
-
Kernel Redesign
- FlashAssign fuses the distance computation (a matrix‑vector multiply) with an online argmin reduction. Each thread processes a subset of points, computes distances to all centroids, and immediately keeps the smallest distance and its index, discarding the rest. This eliminates the intermediate matrix.
- Sort‑inverse update first sorts point indices by their assigned centroid, then builds an “inverse map” that groups points belonging to the same centroid. A segmented reduction over each group computes the new centroid sums without any atomic operations.
-
System Optimizations
- Chunked‑stream overlap: Data is streamed in tiles that fit in shared memory, allowing compute and memory transfers to overlap.
- Cache‑aware compilation: Heuristics select thread‑block sizes and memory tiling strategies that maximize L2 cache reuse for the distance calculations.
-
Evaluation Setup – Experiments run on NVIDIA H200 GPUs across a range of dataset sizes (from 1 M to 100 M points) and cluster counts (K = 16 … 1024). Baselines include cuML, FAISS, and a naïve CUDA implementation.
Results & Findings
| Metric | Flash‑KMeans | cuML | FAISS | Naïve CUDA |
|---|---|---|---|---|
| End‑to‑end speedup (vs. best baseline) | up to 17.9× | 1× | 1× | 1× |
| Speedup vs. cuML | 33× | – | – | – |
| Speedup vs. FAISS | >200× | – | – | – |
| Memory footprint (distance matrix) | O(N) (no matrix) | O(N·K) | O(N·K) | O(N·K) |
| Atomic contention (updates) | Zero (segment reduction) | High | High | High |
Key takeaways:
- IO elimination cuts memory traffic by >95 %, freeing bandwidth for other kernels.
- Contention‑free updates give deterministic performance even when K is small and many points share a centroid.
- The approach scales linearly with both N and K, preserving exact k‑means results (no approximation).
Practical Implications
- Online clustering services – Real‑time recommendation or anomaly‑detection pipelines can now afford exact k‑means on streaming data without a separate batch window.
- Embedding preprocessing – Large‑scale language‑model or vision embedding tables can be re‑clustered on‑the‑fly for quantization or cache partitioning, reducing latency in inference serving.
- GPU‑centric ML frameworks – Flash‑KMeans can be dropped into PyTorch/TensorFlow as a native
torch.kmeansop, enabling end‑to‑end GPU pipelines without CPU‑GPU data shuffles. - Cost savings – By avoiding HBM over‑allocation and atomic stalls, cloud GPU instances can handle larger workloads per dollar, especially on newer Hopper‑based hardware.
Developers can call the library just like any other CUDA‑accelerated routine:
import flashkmeans as fk
centroids, assignments = fk.kmeans(data, k=256, max_iters=20)
The API returns exact centroids and point‑to‑cluster assignments, ready for downstream tasks.
Limitations & Future Work
- Hardware specificity – The current optimizations target NVIDIA Hopper (H200) GPUs; performance gains on older architectures (e.g., Ampere, Turing) are less dramatic and may require retuning.
- Memory bound for extreme K – When K grows to millions, the per‑centroid reduction buffers can exceed shared memory, re‑introducing some contention.
- Distributed scaling – The paper focuses on a single‑GPU setting; extending the sort‑inverse update across multiple nodes (e.g., via NCCL) is left as future work.
- Dynamic data streams – While the kernels are fast, handling continuously arriving points without re‑running the full iteration loop is an open research direction.
The authors suggest exploring hierarchical clustering hybrids, adaptive chunk sizing for heterogeneous workloads, and integrating Flash‑KMeans into larger auto‑ML pipelines as next steps.
Authors
- Shuo Yang
- Haocheng Xi
- Yilong Zhao
- Muyang Li
- Xiaoze Fan
- Jintao Zhang
- Han Cai
- Yujun Lin
- Xiuyu Li
- Kurt Keutzer
- Song Han
- Chenfeng Xu
- Ion Stoica
Paper Information
- arXiv ID: 2603.09229v1
- Categories: cs.DC
- Published: March 10, 2026
- PDF: Download PDF