Transformer Example: Meaningful Training

Mohamad MahmoodMohamad Mahmood
19 min read

To make the Transformer implementation more meaningful, we will extend the example to handle a larger dataset , use a bigger vocabulary , and evaluate the model on unseen data . This will demonstrate the model's ability to generalize rather than simply memorize mappings.

To make the Transformer implementation more meaningful, we will extend the example to handle a larger dataset, use a bigger vocabulary, and evaluate the model on unseen data. This will demonstrate the model's ability to generalize rather than simply memorize mappings.

Below is the updated code that incorporates these enhancements:


Updated Code: Meaningful Transformer Implementation

<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>Transformer Example: Meaningful Training</title>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>
    <style>
        body {
            font-family: Arial, sans-serif;
            margin: 20px;
        }
        #output {
            margin-top: 20px;
            padding: 10px;
            border: 1px solid #ccc;
            background-color: #f9f9f9;
            white-space: pre-wrap;
        }
    </style>
</head>
<body>
    <h1>Transformer Example: Meaningful Training</h1>
    <p>Output will appear below:</p>
    <div id="output">Loading...</div>

    <script>
        // Vocabulary for tokenization
        const vocab = [
            "<PAD>", "<START>", "<END>", "hello", "world", "bonjour", "monde",
            "good", "morning", "see", "you", "later", "salut", "matin", "à", "bientôt"
        ];
        const vocabSize = vocab.length;

        // Tokenizer functions
        function tokenize(sentence) {
            return sentence.split(" ").map(word => vocab.indexOf(word));
        }

        function detokenize(tokens) {
            return tokens.map(token => vocab[token]).join(" ");
        }

        // Padding function
        function padSequence(sequence, maxLength, padValue = 0) {
            const padded = [...sequence];
            while (padded.length < maxLength) {
                padded.push(padValue); // Add padding tokens
            }
            return padded.slice(0, maxLength); // Truncate if necessary
        }

        // Helper Functions
        function getAngles(pos, i, dModel) {
            return pos / Math.pow(10000, (2 * (i / 2)) / dModel);
        }

        function positionalEncoding(position, dModel) {
            const angleRads = tf.tidy(() => {
                const angles = [];
                for (let pos = 0; pos < position; pos++) {
                    for (let i = 0; i < dModel; i++) {
                        angles.push(getAngles(pos, i, dModel));
                    }
                }
                return tf.tensor(angles).reshape([position, dModel]);
            });

            const encoding = tf.tidy(() => {
                const sinEncoding = angleRads.slice([0, 0], [-1, dModel / 2]).sin();
                const cosEncoding = angleRads.slice([0, dModel / 2], [-1, dModel / 2]).cos();
                return tf.concat([sinEncoding, cosEncoding], 1);
            });

            angleRads.dispose();
            return encoding;
        }

        function createPaddingMask(seq) {
            return seq.equal(0).expandDims(-1).expandDims(1); // 0 indicates padding
        }

        // Multi-Head Attention Layer
        class MultiHeadAttention {
            constructor(numHeads, dModel) {
                this.numHeads = numHeads;
                this.dModel = dModel;
                this.depth = dModel / numHeads;

                this.wq = tf.layers.dense({ units: dModel });
                this.wk = tf.layers.dense({ units: dModel });
                this.wv = tf.layers.dense({ units: dModel });
                this.dense = tf.layers.dense({ units: dModel });
            }

            splitHeads(x) {
                const batchSize = x.shape[0];
                return tf.reshape(x, [batchSize, -1, this.numHeads, this.depth])
                         .transpose([0, 2, 1, 3]);
            }

            call(q, k, v, mask = null) {
                const batchSize = q.shape[0];

                q = this.wq.apply(q);
                k = this.wk.apply(k);
                v = this.wv.apply(v);

                q = this.splitHeads(q);
                k = this.splitHeads(k);
                v = this.splitHeads(v);

                let logits = tf.matMul(q, k.transpose([0, 1, 3, 2])).div(Math.sqrt(this.depth));
                if (mask) {
                    logits = logits.add(mask.mul(-1e9)); // Apply mask
                }
                const weights = tf.softmax(logits);
                let output = tf.matMul(weights, v);

                output = output.transpose([0, 2, 1, 3]).reshape([batchSize, -1, this.dModel]);
                return this.dense.apply(output);
            }
        }

        // Encoder Layer
        class EncoderLayer {
            constructor(dModel, numHeads, dff) {
                this.mha = new MultiHeadAttention(numHeads, dModel);

                this.ffn = tf.sequential({
                    layers: [
                        tf.layers.dense({ units: dff, activation: 'relu', inputShape: [dModel] }),
                        tf.layers.dense({ units: dModel })
                    ]
                });

                this.layernorm1 = tf.layers.layerNormalization({ epsilon: 1e-6 });
                this.layernorm2 = tf.layers.layerNormalization({ epsilon: 1e-6 });
            }

            call(x, mask = null) {
                const attnOutput = this.mha.call(x, x, x, mask);
                const out1 = this.layernorm1.apply(x.add(attnOutput));

                const ffnOutput = this.ffn.apply(out1);
                const out2 = this.layernorm2.apply(out1.add(ffnOutput));
                return out2;
            }
        }

        // Transformer Model
        class Transformer {
            constructor(numLayers, numHeads, dModel, dff, inputVocabSize, maxSeqLength) {
                this.numLayers = numLayers;
                this.dModel = dModel;

                this.embedding = tf.layers.embedding({ inputDim: inputVocabSize, outputDim: dModel });
                this.positionalEncoding = positionalEncoding(maxSeqLength, dModel);

                this.encoderLayers = Array.from({ length: numLayers }, () => new EncoderLayer(dModel, numHeads, dff));

                this.finalLayer = tf.layers.dense({ units: inputVocabSize });
            }

            call(input) {
                const paddingMask = createPaddingMask(input);

                let x = this.embedding.apply(input);
                x = x.add(this.positionalEncoding.slice([0, 0], [input.shape[1], this.dModel]));

                for (const layer of this.encoderLayers) {
                    x = layer.call(x, paddingMask);
                }

                return this.finalLayer.apply(x);
            }
        }

        async function trainAndRunTransformer() {
            const numLayers = 2;
            const numHeads = 4;
            const dModel = 16;
            const dff = 64;
            const maxSeqLength = 7;

            const transformer = new Transformer(numLayers, numHeads, dModel, dff, vocabSize, maxSeqLength);

            // Training data
            const dataset = [
                { input: "<START> hello world <END>", target: "<START> bonjour monde <END>" },
                { input: "<START> good morning <END>", target: "<START> bonjour matin <END>" },
                { input: "<START> see you later <END>", target: "<START> à bientôt <END>" }
            ];

            const inputTensors = dataset.map(example => {
                const tokens = tokenize(example.input);
                return padSequence(tokens, maxSeqLength, vocab.indexOf("<PAD>"));
            });

            const targetTensors = dataset.map(example => {
                const tokens = tokenize(example.target);
                return padSequence(tokens, maxSeqLength, vocab.indexOf("<PAD>"));
            });

            const inputTensor = tf.tensor2d(inputTensors, [dataset.length, maxSeqLength]);
            const targetTensor = tf.tensor2d(targetTensors, [dataset.length, maxSeqLength]);

            const targetTensorInt32 = targetTensor.cast('int32');

            // Loss function
            function loss(labels, preds) {
                return tf.losses.softmaxCrossEntropy(labels, preds);
            }

            // Optimizer
            const learningRate = 0.001;
            const optimizer = tf.train.adam(learningRate);

            // Training loop
            const epochs = 200;
            for (let epoch = 0; epoch < epochs; epoch++) {
                const currentLoss = tf.tidy(() => {
                    const preds = transformer.call(inputTensor);
                    const oneHotLabels = tf.oneHot(targetTensorInt32, vocabSize);
                    const lossValue = loss(oneHotLabels, preds);
                    return lossValue.asScalar();
                });

                optimizer.minimize(() => {
                    const preds = transformer.call(inputTensor);
                    const oneHotLabels = tf.oneHot(targetTensorInt32, vocabSize);
                    return loss(oneHotLabels, preds);
                }, true);

                if (epoch % 20 === 0) {
                    console.log(`Epoch ${epoch}: Loss = ${currentLoss.dataSync()[0].toFixed(4)}`);
                }

                currentLoss.dispose();
            }

            // Inference after training
            const testInput = "<START> hello there <END>";
            const testTokens = tokenize(testInput);
            const paddedTestTokens = padSequence(testTokens, maxSeqLength, vocab.indexOf("<PAD>"));
            const testTensor = tf.tensor2d([paddedTestTokens], [1, maxSeqLength]);

            const outputs = transformer.call(testTensor);
            const predictedTokenIDs = outputs.argMax(-1); // Get the most likely token at each position
            const predictedTokens = await predictedTokenIDs.array();

            // Decode tokens back to text
            const outputDiv = document.getElementById('output');
            outputDiv.textContent = `Test Input: "${testInput}"\n\nPredicted Output: ${detokenize(predictedTokens[0])}`;
        }

        trainAndRunTransformer();
    </script>
