Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove evaluation loop legacy dict returns for *_epoch_end hooks #6973

Merged
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from pytorch_lightning.trainer.connectors.logger_connector.epoch_result_store import EpochResultStore
from pytorch_lightning.trainer.connectors.logger_connector.metrics_holder import MetricsHolder
from pytorch_lightning.trainer.states import RunningStage, TrainerState
from pytorch_lightning.utilities import DeviceType, flatten_dict
from pytorch_lightning.utilities import DeviceType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden

Expand Down Expand Up @@ -297,33 +297,6 @@ def get_evaluate_epoch_results(self):
self.eval_loop_results = []
return results

def _track_callback_metrics(self, eval_results):
if len(eval_results) > 0 and (eval_results[0] is None or not isinstance(eval_results[0], Result)):
return

flat = {}
if isinstance(eval_results, list):
for eval_result in eval_results:
# with a scalar return, auto set it to "val_loss" for callbacks
if isinstance(eval_result, torch.Tensor):
flat = {'val_loss': eval_result}
elif isinstance(eval_result, dict):
flat = flatten_dict(eval_result)

self.trainer.logger_connector.callback_metrics.update(flat)
if self.trainer.state in (TrainerState.TESTING, TrainerState.VALIDATING):
self.trainer.logger_connector.evaluation_callback_metrics.update(flat)
else:
# with a scalar return, auto set it to "val_loss" for callbacks
if isinstance(eval_results, torch.Tensor):
flat = {'val_loss': eval_results}
else:
flat = flatten_dict(eval_results)

self.trainer.logger_connector.callback_metrics.update(flat)
if self.trainer.state in (TrainerState.TESTING, TrainerState.VALIDATING):
self.trainer.logger_connector.evaluation_callback_metrics.update(flat)

def on_train_epoch_end(self):
# inform cached logger connector epoch finished
self.cached_results.has_batch_loop_finished = True
Expand Down
48 changes: 6 additions & 42 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,61 +190,25 @@ def evaluation_epoch_end(self):
# unset dataloder_idx in model
self.trainer.logger_connector.evaluation_epoch_end()

# call the model epoch end
deprecated_results = self.__run_eval_epoch_end(self.num_dataloaders)

# enable returning anything
for i, r in enumerate(deprecated_results):
if not isinstance(r, (dict, Result, torch.Tensor)):
deprecated_results[i] = []

return deprecated_results

def log_epoch_metrics_on_evaluation_end(self):
# get the final loop results
eval_loop_results = self.trainer.logger_connector.get_evaluate_epoch_results()
return eval_loop_results

def __run_eval_epoch_end(self, num_dataloaders):
model = self.trainer.lightning_module

# with a single dataloader don't pass an array
outputs = self.outputs
# with a single dataloader don't pass an array
eval_results = outputs[0] if self.num_dataloaders == 1 else outputs

eval_results = outputs
if num_dataloaders == 1:
eval_results = outputs[0]

user_reduced = False
# call the model epoch end
model = self.trainer.lightning_module

if self.trainer.testing:
if is_overridden('test_epoch_end', model=model):
model._current_fx_name = 'test_epoch_end'
eval_results = model.test_epoch_end(eval_results)
user_reduced = True
model.test_epoch_end(eval_results)

else:
if is_overridden('validation_epoch_end', model=model):
model._current_fx_name = 'validation_epoch_end'
eval_results = model.validation_epoch_end(eval_results)
user_reduced = True
model.validation_epoch_end(eval_results)

# capture logging
self.trainer.logger_connector.cache_logged_metrics()
# depre warning
if eval_results is not None and user_reduced:
step = 'testing_epoch_end' if self.trainer.testing else 'validation_epoch_end'
self.warning_cache.warn(
f'The {step} should not return anything as of 9.1.'
' To log, use self.log(...) or self.write(...) directly in the LightningModule'
)

if not isinstance(eval_results, list):
eval_results = [eval_results]

self.trainer.logger_connector._track_callback_metrics(eval_results)

return eval_results

def __gather_epoch_end_eval_results(self, outputs):
eval_results = []
Expand Down
10 changes: 5 additions & 5 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,7 @@ def run_evaluation(self, on_epoch=False):
self.evaluation_loop.outputs.append(dl_outputs)

# lightning module method
deprecated_eval_results = self.evaluation_loop.evaluation_epoch_end()
self.evaluation_loop.evaluation_epoch_end()

# hook
self.evaluation_loop.on_evaluation_epoch_end()
Expand All @@ -725,7 +725,7 @@ def run_evaluation(self, on_epoch=False):
self.evaluation_loop.on_evaluation_end()

# log epoch metrics
eval_loop_results = self.evaluation_loop.log_epoch_metrics_on_evaluation_end()
eval_loop_results = self.logger_connector.get_evaluate_epoch_results()

# save predictions to disk
self.evaluation_loop.predictions.to_disk()
Expand All @@ -735,7 +735,7 @@ def run_evaluation(self, on_epoch=False):

torch.set_grad_enabled(True)

return eval_loop_results, deprecated_eval_results
return eval_loop_results

def track_output_for_epoch_end(self, outputs, output):
if output is not None:
Expand All @@ -757,7 +757,7 @@ def run_evaluate(self):
assert self.evaluating

with self.profiler.profile(f"run_{self._running_stage}_evaluation"):
eval_loop_results, _ = self.run_evaluation()
eval_loop_results = self.run_evaluation()

if len(eval_loop_results) == 0:
return 1
Expand Down Expand Up @@ -831,7 +831,7 @@ def run_sanity_check(self, ref_model):
self.on_sanity_check_start()

# run eval step
_, eval_results = self.run_evaluation()
self.run_evaluation()

self.on_sanity_check_end()

Expand Down
13 changes: 3 additions & 10 deletions tests/trainer/data_flow/test_eval_loop_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Tests to ensure that the training loop works with a dict (1.0)
Tests the evaluation loop
"""

import pytest
import torch

from pytorch_lightning import Trainer
Expand Down Expand Up @@ -189,8 +188,6 @@ def validation_epoch_end(self, outputs):
assert out_a == self.out_a
assert out_b == self.out_b

return {'no returns needed'}

def backward(self, loss, optimizer, optimizer_idx):
return LightningModule.backward(self, loss, optimizer, optimizer_idx)

Expand All @@ -206,8 +203,7 @@ def backward(self, loss, optimizer, optimizer_idx):
weights_summary=None,
)

with pytest.warns(UserWarning, match=r".*should not return anything as of 9.1.*"):
trainer.fit(model)
trainer.fit(model)

# make sure correct steps were called
assert model.validation_step_called
Expand Down Expand Up @@ -254,8 +250,6 @@ def validation_epoch_end(self, outputs):
assert out_a == self.out_a
assert out_b == self.out_b

return {'no returns needed'}

def backward(self, loss, optimizer, optimizer_idx):
return LightningModule.backward(self, loss, optimizer, optimizer_idx)

Expand All @@ -270,8 +264,7 @@ def backward(self, loss, optimizer, optimizer_idx):
weights_summary=None,
)

with pytest.warns(UserWarning, match=r".*should not return anything as of 9.1.*"):
trainer.fit(model)
trainer.fit(model)

# make sure correct steps were called
assert model.validation_step_called
Expand Down