diff --git a/CHANGELOG.md b/CHANGELOG.md index 66722586deebb..0fd3eed24e0d1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -510,6 +510,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed DeepSpeed Windows support ([#8488](https://github.com/PyTorchLightning/pytorch-lightning/pull/8488)) +- Enabled manual optimization for TPUs ([#8458](https://github.com/PyTorchLightning/pytorch-lightning/pull/8458)) + + - Fixed `accumulate_grad_batches` not been recomputed during model reload ([#5334](https://github.com/PyTorchLightning/pytorch-lightning/pull/5334)) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 64f10f30ea5a8..072c69c1f8aec 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -337,7 +337,7 @@ def model_to_device(self): def pre_backward(self, closure_loss: torch.Tensor) -> None: """Run before precision plugin executes backward""" - if not self.lightning_module.automatic_optimization and self.model.require_backward_grad_sync: + if not self.lightning_module.automatic_optimization: prepare_for_backward(self.model, closure_loss) def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Union[ReduceOp, str] = "mean") -> torch.Tensor: diff --git a/tests/accelerators/test_tpu_backend.py b/tests/accelerators/test_tpu_backend.py index 175344f897803..c54d9bd905fb2 100644 --- a/tests/accelerators/test_tpu_backend.py +++ b/tests/accelerators/test_tpu_backend.py @@ -11,6 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License +import collections +from copy import deepcopy + import pytest import torch from torch import nn @@ -18,6 +21,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.accelerators.cpu import CPUAccelerator from pytorch_lightning.accelerators.tpu import TPUAccelerator +from pytorch_lightning.callbacks import Callback from pytorch_lightning.plugins import TPUSpawnPlugin from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel @@ -186,3 +190,77 @@ def test_set_devices_if_none_tpu(): trainer = Trainer(accelerator="tpu", tpu_cores=8) assert trainer.devices == 8 + + +@RunIf(tpu=True) +def test_manual_optimization_tpus(tmpdir): + + class ManualOptimizationModel(BoringModel): + + count = 0 + called = collections.defaultdict(int) + + def __init__(self): + super().__init__() + self.automatic_optimization = False + + @property + def should_update(self): + return self.count % 2 == 0 + + def on_train_batch_start(self, batch, batch_idx, dataloader_idx): + self.called["on_train_batch_start"] += 1 + self.weight_before = self.layer.weight.clone() + + def training_step(self, batch, batch_idx): + self.called["training_step"] += 1 + opt = self.optimizers() + output = self.layer(batch) + loss = self.loss(batch, output) + + if self.should_update: + self.manual_backward(loss) + opt.step() + opt.zero_grad() + return loss + + def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): + self.called["on_train_batch_end"] += 1 + after_before = self.layer.weight.clone() + if self.should_update: + assert not torch.equal(self.weight_before, after_before), self.count + else: + assert torch.equal(self.weight_before, after_before) + assert torch.all(self.layer.weight.grad == 0) + self.count += 1 + + def on_train_end(self): + assert self.called["training_step"] == 5 + assert self.called["on_train_batch_start"] == 5 + assert self.called["on_train_batch_end"] == 5 + + class TestManualOptimizationCallack(Callback): + + def on_train_end(self, trainer, pl_module): + + opt = pl_module.optimizers() + assert opt._total_optimizer_step_calls == 3 + + model = ManualOptimizationModel() + model_copy = deepcopy(model) + model.training_step_end = None + model.training_epoch_end = None + + trainer = Trainer( + max_epochs=1, + default_root_dir=tmpdir, + limit_train_batches=5, + limit_test_batches=0, + limit_val_batches=0, + tpu_cores=8, + callbacks=[TestManualOptimizationCallack()] + ) + trainer.fit(model) + + for param, param_copy in zip(model.parameters(), model_copy.parameters()): + assert not torch.equal(param.cpu().data, param_copy.data)