Transfer Learning

Gargi KaushikGargi Kaushik
8 min read

1. Introduction

Transfer Learning is a machine learning technique where a model developed for a particular task is reused as the starting point for a model on a second, related task. Instead of training a model from scratch, we transfer knowledge from a pre-trained model to a new problem. This is especially useful when there is limited data available for the target task, as it leverages patterns and features learned in a related domain.

Key Benefits:

  1. Reduced Training Time: Pre-trained models reduce the computational cost as fewer resources are needed.

  2. Improved Performance: Can yield better accuracy for tasks with limited data.

  3. Resource Efficiency: Avoids the need for massive datasets and large compute resources.

  4. Adaptability: Useful for a range of domains, including computer vision, natural language processing, and time-series analysis.


2. Types of Transfer Learning

  1. Domain Adaptation: When the source and target domains are different but related (e.g., training on synthetic images to adapt to real images).

  2. Inductive Transfer Learning: Both source and target tasks are similar (e.g., object recognition in cats and dogs).

  3. Transductive Transfer Learning: Different domains with similar tasks (e.g., sentiment analysis in English to Spanish).

  4. Self-taught Learning: Learning a representation on unrelated but large-scale data and transferring that representation to the target task.


3. Approaches in Transfer Learning

3.1 Fine-Tuning

Fine-tuning is the process of adjusting a pre-trained model on a new dataset. It involves:

  • Freezing Layers: Freezing earlier layers to retain low-level feature extraction (e.g., edges in images).

  • Training Final Layers: Adjusting higher-level layers to adapt to the specific features of the new task.

Use Case: Fine-tuning a model pre-trained on ImageNet for identifying specific objects in medical images.

3.2 Feature Extraction

In feature extraction, the pre-trained model is used as a fixed feature extractor.

  • Freeze Entire Model: The entire pre-trained model is frozen, and only a new classifier is trained on the extracted features.

Use Case: Using a convolutional neural network (CNN) pre-trained on ImageNet to extract features from images, then using these features for image classification tasks.

3.3 Weight Initialization

Transfer Learning can also be implemented by using pre-trained weights as initialization rather than random initialization, especially in similar tasks.

Use Case: Initializing weights with a model trained on large document corpus for language tasks.


4. Practical Applications

  1. Computer Vision

    • Object detection, image classification, and image segmentation.

    • Example: A model trained on large-scale datasets (e.g., ImageNet) is adapted for specific tasks like medical imaging or autonomous driving.

  2. Natural Language Processing (NLP)

    • Text classification, sentiment analysis, translation, and question-answering.

    • Example: BERT or GPT, models trained on vast text corpora, can be fine-tuned for specific tasks like sentiment analysis on reviews.

  3. Speech Recognition

    • Transfer learning in automatic speech recognition (ASR) to recognize specific accents or new languages by fine-tuning pre-trained ASR models.
  4. Time-Series Forecasting

    • Using pre-trained models on similar time-series datasets for predictions in finance, weather, or traffic.

Computer Vision Models:

  • ResNet: Effective in image classification; deep residual networks for retaining learned knowledge.

  • VGG: Known for simplicity and high performance in image classification.

  • Inception: Designed to handle a variety of scales with inception modules.

  • YOLO (You Only Look Once): Pre-trained for object detection tasks.

NLP Models:

  • BERT (Bidirectional Encoder Representations from Transformers): Good for sentence understanding tasks.

  • GPT (Generative Pre-trained Transformer): Effective for text generation tasks.

  • RoBERTa: Robustly optimized BERT variant for various NLP tasks.

  • T5 (Text-to-Text Transfer Transformer): Converts all tasks into text-to-text format for versatile use.

Speech Models:

  • Wav2Vec: For automatic speech recognition.

  • DeepSpeech: An open-source ASR model trained on diverse audio datasets.


6. Implementing Transfer Learning

Example: Fine-Tuning a Pre-Trained CNN (ResNet) for Custom Image Classification

Step 1: Import Required Libraries

pythonCopy codeimport tensorflow as tf
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras.models import Model
from tensorflow.keras.preprocessing.image import ImageDataGenerator

Step 2: Load the Pre-Trained Model

pythonCopy code# Load pre-trained ResNet50 with ImageNet weights, excluding the top layer
base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))

# Freeze the base model
for layer in base_model.layers:
    layer.trainable = False

Step 3: Add Custom Layers

pythonCopy code# Add custom layers on top
x = Flatten()(base_model.output)
x = Dense(128, activation='relu')(x)
x = Dense(1, activation='sigmoid')(x)

# Define new model
model = Model(inputs=base_model.input, outputs=x)

Step 4: Compile and Train the Model

pythonCopy codemodel.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
# Train model on new data
model.fit(train_data, validation_data=val_data, epochs=10)

7. Challenges and Limitations

  • Negative Transfer: When the knowledge transferred harms rather than helps performance on the target task.

  • Domain-Specificity: Models may perform poorly if the target task is too different from the source task.

  • Data Scarcity: Though helpful with limited data, transfer learning still requires some amount of labeled data for fine-tuning.

  • Overfitting: Models may overfit if not tuned carefully, especially when the target dataset is small.


8. Conclusion

Transfer Learning is a powerful technique that allows models to leverage prior knowledge and improve performance on related tasks with less data. It has proven transformative in fields like computer vision, NLP, and more. However, careful consideration of the domain and task similarity is essential to fully benefit from transfer learning without negative effects.


Detailed examples to demonstrate different approaches in transfer learning:

1. Fine-Tuning a Pre-Trained Model (VGG16) for a Custom Classification Task

In this example, we’ll use VGG16 to fine-tune for classifying images into two categories (e.g., Cats vs. Dogs).

Step 1: Import Libraries

