Getting Started with JAX: Modern Deep Learning Framework


Machine Learning frameworks make our lives easier by offering built-in methods for faster prototyping and optimized GPU training. They also help by providing robust and scalable solutions which help engineers implement state-of-the-art models in a better way by employing several programming shenanigans.
Tensorflow and Pytorch have been the standard for programmers and researchers alike for the past few years (researchers especially leaning towards the latter one). But a new framework is slowly becoming the talk of the town.
JAX, a modern Deep Learning framework developed by Google (which also developed Tensorflow) is optimized for handling Tensor operations. This makes it the best choice for training on TPUs (Tensor Processing Units) which provide great results in terms of computation speed and resources. This advantage has given it a one up over Pytorch which was enjoying its reign.
Though Pytorch is still super optimized for multi-GPU training, JAX is being rapidly adopted by researchers for implementing cutting edge algorithms, especially for 3D computer vision related tasks like Differential rendering and 3D Gaussian Splatting using TPUs.
So today, we will have a look a JAX and try to do what one does whenever they get encounter a new ML framework - Implement a basic Neural Network.
(Note : You can use google colab to follow this tutorial since JAX is preinstalled in colab)
Features of JAX
Before diving into a full code implementation, it is important to understand first some fundamental differences in the framework. JAX has three parts:
- J stands for JIT or Just In Time compilation. If you are new to this term, it broadly means that functions won’t be compiled every time you run them. JAX traces the function and compiles it with XLA in the first call and then subsequent calls will use the cached compiled version, which results in increased speed. A standard way of testing this in JAX is also what the documentation provides :
import jax.numpy as jnp
import from jax import random
from jax import jit
#In first cell
key = random.key(1712)
x = random.normal(key, (1_000_000,))
%timeit selu(x).block_until_ready()
#In second cell
selu_jit = jit(selu)
_ = selu_jit(x)
%timeit selu_jit(x).block_until_ready()
You would notice, that the output shows way faster timing for the second cell, and if you run it again, it will show even lesser time for the first few trials, and then the timing will plateau off.
(Note : Generally frameworks like JAX schedule the execution and return a promise of the result instead of the result (asynchronous behavior). To accurately time the computation and not the scheduling of the task, we use
block_until_ready()
to block until the execution is complete)
Now JIT is not something which Pytorch is devoid of. But the way it handles it, is way different than JAX.
When you write a function and decorate it with @jit
, JAX traces the function, builds an XLA graph, and compiles it. This is less flexible, but results in super fast calculations.
PyTorch is eager and imperative by default, which is flexible but slower for execution.
To make it faster or exportable, it wraps models into static graphs (TorchScript). But this wrapping is not as deep or aggressive as JAX’s full XLA compilation, so JAX can be faster in many pure-numeric tasks.
- A stands for Autodiff. Automatic differentiation is a technique to compute exact derivatives of functions expressed as code. Autodiff works by tracing operations and applying the chain rule to compute derivatives automatically. We won’t go into details since it is a more fundamental concept, but you can check the referenced video for a deeper understanding.
In JAX, autodiff is built-in using functional programming principles. The core idea is: if you can write a function, JAX can differentiate it. JAX uses an arsenal of tools to execute that. First of them, which you may have noticed above as well, is jax.numpy
.
So JAX has its own implementation of NumPy, which is quite different than the regular one. One of the main differences is that jnp
(used for Jax NumPy) arrays are immutable and a provided method must be used to make a change ( Instead of x[idx] = y
, use x = x.at[idx].set(y)
) . Another one, is that they are made to be compatible with grad
jit
etc. as they are traceable.
Moreover, they are incompatible with iterators:
from jax import lax
array = jnp.arange(10)
print(lax.fori_loop(0, 10, lambda i,x: x+array[i], 0))
iterator = iter(range(10))
print(lax.fori_loop(0, 10, lambda i,x: x+next(iterator), 0))
This is an example provided in the official documentation. The first one (compatible) will return the expected result 45, while the second one (which uses iterators) will return 0.
This is because only the first value of the iterator is acknowledged while tracing and then it is frozen due to a static graph. So external states if included in the function are only accounted once during tracing. You can further verify this by running this incorrect code :
counter = {"count": 0}
def body(i, x):
counter["count"] += 1
return x + 1
lax.fori_loop(0, 5, body, 0)
print(counter["count"])
The counter here is an external state, but if you follow the logic, you can guess what the output to the counter is gonna be.
Another important tool is vmap
. If you write a function to operate on a single entity, say an image, vmap
will vectorize it in the backend to make it work with batched input. Awesome right? So you just need to write functions with a single input in mind and enter correct arguments in the in_axes
parameter and you get a fully optimized batch ready function with parallel processing in the bts.
There are other important methods as well, but we will go on to the next point and leave the job to surmising them along with the above discussed ones to this table.
Tool | Purpose |
grad | Computes gradient of scalar-output functions |
vmap | Vectorizes functions (applies them in parallel) |
jax.numpy | NumPy-like API that records operations for autodiff |
jacfwd , jacrev | Compute full Jacobians (multi-input/multi-output derivatives) |
value_and_grad | Get both function value and its gradient |
X means XLA which stands for Accelerated Linear Algebra. It is a domain-specific compiler for linear algebra. It takes numerical code (like matrix multiplications, etc.) and compiles it into highly optimized machine-specific code. Just keep in mind that it is a something required for utilizing TPUs.
Pytorch too can be used with TPUs using
torch_xla
but it is not that straightforward.
State Handling in JAX
As discussed before, JAX’s soul lies in functional programming and it works best with pure functions. This means that we cannot use NumPy like global random states. The way JAX handles it, is that it requires you to explicitly define a new subkey every time you need to use one (or separately for every method). For example :
# Initial random key generated using a PRNG and a seed
key = random.PRNGKey(0)
# When we need randomness again for some function (say, when initializing weights)
key, subkey = random.split(key) # Notice that initial key is also updated here
We will be looking at a better example in a while, but you get the idea. What the random.split
function does is that it generates two new keys based on a provided key (so it’s not technically a split). The reason we update the initial key as well is because it will be later used to generate subkeys again. We don’t want the same key for that cause, since it will generate the same subkey repeatedly.
Implementing a Neural Network
So we can get started with the code, now that the fundamentals are covered. If you are new to this whole neural network thing, or want to know the mathematics behind this, then you can checkout one of our previous blogs, we won’t go into the mathematical details here.
The way we do it in frameworks like Pytorch or JAX, is that we define architecture, initializing parameters and the forward pass. The backward pass is handled by the framework itself, unless you want to override it. So we will be implementing functions for these four:
Initializing parameters
Forward pass
Loss function
Gradient update
Initializing parameters
Initializing weights is important because it helps in faster convergence and avoids vanishing or exploding gradients problems. The two main initializations are Xavier Glorot and Kaiming He initialization. It is empirically shown that Kaiming He works better with relu
(and relu-like) activation functions while Xavier Glorot works better with tanh
activation function. The mathematical formulae goes like this:
$$\begin{align} \text{For Normal Distribution : } \hspace{5mm} \sigma &= \sqrt{\frac{2}{n_{in}}} \\ \\ \text{For Uniform Distribution : } \hspace{5mm} \sigma &= \sqrt{\frac{6}{n_{in}}} \end{align}$$
where \(n_{in}\) represents the number of incoming connections to the node. So for uniform distribution (our case) this \(\sigma\) is used as a bound/limit of a uniform distribution and weights are randomly sampled from that range using the generated subkey (as you can see, we used the same state handling concepts as discussed before).
def _init_params(self) -> dict:
params = {}
for i in range(len(self.layer_sizes) - 1):
self.key, subkey = random.split(self.key)
fan_in = self.layer_sizes[i]
bound = jnp.sqrt(6.0 / fan_in)
params[f'W{i}'] = random.uniform(subkey, minval=-bound, maxval=bound, shape=(self.layer_sizes[i], self.layer_sizes[i + 1]))
params[f'b{i}'] = jnp.zeros(self.layer_sizes[i + 1])
return params
Forward Pass
This is simple, just follow the maths: \(z = w\cdot x + b\)
def _forward_pass(self, params: dict, x: jnp.ndarray) -> jnp.ndarray:
for i in range(len(self.layer_sizes) - 2):
x = jnp.dot(x, params[f'W{i}']) + params[f'b{i}']
x = jnp.maximum(0, x) # ReLU function
output_idx = len(self.layer_sizes) - 2
output = jnp.dot(x, params[f'W{output_idx}']) + params[f'b{output_idx}']
return output
We have applied the ReLU function in the forward pass itself, but you can also handle it as a separate method. Naturally, it is required for all layers of the architecture except the final one.
Batch Loss Function
Now the Loss function specifically depends on the task you are doing, which have not addressed yet. Since we are just doing an implementation for understanding (and we are on colab), we will perform a simple regression analysis on California Housing Dataset. For that, we will use the MSE loss.
def _compute_loss(self, params: dict, x: jnp.ndarray, y: jnp.ndarray) -> float:
pred = self._forward_pass(params, x)
return jnp.mean((pred - y) ** 2)
def _batch_loss_function(self, params: dict, x: jnp.ndarray, y: jnp.ndarray) -> float:
return jnp.mean(vmap(self._compute_loss, (None, 0, 0))(params, x, y))
Here, as you can see, is another practical example of the application of vmap
. We can use it to convert a simple loss function to a batched loss function.. The in_axes
arguments specify which parameters of the function are supposed to be batched.
(Note : Actually the in_axes arguments specify the particular axes/dimension which will be considered for batching for corresponding parameter of the function. If the argument is not None or 0 but 1, then it means the 1st dimension of that particular parameter will be batched. This is not often used in Neural Network architectures since they concern mostly batches but is important to note. You can see more of this in the provided notebooks.)
Gradient Update
This too, just follows along the known mathematics:
$$\begin{align} W &= W - \eta \cdot \nabla W \\ B &= B - \eta \cdot \nabla B \end{align}$$
def _gradient_update(self, params: dict, X: jnp.ndarray, y: jnp.ndarray, learning_rate: float) -> dict:
grads = jax.grad(self._batch_loss_function)(params, X, y)
return {k: params[k] - learning_rate * grads[k] for k in params}
Rest of the implementation is intuitive and doesn’t have anything noteworthy, except for the data loading step. The problem is that JAX doesn’t have any component for data loading like torchvision
or something (which is an important thing to note). So for now, we will be using classic NumPy and Pandas for loading our csv files and then turn them into jnp
arrays. Also, for the same reason, we will be using metrics and StandardScaler
from the scikit-learn
library.
def load_california_housing_data():
train_df = pd.read_csv('/content/sample_data/california_housing_train.csv')
test_df = pd.read_csv('/content/sample_data/california_housing_test.csv')
print(f"Training data shape: {train_df.shape}")
print(f"Test data shape: {test_df.shape}")
print(f"Features: {list(train_df.columns[:-1])}")
print(f"Target: {train_df.columns[-1]}")
X_train = train_df.iloc[:, :-1].values
y_train = train_df.iloc[:, -1].values.reshape(-1, 1)
X_test = test_df.iloc[:, :-1].values
y_test = test_df.iloc[:, -1].values.reshape(-1, 1)
# Standardize features
scaler_X = StandardScaler()
X_train_scaled = scaler_X.fit_transform(X_train)
X_test_scaled = scaler_X.transform(X_test)
scaler_y = StandardScaler()
y_train_scaled = scaler_y.fit_transform(y_train)
y_test_scaled = scaler_y.transform(y_test)
# Convert to JAX arrays
X_train_jax = jnp.array(X_train_scaled)
y_train_jax = jnp.array(y_train_scaled)
X_test_jax = jnp.array(X_test_scaled)
y_test_jax = jnp.array(y_test_scaled)
print(f"\nData preprocessing completed")
return X_train_jax, y_train_jax, X_test_jax, y_test_jax, scaler_X, scaler_y
The whole code and an additional file is provided below to verify and review how JAX works differently.
Is it truly a Framework?
Now it’s time for an important discussion. Is JAX really a framework? A framework, by a loose definition, is supposed to do two things. First is Inversion of control : It means that you are not calling the code and its methods, but the framework imposes structure and rules and asks you to fit your code into the defined form and structure, i.e. the code calls you. Pytorch is similar, you explicitly call functions like forward()
, backward()
, and optimizer.step()
but you are in control.
Second is the more or less similar to the first one, which is Imposition of structure or less flexibility. Frameworks like Caffe which were used before, had very less user written code and heavily imposed structure. Since it is obsolete now (not entirely, Caffe2 became a part of Pytorch backend) I looked up about how things were done in Caffe:
layer {
name: "conv1"
type: "Convolution"
bottom: "data"
top: "conv1"
convolution_param { num_output: 32 kernel_size: 5 }
}
layer {
name: "relu1"
type: "ReLU"
bottom: "conv1"
top: "conv1"
}
As you can see, we are just defining what we want to use, not writing the methods and algorithms.
A true modern day framework would be Tensorflow or Pytorch Lightning. In Tensorflow, you define model architecture, loss computation, data pipelines etc. But when you call model.fit()
it decides on its own, when to call back propagation, updates and validation. Unfortunately, I have never used either of those so I can’t say much on this, but it is clear that JAX is not at all a framework. It is more of a transformation library.
Though JAX does impose some rules (pure functions), you are still in charge of when to call its methods like grad
, jit
and vmap
. As we saw first hand, we had to write the whole definition. Also, the problem it addresses clarifies further why it is a library. JAX’s main idea is to super optimize numerical computing using low-level primitives. In fact, it’s not even a NumPy wrapper even though it may seem like that to some. The techniques we discussed throughout which it employs, solidify its uniqueness. There are frameworks built over it though, Flax and Haiku are some examples.
For instance, the whole model definition we wrote can be written like this in Flax:
class FlaxNet(nn.Module):
layer_sizes: list[int]
@nn.compact
def __call__(self, x):
for i, size in enumerate(self.layer_sizes[:-1]):
x = nn.Dense(features=size)(x)
x = nn.relu(x)
x = nn.Dense(features=self.layer_sizes[-1])(x)
return x
(Just for the model definition, not the train and evaluation function, though they would be heavily shortened as well)
This makes me realize that the title of the blog and the intro paragraphs are not correct at all (basically every instance I addressed it as a framework is a mistake). But I’m leaving them as they are for the sake of SEO and the narrative payoff of this revelation. If you clicked this because of the title and have read this far, drop a comment.
Provided Colab Notebooks
References
(Side Note (not important): This The Sharp Bits documentation was written either by some cool guy, because the inclusion of emojis is quite funny. It is titled sharp bits so a knife is used for representing paragraph changes. And not just that, even the url slug says ‘Common Gotchas in JAX’.)
Subscribe to my newsletter
Read articles from Ayush Saraswat directly inside your inbox. Subscribe to the newsletter, and don't miss out.
Written by

Ayush Saraswat
Ayush Saraswat
Aspiring Computer Vision engineer, eager to delve into the world of AI/ML, cloud, Computer Vision, TinyML and other programming stuff.