Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mux up callback #1241

Merged
merged 47 commits into from
Jul 15, 2021
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
ffb45ff
add DynamicBalanceClassSampler
Oct 7, 2020
2146b89
add DynamicBalanceClassSampler: add usage example
Oct 11, 2020
93a9d92
add DynamicBalanceClassSampler: add tests
Oct 11, 2020
8573676
Update catalyst/data/tests/test_sampler.py
Scitator Oct 13, 2020
f4b21ae
Update catalyst/data/tests/test_sampler.py
Scitator Oct 13, 2020
a12a05c
add DynamicBalanceClassSampler: debag tests
Oct 15, 2020
79332e1
update sampler: add mode
Nov 7, 2020
ef33956
add example notebook
Nov 7, 2020
2ad65c6
Merge remote-tracking branch 'original_C/master'
Nov 7, 2020
d61fc8f
sampler: fixes
Nov 8, 2020
2be40b3
samler: docs
Nov 8, 2020
594328f
Merge remote-tracking branch 'original_C/master'
Nov 8, 2020
7c6a68e
DynamicBalanceClassSampler: fixes
Nov 9, 2020
f5dafe4
change import order
Nov 9, 2020
070a4ad
change import order
Nov 9, 2020
3363054
Merge with master
Mar 28, 2021
458ae51
Merge remote-tracking branch 'original_C/master'
Apr 6, 2021
1815863
Merge remote-tracking branch 'original_C/master'
May 8, 2021
de8e681
Delete sampler
Jun 19, 2021
7ab855a
Return mixup callback
Jun 19, 2021
1fbabd2
CodeStyle
Jun 19, 2021
296d98c
Update MixUp: fix batch_size, keys, example
Jun 20, 2021
3ded77b
Merge branch 'master' into MuxUPCallback
Dokholyan Jun 20, 2021
4e94b35
CodeStyle and docstrings
Jul 2, 2021
00c5a48
Merge branch 'master' into MuxUPCallback
Dokholyan Jul 2, 2021
678a85e
MixUp tests
Jul 4, 2021
e48fb70
Merge branch 'MuxUPCallback' of https://github.com/Dokholyan/catalyst…
Jul 4, 2021
710167f
Fix beta type bag
Jul 4, 2021
6bea00c
Merge branch 'master' into MuxUPCallback
Dokholyan Jul 5, 2021
62a2209
Fix bag with devices
Jul 5, 2021
65ed442
Fix import order
Jul 5, 2021
3e7c358
Move the logic into a function
Jul 5, 2021
19d0d2d
Codestyle: fix imports
Jul 5, 2021
af583b9
Fix imports
Jul 5, 2021
6c53dd1
Add docs
Jul 5, 2021
b2fe666
Merge branch 'master' into MuxUPCallback
Dokholyan Jul 5, 2021
ac106d5
Update catalyst/callbacks/mixup.py
Dokholyan Jul 8, 2021
45597e8
Fix docstrings
Dokholyan Jul 8, 2021
59b4460
Update catalyst/utils/mixup.py
Dokholyan Jul 8, 2021
f264535
Update catalyst/utils/mixup.py
Dokholyan Jul 8, 2021
2769b9f
Codestyle fixes
Jul 8, 2021
66351e3
Merge branch 'MuxUPCallback' of https://github.com/Dokholyan/catalyst…
Jul 8, 2021
44eab9d
Fix Docstrings
Jul 8, 2021
4c5ca2e
Fix indexes generation
Jul 8, 2021
c0f36cc
Merge branch 'master' into MuxUPCallback
Dokholyan Jul 8, 2021
aabed03
Simplification MixUp utils
Jul 12, 2021
a4698ca
Merge branch 'MuxUPCallback' of https://github.com/Dokholyan/catalyst…
Jul 12, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions catalyst/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
IEpochMetricHandlerCallback,
EarlyStoppingCallback,
)
from catalyst.callbacks.mixup import MixupCallback
from catalyst.callbacks.optimizer import IOptimizerCallback, OptimizerCallback