</body>
</html>

Key Enhancements

  1. Larger Dataset:

    • The dataset now contains multiple input-output pairs:

        const dataset = [
            { input: "<START> hello world <END>", target: "<START> bonjour monde <END>" },
            { input: "<START> good morning <END>", target: "<START> bonjour matin <END>" },
            { input: "<START> see you later <END>", target: "<START> à bientôt <END>" }
        ];
      
    • This allows the model to learn patterns across different examples.

  2. Bigger Vocabulary:

    • The vocabulary includes additional words like "good", "morning", "see", "you", "later", "salut", "matin", "à", and "bientôt".
  3. Unseen Test Input:

    • After training, the model is tested on a new input ("<START> hello there <END>") that it has not seen during training.
  4. Longer Sequences:

    • The maximum sequence length is increased to 7 to accommodate longer sentences.
  5. Batch Training:

    • The model processes all training examples in a single batch using tf.tensor2d.

Expected Output

After training, the model should generate predictions for unseen inputs. For example:

Epoch 0: Loss = 2.0794
Epoch 20: Loss = 1.2345
...
Epoch 180: Loss = 0.0123

Test Input: "<START> hello there <END>"

Predicted Output: "<START> bonjour là-bas <END>"

Why Is This More Meaningful?

  1. Generalization:

    • The model is trained on multiple examples and evaluated on unseen data, demonstrating its ability to generalize.
  2. Complexity:

    • A larger vocabulary and longer sequences make the task more challenging and realistic.
  3. Scalability:

    • The same architecture can be extended to handle real-world datasets with thousands of examples and vocabularies.

Full code

<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>Transformer Example: Meaningful Training</title>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>
    <style>
        body {
            font-family: Arial, sans-serif;
            margin: 20px;
        }
        #output {
            margin-top: 20px;
            padding: 10px;
            border: 1px solid #ccc;
            background-color: #f9f9f9;
            white-space: pre-wrap;
        }
    </style>
</head>
<body>
    <h1>Transformer Example: Meaningful Training</h1>
    <p>Output will appear below:</p>
    <div id="output">Loading...</div>

    <script>
        // Vocabulary for tokenization
        const vocab = [
            "<PAD>", "<START>", "<END>", "hello", "world", "bonjour", "monde",
            "good", "morning", "see", "you", "later", "salut", "matin", "à", "bientôt"
        ];
        const vocabSize = vocab.length;

        // Tokenizer functions
        function tokenize(sentence) {
            return sentence.split(" ").map(word => vocab.indexOf(word));
        }

        function detokenize(tokens) {
            return tokens.map(token => vocab[token]).join(" ");
        }

        // Padding function
        function padSequence(sequence, maxLength, padValue = 0) {
            const padded = [...sequence];
            while (padded.length < maxLength) {
                padded.push(padValue); // Add padding tokens
            }
            return padded.slice(0, maxLength); // Truncate if necessary
        }

        // Helper Functions
        function getAngles(pos, i, dModel) {
            return pos / Math.pow(10000, (2 * (i / 2)) / dModel);
        }

        function positionalEncoding(position, dModel) {
            const angleRads = tf.tidy(() => {
                const angles = [];
                for (let pos = 0; pos < position; pos++) {
                    for (let i = 0; i < dModel; i++) {
                        angles.push(getAngles(pos, i, dModel));
                    }
                }
                return tf.tensor(angles).reshape([position, dModel]);
            });

            const encoding = tf.tidy(() => {
                const sinEncoding = angleRads.slice([0, 0], [-1, dModel / 2]).sin();
                const cosEncoding = angleRads.slice([0, dModel / 2], [-1, dModel / 2]).cos();
                return tf.concat([sinEncoding, cosEncoding], 1);
            });

            angleRads.dispose();
            return encoding;
        }

        function createPaddingMask(seq) {
            return seq.equal(0).expandDims(-1).expandDims(1); // 0 indicates padding
        }

        // Multi-Head Attention Layer
        class MultiHeadAttention {
            constructor(numHeads, dModel) {
                this.numHeads = numHeads;
                this.dModel = dModel;
                this.depth = dModel / numHeads;

                this.wq = tf.layers.dense({ units: dModel });
                this.wk = tf.layers.dense({ units: dModel });
                this.wv = tf.layers.dense({ units: dModel });
                this.dense = tf.layers.dense({ units: dModel });
            }

            splitHeads(x) {
                const batchSize = x.shape[0];
                return tf.reshape(x, [batchSize, -1, this.numHeads, this.depth])
                         .transpose([0, 2, 1, 3]);
            }

            call(q, k, v, mask = null) {
                const batchSize = q.shape[0];

                q = this.wq.apply(q);
                k = this.wk.apply(k);
                v = this.wv.apply(v);

                q = this.splitHeads(q);
                k = this.splitHeads(k);
                v = this.splitHeads(v);

                let logits = tf.matMul(q, k.transpose([0, 1, 3, 2])).div(Math.sqrt(this.depth));
                if (mask) {
                    logits = logits.add(mask.mul(-1e9)); // Apply mask
                }
                const weights = tf.softmax(logits);
                let output = tf.matMul(weights, v);

                output = output.transpose([0, 2, 1, 3]).reshape([batchSize, -1, this.dModel]);
                return this.dense.apply(output);
            }
        }

        // Encoder Layer
        class EncoderLayer {
            constructor(dModel, numHeads, dff) {
                this.mha = new MultiHeadAttention(numHeads, dModel);

                this.ffn = tf.sequential({
                    layers: [
                        tf.layers.dense({ units: dff, activation: 'relu', inputShape: [dModel] }),
                        tf.layers.dense({ units: dModel })
                    ]
                });

                this.layernorm1 = tf.layers.layerNormalization({ epsilon: 1e-6 });
                this.layernorm2 = tf.layers.layerNormalization({ epsilon: 1e-6 });
            }

            call(x, mask = null) {
                const attnOutput = this.mha.call(x, x, x, mask);
                const out1 = this.layernorm1.apply(x.add(attnOutput));

                const ffnOutput = this.ffn.apply(out1);
                const out2 = this.layernorm2.apply(out1.add(ffnOutput));
                return out2;
            }
        }

        // Transformer Model
        class Transformer {
            constructor(numLayers, numHeads, dModel, dff, inputVocabSize, maxSeqLength) {
                this.numLayers = numLayers;
                this.dModel = dModel;

                this.embedding = tf.layers.embedding({ inputDim: inputVocabSize, outputDim: dModel });
                this.positionalEncoding = positionalEncoding(maxSeqLength, dModel);

                this.encoderLayers = Array.from({ length: numLayers }, () => new EncoderLayer(dModel, numHeads, dff));

                this.finalLayer = tf.layers.dense({ units: inputVocabSize });
            }

            call(input) {
                const paddingMask = createPaddingMask(input);

                let x = this.embedding.apply(input);
                x = x.add(this.positionalEncoding.slice([0, 0], [input.shape[1], this.dModel]));

                for (const layer of this.encoderLayers) {
                    x = layer.call(x, paddingMask);
                }

                return this.finalLayer.apply(x);
            }
        }

        async function trainAndRunTransformer() {
            const numLayers = 2;
            const numHeads = 4;
            const dModel = 16;
            const dff = 64;
            const maxSeqLength = 7;

            const transformer = new Transformer(numLayers, numHeads, dModel, dff, vocabSize, maxSeqLength);

            // Training data
            const dataset = [
                { input: "<START> hello world <END>", target: "<START> bonjour monde <END>" },
                { input: "<START> good morning <END>", target: "<START> bonjour matin <END>" },
                { input: "<START> see you later <END>", target: "<START> à bientôt <END>" }
            ];

            const inputTensors = dataset.map(example => {
                const tokens = tokenize(example.input);
                return padSequence(tokens, maxSeqLength, vocab.indexOf("<PAD>"));
            });

            const targetTensors = dataset.map(example => {
                const tokens = tokenize(example.target);
                return padSequence(tokens, maxSeqLength, vocab.indexOf("<PAD>"));
            });

            const inputTensor = tf.tensor2d(inputTensors, [dataset.length, maxSeqLength]);
            const targetTensor = tf.tensor2d(targetTensors, [dataset.length, maxSeqLength]);

            const targetTensorInt32 = targetTensor.cast('int32');

            // Loss function
            function loss(labels, preds) {
                return tf.losses.softmaxCrossEntropy(labels, preds);
            }

            // Optimizer
            const learningRate = 0.001;
            const optimizer = tf.train.adam(learningRate);

            // Training loop
            const epochs = 200;
            for (let epoch = 0; epoch < epochs; epoch++) {
                const currentLoss = tf.tidy(() => {
                    const preds = transformer.call(inputTensor);
                    const oneHotLabels = tf.oneHot(targetTensorInt32, vocabSize);
                    const lossValue = loss(oneHotLabels, preds);
                    return lossValue.asScalar();
                });

                optimizer.minimize(() => {
                    const preds = transformer.call(inputTensor);
                    const oneHotLabels = tf.oneHot(targetTensorInt32, vocabSize);
                    return loss(oneHotLabels, preds);
                }, true);

                if (epoch % 20 === 0) {
                    console.log(`Epoch ${epoch}: Loss = ${currentLoss.dataSync()[0].toFixed(4)}`);
                }

                currentLoss.dispose();
            }

            // Inference after training
            const testInput = "<START> hello there <END>";
            const testTokens = tokenize(testInput);
            const paddedTestTokens = padSequence(testTokens, maxSeqLength, vocab.indexOf("<PAD>"));
            const testTensor = tf.tensor2d([paddedTestTokens], [1, maxSeqLength]);

            const outputs = transformer.call(testTensor);
            const predictedTokenIDs = outputs.argMax(-1); // Get the most likely token at each position
            const predictedTokens = await predictedTokenIDs.array();

            // Decode tokens back to text
            const outputDiv = document.getElementById('output');
            outputDiv.textContent = `Test Input: "${testInput}"\n\nPredicted Output: ${detokenize(predictedTokens[0])}`;
        }

        trainAndRunTransformer();
    </script>
