-
-
Notifications
You must be signed in to change notification settings - Fork 392
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Mux up callback #1241
Merged
Merged
Mux up callback #1241
Changes from 23 commits
Commits
Show all changes
47 commits
Select commit
Hold shift + click to select a range
ffb45ff
add DynamicBalanceClassSampler
2146b89
add DynamicBalanceClassSampler: add usage example
93a9d92
add DynamicBalanceClassSampler: add tests
8573676
Update catalyst/data/tests/test_sampler.py
Scitator f4b21ae
Update catalyst/data/tests/test_sampler.py
Scitator a12a05c
add DynamicBalanceClassSampler: debag tests
79332e1
update sampler: add mode
ef33956
add example notebook
2ad65c6
Merge remote-tracking branch 'original_C/master'
d61fc8f
sampler: fixes
2be40b3
samler: docs
594328f
Merge remote-tracking branch 'original_C/master'
7c6a68e
DynamicBalanceClassSampler: fixes
f5dafe4
change import order
070a4ad
change import order
3363054
Merge with master
458ae51
Merge remote-tracking branch 'original_C/master'
1815863
Merge remote-tracking branch 'original_C/master'
de8e681
Delete sampler
7ab855a
Return mixup callback
1fbabd2
CodeStyle
296d98c
Update MixUp: fix batch_size, keys, example
3ded77b
Merge branch 'master' into MuxUPCallback
Dokholyan 4e94b35
CodeStyle and docstrings
00c5a48
Merge branch 'master' into MuxUPCallback
Dokholyan 678a85e
MixUp tests
e48fb70
Merge branch 'MuxUPCallback' of https://github.com/Dokholyan/catalyst…
710167f
Fix beta type bag
6bea00c
Merge branch 'master' into MuxUPCallback
Dokholyan 62a2209
Fix bag with devices
65ed442
Fix import order
3e7c358
Move the logic into a function
19d0d2d
Codestyle: fix imports
af583b9
Fix imports
6c53dd1
Add docs
b2fe666
Merge branch 'master' into MuxUPCallback
Dokholyan ac106d5
Update catalyst/callbacks/mixup.py
Dokholyan 45597e8
Fix docstrings
Dokholyan 59b4460
Update catalyst/utils/mixup.py
Dokholyan f264535
Update catalyst/utils/mixup.py
Dokholyan 2769b9f
Codestyle fixes
66351e3
Merge branch 'MuxUPCallback' of https://github.com/Dokholyan/catalyst…
44eab9d
Fix Docstrings
4c5ca2e
Fix indexes generation
c0f36cc
Merge branch 'master' into MuxUPCallback
Dokholyan aabed03
Simplification MixUp utils
a4698ca
Merge branch 'MuxUPCallback' of https://github.com/Dokholyan/catalyst…
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,101 +1,190 @@ | ||
# from typing import List | ||
# | ||
# import numpy as np | ||
# import torch | ||
# | ||
# from catalyst.callbacks.criterion import CriterionCallback | ||
# from catalyst.core.runner import IRunner | ||
# | ||
# | ||
# class MixupCallback(CriterionCallback): | ||
# """Callback to do mixup augmentation. | ||
# | ||
# More details about mixin can be found in the paper | ||
# `mixup: Beyond Empirical Risk Minimization`_. | ||
# | ||
# .. warning:: | ||
# `catalyst.contrib.callbacks.MixupCallback` is inherited from | ||
# `catalyst.callbacks.CriterionCallback` and does its work. | ||
# You may not use them together. | ||
# | ||
# .. _mixup\: Beyond Empirical Risk Minimization: # noqa: W605 | ||
# https://arxiv.org/abs/1710.09412 | ||
# """ | ||
# | ||
# def __init__( | ||
# self, | ||
# input_key: str = "targets", | ||
# output_key: str = "logits", | ||
# fields: List[str] = ("features"), | ||
# alpha=1.0, | ||
# on_train_only=True, | ||
# **kwargs | ||
# ): | ||
# """ | ||
# Args: | ||
# fields: list of features which must be affected. | ||
# alpha: beta distribution a=b parameters. | ||
# Must be >=0. The more alpha closer to zero | ||
# the less effect of the mixup. | ||
# on_train_only: Apply to train only. | ||
# As the mixup use the proxy inputs, the targets are also proxy. | ||
# We are not interested in them, are we? | ||
# So, if on_train_only is True, use a standard output/metric | ||
# for validation. | ||
# """ | ||
# assert isinstance(input_key, str) and isinstance(output_key, str) | ||
# assert len(fields) > 0, "At least one field for MixupCallback is required" | ||
# assert alpha >= 0, "alpha must be>=0" | ||
# | ||
# super().__init__(input_key=input_key, input_key=output_key, **kwargs) | ||
# | ||
# self.on_train_only = on_train_only | ||
# self.fields = fields | ||
# self.alpha = alpha | ||
# self.lam = 1 | ||
# self.index = None | ||
# self.is_needed = True | ||
# | ||
# def _compute_loss_value(self, runner: "IRunner", criterion): | ||
# if not self.is_needed: | ||
# return super()._compute_loss_value(runner, criterion) | ||
# | ||
# pred = runner.output[self.input_key] | ||
# y_a = runner.input[self.input_key] | ||
# y_b = runner.input[self.input_key][self.index] | ||
# | ||
# loss = self.lam * criterion(pred, y_a) + (1 - self.lam) * criterion(pred, y_b) | ||
# return loss | ||
# | ||
# def on_loader_start(self, runner: "IRunner"): | ||
# """Loader start hook. | ||
# | ||
# Args: | ||
# runner: current runner | ||
# """ | ||
# self.is_needed = not self.on_train_only or runner.is_train_loader | ||
# | ||
# def on_batch_start(self, runner: "IRunner") -> None: | ||
# """Batch start hook. | ||
# | ||
# Args: | ||
# runner: current runner | ||
# """ | ||
# if not self.is_needed: | ||
# return | ||
# | ||
# if self.alpha > 0: | ||
# self.lam = np.random.beta(self.alpha, self.alpha) | ||
# else: | ||
# self.lam = 1 | ||
# | ||
# self.index = torch.randperm(runner.input[self.fields[0]].shape[0]) | ||
# self.index.to(runner.device) | ||
# | ||
# for f in self.fields: | ||
# runner.input[f] = ( | ||
# self.lam * runner.input[f] + (1 - self.lam) * runner.input[f][self.index] | ||
# ) | ||
# | ||
# | ||
# __all__ = ["MixupCallback"] | ||
from typing import List, Union | ||
|
||
import numpy as np | ||
import torch | ||
|
||
from catalyst.core import Callback, CallbackOrder, IRunner | ||
|
||
|
||
class MixupCallback(Callback): | ||
Scitator marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Callback to do mixup augmentation. More details about mixin can be found in the paper | ||
`mixup: Beyond Empirical Risk Minimization. | ||
Scitator marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Examples: | ||
|
||
.. code-block:: python | ||
|
||
import os | ||
import numpy as np | ||
import torch | ||
from torch import nn | ||
from torch.utils.data import DataLoader | ||
from catalyst import dl | ||
from catalyst.callbacks import MixupCallback | ||
from catalyst.data.transforms import ToTensor | ||
from catalyst.contrib.datasets import MNIST | ||
from typing import Dict, Any | ||
|
||
|
||
class SimpleNet(nn.Module): | ||
|
||
def __init__(self, in_channels, in_hw, out_features): | ||
super().__init__() | ||
self.encoder = nn.Sequential(nn.Conv2d(in_channels, in_channels, 3, 1, 1), | ||
nn.Tanh()) | ||
self.clf = nn.Linear(in_channels * in_hw * in_hw, out_features) | ||
|
||
def forward(self, x): | ||
z = self.encoder(x) | ||
z_ = z.view(z.size(0), -1) | ||
y_hat = self.clf(z_) | ||
return y_hat | ||
|
||
|
||
class SimpleDataset(torch.utils.data.Dataset): | ||
def __init__(self, train: bool = False): | ||
self.mnist = MNIST(os.getcwd(), train=train, download=True, transform=ToTensor()) | ||
Scitator marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def __len__(self) -> int: | ||
return len(self.mnist) | ||
|
||
def __getitem__(self, idx: int) -> Dict[str, Any]: | ||
x, y = self.mnist.__getitem__(idx) | ||
y_one_hot = np.zeros(10) | ||
y_one_hot[y] = 1 | ||
return {'image': x, | ||
'clf_targets': y, | ||
'clf_targets_one_hot': torch.Tensor(y_one_hot)} | ||
|
||
|
||
model = SimpleNet(1, 28, 10) | ||
criterion = torch.nn.BCEWithLogitsLoss() | ||
optimizer = torch.optim.Adam(model.parameters(), lr=0.02) | ||
|
||
loaders = { | ||
"train": DataLoader(SimpleDataset(train=True), batch_size=32), | ||
"valid": DataLoader(SimpleDataset(train=False), batch_size=32), | ||
} | ||
|
||
|
||
class CustomRunner(dl.Runner): | ||
|
||
def handle_batch(self, batch): | ||
image = batch['image'] | ||
clf_logits = self.model(image) | ||
self.batch['clf_logits'] = clf_logits | ||
|
||
|
||
runner = CustomRunner() | ||
runner.train( | ||
loaders=loaders, | ||
model=model, | ||
criterion=criterion, | ||
optimizer=optimizer, | ||
logdir="./logdir14", | ||
num_epochs=2, | ||
verbose=True, | ||
valid_loader="valid", | ||
valid_metric="loss", | ||
minimize_valid_metric=True, | ||
callbacks={ | ||
"mixup": MixupCallback(keys=['image', 'clf_targets_one_hot']), | ||
"criterion": dl.CriterionCallback( | ||
metric_key="loss", | ||
input_key="clf_logits", | ||
target_key="clf_targets_one_hot", | ||
), | ||
"optimizer": dl.OptimizerCallback(metric_key="loss"), | ||
"classification": dl.ControlFlowCallback( | ||
dl.PrecisionRecallF1SupportCallback(input_key="clf_logits", | ||
target_key="clf_targets", | ||
num_classes=10, ), | ||
ignore_loaders='train', | ||
), | ||
|
||
}, | ||
) | ||
|
||
.. note:: | ||
Callback can only be used with an even batch size | ||
Scitator marked this conversation as resolved.
Show resolved
Hide resolved
|
||
With running this callback, many metrics (for example, accuracy) become undefined, so | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "By running ..." ? |
||
use ControlFlowCallback in order to evaluate model(see example) | ||
|
||
""" | ||
|
||
def __init__( | ||
self, keys: Union[str, List[str]], alpha=0.2, mode="replace", on_train_only=True, **kwargs, | ||
): | ||
""" | ||
|
||
Args: | ||
keys: batch keys to which you want to apply augmentation | ||
alpha: beta distribution a=b parameters. Must be >=0. The more alpha closer to zero the | ||
less effect of the mixup. | ||
mode: mode determines the method of use. Must be in ["replace", "add"]. If "replace" | ||
then replaces the batch with a mixed one, while the batch size is not changed | ||
If "add", concatenates mixed examples to the current ones, the batch size increases by | ||
s times. | ||
Scitator marked this conversation as resolved.
Show resolved
Hide resolved
|
||
on_train_only: apply to train only. As the mixup use the proxy inputs, the targets are | ||
also proxy. We are not interested in them, are we? So, if on_train_only is True, use a | ||
standard output/metric for validation. | ||
Scitator marked this conversation as resolved.
Show resolved
Hide resolved
|
||
**kwargs: | ||
""" | ||
assert isinstance(keys, (str, list)), f"keys must be str of list[str], get: {type(keys)}" | ||
assert alpha >= 0, "alpha must be>=0" | ||
assert mode in ["add", "replace"], f"mode must be in 'add', 'replace', get: {mode}" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd appreciate it if we could use tuple or set here instead of a list |
||
super().__init__(order=CallbackOrder.Internal) | ||
if isinstance(keys, str): | ||
keys = [keys] | ||
self.keys = keys | ||
self.on_train_only = on_train_only | ||
self.alpha = alpha | ||
self.mode = mode | ||
self.is_needed = True | ||
|
||
def _handle_batch(self, runner: "IRunner") -> None: | ||
Scitator marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Applies mixup augmentation for a batch | ||
|
||
Args: | ||
runner: runner for the experiment. | ||
""" | ||
batch_size = runner.batch[self.keys[0]].shape[0] | ||
beta = np.random.beta(self.alpha, self.alpha, batch_size).astype(np.float32) | ||
indexes = np.array(list(range(batch_size))) | ||
# index shift by 1 | ||
indexes_2 = (indexes + 1) % batch_size | ||
for key in self.keys: | ||
targets = runner.batch[key] | ||
targets_shape = [batch_size] + [1] * len(targets.shape[1:]) | ||
key_beta = beta.reshape(targets_shape) | ||
targets = targets * key_beta + targets[indexes_2] * (1 - key_beta) | ||
|
||
if self.mode == "replace": | ||
runner.batch[key] = targets | ||
else: | ||
# self.mode == 'add' | ||
runner.batch[key] = torch.cat([runner.batch[key], targets]) | ||
|
||
def on_loader_start(self, runner: "IRunner") -> None: | ||
""" | ||
Loader start hook. | ||
|
||
Args: | ||
runner: current runner | ||
""" | ||
self.is_needed = not self.on_train_only or runner.is_train_loader | ||
Scitator marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def on_batch_start(self, runner: "IRunner") -> None: | ||
""" | ||
On batch start action. | ||
|
||
Args: | ||
runner: runner for the experiment. | ||
""" | ||
if self.is_needed: | ||
self._handle_batch(runner) | ||
|
||
|
||
__all__ = ["MixupCallback"] |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could we make imports in lexicographical order?