Making Pre-Trained Models More Robust

Kiarash KianiKiarash Kiani
9 min read

Introduction

In the ever-evolving landscape of machine learning, robustness is essential for building models that perform reliably in real-world scenarios. As models are deployed in diverse environments, they face many challenges, including noisy data, adversarial attacks, and unexpected input variations. In this blog post, I will try to walk you through the concept of robustification, mainly focusing on how to enhance pre-trained models to withstand these challenges by showing you how to do it in code, but first, Let me give you some contexts.

What is Robustification?

Robustification refers to the process of making machine learning models more resilient to noise or attacks in the input data. It involves techniques designed to improve the model's ability to generalize well across various conditions, ensuring that the performance does not degrade significantly when exposed to unexpected or noisy inputs.

Why is it Important?

Robustification is crucial for several reasons:

  1. Real-World Application: Models deployed in the real world often encounter data that is noisy or differs from the training data. Robust models maintain their performance despite these discrepancies.

  2. Security: Robust models are less susceptible to adversarial attacks where malicious inputs are designed to deceive the model.

  3. Reliability: Ensuring consistent performance across various scenarios builds trust in AI systems, which is essential for their adoption in critical applications such as healthcare and autonomous driving.

See it in a Simple Example

Let's start with a simple example using a feed-forward neural network to classify movie reviews. We'll add noise to the dataset and see how robustification can help improve the model's performance.

import random
import string
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import CountVectorizer
from torch.utils.data import DataLoader, Dataset

# Sample data
texts = ["This is a good movie", "I hated this movie", "It was an awesome film", "Terrible film", "Great movie!", "I will not watch this again", "Fantastic film"]
labels = [1, 0, 1, 0, 1, 0, 1]

# Adding noise to the dataset (e.g., by randomly shuffling some characters)
def add_noise(text):
    noisy_text = list(text)
    for _ in range(int(0.1 * len(text))):
        idx = torch.randint(0, len(text), (1,))[0]
        noisy_text[idx] = random.choice(string.ascii_letters)
    return ''.join(noisy_text)

noisy_texts = [add_noise(text) for text in texts]

# Vectorizing the text data
vectorizer = CountVectorizer(binary=True)
X = vectorizer.fit_transform(texts + noisy_texts).toarray()
y = labels + labels  # Duplicating labels for noisy data

# Train-test split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Custom Dataset class
class TextDataset(Dataset):
    def __init__(self, texts, labels):
        self.texts = texts
        self.labels = labels

    def __len__(self):
        return self.texts.shape[0]

    def __getitem__(self, idx):
        return torch.tensor(self.texts[idx], dtype=torch.float32), torch.tensor(self.labels[idx], dtype=torch.long)

train_dataset = TextDataset(X_train, y_train)
test_dataset = TextDataset(X_test, y_test)

train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False)

# Simple feed-forward neural network
class SimpleNN(nn.Module):
    def __init__(self, input_size, num_classes):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(input_size, 50)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(50, num_classes)

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out

# Model initialization
input_size = X_train.shape[1]
num_classes = 2
model = SimpleNN(input_size, num_classes)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training function
def train_model(model, train_loader, criterion, optimizer, num_epochs=5):
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0
        for texts, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(texts)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f'Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(train_loader):.4f}')

# Train the model
train_model(model, train_loader, criterion, optimizer)

Using Robustification on an Llm

Large Language Models have been on top of the news over the past years. Especially after the release of ChatGPT, everyone is trying to push these model into their system to automate more manual tasks and efforts. However, in real-world examples, data is always messy and is not in our favor. We clean, preprocess, and store vast amounts of data, but still, there are noises in data that we cannot get rid of it quickly. When it comes to text data, noises usually appear as typos. Typos can happen for many reasons, and you cannot expect your users to write always without any mistakes. So, in this example, I will walk you through how to make your LLM model robust against typos.

We will use an MLM, a self-supervised learning task where parts of the input text are masked, and the model is trained to predict the masked tokens. This helps the model learn contextual representations of the text.

Robustification in LLMs involves advanced techniques such as Robust Contrastive Pre-training. This method leverages contrastive learning to enhance the model's ability to handle noisy and perturbed inputs.

1. Dataset

Let's begin by loading some datasets that we will need later to feed our pre-trained model. I will use yelp_reveiw_full dataset from the Hugging Face hub. The dataset is text reviews from Yelp along with ratings in two columns (text and label).

The load_dataset function from datasets library should be able to fetch the data and load it. However, I would like to save part of the data to test the functional structure of my application without the need to fetch the data every time. Furthermore, I need to create a new column that has the text column data with normally distributed noises.

import torch
from torch.utils.data import  Dataset
import pandas as pd
from datasets import load_dataset
import random
import string


class SentenceDataset(Dataset):
    @staticmethod
    def from_huggingface() -> 'SentenceDataset':
        data = load_dataset("yelp_review_full", split='train')
        return SentenceDataset(pd.DataFrame(data))

    @staticmethod
    def from_csv(path) -> 'SentenceDataset':
        return SentenceDataset(pd.read_csv(path))

    @staticmethod
    def __add_noise(text):
        noisy_text = list(text)
        for _ in range(int(0.1 * len(text))):
            idx = random.randint(0, len(text) - 1)
            noisy_text[idx] = random.choice(string.ascii_letters)
        return ''.join(noisy_text)

    def __init__(self, dataframe):
        self.dataframe = dataframe
        self.dataframe['corrupted_text'] = self.dataframe['text'].apply(SentenceDataset.__add_noise)
        self.dataframe['label'] = self.dataframe['label'].astype(float)

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        original_text = self.dataframe.iloc[idx]['text']
        corrupted_text = self.dataframe.iloc[idx]['corrupted_text']
        label = self.dataframe.iloc[idx]['label']
        return original_text, corrupted_text, label

2. Model Definition

Next, we define our model, which extends a pre-trained BERT model with a regression head for predicting ratings.

import torch 
import torch.nn as nn 
from transformers import AutoModel, AutoTokenizer

class RobustContrastivePretrainingModel(nn.Module):
    def __init__(self, model_name='bert-base-multilingual-cased'):
        super(RobustContrastivePretrainingModel, self).__init__()
        self.encoder = AutoModel.from_pretrained(model_name)
        self.pooler = nn.Identity()
        self.regressor = nn.Linear(self.encoder.config.hidden_size, 1)

    def forward(self, input_ids, attention_mask):
        outputs = self.encoder(input_ids, attention_mask=attention_mask)
        pooled_output = outputs.last_hidden_state.mean(dim=1)
        rating = self.regressor(pooled_output).squeeze(-1)
        return pooled_output, rating

3. Loss Function

So far, we have built our wrapper for loading the data, and our model is based on the BERT model. We did add noises to our our dataset and we saw in the previous example that adding noise can help model to be more generalized and more reliable against the noise but we want to push our model be more robust. To do that, we are going to customize our loss function to include contrastive loss function as part of its loss function.

The contrastive loss function is used in contrastive learning, a technique where the model learns to distinguish between similar and dissimilar pairs of inputs. It improves the model's robustness by encouraging it to create distinct representations for different classes, even in the presence of noise. By using the contrastive loss function here, we push the original texts and the noisy texts with the same labels closer to each other and farther away from the other labels.

Even though, we are pushing similar vectorized texts to come closer together, this will only update the encoder networks. Still we need to calculate the loss for our regression head that is going to predict the rating for us. To do that, we use the MSE loss function and then sum it with the contrastive loss to calculate the total loss.

import torch.nn.functional as F

class ContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.05):
        super(ContrastiveLoss, self).__init__()
        self.temperature = temperature

    def forward(self, anchor, positive, negatives):
        pos_similarity = F.cosine_similarity(anchor, positive) / self.temperature
        neg_similarities = [F.cosine_similarity(anchor, neg) / self.temperature for neg in negatives]
        pos_exp = torch.exp(pos_similarity)
        neg_exp_sum = torch.sum(torch.exp(torch.stack(neg_similarities)), dim=0)
        loss = -torch.log(pos_exp / (pos_exp + neg_exp_sum))
        return loss.mean()

