Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove legacy Result parameters #6016

Merged
merged 20 commits into from
Mar 28, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Resolve hiddens
  • Loading branch information
carmocca committed Mar 26, 2021
commit 82df1254fe37e31fa0a3194a712e12eb57001867
8 changes: 1 addition & 7 deletions docs/source/common/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1478,15 +1478,9 @@ with the hidden
def training_step(self, batch, batch_idx, hiddens):
# hiddens are the hiddens from the previous truncated backprop step
out, hiddens = self.lstm(data, hiddens)

# remember to detach() hiddens.
# If you don't, you will get a RuntimeError: Trying to backward through
# the graph a second time...
# Using hiddens.detach() allows each split to be disconnected.

return {
"loss": ...,
"hiddens": hiddens # remember to detach() this
"hiddens": hiddens
}

To modify how the batch is split,
Expand Down
14 changes: 10 additions & 4 deletions pytorch_lightning/trainer/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import inspect
from abc import ABC
from collections import Mapping

import torch

Expand Down Expand Up @@ -75,9 +76,7 @@ def process_dict_result(self, output, train=False):
# --------------------------
# single scalar returned from a xx_step
if isinstance(output, torch.Tensor):
progress_bar_metrics = {}
log_metrics = {}
return output, progress_bar_metrics, log_metrics
return output, {}, {}, None

# ---------------
# EXTRACT PROGRESS BAR KEYS
Expand Down Expand Up @@ -134,12 +133,19 @@ def process_dict_result(self, output, train=False):
if self._distrib_type in (DistributedType.DP, DistributedType.DDP2):
loss = self.reduce_distributed_output(loss, self.num_gpus)

# ---------------
# EXTRACT HIDDEN
# ---------------
hiddens = output.get('hiddens', None) if isinstance(output, Mapping) else None
if hiddens is not None:
hiddens = hiddens.detach()

# detach all metrics for callbacks to prevent memory leaks
# no .item() because it will slow things down
progress_bar_metrics = recursive_detach(progress_bar_metrics)
log_metrics = recursive_detach(log_metrics)

return loss, progress_bar_metrics, log_metrics
return loss, progress_bar_metrics, log_metrics, hiddens

def reduce_distributed_output(self, output, num_gpus):
if num_gpus <= 1:
Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,19 +354,22 @@ def _process_training_step_output_1_0(self, training_step_output, split_batch):
result = self.trainer.lightning_module._results

loss = None
hiddens = None
result["extra"] = {}

# handle dict return
if isinstance(training_step_output, dict):
loss = training_step_output.pop("loss", None)
hiddens = training_step_output.pop("hiddens", None)
result["extra"] = training_step_output

# handle scalar return
elif isinstance(training_step_output, torch.Tensor):
loss = training_step_output
result["extra"] = {}

# map to results under the hood
result.minimize = loss
self.trainer.hiddens = hiddens

# track batch for manual reduction with result
result.track_batch_size(len(split_batch))
Expand Down
3 changes: 2 additions & 1 deletion tests/trainer/logging_/test_logger_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def training_step_end(self, *_):
assert generated == excepted


def test__logger_connector__epoch_result_store__train__ttbt(tmpdir):
def test__logger_connector__epoch_result_store__train__tbptt(tmpdir):
"""
Tests that LoggerConnector will properly capture logged information with ttbt
and reduce them
Expand Down Expand Up @@ -142,6 +142,7 @@ def __init__(self):

@decorator_with_arguments(fx_name="training_step")
def training_step(self, batch, batch_idx, hiddens):
assert hiddens == self.test_hidden, "Hidden state not persistent between tbptt steps"
self.test_hidden = torch.rand(1)

x_tensor, y_list = batch
Expand Down
7 changes: 1 addition & 6 deletions tests/trainer/logging_/test_train_loop_logging_1_0.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,12 +318,7 @@ def __init__(self):
self.layer = torch.nn.Linear(2, 2)

def training_step(self, batch, batch_idx, hiddens):
try:
assert hiddens == self.test_hidden, "Hidden state not persistent between tbptt steps"
# todo: specify the possible exception
except Exception as ex:
print(ex)

assert hiddens == self.test_hidden, "Hidden state not persistent between tbptt steps"
self.test_hidden = torch.rand(1)

x_tensor, y_list = batch
Expand Down