[Paper] Multi-Token Prediction via Self-Distillation
Source: arXiv - 2602.06019v1
Overview
A team of researchers proposes a surprisingly simple way to turn any pretrained autoregressive language model (e.g., GPT‑2, LLaMA) into a multi‑token predictor that runs up to three times faster at inference time. Instead of adding a separate “speculator” model or building a complex decoding pipeline, they train the original model to predict several future tokens in one forward pass using an online self‑distillation objective. The result is a drop of less than 5 % in accuracy on a math‑reasoning benchmark (GSM8K) while achieving a substantial speed boost, all without changing the model’s architecture or deployment code.
Key Contributions
- Self‑distillation for multi‑token prediction: Introduces an online distillation loss that teaches a pretrained model to output a short sequence of tokens rather than a single next token.
- Zero‑change deployment: The final model uses the exact same checkpoint and inference code as the original single‑token model—no auxiliary verifier, speculator, or custom runtime needed.
- Empirical speed‑accuracy trade‑off: Demonstrates > 3× decoding speed on GSM8K with < 5 % relative accuracy loss, closing the gap between speculative decoding and naïve single‑token generation.
- Broad applicability: Works with any autoregressive LM regardless of size or pretraining data, making it a drop‑in upgrade for existing services.
Methodology
- Baseline model – Start with a frozen pretrained autoregressive LM (e.g., a decoder‑only transformer).
- Online teacher – During training, the same model runs in its usual single‑token mode to generate “teacher” predictions for the next k tokens (k is a small integer like 4 or 8).
- Student head – A lightweight additional head is attached to the final hidden state, trained to directly output the k tokens in one shot.
- Distillation loss – The student’s logits are penalized for deviating from the teacher’s logits (cross‑entropy) across all k positions, while also preserving the original language modeling loss for the first token.
- Curriculum – The value of k is gradually increased as training progresses, allowing the model to adapt to longer horizons without destabilizing.
- Inference – At runtime, the model simply calls the new head to emit k tokens, then shifts the context forward and repeats, eliminating the need for a separate verifier that checks whether the multi‑token guess was correct.
Because the teacher and student are the same network, the process is self‑distilling and can be performed online (no separate teacher model or dataset generation step).
Results & Findings
| Metric | Single‑token decoding | Multi‑token (k=4) | Multi‑token (k=8) |
|---|---|---|---|
| GSM8K accuracy (relative) | 100 % | 96 % | 93 % |
| Tokens per second (TP/s) | 1× (baseline) | 2.8× | 3.2× |
| Latency reduction | – | 65 % | 70 % |
- Speed: Decoding speed scales roughly linearly with k up to a point; beyond k = 8 the accuracy drop becomes more pronounced.
- Quality: The modest accuracy loss is largely due to occasional “drift” where early token errors propagate through the multi‑token block, but most errors are recoverable in subsequent blocks.
- Compatibility: The same approach was tested on GPT‑2‑medium and a 1.3 B LLaMA checkpoint, showing consistent speedups without any architecture changes.
Practical Implications
- Faster APIs: Cloud providers can boost throughput of existing LLM endpoints simply by fine‑tuning the model with the self‑distillation loss—no extra servers or custom inference kernels required.
- Cost savings: Reducing the number of forward passes per generated token translates directly into lower GPU/TPU utilization, cutting inference bills for high‑volume applications (chatbots, code completion, etc.).
- Edge deployment: Devices with limited compute (mobile, IoT) can run larger models more responsively by emitting several tokens per inference step, extending the feasible model size on‑device.
- Simplified pipelines: Unlike speculative decoding, there is no need to maintain a separate “verifier” model or orchestrate speculative‑then‑fallback logic, lowering engineering overhead and potential bugs.
Limitations & Future Work
- Error propagation: When the model mis‑predicts early tokens in a block, the mistake can affect the rest of the block, leading to occasional bursts of low‑quality output.
- Fixed block size: The current method uses a static k; adaptive block lengths based on confidence could improve the accuracy‑speed trade‑off.
- Benchmark scope: Experiments focus on GSM8K (math reasoning); broader evaluations (dialogue, code generation, long‑form text) are needed to confirm generality.
- Training overhead: The self‑distillation fine‑tuning adds extra compute compared to a plain inference‑only deployment, though it is modest relative to pretraining costs.
Future research directions include integrating confidence‑based early stopping within a block, combining self‑distillation with quantization or pruning for even tighter latency budgets, and exploring multi‑modal extensions (e.g., vision‑language models).
Authors
- John Kirchenbauer
- Abhimanyu Hans
- Brian Bartoldson
- Micah Goldblum
- Ashwinee Panda
- Tom Goldstein
Paper Information
- arXiv ID: 2602.06019v1
- Categories: cs.CL, cs.LG
- Published: February 5, 2026
- PDF: Download PDF