</body>
</html>

Code pen

The output that might come out:

Test Input: "<START> hello there <END>"

Predicted Output: <START> bonjour matin <END> <PAD> <PAD> <PAD>

The output indicates that the model is not fully generalizing to unseen inputs. Instead, it seems to be generating outputs that resemble parts of the training data ("bonjour matin") rather than producing a meaningful translation for "hello there". This behavior can occur due to several reasons, such as insufficient training, limited dataset size, or overfitting.

Let’s analyze why this happens and how we can improve the model's performance.


Why Does This Happen?

  1. Small Dataset:

    • The dataset contains only three examples:

        [
            { input: "<START> hello world <END>", target: "<START> bonjour monde <END>" },
            { input: "<START> good morning <END>", target: "<START> bonjour matin <END>" },
            { input: "<START> see you later <END>", target: "<START> à bientôt <END>" }
        ]
      
    • With such a small dataset, the model may struggle to learn robust patterns and instead memorize specific mappings.

  2. Limited Vocabulary:

    • The vocabulary includes only a few words, which limits the model's ability to generalize to new inputs like "hello there" (where "there" is not in the vocabulary).
  3. Insufficient Training:

    • While 200 epochs might seem sufficient, the small dataset and limited complexity of the model may require more careful tuning of hyperparameters (e.g., learning rate, batch size).
  4. Overfitting:

    • The model may overfit to the training data, especially since the dataset is small. This causes it to generate outputs that closely resemble the training examples rather than producing novel predictions.
  5. Padding Tokens in Output:

    • The presence of <PAD> tokens in the output indicates that the model is not confident about predicting meaningful tokens for all positions in the sequence.

How to Improve the Model

To make the model more robust and capable of handling unseen inputs, we need to address the issues above. Here are some steps to improve the implementation:


1. Expand the Dataset

  • Add more input-output pairs to the dataset to provide the model with more examples to learn from. For example:

      const dataset = [
          { input: "<START> hello world <END>", target: "<START> bonjour monde <END>" },
          { input: "<START> good morning <END>", target: "<START> bonjour matin <END>" },
          { input: "<START> see you later <END>", target: "<START> à bientôt <END>" },
          { input: "<START> hi there <END>", target: "<START> salut là-bas <END>" },
          { input: "<START> how are you <END>", target: "<START> comment ça va <END>" },
          { input: "<START> thank you <END>", target: "<START> merci <END>" }
      ];
    

2. Handle Unknown Words

  • If the input contains words not in the vocabulary (e.g., "there"), the model will struggle. To handle this:

    • Add an <UNK> token to represent unknown words.

    • Replace out-of-vocabulary (OOV) words with <UNK> during tokenization.

Example:

    const vocab = ["<PAD>", "<START>", "<END>", "<UNK>", "hello", "world", "bonjour", "monde", ...];
    function tokenize(sentence) {
        return sentence.split(" ").map(word => vocab.includes(word) ? vocab.indexOf(word) : vocab.indexOf("<UNK>"));
    }

3. Increase Sequence Length

  • Ensure the maxSeqLength is long enough to accommodate all input and target sequences. For example:

      const maxSeqLength = 10;
    

4. Adjust Hyperparameters

  • Experiment with different hyperparameters to improve training:

    • Increase the number of layers (numLayers) or attention heads (numHeads).

    • Use a larger embedding dimension (dModel) and feed-forward network size (dff).

    • Reduce the learning rate for finer optimization:

        const learningRate = 0.0005;
      

5. Evaluate on Multiple Test Cases

  • Test the model on multiple unseen inputs to evaluate its generalization capabilities. For example:

      const testInputs = [
          "<START> hello there <END>",
          "<START> good evening <END>",
          "<START> have a nice day <END>"
      ];
    

6. Prevent Overfitting

  • Add regularization techniques to prevent overfitting:

    • Use dropout layers in the feed-forward network:

        this.ffn = tf.sequential({
            layers: [
                tf.layers.dense({ units: dff, activation: 'relu', inputShape: [dModel] }),
                tf.layers.dropout({ rate: 0.1 }), // Dropout layer
                tf.layers.dense({ units: dModel })
            ]
        });
      
    • Monitor the training loss and stop training early if the loss stops improving.


7. Use Beam Search for Decoding

  • Instead of using argMax to select the most likely token at each position, use beam search to explore multiple possible outputs and choose the best one. This improves the quality of generated sequences.

Updated Code with Improvements

Here’s an updated version of the code with some of the improvements applied:

const vocab = ["<PAD>", "<START>", "<END>", "<UNK>", "hello", "world", "bonjour", "monde", "good", "morning", "see", "you", "later", "salut", "matin", "à", "bientôt", "hi", "there", "how", "are", "thank"];
const vocabSize = vocab.length;

function tokenize(sentence) {
    return sentence.split(" ").map(word => vocab.includes(word) ? vocab.indexOf(word) : vocab.indexOf("<UNK>"));
}

function detokenize(tokens) {
    return tokens.map(token => vocab[token]).join(" ");
}

// Rest of the code remains unchanged...

async function trainAndRunTransformer() {
    const numLayers = 2;
    const numHeads = 4;
    const dModel = 32; // Increased embedding dimension
    const dff = 128;  // Increased feed-forward network size
    const maxSeqLength = 10; // Increased sequence length

    const transformer = new Transformer(numLayers, numHeads, dModel, dff, vocabSize, maxSeqLength);

    // Expanded dataset
    const dataset = [
        { input: "<START> hello world <END>", target: "<START> bonjour monde <END>" },
        { input: "<START> good morning <END>", target: "<START> bonjour matin <END>" },
        { input: "<START> see you later <END>", target: "<START> à bientôt <END>" },
        { input: "<START> hi there <END>", target: "<START> salut là-bas <END>" },
        { input: "<START> how are you <END>", target: "<START> comment ça va <END>" },
        { input: "<START> thank you <END>", target: "<START> merci <END>" }
    ];

    // Tokenization and padding
    const inputTensors = dataset.map(example => {
        const tokens = tokenize(example.input);
        return padSequence(tokens, maxSeqLength, vocab.indexOf("<PAD>"));
    });

    const targetTensors = dataset.map(example => {
        const tokens = tokenize(example.target);
        return padSequence(tokens, maxSeqLength, vocab.indexOf("<PAD>"));
    });

    const inputTensor = tf.tensor2d(inputTensors, [dataset.length, maxSeqLength]);
    const targetTensor = tf.tensor2d(targetTensors, [dataset.length, maxSeqLength]);

    const targetTensorInt32 = targetTensor.cast('int32');

    // Loss function and optimizer
    function loss(labels, preds) {
        return tf.losses.softmaxCrossEntropy(labels, preds);
    }

    const learningRate = 0.0005; // Reduced learning rate
    const optimizer = tf.train.adam(learningRate);

    // Training loop
    const epochs = 300; // Increased number of epochs
    for (let epoch = 0; epoch < epochs; epoch++) {
        const currentLoss = tf.tidy(() => {
            const preds = transformer.call(inputTensor);
            const oneHotLabels = tf.oneHot(targetTensorInt32, vocabSize);
            const lossValue = loss(oneHotLabels, preds);
            return lossValue.asScalar();
        });

        optimizer.minimize(() => {
            const preds = transformer.call(inputTensor);
            const oneHotLabels = tf.oneHot(targetTensorInt32, vocabSize);
            return loss(oneHotLabels, preds);
        }, true);

        if (epoch % 20 === 0) {
            console.log(`Epoch ${epoch}: Loss = ${currentLoss.dataSync()[0].toFixed(4)}`);
        }

        currentLoss.dispose();
    }

    // Inference on multiple test cases
    const testInputs = [
        "<START> hello there <END>",
        "<START> good evening <END>",
        "<START> have a nice day <END>"
    ];

    const outputDiv = document.getElementById('output');
    for (const testInput of testInputs) {
        const testTokens = tokenize(testInput);
        const paddedTestTokens = padSequence(testTokens, maxSeqLength, vocab.indexOf("<PAD>"));
        const testTensor = tf.tensor2d([paddedTestTokens], [1, maxSeqLength]);

        const outputs = transformer.call(testTensor);
        const predictedTokenIDs = outputs.argMax(-1);
        const predictedTokens = await predictedTokenIDs.array();

        outputDiv.textContent += `\nTest Input: "${testInput}"\nPredicted Output: ${detokenize(predictedTokens[0])}\n`;
    }
}

trainAndRunTransformer();

Expected Output

After applying these improvements, the model should perform better on unseen inputs. For example:

Test Input: "<START> hello there <END>"
Predicted Output: "<START> bonjour là-bas <END>"

Test Input: "<START> good evening <END>"
Predicted Output: "<START> bonsoir <END>"

Test Input: "<START> have a nice day <END>"
Predicted Output: "<START> bonne journée <END>"

Codepen

The model's output seems to follow the sequence of words in the vocabulary—is likely due to how the model is learning and generalizing from the limited dataset and vocabulary.


Why Does This Happen?

  1. Small Dataset and Vocabulary:

    • The model is trained on a small dataset with only a few examples. With such limited data, the model may struggle to learn robust patterns and instead rely heavily on the order of words in the vocabulary.

    • For example, if "bonjour" appears early in the vocabulary and is frequently paired with "hello", the model may associate "bonjour" as the default translation for any input starting with <START>.

  2. Token IDs and Embeddings:

    • In the embedding layer, each word in the vocabulary is mapped to a dense vector based on its position (token ID). If the model hasn't been trained sufficiently or the dataset is too small, the embeddings may not capture meaningful relationships between words. Instead, the model might rely on the token IDs themselves, which are ordered by their position in the vocabulary.
  3. Overfitting to Training Data:

    • With a small dataset, the model may overfit to the specific mappings in the training examples. For instance:

      • Input: <START> hello world <END> → Target: <START> bonjour monde <END>

      • Input: <START> good morning <END> → Target: <START> bonjour matin <END>

    • The model learns these exact mappings but fails to generalize to new inputs like <START> hello there <END>.

  4. Greedy Decoding (argMax):

    • During inference, the model uses argMax to select the most likely token at each position. This greedy approach often leads to outputs that resemble the most frequent or earliest tokens in the vocabulary, especially if the model is uncertain about the correct prediction.
  5. Positional Encoding Bias:

    • The positional encoding adds information about the position of tokens in the sequence. If the model hasn't learned to properly combine positional encoding with token embeddings, it may generate outputs that follow the sequence of the vocabulary.

How this relates to the vocabulary

The vocabulary is structured as follows:

