Making Sense of PyTorch’s to(device) and map_location.

Akhil Theerthala
2 min readAug 26, 2023

--

In a project I am working on, I had to push the model trained on a GPU into production. The problem started when I tried to make an inference script for the model and execute it in a CPU-only device to test the script.

Habituated to the .to(device) method, I loaded the model using the same structure.

model = torch.load(model_path).to(device)

However, it failed when I tried to load the model on a CPU-only device. That’s when I learned we must pass the argument map_location in the .load() method instead of using the to(device).

model = torch.load(model_path, map_location=device)

Why can’t I keep using to(device)? What difference does map_location make? When can I use it safely to(device) ? When is to(device) more useful than map_location ? Let's discuss the answers to these questions now.

The “technical” difference between the two.

The torch.load() (without map_location) method tries to load the model into the device it was saved on, then ports it to a specified GPU or CPU using to(device). It's a multi-step process. Whereas the ` torch.load(model_path, map_location=device) is a single-step process where the model parameters are directly loaded into the specified device. The following illustration might help us understand what is happening on a high level.

High-level understanding of map_location vs to(device). Image generated by the author.

Why does this difference matter?

Imagine if we removed the GPU when loading the model in the above image. Which of the two paths would fail? Then .to("cpu") would fail as there is no GPU to load the model into. Even when you have an additional GPU that can store the model, you reduce the computational load of transferring the model from one device to another when using the map_location.

How can this difference be utilized?

There are multiple areas where we can make the most of this difference. For example, we can safely skip the CUDA version mismatch or the architecture mismatch issues with the help of map_location, While we can use .to(device) to dynamically choose model parts to be pushed into the device of choice.

In the following table, I have tried to summarize all the major differences where one approach can be more useful than the other. This side-by-side comparison is presented in the table below.

To avoid all these issues, one of the best and recommended approaches is to push the models to the CPU before saving them using torch.save() , as a system might not have a TPU or GPU, but it can't function without a CPU.

I hope you enjoyed reading this article. Let’s meet again with more detailed articles. Till then, stay tuned! ✌️ P.S. In the meantime, you can check out my other articles where I discussed RegEx or learn more about the different phases involved in a machine learning project lifecycle here.

--

--

Akhil Theerthala

Machine Learning Enthusiast | Data Enthusiast | AI Enthusiast |