Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for incorrect usage of detach(), cpu(), to() #6216

Merged
merged 12 commits into from
Mar 1, 2021
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed multiple early stopping callbacks ([#6197](https://github.com/PyTorchLightning/pytorch-lightning/pull/6197))


- Fixed incorrect usage of `detach()`, `cpu()`, `to()` ([#6216](https://github.com/PyTorchLightning/pytorch-lightning/pull/6216))


- Fixed LBFGS optimizer support which didn't converge in automatic optimization ([#6147](https://github.com/PyTorchLightning/pytorch-lightning/pull/6147))


Expand Down
10 changes: 6 additions & 4 deletions pytorch_lightning/core/step_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,20 +416,22 @@ def get_batch_pbar_metrics(self, include_forked_originals=True, add_dataloader_i

return result

def detach(self):
def detach(self) -> 'Result':
for k, v in self.items():
if isinstance(v, torch.Tensor):
self.__setitem__(k, v.detach())
return self

def to(self, *args, **kwargs):
def to(self, *args, **kwargs) -> 'Result':
"""Move all self attributes to the given device."""
for k, v in self.items():
if isinstance(v, torch.Tensor):
self.__setitem__(k, v.to(*args, **kwargs))
return self

def cpu(self):
def cpu(self) -> 'Result':
"""Move all self attributes to CPU."""
self.to(torch.device("cpu"))
return self.to(torch.device("cpu"))

def __repr__(self):
self_copy = self.copy()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -281,11 +281,11 @@ def cache_result(self) -> None:
# attach capture batch_size
Result.attach_batch_size(self._batch_size, hook_result)

hook_result.detach()
hook_result = hook_result.detach()
if self.trainer.move_metrics_to_cpu:
hook_result.cpu()
hook_result = hook_result.cpu()
elif self.trainer._distrib_type == DistributedType.DP:
hook_result.to(torch.device("cuda", self.trainer.root_gpu))
hook_result = hook_result.to(torch.device("cuda", self.trainer.root_gpu))

self._internals[fx_name].append(hook_result, info)

Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,9 +736,9 @@ def run_evaluation(self, max_batches=None, on_epoch=False):
def track_output_for_epoch_end(self, outputs, output):
if output is not None:
if isinstance(output, Result):
output.detach()
output = output.detach()
if self.move_metrics_to_cpu:
output.cpu()
output = output.cpu()
elif isinstance(output, dict):
output = recursive_detach(output, to_cpu=self.move_metrics_to_cpu)
elif isinstance(output, torch.Tensor) and output.is_cuda and self.move_metrics_to_cpu:
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def on_after_backward(self, training_step_output, batch_idx, untouched_loss):
is_result_obj = isinstance(training_step_output, Result)

if is_result_obj:
training_step_output.detach()
training_step_output = training_step_output.detach()
else:
training_step_output.batch_loss = training_step_output.batch_loss.detach()

Expand Down Expand Up @@ -395,9 +395,9 @@ def _process_training_step_output_1_0(self, training_step_output, split_batch):

# track metrics without grads for epoch reduction
training_step_output_for_epoch_end = copy(result)
training_step_output_for_epoch_end.detach()
training_step_output_for_epoch_end = training_step_output_for_epoch_end.detach()
if self.trainer.move_metrics_to_cpu:
training_step_output_for_epoch_end.cpu()
training_step_output_for_epoch_end = training_step_output_for_epoch_end.cpu()

# what flows back into the system
training_step_output = result
Expand Down
3 changes: 1 addition & 2 deletions tests/overrides/test_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,7 @@ def training_step(self, batch, batch_idx):
output.update({"python scalar": 12.3})
return output

model = TestModel()
model.to(device)
model = TestModel().to(device)
model.trainer = Mock()
model.trainer._running_stage = RunningStage.TRAINING
batch = torch.rand(2, 32).to(device)
Expand Down