The Transformer Implementation in TensorFlow.js

Mohamad MahmoodMohamad Mahmood
9 min read

What is the Transformer Model?

The Transformer model is a groundbreaking neural network architecture introduced in the 2017 paper "Attention Is All You Need" by Vaswani et al. It revolutionized natural language processing (NLP) and has since become the foundation for many state-of-the-art models, such as BERT, GPT, T5, and others.

Unlike traditional sequence models like Recurrent Neural Networks (RNNs) or Convolutional Neural Networks (CNNs), the Transformer relies entirely on self-attention mechanisms to process input sequences. This design allows it to handle long-range dependencies more effectively, scale better with large datasets, and train faster on modern hardware.


Key Features of the Transformer

1. Self-Attention Mechanism

2. Multi-Head Attention

3. Positional Encoding

4. Encoder-Decoder Architecture

5. Feed-Forward Networks

6. Residual Connections and Layer Normalization


How Does the Transformer Work?

1. Input Representation

2. Encoder

3. Decoder

4. Output


Advantages of the Transformer

  1. Parallelization :

  2. Long-Range Dependencies :

  3. Scalability :

    • Transformers can be scaled to very large sizes (e.g., billions of parameters) and trained on massive datasets, leading to state-of-the-art performance on a wide range of tasks.
  4. Versatility :


Applications of the Transformer

  1. Natural Language Processing (NLP) :

  2. Computer Vision :

  3. Speech Processing :

  4. Generative Models :

  5. Multimodal Models :


Demonstration

The following code demonstrates a simplified implementation of a Transformer model using TensorFlow.js, a JavaScript library for machine learning. Below, I will break down the code and explain its components in detail.


1. HTML Structure

<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>Transformer in TensorFlow.js</title>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>
</head>
<body>
    <h1>Transformer Model Example</h1>
    <script>
        // JavaScript code for the Transformer model
    </script>
</body>
</html>

Key Points

  • The <html> tag defines the document structure.

  • The <head> section includes:

    • Metadata such as character encoding (UTF-8) and viewport settings for responsiveness.

    • A <script> tag to load the TensorFlow.js library from a CDN (Content Delivery Network).

  • The <body> contains:

    • A heading (<h1>) titled "Transformer Model Example".

    • A <script> block where the Transformer model is implemented.


2. Multi-Head Attention Layer

class MultiHeadAttention {
    constructor(numHeads, dModel) {
        this.numHeads = numHeads;
        this.dModel = dModel;
        this.depth = dModel / numHeads;

        this.wq = tf.layers.dense({ units: dModel });
        this.wk = tf.layers.dense({ units: dModel });
        this.wv = tf.layers.dense({ units: dModel });
        this.dense = tf.layers.dense({ units: dModel });
    }

    splitHeads(x) {
        const batchSize = x.shape[0];
        return tf.reshape(x, [batchSize, -1, this.numHeads, this.depth])
                 .transpose([0, 2, 1, 3]);
    }

    call(q, k, v) {
        const batchSize = q.shape[0];

        q = this.wq.apply(q);
        k = this.wk.apply(k);
        v = this.wv.apply(v);

        q = this.splitHeads(q);
        k = this.splitHeads(k);
        v = this.splitHeads(v);

        const scale = Math.sqrt(this.depth);
        const logits = tf.matMul(q, k.transpose([0, 1, 3, 2])).div(scale);
        const weights = tf.softmax(logits);
        let output = tf.matMul(weights, v);

        output = output.transpose([0, 2, 1, 3]).reshape([batchSize, -1, this.dModel]);
        return this.dense.apply(output);
    }
}

Key Concepts

  1. Multi-Head Attention:

    • This layer implements the core mechanism of the Transformer model, allowing it to focus on different parts of the input sequence simultaneously.

    • It splits the input into multiple "heads," computes attention for each head, and then combines the results.

  2. Constructor:

    • numHeads: Number of attention heads.

    • dModel: Dimensionality of the input embeddings.

    • depth: Each head's dimensionality (dModel / numHeads).

    • Dense layers (wq, wk, wv) are used to project the queries, keys, and values into the appropriate dimensions.

    • A final dense layer (dense) combines the outputs of all heads.

  3. splitHeads:

    • Reshapes the input tensor to separate the heads and rearranges the dimensions for efficient computation.
  4. call:

    • Applies the query (wq), key (wk), and value (wv) transformations.

    • Splits the tensors into multiple heads.

    • Computes scaled dot-product attention:

      • logits: Dot product of queries and keys, scaled by sqrt(depth).

      • weights: Softmax of logits to compute attention weights.

      • output: Weighted sum of values.

    • Combines the outputs of all heads and applies the final dense layer.


