End To End Paper Implementation “Attention Is All You Need”

Ramazan TuranRamazan Turan
15 min read

In this article, I present an end-to-end implementation of the paper “Attention is All You Need”, along with selected quotes from the paper.

This article focuses only on implementation. For a more explanatory and conceptual guide, I recommend the following YouTube video: https://www.youtube.com/watch?v=KJtZARuO3JY

Detailed Implementation

Components

Architecture

Encoder

Quote: “The encoder is composed of a stack of N = 6 identical layers*. Each layer has two sub-layers. The first is a* multi-head self-attention mechanism, and the second is a simple, position-wise fully connected feed-forward network. We employ a residual connection around each of the two sub-layers, followed by layer normalization*. That is, the output of eac*h sub-layer is LayerNorm(x + Sublayer(x)), where Sublayer(x) is the function implemented by the sub-layer itself. To facilitate these residual connections, all sub-layers in the model, as well as the embedding layers, produce outputs of dimension dmodel = 512.*”*

Positional Encoding

class PositionalEncoding(nn.Module):
    """
    Adds positional encoding to the token embeddings for the Transformer model
    Paper Reference: Section 3.5 "Positional Encoding"
    Described in Equations 5 and 6.
    """
    def __init__(self, d_model, max_len=5000, dropout=0.1):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        # Positional encoding calculation
        # Paper Reference: Section 3.5, Equations 5 and 6
        # PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
        # PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)  # Sine for even indices
        pe[:, 1::2] = torch.cos(position * div_term)  # Cosine for odd indices
        pe = pe.unsqueeze(0).transpose(0, 1)  # [max_len, 1, d_model]

        # Register as a persistent buffer (not a model parameter)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Args:
            x: [seq_len, batch_size, d_model]
        """
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

Multi-Head Self-Attention


class MultiHeadAttention(nn.Module):
    """
    Multi-Head Attention mechanism
    Paper Reference: Section 3.2.2 "Multi-Head Attention"
    Described in Equations 1 and 2.
    Structure is shown on the left side of Figure 2 in the paper.
    """

    def __init__(self, d_model, n_head, dropout=0.1):
        super(MultiHeadAttention, self).__init__()
        assert d_model % n_head == 0

        self.d_model = d_model
        self.n_head = n_head
        self.d_k = d_model // n_head

        # Linear projections
        self.wq = nn.Linear(d_model, d_model)
        self.wk = nn.Linear(d_model, d_model)
        self.wv = nn.Linear(d_model, d_model)
        self.fc = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value, mask=None):
        """
        Args:
            query: [batch_size, seq_len_q, d_model]
            key: [batch_size, seq_len_k, d_model]
            value: [batch_size, seq_len_v, d_model]
            mask: [batch_size, seq_len_q, seq_len_k] or [batch_size, 1, seq_len_q, seq_len_k]

        Paper Reference: Sections 3.2.1 and 3.2.2
        Attention(Q, K, V) = softmax(QK^T / sqrt(d_k))V
        MultiHead(Q, K, V) = Concat(head_1, ..., head_h)W^O
            head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)
        """
        batch_size = query.size(0)

        # Linear projections and head separation
        # [batch_size, seq_len, n_head, d_k]
        q = self.wq(query).view(batch_size, -1, self.n_head, self.d_k).transpose(1, 2)
        k = self.wk(key).view(batch_size, -1, self.n_head, self.d_k).transpose(1, 2)
        v = self.wv(value).view(batch_size, -1, self.n_head, self.d_k).transpose(1, 2)

        # Scaled Dot-Product Attention calculation
        # Paper Reference: Section 3.2.1, Equation 1
        # Attention(Q, K, V) = softmax(QK^T / sqrt(d_k))V
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)

        # Masking (optional)
        if mask is not None:
            # Handle different mask dimensions
            if mask.dim() == 3:  # [batch_size, seq_len_q, seq_len_k]
                mask = mask.unsqueeze(1)  # [batch_size, 1, seq_len_q, seq_len_k]
            elif mask.dim() == 4:  # [batch_size, 1, seq_len_q, seq_len_k]
                pass  # Already correct dimension

            # Expand mask for all heads
            mask = mask.expand(batch_size, self.n_head, -1, -1)
            scores = scores.masked_fill(mask == 0, -1e9)

        # Softmax and Dropout
        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)

        # Output calculation
        out = torch.matmul(attn, v)  # [batch_size, n_head, seq_len_q, d_k]
        out = out.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        out = self.fc(out)

        return out

Positional Feed-Forward

class PositionwiseFeedForward(nn.Module):
    """
    Two-layer Feed-Forward Network
    Paper Reference: Section 3.3 "Position-wise Feed-Forward Networks"
    FFN(x) = max(0, xW₁ + b₁)W₂ + b₂ 
    Described in Equation 2.
    """
    def __init__(self, d_model, d_ff, dropout=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        """
        Args:
            x: [batch_size, seq_len, d_model]
        """
        return self.fc2(self.dropout(F.relu(self.fc1(x))))

Encoder Layer

class EncoderLayer(nn.Module):
    """
    Transformer Encoder Layer
    Paper Reference: Section 3.1 "Encoder and Decoder Stacks" - Encoder part
    Structure shown on the left side of Figure 1.
    Each encoder layer has a multi-head self-attention and a feed-forward network.
    Each sublayer has a residual connection and layer normalization.
    """

    def __init__(self, d_model, n_head, d_ff, dropout=0.1):
        super(EncoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, n_head, dropout)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        """
        Quote: "We apply dropout [33] to the output of each sub-layer, before it is 
        added to the sub-layer input and normalized"
        Section: 5.4 Regularization
        """

    def forward(self, x, mask=None):
        """
        Args:
            x: [batch_size, seq_len, d_model]
            mask: [batch_size, seq_len, seq_len]

        Paper Reference: Section 3.1, "Sublayer Connection"
        For each sublayer: LayerNorm(x + Sublayer(x))
        """
        # First sublayer: Multi-Head Self-Attention
        attn_output = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout1(attn_output)) # Residual connection + LayerNorm

        # Second sublayer: Position-wise Feed-Forward Network
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout2(ff_output)) # Residual connection + LayerNorm

Encoder

class Encoder(nn.Module):
    """
    Transformer Encoder
    Paper Reference: Section 3.1 "Encoder and Decoder Stacks"
    The encoder consists of N=6 identical encoder layers.
    First applies token embedding and positional encoding.
    """
    def __init__(self, vocab_size, d_model, n_head, d_ff, n_layers, dropout=0.1, max_len=5000):
        super(Encoder, self).__init__()
        # Paper Reference: Section 3.4, "Embeddings and Softmax"
        # "We multiply those weights by sqrt(d_model)"
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_len, dropout)
        self.layers = nn.ModuleList([
            EncoderLayer(d_model, n_head, d_ff, dropout) for _ in range(n_layers)
        ])
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x, mask=None):
        """
        Args:
            x: [batch_size, src_seq_len]
            mask: [batch_size, src_seq_len, src_seq_len]

        Paper Reference: Section 3.1 "Encoder"
        The encoder consists of N identical encoder layers.
        """
        # Embedding and Positional Encoding
        # Paper Reference: Section 3.4, "We multiply those weights by sqrt(d_model)"
        x = self.embedding(x) * math.sqrt(self.embedding.embedding_dim)
        x = self.pos_encoding(x)

        for layer in self.layers:
            x = layer(x, mask)

        x = self.norm(x)
        return x

Decoder

Quote: “The decoder is also composed of a stack of N = 6 identical layers. In addition to the two sub-layers in each encoder layer, the decoder inserts a third sub-layer, which performs multi-head attention over the output of the encoder stack. Similar to the encoder, we employ residual connections around each of the sub-layers, followed by layer normalization. We also modify the self-attention sub-layer in the decoder stack to prevent positions from attending to subsequent positions. This masking, combined with fact that the output embeddings are offset by one position, ensures that the predictions for position i can depend only on the known outputs at positions less than i.”

Decoder Layer

class DecoderLayer(nn.Module):
    """
    Transformer Decoder Layer
    Paper Reference: Section 3.1 "Encoder and Decoder Stacks" - Decoder part
    Structure shown on the right side of Figure 1.
    Each decoder layer has a masked multi-head self-attention, 
    a multi-head cross-attention, and a feed-forward network.
    Each sublayer has a residual connection and layer normalization.
    """
    def __init__(self, d_model, n_head, d_ff, dropout=0.1):
        super(DecoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, n_head, dropout)
        self.cross_attn = MultiHeadAttention(d_model, n_head, dropout)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

    def forward(self, x, enc_output, self_mask=None, cross_mask=None):
        """
        Args:
            x: [batch_size, seq_len, d_model]
            enc_output: [batch_size, src_seq_len, d_model]
            self_mask: [batch_size, seq_len, seq_len]
            cross_mask: [batch_size, seq_len, src_seq_len]

        Paper Reference: Section 3.1, "Decoder"
        The decoder has masked multi-head attention, multi-head attention, and 
        feed-forward network sublayers.
        For each sublayer: LayerNorm(x + Sublayer(x))
        """
        attn_output = self.self_attn(x, x, x, self_mask)
        x = self.norm1(x + self.dropout1(attn_output))  # Residual connection + LayerNorm

        # Sublayer with Cross-Attention (Decoder attends to encoder output)
        attn_output = self.cross_attn(x, enc_output, enc_output, cross_mask)
        x = self.norm2(x + self.dropout2(attn_output)) # Residual connection + LayerNorm

        # Sublayer with Feed-Forward Network
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout3(ff_output)) # Residual connection + LayerNorm

        return x

Decoder

class Decoder(nn.Module):
    """
    Transformer Decoder
    Paper Reference: Section 3.1 "Encoder and Decoder Stacks"
    The decoder consists of N=6 identical decoder layers.
    Like the encoder, it first applies token embedding and positional encoding.
    The decoder also uses masking for subsequent positions (section 3.2.3).
    """

    def __init__(self, vocab_size, d_model, n_head, d_ff, n_layers, dropout=0.1, max_len=5000):
        super(Decoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_len, dropout)
        self.layers = nn.ModuleList([
            DecoderLayer(d_model, n_head, d_ff, dropout) for _ in range(n_layers)
        ])
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x, enc_output, self_mask=None, cross_mask=None):
        """
        Args:
            x: [batch_size, tgt_seq_len]
            enc_output: [batch_size, src_seq_len, d_model]
            self_mask: [batch_size, tgt_seq_len, tgt_seq_len] or [batch_size, 1, tgt_seq_len, tgt_seq_len]
            cross_mask: [batch_size, tgt_seq_len, src_seq_len] or [batch_size, 1, tgt_seq_len, src_seq_len]
        """
        # Embedding and Positional Encoding
        x = self.embedding(x) * math.sqrt(self.embedding.embedding_dim)
        x = self.pos_encoding(x)

        # Pass through decoder layers
        for layer in self.layers:
            x = layer(x, enc_output, self_mask, cross_mask)

        x = self.norm(x)
        return x

Transformer Model

class Transformer(nn.Module):
    """
    Transformer model (Attention is All You Need)
    Paper Reference: The entire paper, especially Section 3 and Figure 1
    The Transformer consists of an encoder and a decoder.
    The output projection converts decoder output to target word distributions.
    """
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512, n_head=8, 
                 d_ff=2048, n_layers=6, dropout=0.1, max_len=5000):
        super(Transformer, self).__init__()
        # Paper Reference: Section 3.1 and Table 1
        # Base model: d_model=512, n_head=8, d_ff=2048, n_layers=6, dropout=0.1
        self.encoder = Encoder(src_vocab_size, d_model, n_head, d_ff, n_layers, dropout, max_len)
        self.decoder = Decoder(tgt_vocab_size, d_model, n_head, d_ff, n_layers, dropout, max_len)
        self.projection = nn.Linear(d_model, tgt_vocab_size)

    def forward(self, src, tgt, src_mask=None, tgt_mask=None, memory_mask=None):
        """
        Args:
            src: [batch_size, src_seq_len]
            tgt: [batch_size, tgt_seq_len]
            src_mask: [batch_size, src_seq_len, src_seq_len]
            tgt_mask: [batch_size, tgt_seq_len, tgt_seq_len]
            memory_mask: [batch_size, tgt_seq_len, src_seq_len]

        Paper Reference: The entire paper, especially Figure 1
        Transformer model flow: 
        1. Encoder takes input and produces encoder output
        2. Decoder takes encoder output and its own input
        3. Final projection produces target word distribution
        """
        enc_output = self.encoder(src, src_mask)
        dec_output = self.decoder(tgt, enc_output, tgt_mask, memory_mask)
        output = self.projection(dec_output)
        return output

Full Code

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


class PositionalEncoding(nn.Module):
    """
    Adds positional encoding to the token embeddings for the Transformer model
    Paper Reference: Section 3.5 "Positional Encoding"
    Described in Equations 5 and 6.
    """

    def __init__(self, d_model, max_len=5000, dropout=0.1):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        # Positional encoding calculation
        # Paper Reference: Section 3.5, Equations 5 and 6
        # PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
        # PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)  # Even indices use sine
        pe[:, 1::2] = torch.cos(position * div_term)  # Odd indices use cosine
        pe = pe.unsqueeze(0)  # [1, max_len, d_model]

        # Register as persistent buffer (not a model parameter)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Args:
            x: [batch_size, seq_len, d_model]
        """
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)


class MultiHeadAttention(nn.Module):
    """
    Multi-Head Attention mechanism
    Paper Reference: Section 3.2.2 "Multi-Head Attention"
    Described in Equations 1 and 2.
    Structure is shown on the left side of Figure 2 in the paper.
    """

    def __init__(self, d_model, n_head, dropout=0.1):
        super(MultiHeadAttention, self).__init__()
        assert d_model % n_head == 0

        self.d_model = d_model
        self.n_head = n_head
        self.d_k = d_model // n_head

        # Linear projections
        self.wq = nn.Linear(d_model, d_model)
        self.wk = nn.Linear(d_model, d_model)
        self.wv = nn.Linear(d_model, d_model)
        self.fc = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value, mask=None):
        """
        Args:
            query: [batch_size, seq_len_q, d_model]
            key: [batch_size, seq_len_k, d_model]
            value: [batch_size, seq_len_v, d_model]
            mask: [batch_size, seq_len_q, seq_len_k] or [batch_size, 1, seq_len_q, seq_len_k]

        Paper Reference: Sections 3.2.1 and 3.2.2
        Attention(Q, K, V) = softmax(QK^T / sqrt(d_k))V
        MultiHead(Q, K, V) = Concat(head_1, ..., head_h)W^O
            head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)
        """
        batch_size = query.size(0)

        # Linear projections and head separation
        # [batch_size, seq_len, n_head, d_k]
        q = self.wq(query).view(batch_size, -1, self.n_head, self.d_k).transpose(1, 2)
        k = self.wk(key).view(batch_size, -1, self.n_head, self.d_k).transpose(1, 2)
        v = self.wv(value).view(batch_size, -1, self.n_head, self.d_k).transpose(1, 2)

        # Scaled Dot-Product Attention calculation
        # Paper Reference: Section 3.2.1, Equation 1
        # Attention(Q, K, V) = softmax(QK^T / sqrt(d_k))V
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)

        # Masking (optional)
        if mask is not None:
            # Handle different mask dimensions
            if mask.dim() == 3:  # [batch_size, seq_len_q, seq_len_k]
                mask = mask.unsqueeze(1)  # [batch_size, 1, seq_len_q, seq_len_k]
            elif mask.dim() == 4:  # [batch_size, 1, seq_len_q, seq_len_k]
                pass  # Already correct dimension

            # Expand mask for all heads
            mask = mask.expand(batch_size, self.n_head, -1, -1)
            scores = scores.masked_fill(mask == 0, -1e9)

        # Softmax and Dropout
        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)

        # Output calculation
        out = torch.matmul(attn, v)  # [batch_size, n_head, seq_len_q, d_k]
        out = out.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        out = self.fc(out)

        return out


