Transformer Example: Hello World with Meaningful Prediction

Mohamad MahmoodMohamad Mahmood
10 min read

To make the Transformer model predict something meaningful, we need to train it on a dataset. For simplicity, let's use a small, synthetic dataset where the input is a sequence like "<START> hello world <END>" and the target is a slightly modified version of the input, such as "<START> bonjour monde <END>".

Below is the code that includes a simple training loop. The model will learn to map the input sequence ("hello world") to the target sequence ("bonjour monde"). After training, the model should be able to generate meaningful predictions.


1. Vocabulary and Tokenization

const vocab = ["<PAD>", "<START>", "<END>", "hello", "world", "bonjour", "monde"];
const vocabSize = vocab.length;

function tokenize(sentence) {
    return sentence.split(" ").map(word => vocab.indexOf(word));
}

function detokenize(tokens) {
    return tokens.map(token => vocab[token]).join(" ");
}

Explanation

  • Vocabulary (vocab):

    • A predefined list of words that the model understands. It includes special tokens like <PAD> (for padding), <START> (to mark the start of a sequence), and <END> (to mark the end of a sequence).
  • Tokenization (tokenize):

    • Converts a sentence (e.g., "hello world") into a list of token IDs based on their positions in the vocabulary.

    • Example: "hello world"[3, 4].

  • Detokenization (detokenize):

    • Converts a list of token IDs back into a human-readable sentence.

    • Example: [3, 4]"hello world".


2. Padding Function

function padSequence(sequence, maxLength, padValue = 0) {
    const padded = [...sequence];
    while (padded.length < maxLength) {
        padded.push(padValue); // Add padding tokens
    }
    return padded.slice(0, maxLength); // Truncate if necessary
}

Explanation

  • Ensures all sequences have the same length (maxLength) by adding padding tokens (<PAD>, represented as 0).

  • Example: If maxLength = 5 and the input sequence is [3, 4], the output will be [3, 4, 0, 0, 0].


3. Positional Encoding

function getAngles(pos, i, dModel) {
    return pos / Math.pow(10000, (2 * (i / 2)) / dModel);
}

function positionalEncoding(position, dModel) {
    const angleRads = tf.tidy(() => {
        const angles = [];
        for (let pos = 0; pos < position; pos++) {
            for (let i = 0; i < dModel; i++) {
                angles.push(getAngles(pos, i, dModel));
            }
        }
        return tf.tensor(angles).reshape([position, dModel]);
    });

    const encoding = tf.tidy(() => {
        const sinEncoding = angleRads.slice([0, 0], [-1, dModel / 2]).sin();
        const cosEncoding = angleRads.slice([0, dModel / 2], [-1, dModel / 2]).cos();
        return tf.concat([sinEncoding, cosEncoding], 1);
    });

    angleRads.dispose();
    return encoding;
}

Explanation

  • Purpose:

    • Adds information about the position of tokens in the sequence to the embeddings. This is crucial because the Transformer does not inherently understand the order of tokens.
  • Implementation:

    • Computes sine and cosine values for each position and dimension using a formula described in the original Transformer paper.

    • Combines these values into a positional encoding tensor.


4. Masking

function createPaddingMask(seq) {
    return seq.equal(0).expandDims(-1).expandDims(1); // 0 indicates padding
}

Explanation

  • Purpose:

    • Masks out padding tokens (<PAD>) so they do not affect attention computations.
  • How It Works:

    • Creates a binary mask where true corresponds to padding tokens (0) and false corresponds to non-padding tokens.

5. Multi-Head Attention

class MultiHeadAttention {
    constructor(numHeads, dModel) {
        this.numHeads = numHeads;
        this.dModel = dModel;
        this.depth = dModel / numHeads;

        this.wq = tf.layers.dense({ units: dModel });
        this.wk = tf.layers.dense({ units: dModel });
        this.wv = tf.layers.dense({ units: dModel });
        this.dense = tf.layers.dense({ units: dModel });
    }

