Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Feb 27, 2024
1 parent 3cdd66f commit b8752e4
Showing 1 changed file with 9 additions and 10 deletions.
19 changes: 9 additions & 10 deletions examples/pytorch/bug_report/bug_report_model.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
from contextlib import contextmanager
from typing import Literal, Optional, Generator
from typing import Generator, Literal, Optional

import torch
from lightning.pytorch.plugins.precision import MixedPrecisionPlugin



class PipelineMixedPrecisionPlugin(MixedPrecisionPlugin):
""" Overrides PTL autocasting to not wrap training/val/test_step.
We do this because we have the megatron-core fwd/bwd functions in training_step.
This means .backward is being called in training_step so we do not want the whole
step wrapped in autocast.
"""Overrides PTL autocasting to not wrap training/val/test_step. We do this because we have the megatron-core
fwd/bwd functions in training_step. This means .backward is being called in training_step so we do not want the
whole step wrapped in autocast.
We instead wrap the fwd_output_and_loss_func that is passed to the megatron-core fwd/bwd functions.
We instead wrap the fwd_output_and_loss_func that is passed to the megatron-core fwd/bwd functions.
"""

def __init__(
Expand All @@ -24,9 +23,9 @@ def __init__(
super().__init__(precision, device, scaler=scaler)
dtype = None
# MixedPrecisionPlugin class in PTL >= 2.0 takes only "16-mixed" or "bf16-mixed" for precision arg
if precision == '16-mixed':
if precision == "16-mixed":
dtype = torch.float16
elif precision == 'bf16-mixed':
elif precision == "bf16-mixed":
dtype = torch.bfloat16

torch.set_autocast_gpu_dtype(dtype)
Expand All @@ -37,4 +36,4 @@ def forward_context(self) -> Generator[None, None, None]:
yield


PipelineMixedPrecisionPlugin(precision="16-mixed", device="cuda:0")
PipelineMixedPrecisionPlugin(precision="16-mixed", device="cuda:0")

0 comments on commit b8752e4

Please sign in to comment.