PyTorch Model Saving with torch.save & torch.load

1. Introduction — Why Model Saving and Loading Matter

In the lifecycle of a machine learning project, model serialization — saving and loading trained models — is a critical operation. It allows you to:

  • Resume training from a checkpoint after interruption.
  • Deploy models for inference in production environments.
  • Share and distribute models across teams or hardware setups.
  • Reproduce experiments reliably.

Without robust save/load mechanisms, retraining from scratch after every modification would be a nightmare. PyTorch simplifies this process through two powerful utilities: torch.save() and torch.load().

In this guide, we’ll explore how these functions work, why saving state_dict is the preferred method, how to implement robust checkpointing, and how to avoid common pitfalls that can derail your workflow.


2. Deep Dive on torch.save()

Function Signature

torch.save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True)

Purpose

torch.save() serializes Python objects (usually PyTorch models, tensors, or dictionaries) to disk using Python’s pickle module under the hood. The obj can be almost any Python object, but in practice, it’s typically one of:

  • A model’s entire object (model)
  • A model’s state dictionary (model.state_dict())
  • An optimizer’s state dictionary (optimizer.state_dict())
  • A checkpoint containing multiple items (e.g., model, optimizer, epoch)

Example: Saving a Model

import torch

# Suppose we have a model
model = MyModel()
torch.save(model.state_dict(), "model_state.pth")

This saves only the model’s parameters — not the model class definition itself.

How torch.save() Works Internally

Under the hood, PyTorch uses Python’s pickle module to serialize objects into byte streams, which are then written to disk. Because of this, you should be careful when loading files from untrusted sources — pickle can execute arbitrary code.

What You Can Save

  • Entire model: Includes structure + parameters
  • State dictionary: Weights and buffers only
  • Optimizers: State for optimizers (e.g., momentum, learning rates)
  • Checkpoints: Dictionaries combining multiple items

3. Deep Dive on torch.load()

Function Signature

torch.load(f, map_location=None, pickle_module=pickle, weights_only=False)

Purpose

torch.load() deserializes a file saved with torch.save() and reconstructs the Python object.

model = MyModel()
model.load_state_dict(torch.load("model_state.pth"))

Understanding map_location

The map_location parameter controls how tensors are mapped across devices during loading. This is vital when models are trained on GPUs but loaded on CPUs or different GPUs.

Examples:

  1. Load GPU model on CPU torch.load("gpu_model.pth", map_location=torch.device("cpu"))
  2. Load onto a specific GPU torch.load("gpu_model.pth", map_location="cuda:1")
  3. Automatic mapping device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.load_state_dict(torch.load("model.pth", map_location=device))

Security Considerations

Because torch.load() relies on pickle, never load models from untrusted sources. Pickle can execute arbitrary code on your system. When sharing models, prefer using TorchScript or safetensors for secure serialization.


4. The Recommended Way — Save and Load state_dict

While torch.save(model, PATH) seems simpler, it has hidden drawbacks.

ApproachProsCons
torch.save(model, PATH)Saves everythingCan break if class code changes
torch.save(model.state_dict(), PATH)Lightweight, flexibleRequires model class definition to load

Best Practice

Always save only the state_dict and recreate the model class when loading.

Example: Save and Load Using state_dict

# Save
torch.save(model.state_dict(), "model_state.pth")

# Load
model = MyModel()
model.load_state_dict(torch.load("model_state.pth"))
model.eval()

Why This Is Better

  • You can refactor your model code later without breaking compatibility.
  • The files are smaller and faster to load.
  • It’s the recommended approach by the official PyTorch documentation.

5. Checkpointing for Training — A Full Example

During long training runs, saving checkpoints helps recover progress after interruptions or crashes.

Here’s a reusable pattern:

import torch
import os

def save_checkpoint(state, filename="checkpoint.pth.tar"):
    torch.save(state, filename)
    print(f"Checkpoint saved at {filename}")

def load_checkpoint(model, optimizer, filename="checkpoint.pth.tar"):
    if os.path.isfile(filename):
        print("=> Loading checkpoint...")
        checkpoint = torch.load(filename, map_location="cpu")
        model.load_state_dict(checkpoint["state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        epoch = checkpoint["epoch"]
        loss = checkpoint["loss"]
        print(f"=> Loaded checkpoint (epoch {epoch})")
        return epoch, loss
    else:
        print("=> No checkpoint found.")
        return 0, None

During training:

for epoch in range(num_epochs):
    # training loop
    loss = train(...)
    
    # Save every 5 epochs
    if epoch % 5 == 0:
        checkpoint = {
            "epoch": epoch,
            "state_dict": model.state_dict(),
            "optimizer": optimizer.state_dict(),
            "loss": loss
        }
        save_checkpoint(checkpoint)

6. Common Pitfalls (and How to Avoid Them)

ProblemCauseSolution
RuntimeError: Error(s) in loading state_dictModel architecture mismatchEnsure identical model definition when saving/loading
Device mismatch (CUDA vs CPU)Loading GPU tensors on CPUUse map_location=torch.device("cpu")
Version mismatch (old vs new PyTorch)Serialization format changesRe-save models with the latest PyTorch
Pickle security riskLoading untrusted filesUse safetensors or TorchScript
Missing keysLayers renamed or modifiedAlign layer names or use strict=False in load_state_dict()

7. 5 Best Practices for Model Serialization

  1. Always save state_dict() instead of the entire model.
  2. Include optimizer states if you plan to resume training.
  3. Use descriptive filenames — include epoch or accuracy (model_epoch10_acc90.pth).
  4. Handle devices gracefully using map_location.
  5. Verify loads with a simple forward pass after loading.

8. Full Working Example — End-to-End

import torch
import torch.nn as nn
import torch.optim as optim

class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc = nn.Linear(10, 2)
    def forward(self, x):
        return self.fc(x)

# Initialize model and optimizer
model = SimpleNN()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Train briefly (simulated)
loss = torch.tensor(0.5)

# Save checkpoint
torch.save({
    'epoch': 10,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss,
}, "checkpoint.pth")

# Load checkpoint
checkpoint = torch.load("checkpoint.pth", map_location=torch.device("cpu"))
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
print(f"Resumed from epoch {epoch}, loss {loss}")

9. Frequently Asked Questions (FAQ)

Q1: What is a state_dict in PyTorch?
A state_dict is a Python dictionary mapping layer names to their parameter tensors. Both models and optimizers have their own state_dicts.

Q2: How do I load a model on CPU if it was saved on GPU?
Use:

torch.load("model_gpu.pth", map_location=torch.device("cpu"))

Q3: Is torch.save cross-platform?
Yes, as long as you use compatible PyTorch versions and avoid platform-specific pickle objects.

Q4: Can I save multiple models in one file?
Yes. Use a dictionary:

torch.save({'model1': model1.state_dict(), 'model2': model2.state_dict()}, "multi.pth")

10. Conclusion

Model serialization is not just a convenience — it’s a cornerstone of reproducible, scalable, and deployable deep learning workflows.

By mastering torch.save() and torch.load(), you can:

  • Seamlessly save and resume training
  • Deploy models across CPUs, GPUs, and servers
  • Implement robust checkpointing systems
  • Avoid the most common pitfalls that plague PyTorch beginners

Stick with the state_dict approach, manage your devices carefully, and version your checkpoints intelligently — and you’ll never lose a trained model again.


🔑 Key Takeaways

  • torch.save() and torch.load() are the backbone of model persistence.
  • Prefer saving state_dict instead of the entire model.
  • Use map_location for smooth CPU/GPU transitions.
  • Always checkpoint models during long training.
  • Load responsibly — security matters.

Leave a Reply