Making Sense of PyTorch’s to(device) and map_location.
In a project I am working on, I had to push the model trained on a GPU into production. The problem started when I tried to make an inference script for the model and execute it in a CPU-only device to test the script.
Habituated to the .to(device)
method, I loaded the model using the same structure.
model = torch.load(model_path).to(device)
However, it failed when I tried to load the model on a CPU-only device. That’s when I learned we must pass the argument map_location
in the .load()
method instead of using the to(device)
.
model = torch.load(model_path, map_location=device)
Why can’t I keep using to(device)
? What difference does map_location
make? When can I use it safely to(device)
? When is to(device)
more useful than map_location
? Let's discuss the answers to these questions now.
The “technical” difference between the two.
The torch.load()
(without map_location) method tries to load the model into the device it was saved on, then ports it to a specified GPU or CPU using to(device)
. It's a multi-step process. Whereas the ` torch.load(model_path, map_location=device)
is a single-step process where the model parameters are directly loaded into the specified device. The following illustration might help us understand what is happening on a high level.