✨ Transfer Learning with BERT for Text Classification

πŸ” Introduction

In this blog, we’ll walk through a full pipeline for applying transfer learning using BERT (Bidirectional Encoder Representations from Transformers) to a text classification task. We'll use the Hugging Face Transformers library along with PyTorch to build, train, and evaluate the model.


Architecture

Dive into the Code

πŸ“¦ Step 1: Installing Dependencies

pythonCopyEdit!pip install transformers

We install the Hugging Face transformers library, which provides pretrained BERT models and tokenizers.


🧠 Step 2: Importing Libraries

pythonCopyEditimport pandas as pd
import numpy as np
import torch
import torch.nn as nn
import transformers
from transformers import AutoModel, AutoTokenizer
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report

This includes:

  • Pandas & NumPy for data handling

  • PyTorch for modeling

  • Transformers for BERT

  • Sklearn for metrics and data splitting


πŸ’» Step 3: Setting Up the Device

pythonCopyEditdevice = "cuda" if torch.cuda.is_available() else "cpu"

This lets the model use a GPU if available, otherwise defaults to CPU.


πŸ“ Step 4: Loading the Dataset

pythonCopyEditdf = pd.read_csv('data.csv')
df.head()

You’re expected to have a data.csv file containing:

  • text or similar column with the input text

  • label column for classification


βœ‚οΈ Step 5: Train-Test Split

pythonCopyEdittrain_texts, val_texts, train_labels, val_labels = train_test_split(df['text'], df['label'], test_size=0.2)

We split the data into 80% training and 20% validation.


✨ Step 6: Tokenizing with BERT

pythonCopyEdittokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
tokens_train = tokenizer(list(train_texts), padding=True, truncation=True, return_tensors="pt")
tokens_val = tokenizer(list(val_texts), padding=True, truncation=True, return_tensors="pt")

The tokenizer:

  • Converts text into tokens (input_ids)

  • Pads sequences to equal length

  • Truncates long text

  • Returns tensors directly usable by PyTorch


🧱 Step 7: Creating Attention Masks and Labels

pythonCopyEdittrain_seq = torch.tensor(tokens_train['input_ids'])
train_mask = torch.tensor(tokens_train['attention_mask'])
train_y = torch.tensor(train_labels.tolist())
  • input_ids: tokenized text

  • attention_mask: tells BERT which tokens are real (1) and which are padding (0)

  • train_y: tensor of labels


πŸ“¦ Step 8: Dataset & DataLoader

Custom Dataset class:

pythonCopyEditclass CustomDataset(Dataset):
    def __init__(self, input_ids, attention_mask, labels):
        self.input_ids = input_ids
        self.attention_mask = attention_mask
        self.labels = labels

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return {
            'input_ids': self.input_ids[idx],
            'attention_mask': self.attention_mask[idx],
            'labels': self.labels[idx]
        }

Dataloaders:

pythonCopyEdittrain_data = CustomDataset(train_seq, train_mask, train_y)
train_loader = DataLoader(train_data, batch_size=16, shuffle=True)

🧠 Step 9: Defining the BERT-based Classifier

pythonCopyEditclass BERTClassifier(nn.Module):
    def __init__(self):
        super(BERTClassifier, self).__init__()
        self.bert = AutoModel.from_pretrained('bert-base-uncased')
        self.dropout = nn.Dropout(0.3)
        self.fc = nn.Linear(self.bert.config.hidden_size, 1)

    def forward(self, input_ids, attention_mask):
        _, pooled_output = self.bert(input_ids=input_ids, attention_mask=attention_mask, return_dict=False)
        x = self.dropout(pooled_output)
        return self.fc(x)
  • Uses pretrained BERT

  • Adds dropout for regularization

  • Final fully connected layer for binary classification


βš™οΈ Step 10: Compiling the Model

pythonCopyEditmodel = BERTClassifier().to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)
  • Loss function: Binary Cross Entropy with logits (since output isn’t sigmoid yet)

  • Optimizer: Adam


πŸ” Step 11: Training the Model

pythonCopyEditfor epoch in range(epochs):
    model.train()
    for batch in train_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].float().to(device)

        outputs = model(input_ids, attention_mask).squeeze()
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

Training loop:

  • Gets batches of data

  • Computes forward pass

  • Calculates loss

  • Backpropagates

  • Updates weights


πŸ“Š Step 12: Evaluating the Model

pythonCopyEditmodel.eval()
with torch.no_grad():
    predictions, true_labels = [], []
    for batch in val_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        outputs = model(input_ids, attention_mask).squeeze()
        preds = torch.round(torch.sigmoid(outputs))

        predictions.extend(preds.cpu().numpy())
        true_labels.extend(labels.cpu().numpy())

print(classification_report(true_labels, predictions))
  • Converts logits to probabilities using sigmoid

  • Rounds to 0 or 1

  • Uses classification_report to show precision, recall, F1-score, and accuracy


βœ… Conclusion

This notebook gives you a solid blueprint for applying transfer learning with BERT to any text classification problem.

πŸ§ͺ What You Learned:

  • Tokenizing and preparing data for BERT

  • Using attention masks

  • Building a custom BERT classifier

  • Training and evaluating using PyTorch


🧰 What Next?

  • Try with multi-class classification using nn.CrossEntropyLoss and softmax

  • Experiment with learning rate schedulers

  • Explore model checkpointing and early stopping

Dream.Achieve.Repeat

0
Subscribe to my newsletter

Read articles from GADDAM SAI BHARATH CHANDRA REDDY directly inside your inbox. Subscribe to the newsletter, and don't miss out.

Written by

GADDAM SAI BHARATH CHANDRA REDDY
GADDAM SAI BHARATH CHANDRA REDDY

Code...Design...Create