[Paper] PRISM: 분포 자유 적응형 행렬 함수 계산을 통한 신경망 학습 가속화
Source: arXiv - 2601.22137v1
Overview
논문에서는 PRISM이라는 새로운 프레임워크를 소개한다. 이 프레임워크는 Shampoo와 Muon과 같은 최신 사전조건 최적화기에서 많이 사용되는 행렬 함수(예: 제곱근, 역제곱근, 직교화)의 계산을 가속화한다. 적응형 다항식 근사와 가벼운 무작위 스케치 기법을 결합함으로써 PRISM은 비용이 많이 드는 행렬‑곱 연산 반복 횟수를 줄여, GPU에서 사전 지식 없이도 신경망 훈련을 더 빠르게 수행할 수 있다.
Key Contributions
- Distribution‑free adaptive approximation – PRISM은 대상 행렬 함수의 다항식 대리 모델을 실시간으로 구축하며, 저비용 스케치된 최소제곱 피팅만을 사용해 어떠한 스펙트럼 형태에도 적용 가능합니다.
- Randomized iterative sketching – 각 반복에서는 전체 문제의 저차원 스케치를 풀어, 정확도를 유지하면서 반복당 비용을 크게 줄입니다.
- Plug‑and‑play acceleration – 이 프레임워크는 기존의 Newton‑Schulz 방식 행렬 제곱근 및 직교화 반복에 그대로 삽입할 수 있어, 기본 최적화기를 재설계할 필요가 없습니다.
- No spectral bounds required – 기존 방법과 달리 PRISM은 사전 계산된 고유값이나 특이값 추정이 필요 없으며, 흔히 발생하는 하이퍼파라미터 튜닝 부담을 없앱니다.
- Empirical validation on real workloads – Shampoo와 Muon에 통합된 PRISM은 대규모 언어 모델 및 비전 모델 학습에서 실질적인 실행 시간 감소를 입증합니다.
방법론
-
Iterative baseline – 많은 사전조건 최적화기들은 뉴턴‑슐츠 유형 업데이트를 반복 적용하여 행렬 함수를 계산하는데, 이는 2차 수렴하지만 여전히 많은 행렬‑곱셈 단계가 필요합니다.
-
Polynomial surrogate – 반복 k에서 PRISM은 소수의 랜덤 벡터를 샘플링하고 현재 행렬 (A_k)의 스케치를 구성합니다. 그런 다음 아주 작은 최소제곱 문제를 풀어 저차 다항식 (p_k(\lambda))을 맞추는데, 이는 스케치의 관측된 스펙트럼에 대해 원하는 함수 (f(\lambda)) (예: (\sqrt{\lambda}))를 근사합니다.
-
Adaptive degree selection – 알고리즘은 스케치의 잔차를 모니터링하고 필요할 때만 자동으로 다항식 차수를 높여 작업량을 최소화합니다.
-
Sketch‑based update – 다항식 대리 모델은 몇 번의 추가 행렬‑곱셈 패스를 통해 전체 행렬에 적용됩니다(이는 GPU가 뛰어나게 수행하는 연산과 동일합니다). 다항식 계수가 이미 현재 스펙트럼에 맞춰져 있기 때문에, 업데이트는 일반적인 뉴턴‑슐츠 루프보다 훨씬 적은 패스에서 수렴합니다.
-
Integration – PRISM은 기존 최적화기의 행렬‑함수 루틴을 감싸며, 나머지 학습 파이프라인(손실, 역전파, 데이터 로딩)은 그대로 유지됩니다.
결과 및 발견
| 실험 | 기본 옵티마이저 | PRISM‑증강 옵티마이저 | 속도 향상 (실제 시간) | 최종 검증 손실 |
|---|---|---|---|---|
| BERT‑large 사전 학습 (8 GPU) | Shampoo | Shampoo + PRISM | ≈ 1.6× 빠름 | 동일 (±0.1 %) |
| ImageNet에서 ResNet‑50 (16 GPU) | Muon | Muon + PRISM | ≈ 1.4× 빠름 | 동일 |
| 합성 대형 행렬 제곱근 (10⁴ × 10⁴) | Newton‑Schulz | PRISM‑Newton‑Schulz | ≈ 2.2× 적은 곱셈 | 오류 ≤ 1e‑6 |
핵심: PRISM은 비용이 많이 드는 행렬 곱셈 반복 횟수를 지속적으로 줄이면서 수치 정확성을 유지하여 실제 학습 작업에서 30‑60 % 실시간 절감 효과를 제공합니다.
Practical Implications
- Faster model iteration cycles – 팀은 추가 하드웨어를 구입하지 않고도 더 큰 모델을 훈련하거나 더 빠르게 실험할 수 있습니다.
- Lower GPU utilization – PRISM이 조밀한 행렬 곱셈 수를 줄이기 때문에 GPU 메모리 대역폭과 전력 소비가 감소하며, 이는 비용에 민감한 클라우드 훈련에 유용합니다.
- Zero‑tuning integration – 개발자는 이미 Shampoo, Muon 또는 Newton‑Schulz 스타일 행렬‑함수 루틴을 사용하는 기존 코드베이스에 PRISM을 바로 적용할 수 있으며, 스펙트럼 경계를 직접 설계하거나 하이퍼파라미터를 조정할 필요가 없습니다.
- Broader applicability – 행렬 제곱근, 역제곱근, 또는 직교화에 의존하는 모든 알고리즘(예: 자연 그래디언트, 2차 방법, 공분산 추정)은 PRISM의 스케치 기반 가속의 혜택을 받을 수 있습니다.
- GPU‑friendly design – 모든 연산이 배치된 GEMM(일반 행렬‑행렬 곱)으로 표현되어 CUDA/cuBLAS 및 최신 텐서‑코어 파이프라인과 완벽하게 일치합니다.
제한 사항 및 향후 작업
- 스케치 크기 민감도 – 저자들은 견고성을 보여주지만, 스케치 차원의 선택은 오버헤드와 근사 품질 사이의 트레이드오프를 의미합니다; 매우 조건이 나쁜 행렬은 여전히 더 큰 스케치를 필요로 할 수 있습니다.
- 매우 큰 모델에 대한 메모리 오버헤드 – 행렬 차원이 GPU 메모리 한계에 근접할 때 추가 스케치 벡터를 저장하는 비용이 무시할 수 없을 정도로 커질 수 있습니다.
- 비정방형 함수에 대한 확장 – PRISM은 현재 제곱근형 함수에 초점을 맞추고 있으며, 행렬 로그와 같은 보다 이색적인 행렬 함수에 적용하려면 추가 연구가 필요합니다.
- 이론적 수렴 보장 – 논문은 빠른 수렴에 대한 실증적 증거를 제공하지만, 임의 스펙트럼 하에서의 최악 경우 반복 횟수에 대한 완전한 경계는 향후 분석 과제로 남겨두었습니다.
향후 방향에는 자동 스케치 크기 선택, 분산 학습 프레임워크(예: ZeRO, DeepSpeed)와의 통합, 그리고 적응형 다항식 아이디어를 다른 2차 최적화 원시 연산에 확장하는 것이 포함됩니다.
저자
- Shenghao Yang
- Zhichao Wang
- Oleg Balabanov
- N. Benjamin Erichson
- Michael W. Mahoney
논문 정보
- arXiv ID: 2601.22137v1
- 분류: cs.LG, cs.AI, math.NA, math.OC
- 출판일: 2026년 1월 29일
- PDF: Download PDF