The universal principle of knowledge distillation, model compression, and rule extraction
Machine learning (ML) model training typically follows a familiar pipeline: start with data collection, clean and prepare it, then move on to model fitting. But what if we could take this process further? Just as some insects undergo dramatic transformations before reaching maturity, ML models can evolve in a similar way (see Hinton et al. [1]) — what I will call the ML metamorphosis. This process involves chaining different models together, resulting in a final model that achieves significantly better quality than if it had been trained directly from the start.
Here’s how it works:
- Start with some initial knowledge, Data 1.
- Train an ML model, Model A (say, a neural network), on this data.
- Generate new data, Data 2, using Model A.
- Finally, use Data 2 to fit your target model, Model B.
You may already be familiar with this concept from knowledge distillation, where a smaller neural network replaces a larger one. But ML metamorphosis goes beyond this, and neither the initial model (Model A) nor the final one (Model B) need be neural networks at all.
Example: ML metamorphosis on the MNIST Dataset
Imagine you’re tasked with training a multi-class decision tree on the MNIST dataset of handwritten digit images, but only 1,000 images are labelled. You could train the tree directly on this limited data, but the accuracy would be capped at around 0.67. Not great, right? Alternatively, you could use ML metamorphosis to improve your results.
But before we dive into the solution, let’s take a quick look at the techniques and research behind this approach.
1. Knowledge distillation (2015)
Even if you haven’t used knowledge distillation, you’ve probably seen it in action. For example, Meta suggests distilling its Llama 3.2 model to adapt it to specific tasks [2]. Or take DistilBERT — a distilled version of BERT [3]— or the DMD framework, which distills Stable Diffusion to speed up image generation by a factor of 30 [4].
At its core, knowledge distillation transfers knowledge from a large, complex model (the teacher) to a smaller, more efficient model (the student). The process involves creating a transfer set that includes both the original training data and additional data (either original or synthesized) pseudo-labeled by the teacher model. The pseudo-labels are known as soft labels — derived from the probabilities predicted by the teacher across multiple classes. These soft labels provide richer information than hard labels (simple class indicators) because they reflect the teacher’s confidence and capture subtle similarities between classes. For instance, they might show that a particular “1” is more similar to a “7” than to a “5.”
By training on this enriched transfer set, the student model can effectively mimic the teacher’s performance while being much lighter, faster, and easier to use.
The student model obtained in this way is more accurate than it would have been if it had been trained solely on the original training set.
2. Model compression (2007)
Model compression [5] is often seen as a precursor to knowledge distillation, but there are important differences. Unlike knowledge distillation, model compression doesn’t seem to use soft labels, despite some claims in the literature [1,6]. I haven’t found any evidence that soft labels are part of the process. In fact, the method in the original paper doesn’t even rely on artificial neural networks (ANNs) as Model A. Instead, it uses an ensemble of models — such as SVMs, decision trees, random forests, and others.
Model compression works by approximating the feature distribution p(x) to create a transfer set. This set is then labelled by Model A, which provides the conditional distribution p(y∣x). The key innovation in the original work is a technique called MUNGE to approximate p(x). As with knowledge distillation, the goal is to train a smaller, more efficient Model B that retains the performance of the larger Model A.
As in knowledge distillation, the compressed model trained in this way can often outperform a similar model trained directly on the original data, thanks to the rich information embedded in the transfer set [5].
Often, “model compression” is used more broadly to refer to any technique that reduces the size of Model A [7,8]. This includes methods like knowledge distillation but also techniques that don’t rely on a transfer set, such as pruning, quantization, or low-rank approximation for neural networks.
3. Rule extraction (1995)
When the problem isn’t computational complexity or memory, but the opacity of a model’s decision-making, pedagogical rule extraction offers a solution [9]. In this approach, a simpler, more interpretable model (Model B) is trained to replicate the behavior of the opaque teacher model (Model A), with the goal of deriving a set of human-readable rules. The process typically starts by feeding unlabelled examples — often randomly generated — into Model A, which labels them to create a transfer set. This transfer set is then used to train the transparent student model. For example, in a classification task, the student model might be a decision tree that outputs rules such as: “If feature X1 is above threshold T1 and feature X2 is below threshold T2, then classify as positive”.
The main goal of pedagogical rule extraction is to closely mimic the teacher model’s behavior, with fidelity — the accuracy of the student model relative to the teacher model — serving as the primary quality measure.
Interestingly, research has shown that transparent models created through this method can sometimes reach higher accuracy than similar models trained directly on the original data used to build Model A [10,11].
Pedagogical rule extraction belongs to a broader family of techniques known as “global” model explanation methods, which also include decompositional and eclectic rule extraction. See [12] for more details.
4. Simulations as Model A
Model A doesn’t have to be an ML model — it could just as easily be a computer simulation of an economic or physical process, such as the simulation of airflow around an airplane wing. In this case, Data 1 consists of the differential or difference equations that define the process. For any given input, the simulation makes predictions by solving these equations numerically. However, when these simulations become computationally expensive, a faster alternative is needed: a surrogate model (Model B), which can accelerate tasks like optimization [13]. When the goal is to identify important regions in the input space, such as zones of system stability, an interpretable Model B is developed through a process known as scenario discovery [14]. To generate the transfer set (Data 2) for both surrogate modelling and scenario discovery, Model A is run on a diverse set of inputs.
Back to our MNIST example
In an insightful article on TDS [15], Niklas von Moers shows how semi-supervised learning can improve the performance of a convolutional neural network (CNN) on the same input data. This result fits into the first stage of the ML metamorphosis pipeline, where Model A is a trained CNN classifier. The transfer set, Data 2, then contains the originally labelled 1,000 training examples plus about 55,000 examples pseudo-labelled by Model A with high confidence predictions. I now train our target Model B, a decision tree classifier, on Data 2 and achieve an accuracy of 0.86 — much higher than 0.67 when training on the labelled part of Data 1 alone. This means that chaining the decision tree to the CNN solution reduces error rate of the decision tree from 0.33 to 0.14. Quite an improvement, wouldn’t you say?
For the full experimental code, check out the GitHub repository.
Conclusion
In summary, ML metamorphosis isn’t always necessary — especially if accuracy is your only concern and there’s no need for interpretability, faster inference, or reduced storage requirements. But in other cases, chaining models may yield significantly better results than training the target model directly on the original data.
For a classification task, the process involves:
- Data 1: The original, fully or partially labeled data.
- Model A: A model trained on Data 1.
- Data 2: A transfer set that includes pseudo-labeled data.
- Model B: The final model, designed to meet additional requirements, such as interpretability or efficiency.
So why don’t we always use ML metamorphosis? The challenge often lies in finding the right transfer set, Data 2 [9]. But that’s a topic for another story.
References
[1] Hinton, Geoffrey. “Distilling the Knowledge in a Neural Network.” arXiv preprint arXiv:1503.02531 (2015).
[3] Sanh, Victor, et al. “DistilBERT, a distilled version of BERT: Smaller, faster, cheaper and lighter. ” arXiv preprint arXiv:1910.01108 (2019).
[4] Yin, Tianwei, et al. “One-step diffusion with distribution matching distillation.” Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2024.
[5] Buciluǎ, Cristian, Rich Caruana, and Alexandru Niculescu-Mizil. “Model compression.” Proceedings of the 12th ACM SIGKDD international conference on Knowledge discovery and data mining. 2006.
[6] Knowledge distillation, Wikipedia
[7] An Overview of Model Compression Techniques for Deep Learning in Space, on Medium
[8] Distilling BERT Using an Unlabeled Question-Answering Dataset, on Towards Data Science
[9] Arzamasov, Vadim, Benjamin Jochum, and Klemens Böhm. “Pedagogical Rule Extraction to Learn Interpretable Models — an Empirical Study.” arXiv preprint arXiv:2112.13285 (2021).
[10] Domingos, Pedro. “Knowledge acquisition from examples via multiple models.” MACHINE LEARNING-INTERNATIONAL WORKSHOP THEN CONFERENCE-. MORGAN KAUFMANN PUBLISHERS, INC., 1997.
[11] De Fortuny, Enric Junque, and David Martens. “Active learning-based pedagogical rule extraction.” IEEE transactions on neural networks and learning systems 26.11 (2015): 2664–2677.
[12] Guidotti, Riccardo, et al. “A survey of methods for explaining black box models.” ACM computing surveys (CSUR) 51.5 (2018): 1–42.
[13] Surrogate model, Wikipedia
[14] Scenario discovery in Python, blog post on Water Programming
[15] Teaching Your Model to Learn from Itself, on Towards Data Science