Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix a load_params bug when loading a CUDA model to CPU. #358

Merged
merged 4 commits into from
Oct 23, 2018

Conversation

benjamin-work
Copy link
Contributor

This bug occurred when a model was trained on GPU and saved using save_params, then loaded, using load_params, with device='CPU' on a machine without CUDA device.

Fixes #354.
Should supersede #356.

This bug occurred when a model was trained on GPU and saved using
save_params, then loaded, using load_params, with device='CPU' on a
machine without CUDA device.
Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The usage of early exiting makes the code more understandable. There seems to be no avoiding adding skorch/tests/net_cuda.pt to test loading a cuda model into a cpu.

skorch/net.py Outdated
else:
# use CPU
if not self.device.startswith('cuda'):
model = torch.load(f, map_location=lambda storage, loc: storage)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pytorch supports map_location=torch.device('cpu'), which may make the intend of loading into a cpu clearer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it makes sense to be more explicit.

skorch/net.py Outdated
"available. Loading on CPU instead.",
ResourceWarning)
self.device = 'cpu'
model = torch.load(f, lambda storage, loc: storage)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pytorch supports map_location=torch.device('cpu'), which may make the intend of loading into a cpu clearer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Idem

def test_load_cuda_params_to_cpu(self, net_cls, module_cls, data):
# Note: This test will pass trivially when CUDA is available
# but triggered a bug when CUDA is not available.
net = net_cls(module_cls).initialize()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am unable to see where the bug is in this test. Can you expand on this point?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the current implementation, when you want to load a model trained with CUDA on CPU, you get:

RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location='cpu' to map your storages to the CPU.

This is the bug from the initial issue. The reason is that currently, we only explicitly load to CPU when the user indicates they want to use CUDA but no CUDA device is detected. The more trivial case that the user indicates that they want to load to CPU is not covered.

@benjamin-work
Copy link
Contributor Author

I will wait for #360 before merging this, it's probably easier that way than the other way round.

Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor suggestion to avoid a DeprecationWarning. Otherwise LGTM

X, y = data
net = net_cls(module_cls, device='cuda', max_epochs=1).fit(X, y)
p = tmpdir.mkdir('skorch').join('testmodel.pkl')
net.save_params(str(p))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
net.save_params(str(p))
net.save_params(f_params=str(p))

def test_load_cuda_params_to_cuda(self, net_cls, module_cls, data):
net = net_cls(module_cls, device='cuda').initialize()
# net_cuda.pt is a net trained on CUDA
net.load_params(os.path.join('skorch', 'tests', 'net_cuda.pt'))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
net.load_params(os.path.join('skorch', 'tests', 'net_cuda.pt'))
net.load_params(f_params=os.path.join('skorch', 'tests', 'net_cuda.pt'))

net.save_params(str(p))

net2 = net_cls(module_cls, device='cpu').initialize()
net2.load_params(str(p))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
net2.load_params(str(p))
net2.load_params(f_params=str(p))

Copy link
Member

@ottonemo ottonemo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, I think the things I addressed should (if at all) be done in a new PR.

map_location = torch.device('cpu')

return torch.load(f, map_location=map_location)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two things:

  1. The warning should include that the self.device parameter is now set to 'cpu'.
  2. I think we should refactor _get_state_dict into _get_map_location so it can be used in __setstate__ as well since the code there is doing basically the same thing.

@ottonemo ottonemo merged commit 3ae8120 into master Oct 23, 2018
@ottonemo ottonemo deleted the bugfix/load-save-cuda branch October 25, 2018 07:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants