Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make both torch.amp and apex.amp available as backend for mixed precision training #91

Merged
merged 4 commits into from
Jan 26, 2024

Conversation

NaleRaphael
Copy link
Contributor

@NaleRaphael NaleRaphael commented Jan 13, 2024

Hi @davidtvs, here is a summary of changes made in this PR for issue #67, #90.

Required changes to integrate torch.amp

  • torch.cuda.amp.GradScaler has to be passed into LRFinder, because it's will be used in the following stages:
    • backward pass:
      # without AMP (automatic precision training), or using `apex.amp`
      loss.backward()
      
      # using `torch.amp`
      scaler.scale(loss).backward()
    • optimizer steps:
      # without AMP, or using `apex.amp`
      optimizer.step()
      
      # using `torch.amp`
      scaler.step(optimizer)
      scaler.update()
  • torch.amp.autocast() need to be called in forward pass:
    # without AMP, or using `apex.amp`
    outputs = self.model(inputs)
    loss = self.criterion(outputs, labels)
    
    # using `torch.amp`
    with torch.amp.autocast(device_type=..., dtype=...):
      outputs = self.model(inputs)
      loss = self.criterion(outputs, labels)

Proposed changes

3 new keyword arguments are added to LRFinder.__init__():

  • amp_backend: a string to select AMP backend
  • amp_config: a dict to store arguments required by torch.amp.autocast()
  • grad_scaler: a torch.cuda.amp.GradScaler instance to be used in LRFinder._train_batch()

This should maximize the flexibility for user to control how AMP works with LRFinder.

If there is a need to apply advanced tricks with torch.amp (e.g., for multi-GPUs/models/losses 1), it's still achievable by just overriding LRFinder._train_batch(). So we can focus on the current implementation for gradient accumulation without worrying about other variants.

Note

The new script examples/mnist_with_amp.py can be used to check the results produced by running LRFinder with different AMP backends. Here are the results produced on my machine with command $ python mnist_with_amp.py --batch_size=32 --tqdm --amp_backend=...:

without AMP apex.amp torch.amp

* In these 3 figures, suggested LRs are all the same one: 2.42E-01

Package information:

  • python: 3.9.18
  • torch: 1.13.1+cu117
  • apex: 0.1 (compiled from source, and I need to checkout to revision 2386a912164b0c5cfcd8be7a2b890fbac5607c82 to build. see also this comment)

As always, feel free to let me know if there is anything can be improved.

User can choose the backend for mixed precision training by specifying
keyword argument `amp_backend` to `LRFinder` now.
…mulation is enabled

Since further advanced tricks for gradient accumulation can be done
by overriding `LRFinder._train_batch()`, it seems it's not necessary
to do it by our own. Also, removing it can make less surprises once
there is overflow while training in lower precision.

See also this section in `apex` document:
https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations
@NaleRaphael
Copy link
Contributor Author

According to this issue actions/setup-python#544, it seems supports for Python<3.6 are dropped after the base image is updated to Ubuntu 22.04.
Also, torch.amp is released in PyTorch 1.6, and the newly added unit tests are designed to test it to run with GPU. So it might not be able to test with current CI setup.

@davidtvs
Copy link
Owner

@NaleRaphael, thanks for the PR, looks good. I've updated the CI workflow, it should be functional again. You probably have to rebase this branch.

@NaleRaphael
Copy link
Contributor Author

Sure, wait for a moment and I'll update it.

@NaleRaphael
Copy link
Contributor Author

NaleRaphael commented Jan 20, 2024

updated: fix reference links
Fixed f-string issue (to be compatible with Py 3.5).

By the way, there is one thing I forget to mention. Since the default value of the new keyword argument amp_backend is None, AMP won't be enabled automatically even though apex has been installed. The behavior is different to previous revisions because it would use apex.amp automatically if apex was available.

This might be a breaking change to user. However, current behavior seems more reasonable to me now. Because mixed precision training provided by apex would only work with GPUs with tensor cores (RTX series). Though it can run on some pre-Turing cards (see also here), performance won't be as good as training in fp32, see also this comment. Therefore, defaulting amp_backend to None should avoid resulting some surprises to users.

Copy link
Owner

@davidtvs davidtvs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG! Thanks @NaleRaphael

@davidtvs davidtvs merged commit fd3b5c8 into davidtvs:master Jan 26, 2024
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants