Transformer Example: Hello World

Mohamad MahmoodMohamad Mahmood
4 min read

To make the Transformer model more practical, let's create a real-world scenario where we represent and process a simple sentence like "hello world" in TensorFlow.js. This involves tokenizing the input text, embedding it into numerical representations, and passing it through the Transformer model.

Below is an implementation that demonstrates how to encode a sentence like "hello world," process it through the Transformer, and decode the output back into human-readable text.

<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>Transformer Example: Hello World</title>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>
    <style>
        body {
            font-family: Arial, sans-serif;
            margin: 20px;
        }
        #output {
            margin-top: 20px;
            padding: 10px;
            border: 1px solid #ccc;
            background-color: #f9f9f9;
            white-space: pre-wrap;
        }
    </style>
</head>
<body>
    <h1>Transformer Example: Hello World</h1>
    <p>Output will appear below:</p>
    <div id="output">Loading...</div>
    <p>
        Note: The output is a result of passing an untrained Transformer model random inputs. 
        The model's weights are initialized randomly, so its predictions are meaningless until 
        it is trained on a dataset.
    </p>
    <script>
        // Vocabulary for tokenization
        const vocab = ["<PAD>", "<START>", "<END>", "hello", "world"];
        const vocabSize = vocab.length;

        // Tokenizer functions
        function tokenize(sentence) {
            return sentence.split(" ").map(word => vocab.indexOf(word));
        }

        function detokenize(tokens) {
            return tokens.map(token => vocab[token]).join(" ");
        }

        // Padding function
        function padSequence(sequence, maxLength, padValue = 0) {
            const padded = [...sequence];
            while (padded.length < maxLength) {
                padded.push(padValue); // Add padding tokens
            }
            return padded.slice(0, maxLength); // Truncate if necessary
        }

        // Helper Functions
        function getAngles(pos, i, dModel) {
            return pos / Math.pow(10000, (2 * (i / 2)) / dModel);
        }

        function positionalEncoding(position, dModel) {
            const angleRads = tf.tidy(() => {
                const angles = [];
                for (let pos = 0; pos < position; pos++) {
                    for (let i = 0; i < dModel; i++) {
                        angles.push(getAngles(pos, i, dModel));
                    }
                }
                return tf.tensor(angles).reshape([position, dModel]);
            });

            const encoding = tf.tidy(() => {
                const sinEncoding = angleRads.slice([0, 0], [-1, dModel / 2]).sin();
                const cosEncoding = angleRads.slice([0, dModel / 2], [-1, dModel / 2]).cos();
                return tf.concat([sinEncoding, cosEncoding], 1);
            });

            angleRads.dispose();
            return encoding;
        }

        function createPaddingMask(seq) {
            return seq.equal(0).expandDims(-1).expandDims(1); // 0 indicates padding
        }

        function createCausalMask(seqLength) {
            const mask = tf.linalg.bandPart(tf.ones([seqLength, seqLength]), -1, 0);
            return mask.expandDims(0).expandDims(1); // Add batch and head dimensions
        }

        // Multi-Head Attention Layer
        class MultiHeadAttention {
            constructor(numHeads, dModel) {
                this.numHeads = numHeads;
                this.dModel = dModel;
                this.depth = dModel / numHeads;

                this.wq = tf.layers.dense({ units: dModel });
                this.wk = tf.layers.dense({ units: dModel });
                this.wv = tf.layers.dense({ units: dModel });
                this.dense = tf.layers.dense({ units: dModel });
            }

            splitHeads(x) {
                const batchSize = x.shape[0];
                return tf.reshape(x, [batchSize, -1, this.numHeads, this.depth])
                         .transpose([0, 2, 1, 3]);
            }

            call(q, k, v, mask = null) {
                const batchSize = q.shape[0];

                q = this.wq.apply(q);
                k = this.wk.apply(k);
                v = this.wv.apply(v);

                q = this.splitHeads(q);
                k = this.splitHeads(k);
                v = this.splitHeads(v);

                let logits = tf.matMul(q, k.transpose([0, 1, 3, 2])).div(Math.sqrt(this.depth));
                if (mask) {
                    logits = logits.add(mask.mul(-1e9)); // Apply mask
                }
                const weights = tf.softmax(logits);
                let output = tf.matMul(weights, v);

                output = output.transpose([0, 2, 1, 3]).reshape([batchSize, -1, this.dModel]);
                return this.dense.apply(output);
            }
        }

        // Encoder Layer
        class EncoderLayer {
            constructor(dModel, numHeads, dff) {
                this.mha = new MultiHeadAttention(numHeads, dModel);

                this.ffn = tf.sequential({
                    layers: [
                        tf.layers.dense({ units: dff, activation: 'relu', inputShape: [dModel] }),
                        tf.layers.dense({ units: dModel })
                    ]
                });

                this.layernorm1 = tf.layers.layerNormalization({ epsilon: 1e-6 });
                this.layernorm2 = tf.layers.layerNormalization({ epsilon: 1e-6 });
            }

            call(x, mask = null) {
                const attnOutput = this.mha.call(x, x, x, mask);
                const out1 = this.layernorm1.apply(x.add(attnOutput));

                const ffnOutput = this.ffn.apply(out1);
                const out2 = this.layernorm2.apply(out1.add(ffnOutput));
                return out2;
            }
        }

        // Transformer Model
        class Transformer {
            constructor(numLayers, numHeads, dModel, dff, inputVocabSize, maxSeqLength) {
                this.numLayers = numLayers;
                this.dModel = dModel;

                this.embedding = tf.layers.embedding({ inputDim: inputVocabSize, outputDim: dModel });
                this.positionalEncoding = positionalEncoding(maxSeqLength, dModel);

                this.encoderLayers = Array.from({ length: numLayers }, () => new EncoderLayer(dModel, numHeads, dff));

                this.finalLayer = tf.layers.dense({ units: inputVocabSize });
            }

            call(input) {
                const paddingMask = createPaddingMask(input);

                let x = this.embedding.apply(input);
                x = x.add(this.positionalEncoding.slice([0, 0], [input.shape[1], this.dModel]));

                for (const layer of this.encoderLayers) {
                    x = layer.call(x, paddingMask);
                }

                return this.finalLayer.apply(x);
            }
        }

        async function runTransformer() {
            const numLayers = 2;
            const numHeads = 4;
            const dModel = 16;
            const dff = 64;
            const maxSeqLength = 5;

            const transformer = new Transformer(numLayers, numHeads, dModel, dff, vocabSize, maxSeqLength);

            // Input sentence: "hello world"
            const inputSentence = "<START> hello world <END>";
            const inputTokens = tokenize(inputSentence);

            // Pad the sequence to match maxSeqLength
            const paddedInputTokens = padSequence(inputTokens, maxSeqLength, vocab.indexOf("<PAD>"));
            const inputTensor = tf.tensor2d([paddedInputTokens], [1, maxSeqLength]); // Shape: [batch_size, seq_length]

            // Forward pass
            const outputs = transformer.call(inputTensor);

            // Convert logits to token IDs
            const predictedTokenIDs = outputs.argMax(-1); // Get the most likely token at each position
            const predictedTokens = await predictedTokenIDs.array();

            // Decode tokens back to text
            const outputDiv = document.getElementById('output');
            outputDiv.textContent = `Input Sentence: "${inputSentence}"\n\nPredicted Tokens: ${detokenize(predictedTokens[0])}`;
        }

        runTransformer();
    </script>
</body>
</html>

Codepen

0
Subscribe to my newsletter

Read articles from Mohamad Mahmood directly inside your inbox. Subscribe to the newsletter, and don't miss out.

Written by

Mohamad Mahmood
Mohamad Mahmood

Mohamad's interest is in Programming (Mobile, Web, Database and Machine Learning). He studies at the Center For Artificial Intelligence Technology (CAIT), Universiti Kebangsaan Malaysia (UKM).