Skip to content

Commit

Permalink
Merge pull request #316 from jonasrauber/generator-batching3
Browse files Browse the repository at this point in the history
batch support
  • Loading branch information
jonasrauber authored May 21, 2019
2 parents d9ef908 + e01cfc7 commit 562f21e
Show file tree
Hide file tree
Showing 59 changed files with 1,807 additions and 12 deletions.
5 changes: 3 additions & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,9 @@ install:
script:
- pytest --collect-only
# tf eager cannot be run in the same process as standard tf
- pytest --ignore=foolbox/tests/test_models_tensorflow_eager.py --ignore=foolbox/tests/test_models_caffe.py
- pytest --cov-append foolbox/tests/test_models_tensorflow_eager.py
- if [[ $TRAVIS_PYTHON_VERSION == 2.7 ]]; then pytest --ignore=foolbox/tests/models/test_models_tensorflow_eager.py --ignore=foolbox/tests/models/test_models_caffe.py --ignore=foolbox/tests/batch_attacks/; fi
- if [[ $TRAVIS_PYTHON_VERSION != 2.7 ]]; then pytest --ignore=foolbox/tests/models/test_models_tensorflow_eager.py --ignore=foolbox/tests/models/test_models_caffe.py; fi
- pytest --cov-append foolbox/tests/models/test_models_tensorflow_eager.py
- if [[ $TRAVIS_PYTHON_VERSION == 2.7 ]]; then docker run -v `pwd`:`pwd` bvlc/caffe:cpu sh -c "cd `pwd` && bash foolbox/tests/run_caffe_test.sh"; fi
- flake8 --ignore E402,E741,W503,W504 .
after_success:
Expand Down
5 changes: 5 additions & 0 deletions foolbox/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@
from . import criteria # noqa: F401
from . import distances # noqa: F401
from . import attacks # noqa: F401
from . import batch_attacks # noqa: F401
from . import utils # noqa: F401
from . import gradient_estimators # noqa: F401

from .adversarial import Adversarial # noqa: F401
from .yielding_adversarial import YieldingAdversarial # noqa: F401