    splitHeads(x) {
        const batchSize = x.shape[0];
        return tf.reshape(x, [batchSize, -1, this.numHeads, this.depth])
                 .transpose([0, 2, 1, 3]);
    }

    call(q, k, v, mask = null) {
        const batchSize = q.shape[0];

        q = this.wq.apply(q);
        k = this.wk.apply(k);
        v = this.wv.apply(v);

        q = this.splitHeads(q);
        k = this.splitHeads(k);
        v = this.splitHeads(v);

        let logits = tf.matMul(q, k.transpose([0, 1, 3, 2])).div(Math.sqrt(this.depth));
        if (mask) {
            logits = logits.add(mask.mul(-1e9)); // Apply mask
        }
        const weights = tf.softmax(logits);
        let output = tf.matMul(weights, v);

        output = output.transpose([0, 2, 1, 3]).reshape([batchSize, -1, this.dModel]);
        return this.dense.apply(output);
    }
}

Explanation

  • Purpose:

    • Implements multi-head attention, which allows the model to focus on different parts of the input sequence simultaneously.
  • Key Steps:

    • Linear transformations (wq, wk, wv) project the inputs into query, key, and value vectors.

    • Splits the projections into multiple heads.

    • Computes attention scores (scaled dot-product attention) and applies softmax to obtain attention weights.

    • Combines the results from all heads into a single output.


6. Encoder Layer

class EncoderLayer {
    constructor(dModel, numHeads, dff) {
        this.mha = new MultiHeadAttention(numHeads, dModel);

        this.ffn = tf.sequential({
            layers: [
                tf.layers.dense({ units: dff, activation: 'relu', inputShape: [dModel] }),
                tf.layers.dense({ units: dModel })
            ]
        });

        this.layernorm1 = tf.layers.layerNormalization({ epsilon: 1e-6 });
        this.layernorm2 = tf.layers.layerNormalization({ epsilon: 1e-6 });
    }

    call(x, mask = null) {
        const attnOutput = this.mha.call(x, x, x, mask);
        const out1 = this.layernorm1.apply(x.add(attnOutput));

        const ffnOutput = this.ffn.apply(out1);
        const out2 = this.layernorm2.apply(out1.add(ffnOutput));
        return out2;
    }
}

Explanation

  • Purpose:

    • Processes the input through multi-head attention and a feed-forward network.
  • Key Components:

    • Multi-Head Attention (mha): Captures relationships between tokens.

    • Feed-Forward Network (ffn): Applies non-linear transformations to the attention output.

    • Layer Normalization (layernorm1, layernorm2): Stabilizes training by normalizing activations.

    • Residual Connections: Adds the input to the output of each sub-layer to preserve information.


7. Transformer Model

class Transformer {
    constructor(numLayers, numHeads, dModel, dff, inputVocabSize, maxSeqLength) {
        this.numLayers = numLayers;
        this.dModel = dModel;

        this.embedding = tf.layers.embedding({ inputDim: inputVocabSize, outputDim: dModel });
        this.positionalEncoding = positionalEncoding(maxSeqLength, dModel);

        this.encoderLayers = Array.from({ length: numLayers }, () => new EncoderLayer(dModel, numHeads, dff));

        this.finalLayer = tf.layers.dense({ units: inputVocabSize });
    }

    call(input) {
        const paddingMask = createPaddingMask(input);

        let x = this.embedding.apply(input);
        x = x.add(this.positionalEncoding.slice([0, 0], [input.shape[1], this.dModel]));

        for (const layer of this.encoderLayers) {
            x = layer.call(x, paddingMask);
        }

        return this.finalLayer.apply(x);
    }
}

Explanation

  • Purpose:

    • Implements the Transformer architecture with an embedding layer, positional encoding, stacked encoder layers, and a final dense layer.
  • Key Steps:

    • Embedding: Converts token IDs into dense vectors.

    • Positional Encoding: Adds positional information to the embeddings.

    • Encoder Layers: Processes the input through multiple layers of multi-head attention and feed-forward networks.

    • Final Dense Layer: Produces logits for each token in the vocabulary.


8. Training Loop

async function trainAndRunTransformer() {
    const transformer = new Transformer(...);

    // Training data
    const inputSentence = "<START> hello world <END>";
    const targetSentence = "<START> bonjour monde <END>";

    const inputTokens = tokenize(inputSentence);
    const targetTokens = tokenize(targetSentence);

    const paddedInputTokens = padSequence(inputTokens, maxSeqLength, vocab.indexOf("<PAD>"));
    const paddedTargetTokens = padSequence(targetTokens, maxSeqLength, vocab.indexOf("<PAD>"));

    const inputTensor = tf.tensor2d([paddedInputTokens], [1, maxSeqLength]);
    const targetTensor = tf.tensor2d([paddedTargetTokens], [1, maxSeqLength]);

    const targetTensorInt32 = targetTensor.cast('int32');

    function loss(labels, preds) {
        return tf.losses.softmaxCrossEntropy(labels, preds);
    }

    const optimizer = tf.train.adam(learningRate);

    for (let epoch = 0; epoch < epochs; epoch++) {
        const currentLoss = tf.tidy(() => {
            const preds = transformer.call(inputTensor);
            const oneHotLabels = tf.oneHot(targetTensorInt32, vocabSize);
            const lossValue = loss(oneHotLabels, preds);
            return lossValue.asScalar();
        });

        optimizer.minimize(() => {
            const preds = transformer.call(inputTensor);
            const oneHotLabels = tf.oneHot(targetTensorInt32, vocabSize);
            return loss(oneHotLabels, preds);
        }, true);

        if (epoch % 10 === 0) {
            console.log(`Epoch ${epoch}: Loss = ${currentLoss.dataSync()[0].toFixed(4)}`);
        }

        currentLoss.dispose();
    }

    // Inference after training
    const outputs = transformer.call(inputTensor);
    const predictedTokenIDs = outputs.argMax(-1);
    const predictedTokens = await predictedTokenIDs.array();

    const outputDiv = document.getElementById('output');
    outputDiv.textContent = `Input Sentence: "${inputSentence}"\n\nPredicted Tokens: ${detokenize(predictedTokens[0])}`;
}

Explanation

  • Training Data:

    • Defines input-output pairs for training (e.g., "hello world""bonjour monde").
  • Loss Function:

    • Uses cross-entropy loss to measure the difference between predicted and target tokens.
  • Optimizer:

    • Updates the model's weights using the Adam optimizer to minimize the loss.
  • Training Loop:

    • Iteratively computes predictions, calculates the loss, and updates the model's weights.

    • Logs the loss every 10 epochs to monitor progress.

  • Inference:

    • After training, the model generates predictions for the input sequence and displays the decoded output.

Conclusion

This implementation demonstrates how to build and train a Transformer model in TensorFlow.js. The model learns to map input sequences to target sequences by minimizing the cross-entropy loss. Proper memory management ensures efficient computation, and the modular design makes it easy to extend or modify the model.

Full code

<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>Transformer Example: Hello World</title>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>
    <style>
        body {
            font-family: Arial, sans-serif;
            margin: 20px;
        }
        #output {
            margin-top: 20px;
            padding: 10px;
            border: 1px solid #ccc;
            background-color: #f9f9f9;
            white-space: pre-wrap;
        }
    </style>
