-
Notifications
You must be signed in to change notification settings - Fork 394
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MAINT: make sure copy links weights and optimizer #318
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,7 @@ | |
from unittest.mock import Mock | ||
from unittest.mock import patch | ||
from pathlib import Path | ||
import copy | ||
|
||
import numpy as np | ||
import pytest | ||
|
@@ -125,6 +126,51 @@ def net_pickleable(self, net_fit): | |
net_fit.callbacks_ = callbacks_ | ||
return net_clone | ||
|
||
@pytest.mark.parametrize("copy_method", ["pickle", "copy.deepcopy"]) | ||
def test_train_net_after_copy(self, net_cls, module_cls, data, | ||
copy_method): | ||
# This test comes from [issue #317], and makes sure that models | ||
# can be trained after copying (which is really pickling). | ||
# | ||
# [issue #317]:https://github.com/dnouri/skorch/issues/317 | ||
X, y = data | ||
n1 = net_cls(module_cls) | ||
n1.partial_fit(X, y, epochs=1) | ||
if copy_method == "copy.deepcopy": | ||
n2 = copy.deepcopy(n1) | ||
elif copy_method == "pickle": | ||
n2 = pickle.loads(pickle.dumps(n1)) | ||
else: | ||
raise ValueError | ||
|
||
# Test to make sure the parameters got copied correctly | ||
close = [torch.allclose(p1, p2) | ||
for p1, p2 in zip(n1.module_.parameters(), | ||
n2.module_.parameters())] | ||
assert all(close) | ||
|
||
# make sure the parameters change | ||
# cal this twice to make sure history updates below | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Typo "cal"; also, why are 2 calls needed? When I comment one out, the test still passes. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for this. I can also pass |
||
n2.partial_fit(X, y, epochs=1) | ||
n2.partial_fit(X, y, epochs=1) | ||
far = [not torch.allclose(p1, p2) | ||
for p1, p2 in zip(n1.module_.parameters(), | ||
n2.module_.parameters())] | ||
assert all(far) | ||
|
||
# Make sure the model is being trained, and the loss actually changes | ||
# (and hopefully decreases, but no test for that) | ||
# If copied incorrectly, the optimizer can't see the gradients | ||
# calculated by loss.backward(), so the loss stays *exactly* the same | ||
assert n2.history[-1]['train_loss'] != n2.history[-2]['train_loss'] | ||
|
||
# Make sure the optimizer params and module params point to the same | ||
# memory | ||
for opt_param, param in zip( | ||
n2.module_.parameters(), | ||
n2.optimizer_.param_groups[0]['params']): | ||
assert param is opt_param | ||
|
||
def test_net_init_one_unknown_argument(self, net_cls, module_cls): | ||
with pytest.raises(TypeError) as e: | ||
net_cls(module_cls, unknown_arg=123) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line is too long, could you please break it?