Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

extract optimizer loop #9191

Merged
merged 45 commits into from
Sep 2, 2021
Merged
Show file tree
Hide file tree
Changes from 44 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
fea744f
wip
awaelchli Jul 24, 2021
81d3797
extract optimizer loop
awaelchli Aug 27, 2021
b39ddbf
handle restart
awaelchli Aug 29, 2021
97eabcf
update running loss
awaelchli Aug 29, 2021
69a29d9
add changelog
awaelchli Aug 29, 2021
bec5341
update tests
awaelchli Aug 29, 2021
58404aa
refactor block parallel sync behavior
awaelchli Aug 29, 2021
9226271
remove automatic opt specific logic
awaelchli Aug 29, 2021
d376c91
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 29, 2021
38acf9a
fix circular import
awaelchli Aug 30, 2021
36c8692
fix swa tests
awaelchli Aug 30, 2021
5682a6b
fix state dict test
awaelchli Aug 30, 2021
0b59c37
add connect
awaelchli Aug 30, 2021
574c11e
fix reset
awaelchli Aug 30, 2021
ce4f08a
fix imports
awaelchli Aug 30, 2021
39fb458
add license
awaelchli Aug 30, 2021
9f1880e
fix test_loops.py
awaelchli Aug 30, 2021
d4dfd54
remove commented code
awaelchli Aug 30, 2021
0b68a11
add docstrings
awaelchli Aug 30, 2021
624394f
fix typing in constructor
awaelchli Aug 30, 2021
07fbafe
update hidden state management
awaelchli Aug 30, 2021
7572deb
extract build_kwargs method
awaelchli Aug 30, 2021
9d4f459
remove todo
awaelchli Aug 30, 2021
b4962ef
isort
awaelchli Aug 30, 2021
42cdc14
update init files
awaelchli Aug 30, 2021
fde9e9e
Merge branch 'master' into refactor/optimizer-loop
awaelchli Aug 30, 2021
bd128f0
fix loop state dict test
awaelchli Aug 30, 2021
a304ac5
fix tbtt tests
awaelchli Aug 30, 2021
eb00a4c
fix imports
awaelchli Aug 30, 2021
d346be2
no longer duplicated
awaelchli Aug 30, 2021
b0c997e
remove unused optimiizer arguments for the manual opt path
awaelchli Aug 30, 2021
86743aa
update typehint
awaelchli Aug 30, 2021
e130da5
update docs
awaelchli Aug 30, 2021
586ead7
Merge branch 'master' into refactor/optimizer-loop
awaelchli Aug 30, 2021
6b123be
Merge branch 'master' into refactor/optimizer-loop
awaelchli Aug 30, 2021
fa4c788
remove unused argument
awaelchli Aug 30, 2021
9701c89
Merge branch 'master' into refactor/optimizer-loop
awaelchli Aug 30, 2021
ef05def
update typing
awaelchli Aug 31, 2021
87f3002
Merge branch 'master' into refactor/optimizer-loop
awaelchli Aug 31, 2021
73c60d3
remove redundant process closure result
awaelchli Aug 31, 2021
fc03c90
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 31, 2021
6ca1933
add todo
awaelchli Aug 31, 2021
e5454ff
remove list copy for optimizers
awaelchli Aug 31, 2021
2fadaac
undo skip_backward changes in swa
awaelchli Aug 31, 2021
8475526
Merge branch 'master' into refactor/optimizer-loop
awaelchli Sep 1, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Loop customization:
* Added `Closure` and `AbstractClosure` classes ([#8642](https://github.com/PyTorchLightning/pytorch-lightning/pull/8642))

* Refactored `TrainingBatchLoop` and extracted `OptimizerLoop`, splitting off automatic optimization into its own loop ([#9191](https://github.com/PyTorchLightning/pytorch-lightning/pull/9191))

- Added support for saving and loading state of multiple callbacks of the same type ([#7187](https://github.com/PyTorchLightning/pytorch-lightning/pull/7187))

Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/loops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@
from pytorch_lightning.loops.dataloader import DataLoaderLoop, EvaluationLoop, PredictionLoop # noqa: F401
from pytorch_lightning.loops.epoch import EvaluationEpochLoop, PredictionEpochLoop, TrainingEpochLoop # noqa: F401
from pytorch_lightning.loops.fit_loop import FitLoop # noqa: F401
from pytorch_lightning.loops.optimizer.optimizer_loop import OptimizerLoop # noqa: F401
247 changes: 26 additions & 221 deletions pytorch_lightning/loops/batch/training_batch_loop.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def reset(self) -> None:
if not self.restarting:
self.batch_progress.current.reset()
self.scheduler_progress.current.reset()
self.batch_loop.optim_progress.reset_on_epoch()
self.batch_loop.optimizer_loop.optim_progress.reset_on_epoch()

def on_run_start(self, dataloader_iter: Iterator, **kwargs: Any) -> None:
# hook
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,12 @@ def running_loss(self) -> TensorRunningAccum:
@property
def _skip_backward(self) -> bool:
"""Determines whether the loop will skip backward during automatic optimization."""
return self.epoch_loop.batch_loop._skip_backward
return self.epoch_loop.batch_loop.optimizer_loop._skip_backward

@_skip_backward.setter
def _skip_backward(self, value: bool) -> None:
"""Determines whether the loop will skip backward during automatic optimization."""
self.epoch_loop.batch_loop._skip_backward = value
self.epoch_loop.batch_loop.optimizer_loop._skip_backward = value

@property
def _results(self) -> ResultCollection:
Expand Down
15 changes: 15 additions & 0 deletions pytorch_lightning/loops/optimizer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pytorch_lightning.loops.optimizer.optimizer_loop import OptimizerLoop # noqa: F401
Loading