class RobustContrastivePretrainLoss(nn.Module):
    def __init__(self, mlm_loss_fn=nn.MSELoss(), contrastive_loss_temperature=0.05):
        super(RobustContrastivePretrainLoss, self).__init__()
        self.contrastive_loss_fn = ContrastiveLoss(contrastive_loss_temperature)
        self.mlm_loss_fn = mlm_loss_fn

    def forward(self, anchor, positive, negatives, anchor_ratings, positive_ratings, labels):
        contrastive_loss = self.contrastive_loss_fn(anchor, positive, negatives)
        mlm_loss_original = self.mlm_loss_fn(anchor_ratings, labels)
        mlm_loss_noisy = self.mlm_loss_fn(positive_ratings, labels)
        total_loss = contrastive_loss + mlm_loss_original + mlm_loss_noisy
        return total_loss

4. Training Loop

Lastly, We define the training loop that leverages our total loss function to train the model.

def collate_fn(batch, tokenizer):
    original, corrupted, labels = zip(*batch)
    original_encodings = tokenizer(list(original), return_tensors='pt', padding=True, truncation=True)
    corrupted_encodings = tokenizer(list(corrupted), return_tensors='pt', padding=True, truncation=True)
    labels_tensor = torch.tensor(labels, dtype=torch.float32)
    return original_encodings, corrupted_encodings, labels_tensor

def trainer(model, dataloader, optimizer, rcp_loss_fn, device):
    model.train()
    total_loss = 0.0
    for batch in dataloader:
        original_encodings, corrupted_encodings, labels = batch
        original_input_ids = original_encodings['input_ids'].to(device)
        original_attention_mask = original_encodings['attention_mask'].to(device)
        corrupted_input_ids = corrupted_encodings['input_ids'].to(device)
        corrupted_attention_mask = corrupted_encodings['attention_mask'].to(device)
        labels = labels.to(device)

        anchor, anchor_ratings = model(original_input_ids, original_attention_mask)
        positive, positive_ratings = model(corrupted_input_ids, corrupted_attention_mask)

        negative_samples = torch.cat([anchor[1:], anchor[:1]], dim=0)

        loss = rcp_loss_fn(anchor, positive, [negative_samples], anchor_ratings, positive_ratings, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    avg_loss = total_loss / len(dataloader)
    return avg_loss

6. Putting it All together

Finally, we bring everything together, including dataset loading, model initialization, and training.

import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
import pandas as pd
from datasets import load_dataset
import random
import string

DATASET_FILE_PATH = 'data/review.csv'
EPOCHS = 5
DEVICE = torch.device('cpu')

class SentenceDataset(Dataset):
    @staticmethod
    def from_huggingface() -> 'SentenceDataset':
        data = load_dataset("yelp_review_full", split='train')
        return SentenceDataset(pd.DataFrame(data))

    @staticmethod
    def from_csv(path) -> 'SentenceDataset':
        return SentenceDataset(pd.read_csv(path))

    @staticmethod
    def __add_noise(text):
        noisy_text = list(text)
        for _ in range(int(0.1 * len(text))):
            idx = random.randint(0, len(text) - 1)
            noisy_text[idx] = random.choice(string.ascii_letters)
        return ''.join(noisy_text)

    def __init__(self, dataframe):
        self.dataframe = dataframe
        self.dataframe['corrupted_text'] = self.dataframe['text'].apply(SentenceDataset.__add_noise)
        self.dataframe['label'] = self.dataframe['label'].astype(float)

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        original_text = self.dataframe.iloc[idx]['text']
        corrupted_text = self.dataframe.iloc[idx]['corrupted_text']
        label = self.dataframe.iloc[idx]['label']
        return original_text, corrupted_text, label

class ContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.05):
        super(ContrastiveLoss, self).__init__()
        self.temperature = temperature

    def forward(self, anchor, positive, negatives):
        pos_similarity = F.cosine_similarity(anchor, positive) / self.temperature
        neg_similarities = [F.cosine_similarity(anchor, neg) / self.temperature for neg in negatives]
        pos_exp = torch.exp(pos_similarity)
        neg_exp_sum = torch.sum(torch.exp(torch.stack(neg_similarities)), dim=0)
        loss = -torch.log(pos_exp / (pos_exp + neg_exp_sum))
        return loss.mean()

