[Paper] Improving Deep Learning Library Testing with Machine Learning
Source: arXiv - 2602.03755v1
Overview
Deep learning frameworks such as TensorFlow and PyTorch power countless AI products, yet their massive, highly‑dynamic APIs make them surprisingly error‑prone. This paper shows how a lightweight machine‑learning classifier—trained on the shapes of tensors that flow through API calls—can automatically learn the “legal” input space of a library and dramatically cut down false alarms in automated testing tools.
Key Contributions
- Shape‑based abstraction: Demonstrates that tensor shape information alone is a sufficient, low‑dimensional representation for learning API input constraints.
- ML‑driven input validator: Trains binary classifiers (e.g., Random Forest, Gradient Boosting) on runtime‑labeled examples to predict whether a given API call will succeed.
- Large‑scale empirical study: Evaluates the approach on 183 TensorFlow and PyTorch APIs, achieving > 91 % classification accuracy on unseen inputs.
- Integration with ACETest: Embeds the learned validators into the state‑of‑the‑art bug‑finding tool ACETest, boosting its effective pass rate from ~29 % to ~61 %.
- Open‑source artifact: Provides the data collection pipeline and trained models, enabling other researchers and engineers to replicate or extend the work.
Methodology
- Data collection: Randomly generate concrete inputs for each target API (e.g., tensors of various dimensions, data types, and values). Execute the call and record whether it succeeds or throws an error.
- Shape extraction: For each input, keep only the tensor shapes (e.g.,
[32, 64],[None, 128]) and ancillary metadata (dtype, number of arguments). This reduces the feature space from millions of possible numeric values to a handful of categorical/ordinal features. - Labeling: The runtime outcome (pass/fail) becomes the ground‑truth label.
- Model training: Train standard supervised classifiers (Random Forest, XGBoost, shallow neural nets) on the shape‑based feature vectors. Hyper‑parameter tuning is performed via cross‑validation.
- Evaluation: Split the dataset into training and hold‑out sets per API, measuring accuracy, precision, recall, and F1‑score.
- Tool integration: Replace ACETest’s heuristic input‑validation step with the trained classifier, allowing the bug‑finder to discard clearly invalid inputs before deeper symbolic execution.
Results & Findings
- Classification performance: Across all 183 APIs, the best‑performing model reaches 91.3 % accuracy, with precision/recall above 0.9 for most APIs.
- Generalization: Models trained on a subset of inputs correctly classify > 85 % of completely unseen shape combinations, confirming that the shape abstraction captures the essential constraints.
- Impact on bug‑finding: When the classifier is used to filter out invalid test cases, ACETest’s pass rate (i.e., the proportion of generated tests that are meaningful) jumps from ~29 % to ~61 %, more than doubling its efficiency.
- False‑positive reduction: The number of spurious bug reports (invalid inputs flagged as bugs) drops by roughly 70 %, saving developers time in triage.
Practical Implications
- Faster CI pipelines: Teams can plug the shape‑based classifier into their continuous‑integration testing of custom TensorFlow/PyTorch extensions, cutting down wasted test executions.
- Better fuzzing tools: Existing fuzzers for DL libraries can adopt the same abstraction to focus their mutation engine on valid shape spaces, yielding higher bug‑discovery rates.
- API documentation assistance: The learned constraints can be reverse‑engineered into human‑readable preconditions, helping library maintainers improve docs and static analysis tools.
- Cross‑framework portability: Because the approach relies only on tensor shapes, it can be reused for emerging frameworks (e.g., JAX, MindSpore) with minimal retraining effort.
- Low overhead: Training a classifier for a new API takes seconds to minutes on a commodity laptop, making it feasible to integrate into developer tooling workflows.
Limitations & Future Work
- Shape‑only abstraction: While effective for many APIs, some functions also depend on tensor contents (e.g., values must be non‑negative). The current model cannot capture such semantic constraints.
- Static vs. dynamic behavior: The approach learns from observed runtime outcomes; if the training set does not include rare edge‑case shapes, the classifier may misclassify them.
- Scalability to custom ops: Extending the method to user‑defined operations with complex control flow may require richer feature sets beyond shapes.
- Future directions: The authors suggest augmenting shape features with lightweight value statistics (min/max, sparsity) and exploring transfer learning to bootstrap classifiers for new libraries with minimal data.
Bottom line: By turning the seemingly messy problem of DL library input validation into a tractable machine‑learning task, this work offers a practical, plug‑and‑play component that can make automated testing of TensorFlow, PyTorch, and similar frameworks far more efficient and developer‑friendly.
Authors
- Facundo Molina
- M M Abid Naziri
- Feiran Qin
- Alessandra Gorla
- Marcelo d’Amorim
Paper Information
- arXiv ID: 2602.03755v1
- Categories: cs.SE
- Published: February 3, 2026
- PDF: Download PDF