Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ref: organize args 3/n #3447

Merged
merged 6 commits into from
Sep 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pytorch_lightning/accelerators/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ def on_trainer_init(
# NVIDIA setup
self.trainer.set_nvidia_flags(self.trainer.is_slurm_managing_tasks, self.trainer.data_parallel_device_ids)

self.trainer.on_colab_kaggle = os.getenv('COLAB_GPU') or os.getenv('KAGGLE_URL_BASE')

def select_accelerator(self):
# SLURM ddp
use_slurm_ddp = self.trainer.use_ddp and self.trainer.is_slurm_managing_tasks
Expand Down
7 changes: 0 additions & 7 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,10 +345,3 @@ def request_dataloader(self, dataloader_fx: Callable) -> DataLoader:
hvd.join()

return dataloader

def determine_data_use_amount(self, overfit_batches: float) -> None:
"""Use less data for debugging purposes"""
if overfit_batches > 0:
self.limit_train_batches = overfit_batches
self.limit_val_batches = overfit_batches
self.limit_test_batches = overfit_batches
109 changes: 109 additions & 0 deletions pytorch_lightning/trainer/debugging_connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pytorch_lightning.utilities.exceptions import MisconfigurationException
from typing import Union
from pytorch_lightning.utilities import rank_zero_warn, rank_zero_info


class DebuggingConnector:

def __init__(self, trainer):
self.trainer = trainer

def on_init_start(
self,
overfit_pct,
val_percent_check,
test_percent_check,
train_percent_check,
limit_train_batches,
limit_val_batches,
limit_test_batches,
val_check_interval,
overfit_batches,
fast_dev_run
):
self.trainer.fast_dev_run = fast_dev_run
if self.trainer.fast_dev_run:
limit_train_batches = 1
limit_val_batches = 1
limit_test_batches = 1
self.trainer.num_sanity_val_steps = 0
self.trainer.max_epochs = 1
rank_zero_info(
'Running in fast_dev_run mode: will run a full train,' ' val and test loop using a single batch'
)

# how much of the data to use
# TODO: remove in 0.10.0
if overfit_pct is not None:
rank_zero_warn(
"Argument `overfit_pct` is now set by `overfit_batches` since v0.8.0"
" and this argument will be removed in v0.10.0",
DeprecationWarning,
)
overfit_batches = overfit_pct

# TODO: remove in 0.10.0
if val_percent_check is not None:
rank_zero_warn(
"Argument `val_percent_check` is now set by `limit_val_batches` since v0.8.0"
" and this argument will be removed in v0.10.0",
DeprecationWarning,
)
limit_val_batches = val_percent_check

# TODO: remove in 0.10.0
if test_percent_check is not None:
rank_zero_warn(
"Argument `test_percent_check` is now set by `limit_test_batches` since v0.8.0"
" and this argument will be removed in v0.10.0",
DeprecationWarning,
)
limit_test_batches = test_percent_check

# TODO: remove in 0.10.0
if train_percent_check is not None:
rank_zero_warn(
"Argument `train_percent_check` is now set by `limit_train_batches` since v0.8.0"
" and this argument will be removed in v0.10.0",
DeprecationWarning,
)
limit_train_batches = train_percent_check

self.trainer.limit_train_batches = _determine_batch_limits(limit_train_batches, 'limit_train_batches')
self.trainer.limit_val_batches = _determine_batch_limits(limit_val_batches, 'limit_val_batches')
self.trainer.limit_test_batches = _determine_batch_limits(limit_test_batches, 'limit_test_batches')
self.trainer.val_check_interval = _determine_batch_limits(val_check_interval, 'val_check_interval')
self.trainer.overfit_batches = _determine_batch_limits(overfit_batches, 'overfit_batches')
self.determine_data_use_amount(self.trainer.overfit_batches)

def determine_data_use_amount(self, overfit_batches: float) -> None:
"""Use less data for debugging purposes"""
if overfit_batches > 0:
self.trainer.limit_train_batches = overfit_batches
self.trainer.limit_val_batches = overfit_batches
self.trainer.limit_test_batches = overfit_batches


def _determine_batch_limits(batches: Union[int, float], name: str) -> Union[int, float]:
if 0 <= batches <= 1:
return batches
elif batches > 1 and batches % 1.0 == 0:
return int(batches)
else:
raise MisconfigurationException(
f'You have passed invalid value {batches} for {name}, it has to be in [0.0, 1.0] or an int.'
)
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,26 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pytorch_lightning import _logger as log
from pytorch_lightning.utilities import APEX_AVAILABLE, NATIVE_AMP_AVALAIBLE, rank_zero_warn, AMPType


class Initializer:
class PrecisionConnector:

def __init__(self, trainer):
self.trainer = trainer

def on_trainer_init(self, precision, amp_level, amp_backend):
# AMP init
# These are the only lines needed after v0.8.0
# we wrap the user's forward with autocast and give it back at the end of fit
self.trainer.autocast_original_forward = None
self.trainer.precision = precision
self.trainer.scaler = None

self.trainer.amp_level = amp_level
self.init_amp(amp_backend)

def init_amp(self, amp_type: str):
assert self.trainer.precision in (16, 32), 'only 32 or 16 bit precision supported'
self.trainer.amp_backend = None
Expand Down
112 changes: 23 additions & 89 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,16 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.trainer.evaluation_loop import EvaluationLoop
from pytorch_lightning.trainer.training_loop import TrainLoop
from pytorch_lightning.trainer.data_connector import DataConnector
from pytorch_lightning.accelerators.accelerator_connector import AcceleratorConnector
from pytorch_lightning.trainer.logger_connector import LoggerConnector
from pytorch_lightning.trainer.lr_scheduler_connector import LRSchedulerConnector
from pytorch_lightning.trainer.callback_connector import CallbackConnector
from pytorch_lightning.trainer.model_connector import ModelConnector
from pytorch_lightning.trainer.debugging_connector import DebuggingConnector
from pytorch_lightning import _logger as log
from pytorch_lightning.tuner.tuning import Tuner
from pytorch_lightning.trainer.initializer import Initializer
from pytorch_lightning.trainer.precision_connector import PrecisionConnector
from pytorch_lightning.trainer.data_connector import DataConnector
from pytorch_lightning.utilities.model_utils import is_overridden
from pytorch_lightning.trainer import docstrings
from pytorch_lightning.trainer.properties import TrainerProperties
Expand Down Expand Up @@ -173,8 +174,10 @@ def __init__(
self.accelerator_connector = AcceleratorConnector(self)
self.logger_connector = LoggerConnector(self)
self.model_connector = ModelConnector(self)
self.initializer = Initializer(self)
self.precision_connector = PrecisionConnector(self)
self.callback_connector = CallbackConnector(self)
self.debugging_connector = DebuggingConnector(self)

self.tuner = Tuner(self)
self.accelerator_backend = None

Expand Down Expand Up @@ -253,15 +256,8 @@ def __init__(
# -------------------
self.weights_summary = weights_summary

self.max_epochs = max_epochs
self.min_epochs = min_epochs
self.max_steps = max_steps
self.min_steps = min_steps

if num_sanity_val_steps == -1:
self.num_sanity_val_steps = float('inf')
else:
self.num_sanity_val_steps = num_sanity_val_steps
# init train loop related flags
self.train_loop.on_init_start(max_epochs, min_epochs, max_steps, min_steps, num_sanity_val_steps)

self.reload_dataloaders_every_epoch = reload_dataloaders_every_epoch

Expand All @@ -275,17 +271,6 @@ def __init__(
self.terminate_on_nan = terminate_on_nan
self.shown_warnings = set()

self.fast_dev_run = fast_dev_run
if self.fast_dev_run:
limit_train_batches = 1
limit_val_batches = 1
limit_test_batches = 1
self.num_sanity_val_steps = 0
self.max_epochs = 1
rank_zero_info(
'Running in fast_dev_run mode: will run a full train,' ' val and test loop using a single batch'
)

# configure profiler
if profiler is True:
profiler = SimpleProfiler()
Expand All @@ -300,61 +285,22 @@ def __init__(
self.log_save_interval = log_save_interval
self.row_log_interval = row_log_interval

# how much of the data to use
# TODO: remove in 0.10.0
if overfit_pct is not None:
rank_zero_warn(
"Argument `overfit_pct` is now set by `overfit_batches` since v0.8.0"
" and this argument will be removed in v0.10.0",
DeprecationWarning,
)
overfit_batches = overfit_pct

# TODO: remove in 0.10.0
if val_percent_check is not None:
rank_zero_warn(
"Argument `val_percent_check` is now set by `limit_val_batches` since v0.8.0"
" and this argument will be removed in v0.10.0",
DeprecationWarning,
)
limit_val_batches = val_percent_check

# TODO: remove in 0.10.0
if test_percent_check is not None:
rank_zero_warn(
"Argument `test_percent_check` is now set by `limit_test_batches` since v0.8.0"
" and this argument will be removed in v0.10.0",
DeprecationWarning,
)
limit_test_batches = test_percent_check

# TODO: remove in 0.10.0
if train_percent_check is not None:
rank_zero_warn(
"Argument `train_percent_check` is now set by `limit_train_batches` since v0.8.0"
" and this argument will be removed in v0.10.0",
DeprecationWarning,
)
limit_train_batches = train_percent_check

self.limit_train_batches = _determine_batch_limits(limit_train_batches, 'limit_train_batches')
self.limit_val_batches = _determine_batch_limits(limit_val_batches, 'limit_val_batches')
self.limit_test_batches = _determine_batch_limits(limit_test_batches, 'limit_test_batches')
self.val_check_interval = _determine_batch_limits(val_check_interval, 'val_check_interval')
self.overfit_batches = _determine_batch_limits(overfit_batches, 'overfit_batches')
self.determine_data_use_amount(self.overfit_batches)

# AMP init
# These are the only lines needed after v0.8.0
# we wrap the user's forward with autocast and give it back at the end of fit
self.autocast_original_forward = None
self.precision = precision
self.scaler = None

self.amp_level = amp_level
self.initializer.init_amp(amp_backend)
# init debugging flags
self.debugging_connector.on_init_start(
overfit_pct,
val_percent_check,
test_percent_check,
train_percent_check,
limit_train_batches,
limit_val_batches,
limit_test_batches,
val_check_interval,
overfit_batches,
fast_dev_run
)

self.on_colab_kaggle = os.getenv('COLAB_GPU') or os.getenv('KAGGLE_URL_BASE')
# set precision
self.precision_connector.on_trainer_init(precision, amp_level, amp_backend)

# Callback system
self.on_init_end()
Expand Down Expand Up @@ -862,18 +808,6 @@ def call_hook(self, hook_name, *args, **kwargs):

return output


def _determine_batch_limits(batches: Union[int, float], name: str) -> Union[int, float]:
if 0 <= batches <= 1:
return batches
elif batches > 1 and batches % 1.0 == 0:
return int(batches)
else:
raise MisconfigurationException(
f'You have passed invalid value {batches} for {name}, it has to be in [0.0, 1.0] or an int.'
)


# add docstrings
Trainer.__init__.__doc__ = docstrings.trainer.init
Trainer.fit.__doc__ = docstrings.trainer.fit
Expand Down
11 changes: 11 additions & 0 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,17 @@ def __init__(self, trainer):
self._teardown_already_run = False
self.running_loss = TensorRunningAccum(window_length=20)

def on_init_start(self, max_epochs, min_epochs, max_steps, min_steps, num_sanity_val_steps):
self.trainer.max_epochs = max_epochs
self.trainer.min_epochs = min_epochs
self.trainer.max_steps = max_steps
self.trainer.min_steps = min_steps

if num_sanity_val_steps == -1:
self.trainer.num_sanity_val_steps = float('inf')
else:
self.trainer.num_sanity_val_steps = num_sanity_val_steps

@property
def num_optimizers(self):
num_optimizers = len(self.get_optimizers_iterable())
Expand Down