From 4ddc0b114d3d19c9f8c3e22cb72a37515535babf Mon Sep 17 00:00:00 2001 From: Sergey Kolesnikov Date: Sat, 30 Oct 2021 08:46:28 +0300 Subject: [PATCH 1/3] fix --- catalyst/callbacks/control_flow.py | 236 +++++++++++++---------------- 1 file changed, 103 insertions(+), 133 deletions(-) diff --git a/catalyst/callbacks/control_flow.py b/catalyst/callbacks/control_flow.py index a5ebe31d46..f3f4848660 100644 --- a/catalyst/callbacks/control_flow.py +++ b/catalyst/callbacks/control_flow.py @@ -10,139 +10,108 @@ from catalyst.core.runner import IRunner -def _filter_fn_from_epochs( - epochs: Union[int, float, Sequence[int]], reverse_condition: bool -) -> FILTER_FN: - """Build ``filter_fn`` from epochs for ``ControlFlowCallback`` - - Args: - epochs: epochs description - reverse_condition: indicator to use reversed - condition in filter function - - Raises: - ValueError: if passed object with unexpected type - - Returns: - filter function which accepts 3 arguments - stage (str), - epoch (int), loader (str) and return ``True`` if - need to disable callback - """ - if isinstance(epochs, (int, float)): - epochs = int(epochs) - if reverse_condition: - filter_fn = lambda stage, epoch, loader: epoch % epochs != 0 - else: - filter_fn = lambda stage, epoch, loader: epoch % epochs == 0 - elif isinstance(epochs, (list, tuple)): - epochs = sorted(set(epochs)) - if reverse_condition: - filter_fn = lambda stage, epoch, loader: epoch not in epochs - else: - filter_fn = lambda stage, epoch, loader: epoch in epochs - else: - raise ValueError("'epochs' should be int/float/Sequence[int]! " f"(got {type(epochs)})") - return filter_fn - - -def _filter_fn_from_loaders(loaders: LOADERS, reverse_condition: bool) -> FILTER_FN: - """Build ``filter_fn`` from loaders for ``ControlFlowCallback``. - - Args: - loaders (str/Sequence[str]/Mapping[str, int/Sequence[str]]): - loaders description - reverse_condition: indicator to use reversed - condition in filter function - - Raises: - ValueError: if can't build filter_fn from mappings - ValueError: if passed object with unexpected type - - Returns: - filter function which accepts 3 arguments - stage (str), - epoch (int), loader (str) and return ``True`` if - need to disable callback - """ - if isinstance(loaders, str): - loaders = [loaders] - - # sequence of loaders - if isinstance(loaders, (list, tuple)): - loaders = sorted(set(loaders)) # ignore duplicates - if reverse_condition: - filter_fn = lambda stage, epoch, loader: loader not in loaders - else: - filter_fn = lambda stage, epoch, loader: loader in loaders - # loader: ignore epoch or epochs - elif isinstance(loaders, (dict, OrderedDict)): - ignore_list = {} - for loader, epochs in loaders.items(): - if isinstance(epochs, (int, float)): - ignore_list[loader] = [int(epochs)] - else: - try: - ignore_list[loader] = [] - for num in sorted(set(epochs)): - to_add = int(num) - ignore_list[loader].append(to_add) - except (ValueError, TypeError): - raise ValueError( - "'ignore_list' should be a dict where " - "keys is a int/float/List[int]/Tuple[int]!" - ) - if reverse_condition: - filter_fn = lambda stage, epoch, loader: epoch not in ( - ignore_list.get(loader) or {} # {loader: [epoch]}.get(loader) +class _EpochFilterFn: + def __init__(self, epochs: Union[int, float, Sequence[int]], reverse_condition: bool): + if not isinstance(epochs, (int, float, list, tuple)): + raise ValueError( + "'epochs' should be int/float/Sequence[int]! " f"(got {type(epochs)})" ) - else: - filter_fn = lambda stage, epoch, loader: epoch in (ignore_list.get(loader) or {}) - else: - raise ValueError( - "'loaders' type should be one of - str, " - "Sequence[str], Mapping[str, int] or " - "Mapping[str, Sequence[int]]! " - f"(got {type(loaders)})" - ) - return filter_fn - + self.epochs = epochs + self.reverse_condition = reverse_condition + + # extra conditions precomputing + if isinstance(self.epochs, (int, float)): + self.epochs = int(self.epochs) + elif isinstance(self.epochs, (list, tuple)): + self.epochs = sorted(set(self.epochs)) + + def __call__(self, stage, epoch, loader): + if isinstance(self.epochs, (int, float)): + if self.reverse_condition: + return epoch % self.epochs != 0 + else: + return epoch % self.epochs == 0 + elif isinstance(self.epochs, (list, tuple)): + if self.reverse_condition: + return epoch not in self.epochs + else: + return epoch in self.epochs -def _filter_fn_from_arg(filter_fn: Union[str, FILTER_FN]) -> FILTER_FN: - """Check if filter function from argumets - can be used with ``ControlFlowCallback``. - Args: - filter_fn (str or Callable): filter function to check - - Raises: - ValueError: if ``filter_fn`` is a string and can not be - interpreted as python code then an error will be raised - ValueError: if passed not callable object then will be - raised an error - ValueError: will be raised error if filter function do not - have three arguments - - Returns: - filter function which accepts 3 arguments - stage (str), - epoch (int), loader (str) and return ``True`` if - need to disable callback - """ - if isinstance(filter_fn, str): - # lambda function from string - try: - filter_fn = eval(filter_fn) - except (ValueError, SyntaxError): +class _LoaderFilterFn: + def __init__(self, loaders: LOADERS, reverse_condition: bool): + if not isinstance(loaders, (list, tuple, dict, OrderedDict)): raise ValueError( - "'filter_fn' should be a valid " - "python lambda function with " - "three arguments - 'stage', 'epoch' and 'loader'!" + "'loaders' type should be one of - str, " + "Sequence[str], Mapping[str, int] or " + "Mapping[str, Sequence[int]]! " + f"(got {type(loaders)})" ) - if not callable(filter_fn): - raise ValueError("'filter_fn' should be a callable!") - if filter_fn.__code__.co_argcount != 3: - raise ValueError( - "Filter function should have three arguments - " "'stage', 'epoch' and 'loader'!" - ) - return filter_fn + if isinstance(loaders, str): + loaders = [loaders] + self.loaders = loaders + self.reverse_condition = reverse_condition + + # extra conditions precomputing + if isinstance(self.loaders, (list, tuple)): + self.loaders = sorted(set(self.loaders)) # ignore duplicates + elif isinstance(self.loaders, (dict, OrderedDict)): + ignore_list = {} + for loader, epochs in self.loaders.items(): + if isinstance(epochs, (int, float)): + ignore_list[loader] = [int(epochs)] + else: + try: + ignore_list[loader] = [] + for num in sorted(set(epochs)): + to_add = int(num) + ignore_list[loader].append(to_add) + except (ValueError, TypeError): + raise ValueError( + "'ignore_list' should be a dict where " + "keys is a int/float/List[int]/Tuple[int]!" + ) + self._ignore_list = ignore_list + + def __call__(self, stage, epoch, loader): + # sequence of loaders + if isinstance(self.loaders, (list, tuple)): + if self.reverse_condition: + return loader not in self.loaders + else: + return loader in self.loaders + # loader: ignore epoch or epochs + elif isinstance(self.loaders, (dict, OrderedDict)): + if self.reverse_condition: + return epoch not in ( + self._ignore_list.get(loader) or {} # {loader: [epoch]}.get(loader) + ) + else: + return epoch in (self._ignore_list.get(loader) or {}) + + +class _ArgsFilterFn: + def __init__(self, filter_fn: Union[str, FILTER_FN]): + if isinstance(filter_fn, str): + # lambda function from string + try: + filter_fn = eval(filter_fn) + except (ValueError, SyntaxError): + raise ValueError( + "'filter_fn' should be a valid " + "python lambda function with " + "three arguments - 'stage', 'epoch' and 'loader'!" + ) + if not callable(filter_fn): + raise ValueError("'filter_fn' should be a callable!") + if filter_fn.__code__.co_argcount != 3: + raise ValueError( + "Filter function should have three arguments - " "'stage', 'epoch' and 'loader'!" + ) + self.filter_fn = filter_fn + + def __call__(self, stage, epoch, loader): + return self.filter_fn(stage, epoch, loader) class ControlFlowCallback(CallbackWrapper): @@ -352,16 +321,17 @@ def __init__( # loader parameters self.filter_fn = None + # due to ddp-setup, we have to wrap everything with classes if epochs is not None: - self.filter_fn = _filter_fn_from_epochs(epochs, False) + self.filter_fn = _EpochFilterFn(epochs, False) elif ignore_epochs is not None: - self.filter_fn = _filter_fn_from_epochs(ignore_epochs, True) + self.filter_fn = _EpochFilterFn(ignore_epochs, True) elif loaders is not None: - self.filter_fn = _filter_fn_from_loaders(loaders, False) + self.filter_fn = _LoaderFilterFn(loaders, False) elif ignore_loaders is not None: - self.filter_fn = _filter_fn_from_loaders(ignore_loaders, True) + self.filter_fn = _LoaderFilterFn(ignore_loaders, True) elif filter_fn is not None: - self.filter_fn = _filter_fn_from_arg(filter_fn) + self.filter_fn = _ArgsFilterFn(filter_fn) def on_loader_start(self, runner: "IRunner") -> None: """ From d9e6a05ffa20d101f47a2bf4905a17550385fd6c Mon Sep 17 00:00:00 2001 From: Sergey Kolesnikov Date: Sat, 30 Oct 2021 08:59:00 +0300 Subject: [PATCH 2/3] codestyle --- catalyst/callbacks/control_flow.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/catalyst/callbacks/control_flow.py b/catalyst/callbacks/control_flow.py index f3f4848660..686195dc30 100644 --- a/catalyst/callbacks/control_flow.py +++ b/catalyst/callbacks/control_flow.py @@ -18,7 +18,7 @@ def __init__(self, epochs: Union[int, float, Sequence[int]], reverse_condition: ) self.epochs = epochs self.reverse_condition = reverse_condition - + # extra conditions precomputing if isinstance(self.epochs, (int, float)): self.epochs = int(self.epochs) @@ -51,7 +51,7 @@ def __init__(self, loaders: LOADERS, reverse_condition: bool): loaders = [loaders] self.loaders = loaders self.reverse_condition = reverse_condition - + # extra conditions precomputing if isinstance(self.loaders, (list, tuple)): self.loaders = sorted(set(self.loaders)) # ignore duplicates From 728f0beca9be9cede6a5a44f10cd30209478d3c8 Mon Sep 17 00:00:00 2001 From: Sergey Kolesnikov Date: Sat, 30 Oct 2021 09:41:03 +0300 Subject: [PATCH 3/3] fix --- catalyst/callbacks/control_flow.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/catalyst/callbacks/control_flow.py b/catalyst/callbacks/control_flow.py index 686195dc30..3206709c09 100644 --- a/catalyst/callbacks/control_flow.py +++ b/catalyst/callbacks/control_flow.py @@ -40,6 +40,8 @@ def __call__(self, stage, epoch, loader): class _LoaderFilterFn: def __init__(self, loaders: LOADERS, reverse_condition: bool): + if isinstance(loaders, str): + loaders = [loaders] if not isinstance(loaders, (list, tuple, dict, OrderedDict)): raise ValueError( "'loaders' type should be one of - str, " @@ -47,8 +49,6 @@ def __init__(self, loaders: LOADERS, reverse_condition: bool): "Mapping[str, Sequence[int]]! " f"(got {type(loaders)})" ) - if isinstance(loaders, str): - loaders = [loaders] self.loaders = loaders self.reverse_condition = reverse_condition