Skip to content

Commit

Permalink
Merge 421cfa8 into 352e8f0
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda authored Mar 1, 2021
2 parents 352e8f0 + 421cfa8 commit b4ac140
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 8 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed the order of `backward`, `step`, `zero_grad` to `zero_grad`, `backward`, `step` ([#6147](https://github.com/PyTorchLightning/pytorch-lightning/pull/6147))


- Renamed `pytorch_lightning.callbacks.swa` to `pytorch_lightning.callbacks.stochastic_weight_avg` ([#6259](https://github.com/PyTorchLightning/pytorch-lightning/pull/6259))


### Deprecated


Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from pytorch_lightning.callbacks.progress import ProgressBar, ProgressBarBase
from pytorch_lightning.callbacks.pruning import ModelPruning
from pytorch_lightning.callbacks.quantization import QuantizationAwareTraining
from pytorch_lightning.callbacks.swa import StochasticWeightAveraging
from pytorch_lightning.callbacks.stochastic_weight_avg import StochasticWeightAveraging

__all__ = [
'BackboneFinetuning',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,10 @@ def __init__(
if isinstance(swa_epoch_start, float) and not (0 <= swa_epoch_start <= 1):
raise MisconfigurationException(err_msg)

if (
swa_lrs is not None and (
not isinstance(swa_lrs, (float, list)) or isinstance(swa_lrs, float) and swa_lrs <= 0
or isinstance(swa_lrs, list) and not all(lr > 0 and isinstance(lr, float) for lr in swa_lrs)
)
):
wrong_type = not isinstance(swa_lrs, (float, list))
wrong_float = isinstance(swa_lrs, float) and swa_lrs <= 0
wrong_list = isinstance(swa_lrs, list) and not all(lr > 0 and isinstance(lr, float) for lr in swa_lrs)
if (swa_lrs is not None and (wrong_type or wrong_float or wrong_list)):
raise MisconfigurationException("The `swa_lrs` should be a positive float or a list of positive float.")

if avg_fn is not None and not isinstance(avg_fn, Callable):
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/connectors/callback_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def _configure_swa_callbacks(self):
if not self.trainer._stochastic_weight_avg:
return

from pytorch_lightning.callbacks.swa import StochasticWeightAveraging
from pytorch_lightning.callbacks.stochastic_weight_avg import StochasticWeightAveraging
existing_swa = [cb for cb in self.trainer.callbacks if isinstance(cb, StochasticWeightAveraging)]
if not existing_swa:
self.trainer.callbacks = [StochasticWeightAveraging()] + self.trainer.callbacks
Expand Down
File renamed without changes.

0 comments on commit b4ac140

Please sign in to comment.