Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MAINT: make sure copy links weights and optimizer #318

Merged
merged 4 commits into from
Aug 23, 2018
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 24 additions & 20 deletions skorch/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -1307,13 +1307,15 @@ def _replace_callback(self, name, new_val):

def __getstate__(self):
state = self.__dict__.copy()
cuda_attrs = {}
for key in self.cuda_dependent_attributes_:
if key in state:
val = state.pop(key)
with tempfile.SpooledTemporaryFile() as f:
torch.save(val, f)
f.seek(0)
state[key] = f.read()
cuda_attrs[key] = val
with tempfile.SpooledTemporaryFile() as f:
torch.save(cuda_attrs, f)
f.seek(0)
state['cuda_dependent_attributes_'] = f.read()

return state

Expand All @@ -1323,24 +1325,26 @@ def uses_cuda(device):
device = device.type
return device.startswith('cuda')

disable_cuda = False
disable_cuda = uses_cuda(state['device']) and not torch.cuda.is_available()
load_kwargs = {} if not disable_cuda else {'map_location': lambda store, loc: store}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is too long, could you please break it?


with tempfile.SpooledTemporaryFile() as f:
f.write(state['cuda_dependent_attributes_'])
f.seek(0)
cuda_attrs = torch.load(f, **load_kwargs)

set_cuda_attrs = {}
state.update(cuda_attrs)
for key in self.cuda_dependent_attributes_:
if key not in state:
if key not in cuda_attrs:
continue
dump = state.pop(key)
with tempfile.SpooledTemporaryFile() as f:
f.write(dump)
f.seek(0)
if (
uses_cuda(state['device']) and
not torch.cuda.is_available()
):
disable_cuda = True
val = torch.load(
f, map_location=lambda storage, loc: storage)
else:
val = torch.load(f)
state[key] = val
set_cuda_attrs[key] = state.pop(key)
with tempfile.SpooledTemporaryFile() as f:
torch.save(cuda_attrs, f)
f.seek(0)
cuda_attrs = torch.load(f, **load_kwargs)

state.update(cuda_attrs)
if disable_cuda:
warnings.warn(
"Model configured to use CUDA but no CUDA devices "
Expand Down
46 changes: 46 additions & 0 deletions skorch/tests/test_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from unittest.mock import Mock
from unittest.mock import patch
from pathlib import Path
import copy

import numpy as np
import pytest
Expand Down Expand Up @@ -125,6 +126,51 @@ def net_pickleable(self, net_fit):
net_fit.callbacks_ = callbacks_
return net_clone

@pytest.mark.parametrize("copy_method", ["pickle", "copy.deepcopy"])
def test_train_net_after_copy(self, net_cls, module_cls, data,
copy_method):
# This test comes from [issue #317], and makes sure that models
# can be trained after copying (which is really pickling).
#
# [issue #317]:https://github.com/dnouri/skorch/issues/317
X, y = data
n1 = net_cls(module_cls)
n1.partial_fit(X, y, epochs=1)
if copy_method == "copy.deepcopy":
n2 = copy.deepcopy(n1)
elif copy_method == "pickle":
n2 = pickle.loads(pickle.dumps(n1))
else:
raise ValueError

# Test to make sure the parameters got copied correctly
close = [torch.allclose(p1, p2)
for p1, p2 in zip(n1.module_.parameters(),
n2.module_.parameters())]
assert all(close)

# make sure the parameters change
# cal this twice to make sure history updates below
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo "cal"; also, why are 2 calls needed? When I comment one out, the test still passes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this. I can also pass epochs=2 to partial_fit, and I explained more why in the comments.

n2.partial_fit(X, y, epochs=1)
n2.partial_fit(X, y, epochs=1)
far = [not torch.allclose(p1, p2)
for p1, p2 in zip(n1.module_.parameters(),
n2.module_.parameters())]
assert all(far)

# Make sure the model is being trained, and the loss actually changes
# (and hopefully decreases, but no test for that)
# If copied incorrectly, the optimizer can't see the gradients
# calculated by loss.backward(), so the loss stays *exactly* the same
assert n2.history[-1]['train_loss'] != n2.history[-2]['train_loss']

# Make sure the optimizer params and module params point to the same
# memory
for opt_param, param in zip(
n2.module_.parameters(),
n2.optimizer_.param_groups[0]['params']):
assert param is opt_param

def test_net_init_one_unknown_argument(self, net_cls, module_cls):
with pytest.raises(TypeError) as e:
net_cls(module_cls, unknown_arg=123)
Expand Down