Neural Style Transfer using PyTorch


Introduction
Have you ever wondered how apps like Prisma transform your photos into artwork that mimics famous painters like Van Gogh or Picasso? The magic behind this visual transformation is a fascinating deep learning technique called Neural Style Transfer (NST).
In this tutorial, we'll dive into implementing Neural Style Transfer from scratch using PyTorch. By the end, you'll be able to blend the content of one image with the artistic style of another, creating unique and visually stunning compositions.
Understanding Neural Style Transfer
Neural Style Transfer, first introduced by Gatys et al. in their 2015 paper "A Neural Algorithm of Artistic Style," works by extracting content features from one image and style features from another, then creating a new image that combines both.
The technique leverages the power of Convolutional Neural Networks (CNNs) to separate and recombine these features in a way that preserves the content structure while adopting the artistic style.
The Theoretical Foundation
At its core, Neural Style Transfer is an optimization problem. We start with a content image, a style image, and generate a new image that balances two competing objectives:
Content Preservation: The generated image should maintain the structural elements of the content image.
Style Transfer: The generated image should adopt the artistic style (textures, colors, brushstrokes) of the style image.
The Loss Function Concept
To achieve this balance, we define a loss function with two components:
Content Loss: Measures how different the content of the generated image is from the original content image.
Style Loss: Measures how different the style of the generated image is from the style image.
These components are weighted by parameters (alpha and beta) that control the trade-off between content preservation and style transfer.
Content Representation
Content is represented by the activations of specific layers in a pre-trained CNN (typically VGG19). Higher-level layers in CNNs capture the high-level content and semantic information of an image while being less sensitive to exact pixel values.
The content loss measures the difference between these feature representations in the content image and the generated image.
Style Representation: The Gram Matrix
Style is captured using a special mathematical construct called the Gram matrix. This matrix represents the correlations between different feature maps in a layer of the CNN.
What is a Gram Matrix?
The Gram matrix captures the statistical relationships between features in an image. When two features tend to activate together, their correlation in the Gram matrix will be high. This effectively captures texture information while discarding spatial arrangement - perfect for representing style!
By comparing the Gram matrices of the style image and the generated image across multiple layers, we can measure how well the generated image adopts the style characteristics.
Optimization Process
Once we have defined our loss function, we use gradient descent to iteratively update our generated image. Starting from either random noise or the content image, we:
Calculate the content and style losses
Compute the gradients of these losses with respect to the generated image
Update the generated image to minimize the total loss
Repeat until we achieve a satisfactory result
Prerequisites
Before we begin, make sure you have the following installed:
Python 3.6+
PyTorch
torchvision
PIL (Python Imaging Library)
You'll also need two images:
A content image (the base structure you want to preserve)
A style image (the artistic style you want to apply)
The Code Breakdown
Let's break down our implementation step by step:
1. Importing Libraries
import torch
import torch.nn as nn
import torch.optim as optim
from PIL import Image
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision.utils import save_image
We import the necessary libraries for image processing, neural networks, and optimization.
2. Creating the VGG Model
class VGG(nn.Module):
def __init__(self):
super(VGG, self).__init__()
self.chosen_features = ["0", "5", "10", "19", "28"]
self.model = models.vgg19(pretrained=True).features[:29]
def forward(self, x):
features = []
for layer_num, layer in enumerate(self.model):
x = layer(x)
if str(layer_num) in self.chosen_features:
features.append(x)
return features
Here, we define a custom VGG19 model class. VGG19 is a deep CNN pre-trained on ImageNet, which we'll use to extract features from our images.
The chosen_features
list contains the specific layer indices we're interested in. These layers capture different levels of abstraction:
Lower layers (0, 5) capture basic features like edges and textures
Middle layers (10, 19) capture more complex patterns
Higher layers (28) capture high-level content
The specific layers we've chosen correspond to:
Layer 0: First convolutional layer (low-level features)
Layer 5: After the first max pooling (textures)
Layer 10: Deep in the second convolutional block (patterns)
Layer 19: In the fourth convolutional block (complex features)
Layer 28: Near the end of the network (high-level content)
3. Image Loading Function
def load_image(image_name):
image = Image.open(image_name)
image = loader(image).unsqueeze(0)
return image.to(device)
This function loads an image, applies our transformation pipeline, and moves it to the appropriate device (CPU or GPU).
4. Setting Up the Environment
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
imsize = 356
loader = transforms.Compose(
[
transforms.Resize((imsize, imsize)),
transforms.ToTensor(),
]
)
We determine whether to use GPU or CPU and define our image transformation pipeline, which resizes images to 356×356 pixels and converts them to PyTorch tensors.
5. Loading Images
original_img = load_image("orignal.jpeg")
style_img = load_image("style.jpeg")
generated = original_img.clone().requires_grad_(True)
We load our content and style images, then create a copy of the content image as our starting point for the generated image. The requires_grad_(True)
enables PyTorch to compute gradients for this tensor.
6. Setting Up the Model and Optimizer
model = VGG().to(device).eval()
total_steps = 6000
learning_rate = 0.001
alpha = 1
beta = 0.01
optimizer = optim.Adam([generated], lr=learning_rate)
We initialize our VGG model, set it to evaluation mode, and configure our optimization parameters:
total_steps
: Number of iterations to runlearning_rate
: Step size for the optimizeralpha
: Weight for content lossbeta
: Weight for style lossWe use Adam optimizer to update our generated image
7. The Training Loop
for step in range(total_steps):
# Extract features
generated_features = model(generated)
original_img_features = model(original_img)
style_features = model(style_img)
style_loss = original_loss = 0
# Calculate content and style losses
for gen_feature, orig_feature, style_feature in zip(
generated_features, original_img_features, style_features
):
batch_size, channel, height, width = gen_feature.shape
# Content loss - Mean Squared Error between feature maps
original_loss += torch.mean((gen_feature - orig_feature) ** 2)
# Style loss - Mean Squared Error between Gram matrices
G = gen_feature.view(channel, height * width).mm(
gen_feature.view(channel, height * width).t()
)
A = style_feature.view(channel, height * width).mm(
style_feature.view(channel, height * width).t()
)
style_loss += torch.mean((G - A) ** 2)
# Combined loss
total_loss = alpha * original_loss + beta * style_loss
# Backpropagation
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
# Save progress
if step % 10 == 0:
print(total_loss)
save_image(generated, "generated.png")
This is where the magic happens:
We extract features from all three images using our VGG model
For each layer, we calculate:
Content loss: Mean squared error between the generated and content image features
Style loss: Mean squared error between the Gram matrices of the generated and style images
We combine these losses with their respective weights
We perform backpropagation to update our generated image
Every 10 steps, we save the current state of the generated image
The Gram Matrix: Capturing Style
A key concept in Neural Style Transfer is the Gram matrix, which captures the correlations between different feature maps. This matrix represents the style of an image by measuring how different features tend to activate together.
Why does this work for style?
Style is largely about textures, patterns, and colors - not about where specific objects are located. The Gram matrix is perfect for this because:
It captures which features activate together, regardless of where they appear in the image
It loses spatial information (where features are located) but retains textural information (how features relate to each other)
It's invariant to translation, meaning the same texture will produce similar Gram matrices regardless of position
The code for calculating the Gram matrix is:
G = gen_feature.view(channel, height * width).mm(
gen_feature.view(channel, height * width).t()
)
This reshapes our feature maps and performs a matrix multiplication with its transpose, resulting in a channel × channel matrix that represents feature correlations.
Why This Works: The Intuition
The effectiveness of Neural Style Transfer relies on how CNNs process images:
Feature Hierarchy: Lower layers capture local patterns (edges, colors), while deeper layers capture more complex structures.
Content Representation: By minimizing the difference between high-level feature activations, we preserve the content structure.
Style Representation: The Gram matrices capture texture information by measuring correlations between features, regardless of their spatial arrangement.
Gradient Descent: By iteratively updating the generated image to minimize both losses, we create an image that satisfies both content and style constraints.
Hyperparameter Tuning
The balance between content and style is controlled by the alpha
and beta
parameters:
Higher
alpha
preserves more of the original contentHigher
beta
incorporates more of the style
In our implementation, we use alpha = 1
and beta = 0.01
, which gives a good balance, but feel free to experiment with these values to achieve different effects.
By adjusting the ratio between these parameters, you control how much the optimization process prioritizes content preservation versus style transfer.
Results and Visualization
After running the code, you'll find your stylized image saved as "generated.png". The transformation happens gradually, with early iterations showing subtle style elements that become more pronounced as the optimization progresses.
Content Image
Style Image
Generated Image
Extending the Implementation
Here are some ways you could enhance this basic implementation:
Content-Style Tradeoff: Experiment with different values of
alpha
andbeta
Layer Selection: Try different combinations of VGG layers
Resolution: Increase
imsize
for higher-quality outputs (requires more memory)Initialization: Start from noise instead of the content image for different results
Total Variation Loss: Add a regularization term to reduce noise in the output
Color Preservation: Transfer only the luminance of the style image to preserve the colors of the content image
Conclusion
Neural Style Transfer is a fascinating application of deep learning that bridges the gap between computer vision and art. With just a few hundred lines of PyTorch code, we've implemented a technique that can transform ordinary photos into artistic masterpieces.
The theoretical foundation of NST reveals how deep neural networks can be repurposed from classification tasks to creative applications. By understanding how different layers of a CNN represent different aspects of an image, we can manipulate these representations to achieve stunning visual effects.
The next time you see an app that turns your selfie into a Monet or Picasso, you'll understand the neural magic happening behind the scenes!
Resources for Further Learning
Original NST Paper by Gatys et al.
Happy styling! Share your creations in the comments below.
Did you find this tutorial helpful? Let me know what other deep learning techniques you'd like to see implemented from scratch!
Subscribe to my newsletter
Read articles from Harshvardhan Vatsa directly inside your inbox. Subscribe to the newsletter, and don't miss out.
Written by
