Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make _setup_fill_arg serializable #6730

Merged
merged 2 commits into from
Oct 10, 2022
Merged

Conversation

pmeier
Copy link
Collaborator

@pmeier pmeier commented Oct 10, 2022

Addresses #6728. The only other usages of lambda are happening in the AA transforms:

"ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),

This was a perf optimization we did in v2. In v1 we have

"ShearX": (torch.linspace(0.0, 0.3, num_bins), True),

Since we want to use AA transforms for videos as well (we do, don't we?), we probably either need to revert this or use a similar technique I used here. I'll look into this next.

As explained in #6728, we don't have unified testing for transforms yet. Thus, I cannot guarantee that these two instances cover every non-serializable things on transforms v2. I will work on testing when video training is up and running.

@datumbox
Copy link
Contributor

@pmeier Looks good. Let's do the same for AA.

@pmeier
Copy link
Collaborator Author

pmeier commented Oct 10, 2022

Not sure why, but it seems AA seems to work as is:

import pickle

import torch
from torchvision.prototype import transforms


image = torch.randint(0, 256, (3, 512, 512), dtype=torch.uint8)

for transform in [
    transforms.AutoAugment(),
    transforms.AugMix(),
    transforms.AutoAugment(),
    transforms.RandAugment(),
    transforms.TrivialAugmentWide(),
]:
    serialized = pickle.dumps(transform)
    deserialized = pickle.loads(serialized)

    deserialized(image)

I suggested we don't make it more complicated until we actually hit a problem. Using the same technique that I used here, is not really possible for AA, since the return value is dependent on the input parameters.

@datumbox
Copy link
Contributor

Using the same technique that I used here, is not really possible for AA, since the return value is dependent on the input parameters.

Isn't the return type of the lambda always Optional[Tensor]?

@pmeier
Copy link
Collaborator Author

pmeier commented Oct 10, 2022

Isn't the return type of the lambda always Optional[Tensor]?

Yeah, but the return value is different and depends on the inputs. functools.partial only works in this PR, because the return type is static and we know it in advance. For AA we would need to replace

"ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins), True),

with something like

def _auto_augment_shearx(num_bins: int, height: int, width: int) -> torch.Tensor:
    return torch.linspace(0.0, 0.3, num_bins)

If we would know num_bins at class definition, we could do something like

functools.partial(torch.linspace, 0.0, 0.3, STATIC_NUM_BINS)

This is indeed fixed for some of the AA transforms like

magnitudes = magnitudes_fn(10, height, width)

but RandAugment for example, takes this as input parameter:

num_magnitude_bins: int = 31,

Worse, although not used often, know anything about height and width at definition time. Thus, all definitions that use height or width would need to have a standalone function.

@pmeier pmeier merged commit 019139f into pytorch:main Oct 10, 2022
@pmeier pmeier deleted the fill-serialize branch October 10, 2022 09:16
facebook-github-bot pushed a commit that referenced this pull request Oct 17, 2022
Reviewed By: NicolasHug

Differential Revision: D40427456

fbshipit-source-id: 5b0e98c73906a2ed2a66045e6608ce2aef09c003
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants