Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model Verification in Trainer #1237

Open
TylerYep opened this issue Mar 25, 2020 · 19 comments
Open

Model Verification in Trainer #1237

TylerYep opened this issue Mar 25, 2020 · 19 comments
Assignees
Labels
feature Is an improvement or enhancement help wanted Open to be worked on let's do it! approved to implement
Milestone

Comments

@TylerYep
Copy link
Contributor

🚀 Feature

Verifies that the provided model code does not mix up data across the batch dimension. We do this by setting the loss to be something trivial (e.g. the sum of all outputs of example i), running the backward pass all the way to the input, and ensuring that we only get a non-zero gradient on the i-th input.

Motivation

First of all, I would like to say thank you for the fantastic work being done on this project. Recently, I was working on a side project that has almost the exact same goal as this one, which I used as motivation to learn more about PyTorch and how to make Deep Learning easier. Clearly, this project is a lot more thought-out than mine :^), but I wanted to see if there were any ideas I developed independently that might be useful in this project.

One of the most useful utils I've implemented is a verification step before the model runs. In my project, this verification step performs checks such as:

  • ensuring data is not mixed across the batch dimension
  • ensuring the model can overfit a single example
  • ensuring that all layers of the model are training (or selected layers are properly frozen)

Since I am very new to this project, I thought that the first bullet point might be a good place to start.

Pitch

Given the introductory example in the documention, assume we had written some poor tensor operations in our model like so:

class BadModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.layer_1 = torch.nn.Linear(28 * 28, 128)
        self.layer_2 = torch.nn.Linear(128, 256)
        self.layer_3 = torch.nn.Linear(256, 10)

    def forward(self, x):
        batch_size, channels, width, height = x.size()

        ###
        # x = x.view(batch_size, -1)
        ###
        x = x.view(-1, 1, 56, 56)
        x = x.permute(1, 0, 3, 2)
        x = x.reshape((batch_size, -1))
        ###

        x = self.layer_1(x)
        x = torch.relu(x)
        x = self.layer_2(x)
        x = torch.relu(x)
        x = self.layer_3(x)
        x = torch.log_softmax(x, dim=1)
        return x

When we start to train our model, everything begins training smoothly. However, this code is clearly wrong - we are crossing image data from separate datapoints in our batch.

It would be helpful if Lightning gave us a warning if this has happened. For example:

def check_batch_dimension(model, loader, optimizer, test_val=2):
    model.eval()
    torch.set_grad_enabled(True)
    data, _ = next(iter(loader))
    optimizer.zero_grad()
    data.requires_grad_()

    output = model(data)
    loss = output[test_val].sum()
    loss.backward()

    error_msg = "Your model is mixing up data across the batch dimension!"
    assert loss != 0
    assert (data.grad[test_val] != 0).any(), error_msg
    assert (data.grad[:test_val] == 0.).all() and (data.grad[test_val+1:] == 0.).all(), error_msg

This function verifies that only a single datapoint in the batch should have a nonzero gradient. This check has saved me countless times from running a poorly written model. :)

Implementation-wise, I am looking for any advice on whether this is a useful effort, whether it fits into the intended goals of Lightning, and what are possible difficulties that may arise.

Alternatives

It is clear that the feature as it stands will not work for all models, as some variants of LSTMs and such use a different dimension as its batch dimension (maybe this can be a parameter). There also might be issues if the batch is split up somewhere - I'm not quite certain how everything in this project works, particularly around gradient accumulation.

However, I would expect that this would be useful in almost all models. I advocate this being a default warning, but also allowing well-intentioned users to simply pass some sort of flag to disable this verification step.

I also realize there needs to be some cleanup after this step to reset the model to its previous state. Any insights here would be great as well.

Additional context

None

@TylerYep TylerYep added feature Is an improvement or enhancement help wanted Open to be worked on labels Mar 25, 2020
@github-actions
Copy link
Contributor

Hi! thanks for your contribution!, great first issue!

@awaelchli
Copy link
Contributor

I like the idea. There is a fast_dev_run flag in the Trainer, maybe the proposed check could be done when fast_dev_run is turned on?

btw your second bullet point can be done in PL with the overfit_pct flag in the Trainer :)

@TylerYep
Copy link
Contributor Author

That sounds like a great place for it, thank you for pointing me there! I will see if I can start by integrating it there first.

My only concern is that since fast_dev_run isn't be on by default, this may cause many people who are not aware of fast_dev_run to continue running code that doesn't respect the batch dimension. Would it be better to add another flag to Trainer, e.g. check_batch_dimension: Optional[int] = 0 (default batch dim is 0 as a default by PyTorch convention, None disables the check and warnings entirely)

If I am correct, to prevent breaking changes, this would not be an assertion, but rather a loud warning?

@Borda
Copy link
Member

Borda commented Mar 27, 2020

@TylerYep good point with the fast dev binary state... would it be a solution for you to have kind of three levels:

  • fast run
  • full run
  • complete (fast + full)

cc: @PyTorchLightning/core-contributors and discussion in #1081 #1087

@williamFalcon
Copy link
Contributor

@TylerYep would love a PR for this!

@williamFalcon williamFalcon added this to the 0.7.3 milestone Apr 4, 2020
@williamFalcon williamFalcon added the let's do it! approved to implement label Apr 4, 2020
@Borda Borda modified the milestones: 0.7.4, 0.7.5 Apr 24, 2020
@Borda Borda modified the milestones: 0.7.6, 0.8.0, 0.7.7 May 12, 2020
@Borda Borda modified the milestones: 0.7.7, 0.8.0 May 26, 2020
@Borda Borda modified the milestones: 0.8.0, 0.9.0 Jun 9, 2020
@Borda
Copy link
Member

Borda commented Jun 11, 2020

@TylerYep how is it going here? 🐰

@TylerYep
Copy link
Contributor Author

TylerYep commented Jun 12, 2020

Struggling a lot to understand the codebase and figure out how to fit this feature in.

I tried to fit it into evaluation_loop.py 's _evaluate() function, however I wasn't sure how to proceed - calling evaluation_forward() doesn't seem to contain the model outputs for the batch, and as written I'm not sure how to set the requires grad for the batch and disable it afterwards without creating a completely separate copy of _evaluate().

If you would like, I can make a in-progress PR, but I haven't made it very far, unfortunately.

@stale
Copy link

stale bot commented Aug 11, 2020

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the won't fix This will not be worked on label Aug 11, 2020
@Borda
Copy link
Member

Borda commented Aug 11, 2020

@TylerYep can we help you? or wait a bit after we finish the refactoring...

@stale stale bot removed the won't fix This will not be worked on label Aug 11, 2020
@awaelchli
Copy link
Contributor

I have actually implemented this in a separate class myself to verify my models and used it many times. It is a great sanity test. Maybe I can send a PR or Google Colab and @TylerYep can help me test it. We can also come up with more verification tests.

@williamFalcon
Copy link
Contributor

this is prime for a callback

@awaelchli awaelchli self-assigned this Aug 17, 2020
@TylerYep
Copy link
Contributor Author

Yeah, I would love to help test it! I haven't had the chance to work on this for a while, but if someone with more experience can lead the effort, that would be great :)

@edenlightning edenlightning added this to the 0.9.x milestone Aug 18, 2020
@awaelchli
Copy link
Contributor

awaelchli commented Aug 24, 2020

Draft here in this repo
https://github.com/awaelchli/pytorch-lightning-snippets
Polished my code and made sure it also works great with models that have multiple inputs and outputs.
I put examples in the readme. There is a regular class that works with any nn.Module and then there is also the Callback that integrates these checks easily with PL Trainer.

@TylerYep
Copy link
Contributor Author

@awaelchli looks like the repo is private

@awaelchli
Copy link
Contributor

Thanks, I changed it to public now!

@edenlightning
Copy link
Contributor

@awaelchli keep this open? do we want to include your callback in lightning?

@awaelchli
Copy link
Contributor

@edenlightning It would be great. I made an issue in bolts Lightning-Universe/lightning-bolts#194
but haven't had the time to make a PR.

@TylerYep
Copy link
Contributor Author

What's the distinction between callbacks in bolts and callbacks in the main repo?

Optimistically, a lot of these checks (e.g. batch verification) will fit well in the majority of existing lightning workflows, whereas bolts seems like a better fit for utilities that are a bit more niche or application-specific.

Thoughts?

@stale
Copy link

stale bot commented Nov 19, 2020

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

@stale stale bot added the won't fix This will not be worked on label Nov 19, 2020
@Borda Borda added Hacktoberfest and removed won't fix This will not be worked on labels Nov 19, 2020
@edenlightning edenlightning modified the milestones: 1.2, 1.3 Feb 8, 2021
@edenlightning edenlightning added good first issue Good for newcomers and removed Hacktoberfest labels Feb 16, 2021
@edenlightning edenlightning modified the milestones: v1.3, v1.4 Apr 27, 2021
@edenlightning edenlightning modified the milestones: v1.4, v1.5 Jun 30, 2021
@awaelchli awaelchli modified the milestones: v1.5, v1.6 Nov 4, 2021
@carmocca carmocca removed the good first issue Good for newcomers label Feb 1, 2022
@carmocca carmocca modified the milestones: 1.6, None Feb 1, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement help wanted Open to be worked on let's do it! approved to implement
Projects
None yet
Development

No branches or pull requests

6 participants