Multi-Class Image Classification with FashionMNIST in PyTorch

Tanayendu BariTanayendu Bari
10 min read

Introduction

In this tutorial, we’ll walk through building a multi-class image classification model using PyTorch and the FashionMNIST dataset.

Plot of a Subset of Images From the Fashion-MNIST Dataset

This task involves identifying 10 categories of clothing such as shirts, trousers, and sneakers. PyTorch makes it easy to build, train, and evaluate deep learning models with minimal boilerplate.

Step 1: Import Required Libraries

Before we start building the model, we need to import the essential libraries from PyTorch and Torchvision:

import torch
import torchvision
from torch import nn
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt

Explanation:

  • torch: The core PyTorch library for tensor operations and deep learning components.

  • torchvision: A utility library for image datasets and transformations.

  • datasets: Contains ready-to-use datasets like MNIST, CIFAR10, and FashionMNIST.

  • transforms: Provides tools to preprocess and normalize images.

  • ToTensor: Converts images from PIL format to PyTorch tensors.

  • matplotlib.pyplot: A popular library for visualizing images and plots.

This setup is essential for loading, transforming, and visualizing the FashionMNIST dataset, which we’ll use for multi-class classification.

Step 2: Set the Device (CPU or GPU)

To make our model run efficiently, we should check if a GPU (CUDA) is available and use it. Otherwise, we'll fall back to the CPU.

device = "cuda" if torch.cuda.is_available() else "cpu"

Why This Matters:

  • GPU acceleration significantly speeds up training for deep learning models.

  • If you’re working on Google Colab, Kaggle, or a local machine with a compatible GPU, PyTorch will automatically use it with this line.

  • This setup ensures that your code works both on CPU-only environments and on systems with CUDA-enabled GPUs.

From this point onward, we’ll make sure to send all our data and models to this device.

Step 3: Load the FashionMNIST Dataset

We'll use the FashionMNIST dataset provided by torchvision.datasets. It contains 28x28 grayscale images of clothing items, each labeled with one of 10 classes such as shirts, trousers, shoes, etc.

Here's how to download and load it into your project:

from torchvision import datasets
from torchvision.transforms import ToTensor

# Download and load the training dataset
train_data = datasets.FashionMNIST(
    root='data',            # Directory to store the dataset
    train=True,             # Load the training set
    download=True,          # Download if it's not already available
    transform=ToTensor(),   # Convert PIL images to PyTorch tensors
    target_transform=None   # We’ll use raw integer labels (0–9)
)

# Download and load the test dataset
test_data = datasets.FashionMNIST(
    root='data',
    train=False,            # Load the test set
    transform=ToTensor(),
    download=True,
    target_transform=None
)

What Happens Here:

  • PyTorch will automatically download the dataset if it’s not already present.

  • Each image is converted into a tensor using ToTensor(), scaling pixel values to the range [0, 1].

  • The data is stored in a folder called data/.

Step 4: Visualize a Sample Image

Let’s take a quick look at one of the training images to better understand the dataset. We'll use Matplotlib for visualization.

import matplotlib.pyplot as plt

# Get the second image and its label from the training dataset
image, label = train_data[1]

# Print the shape of the image tensor
print(f"Image shape is: {image.shape}")

# Plot the image
plt.imshow(image.squeeze(), cmap='gray')
plt.title(class_names[label])
plt.axis(False)

Explanation:

  • train_data[1] gives us the second image and its label from the dataset.

  • image.shape will be [1, 28, 28], indicating a single-channel (grayscale) image of size 28x28.

  • squeeze() removes the single-channel dimension so Matplotlib can display it properly.

  • cmap='gray' renders the image in grayscale.

  • class_names[label] maps the numeric label (e.g., 0) to its actual class name (e.g., 'T-shirt/top').

To make this work, you should define the list of class names before this block:

class_names = train_data.classes

5: Prepare the Data Loaders

Now that we've loaded the dataset, we need to prepare it for training using PyTorch’s DataLoade``r. This allows us to load the data in m**ini-batches**, which improves training efficiency and supports GPU processing.

from torch.utils.data import DataLoader

# Create the DataLoader for training data
train_dataloader = DataLoader(
    dataset=train_data,
    batch_size=32,  # Number of samples per batch
    shuffle=True    # Shuffle data at every epoch
)

# Create the DataLoader for test data
test_dataloader = DataLoader(
    dataset=test_data,
    batch_size=32,
    shuffle=False   # No need to shuffle test data
)

Why Use a DataLoader?

  • Batching allows the model to process multiple images simultaneously, speeding up training.

  • shuffle=True ensures the training data is randomized each epoch, helping the model generalize better.

  • batch_size=32 is a commonly used value that balances performance and memory usage.

Step 6: Define a Convolutional Neural Network (CNN)

For image classification tasks like FashionMNIST, Convolutional Neural Networks (CNNs) are highly effective because they can capture spatial hierarchies in images using convolutional layers.

Let’s define a small CNN architecture inspired by the VGGNet family, called TinyVGG:

class TinyVGG(nn.Module):
    def __init__(self, input_shape: int, hidden_units: int, output_shape: int):
        super().__init__()

        # Convolutional blocks
        self.block_1 = nn.Sequential(
            nn.Conv2d(input_shape, hidden_units, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(hidden_units, hidden_units, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )

        self.block_2 = nn.Sequential(
            nn.Conv2d(hidden_units, hidden_units, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(hidden_units, hidden_units, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )

        # Infer flattened size
        with torch.no_grad():
            sample_input = torch.randn(1, input_shape, 28, 28)  # assuming 28x28 image
            sample_output = self.block_2(self.block_1(sample_input))
            self.flattened_size = sample_output.view(1, -1).shape[1]

        # Classifier
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(self.flattened_size, output_shape)
        )

    def forward(self, x):
        x = self.block_1(x)
        x = self.block_2(x)
        x = self.classifier(x)
        return x

Explanation:

  • The model takes 28x28 grayscale images (so input_shape = 1).

  • It uses two convolutional blocks:

    • Each block has two convolutional layers followed by ReLU activation and a MaxPooling layer.
  • After the convolutions, the image tensor size is reduced from 28x2814x147x7.

  • The final layer is a fully connected (linear) layer that maps the extracted features to 10 output classes.

Step 7: Instantiate the Model and Move it to Device

Now that we’ve defined the TinyVGG model architecture, let’s create an instance of the model and move it to the appropriate device (CPU or GPU) for training.

# Create an instance of the model
model = TinyVGG(input_shape=1, hidden_units=10, output_shape=10)

# Move the model to the available device (CPU or GPU)
model = model.to(device)

Explanation:

  • input_shape=1 because FashionMNIST images are grayscale (1 channel).

  • hidden_units=10 defines how many filters each convolutional layer will learn (you can increase this for a deeper model).

  • output_shape=10 corresponds to the 10 classes in the FashionMNIST dataset.

  • model.to(device) ensures the model runs on GPU if available, otherwise on CPU.

Step 8: Define Loss Function, Optimizer, and Accuracy Metric

To train a neural network, we need:

  • A loss function to measure how wrong the model's predictions are.

  • An optimizer to adjust the model’s parameters to minimize that loss.

  • A metric like accuracy to evaluate how well the model is performing.

import torch.nn as nn
import torch

# Define the loss function
loss_fn = nn.CrossEntropyLoss()

# Define the optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

We’re using:

  • CrossEntropyLoss because it's the standard loss function for multi-class classification.

  • SGD (Stochastic Gradient Descent) with a learning rate of 0.1 as the optimizer. You can later experiment with optimizers like Adam for faster convergence.

Accuracy Function

We'll also define a function to calculate accuracy:

def accuracy_fn(y_true, y_pred):
    """
    Calculates accuracy between true labels and predicted labels.
    Args:
        y_true (torch.Tensor): Ground truth labels.
        y_pred (torch.Tensor): Predicted class labels.
    Returns:
        Float accuracy score (e.g., 78.5)
    """
    correct = torch.eq(y_true, y_pred).sum().item()
    acc = (correct / len(y_pred)) * 100
    return acc

Note:
Do not call accuracy_fn = accuracy_fn(y_true, y_pred) at this stage, because y_true and y_pred don’t exist yet. We'll use this function inside our training and evaluation loops after the model makes predictions.

Step 9: Define the Training Loop

Now that we have our model, loss function, optimizer, and accuracy metric ready, it’s time to build the training loop.

This loop goes through the dataset in batches, updates the model’s weights using backpropagation, and tracks performance.

from tqdm.auto import tqdm
torch.manual_seed(41)  # Set a seed for reproducibility

def train_step(model: torch.nn.Module,
               data_loader: torch.utils.data.DataLoader,
               loss_fn: torch.nn.Module,
               optimizer: torch.optim.Optimizer,
               accuracy_fn,
               device: torch.device = device):

    train_loss, train_acc = 0, 0
    model.to(device)

    for batch, (X, y) in enumerate(data_loader):
        # Move data to the target device (CPU or GPU)
        X, y = X.to(device), y.to(device)

        # 1. Forward pass
        y_pred = model(X)

        # 2. Compute the loss
        loss = loss_fn(y_pred, y)
        train_loss += loss

        # 3. Calculate accuracy
        train_acc += accuracy_fn(y_true=y, y_pred=y_pred.argmax(dim=1))  # Convert logits to predicted labels

        # 4. Backpropagation
        optimizer.zero_grad()   # Clear gradients
        loss.backward()         # Backpropagate
        optimizer.step()        # Update weights

    # Average loss and accuracy across batches
    train_loss /= len(data_loader)
    train_acc /= len(data_loader)
    print(f"Train loss: {train_loss:.5f} | Train accuracy: {train_acc:.2f}%")

What’s Happening Here:

  • DataLoader feeds the model mini-batches of images and labels.

  • The model makes predictions with a forward pass.

  • The loss is calculated and used to compute gradients via loss.backward().

  • The optimizer updates the model weights using optimizer.step().

  • Accuracy is tracked across the entire dataset for reporting.

Note: We call .argmax(dim=1) to convert raw model outputs (logits) into predicted class labels.

This function is modular and reusable—you can call it for every training epoch!

Step 10: Define the Evaluation Loop

After training your model, it's essential to evaluate its performance on unseen data. The test_step function helps you do that by calculating the test loss and test accuracy across the validation or test dataset.

Here’s the complete evaluation function:

def test_step(data_loader: torch.utils.data.DataLoader,
              model: torch.nn.Module,
              loss_fn: torch.nn.Module,
              accuracy_fn,
              device: torch.device = device):

    test_loss, test_acc = 0, 0
    model.to(device)
    model.eval()  # Set the model to evaluation mode

    # Disable gradient tracking for faster inference
    with torch.inference_mode():
        for X, y in data_loader:
            # Move data to the target device
            X, y = X.to(device), y.to(device)

            # 1. Forward pass
            test_pred = model(X)

            # 2. Calculate loss and accuracy
            test_loss += loss_fn(test_pred, y)
            test_acc += accuracy_fn(
                y_true=y,
                y_pred=test_pred.argmax(dim=1)  # Convert logits to predicted class labels
            )

    # Average the results over all batches
    test_loss /= len(data_loader)
    test_acc /= len(data_loader)
    print(f"Test loss: {test_loss:.5f} | Test accuracy: {test_acc:.2f}%\n")

Key Highlights:

  • model.eval() tells PyTorch that the model is in inference mode, which disables certain layers like dropout and batch normalization.

  • with torch.inference_mode() disables gradient calculations, saving memory and improving speed during testing.

  • We calculate loss and accuracy for each batch and then average them across the entire dataset.

This function is structured almost the same as the training loop—except we don’t backpropagate or update weights.

Step 11: Create a General Model Evaluation Function

To make evaluation more modular and reusable, let’s define a utility function that returns the model's name, average loss, and accuracy in a structured format. This is helpful when comparing multiple models later on.

torch.manual_seed(42)  # For reproducibility

def eval_mode(model: torch.nn.Module,
              data_loader: torch.utils.data.DataLoader,
              loss_fn: torch.nn.Module,
              accuracy_fn):

    loss, acc = 0, 0
    model.eval()  # Set the model to evaluation mode

    with torch.inference_mode():  # Disable gradient tracking
        for X, y in data_loader:
            X, y = X.to(device), y.to(device)
            y_pred = model(X)
            loss += loss_fn(y_pred, y)
            acc += accuracy_fn(y_true=y, y_pred=y_pred.argmax(dim=1))

    loss /= len(data_loader)
    acc /= len(data_loader)

    return {
        "Model_name": model.__class__.__name__,
        "Model_loss": loss.item(),
        "Model_acc": acc
    }

What This Does:

  • This is a general-purpose evaluation function.

  • It calculates average loss and accuracy across a dataset.

  • It returns a dictionary with:

    • The model’s class name

    • Final loss

    • Final accuracy

  • Useful for benchmarking or reporting results.

Conclusion

In this tutorial, you learned how to build a multi-class image classification model using PyTorch and the FashionMNIST dataset. We covered:

  • Understanding multi-class classification

  • Loading and preprocessing image data

  • Building a custom CNN model (TinyVGG)

  • Writing modular training and evaluation loops

  • Calculating accuracy and visualizing predictions

This end-to-end pipeline is a strong foundation for many real-world computer vision tasks.

References

  1. PyTorch Documentation

  2. Torchvision Datasets and Transforms

  3. FashionMNIST Dataset

  4. PyTorch Training Loop Best Practices

  5. Deep Learning with PyTorch: A 60 Minute Blitz

  6. CNN Architectures – VGGNet

  7. Scikit-learn Metrics for Classification

Also read about Binary Classification here https://binary-classification-using-pytorch.hashnode.dev/ and for CNNs https://convolutional-neural-networks-cnns.hashnode.dev/

0
Subscribe to my newsletter

Read articles from Tanayendu Bari directly inside your inbox. Subscribe to the newsletter, and don't miss out.

Written by

Tanayendu Bari
Tanayendu Bari