Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

skorch doctor: a tool to understand the net #912

Merged
merged 26 commits into from
May 4, 2023
Merged

Conversation

BenjaminBossan
Copy link
Collaborator

@BenjaminBossan BenjaminBossan commented Nov 4, 2022

I recently watched Andrej Karpathy's new series (yes, I still feel like a noob :)) and also remembered his blog post from a few years ago. These sources contain a couple of useful tips for better understanding the neural net training and diagnose potential problems, shining some light into the blackbox.

I wondered if we could not automate some of these steps to make it super easy for users to attain those insights. So I started working on automating those steps that make sense to be automated. The results can be seen in the notebook here.

This is obviously still in a very rough shape, but before spending more time on it, I wanted to get feedback from @ottonemo and @thomasjpfan if they think it's useful to have something like this.

If we want to proceed, I would implement this in a proper module with tests and docs, and also test it with a more complex architecture (BERT or something along those lines).

For now, the new allows to train with some data and automatically records some useful data. As a user, I can just do:

doctor = SkorchDoctor(net)
doctor.fit(X_sample, y_sample)
doctor.plot_loss()
doctor.plot_activations()
doctor.plot_gradients()
# etc.

I also dabbled into creating something similar to sklearn's cv_resutls_, i.e. a report with some summary statistics. Not sure how useful that would be.

Since SkorchDoctor adds a bunch of hooks and modifies the net in a few ways. This will create a memory overhead, but I try to pull as much to down from GPU to numpy as possible. It's also better to create a new net instance afterwards, though I try to clean up the hooks and modifications as best as possible.

Ah yes, I also found a small bug in on_grad_computed, which has training=False as default arg but obviously during training, it should be set to True. It has no consequences in our code since it's only used by GradientNormClipping, which doesn't check the argument, but still...

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

BenjaminBossan and others added 12 commits November 14, 2022 16:50
- Complete the docstrings
- Some small refactorings
- Make SkorchDoctor work with more complex architectures. Specifically,
  to make it work with BERT from transformers, module outputs that are
  lists, tuples, or dicts need to be unpacked.
Also, move matplotlib import into method.
This includes some bugs in skorch itself, which will require adding
tests later.
These bugs are not strictly related to this PR but they were discovered
while working on it.
@BenjaminBossan BenjaminBossan marked this pull request as ready for review November 24, 2022 10:54
@BenjaminBossan
Copy link
Collaborator Author

BenjaminBossan commented Nov 24, 2022

@ottonemo @thomasjpfan This feature is now ready for review. Please find below the full description, which can also be used for the squashed commit message.

Description

A helper class to assist in understanding the neural net training

The SkorchDoctor helper class allows users to wrap their neural net before
training and then automatically collect useful data that allows to better
understand what is going on during training and how to possibly improve it.

The class will automatically record activations of each module, gradients, and
updates of each learnable parameter, all of those for each training step. Once
training is finished, the user can either directly take a look at the data,
which is stored as an attribute on the helper class, or use one of the provided
plotting functions (requires matplotlib) to plot distributions of the data.

Examples of what conclusions could be drawn from the data:

  • Net is not powerful enough
  • Need for better weight initialization or normalization
  • Need to adjust optimizer
  • Need for gradient clipping

However, the helper class will not suggest any of those solutions itself, I
don't think that's possible. It is only intended to help surfacing potential
problems, it's up to the user to decide on a solution.

A notebook to show the usage of SkorchDoctor, once for a simple MLP and once for
fine-tuning a BERT model, is provided:

https://github.com/skorch-dev/skorch/blob/skorch-doctor/notebooks/Skorch_Doctor.ipynb

Implementation

Because of the additional data being collected, depending on the use case, a
significant memory overhead is expected. To keep this in check, a few measures
are taken:

  • The collected data is immediately pulled to numpy to avoid clobbering GPU
    memory.
  • It is documented, and shown in examples, that you should use only a small
    amount of data and low number of epochs, since that's enough to understand
    most problems. Most notably, this helps with storing less data about
    activations.
  • For parameter updates, only a single scalar per weight/bias is stored,
    indicating the relative magnitude of the update.
  • The biggest overhead will most likely come from storing the gradients, not
    sure if something can be done here without losing too much useful data.

For storing activations, some heuristics are in place to deal with the output of
the modules. The problem here is that modules can return any arbitrary data from
their forward call. A few assumptions are being made here: The output can be
shoved into to_numpy and it has to be either a torch tensor, a list, a tuple, or
a mapping of torch tensors. If it's neither of those, an error is raised.

Coincidental changes

While working on this PR, I stumbled upon a few minor bugs that are fixed as
part of the PR:

  1. Inferring the predict nonlinearity made the assumption that a net.criterion_
    exists. However, this might not always be the case. Now, this function works
    even if the criterion attribute has a different name. Moreover, when more
    than one criterion is defined, the identity function will be returned.
  2. The second bug is that in check_is_fitted, we made the check dependent on the
    module_ attribute. Again, we should not assume that it always exists, as
    users may define different names. Now, those custom names will be checked,
    and only if those don't exist is it assumed that the module_ attribute should
    exist.
  3. The helper.py module now defines an __all__ attribute. This seems to be
    necessary for sphinx to build documentation for objects that are imported to,
    but not defined in, helper.py. Specifically, this is the case for the new
    SkorchDoctor class, but also the existing parse_args CLI feature.
  4. Fixed a small bug in on_grad_computed, which has training=False as default
    arg but during training, it should be set to True. Thankfully, it had no
    consequences in our code since it's only used by GradientNormClipping, which
    doesn't check that argument.

@BenjaminBossan BenjaminBossan changed the title [WIP] skorch doctor: a tool to understand the net skorch doctor: a tool to understand the net Dec 9, 2022
@thomasjpfan
Copy link
Member

Can the coincidental changes be split out into their own PRs so they can be merged with independent of this PR?

As for including skorch doctor, are there existing libraries that hook onto a module and record the same type of information to visualize?

BenjaminBossan added a commit that referenced this pull request Dec 22, 2022
These bugfixes were originally provided in #912 but are now moved to their own
PR.

1. Inferring the predict nonlinearity made the assumption that a net.criterion_
   exists. However, this might not always be the case. Now, this function works
   even if the criterion attribute has a different name. Moreover, when more
   than one criterion is defined, the identity function will be returned.
2. The second bug is that in check_is_fitted, we made the check dependent on the
   module_ attribute. Again, we should not assume that it always exists, as
   users may define different names. Now, those custom names will be checked,
   and only if those don't exist is it assumed that the module_ attribute should
   exist.
3. The helper.py module now defines an __all__ attribute. This seems to be
   necessary for sphinx to build documentation for objects that are imported to,
   but not defined in, helper.py. Specifically, this is the case for the
   parse_args CLI feature.
4. Fixed a small bug in on_grad_computed, which has training=False as default
   arg but during training, it should be set to True. Thankfully, it had no
   consequences in our code since it's only used by GradientNormClipping, which
   doesn't check that argument.
@BenjaminBossan
Copy link
Collaborator Author

Can the coincidental changes be split out into their own PRs so they can be merged with independent of this PR?

I did: #927

As for including skorch doctor, are there existing libraries that hook onto a module and record the same type of information to visualize?

I am not aware of any library that would provide similar functionality. As recording all this extra information is a memory overhead, no library would provide this by default. It only makes sense to log this with the specific purpose of performing the type of analysis provided here.

The sources that inspired me to add the feature are the ones linked, i.e. the video series/blog post by Karpathy, the borealis.ai blog post, and the experience that colleagues and I gathered when trying to diagnose training issues.

