From 20b49e530a80ee30a2e8b5fd1fd9ba26dac9d562 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= <aedu.waelchli@gmail.com>
Date: Fri, 19 Feb 2021 15:42:51 +0100
Subject: [PATCH 1/9] precision fixes

---
 pytorch_lightning/overrides/base.py               |  3 ++-
 pytorch_lightning/plugins/precision/native_amp.py |  3 ++-
 tests/models/test_amp.py                          |  5 +++--
 tests/overrides/test_data_parallel.py             | 13 +++++++++++++
 4 files changed, 20 insertions(+), 4 deletions(-)

diff --git a/pytorch_lightning/overrides/base.py b/pytorch_lightning/overrides/base.py
index 2fcb4b11a0b7f9..b5f932926f3890 100644
--- a/pytorch_lightning/overrides/base.py
+++ b/pytorch_lightning/overrides/base.py
@@ -19,12 +19,13 @@
 
 from pytorch_lightning.core.lightning import LightningModule
 from pytorch_lightning.trainer.states import RunningStage
+from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
 from pytorch_lightning.utilities.warnings import WarningCache
 
 warning_cache = WarningCache()
 
 
-class _LightningModuleWrapperBase(torch.nn.Module):
+class _LightningModuleWrapperBase(DeviceDtypeModuleMixin, torch.nn.Module):
 
     def __init__(self, pl_module: LightningModule):
         """
diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py
index 60c0f5f84626f5..94e6cf376b03a5 100644
--- a/pytorch_lightning/plugins/precision/native_amp.py
+++ b/pytorch_lightning/plugins/precision/native_amp.py
@@ -91,4 +91,5 @@ def post_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> None:
     @contextmanager
     def train_step_context(self) -> Generator[autocast, None, None]:
         """Enable autocast context"""
-        yield torch.cuda.amp.autocast()
+        with torch.cuda.amp.autocast():
+            yield
diff --git a/tests/models/test_amp.py b/tests/models/test_amp.py
index 2dd6c9d997dbf9..6b98d8f2c47cf7 100644
--- a/tests/models/test_amp.py
+++ b/tests/models/test_amp.py
@@ -63,7 +63,7 @@ def test_amp_single_gpu_ddp_spawn(tmpdir):
     model = BoringModel()
     # tutils.run_model_test(trainer_options, model)
     trainer.fit(model)
-
+    assert torch.is_autocast_enabled()
     assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
 
 
@@ -103,7 +103,7 @@ def test_amp_multi_gpu_ddp_spawn(tmpdir):
     model = BoringModel()
     # tutils.run_model_test(trainer_options, model)
     trainer.fit(model)
-
+    assert torch.is_autocast_enabled()
     assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
 
 
@@ -152,6 +152,7 @@ def test_amp_gpu_ddp_slurm_managed(tmpdir):
     assert trainer.training_type_plugin.cluster_environment.resolve_root_node_address('abc[23-24]') == 'abc23'
     generated = trainer.training_type_plugin.cluster_environment.resolve_root_node_address('abc[23-24, 45-40, 40]')
     assert generated == 'abc23'
+    assert torch.is_autocast_enabled()
 
 
 @pytest.mark.skipif(torch.cuda.is_available(), reason="test is restricted only on CPU")
diff --git a/tests/overrides/test_data_parallel.py b/tests/overrides/test_data_parallel.py
index 64481bd70390d4..9b46ce847b80d3 100644
--- a/tests/overrides/test_data_parallel.py
+++ b/tests/overrides/test_data_parallel.py
@@ -153,3 +153,16 @@ def training_step(self, batch, batch_idx):
     wrapped_model = LightningParallelModule(model)
     output = wrapped_model(batch, batch_idx)
     assert output["python scalar"] == torch.tensor([12.3], device=device)
+
+
+@pytest.mark.parametrize("wrapper_class", [
+    LightningParallelModule,
+    LightningDistributedModule,
+])
+def test_dtype_device_access(wrapper_class):
+    """ Test that device and dtype attributes are accessible through the wrapper. """
+    model = BoringModel()
+    assert model.dtype == torch.float32
+    wrapped_model = wrapper_class(model)
+    wrapped_model.half()
+    assert model.dtype == torch.float16

From 1dd56de9968872ec2f6a1d8a558d907a662cf460 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= <aedu.waelchli@gmail.com>
Date: Fri, 19 Feb 2021 15:58:15 +0100
Subject: [PATCH 2/9] add amp test model

---
 tests/models/test_amp.py | 23 ++++++++++++++++++-----
 1 file changed, 18 insertions(+), 5 deletions(-)

diff --git a/tests/models/test_amp.py b/tests/models/test_amp.py
index 6b98d8f2c47cf7..68a1ae27c044ae 100644
--- a/tests/models/test_amp.py
+++ b/tests/models/test_amp.py
@@ -27,6 +27,19 @@
 from tests.helpers import BoringModel
 
 
+class AMPTestModel(BoringModel):
+
+    def forward(*args, **kwargs):
+        assert torch.is_autocast_enabled()
+        super().forward(*args, **kwargs)
+
+    def training_step(self, batch, batch_idx):
+        output = super().training_step(batch, batch_idx)
+        loss = output["loss"]
+        assert loss.dtype == torch.float16
+        return output
+
+
 @pytest.mark.skip(reason='dp + amp not supported currently')  # TODO
 @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
 def test_amp_single_gpu_dp(tmpdir):
@@ -41,7 +54,7 @@ def test_amp_single_gpu_dp(tmpdir):
         precision=16,
     )
 
-    model = BoringModel()
+    model = AMPTestModel()
     # tutils.run_model_test(trainer_options, model)
     trainer.fit(model)
 
@@ -60,7 +73,7 @@ def test_amp_single_gpu_ddp_spawn(tmpdir):
         precision=16,
     )
 
-    model = BoringModel()
+    model = AMPTestModel()
     # tutils.run_model_test(trainer_options, model)
     trainer.fit(model)
     assert torch.is_autocast_enabled()
@@ -81,7 +94,7 @@ def test_amp_multi_gpu_dp(tmpdir):
         precision=16,
     )
 
-    model = BoringModel()
+    model = AMPTestModel()
     # tutils.run_model_test(trainer_options, model)
     trainer.fit(model)
 
@@ -100,7 +113,7 @@ def test_amp_multi_gpu_ddp_spawn(tmpdir):
         precision=16,
     )
 
-    model = BoringModel()
+    model = AMPTestModel()
     # tutils.run_model_test(trainer_options, model)
     trainer.fit(model)
     assert torch.is_autocast_enabled()
@@ -122,7 +135,7 @@ def test_amp_gpu_ddp_slurm_managed(tmpdir):
     # simulate setting slurm flags
     tutils.set_random_master_port()
 
-    model = BoringModel()
+    model = AMPTestModel()
 
     # exp file to get meta
     logger = tutils.get_default_logger(tmpdir)

From 3d69769ce86892ce4a1009f727b7afb1b3ea40ea Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= <aedu.waelchli@gmail.com>
Date: Fri, 19 Feb 2021 16:07:06 +0100
Subject: [PATCH 3/9] fix test

---
 tests/models/test_amp.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/tests/models/test_amp.py b/tests/models/test_amp.py
index 68a1ae27c044ae..eb21cd7652f772 100644
--- a/tests/models/test_amp.py
+++ b/tests/models/test_amp.py
@@ -29,9 +29,9 @@
 
 class AMPTestModel(BoringModel):
 
-    def forward(*args, **kwargs):
+    def forward(self, *args, **kwargs):
         assert torch.is_autocast_enabled()
-        super().forward(*args, **kwargs)
+        return super().forward(*args, **kwargs)
 
     def training_step(self, batch, batch_idx):
         output = super().training_step(batch, batch_idx)

From f44c91ada922fd99ef9e6e6a4db98709fe1aa278 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= <aedu.waelchli@gmail.com>
Date: Fri, 19 Feb 2021 16:09:27 +0100
Subject: [PATCH 4/9] revert

---
 tests/models/test_amp.py | 3 ---
 1 file changed, 3 deletions(-)

diff --git a/tests/models/test_amp.py b/tests/models/test_amp.py
index eb21cd7652f772..2f0ff664f12c63 100644
--- a/tests/models/test_amp.py
+++ b/tests/models/test_amp.py
@@ -76,7 +76,6 @@ def test_amp_single_gpu_ddp_spawn(tmpdir):
     model = AMPTestModel()
     # tutils.run_model_test(trainer_options, model)
     trainer.fit(model)
-    assert torch.is_autocast_enabled()
     assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
 
 
@@ -116,7 +115,6 @@ def test_amp_multi_gpu_ddp_spawn(tmpdir):
     model = AMPTestModel()
     # tutils.run_model_test(trainer_options, model)
     trainer.fit(model)
-    assert torch.is_autocast_enabled()
     assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
 
 
@@ -165,7 +163,6 @@ def test_amp_gpu_ddp_slurm_managed(tmpdir):
     assert trainer.training_type_plugin.cluster_environment.resolve_root_node_address('abc[23-24]') == 'abc23'
     generated = trainer.training_type_plugin.cluster_environment.resolve_root_node_address('abc[23-24, 45-40, 40]')
     assert generated == 'abc23'
-    assert torch.is_autocast_enabled()
 
 
 @pytest.mark.skipif(torch.cuda.is_available(), reason="test is restricted only on CPU")

From 3935b6418d6c85ec191f9e9d88cc27a97cb40480 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= <aedu.waelchli@gmail.com>
Date: Fri, 19 Feb 2021 16:47:52 +0100
Subject: [PATCH 5/9] move assert to training step

---
 tests/models/test_amp.py | 5 +----
 1 file changed, 1 insertion(+), 4 deletions(-)

diff --git a/tests/models/test_amp.py b/tests/models/test_amp.py
index 2f0ff664f12c63..477072048ea3ae 100644
--- a/tests/models/test_amp.py
+++ b/tests/models/test_amp.py
@@ -29,11 +29,8 @@
 
 class AMPTestModel(BoringModel):
 
-    def forward(self, *args, **kwargs):
-        assert torch.is_autocast_enabled()
-        return super().forward(*args, **kwargs)
-
     def training_step(self, batch, batch_idx):
+        assert torch.is_autocast_enabled()
         output = super().training_step(batch, batch_idx)
         loss = output["loss"]
         assert loss.dtype == torch.float16

From f3cb7a3ecfe18711add97e18b6ee42e9b40765aa Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= <aedu.waelchli@gmail.com>
Date: Fri, 19 Feb 2021 16:50:52 +0100
Subject: [PATCH 6/9] fix test

---
 tests/models/test_amp.py | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)

diff --git a/tests/models/test_amp.py b/tests/models/test_amp.py
index 477072048ea3ae..7058f50ed0052e 100644
--- a/tests/models/test_amp.py
+++ b/tests/models/test_amp.py
@@ -31,10 +31,10 @@ class AMPTestModel(BoringModel):
 
     def training_step(self, batch, batch_idx):
         assert torch.is_autocast_enabled()
-        output = super().training_step(batch, batch_idx)
-        loss = output["loss"]
-        assert loss.dtype == torch.float16
-        return output
+        output = self(batch)
+        assert output.dtype == torch.float16
+        loss = self.loss(batch, output)
+        return {"loss": loss}
 
 
 @pytest.mark.skip(reason='dp + amp not supported currently')  # TODO
@@ -66,7 +66,7 @@ def test_amp_single_gpu_ddp_spawn(tmpdir):
         default_root_dir=tmpdir,
         max_epochs=1,
         gpus=1,
-        accelerator='ddp_spawn',
+        # accelerator='ddp_spawn',
         precision=16,
     )
 

From 7ca2d0b498c5bd4c5abbec244e2e52767670e04a Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= <aedu.waelchli@gmail.com>
Date: Fri, 19 Feb 2021 16:51:07 +0100
Subject: [PATCH 7/9] fix test

---
 tests/models/test_amp.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/tests/models/test_amp.py b/tests/models/test_amp.py
index 7058f50ed0052e..53ec32764f3ed8 100644
--- a/tests/models/test_amp.py
+++ b/tests/models/test_amp.py
@@ -66,7 +66,7 @@ def test_amp_single_gpu_ddp_spawn(tmpdir):
         default_root_dir=tmpdir,
         max_epochs=1,
         gpus=1,
-        # accelerator='ddp_spawn',
+        accelerator='ddp_spawn',
         precision=16,
     )
 

From 452eb2613e43cc1b13487b01a94102aa40926205 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= <aedu.waelchli@gmail.com>
Date: Fri, 19 Feb 2021 17:18:18 +0100
Subject: [PATCH 8/9] remove unrelated changes

---
 pytorch_lightning/overrides/base.py   |  2 +-
 tests/overrides/test_data_parallel.py | 13 -------------
 2 files changed, 1 insertion(+), 14 deletions(-)

diff --git a/pytorch_lightning/overrides/base.py b/pytorch_lightning/overrides/base.py
index b5f932926f3890..535ebde4c169e8 100644
--- a/pytorch_lightning/overrides/base.py
+++ b/pytorch_lightning/overrides/base.py
@@ -25,7 +25,7 @@
 warning_cache = WarningCache()
 
 
-class _LightningModuleWrapperBase(DeviceDtypeModuleMixin, torch.nn.Module):
+class _LightningModuleWrapperBase(torch.nn.Module):
 
     def __init__(self, pl_module: LightningModule):
         """