class PositionwiseFeedForward(nn.Module):
    """
    Two-layer Feed-Forward Network
    Paper Reference: Section 3.3 "Position-wise Feed-Forward Networks"
    FFN(x) = max(0, xW₁ + b₁)W₂ + b₂
    Described in Equation 2.
    """

    def __init__(self, d_model, d_ff, dropout=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        """
        Args:
            x: [batch_size, seq_len, d_model]
        """
        return self.fc2(self.dropout(F.relu(self.fc1(x))))


class EncoderLayer(nn.Module):
    """
    Transformer Encoder Layer
    Paper Reference: Section 3.1 "Encoder and Decoder Stacks" - Encoder part
    Structure shown on the left side of Figure 1.
    Each encoder layer has a multi-head self-attention and a feed-forward network.
    Each sublayer has a residual connection and layer normalization.
    """

    def __init__(self, d_model, n_head, d_ff, dropout=0.1):
        super(EncoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, n_head, dropout)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        """
        Quote: "We apply dropout [33] to the output of each sub-layer, before it is 
        added to thesub-layer input and normalized"
        Section: 5.4 Regularization
        """
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        """
        Args:
            x: [batch_size, seq_len, d_model]
            mask: [batch_size, seq_len, seq_len]

        Paper Reference: Section 3.1, "Sublayer Connection"
        For each sublayer: LayerNorm(x + Sublayer(x))
        """
        # First sublayer: Multi-Head Self-Attention
        attn_output = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout1(attn_output)) # Residual connection + LayerNorm

        # Second sublayer: Position-wise Feed-Forward Network
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout2(ff_output)) # Residual connection + LayerNorm


class DecoderLayer(nn.Module):
    """
    Transformer Decoder Layer
    Paper Reference: Section 3.1 "Encoder and Decoder Stacks" - Decoder part
    Structure shown on the right side of Figure 1.
    Each decoder layer has a masked multi-head self-attention,
    a multi-head cross-attention, and a feed-forward network.
    Each sublayer has a residual connection and layer normalization.
    """

    def __init__(self, d_model, n_head, d_ff, dropout=0.1):
        super(DecoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, n_head, dropout)
        self.cross_attn = MultiHeadAttention(d_model, n_head, dropout)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

    def forward(self, x, enc_output, self_mask=None, cross_mask=None):
        """
        Args:
            x: [batch_size, seq_len, d_model]
            enc_output: [batch_size, src_seq_len, d_model]
            self_mask: [batch_size, seq_len, seq_len]
            cross_mask: [batch_size, seq_len, src_seq_len]

        Paper Reference: Section 3.1, "Decoder"
        The decoder has masked multi-head attention, multi-head attention, and 
        feed-forward network sublayers.
        For each sublayer: LayerNorm(x + Sublayer(x))
        """
        attn_output = self.self_attn(x, x, x, self_mask)
        x = self.norm1(x + self.dropout1(attn_output))  # Residual connection + LayerNorm

        # Sublayer with Cross-Attention (Decoder attends to encoder output)
        attn_output = self.cross_attn(x, enc_output, enc_output, cross_mask)
        x = self.norm2(x + self.dropout2(attn_output)) # Residual connection + LayerNorm

        # Sublayer with Feed-Forward Network
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout3(ff_output)) # Residual connection + LayerNorm

        return x