class RobustContrastivePretrainLoss(nn.Module):
    def __init__(self, mlm_loss_fn=nn.MSELoss(), contrastive_loss_temperature=0.05):
        super(RobustContrastivePretrainLoss, self).__init__()
        self.contrastive_loss_fn = ContrastiveLoss(contrastive_loss_temperature)
        self.mlm_loss_fn = mlm_loss_fn

    def forward(self, anchor, positive, negatives, anchor_ratings, positive_ratings, labels):
        contrastive_loss = self.contrastive_loss_fn(anchor, positive, negatives)
        mlm_loss_original = self.mlm_loss_fn(anchor_ratings, labels)
        mlm_loss_noisy = self.mlm_loss_fn(positive_ratings, labels)
        total_loss = contrastive_loss + mlm_loss_original + mlm_loss_noisy
        return total_loss

class RobustContrastivePretrainingModel(nn.Module):
    def __init__(self, model_name='bert-base-multilingual-cased'):
        super(RobustContrastivePretrainingModel, self).__init__()
        self.encoder = AutoModel.from_pretrained(model_name)
        self.pooler = nn.Identity()
        self.regressor = nn.Linear(self.encoder.config.hidden_size, 1)

    def forward(self, input_ids, attention_mask):
        outputs = self.encoder(input_ids, attention_mask=attention_mask)
        pooled_output = outputs.last_hidden_state.mean(dim=1)
        rating = self.regressor(pooled_output).squeeze(-1)
        return pooled_output, rating

def collate_fn(batch, tokenizer):
    original, corrupted, labels = zip(*batch)
    original_encodings = tokenizer(list(original), return_tensors='pt', padding=True, truncation=True)
    corrupted_encodings = tokenizer(list(corrupted), return_tensors='pt', padding=True, truncation=True)
    labels_tensor = torch.tensor(labels, dtype=torch.float32)
    return original_encodings, corrupted_encodings, labels_tensor

def trainer(model, dataloader, optimizer, rcp_loss_fn, device):
    model.train()
    total_loss = 0.0
    for batch in dataloader:
        original_encodings, corrupted_encodings, labels = batch
        original_input_ids = original_encodings['input_ids'].to(device)
        original_attention_mask = original_encodings['attention_mask'].to(device)
        corrupted_input_ids = corrupted_encodings['input_ids'].to(device)
        corrupted_attention_mask = corrupted_encodings['attention_mask'].to(device)
        labels = labels.to(device)

        anchor, anchor_ratings = model(original_input_ids, original_attention_mask)
        positive, positive_ratings = model(corrupted_input_ids, corrupted_attention_mask)

        negative_samples = torch.cat([anchor[1:], anchor[:1]], dim=0)

        loss = rcp_loss_fn(anchor, positive, [negative_samples], anchor_ratings, positive_ratings, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    avg_loss = total_loss / len(dataloader)
    return avg_loss

# Example usage
model_name = 'bert-base-multilingual-cased'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = RobustContrastivePretrainingModel(model_name=model_name).to(DEVICE)
total_loss_fn = RobustContrastivePretrainLoss().to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-5)

# Create dataset and dataloader
dataset = SentenceDataset.from_csv(DATASET_FILE_PATH)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=lambda x: collate_fn(x, tokenizer))

# Training
for epoch in range(EPOCHS):
    avg_loss = trainer(model, dataloader, optimizer, total_loss_fn, DEVICE)
    print(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}")

Next Step?

Robustification is an ongoing process that requires continuous monitoring and updating of models to ensure they remain resilient to new types of noise and adversarial attacks. Future steps include:

  1. Experimentation with Different Robustification Techniques: Try various methods like adversarial training, data augmentation, and noise injection.

  2. Continuous Monitoring: Regularly evaluate the model's performance in real-world scenarios and update it as needed.

  3. Community Collaboration: Engage with the research community to stay updated with the latest advancements in robust machine learning.

0
Subscribe to my newsletter

Read articles from Kiarash Kiani directly inside your inbox. Subscribe to the newsletter, and don't miss out.

Written by

Kiarash Kiani
Kiarash Kiani

Machine Learning Engineer at datachef.co