diff --git a/tests/overrides/test_data_parallel.py b/tests/overrides/test_data_parallel.py
index 9b46ce847b80d3..64481bd70390d4 100644
--- a/tests/overrides/test_data_parallel.py
+++ b/tests/overrides/test_data_parallel.py
@@ -153,16 +153,3 @@ def training_step(self, batch, batch_idx):
     wrapped_model = LightningParallelModule(model)
     output = wrapped_model(batch, batch_idx)
     assert output["python scalar"] == torch.tensor([12.3], device=device)
-
-
-@pytest.mark.parametrize("wrapper_class", [
-    LightningParallelModule,
-    LightningDistributedModule,
-])
-def test_dtype_device_access(wrapper_class):
-    """ Test that device and dtype attributes are accessible through the wrapper. """
-    model = BoringModel()
-    assert model.dtype == torch.float32
-    wrapped_model = wrapper_class(model)
-    wrapped_model.half()
-    assert model.dtype == torch.float16

From 1074aff8f7aa1b294610b9311568ddf5c5f2c999 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= <aedu.waelchli@gmail.com>
Date: Fri, 19 Feb 2021 17:20:11 +0100
Subject: [PATCH 9/9] add changelog

---
 CHANGELOG.md | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index 2ad54381a082b1..7dad863d412933 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -21,6 +21,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
 
 ### Fixed
 
+- Fixed incorrect yield logic for the amp autocast context manager ([#6080](https://github.com/PyTorchLightning/pytorch-lightning/pull/6080))
+
 
 ## [1.2.0] - 2021-02-18