Skip to main content
  1. Posts/

A Dig into PyTorch Model Loading

·285 words·2 mins·
Machine-Learning PyTorch
Table of Contents

Saving and loading PyTorch models

Models in PyTorch are a subclass of torch.nn.Module. To save the model parameters, we use model.state_dict() to get all the model parameters:

state = model.state_dict()

Then save the model parameter using, 'my-model.pth')

Loading error when using torch.load to load model trained on GPU

When we load a model trained on GPU in a machine with no GPU using torch.load(model_path), we often get the following error:

RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device(‘cpu’) to map your storages to the CPU.

The cause

This is because when we use to save an object, torch will also store the location of original data (called location tag, check code here). also keeps the view relationship between tensors unchanged, see here.

Based on code here, it seems that PyTorch will save the GPU tensor as CPU.

When we use torch.load(), since the tensor location has been recorded, torch will load the tensor first to CPU, then moves it to the GPU indicated by the location tag. If that GPU is missing or we are using a CPU machine, the above error will occur.

Load the model correctly

A better way to load a model is to move it to CPU using the map_location parameter of torch.load(). Load the model to CPU, then load the model parameter into the model, finally, move the model to GPU:

# move the model parameter to cpu
state = torch.load('my-model.pth', map_location=torch.device('cpu'))


# now move the model parameter to a GPU device of your choice'cuda:0'))



Dependency Hell When Building A PyTorch GPU Docker Image
·262 words·2 mins
Machine-Learning PyTorch Docker
Accelerate Batched Image Inference in PyTorch
··517 words·3 mins
Machine-Learning PyTorch
Set the Number of Threads to Use in PyTorch
··245 words·2 mins
Machine-Learning PyTorch Thread