Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[refactor] Add setup to profilers + _run_stage_setup to trainer 2/5 #6633

Merged
merged 6 commits into from
Mar 22, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pytorch_lightning/plugins/training_type/horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,22 +96,22 @@ def start_training(self, trainer):
stack.enter_context(optimizer.skip_synchronize())

# set up training routine
self._results = trainer.run_train()
self._results = trainer.run_stage()
tchaton marked this conversation as resolved.
Show resolved Hide resolved

# Make sure all workers have finished training before returning to the user
hvd.join()

def start_evaluating(self, trainer):
with ExitStack():
self._results = trainer.run_evaluate()
self._results = trainer.run_stage()

# Make sure all workers have finished training before returning to the user
hvd.join()

def start_predicting(self, trainer):
with ExitStack():
# set up training routine
self._results = trainer.run_predict()
self._results = trainer.run_stage()

# Make sure all workers have finished training before returning to the user
hvd.join()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,15 +132,15 @@ def rpc_enabled(self) -> bool:

def start_training(self, trainer: 'Trainer') -> None:
# double dispatch to initiate the training loop
self._results = trainer.run_train()
self._results = trainer.run_stage()

def start_evaluating(self, trainer: 'Trainer') -> None:
# double dispatch to initiate the test loop
self._results = trainer.run_evaluate()
self._results = trainer.run_stage()

def start_predicting(self, trainer: 'Trainer') -> None:
# double dispatch to initiate the predicting loop
self._results = trainer.run_predict()
self._results = trainer.run_stage()

def training_step(self, *args, **kwargs):
return self.lightning_module.training_step(*args, **kwargs)
Expand Down
51 changes: 21 additions & 30 deletions pytorch_lightning/profiler/profilers.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,23 @@ def start(self, action_name: str) -> None:
def stop(self, action_name: str) -> None:
"""Defines how to record the duration once an action is complete."""

def teardown(self) -> None:
def setup(
self,
stage: Optional[str] = None,
local_rank: Optional[int] = None,
log_dir: Optional[str] = None
) -> None:
"""Execute arbitrary post-profiling tear-down steps as defined by subclass."""
pass
self.stage = stage
self.local_rank = local_rank
self.log_dir = log_dir

def teardown(self, stage: Optional[str] = None) -> None:
"""Execute arbitrary post-profiling tear-down steps as defined by subclass."""
self.stage = stage
if self.output_file:
self.output_file.close()
self.output_file = None

@contextmanager
def profile(self, action_name: str) -> None:
Expand Down Expand Up @@ -94,13 +108,15 @@ def describe(self) -> None:
"""Logs a profile report after the conclusion of the training run."""
for write in self.write_streams:
write(self.summary())
if self.output_file is not None:
self.output_file.flush()

@abstractmethod
def summary(self) -> str:
"""Create profiler summary in text format."""

def on_train_start(self, local_rank: Optional[int] = None):
self.local_rank = local_rank
def __del__(self):
self.teardown(None)


class PassThroughProfiler(BaseProfiler):
Expand All @@ -110,6 +126,7 @@ class PassThroughProfiler(BaseProfiler):
"""

def __init__(self):
self.output_file = None
super().__init__(output_streams=None)

def start(self, action_name: str) -> None:
Expand Down Expand Up @@ -212,19 +229,6 @@ def log_row(action, mean, total):
output_string += os.linesep
return output_string

def describe(self):
"""Logs a profile report after the conclusion of the training run."""
super().describe()
self.teardown()

def teardown(self) -> None:
"""Close profiler's stream."""
if self.output_file:
self.output_file.close()

def __del__(self):
self.teardown()


class AdvancedProfiler(BaseProfiler):
"""
Expand Down Expand Up @@ -285,16 +289,3 @@ def summary(self) -> str:
output_string += f"{os.linesep}Profile stats for: {action}{os.linesep}{stats}"

return output_string

def describe(self):
"""Logs a profile report after the conclusion of the training run."""
super().describe()
self.teardown()

def teardown(self) -> None:
"""Close profiler's stream."""
if self.output_file:
self.output_file.close()

def __del__(self):
self.teardown()
21 changes: 4 additions & 17 deletions pytorch_lightning/profiler/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,11 @@ def __init__(
self.output_fname = output_filename
self.output_file = None
if local_rank is not None:
self.on_train_start(local_rank=local_rank)
self.on_train_start = super().on_train_start
self.setup(local_rank=local_rank)
self.setup = super().setup

def on_train_start(self, local_rank: Optional[str] = None):
self.local_rank = local_rank
def setup(self, stage: Optional[str] = None, local_rank: Optional[int] = None, log_dir: Optional[str] = None):
super().setup(stage=stage, local_rank=local_rank, log_dir=log_dir)

# when logging to `log.info`, only perform profiling on rank 0
if local_rank != 0 and self.output_fname is None:
Expand Down Expand Up @@ -290,16 +290,3 @@ def summary(self) -> str:
output_string += (f"{os.linesep}Profile stats for: {action} rank: {local_rank} {os.linesep}{stats}")

return output_string

def describe(self):
"""Logs a profile report after the conclusion of the training run."""
super().describe()
self.teardown()

def teardown(self) -> None:
"""Close profiler's stream."""
if self.output_file:
self.output_file.close()

def __del__(self):
self.teardown()
5 changes: 3 additions & 2 deletions pytorch_lightning/trainer/connectors/profiler_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def on_trainer_init(self, profiler: Union[BaseProfiler, str]):
)
self.trainer.profiler = profiler or PassThroughProfiler()

def on_train_start(self, trainer):
def setup(self) -> None:
trainer = self.trainer
local_rank = trainer.local_rank if trainer.world_size > 1 else None
self.trainer.profiler.on_train_start(local_rank)
trainer.profiler.setup(stage=trainer._setup_state, local_rank=local_rank, log_dir=trainer.log_dir)
10 changes: 10 additions & 0 deletions pytorch_lightning/trainer/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,16 @@ def sanity_checking(self, val: bool) -> None:
elif self.sanity_checking:
self._running_stage = None

@property
def _setup_state(self) -> TrainerState:
# 'fit' is passed for `trainer.tune()` as there aren't "tune_dataloaders"
return TrainerState.FITTING if self.state == TrainerState.TUNING else self.state

@property
def _teardown_state(self) -> Optional[TrainerState]:
if self.state.running:
return self._setup_state


# Used to represent the concrete type TrainerProperties class methods are called on.
_T = TypeVar('_T', bound=TrainerProperties)
28 changes: 14 additions & 14 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,13 +445,15 @@ def fit(
| ||
{self.dispatch} ||
| || LIGHTNING
{self.accelerator.start_training} or ||
{self.accelerator.start_evaluating} or || FLOW
{self.accelerator.start_predicting} ||
{self.accelerator.start_training} ||
or {self.accelerator.start_evaluating} ||
or {self.accelerator.start_predicting} || FLOW
| ||
{self.run_stage} ||
| || DIRECTION
{self.run_train} or ||
{self.run_evaluation} or ||
{self.run_predict} ||
{self.run_train} ||
or {self.run_evaluation} ||
or {self.run_predict} ||
| ||
results \/
This is used to guide readers to the core loops: train, test, predict.
Expand Down Expand Up @@ -518,6 +520,9 @@ def dispatch(self):

def run_stage(self):
results = None

self.profile_connector.setup()

if self.evaluating:
results = self.run_evaluate()
elif self.predicting:
Expand Down Expand Up @@ -1060,8 +1065,7 @@ def tune(

def call_setup_hook(self, model: LightningModule) -> None:
assert self.state.running, f"TrainerState: {self.state}"
# 'fit' is passed for `trainer.tune()` as there aren't "tune_dataloaders"
state = TrainerState.FITTING if self.state == TrainerState.TUNING else self.state
state = self._setup_state

if self.datamodule is not None:
called = getattr(self.datamodule, f'has_setup_{state}')
Expand All @@ -1072,12 +1076,8 @@ def call_setup_hook(self, model: LightningModule) -> None:
model.setup(stage=state)

def call_teardown_hook(self, model: LightningModule) -> None:
if self.state.running:
state = TrainerState.FITTING if self.state == TrainerState.TUNING else self.state
else:
state = None

self.profiler.teardown()
state = self._teardown_state
self.profiler.teardown(stage=state)
self.teardown(stage=state)
model.teardown(stage=state)

Expand Down
3 changes: 0 additions & 3 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,6 @@ def on_train_start(self):
# hook
self.trainer.call_hook("on_train_start")

# provide rank to profiler
self.trainer.profile_connector.on_train_start(self.trainer)

def setup_fit(self, model, train_dataloader=None, val_dataloaders=None, datamodule=None):
# clean hparams
if hasattr(model, "hparams"):
Expand Down
17 changes: 10 additions & 7 deletions tests/test_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
import torch

from pytorch_lightning import Trainer
from pytorch_lightning.profiler import AdvancedProfiler, SimpleProfiler, PyTorchProfiler
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.profiler import AdvancedProfiler, PyTorchProfiler, SimpleProfiler
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf

Expand Down Expand Up @@ -323,14 +324,16 @@ def test_profiler_teardown(tmpdir, cls):
"""
This test checks if profiler teardown method is called when trainer is exiting.
"""

class TestCallback(Callback):

def on_fit_end(self, trainer, pl_module) -> None:
assert trainer.profiler.output_file is not None

profiler = cls(output_filename=os.path.join(tmpdir, "profiler.txt"))

model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=True,
profiler=profiler,
)
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, profiler=profiler, callbacks=[TestCallback()])
trainer.fit(model)

assert profiler.output_file.closed
assert profiler.output_file is None