@BenjaminBossan BenjaminBossan self-assigned this Dec 22, 2022
BenjaminBossan added a commit that referenced this pull request Jan 3, 2023
…eria (#927)

These bugfixes were originally provided in #912 but are now moved to their own
PR.

1. Inferring the predict nonlinearity made the assumption that a net.criterion_
   exists. However, this might not always be the case. Now, this function works
   even if the criterion attribute has a different name. Moreover, when more
   than one criterion is defined, the identity function will be returned.
2. The second bug is that in check_is_fitted, we made the check dependent on the
   module_ attribute. Again, we should not assume that it always exists, as
   users may define different names. Now, those custom names will be checked,
   and only if those don't exist is it assumed that the module_ attribute should
   exist.
3. The helper.py module now defines an __all__ attribute. This seems to be
   necessary for sphinx to build documentation for objects that are imported to,
   but not defined in, helper.py. Specifically, this is the case for the
   parse_args CLI feature.
4. Fixed a small bug in on_grad_computed, which has training=False as default
   arg but during training, it should be set to True. Thankfully, it had no
   consequences in our code since it's only used by GradientNormClipping, which
   doesn't check that argument.
I think I accidentally added it back in after a merge with master
branch.
Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recently got a chance to watch the videos. It was very interesting!

This is a first pass over the code itself.

# related to caching and/or memory-mapping.
grad = grad.clone()

log_grad[-1][param_name] = to_numpy(grad)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can the logs be renamed "records"? I keep thinking that "log_grad" means the mathematical "log(grad)".

skorch/_doctor.py Show resolved Hide resolved
_, axes = plt.subplots(nrows, 1, figsize=figsize, squeeze=squeeze)
return axes

def plot_loss(self, ax=None, figsize=None, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think it's useful to have two kwargs to pass into the train and val loss separately?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, then we couldn't just pass them as **kwargs but would require them to be passed as dicts, which would result in a different API from all the other plotting functions. I see the benefit but I'm not sure if they outweigh the costs. Since the ax is returned to the user at the end, they can still do some changes to the plot afterwards.

bins=None,
density=True,
figsize=None,
**kwargs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kwargs is routed directly to the hist plot. What do you think of the following API?

def plot_activation(...,
    hist_kwargs=None,
):
    hist_kwargs = hist_kwargs or {}

This comment applies to the other plot_* methods.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if I quite understand, you want to avoid having **kwargs at all but rather prefer to have foo_kwargs where foo corresponds to the kind of plot being generated? One reason why I chose to implement kwargs like this is because it is similar to the pandas plotting API, where you can also pass arbitrary kwargs and they are passed to the underlying plotting function.

def plot_param_updates(self, match_fn=None, axes=None, figsize=None, **kwargs):
"""Plot the distribution of relative parameter updates.

Plots the log10 of the standard deviation of the parameter update
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I may have missed it, but what was the reason behind log10 compared to natural log (ln)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think there is any strict reason, it's just easier to translate log10 in your head I guess. When you look at the corresponding part of the video, you can see that Andrej even starts with log and then changes to log10 :)


axes = self._get_axes(axes, figsize=figsize, nrows=len(module_names))

for module_name, ax in zip(module_names, axes):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the axes be flattened for cases where the user passes in axes that is "2d":

fig, axes = plt.subplots(3, 4)

doctor.plot_activations(..., axes=axes)


def test_callbacks_cleaned_up_after_fit(self, doctor, net_cls, module_cls):
# make sure that the callbacks are the same before and after, this is
# important because SkorchDoctor will temporarily add a callback
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm trying to find where SkorchDoctor adds a callback. May you point to the code where this happens?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes, an earlier implementation added an extra callback, but after a refactoring, this was no longer necessary. Therefore, the test can now be removed.

BenjaminBossan and others added 6 commits February 13, 2023 14:06
To avoid confusion between "log" in the sense of "record" vs the
mathematical function.
This regresses into the bug that this line was supposed to be prevent in
the first place.
This test was required for an earlier iteration, but since callbacks are
no longer mutated, it can be removed now.
Copy link
Collaborator Author

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review Thomas, I made some changes according to your comments or asked some clarifying questions.

Regarding the name, good point about the potentially confusing name of "logs". I changed it now to "recs". I didn't choose "records" because it's longer, I hope that works for you.

skorch/_doctor.py Show resolved Hide resolved
_, axes = plt.subplots(nrows, 1, figsize=figsize, squeeze=squeeze)
return axes

def plot_loss(self, ax=None, figsize=None, **kwargs):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, then we couldn't just pass them as **kwargs but would require them to be passed as dicts, which would result in a different API from all the other plotting functions. I see the benefit but I'm not sure if they outweigh the costs. Since the ax is returned to the user at the end, they can still do some changes to the plot afterwards.

bins=None,
density=True,
figsize=None,
**kwargs
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if I quite understand, you want to avoid having **kwargs at all but rather prefer to have foo_kwargs where foo corresponds to the kind of plot being generated? One reason why I chose to implement kwargs like this is because it is similar to the pandas plotting API, where you can also pass arbitrary kwargs and they are passed to the underlying plotting function.

def plot_param_updates(self, match_fn=None, axes=None, figsize=None, **kwargs):
"""Plot the distribution of relative parameter updates.

Plots the log10 of the standard deviation of the parameter update
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think there is any strict reason, it's just easier to translate log10 in your head I guess. When you look at the corresponding part of the video, you can see that Andrej even starts with log and then changes to log10 :)


def test_callbacks_cleaned_up_after_fit(self, doctor, net_cls, module_cls):
# make sure that the callbacks are the same before and after, this is
# important because SkorchDoctor will temporarily add a callback
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes, an earlier implementation added an extra callback, but after a refactoring, this was no longer necessary. Therefore, the test can now be removed.

BenjaminBossan added a commit that referenced this pull request Mar 17, 2023
…eria (#927)

These bugfixes were originally provided in #912 but are now moved to their own
PR.

1. Inferring the predict nonlinearity made the assumption that a net.criterion_
   exists. However, this might not always be the case. Now, this function works
   even if the criterion attribute has a different name. Moreover, when more
   than one criterion is defined, the identity function will be returned.
2. The second bug is that in check_is_fitted, we made the check dependent on the
   module_ attribute. Again, we should not assume that it always exists, as
   users may define different names. Now, those custom names will be checked,
   and only if those don't exist is it assumed that the module_ attribute should
   exist.
3. The helper.py module now defines an __all__ attribute. This seems to be
   necessary for sphinx to build documentation for objects that are imported to,
   but not defined in, helper.py. Specifically, this is the case for the
   parse_args CLI feature.
4. Fixed a small bug in on_grad_computed, which has training=False as default
   arg but during training, it should be set to True. Thankfully, it had no
   consequences in our code since it's only used by GradientNormClipping, which
   doesn't check that argument.
Passing a match_fn will now only record layers/gradients/updates whose
name matches. This is especially helpful if there is otherwise not
enough memory to record everything.
Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The update with match_fn makes sense. I left one small testing commit, otherwise I'm okay with merging.

doctor = doctor_cls(net, match_fn=match_fn)
doctor.fit(*data)

for rec in doctor.activation_recs_['module']:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know that an error would raise if nothing matched, but can we also assert len(doctor.activation_recs_["module"]) to make sure there is something in the list?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@BenjaminBossan BenjaminBossan merged commit 785b917 into master May 4, 2023
@BenjaminBossan BenjaminBossan deleted the skorch-doctor branch May 4, 2023 09:14
@BenjaminBossan BenjaminBossan mentioned this pull request May 8, 2023
BenjaminBossan added a commit that referenced this pull request May 17, 2023
Preparation for release of version 0.13.0

Release text:

The new skorch release is here and it has some changes that will be exiting for
some users.

- First of all, you may have heard of the [PyTorch 2.0
  release](https://pytorch.org/get-started/pytorch-2.0/), which includes the
  option to compile the PyTorch module for better runtime performance. This
  skorch release allows you to pass `compile=True` when initializing the net to
  enable compilation.
- Support for training on multiple GPUs with the help of the
  [`accelerate`](https://huggingface.co/docs/accelerate/index) package has been
  improved by fixing some bugs and providing a dedicated [history
  class](https://skorch.readthedocs.io/en/latest/user/history.html#distributed-history).
  Our documentation contains more information on [what to consider when training
  on multiple
  GPUs](https://skorch.readthedocs.io/en/latest/user/huggingface.html#caution-when-using-a-multi-gpu-setup).
- If you have ever been frustrated with your neural net not training properly,
  you know how hard it can be to discover the underlying issue. Using the new
  [`SkorchDoctor`](https://skorch.readthedocs.io/en/latest/helper.html#skorch.helper.SkorchDoctor)
  class will simplify the diagnosis of underlying issues. Take a look at the
  accompanying
  [notebook](https://nbviewer.org/github/skorch-dev/skorch/blob/master/notebooks/Skorch_Doctor.ipynb)

Apart from that, a few bugs have been fixed and the included notebooks have been
updated to properly install requirements on Google Colab.

We are grateful for external contributors, many thanks to:

- Kshiteej K (kshitij12345)
- Muhammad Abdullah (abdulasiraj)
- Royi (RoyiAvital)
- Sawradip Saha (sawradip)
- y10ab1 (y10ab1)

Find below the list of all changes since v0.12.1 below:

### Added
- Add support for compiled PyTorch modules using the `torch.compile` function,
  introduced in [PyTorch 2.0
  release](https://pytorch.org/get-started/pytorch-2.0/), which can greatly
  improve performance on new GPU architectures; to use it, initialize your net
  with the `compile=True` argument, further compilation arguments can be
  specified using the dunder notation, e.g. `compile__dynamic=True`
- Add a class
  [`DistributedHistory`](https://skorch.readthedocs.io/en/latest/history.html#skorch.history.DistributedHistory)
  which should be used when training in a multi GPU setting (#955)
- `SkorchDoctor`: A helper class that assists in understanding and debugging the
  neural net training, see [this
  notebook](https://nbviewer.org/github/skorch-dev/skorch/blob/master/notebooks/Skorch_Doctor.ipynb)
  (#912)
- When using `AccelerateMixin`, it is now possible to prevent unwrapping of the
  modules by setting `unwrap_after_train=True` (#963)

### Fixed
- Fixed install command to work with recent changes in Google Colab (#928)
- Fixed a couple of bugs related to using non-default modules and criteria
  (#927)
- Fixed a bug when using `AccelerateMixin` in a multi-GPU setup (#947)
- `_get_param_names` returns a list instead of a generator so that subsequent
  error messages return useful information instead of a generator `repr` string
  (#925)
- Fixed a bug that caused modules to not be sufficiently unwrapped at the end of
  training when using `AccelerateMixin`, which could prevent them from being
  pickleable (#963)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants