[Paper] PRISM: Distribution-free Adaptive Computation of Matrix Functions for Accelerating Neural Network Training

Published: (January 29, 2026 at 01:55 PM EST)
4 min read
Source: arXiv

Source: arXiv - 2601.22137v1

Overview

The paper introduces PRISM, a new framework that speeds up the computation of matrix functions (e.g., square roots, inverse roots, orthogonalization) that are heavily used in modern preconditioned optimizers such as Shampoo and Muon. By marrying adaptive polynomial approximation with lightweight randomized sketching, PRISM cuts the number of expensive matrix‑multiply iterations needed, delivering faster neural‑network training on GPUs without requiring any prior knowledge of the matrix spectrum.

Key Contributions

  • Distribution‑free adaptive approximation – PRISM builds a polynomial surrogate of the target matrix function on‑the‑fly, using only a cheap sketched least‑squares fit, so it works for any spectrum shape.
  • Randomized iterative sketching – Each iteration solves a low‑dimensional sketch of the full problem, dramatically reducing per‑iteration cost while preserving accuracy.
  • Plug‑and‑play acceleration – The framework can be dropped into existing Newton‑Schulz‑style iterations for matrix square roots and orthogonalization without redesigning the underlying optimizer.
  • No spectral bounds required – Unlike prior methods, PRISM does not need pre‑computed eigenvalue or singular‑value estimates, eliminating a common source of hyper‑parameter tuning.
  • Empirical validation on real workloads – Integrated into Shampoo and Muon, PRISM yields measurable wall‑time reductions on large‑scale language‑model and vision‑model training runs.

Methodology

  1. Iterative baseline – Many preconditioned optimizers compute a matrix function by repeatedly applying a Newton‑Schulz‑type update, which converges quadratically but still needs many matrix‑multiply steps.
  2. Polynomial surrogate – At iteration k, PRISM samples a small set of random vectors and forms a sketch of the current matrix (A_k). It then solves a tiny least‑squares problem to fit a low‑degree polynomial (p_k(\lambda)) that approximates the desired function (f(\lambda)) (e.g., (\sqrt{\lambda})) over the observed spectrum of the sketch.
  3. Adaptive degree selection – The algorithm monitors the residual of the sketch and automatically raises the polynomial degree only when needed, keeping the work minimal.
  4. Sketch‑based update – The polynomial surrogate is applied to the full matrix via a few additional matrix‑multiply passes (the same operations that GPUs excel at). Because the polynomial coefficients are already tuned to the current spectrum, the update converges in far fewer passes than the vanilla Newton‑Schulz loop.
  5. Integration – PRISM wraps around the existing optimizer’s matrix‑function routine; the rest of the training pipeline (loss, back‑prop, data loading) stays untouched.

Results & Findings

ExperimentBaseline OptimizerPRISM‑augmented OptimizerSpeed‑up (wall‑time)Final validation loss
BERT‑large pre‑training (8 GPU)ShampooShampoo + PRISM≈ 1.6× fasterSame (±0.1 %)
ResNet‑50 on ImageNet (16 GPU)MuonMuon + PRISM≈ 1.4× fasterSame
Synthetic large‑matrix square‑root (10⁴ × 10⁴)Newton‑SchulzPRISM‑Newton‑Schulz≈ 2.2× fewer multiplicationsError ≤ 1e‑6

Takeaway: PRISM consistently reduces the number of expensive matrix‑multiply iterations while preserving numerical accuracy, translating into 30‑60 % wall‑time savings on real training workloads.

Practical Implications

  • Faster model iteration cycles – Teams can train larger models or experiment more quickly without buying extra hardware.
  • Lower GPU utilization – Because PRISM cuts the number of dense matrix multiplications, GPU memory bandwidth and power consumption are reduced, which is valuable for cost‑sensitive cloud training.
  • Zero‑tuning integration – Developers can drop PRISM into existing codebases that already use Shampoo, Muon, or any Newton‑Schulz‑style matrix‑function routine, with no need to hand‑craft spectral bounds or adjust hyper‑parameters.
  • Broader applicability – Any algorithm that relies on matrix square roots, inverse roots, or orthogonalization (e.g., natural gradient, second‑order methods, covariance estimation) can benefit from PRISM’s sketch‑based acceleration.
  • GPU‑friendly design – All operations are expressed as batched GEMM (general matrix‑matrix multiply), which aligns perfectly with CUDA/cuBLAS and emerging tensor‑core pipelines.

Limitations & Future Work

  • Sketch size sensitivity – While the authors show robustness, the choice of sketch dimension trades off overhead vs. approximation quality; extremely ill‑conditioned matrices may still need larger sketches.
  • Memory overhead for very large models – Storing the additional sketch vectors can be non‑trivial when the matrix dimensions approach GPU memory limits.
  • Extension to non‑square functions – PRISM currently targets square‑root‑type functions; adapting it to more exotic matrix functions (e.g., matrix logarithm) remains an open question.
  • Theoretical convergence guarantees – The paper provides empirical evidence of rapid convergence but a full worst‑case bound on iteration count under arbitrary spectra is left for future analysis.

Future directions include automated sketch‑size selection, integration with distributed training frameworks (e.g., ZeRO, DeepSpeed), and extending the adaptive polynomial idea to other second‑order optimization primitives.

Authors

  • Shenghao Yang
  • Zhichao Wang
  • Oleg Balabanov
  • N. Benjamin Erichson
  • Michael W. Mahoney

Paper Information

  • arXiv ID: 2601.22137v1
  • Categories: cs.LG, cs.AI, math.NA, math.OC
  • Published: January 29, 2026
  • PDF: Download PDF
Back to Blog

Related posts

Read more »