</head>
<body>
    <h1>Transformer Example: Hello World</h1>
    <p>Output will appear below:</p>
    <div id="output">Loading...</div>

    <script>
        // Vocabulary for tokenization
        const vocab = ["<PAD>", "<START>", "<END>", "hello", "world", "bonjour", "monde"];
        const vocabSize = vocab.length;

        // Tokenizer functions
        function tokenize(sentence) {
            return sentence.split(" ").map(word => vocab.indexOf(word));
        }

        function detokenize(tokens) {
            return tokens.map(token => vocab[token]).join(" ");
        }

        // Padding function
        function padSequence(sequence, maxLength, padValue = 0) {
            const padded = [...sequence];
            while (padded.length < maxLength) {
                padded.push(padValue); // Add padding tokens
            }
            return padded.slice(0, maxLength); // Truncate if necessary
        }

        // Helper Functions
        function getAngles(pos, i, dModel) {
            return pos / Math.pow(10000, (2 * (i / 2)) / dModel);
        }

        function positionalEncoding(position, dModel) {
            const angleRads = tf.tidy(() => {
                const angles = [];
                for (let pos = 0; pos < position; pos++) {
                    for (let i = 0; i < dModel; i++) {
                        angles.push(getAngles(pos, i, dModel));
                    }
                }
                return tf.tensor(angles).reshape([position, dModel]);
            });

            const encoding = tf.tidy(() => {
                const sinEncoding = angleRads.slice([0, 0], [-1, dModel / 2]).sin();
                const cosEncoding = angleRads.slice([0, dModel / 2], [-1, dModel / 2]).cos();
                return tf.concat([sinEncoding, cosEncoding], 1);
            });

            angleRads.dispose();
            return encoding;
        }

        function createPaddingMask(seq) {
            return seq.equal(0).expandDims(-1).expandDims(1); // 0 indicates padding
        }

        // Multi-Head Attention Layer
        class MultiHeadAttention {
            constructor(numHeads, dModel) {
                this.numHeads = numHeads;
                this.dModel = dModel;
                this.depth = dModel / numHeads;

                this.wq = tf.layers.dense({ units: dModel });
                this.wk = tf.layers.dense({ units: dModel });
                this.wv = tf.layers.dense({ units: dModel });
                this.dense = tf.layers.dense({ units: dModel });
            }

            splitHeads(x) {
                const batchSize = x.shape[0];
                return tf.reshape(x, [batchSize, -1, this.numHeads, this.depth])
                         .transpose([0, 2, 1, 3]);
            }

            call(q, k, v, mask = null) {
                const batchSize = q.shape[0];

                q = this.wq.apply(q);
                k = this.wk.apply(k);
                v = this.wv.apply(v);

                q = this.splitHeads(q);
                k = this.splitHeads(k);
                v = this.splitHeads(v);

                let logits = tf.matMul(q, k.transpose([0, 1, 3, 2])).div(Math.sqrt(this.depth));
                if (mask) {
                    logits = logits.add(mask.mul(-1e9)); // Apply mask
                }
                const weights = tf.softmax(logits);
                let output = tf.matMul(weights, v);

                output = output.transpose([0, 2, 1, 3]).reshape([batchSize, -1, this.dModel]);
                return this.dense.apply(output);
            }
        }

        // Encoder Layer
        class EncoderLayer {
            constructor(dModel, numHeads, dff) {
                this.mha = new MultiHeadAttention(numHeads, dModel);

                this.ffn = tf.sequential({
                    layers: [
                        tf.layers.dense({ units: dff, activation: 'relu', inputShape: [dModel] }),
                        tf.layers.dense({ units: dModel })
                    ]
                });

                this.layernorm1 = tf.layers.layerNormalization({ epsilon: 1e-6 });
                this.layernorm2 = tf.layers.layerNormalization({ epsilon: 1e-6 });
            }

            call(x, mask = null) {
                const attnOutput = this.mha.call(x, x, x, mask);
                const out1 = this.layernorm1.apply(x.add(attnOutput));

                const ffnOutput = this.ffn.apply(out1);
                const out2 = this.layernorm2.apply(out1.add(ffnOutput));
                return out2;
            }
        }

        // Transformer Model
        class Transformer {
            constructor(numLayers, numHeads, dModel, dff, inputVocabSize, maxSeqLength) {
                this.numLayers = numLayers;
                this.dModel = dModel;

                this.embedding = tf.layers.embedding({ inputDim: inputVocabSize, outputDim: dModel });
                this.positionalEncoding = positionalEncoding(maxSeqLength, dModel);

                this.encoderLayers = Array.from({ length: numLayers }, () => new EncoderLayer(dModel, numHeads, dff));

                this.finalLayer = tf.layers.dense({ units: inputVocabSize });
            }

            call(input) {
                const paddingMask = createPaddingMask(input);

                let x = this.embedding.apply(input);
                x = x.add(this.positionalEncoding.slice([0, 0], [input.shape[1], this.dModel]));

                for (const layer of this.encoderLayers) {
                    x = layer.call(x, paddingMask);
                }

                return this.finalLayer.apply(x);
            }
        }

        async function trainAndRunTransformer() {
            const numLayers = 2;
            const numHeads = 4;
            const dModel = 16;
            const dff = 64;
            const maxSeqLength = 5;

            const transformer = new Transformer(numLayers, numHeads, dModel, dff, vocabSize, maxSeqLength);

            // Training data
            const inputSentence = "<START> hello world <END>";
            const targetSentence = "<START> bonjour monde <END>";

            const inputTokens = tokenize(inputSentence);
            const targetTokens = tokenize(targetSentence);

            const paddedInputTokens = padSequence(inputTokens, maxSeqLength, vocab.indexOf("<PAD>"));
            const paddedTargetTokens = padSequence(targetTokens, maxSeqLength, vocab.indexOf("<PAD>"));

            const inputTensor = tf.tensor2d([paddedInputTokens], [1, maxSeqLength]);
            const targetTensor = tf.tensor2d([paddedTargetTokens], [1, maxSeqLength]);

            // Ensure targetTensor is int32 for oneHot
            const targetTensorInt32 = targetTensor.cast('int32');

            // Loss function
            function loss(labels, preds) {
                return tf.losses.softmaxCrossEntropy(labels, preds);
            }

            // Optimizer
            const learningRate = 0.001;
            const optimizer = tf.train.adam(learningRate);

            // Training loop
            const epochs = 100;
            for (let epoch = 0; epoch < epochs; epoch++) {
                // Use tf.tidy to manage memory
                const currentLoss = tf.tidy(() => {
                    const preds = transformer.call(inputTensor);
                    const oneHotLabels = tf.oneHot(targetTensorInt32, vocabSize);
                    const lossValue = loss(oneHotLabels, preds);
                    return lossValue.asScalar(); // Ensure loss is a scalar tensor
                });

                // Minimize the loss
                optimizer.minimize(() => {
                    const preds = transformer.call(inputTensor);
                    const oneHotLabels = tf.oneHot(targetTensorInt32, vocabSize);
                    return loss(oneHotLabels, preds);
                }, true);

                // Log the loss every 10 epochs
                if (epoch % 10 === 0) {
                    console.log(`Epoch ${epoch}: Loss = ${currentLoss.dataSync()[0].toFixed(4)}`);
                }

                // Dispose of the loss tensor to free memory
                currentLoss.dispose();
            }

            // Inference after training
            const outputs = transformer.call(inputTensor);
            const predictedTokenIDs = outputs.argMax(-1); // Get the most likely token at each position
            const predictedTokens = await predictedTokenIDs.array();

            // Decode tokens back to text
            const outputDiv = document.getElementById('output');
            outputDiv.textContent = `Input Sentence: "${inputSentence}"\n\nPredicted Tokens: ${detokenize(predictedTokens[0])}`;
        }

        trainAndRunTransformer();
    </script>
</body>
</html>

Codepen

0
Subscribe to my newsletter

Read articles from Mohamad Mahmood directly inside your inbox. Subscribe to the newsletter, and don't miss out.

Written by

Mohamad Mahmood
Mohamad Mahmood

Mohamad's interest is in Programming (Mobile, Web, Database and Machine Learning). He studies at the Center For Artificial Intelligence Technology (CAIT), Universiti Kebangsaan Malaysia (UKM).