Enhanced Transformer Implementation in TensorFlow.js

Mohamad MahmoodMohamad Mahmood
7 min read

The "enhanced Transformer implementation" refers to an improved version of the Transformer model that incorporates key features missing from a basic implementation, such as positional encoding (to account for token order), masking (to handle padding and ensure causality in autoregressive tasks), and stacked layers (to enable deeper processing). Additionally, this implementation ensures proper handling of input shapes, avoids errors like exceeding string length limits by summarizing large tensor outputs, and dynamically displays results in a webpage using JavaScript and TensorFlow.js. These enhancements make the model more robust, functional, and aligned with the original Transformer architecture described in the "Attention Is All You Need" paper.

Below is the code that displays the output of the Transformer model in a <div> element.


1. HTML Structure

<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>Enhanced Transformer in TensorFlow.js</title>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>
    <style>
        body {
            font-family: Arial, sans-serif;
            margin: 20px;
        }
        #output {
            margin-top: 20px;
            padding: 10px;
            border: 1px solid #ccc;
            background-color: #f9f9f9;
            white-space: pre-wrap; /* Preserve whitespace for tensor output */
        }
    </style>
</head>
<body>
    <h1>Enhanced Transformer Model Example</h1>
    <p>Output will appear below:</p>
    <div id="output">Loading...</div>

    <script>
        // JavaScript code for the Transformer model
    </script>
</body>
</html>

Explanation

  • HTML Tags:

    • The <html> tag defines the document structure.

    • The <head> section includes metadata (e.g., character encoding, viewport settings) and loads the TensorFlow.js library via a <script> tag.

    • The <body> contains:

      • A heading (<h1>) titled "Enhanced Transformer Model Example."

      • A paragraph (<p>) indicating where the output will appear.

      • A <div> with the ID output to display the results dynamically.

      • A <script> block where the Transformer model is implemented.

  • CSS Styling:

    • The #output element is styled with padding, a border, and a light background color for better readability.

    • The white-space: pre-wrap ensures that the output preserves line breaks and formatting.


2. Helper Functions

// Helper Functions
function getAngles(pos, i, dModel) {
    return pos / Math.pow(10000, (2 * (i / 2)) / dModel);
}

function positionalEncoding(position, dModel) {
    const angleRads = tf.tidy(() => {
        const angles = [];
        for (let pos = 0; pos < position; pos++) {
            for (let i = 0; i < dModel; i++) {
                angles.push(getAngles(pos, i, dModel));
            }
        }
        return tf.tensor(angles).reshape([position, dModel]);
    });

    const encoding = tf.tidy(() => {
        const sinEncoding = angleRads.slice([0, 0], [-1, dModel / 2]).sin();
        const cosEncoding = angleRads.slice([0, dModel / 2], [-1, dModel / 2]).cos();
        return tf.concat([sinEncoding, cosEncoding], 1);
    });

    angleRads.dispose();
    return encoding;
}

function createPaddingMask(seq) {
    return seq.equal(0).expandDims(-1).expandDims(1); // 0 indicates padding
}

function createCausalMask(seqLength) {
    const mask = tf.linalg.bandPart(tf.ones([seqLength, seqLength]), -1, 0);
    return mask.expandDims(0).expandDims(1); // Add batch and head dimensions
}

Explanation

  1. Positional Encoding:

    • Positional encoding adds information about the position of tokens in the sequence.

    • getAngles computes the angles used in the encoding formula.

    • positionalEncoding generates sine and cosine values for each position and dimension, combining them into a single tensor.

  2. Masking:

    • createPaddingMask creates a binary mask to ignore padded tokens (tokens with value 0).

    • createCausalMask creates a triangular mask to ensure the decoder only attends to previous tokens during training (autoregressive property).


3. Multi-Head Attention Layer

class MultiHeadAttention {
    constructor(numHeads, dModel) {
        this.numHeads = numHeads;
        this.dModel = dModel;
        this.depth = dModel / numHeads;

        this.wq = tf.layers.dense({ units: dModel });
        this.wk = tf.layers.dense({ units: dModel });
        this.wv = tf.layers.dense({ units: dModel });
        this.dense = tf.layers.dense({ units: dModel });
    }