3. Transformer Model

class Transformer {
    constructor(numLayers, numHeads, dModel, dff, inputVocabSize, targetVocabSize) {
        this.encoderEmbedding = tf.layers.embedding({ inputDim: inputVocabSize, outputDim: dModel });
        this.decoderEmbedding = tf.layers.embedding({ inputDim: targetVocabSize, outputDim: dModel });
        this.attention = new MultiHeadAttention(numHeads, dModel);
        this.ffn = tf.layers.dense({ units: dff, activation: 'relu' });
        this.finalLayer = tf.layers.dense({ units: targetVocabSize });
    }

    call(encInput, decInput) {
        const encOutput = this.encoderEmbedding.apply(encInput);
        const decOutput = this.decoderEmbedding.apply(decInput);
        const attentionOutput = this.attention.call(encOutput, encOutput, decOutput);
        const ffnOutput = this.ffn.apply(attentionOutput);
        return this.finalLayer.apply(ffnOutput);
    }
}

Key Concepts

  1. Transformer Architecture:

    • A Transformer consists of an encoder and a decoder, both of which use multi-head attention and feed-forward networks.
  2. Constructor:

    • numLayers: Number of layers in the encoder and decoder (not fully implemented here).

    • numHeads, dModel, dff: Parameters for the multi-head attention and feed-forward network.

    • inputVocabSize, targetVocabSize: Vocabulary sizes for the input and target sequences.

    • Embedding layers (encoderEmbedding, decoderEmbedding) convert token indices into dense vectors.

    • attention: Multi-head attention layer.

    • ffn: Feed-forward network with ReLU activation.

    • finalLayer: Produces logits for the target vocabulary.

  3. call:

    • Converts input tokens (encInput, decInput) into embeddings.

    • Applies multi-head attention to combine encoder and decoder outputs.

    • Passes the result through a feed-forward network and produces final logits.


4. Running the Transformer

async function runTransformer() {
    const numLayers = 4;
    const numHeads = 8;
    const dModel = 128;
    const dff = 512;
    const inputVocabSize = 10000;
    const targetVocabSize = 10000;

    const transformer = new Transformer(numLayers, numHeads, dModel, dff, inputVocabSize, targetVocabSize);

    // Sample data with consistent lengths
    const batchSize = 64;
    const encInput = tf.randomUniform([batchSize, 38], 0, inputVocabSize, 'int32'); // (batch_size, seq_length)
    const decInput = tf.randomUniform([batchSize, 38], 0, targetVocabSize, 'int32'); // (batch_size, seq_length)

    // Forward pass
    const outputs = transformer.call(encInput, decInput);
    outputs.print(); // Outputs the result to the console
}

runTransformer();

Key Concepts

  1. Parameters:

    • Defines the architecture of the Transformer (numLayers, numHeads, dModel, etc.).

    • Specifies the vocabulary sizes and batch size.

  2. Sample Data:

    • Generates random input and target sequences (encInput, decInput) using tf.randomUniform.
  3. Forward Pass:

    • Calls the transformer.call method to perform a forward pass through the model.

    • Prints the output logits to the console.


Summary

This code implements a simplified version of the Transformer model in TensorFlow.js. It includes:

  1. A multi-head attention layer for computing attention over multiple heads.

  2. A Transformer class that integrates embeddings, attention, and feed-forward networks.

  3. A test function (runTransformer) to demonstrate the model's forward pass with random input data.

Limitations

While this implementation is functional, it omits some details of a full Transformer model, such as:

  • Positional encoding.

  • Masking (e.g., padding masks, causal masks).

  • Stacking multiple layers.

However, it serves as a good starting point for understanding the core mechanisms of Transformers in JavaScript.

Codepen

0
Subscribe to my newsletter

Read articles from Mohamad Mahmood directly inside your inbox. Subscribe to the newsletter, and don't miss out.

Written by

Mohamad Mahmood
Mohamad Mahmood

Mohamad's interest is in Programming (Mobile, Web, Database and Machine Learning). He studies at the Center For Artificial Intelligence Technology (CAIT), Universiti Kebangsaan Malaysia (UKM).