Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MAINT: make sure copy links weights and optimizer #318

Merged
merged 4 commits into from
Aug 23, 2018

Conversation

stsievert
Copy link
Contributor

Closes #317.

This needs review, hence the WIP label. I don't use state_dict as mentioned in #317 (comment), and I coded to the test.

@benjamin-work
Copy link
Contributor

Thank you for taking on this issue.

The fix looks sound. The use of state_dict could help with portability but that is another issue.

There are some adjustments required for the unit test:

  • As it is right now, it succeeds even without the fix. The reason is that you call fit the second time instead of partial_fit (always first check whether the test fails before writing the bug fix).
  • The test name should be more explanatory.
  • The unit test definitely needs more explanations in the form of comments, e.g. it is not restricted to deep copying but would also occur with pickle. You can also link to the issue in a comment.
  • Finally, it is not quite clear whether all of the asserts are needed. Something like this directly after the deepcopy already reveals the bug:
        # check that the optimizer references the same parameters as used by 
        # the module
        for p0, p1 in zip(
            n2.module_.parameters(), n2.optimizer_.param_groups[0]['params']
        ):
            assert p0 is p1

But I also see the value in comparing the parameters after the fit call, so leave that in as well. The comparison of train_loss, on the other hand, seems pointless -- if the module parameters are updated, can the loss ever the same (and if it were, the issue would certainly lie somewhere else)?

@stsievert stsievert force-pushed the frozen-fit-after-clone branch from c3cdfbd to a8291b6 Compare August 21, 2018 16:10
@stsievert
Copy link
Contributor Author

Comments addressed. I've added a lot more comments and this seems like a better test.

The comparison of train_loss, on the other hand, seems pointless -- if the module parameters are updated, can the loss ever the same (and if it were, the issue would certainly lie somewhere else)?

This is a user-facing test, and the reason I filed this issue and PR. I agree that it's redundant, but if that assert fails users will complain, and it's a good sanity check.

@stsievert stsievert changed the title WIP: MAINT: make sure copy links weights and optimizer MAINT: make sure copy links weights and optimizer Aug 22, 2018
Copy link
Contributor

@benjamin-work benjamin-work left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the changes, the test is now much better to understand. There are only 2 minor comments left.

assert all(close)

# make sure the parameters change
# cal this twice to make sure history updates below
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo "cal"; also, why are 2 calls needed? When I comment one out, the test still passes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this. I can also pass epochs=2 to partial_fit, and I explained more why in the comments.

skorch/net.py Outdated
@@ -1323,24 +1325,26 @@ def uses_cuda(device):
device = device.type
return device.startswith('cuda')

disable_cuda = False
disable_cuda = uses_cuda(state['device']) and not torch.cuda.is_available()
load_kwargs = {} if not disable_cuda else {'map_location': lambda store, loc: store}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is too long, could you please break it?

@stsievert stsievert force-pushed the frozen-fit-after-clone branch from 8151c09 to 2942939 Compare August 22, 2018 16:02
@benjamin-work
Copy link
Contributor

Thank you very much for this contribution.

@benjamin-work benjamin-work merged commit 0f7823c into skorch-dev:master Aug 23, 2018
spott pushed a commit to spott/skorch that referenced this pull request Oct 12, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants