Transformer Architectures Across Frameworks: TensorFlow vs. PyTorch

Mohamad MahmoodMohamad Mahmood
4 min read

This discussion is based on the post at https://medium.com/lexiconia/transformers-tensorflow-vs-pytorch-implementation-3f4e5a7239e3

[1] Similarities Between TensorFlow and PyTorch Implementations

  1. Model Architecture

    • Both implementations follow the standard Transformer architecture:

      • Tokenization and padding

      • Positional encoding

      • Multi-head attention

      • Feed-forward neural networks

      • Layer normalization

      • Encoder/decoder stacking

  2. Functionality

    • Core functionalities such as positional_encoding, create_padding_mask, multi-head attention, and encoder layer are implemented in both frameworks.

    • Masking is used in both to handle padding tokens properly.

  3. Math and Logic

    • Matrix multiplications, softmax, and scaling in attention are handled similarly.

    • The computation of angle rates and sine/cosine functions for positional encoding is mathematically identical.

[2] Differences Between TensorFlow and PyTorch Implementations

AspectTensorFlowPyTorch
SyntaxUses tf.keras.layers, tf.Tensor, @tf.function (if optimized)Uses torch.nn.Module, torch.Tensor, and decorators like @torch.no_grad or @staticmethod
Layer DefinitionInherits from tf.keras.layers.LayerInherits from torch.nn.Module
Sequential Modelstf.keras.Sequential([...])torch.nn.Sequential(...)
Tensor Operationstf.reshape, tf.transpose, broadcasting via tf.expand_dimstorch.reshape, torch.permute, unsqueeze
Training ParadigmHigh-level with model.compile() and model.fit()Low-level training loop with manual optimizer.zero_grad(), loss.backward(), and optimizer.step()
Data HandlingTensorFlow uses tf.data.DatasetPyTorch uses torch.utils.data.DataLoader
Device ManagementLess explicit (e.g., eager mode runs on CPU/GPU transparently)Explicit (you must push models/tensors to device: model.to(device), tensor.to(device))

Side-by-side comparison of the EncoderLayer and Transformer classes in TensorFlow and PyTorch

1. EncoderLayer Comparison

AspectTensorFlow (tf.keras.layers.Layer)PyTorch (torch.nn.Module)
Inheritanceclass EncoderLayer(tf.keras.layers.Layer)class EncoderLayer(nn.Module)
Attention LayerUses custom MultiHeadAttention layer with call() methodUses custom MultiHeadAttention layer with forward() method
Feedforward (FFN)tf.keras.Sequential([...])nn.Sequential(...)
Normalizationtf.keras.layers.LayerNormalization(epsilon=1e-6)nn.LayerNorm(d_model, eps=1e-6)
Forward Calldef call(self, x, mask=None)def forward(self, x, mask=None)
Dropouttf.keras.layers.Dropout(0.2)nn.Dropout(0.2)

2. Transformer Class Comparison

AspectTensorFlow (Transformer(tf.keras.Model))PyTorch (Transformer(nn.Module))
Embeddingtf.keras.layers.Embedding(...)nn.Embedding(...)
Positional Encodingx += self.positional_encoding[:, :tf.shape(x)[1], :]x += self.positional_encoding[:, :x.size(1), :]
Stacking Encoder LayersPython list: [EncoderLayer(...) for _ in range(n)]nn.ModuleList([EncoderLayer(...) for _ in range(n)])
Output Layertf.keras.layers.Dense(input_vocab_size)nn.Linear(d_model, input_vocab_size)
Forward Passdef call(self, inputs)def forward(self, inputs)

Summary of Key Differences

  • Method Names: TensorFlow uses call() for forward logic; PyTorch uses forward().

  • Layer Definition: TensorFlow layers are Keras-based; PyTorch uses torch.nn.

  • Weight Registration: PyTorch needs nn.ModuleList for tracking submodules.

  • Execution: TensorFlow is graph-based (eager by default now); PyTorch is natively eager.

Training Routine Comparison

AspectTensorFlowPyTorch
Dataset HandlingUses Python list, tokenized and padded, converted to tf.convert_to_tensor()(Not shown in current code but would use torch.tensor(...) for similar data)
Loss Functiontf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction="none")Typically nn.CrossEntropyLoss(reduction='none')
MaskingApplies a mask on loss with tf.math.logical_not and scales losses accordinglyWould use (targets != pad_idx) boolean masks with PyTorch tensor operations
Gradient Calculationwith tf.GradientTape() as tape: for automatic differentiationloss.backward() in standard autograd flow
Optimizertf.keras.optimizers.Adam(...)torch.optim.Adam(...)
Weight Updatetape.gradient(...) and optimizer.apply_gradients(...)optimizer.step() after optimizer.zero_grad()
Epoch Loggingif epoch % 50 == 0: print(...)Same idea, implemented manually

Inference Routine Comparison

AspectTensorFlowPyTorch
InputsSentence string → tokenized → padded → tensorSame concept; would use torch.tensor(...)
Beam SearchImplemented as a custom function using tf.nn.softmax(logits)Would require similar logic using torch.softmax(...)
Model Predictionoutputs = transformer(test_tensor)In PyTorch: with torch.no_grad(): outputs = transformer(inputs)
Output ProcessingBeam search used to get best sequence of token indicesSame idea would be used in PyTorch
DetokenizationUses detokenize() to convert predicted tokens back to stringSame function reused in both implementations

Summary of Training & Inference Differences

FeatureTensorFlowPyTorch
Autodifftf.GradientTape()loss.backward()
Step controlapply_gradients(...)optimizer.step()
ExecutionImplicit graph-basedExplicit control
Inference contextNo special decorator neededUse torch.no_grad() to disable gradients
0
Subscribe to my newsletter

Read articles from Mohamad Mahmood directly inside your inbox. Subscribe to the newsletter, and don't miss out.

Written by

Mohamad Mahmood
Mohamad Mahmood

Mohamad's interest is in Programming (Mobile, Web, Database and Machine Learning). He studies at the Center For Artificial Intelligence Technology (CAIT), Universiti Kebangsaan Malaysia (UKM).