Skip to content

Commit

Permalink
[refactor] Add setup to profilers + _run_stage_setup to trainer 2/5 (#…
Browse files Browse the repository at this point in the history
…6633)

* add setup

* update

* updates on comment

* Minor changes

* Extra import

* Docs

Co-authored-by: Carlos Mocholi <[email protected]>
  • Loading branch information
tchaton and carmocca authored Mar 22, 2021
1 parent e62c7c7 commit 2064ece
Show file tree
Hide file tree
Showing 10 changed files with 72 additions and 80 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `teardown` method to `BaseProfiler` to enable subclasses defining post-profiling steps outside of `__del__` ([#6370](https://github.com/PyTorchLightning/pytorch-lightning/pull/6370))


- Added `setup` method to `BaseProfiler` to enable subclasses defining pre-profiling steps for every process ([#6633](https://github.com/PyTorchLightning/pytorch-lightning/pull/6633))


- Added no return warning to predict ([#6139](https://github.com/PyTorchLightning/pytorch-lightning/pull/6139))


Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/plugins/training_type/horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,22 +96,22 @@ def start_training(self, trainer):
stack.enter_context(optimizer.skip_synchronize())

# set up training routine
self._results = trainer.run_train()
self._results = trainer.run_stage()

# Make sure all workers have finished training before returning to the user
hvd.join()

def start_evaluating(self, trainer):
with ExitStack():
self._results = trainer.run_evaluate()
self._results = trainer.run_stage()

# Make sure all workers have finished training before returning to the user
hvd.join()

def start_predicting(self, trainer):
with ExitStack():
# set up training routine
self._results = trainer.run_predict()
self._results = trainer.run_stage()

# Make sure all workers have finished training before returning to the user
hvd.join()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,15 +132,15 @@ def rpc_enabled(self) -> bool:

def start_training(self, trainer: 'Trainer') -> None:
# double dispatch to initiate the training loop
self._results = trainer.run_train()
self._results = trainer.run_stage()

def start_evaluating(self, trainer: 'Trainer') -> None:
# double dispatch to initiate the test loop
self._results = trainer.run_evaluate()
self._results = trainer.run_stage()

def start_predicting(self, trainer: 'Trainer') -> None:
# double dispatch to initiate the predicting loop
self._results = trainer.run_predict()
self._results = trainer.run_stage()

def training_step(self, *args, **kwargs):
return self.lightning_module.training_step(*args, **kwargs)
Expand Down
53 changes: 22 additions & 31 deletions pytorch_lightning/profiler/profilers.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,23 @@ def start(self, action_name: str) -> None:
def stop(self, action_name: str) -> None:
"""Defines how to record the duration once an action is complete."""

def teardown(self) -> None:
"""Execute arbitrary post-profiling tear-down steps as defined by subclass."""
pass
def setup(
self,
stage: Optional[str] = None,
local_rank: Optional[int] = None,
log_dir: Optional[str] = None
) -> None:
"""Execute arbitrary pre-profiling set-up steps."""
self.stage = stage
self.local_rank = local_rank
self.log_dir = log_dir

def teardown(self, stage: Optional[str] = None) -> None:
"""Execute arbitrary post-profiling tear-down steps."""
self.stage = stage
if self.output_file:
self.output_file.close()
self.output_file = None

@contextmanager
def profile(self, action_name: str) -> None:
Expand Down Expand Up @@ -94,13 +108,15 @@ def describe(self) -> None:
"""Logs a profile report after the conclusion of the training run."""
for write in self.write_streams:
write(self.summary())
if self.output_file is not None:
self.output_file.flush()

@abstractmethod
def summary(self) -> str:
"""Create profiler summary in text format."""

def on_train_start(self, local_rank: Optional[int] = None):
self.local_rank = local_rank
def __del__(self):
self.teardown(None)


class PassThroughProfiler(BaseProfiler):
Expand All @@ -110,6 +126,7 @@ class PassThroughProfiler(BaseProfiler):
"""

def __init__(self):
self.output_file = None
super().__init__(output_streams=None)

def start(self, action_name: str) -> None:
Expand Down Expand Up @@ -212,19 +229,6 @@ def log_row(action, mean, total):
output_string += os.linesep
return output_string

def describe(self):
"""Logs a profile report after the conclusion of the training run."""
super().describe()
self.teardown()

def teardown(self) -> None:
"""Close profiler's stream."""
if self.output_file:
self.output_file.close()

def __del__(self):
self.teardown()


class AdvancedProfiler(BaseProfiler):
"""
Expand Down Expand Up @@ -285,16 +289,3 @@ def summary(self) -> str:
output_string += f"{os.linesep}Profile stats for: {action}{os.linesep}{stats}"

return output_string

def describe(self):
"""Logs a profile report after the conclusion of the training run."""
super().describe()
self.teardown()

def teardown(self) -> None:
"""Close profiler's stream."""
if self.output_file:
self.output_file.close()

def __del__(self):
self.teardown()
21 changes: 4 additions & 17 deletions pytorch_lightning/profiler/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,11 @@ def __init__(
self.output_fname = output_filename
self.output_file = None
if local_rank is not None:
self.on_train_start(local_rank=local_rank)
self.on_train_start = super().on_train_start
self.setup(local_rank=local_rank)
self.setup = super().setup

def on_train_start(self, local_rank: Optional[str] = None):
self.local_rank = local_rank
def setup(self, stage: Optional[str] = None, local_rank: Optional[int] = None, log_dir: Optional[str] = None):
super().setup(stage=stage, local_rank=local_rank, log_dir=log_dir)

# when logging to `log.info`, only perform profiling on rank 0
if local_rank != 0 and self.output_fname is None:
Expand Down Expand Up @@ -290,16 +290,3 @@ def summary(self) -> str:
output_string += (f"{os.linesep}Profile stats for: {action} rank: {local_rank} {os.linesep}{stats}")

return output_string

def describe(self):
"""Logs a profile report after the conclusion of the training run."""
super().describe()
self.teardown()

def teardown(self) -> None:
"""Close profiler's stream."""
if self.output_file:
self.output_file.close()

def __del__(self):
self.teardown()
5 changes: 3 additions & 2 deletions pytorch_lightning/trainer/connectors/profiler_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def on_trainer_init(self, profiler: Union[BaseProfiler, str]):
)
self.trainer.profiler = profiler or PassThroughProfiler()

def on_train_start(self, trainer):
def setup(self) -> None:
trainer = self.trainer
local_rank = trainer.local_rank if trainer.world_size > 1 else None
self.trainer.profiler.on_train_start(local_rank)
trainer.profiler.setup(stage=trainer._setup_state, local_rank=local_rank, log_dir=trainer.log_dir)
10 changes: 10 additions & 0 deletions pytorch_lightning/trainer/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,16 @@ def sanity_checking(self, val: bool) -> None:
elif self.sanity_checking:
self._running_stage = None

@property
def _setup_state(self) -> TrainerState:
# 'fit' is passed for `trainer.tune()` as there aren't "tune_dataloaders"
return TrainerState.FITTING if self.state == TrainerState.TUNING else self.state

@property
def _teardown_state(self) -> Optional[TrainerState]:
if self.state.running:
return self._setup_state


# Used to represent the concrete type TrainerProperties class methods are called on.
_T = TypeVar('_T', bound=TrainerProperties)
28 changes: 14 additions & 14 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,13 +445,15 @@ def fit(
| ||
{self.dispatch} ||
| || LIGHTNING
{self.accelerator.start_training} or ||
{self.accelerator.start_evaluating} or || FLOW
{self.accelerator.start_predicting} ||
{self.accelerator.start_training} ||
or {self.accelerator.start_evaluating} ||
or {self.accelerator.start_predicting} || FLOW
| ||
{self.run_stage} ||
| || DIRECTION
{self.run_train} or ||
{self.run_evaluation} or ||
{self.run_predict} ||
{self.run_train} ||
or {self.run_evaluation} ||
or {self.run_predict} ||
| ||
results \/
This is used to guide readers to the core loops: train, test, predict.
Expand Down Expand Up @@ -518,6 +520,9 @@ def dispatch(self):

def run_stage(self):
results = None

self.profile_connector.setup()

if self.evaluating:
results = self.run_evaluate()
elif self.predicting:
Expand Down Expand Up @@ -1060,8 +1065,7 @@ def tune(

def call_setup_hook(self, model: LightningModule) -> None:
assert self.state.running, f"TrainerState: {self.state}"
# 'fit' is passed for `trainer.tune()` as there aren't "tune_dataloaders"
state = TrainerState.FITTING if self.state == TrainerState.TUNING else self.state
state = self._setup_state

if self.datamodule is not None:
called = getattr(self.datamodule, f'has_setup_{state}')
Expand All @@ -1072,12 +1076,8 @@ def call_setup_hook(self, model: LightningModule) -> None:
model.setup(stage=state)

def call_teardown_hook(self, model: LightningModule) -> None:
if self.state.running:
state = TrainerState.FITTING if self.state == TrainerState.TUNING else self.state
else:
state = None

self.profiler.teardown()
state = self._teardown_state
self.profiler.teardown(stage=state)
self.teardown(stage=state)
model.teardown(stage=state)

Expand Down
3 changes: 0 additions & 3 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,6 @@ def on_train_start(self):
# hook
self.trainer.call_hook("on_train_start")

# provide rank to profiler
self.trainer.profile_connector.on_train_start(self.trainer)

def setup_fit(self, model, train_dataloader=None, val_dataloaders=None, datamodule=None):
# clean hparams
if hasattr(model, "hparams"):
Expand Down
17 changes: 10 additions & 7 deletions tests/test_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
import torch

from pytorch_lightning import Trainer
from pytorch_lightning.profiler import AdvancedProfiler, SimpleProfiler, PyTorchProfiler
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.profiler import AdvancedProfiler, PyTorchProfiler, SimpleProfiler
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf

Expand Down Expand Up @@ -323,14 +324,16 @@ def test_profiler_teardown(tmpdir, cls):
"""
This test checks if profiler teardown method is called when trainer is exiting.
"""

class TestCallback(Callback):

def on_fit_end(self, trainer, pl_module) -> None:
assert trainer.profiler.output_file is not None

profiler = cls(output_filename=os.path.join(tmpdir, "profiler.txt"))

model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=True,
profiler=profiler,
)
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, profiler=profiler, callbacks=[TestCallback()])
trainer.fit(model)

assert profiler.output_file.closed
assert profiler.output_file is None

0 comments on commit 2064ece

Please sign in to comment.