Training GPT-2 From Scratch: A Beginner-Friendly Step-by-Step Guide

Rafal JackiewiczRafal Jackiewicz
10 min read

Training a GPT-2 model from scratch is a rewarding experience, especially if you want to learn about natural language processing and get hands-on with machine learning models. This guide will walk you through the process step-by-step, with simplified explanations to help you understand each concept and why it's important.

Table of Contents

  1. Setting Up Your Working Environment

  2. Loading the Dataset

  3. Loading the Model and Tokenizer

  4. Configuring GPT-2 with GPT2Config

  5. Training the Model

  6. Running Inference (Making the Model Generate Text)

Let's jump right in!

1. Setting Up Your Working Environment

Before we start training, we need to set up the environment and install some required libraries. Here's a quick overview of the tools you'll use:

  • Transformers: A library for working with transformer models like GPT-2.

  • DeepLake: Helps manage and load large datasets.

  • WandB (Weights and Biases): Tracks experiments and helps visualize model training.

  • Accelerate: Speeds up model training.

Install these packages by running the following command:

!pip install transformers deeplake wandb accelerate

Next, log in to Weights and Biases to track the training progress. You'll need an account and an API key:

!wandb login

For full training, it’s recommended to use a powerful GPU, like an NVIDIA A100, as it speeds up training significantly.

2. Loading the Dataset

In this guide, we use the OpenWebText dataset, which is a collection of Reddit posts with at least three upvotes. This dataset is suitable for creating a foundational model because it contains a wide variety of content.

The dataset is structured into two main parts:

  • Text: This is the raw written content.

  • Tokens: These are pieces of text broken down to make it easier for the model to process, such as words or parts of words.

We can easily load the dataset using DeepLake:

import deeplake

# Load training and validation datasets
ds = deeplake.load('hub://activeloop/openwebtext-train')
ds_val = deeplake.load('hub://activeloop/openwebtext-val')

3. Loading the Model and Tokenizer

To train the model, we need to load GPT-2’s architecture and tokenizer. The tokenizer helps convert raw text into numbers (tokens) that the model can understand. We also need to adjust some settings to make sure the model can handle our specific dataset.

First, we load the tokenizer and set a padding token to ensure all text samples are the same length:

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

The tokenizer is responsible for converting text into tokens, which are numerical representations of words or parts of words. These tokens are easier for the model to understand and work with.

Next, we define a function to tokenize the dataset and create data loaders for training:

def get_tokens_transform(tokenizer):
    def tokens_transform(sample_in):
        tokenized_text = tokenizer(
            sample_in["text"],
            truncation=True,
            max_length=512,
            padding='max_length',
            return_tensors="pt"
        )
        tokenized_text = tokenized_text["input_ids"][0]
        return {
            "input_ids": tokenized_text,
            "labels": tokenized_text
        }
    return tokens_transform

# Create data loaders
train_loader = ds.dataloader().batch(32).transform(get_tokens_transform(tokenizer)).pytorch()
val_loader = ds_val.dataloader().batch(32).transform(get_tokens_transform(tokenizer)).pytorch()

In the function above, we create input_ids and labels for each sample:

  • input_ids: These are the tokens that represent the input text. Think of them as the words (or parts of words) turned into numbers, which the model can process.

  • labels: These are the same as the input_ids in our case, because GPT-2 is a language model that tries to predict the next word in a sequence. The labels are used to tell the model what the correct output should be for each input. During training, the model looks at the input tokens (input_ids) and tries to predict the next token. The labels are what we use to check if the model's predictions are correct.

By setting input_ids and labels to the same values, we're effectively training the model to predict the next word in the sequence based on the previous words. This is the core of how language models like GPT-2 learn to generate text.

Padding Explained

Padding is used to ensure that all input sequences have the same length. When we feed the model batches of text, the sequences might have different lengths (e.g., one sentence may have 10 words while another has 20). However, models like GPT-2 work better with fixed-length inputs because it allows for more efficient parallel processing.

To solve this problem, we pad shorter sequences so that they match the length of the longest sequence in the batch. Padding simply involves adding special tokens (in this case, the end-of-sequence token, eos_token) to the end of shorter sequences until they all reach the same length. This ensures that all sequences are the same size, allowing the model to handle them effectively during training.

For example, if we have two sentences:

  • Sentence 1: "The cat is sleeping." (5 tokens)

  • Sentence 2: "The dog is barking loudly in the yard." (9 tokens)

To make both sentences the same length, we pad Sentence 1 with 4 padding tokens to reach 9 tokens in total. This allows both sentences to be processed together without any issues, ensuring consistency in the model's input size.

Determining Maximum Sequence Length

The maximum sequence length is a predefined limit that determines how long each sequence of tokens can be. In our example, we set max_length=512, which means that the model can handle sequences of up to 512 tokens. This limit is chosen based on the model's architecture and the computational resources available.

If a text sequence is longer than the maximum length, it needs to be truncated. Truncation means cutting off the extra tokens so that the sequence fits within the allowed length. This helps ensure that all sequences are of a manageable size and that they don't exceed the model's capabilities.

For example, if we have a sequence with 600 tokens and the maximum length is 512, we truncate the sequence by removing the last 88 tokens, keeping only the first 512 tokens. This way, the model can handle the input without running into issues related to sequence length.

The choice of maximum sequence length depends on several factors, including:

  1. Model Capacity: Larger models can handle longer sequences more effectively, while smaller models may struggle.

  2. Computational Resources: Longer sequences require more memory and processing power, so the maximum length should be chosen based on the available hardware.

  3. Nature of the Dataset: If the dataset contains mostly short sentences, a lower maximum sequence length may be sufficient. For more complex and lengthy content, a higher limit might be needed.

Using a balanced maximum sequence length helps ensure that the model can process the input efficiently without running out of memory or computational resources.

n_positions and max_length

The n_positions parameter in GPT-2 defines the maximum number of tokens that the model can handle in a single input sequence. In other words, it represents the maximum context length that the model was pre-trained to understand, typically set to 1024 tokens for standard GPT-2.

When setting max_length during training or inference, it’s crucial that max_length is less than or equal to n_positions. This ensures that the input sequence can be fully processed by the model without exceeding its internal limit. If max_length exceeds n_positions, the model will not be able to effectively process the additional tokens, which may lead to errors or reduced performance.

In practice, max_length should be set to a value ≤ n_positions to ensure compatibility. If a longer sequence is provided, only the first n_positions tokens are considered, and anything beyond that will be truncated. This relationship ensures that the model operates within the context length it was designed for.

The length of the input text also affects the length of the generated text due to the context window size of GPT-2. If the input is long, it leaves less space for the model to generate additional tokens. For example, if the input text is already 800 tokens long, then GPT-2 can only generate up to 224 additional tokens before hitting the 1024-token context window limit. Therefore, keeping the input concise allows the model more room to generate meaningful output.

4. Configuring GPT-2 with GPT2Config

If you want to pre-train the GPT-2 model from scratch, you can use the GPT2Config class to create a custom configuration for the model. The GPT2Config class allows you to set various parameters, such as the number of layers, hidden units, attention heads, and the n_positions parameter.

Here’s an example of how to use GPT2Config to create a custom GPT-2 model:

from transformers import GPT2Config, GPT2LMHeadModel

# Define a custom configuration
config = GPT2Config(
    vocab_size=50257,        # Size of the vocabulary
    n_positions=1024,        # Maximum sequence length (context window)
    n_ctx=1024,              # Context size
    n_embd=768,              # Size of the embedding layer
    n_layer=12,              # Number of transformer blocks (layers)
    n_head=12                # Number of attention heads
)

# Initialize a GPT-2 model with the custom configuration
model = GPT2LMHeadModel(config)

In the code above, we define a custom GPT2Config to specify the architecture of our GPT-2 model:

  • vocab_size: The size of the vocabulary used by the tokenizer. For GPT-2, this is typically 50257.

  • n_positions: The maximum number of tokens the model can handle in a single sequence (context window). This should be set based on your dataset and hardware capabilities. Typically, it is set to a large value to ensure flexibility (e.g., 512, 1024, or 2048).

  • n_ctx: The dimensionality of the causal mask, usually set to the same value as n_positions. It defines how many tokens the model can handle at once (context window). In GPT-2, n_positions is set to 1024 tokens, meaning the model can process up to 1024 tokens, including both input and generated tokens. Setting n_ctx equal to n_positions ensures consistency in the model's behavior during both training and inference.

  • n_embd: The size of the embedding layer. This controls the dimensionality of the token embeddings and affects the model's capacity.

  • n_layer: The number of transformer blocks (layers) in the model. More layers generally mean a more powerful model, but also require more computational resources.

  • n_head: The number of attention heads in each transformer block. More attention heads can improve the model's ability to learn complex relationships in the data.

Using GPT2Config is especially useful if you want to customize the model's capacity to suit your dataset and computational resources. You can adjust these parameters to create a smaller or larger version of GPT-2, depending on your needs.

5. Training the Model

Now, we load the GPT-2 model. We'll start with a smaller version of GPT-2 to make it easier to train. We’ll also configure hyperparameters to control how the model learns, like how many training epochs to run and the learning rate.

Here's how you can load and set up the model for training:

from transformers import GPT2LMHeadModel, TrainingArguments, Trainer

# Load GPT-2 model
model = GPT2LMHeadModel.from_pretrained("gpt2")

# Define training arguments
args = TrainingArguments(
    output_dir="GPT2-training-from-scratch",
    evaluation_strategy="steps",
    save_strategy="steps",
    eval_steps=500,
    save_steps=500,
    num_train_epochs=2,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    logging_steps=10,
    learning_rate=5e-4,
    report_to="wandb"
)

# Train the model
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_loader,
    eval_dataset=val_loader
)

trainer.train()

# Save the model after training
model.save_pretrained('./trained_gpt2_model')
tokenizer.save_pretrained('./trained_gpt2_model')

During training, the Trainer will handle both model evaluation and saving checkpoints automatically.

6. Running Inference (Generating Text)

Once training is complete, you can use the model to generate text. The simplest way to do this is by using the pipeline functionality from the Transformers library:

from transformers import GPT2LMHeadModel, AutoTokenizer, pipeline

# Load the saved model and tokenizer
model = GPT2LMHeadModel.from_pretrained('./trained_gpt2_model')
tokenizer = AutoTokenizer.from_pretrained('./trained_gpt2_model')

# Create a text generation pipeline
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)

# Generate some text
input_text = "The future of AI is"
output = pipe(input_text, max_length=50, num_return_sequences=1)
print(output)

This code snippet will generate text based on the input, such as "The future of AI is...", and the model will complete it based on what it learned during training.

Summary

In this guide, we covered setting up your environment, loading the dataset, configuring the model, training the GPT-2 model from scratch, and generating text. We used the OpenWebText dataset, which is well-structured with text and tokens for ease of processing. Each step was simplified to help you understand not just what to do, but why each part is important.

Feel free to modify the settings, try out different datasets, and experiment with model configurations to deepen your understanding!


Author Bio

Rafal Jackiewicz is an author of books about programming in C and Java. You can find more information about him and his work on Amazon.

0
Subscribe to my newsletter

Read articles from Rafal Jackiewicz directly inside your inbox. Subscribe to the newsletter, and don't miss out.

Written by

Rafal Jackiewicz
Rafal Jackiewicz

Rafal Jackiewicz is an author of books about programming in C and Java. You can find more information about him and his work on https://www.jackiewicz.org