How I Made a BERT 96% Smaller and 46x Faster (And Kept ~89% performance)

Nijat ZeynalovNijat Zeynalov
7 min read

We often face a classic dilemma: the most powerful models are also the most resource-hungry. A model like bert-base-uncased, with its 109 million parameters, can achieve best results but is often a non-starter for applications that demand low latency, a small memory footprint, or on-device deployment.

I have recently tackled this exact problem. I needed to build a classifier to categorize clinical doctor's notes into one of five medical specialties. While a fine-tuned BERT model gave me a great performance baseline, its size was a roadblock.

This post details my journey using Task-Specific Knowledge Distillation to transfer the "intelligence" of our 109M parameter BERT model into a tiny 4.4M parameter bert-micro model. The result? We shrank the model by 96%, made it 46X faster, and retained nearly 90% of its original performance.

Understanding Knowledge Distillation

Before diving into the code, let's understand the core concept of distillation with an analogy.

Imagine you're training a new student for a complex exam.

  • Standard Fine-Tuning (The Textbook Method): You give the student a textbook and an answer key. The student studies the material and is graded on whether their answers are right or wrong. These are "hard labels" ([0, 0, 1, 0, 0]). It's effective, but the student only learns what the correct answer is.

  • Knowledge Distillation (The Professor Method): Now, imagine an expert professor teaches the student. The professor not only gives the correct answer but also explains their reasoning. They might say, "The answer is C, but I can see why you might think B is plausible because of this subtlety. However, A and D are completely off-track." This nuanced explanation—the "soft labels" ([0.05, 0.20, 0.70, 0.05, 0])—is incredibly rich. It teaches the student how to think and understand the relationships between different concepts.

In my project, bert-base-uncased is the expert professor, and the tiny bert-micro is the eager student.

Our Game Plan and Setup

Our experiment was designed to be a clear, head-to-head comparison.

  1. The Task: A 5-class classification of medical notes.

  2. The Teacher Model: bert-base-uncased (109M parameters). Our powerful but bulky expert.

  3. The Student Model: boltuix/bert-micro (4.4M parameters). An extremely compact model from the same BERT family, making it an ideal distillation candidate.

A crucial decision was to use models from the same family. Both bert-base and bert-micro use the exact same tokenizer, which dramatically simplifies the data preparation pipeline and avoids potential vocabulary mismatch errors.

Step 1: Data Preparation and Tokenization

First, we loaded our custom dataset and prepared the labels. The key here is that since our models are compatible, we only need a single tokenizer and one tokenization function.

from datasets import load_from_disk
from transformers import AutoTokenizer

# --- 1. Load Data and Define Labels ---
med_dataset = load_from_disk("med_dr_notes")

label_list = ['Dermatology', 'Gastroenterology', 'Endocrinology', 'Oncology', 'Pulmonology']
label2id = {label: idx for idx, label in enumerate(label_list)}
# other stuff

# --- 2. Initialize a Single, Shared Tokenizer ---
teacher_id = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(teacher_id)

# --- 3. Create a Tokenization Function ---
def tokenize_function(examples):
    # both models use same tokenizer
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=512)

# --- 4. Apply to All Dataset Splits ---
tokenized_train = train_dataset.map(tokenize_function, batched=True)
tokenized_eval = eval_dataset.map(tokenize_function, batched=True)
# ... and so on.

This simple setup ensures that the input_ids we generate are valid for both our teacher and student models.

Step 2: Training the Expert Teacher

Before we could start distilling, we needed our expert. We fine-tuned bert-base-uncased on our medical notes dataset using a standard Trainer from Hugging Face. This process establishes our performance ceiling—the "gold standard" we want the student to emulate.

from transformers import AutoModelForSequenceClassification, Trainer, TrainingArguments

# --- 1. Load the Teacher Model ---
teacher_model = AutoModelForSequenceClassification.from_pretrained(
    "bert-base-uncased",
    num_labels=len(label_list),
    id2label=id2label,
    label2id=label2id
)

# --- 2. Define Standard Training Arguments ---
teacher_training_args = TrainingArguments(
    output_dir="models/teacher_bert_med_notes",
    num_train_epochs=3,
    per_device_train_batch_size=16,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    report_to="none"
)

# --- 3. Train the Model ---
teacher_trainer = Trainer(
    model=teacher_model,
    args=teacher_training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_eval,
    # ... compute_metrics function ...
)

teacher_trainer.train()

After three epochs, our teacher model achieved an F1-score of 0.9056 on the test set. This became our benchmark.

Step 3: The Knowledge Transfer

This is where the core distillation logic resides. We created a custom DistillationTrainer that overrides the default loss calculation.

The compute_loss function is the heart of this process. It calculates three things:

  1. student_loss: The standard cross-entropy loss against the "hard" ground-truth labels.

  2. teacher_outputs: The "soft" probability distribution from our expert teacher.

  3. distillation_loss: A special loss (Kullback-Leibler Divergence) that measures how well the student's soft predictions match the teacher's.

Here’s the implementation:

import torch.nn.functional as F

class DistillationTrainer(Trainer):
    def __init__(self, *args, teacher_model=None, alpha=0.5, temperature=2.0, **kwargs):
        super().__init__(*args, **kwargs)
        self.teacher_model = teacher_model
        # Hyperparameters to balance the two losses
        self.alpha = alpha
        self.temperature = temperature
        if self.teacher_model:
            self.teacher_model.to(self.args.device)
            self.teacher_model.eval()

    def compute_loss(self, model, inputs, return_outputs=False):
        # 1. Get student's output and standard loss (the "textbook" part)
        student_outputs = model(**inputs)
        student_loss = student_outputs.loss

        # 2. Get teacher's soft predictions (the "professor's lecture" part)
        with torch.no_grad():
            teacher_outputs = self.teacher_model(**inputs)

        # 3. Calculate distillation loss using KL Divergence
        # Temperature "softens" the probabilities to emphasize relational knowledge
        distillation_loss = F.kl_div(
            F.log_softmax(student_outputs.logits / self.temperature, dim=-1),
            F.softmax(teacher_outputs.logits / self.temperature, dim=-1),
            reduction='batchmean'
        ) * (self.temperature ** 2)

        # 4. Combine the two losses.
        # `alpha` controls the balance between the textbook and the professor.
        loss = self.alpha * student_loss + (1.0 - self.alpha) * distillation_loss
        return (loss, student_outputs) if return_outputs else loss

We then initialized this trainer, feeding it both our tiny bert-micro student and our expert bert-base teacher.

      # --- Load the tiny student model ---
student_model = AutoModelForSequenceClassification.from_pretrained(
    "boltuix/bert-micro",
    num_labels=num_labels,
    # ... id2label, label2id ...
)

# --- Initialize our custom trainer ---
distillation_trainer = DistillationTrainer(
    model=student_model,
    teacher_model=teacher_model,
    args=student_training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_eval,
    # ... compute_metrics ...
    alpha=0.3,       # 30% textbook, 70% professor's lecture
    temperature=3.0  # Soften probabilities to extract more knowledge
)

distillation_trainer.train()

Because bert-micro has so little capacity, we relied heavily on the teacher's guidance by setting alpha to 0.3.

The Results

The final comparison revealed a highly practical outcome.

ModelParametersMacro F1-ScoreSize ReductionPerformance RetainedSpeed
Teacher (bert-base-uncased)109.49 M0.9056-100%1x
Student (boltuix/bert-micro)4.39 M0.8065-95.99%89.06%46.38x

96% Smaller Model

This is an incredible win for efficiency. The student model is 25 times smaller than the teacher. This transforms the application from a server-heavy service to something that could potentially run in a web browser, on a smartphone, or on a low-cost IoT device.

The operational cost savings and latency improvements are perfect.

10% Drop in Performance

As expected, this massive compression wasn't free. The student's F1-score of 0.8065 is a noticeable step down from the teacher's 0.9056. For a life-critical medical application, this might not be acceptable.

But for this use case (first-pass triaging tool), this level of performance could be more than sufficient.

Retaining 89% of the Knowledge

This is the true success story. Despite having only 4% of the parameters, the student model retained nearly 90% of the teacher's task-specific knowledge.

A bert-micro model trained from scratch on its own would have likely struggled to reach this level of performance. The teacher's guidance was essential.

~46x speedup

This is arguably the most critical business outcome. A ~46x speedup is the difference between a real-time interactive tool and a slow, batch-processing system. It unlocks new product possibilities and dramatically improves the user experience.

Final Thoughts

Knowledge distillation is a vital tool for the ML engineer. It allows us to bridge the gap between state-of-the-art research models and practical, deployable applications.

For next articles, I will share other KD approaches that I am planning to apply to this use case. I think to test Internal Representation Matching which is a powerful technique if we need to squeeze out an extra bit of performance.

I also will try Teacher Assistant which can be most valuable in our case. We use it when the capacity gap between the teacher and student is enormous (in our case 109M -> 4.4M)

You can checkout following repo for full implementation: NijatZeynalov/bert-distillation: BERT Knowledge Distillation for Medical Note Classification

Anyway, see you!

0
Subscribe to my newsletter

Read articles from Nijat Zeynalov directly inside your inbox. Subscribe to the newsletter, and don't miss out.

Written by

Nijat Zeynalov
Nijat Zeynalov