Persisting a Model in PyTorch: A Comprehensive Guide
In the evolving landscape of machine learning and deep learning, the ability to persist a model is a crucial step.
Imagine training a state-of-the-art model for days or even weeks, only to realize you cannot reliably save and reuse it or just deploy to production.
This scenario is not just inconvenient; it's a significant setback.
PyTorch, a leading framework in the field, offers robust solutions for model persistence, ensuring the migration process to production or re-using the models.
The Importance of Persisting a Model
Why is persisting a model important?
The answer lies in the essence of machine learning itself. The weights (or parameters) of the model are learned during training and are essential for the model to make predictions.
These need to be saved after training and loaded into the model architecture in the production environment. Or maybe, use them to resume the training process where we left off.
Understanding Model and Optimizer State Dictionaries
Before delving into persistence, let's clarify what makes up a PyTorch model.
At its core, a model in PyTorch is encapsulated within an nn.Module
.
This module comprises not just the architecture but also the learnable parameters (weights and biases) vital for the model's predictions.
These parameters are accessible through the model's state dictionary (state_dict
), a Python dictionary mapping each layer to its parameters.
# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
print(param_tensor, "\t", model.state_dict()[param_tensor].size())
Similarly, optimizers in PyTorch, responsible for adjusting these parameters based on gradients, have their state_dict
.
This dictionary contains all hyper-parameters and state information necessary for the optimizer's operation.
The synchronization of model and optimizer state dictionaries is crucial for seamless training continuation and model deployment.
The Workflow of Saving and Loading Models
Define the Model, Optimizer and Cost function
# Example of saving a model architecture in PyTorch
import torch
import torch.nn as nn
# model
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.linear = nn.Linear(1, 1)
def forward(self, x):
return self.linear(x)
# optimizer
optimizer = optim.SGD(model.parameters(), lr=0.01)
# loss function
criterion = nn.MSELoss()
Saving and Loading Model Architecture
Save the Model Architecture
# Save your model architecture
model = Model()
torch.save(model, 'model_architecture.pth')
Load the Model in Production
from model_architecture import Model
model = Model()
Saving And Loading Model Parameters
Saving Model Parameters
These are the learned parameters of the model necessary for making predictions. Without them, your model would not perform better than a randomly initialized model.
You can read more about state_dict here.
Save the the parameters using the code below.
torch.save(model.state_dict(), 'model_weights.pth')
Loading Model Parameters in Production
Loading the model's parameters for inference or further training involves initializing the model structure and loading the saved state dictionary.
model = Model()
model.load_state_dict(torch.load('model_weights.pth'))
# Remember that you must call model.eval() to set dropout and batch normalization layers to evaluation mode before running inference.
# Failing to do this will yield inconsistent inference results.
model.eval()
This process reinstates the model to its trained state, ready for inference or continued training.
General Checkpoints for Inference and/or Resuming Training
For more comprehensive persistence, especially in long-running training jobs, saving just the model's parameters might not suffice.
You might also want to preserve the optimizer state, current epoch, and loss history to resume training seamlessly from where it left off.
PyTorch facilitates this through extended checkpointing:
Saving Training State
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
}, 'checkpoint.pth')
This approach encapsulates a more holistic snapshot of the training process, enabling a more flexible training lifecycle management.
To resume training in other environment or inference in production
checkpoint = torch.load('checkpoint.tar')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
# Continue training or evaluation from here
model.eval()
# - or -
model.train()
The Role of Pickle in PyTorch Persistence
PyTorch relies on Python's pickle
module for serialization, allowing for the conversion of complex Python objects into byte streams for storage or transmission.
While pickle
is powerful, it introduces potential security concerns due to its ability to execute arbitrary code upon loading. This characteristic of pickle
necessitates caution, especially when dealing with untrusted sources.
See the following note about the security flaws here:
The insecurity is not because pickles contain code, but because they create objects by calling constructors named in the pickle. Any callable can be used in place of your class name to construct objects. Malicious pickles will use other Python callables as the “constructors.” For example, instead of executing “models.MyObject(17)”, a dangerous pickle might execute “os.system(‘rm -rf /’)”. The unpickler can’t tell the difference between “models.MyObject” and “os.system”. Both are names it can resolve, producing something it can call. The unpickler executes either of them as directed by the pickle.'
Emerging Solutions: SafeTensors
Recognizing the limitations of pickle
, the community has been exploring alternatives.
As an alternative to pickle, engineers at HuggingFace developed safetensors
. This library offers a safer and more efficient way to persist tensors, promoting type safety and cross-language compatibility, which are crucial for modern machine learning workflows.
See an example of persisting a model using safetensors
.
import torch
from safetensors import safe_open
from safetensors.torch import save_file
tensors = {
"weight1": torch.zeros((1024, 1024)),
"weight2": torch.zeros((1024, 1024))
}
save_file(tensors, "model.safetensors")
The Future of Model Persistence: GGML and GGUF
The evolution of model persistence is not stopping at safetensors
.
The GGML and GGUF formats represent the next frontier, particularly in bringing machine learning models to edge devices and personal computers.
These formats emphasize efficiency, backward compatibility, and the ability to run models locally, underscoring the shifting paradigm towards decentralized and accessible AI.
Practical Example: Running GGUF Models Locally
Consider the process of downloading and running a GGUF model on a local machine.
The simplicity of the process, involving downloading the model using Hugging Face's CLI and running it with tools like ollama
, highlights the accessibility and practicality of modern AI technologies.
This ease of use is instrumental in democratizing AI, allowing a broader range of users to leverage powerful models without requiring extensive resources.
Steps to download TheBloke/MistralLite-7B-GGUF
model and run it on a Mac M2.
pip install huggingface-hub
huggingface-cli download \
TheBloke/MistralLite-7B-GGUF \
mistrallite.Q4_K_M.gguf \
--local-dir ~/Downloads \
--local-dir-use-symlinks False
Create Modelfile
for ollama
.
Modelfile
FROM ~/Downloads/mistrallite.Q4_K_M.gguf
Build the model and run it locally.
ollama create mistrallite -f Modelfile
ollama run mistrallite "What is ollama?"
Conclusion
In conclusion, the process of migrating PyTorch models from training to a production environment or reusing the model in different environment, encompasses several critical steps.
Understanding these components and their roles in model migration ensures that practitioners can leverage PyTorch's capabilities to their fullest, enabling a smooth transition to production environments.
This knowledge base serves as a cornerstone for anyone looking to navigate the complexities of migrating deep learning models effectively.
Subscribe to my newsletter
Read articles from Juan Carlos Olamendy directly inside your inbox. Subscribe to the newsletter, and don't miss out.
Written by
Juan Carlos Olamendy
Juan Carlos Olamendy
🤖 Talk about AI/ML · AI-preneur 🛠️ Build AI tools 🚀 Share my journey 𓀙 🔗 http://pixela.io