Skip to content

Commit

Permalink
Merge b66f1f6 into e848542
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Mar 5, 2021
2 parents e848542 + b66f1f6 commit 9d17fb6
Show file tree
Hide file tree
Showing 8 changed files with 322 additions and 313 deletions.
3 changes: 1 addition & 2 deletions pytorch_lightning/callbacks/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@

class tqdm(_tqdm):
"""
Custom tqdm progressbar where we append 0 to floating points/strings to
prevent the progress bar from flickering
Custom tqdm progressbar where we append 0 to floating points/strings to prevent the progress bar from flickering
"""

@staticmethod
Expand Down
152 changes: 66 additions & 86 deletions pytorch_lightning/profiler/profilers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Profiler to check if there are any bottlenecks in your code."""

import cProfile
import io
import logging
Expand All @@ -22,7 +21,7 @@
from abc import ABC, abstractmethod
from collections import defaultdict
from contextlib import contextmanager
from typing import Optional, Union
from typing import Optional

import numpy as np

Expand All @@ -31,22 +30,8 @@
log = logging.getLogger(__name__)


class BaseProfiler(ABC):
"""
If you wish to write a custom profiler, you should inhereit from this class.
"""

def __init__(self, output_streams: Optional[Union[list, tuple]] = None):
"""
Args:
output_streams: callable
"""
if output_streams:
if not isinstance(output_streams, (list, tuple)):
output_streams = [output_streams]
else:
output_streams = []
self.write_streams = output_streams
class AbstractProfiler(ABC):
"""Specification of a profiler."""

@abstractmethod
def start(self, action_name: str) -> None:
Expand All @@ -56,6 +41,38 @@ def start(self, action_name: str) -> None:
def stop(self, action_name: str) -> None:
"""Defines how to record the duration once an action is complete."""

@abstractmethod
def summary(self) -> str:
"""Create profiler summary in text format."""


class BaseProfiler(AbstractProfiler, ABC):
"""
If you wish to write a custom profiler, you should inherit from this class.
"""

def __init__(self, local_rank: Optional[int] = None, log_dir: Optional[str] = None) -> None:
self.output_fname = getattr(self, "output_fname", None)
# the profiler can be used outside of lightning
# that's why we call `on_train_start` manually
self.on_train_start(local_rank=local_rank, log_dir=log_dir)

def on_train_start(self, local_rank: Optional[int] = None, log_dir: Optional[str] = None):
"""
This function is used by the Trainer to inject local_rank with `DDP`
and `TensorBoardLogger` log_dir in the profiler.
"""
self.local_rank = local_rank
self.log_dir = log_dir
self.prepare_file()

def prepare_file(self) -> None:
self.output_file = None
if self.output_fname:
fs = get_filesystem(self.output_fname)
self.output_file = fs.open(self.output_fname, "w")
self.write_streams = [self.output_file.write] if self.output_file else [log.info]

@contextmanager
def profile(self, action_name: str) -> None:
"""
Expand Down Expand Up @@ -91,24 +108,30 @@ def describe(self) -> None:
"""Logs a profile report after the conclusion of the training run."""
for write in self.write_streams:
write(self.summary())
if self.output_file:
self.output_file.flush()

@abstractmethod
def summary(self) -> str:
"""Create profiler summary in text format."""

def on_train_start(self, local_rank: Optional[int] = None):
self.local_rank = local_rank
def stats_to_str(self, stats: dict) -> str:
output = ["Profiler Report"]
for action, value in stats.items():
header = f"Profile stats for: {action}"
if getattr(self, "local_rank", None) is not None:
header += f" rank: {self.local_rank}"
output.append(header)
output.append(value)
return os.linesep.join(output)

def __del__(self) -> None:
"""Close profiler's stream."""
if self.output_file:
self.output_file.close()


class PassThroughProfiler(BaseProfiler):
"""
This class should be used when you don't want the (small) overhead of profiling.
The Trainer uses this class by default.
"""

def __init__(self):
super().__init__(output_streams=None)

def start(self, action_name: str) -> None:
pass