if SETTINGS.onnx_required:
Expand Down
269 changes: 168 additions & 101 deletions catalyst/callbacks/mixup.py
Original file line number Diff line number Diff line change
@@ -1,101 +1,168 @@
# from typing import List
#
# import numpy as np
# import torch
#
# from catalyst.callbacks.criterion import CriterionCallback
# from catalyst.core.runner import IRunner
#
#
# class MixupCallback(CriterionCallback):
# """Callback to do mixup augmentation.
#
# More details about mixin can be found in the paper
# `mixup: Beyond Empirical Risk Minimization`_.
#
# .. warning::
# `catalyst.contrib.callbacks.MixupCallback` is inherited from
# `catalyst.callbacks.CriterionCallback` and does its work.
# You may not use them together.
#
# .. _mixup\: Beyond Empirical Risk Minimization: # noqa: W605
# https://arxiv.org/abs/1710.09412
# """
#
# def __init__(
# self,
# input_key: str = "targets",
# output_key: str = "logits",
# fields: List[str] = ("features"),
# alpha=1.0,
# on_train_only=True,
# **kwargs
# ):
# """
# Args:
# fields: list of features which must be affected.
# alpha: beta distribution a=b parameters.
# Must be >=0. The more alpha closer to zero
# the less effect of the mixup.
# on_train_only: Apply to train only.
# As the mixup use the proxy inputs, the targets are also proxy.
# We are not interested in them, are we?
# So, if on_train_only is True, use a standard output/metric
# for validation.
# """
# assert isinstance(input_key, str) and isinstance(output_key, str)
# assert len(fields) > 0, "At least one field for MixupCallback is required"
# assert alpha >= 0, "alpha must be>=0"
#
# super().__init__(input_key=input_key, input_key=output_key, **kwargs)
#
# self.on_train_only = on_train_only
# self.fields = fields
# self.alpha = alpha
# self.lam = 1
# self.index = None
# self.is_needed = True
#
# def _compute_loss_value(self, runner: "IRunner", criterion):
# if not self.is_needed:
# return super()._compute_loss_value(runner, criterion)
#
# pred = runner.output[self.input_key]
# y_a = runner.input[self.input_key]
# y_b = runner.input[self.input_key][self.index]
#
# loss = self.lam * criterion(pred, y_a) + (1 - self.lam) * criterion(pred, y_b)
# return loss
#
# def on_loader_start(self, runner: "IRunner"):
# """Loader start hook.
#
# Args:
# runner: current runner
# """
# self.is_needed = not self.on_train_only or runner.is_train_loader
#
# def on_batch_start(self, runner: "IRunner") -> None:
# """Batch start hook.
#
# Args:
# runner: current runner
# """
# if not self.is_needed:
# return
#
# if self.alpha > 0:
# self.lam = np.random.beta(self.alpha, self.alpha)
# else:
# self.lam = 1
#
# self.index = torch.randperm(runner.input[self.fields[0]].shape[0])
# self.index.to(runner.device)
#
# for f in self.fields:
# runner.input[f] = (
# self.lam * runner.input[f] + (1 - self.lam) * runner.input[f][self.index]
# )
#
#
# __all__ = ["MixupCallback"]
from typing import List, Union

from catalyst.core import Callback, CallbackOrder, IRunner
from catalyst.utils import mixup_batch
Scitator marked this conversation as resolved.
Show resolved Hide resolved


