Enhanced Transformer Implementation in TensorFlow.js

The "enhanced Transformer implementation" refers to an improved version of the Transformer model that incorporates key features missing from a basic implementation, such as positional encoding (to account for token order), masking (to handle padding and ensure causality in autoregressive tasks), and stacked layers (to enable deeper processing). Additionally, this implementation ensures proper handling of input shapes, avoids errors like exceeding string length limits by summarizing large tensor outputs, and dynamically displays results in a webpage using JavaScript and TensorFlow.js. These enhancements make the model more robust, functional, and aligned with the original Transformer architecture described in the "Attention Is All You Need" paper.
Below is the code that displays the output of the Transformer model in a <div>
element.
1. HTML Structure
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Enhanced Transformer in TensorFlow.js</title>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>
<style>
body {
font-family: Arial, sans-serif;
margin: 20px;
}
#output {
margin-top: 20px;
padding: 10px;
border: 1px solid #ccc;
background-color: #f9f9f9;
white-space: pre-wrap; /* Preserve whitespace for tensor output */
}
</style>
</head>
<body>
<h1>Enhanced Transformer Model Example</h1>
<p>Output will appear below:</p>
<div id="output">Loading...</div>
<script>
// JavaScript code for the Transformer model
</script>
</body>
</html>
Explanation
HTML Tags:
The
<html>
tag defines the document structure.The
<head>
section includes metadata (e.g., character encoding, viewport settings) and loads the TensorFlow.js library via a<script>
tag.The
<body>
contains:A heading (
<h1>
) titled "Enhanced Transformer Model Example."A paragraph (
<p>
) indicating where the output will appear.A
<div>
with the IDoutput
to display the results dynamically.A
<script>
block where the Transformer model is implemented.
CSS Styling:
The
#output
element is styled with padding, a border, and a light background color for better readability.The
white-space: pre-wrap
ensures that the output preserves line breaks and formatting.
2. Helper Functions
// Helper Functions
function getAngles(pos, i, dModel) {
return pos / Math.pow(10000, (2 * (i / 2)) / dModel);
}
function positionalEncoding(position, dModel) {
const angleRads = tf.tidy(() => {
const angles = [];
for (let pos = 0; pos < position; pos++) {
for (let i = 0; i < dModel; i++) {
angles.push(getAngles(pos, i, dModel));
}
}
return tf.tensor(angles).reshape([position, dModel]);
});
const encoding = tf.tidy(() => {
const sinEncoding = angleRads.slice([0, 0], [-1, dModel / 2]).sin();
const cosEncoding = angleRads.slice([0, dModel / 2], [-1, dModel / 2]).cos();
return tf.concat([sinEncoding, cosEncoding], 1);
});
angleRads.dispose();
return encoding;
}
function createPaddingMask(seq) {
return seq.equal(0).expandDims(-1).expandDims(1); // 0 indicates padding
}
function createCausalMask(seqLength) {
const mask = tf.linalg.bandPart(tf.ones([seqLength, seqLength]), -1, 0);
return mask.expandDims(0).expandDims(1); // Add batch and head dimensions
}
Explanation
Positional Encoding:
Positional encoding adds information about the position of tokens in the sequence.
getAngles
computes the angles used in the encoding formula.positionalEncoding
generates sine and cosine values for each position and dimension, combining them into a single tensor.
Masking:
createPaddingMask
creates a binary mask to ignore padded tokens (tokens with value0
).createCausalMask
creates a triangular mask to ensure the decoder only attends to previous tokens during training (autoregressive property).
3. Multi-Head Attention Layer
class MultiHeadAttention {
constructor(numHeads, dModel) {
this.numHeads = numHeads;
this.dModel = dModel;
this.depth = dModel / numHeads;
this.wq = tf.layers.dense({ units: dModel });
this.wk = tf.layers.dense({ units: dModel });
this.wv = tf.layers.dense({ units: dModel });
this.dense = tf.layers.dense({ units: dModel });
}
splitHeads(x) {
const batchSize = x.shape[0];
return tf.reshape(x, [batchSize, -1, this.numHeads, this.depth])
.transpose([0, 2, 1, 3]);
}
call(q, k, v, mask = null) {
const batchSize = q.shape[0];
q = this.wq.apply(q);
k = this.wk.apply(k);
v = this.wv.apply(v);
q = this.splitHeads(q);
k = this.splitHeads(k);
v = this.splitHeads(v);
let logits = tf.matMul(q, k.transpose([0, 1, 3, 2])).div(Math.sqrt(this.depth));
if (mask) {
logits = logits.add(mask.mul(-1e9)); // Apply mask
}
const weights = tf.softmax(logits);
let output = tf.matMul(weights, v);
output = output.transpose([0, 2, 1, 3]).reshape([batchSize, -1, this.dModel]);
return this.dense.apply(output);
}
}
Explanation
Multi-Head Attention:
This layer splits the input into multiple "heads," computes attention for each head, and combines the results.
The
splitHeads
function reshapes the input tensor to separate the heads.The
call
method computes scaled dot-product attention:logits
: Dot product of queries and keys, scaled bysqrt(depth)
.weights
: Softmax oflogits
to compute attention weights.output
: Weighted sum of values.
4. Encoder and Decoder Layers
class EncoderLayer {
constructor(dModel, numHeads, dff) {
this.mha = new MultiHeadAttention(numHeads, dModel);
this.ffn = tf.sequential({
layers: [
tf.layers.dense({ units: dff, activation: 'relu', inputShape: [dModel] }),
tf.layers.dense({ units: dModel })
]
});
this.layernorm1 = tf.layers.layerNormalization({ epsilon: 1e-6 });
this.layernorm2 = tf.layers.layerNormalization({ epsilon: 1e-6 });
}
call(x, mask = null) {
const attnOutput = this.mha.call(x, x, x, mask);
const out1 = this.layernorm1.apply(x.add(attnOutput));
const ffnOutput = this.ffn.apply(out1);
const out2 = this.layernorm2.apply(out1.add(ffnOutput));
return out2;
}
}
class DecoderLayer {
constructor(dModel, numHeads, dff) {
this.mha1 = new MultiHeadAttention(numHeads, dModel);
this.mha2 = new MultiHeadAttention(numHeads, dModel);
this.ffn = tf.sequential({
layers: [
tf.layers.dense({ units: dff, activation: 'relu', inputShape: [dModel] }),
tf.layers.dense({ units: dModel })
]
});
this.layernorm1 = tf.layers.layerNormalization({ epsilon: 1e-6 });
this.layernorm2 = tf.layers.layerNormalization({ epsilon: 1e-6 });
this.layernorm3 = tf.layers.layerNormalization({ epsilon: 1e-6 });
}
call(x, encOutput, lookAheadMask = null, paddingMask = null) {
const attn1 = this.mha1.call(x, x, x, lookAheadMask);
const out1 = this.layernorm1.apply(x.add(attn1));
const attn2 = this.mha2.call(out1, encOutput, encOutput, paddingMask);
const out2 = this.layernorm2.apply(out1.add(attn2));
const ffnOutput = this.ffn.apply(out2);
const out3 = this.layernorm3.apply(out2.add(ffnOutput));
return out3;
}
}
Explanation
Encoder Layer:
Processes the input sequence through multi-head attention and a feed-forward network.
Residual connections and layer normalization stabilize training.
Decoder Layer:
Similar to the encoder but includes two attention mechanisms:
Self-attention (masked to prevent attending to future tokens).
Encoder-decoder attention (combines encoder output with decoder state).
5. Transformer Model
class Transformer {
constructor(numLayers, numHeads, dModel, dff, inputVocabSize, targetVocabSize, maxSeqLength) {
this.numLayers = numLayers;
this.dModel = dModel;
this.encoderEmbedding = tf.layers.embedding({ inputDim: inputVocabSize, outputDim: dModel });
this.decoderEmbedding = tf.layers.embedding({ inputDim: targetVocabSize, outputDim: dModel });
this.positionalEncoding = positionalEncoding(maxSeqLength, dModel);
this.encoderLayers = Array.from({ length: numLayers }, () => new EncoderLayer(dModel, numHeads, dff));
this.decoderLayers = Array.from({ length: numLayers }, () => new DecoderLayer(dModel, numHeads, dff));
this.finalLayer = tf.layers.dense({ units: targetVocabSize });
}
call(encInput, decInput) {
const paddingMask = createPaddingMask(encInput);
const lookAheadMask = createCausalMask(decInput.shape[1]);
let encOutput = this.encoderEmbedding.apply(encInput);
encOutput = encOutput.add(this.positionalEncoding.slice([0, 0], [encInput.shape[1], this.dModel]));
for (const layer of this.encoderLayers) {
encOutput = layer.call(encOutput, paddingMask);
}
let decOutput = this.decoderEmbedding.apply(decInput);
decOutput = decOutput.add(this.positionalEncoding.slice([0, 0], [decInput.shape[1], this.dModel]));
for (const layer of this.decoderLayers) {
decOutput = layer.call(decOutput, encOutput, lookAheadMask, paddingMask);
}
return this.finalLayer.apply(decOutput);
}
}
Explanation
Transformer Architecture:
The encoder processes the input sequence, and the decoder generates the target sequence.
Positional encoding is added to embeddings to preserve token order.
Multiple encoder and decoder layers are stacked for deeper processing.
The final dense layer produces logits for the target vocabulary.
6. Running the Transformer
async function runTransformer() {
const transformer = new Transformer(
numLayers = 4,
numHeads = 8,
dModel = 128,
dff = 512,
inputVocabSize = 10000,
targetVocabSize = 10000,
maxSeqLength = 38
);
const batchSize = 64;
const encInput = tf.randomUniform([batchSize, maxSeqLength], 0, inputVocabSize, 'int32');
const decInput = tf.randomUniform([batchSize, maxSeqLength], 0, targetVocabSize, 'int32');
const outputs = transformer.call(encInput, decInput);
const outputArray = await outputs.array();
const outputDiv = document.getElementById('output');
const summarizedOutput = outputArray.slice(0, 2).map(row => row.slice(0, 5));
outputDiv.textContent = `Output Shape: ${outputs.shape}\n\nFirst Few Rows of Output Tensor:\n${JSON.stringify(summarizedOutput, null, 2)}`;
}
runTransformer();
Explanation
Transformer Initialization:
- Creates a Transformer instance with specified hyperparameters.
Random Input Data:
- Generates random input and target sequences using
tf.randomUniform
.
- Generates random input and target sequences using
Forward Pass:
- Calls the
transformer.call
method to perform a forward pass.
- Calls the
Summarized Output:
Extracts only the first few rows and columns of the output tensor to avoid exceeding string length limits.
Updates the
<div>
element with the tensor shape and summarized values.
Conclusion
This implementation demonstrates how to build and run a Transformer model in TensorFlow.js while handling large tensors gracefully. The output is displayed dynamically in a <div>
element, making it easy to visualize the results in a web browser. Let me know if you need further clarification!
Codepen
Subscribe to my newsletter
Read articles from Mohamad Mahmood directly inside your inbox. Subscribe to the newsletter, and don't miss out.
Written by

Mohamad Mahmood
Mohamad Mahmood
Mohamad's interest is in Programming (Mobile, Web, Database and Machine Learning). He studies at the Center For Artificial Intelligence Technology (CAIT), Universiti Kebangsaan Malaysia (UKM).