From 516ba84b938032a0e4ad593c3054412faec1e467 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <jirka.borovec@seznam.cz> Date: Mon, 1 Mar 2021 14:24:34 +0100 Subject: [PATCH 1/6] rename --- pytorch_lightning/callbacks/__init__.py | 2 +- .../callbacks/{swa.py => stochastic_weight_avg.py} | 0 pytorch_lightning/trainer/connectors/callback_connector.py | 2 +- tests/callbacks/{test_swa.py => test_stochastic_weight_avg.py} | 0 4 files changed, 2 insertions(+), 2 deletions(-) rename pytorch_lightning/callbacks/{swa.py => stochastic_weight_avg.py} (100%) rename tests/callbacks/{test_swa.py => test_stochastic_weight_avg.py} (100%) diff --git a/pytorch_lightning/callbacks/__init__.py b/pytorch_lightning/callbacks/__init__.py index f3787c1cb2f7f7..fb61ad81aee283 100644 --- a/pytorch_lightning/callbacks/__init__.py +++ b/pytorch_lightning/callbacks/__init__.py @@ -22,7 +22,7 @@ from pytorch_lightning.callbacks.progress import ProgressBar, ProgressBarBase from pytorch_lightning.callbacks.pruning import ModelPruning from pytorch_lightning.callbacks.quantization import QuantizationAwareTraining -from pytorch_lightning.callbacks.swa import StochasticWeightAveraging +from pytorch_lightning.callbacks.stochastic_weight_avg import StochasticWeightAveraging __all__ = [ 'BackboneFinetuning', diff --git a/pytorch_lightning/callbacks/swa.py b/pytorch_lightning/callbacks/stochastic_weight_avg.py similarity index 100% rename from pytorch_lightning/callbacks/swa.py rename to pytorch_lightning/callbacks/stochastic_weight_avg.py diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index 40ac8f3e698708..8a5289e608c945 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -76,7 +76,7 @@ def _configure_swa_callbacks(self): if not self.trainer._stochastic_weight_avg: return - from pytorch_lightning.callbacks.swa import StochasticWeightAveraging + from pytorch_lightning.callbacks.stochastic_weight_avg import StochasticWeightAveraging existing_swa = [cb for cb in self.trainer.callbacks if isinstance(cb, StochasticWeightAveraging)] if not existing_swa: self.trainer.callbacks = [StochasticWeightAveraging()] + self.trainer.callbacks diff --git a/tests/callbacks/test_swa.py b/tests/callbacks/test_stochastic_weight_avg.py similarity index 100% rename from tests/callbacks/test_swa.py rename to tests/callbacks/test_stochastic_weight_avg.py From 7439e769048649d4748226f805d54863389ac382 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <jirka.borovec@seznam.cz> Date: Mon, 1 Mar 2021 14:27:45 +0100 Subject: [PATCH 2/6] if --- pytorch_lightning/callbacks/stochastic_weight_avg.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/callbacks/stochastic_weight_avg.py b/pytorch_lightning/callbacks/stochastic_weight_avg.py index c8cf367cb4d5e9..d639c67120e96f 100644 --- a/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -102,12 +102,10 @@ def __init__( if isinstance(swa_epoch_start, float) and not (0 <= swa_epoch_start <= 1): raise MisconfigurationException(err_msg) - if ( - swa_lrs is not None and ( - not isinstance(swa_lrs, (float, list)) or isinstance(swa_lrs, float) and swa_lrs <= 0 - or isinstance(swa_lrs, list) and not all(lr > 0 and isinstance(lr, float) for lr in swa_lrs) - ) - ): + _wrong_type = not isinstance(swa_lrs, (float, list)) + _wrong_float = isinstance(swa_lrs, float) and swa_lrs <= 0 + _wrong_list = isinstance(swa_lrs, list) and not all(lr > 0 and isinstance(lr, float) for lr in swa_lrs) + if (swa_lrs is not None and (_wrong_type or _wrong_float or _wrong_list)): raise MisconfigurationException("The `swa_lrs` should be a positive float or a list of positive float.") if avg_fn is not None and not isinstance(avg_fn, Callable): From c39ef97e3f2f4c3bca39a226e260ff7dbdcc5c5f Mon Sep 17 00:00:00 2001 From: Jirka Borovec <jirka.borovec@seznam.cz> Date: Mon, 1 Mar 2021 14:33:35 +0100 Subject: [PATCH 3/6] test --- pytorch_lightning/callbacks/swa.py | 8 ++++++++ tests/deprecated_api/test_remove_1-5.py | 7 +++++++ 2 files changed, 15 insertions(+) create mode 100644 pytorch_lightning/callbacks/swa.py diff --git a/pytorch_lightning/callbacks/swa.py b/pytorch_lightning/callbacks/swa.py new file mode 100644 index 00000000000000..5d72fae5a6a7d8 --- /dev/null +++ b/pytorch_lightning/callbacks/swa.py @@ -0,0 +1,8 @@ +from warnings import warn + +warn( + "`swa` package has been renamed to `stochastic_weight_avg` since v1.3 and will be removed in v1.5", + DeprecationWarning +) + +from pytorch_lightning.callbacks.stochastic_weight_avg import StochasticWeightAveraging # noqa: F401 E402 diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index cb1d4614146038..f17a2af67e911c 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -19,10 +19,17 @@ from pytorch_lightning import Callback, Trainer from pytorch_lightning.loggers import WandbLogger +from tests.deprecated_api import _soft_unimport_module from tests.helpers import BoringModel from tests.helpers.utils import no_warning_call +def test_v1_4_0_deprecated_imports(): + _soft_unimport_module('pytorch_lightning.callbacks.swa') + with pytest.deprecated_call(match='will be removed in v1.4'): + from pytorch_lightning.callbacks.swa import StochasticWeightAveraging # noqa: F811 F401 + + @mock.patch('pytorch_lightning.loggers.wandb.wandb') def test_v1_5_0_wandb_unused_sync_step(tmpdir): with pytest.deprecated_call(match=r"v1.2.1 and will be removed in v1.5"): From 8b4231167bcb6a76d562f495e5c956458ffbc23d Mon Sep 17 00:00:00 2001 From: Jirka Borovec <jirka.borovec@seznam.cz> Date: Mon, 1 Mar 2021 14:35:42 +0100 Subject: [PATCH 4/6] chlog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 06a91bf973790b..b3edb21a7104ee 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed the order of `backward`, `step`, `zero_grad` to `zero_grad`, `backward`, `step` ([#6147](https://github.com/PyTorchLightning/pytorch-lightning/pull/6147)) +- Renamed `pytorch_lightning.callbacks.swa` to `pytorch_lightning.callbacks.stochastic_weight_avg` ([#6259](https://github.com/PyTorchLightning/pytorch-lightning/pull/6259)) + + ### Deprecated From ced1052797c799b6a0388ca59521b89bf268004a Mon Sep 17 00:00:00 2001 From: Jirka Borovec <jirka.borovec@seznam.cz> Date: Mon, 1 Mar 2021 14:50:36 +0100 Subject: [PATCH 5/6] drop swa --- pytorch_lightning/callbacks/swa.py | 8 -------- tests/deprecated_api/test_remove_1-5.py | 7 ------- 2 files changed, 15 deletions(-) delete mode 100644 pytorch_lightning/callbacks/swa.py diff --git a/pytorch_lightning/callbacks/swa.py b/pytorch_lightning/callbacks/swa.py deleted file mode 100644 index 5d72fae5a6a7d8..00000000000000 --- a/pytorch_lightning/callbacks/swa.py +++ /dev/null @@ -1,8 +0,0 @@ -from warnings import warn - -warn( - "`swa` package has been renamed to `stochastic_weight_avg` since v1.3 and will be removed in v1.5", - DeprecationWarning -) - -from pytorch_lightning.callbacks.stochastic_weight_avg import StochasticWeightAveraging # noqa: F401 E402 diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index f17a2af67e911c..cb1d4614146038 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -19,17 +19,10 @@ from pytorch_lightning import Callback, Trainer from pytorch_lightning.loggers import WandbLogger -from tests.deprecated_api import _soft_unimport_module from tests.helpers import BoringModel from tests.helpers.utils import no_warning_call -def test_v1_4_0_deprecated_imports(): - _soft_unimport_module('pytorch_lightning.callbacks.swa') - with pytest.deprecated_call(match='will be removed in v1.4'): - from pytorch_lightning.callbacks.swa import StochasticWeightAveraging # noqa: F811 F401 - - @mock.patch('pytorch_lightning.loggers.wandb.wandb') def test_v1_5_0_wandb_unused_sync_step(tmpdir): with pytest.deprecated_call(match=r"v1.2.1 and will be removed in v1.5"): From 421cfa89d24e855f806f01dd2fef0721664d3db4 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <Borda@users.noreply.github.com> Date: Mon, 1 Mar 2021 17:16:30 +0100 Subject: [PATCH 6/6] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ <carlossmocholi@gmail.com> --- pytorch_lightning/callbacks/stochastic_weight_avg.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/callbacks/stochastic_weight_avg.py b/pytorch_lightning/callbacks/stochastic_weight_avg.py index d639c67120e96f..bece2ffe9f1b21 100644 --- a/pytorch_lightning/callbacks/stochastic_weight_avg.py +++ b/pytorch_lightning/callbacks/stochastic_weight_avg.py @@ -102,10 +102,10 @@ def __init__( if isinstance(swa_epoch_start, float) and not (0 <= swa_epoch_start <= 1): raise MisconfigurationException(err_msg) - _wrong_type = not isinstance(swa_lrs, (float, list)) - _wrong_float = isinstance(swa_lrs, float) and swa_lrs <= 0 - _wrong_list = isinstance(swa_lrs, list) and not all(lr > 0 and isinstance(lr, float) for lr in swa_lrs) - if (swa_lrs is not None and (_wrong_type or _wrong_float or _wrong_list)): + wrong_type = not isinstance(swa_lrs, (float, list)) + wrong_float = isinstance(swa_lrs, float) and swa_lrs <= 0 + wrong_list = isinstance(swa_lrs, list) and not all(lr > 0 and isinstance(lr, float) for lr in swa_lrs) + if (swa_lrs is not None and (wrong_type or wrong_float or wrong_list)): raise MisconfigurationException("The `swa_lrs` should be a positive float or a list of positive float.") if avg_fn is not None and not isinstance(avg_fn, Callable):