Skip to main content
  1. Posts/

A Dig into PyTorch Model Loading

·285 words·2 mins·
Table of Contents

Saving and loading PyTorch models
#

Models in PyTorch are a subclass of torch.nn.Module. To save the model parameters, we use model.state_dict() to get all the model parameters:

state = model.state_dict()

Then save the model parameter using torch.save():

torch.save(state, 'my-model.pth')

Loading error when using torch.load to load model trained on GPU
#

When we load a model trained on GPU in a machine with no GPU using torch.load(model_path), we often get the following error:

RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device(‘cpu’) to map your storages to the CPU.

The cause
#

This is because when we use torch.save() to save an object, torch will also store the location of original data (called location tag, check code here). torch.save() also keeps the view relationship between tensors unchanged, see here.

Based on code here, it seems that PyTorch will save the GPU tensor as CPU.

When we use torch.load(), since the tensor location has been recorded, torch will load the tensor first to CPU, then moves it to the GPU indicated by the location tag. If that GPU is missing or we are using a CPU machine, the above error will occur.

Load the model correctly
#

A better way to load a model is to move it to CPU using the map_location parameter of torch.load(). Load the model to CPU, then load the model parameter into the model, finally, move the model to GPU:

# move the model parameter to cpu
state = torch.load('my-model.pth', map_location=torch.device('cpu'))

model.load_state_dict(state)

# now move the model parameter to a GPU device of your choice
model.to(torch.device('cuda:0'))

ref
#

Related

Dependency Hell When Building A PyTorch GPU Docker Image
·262 words·2 mins
Accelerate Batched Image Inference in PyTorch
··517 words·3 mins
Set the Number of Threads to Use in PyTorch
··245 words·2 mins