Skip to content

Commit

Permalink
ref: add eval loop object to streamline eval loop (#3138)
Browse files Browse the repository at this point in the history
* added eval loop

* added eval loop

* added eval loop

* added eval loop

* added eval loop

* added eval loop
  • Loading branch information
williamFalcon authored Aug 25, 2020
1 parent 82d1128 commit 229b876
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 55 deletions.
76 changes: 76 additions & 0 deletions pytorch_lightning/trainer/evaluate_loop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import torch
from pytorch_lightning.trainer.supporters import PredictionCollection
from pytorch_lightning.core.step_result import EvalResult


class EvaluationLoop(object):
def __init__(self, trainer):
self.trainer = trainer
self.testing = False
self.outputs = []
self.predictions = None
self.max_batches = None

def is_using_eval_results(self):
outputs = self.outputs
using_eval_result = len(outputs) > 0 and len(outputs[0]) > 0 and isinstance(outputs[0][0], EvalResult)
return using_eval_result

def setup(self, model, max_batches, dataloaders):
# enable eval mode
model.zero_grad()
model.eval()

# copy properties for forward overrides
self.trainer.copy_trainer_model_properties(model)

# disable gradients to save memory
torch.set_grad_enabled(False)

# bookkeeping
self.outputs = []
self.predictions = PredictionCollection(self.trainer.global_rank, self.trainer.world_size)

# convert max_batches to list
if isinstance(max_batches, int):
max_batches = [max_batches] * len(dataloaders)

self.max_batches = max_batches

def on_evaluation_epoch_start(self, *args, **kwargs):
if self.testing:
self.trainer.call_hook('on_test_epoch_start', *args, **kwargs)
else:
self.trainer.call_hook('on_validation_epoch_start', *args, **kwargs)

def evaluation_step(self, *args, **kwargs):
if self.testing:
output = self.trainer.accelerator_backend.test_step(*args, **kwargs)
else:
output = self.trainer.accelerator_backend.validation_step(*args, **kwargs)
return output

def evaluation_step_end(self, *args, **kwargs):
if self.testing:
output = self.trainer.call_hook('test_step_end', *args, **kwargs)
else:
output = self.trainer.call_hook('validation_step_end', *args, **kwargs)
return output

def on_evaluation_batch_start(self, *args, **kwargs):
if self.testing:
self.trainer.call_hook('on_test_batch_start', *args, **kwargs)
else:
self.trainer.call_hook('on_validation_batch_start', *args, **kwargs)

def on_evaluation_batch_end(self, *args, **kwargs):
if self.testing:
self.trainer.call_hook('on_test_batch_end', *args, **kwargs)
else:
self.trainer.call_hook('on_validation_batch_end', *args, **kwargs)

def on_evaluation_epoch_end(self, *args, **kwargs):
if self.testing:
self.trainer.call_hook('on_test_epoch_end', *args, **kwargs)
else:
self.trainer.call_hook('on_validation_epoch_end', *args, **kwargs)
83 changes: 28 additions & 55 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@
from pytorch_lightning.utilities import rank_zero_warn, flatten_dict, AMPType
from pytorch_lightning.core.step_result import Result, EvalResult
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.trainer.supporters import PredictionCollection
from pytorch_lightning.trainer.evaluate_loop import EvaluationLoop

try:
import torch_xla.distributed.parallel_loader as xla_pl
Expand Down Expand Up @@ -192,6 +192,7 @@ class TrainerEvaluationLoopMixin(ABC):
on_test_start: Callable
on_test_end: Callable
accelerator_backend: ...
evaluation_loop: EvaluationLoop

@abstractmethod
def copy_trainer_model_properties(self, *args):
Expand Down Expand Up @@ -245,31 +246,14 @@ def _evaluate(
entry is the number of batches to process in the corresponding dataloader.
test_mode:
"""
# enable eval mode
model.zero_grad()
model.eval()
# set up the loop for val/test
self.evaluation_loop.testing = test_mode

# copy properties for forward overrides
self.copy_trainer_model_properties(model)
# set up the eval loop
self.evaluation_loop.setup(model, max_batches, dataloaders)

# disable gradients to save memory
torch.set_grad_enabled(False)

# bookkeeping
outputs = []
predictions = PredictionCollection(self.global_rank, self.world_size)

# convert max_batches to list
if isinstance(max_batches, int):
max_batches = [max_batches] * len(dataloaders)

# --------------------------
# ON_EVAL_EPOCH_START hook
# --------------------------
if test_mode:
self.call_hook('on_test_epoch_start')
else:
self.call_hook('on_validation_epoch_start')
# hook
self.evaluation_loop.on_evaluation_epoch_start()

# run validation
for dataloader_idx, dataloader in enumerate(dataloaders):
Expand All @@ -282,7 +266,7 @@ def _evaluate(
dataloader = dataloader.per_device_loader(device)

# each dataloader has a max num batches
dl_max_batches = max_batches[dataloader_idx]
dl_max_batches = self.evaluation_loop.max_batches[dataloader_idx]

for batch_idx, batch in enumerate(dataloader):
if batch is None:
Expand All @@ -292,25 +276,19 @@ def _evaluate(
if batch_idx >= dl_max_batches:
break

# callbacks
if test_mode:
self.call_hook('on_test_batch_start', batch, batch_idx, dataloader_idx)
else:
self.call_hook('on_validation_batch_start', batch, batch_idx, dataloader_idx)
# -----------------
# eval_batch_start
# -----------------
self.evaluation_loop.on_evaluation_batch_start(batch, batch_idx, dataloader_idx)

# -----------------
# RUN EVALUATION STEP
# -----------------
args = self.build_args(test_mode, batch, batch_idx, dataloader_idx)

if test_mode:
output = self.accelerator_backend.test_step(args)
else:
output = self.accelerator_backend.validation_step(args)

is_result_obj = isinstance(output, Result)
output = self.evaluation_loop.evaluation_step(args)

# track batch size for weighted average
is_result_obj = isinstance(output, Result)
if is_result_obj:
output.track_batch_size(len(batch))

Expand All @@ -322,19 +300,12 @@ def _evaluate(
# ------------------
# EVAL STEP END
# ------------------
if test_mode:
output = self.call_hook('test_step_end', output)
else:
output = self.call_hook('validation_step_end', output)
output = self.evaluation_loop.evaluation_step_end(output)

# ------------------
# Hook: on_eval_batch_end
# ------------------
# callbacks (on __batch_end)
if test_mode:
self.call_hook('on_test_batch_end', batch, batch_idx, dataloader_idx)
else:
self.call_hook('on_validation_batch_end', batch, batch_idx, dataloader_idx)
self.evaluation_loop.on_evaluation_batch_end(batch, batch_idx, dataloader_idx)

# ----------------------
# Post processing
Expand All @@ -345,7 +316,7 @@ def _evaluate(
# Add step predictions to prediction collection to write later
do_write_predictions = is_result_obj and test_mode
if do_write_predictions:
predictions.add(output.pop('predictions', None))
self.evaluation_loop.predictions.add(output.pop('predictions', None))

dl_outputs.append(output)

Expand All @@ -354,19 +325,24 @@ def _evaluate(
# track debug metrics
self.dev_debugger.track_eval_loss_history(test_mode, batch_idx, dataloader_idx, output)

outputs.append(dl_outputs)
self.evaluation_loop.outputs.append(dl_outputs)

# ---------------------
# EVAL_EPOCH_END
# ---------------------
using_eval_result = len(outputs) > 0 and len(outputs[0]) > 0 and isinstance(outputs[0][0], EvalResult)
eval_results = self.__run_eval_epoch_end(test_mode, outputs, dataloaders, using_eval_result)
using_eval_result = self.evaluation_loop.is_using_eval_results()
eval_results = self.__run_eval_epoch_end(
test_mode,
self.evaluation_loop.outputs,
dataloaders,
using_eval_result
)

# log callback metrics
self.__update_callback_metrics(eval_results, using_eval_result)

# Write predictions to disk if they're available.
predictions.to_disk()
self.evaluation_loop.predictions.to_disk()

# enable train mode again
model.train()
Expand All @@ -377,10 +353,7 @@ def _evaluate(
# --------------------------
# ON_EVAL_EPOCH_END hook
# --------------------------
if test_mode:
self.call_hook('on_test_epoch_end')
else:
self.call_hook('on_validation_epoch_end')
self.evaluation_loop.on_evaluation_epoch_end()

return eval_results

Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from pytorch_lightning.utilities.debugging import InternalDebugger
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.cloud_io import is_remote_path
from pytorch_lightning.trainer.evaluate_loop import EvaluationLoop

# warnings to ignore in trainer
warnings.filterwarnings(
Expand Down Expand Up @@ -608,6 +609,9 @@ def __init__(
self.config_validator = ConfigValidator(self)
self.accelerator_backend = None

# loops
self.evaluation_loop = EvaluationLoop(self)

# Callback system
self.on_init_end()

Expand Down

0 comments on commit 229b876

Please sign in to comment.