[Paper] A Machine Learning Approach Towards Runtime Optimisation of Matrix Multiplication
Source: arXiv - 2601.09114v1
Overview
The paper proposes a machine‑learning‑driven technique for automatically picking the best thread count when running the General Matrix Multiplication (GEMM) routine on modern multi‑core CPUs. By training a lightweight model on‑the‑fly, the authors achieve 25‑40 % speed‑ups on two contemporary HPC nodes compared with the default thread‑selection heuristics used in popular BLAS libraries.
Key Contributions
- ADSALA library prototype – a “Architecture and Data‑Structure Aware Linear Algebra” framework that plugs a ML model into BLAS calls.
- On‑the‑fly model training – gathers runtime characteristics (matrix sizes, cache usage, core topology) during early GEMM executions and continuously refines a regression model.
- Thread‑count prediction – the model outputs the optimal number of OpenMP threads for each GEMM invocation, superseding static or naïve heuristics.
- Cross‑architecture validation – experiments on Intel Cascade Lake (2 × 18 cores) and AMD Zen 3 (2 × 16 cores) demonstrate consistent gains for matrix workloads under ~100 MiB.
- Open‑source proof‑of‑concept – the authors release the code and training data, enabling reproducibility and further extension.
Methodology
-
Feature extraction – For each GEMM call the library records:
- Matrix dimensions (M, N, K)
- Estimated working set size (bytes)
- CPU topology (socket, core count, hyper‑threading state)
- Current system load (optional)
-
Model choice – A simple decision‑tree regressor (or gradient‑boosted tree) is trained to map these features to the runtime observed for each thread count tested during a short “exploration phase”.
-
Exploration phase – The first few GEMM executions on a new problem size are run with a range of thread counts (e.g., 1, ½ cores, full cores). Their runtimes populate the training set.
-
Prediction & deployment – Once the model reaches a predefined confidence threshold, it predicts the optimal thread count for subsequent GEMM calls, which are then executed with that setting.
-
Continuous adaptation – If a new matrix size or system state deviates significantly from the training distribution, the library falls back to a brief re‑exploration to update the model.
The whole pipeline adds only a few milliseconds of overhead, which is negligible compared with the multi‑second GEMM kernels targeted.
Results & Findings
| Architecture | Matrix footprint (≤ 100 MiB) | Baseline BLAS (static threads) | ADSALA (ML‑selected threads) | Speed‑up |
|---|---|---|---|---|
| Intel Cascade Lake (2 × 18) | 20 MiB – 100 MiB | 1.0× (reference) | 1.25× – 1.40× | 25 % – 40 % |
| AMD Zen 3 (2 × 16) | 20 MiB – 100 MiB | 1.0× | 1.28× – 1.38× | 28 % – 38 % |
- The optimal thread count often differed from the naïve “use all cores” rule; for many medium‑sized matrices, using only half the cores reduced contention on shared caches and memory bandwidth.
- Prediction accuracy (selected thread count vs. brute‑force optimum) was > 90 % after the initial exploration phase.
- Overhead of model training and inference was < 2 % of total runtime for the tested workloads.
Practical Implications
- Performance‑critical libraries – Developers of scientific Python (NumPy, SciPy), machine‑learning frameworks (TensorFlow, PyTorch), or custom HPC kernels can embed the ADSALA approach to get automatic, architecture‑aware thread tuning without manual benchmarking.
- Cloud & container environments – In virtualized or containerized settings where the number of physical cores visible to a process can change at runtime, a self‑tuning BLAS can adapt on the fly, delivering consistent throughput.
- Energy efficiency – Running fewer threads when they would otherwise compete for memory bandwidth reduces power draw, an attractive side‑effect for green‑computing initiatives.
- Ease of deployment – Because the model is lightweight (decision tree) and trained online, there is no need for large offline autotuning databases; the library can be shipped as a drop‑in replacement for existing BLAS calls.
Limitations & Future Work
- Scope limited to GEMM ≤ 100 MiB – Larger matrices that saturate memory bandwidth may exhibit different scaling behavior; extending the approach to the full GEMM spectrum is pending.
- Single‑node focus – Multi‑node distributed GEMM (e.g., ScaLAPACK) introduces network latency and process placement factors not covered here.
- Model simplicity – While decision trees are fast, more expressive models (e.g., neural nets or Gaussian processes) could capture subtler interactions between cache hierarchies and thread affinity.
- Exploration overhead – The initial profiling runs add latency for the first few calls; future work aims to reuse historical data across applications or share models via a central repository.
Overall, the paper demonstrates that a modest amount of machine‑learning intelligence can bridge the gap between hand‑tuned BLAS libraries and the ever‑changing landscape of multi‑core CPU architectures, offering developers a practical path to faster linear‑algebra kernels.
Authors
- Yufan Xia
- Marco De La Pierre
- Amanda S. Barnard
- Giuseppe Maria Junior Barca
Paper Information
- arXiv ID: 2601.09114v1
- Categories: cs.DC, cs.LG
- Published: January 14, 2026
- PDF: Download PDF