Train a neural network CIFAR

Anix LynchAnix Lynch
6 min read

The CIFAR-10 dataset consists of 60,000 images across 10 different classes, including airplanes, cars, birds, and more. It's a perfect starting point for anyone looking to get hands-on experience with neural networks in image classification tasks.

How to:

  • Load and preprocess image data

  • Build a neural network from scratch

  • Apply a loss function and optimizer

  • Train your model and evaluate its performance


1. Importing Necessary Libraries

import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import utils, transforms
from Cifar10Dataloader import CIFAR10
  • os, numpy, matplotlib: These are standard libraries for file management (os), numerical operations (numpy), and plotting (matplotlib).

  • torch: PyTorch's main library.

  • torch.nn: Provides modules for building neural networks.

  • torch.optim: Includes optimization algorithms like SGD.

  • torchvision.transforms: Used for data transformations like converting images to tensors.

  • Cifar10Dataloader: A custom loader for the CIFAR-10 dataset.


2. Loading the Dataset and Applying Transformations

batch_size = 4

def load_data():
    transform = transforms.Compose([
         transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    trainset = CIFAR10(root='../cifar10',  transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=False, num_workers=1)
    return trainloader
  • Transforms: Images are converted to tensors and normalized so the pixel values are in the range [-1, 1].

  • CIFAR10 DataLoader: Loads the CIFAR-10 dataset using a custom loader and applies the transforms.

  • DataLoader: This splits the dataset into small batches (in this case, size 4) for efficient training.

    Expected Output:

  • There won't be any printed output, but the dataset (CIFAR10) is now loaded and transformed into tensors, normalized between [-1, 1]. The data loader will now return small batches (4 images at a time).


3. Displaying Images from the Dataset

def show_image(img):
    img = img / 2 + 0.5  # Unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()
  • show_image(): This function takes an image and displays it using matplotlib.

  • Unnormalizing: Since images were normalized earlier, this function un-normalizes them for display.

    Expected Output:

  • No output yet, but this function will be used to visualize the images in the next step.


4. Previewing Images and Labels

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

dataiter = iter(load_data())
images, labels = dataiter.next()

# Show images
show_image(utils.make_grid(images))
# Print labels
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
  • Classes: Labels for each image category (e.g., plane, car, dog, etc.).

  • Previewing: Here, we load a small batch of images and show them with their respective labels.

    Expected Output:

  • Image Display: The function show_image() will plot a 2x2 grid of images. You’ll see 4 images from the CIFAR-10 dataset displayed using matplotlib.

  • Printed Labels: The print statement will display the class labels for those 4 images. For example:

 plane   car  frog  ship

This shows the labels corresponding to each image in the grid.


5. Defining the Neural Network (Model)

model = nn.Sequential(
    nn.Linear(3072, 128),
    nn.ReLU(),
    nn.Linear(128, 64),
    nn.ReLU(),
    nn.Linear(64, 10)
)
  • 3072: Each CIFAR-10 image is 32x32 with 3 channels (32323 = 3072).

  • 3 Layers:

    • First layer: Input size 3072, output size 128.

    • Second layer: Input size 128, output size 64.

    • Final layer: Input size 64, output size 10 (one for each class).

  • ReLU Activation: Used to add non-linearity between the linear layers.

    Expected Output:

  • There will be no output here, but this defines the architecture of the model. The model consists of 3 linear layers with ReLU activations in between.


6. Loss Function and Optimizer

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
  • CrossEntropyLoss: This is used for classification tasks. It calculates how far off the predicted outputs are from the true labels.

  • SGD (Stochastic Gradient Descent): This optimizer updates the model’s weights based on the gradients from backpropagation.

    • lr=0.001: Learning rate for how fast weights are updated.

    • momentum=0.9: Helps accelerate gradient descent by reducing oscillations.

    • Oscillation is the process of moving back and forth regularly,

      Expected Output:

  • No output, but this step sets up the loss function (CrossEntropyLoss) and the optimizer (SGD with learning rate = 0.001 and momentum = 0.9). This prepares the model for training.


7. Training the Model

def train():
    training_data = load_data()
    running_loss = 0.0

    for epoch in range(10):  # Train for 10 epochs
        for i, data in enumerate(training_data, 0):
            inputs, labels = data
            inputs = inputs.view(inputs.size(0), -1)  # Flatten images for linear layers

            optimizer.zero_grad()  # Reset gradients

            outputs = model(inputs)  # Forward pass
            loss = criterion(outputs, labels)  # Calculate loss
            loss.backward()  # Backward pass (calculate gradients)
            optimizer.step()  # Update weights

            running_loss += loss.item()
            if i % 500 == 499:  # Print loss every 500 mini-batches
                print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 500))
                running_loss = 0.0

    print('Training finished')
  • Epochs: Loops through the dataset 10 times.

  • Flattening: The images are flattened into vectors of size 3072 (32x32x3) before being fed into the linear layers.

  • Zero Gradients: Gradients from the previous batch are reset.

  • Forward Pass: Pass the inputs through the model to get outputs.

  • Loss Calculation: Compute how far the outputs are from the actual labels.

  • Backward Pass: Calculate the gradients based on the loss.

  • Weight Update: Update the model weights using the optimizer.

  • Print Loss: Loss is printed every 500 mini-batches to track how well the model is learning.

    Expected Output:

  • During training, the loss will be printed every 500 batches, showing how the loss decreases over time as the model learns. Here’s an example of what you might see:

[1,   500] loss: 2.302
[1,  1000] loss: 2.150
[2,   500] loss: 1.950
[2,  1000] loss: 1.800
...
Training finished

The loss values will keep decreasing as the model improves.


8. Evaluating the Model

def evaluate():
    dataiter = iter(load_data())
    images, labels = dataiter.next()

    show_image(utils.make_grid(images))
    print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))

    images = images.view(images.size(0), -1)
    outputs = model(images)

    _, predicted = torch.max(outputs, 1)
    print('Predicted: ', ' '.join('%5s' % classes[predicted[j]] for j in range(4)))
  • Evaluation: After training, the model is tested on unseen data.

  • Prediction: The model makes a prediction for each input image. The torch.max function is used to select the class with the highest score.

  • Predicted Output: The predicted class is printed alongside the ground truth to compare how well the model performed.

    Expected Output:

  1. Display Images: You will see another grid of 4 images from the test set.

  2. GroundTruth: The true labels for the displayed images will be printed. Example:

GroundTruth:  plane  car  frog  ship
  1. Predicted: After feeding the images into the model, the predicted class labels will be printed. For example:
Predicted:  plane  car  dog  ship

In this case, the model made 3 correct predictions (plane, car, and ship), but incorrectly classified the "frog" image as "dog."


In Summary:

  • Images and Labels: You’ll see the images and their labels printed and displayed.

  • Loss During Training: Loss values will decrease over time as the model trains.

  • Predictions: You will be able to see the ground truth labels versus the predicted labels after evaluating the model.

0
Subscribe to my newsletter

Read articles from Anix Lynch directly inside your inbox. Subscribe to the newsletter, and don't miss out.

Written by

Anix Lynch
Anix Lynch