Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Connect progress tracking dataclasses to loops #8244

Merged
merged 23 commits into from
Jul 5, 2021
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
082068a
Connect progress dataclasses
carmocca Jul 1, 2021
71f2dc2
Update CHANGELOG
carmocca Jul 1, 2021
62ea5a2
Remove re-definition
carmocca Jul 1, 2021
8b053be
Reorder and undo name change
carmocca Jul 1, 2021
52bf1ec
whitespace
carmocca Jul 1, 2021
d45fc92
Move result teardown to loops
carmocca Jul 1, 2021
2519407
Update CHANGELOG
carmocca Jul 1, 2021
ddf19f0
Proper result acesss
carmocca Jul 1, 2021
ccaa232
Remove teardown from run
carmocca Jul 1, 2021
a81b4df
Move previous teardown to on_run_end
carmocca Jul 1, 2021
9ea2078
Add comment
carmocca Jul 1, 2021
956d42f
Merge branch 'master' into refactor/move-result-teardown-to-loops
carmocca Jul 1, 2021
9030119
Merge 8250
carmocca Jul 1, 2021
2befea7
Merge branch 'refactor/move-result-teardown-to-loops' into feat/conne…
carmocca Jul 1, 2021
d75b9e2
Remove stage set to None where it shouldnt
carmocca Jul 2, 2021
cc85747
Merge branch 'refactor/move-result-teardown-to-loops' into feat/conne…
carmocca Jul 2, 2021
011a0a4
Merge branch 'master' into feat/connect-progress-dataclasses
carmocca Jul 2, 2021
b807f47
Remove extra getter/setter
carmocca Jul 2, 2021
d9ef679
trigger ci
carmocca Jul 2, 2021
954a0e1
Test deepcopy for progress tracking dataclasses
carmocca Jul 2, 2021
b6bbbcc
Merge branch 'bugfix/progress-tracking-deepcopy' into feat/connect-pr…
carmocca Jul 2, 2021
c988bd2
Update pytorch_lightning/loops/batch/training_batch_loop.py
carmocca Jul 3, 2021
7ceed96
Merge branch 'master' into feat/connect-progress-dataclasses
carmocca Jul 5, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Progress tracking
* Added dataclasses for progress tracking ([#6603](https://github.com/PyTorchLightning/pytorch-lightning/pull/6603), [#7574](https://github.com/PyTorchLightning/pytorch-lightning/pull/7574), [#8140](https://github.com/PyTorchLightning/pytorch-lightning/pull/8140))
* Add `{,load_}state_dict` to the progress tracking dataclasses ([#8140](https://github.com/PyTorchLightning/pytorch-lightning/pull/8140))
* Connect the progress tracking dataclasses to the loops ([#8244](https://github.com/PyTorchLightning/pytorch-lightning/pull/8244))


- Added support for passing a `LightningDataModule` positionally as the second argument to `trainer.{validate,test,predict}` ([#7431](https://github.com/PyTorchLightning/pytorch-lightning/pull/7431))
Expand Down
10 changes: 1 addition & 9 deletions pytorch_lightning/loops/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,7 @@ class Loop(ABC):
def __init__(self) -> None:
self.iteration_count: int = 0
self.trainer: Optional['pl.Trainer'] = None
self._restarting = False

@property
def restarting(self) -> bool:
return self._restarting

@restarting.setter
def restarting(self, restarting: bool) -> None:
self._restarting = restarting
self.restarting = False

@property
@abstractmethod
Expand Down
21 changes: 20 additions & 1 deletion pytorch_lightning/loops/batch/training_batch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@
from torch import Tensor
from torch.optim import Optimizer

import pytorch_lightning as pl
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.loops.base import Loop
from pytorch_lightning.plugins import ParallelPlugin
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.progress import BatchProgress, OptimizationProgress
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.utilities import AMPType, AttributeDict, DeviceType, grad_norm
from pytorch_lightning.utilities.apply_func import apply_to_collection
Expand All @@ -48,13 +50,30 @@ def __init__(self) -> None:
self.running_loss: TensorRunningAccum = TensorRunningAccum(window_length=20)
self.batch_idx: int = 0
self.split_idx: Optional[int] = None
self._warning_cache: WarningCache = WarningCache()
self.progress = BatchProgress()
self.optim_progress = OptimizationProgress()
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

self._warning_cache: WarningCache = WarningCache()
self._hiddens: Optional[Tensor] = None
self._optimizer_freq_cumsum: Optional[int] = None
self._remaining_splits: Optional[List[Any]] = None
self._skip_backward: bool = False

def connect(
self,
trainer: 'pl.Trainer',
*args: Any,
progress: Optional[BatchProgress] = None,
optim_progress: Optional[OptimizationProgress] = None,
**kwargs: Any
) -> None:
"""Connects the loop with necessary arguments like the trainer"""
super().connect(trainer, *args, **kwargs)
if progress is not None:
self.progress = progress
if optim_progress:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
self.optim_progress = optim_progress

carmocca marked this conversation as resolved.
Show resolved Hide resolved
@property
def done(self) -> bool:
"""Returns if all batch splits have been processed already"""
Expand Down
14 changes: 10 additions & 4 deletions pytorch_lightning/loops/dataloader/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from pytorch_lightning.loops.dataloader import DataLoaderLoop
from pytorch_lightning.loops.epoch import EvaluationEpochLoop
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.progress import EpochLoopProgress
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import EPOCH_OUTPUT
Expand All @@ -31,12 +32,13 @@ class EvaluationLoop(DataLoaderLoop):

def __init__(self):
super().__init__()
self._max_batches: Optional[Union[int, Sequence[int]]] = None
self.outputs = []
self.progress = EpochLoopProgress()

self.epoch_loop = EvaluationEpochLoop()

self._results = ResultCollection(training=False)
self._max_batches: Optional[Union[int, Sequence[int]]] = None
self._has_run: bool = False
carmocca marked this conversation as resolved.
Show resolved Hide resolved

@property
Expand Down Expand Up @@ -64,10 +66,14 @@ def predictions(self):
"""Returns the predictions from all dataloaders"""
return self.epoch_loop.predictions

def connect(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None:
"""Connects the loop to everything necessary (like trainer and accelerators)"""
def connect(
self, trainer: "pl.Trainer", *args: Any, progress: Optional[EpochLoopProgress] = None, **kwargs: Any
) -> None:
"""Connects the loop with necessary arguments like the trainer"""
super().connect(trainer, *args, **kwargs)
self.epoch_loop.connect(trainer)
if progress is not None:
self.progress = progress
self.epoch_loop.connect(trainer, progress=self.progress.epoch)

@property
def done(self) -> bool:
Expand Down
14 changes: 10 additions & 4 deletions pytorch_lightning/loops/dataloader/prediction_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pytorch_lightning.loops.dataloader.dataloader_loop import DataLoaderLoop
from pytorch_lightning.loops.epoch.prediction_epoch_loop import PredictionEpochLoop
from pytorch_lightning.plugins import DDPSpawnPlugin
from pytorch_lightning.trainer.progress import EpochLoopProgress
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.types import _PREDICT_OUTPUT

Expand All @@ -18,8 +19,9 @@ def __init__(self):
super().__init__()
self.predictions: Optional[List[List[Any]]] = None
self.epoch_batch_indices: Optional[List[List[int]]] = None
self.progress = EpochLoopProgress()

self.epoch_loop: PredictionEpochLoop = PredictionEpochLoop()
self.epoch_loop = PredictionEpochLoop()

self._results = None # for `trainer._results` access
self._return_predictions: bool = False
Expand Down Expand Up @@ -74,10 +76,14 @@ def done(self) -> bool:
def skip(self) -> bool:
return sum(self.max_batches) == 0

def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None:
"""Connects the loop with all necessary things (like trainer)"""
def connect(
self, trainer: "pl.Trainer", *args: Any, progress: Optional[EpochLoopProgress] = None, **kwargs: Any
) -> None:
"""Connects the loop with necessary arguments like the trainer"""
super().connect(trainer, *args, **kwargs)
self.epoch_loop.connect(trainer, *args, **kwargs)
if progress is not None:
self.progress = progress
self.epoch_loop.connect(trainer, progress=self.progress.epoch)

def reset(self) -> None:
"""Resets the internal state of the loop for a new run"""
Expand Down
11 changes: 11 additions & 0 deletions pytorch_lightning/loops/epoch/evaluation_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
from deprecate import void
from torch import Tensor

import pytorch_lightning as pl
from pytorch_lightning.loops.base import Loop
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.progress import EpochProgress
from pytorch_lightning.trainer.supporters import PredictionCollection
from pytorch_lightning.utilities.memory import recursive_detach
from pytorch_lightning.utilities.types import STEP_OUTPUT
Expand All @@ -39,6 +41,15 @@ def __init__(self) -> None:
self.dataloader_idx: Optional[int] = None
self.num_dataloaders: Optional[int] = None
self.outputs: List[STEP_OUTPUT] = []
self.progress = EpochProgress()

def connect(
self, trainer: "pl.Trainer", *args: Any, progress: Optional[EpochProgress] = None, **kwargs: Any
) -> None:
"""Connects the loop with necessary arguments like the trainer"""
super().connect(trainer, *args, **kwargs)
if progress is not None:
self.progress = progress

@property
def done(self) -> bool:
Expand Down
12 changes: 12 additions & 0 deletions pytorch_lightning/loops/epoch/prediction_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@

from deprecate import void

import pytorch_lightning as pl
from pytorch_lightning.loops.base import Loop
from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper
from pytorch_lightning.trainer.progress import EpochProgress
from pytorch_lightning.utilities.warnings import WarningCache


Expand All @@ -16,11 +18,21 @@ def __init__(self) -> None:
self.return_predictions: bool = False
self.predictions: List[Any] = []
self.current_batch_indices: List[int] = []
self.progress = EpochProgress()

self._dl_max_batches: Optional[int] = None
self._num_dataloaders: Optional[int] = None
self._warning_cache = WarningCache()
self._all_batch_indices: List[int] = []

def connect(
self, trainer: "pl.Trainer", *args: Any, progress: Optional[EpochProgress] = None, **kwargs: Any
) -> None:
"""Connects the loop with necessary arguments like the trainer"""
super().connect(trainer, *args, **kwargs)
if progress is not None:
self.progress = progress

@property
def done(self) -> bool:
"""Ends prediction when the iteration count exceeds the total number of available batches"""
Expand Down
18 changes: 14 additions & 4 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from pytorch_lightning import loops # import as loops to avoid circular imports
from pytorch_lightning.loops.batch import TrainingBatchLoop
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.progress import TrainingEpochProgress
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
Expand All @@ -44,6 +45,7 @@ def __init__(self, min_steps: int, max_steps: int):
# the number of batches seen this run, updates immediately after batch_loop.run()
self.batches_seen: int = 0
self.is_last_batch: Optional[bool] = None
self.progress = TrainingEpochProgress()

self.batch_loop = TrainingBatchLoop()
self.val_loop = loops.EvaluationLoop()
Expand All @@ -67,11 +69,19 @@ def done(self) -> bool:
max_steps_reached = self.max_steps is not None and self.global_step >= self.max_steps
return max_steps_reached or self.trainer.should_stop or self._num_training_batches_reached(self.is_last_batch)

def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None:
"""Connects the loop with all necessary parts like trainer and accelerators"""
def connect(
self,
trainer: 'pl.Trainer',
*args: Any,
progress: Optional[TrainingEpochProgress] = None,
**kwargs: Any
) -> None:
"""Connects the loop with necessary arguments like the trainer"""
super().connect(trainer, *args, **kwargs)
self.batch_loop.connect(trainer)
self.val_loop.connect(trainer)
if progress is not None:
self.progress = progress
self.batch_loop.connect(trainer, progress=self.progress.batch, optim_progress=self.progress.optim)
carmocca marked this conversation as resolved.
Show resolved Hide resolved
self.val_loop.connect(trainer, progress=self.progress.val)

def reset(self) -> None:
"""Resets the internal state of the loop for a new run"""
Expand Down
10 changes: 8 additions & 2 deletions pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from pytorch_lightning.loops import Loop
from pytorch_lightning.loops.epoch import TrainingEpochLoop
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.progress import FitLoopProgress
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.utilities import rank_zero_info

Expand Down Expand Up @@ -50,6 +51,7 @@ def __init__(
super().__init__()
self.max_epochs = 1000 if (max_epochs is None and max_steps is None) else max_epochs
self.min_epochs = 1 if (min_epochs is None and min_steps is None) else min_epochs
self.progress = FitLoopProgress()

self.epoch_loop = TrainingEpochLoop(min_steps, max_steps)

Expand Down Expand Up @@ -167,10 +169,14 @@ def skip(self) -> bool:
"""Whether we should skip the training and immediately return from the call to :meth:`run`."""
return self.done or self.trainer.num_training_batches == 0

def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None:
def connect(
self, trainer: 'pl.Trainer', *args: Any, progress: Optional[FitLoopProgress] = None, **kwargs: Any
) -> None:
"""Connects the loop with necessary arguments like the trainer"""
super().connect(trainer, *args, **kwargs)
self.epoch_loop.connect(trainer)
if progress is not None:
self.progress = progress
self.epoch_loop.connect(trainer, progress=self.progress.epoch)

def reset(self) -> None:
"""Resets the internal state of this loop"""
Expand Down
47 changes: 22 additions & 25 deletions pytorch_lightning/trainer/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,16 @@
@dataclass
class _DataclassStateDictMixin:

def __getstate__(self) -> dict:
def state_dict(self) -> dict:
return asdict(self)

def __setstate__(self, state: dict) -> None:
self.__dict__.update(state)

def state_dict(self) -> dict:
return self.__getstate__()
def load_state_dict(self, state_dict: dict) -> None:
self.__dict__.update(state_dict)

@classmethod
def from_state_dict(cls, state_dict: dict) -> "_DataclassStateDictMixin":
obj = cls()
obj.__setstate__(state_dict)
obj.load_state_dict(state_dict)
return obj


Expand Down Expand Up @@ -115,9 +112,9 @@ def increment_completed(self) -> None:
def from_defaults(cls, **kwargs: Optional[int]) -> "Progress":
return cls(total=Tracker(**kwargs), current=Tracker(**kwargs))

def __setstate__(self, state: dict) -> None:
self.total.__setstate__(state["total"])
self.current.__setstate__(state["current"])
def load_state_dict(self, state_dict: dict) -> None:
self.total.load_state_dict(state_dict["total"])
self.current.load_state_dict(state_dict["current"])


class BatchProgress(Progress):
Expand Down Expand Up @@ -147,9 +144,9 @@ class EpochProgress(Progress):
def reset_on_epoch(self) -> None:
self.batch.current.reset()

def __setstate__(self, state: dict) -> None:
super().__setstate__(state)
self.batch.__setstate__(state["batch"])
def load_state_dict(self, state_dict: dict) -> None:
super().load_state_dict(state_dict)
self.batch.load_state_dict(state_dict["batch"])


@dataclass
Expand All @@ -169,9 +166,9 @@ def reset_on_epoch(self) -> None:
self.step.current.reset()
self.zero_grad.current.reset()

def __setstate__(self, state: dict) -> None:
self.step.__setstate__(state["step"])
self.zero_grad.__setstate__(state["zero_grad"])
def load_state_dict(self, state_dict: dict) -> None:
self.step.load_state_dict(state_dict["step"])
self.zero_grad.load_state_dict(state_dict["zero_grad"])


@dataclass
Expand Down Expand Up @@ -200,9 +197,9 @@ def reset_on_epoch(self) -> None:
self.optimizer.reset_on_epoch()
self.scheduler.current.reset()

def __setstate__(self, state: dict) -> None:
self.optimizer.__setstate__(state["optimizer"])
self.scheduler.__setstate__(state["scheduler"])
def load_state_dict(self, state_dict: dict) -> None:
self.optimizer.load_state_dict(state_dict["optimizer"])
self.scheduler.load_state_dict(state_dict["scheduler"])


@dataclass
Expand All @@ -225,8 +222,8 @@ def reset_on_epoch(self) -> None:
self.epoch.reset_on_epoch()
self.epoch.current.reset()

def __setstate__(self, state: dict) -> None:
self.epoch.__setstate__(state["epoch"])
def load_state_dict(self, state_dict: dict) -> None:
self.epoch.load_state_dict(state_dict["epoch"])


@dataclass
Expand All @@ -245,10 +242,10 @@ class TrainingEpochProgress(EpochProgress):
optim: OptimizationProgress = field(default_factory=OptimizationProgress)
val: EpochLoopProgress = field(default_factory=EpochLoopProgress)

def __setstate__(self, state: dict) -> None:
super().__setstate__(state)
self.optim.__setstate__(state["optim"])
self.val.__setstate__(state["val"])
def load_state_dict(self, state_dict: dict) -> None:
super().load_state_dict(state_dict)
self.optim.load_state_dict(state_dict["optim"])
self.val.load_state_dict(state_dict["val"])


@dataclass
Expand Down
Loading