From 7926dafb506b1e0952fc7588213c1a1518cb7bed Mon Sep 17 00:00:00 2001 From: Yauheni Kachan <19803638+bagxi@users.noreply.github.com> Date: Sun, 28 Mar 2021 16:49:58 +0300 Subject: [PATCH] fix: `_key_value` for schedulers in case of multiple optimizers fixed --- CHANGELOG.md | 1 + catalyst/runners/config.py | 4 ++-- catalyst/runners/hydra.py | 4 ++-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4d599875e2..3307fe18d2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added gradient clipping function to optimizer callback ([1124](https://github.com/catalyst-team/catalyst/pull/1124)) - FactorizedLinear to contrib ([1142](https://github.com/catalyst-team/catalyst/pull/1142)) - Extra init params for ``ConsoleLogger`` ([1142](https://github.com/catalyst-team/catalyst/pull/1142)) +- `_key_value` for schedulers in case of multiple optimizers fixed ([#1146](https://github.com/catalyst-team/catalyst/pull/1146)) ### Changed diff --git a/catalyst/runners/config.py b/catalyst/runners/config.py index 0bfe352a52..5a3f94f8f1 100644 --- a/catalyst/runners/config.py +++ b/catalyst/runners/config.py @@ -310,9 +310,9 @@ def _get_scheduler_from_params(*, optimizer: RunnerOptimizer, **params) -> Runne for key, scheduler_params in params.items(): scheduler_params = deepcopy(scheduler_params) optimizer_key = scheduler_params.pop("_optimizer", None) - optimizer = optimizer[optimizer_key] if optimizer_key else optimizer + optim = optimizer[optimizer_key] if optimizer_key else optimizer scheduler[key] = ConfigRunner._get_scheduler_from_params( - **scheduler_params, optimizer=optimizer + **scheduler_params, optimizer=optim ) # noqa: WPS437 else: optimizer_key = params.pop("_optimizer", None) diff --git a/catalyst/runners/hydra.py b/catalyst/runners/hydra.py index 96e7817016..7fc0610add 100644 --- a/catalyst/runners/hydra.py +++ b/catalyst/runners/hydra.py @@ -311,9 +311,9 @@ def _get_scheduler_from_params( for key, scheduler_params in params.items(): scheduler_params = deepcopy(scheduler_params) optimizer_key = scheduler_params._optimizer or None - optimizer = optimizer[optimizer_key] if optimizer_key else optimizer + optim = optimizer[optimizer_key] if optimizer_key else optimizer scheduler[key] = HydraRunner._get_scheduler_from_params( # noqa: WPS437 - optimizer=optimizer, params=scheduler_params + optimizer=optim, params=scheduler_params ) else: optimizer_key = params._optimizer or None