-
Notifications
You must be signed in to change notification settings - Fork 394
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MAINT: make sure copy links weights and optimizer #318
MAINT: make sure copy links weights and optimizer #318
Conversation
Thank you for taking on this issue. The fix looks sound. The use of There are some adjustments required for the unit test:
But I also see the value in comparing the parameters after the fit call, so leave that in as well. The comparison of |
c3cdfbd
to
a8291b6
Compare
Comments addressed. I've added a lot more comments and this seems like a better test.
This is a user-facing test, and the reason I filed this issue and PR. I agree that it's redundant, but if that |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the changes, the test is now much better to understand. There are only 2 minor comments left.
skorch/tests/test_net.py
Outdated
assert all(close) | ||
|
||
# make sure the parameters change | ||
# cal this twice to make sure history updates below |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo "cal"; also, why are 2 calls needed? When I comment one out, the test still passes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this. I can also pass epochs=2
to partial_fit
, and I explained more why in the comments.
skorch/net.py
Outdated
@@ -1323,24 +1325,26 @@ def uses_cuda(device): | |||
device = device.type | |||
return device.startswith('cuda') | |||
|
|||
disable_cuda = False | |||
disable_cuda = uses_cuda(state['device']) and not torch.cuda.is_available() | |||
load_kwargs = {} if not disable_cuda else {'map_location': lambda store, loc: store} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line is too long, could you please break it?
8151c09
to
2942939
Compare
Thank you very much for this contribution. |
Closes #317.
This needs review, hence the WIP label. I don't use
state_dict
as mentioned in #317 (comment), and I coded to the test.