class MixupCallback(Callback):
Scitator marked this conversation as resolved.
Show resolved Hide resolved
"""
Callback to do mixup augmentation. More details about mixin can be found in the paper
`mixup: Beyond Empirical Risk Minimization`: https://arxiv.org/abs/1710.09412 .

Examples:

.. code-block:: python

from typing import Any, Dict
import os

import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader

from catalyst import dl
from catalyst.callbacks import MixupCallback
from catalyst.contrib.datasets import MNIST
from catalyst.data.transforms import ToTensor


class SimpleNet(nn.Module):
def __init__(self, in_channels, in_hw, out_features):
super().__init__()
self.encoder = nn.Sequential(nn.Conv2d(in_channels,
in_channels, 3, 1, 1), nn.Tanh())
self.clf = nn.Linear(in_channels * in_hw * in_hw, out_features)

def forward(self, x):
features = self.encoder(x)
features = features.view(features.size(0), -1)
logits = self.clf(features)
return logits


class SimpleDataset(torch.utils.data.Dataset):
def __init__(self, train: bool = False):
self.mnist = MNIST(os.getcwd(), train=train, download=True, transform=ToTensor())

def __len__(self) -> int:
return len(self.mnist)

def __getitem__(self, idx: int) -> Dict[str, Any]:
x, y = self.mnist.__getitem__(idx)
y_one_hot = np.zeros(10)
y_one_hot[y] = 1
return {"image": x,
"clf_targets": y,
"clf_targets_one_hot": torch.Tensor(y_one_hot)}


model = SimpleNet(1, 28, 10)
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.02)

loaders = {
"train": DataLoader(SimpleDataset(train=True), batch_size=32),
"valid": DataLoader(SimpleDataset(train=False), batch_size=32),
}


class CustomRunner(dl.Runner):
def handle_batch(self, batch):
image = batch["image"]
clf_logits = self.model(image)
self.batch["clf_logits"] = clf_logits


runner = CustomRunner()
runner.train(
loaders=loaders,
model=model,
criterion=criterion,
optimizer=optimizer,
logdir="./logdir14",
num_epochs=2,
verbose=True,
valid_loader="valid",
valid_metric="loss",
minimize_valid_metric=True,
callbacks={
"mixup": MixupCallback(keys=["image", "clf_targets_one_hot"]),
"criterion": dl.CriterionCallback(
metric_key="loss", input_key="clf_logits", target_key="clf_targets_one_hot",
),
"optimizer": dl.OptimizerCallback(metric_key="loss"),
"classification": dl.ControlFlowCallback(
dl.PrecisionRecallF1SupportCallback(
input_key="clf_logits", target_key="clf_targets", num_classes=10,
),
ignore_loaders="train",
),
},
)

.. note::
With running this callback, many metrics (for example, accuracy) become undefined, so
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"By running ..." ?

use ControlFlowCallback in order to evaluate model(see example)
"""

def __init__(
self, keys: Union[str, List[str]], alpha=0.2, mode="replace", on_train_only=True, **kwargs,
):
"""

Args:
keys: batch keys to which you want to apply augmentation
alpha: beta distribution a=b parameters. Must be >=0. The more alpha closer to zero the
less effect of the mixup.
mode: mode determines the method of use. Must be in ["replace", "add"]. If "replace"
then replaces the batch with a mixed one, while the batch size is not changed
If "add", concatenates mixed examples to the current ones, the batch size increases
by 2 times.
on_train_only: apply to train only. As the mixup use the proxy inputs, the targets are
also proxy. We are not interested in them, are we? So, if on_train_only is True,
Dokholyan marked this conversation as resolved.
Show resolved Hide resolved
use a standard output/metric for validation.
**kwargs:
"""
assert isinstance(keys, (str, list)), f"keys must be str of list[str], get: {type(keys)}"
assert alpha >= 0, "alpha must be>=0"
assert mode in ["add", "replace"], f"mode must be in 'add', 'replace', get: {mode}"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd appreciate it if we could use tuple or set here instead of a list

super().__init__(order=CallbackOrder.Internal)
if isinstance(keys, str):
keys = [keys]
self.keys = keys
self.on_train_only = on_train_only
self.alpha = alpha
self.mode = mode
self.required = True

def _handle_batch(self, runner: "IRunner") -> None:
Scitator marked this conversation as resolved.
Show resolved Hide resolved
"""
Applies mixup augmentation for a batch

Args:
runner: runner for the experiment.
"""
runner.batch = mixup_batch(runner.batch, self.keys, alpha=self.alpha, mode=self.mode)
Scitator marked this conversation as resolved.
Show resolved Hide resolved

def on_loader_start(self, runner: "IRunner") -> None:
"""
Loader start hook.

Args:
runner: current runner
"""
self.required = not self.on_train_only or runner.is_train_loader

def on_batch_start(self, runner: "IRunner") -> None:
"""
On batch start action.

Args:
runner: runner for the experiment.
"""
if self.required:
self._handle_batch(runner)


__all__ = ["MixupCallback"]
1 change: 1 addition & 0 deletions catalyst/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
get_by_keys,
convert_labels2list,
)
from catalyst.utils.mixup import mixup_batch
from catalyst.utils.numpy import get_one_hot

from catalyst.utils.onnx import onnx_export
Expand Down
47 changes: 47 additions & 0 deletions catalyst/utils/mixup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from typing import Dict, List

import numpy as np
import torch


def mixup_batch(
Scitator marked this conversation as resolved.
Show resolved Hide resolved
batch: Dict[str, torch.Tensor], keys: List[str], alpha: float = 0.2, mode: str = "replace"
) -> Dict[str, torch.Tensor]:
"""

Args:
batch: batch to which you want to apply augmentation
keys: batch keys to which you want to apply augmentation
alpha: beta distribution a=b parameters. Must be >=0. The more alpha closer to zero the
less effect of the mixup.
Dokholyan marked this conversation as resolved.
Show resolved Hide resolved
mode: mode determines the method of use. Must be in ["replace", "add"]. If "replace"
then replaces the batch with a mixed one, while the batch size is not changed
If "add", concatenates mixed examples to the current ones, the batch size increases
by 2 times.
Dokholyan marked this conversation as resolved.
Show resolved Hide resolved

Returns:
augmented batch

"""
assert isinstance(keys, list), f"keys must be list[str], get: {type(keys)}"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd appreciate it if we could check not for lists only, but for other iterables too, for example, tuples

assert alpha >= 0, "alpha must be>=0"
assert mode in ["add", "replace"], f"mode must be in 'add', 'replace', get: {mode}"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd appreciate it if we could use tuple or set here instead of a list


batch_size = batch[keys[0]].shape[0]
beta = np.random.beta(alpha, alpha, batch_size).astype(np.float32)
indexes = np.array(list(range(batch_size)))
Scitator marked this conversation as resolved.
Show resolved Hide resolved
# index shift by 1
indexes_2 = (indexes + 1) % batch_size
for key in keys:
targets = batch[key]
Scitator marked this conversation as resolved.
Show resolved Hide resolved
device = targets.device
targets_shape = [batch_size] + [1] * len(targets.shape[1:])
key_beta = torch.Tensor(beta.reshape(targets_shape)).to(device)
Dokholyan marked this conversation as resolved.
Show resolved Hide resolved
targets = targets * key_beta + targets[indexes_2] * (1 - key_beta)

if mode == "replace":
batch[key] = targets
else:
# mode == 'add'
batch[key] = torch.cat([batch[key], targets])
return batch
31 changes: 31 additions & 0 deletions tests/catalyst/callbacks/test_mixup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# flake8: noqa
from typing import Tuple

import torch
from torch.utils.data import DataLoader, TensorDataset

from catalyst import dl, utils
from catalyst.callbacks import MixupCallback


class DymmyRunner(dl.Runner):
def handle_batch(self, batch: Tuple[torch.Tensor]):
self.batch = {"image": batch[0], "clf_targets_one_hot": batch[1]}


def test_mixup_1():
utils.set_global_seed(42)
num_samples, num_features, num_classes = int(1e4), int(1e1), 4
X = torch.rand(num_samples, num_features)
y = (torch.rand(num_samples,) * num_classes).to(torch.int64)
y = torch.nn.functional.one_hot(y, num_classes).double()
dataset = TensorDataset(X, y)
loader = DataLoader(dataset, batch_size=32, num_workers=1)
loaders = {"train": loader, "valid": loader}
runner = DymmyRunner()
callback = MixupCallback(keys=["image", "clf_targets_one_hot"])
for loader_name in ["train", "valid"]:
for batch in loaders[loader_name]:
runner.handle_batch(batch)
callback.on_batch_start(runner)
assert runner.batch["clf_targets_one_hot"].max(1)[0].mean() < 1