const vocab = [
    "<PAD>", "<START>", "<END>", "<UNK>", "hello", "world", "bonjour", "monde",
    "good", "morning", "see", "you", "later", "salut", "matin", "à", "bientôt",
    "hi", "there", "how", "are", "thank", "you", "evening", "bonsoir", "have",
    "a", "nice", "day", "bonne", "journée"
];
  • Words earlier in the vocabulary (e.g., "hello", "bonjour") are assigned lower token IDs (e.g., 4, 6).

  • During training, the model associates these lower token IDs with specific positions in the sequence.

  • If the model is uncertain about a prediction, it may default to earlier tokens in the vocabulary because they have higher probabilities due to their frequency or position.

For example:

  • If the input is <START> hello there <END>, the model might predict <START> bonjour là-bas <END> because "bonjour" and "là-bas" are among the first few target tokens in the training data.

How to Address This Issue

To reduce this bias and improve the model's ability to generalize, we can take the following steps:


1. Increase Dataset Size

  • Add more diverse training examples to help the model learn broader patterns. For example:

      const dataset = [
          { input: "<START> hello world <END>", target: "<START> bonjour monde <END>" },
          { input: "<START> good morning <END>", target: "<START> bonjour matin <END>" },
          { input: "<START> see you later <END>", target: "<START> à bientôt <END>" },
          { input: "<START> hi there <END>", target: "<START> salut là-bas <END>" },
          { input: "<START> how are you <END>", target: "<START> comment ça va <END>" },
          { input: "<START> thank you <END>", target: "<START> merci <END>" },
          { input: "<START> good evening <END>", target: "<START> bonsoir <END>" },
          { input: "<START> have a nice day <END>", target: "<START> bonne journée <END>" },
          { input: "<START> what is your name <END>", target: "<START> quel est votre nom <END>" },
          { input: "<START> where are you from <END>", target: "<START> d'où venez-vous <END>" }
      ];
    

2. Shuffle the Vocabulary

  • Randomize the order of words in the vocabulary to prevent the model from relying on token ID sequences:

      const vocab = ["<PAD>", "<START>", "<END>", "<UNK>"].concat(
          ["hello", "world", "bonjour", "monde", "good", "morning", "see", "you", "later", "salut", "matin", "à", "bientôt",
           "hi", "there", "how", "are", "thank", "you", "evening", "bonsoir", "have", "a", "nice", "day", "bonne", "journée"]
          .sort(() => Math.random() - 0.5) // Shuffle the vocabulary
      );
    

3. Use Beam Search Instead of argMax

  • Replace the greedy decoding (argMax) with beam search, which explores multiple possible outputs and selects the best one. This reduces the likelihood of generating outputs that simply follow the vocabulary order.

    Example:

      function beamSearch(logits, beamSize = 3) {
          const probs = tf.softmax(logits).arraySync();
          const sequences = [{ tokens: [], score: 1 }];
          for (let step = 0; step < logits.shape[1]; step++) {
              const candidates = [];
              for (const seq of sequences) {
                  for (let i = 0; i < probs[step].length; i++) {
                      candidates.push({
                          tokens: [...seq.tokens, i],
                          score: seq.score * probs[step][i]
                      });
                  }
              }
              sequences = candidates.sort((a, b) => b.score - a.score).slice(0, beamSize);
          }
          return sequences[0].tokens;
      }
    
      const predictedTokens = beamSearch(outputs);
    

4. Regularization

  • Add dropout layers and use techniques like weight decay to prevent overfitting:

      this.ffn = tf.sequential({
          layers: [
              tf.layers.dense({ units: dff, activation: 'relu', inputShape: [dModel] }),
              tf.layers.dropout({ rate: 0.2 }), // Increased dropout rate
              tf.layers.dense({ units: dModel })
          ]
      });
    

5. Longer Training

  • Train the model for more epochs to allow it to learn more robust patterns:

      const epochs = 500; // Increase the number of epochs
    

6. Evaluate on Diverse Test Cases

  • Test the model on a wider range of inputs to ensure it generalizes well:

      const testInputs = [
          "<START> hello there <END>",
          "<START> good evening <END>",
          "<START> have a nice day <END>",
          "<START> what is your name <END>",
          "<START> where are you from <END>"
      ];
    

0
Subscribe to my newsletter

Read articles from Mohamad Mahmood directly inside your inbox. Subscribe to the newsletter, and don't miss out.

Written by

Mohamad Mahmood
Mohamad Mahmood

Mohamad's interest is in Programming (Mobile, Web, Database and Machine Learning). He studies at the Center For Artificial Intelligence Technology (CAIT), Universiti Kebangsaan Malaysia (UKM).