Multi-Class Image Classification with FashionMNIST in PyTorch


Introduction
In this tutorial, we’ll walk through building a multi-class image classification model using PyTorch and the FashionMNIST dataset.
This task involves identifying 10 categories of clothing such as shirts, trousers, and sneakers. PyTorch makes it easy to build, train, and evaluate deep learning models with minimal boilerplate.
Step 1: Import Required Libraries
Before we start building the model, we need to import the essential libraries from PyTorch and Torchvision:
import torch
import torchvision
from torch import nn
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
Explanation:
torch
: The core PyTorch library for tensor operations and deep learning components.torchvision
: A utility library for image datasets and transformations.datasets
: Contains ready-to-use datasets like MNIST, CIFAR10, and FashionMNIST.transforms
: Provides tools to preprocess and normalize images.ToTensor
: Converts images from PIL format to PyTorch tensors.matplotlib.pyplot
: A popular library for visualizing images and plots.
This setup is essential for loading, transforming, and visualizing the FashionMNIST dataset, which we’ll use for multi-class classification.
Step 2: Set the Device (CPU or GPU)
To make our model run efficiently, we should check if a GPU (CUDA) is available and use it. Otherwise, we'll fall back to the CPU.
device = "cuda" if torch.cuda.is_available() else "cpu"
Why This Matters:
GPU acceleration significantly speeds up training for deep learning models.
If you’re working on Google Colab, Kaggle, or a local machine with a compatible GPU, PyTorch will automatically use it with this line.
This setup ensures that your code works both on CPU-only environments and on systems with CUDA-enabled GPUs.
From this point onward, we’ll make sure to send all our data and models to this device
.
Step 3: Load the FashionMNIST Dataset
We'll use the FashionMNIST dataset provided by torchvision.datasets
. It contains 28x28 grayscale images of clothing items, each labeled with one of 10 classes such as shirts, trousers, shoes, etc.
Here's how to download and load it into your project:
from torchvision import datasets
from torchvision.transforms import ToTensor
# Download and load the training dataset
train_data = datasets.FashionMNIST(
root='data', # Directory to store the dataset
train=True, # Load the training set
download=True, # Download if it's not already available
transform=ToTensor(), # Convert PIL images to PyTorch tensors
target_transform=None # We’ll use raw integer labels (0–9)
)
# Download and load the test dataset
test_data = datasets.FashionMNIST(
root='data',
train=False, # Load the test set
transform=ToTensor(),
download=True,
target_transform=None
)
What Happens Here:
PyTorch will automatically download the dataset if it’s not already present.
Each image is converted into a tensor using
ToTensor()
, scaling pixel values to the range [0, 1].The data is stored in a folder called
data/
.
Step 4: Visualize a Sample Image
Let’s take a quick look at one of the training images to better understand the dataset. We'll use Matplotlib for visualization.
import matplotlib.pyplot as plt
# Get the second image and its label from the training dataset
image, label = train_data[1]
# Print the shape of the image tensor
print(f"Image shape is: {image.shape}")
# Plot the image
plt.imshow(image.squeeze(), cmap='gray')
plt.title(class_names[label])
plt.axis(False)
Explanation:
train_data[1]
gives us the second image and its label from the dataset.image.shape
will be[1, 28, 28]
, indicating a single-channel (grayscale) image of size 28x28.squeeze()
removes the single-channel dimension so Matplotlib can display it properly.cmap='gray'
renders the image in grayscale.class_names[label]
maps the numeric label (e.g.,0
) to its actual class name (e.g.,'T-shirt/top'
).
To make this work, you should define the list of class names before this block:
class_names = train_data.classes
5: Prepare the Data Loaders
Now that we've loaded the dataset, we need to prepare it for training using PyTorch’s DataLoade``r
. This allows us to load the data in m**ini-batches**, which improves training efficiency and supports GPU processing.
from torch.utils.data import DataLoader
# Create the DataLoader for training data
train_dataloader = DataLoader(
dataset=train_data,
batch_size=32, # Number of samples per batch
shuffle=True # Shuffle data at every epoch
)
# Create the DataLoader for test data
test_dataloader = DataLoader(
dataset=test_data,
batch_size=32,
shuffle=False # No need to shuffle test data
)
Why Use a DataLoader?
Batching allows the model to process multiple images simultaneously, speeding up training.
shuffle=True
ensures the training data is randomized each epoch, helping the model generalize better.batch_size=32
is a commonly used value that balances performance and memory usage.
Step 6: Define a Convolutional Neural Network (CNN)
For image classification tasks like FashionMNIST, Convolutional Neural Networks (CNNs) are highly effective because they can capture spatial hierarchies in images using convolutional layers.
Let’s define a small CNN architecture inspired by the VGGNet family, called TinyVGG
:
class TinyVGG(nn.Module):
def __init__(self, input_shape: int, hidden_units: int, output_shape: int):
super().__init__()
# Convolutional blocks
self.block_1 = nn.Sequential(
nn.Conv2d(input_shape, hidden_units, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(hidden_units, hidden_units, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2)
)
self.block_2 = nn.Sequential(
nn.Conv2d(hidden_units, hidden_units, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(hidden_units, hidden_units, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2)
)
# Infer flattened size
with torch.no_grad():
sample_input = torch.randn(1, input_shape, 28, 28) # assuming 28x28 image
sample_output = self.block_2(self.block_1(sample_input))
self.flattened_size = sample_output.view(1, -1).shape[1]
# Classifier
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(self.flattened_size, output_shape)
)
def forward(self, x):
x = self.block_1(x)
x = self.block_2(x)
x = self.classifier(x)
return x
Explanation:
The model takes 28x28 grayscale images (so
input_shape = 1
).It uses two convolutional blocks:
- Each block has two convolutional layers followed by ReLU activation and a MaxPooling layer.
After the convolutions, the image tensor size is reduced from
28x28
→14x14
→7x7
.The final layer is a fully connected (linear) layer that maps the extracted features to 10 output classes.
Step 7: Instantiate the Model and Move it to Device
Now that we’ve defined the TinyVGG
model architecture, let’s create an instance of the model and move it to the appropriate device (CPU or GPU) for training.
# Create an instance of the model
model = TinyVGG(input_shape=1, hidden_units=10, output_shape=10)
# Move the model to the available device (CPU or GPU)
model = model.to(device)
Explanation:
input_shape=1
because FashionMNIST images are grayscale (1 channel).hidden_units=10
defines how many filters each convolutional layer will learn (you can increase this for a deeper model).output_shape=10
corresponds to the 10 classes in the FashionMNIST dataset.model.to
(device)
ensures the model runs on GPU if available, otherwise on CPU.
Step 8: Define Loss Function, Optimizer, and Accuracy Metric
To train a neural network, we need:
A loss function to measure how wrong the model's predictions are.
An optimizer to adjust the model’s parameters to minimize that loss.
A metric like accuracy to evaluate how well the model is performing.
import torch.nn as nn
import torch
# Define the loss function
loss_fn = nn.CrossEntropyLoss()
# Define the optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
We’re using:
CrossEntropyLoss because it's the standard loss function for multi-class classification.
SGD (Stochastic Gradient Descent) with a learning rate of
0.1
as the optimizer. You can later experiment with optimizers likeAdam
for faster convergence.
Accuracy Function
We'll also define a function to calculate accuracy:
def accuracy_fn(y_true, y_pred):
"""
Calculates accuracy between true labels and predicted labels.
Args:
y_true (torch.Tensor): Ground truth labels.
y_pred (torch.Tensor): Predicted class labels.
Returns:
Float accuracy score (e.g., 78.5)
"""
correct = torch.eq(y_true, y_pred).sum().item()
acc = (correct / len(y_pred)) * 100
return acc
Note:
Do not call accuracy_fn = accuracy_fn(y_true, y_pred)
at this stage, because y_true
and y_pred
don’t exist yet. We'll use this function inside our training and evaluation loops after the model makes predictions.
Step 9: Define the Training Loop
Now that we have our model, loss function, optimizer, and accuracy metric ready, it’s time to build the training loop.
This loop goes through the dataset in batches, updates the model’s weights using backpropagation, and tracks performance.
from tqdm.auto import tqdm
torch.manual_seed(41) # Set a seed for reproducibility
def train_step(model: torch.nn.Module,
data_loader: torch.utils.data.DataLoader,
loss_fn: torch.nn.Module,
optimizer: torch.optim.Optimizer,
accuracy_fn,
device: torch.device = device):
train_loss, train_acc = 0, 0
model.to(device)
for batch, (X, y) in enumerate(data_loader):
# Move data to the target device (CPU or GPU)
X, y = X.to(device), y.to(device)
# 1. Forward pass
y_pred = model(X)
# 2. Compute the loss
loss = loss_fn(y_pred, y)
train_loss += loss
# 3. Calculate accuracy
train_acc += accuracy_fn(y_true=y, y_pred=y_pred.argmax(dim=1)) # Convert logits to predicted labels
# 4. Backpropagation
optimizer.zero_grad() # Clear gradients
loss.backward() # Backpropagate
optimizer.step() # Update weights
# Average loss and accuracy across batches
train_loss /= len(data_loader)
train_acc /= len(data_loader)
print(f"Train loss: {train_loss:.5f} | Train accuracy: {train_acc:.2f}%")
What’s Happening Here:
DataLoader feeds the model mini-batches of images and labels.
The model makes predictions with a forward pass.
The loss is calculated and used to compute gradients via
loss.backward()
.The optimizer updates the model weights using
optimizer.step()
.Accuracy is tracked across the entire dataset for reporting.
Note: We call
.argmax(dim=1)
to convert raw model outputs (logits) into predicted class labels.
This function is modular and reusable—you can call it for every training epoch!
Step 10: Define the Evaluation Loop
After training your model, it's essential to evaluate its performance on unseen data. The test_step
function helps you do that by calculating the test loss and test accuracy across the validation or test dataset.
Here’s the complete evaluation function:
def test_step(data_loader: torch.utils.data.DataLoader,
model: torch.nn.Module,
loss_fn: torch.nn.Module,
accuracy_fn,
device: torch.device = device):
test_loss, test_acc = 0, 0
model.to(device)
model.eval() # Set the model to evaluation mode
# Disable gradient tracking for faster inference
with torch.inference_mode():
for X, y in data_loader:
# Move data to the target device
X, y = X.to(device), y.to(device)
# 1. Forward pass
test_pred = model(X)
# 2. Calculate loss and accuracy
test_loss += loss_fn(test_pred, y)
test_acc += accuracy_fn(
y_true=y,
y_pred=test_pred.argmax(dim=1) # Convert logits to predicted class labels
)
# Average the results over all batches
test_loss /= len(data_loader)
test_acc /= len(data_loader)
print(f"Test loss: {test_loss:.5f} | Test accuracy: {test_acc:.2f}%\n")
Key Highlights:
model.eval()
tells PyTorch that the model is in inference mode, which disables certain layers like dropout and batch normalization.with torch.inference_mode()
disables gradient calculations, saving memory and improving speed during testing.We calculate loss and accuracy for each batch and then average them across the entire dataset.
This function is structured almost the same as the training loop—except we don’t backpropagate or update weights.
Step 11: Create a General Model Evaluation Function
To make evaluation more modular and reusable, let’s define a utility function that returns the model's name, average loss, and accuracy in a structured format. This is helpful when comparing multiple models later on.
torch.manual_seed(42) # For reproducibility
def eval_mode(model: torch.nn.Module,
data_loader: torch.utils.data.DataLoader,
loss_fn: torch.nn.Module,
accuracy_fn):
loss, acc = 0, 0
model.eval() # Set the model to evaluation mode
with torch.inference_mode(): # Disable gradient tracking
for X, y in data_loader:
X, y = X.to(device), y.to(device)
y_pred = model(X)
loss += loss_fn(y_pred, y)
acc += accuracy_fn(y_true=y, y_pred=y_pred.argmax(dim=1))
loss /= len(data_loader)
acc /= len(data_loader)
return {
"Model_name": model.__class__.__name__,
"Model_loss": loss.item(),
"Model_acc": acc
}
What This Does:
This is a general-purpose evaluation function.
It calculates average loss and accuracy across a dataset.
It returns a dictionary with:
The model’s class name
Final loss
Final accuracy
Useful for benchmarking or reporting results.
Conclusion
In this tutorial, you learned how to build a multi-class image classification model using PyTorch and the FashionMNIST dataset. We covered:
Understanding multi-class classification
Loading and preprocessing image data
Building a custom CNN model (
TinyVGG
)Writing modular training and evaluation loops
Calculating accuracy and visualizing predictions
This end-to-end pipeline is a strong foundation for many real-world computer vision tasks.
References
Also read about Binary Classification here https://binary-classification-using-pytorch.hashnode.dev/ and for CNNs https://convolutional-neural-networks-cnns.hashnode.dev/
Subscribe to my newsletter
Read articles from Tanayendu Bari directly inside your inbox. Subscribe to the newsletter, and don't miss out.
Written by
