1. Introduction — Why Model Saving and Loading Matter
In the lifecycle of a machine learning project, model serialization — saving and loading trained models — is a critical operation. It allows you to:
- Resume training from a checkpoint after interruption.
- Deploy models for inference in production environments.
- Share and distribute models across teams or hardware setups.
- Reproduce experiments reliably.
Without robust save/load mechanisms, retraining from scratch after every modification would be a nightmare. PyTorch simplifies this process through two powerful utilities: torch.save()
and torch.load()
.
In this guide, we’ll explore how these functions work, why saving state_dict
is the preferred method, how to implement robust checkpointing, and how to avoid common pitfalls that can derail your workflow.
2. Deep Dive on torch.save()
Function Signature
torch.save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True)
Purpose
torch.save()
serializes Python objects (usually PyTorch models, tensors, or dictionaries) to disk using Python’s pickle module under the hood. The obj
can be almost any Python object, but in practice, it’s typically one of:
- A model’s entire object (
model
) - A model’s state dictionary (
model.state_dict()
) - An optimizer’s state dictionary (
optimizer.state_dict()
) - A checkpoint containing multiple items (e.g., model, optimizer, epoch)
Example: Saving a Model
import torch
# Suppose we have a model
model = MyModel()
torch.save(model.state_dict(), "model_state.pth")
This saves only the model’s parameters — not the model class definition itself.
How torch.save()
Works Internally
Under the hood, PyTorch uses Python’s pickle module to serialize objects into byte streams, which are then written to disk. Because of this, you should be careful when loading files from untrusted sources — pickle can execute arbitrary code.
What You Can Save
- Entire model: Includes structure + parameters
- State dictionary: Weights and buffers only
- Optimizers: State for optimizers (e.g., momentum, learning rates)
- Checkpoints: Dictionaries combining multiple items
3. Deep Dive on torch.load()
Function Signature
torch.load(f, map_location=None, pickle_module=pickle, weights_only=False)
Purpose
torch.load()
deserializes a file saved with torch.save()
and reconstructs the Python object.
model = MyModel()
model.load_state_dict(torch.load("model_state.pth"))
Understanding map_location
The map_location
parameter controls how tensors are mapped across devices during loading. This is vital when models are trained on GPUs but loaded on CPUs or different GPUs.
Examples:
- Load GPU model on CPU
torch.load("gpu_model.pth", map_location=torch.device("cpu"))
- Load onto a specific GPU
torch.load("gpu_model.pth", map_location="cuda:1")
- Automatic mapping
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.load_state_dict(torch.load("model.pth", map_location=device))
Security Considerations
Because torch.load()
relies on pickle, never load models from untrusted sources. Pickle can execute arbitrary code on your system. When sharing models, prefer using TorchScript or safetensors
for secure serialization.
4. The Recommended Way — Save and Load state_dict
While torch.save(model, PATH)
seems simpler, it has hidden drawbacks.
Approach | Pros | Cons |
---|---|---|
torch.save(model, PATH) | Saves everything | Can break if class code changes |
torch.save(model.state_dict(), PATH) | Lightweight, flexible | Requires model class definition to load |
Best Practice
Always save only the state_dict
and recreate the model class when loading.
Example: Save and Load Using state_dict
# Save
torch.save(model.state_dict(), "model_state.pth")
# Load
model = MyModel()
model.load_state_dict(torch.load("model_state.pth"))
model.eval()
Why This Is Better
- You can refactor your model code later without breaking compatibility.
- The files are smaller and faster to load.
- It’s the recommended approach by the official PyTorch documentation.
5. Checkpointing for Training — A Full Example
During long training runs, saving checkpoints helps recover progress after interruptions or crashes.
Here’s a reusable pattern:
import torch
import os
def save_checkpoint(state, filename="checkpoint.pth.tar"):
torch.save(state, filename)
print(f"Checkpoint saved at {filename}")
def load_checkpoint(model, optimizer, filename="checkpoint.pth.tar"):
if os.path.isfile(filename):
print("=> Loading checkpoint...")
checkpoint = torch.load(filename, map_location="cpu")
model.load_state_dict(checkpoint["state_dict"])
optimizer.load_state_dict(checkpoint["optimizer"])
epoch = checkpoint["epoch"]
loss = checkpoint["loss"]
print(f"=> Loaded checkpoint (epoch {epoch})")
return epoch, loss
else:
print("=> No checkpoint found.")
return 0, None
During training:
for epoch in range(num_epochs):
# training loop
loss = train(...)
# Save every 5 epochs
if epoch % 5 == 0:
checkpoint = {
"epoch": epoch,
"state_dict": model.state_dict(),
"optimizer": optimizer.state_dict(),
"loss": loss
}
save_checkpoint(checkpoint)
6. Common Pitfalls (and How to Avoid Them)
Problem | Cause | Solution |
---|---|---|
RuntimeError: Error(s) in loading state_dict | Model architecture mismatch | Ensure identical model definition when saving/loading |
Device mismatch (CUDA vs CPU ) | Loading GPU tensors on CPU | Use map_location=torch.device("cpu") |
Version mismatch (old vs new PyTorch) | Serialization format changes | Re-save models with the latest PyTorch |
Pickle security risk | Loading untrusted files | Use safetensors or TorchScript |
Missing keys | Layers renamed or modified | Align layer names or use strict=False in load_state_dict() |
7. 5 Best Practices for Model Serialization
- Always save
state_dict()
instead of the entire model. - Include optimizer states if you plan to resume training.
- Use descriptive filenames — include epoch or accuracy (
model_epoch10_acc90.pth
). - Handle devices gracefully using
map_location
. - Verify loads with a simple forward pass after loading.
8. Full Working Example — End-to-End
import torch
import torch.nn as nn
import torch.optim as optim
class SimpleNN(nn.Module):
def __init__(self):
super(SimpleNN, self).__init__()
self.fc = nn.Linear(10, 2)
def forward(self, x):
return self.fc(x)
# Initialize model and optimizer
model = SimpleNN()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Train briefly (simulated)
loss = torch.tensor(0.5)
# Save checkpoint
torch.save({
'epoch': 10,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
}, "checkpoint.pth")
# Load checkpoint
checkpoint = torch.load("checkpoint.pth", map_location=torch.device("cpu"))
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
print(f"Resumed from epoch {epoch}, loss {loss}")
9. Frequently Asked Questions (FAQ)
Q1: What is a state_dict
in PyTorch?
A state_dict
is a Python dictionary mapping layer names to their parameter tensors. Both models and optimizers have their own state_dicts.
Q2: How do I load a model on CPU if it was saved on GPU?
Use:
torch.load("model_gpu.pth", map_location=torch.device("cpu"))
Q3: Is torch.save
cross-platform?
Yes, as long as you use compatible PyTorch versions and avoid platform-specific pickle objects.
Q4: Can I save multiple models in one file?
Yes. Use a dictionary:
torch.save({'model1': model1.state_dict(), 'model2': model2.state_dict()}, "multi.pth")
10. Conclusion
Model serialization is not just a convenience — it’s a cornerstone of reproducible, scalable, and deployable deep learning workflows.
By mastering torch.save()
and torch.load()
, you can:
- Seamlessly save and resume training
- Deploy models across CPUs, GPUs, and servers
- Implement robust checkpointing systems
- Avoid the most common pitfalls that plague PyTorch beginners
Stick with the state_dict
approach, manage your devices carefully, and version your checkpoints intelligently — and you’ll never lose a trained model again.
🔑 Key Takeaways
torch.save()
andtorch.load()
are the backbone of model persistence.- Prefer saving
state_dict
instead of the entire model. - Use
map_location
for smooth CPU/GPU transitions. - Always checkpoint models during long training.
- Load responsibly — security matters.