pythonCopy codeimport tensorflow as tf
from tensorflow.keras.applications import VGG16
from tensorflow.keras.layers import Dense, Flatten, Dropout
from tensorflow.keras.models import Model
from tensorflow.keras.preprocessing.image import ImageDataGenerator

Step 2: Load Pre-Trained Model and Freeze Initial Layers

pythonCopy code# Load VGG16 with ImageNet weights, without the top fully-connected layers
base_model = VGG16(weights='imagenet', include_top=False, input_shape=(224, 224, 3))

# Freeze all layers except last 4 layers,,for fine-tuning
for layer in base_model.layers[:-4]:
    layer.trainable = False

Step 3: Add Custom Layers

pythonCopy code# Add custom classification layers
x = base_model.output
x = Flatten()(x)
x = Dense(256, activation='relu')(x)
x = Dropout(0.5)(x)  # Dropout-to reduce overfitting
x = Dense(1, activation='sigmoid')(x)  # Binary classification

#new model
model = Model(inputs=base_model.input, outputs=x)

Step 4: Compile and Train the Model

pythonCopy codemodel.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
              loss='binary_crossentropy',
              metrics=['accuracy'])

# Set up data generators for training and validation
train_datagen = ImageDataGenerator(rescale=1./255, rotation_range=20, shear_range=0.2, zoom_range=0.2, horizontal_flip=True)
val_datagen = ImageDataGenerator(rescale=1./255)

train_data = train_datagen.flow_from_directory('data/train', target_size=(224, 224), batch_size=32, class_mode='binary')
val_data = val_datagen.flow_from_directory('data/val', target_size=(224, 224), batch_size=32, class_mode='binary')

# Train model
model.fit(train_data, validation_data=val_data, epochs=10)

2. Transfer Learning for NLP (Fine-Tuning BERT for Sentiment Analysis)

In this example, we use BERT from the transformers library by Hugging Face to fine-tune for a binary sentiment classification task.

Step 1: Install and Import Required Libraries

pythonCopy code!pip install transformers torch
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
from torch.utils.data import Dataset, DataLoader
import torch

Step 2: Define Dataset Class

pythonCopy codeclass SentimentDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]

        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            return_token_type_ids=False,
            padding='max_length',
            return_attention_mask=True,
            return_tensors='pt',
        )

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

Step 3: Prepare Data

pythonCopy code# Example data
texts = ["I love this!", "I hate this!"]
labels = [1, 0]  # 1=Positive, 0=Negative

# Initialize tokenizer and dataset
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
dataset = SentimentDataset(texts, labels, tokenizer, max_len=128)

# Split into data loaders
train_loader = DataLoader(dataset, batch_size=2, shuffle=True)

Step 4: Load and Fine-Tune the BERT Model

pythonCopy code# Load BERT pre-trained model for binary classification
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)

# Define training arguments
training_args = TrainingArguments(
    output_dir='./results',          
    num_train_epochs=3,              
    per_device_train_batch_size=2,   
    per_device_eval_batch_size=2,    
    warmup_steps=10,                 
    weight_decay=0.01,               
    logging_dir='./logs',            
)

# Initialize Trainer
trainer = Trainer(
    model=model,                        
    args=training_args,                 
    train_dataset=dataset              
)

# Train model
trainer.train()

3. Transfer Learning with YOLOv5 for Object Detection

In this example, we use YOLOv5 pre-trained weights for custom object detection tasks (e.g., detecting specific objects like logos or plants).

Step 1: Install YOLOv5 Repository and Dependencies

pythonCopy code!git clone https://github.com/ultralytics/yolov5
%cd yolov5
!pip install -r requirements.txt

Step 2: Prepare Custom Dataset

Create a custom dataset with labeled images in YOLO format (bounding boxes in .txt files alongside images in the dataset folder).

Step 3: Fine-Tune YOLOv5

Use YOLOv5’s command-line tool to train on your dataset, specifying a custom configuration.

pythonCopy code!python train.py --img 640 --batch 16 --epochs 50 --data path/to/data.yaml --weights yolov5s.pt

In the above command:

  • --img 640: Image resolution for training.

  • --batch 16: Batch size.

  • --epochs 50: Number of training epochs.

  • --data path/to/data.yaml: Path to your dataset configuration file.

  • --weights yolov5s.pt: Pre-trained YOLOv5 weights.


4. Using a Pre-Trained RNN for Transfer Learning in Time-Series Forecasting

In this example, we use an LSTM model trained on a generic dataset for time-series forecasting to adapt to another domain (e.g., stock price prediction).

Step 1: Import Libraries

pythonCopy codeimport tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense
import numpy as np

Step 2: Define and Load Pre-Trained Model

pythonCopy code# Define an LSTM model 
model = Sequential([
    LSTM(50, activation='relu', input_shape=(10, 1), return_sequences=True),
    LSTM(50, activation='relu'),
    Dense(1)
])

# Load pre-trained weights
model.load_weights('pretrained_lstm_weights.h5')

Step 3: Freeze Some Layers and Re-Train on New Data

pythonCopy code# Freeze all layers except the last LSTM layer and Dense layer
for layer in model.layers[:-2]:
    layer.trainable = False

# Compile model
model.compile(optimizer='adam', loss='mse')

# Train on new data
X_train = np.random.rand(100, 10, 1)  #Replace with ur actual data
y_train = np.random.rand(100, 1)      #Replace with ur actual data
model.fit(X_train, y_train, epochs=10, batch_size=16)

[This documentation provides an overview of transfer learning concepts, types, approaches, applications, and implementation techniques.]

0
Subscribe to my newsletter

Read articles from Gargi Kaushik directly inside your inbox. Subscribe to the newsletter, and don't miss out.

Written by

Gargi Kaushik
Gargi Kaushik