Vision Transformers in JAX


Transformers were introduced in 2017 and changed the course of Deep Learning research for many forthcoming years. Naturally, they were also adopted in the computer vision ecosystem which was till then, dominated by Convolutional Neural Networks or Convolution based approaches in general. Vision Transformers achieved 88.55% top-1 accuracy on ImageNet with a pre-trained large model (ViT-L/16), thus bringing it to the spotlight.
Today, we will be covering the underlying mechanism of Vision Transformers and try to implement it in Equinox, a JAX based framework.
(Note : Since this blog is about Vision Transformers, we will not delve into the details of how JAX and Equinox function by themselves. Also, you don’t need prior knowledge of Transformers for following this article as everything concerned will be gradually discussed. But If you are interested you can you can check the references provided by the end, along with the github repository of this code)
The Core Ideology
This won't be hard to understand if you are acquainted with the knowledge of Transformers. This is because Vision Transformers simply divide the image into several patches of a fixed size, and arranges them to form a sequence. Rest of the functioning is similar to how regular Encoder of a Transformer functions. The original paper used 16 x 16 patches for the ImageNet dataset, while smaller patch sizes like 4 x 4 or 8 x 8 are preferred for datasets with smaller sized images since longer sequences are generally more helpful, especially while using Linear Attention.
Therefore, throughout this implementation, we will be focusing on these few things:
Patch Embeddings
Self Attention
Multi Head Attention
Encoder Block
Classification Head
Vision Transformer
Training Loop
Patch Embeddings
Transformers, in general, are used for words. First they assign each words a number, then they further convert them into learnable vectors of size (sequence_length, embedding_dim)
, i.e each words will have a vector of size (embedding_dim, )
. These learnable vectors are called embeddings. So Embeddings, if you didn't realize, are just lookup tables. So first, we will try to create a learnable lookup table in the __init__()
method which is initiated according to the provided values for the sizes. Since patches are square, and are going to be flattened to single dimensional vector for creating embeddings. This means, that for our use case, Patch Embeddings will be of shape (flattened_image_shape, embedding_dim)
class PatchEmbedding(eqx.Module):
linear: eqx.nn.Linear
patch_size: int
def __init__(self, patch_size: int, embedding_dim: int, key: PRNGKeyArray):
self.patch_size = patch_size
patch_dim = patch_size * patch_size * 3
self.linear = eqx.nn.Linear(in_features=patch_dim, out_features=embedding_dim, key=key)
def __call__(self, x: Array) -> Array:
H, W, C = x.shape
# Tensorized
x_patches = rearrange(x, '(h p1) (w p2) c -> (h w) (p1 p2 c)',
p1=self.patch_size, p2=self.patch_size)
embeddings = jax.vmap(self.linear)(x_patches)
return embeddings
To fasten-up the process, it is necessary to use tensorized operations, that is why you will see us using einops.rearrange()
here (also because einops
is optimized for some heavy batch operations which can be done effectively using einsum
).
Self Attention
Transformers rely on the concept of attention. It means every unit (token or patch) being aware of the information of every other unit. This is done dynamically based on input values called Query, Key and Values (all three are same and refer to the input embeddings). So these are the things we have to keep in mind while looking for an attention function :
Similarity metric : This just means what we said above. We need to find a way to allow each element in a sequence to derive context from every other element. This requires a mechanism to measure the similarity or relevance between any two elements. The most common way to do this is what the original paper did, i.e. Dot product (multiplication). Some other approaches also use cosine similarity, additive attention, etc. The output of this stage should be a set of raw "attention scores" between elements.
Scaling : As the score will grow relative to the size of the input vector embeddings, we need a way to scale it down.
Normalization : The weights represent the proportion of attention each element should pay to every other element, which these raw attention scores don't do. Therefore, we have to apply softmax, to turn them into proportions which would sum up to 1 and can be used as weights.
Differentiability : The entire attention function should be differentiable.
Now perhaps, it would make more sense to introduce the attention function which the original paper used.
$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^{\top}}{\sqrt{d_k}}\right)V$$
(Where \(d_k\) is embedding_dimension)
This method is called scaled dot-product attention. An important thing to note is abstracted from this formulae, so let’s have a look for the one with single-query \(q_i\) :
$$\text{Attention}(qi,K,V) = \sum_{j=1}^{N} \space \text{softmax} \left( \frac{q_ik_j^\top}{\sqrt{d_k}} \right) vj$$
Here, as we can see, we are accounting for every key in the sequence for one query. This means, that effectively, we will be multiplying \(N \times d\) shaped query to \(N \times N\) shaped Key which leads to the time complexity being \(\text{O}(N^2)\). Modern research has managed to approximate the first product and turn this into linear time complexity, Linformers and Performers being the example of that. But we won’t go into the specifics for this implementation.
Multi Head Attention
This self attention, isn’t applied as it is. What we do is that we take the learned projections of input Query, Key and Value and divide it into multiple parts (taken as a parameter in the form of num_heads
). This allows us to jointly attend to different parts of the input in parallel and learn diverse attention patterns (Notice that all heads are getting the same sequence lengths, the division in dimension just helps them focus on different aspects of understanding).
The output (H in the image) is further multiplied with a projection of size (embedding_dim, embedding_dim)
as done before with inputs, resulting in usable Multi Head Attention scores.
Equinox does provide an inbuilt function for multi-headed self attention, and we will be using that to ensure safe and optimized gradient flow. But if you are skeptical, here is a snippet from the Equinox source code on Github which shows that the function does use these expressions in its implementation.
def dot_product_attention_weights(
query: Float[Array, "q_seq qk_size"],
key: Float[Array, "kv_seq qk_size"],
mask: Bool[Array, "q_seq kv_seq"] | None = None,
) -> Float[Array, "q_seq kv_seq"]:
query = query / math.sqrt(query.shape[-1])
logits = jnp.einsum("sd,Sd->sS", query, key)
if mask is not None:
if mask.shape != logits.shape:
raise ValueError(
f"mask must have shape (query_seq_length, "
f"kv_seq_length)=({query.shape[0]}, "
f"{key.shape[0]}). Got {mask.shape}."
)
logits = jnp.where(mask, logits, jnp.finfo(logits.dtype).min)
logits = cast(Array, logits)
with jax.numpy_dtype_promotion("standard"):
dtype = jnp.result_type(logits.dtype, jnp.float32)
weights = jax.nn.softmax(logits.astype(dtype)).astype(logits.dtype)
return weights
def dot_product_attention(
query: Float[Array, "q_seq qk_size"],
key_: Float[Array, "kv_seq qk_size"],
value: Float[Array, "kv_seq v_size"],
mask: Bool[Array, "q_seq kv_seq"] | None = None,
dropout: Dropout | None = None,
*,
key: PRNGKeyArray | None = None,
inference: bool | None = None,
) -> Float[Array, "q_seq v_size"]:
weights = dot_product_attention_weights(query, key_, mask)
if dropout is not None:
weights = dropout(weights, key=key, inference=inference)
attn = jnp.einsum("sS,Sd->sd", weights, value)
return attn
You don’t have to focus on every part of this. Just notice that the first function extracts \(d_k \) from query and then scales it according to the factor discussed before \((1/\sqrt{d_k})\) . Then it uses einsum
to perform dot product between query and key. Finally, it uses softmax to achieve the weights i.e \(\text{softmax}\left( \frac{QK^\top}{\sqrt{d_k}} \right)\). The second function then performs the dot product with values.
Since we are implementing VIT and not Transformers, the masking part is not relevant to us, and you must be familiar with the dropout layer already. The rest of the Multi Head implementation can also be verified through the link. We will proceed to use the provided method for a cleaner implementation, because it is time to discuss something essential which we had deliberately left before.
class Attention(eqx.Module):
attention: eqx.nn.MultiheadAttention
norm: eqx.nn.LayerNorm
dropout: eqx.nn.Dropout
def __init__(self, embedding_dim: int, num_heads: int, dropout_rate: float, key: PRNGKeyArray):
self.attention = eqx.nn.MultiheadAttention(
num_heads=num_heads,
query_size=embedding_dim,
key=key,
dropout_p=dropout_rate
)
self.norm = eqx.nn.LayerNorm(shape=embedding_dim)
self.dropout = eqx.nn.Dropout(p=dropout_rate)
def __call__(self, x: Array, *, key: PRNGKeyArray = None, inference: bool = False) -> Array:
normed = jax.vmap(self.norm)(x)
attended = self.attention(normed, normed, normed, key=key, inference=inference)
if not inference and key is not None:
attended = self.dropout(attended, key=key)
return x + attended
Positional Encoding
Generally, this section is introduced before exploring the attention part, but we think this will help in better comprehension of the underlying reasoning.
Remember the vector embeddings we created after patchification? Yeah, they are not used like that. The problem is, that the attention function we just used is an aggregating function, i.e. the order of the input entities does not matter. We need to understand this in context of a regular sentence. Let's say the input sentence is 'this is a cat'. Now, since all of these are valued equally by the attention function, the attention scores will be same as that of 'cat is this a' or some other permutation of this sentence. This means that the transformer inherently has permutation invariance.
To counter this, we need to inject additional information regarding positions of the entities, so that the model does not lose the context of the ordering in the sentence (same logic can be applied to image patches as well). This information is in the form of another vector of shape (embedding_dim, )
. This vector is added to the original embedding vector and then passed on as input. But for the vector embeddings we know that they are learnable, so we just have to initialize them and make sure they are treated as learnable parameter. What about these Positional Embedding vectors?
In the original paper, the researchers use sinusoidal waves with different time periods for a multi-scale representation. This is similar to how binary encoding is used. The waves with larger time periods behaved like the left-most digit in a binary number which changes very less and represents the coarse relationships and waves with lesser time periods represent the finer relationships.
However, you don't have to worry about this if it seems complex.
Initially positional encodings were kept static and non-learnable, since they can extrapolate to sequences longer than those seen during training. But further research proved that it is more suitable to treat it as another learnable layer.
This has its own benefits. It provides flexibility, and allows the model to find the best way to represent the positional information by its own while not sacrificing much performance for longer unseen sequences. Therefore, we just have to implement this using just a linear layer which should be easy.
class PositionalEmbedding(eqx.Module):
pos_embed: Array
def __init__(self, num_patches: int, embedding_dim: int, *, key: PRNGKeyArray):
self.pos_embed = jr.normal(key, (num_patches + 1, embedding_dim)) * 0.02
def __call__(self, x: Array) -> Array:
return x + self.pos_embed
MLP Layer and Skip Connections
MLP is just another name for Feed Forward Neural Network which we already know about. The original paper uses this expression to show the minimal architecture they were using.
$$FFN(x) = max(0, \space xW_1 + b_1)W_2 + b_2$$
As you can see, there isn’t much to implement here. But one thing which you may have notice in the provided code (and in the above code as well) is the addition of input to the calculated output. Some of you may also know this as a residual connection.
If you don’t know about them already, they are what helped ResNet in 2015 to surpass its predecessors like AlexNet and unlock deeper architectures by addressing the vanishing gradients (the gradient can propagate via the identity term \(x\) even if the output gradient vanishes).
It also allows utilization of low-level features which often get lost in multiple layers. They are used after non-linear layers (like ReLU, MLP, Attention Block etc.) as they can distort or suppress information while this allows us to retain information. At least that’s the high-level idea.
Despite all this techno jumble, it is quite easy to implement them as you can see in the return statement.
class MLP(eqx.Module):
linear1: eqx.nn.Linear
linear2: eqx.nn.Linear
norm: eqx.nn.LayerNorm
dropout: eqx.nn.Dropout
def __init__(self, embedding_dim: int, hidden_dim: int, dropout_rate: float, key: PRNGKeyArray):
key1, key2 = jr.split(key)
self.linear1 = eqx.nn.Linear(in_features=embedding_dim, out_features=hidden_dim, key=key1)
self.linear2 = eqx.nn.Linear(in_features=hidden_dim, out_features=embedding_dim, key=key2)
self.norm = eqx.nn.LayerNorm(shape=embedding_dim)
self.dropout = eqx.nn.Dropout(p=dropout_rate)
def __call__(self, x: Array, *, key: PRNGKeyArray = None, inference: bool = False) -> Array:
# Pre-norm design
normed = jax.vmap(self.norm)(x)
hidden = jax.vmap(self.linear1)(normed)
hidden = jax.nn.gelu(hidden)
if not inference and key is not None:
hidden = self.dropout(hidden, key=key)
output = jax.vmap(self.linear2)(hidden)
if not inference and key is not None:
output = self.dropout(output, key=key)
return x + output
Encoder Block
Regular Transformers are made up of two parts, Encoders and Decoders. However, Vision Transformers concern only with the first part. It is just a fancy way of saying that this is the core part of the model. Both, the attention mechanism and the MLPs will be combined in this part, and by the end, we will have encoded representation where each element has the contextual information of each other element.
This is how the original Transformer and VIT paper defines the Encoder block. We just have to follow it and create the class.
(Note : These longer arrows which go around the main line represent residual connections. One more thing to notice is that ViT being modern architecture, uses pre-norm architecture instead of Transformer’s post-norm because it is empirically found to be more stable.)
class EncoderBlock(eqx.Module):
attention: Attention
mlp: MLP
def __init__(self, embedding_dim: int, hidden_dim: int, num_heads: int,
dropout_rate: float, key: PRNGKeyArray):
key1, key2 = jr.split(key)
self.attention = Attention(embedding_dim, num_heads, dropout_rate, key1)
self.mlp = MLP(embedding_dim, hidden_dim, dropout_rate, key2)
def __call__(self, x: Array, *, key: PRNGKeyArray = None, inference: bool = False) -> Array:
key1, key2 = jr.split(key, 2) if key is not None else (None, None)
x = self.attention(x, key=key1, inference=inference)
x = self.mlp(x, key=key2, inference=inference)
return x
Classification Head
Since we are using this VIT for a classification task on CIFAR-10 dataset, everything we did till now is not contributing to the job. The Encoder block in its entirety is just generating better representations. We need to introduce another learnable parameter which will undergo all these processes, gather global data and represent the image for further downstream tasks, i.e a classification token or a cls_token
. It's simple, we will have one per image of shape (1, D)
and add it to the end of the sequence of all patch embeddings.
After this cls_token
undergoes all this encoder drama, we extract it from the output of the final encoder block and feed it to a linear classification layer called Classification Head.
class ClassificationHead(eqx.Module):
classifier : eqx.nn.Linear
dropout: eqx.nn.Dropout
def __init__(self, embedding_dim: int, num_classes: int, dropout_rate: float, *, key: PRNGKeyArray):
classifier_key = key
self.classifier = eqx.nn.Linear(in_features=embedding_dim, out_features=num_classes, key=classifier_key)
self.dropout = eqx.nn.Dropout(p=dropout_rate)
def __call__(self, cls_token: Array, *, key: PRNGKeyArray | None = None, inference: bool = False) -> Array:
if not inference and key is not None:
cls_token = self.dropout(cls_token, key=key)
return self.classifier(cls_token)
(The creation of cls_token
and its extraction are included in the VIT class)
Vision Transformer
Now is the time to create a full culmination of every individual unit. There is not much to say, we just follow the architecture provided in the previous image and we will get our desired Vision Transformer block.
class VIT(eqx.Module):
patch_embed: PatchEmbedding
cls_token: Array
pos_embed: PositionalEmbedding
encoder_blocks: list
final_norm: eqx.nn.LayerNorm
classification_head: ClassificationHead
dropout: eqx.nn.Dropout
def __init__(self, patch_size: int, embedding_dim: int, hidden_dim: int,
num_heads: int, num_layers: int, num_classes: int,
dropout_rate: float, num_patches: int, *, key: PRNGKeyArray):
keys = jr.split(key, num_layers + 6)
self.patch_embed = PatchEmbedding(patch_size, embedding_dim, keys[0])
self.cls_token = jr.normal(keys[1], (1, embedding_dim)) * 0.02
self.pos_embed = PositionalEmbedding(num_patches, embedding_dim, key=keys[2])
self.encoder_blocks = [
EncoderBlock(embedding_dim, hidden_dim, num_heads, dropout_rate, keys[3 + i])
for i in range(num_layers)
]
self.final_norm = eqx.nn.LayerNorm(shape=embedding_dim)
self.classification_head = ClassificationHead(embedding_dim, num_classes, dropout_rate, key=keys[3 + num_layers])
self.dropout = eqx.nn.Dropout(p=dropout_rate)
def __call__(self, x: Array, *, key: PRNGKeyArray = None, inference: bool = False) -> Array:
keys = jr.split(key, len(self.encoder_blocks) + 2) if key is not None else [None] * (len(self.encoder_blocks) + 2)
x = self.patch_embed(x)
x = jnp.concatenate([self.cls_token, x], axis=0)
x = self.pos_embed(x)
if not inference:
x = self.dropout(x, key=keys[0])
for i, block in enumerate(self.encoder_blocks):
x = block(x, key=keys[i + 1], inference=inference)
cls_token = self.final_norm(x[0])
logits = self.classification_head(cls_token, key=keys[-1], inference=inference)
return logits
Training loop
This part mostly concerns computing gradients and accounting for them during the training loop. This may seem strange to you if you have not used Jax based frameworks before. Though we will not dive deep into the nuances of the framework, we will still try to present a high level idea of how JAX deals with training. First of all, what we will be using for calculating loss and handling optimizer is a library called optax, which is a core part of the JAX ecosystem. Optax provides methods to initialize optimizers, schedule learning rates along with several loss function. The one which we are using here is softmax with cross entropy loss for integer labels.
This may look simple but is using an important concept here. Theoretically, we are not supposed to return logits, since in classification, we are supposed to return probabilities corresponding to all classes. However, logits are more suitable for training during multiple epochs due to underlying calculations. Therefore, we have to employ a function which works accordingly and uses softmax internally during training to ensure proper numerical stability. You can find similar function in both Tensorflow and Pytorch as well.
@eqx.filter_value_and_grad
def compute_loss(model, images, labels, key):
def single_forward(image, subkey):
return model(image, key=subkey, inference=False)
keys = jr.split(key, len(images))
logits = jax.vmap(single_forward)(images, keys)
loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels)
return jnp.mean(loss)
Also, this is the standard way of updating gradients using optax:
@eqx.filter_jit
def train_step(model, opt_state, images, labels, optimizer, key):
loss, grads = compute_loss(model, images, labels, key)
updates, new_opt_state = optimizer.update(grads, opt_state, model)
new_model = eqx.apply_updates(model, updates)
return new_model, new_opt_state, loss
One more thing to keep in mind is that JAX based implementations don't use epochs, instead, they use steps. One step corresponds to processing of a single batch while one epoch corresponds to processing of the whole training set once. Therefore, the number of steps is often larger than what you would usually see with epochs.
def train_model(model, optimizer, trainloader, num_steps):
opt_state = optimizer.init(eqx.filter(model, eqx.is_array))
losses = []
def infinite_dataloader():
while True:
yield from trainloader
print("Starting training...")
key = jr.PRNGKey(0)
for step, (images, labels) in zip(range(num_steps), infinite_dataloader()):
# CIFAR-10 has channel first data
images = jnp.array(images.numpy().transpose(0, 2, 3, 1)) # BCHW -> BHWC
labels = jnp.array(labels.numpy())
if step == 0:
print(f"Image shape: {images.shape}, range: [{images.min():.3f}, {images.max():.3f}]")
print(f"Labels shape: {labels.shape}, unique labels: {jnp.unique(labels)}")
key, subkey = jr.split(key)
model, opt_state, loss = train_step(model, opt_state, images, labels, optimizer, subkey)
losses.append(float(loss))
if step % 2000 == 0:
print(f"Step {step:5d}: Loss = {loss:.4f}")
return model, losses
Let’s see how this performs on the CIFAR-10 dataset.
As you can see. We were able to obtain 79% accuracy on CIFAR-10 for 30k steps. Certainly, this can be further improved through some hyperparameter tuning, but this marks the completing of our current project.
Thank you for reading this far. Links for all related code, docs and other references are below.
References
→ Papers :
Attention Is All You Need - The Original transformers paper
An Image is Worth 16 × 16 Words : The Original ViT paper
→ Videos :
Umar Jamil’s Videos on Transformers and its Pytorch implementation.
Yannic Kilcher’s Video on Vision Transformer.
→ Docs :
Equinox’s provided implementation of ViT.
JAX Documentation
→ My Code : Github
Subscribe to my newsletter
Read articles from Ayush Saraswat directly inside your inbox. Subscribe to the newsletter, and don't miss out.
Written by

Ayush Saraswat
Ayush Saraswat
Aspiring Computer Vision engineer, eager to delve into the world of AI/ML, cloud, Computer Vision, TinyML and other programming stuff.