Expand All @@ -125,7 +148,7 @@ class SimpleProfiler(BaseProfiler):
the mean duration of each action and the total time spent over the entire training run.
"""

def __init__(self, output_filename: Optional[str] = None, extended=True):
def __init__(self, output_filename: Optional[str] = None, extended: bool = True):
"""
Args:
output_filename: optionally save profile results to file instead of printing
Expand All @@ -136,19 +159,12 @@ def __init__(self, output_filename: Optional[str] = None, extended=True):
If you attempt to start an action which has already started, or
if you attempt to stop recording an action which was never started.
"""
self.output_fname = output_filename
self.current_actions = {}
self.recorded_durations = defaultdict(list)
self.extended = extended

self.output_fname = output_filename
self.output_file = None
if self.output_fname:
fs = get_filesystem(self.output_fname)
self.output_file = fs.open(self.output_fname, "w")

streaming_out = [self.output_file.write] if self.output_file else [log.info]
super().__init__()
self.start_time = time.monotonic()
super().__init__(output_streams=streaming_out)

def start(self, action_name: str) -> None:
if action_name in self.current_actions:
Expand All @@ -170,24 +186,25 @@ def make_report(self):
return report, total_duration

def summary(self) -> str:
output_string = "\n\nProfiler Report\n"
sep = os.linesep
output_string = f"Profiler Report{sep}"

if self.extended:

if len(self.recorded_durations) > 0:
max_key = np.max([len(k) for k in self.recorded_durations.keys()])

def log_row(action, mean, num_calls, total, per):
row = f"{os.linesep}{action:<{max_key}s}\t| {mean:<15}\t|"
row = f"{sep}{action:<{max_key}s}\t| {mean:<15}\t|"
row += f"{num_calls:<15}\t| {total:<15}\t| {per:<15}\t|"
return row

output_string += log_row("Action", "Mean duration (s)", "Num calls", "Total time (s)", "Percentage %")
output_string_len = len(output_string)
output_string += f"{os.linesep}{'-' * output_string_len}"
output_string += f"{sep}{'-' * output_string_len}"
report, total_duration = self.make_report()
output_string += log_row("Total", "-", "_", f"{total_duration:.5}", "100 %")
output_string += f"{os.linesep}{'-' * output_string_len}"
output_string += f"{sep}{'-' * output_string_len}"
for action, durations, duration_per in report:
output_string += log_row(
action,
Expand All @@ -199,27 +216,16 @@ def log_row(action, mean, num_calls, total, per):
else:

def log_row(action, mean, total):
return f"{os.linesep}{action:<20s}\t| {mean:<15}\t| {total:<15}"
return f"{sep}{action:<20s}\t| {mean:<15}\t| {total:<15}"

output_string += log_row("Action", "Mean duration (s)", "Total time (s)")
output_string += f"{os.linesep}{'-' * 65}"
output_string += f"{sep}{'-' * 65}"

for action, durations in self.recorded_durations.items():
output_string += log_row(action, f"{np.mean(durations):.5}", f"{np.sum(durations):.5}")
output_string += os.linesep
output_string += sep
return output_string

def describe(self):
"""Logs a profile report after the conclusion of the training run."""
super().describe()
if self.output_file:
self.output_file.flush()

def __del__(self):
"""Close profiler's stream."""
if self.output_file:
self.output_file.close()


class AdvancedProfiler(BaseProfiler):
"""
Expand All @@ -241,17 +247,10 @@ def __init__(self, output_filename: Optional[str] = None, line_count_restriction
ValueError:
If you attempt to stop recording an action which was never started.
"""
self.output_fname = output_filename
self.profiled_actions = {}
self.line_count_restriction = line_count_restriction

self.output_fname = output_filename
self.output_file = None
if self.output_fname:
fs = get_filesystem(self.output_fname)
self.output_file = fs.open(self.output_fname, "w")

streaming_out = [self.output_file.write] if self.output_file else [log.info]
super().__init__(output_streams=streaming_out)
super().__init__()

def start(self, action_name: str) -> None:
if action_name not in self.profiled_actions:
Expand All @@ -261,9 +260,7 @@ def start(self, action_name: str) -> None:
def stop(self, action_name: str) -> None:
pr = self.profiled_actions.get(action_name)
if pr is None:
raise ValueError( # pragma: no-cover
f"Attempting to stop recording an action ({action_name}) which was never started."
)
raise ValueError(f"Attempting to stop recording an action ({action_name}) which was never started.")
pr.disable()

def summary(self) -> str:
Expand All @@ -273,21 +270,4 @@ def summary(self) -> str:
ps = pstats.Stats(pr, stream=s).strip_dirs().sort_stats('cumulative')
ps.print_stats(self.line_count_restriction)
recorded_stats[action_name] = s.getvalue()

# log to standard out
output_string = f"{os.linesep}Profiler Report{os.linesep}"
for action, stats in recorded_stats.items():
output_string += f"{os.linesep}Profile stats for: {action}{os.linesep}{stats}"

return output_string

def describe(self):
"""Logs a profile report after the conclusion of the training run."""
super().describe()
if self.output_file:
self.output_file.flush()

def __del__(self):
"""Close profiler's stream."""
if self.output_file:
self.output_file.close()
return self.stats_to_str(recorded_stats)
Loading

0 comments on commit 9d17fb6

Please sign in to comment.