from .batching import run_parallel # noqa: F401
from .batching import run_sequential # noqa: F401
10 changes: 7 additions & 3 deletions foolbox/adversarial.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class Adversarial(object):
if the threshold has been reached.
"""

def __init__(
self,
model,
Expand Down Expand Up @@ -82,10 +83,13 @@ def __init__(
self._best_gradient_calls = 0

# check if the original input is already adversarial
self._check_unperturbed()

def _check_unperturbed(self):
try:
self.forward_one(unperturbed)
self.forward_one(self.__unperturbed)
except StopAttack:
# if a threshold is specified and the original input is
# if a threshold is specified and the unperturbed input is
# misclassified, this can already cause a StopAttack
# exception
assert self.distance.value == 0.
Expand All @@ -98,7 +102,7 @@ def _reset(self):
self._best_prediction_calls = 0
self._best_gradient_calls = 0

self.forward_one(self.__unperturbed)
self._check_unperturbed()

@property
def perturbed(self):
Expand Down
6 changes: 5 additions & 1 deletion foolbox/attacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
ABC = abc.ABCMeta('ABC', (), {})

from ..adversarial import Adversarial
from ..yielding_adversarial import YieldingAdversarial
from ..adversarial import StopAttack
from ..criteria import Misclassification
from ..distances import MSE
Expand Down Expand Up @@ -101,7 +102,10 @@ def call_decorator(call_fn):
def wrapper(self, input_or_adv, label=None, unpack=True, **kwargs):
assert input_or_adv is not None

if isinstance(input_or_adv, Adversarial):
if isinstance(input_or_adv, YieldingAdversarial):
raise ValueError('If you pass an Adversarial instance, it must not be a YieldingAdversarial instance'
' when calling non-batch-supporting attacks like this one (check foolbox.batch_attacks).')
elif isinstance(input_or_adv, Adversarial):
a = input_or_adv
if label is not None:
raise ValueError('Label must not be passed when input_or_adv'
Expand Down
5 changes: 2 additions & 3 deletions foolbox/attacks/carlini_wagner.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def to_model_space(x):
binary_search_steps >= 10:
# in the last binary search step, use the upper_bound instead
# TODO: find out why... it's not obvious why this is useful
const = upper_bound
const = min(1e10, upper_bound)

logging.info('starting optimization with const = {}'.format(const))

Expand All @@ -148,8 +148,7 @@ def to_model_space(x):
const, a, x, logits, reconstructed_original,
confidence, min_, max_)

logging.info('loss: {}; best overall distance: {}'.format(
loss, a.distance))
logging.info('loss: {}; best overall distance: {}'.format(loss, a.distance))

# backprop the gradient of the loss w.r.t. x further
# to get the gradient of the loss w.r.t. att_perturbation
Expand Down
11 changes: 11 additions & 0 deletions foolbox/batch_attacks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# flake8: noqa

from .gradient import GradientAttack, GradientSignAttack, FGSM
from .carlini_wagner import CarliniWagnerL2Attack

from .iterative_projected_gradient import LinfinityBasicIterativeAttack, BasicIterativeMethod, BIM
from .iterative_projected_gradient import L1BasicIterativeAttack
from .iterative_projected_gradient import L2BasicIterativeAttack
from .iterative_projected_gradient import ProjectedGradientDescentAttack, ProjectedGradientDescent, PGD
from .iterative_projected_gradient import RandomStartProjectedGradientDescentAttack, RandomProjectedGradientDescent, RandomPGD
from .iterative_projected_gradient import MomentumIterativeAttack, MomentumIterativeMethod
71 changes: 71 additions & 0 deletions foolbox/batch_attacks/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import warnings
import logging
import functools
import numpy as np

from ..attacks.base import Attack
from ..yielding_adversarial import YieldingAdversarial
from ..adversarial import StopAttack
from ..batching import run_parallel


class BatchAttack(Attack):
def __call__(self, inputs, labels, unpack=True, **kwargs):
assert isinstance(inputs, np.ndarray)
assert isinstance(labels, np.ndarray)

if len(inputs) != len(labels):
raise ValueError('The number of inputs and labels needs to be equal')

model = self._default_model
criterion = self._default_criterion
distance = self._default_distance
threshold = self._default_threshold

if model is None:
raise ValueError('The attack needs to be initialized with a model')
if criterion is None:
raise ValueError('The attack needs to be initialized with a criterion')
if distance is None:
raise ValueError('The attack needs to be initialized with a distance')

create_attack_fn = self.__class__
advs = run_parallel(create_attack_fn, model, criterion, inputs, labels,
distance=distance, threshold=threshold, **kwargs)

if unpack:
advs = [a.perturbed for a in advs]
advs = [p if p is not None else np.full_like(u, np.nan) for p, u in zip(advs, inputs)]
advs = np.stack(advs)
return advs


def generator_decorator(generator):
@functools.wraps(generator)
def wrapper(self, a, **kwargs):
assert isinstance(a, YieldingAdversarial)

if a.distance.value == 0.:
warnings.warn('Not running the attack because the original input'
' is already misclassified and the adversarial thus'
' has a distance of 0.')
elif a.reached_threshold():
warnings.warn('Not running the attack because the given treshold'
' is already reached')
else:
try:
_ = yield from generator(self, a, **kwargs)
assert _ is None, 'decorated __call__ method must return None'
except StopAttack:
# if a threshold is specified, StopAttack will be thrown
# when the treshold is reached; thus we can do early
# stopping of the attack
logging.info('threshold reached, stopping attack')

if a.perturbed is None:
warnings.warn('{} did not find an adversarial, maybe the model'
' or the criterion is not supported by this'
' attack.'.format(self.name()))
return a

return wrapper
Loading

0 comments on commit 562f21e

Please sign in to comment.