Clone the Abilities of Powerful LLMs into Small Local Models Using Knowledge Distillation

Boost the performance of local LLMs using supervision from larger ones

Photo by matthew Feeney on Unsplash

In the realm of Natural Language Processing (NLP), cutting-edge Large Language Models (LLMs) offer remarkable few-shot learning and reasoning capabilities. However, the computational demands and latency associated with these models can sometimes render them impractical for certain applications. If your goal, for instance, is to develop a translation service, you probably don’t require your back-end LLM to possess the ability to crack jokes or explain quantum physics to a kindergartner. This highlights the demand for specialized, smaller-scale models.

A viable solution to this challenge is to construct tailored LLMs that cater precisely to your specific use case. This involves annotating significant volumes of data and then fine-tuning a more compact model like Tiny-llama to suit your requirements. Such an approach not only ensures that the model aligns closely with your needs but also mitigates the computational and deployment expenses associated with larger LLMs. However, one must acknowledge the downside of this method: the process of data annotation is often laborious and time-consuming.

To address this bottleneck, an alternative emerges in the form of knowledge distillation. Instead of relying solely on manual labeling, this approach leverages the capabilities of a very large language model along with targeted prompting to generate labeled data automatically. Subsequently, a smaller model can be fine-tuned using this distilled knowledge, thereby streamlining the model development process while maintaining performance.

In this post, we will work trough this exact same scenario applied to building a model for multi-language grammatical error correction.

The Task:

Our goal is to detect and correct grammatical errors within a sentence. For instance:

  • Corrupted sentence: “It is very hard to get rid of bad habit.”
  • Corrected sentence: “It is very hard to get rid of bad habits.”

The Distillation Workflow:

Here’s how we’re going to distill the knowledge from our teacher model to our student model:

  1. First, acquire unlabeled in-domain data.
  2. Second, craft a prompt to extract pseudo-labels from the teacher model by leveraging Anyscale’s API.
  3. Finally, fine-tune the student model on these pseudo labels using LoRa + Peft.

The Data:

The data we use is from huggingface datasets “`juancavallotti/multilingual-gec““ where we only use the labels for evaluation and not for training. [Licensed under Apache 2]

This data can be loaded as follows:

from datasets import load_dataset

data = load_dataset("juancavallotti/multilingual-gec", split="train")

The Teacher Model:

We’re employing the LLama 2–70B as our teacher model. The teacher model is what will produce the pseudo-labels that will be used for the training. This powerful LLM is hosted on AnyScale’s pay-per-use API. AnyScale offers a $10 credit, allowing you to explore and utilize the model without incurring any costs initially. As an alternative you can also use OpenAI or Anthropic’s API.

We generate pseudo-labels for around 5000 samples. It costs 1.2 dollars.

You can call this API like this:

from openai import OpenAI

BASE_URL = "https://api.endpoints.anyscale.com/v1"
BASE_MODEL = "meta-llama/Llama-2-70b-chat-hf"

BASE_CLIENT = OpenAI(base_url=BASE_URL, api_key=API_KEY)

def process_call(prompt):

completion = BASE_CLIENT.completions.create(
model=BASE_MODEL,
prompt=prompt,
max_tokens=100,
temperature=0,
)
result = completion.model_dump()

return result["choices"][0]["text"].strip()

We use a simple few-shot prompting technique using the LLama 2 prompt template. This allows the LLM to understand what is the expected output and generally improves the quality of the result.

<s>[INST]
Your role is to correct all grammatical errors in the input text. Only answer with the corrected text and nothing else.

Text: Il est très importante de parler une langue étrangère.
[/INST]
Output: Il est très important de parler une langue étrangère.</s>
[INST]
Text: Nadie dise ezo.
[/INST]
Output: Nadie dice eso.</s>
[INST]
Text: What is your favorite part of being a member of SWE RMS?
[/INST]
Output: What is your favorite part of being a member of SWE RMS?</s>
[INST]
Text: I looked, at the schedule.
[/INST]
Output: I looked at the schedule.</s>
[INST]
Text: $text
[/INST]
Output:

The Student Model:

We are using Tiny-LLama as our student model. The student model is what we will “train” on the grammar correction task using the pseudo-labels from the teacher model. Despite its smaller scale with 1 billion parameters, it’s highly efficient. Tiny-LLama can run on consumer GPUs with just a few gigabytes of memory.

This model can be run as a HuggingFace Pipeline. We use BitsAndBytes for GPU quantization, this reduces the memory requirements of running LLMs.

from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
pipeline,
)

base_model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

llama_tokenizer = AutoTokenizer.from_pretrained(
base_model_name, trust_remote_code=True
)
llama_tokenizer.padding_side = "right"

quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=False,
)
# Model
model = AutoModelForCausalLM.from_pretrained(
base_model_name,
quantization_config=quant_config,
device_map={"": 0},
)

text_gen = pipeline(
task="text-generation",
model=model,
tokenizer=llama_tokenizer,
max_new_tokens=256,
do_sample=False,
return_full_text=False,
)

print(text_gen("Hello ! Who are you ?"))

You should get something like this in the output:

[{'generated_text': ' I am a writer, a poet, a musician, a dancer, a painter, a sculptor, a filmmaker, a photographer, a cartoonist, a journalist, a teacher, a student, a lover, a friend, a stranger, a human being, a cat, a dog, a bird, a tree, a rock, a sandstone, a mineral, a fossil, a plant, a fungus, a bacterium, a virus, a microbe, a parasite, a symbiosis, a symphony, a symmetry, a chaos, a harmony, a balance, a balance of forces, a balance of energies, a balance of opposites, a balance of opposing forces, a balance of opposing principles, a balance of opposing ideas, a balance of opposing emotions, a balance of opposing thoughts, a balance of opposing desires, a balance of opposing needs, a balance of opposing needs, a balance of opposing desires, a balance of opposing emotions, a balance of opposing principles, a balance of opposing forces, a balance of opposing energies, a balance of opposing symb'}]

We can also fine-tune it using HuggingFace libraries: PEFT and TRL. PEFT stands for “Parameter-Efficient Fine-Tuning” and it implements different types of low-rank adapter LLM fine-tuning methods. TRL stands for “Transformer Reinforcement Learning” and implements general fine-tuning workflows.
You can read all about it here: https://huggingface.co/docs/trl/main/en/lora_tuning_peft

The implementation uses QLoRa, an approach that is able to fine-tune adapter weights of a quantized version of the full model. This allows us to run the training with around 3Gb of VRam using a mini-batch size of 8 which makes it possible to run in most consumer grade GPUs.

LoRa are additive low rank adapter weights that are trained while freezing the backbone. It allows to build specialized models that can be trained with a much smaller VRam and disk space footprint. In our case, the weights are only 4.5 MB and include around one million parameters.
Here is the pseudo-code that shows how it works, full code is linked at the end of the post:

import torch
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
TrainingArguments,
)
from trl import SFTTrainer

if __name__ == "__main__":
.
.
.
.
peft_parameters = LoraConfig(
lora_alpha=8,
lora_dropout=0.1,
r=8,
bias="none",
task_type="CAUSAL_LM",
# target_modules=target_modules,
)

base_model = prepare_model_for_kbit_training(base_model)
base_model = get_peft_model(base_model, peft_parameters)

# Training Params
train_params = TrainingArguments(
output_dir=str(BASE_PATH / "results_modified"),
num_train_epochs=EPOCHS,
per_device_train_batch_size=8,
gradient_accumulation_steps=1,
optim="paged_adamw_32bit",
save_steps=len(training_data) // 10,
logging_steps=len(training_data) // 100,
learning_rate=2e-4,
lr_scheduler_type="cosine",
warmup_steps=100,
weight_decay=0.05,
fp16=True,
max_steps=-1,
group_by_length=False,
max_grad_norm=0.3,
)
# Trainer
fine_tuning = SFTTrainer(
model=base_model,
train_dataset=training_data,
data_collator=collator,
peft_config=peft_parameters,
dataset_text_field="Why is this mandatory ?",
tokenizer=llama_tokenizer,
args=train_params,
max_seq_length=llama_tokenizer.model_max_length,
)

print(fine_tuning.model.print_trainable_parameters())
# Training
fine_tuning.train()

The results:

To evaluate whether or not this whole workflow works or not we can look at few outputs of the base Tiny-LLama versus the version distilled from LLama 2–70B’s output. So let’s see:

Example 1:

Corrupted input:
* We dont live in Australia Were just visiting
Base model output:
* We don’t live in Australia, We’re just visiting.
Distilled model output:
* We don’t live in Australia. We are just visiting.

Here the base model fixed some of the issues but messed up the punctuation.

Example 2:

Corrupted input:
* Je ai été surprise.
Base model output:
* I was surprised.
Distilled model output:
* J’ai été surprise.

Here the base model fixed the sentence but created an output in English instead of in the original french while the distilled model fixed it in French.

We can also compute the fraction of cases where the output of the model matches exactly with expected output. This metric is flawed as there can be multiple ways a sentence can be fixed (“It is very hard to get rid of bad habit.” can be corrected as “It is very hard to get rid of bad habits.” or “It is very hard to get rid of a bad habit.”) but it can serve as a good proxy of the quality of generation. We get the following scores:

LLama 2–70B: 42%
Base Tiny-LLama: 11%
Distilled Tiny-LLama: 31%

While we are still far from the performance of the teacher model, we were able to significantly improve the performance of the student model from 11% to 31%. The gap from 31% to 42% can be bridged by either using a larger distillation dataset or a bigger student model.

Conclusion:

By distilling knowledge from a high-capacity teacher model, such as the LLama 2–70B, to a more compact student model like Tiny-LLama, we navigate the trade-offs between computational efficiency and task-specific accuracy. This process involves crafting prompts, acquiring unlabeled in-domain data, and fine-tuning the student model using pseudo-labels generated by the teacher model. This approach mitigates the computational and deployment expenses associated with larger LLMs.

The implementation showcased here, focusing on multi-language grammatical error correction, underscores the practicality and effectiveness of knowledge distillation. Despite the laborious and time-consuming nature of data annotation, distillation techniques offer a scalable solution by automating the generation of labeled data through targeted prompting. Moreover, advancements in model quantization and training methodologies, such as QLoRa and PeFt, further optimize the training of specialized models on consumer-grade GPUs.

Evaluation results demonstrate a notable improvement in the performance of the student model, transitioning from 11% accuracy to 31% exact match score, albeit still below the benchmark set by the teacher model at 42%. However, this progress underscores the efficacy of distillation techniques in bridging the gap between computational efficiency and task-specific accuracy.

Code: https://github.com/CVxTz/distill-llm