Building Convolutional Neural Networks in JAX


Introduction
Deep learning has revolutionized the field of artificial intelligence, and at the heart of this revolution are Convolutional Neural Networks (CNNs). CNNs have become the go-to architectures for tasks involving images, such as object detection, facial recognition, medical imaging, and self-driving cars.
Traditionally, frameworks like TensorFlow and PyTorch have dominated the deep learning landscape. However, JAX has emerged as a powerful alternative, especially for research and high-performance computing. Developed by Google, JAX provides automatic differentiation and Just-In-Time (JIT) compilation, making it highly efficient for numerical computing and deep learning applications.
Why Use JAX for CNNs?
JAX stands out because of its ability to seamlessly run code on CPUs, GPUs, and TPUs while maintaining a NumPy-like API. This means you can develop models using familiar syntax while benefiting from:
Automatic Vectorization – With functions like
vmap
, JAX makes it easy to apply operations over large batches of data without writing explicit loops.Efficient Autograd – JAX provides automatic differentiation using
grad
, which simplifies training deep learning models.XLA Compilation – Just-In-Time (JIT) compilation speeds up execution by compiling computation graphs for efficient hardware utilization.
Functional Programming Paradigm – Unlike traditional deep learning frameworks, JAX encourages pure functions, which improves reproducibility and debugging.
Prerequisites
Before proceeding, ensure you are familiar with:
JAX fundamentals, you can check out the JAX documentation here.
Building CNNs in TensorFlow or PyTorch
JAX optimizers and loss functions
Install Dependencies
Install the required libraries:
!pip install jax jaxlib flax optax tensorflow tensorflow_datasets dm-pix tqdm matplotlib
jax
andjaxlib
– The core JAX library and its hardware acceleration backend.flax
– A neural network library for JAX, similar to PyTorch’storch.nn
.optax
– A library for optimization algorithms in JAX.dm_pix
– A lightweight image processing library for JAX.matplotlib
– For visualizing images.
Import Packages
Load necessary libraries:
import os
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
import optax
from tqdm.auto import tqdm
from flax import linen as nn
from flax.training import train_state
import dm_pix as pix # Image processing in JAX
Verify GPU Access
JAX runs computations on CPUs, GPUs, and TPUs seamlessly. To check if your machine has a GPU or TPU available:
# Get available devices
print("Available Devices:", jax.devices())
# Check if GPU is available
if jax.default_backend() == "gpu":
print("Using GPU:", jax.devices("gpu"))
elif jax.default_backend() == "tpu":
print("Using TPU:", jax.devices("tpu"))
else:
print("Using CPU")
Data Preprocessing
Raw image data comes in various sizes, orientations, and quality levels. Preprocessing is crucial for:
Ensuring uniform input dimensions.
Normalizing pixel values for stable training.
Augmenting data to improve model generalization.
Converting images into JAX-compatible tensors.
Loading the Dataset
We will use the Cats vs. Dogs dataset available on Kaggle. Download and unzip the dataset using:
!kaggle datasets download -d chetankv/dogs-cats-images
!unzip dogs-cats-images.zip
This dataset gives us:
A training dataset (
train_data
)A test dataset (
test_data
)
Define the path to the images and the batch size:
base_dir = "/content/dog vs cat/dataset/training_set"
batch_size = 64
Resizing and Normalizing Images
Since images in the dataset have varying sizes, we must resize them to a fixed size (e.g., 128×128 pixels). Additionally, we normalize pixel values from [0, 255] → [0, 1] for stable training.
IMG_SIZE = 128
resize_and_rescale = tf.keras.Sequential(
[
tf.keras.layers.Resizing(IMG_SIZE, IMG_SIZE),
tf.keras.layers.Rescaling(1.0 / 255),
]
)
Data Augmentation for Better Generalization
Data augmentation helps improve model generalization by applying transformations like flipping, rotation, brightness adjustments, and cropping.
rng = jax.random.PRNGKey(0)
rng, inp_rng, init_rng = jax.random.split(rng, 3)
delta = 0.42
factor = 0.42
@jax.jit
def data_augmentation(image):
new_image = pix.adjust_brightness(image=image, delta=delta)
new_image = pix.random_brightness(image=new_image, max_delta=delta, key=inp_rng)
new_image = pix.flip_up_down(image=image)
new_image = pix.flip_left_right(image=new_image)
new_image = pix.rot90(k=1, image=new_image) # k = number of times the rotation is applied
return new_image
Converting Data to JAX-Compatible Tensors
JAX primarily operates on NumPy-like arrays (jnp.array
). TensorFlow uses tf.Tensor
, so we must convert our dataset into a JAX-friendly format.
AUTOTUNE = tf.data.AUTOTUNE
def prepare(ds, shuffle=False):
# Rescale and resize all datasets.
ds = ds.map(lambda x, y: (resize_and_rescale(x), y), num_parallel_calls=AUTOTUNE)
if shuffle:
ds = ds.shuffle(1000)
# Use buffered prefetching on all datasets.
return ds.prefetch(buffer_size=AUTOTUNE)
train_ds = prepare(training_set, shuffle=True)
val_ds = prepare(validation_set)
evaluation_set = prepare(eval_set)
Visualizing Preprocessed Images
Let’s check if preprocessing is working as expected.
plt.figure(figsize=(10, 10))
augmented_images = []
for images, _ in training_set.take(1):
for i in range(9):
augmented_image = data_augmentation(np.array(images[i], dtype=jnp.float32))
augmented_images.append(augmented_image)
ax = plt.subplot(3, 3, i + 1)
plt.imshow(augmented_images[i].astype("uint8"))
plt.axis("off")
Defining a CNN in JAX
In JAX, neural networks are often implemented using Flax, a high-level neural network library designed to work seamlessly with JAX’s functional paradigm. Flax provides an intuitive way to define models using Module classes.
Below is a simple implementation of a Convolutional Neural Network (CNN) in JAX using Flax:
import jax.numpy as jnp
import flax.linen as nn
class CNN(nn.Module):
num_classes: int # Number of output classes
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3), strides=(1, 1), padding="SAME")(x)
x = nn.relu(x)
x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3), strides=(1, 1), padding="SAME")(x)
x = nn.relu(x)
x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=128, kernel_size=(3, 3), strides=(1, 1), padding="SAME")(x)
x = nn.relu(x)
x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # Flatten feature maps
x = nn.Dense(features=128)(x)
x = nn.relu(x)
x = nn.Dense(features=self.num_classes)(x) # Output layer
return x
First Convolutional Layer
Applies a 32-channel convolution with a 3×3 kernel.
Uses ReLU activation to introduce non-linearity.
Applies max pooling with a 2×2 window and a stride of 2, reducing the spatial dimensions by half.
Second Convolutional Layer
Uses a 64-channel convolution with a 3×3 kernel.
Again applies ReLU activation.
Another max pooling operation further reduces spatial dimensions.
Third Convolutional Layer
Increases the number of channels to 128 while keeping the 3×3 kernel size.
Applies ReLU activation and another max pooling step.
Flattening and Fully Connected Layers
The feature maps from the final convolutional layer are flattened into a 1D vector.
A dense layer with 128 neurons applies a ReLU activation.
The final output layer produces logits corresponding to the number of classes.
Why Use @nn.compact
?
Flax provides two ways to define models:
Using
@nn.compact
, which allows direct instantiation of layers within the__call__
method.Using
setup()
, where layers are explicitly defined as attributes.
The compact approach is cleaner and more intuitive for simple models, avoiding the need to define layer attributes separately.
Initializing the Model
JAX does not use an implicit state, so model parameters must be explicitly initialized. The init
function from Flax helps generate the model’s parameters using a random key and an input shape.
from flax.core import freeze, unfreeze
# Set up PRNG key
key = jax.random.PRNGKey(0)
# Define input shape (batch_size, height, width, channels)
input_shape = (1, 32, 32, 3) # Example for a 32x32 RGB image
# Initialize model
model = CNN(num_classes=10) # Assuming 10 output classes
params = model.init(key, jnp.ones(input_shape))["params"]
A random key is generated using
jax.random.PRNGKey(0)
. JAX requires explicit control over random number generation for reproducibility.A dummy input tensor of shape
(1, 32, 32, 3)
is created to initialize the network.The model is instantiated and the
init
function generates model parameters using the random key.The
"params"
field is extracted from the initialization output, as Flax’sinit
method returns a dictionary containing additional information (e.g., batch statistics if using BatchNorm).
Defining the Training State
Flax provides a train_state
abstraction to manage model parameters, optimizer state, and other training-related information. The optax
library is used for defining the optimizer.
import optax
from flax.training import train_state
class TrainState(train_state.TrainState):
pass # No additional attributes needed for now
# Define the optimizer
learning_rate = 0.001
optimizer = optax.adam(learning_rate)
# Initialize the training state
state = TrainState.create(apply_fn=model.apply, params=params, tx=optimizer)
TrainState
is a dataclass that stores the model's parameters, optimizer state, andapply_fn
(the function used for forward passes).Optax's Adam optimizer is set up with a learning rate of
0.001
.The
state.create()
method initializes the model’s training state with:apply_fn
: The forward pass function from the model.params
: The initialized parameters from the previous step.tx
: The optimizer (Adam in this case).
Defining Loss and Accuracy Metrics
A loss function is required to guide training, while an accuracy function evaluates model performance.
Loss Function
import jax.nn as jnn
def cross_entropy_loss(params, state, batch):
logits = state.apply_fn({'params': params}, batch['images'])
labels = jnn.one_hot(batch['labels'], num_classes=10)
return -jnp.sum(labels * jnn.log_softmax(logits)) / batch['labels'].shape[0]
The function takes model parameters, the current training state, and a batch of input data.
The forward pass is performed using
apply_fn
, producing logits (raw model predictions).The labels are one-hot encoded to match the logits' shape.
The cross-entropy loss is computed using
log_softmax(logits)
, ensuring numerical stability.The loss is averaged over the batch size for proper optimization.
Accuracy Function
def compute_accuracy(params, state, batch):
logits = state.apply_fn({'params': params}, batch['images'])
predicted_labels = jnp.argmax(logits, axis=1)
return jnp.mean(predicted_labels == batch['labels'])
The function takes model parameters, the training state, and a batch of data.
The forward pass is executed to obtain logits.
The highest-scoring class is selected using
argmax()
, determining the model’s predicted label.Accuracy is computed by comparing predictions with actual labels and averaging the correct classifications.
Training and Evaluating a CNN in JAX
Training in JAX is based on functional transformations, meaning explicit gradient computation and parameter updates are required. The jax.grad
function is used to compute gradients efficiently.
Training Step Function
@jax.jit
def train_step(state, batch):
def loss_fn(params):
logits = state.apply_fn({'params': params}, batch['images'])
labels = jax.nn.one_hot(batch['labels'], num_classes=10)
loss = -jnp.sum(labels * jax.nn.log_softmax(logits)) / batch['labels'].shape[0]
return loss
# Compute gradients
grads = jax.grad(loss_fn)(state.params)
# Update model state
state = state.apply_gradients(grads=grads)
return state
JIT Compilation (
@jax.jit
): JAX’s just-in-time compilation speeds up training by optimizing computation.loss_fn
Function: Defines the cross-entropy loss to be minimized.jax.grad(loss_fn)
: Computes gradients with respect to model parameters.state.apply_gradients(grads=grads)
: Updates the training state using computed gradients.
Training Loop
def train_model(state, train_loader, num_epochs=10):
for epoch in range(num_epochs):
for batch in train_loader:
state = train_step(state, batch)
print(f"Epoch {epoch + 1} completed")
return state
Iterates through multiple epochs, training the model for
num_epochs
.Processes each batch, updating the model parameters.
Logs progress at the end of each epoch.
Evaluation Step Function
def evaluate_model(state, test_loader):
accuracies = []
for batch in test_loader:
acc = compute_accuracy(state.params, state, batch)
accuracies.append(acc)
final_accuracy = jnp.mean(jnp.array(accuracies))
print(f"Test Accuracy: {final_accuracy * 100:.2f}%")
return final_accuracy
Iterates through the test dataset, computing accuracy for each batch.
Aggregates accuracy scores across batches to compute the final accuracy.
Prints the test accuracy, indicating how well the model generalizes to unseen data.
Conclusion
JAX’s functional and hardware-accelerated approach allows for efficient model training, particularly on GPUs and TPUs. The explicit handling of gradients and optimizers ensures flexibility while maintaining high performance.
Future work could explore advanced techniques such as data augmentation, model regularization, and hyperparameter tuning to improve performance. Additionally, integrating JAX with frameworks like TensorFlow or PyTorch could provide hybrid workflows for deep learning research and production applications.
Subscribe to my newsletter
Read articles from Wesley Kambale directly inside your inbox. Subscribe to the newsletter, and don't miss out.
Written by

Wesley Kambale
Wesley Kambale
Wesley is a machine learning engineer and data scientist, adept at crafting production-ready ML systems that provide impactful solutions in the African market. As a tech conference speaker, he shares his expertise through insightful talks and occasional articles on TensorFlow and Keras, aiming to disseminate his knowledge and experiences. He is a seasoned community organizer with vast experience in launching and building Google Developer communities in western Uganda. He is an active organizer in Google Developer Groups (GDG) program and an alumni of the Google Developer Students Club (GDSC) program. Wesley has an undergraduate degree in computer science from Mbarara University of Science and Technology and holds various certificates and certifications in data science and machine learning.