    splitHeads(x) {
        const batchSize = x.shape[0];
        return tf.reshape(x, [batchSize, -1, this.numHeads, this.depth])
                 .transpose([0, 2, 1, 3]);
    }

    call(q, k, v, mask = null) {
        const batchSize = q.shape[0];

        q = this.wq.apply(q);
        k = this.wk.apply(k);
        v = this.wv.apply(v);

        q = this.splitHeads(q);
        k = this.splitHeads(k);
        v = this.splitHeads(v);

        let logits = tf.matMul(q, k.transpose([0, 1, 3, 2])).div(Math.sqrt(this.depth));
        if (mask) {
            logits = logits.add(mask.mul(-1e9)); // Apply mask
        }
        const weights = tf.softmax(logits);
        let output = tf.matMul(weights, v);

        output = output.transpose([0, 2, 1, 3]).reshape([batchSize, -1, this.dModel]);
        return this.dense.apply(output);
    }
}

Explanation

  • Multi-Head Attention:

    • This layer splits the input into multiple "heads," computes attention for each head, and combines the results.

    • The splitHeads function reshapes the input tensor to separate the heads.

    • The call method computes scaled dot-product attention:

      • logits: Dot product of queries and keys, scaled by sqrt(depth).

      • weights: Softmax of logits to compute attention weights.

      • output: Weighted sum of values.


4. Encoder and Decoder Layers

class EncoderLayer {
    constructor(dModel, numHeads, dff) {
        this.mha = new MultiHeadAttention(numHeads, dModel);

        this.ffn = tf.sequential({
            layers: [
                tf.layers.dense({ units: dff, activation: 'relu', inputShape: [dModel] }),
                tf.layers.dense({ units: dModel })
            ]
        });

        this.layernorm1 = tf.layers.layerNormalization({ epsilon: 1e-6 });
        this.layernorm2 = tf.layers.layerNormalization({ epsilon: 1e-6 });
    }

    call(x, mask = null) {
        const attnOutput = this.mha.call(x, x, x, mask);
        const out1 = this.layernorm1.apply(x.add(attnOutput));

        const ffnOutput = this.ffn.apply(out1);
        const out2 = this.layernorm2.apply(out1.add(ffnOutput));
        return out2;
    }
}

class DecoderLayer {
    constructor(dModel, numHeads, dff) {
        this.mha1 = new MultiHeadAttention(numHeads, dModel);
        this.mha2 = new MultiHeadAttention(numHeads, dModel);

        this.ffn = tf.sequential({
            layers: [
                tf.layers.dense({ units: dff, activation: 'relu', inputShape: [dModel] }),
                tf.layers.dense({ units: dModel })
            ]
        });

        this.layernorm1 = tf.layers.layerNormalization({ epsilon: 1e-6 });
        this.layernorm2 = tf.layers.layerNormalization({ epsilon: 1e-6 });
        this.layernorm3 = tf.layers.layerNormalization({ epsilon: 1e-6 });
    }

    call(x, encOutput, lookAheadMask = null, paddingMask = null) {
        const attn1 = this.mha1.call(x, x, x, lookAheadMask);
        const out1 = this.layernorm1.apply(x.add(attn1));

        const attn2 = this.mha2.call(out1, encOutput, encOutput, paddingMask);
        const out2 = this.layernorm2.apply(out1.add(attn2));

        const ffnOutput = this.ffn.apply(out2);
        const out3 = this.layernorm3.apply(out2.add(ffnOutput));
        return out3;
    }
}

Explanation

  • Encoder Layer:

    • Processes the input sequence through multi-head attention and a feed-forward network.

    • Residual connections and layer normalization stabilize training.

  • Decoder Layer:

    • Similar to the encoder but includes two attention mechanisms:

      • Self-attention (masked to prevent attending to future tokens).

      • Encoder-decoder attention (combines encoder output with decoder state).


5. Transformer Model

class Transformer {
    constructor(numLayers, numHeads, dModel, dff, inputVocabSize, targetVocabSize, maxSeqLength) {
        this.numLayers = numLayers;
        this.dModel = dModel;

        this.encoderEmbedding = tf.layers.embedding({ inputDim: inputVocabSize, outputDim: dModel });
        this.decoderEmbedding = tf.layers.embedding({ inputDim: targetVocabSize, outputDim: dModel });

        this.positionalEncoding = positionalEncoding(maxSeqLength, dModel);

        this.encoderLayers = Array.from({ length: numLayers }, () => new EncoderLayer(dModel, numHeads, dff));
        this.decoderLayers = Array.from({ length: numLayers }, () => new DecoderLayer(dModel, numHeads, dff));

        this.finalLayer = tf.layers.dense({ units: targetVocabSize });
    }

    call(encInput, decInput) {
        const paddingMask = createPaddingMask(encInput);
        const lookAheadMask = createCausalMask(decInput.shape[1]);

        let encOutput = this.encoderEmbedding.apply(encInput);
        encOutput = encOutput.add(this.positionalEncoding.slice([0, 0], [encInput.shape[1], this.dModel]));
        for (const layer of this.encoderLayers) {
            encOutput = layer.call(encOutput, paddingMask);
        }

        let decOutput = this.decoderEmbedding.apply(decInput);
        decOutput = decOutput.add(this.positionalEncoding.slice([0, 0], [decInput.shape[1], this.dModel]));
        for (const layer of this.decoderLayers) {
            decOutput = layer.call(decOutput, encOutput, lookAheadMask, paddingMask);
        }

        return this.finalLayer.apply(decOutput);
    }
}

Explanation

  • Transformer Architecture:

    • The encoder processes the input sequence, and the decoder generates the target sequence.

    • Positional encoding is added to embeddings to preserve token order.

    • Multiple encoder and decoder layers are stacked for deeper processing.

    • The final dense layer produces logits for the target vocabulary.


6. Running the Transformer

async function runTransformer() {
    const transformer = new Transformer(
        numLayers = 4,
        numHeads = 8,
        dModel = 128,
        dff = 512,
        inputVocabSize = 10000,
        targetVocabSize = 10000,
        maxSeqLength = 38
    );

    const batchSize = 64;
    const encInput = tf.randomUniform([batchSize, maxSeqLength], 0, inputVocabSize, 'int32');
    const decInput = tf.randomUniform([batchSize, maxSeqLength], 0, targetVocabSize, 'int32');

    const outputs = transformer.call(encInput, decInput);

    const outputArray = await outputs.array();
    const outputDiv = document.getElementById('output');

    const summarizedOutput = outputArray.slice(0, 2).map(row => row.slice(0, 5));
    outputDiv.textContent = `Output Shape: ${outputs.shape}\n\nFirst Few Rows of Output Tensor:\n${JSON.stringify(summarizedOutput, null, 2)}`;
}

runTransformer();

Explanation

  • Transformer Initialization:

    • Creates a Transformer instance with specified hyperparameters.
  • Random Input Data:

    • Generates random input and target sequences using tf.randomUniform.
  • Forward Pass:

  • Summarized Output:

    • Extracts only the first few rows and columns of the output tensor to avoid exceeding string length limits.

    • Updates the <div> element with the tensor shape and summarized values.


Conclusion

This implementation demonstrates how to build and run a Transformer model in TensorFlow.js while handling large tensors gracefully. The output is displayed dynamically in a <div> element, making it easy to visualize the results in a web browser. Let me know if you need further clarification!

Codepen

0
Subscribe to my newsletter

Read articles from Mohamad Mahmood directly inside your inbox. Subscribe to the newsletter, and don't miss out.

Written by

Mohamad Mahmood
Mohamad Mahmood

Mohamad's interest is in Programming (Mobile, Web, Database and Machine Learning). He studies at the Center For Artificial Intelligence Technology (CAIT), Universiti Kebangsaan Malaysia (UKM).