Skip to content

Commit

Permalink
ref: organize args 2/n (#3448)
Browse files Browse the repository at this point in the history
* ref: organize args 2/n

* ref: organize args 2/n

* ref: organize args 2/n
  • Loading branch information
williamFalcon authored Sep 10, 2020
1 parent 1255a8f commit a208d6d
Show file tree
Hide file tree
Showing 9 changed files with 107 additions and 55 deletions.
5 changes: 4 additions & 1 deletion pytorch_lightning/accelerators/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ def on_trainer_init(
num_nodes,
log_gpu_memory,
sync_batchnorm,
benchmark
benchmark,
replace_sampler_ddp
):
# benchmarking
self.trainer.benchmark = benchmark
Expand Down Expand Up @@ -84,6 +85,8 @@ def on_trainer_init(

self.trainer.on_colab_kaggle = os.getenv('COLAB_GPU') or os.getenv('KAGGLE_URL_BASE')

self.trainer.replace_sampler_ddp = replace_sampler_ddp

def select_accelerator(self):
# SLURM ddp
use_slurm_ddp = self.trainer.use_ddp and self.trainer.is_slurm_managing_tasks
Expand Down
7 changes: 7 additions & 0 deletions pytorch_lightning/trainer/callback_connector.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, ProgressBarBase, ProgressBar
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_utils import is_overridden
Expand All @@ -15,7 +16,13 @@ def on_trainer_init(
checkpoint_callback,
progress_bar_refresh_rate,
process_position,
default_root_dir,
weights_save_path
):
# init folder paths for checkpoint + weights save callbacks
self.trainer._default_root_dir = default_root_dir or os.getcwd()
self.trainer._weights_save_path = weights_save_path or self.trainer._default_root_dir

# init callbacks
self.trainer.callbacks = callbacks or []

Expand Down
5 changes: 5 additions & 0 deletions pytorch_lightning/trainer/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ class DataConnector(object):
def __init__(self, trainer):
self.trainer = trainer

def on_trainer_init(self, check_val_every_n_epoch, reload_dataloaders_every_epoch):
self.trainer.check_val_every_n_epoch = check_val_every_n_epoch
self.trainer.reload_dataloaders_every_epoch = reload_dataloaders_every_epoch
self.trainer._is_data_prepared = False

def get_profiled_train_dataloader(self, train_dataloader):
profiled_dl = self.trainer.profiler.profile_iterable(
enumerate(self._with_is_last(train_dataloader)),
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/trainer/debugging_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def on_init_start(
overfit_batches,
fast_dev_run
):

self.trainer.fast_dev_run = fast_dev_run
if self.trainer.fast_dev_run:
limit_train_batches = 1
Expand Down
24 changes: 24 additions & 0 deletions pytorch_lightning/trainer/logger_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
# limitations under the License.
import torch
from pytorch_lightning.core import memory
from pytorch_lightning.loggers import TensorBoardLogger, LoggerCollection
from pytorch_lightning.utilities import flatten_dict
from pytorch_lightning.utilities.model_utils import is_overridden
from pytorch_lightning.core.step_result import EvalResult, Result
from pprint import pprint
from typing import Iterable


class LoggerConnector:
Expand All @@ -27,6 +29,28 @@ def __init__(self, trainer):
self.logged_metrics = {}
self.progress_bar_metrics = {}

def on_trainer_init(self, logger, log_save_interval, row_log_interval):
# logging
self.configure_logger(logger)
self.trainer.log_save_interval = log_save_interval
self.trainer.row_log_interval = row_log_interval

def configure_logger(self, logger):
if logger is True:
# default logger
self.trainer.logger = TensorBoardLogger(
save_dir=self.trainer.default_root_dir,
version=self.trainer.slurm_job_id,
name='lightning_logs'
)
elif logger is False:
self.trainer.logger = None
else:
if isinstance(logger, Iterable):
self.trainer.logger = LoggerCollection(logger)
else:
self.trainer.logger = logger

def log_metrics(self, metrics, grad_norm_dic, step=None):
"""Logs the metric dict passed in.
If `step` parameter is None and `step` key is presented is metrics,
Expand Down
16 changes: 0 additions & 16 deletions pytorch_lightning/trainer/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,22 +39,6 @@ class TrainerLoggingMixin(ABC):
num_gpus: int
logged_metrics: ...

def configure_logger(self, logger):
if logger is True:
# default logger
self.logger = TensorBoardLogger(
save_dir=self.default_root_dir,
version=self.slurm_job_id,
name='lightning_logs'
)
elif logger is False:
self.logger = None
else:
if isinstance(logger, Iterable):
self.logger = LoggerCollection(logger)
else:
self.logger = logger

def metrics_to_scalars(self, metrics):
new_metrics = {}
for k, v in metrics.items():
Expand Down
50 changes: 21 additions & 29 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from pytorch_lightning.accelerators.accelerator_connector import AcceleratorConnector
from pytorch_lightning.trainer.logger_connector import LoggerConnector
from pytorch_lightning.trainer.lr_scheduler_connector import LRSchedulerConnector
from pytorch_lightning.trainer.training_trick_connector import TrainingTricksConnector
from pytorch_lightning.trainer.callback_connector import CallbackConnector
from pytorch_lightning.trainer.model_connector import ModelConnector
from pytorch_lightning.trainer.debugging_connector import DebuggingConnector
Expand Down Expand Up @@ -177,6 +178,7 @@ def __init__(
self.precision_connector = PrecisionConnector(self)
self.callback_connector = CallbackConnector(self)
self.debugging_connector = DebuggingConnector(self)
self.training_tricks_connector = TrainingTricksConnector(self)

self.tuner = Tuner(self)
self.accelerator_backend = None
Expand All @@ -203,6 +205,7 @@ def __init__(
self.tested_ckpt_path = None

# training state
self.weights_summary = weights_summary
self.model = None
self.datamodule = None
self.testing = False
Expand All @@ -217,26 +220,30 @@ def __init__(
self.running_sanity_check = False
self._state = TrainerState.INITIALIZING

self._default_root_dir = default_root_dir or os.getcwd()
self._weights_save_path = weights_save_path or self._default_root_dir

# init callbacks
self.callback_connector.on_trainer_init(
callbacks,
early_stop_callback,
checkpoint_callback,
progress_bar_refresh_rate,
process_position
process_position,
default_root_dir,
weights_save_path,
)

self.on_init_start()
# init data flags
self.data_connector.on_trainer_init(check_val_every_n_epoch, reload_dataloaders_every_epoch)

self.gradient_clip_val = gradient_clip_val
self.check_val_every_n_epoch = check_val_every_n_epoch
# hook
self.on_init_start()

if not isinstance(track_grad_norm, (int, float)) and track_grad_norm != 'inf':
raise MisconfigurationException("track_grad_norm can be an int, a float or 'inf' (infinity norm).")
self.track_grad_norm = float(track_grad_norm)
# init training tricks
self.training_tricks_connector.on_trainer_init(
gradient_clip_val,
track_grad_norm,
accumulate_grad_batches,
truncated_bptt_steps
)

# init accelerator related flags
self.accelerator_connector.on_trainer_init(
Expand All @@ -248,25 +255,16 @@ def __init__(
num_nodes,
log_gpu_memory,
sync_batchnorm,
benchmark
benchmark,
replace_sampler_ddp
)

# -------------------
# CONTINUE
# -------------------
self.weights_summary = weights_summary

# init train loop related flags
self.train_loop.on_init_start(max_epochs, min_epochs, max_steps, min_steps, num_sanity_val_steps)

self.reload_dataloaders_every_epoch = reload_dataloaders_every_epoch

self.auto_lr_find = auto_lr_find
self.auto_scale_batch_size = auto_scale_batch_size
self._is_data_prepared = False
self.replace_sampler_ddp = replace_sampler_ddp

self.truncated_bptt_steps = truncated_bptt_steps
self.resume_from_checkpoint = resume_from_checkpoint
self.terminate_on_nan = terminate_on_nan
self.shown_warnings = set()
Expand All @@ -276,14 +274,8 @@ def __init__(
profiler = SimpleProfiler()
self.profiler = profiler or PassThroughProfiler()

# accumulated grads
self.accumulate_grad_batches = accumulate_grad_batches
self.configure_accumulated_gradients(accumulate_grad_batches)

# logging
self.configure_logger(logger)
self.log_save_interval = log_save_interval
self.row_log_interval = row_log_interval
# init logger flags
self.logger_connector.on_trainer_init(logger, log_save_interval, row_log_interval)

# init debugging flags
self.debugging_connector.on_init_start(
Expand Down
45 changes: 45 additions & 0 deletions pytorch_lightning/trainer/training_trick_connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.callbacks import GradientAccumulationScheduler


class TrainingTricksConnector:

def __init__(self, trainer):
self.trainer = trainer

def on_trainer_init(self, gradient_clip_val, track_grad_norm, accumulate_grad_batches, truncated_bptt_steps):
# gradient clipping
self.trainer.gradient_clip_val = gradient_clip_val

# gradient norm tracking
if not isinstance(track_grad_norm, (int, float)) and track_grad_norm != 'inf':
raise MisconfigurationException("track_grad_norm can be an int, a float or 'inf' (infinity norm).")
self.trainer.track_grad_norm = float(track_grad_norm)

# accumulated grads
self.trainer.accumulate_grad_batches = accumulate_grad_batches
self.configure_accumulated_gradients(accumulate_grad_batches)

self.trainer.truncated_bptt_steps = truncated_bptt_steps

def configure_accumulated_gradients(self, accumulate_grad_batches):
if isinstance(accumulate_grad_batches, dict):
self.trainer.accumulation_scheduler = GradientAccumulationScheduler(accumulate_grad_batches)
elif isinstance(accumulate_grad_batches, int):
schedule = {0: accumulate_grad_batches}
self.trainer.accumulation_scheduler = GradientAccumulationScheduler(schedule)
else:
raise TypeError("Gradient accumulation supports only int and dict types")
9 changes: 0 additions & 9 deletions pytorch_lightning/trainer/training_tricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,3 @@ def detect_nan_tensors(self, loss: Tensor) -> None:
f'Detected nan and/or inf values in `{name}`.'
' Check your forward pass for numerically unstable operations.'
)

def configure_accumulated_gradients(self, accumulate_grad_batches):
if isinstance(accumulate_grad_batches, dict):
self.accumulation_scheduler = GradientAccumulationScheduler(accumulate_grad_batches)
elif isinstance(accumulate_grad_batches, int):
schedule = {0: accumulate_grad_batches}
self.accumulation_scheduler = GradientAccumulationScheduler(schedule)
else:
raise TypeError("Gradient accumulation supports only int and dict types")

0 comments on commit a208d6d

Please sign in to comment.