diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 94d4b420fc076..7a113e69dc119 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -33,8 +33,8 @@ /src/pytorch_lightning/loops @tchaton @awaelchli @justusschock @carmocca /src/pytorch_lightning/overrides @tchaton @SeanNaren @borda /src/pytorch_lightning/plugins @tchaton @SeanNaren @awaelchli @justusschock -/src/pytorch_lightning/profiler @williamfalcon @tchaton @borda @carmocca -/src/pytorch_lightning/profiler/pytorch.py @nbcsm @guotuofeng +/src/pytorch_lightning/profilers @williamfalcon @tchaton @borda @carmocca +/src/pytorch_lightning/profilers/pytorch.py @nbcsm @guotuofeng /src/pytorch_lightning/strategies @tchaton @SeanNaren @awaelchli @justusschock @kaushikb11 /src/pytorch_lightning/trainer @williamfalcon @borda @tchaton @SeanNaren @carmocca @awaelchli @justusschock @kaushikb11 /src/pytorch_lightning/trainer/connectors @tchaton @SeanNaren @carmocca @borda diff --git a/CHANGELOG.md b/CHANGELOG.md index df4bb6e286600..4b01be8bfad8d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -143,6 +143,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated LightningCLI's registries in favor of importing the respective package ([#13221](https://github.com/PyTorchLightning/pytorch-lightning/pull/13221)) + +- Deprecated `pytorch_lightning.profiler` in favor of `pytorch_lightning.profilers` ([#12308](https://github.com/PyTorchLightning/pytorch-lightning/pull/12308)) + + ### Removed - Removed the deprecated `Logger.close` method ([#13149](https://github.com/PyTorchLightning/pytorch-lightning/pull/13149)) diff --git a/dockers/tpu-tests/tpu_test_cases.jsonnet b/dockers/tpu-tests/tpu_test_cases.jsonnet index 5ec110729bd93..338d09d0b6612 100644 --- a/dockers/tpu-tests/tpu_test_cases.jsonnet +++ b/dockers/tpu-tests/tpu_test_cases.jsonnet @@ -36,7 +36,7 @@ local tputests = base.BaseTest { # TODO (@kaushikb11): Add device stats tests here coverage run --source=pytorch_lightning -m pytest -v --capture=no \ strategies/test_tpu_spawn.py \ - profiler/test_xla_profiler.py \ + profilers/test_xla_profiler.py \ accelerators/test_tpu.py \ models/test_tpu.py \ plugins/environments/test_xla_environment.py diff --git a/docs/source-pytorch/api_references.rst b/docs/source-pytorch/api_references.rst index a147340d36df4..15640bc3ca81f 100644 --- a/docs/source-pytorch/api_references.rst +++ b/docs/source-pytorch/api_references.rst @@ -232,7 +232,7 @@ others profiler -------- -.. currentmodule:: pytorch_lightning.profiler +.. currentmodule:: pytorch_lightning.profilers .. autosummary:: :toctree: api diff --git a/docs/source-pytorch/common/trainer.rst b/docs/source-pytorch/common/trainer.rst index 68fca5e7ba30e..3e3734ceb35e3 100644 --- a/docs/source-pytorch/common/trainer.rst +++ b/docs/source-pytorch/common/trainer.rst @@ -1213,7 +1213,7 @@ See the :doc:`profiler documentation <../tuning/profiler>`. for more details. .. testcode:: - from pytorch_lightning.profiler import SimpleProfiler, AdvancedProfiler + from pytorch_lightning.profilers import SimpleProfiler, AdvancedProfiler # default used by the Trainer trainer = Trainer(profiler=None) diff --git a/docs/source-pytorch/tuning/profiler_advanced.rst b/docs/source-pytorch/tuning/profiler_advanced.rst index ad2ab9e2020a1..1a1794f35f0a0 100644 --- a/docs/source-pytorch/tuning/profiler_advanced.rst +++ b/docs/source-pytorch/tuning/profiler_advanced.rst @@ -12,11 +12,11 @@ Find bottlenecks in your code (advanced) ************************ Profile cloud TPU models ************************ -To profile TPU models use the :class:`~pytorch_lightning.profiler.xla.XLAProfiler` +To profile TPU models use the :class:`~pytorch_lightning.profilers.xla.XLAProfiler` .. code-block:: python - from pytorch_lightning.profiler import XLAProfiler + from pytorch_lightning.profilers import XLAProfiler profiler = XLAProfiler(port=9001) trainer = Trainer(profiler=profiler) diff --git a/docs/source-pytorch/tuning/profiler_basic.rst b/docs/source-pytorch/tuning/profiler_basic.rst index c3ddc114dce9a..02954e287c586 100644 --- a/docs/source-pytorch/tuning/profiler_basic.rst +++ b/docs/source-pytorch/tuning/profiler_basic.rst @@ -68,7 +68,7 @@ The simple profiler measures all the standard methods used in the training loop ************************************** Profile the time within every function ************************************** -To profile the time within every function, use the :class:`~pytorch_lightning.profiler.advanced.AdvancedProfiler` built on top of Python's `cProfiler `_. +To profile the time within every function, use the :class:`~pytorch_lightning.profilers.advanced.AdvancedProfiler` built on top of Python's `cProfiler `_. .. code-block:: python @@ -101,7 +101,7 @@ If the profiler report becomes too long, you can stream the report to a file: .. code-block:: python - from pytorch_lightning.profiler import AdvancedProfiler + from pytorch_lightning.profilers import AdvancedProfiler profiler = AdvancedProfiler(dirpath=".", filename="perf_logs") trainer = Trainer(profiler=profiler) diff --git a/docs/source-pytorch/tuning/profiler_expert.rst b/docs/source-pytorch/tuning/profiler_expert.rst index 64ff784ed6c0d..fe864536e4b03 100644 --- a/docs/source-pytorch/tuning/profiler_expert.rst +++ b/docs/source-pytorch/tuning/profiler_expert.rst @@ -12,12 +12,12 @@ Find bottlenecks in your code (expert) *********************** Build your own profiler *********************** -To build your own profiler, subclass :class:`~pytorch_lightning.profiler.base.Profiler` +To build your own profiler, subclass :class:`~pytorch_lightning.profilers.profiler.Profiler` and override some of its methods. Here is a simple example that profiles the first occurrence and total calls of each action: .. code-block:: python - from pytorch_lightning.profiler import Profiler + from pytorch_lightning.profilers import Profiler from collections import defaultdict import time @@ -69,7 +69,7 @@ To profile a specific action of interest, reference a profiler in the LightningM .. code-block:: python - from pytorch_lightning.profiler import SimpleProfiler, PassThroughProfiler + from pytorch_lightning.profilers import SimpleProfiler, PassThroughProfiler class MyModel(LightningModule): @@ -90,7 +90,7 @@ Here's the full code: .. code-block:: python - from pytorch_lightning.profiler import SimpleProfiler, PassThroughProfiler + from pytorch_lightning.profilers import SimpleProfiler, PassThroughProfiler class MyModel(LightningModule): diff --git a/docs/source-pytorch/tuning/profiler_intermediate.rst b/docs/source-pytorch/tuning/profiler_intermediate.rst index d2b64b5d54743..a3356cb3e2f8a 100644 --- a/docs/source-pytorch/tuning/profiler_intermediate.rst +++ b/docs/source-pytorch/tuning/profiler_intermediate.rst @@ -12,11 +12,11 @@ Find bottlenecks in your code (intermediate) ************************** Profile pytorch operations ************************** -To understand the cost of each PyTorch operation, use the :class:`~pytorch_lightning.profiler.pytorch.PyTorchProfiler` built on top of the `PyTorch profiler `__. +To understand the cost of each PyTorch operation, use the :class:`~pytorch_lightning.profilers.pytorch.PyTorchProfiler` built on top of the `PyTorch profiler `__. .. code-block:: python - from pytorch_lightning.profiler import PyTorchProfiler + from pytorch_lightning.profilers import PyTorchProfiler profiler = PyTorchProfiler() trainer = Trainer(profiler=profiler) @@ -65,11 +65,11 @@ The profiler will generate an output like this: *************************** Profile a distributed model *************************** -To profile a distributed model, use the :class:`~pytorch_lightning.profiler.pytorch.PyTorchProfiler` with the *filename* argument which will save a report per rank. +To profile a distributed model, use the :class:`~pytorch_lightning.profilers.pytorch.PyTorchProfiler` with the *filename* argument which will save a report per rank. .. code-block:: python - from pytorch_lightning.profiler import PyTorchProfiler + from pytorch_lightning.profilers import PyTorchProfiler profiler = PyTorchProfiler(filename="perf-logs") trainer = Trainer(profiler=profiler) @@ -153,11 +153,11 @@ to extend the scope of profiled functions. ***************************** Visualize profiled operations ***************************** -To visualize the profiled operations, enable **emit_nvtx** in the :class:`~pytorch_lightning.profiler.pytorch.PyTorchProfiler`. +To visualize the profiled operations, enable **emit_nvtx** in the :class:`~pytorch_lightning.profilers.pytorch.PyTorchProfiler`. .. code-block:: python - from pytorch_lightning.profiler import PyTorchProfiler + from pytorch_lightning.profilers import PyTorchProfiler profiler = PyTorchProfiler(emit_nvtx=True) trainer = Trainer(profiler=profiler) diff --git a/examples/pl_basics/profiler_example.py b/examples/pl_basics/profiler_example.py index 89074548683c2..050740e3ce314 100644 --- a/examples/pl_basics/profiler_example.py +++ b/examples/pl_basics/profiler_example.py @@ -31,7 +31,7 @@ import torchvision.transforms as T from pytorch_lightning import cli_lightning_logo, LightningDataModule, LightningModule -from pytorch_lightning.profiler.pytorch import PyTorchProfiler +from pytorch_lightning.profilers.pytorch import PyTorchProfiler from pytorch_lightning.utilities.cli import LightningCLI DEFAULT_CMD_LINE = ( diff --git a/pyproject.toml b/pyproject.toml index 4b9f45068e089..7802f9e7dca8a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,10 +78,10 @@ module = [ "pytorch_lightning.strategies.single_tpu", "pytorch_lightning.strategies.tpu_spawn", "pytorch_lightning.strategies.strategy", - "pytorch_lightning.profiler.advanced", - "pytorch_lightning.profiler.base", - "pytorch_lightning.profiler.pytorch", - "pytorch_lightning.profiler.simple", + "pytorch_lightning.profilers.advanced", + "pytorch_lightning.profilers.base", + "pytorch_lightning.profilers.pytorch", + "pytorch_lightning.profilers.simple", "pytorch_lightning.trainer.callback_hook", "pytorch_lightning.trainer.connectors.callback_connector", "pytorch_lightning.trainer.connectors.data_connector", diff --git a/src/pytorch_lightning/profiler/__init__.py b/src/pytorch_lightning/profiler/__init__.py index 60a1fa7ed869e..f55d7223e4051 100644 --- a/src/pytorch_lightning/profiler/__init__.py +++ b/src/pytorch_lightning/profiler/__init__.py @@ -11,12 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from pytorch_lightning.profiler.advanced import AdvancedProfiler -from pytorch_lightning.profiler.base import AbstractProfiler, BaseProfiler, PassThroughProfiler -from pytorch_lightning.profiler.profiler import Profiler -from pytorch_lightning.profiler.pytorch import PyTorchProfiler -from pytorch_lightning.profiler.simple import SimpleProfiler -from pytorch_lightning.profiler.xla import XLAProfiler +from pytorch_lightning.profiler.base import AbstractProfiler, BaseProfiler +from pytorch_lightning.profilers.advanced import AdvancedProfiler +from pytorch_lightning.profilers.base import PassThroughProfiler +from pytorch_lightning.profilers.profiler import Profiler +from pytorch_lightning.profilers.pytorch import PyTorchProfiler +from pytorch_lightning.profilers.simple import SimpleProfiler +from pytorch_lightning.profilers.xla import XLAProfiler __all__ = [ "AbstractProfiler", diff --git a/src/pytorch_lightning/profiler/advanced.py b/src/pytorch_lightning/profiler/advanced.py index a776f50764589..1d2bbed5d96f6 100644 --- a/src/pytorch_lightning/profiler/advanced.py +++ b/src/pytorch_lightning/profiler/advanced.py @@ -11,81 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Profiler to check if there are any bottlenecks in your code.""" -import cProfile -import io -import logging -import pstats -from pathlib import Path -from typing import Dict, Optional, Union +from pytorch_lightning.profilers.advanced import AdvancedProfiler as NewAdvancedProfiler +from pytorch_lightning.utilities import rank_zero_deprecation -from pytorch_lightning.profiler.profiler import Profiler -log = logging.getLogger(__name__) - - -class AdvancedProfiler(Profiler): - """This profiler uses Python's cProfiler to record more detailed information about time spent in each function - call recorded during a given action. - - The output is quite verbose and you should only use this if you want very detailed reports. - """ - - def __init__( - self, - dirpath: Optional[Union[str, Path]] = None, - filename: Optional[str] = None, - line_count_restriction: float = 1.0, - ) -> None: - """ - Args: - dirpath: Directory path for the ``filename``. If ``dirpath`` is ``None`` but ``filename`` is present, the - ``trainer.log_dir`` (from :class:`~pytorch_lightning.loggers.tensorboard.TensorBoardLogger`) - will be used. - - filename: If present, filename where the profiler results will be saved instead of printing to stdout. - The ``.txt`` extension will be used automatically. - - line_count_restriction: this can be used to limit the number of functions - reported for each action. either an integer (to select a count of lines), - or a decimal fraction between 0.0 and 1.0 inclusive (to select a percentage of lines) - - Raises: - ValueError: - If you attempt to stop recording an action which was never started. - """ - super().__init__(dirpath=dirpath, filename=filename) - self.profiled_actions: Dict[str, cProfile.Profile] = {} - self.line_count_restriction = line_count_restriction - - def start(self, action_name: str) -> None: - if action_name not in self.profiled_actions: - self.profiled_actions[action_name] = cProfile.Profile() - self.profiled_actions[action_name].enable() - - def stop(self, action_name: str) -> None: - pr = self.profiled_actions.get(action_name) - if pr is None: - raise ValueError(f"Attempting to stop recording an action ({action_name}) which was never started.") - pr.disable() - - def summary(self) -> str: - recorded_stats = {} - for action_name, pr in self.profiled_actions.items(): - s = io.StringIO() - ps = pstats.Stats(pr, stream=s).strip_dirs().sort_stats("cumulative") - ps.print_stats(self.line_count_restriction) - recorded_stats[action_name] = s.getvalue() - return self._stats_to_str(recorded_stats) - - def teardown(self, stage: Optional[str] = None) -> None: - super().teardown(stage=stage) - self.profiled_actions = {} - - def __reduce__(self): - # avoids `TypeError: cannot pickle 'cProfile.Profile' object` - return ( - self.__class__, - (), - dict(dirpath=self.dirpath, filename=self.filename, line_count_restriction=self.line_count_restriction), +class AdvancedProfiler(NewAdvancedProfiler): + def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] + rank_zero_deprecation( + "`pytorch_lightning.profiler.AdvancedProfiler` is deprecated in v1.7 and will be removed in v1.9." + " Use the equivalent `pytorch_lightning.profilers.AdvancedProfiler` class instead." ) + super().__init__(*args, **kwargs) diff --git a/src/pytorch_lightning/profiler/base.py b/src/pytorch_lightning/profiler/base.py index b4eae688ebf80..f2e0ad5276f2e 100644 --- a/src/pytorch_lightning/profiler/base.py +++ b/src/pytorch_lightning/profiler/base.py @@ -15,7 +15,8 @@ from abc import ABC, abstractmethod from typing import Any -from pytorch_lightning.profiler.profiler import Profiler +from pytorch_lightning.profilers.base import PassThroughProfiler as NewPassThroughProfiler +from pytorch_lightning.profilers.profiler import Profiler from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation @@ -57,21 +58,17 @@ class BaseProfiler(Profiler): Please use `Profiler` instead. """ - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs): # type: ignore[no-untyped-def] rank_zero_deprecation( "`BaseProfiler` was deprecated in v1.6 and will be removed in v1.8. Please use `Profiler` instead." ) super().__init__(*args, **kwargs) -class PassThroughProfiler(Profiler): - """This class should be used when you don't want the (small) overhead of profiling. - - The Trainer uses this class by default. - """ - - def start(self, action_name: str) -> None: - pass - - def stop(self, action_name: str) -> None: - pass +class PassThroughProfiler(NewPassThroughProfiler): + def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] + rank_zero_deprecation( + "`pytorch_lightning.profiler.PassThroughProfiler` is deprecated in v1.7 and will be removed in v1.9." + " Use the equivalent `pytorch_lightning.profilers.PassThroughProfiler` class instead." + ) + super().__init__(*args, **kwargs) diff --git a/src/pytorch_lightning/profiler/profiler.py b/src/pytorch_lightning/profiler/profiler.py index 1b36159837523..84bea3ecae238 100644 --- a/src/pytorch_lightning/profiler/profiler.py +++ b/src/pytorch_lightning/profiler/profiler.py @@ -11,164 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Profiler to check if there are any bottlenecks in your code.""" -import logging -import os -from abc import ABC, abstractmethod -from contextlib import contextmanager -from pathlib import Path -from typing import Any, Callable, Dict, Generator, Iterable, Optional, TextIO, Union +from pytorch_lightning.profilers.profiler import Profiler as NewProfiler +from pytorch_lightning.utilities import rank_zero_deprecation -from pytorch_lightning.utilities.cloud_io import get_filesystem -from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation -log = logging.getLogger(__name__) +class Profiler(NewProfiler): + """ + .. deprecated:: v1.6 + `pytorch_lightning.profiler.Profiler` is deprecated in v1.7 and will be removed in v1.9. + Use the equivalent `pytorch_lightning.profilers.Profiler` class instead. + """ - -class Profiler(ABC): - """If you wish to write a custom profiler, you should inherit from this class.""" - - def __init__( - self, - dirpath: Optional[Union[str, Path]] = None, - filename: Optional[str] = None, - ) -> None: - self.dirpath = dirpath - self.filename = filename - - self._output_file: Optional[TextIO] = None - self._write_stream: Optional[Callable] = None - self._local_rank: Optional[int] = None - self._stage: Optional[str] = None - - @abstractmethod - def start(self, action_name: str) -> None: - """Defines how to start recording an action.""" - - @abstractmethod - def stop(self, action_name: str) -> None: - """Defines how to record the duration once an action is complete.""" - - def summary(self) -> str: - return "" - - @contextmanager - def profile(self, action_name: str) -> Generator: - """Yields a context manager to encapsulate the scope of a profiled action. - - Example:: - - with self.profile('load training data'): - # load training data code - - The profiler will start once you've entered the context and will automatically - stop once you exit the code block. - """ - try: - self.start(action_name) - yield action_name - finally: - self.stop(action_name) - - def profile_iterable(self, iterable: Iterable, action_name: str) -> Generator: - """Profiles over each value of an iterable. - - See deprecation message below. - - .. deprecated:: v1.6 - `Profiler.profile_iterable` is deprecated in v1.6 and will be removed in v1.8. - """ + def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] rank_zero_deprecation( - f"`{self.__class__.__name__}.profile_iterable` is deprecated in v1.6 and will be removed in v1.8." + "`pytorch_lightning.profiler.Profiler` is deprecated in v1.7 and will be removed in v1.9." + " Use the equivalent `pytorch_lightning.profilers.Profiler` class instead." ) - iterator = iter(iterable) - while True: - try: - self.start(action_name) - value = next(iterator) - self.stop(action_name) - yield value - except StopIteration: - self.stop(action_name) - break - - def _rank_zero_info(self, *args: Any, **kwargs: Any) -> None: - if self._local_rank in (None, 0): - log.info(*args, **kwargs) - - def _prepare_filename( - self, action_name: Optional[str] = None, extension: str = ".txt", split_token: str = "-" - ) -> str: - args = [] - if self._stage is not None: - args.append(self._stage) - if self.filename: - args.append(self.filename) - if self._local_rank is not None: - args.append(str(self._local_rank)) - if action_name is not None: - args.append(action_name) - filename = split_token.join(args) + extension - return filename - - def _prepare_streams(self) -> None: - if self._write_stream is not None: - return - if self.filename and self.dirpath: - filepath = os.path.join(self.dirpath, self._prepare_filename()) - fs = get_filesystem(filepath) - fs.mkdirs(self.dirpath, exist_ok=True) - file = fs.open(filepath, "a") - self._output_file = file - self._write_stream = file.write - else: - self._write_stream = self._rank_zero_info - - def describe(self) -> None: - """Logs a profile report after the conclusion of run.""" - # users might call `describe` directly as the profilers can be used by themselves. - # to allow this, we open and close the files within this function by calling `_prepare_streams` and `teardown` - # manually instead of letting the `Trainer` do it through `setup` and `teardown` - self._prepare_streams() - summary = self.summary() - if summary and self._write_stream is not None: - self._write_stream(summary) - if self._output_file is not None: - self._output_file.flush() - self.teardown(stage=self._stage) - - def _stats_to_str(self, stats: Dict[str, str]) -> str: - stage = f"{self._stage.upper()} " if self._stage is not None else "" - output = [stage + "Profiler Report"] - for action, value in stats.items(): - header = f"Profile stats for: {action}" - if self._local_rank is not None: - header += f" rank: {self._local_rank}" - output.append(header) - output.append(value) - return os.linesep.join(output) - - def setup( - self, stage: Optional[str] = None, local_rank: Optional[int] = None, log_dir: Optional[str] = None - ) -> None: - """Execute arbitrary pre-profiling set-up steps.""" - self._stage = stage - self._local_rank = local_rank - self.dirpath = self.dirpath or log_dir - - def teardown(self, stage: Optional[str] = None) -> None: - """Execute arbitrary post-profiling tear-down steps. - - Closes the currently open file and stream. - """ - self._write_stream = None - if self._output_file is not None: - self._output_file.close() - self._output_file = None # can't pickle TextIOWrapper - - def __del__(self) -> None: - self.teardown(stage=self._stage) - - @property - def local_rank(self) -> int: - return 0 if self._local_rank is None else self._local_rank + super().__init__(*args, **kwargs) diff --git a/src/pytorch_lightning/profiler/pytorch.py b/src/pytorch_lightning/profiler/pytorch.py index 062031bafa2a6..d443059912602 100644 --- a/src/pytorch_lightning/profiler/pytorch.py +++ b/src/pytorch_lightning/profiler/pytorch.py @@ -11,504 +11,34 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Profiler to check if there are any bottlenecks in your code.""" -import inspect -import logging -import os -from functools import lru_cache, partial -from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Type, TYPE_CHECKING, Union +from pytorch_lightning.profilers.pytorch import PyTorchProfiler as NewPyTorchProfiler +from pytorch_lightning.profilers.pytorch import RegisterRecordFunction as NewRegisterRecordFuncion +from pytorch_lightning.profilers.pytorch import ScheduleWrapper as NewScheduleWrapper +from pytorch_lightning.utilities import rank_zero_deprecation -import torch -from torch import nn, Tensor -from torch.autograd.profiler import record_function -from pytorch_lightning.profiler.profiler import Profiler -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _KINETO_AVAILABLE -from pytorch_lightning.utilities.rank_zero import rank_zero_warn -from pytorch_lightning.utilities.warnings import WarningCache - -if TYPE_CHECKING: - from torch.autograd.profiler import EventList - from torch.utils.hooks import RemovableHandle - - from pytorch_lightning.core.module import LightningModule - -if _KINETO_AVAILABLE: - from torch.profiler import ProfilerAction, ProfilerActivity, tensorboard_trace_handler - -log = logging.getLogger(__name__) -warning_cache = WarningCache() - -_PROFILER = Union[torch.autograd.profiler.profile, torch.cuda.profiler.profile, torch.autograd.profiler.emit_nvtx] - - -class RegisterRecordFunction: - """While profiling autograd operations, this class will add labels for module names around the forward - function. - - The Lightning PyTorch Profiler will activate this feature automatically. It can be deactivated as follows: - - Example:: - from pytorch_lightning.profilers import PyTorchProfiler - profiler = PyTorchProfiler(record_module_names=False) - Trainer(profiler=profiler) - - It can be used outside of Lightning as follows: - - Example:: - from pytorch_lightning import Trainer, seed_everything - with RegisterRecordFunction(model): - out = model(batch) - """ - - def __init__(self, model: nn.Module) -> None: - self._model = model - self._records: Dict[str, record_function] = {} - self._handles: Dict[str, List["RemovableHandle"]] = {} - - def _start_recording_forward(self, _: nn.Module, input: Tensor, record_name: str) -> Tensor: - # Add [pl][module] in name for pytorch profiler to recognize - record = record_function("[pl][module]" + record_name) - record.__enter__() - self._records[record_name] = record - return input - - def _stop_recording_forward(self, _: nn.Module, __: Tensor, output: Tensor, record_name: str) -> Tensor: - self._records[record_name].__exit__(None, None, None) - return output - - def __enter__(self) -> None: - for module_name, module in self._model.named_modules(): - if module_name: - full_name = f"{type(module).__module__}.{type(module).__name__}" - record_name = f"{full_name}: {module_name}" - pre_forward_handle = module.register_forward_pre_hook( - partial(self._start_recording_forward, record_name=record_name) - ) - post_forward_handle = module.register_forward_hook( - partial(self._stop_recording_forward, record_name=record_name) - ) - - self._handles[module_name] = [pre_forward_handle, post_forward_handle] - - def __exit__(self, type: Any, value: Any, traceback: Any) -> None: - for handles in self._handles.values(): - for h in handles: - h.remove() - self._handles = {} - - -class ScheduleWrapper: - """This class is used to override the schedule logic from the profiler and perform recording for both - `training_step`, `validation_step`.""" - - def __init__(self, schedule: Callable) -> None: - if not _KINETO_AVAILABLE: - raise ModuleNotFoundError("You are trying to use `ScheduleWrapper` which require kineto install.") - self._schedule = schedule - self.reset() - - def setup(self, start_action_name: str) -> None: - self._start_action_name = start_action_name - - def pre_step(self, current_action: str) -> None: - self._current_action = current_action - - def reset(self): - # handle properly `fast_dev_run`. PyTorch Profiler will fail otherwise. - self._num_training_step = 0 - self._num_validation_step = 0 - self._num_test_step = 0 - self._num_predict_step = 0 - self._training_step_reached_end = False - self._validation_step_reached_end = False - self._test_step_reached_end = False - self._predict_step_reached_end = False - # used to stop profiler when `ProfilerAction.RECORD_AND_SAVE` is reached. - self._current_action: Optional[str] = None - self._prev_schedule_action: Optional[ProfilerAction] = None - self._start_action_name: Optional[str] = None - - @property - def is_training(self): - return self._current_action.endswith("training_step") - - @property - def is_validating(self): - return self._current_action.endswith("validation_step") - - @property - def is_testing(self): - return self._current_action.endswith("test_step") - - @property - def is_predicting(self): - return self._current_action.endswith("predict_step") - - @property - def num_step(self) -> int: - if self.is_training: - return self._num_training_step - if self.is_validating: - return self._num_validation_step - if self.is_testing: - return self._num_test_step - if self.is_predicting: - return self._num_predict_step - return 0 - - def _step(self) -> None: - if self.is_training: - self._num_training_step += 1 - elif self.is_validating: - if self._start_action_name.endswith("on_fit_start"): - if self._num_training_step > 0: - self._num_validation_step += 1 - else: - self._num_validation_step += 1 - elif self.is_testing: - self._num_test_step += 1 - elif self.is_predicting: - self._num_predict_step += 1 - - @property - def has_finished(self) -> bool: - if self.is_training: - return self._training_step_reached_end - if self.is_validating: - return self._validation_step_reached_end - if self.is_testing: - return self._test_step_reached_end - if self.is_predicting: - return self._predict_step_reached_end - return False - - def __call__(self, num_step: int) -> "ProfilerAction": - # ignore the provided input. Keep internal state instead. - if self._current_action is None or self.has_finished: - return ProfilerAction.NONE - - self._step() - action = self._schedule(max(self.num_step, 0)) - if self._prev_schedule_action == ProfilerAction.RECORD and action == ProfilerAction.WARMUP: - # Work around the corner case when validation starts before train. - # In this case, the action is RECORD in validation loop, and then call into the train - # and the action is still WARMUP in train and pytorch will recognize this as error. - action = ProfilerAction.RECORD - if action == ProfilerAction.RECORD_AND_SAVE: - if self.is_training: - self._training_step_reached_end = True - elif self.is_validating: - self._validation_step_reached_end = True - elif self.is_testing: - self._test_step_reached_end = True - elif self.is_predicting: - self._predict_step_reached_end = True - self._prev_schedule_action = action - return action - - -class PyTorchProfiler(Profiler): - - STEP_FUNCTIONS = {"training_step", "validation_step", "test_step", "predict_step"} - AVAILABLE_SORT_KEYS = { - "cpu_time", - "cuda_time", - "cpu_time_total", - "cuda_time_total", - "cpu_memory_usage", - "cuda_memory_usage", - "self_cpu_memory_usage", - "self_cuda_memory_usage", - "count", - } - - def __init__( - self, - dirpath: Optional[Union[str, Path]] = None, - filename: Optional[str] = None, - group_by_input_shapes: bool = False, - emit_nvtx: bool = False, - export_to_chrome: bool = True, - row_limit: int = 20, - sort_by_key: Optional[str] = None, - record_module_names: bool = True, - **profiler_kwargs: Any, - ) -> None: - """This profiler uses PyTorch's Autograd Profiler and lets you inspect the cost of. - - different operators inside your model - both on the CPU and GPU - - Args: - dirpath: Directory path for the ``filename``. If ``dirpath`` is ``None`` but ``filename`` is present, the - ``trainer.log_dir`` (from :class:`~pytorch_lightning.loggers.tensorboard.TensorBoardLogger`) - will be used. - - filename: If present, filename where the profiler results will be saved instead of printing to stdout. - The ``.txt`` extension will be used automatically. - - group_by_input_shapes: Include operator input shapes and group calls by shape. - - emit_nvtx: Context manager that makes every autograd operation emit an NVTX range - Run:: - - nvprof --profile-from-start off -o trace_name.prof -- - - To visualize, you can either use:: - - nvvp trace_name.prof - torch.autograd.profiler.load_nvprof(path) - - export_to_chrome: Whether to export the sequence of profiled operators for Chrome. - It will generate a ``.json`` file which can be read by Chrome. - - row_limit: Limit the number of rows in a table, ``-1`` is a special value that - removes the limit completely. - - sort_by_key: Attribute used to sort entries. By default - they are printed in the same order as they were registered. - Valid keys include: ``cpu_time``, ``cuda_time``, ``cpu_time_total``, - ``cuda_time_total``, ``cpu_memory_usage``, ``cuda_memory_usage``, - ``self_cpu_memory_usage``, ``self_cuda_memory_usage``, ``count``. - - record_module_names: Whether to add module names while recording autograd operation. - - profiler_kwargs: Keyword arguments for the PyTorch profiler. This depends on your PyTorch version - - Raises: - MisconfigurationException: - If arg ``sort_by_key`` is not present in ``AVAILABLE_SORT_KEYS``. - If arg ``schedule`` is not a ``Callable``. - If arg ``schedule`` does not return a ``torch.profiler.ProfilerAction``. - """ - super().__init__(dirpath=dirpath, filename=filename) - - self._group_by_input_shapes = group_by_input_shapes and profiler_kwargs.get("record_shapes", False) - self._emit_nvtx = emit_nvtx - self._export_to_chrome = export_to_chrome - self._row_limit = row_limit - self._sort_by_key = sort_by_key or f"{'cuda' if profiler_kwargs.get('use_cuda', False) else 'cpu'}_time_total" - self._record_module_names = record_module_names - self._profiler_kwargs = profiler_kwargs - - self.profiler: Optional[_PROFILER] = None - self.function_events: Optional["EventList"] = None - self._lightning_module: Optional["LightningModule"] = None # set by ProfilerConnector - self._register: Optional[RegisterRecordFunction] = None - self._parent_profiler: Optional[_PROFILER] = None - self._recording_map: Dict[str, record_function] = {} - self._start_action_name: Optional[str] = None - self._schedule: Optional[ScheduleWrapper] = None - - if _KINETO_AVAILABLE: - self._init_kineto(profiler_kwargs) - - if self._sort_by_key not in self.AVAILABLE_SORT_KEYS: - raise MisconfigurationException( - f"Found sort_by_key: {self._sort_by_key}. Should be within {self.AVAILABLE_SORT_KEYS}. " - ) - - def _init_kineto(self, profiler_kwargs: Any) -> None: - has_schedule = "schedule" in profiler_kwargs - self._has_on_trace_ready = "on_trace_ready" in profiler_kwargs - - schedule = profiler_kwargs.get("schedule", None) - if schedule is not None: - if not isinstance(schedule, Callable): - raise MisconfigurationException(f"Schedule should be a callable. Found: {schedule}") - action = schedule(0) - if not isinstance(action, ProfilerAction): - raise MisconfigurationException( - f"Schedule should return a `torch.profiler.ProfilerAction`. Found: {action}" - ) - self._default_schedule() - schedule = schedule if has_schedule else self._default_schedule() - self._schedule = ScheduleWrapper(schedule) if schedule is not None else schedule - self._profiler_kwargs["schedule"] = self._schedule - - activities = profiler_kwargs.get("activities", None) - self._profiler_kwargs["activities"] = activities or self._default_activities() - self._export_to_flame_graph = profiler_kwargs.get("export_to_flame_graph", False) - self._metric = profiler_kwargs.get("metric", "self_cpu_time_total") - with_stack = profiler_kwargs.get("with_stack", False) or self._export_to_flame_graph - self._profiler_kwargs["with_stack"] = with_stack - - @property - def _total_steps(self) -> int: - trainer = self._lightning_module.trainer - if self._schedule.is_training: - return trainer.num_training_batches - if self._schedule.is_validating: - return sum(trainer.num_val_batches) + sum(trainer.num_sanity_val_batches) - if self._schedule.is_testing: - return sum(trainer.num_test_batches) - if self._schedule.is_predicting: - return sum(trainer.num_predict_batches) - - def _should_override_schedule(self) -> bool: - return ( - self._lightning_module is not None - and self._schedule is not None - and self._total_steps < 5 - and self._schedule._schedule == self._default_schedule() +class RegisterRecordFunction(NewRegisterRecordFuncion): + def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] + rank_zero_deprecation( + "`pytorch_lightning.profiler.pytorch.RegisterRecordFunction` is deprecated in v1.7 and will be removed in" + " in v1.9. Use the equivalent `pytorch_lightning.profilers.pytorch.RegisterRecordFunction` class instead." ) + super().__init__(*args, **kwargs) - @staticmethod - @lru_cache(1) - def _default_schedule() -> Optional[callable]: - if _KINETO_AVAILABLE: - # Those schedule defaults allow the profiling overhead to be negligible over training time. - return torch.profiler.schedule(wait=1, warmup=1, active=3) - - def _default_activities(self) -> List["ProfilerActivity"]: - activities = [] - if not _KINETO_AVAILABLE: - return activities - if self._profiler_kwargs.get("use_cpu", True): - activities.append(ProfilerActivity.CPU) - if self._profiler_kwargs.get("use_cuda", torch.cuda.is_available()): - activities.append(ProfilerActivity.CUDA) - return activities - - def start(self, action_name: str) -> None: - if self.profiler is None: - # close profiler if it is already opened. might happen if 2 profilers - # are created and the first one did not call `describe` - if torch.autograd._profiler_enabled(): - torch.autograd._disable_profiler() - - if self._schedule is not None: - self._schedule.setup(action_name) - - self._create_profilers() - - profiler = self.profiler.__enter__() - if profiler is not None: - self.profiler = profiler - - if self._parent_profiler is not None: - self._parent_profiler.__enter__() - - if self._lightning_module is not None and self._register is None and self._record_module_names: - self._register = RegisterRecordFunction(self._lightning_module) - self._register.__enter__() - - if self.profiler is not None and action_name not in self._recording_map: - # Add [pl][profile] in name for pytorch profiler to recognize - recording = record_function("[pl][profile]" + action_name) - recording.__enter__() - self._recording_map[action_name] = recording - - def stop(self, action_name: str) -> None: - if action_name in self._recording_map: - self._recording_map[action_name].__exit__(None, None, None) - del self._recording_map[action_name] - if not _KINETO_AVAILABLE or self._emit_nvtx: - return - - if self.profiler is not None and any(action_name.endswith(func) for func in self.STEP_FUNCTIONS): - if self._schedule is not None: - self._schedule.pre_step(action_name) - - # the default schedule requires a minimum of 5 steps to properly work: `wait=1, warmup=1, active=3`. - # otherwise, this will raise a `segmentation fault`. - if self._should_override_schedule(): - warning_cache.warn( - "The PyTorch Profiler default schedule will be overridden as there is not enough " - "steps to properly record traces." - ) - self._schedule = None - self.profiler.schedule = torch.profiler.profiler._default_schedule_fn - - def on_trace_ready(profiler): - if self.dirpath is not None: - if self._export_to_chrome: - handler = tensorboard_trace_handler( - self.dirpath, self._prepare_filename(action_name=action_name, extension="") - ) - handler(profiler) - - if self._export_to_flame_graph: - path = os.path.join( - self.dirpath, self._prepare_filename(action_name=action_name, extension=".stack") - ) - profiler.export_stacks(path, metric=self._metric) - else: - rank_zero_warn("The PyTorchProfiler failed to export trace as `dirpath` is None") - - if not self._has_on_trace_ready: - self.profiler.on_trace_ready = on_trace_ready - - if self._schedule is not None: - self.profiler.step_num = self._schedule.num_step - self.profiler.step() - self.profiler.add_metadata("Framework", "pytorch-lightning") - - def summary(self) -> str: - if not self._profiler_kwargs.get("enabled", True) or self._emit_nvtx: - return "" - - self._delete_profilers() - - if not self.function_events: - return "" - - if self._export_to_chrome and not _KINETO_AVAILABLE: - filename = f"{self.local_rank}_trace.json" - path_to_trace = filename if self.dirpath is None else os.path.join(self.dirpath, filename) - self.function_events.export_chrome_trace(path_to_trace) - - data = self.function_events.key_averages(group_by_input_shapes=self._group_by_input_shapes) - table = data.table(sort_by=self._sort_by_key, row_limit=self._row_limit) - - recorded_stats = {"records": table} - return self._stats_to_str(recorded_stats) - - def _create_profilers(self) -> None: - if self._emit_nvtx: - self._parent_profiler = self._create_profiler(torch.cuda.profiler.profile) - self.profiler = self._create_profiler(torch.autograd.profiler.emit_nvtx) - else: - self._parent_profiler = None - self.profiler = self._create_profiler( - torch.profiler.profile if _KINETO_AVAILABLE else torch.autograd.profiler.profile - ) - - def _create_profiler(self, profiler: Type[_PROFILER]) -> _PROFILER: - init_parameters = inspect.signature(profiler.__init__).parameters - kwargs = {k: v for k, v in self._profiler_kwargs.items() if k in init_parameters} - return profiler(**kwargs) - - def _cache_functions_events(self) -> None: - if self._emit_nvtx: - return - self.function_events = self.profiler.events() if _KINETO_AVAILABLE else self.profiler.function_events - - def _delete_profilers(self) -> None: - if self.profiler is not None: - self.profiler.__exit__(None, None, None) - self._cache_functions_events() - self.profiler = None - - if self._schedule is not None: - self._schedule.reset() - - if self._parent_profiler is not None: - self._parent_profiler.__exit__(None, None, None) - self._parent_profiler = None - - if self._register is not None: - self._register.__exit__(None, None, None) - self._register = None - - def teardown(self, stage: Optional[str] = None) -> None: - self._delete_profilers() +class ScheduleWrapper(NewScheduleWrapper): + def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] + rank_zero_deprecation( + "`pytorch_lightning.profiler.pytorch.ScheduleWrapper` is deprecated in v1.7 and will be removed in v1.9." + " Use the equivalent `pytorch_lightning.profilers.pytorch.ScheduleWrapper` class instead." + ) + super().__init__(*args, **kwargs) - for k in list(self._recording_map): - self.stop(k) - self._recording_map = {} - super().teardown(stage=stage) +class PyTorchProfiler(NewPyTorchProfiler): + def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] + rank_zero_deprecation( + "`pytorch_lightning.profiler.PyTorchProfiler` is deprecated in v1.7 and will be removed in v1.9." + " Use the equivalent `pytorch_lightning.profilers.PyTorchProfiler` class instead." + ) + super().__init__(*args, **kwargs) diff --git a/src/pytorch_lightning/profiler/simple.py b/src/pytorch_lightning/profiler/simple.py index f49dc8dc97d0c..61ef7da8ae0f4 100644 --- a/src/pytorch_lightning/profiler/simple.py +++ b/src/pytorch_lightning/profiler/simple.py @@ -11,134 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Profiler to check if there are any bottlenecks in your code.""" -import logging -import os -import time -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple, Union +from pytorch_lightning.profilers.simple import SimpleProfiler as NewSimpleProfiler +from pytorch_lightning.utilities import rank_zero_deprecation -import numpy as np -from pytorch_lightning.profiler.profiler import Profiler - -log = logging.getLogger(__name__) - -_TABLE_ROW_EXTENDED = Tuple[str, float, int, float, float] -_TABLE_DATA_EXTENDED = List[_TABLE_ROW_EXTENDED] -_TABLE_ROW = Tuple[str, float, float] -_TABLE_DATA = List[_TABLE_ROW] - - -class SimpleProfiler(Profiler): - """This profiler simply records the duration of actions (in seconds) and reports the mean duration of each - action and the total time spent over the entire training run.""" - - def __init__( - self, - dirpath: Optional[Union[str, Path]] = None, - filename: Optional[str] = None, - extended: bool = True, - ) -> None: - """ - Args: - dirpath: Directory path for the ``filename``. If ``dirpath`` is ``None`` but ``filename`` is present, the - ``trainer.log_dir`` (from :class:`~pytorch_lightning.loggers.tensorboard.TensorBoardLogger`) - will be used. - - filename: If present, filename where the profiler results will be saved instead of printing to stdout. - The ``.txt`` extension will be used automatically. - - extended: If ``True``, adds extra columns representing number of calls and percentage of total time spent on - respective action. - - Raises: - ValueError: - If you attempt to start an action which has already started, or - if you attempt to stop recording an action which was never started. - """ - super().__init__(dirpath=dirpath, filename=filename) - self.current_actions: Dict[str, float] = {} - self.recorded_durations = defaultdict(list) - self.extended = extended - self.start_time = time.monotonic() - - def start(self, action_name: str) -> None: - if action_name in self.current_actions: - raise ValueError(f"Attempted to start {action_name} which has already started.") - self.current_actions[action_name] = time.monotonic() - - def stop(self, action_name: str) -> None: - end_time = time.monotonic() - if action_name not in self.current_actions: - raise ValueError(f"Attempting to stop recording an action ({action_name}) which was never started.") - start_time = self.current_actions.pop(action_name) - duration = end_time - start_time - self.recorded_durations[action_name].append(duration) - - def _make_report_extended(self) -> Tuple[_TABLE_DATA_EXTENDED, float, float]: - total_duration = time.monotonic() - self.start_time - report = [ - (a, np.mean(d), len(d), np.sum(d), 100.0 * np.sum(d) / total_duration) - for a, d in self.recorded_durations.items() - ] - report.sort(key=lambda x: x[4], reverse=True) - total_calls = sum(x[2] for x in report) - return report, total_calls, total_duration - - def _make_report(self) -> _TABLE_DATA: - report = [(action, np.mean(d), np.sum(d)) for action, d in self.recorded_durations.items()] - report.sort(key=lambda x: x[1], reverse=True) - return report - - def summary(self) -> str: - sep = os.linesep - output_string = "" - if self._stage is not None: - output_string += f"{self._stage.upper()} " - output_string += f"Profiler Report{sep}" - - if self.extended: - - if len(self.recorded_durations) > 0: - max_key = max(len(k) for k in self.recorded_durations.keys()) - - def log_row(action, mean, num_calls, total, per): - row = f"{sep}| {action:<{max_key}s}\t| {mean:<15}\t|" - row += f" {num_calls:<15}\t| {total:<15}\t| {per:<15}\t|" - return row - - header_string = log_row("Action", "Mean duration (s)", "Num calls", "Total time (s)", "Percentage %") - output_string_len = len(header_string.expandtabs()) - sep_lines = f"{sep}{'-' * output_string_len}" - output_string += sep_lines + header_string + sep_lines - report, total_calls, total_duration = self._make_report_extended() - output_string += log_row("Total", "-", f"{total_calls:}", f"{total_duration:.5}", "100 %") - output_string += sep_lines - for action, mean_duration, num_calls, total_duration, duration_per in report: - output_string += log_row( - action, - f"{mean_duration:.5}", - f"{num_calls}", - f"{total_duration:.5}", - f"{duration_per:.5}", - ) - output_string += sep_lines - else: - max_key = max(len(k) for k in self.recorded_durations) - - def log_row(action, mean, total): - return f"{sep}| {action:<{max_key}s}\t| {mean:<15}\t| {total:<15}\t|" - - header_string = log_row("Action", "Mean duration (s)", "Total time (s)") - output_string_len = len(header_string.expandtabs()) - sep_lines = f"{sep}{'-' * output_string_len}" - output_string += sep_lines + header_string + sep_lines - report = self._make_report() - - for action, mean_duration, total_duration in report: - output_string += log_row(action, f"{mean_duration:.5}", f"{total_duration:.5}") - output_string += sep_lines - output_string += sep - return output_string +class SimpleProfiler(NewSimpleProfiler): + def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] + rank_zero_deprecation( + "`pytorch_lightning.profiler.SimpleProfiler` is deprecated in v1.7 and will be removed in v1.9." + " Use the equivalent `pytorch_lightning.profilers.SimpleProfiler` class instead." + ) + super().__init__(*args, **kwargs) diff --git a/src/pytorch_lightning/profiler/xla.py b/src/pytorch_lightning/profiler/xla.py index 010b083ff1b95..dde858e99eeaa 100644 --- a/src/pytorch_lightning/profiler/xla.py +++ b/src/pytorch_lightning/profiler/xla.py @@ -11,68 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import logging -from typing import Dict +from pytorch_lightning.profilers.xla import XLAProfiler as NewXLAProfiler +from pytorch_lightning.utilities import rank_zero_deprecation -from pytorch_lightning.profiler.profiler import Profiler -from pytorch_lightning.utilities import _TPU_AVAILABLE -from pytorch_lightning.utilities.exceptions import MisconfigurationException -if _TPU_AVAILABLE: - import torch_xla.debug.profiler as xp - -log = logging.getLogger(__name__) - - -class XLAProfiler(Profiler): - - STEP_FUNCTIONS = {"validation_step", "test_step", "predict_step"} - RECORD_FUNCTIONS = { - "training_step", - "backward", - "validation_step", - "test_step", - "predict_step", - } - - def __init__(self, port: int = 9012) -> None: - """XLA Profiler will help you debug and optimize training workload performance for your models using Cloud - TPU performance tools. - - Args: - port: the port to start the profiler server on. An exception is - raised if the provided port is invalid or busy. - """ - if not _TPU_AVAILABLE: - raise MisconfigurationException("`XLAProfiler` is only supported on TPUs") - super().__init__(dirpath=None, filename=None) - self.port = port - self._recording_map: Dict = {} - self._step_recoding_map: Dict = {} - self._start_trace: bool = False - - def start(self, action_name: str) -> None: - if action_name in self.RECORD_FUNCTIONS: - if not self._start_trace: - self.server = xp.start_server(self.port) - self._start_trace = True - - if action_name in self.STEP_FUNCTIONS: - step = self._get_step_num(action_name) - recording = xp.StepTrace(action_name, step_num=step) - else: - recording = xp.Trace(action_name) - recording.__enter__() - self._recording_map[action_name] = recording - - def stop(self, action_name: str) -> None: - if action_name in self._recording_map: - self._recording_map[action_name].__exit__(None, None, None) - del self._recording_map[action_name] - - def _get_step_num(self, action_name: str) -> int: - if action_name not in self._step_recoding_map: - self._step_recoding_map[action_name] = 1 - else: - self._step_recoding_map[action_name] += 1 - return self._step_recoding_map[action_name] +class XLAProfiler(NewXLAProfiler): + def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] + rank_zero_deprecation( + "`pytorch_lightning.profiler.XLAProfiler` is deprecated in v1.7 and will be removed in v1.9." + " Use the equivalent `pytorch_lightning.profilers.XLAProfiler` class instead." + ) + super().__init__(*args, **kwargs) diff --git a/src/pytorch_lightning/profilers/__init__.py b/src/pytorch_lightning/profilers/__init__.py new file mode 100644 index 0000000000000..dad105135fa01 --- /dev/null +++ b/src/pytorch_lightning/profilers/__init__.py @@ -0,0 +1,30 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pytorch_lightning.profilers.advanced import AdvancedProfiler +from pytorch_lightning.profilers.base import AbstractProfiler, BaseProfiler, PassThroughProfiler +from pytorch_lightning.profilers.profiler import Profiler +from pytorch_lightning.profilers.pytorch import PyTorchProfiler +from pytorch_lightning.profilers.simple import SimpleProfiler +from pytorch_lightning.profilers.xla import XLAProfiler + +__all__ = [ + "AbstractProfiler", + "BaseProfiler", + "Profiler", + "AdvancedProfiler", + "PassThroughProfiler", + "PyTorchProfiler", + "SimpleProfiler", + "XLAProfiler", +] diff --git a/src/pytorch_lightning/profilers/advanced.py b/src/pytorch_lightning/profilers/advanced.py new file mode 100644 index 0000000000000..214d67e52eb4c --- /dev/null +++ b/src/pytorch_lightning/profilers/advanced.py @@ -0,0 +1,91 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Profiler to check if there are any bottlenecks in your code.""" +import cProfile +import io +import logging +import pstats +from pathlib import Path +from typing import Dict, Optional, Union + +from pytorch_lightning.profilers.profiler import Profiler + +log = logging.getLogger(__name__) + + +class AdvancedProfiler(Profiler): + """This profiler uses Python's cProfiler to record more detailed information about time spent in each function + call recorded during a given action. + + The output is quite verbose and you should only use this if you want very detailed reports. + """ + + def __init__( + self, + dirpath: Optional[Union[str, Path]] = None, + filename: Optional[str] = None, + line_count_restriction: float = 1.0, + ) -> None: + """ + Args: + dirpath: Directory path for the ``filename``. If ``dirpath`` is ``None`` but ``filename`` is present, the + ``trainer.log_dir`` (from :class:`~pytorch_lightning.loggers.tensorboard.TensorBoardLogger`) + will be used. + + filename: If present, filename where the profiler results will be saved instead of printing to stdout. + The ``.txt`` extension will be used automatically. + + line_count_restriction: this can be used to limit the number of functions + reported for each action. either an integer (to select a count of lines), + or a decimal fraction between 0.0 and 1.0 inclusive (to select a percentage of lines) + + Raises: + ValueError: + If you attempt to stop recording an action which was never started. + """ + super().__init__(dirpath=dirpath, filename=filename) + self.profiled_actions: Dict[str, cProfile.Profile] = {} + self.line_count_restriction = line_count_restriction + + def start(self, action_name: str) -> None: + if action_name not in self.profiled_actions: + self.profiled_actions[action_name] = cProfile.Profile() + self.profiled_actions[action_name].enable() + + def stop(self, action_name: str) -> None: + pr = self.profiled_actions.get(action_name) + if pr is None: + raise ValueError(f"Attempting to stop recording an action ({action_name}) which was never started.") + pr.disable() + + def summary(self) -> str: + recorded_stats = {} + for action_name, pr in self.profiled_actions.items(): + s = io.StringIO() + ps = pstats.Stats(pr, stream=s).strip_dirs().sort_stats("cumulative") + ps.print_stats(self.line_count_restriction) + recorded_stats[action_name] = s.getvalue() + return self._stats_to_str(recorded_stats) + + def teardown(self, stage: Optional[str] = None) -> None: + super().teardown(stage=stage) + self.profiled_actions = {} + + def __reduce__(self): + # avoids `TypeError: cannot pickle 'cProfile.Profile' object` + return ( + self.__class__, + (), + dict(dirpath=self.dirpath, filename=self.filename, line_count_restriction=self.line_count_restriction), + ) diff --git a/src/pytorch_lightning/profilers/base.py b/src/pytorch_lightning/profilers/base.py new file mode 100644 index 0000000000000..b91f628013a33 --- /dev/null +++ b/src/pytorch_lightning/profilers/base.py @@ -0,0 +1,77 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Profiler to check if there are any bottlenecks in your code.""" +from abc import ABC, abstractmethod +from typing import Any + +from pytorch_lightning.profilers.profiler import Profiler +from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation + + +class AbstractProfiler(ABC): + """Specification of a profiler. + + See deprecation warning below + + .. deprecated:: v1.6 + `AbstractProfiler` was deprecated in v1.6 and will be removed in v1.8. + Please use `Profiler` instead. + """ + + @abstractmethod + def start(self, action_name: str) -> None: + """Defines how to start recording an action.""" + + @abstractmethod + def stop(self, action_name: str) -> None: + """Defines how to record the duration once an action is complete.""" + + @abstractmethod + def summary(self) -> str: + """Create profiler summary in text format.""" + + @abstractmethod + def setup(self, **kwargs: Any) -> None: + """Execute arbitrary pre-profiling set-up steps as defined by subclass.""" + + @abstractmethod + def teardown(self, **kwargs: Any) -> None: + """Execute arbitrary post-profiling tear-down steps as defined by subclass.""" + + +class BaseProfiler(Profiler): + """ + .. deprecated:: v1.6 + `BaseProfiler` was deprecated in v1.6 and will be removed in v1.8. + Please use `Profiler` instead. + """ + + def __init__(self, *args, **kwargs): + rank_zero_deprecation( + "`BaseProfiler` was deprecated in v1.6 and will be removed in v1.8. Please use `Profiler` instead." + ) + super().__init__(*args, **kwargs) + + +class PassThroughProfiler(Profiler): + """This class should be used when you don't want the (small) overhead of profiling. + + The Trainer uses this class by default. + """ + + def start(self, action_name: str) -> None: + pass + + def stop(self, action_name: str) -> None: + pass diff --git a/src/pytorch_lightning/profilers/profiler.py b/src/pytorch_lightning/profilers/profiler.py new file mode 100644 index 0000000000000..1b36159837523 --- /dev/null +++ b/src/pytorch_lightning/profilers/profiler.py @@ -0,0 +1,174 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Profiler to check if there are any bottlenecks in your code.""" +import logging +import os +from abc import ABC, abstractmethod +from contextlib import contextmanager +from pathlib import Path +from typing import Any, Callable, Dict, Generator, Iterable, Optional, TextIO, Union + +from pytorch_lightning.utilities.cloud_io import get_filesystem +from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation + +log = logging.getLogger(__name__) + + +class Profiler(ABC): + """If you wish to write a custom profiler, you should inherit from this class.""" + + def __init__( + self, + dirpath: Optional[Union[str, Path]] = None, + filename: Optional[str] = None, + ) -> None: + self.dirpath = dirpath + self.filename = filename + + self._output_file: Optional[TextIO] = None + self._write_stream: Optional[Callable] = None + self._local_rank: Optional[int] = None + self._stage: Optional[str] = None + + @abstractmethod + def start(self, action_name: str) -> None: + """Defines how to start recording an action.""" + + @abstractmethod + def stop(self, action_name: str) -> None: + """Defines how to record the duration once an action is complete.""" + + def summary(self) -> str: + return "" + + @contextmanager + def profile(self, action_name: str) -> Generator: + """Yields a context manager to encapsulate the scope of a profiled action. + + Example:: + + with self.profile('load training data'): + # load training data code + + The profiler will start once you've entered the context and will automatically + stop once you exit the code block. + """ + try: + self.start(action_name) + yield action_name + finally: + self.stop(action_name) + + def profile_iterable(self, iterable: Iterable, action_name: str) -> Generator: + """Profiles over each value of an iterable. + + See deprecation message below. + + .. deprecated:: v1.6 + `Profiler.profile_iterable` is deprecated in v1.6 and will be removed in v1.8. + """ + rank_zero_deprecation( + f"`{self.__class__.__name__}.profile_iterable` is deprecated in v1.6 and will be removed in v1.8." + ) + iterator = iter(iterable) + while True: + try: + self.start(action_name) + value = next(iterator) + self.stop(action_name) + yield value + except StopIteration: + self.stop(action_name) + break + + def _rank_zero_info(self, *args: Any, **kwargs: Any) -> None: + if self._local_rank in (None, 0): + log.info(*args, **kwargs) + + def _prepare_filename( + self, action_name: Optional[str] = None, extension: str = ".txt", split_token: str = "-" + ) -> str: + args = [] + if self._stage is not None: + args.append(self._stage) + if self.filename: + args.append(self.filename) + if self._local_rank is not None: + args.append(str(self._local_rank)) + if action_name is not None: + args.append(action_name) + filename = split_token.join(args) + extension + return filename + + def _prepare_streams(self) -> None: + if self._write_stream is not None: + return + if self.filename and self.dirpath: + filepath = os.path.join(self.dirpath, self._prepare_filename()) + fs = get_filesystem(filepath) + fs.mkdirs(self.dirpath, exist_ok=True) + file = fs.open(filepath, "a") + self._output_file = file + self._write_stream = file.write + else: + self._write_stream = self._rank_zero_info + + def describe(self) -> None: + """Logs a profile report after the conclusion of run.""" + # users might call `describe` directly as the profilers can be used by themselves. + # to allow this, we open and close the files within this function by calling `_prepare_streams` and `teardown` + # manually instead of letting the `Trainer` do it through `setup` and `teardown` + self._prepare_streams() + summary = self.summary() + if summary and self._write_stream is not None: + self._write_stream(summary) + if self._output_file is not None: + self._output_file.flush() + self.teardown(stage=self._stage) + + def _stats_to_str(self, stats: Dict[str, str]) -> str: + stage = f"{self._stage.upper()} " if self._stage is not None else "" + output = [stage + "Profiler Report"] + for action, value in stats.items(): + header = f"Profile stats for: {action}" + if self._local_rank is not None: + header += f" rank: {self._local_rank}" + output.append(header) + output.append(value) + return os.linesep.join(output) + + def setup( + self, stage: Optional[str] = None, local_rank: Optional[int] = None, log_dir: Optional[str] = None + ) -> None: + """Execute arbitrary pre-profiling set-up steps.""" + self._stage = stage + self._local_rank = local_rank + self.dirpath = self.dirpath or log_dir + + def teardown(self, stage: Optional[str] = None) -> None: + """Execute arbitrary post-profiling tear-down steps. + + Closes the currently open file and stream. + """ + self._write_stream = None + if self._output_file is not None: + self._output_file.close() + self._output_file = None # can't pickle TextIOWrapper + + def __del__(self) -> None: + self.teardown(stage=self._stage) + + @property + def local_rank(self) -> int: + return 0 if self._local_rank is None else self._local_rank diff --git a/src/pytorch_lightning/profilers/pytorch.py b/src/pytorch_lightning/profilers/pytorch.py new file mode 100644 index 0000000000000..c9340444a06eb --- /dev/null +++ b/src/pytorch_lightning/profilers/pytorch.py @@ -0,0 +1,514 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Profiler to check if there are any bottlenecks in your code.""" +import inspect +import logging +import os +from functools import lru_cache, partial +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Type, TYPE_CHECKING, Union + +import torch +from torch import nn, Tensor +from torch.autograd.profiler import record_function + +from pytorch_lightning.profilers.profiler import Profiler +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _KINETO_AVAILABLE +from pytorch_lightning.utilities.rank_zero import rank_zero_warn +from pytorch_lightning.utilities.warnings import WarningCache + +if TYPE_CHECKING: + from torch.autograd.profiler import EventList + from torch.utils.hooks import RemovableHandle + + from pytorch_lightning.core.module import LightningModule + +if _KINETO_AVAILABLE: + from torch.profiler import ProfilerAction, ProfilerActivity, tensorboard_trace_handler + +log = logging.getLogger(__name__) +warning_cache = WarningCache() + +_PROFILER = Union[torch.autograd.profiler.profile, torch.cuda.profiler.profile, torch.autograd.profiler.emit_nvtx] + + +class RegisterRecordFunction: + """While profiling autograd operations, this class will add labels for module names around the forward + function. + + The Lightning PyTorch Profiler will activate this feature automatically. It can be deactivated as follows: + + Example:: + from pytorch_lightning.profilers import PyTorchProfiler + profiler = PyTorchProfiler(record_module_names=False) + Trainer(profiler=profiler) + + It can be used outside of Lightning as follows: + + Example:: + from pytorch_lightning import Trainer, seed_everything + with RegisterRecordFunction(model): + out = model(batch) + """ + + def __init__(self, model: nn.Module) -> None: + self._model = model + self._records: Dict[str, record_function] = {} + self._handles: Dict[str, List["RemovableHandle"]] = {} + + def _start_recording_forward(self, _: nn.Module, input: Tensor, record_name: str) -> Tensor: + # Add [pl][module] in name for pytorch profiler to recognize + record = record_function("[pl][module]" + record_name) + record.__enter__() + self._records[record_name] = record + return input + + def _stop_recording_forward(self, _: nn.Module, __: Tensor, output: Tensor, record_name: str) -> Tensor: + self._records[record_name].__exit__(None, None, None) + return output + + def __enter__(self) -> None: + for module_name, module in self._model.named_modules(): + if module_name: + full_name = f"{type(module).__module__}.{type(module).__name__}" + record_name = f"{full_name}: {module_name}" + pre_forward_handle = module.register_forward_pre_hook( + partial(self._start_recording_forward, record_name=record_name) + ) + post_forward_handle = module.register_forward_hook( + partial(self._stop_recording_forward, record_name=record_name) + ) + + self._handles[module_name] = [pre_forward_handle, post_forward_handle] + + def __exit__(self, type: Any, value: Any, traceback: Any) -> None: + for handles in self._handles.values(): + for h in handles: + h.remove() + self._handles = {} + + +class ScheduleWrapper: + """This class is used to override the schedule logic from the profiler and perform recording for both + `training_step`, `validation_step`.""" + + def __init__(self, schedule: Callable) -> None: + if not _KINETO_AVAILABLE: + raise ModuleNotFoundError("You are trying to use `ScheduleWrapper` which require kineto install.") + self._schedule = schedule + self.reset() + + def setup(self, start_action_name: str) -> None: + self._start_action_name = start_action_name + + def pre_step(self, current_action: str) -> None: + self._current_action = current_action + + def reset(self): + # handle properly `fast_dev_run`. PyTorch Profiler will fail otherwise. + self._num_training_step = 0 + self._num_validation_step = 0 + self._num_test_step = 0 + self._num_predict_step = 0 + self._training_step_reached_end = False + self._validation_step_reached_end = False + self._test_step_reached_end = False + self._predict_step_reached_end = False + # used to stop profiler when `ProfilerAction.RECORD_AND_SAVE` is reached. + self._current_action: Optional[str] = None + self._prev_schedule_action: Optional[ProfilerAction] = None + self._start_action_name: Optional[str] = None + + @property + def is_training(self): + return self._current_action.endswith("training_step") + + @property + def is_validating(self): + return self._current_action.endswith("validation_step") + + @property + def is_testing(self): + return self._current_action.endswith("test_step") + + @property + def is_predicting(self): + return self._current_action.endswith("predict_step") + + @property + def num_step(self) -> int: + if self.is_training: + return self._num_training_step + if self.is_validating: + return self._num_validation_step + if self.is_testing: + return self._num_test_step + if self.is_predicting: + return self._num_predict_step + return 0 + + def _step(self) -> None: + if self.is_training: + self._num_training_step += 1 + elif self.is_validating: + if self._start_action_name.endswith("on_fit_start"): + if self._num_training_step > 0: + self._num_validation_step += 1 + else: + self._num_validation_step += 1 + elif self.is_testing: + self._num_test_step += 1 + elif self.is_predicting: + self._num_predict_step += 1 + + @property + def has_finished(self) -> bool: + if self.is_training: + return self._training_step_reached_end + if self.is_validating: + return self._validation_step_reached_end + if self.is_testing: + return self._test_step_reached_end + if self.is_predicting: + return self._predict_step_reached_end + return False + + def __call__(self, num_step: int) -> "ProfilerAction": + # ignore the provided input. Keep internal state instead. + if self._current_action is None or self.has_finished: + return ProfilerAction.NONE + + self._step() + action = self._schedule(max(self.num_step, 0)) + if self._prev_schedule_action == ProfilerAction.RECORD and action == ProfilerAction.WARMUP: + # Work around the corner case when validation starts before train. + # In this case, the action is RECORD in validation loop, and then call into the train + # and the action is still WARMUP in train and pytorch will recognize this as error. + action = ProfilerAction.RECORD + if action == ProfilerAction.RECORD_AND_SAVE: + if self.is_training: + self._training_step_reached_end = True + elif self.is_validating: + self._validation_step_reached_end = True + elif self.is_testing: + self._test_step_reached_end = True + elif self.is_predicting: + self._predict_step_reached_end = True + self._prev_schedule_action = action + return action + + +class PyTorchProfiler(Profiler): + + STEP_FUNCTIONS = {"training_step", "validation_step", "test_step", "predict_step"} + AVAILABLE_SORT_KEYS = { + "cpu_time", + "cuda_time", + "cpu_time_total", + "cuda_time_total", + "cpu_memory_usage", + "cuda_memory_usage", + "self_cpu_memory_usage", + "self_cuda_memory_usage", + "count", + } + + def __init__( + self, + dirpath: Optional[Union[str, Path]] = None, + filename: Optional[str] = None, + group_by_input_shapes: bool = False, + emit_nvtx: bool = False, + export_to_chrome: bool = True, + row_limit: int = 20, + sort_by_key: Optional[str] = None, + record_module_names: bool = True, + **profiler_kwargs: Any, + ) -> None: + """This profiler uses PyTorch's Autograd Profiler and lets you inspect the cost of. + + different operators inside your model - both on the CPU and GPU + + Args: + dirpath: Directory path for the ``filename``. If ``dirpath`` is ``None`` but ``filename`` is present, the + ``trainer.log_dir`` (from :class:`~pytorch_lightning.loggers.tensorboard.TensorBoardLogger`) + will be used. + + filename: If present, filename where the profiler results will be saved instead of printing to stdout. + The ``.txt`` extension will be used automatically. + + group_by_input_shapes: Include operator input shapes and group calls by shape. + + emit_nvtx: Context manager that makes every autograd operation emit an NVTX range + Run:: + + nvprof --profile-from-start off -o trace_name.prof -- + + To visualize, you can either use:: + + nvvp trace_name.prof + torch.autograd.profiler.load_nvprof(path) + + export_to_chrome: Whether to export the sequence of profiled operators for Chrome. + It will generate a ``.json`` file which can be read by Chrome. + + row_limit: Limit the number of rows in a table, ``-1`` is a special value that + removes the limit completely. + + sort_by_key: Attribute used to sort entries. By default + they are printed in the same order as they were registered. + Valid keys include: ``cpu_time``, ``cuda_time``, ``cpu_time_total``, + ``cuda_time_total``, ``cpu_memory_usage``, ``cuda_memory_usage``, + ``self_cpu_memory_usage``, ``self_cuda_memory_usage``, ``count``. + + record_module_names: Whether to add module names while recording autograd operation. + + profiler_kwargs: Keyword arguments for the PyTorch profiler. This depends on your PyTorch version + + Raises: + MisconfigurationException: + If arg ``sort_by_key`` is not present in ``AVAILABLE_SORT_KEYS``. + If arg ``schedule`` is not a ``Callable``. + If arg ``schedule`` does not return a ``torch.profiler.ProfilerAction``. + """ + super().__init__(dirpath=dirpath, filename=filename) + + self._group_by_input_shapes = group_by_input_shapes and profiler_kwargs.get("record_shapes", False) + self._emit_nvtx = emit_nvtx + self._export_to_chrome = export_to_chrome + self._row_limit = row_limit + self._sort_by_key = sort_by_key or f"{'cuda' if profiler_kwargs.get('use_cuda', False) else 'cpu'}_time_total" + self._record_module_names = record_module_names + self._profiler_kwargs = profiler_kwargs + + self.profiler: Optional[_PROFILER] = None + self.function_events: Optional["EventList"] = None + self._lightning_module: Optional["LightningModule"] = None # set by ProfilerConnector + self._register: Optional[RegisterRecordFunction] = None + self._parent_profiler: Optional[_PROFILER] = None + self._recording_map: Dict[str, record_function] = {} + self._start_action_name: Optional[str] = None + self._schedule: Optional[ScheduleWrapper] = None + + if _KINETO_AVAILABLE: + self._init_kineto(profiler_kwargs) + + if self._sort_by_key not in self.AVAILABLE_SORT_KEYS: + raise MisconfigurationException( + f"Found sort_by_key: {self._sort_by_key}. Should be within {self.AVAILABLE_SORT_KEYS}. " + ) + + def _init_kineto(self, profiler_kwargs: Any) -> None: + has_schedule = "schedule" in profiler_kwargs + self._has_on_trace_ready = "on_trace_ready" in profiler_kwargs + + schedule = profiler_kwargs.get("schedule", None) + if schedule is not None: + if not isinstance(schedule, Callable): + raise MisconfigurationException(f"Schedule should be a callable. Found: {schedule}") + action = schedule(0) + if not isinstance(action, ProfilerAction): + raise MisconfigurationException( + f"Schedule should return a `torch.profiler.ProfilerAction`. Found: {action}" + ) + self._default_schedule() + schedule = schedule if has_schedule else self._default_schedule() + self._schedule = ScheduleWrapper(schedule) if schedule is not None else schedule + self._profiler_kwargs["schedule"] = self._schedule + + activities = profiler_kwargs.get("activities", None) + self._profiler_kwargs["activities"] = activities or self._default_activities() + self._export_to_flame_graph = profiler_kwargs.get("export_to_flame_graph", False) + self._metric = profiler_kwargs.get("metric", "self_cpu_time_total") + with_stack = profiler_kwargs.get("with_stack", False) or self._export_to_flame_graph + self._profiler_kwargs["with_stack"] = with_stack + + @property + def _total_steps(self) -> int: + trainer = self._lightning_module.trainer + if self._schedule.is_training: + return trainer.num_training_batches + if self._schedule.is_validating: + return sum(trainer.num_val_batches) + sum(trainer.num_sanity_val_batches) + if self._schedule.is_testing: + return sum(trainer.num_test_batches) + if self._schedule.is_predicting: + return sum(trainer.num_predict_batches) + + def _should_override_schedule(self) -> bool: + return ( + self._lightning_module is not None + and self._schedule is not None + and self._total_steps < 5 + and self._schedule._schedule == self._default_schedule() + ) + + @staticmethod + @lru_cache(1) + def _default_schedule() -> Optional[callable]: + if _KINETO_AVAILABLE: + # Those schedule defaults allow the profiling overhead to be negligible over training time. + return torch.profiler.schedule(wait=1, warmup=1, active=3) + + def _default_activities(self) -> List["ProfilerActivity"]: + activities = [] + if not _KINETO_AVAILABLE: + return activities + if self._profiler_kwargs.get("use_cpu", True): + activities.append(ProfilerActivity.CPU) + if self._profiler_kwargs.get("use_cuda", torch.cuda.is_available()): + activities.append(ProfilerActivity.CUDA) + return activities + + def start(self, action_name: str) -> None: + if self.profiler is None: + # close profiler if it is already opened. might happen if 2 profilers + # are created and the first one did not call `describe` + if torch.autograd._profiler_enabled(): + torch.autograd._disable_profiler() + + if self._schedule is not None: + self._schedule.setup(action_name) + + self._create_profilers() + + profiler = self.profiler.__enter__() + if profiler is not None: + self.profiler = profiler + + if self._parent_profiler is not None: + self._parent_profiler.__enter__() + + if self._lightning_module is not None and self._register is None and self._record_module_names: + self._register = RegisterRecordFunction(self._lightning_module) + self._register.__enter__() + + if self.profiler is not None and action_name not in self._recording_map: + # Add [pl][profile] in name for pytorch profiler to recognize + recording = record_function("[pl][profile]" + action_name) + recording.__enter__() + self._recording_map[action_name] = recording + + def stop(self, action_name: str) -> None: + if action_name in self._recording_map: + self._recording_map[action_name].__exit__(None, None, None) + del self._recording_map[action_name] + + if not _KINETO_AVAILABLE or self._emit_nvtx: + return + + if self.profiler is not None and any(action_name.endswith(func) for func in self.STEP_FUNCTIONS): + if self._schedule is not None: + self._schedule.pre_step(action_name) + + # the default schedule requires a minimum of 5 steps to properly work: `wait=1, warmup=1, active=3`. + # otherwise, this will raise a `segmentation fault`. + if self._should_override_schedule(): + warning_cache.warn( + "The PyTorch Profiler default schedule will be overridden as there is not enough " + "steps to properly record traces." + ) + self._schedule = None + self.profiler.schedule = torch.profiler.profiler._default_schedule_fn + + def on_trace_ready(profiler): + if self.dirpath is not None: + if self._export_to_chrome: + handler = tensorboard_trace_handler( + self.dirpath, self._prepare_filename(action_name=action_name, extension="") + ) + handler(profiler) + + if self._export_to_flame_graph: + path = os.path.join( + self.dirpath, self._prepare_filename(action_name=action_name, extension=".stack") + ) + profiler.export_stacks(path, metric=self._metric) + else: + rank_zero_warn("The PyTorchProfiler failed to export trace as `dirpath` is None") + + if not self._has_on_trace_ready: + self.profiler.on_trace_ready = on_trace_ready + + if self._schedule is not None: + self.profiler.step_num = self._schedule.num_step + self.profiler.step() + self.profiler.add_metadata("Framework", "pytorch-lightning") + + def summary(self) -> str: + if not self._profiler_kwargs.get("enabled", True) or self._emit_nvtx: + return "" + + self._delete_profilers() + + if not self.function_events: + return "" + + if self._export_to_chrome and not _KINETO_AVAILABLE: + filename = f"{self.local_rank}_trace.json" + path_to_trace = filename if self.dirpath is None else os.path.join(self.dirpath, filename) + self.function_events.export_chrome_trace(path_to_trace) + + data = self.function_events.key_averages(group_by_input_shapes=self._group_by_input_shapes) + table = data.table(sort_by=self._sort_by_key, row_limit=self._row_limit) + + recorded_stats = {"records": table} + return self._stats_to_str(recorded_stats) + + def _create_profilers(self) -> None: + if self._emit_nvtx: + self._parent_profiler = self._create_profiler(torch.cuda.profiler.profile) + self.profiler = self._create_profiler(torch.autograd.profiler.emit_nvtx) + else: + self._parent_profiler = None + self.profiler = self._create_profiler( + torch.profiler.profile if _KINETO_AVAILABLE else torch.autograd.profiler.profile + ) + + def _create_profiler(self, profiler: Type[_PROFILER]) -> _PROFILER: + init_parameters = inspect.signature(profiler.__init__).parameters + kwargs = {k: v for k, v in self._profiler_kwargs.items() if k in init_parameters} + return profiler(**kwargs) + + def _cache_functions_events(self) -> None: + if self._emit_nvtx: + return + self.function_events = self.profiler.events() if _KINETO_AVAILABLE else self.profiler.function_events + + def _delete_profilers(self) -> None: + if self.profiler is not None: + self.profiler.__exit__(None, None, None) + self._cache_functions_events() + self.profiler = None + + if self._schedule is not None: + self._schedule.reset() + + if self._parent_profiler is not None: + self._parent_profiler.__exit__(None, None, None) + self._parent_profiler = None + + if self._register is not None: + self._register.__exit__(None, None, None) + self._register = None + + def teardown(self, stage: Optional[str] = None) -> None: + self._delete_profilers() + + for k in list(self._recording_map): + self.stop(k) + self._recording_map = {} + + super().teardown(stage=stage) diff --git a/src/pytorch_lightning/profilers/simple.py b/src/pytorch_lightning/profilers/simple.py new file mode 100644 index 0000000000000..20d76f9b2d378 --- /dev/null +++ b/src/pytorch_lightning/profilers/simple.py @@ -0,0 +1,144 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Profiler to check if there are any bottlenecks in your code.""" +import logging +import os +import time +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np + +from pytorch_lightning.profilers.profiler import Profiler + +log = logging.getLogger(__name__) + +_TABLE_ROW_EXTENDED = Tuple[str, float, int, float, float] +_TABLE_DATA_EXTENDED = List[_TABLE_ROW_EXTENDED] +_TABLE_ROW = Tuple[str, float, float] +_TABLE_DATA = List[_TABLE_ROW] + + +class SimpleProfiler(Profiler): + """This profiler simply records the duration of actions (in seconds) and reports the mean duration of each + action and the total time spent over the entire training run.""" + + def __init__( + self, + dirpath: Optional[Union[str, Path]] = None, + filename: Optional[str] = None, + extended: bool = True, + ) -> None: + """ + Args: + dirpath: Directory path for the ``filename``. If ``dirpath`` is ``None`` but ``filename`` is present, the + ``trainer.log_dir`` (from :class:`~pytorch_lightning.loggers.tensorboard.TensorBoardLogger`) + will be used. + + filename: If present, filename where the profiler results will be saved instead of printing to stdout. + The ``.txt`` extension will be used automatically. + + extended: If ``True``, adds extra columns representing number of calls and percentage of total time spent on + respective action. + + Raises: + ValueError: + If you attempt to start an action which has already started, or + if you attempt to stop recording an action which was never started. + """ + super().__init__(dirpath=dirpath, filename=filename) + self.current_actions: Dict[str, float] = {} + self.recorded_durations = defaultdict(list) + self.extended = extended + self.start_time = time.monotonic() + + def start(self, action_name: str) -> None: + if action_name in self.current_actions: + raise ValueError(f"Attempted to start {action_name} which has already started.") + self.current_actions[action_name] = time.monotonic() + + def stop(self, action_name: str) -> None: + end_time = time.monotonic() + if action_name not in self.current_actions: + raise ValueError(f"Attempting to stop recording an action ({action_name}) which was never started.") + start_time = self.current_actions.pop(action_name) + duration = end_time - start_time + self.recorded_durations[action_name].append(duration) + + def _make_report_extended(self) -> Tuple[_TABLE_DATA_EXTENDED, float, float]: + total_duration = time.monotonic() - self.start_time + report = [ + (a, np.mean(d), len(d), np.sum(d), 100.0 * np.sum(d) / total_duration) + for a, d in self.recorded_durations.items() + ] + report.sort(key=lambda x: x[4], reverse=True) + total_calls = sum(x[2] for x in report) + return report, total_calls, total_duration + + def _make_report(self) -> _TABLE_DATA: + report = [(action, np.mean(d), np.sum(d)) for action, d in self.recorded_durations.items()] + report.sort(key=lambda x: x[1], reverse=True) + return report + + def summary(self) -> str: + sep = os.linesep + output_string = "" + if self._stage is not None: + output_string += f"{self._stage.upper()} " + output_string += f"Profiler Report{sep}" + + if self.extended: + + if len(self.recorded_durations) > 0: + max_key = max(len(k) for k in self.recorded_durations.keys()) + + def log_row(action, mean, num_calls, total, per): + row = f"{sep}| {action:<{max_key}s}\t| {mean:<15}\t|" + row += f" {num_calls:<15}\t| {total:<15}\t| {per:<15}\t|" + return row + + header_string = log_row("Action", "Mean duration (s)", "Num calls", "Total time (s)", "Percentage %") + output_string_len = len(header_string.expandtabs()) + sep_lines = f"{sep}{'-' * output_string_len}" + output_string += sep_lines + header_string + sep_lines + report, total_calls, total_duration = self._make_report_extended() + output_string += log_row("Total", "-", f"{total_calls:}", f"{total_duration:.5}", "100 %") + output_string += sep_lines + for action, mean_duration, num_calls, total_duration, duration_per in report: + output_string += log_row( + action, + f"{mean_duration:.5}", + f"{num_calls}", + f"{total_duration:.5}", + f"{duration_per:.5}", + ) + output_string += sep_lines + else: + max_key = max(len(k) for k in self.recorded_durations) + + def log_row(action, mean, total): + return f"{sep}| {action:<{max_key}s}\t| {mean:<15}\t| {total:<15}\t|" + + header_string = log_row("Action", "Mean duration (s)", "Total time (s)") + output_string_len = len(header_string.expandtabs()) + sep_lines = f"{sep}{'-' * output_string_len}" + output_string += sep_lines + header_string + sep_lines + report = self._make_report() + + for action, mean_duration, total_duration in report: + output_string += log_row(action, f"{mean_duration:.5}", f"{total_duration:.5}") + output_string += sep_lines + output_string += sep + return output_string diff --git a/src/pytorch_lightning/profilers/xla.py b/src/pytorch_lightning/profilers/xla.py new file mode 100644 index 0000000000000..0f86d63b546eb --- /dev/null +++ b/src/pytorch_lightning/profilers/xla.py @@ -0,0 +1,78 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import Dict + +from pytorch_lightning.profilers.profiler import Profiler +from pytorch_lightning.utilities import _TPU_AVAILABLE +from pytorch_lightning.utilities.exceptions import MisconfigurationException + +if _TPU_AVAILABLE: + import torch_xla.debug.profiler as xp + +log = logging.getLogger(__name__) + + +class XLAProfiler(Profiler): + + STEP_FUNCTIONS = {"validation_step", "test_step", "predict_step"} + RECORD_FUNCTIONS = { + "training_step", + "backward", + "validation_step", + "test_step", + "predict_step", + } + + def __init__(self, port: int = 9012) -> None: + """XLA Profiler will help you debug and optimize training workload performance for your models using Cloud + TPU performance tools. + + Args: + port: the port to start the profiler server on. An exception is + raised if the provided port is invalid or busy. + """ + if not _TPU_AVAILABLE: + raise MisconfigurationException("`XLAProfiler` is only supported on TPUs") + super().__init__(dirpath=None, filename=None) + self.port = port + self._recording_map: Dict = {} + self._step_recoding_map: Dict = {} + self._start_trace: bool = False + + def start(self, action_name: str) -> None: + if action_name in self.RECORD_FUNCTIONS: + if not self._start_trace: + self.server = xp.start_server(self.port) + self._start_trace = True + + if action_name in self.STEP_FUNCTIONS: + step = self._get_step_num(action_name) + recording = xp.StepTrace(action_name, step_num=step) + else: + recording = xp.Trace(action_name) + recording.__enter__() + self._recording_map[action_name] = recording + + def stop(self, action_name: str) -> None: + if action_name in self._recording_map: + self._recording_map[action_name].__exit__(None, None, None) + del self._recording_map[action_name] + + def _get_step_num(self, action_name: str) -> int: + if action_name not in self._step_recoding_map: + self._step_recoding_map[action_name] = 1 + else: + self._step_recoding_map[action_name] += 1 + return self._step_recoding_map[action_name] diff --git a/src/pytorch_lightning/trainer/trainer.py b/src/pytorch_lightning/trainer/trainer.py index 85de7acbe352a..f64afddca79fa 100644 --- a/src/pytorch_lightning/trainer/trainer.py +++ b/src/pytorch_lightning/trainer/trainer.py @@ -55,7 +55,7 @@ PrecisionPlugin, ) from pytorch_lightning.plugins.environments.slurm_environment import SLURMEnvironment -from pytorch_lightning.profiler import ( +from pytorch_lightning.profilers import ( AdvancedProfiler, PassThroughProfiler, Profiler, diff --git a/tests/tests_pytorch/core/test_datamodules.py b/tests/tests_pytorch/core/test_datamodules.py index 2c6625ac0772d..bfcb49a306d88 100644 --- a/tests/tests_pytorch/core/test_datamodules.py +++ b/tests/tests_pytorch/core/test_datamodules.py @@ -24,7 +24,7 @@ from pytorch_lightning import LightningDataModule, Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.demos.boring_classes import BoringDataModule, BoringModel -from pytorch_lightning.profiler.simple import SimpleProfiler +from pytorch_lightning.profilers.simple import SimpleProfiler from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, AttributeDict from pytorch_lightning.utilities.exceptions import MisconfigurationException diff --git a/tests/tests_pytorch/deprecated_api/test_remove_1-8.py b/tests/tests_pytorch/deprecated_api/test_remove_1-8.py index 54292a7d32013..758367d1dd40b 100644 --- a/tests/tests_pytorch/deprecated_api/test_remove_1-8.py +++ b/tests/tests_pytorch/deprecated_api/test_remove_1-8.py @@ -40,7 +40,8 @@ from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin -from pytorch_lightning.profiler import AbstractProfiler, AdvancedProfiler, BaseProfiler, Profiler, SimpleProfiler +from pytorch_lightning.profiler import AbstractProfiler, BaseProfiler +from pytorch_lightning.profilers import AdvancedProfiler, Profiler, SimpleProfiler from pytorch_lightning.strategies import ParallelStrategy from pytorch_lightning.trainer.configuration_validator import _check_datamodule_checkpoint_hooks from pytorch_lightning.trainer.states import RunningStage diff --git a/tests/tests_pytorch/deprecated_api/test_remove_1-9.py b/tests/tests_pytorch/deprecated_api/test_remove_1-9.py index 74d509bd8df4b..66bbf80d4e3ea 100644 --- a/tests/tests_pytorch/deprecated_api/test_remove_1-9.py +++ b/tests/tests_pytorch/deprecated_api/test_remove_1-9.py @@ -20,6 +20,12 @@ from pytorch_lightning import Trainer from pytorch_lightning.core.module import LightningModule from pytorch_lightning.demos.boring_classes import BoringModel +from pytorch_lightning.profiler.advanced import AdvancedProfiler +from pytorch_lightning.profiler.base import PassThroughProfiler +from pytorch_lightning.profiler.profiler import Profiler +from pytorch_lightning.profiler.pytorch import PyTorchProfiler, RegisterRecordFunction, ScheduleWrapper +from pytorch_lightning.profiler.simple import SimpleProfiler +from pytorch_lightning.profiler.xla import XLAProfiler from pytorch_lightning.utilities.cli import ( _deprecate_auto_registry_message, _deprecate_registry_message, @@ -27,7 +33,9 @@ LightningCLI, SaveConfigCallback, ) +from pytorch_lightning.utilities.imports import _KINETO_AVAILABLE from pytorch_lightning.utilities.rank_zero import rank_zero_only +from tests_pytorch.helpers.runif import RunIf def test_lightning_logger_base_deprecation_warning(): @@ -154,3 +162,36 @@ def test_lightningCLI_registries_register_automatically(): with pytest.deprecated_call(match=_deprecate_auto_registry_message): with mock.patch("sys.argv", ["any.py"]): LightningCLI(BoringModel, run=False, auto_registry=True) + + +def test_profiler_deprecation_warning(): + assert "Profiler` is deprecated in v1.7" in Profiler.__doc__ + + +@pytest.mark.parametrize( + "cls", + [ + AdvancedProfiler, + PassThroughProfiler, + PyTorchProfiler, + SimpleProfiler, + pytest.param(XLAProfiler, marks=RunIf(tpu=True)), + ], +) +def test_profiler_classes_deprecated_warning(cls): + with pytest.deprecated_call( + match=f"profiler.{cls.__name__}` is deprecated in v1.7 and will be removed in v1.9." + f" Use .*profilers.{cls.__name__}` class instead." + ): + cls() + + +@pytest.mark.skipif(not _KINETO_AVAILABLE, reason="Requires PyTorch Profiler Kineto") +def test_pytorch_profiler_schedule_wrapper_deprecation_warning(): + with pytest.deprecated_call(match="ScheduleWrapper` is deprecated in v1.7 and will be removed in v1.9."): + _ = ScheduleWrapper(None) + + +def test_pytorch_profiler_register_record_function_deprecation_warning(): + with pytest.deprecated_call(match="RegisterRecordFunction` is deprecated in v1.7 and will be removed in in v1.9."): + _ = RegisterRecordFunction(None) diff --git a/tests/tests_pytorch/profiler/__init__.py b/tests/tests_pytorch/profilers/__init__.py similarity index 100% rename from tests/tests_pytorch/profiler/__init__.py rename to tests/tests_pytorch/profilers/__init__.py diff --git a/tests/tests_pytorch/profiler/test_profiler.py b/tests/tests_pytorch/profilers/test_profiler.py similarity index 99% rename from tests/tests_pytorch/profiler/test_profiler.py rename to tests/tests_pytorch/profilers/test_profiler.py index a74c2cde222a6..127d373f3e574 100644 --- a/tests/tests_pytorch/profiler/test_profiler.py +++ b/tests/tests_pytorch/profilers/test_profiler.py @@ -26,8 +26,8 @@ from pytorch_lightning.callbacks import EarlyStopping, StochasticWeightAveraging from pytorch_lightning.demos.boring_classes import BoringModel, ManualOptimBoringModel from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger -from pytorch_lightning.profiler import AdvancedProfiler, PassThroughProfiler, PyTorchProfiler, SimpleProfiler -from pytorch_lightning.profiler.pytorch import RegisterRecordFunction, warning_cache +from pytorch_lightning.profilers import AdvancedProfiler, PassThroughProfiler, PyTorchProfiler, SimpleProfiler +from pytorch_lightning.profilers.pytorch import RegisterRecordFunction, warning_cache from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _KINETO_AVAILABLE from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/profiler/test_xla_profiler.py b/tests/tests_pytorch/profilers/test_xla_profiler.py similarity index 97% rename from tests/tests_pytorch/profiler/test_xla_profiler.py rename to tests/tests_pytorch/profilers/test_xla_profiler.py index c28b829535b4c..7f5b0ecdd7740 100644 --- a/tests/tests_pytorch/profiler/test_xla_profiler.py +++ b/tests/tests_pytorch/profilers/test_xla_profiler.py @@ -18,7 +18,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.demos.boring_classes import BoringModel -from pytorch_lightning.profiler import XLAProfiler +from pytorch_lightning.profilers import XLAProfiler from pytorch_lightning.utilities import _TPU_AVAILABLE from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/run_standalone_tests.sh b/tests/tests_pytorch/run_standalone_tests.sh index 3dd80324539ce..713b7ab08d5fa 100644 --- a/tests/tests_pytorch/run_standalone_tests.sh +++ b/tests/tests_pytorch/run_standalone_tests.sh @@ -38,7 +38,7 @@ parametrizations=${parametrizations//"tests/tests_pytorch/"/} parametrizations_arr=($parametrizations) # tests to skip - space separated -blocklist='profiler/test_profiler.py::test_pytorch_profiler_nested_emit_nvtx utilities/test_warnings.py' +blocklist='profilers/test_profiler.py::test_pytorch_profiler_nested_emit_nvtx utilities/test_warnings.py' report='' for i in "${!parametrizations_arr[@]}"; do @@ -58,7 +58,7 @@ for i in "${!parametrizations_arr[@]}"; do done if nvcc --version; then - nvprof --profile-from-start off -o trace_name.prof -- python ${defaults} profiler/test_profiler.py::test_pytorch_profiler_nested_emit_nvtx + nvprof --profile-from-start off -o trace_name.prof -- python ${defaults} profilers/test_profiler.py::test_pytorch_profiler_nested_emit_nvtx fi # needs to run outside of `pytest` diff --git a/tests/tests_pytorch/utilities/test_cli.py b/tests/tests_pytorch/utilities/test_cli.py index f499acf41e6d0..8e801299aa23c 100644 --- a/tests/tests_pytorch/utilities/test_cli.py +++ b/tests/tests_pytorch/utilities/test_cli.py @@ -36,7 +36,6 @@ from pytorch_lightning.demos.boring_classes import BoringDataModule, BoringModel from pytorch_lightning.loggers import _COMET_AVAILABLE, _NEPTUNE_AVAILABLE, _WANDB_AVAILABLE, TensorBoardLogger from pytorch_lightning.plugins.environments import SLURMEnvironment -from pytorch_lightning.profiler import PyTorchProfiler from pytorch_lightning.strategies import DDPStrategy from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _TPU_AVAILABLE @@ -1497,6 +1496,8 @@ def __init__(self, a_func: Callable = torch.softmax): def test_pytorch_profiler_init_args(): + from pytorch_lightning.profilers import Profiler, PyTorchProfiler + init = { "dirpath": "profiler", "row_limit": 10, @@ -1510,7 +1511,7 @@ def test_pytorch_profiler_init_args(): cli_args += [f"--trainer.profiler.{k}={v}" for k, v in init.items()] cli_args += [f"--trainer.profiler.dict_kwargs.{k}={v}" for k, v in unresolved.items()] - with mock.patch("sys.argv", ["any.py"] + cli_args): + with mock.patch("sys.argv", ["any.py"] + cli_args), mock_subclasses(Profiler, PyTorchProfiler): cli = LightningCLI(TestModel, run=False) assert isinstance(cli.config_init.trainer.profiler, PyTorchProfiler) diff --git a/tests/tests_pytorch/utilities/test_fetching.py b/tests/tests_pytorch/utilities/test_fetching.py index a362521bc9c52..50c9b85970bf0 100644 --- a/tests/tests_pytorch/utilities/test_fetching.py +++ b/tests/tests_pytorch/utilities/test_fetching.py @@ -22,7 +22,7 @@ from pytorch_lightning import Callback, LightningDataModule, Trainer from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset -from pytorch_lightning.profiler import SimpleProfiler +from pytorch_lightning.profilers import SimpleProfiler from pytorch_lightning.trainer.supporters import CombinedLoader from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.fetching import DataFetcher, DataLoaderIterDataFetcher, InterBatchParallelDataFetcher