Make your own Generative Model
What are GANs?
Deep neural networks are used mainly for supervised learning: classification or regression. Generative Adversarial Networks or GANs, however, use neural networks for a very different purpose: Generative modeling.
Generative Modeling is generating examples that plausibly could have been drawn from the original dataset.
The working of GANs can be explained by the following flowchart:
Dataset:
We will be using a dump of all images from the famous Bored Apes Yacht Club (BAYC) NFT.
Making dataset ready for training...
Let’s load this dataset using the ImageFolder class from torchvision. We will also resize and crop the images to 64x64 px, and normalize the pixel values with a mean & standard deviation of 0.5 for each channel. This will ensure that pixel values are in the range (-1, 1), which is more convenient for training the discriminator.
import os
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import torchvision.transforms as T
import torch
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
%matplotlib inline
DATA_DIR = '../input/bored-apes-yacht-club/'
print(os.listdir(DATA_DIR))
image_size = 64
batch_size = 128
stats = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
train_ds = ImageFolder(DATA_DIR, transform=T.Compose([
T.Resize(image_size),
T.CenterCrop(image_size),
T.ToTensor(),
T.Normalize(*stats)]))
train_dl = DataLoader(train_ds, batch_size, shuffle=True, num_workers=2, pin_memory=True)
def denorm(img_tensors):
return img_tensors * stats[1][0] + stats[0][0]
def show_images(images, nmax=64):
fig, ax = plt.subplots(figsize=(8, 8))
ax.set_xticks([]); ax.set_yticks([])
ax.imshow(make_grid(denorm(images.detach()[:nmax]), nrow=8).permute(1, 2, 0))
def show_batch(dl, nmax=64):
for images, _ in dl:
show_images(images, nmax)
break;
Lets see a sample from dataset -
Discriminator Network
The discriminator takes an image as an input, and tries to classify it as “real” or “generated”. In this sense, it’s like any other neural network. We’ll use a convolutional neural network (CNN) which outputs a single number output for every image. We’ll use a stride of 2 to progressively reduce the size of the output feature map.
import torch.nn as nn
discriminator = nn.Sequential(
# in: 3 x 64 x 64
nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2, inplace=True),
# out: 64 x 32 x 32
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
# out: 128 x 16 x 16
nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace=True),
# out: 256 x 8 x 8
nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2, inplace=True),
# out: 512 x 4 x 4
nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0, bias=False),
# out: 1 x 1 x 1
nn.Flatten(),
nn.Sigmoid())
Generator Network
The input to the generator is typically a vector or a matrix of random numbers (referred to as a latent tensor) which is used as a seed for generating an image. The generator will convert a latent tensor of shape (128, 1, 1) into an image tensor of shape 3 x 28 x 28.
import torch.nn as nn
generator = nn.Sequential(
# in: latent_size x 1 x 1
nn.ConvTranspose2d(latent_size, 512, kernel_size=4, stride=1, padding=0, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(True),
# out: 512 x 4 x 4
nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(True),
# out: 256 x 8 x 8
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(True),
# out: 128 x 16 x 16
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(True),
# out: 64 x 32 x 32
nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1, bias=False),
nn.Tanh()
# out: 3 x 64 x 64
)
Network training
Discriminator training:
The steps involved in training the discriminator are:
We expect the discriminator to output 1 if the image was picked from the real dataset, and 0 if it was generated using the generator network.
We first pass a batch of real images, and compute the loss, setting the target labels to 1.
Then we pass a batch of fake images (generated using the generator) pass them into the discriminator, and compute the loss, setting the target labels to 0.
Finally, we add the two losses and use the overall loss to perform gradient descent to adjust the weights of the discriminator.
def train_discriminator(real_images, opt_d):
# Clear discriminator gradients
opt_d.zero_grad()
# Pass real images through discriminator
real_preds = discriminator(real_images)
real_targets = torch.ones(real_images.size(0), 1, device=device)
real_loss = F.binary_cross_entropy(real_preds, real_targets)
real_score = torch.mean(real_preds).item()
# Generate fake images
latent = torch.randn(batch_size, latent_size, 1, 1, device=device)
fake_images = generator(latent)
# Pass fake images through discriminator
fake_targets = torch.zeros(fake_images.size(0), 1, device=device)
fake_preds = discriminator(fake_images)
fake_loss = F.binary_cross_entropy(fake_preds, fake_targets)
fake_score = torch.mean(fake_preds).item()
# Update discriminator weights
loss = real_loss + fake_loss
loss.backward()
opt_d.step()
return loss.item(), real_score, fake_score
Generator training:
Since the outputs of the generator are images, it’s not obvious how we can train the generator. This is where we employ a rather elegant trick, which is to use the discriminator as a part of the loss function. Here’s how it works:
We generate a batch of images using the generator and pass them into the discriminator.
We calculate the loss by setting the target labels to 1 i.e. real. We do this because the generator’s objective is to “fool” the discriminator.
We use the loss to perform gradient descent i.e. change the weights of the generator, so it gets better at generating real-like images to “fool” the discriminator.
def train_generator(opt_g):
# Clear generator gradients
opt_g.zero_grad()
# Generate fake images
latent = torch.randn(batch_size, latent_size, 1, 1, device=device)
fake_images = generator(latent)
# Try to fool the discriminator
preds = discriminator(fake_images)
targets = torch.ones(batch_size, 1, device=device)
loss = F.binary_cross_entropy(preds, targets)
# Update generator weights
loss.backward()
opt_g.step()
return loss.item()
Full Training
Let’s define a fit function to train the discriminator and generator for each batch of training data. We’ll use the Adam optimizer with some custom parameters (betas) that are known to work well for GANs. We will also save some sample-generated images at regular intervals for inspection.
def fit(epochs, lr, start_idx=1):
torch.cuda.empty_cache()
# Losses & scores
losses_g = []
losses_d = []
real_scores = []
fake_scores = []
# Create optimizers
opt_d = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
opt_g = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
for epoch in range(epochs):
for real_images, _ in tqdm(train_dl):
# Train discriminator
loss_d, real_score, fake_score = train_discriminator(real_images, opt_d)
# Train generator
loss_g = train_generator(opt_g)
# Record losses & scores
losses_g.append(loss_g)
losses_d.append(loss_d)
real_scores.append(real_score)
fake_scores.append(fake_score)
# Log losses & scores (last batch)
print("Epoch [{}/{}], loss_g: {:.4f}, loss_d: {:.4f}, real_score: {:.4f}, fake_score: {:.4f}".format(
epoch+1, epochs, loss_g, loss_d, real_score, fake_score))
# Save generated images
save_samples(epoch+start_idx, fixed_latent, show=False)
return losses_g, losses_d, real_scores, fake_scores
Results
We can see the network progress and a few images are convincing but we are still far from our initial goal.
We struggle to avoid model collapse, which is a state in which every generator input leads to the same result.
Some of our limitations are due to our network, which is quite simple and small to current standards.
Link to Code: https://bit.ly/3T3EHgz
More generated examples: https://bit.ly/3dIZuGg
If you liked this article I would be super excited if you hit the like button or share it with your curious friends. Anyway, thanks for reading have a great day.
Subscribe to my newsletter
Read articles from Harshal Rudra directly inside your inbox. Subscribe to the newsletter, and don't miss out.
Written by