class Encoder(nn.Module):
    """
    Transformer Encoder
    Paper Reference: Section 3.1 "Encoder and Decoder Stacks"
    The encoder consists of N=6 identical encoder layers.
    First applies token embedding and positional encoding.
    """

    def __init__(self, vocab_size, d_model, n_head, d_ff, n_layers, dropout=0.1, max_len=5000):
        super(Encoder, self).__init__()
        # Paper Reference: Section 3.4, "Embeddings and Softmax"
        # "We multiply those weights by sqrt(d_model)"
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_len, dropout)
        self.layers = nn.ModuleList([
            EncoderLayer(d_model, n_head, d_ff, dropout) for _ in range(n_layers)
        ])
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x, mask=None):
        """
        Args:
            x: [batch_size, src_seq_len]
            mask: [batch_size, src_seq_len, src_seq_len] or [batch_size, 1, src_seq_len, src_seq_len]

        Paper Reference: Section 3.1 "Encoder"
        The encoder consists of N identical encoder layers.
        """
        # Embedding and Positional Encoding
        # Paper Reference: Section 3.4, "We multiply those weights by sqrt(d_model)"
        x = self.embedding(x) * math.sqrt(self.embedding.embedding_dim)
        x = self.pos_encoding(x)

        # Pass through encoder layers
        for layer in self.layers:
            x = layer(x, mask)

        x = self.norm(x)
        return x


class Decoder(nn.Module):
    """
    Transformer Decoder
    Paper Reference: Section 3.1 "Encoder and Decoder Stacks"
    The decoder consists of N=6 identical decoder layers.
    Like the encoder, it first applies token embedding and positional encoding.
    The decoder also uses masking for subsequent positions (section 3.2.3).
    """

    def __init__(self, vocab_size, d_model, n_head, d_ff, n_layers, dropout=0.1, max_len=5000):
        super(Decoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_len, dropout)
        self.layers = nn.ModuleList([
            DecoderLayer(d_model, n_head, d_ff, dropout) for _ in range(n_layers)
        ])
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x, enc_output, self_mask=None, cross_mask=None):
        """
        Args:
            x: [batch_size, tgt_seq_len]
            enc_output: [batch_size, src_seq_len, d_model]
            self_mask: [batch_size, tgt_seq_len, tgt_seq_len] or [batch_size, 1, tgt_seq_len, tgt_seq_len]
            cross_mask: [batch_size, tgt_seq_len, src_seq_len] or [batch_size, 1, tgt_seq_len, src_seq_len]
        """
        # Embedding and Positional Encoding
        x = self.embedding(x) * math.sqrt(self.embedding.embedding_dim)
        x = self.pos_encoding(x)

        # Pass through decoder layers
        for layer in self.layers:
            x = layer(x, enc_output, self_mask, cross_mask)

        x = self.norm(x)
        return x


class Transformer(nn.Module):
    """
    Transformer model (Attention is All You Need)
    Paper Reference: The entire paper, especially Section 3 and Figure 1
    The Transformer consists of an encoder and a decoder.
    The output projection converts decoder output to target word distributions.
    """

    def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512, n_head=8,
                 d_ff=2048, n_layers=6, dropout=0.1, max_len=5000):
        super(Transformer, self).__init__()
        # Paper Reference: Section 3.1 and Table 1
        # Base model: d_model=512, n_head=8, d_ff=2048, n_layers=6, dropout=0.1
        self.encoder = Encoder(src_vocab_size, d_model, n_head, d_ff, n_layers, dropout, max_len)
        self.decoder = Decoder(tgt_vocab_size, d_model, n_head, d_ff, n_layers, dropout, max_len)
        self.projection = nn.Linear(d_model, tgt_vocab_size)

    def forward(self, src, tgt, src_mask=None, tgt_mask=None, memory_mask=None):
        """
        Args:
            src: [batch_size, src_seq_len]
            tgt: [batch_size, tgt_seq_len]
            src_mask: [batch_size, src_seq_len, src_seq_len]
            tgt_mask: [batch_size, tgt_seq_len, tgt_seq_len]
            memory_mask: [batch_size, tgt_seq_len, src_seq_len]

        Paper Reference: The entire paper, especially Figure 1
        Transformer model flow:
        1. Encoder takes input and produces encoder output
        2. Decoder takes encoder output and its own input
        3. Final projection produces target word distribution
        """
        enc_output = self.encoder(src, src_mask)
        dec_output = self.decoder(tgt, enc_output, tgt_mask, memory_mask)
        output = self.projection(dec_output)
        return output


def create_masks(src, tgt, pad_idx):
    """
    Creates padding and subsequent masks
    Paper Reference: Section 3.2.3 "Attention Masking"
    Masking is applied in the decoder to prevent seeing future positions.
    Masking is also applied for padding tokens.
    """
    # Encoder masking (padding mask)
    # src_mask: [batch_size, src_len] -> [batch_size, 1, src_len, src_len]
    src_pad_mask = (src != pad_idx).unsqueeze(1).unsqueeze(1)  # [B, 1, 1, src_len]
    src_len = src.size(1)
    src_mask = src_pad_mask.expand(-1, -1, src_len, -1)  # [B, 1, src_len, src_len]

    # Decoder self-attention masking (padding mask + subsequent mask)
    tgt_len = tgt.size(1)
    tgt_pad_mask = (tgt != pad_idx).unsqueeze(1).unsqueeze(1)  # [B, 1, 1, tgt_len]
    tgt_pad_mask = tgt_pad_mask.expand(-1, -1, tgt_len, -1)  # [B, 1, tgt_len, tgt_len]

    # Subsequent mask (lower triangular)
    tgt_sub_mask = torch.tril(torch.ones((tgt_len, tgt_len), device=tgt.device)).bool()
    tgt_sub_mask = tgt_sub_mask.unsqueeze(0).unsqueeze(0)  # [1, 1, tgt_len, tgt_len]

    tgt_mask = tgt_pad_mask & tgt_sub_mask

    # Cross attention masking (encoder-decoder attention)
    # memory_mask: [batch_size, tgt_len, src_len]
    src_pad_mask_for_cross = (src != pad_idx).unsqueeze(1).unsqueeze(1)  # [B, 1, 1, src_len]
    memory_mask = src_pad_mask_for_cross.expand(-1, -1, tgt_len, -1)  # [B, 1, tgt_len, src_len]

    return src_mask, tgt_mask, memory_mask


# Example of model usage
def example_usage():
    """
    Example usage with proper error handling
    """
    try:
        # Parameters
        # Paper Reference: Section 3.1 and Table 1
        # Base model: d_model=512, n_head=8, d_ff=2048, n_layers=6, dropout=0.1
        src_vocab_size = 10000
        tgt_vocab_size = 10000
        d_model = 512
        n_head = 8
        d_ff = 2048
        n_layers = 6
        dropout = 0.1
        pad_idx = 0

        # Create model
        transformer = Transformer(src_vocab_size, tgt_vocab_size, d_model, n_head, d_ff, n_layers, dropout)

        # Print model parameters
        total_params = sum(p.numel() for p in transformer.parameters())
        print(f"Total parameters: {total_params:,}")

        # Example input data
        src = torch.randint(1, src_vocab_size, (64, 10))  # [batch_size=64, src_seq_len=10]
        tgt = torch.randint(1, tgt_vocab_size, (64, 20))  # [batch_size=64, tgt_seq_len=20]

        # Create masks
        src_mask, tgt_mask, memory_mask = create_masks(src, tgt, pad_idx)

        print(f"Source mask shape: {src_mask.shape}")
        print(f"Target mask shape: {tgt_mask.shape}")
        print(f"Memory mask shape: {memory_mask.shape}")

        # Forward pass
        output = transformer(src, tgt, src_mask, tgt_mask, memory_mask)
        print(f"Output shape: {output.shape}")  # [64, 20, tgt_vocab_size]
        print("Forward pass successful!")

        return output

    except Exception as e:
        print(f"Error occurred: {str(e)}")
        return None


example_usage()
0
Subscribe to my newsletter

Read articles from Ramazan Turan directly inside your inbox. Subscribe to the newsletter, and don't miss out.

Written by

Ramazan Turan
Ramazan Turan