-
Notifications
You must be signed in to change notification settings - Fork 394
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
skorch doctor: a tool to understand the net #912
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
2e55171
to
089d5f1
Compare
- Complete the docstrings - Some small refactorings - Make SkorchDoctor work with more complex architectures. Specifically, to make it work with BERT from transformers, module outputs that are lists, tuples, or dicts need to be unpacked.
Also, move matplotlib import into method.
This includes some bugs in skorch itself, which will require adding tests later.
These bugs are not strictly related to this PR but they were discovered while working on it.
@ottonemo @thomasjpfan This feature is now ready for review. Please find below the full description, which can also be used for the squashed commit message. DescriptionA helper class to assist in understanding the neural net training The The class will automatically record activations of each module, gradients, and Examples of what conclusions could be drawn from the data:
However, the helper class will not suggest any of those solutions itself, I A notebook to show the usage of https://github.com/skorch-dev/skorch/blob/skorch-doctor/notebooks/Skorch_Doctor.ipynb ImplementationBecause of the additional data being collected, depending on the use case, a
For storing activations, some heuristics are in place to deal with the output of Coincidental changesWhile working on this PR, I stumbled upon a few minor bugs that are fixed as
|
Can the coincidental changes be split out into their own PRs so they can be merged with independent of this PR? As for including skorch doctor, are there existing libraries that hook onto a module and record the same type of information to visualize? |
These bugfixes were originally provided in #912 but are now moved to their own PR. 1. Inferring the predict nonlinearity made the assumption that a net.criterion_ exists. However, this might not always be the case. Now, this function works even if the criterion attribute has a different name. Moreover, when more than one criterion is defined, the identity function will be returned. 2. The second bug is that in check_is_fitted, we made the check dependent on the module_ attribute. Again, we should not assume that it always exists, as users may define different names. Now, those custom names will be checked, and only if those don't exist is it assumed that the module_ attribute should exist. 3. The helper.py module now defines an __all__ attribute. This seems to be necessary for sphinx to build documentation for objects that are imported to, but not defined in, helper.py. Specifically, this is the case for the parse_args CLI feature. 4. Fixed a small bug in on_grad_computed, which has training=False as default arg but during training, it should be set to True. Thankfully, it had no consequences in our code since it's only used by GradientNormClipping, which doesn't check that argument.
I did: #927
I am not aware of any library that would provide similar functionality. As recording all this extra information is a memory overhead, no library would provide this by default. It only makes sense to log this with the specific purpose of performing the type of analysis provided here. The sources that inspired me to add the feature are the ones linked, i.e. the video series/blog post by Karpathy, the borealis.ai blog post, and the experience that colleagues and I gathered when trying to diagnose training issues. |
…eria (#927) These bugfixes were originally provided in #912 but are now moved to their own PR. 1. Inferring the predict nonlinearity made the assumption that a net.criterion_ exists. However, this might not always be the case. Now, this function works even if the criterion attribute has a different name. Moreover, when more than one criterion is defined, the identity function will be returned. 2. The second bug is that in check_is_fitted, we made the check dependent on the module_ attribute. Again, we should not assume that it always exists, as users may define different names. Now, those custom names will be checked, and only if those don't exist is it assumed that the module_ attribute should exist. 3. The helper.py module now defines an __all__ attribute. This seems to be necessary for sphinx to build documentation for objects that are imported to, but not defined in, helper.py. Specifically, this is the case for the parse_args CLI feature. 4. Fixed a small bug in on_grad_computed, which has training=False as default arg but during training, it should be set to True. Thankfully, it had no consequences in our code since it's only used by GradientNormClipping, which doesn't check that argument.
I think I accidentally added it back in after a merge with master branch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I recently got a chance to watch the videos. It was very interesting!
This is a first pass over the code itself.
skorch/_doctor.py
Outdated
# related to caching and/or memory-mapping. | ||
grad = grad.clone() | ||
|
||
log_grad[-1][param_name] = to_numpy(grad) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can the logs be renamed "records"? I keep thinking that "log_grad" means the mathematical "log(grad)".
_, axes = plt.subplots(nrows, 1, figsize=figsize, squeeze=squeeze) | ||
return axes | ||
|
||
def plot_loss(self, ax=None, figsize=None, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you think it's useful to have two kwargs
to pass into the train and val loss separately?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, then we couldn't just pass them as **kwargs
but would require them to be passed as dicts, which would result in a different API from all the other plotting functions. I see the benefit but I'm not sure if they outweigh the costs. Since the ax
is returned to the user at the end, they can still do some changes to the plot afterwards.
bins=None, | ||
density=True, | ||
figsize=None, | ||
**kwargs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
kwargs
is routed directly to the hist
plot. What do you think of the following API?
def plot_activation(...,
hist_kwargs=None,
):
hist_kwargs = hist_kwargs or {}
This comment applies to the other plot_*
methods.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if I quite understand, you want to avoid having **kwargs
at all but rather prefer to have foo_kwargs
where foo
corresponds to the kind of plot being generated? One reason why I chose to implement kwargs
like this is because it is similar to the pandas plotting API, where you can also pass arbitrary kwargs
and they are passed to the underlying plotting function.
def plot_param_updates(self, match_fn=None, axes=None, figsize=None, **kwargs): | ||
"""Plot the distribution of relative parameter updates. | ||
|
||
Plots the log10 of the standard deviation of the parameter update |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I may have missed it, but what was the reason behind log10
compared to natural log (ln
)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think there is any strict reason, it's just easier to translate log10
in your head I guess. When you look at the corresponding part of the video, you can see that Andrej even starts with log
and then changes to log10
:)
skorch/_doctor.py
Outdated
|
||
axes = self._get_axes(axes, figsize=figsize, nrows=len(module_names)) | ||
|
||
for module_name, ax in zip(module_names, axes): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should the axes be flattened for cases where the user passes in axes that is "2d":
fig, axes = plt.subplots(3, 4)
doctor.plot_activations(..., axes=axes)
skorch/tests/test_doctor.py
Outdated
|
||
def test_callbacks_cleaned_up_after_fit(self, doctor, net_cls, module_cls): | ||
# make sure that the callbacks are the same before and after, this is | ||
# important because SkorchDoctor will temporarily add a callback |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm trying to find where SkorchDoctor
adds a callback. May you point to the code where this happens?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah yes, an earlier implementation added an extra callback, but after a refactoring, this was no longer necessary. Therefore, the test can now be removed.
Co-authored-by: Thomas J. Fan <[email protected]>
To avoid confusion between "log" in the sense of "record" vs the mathematical function.
This regresses into the bug that this line was supposed to be prevent in the first place.
This test was required for an earlier iteration, but since callbacks are no longer mutated, it can be removed now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the review Thomas, I made some changes according to your comments or asked some clarifying questions.
Regarding the name, good point about the potentially confusing name of "logs". I changed it now to "recs". I didn't choose "records" because it's longer, I hope that works for you.
_, axes = plt.subplots(nrows, 1, figsize=figsize, squeeze=squeeze) | ||
return axes | ||
|
||
def plot_loss(self, ax=None, figsize=None, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, then we couldn't just pass them as **kwargs
but would require them to be passed as dicts, which would result in a different API from all the other plotting functions. I see the benefit but I'm not sure if they outweigh the costs. Since the ax
is returned to the user at the end, they can still do some changes to the plot afterwards.
bins=None, | ||
density=True, | ||
figsize=None, | ||
**kwargs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if I quite understand, you want to avoid having **kwargs
at all but rather prefer to have foo_kwargs
where foo
corresponds to the kind of plot being generated? One reason why I chose to implement kwargs
like this is because it is similar to the pandas plotting API, where you can also pass arbitrary kwargs
and they are passed to the underlying plotting function.
def plot_param_updates(self, match_fn=None, axes=None, figsize=None, **kwargs): | ||
"""Plot the distribution of relative parameter updates. | ||
|
||
Plots the log10 of the standard deviation of the parameter update |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think there is any strict reason, it's just easier to translate log10
in your head I guess. When you look at the corresponding part of the video, you can see that Andrej even starts with log
and then changes to log10
:)
skorch/tests/test_doctor.py
Outdated
|
||
def test_callbacks_cleaned_up_after_fit(self, doctor, net_cls, module_cls): | ||
# make sure that the callbacks are the same before and after, this is | ||
# important because SkorchDoctor will temporarily add a callback |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah yes, an earlier implementation added an extra callback, but after a refactoring, this was no longer necessary. Therefore, the test can now be removed.
…eria (#927) These bugfixes were originally provided in #912 but are now moved to their own PR. 1. Inferring the predict nonlinearity made the assumption that a net.criterion_ exists. However, this might not always be the case. Now, this function works even if the criterion attribute has a different name. Moreover, when more than one criterion is defined, the identity function will be returned. 2. The second bug is that in check_is_fitted, we made the check dependent on the module_ attribute. Again, we should not assume that it always exists, as users may define different names. Now, those custom names will be checked, and only if those don't exist is it assumed that the module_ attribute should exist. 3. The helper.py module now defines an __all__ attribute. This seems to be necessary for sphinx to build documentation for objects that are imported to, but not defined in, helper.py. Specifically, this is the case for the parse_args CLI feature. 4. Fixed a small bug in on_grad_computed, which has training=False as default arg but during training, it should be set to True. Thankfully, it had no consequences in our code since it's only used by GradientNormClipping, which doesn't check that argument.
Passing a match_fn will now only record layers/gradients/updates whose name matches. This is especially helpful if there is otherwise not enough memory to record everything.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The update with match_fn
makes sense. I left one small testing commit, otherwise I'm okay with merging.
doctor = doctor_cls(net, match_fn=match_fn) | ||
doctor.fit(*data) | ||
|
||
for rec in doctor.activation_recs_['module']: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know that an error would raise if nothing matched, but can we also assert len(doctor.activation_recs_["module"])
to make sure there is something in the list?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
As suggested by reviewer.
Preparation for release of version 0.13.0 Release text: The new skorch release is here and it has some changes that will be exiting for some users. - First of all, you may have heard of the [PyTorch 2.0 release](https://pytorch.org/get-started/pytorch-2.0/), which includes the option to compile the PyTorch module for better runtime performance. This skorch release allows you to pass `compile=True` when initializing the net to enable compilation. - Support for training on multiple GPUs with the help of the [`accelerate`](https://huggingface.co/docs/accelerate/index) package has been improved by fixing some bugs and providing a dedicated [history class](https://skorch.readthedocs.io/en/latest/user/history.html#distributed-history). Our documentation contains more information on [what to consider when training on multiple GPUs](https://skorch.readthedocs.io/en/latest/user/huggingface.html#caution-when-using-a-multi-gpu-setup). - If you have ever been frustrated with your neural net not training properly, you know how hard it can be to discover the underlying issue. Using the new [`SkorchDoctor`](https://skorch.readthedocs.io/en/latest/helper.html#skorch.helper.SkorchDoctor) class will simplify the diagnosis of underlying issues. Take a look at the accompanying [notebook](https://nbviewer.org/github/skorch-dev/skorch/blob/master/notebooks/Skorch_Doctor.ipynb) Apart from that, a few bugs have been fixed and the included notebooks have been updated to properly install requirements on Google Colab. We are grateful for external contributors, many thanks to: - Kshiteej K (kshitij12345) - Muhammad Abdullah (abdulasiraj) - Royi (RoyiAvital) - Sawradip Saha (sawradip) - y10ab1 (y10ab1) Find below the list of all changes since v0.12.1 below: ### Added - Add support for compiled PyTorch modules using the `torch.compile` function, introduced in [PyTorch 2.0 release](https://pytorch.org/get-started/pytorch-2.0/), which can greatly improve performance on new GPU architectures; to use it, initialize your net with the `compile=True` argument, further compilation arguments can be specified using the dunder notation, e.g. `compile__dynamic=True` - Add a class [`DistributedHistory`](https://skorch.readthedocs.io/en/latest/history.html#skorch.history.DistributedHistory) which should be used when training in a multi GPU setting (#955) - `SkorchDoctor`: A helper class that assists in understanding and debugging the neural net training, see [this notebook](https://nbviewer.org/github/skorch-dev/skorch/blob/master/notebooks/Skorch_Doctor.ipynb) (#912) - When using `AccelerateMixin`, it is now possible to prevent unwrapping of the modules by setting `unwrap_after_train=True` (#963) ### Fixed - Fixed install command to work with recent changes in Google Colab (#928) - Fixed a couple of bugs related to using non-default modules and criteria (#927) - Fixed a bug when using `AccelerateMixin` in a multi-GPU setup (#947) - `_get_param_names` returns a list instead of a generator so that subsequent error messages return useful information instead of a generator `repr` string (#925) - Fixed a bug that caused modules to not be sufficiently unwrapped at the end of training when using `AccelerateMixin`, which could prevent them from being pickleable (#963)
I recently watched Andrej Karpathy's new series (yes, I still feel like a noob :)) and also remembered his blog post from a few years ago. These sources contain a couple of useful tips for better understanding the neural net training and diagnose potential problems, shining some light into the blackbox.
I wondered if we could not automate some of these steps to make it super easy for users to attain those insights. So I started working on automating those steps that make sense to be automated. The results can be seen in the notebook here.
This is obviously still in a very rough shape, but before spending more time on it, I wanted to get feedback from @ottonemo and @thomasjpfan if they think it's useful to have something like this.
If we want to proceed, I would implement this in a proper module with tests and docs, and also test it with a more complex architecture (BERT or something along those lines).
For now, the new allows to train with some data and automatically records some useful data. As a user, I can just do:
I also dabbled into creating something similar to sklearn's
cv_resutls_
, i.e. a report with some summary statistics. Not sure how useful that would be.Since
SkorchDoctor
adds a bunch of hooks and modifies the net in a few ways. This will create a memory overhead, but I try to pull as much to down from GPU to numpy as possible. It's also better to create a new net instance afterwards, though I try to clean up the hooks and modifications as best as possible.Ah yes, I also found a small bug in
on_grad_computed
, which hastraining=False
as default arg but obviously during training, it should be set toTrue
. It has no consequences in our code since it's only used byGradientNormClipping
, which doesn't check the argument, but still...