Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ControlFlowCallback ddp support #1341

Merged
merged 3 commits into from
Oct 30, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
236 changes: 103 additions & 133 deletions catalyst/callbacks/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,139 +10,108 @@
from catalyst.core.runner import IRunner


def _filter_fn_from_epochs(
epochs: Union[int, float, Sequence[int]], reverse_condition: bool
) -> FILTER_FN:
"""Build ``filter_fn`` from epochs for ``ControlFlowCallback``

Args:
epochs: epochs description
reverse_condition: indicator to use reversed
condition in filter function

Raises:
ValueError: if passed object with unexpected type

Returns:
filter function which accepts 3 arguments - stage (str),
epoch (int), loader (str) and return ``True`` if
need to disable callback
"""
if isinstance(epochs, (int, float)):
epochs = int(epochs)
if reverse_condition:
filter_fn = lambda stage, epoch, loader: epoch % epochs != 0
else:
filter_fn = lambda stage, epoch, loader: epoch % epochs == 0
elif isinstance(epochs, (list, tuple)):
epochs = sorted(set(epochs))
if reverse_condition:
filter_fn = lambda stage, epoch, loader: epoch not in epochs
else:
filter_fn = lambda stage, epoch, loader: epoch in epochs
else:
raise ValueError("'epochs' should be int/float/Sequence[int]! " f"(got {type(epochs)})")
return filter_fn


def _filter_fn_from_loaders(loaders: LOADERS, reverse_condition: bool) -> FILTER_FN:
"""Build ``filter_fn`` from loaders for ``ControlFlowCallback``.

Args:
loaders (str/Sequence[str]/Mapping[str, int/Sequence[str]]):
loaders description
reverse_condition: indicator to use reversed
condition in filter function

Raises:
ValueError: if can't build filter_fn from mappings
ValueError: if passed object with unexpected type

Returns:
filter function which accepts 3 arguments - stage (str),
epoch (int), loader (str) and return ``True`` if
need to disable callback
"""
if isinstance(loaders, str):
loaders = [loaders]

# sequence of loaders
if isinstance(loaders, (list, tuple)):
loaders = sorted(set(loaders)) # ignore duplicates
if reverse_condition:
filter_fn = lambda stage, epoch, loader: loader not in loaders
else:
filter_fn = lambda stage, epoch, loader: loader in loaders
# loader: ignore epoch or epochs
elif isinstance(loaders, (dict, OrderedDict)):
ignore_list = {}
for loader, epochs in loaders.items():
if isinstance(epochs, (int, float)):
ignore_list[loader] = [int(epochs)]
else:
try:
ignore_list[loader] = []
for num in sorted(set(epochs)):
to_add = int(num)
ignore_list[loader].append(to_add)
except (ValueError, TypeError):
raise ValueError(
"'ignore_list' should be a dict where "
"keys is a int/float/List[int]/Tuple[int]!"
)
if reverse_condition:
filter_fn = lambda stage, epoch, loader: epoch not in (
ignore_list.get(loader) or {} # {loader: [epoch]}.get(loader)
class _EpochFilterFn:
def __init__(self, epochs: Union[int, float, Sequence[int]], reverse_condition: bool):
if not isinstance(epochs, (int, float, list, tuple)):
raise ValueError(
"'epochs' should be int/float/Sequence[int]! " f"(got {type(epochs)})"
)
else:
filter_fn = lambda stage, epoch, loader: epoch in (ignore_list.get(loader) or {})
else:
raise ValueError(
"'loaders' type should be one of - str, "
"Sequence[str], Mapping[str, int] or "
"Mapping[str, Sequence[int]]! "
f"(got {type(loaders)})"
)
return filter_fn

self.epochs = epochs
self.reverse_condition = reverse_condition

# extra conditions precomputing
if isinstance(self.epochs, (int, float)):
self.epochs = int(self.epochs)
elif isinstance(self.epochs, (list, tuple)):
self.epochs = sorted(set(self.epochs))

def __call__(self, stage, epoch, loader):
if isinstance(self.epochs, (int, float)):
if self.reverse_condition:
return epoch % self.epochs != 0
else:
return epoch % self.epochs == 0
elif isinstance(self.epochs, (list, tuple)):
if self.reverse_condition:
return epoch not in self.epochs
else:
return epoch in self.epochs

def _filter_fn_from_arg(filter_fn: Union[str, FILTER_FN]) -> FILTER_FN:
"""Check if filter function from argumets
can be used with ``ControlFlowCallback``.

Args:
filter_fn (str or Callable): filter function to check

Raises:
ValueError: if ``filter_fn`` is a string and can not be
interpreted as python code then an error will be raised
ValueError: if passed not callable object then will be
raised an error
ValueError: will be raised error if filter function do not
have three arguments

Returns:
filter function which accepts 3 arguments - stage (str),
epoch (int), loader (str) and return ``True`` if
need to disable callback
"""
if isinstance(filter_fn, str):
# lambda function from string
try:
filter_fn = eval(filter_fn)
except (ValueError, SyntaxError):
class _LoaderFilterFn:
def __init__(self, loaders: LOADERS, reverse_condition: bool):
if isinstance(loaders, str):
loaders = [loaders]
if not isinstance(loaders, (list, tuple, dict, OrderedDict)):
raise ValueError(
"'filter_fn' should be a valid "
"python lambda function with "
"three arguments - 'stage', 'epoch' and 'loader'!"
"'loaders' type should be one of - str, "
"Sequence[str], Mapping[str, int] or "
"Mapping[str, Sequence[int]]! "
f"(got {type(loaders)})"
)
if not callable(filter_fn):
raise ValueError("'filter_fn' should be a callable!")
if filter_fn.__code__.co_argcount != 3:
raise ValueError(
"Filter function should have three arguments - " "'stage', 'epoch' and 'loader'!"
)
return filter_fn
self.loaders = loaders
self.reverse_condition = reverse_condition

# extra conditions precomputing
if isinstance(self.loaders, (list, tuple)):
self.loaders = sorted(set(self.loaders)) # ignore duplicates
elif isinstance(self.loaders, (dict, OrderedDict)):
ignore_list = {}
for loader, epochs in self.loaders.items():
if isinstance(epochs, (int, float)):
ignore_list[loader] = [int(epochs)]
else:
try:
ignore_list[loader] = []
for num in sorted(set(epochs)):
to_add = int(num)
ignore_list[loader].append(to_add)
except (ValueError, TypeError):
raise ValueError(
"'ignore_list' should be a dict where "
"keys is a int/float/List[int]/Tuple[int]!"
)
self._ignore_list = ignore_list

def __call__(self, stage, epoch, loader):
# sequence of loaders
if isinstance(self.loaders, (list, tuple)):
if self.reverse_condition:
return loader not in self.loaders
else:
return loader in self.loaders
# loader: ignore epoch or epochs
elif isinstance(self.loaders, (dict, OrderedDict)):
if self.reverse_condition:
return epoch not in (
self._ignore_list.get(loader) or {} # {loader: [epoch]}.get(loader)
)
else:
return epoch in (self._ignore_list.get(loader) or {})


class _ArgsFilterFn:
def __init__(self, filter_fn: Union[str, FILTER_FN]):
if isinstance(filter_fn, str):
# lambda function from string
try:
filter_fn = eval(filter_fn)
except (ValueError, SyntaxError):
raise ValueError(
"'filter_fn' should be a valid "
"python lambda function with "
"three arguments - 'stage', 'epoch' and 'loader'!"
)
if not callable(filter_fn):
raise ValueError("'filter_fn' should be a callable!")
if filter_fn.__code__.co_argcount != 3:
raise ValueError(
"Filter function should have three arguments - " "'stage', 'epoch' and 'loader'!"
)
self.filter_fn = filter_fn

def __call__(self, stage, epoch, loader):
return self.filter_fn(stage, epoch, loader)


class ControlFlowCallback(CallbackWrapper):
Expand Down Expand Up @@ -352,16 +321,17 @@ def __init__(
# loader parameters
self.filter_fn = None

# due to ddp-setup, we have to wrap everything with classes
if epochs is not None:
self.filter_fn = _filter_fn_from_epochs(epochs, False)
self.filter_fn = _EpochFilterFn(epochs, False)
elif ignore_epochs is not None:
self.filter_fn = _filter_fn_from_epochs(ignore_epochs, True)
self.filter_fn = _EpochFilterFn(ignore_epochs, True)
elif loaders is not None:
self.filter_fn = _filter_fn_from_loaders(loaders, False)
self.filter_fn = _LoaderFilterFn(loaders, False)
elif ignore_loaders is not None:
self.filter_fn = _filter_fn_from_loaders(ignore_loaders, True)
self.filter_fn = _LoaderFilterFn(ignore_loaders, True)
elif filter_fn is not None:
self.filter_fn = _filter_fn_from_arg(filter_fn)
self.filter_fn = _ArgsFilterFn(filter_fn)

def on_loader_start(self, runner: "IRunner") -> None:
"""
Expand Down