Transformer Example: Hello World with Meaningful Prediction

To make the Transformer model predict something meaningful, we need to train it on a dataset. For simplicity, let's use a small, synthetic dataset where the input is a sequence like "<START> hello world <END>"
and the target is a slightly modified version of the input, such as "<START> bonjour monde <END>"
.
Below is the code that includes a simple training loop. The model will learn to map the input sequence ("hello world"
) to the target sequence ("bonjour monde"
). After training, the model should be able to generate meaningful predictions.
1. Vocabulary and Tokenization
const vocab = ["<PAD>", "<START>", "<END>", "hello", "world", "bonjour", "monde"];
const vocabSize = vocab.length;
function tokenize(sentence) {
return sentence.split(" ").map(word => vocab.indexOf(word));
}
function detokenize(tokens) {
return tokens.map(token => vocab[token]).join(" ");
}
Explanation
Vocabulary (
vocab
):- A predefined list of words that the model understands. It includes special tokens like
<PAD>
(for padding),<START>
(to mark the start of a sequence), and<END>
(to mark the end of a sequence).
- A predefined list of words that the model understands. It includes special tokens like
Tokenization (
tokenize
):Converts a sentence (e.g.,
"hello world"
) into a list of token IDs based on their positions in the vocabulary.Example:
"hello world"
→[3, 4]
.
Detokenization (
detokenize
):Converts a list of token IDs back into a human-readable sentence.
Example:
[3, 4]
→"hello world"
.
2. Padding Function
function padSequence(sequence, maxLength, padValue = 0) {
const padded = [...sequence];
while (padded.length < maxLength) {
padded.push(padValue); // Add padding tokens
}
return padded.slice(0, maxLength); // Truncate if necessary
}
Explanation
Ensures all sequences have the same length (
maxLength
) by adding padding tokens (<PAD>
, represented as0
).Example: If
maxLength = 5
and the input sequence is[3, 4]
, the output will be[3, 4, 0, 0, 0]
.
3. Positional Encoding
function getAngles(pos, i, dModel) {
return pos / Math.pow(10000, (2 * (i / 2)) / dModel);
}
function positionalEncoding(position, dModel) {
const angleRads = tf.tidy(() => {
const angles = [];
for (let pos = 0; pos < position; pos++) {
for (let i = 0; i < dModel; i++) {
angles.push(getAngles(pos, i, dModel));
}
}
return tf.tensor(angles).reshape([position, dModel]);
});
const encoding = tf.tidy(() => {
const sinEncoding = angleRads.slice([0, 0], [-1, dModel / 2]).sin();
const cosEncoding = angleRads.slice([0, dModel / 2], [-1, dModel / 2]).cos();
return tf.concat([sinEncoding, cosEncoding], 1);
});
angleRads.dispose();
return encoding;
}
Explanation
Purpose:
- Adds information about the position of tokens in the sequence to the embeddings. This is crucial because the Transformer does not inherently understand the order of tokens.
Implementation:
Computes sine and cosine values for each position and dimension using a formula described in the original Transformer paper.
Combines these values into a positional encoding tensor.
4. Masking
function createPaddingMask(seq) {
return seq.equal(0).expandDims(-1).expandDims(1); // 0 indicates padding
}
Explanation
Purpose:
- Masks out padding tokens (
<PAD>
) so they do not affect attention computations.
- Masks out padding tokens (
How It Works:
- Creates a binary mask where
true
corresponds to padding tokens (0
) andfalse
corresponds to non-padding tokens.
- Creates a binary mask where
5. Multi-Head Attention
class MultiHeadAttention {
constructor(numHeads, dModel) {
this.numHeads = numHeads;
this.dModel = dModel;
this.depth = dModel / numHeads;
this.wq = tf.layers.dense({ units: dModel });
this.wk = tf.layers.dense({ units: dModel });
this.wv = tf.layers.dense({ units: dModel });
this.dense = tf.layers.dense({ units: dModel });
}
splitHeads(x) {
const batchSize = x.shape[0];
return tf.reshape(x, [batchSize, -1, this.numHeads, this.depth])
.transpose([0, 2, 1, 3]);
}
call(q, k, v, mask = null) {
const batchSize = q.shape[0];
q = this.wq.apply(q);
k = this.wk.apply(k);
v = this.wv.apply(v);
q = this.splitHeads(q);
k = this.splitHeads(k);
v = this.splitHeads(v);
let logits = tf.matMul(q, k.transpose([0, 1, 3, 2])).div(Math.sqrt(this.depth));
if (mask) {
logits = logits.add(mask.mul(-1e9)); // Apply mask
}
const weights = tf.softmax(logits);
let output = tf.matMul(weights, v);
output = output.transpose([0, 2, 1, 3]).reshape([batchSize, -1, this.dModel]);
return this.dense.apply(output);
}
}
Explanation
Purpose:
- Implements multi-head attention, which allows the model to focus on different parts of the input sequence simultaneously.
Key Steps:
Linear transformations (
wq
,wk
,wv
) project the inputs into query, key, and value vectors.Splits the projections into multiple heads.
Computes attention scores (scaled dot-product attention) and applies softmax to obtain attention weights.
Combines the results from all heads into a single output.
6. Encoder Layer
class EncoderLayer {
constructor(dModel, numHeads, dff) {
this.mha = new MultiHeadAttention(numHeads, dModel);
this.ffn = tf.sequential({
layers: [
tf.layers.dense({ units: dff, activation: 'relu', inputShape: [dModel] }),
tf.layers.dense({ units: dModel })
]
});
this.layernorm1 = tf.layers.layerNormalization({ epsilon: 1e-6 });
this.layernorm2 = tf.layers.layerNormalization({ epsilon: 1e-6 });
}
call(x, mask = null) {
const attnOutput = this.mha.call(x, x, x, mask);
const out1 = this.layernorm1.apply(x.add(attnOutput));
const ffnOutput = this.ffn.apply(out1);
const out2 = this.layernorm2.apply(out1.add(ffnOutput));
return out2;
}
}
Explanation
Purpose:
- Processes the input through multi-head attention and a feed-forward network.
Key Components:
Multi-Head Attention (
mha
): Captures relationships between tokens.Feed-Forward Network (
ffn
): Applies non-linear transformations to the attention output.Layer Normalization (
layernorm1
,layernorm2
): Stabilizes training by normalizing activations.Residual Connections: Adds the input to the output of each sub-layer to preserve information.
7. Transformer Model
class Transformer {
constructor(numLayers, numHeads, dModel, dff, inputVocabSize, maxSeqLength) {
this.numLayers = numLayers;
this.dModel = dModel;
this.embedding = tf.layers.embedding({ inputDim: inputVocabSize, outputDim: dModel });
this.positionalEncoding = positionalEncoding(maxSeqLength, dModel);
this.encoderLayers = Array.from({ length: numLayers }, () => new EncoderLayer(dModel, numHeads, dff));
this.finalLayer = tf.layers.dense({ units: inputVocabSize });
}
call(input) {
const paddingMask = createPaddingMask(input);
let x = this.embedding.apply(input);
x = x.add(this.positionalEncoding.slice([0, 0], [input.shape[1], this.dModel]));
for (const layer of this.encoderLayers) {
x = layer.call(x, paddingMask);
}
return this.finalLayer.apply(x);
}
}
Explanation
Purpose:
- Implements the Transformer architecture with an embedding layer, positional encoding, stacked encoder layers, and a final dense layer.
Key Steps:
Embedding: Converts token IDs into dense vectors.
Positional Encoding: Adds positional information to the embeddings.
Encoder Layers: Processes the input through multiple layers of multi-head attention and feed-forward networks.
Final Dense Layer: Produces logits for each token in the vocabulary.
8. Training Loop
async function trainAndRunTransformer() {
const transformer = new Transformer(...);
// Training data
const inputSentence = "<START> hello world <END>";
const targetSentence = "<START> bonjour monde <END>";
const inputTokens = tokenize(inputSentence);
const targetTokens = tokenize(targetSentence);
const paddedInputTokens = padSequence(inputTokens, maxSeqLength, vocab.indexOf("<PAD>"));
const paddedTargetTokens = padSequence(targetTokens, maxSeqLength, vocab.indexOf("<PAD>"));
const inputTensor = tf.tensor2d([paddedInputTokens], [1, maxSeqLength]);
const targetTensor = tf.tensor2d([paddedTargetTokens], [1, maxSeqLength]);
const targetTensorInt32 = targetTensor.cast('int32');
function loss(labels, preds) {
return tf.losses.softmaxCrossEntropy(labels, preds);
}
const optimizer = tf.train.adam(learningRate);
for (let epoch = 0; epoch < epochs; epoch++) {
const currentLoss = tf.tidy(() => {
const preds = transformer.call(inputTensor);
const oneHotLabels = tf.oneHot(targetTensorInt32, vocabSize);
const lossValue = loss(oneHotLabels, preds);
return lossValue.asScalar();
});
optimizer.minimize(() => {
const preds = transformer.call(inputTensor);
const oneHotLabels = tf.oneHot(targetTensorInt32, vocabSize);
return loss(oneHotLabels, preds);
}, true);
if (epoch % 10 === 0) {
console.log(`Epoch ${epoch}: Loss = ${currentLoss.dataSync()[0].toFixed(4)}`);
}
currentLoss.dispose();
}
// Inference after training
const outputs = transformer.call(inputTensor);
const predictedTokenIDs = outputs.argMax(-1);
const predictedTokens = await predictedTokenIDs.array();
const outputDiv = document.getElementById('output');
outputDiv.textContent = `Input Sentence: "${inputSentence}"\n\nPredicted Tokens: ${detokenize(predictedTokens[0])}`;
}
Explanation
Training Data:
- Defines input-output pairs for training (e.g.,
"hello world"
→"bonjour monde"
).
- Defines input-output pairs for training (e.g.,
Loss Function:
- Uses cross-entropy loss to measure the difference between predicted and target tokens.
Optimizer:
- Updates the model's weights using the Adam optimizer to minimize the loss.
Training Loop:
Iteratively computes predictions, calculates the loss, and updates the model's weights.
Logs the loss every 10 epochs to monitor progress.
Inference:
- After training, the model generates predictions for the input sequence and displays the decoded output.
Conclusion
This implementation demonstrates how to build and train a Transformer model in TensorFlow.js. The model learns to map input sequences to target sequences by minimizing the cross-entropy loss. Proper memory management ensures efficient computation, and the modular design makes it easy to extend or modify the model.
Full code
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Transformer Example: Hello World</title>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>
<style>
body {
font-family: Arial, sans-serif;
margin: 20px;
}
#output {
margin-top: 20px;
padding: 10px;
border: 1px solid #ccc;
background-color: #f9f9f9;
white-space: pre-wrap;
}
</style>
</head>
<body>
<h1>Transformer Example: Hello World</h1>
<p>Output will appear below:</p>
<div id="output">Loading...</div>
<script>
// Vocabulary for tokenization
const vocab = ["<PAD>", "<START>", "<END>", "hello", "world", "bonjour", "monde"];
const vocabSize = vocab.length;
// Tokenizer functions
function tokenize(sentence) {
return sentence.split(" ").map(word => vocab.indexOf(word));
}
function detokenize(tokens) {
return tokens.map(token => vocab[token]).join(" ");
}
// Padding function
function padSequence(sequence, maxLength, padValue = 0) {
const padded = [...sequence];
while (padded.length < maxLength) {
padded.push(padValue); // Add padding tokens
}
return padded.slice(0, maxLength); // Truncate if necessary
}
// Helper Functions
function getAngles(pos, i, dModel) {
return pos / Math.pow(10000, (2 * (i / 2)) / dModel);
}
function positionalEncoding(position, dModel) {
const angleRads = tf.tidy(() => {
const angles = [];
for (let pos = 0; pos < position; pos++) {
for (let i = 0; i < dModel; i++) {
angles.push(getAngles(pos, i, dModel));
}
}
return tf.tensor(angles).reshape([position, dModel]);
});
const encoding = tf.tidy(() => {
const sinEncoding = angleRads.slice([0, 0], [-1, dModel / 2]).sin();
const cosEncoding = angleRads.slice([0, dModel / 2], [-1, dModel / 2]).cos();
return tf.concat([sinEncoding, cosEncoding], 1);
});
angleRads.dispose();
return encoding;
}
function createPaddingMask(seq) {
return seq.equal(0).expandDims(-1).expandDims(1); // 0 indicates padding
}
// Multi-Head Attention Layer
class MultiHeadAttention {
constructor(numHeads, dModel) {
this.numHeads = numHeads;
this.dModel = dModel;
this.depth = dModel / numHeads;
this.wq = tf.layers.dense({ units: dModel });
this.wk = tf.layers.dense({ units: dModel });
this.wv = tf.layers.dense({ units: dModel });
this.dense = tf.layers.dense({ units: dModel });
}
splitHeads(x) {
const batchSize = x.shape[0];
return tf.reshape(x, [batchSize, -1, this.numHeads, this.depth])
.transpose([0, 2, 1, 3]);
}
call(q, k, v, mask = null) {
const batchSize = q.shape[0];
q = this.wq.apply(q);
k = this.wk.apply(k);
v = this.wv.apply(v);
q = this.splitHeads(q);
k = this.splitHeads(k);
v = this.splitHeads(v);
let logits = tf.matMul(q, k.transpose([0, 1, 3, 2])).div(Math.sqrt(this.depth));
if (mask) {
logits = logits.add(mask.mul(-1e9)); // Apply mask
}
const weights = tf.softmax(logits);
let output = tf.matMul(weights, v);
output = output.transpose([0, 2, 1, 3]).reshape([batchSize, -1, this.dModel]);
return this.dense.apply(output);
}
}
// Encoder Layer
class EncoderLayer {
constructor(dModel, numHeads, dff) {
this.mha = new MultiHeadAttention(numHeads, dModel);
this.ffn = tf.sequential({
layers: [
tf.layers.dense({ units: dff, activation: 'relu', inputShape: [dModel] }),
tf.layers.dense({ units: dModel })
]
});
this.layernorm1 = tf.layers.layerNormalization({ epsilon: 1e-6 });
this.layernorm2 = tf.layers.layerNormalization({ epsilon: 1e-6 });
}
call(x, mask = null) {
const attnOutput = this.mha.call(x, x, x, mask);
const out1 = this.layernorm1.apply(x.add(attnOutput));
const ffnOutput = this.ffn.apply(out1);
const out2 = this.layernorm2.apply(out1.add(ffnOutput));
return out2;
}
}
// Transformer Model
class Transformer {
constructor(numLayers, numHeads, dModel, dff, inputVocabSize, maxSeqLength) {
this.numLayers = numLayers;
this.dModel = dModel;
this.embedding = tf.layers.embedding({ inputDim: inputVocabSize, outputDim: dModel });
this.positionalEncoding = positionalEncoding(maxSeqLength, dModel);
this.encoderLayers = Array.from({ length: numLayers }, () => new EncoderLayer(dModel, numHeads, dff));
this.finalLayer = tf.layers.dense({ units: inputVocabSize });
}
call(input) {
const paddingMask = createPaddingMask(input);
let x = this.embedding.apply(input);
x = x.add(this.positionalEncoding.slice([0, 0], [input.shape[1], this.dModel]));
for (const layer of this.encoderLayers) {
x = layer.call(x, paddingMask);
}
return this.finalLayer.apply(x);
}
}
async function trainAndRunTransformer() {
const numLayers = 2;
const numHeads = 4;
const dModel = 16;
const dff = 64;
const maxSeqLength = 5;
const transformer = new Transformer(numLayers, numHeads, dModel, dff, vocabSize, maxSeqLength);
// Training data
const inputSentence = "<START> hello world <END>";
const targetSentence = "<START> bonjour monde <END>";
const inputTokens = tokenize(inputSentence);
const targetTokens = tokenize(targetSentence);
const paddedInputTokens = padSequence(inputTokens, maxSeqLength, vocab.indexOf("<PAD>"));
const paddedTargetTokens = padSequence(targetTokens, maxSeqLength, vocab.indexOf("<PAD>"));
const inputTensor = tf.tensor2d([paddedInputTokens], [1, maxSeqLength]);
const targetTensor = tf.tensor2d([paddedTargetTokens], [1, maxSeqLength]);
// Ensure targetTensor is int32 for oneHot
const targetTensorInt32 = targetTensor.cast('int32');
// Loss function
function loss(labels, preds) {
return tf.losses.softmaxCrossEntropy(labels, preds);
}
// Optimizer
const learningRate = 0.001;
const optimizer = tf.train.adam(learningRate);
// Training loop
const epochs = 100;
for (let epoch = 0; epoch < epochs; epoch++) {
// Use tf.tidy to manage memory
const currentLoss = tf.tidy(() => {
const preds = transformer.call(inputTensor);
const oneHotLabels = tf.oneHot(targetTensorInt32, vocabSize);
const lossValue = loss(oneHotLabels, preds);
return lossValue.asScalar(); // Ensure loss is a scalar tensor
});
// Minimize the loss
optimizer.minimize(() => {
const preds = transformer.call(inputTensor);
const oneHotLabels = tf.oneHot(targetTensorInt32, vocabSize);
return loss(oneHotLabels, preds);
}, true);
// Log the loss every 10 epochs
if (epoch % 10 === 0) {
console.log(`Epoch ${epoch}: Loss = ${currentLoss.dataSync()[0].toFixed(4)}`);
}
// Dispose of the loss tensor to free memory
currentLoss.dispose();
}
// Inference after training
const outputs = transformer.call(inputTensor);
const predictedTokenIDs = outputs.argMax(-1); // Get the most likely token at each position
const predictedTokens = await predictedTokenIDs.array();
// Decode tokens back to text
const outputDiv = document.getElementById('output');
outputDiv.textContent = `Input Sentence: "${inputSentence}"\n\nPredicted Tokens: ${detokenize(predictedTokens[0])}`;
}
trainAndRunTransformer();
</script>
</body>
</html>
Codepen
Subscribe to my newsletter
Read articles from Mohamad Mahmood directly inside your inbox. Subscribe to the newsletter, and don't miss out.
Written by

Mohamad Mahmood
Mohamad Mahmood
Mohamad's interest is in Programming (Mobile, Web, Database and Machine Learning). He studies at the Center For Artificial Intelligence Technology (CAIT), Universiti Kebangsaan Malaysia (UKM).