diff --git a/.travis.yml b/.travis.yml index 45b1ce3a..7f2f948c 100644 --- a/.travis.yml +++ b/.travis.yml @@ -68,8 +68,9 @@ install: script: - pytest --collect-only # tf eager cannot be run in the same process as standard tf - - pytest --ignore=foolbox/tests/test_models_tensorflow_eager.py --ignore=foolbox/tests/test_models_caffe.py - - pytest --cov-append foolbox/tests/test_models_tensorflow_eager.py + - if [[ $TRAVIS_PYTHON_VERSION == 2.7 ]]; then pytest --ignore=foolbox/tests/models/test_models_tensorflow_eager.py --ignore=foolbox/tests/models/test_models_caffe.py --ignore=foolbox/tests/batch_attacks/; fi + - if [[ $TRAVIS_PYTHON_VERSION != 2.7 ]]; then pytest --ignore=foolbox/tests/models/test_models_tensorflow_eager.py --ignore=foolbox/tests/models/test_models_caffe.py; fi + - pytest --cov-append foolbox/tests/models/test_models_tensorflow_eager.py - if [[ $TRAVIS_PYTHON_VERSION == 2.7 ]]; then docker run -v `pwd`:`pwd` bvlc/caffe:cpu sh -c "cd `pwd` && bash foolbox/tests/run_caffe_test.sh"; fi - flake8 --ignore E402,E741,W503,W504 . after_success: diff --git a/foolbox/__init__.py b/foolbox/__init__.py index 88923fca..e893ec91 100755 --- a/foolbox/__init__.py +++ b/foolbox/__init__.py @@ -11,7 +11,12 @@ from . import criteria # noqa: F401 from . import distances # noqa: F401 from . import attacks # noqa: F401 +from . import batch_attacks # noqa: F401 from . import utils # noqa: F401 from . import gradient_estimators # noqa: F401 from .adversarial import Adversarial # noqa: F401 +from .yielding_adversarial import YieldingAdversarial # noqa: F401 + +from .batching import run_parallel # noqa: F401 +from .batching import run_sequential # noqa: F401 diff --git a/foolbox/adversarial.py b/foolbox/adversarial.py index 8374ba1b..11a6bce8 100644 --- a/foolbox/adversarial.py +++ b/foolbox/adversarial.py @@ -48,6 +48,7 @@ class Adversarial(object): if the threshold has been reached. """ + def __init__( self, model, @@ -82,10 +83,13 @@ def __init__( self._best_gradient_calls = 0 # check if the original input is already adversarial + self._check_unperturbed() + + def _check_unperturbed(self): try: - self.forward_one(unperturbed) + self.forward_one(self.__unperturbed) except StopAttack: - # if a threshold is specified and the original input is + # if a threshold is specified and the unperturbed input is # misclassified, this can already cause a StopAttack # exception assert self.distance.value == 0. @@ -98,7 +102,7 @@ def _reset(self): self._best_prediction_calls = 0 self._best_gradient_calls = 0 - self.forward_one(self.__unperturbed) + self._check_unperturbed() @property def perturbed(self): diff --git a/foolbox/attacks/base.py b/foolbox/attacks/base.py index 0a88916e..62c7f9fa 100644 --- a/foolbox/attacks/base.py +++ b/foolbox/attacks/base.py @@ -11,6 +11,7 @@ ABC = abc.ABCMeta('ABC', (), {}) from ..adversarial import Adversarial +from ..yielding_adversarial import YieldingAdversarial from ..adversarial import StopAttack from ..criteria import Misclassification from ..distances import MSE @@ -101,7 +102,10 @@ def call_decorator(call_fn): def wrapper(self, input_or_adv, label=None, unpack=True, **kwargs): assert input_or_adv is not None - if isinstance(input_or_adv, Adversarial): + if isinstance(input_or_adv, YieldingAdversarial): + raise ValueError('If you pass an Adversarial instance, it must not be a YieldingAdversarial instance' + ' when calling non-batch-supporting attacks like this one (check foolbox.batch_attacks).') + elif isinstance(input_or_adv, Adversarial): a = input_or_adv if label is not None: raise ValueError('Label must not be passed when input_or_adv' diff --git a/foolbox/attacks/carlini_wagner.py b/foolbox/attacks/carlini_wagner.py index 6c5a42a8..d76cb1f8 100644 --- a/foolbox/attacks/carlini_wagner.py +++ b/foolbox/attacks/carlini_wagner.py @@ -129,7 +129,7 @@ def to_model_space(x): binary_search_steps >= 10: # in the last binary search step, use the upper_bound instead # TODO: find out why... it's not obvious why this is useful - const = upper_bound + const = min(1e10, upper_bound) logging.info('starting optimization with const = {}'.format(const)) @@ -148,8 +148,7 @@ def to_model_space(x): const, a, x, logits, reconstructed_original, confidence, min_, max_) - logging.info('loss: {}; best overall distance: {}'.format( - loss, a.distance)) + logging.info('loss: {}; best overall distance: {}'.format(loss, a.distance)) # backprop the gradient of the loss w.r.t. x further # to get the gradient of the loss w.r.t. att_perturbation diff --git a/foolbox/batch_attacks/__init__.py b/foolbox/batch_attacks/__init__.py new file mode 100644 index 00000000..882583bf --- /dev/null +++ b/foolbox/batch_attacks/__init__.py @@ -0,0 +1,11 @@ +# flake8: noqa + +from .gradient import GradientAttack, GradientSignAttack, FGSM +from .carlini_wagner import CarliniWagnerL2Attack + +from .iterative_projected_gradient import LinfinityBasicIterativeAttack, BasicIterativeMethod, BIM +from .iterative_projected_gradient import L1BasicIterativeAttack +from .iterative_projected_gradient import L2BasicIterativeAttack +from .iterative_projected_gradient import ProjectedGradientDescentAttack, ProjectedGradientDescent, PGD +from .iterative_projected_gradient import RandomStartProjectedGradientDescentAttack, RandomProjectedGradientDescent, RandomPGD +from .iterative_projected_gradient import MomentumIterativeAttack, MomentumIterativeMethod diff --git a/foolbox/batch_attacks/base.py b/foolbox/batch_attacks/base.py new file mode 100644 index 00000000..63a0b90c --- /dev/null +++ b/foolbox/batch_attacks/base.py @@ -0,0 +1,71 @@ +import warnings +import logging +import functools +import numpy as np + +from ..attacks.base import Attack +from ..yielding_adversarial import YieldingAdversarial +from ..adversarial import StopAttack +from ..batching import run_parallel + + +class BatchAttack(Attack): + def __call__(self, inputs, labels, unpack=True, **kwargs): + assert isinstance(inputs, np.ndarray) + assert isinstance(labels, np.ndarray) + + if len(inputs) != len(labels): + raise ValueError('The number of inputs and labels needs to be equal') + + model = self._default_model + criterion = self._default_criterion + distance = self._default_distance + threshold = self._default_threshold + + if model is None: + raise ValueError('The attack needs to be initialized with a model') + if criterion is None: + raise ValueError('The attack needs to be initialized with a criterion') + if distance is None: + raise ValueError('The attack needs to be initialized with a distance') + + create_attack_fn = self.__class__ + advs = run_parallel(create_attack_fn, model, criterion, inputs, labels, + distance=distance, threshold=threshold, **kwargs) + + if unpack: + advs = [a.perturbed for a in advs] + advs = [p if p is not None else np.full_like(u, np.nan) for p, u in zip(advs, inputs)] + advs = np.stack(advs) + return advs + + +def generator_decorator(generator): + @functools.wraps(generator) + def wrapper(self, a, **kwargs): + assert isinstance(a, YieldingAdversarial) + + if a.distance.value == 0.: + warnings.warn('Not running the attack because the original input' + ' is already misclassified and the adversarial thus' + ' has a distance of 0.') + elif a.reached_threshold(): + warnings.warn('Not running the attack because the given treshold' + ' is already reached') + else: + try: + _ = yield from generator(self, a, **kwargs) + assert _ is None, 'decorated __call__ method must return None' + except StopAttack: + # if a threshold is specified, StopAttack will be thrown + # when the treshold is reached; thus we can do early + # stopping of the attack + logging.info('threshold reached, stopping attack') + + if a.perturbed is None: + warnings.warn('{} did not find an adversarial, maybe the model' + ' or the criterion is not supported by this' + ' attack.'.format(self.name())) + return a + + return wrapper diff --git a/foolbox/batch_attacks/carlini_wagner.py b/foolbox/batch_attacks/carlini_wagner.py new file mode 100644 index 00000000..49f78237 --- /dev/null +++ b/foolbox/batch_attacks/carlini_wagner.py @@ -0,0 +1,282 @@ +# -*- coding: utf-8 -*- +from __future__ import division + +import numpy as np +import logging + +from .base import BatchAttack +from .base import generator_decorator +from ..utils import onehot_like + + +class CarliniWagnerL2Attack(BatchAttack): + """The L2 version of the Carlini & Wagner attack. + + This attack is described in [1]_. This implementation + is based on the reference implementation by Carlini [2]_. + For bounds ≠ (0, 1), it differs from [2]_ because we + normalize the squared L2 loss with the bounds. + + References + ---------- + .. [1] Nicholas Carlini, David Wagner: "Towards Evaluating the + Robustness of Neural Networks", https://arxiv.org/abs/1608.04644 + .. [2] https://github.com/carlini/nn_robust_attacks + + """ + + @generator_decorator + def as_generator(self, a, + binary_search_steps=5, max_iterations=1000, + confidence=0, learning_rate=5e-3, + initial_const=1e-2, abort_early=True): + + """The L2 version of the Carlini & Wagner attack. + + Parameters + ---------- + inputs : `numpy.ndarray` + Batch of inputs with shape as expected by the underlying model. + labels : `numpy.ndarray` + Class labels of the inputs as a vector of integers in [0, number of classes). + unpack : bool + If true, returns the adversarial inputs as an array, otherwise returns Adversarial objects. + binary_search_steps : int + The number of steps for the binary search used to + find the optimal tradeoff-constant between distance and confidence. + max_iterations : int + The maximum number of iterations. Larger values are more + accurate; setting it too small will require a large learning rate + and will produce poor results. + confidence : int or float + Confidence of adversarial examples: a higher value produces + adversarials that are further away, but more strongly classified + as adversarial. + learning_rate : float + The learning rate for the attack algorithm. Smaller values + produce better results but take longer to converge. + initial_const : float + The initial tradeoff-constant to use to tune the relative + importance of distance and confidence. If `binary_search_steps` + is large, the initial constant is not important. + abort_early : bool + If True, Adam will be aborted if the loss hasn't decreased + for some time (a tenth of max_iterations). + + """ + + if not a.has_gradient(): + logging.fatal('Applied gradient-based attack to model that ' + 'does not provide gradients.') + return + + min_, max_ = a.bounds() + + def to_attack_space(x): + # map from [min_, max_] to [-1, +1] + a = (min_ + max_) / 2 + b = (max_ - min_) / 2 + x = (x - a) / b + + # from [-1, +1] to approx. (-1, +1) + x = x * 0.999999 + + # from (-1, +1) to (-inf, +inf) + return np.arctanh(x) + + def to_model_space(x): + """Transforms an input from the attack space + to the model space. This transformation and + the returned gradient are elementwise.""" + + # from (-inf, +inf) to (-1, +1) + x = np.tanh(x) + + grad = 1 - np.square(x) + + # map from (-1, +1) to (min_, max_) + a = (min_ + max_) / 2 + b = (max_ - min_) / 2 + x = x * b + a + + grad = grad * b + return x, grad + + # variables representing inputs in attack space will be + # prefixed with att_ + att_original = to_attack_space(a.unperturbed) + + # will be close but not identical to a.unperturbed + reconstructed_original, _ = to_model_space(att_original) + + # the binary search finds the smallest const for which we + # find an adversarial + const = initial_const + lower_bound = 0 + upper_bound = np.inf + + for binary_search_step in range(binary_search_steps): + if binary_search_step == binary_search_steps - 1 and \ + binary_search_steps >= 10: + # in the last binary search step, use the upper_bound instead + # TODO: find out why... it's not obvious why this is useful + const = min(1e10, upper_bound) + + logging.info('starting optimization with const = {}'.format(const)) + + att_perturbation = np.zeros_like(att_original) + + # create a new optimizer to minimize the perturbation + optimizer = AdamOptimizer(att_perturbation.shape) + + found_adv = False # found adv with the current const + loss_at_previous_check = np.inf + + for iteration in range(max_iterations): + x, dxdp = to_model_space(att_original + att_perturbation) + logits, is_adv = yield from a.forward_one(x) + loss, dldx = yield from self.loss_function( + const, a, x, logits, reconstructed_original, + confidence, min_, max_) + + logging.info('loss: {}; best overall distance: {}'.format(loss, a.distance)) + + # backprop the gradient of the loss w.r.t. x further + # to get the gradient of the loss w.r.t. att_perturbation + assert dldx.shape == x.shape + assert dxdp.shape == x.shape + # we can do a simple elementwise multiplication, because + # grad_x_wrt_p is a matrix of elementwise derivatives + # (i.e. each x[i] w.r.t. p[i] only, for all i) and + # grad_loss_wrt_x is a real gradient reshaped as a matrix + gradient = dldx * dxdp + + att_perturbation += optimizer(gradient, learning_rate) + + if is_adv: + # this binary search step can be considered a success + # but optimization continues to minimize perturbation size + found_adv = True + + if abort_early and \ + iteration % (np.ceil(max_iterations / 10)) == 0: + # after each tenth of the iterations, check progress + if not (loss <= .9999 * loss_at_previous_check): + break # stop Adam if there has not been progress + loss_at_previous_check = loss + + if found_adv: + logging.info('found adversarial with const = {}'.format(const)) + upper_bound = const + else: + logging.info('failed to find adversarial ' + 'with const = {}'.format(const)) + lower_bound = const + + if upper_bound == np.inf: + # exponential search + const *= 10 + else: + # binary search + const = (lower_bound + upper_bound) / 2 + + @classmethod + def loss_function(cls, const, a, x, logits, reconstructed_original, + confidence, min_, max_): + """Returns the loss and the gradient of the loss w.r.t. x, + assuming that logits = model(x).""" + + targeted = a.target_class() is not None + if targeted: + c_minimize = cls.best_other_class(logits, a.target_class()) + c_maximize = a.target_class() + else: + c_minimize = a.original_class + c_maximize = cls.best_other_class(logits, a.original_class) + + is_adv_loss = logits[c_minimize] - logits[c_maximize] + + # is_adv is True as soon as the is_adv_loss goes below 0 + # but sometimes we want additional confidence + is_adv_loss += confidence + is_adv_loss = max(0, is_adv_loss) + + s = max_ - min_ + squared_l2_distance = np.sum((x - reconstructed_original)**2) / s**2 + total_loss = squared_l2_distance + const * is_adv_loss + + # calculate the gradient of total_loss w.r.t. x + logits_diff_grad = np.zeros_like(logits) + logits_diff_grad[c_minimize] = 1 + logits_diff_grad[c_maximize] = -1 + is_adv_loss_grad = yield from a.backward_one(logits_diff_grad, x) + assert is_adv_loss >= 0 + if is_adv_loss == 0: + is_adv_loss_grad = 0 + + squared_l2_distance_grad = (2 / s**2) * (x - reconstructed_original) + + total_loss_grad = squared_l2_distance_grad + const * is_adv_loss_grad + return total_loss, total_loss_grad + + @staticmethod + def best_other_class(logits, exclude): + """Returns the index of the largest logit, ignoring the class that + is passed as `exclude`.""" + other_logits = logits - onehot_like(logits, exclude, value=np.inf) + return np.argmax(other_logits) + + +CarliniWagnerL2Attack.__call__.__doc__ = CarliniWagnerL2Attack.as_generator.__doc__ + + +class AdamOptimizer: + """Basic Adam optimizer implementation that can minimize w.r.t. + a single variable. + + Parameters + ---------- + shape : tuple + shape of the variable w.r.t. which the loss should be minimized + + """ + + def __init__(self, shape): + self.m = np.zeros(shape) + self.v = np.zeros(shape) + self.t = 0 + + def __call__(self, gradient, learning_rate, + beta1=0.9, beta2=0.999, epsilon=10e-8): + """Updates internal parameters of the optimizer and returns + the change that should be applied to the variable. + + Parameters + ---------- + gradient : `np.ndarray` + the gradient of the loss w.r.t. to the variable + learning_rate: float + the learning rate in the current iteration + beta1: float + decay rate for calculating the exponentially + decaying average of past gradients + beta2: float + decay rate for calculating the exponentially + decaying average of past squared gradients + epsilon: float + small value to avoid division by zero + + """ + + self.t += 1 + + self.m = beta1 * self.m + (1 - beta1) * gradient + self.v = beta2 * self.v + (1 - beta2) * gradient**2 + + bias_correction_1 = 1 - beta1**self.t + bias_correction_2 = 1 - beta2**self.t + + m_hat = self.m / bias_correction_1 + v_hat = self.v / bias_correction_2 + + return -learning_rate * m_hat / (np.sqrt(v_hat) + epsilon) diff --git a/foolbox/batch_attacks/gradient.py b/foolbox/batch_attacks/gradient.py new file mode 100644 index 00000000..57bc23a2 --- /dev/null +++ b/foolbox/batch_attacks/gradient.py @@ -0,0 +1,141 @@ +from __future__ import division +import numpy as np +from collections import Iterable +import logging +import abc + +from .base import BatchAttack +from .base import generator_decorator + + +class SingleStepGradientBaseAttack(BatchAttack): + """Common base class for single step gradient attacks.""" + + @abc.abstractmethod + def _gradient(self, a): + raise NotImplementedError + + def _run(self, a, epsilons, max_epsilon): + if not a.has_gradient(): + return + + x = a.unperturbed + min_, max_ = a.bounds() + + gradient = yield from self._gradient(a) + + if not isinstance(epsilons, Iterable): + epsilons = np.linspace(0, max_epsilon, num=epsilons + 1)[1:] + decrease_if_first = True + else: + decrease_if_first = False + + for _ in range(2): # to repeat with decreased epsilons if necessary + for i, epsilon in enumerate(epsilons): + perturbed = x + gradient * epsilon + perturbed = np.clip(perturbed, min_, max_) + + _, is_adversarial = yield from a.forward_one(perturbed) + if is_adversarial: + if decrease_if_first and i < 20: + logging.info('repeating attack with smaller epsilons') + break + return + + max_epsilon = epsilons[i] + epsilons = np.linspace(0, max_epsilon, num=20 + 1)[1:] + + +class GradientAttack(SingleStepGradientBaseAttack): + """Perturbs the input with the gradient of the loss w.r.t. the input, + gradually increasing the magnitude until the input is misclassified. + + Does not do anything if the model does not have a gradient. + + """ + + @generator_decorator + def as_generator(self, a, epsilons=1000, max_epsilon=1): + """Perturbs the input with the gradient of the loss w.r.t. the input, + gradually increasing the magnitude until the input is misclassified. + + Parameters + ---------- + inputs : `numpy.ndarray` + Batch of inputs with shape as expected by the underlying model. + labels : `numpy.ndarray` + Class labels of the inputs as a vector of integers in [0, number of classes). + unpack : bool + If true, returns the adversarial inputs as an array, otherwise returns Adversarial objects. + epsilons : int or Iterable[float] + Either Iterable of step sizes in the gradient direction + or number of step sizes between 0 and max_epsilon that should + be tried. + max_epsilon : float + Largest step size if epsilons is not an iterable. + + """ + + yield from self._run(a, epsilons=epsilons, max_epsilon=max_epsilon) + + def _gradient(self, a): + min_, max_ = a.bounds() + gradient = yield from a.gradient_one() + gradient_norm = np.sqrt(np.mean(np.square(gradient))) + gradient = gradient / (gradient_norm + 1e-8) * (max_ - min_) + return gradient + + +GradientAttack.__call__.__doc__ = GradientAttack.as_generator.__doc__ + + +class GradientSignAttack(SingleStepGradientBaseAttack): + """Adds the sign of the gradient to the input, gradually increasing + the magnitude until the input is misclassified. This attack is + often referred to as Fast Gradient Sign Method and was introduced + in [1]_. + + Does not do anything if the model does not have a gradient. + + References + ---------- + .. [1] Ian J. Goodfellow, Jonathon Shlens, Christian Szegedy, + "Explaining and Harnessing Adversarial Examples", + https://arxiv.org/abs/1412.6572 + """ + + @generator_decorator + def as_generator(self, a, epsilons=1000, max_epsilon=1): + """Adds the sign of the gradient to the input, gradually increasing + the magnitude until the input is misclassified. + + Parameters + ---------- + inputs : `numpy.ndarray` + Batch of inputs with shape as expected by the underlying model. + labels : `numpy.ndarray` + Class labels of the inputs as a vector of integers in [0, number of classes). + unpack : bool + If true, returns the adversarial inputs as an array, otherwise returns Adversarial objects. + epsilons : int or Iterable[float] + Either Iterable of step sizes in the direction of the sign of + the gradient or number of step sizes between 0 and max_epsilon + that should be tried. + max_epsilon : float + Largest step size if epsilons is not an iterable. + + """ + + yield from self._run(a, epsilons=epsilons, max_epsilon=max_epsilon) + + def _gradient(self, a): + min_, max_ = a.bounds() + gradient = yield from a.gradient_one() + gradient = np.sign(gradient) * (max_ - min_) + return gradient + + +GradientSignAttack.__call__.__doc__ = GradientSignAttack.as_generator.__doc__ + + +FGSM = GradientSignAttack diff --git a/foolbox/batch_attacks/iterative_projected_gradient.py b/foolbox/batch_attacks/iterative_projected_gradient.py new file mode 100644 index 00000000..ec7f6d9c --- /dev/null +++ b/foolbox/batch_attacks/iterative_projected_gradient.py @@ -0,0 +1,732 @@ +from __future__ import division +import numpy as np +from abc import abstractmethod +import logging +import warnings + +from .base import BatchAttack +from .base import generator_decorator +from .. import distances +from ..utils import crossentropy +from .. import nprng + + +class IterativeProjectedGradientBaseAttack(BatchAttack): + """Base class for iterative (projected) gradient attacks. + + Concrete subclasses should implement as_generator, _gradient + and _clip_perturbation. + + TODO: add support for other loss-functions, e.g. the CW loss function, + see https://github.com/MadryLab/mnist_challenge/blob/master/pgd_attack.py + """ + + @abstractmethod + def _gradient(self, a, x, class_, strict=True): + raise NotImplementedError + + @abstractmethod + def _clip_perturbation(self, a, noise, epsilon): + raise NotImplementedError + + @abstractmethod + def _check_distance(self, a): + raise NotImplementedError + + def _get_mode_and_class(self, a): + # determine if the attack is targeted or not + target_class = a.target_class() + targeted = target_class is not None + + if targeted: + class_ = target_class + else: + class_ = a.original_class + return targeted, class_ + + def _run(self, a, binary_search, + epsilon, stepsize, iterations, + random_start, return_early): + if not a.has_gradient(): + warnings.warn('applied gradient-based attack to model that' + ' does not provide gradients') + return + + self._check_distance(a) + + targeted, class_ = self._get_mode_and_class(a) + + if binary_search: + if isinstance(binary_search, bool): + k = 20 + else: + k = int(binary_search) + yield from self._run_binary_search( + a, epsilon, stepsize, iterations, + random_start, targeted, class_, return_early, k=k) + return + else: + success = yield from self._run_one( + a, epsilon, stepsize, iterations, + random_start, targeted, class_, return_early) + return success + + def _run_binary_search(self, a, epsilon, stepsize, iterations, + random_start, targeted, class_, return_early, k): + + factor = stepsize / epsilon + + def try_epsilon(epsilon): + stepsize = factor * epsilon + success = yield from self._run_one( + a, epsilon, stepsize, iterations, + random_start, targeted, class_, return_early) + return success + + for i in range(k): + success = yield from try_epsilon(epsilon) + if success: + logging.info('successful for eps = {}'.format(epsilon)) + break + logging.info('not successful for eps = {}'.format(epsilon)) + epsilon = epsilon * 1.5 + else: + logging.warning('exponential search failed') + return + + bad = 0 + good = epsilon + + for i in range(k): + epsilon = (good + bad) / 2 + success = yield from try_epsilon(epsilon) + if success: + good = epsilon + logging.info('successful for eps = {}'.format(epsilon)) + else: + bad = epsilon + logging.info('not successful for eps = {}'.format(epsilon)) + + def _run_one(self, a, epsilon, stepsize, iterations, + random_start, targeted, class_, return_early): + min_, max_ = a.bounds() + s = max_ - min_ + + original = a.unperturbed.copy() + + if random_start: + # using uniform noise even if the perturbation clipping uses + # a different norm because cleverhans does it the same way + noise = nprng.uniform( + -epsilon * s, epsilon * s, original.shape).astype( + original.dtype) + x = original + self._clip_perturbation(a, noise, epsilon) + strict = False # because we don't enforce the bounds here + else: + x = original + strict = True + + success = False + for _ in range(iterations): + gradient = yield from self._gradient(a, x, class_, strict=strict) + # non-strict only for the first call and + # only if random_start is True + strict = True + if targeted: + gradient = -gradient + + # untargeted: gradient ascent on cross-entropy to original class + # targeted: gradient descent on cross-entropy to target class + x = x + stepsize * gradient + + x = original + self._clip_perturbation(a, x - original, epsilon) + + x = np.clip(x, min_, max_) + + logits, is_adversarial = yield from a.forward_one(x) + if logging.getLogger().isEnabledFor(logging.DEBUG): + if targeted: + ce = crossentropy(a.original_class, logits) + logging.debug('crossentropy to {} is {}'.format( + a.original_class, ce)) + ce = crossentropy(class_, logits) + logging.debug('crossentropy to {} is {}'.format(class_, ce)) + if is_adversarial: + if return_early: + return True + else: + success = True + return success + + +class LinfinityGradientMixin(object): + def _gradient(self, a, x, class_, strict=True): + gradient = yield from a.gradient_one(x, class_, strict=strict) + gradient = np.sign(gradient) + min_, max_ = a.bounds() + gradient = (max_ - min_) * gradient + return gradient + + +class L1GradientMixin(object): + def _gradient(self, a, x, class_, strict=True): + gradient = yield from a.gradient_one(x, class_, strict=strict) + # using mean to make range of epsilons comparable to Linf + gradient = gradient / np.mean(np.abs(gradient)) + min_, max_ = a.bounds() + gradient = (max_ - min_) * gradient + return gradient + + +class L2GradientMixin(object): + def _gradient(self, a, x, class_, strict=True): + gradient = yield from a.gradient_one(x, class_, strict=strict) + # using mean to make range of epsilons comparable to Linf + gradient = gradient / np.sqrt(np.mean(np.square(gradient))) + min_, max_ = a.bounds() + gradient = (max_ - min_) * gradient + return gradient + + +class LinfinityClippingMixin(object): + def _clip_perturbation(self, a, perturbation, epsilon): + min_, max_ = a.bounds() + s = max_ - min_ + clipped = np.clip(perturbation, -epsilon * s, epsilon * s) + return clipped + + +class L1ClippingMixin(object): + def _clip_perturbation(self, a, perturbation, epsilon): + # using mean to make range of epsilons comparable to Linf + norm = np.mean(np.abs(perturbation)) + norm = max(1e-12, norm) # avoid divsion by zero + min_, max_ = a.bounds() + s = max_ - min_ + # clipping, i.e. only decreasing norm + factor = min(1, epsilon * s / norm) + return perturbation * factor + + +class L2ClippingMixin(object): + def _clip_perturbation(self, a, perturbation, epsilon): + # using mean to make range of epsilons comparable to Linf + norm = np.sqrt(np.mean(np.square(perturbation))) + norm = max(1e-12, norm) # avoid divsion by zero + min_, max_ = a.bounds() + s = max_ - min_ + # clipping, i.e. only decreasing norm + factor = min(1, epsilon * s / norm) + return perturbation * factor + + +class LinfinityDistanceCheckMixin(object): + def _check_distance(self, a): + if not isinstance(a.distance, distances.Linfinity): + logging.warning('Running an attack that tries to minimize the' + ' Linfinity norm of the perturbation without' + ' specifying foolbox.distances.Linfinity as' + ' the distance metric might lead to suboptimal' + ' results.') + + +class L1DistanceCheckMixin(object): + def _check_distance(self, a): + if not isinstance(a.distance, distances.MAE): + logging.warning('Running an attack that tries to minimize the' + ' L1 norm of the perturbation without' + ' specifying foolbox.distances.MAE as' + ' the distance metric might lead to suboptimal' + ' results.') + + +class L2DistanceCheckMixin(object): + def _check_distance(self, a): + if not isinstance(a.distance, distances.MSE): + logging.warning('Running an attack that tries to minimize the' + ' L2 norm of the perturbation without' + ' specifying foolbox.distances.MSE as' + ' the distance metric might lead to suboptimal' + ' results.') + + +class LinfinityBasicIterativeAttack( + LinfinityGradientMixin, + LinfinityClippingMixin, + LinfinityDistanceCheckMixin, + IterativeProjectedGradientBaseAttack): + + """The Basic Iterative Method introduced in [1]_. + + This attack is also known as Projected Gradient + Descent (PGD) (without random start) or FGMS^k. + + References + ---------- + .. [1] Alexey Kurakin, Ian Goodfellow, Samy Bengio, + "Adversarial examples in the physical world", + https://arxiv.org/abs/1607.02533 + + .. seealso:: :class:`ProjectedGradientDescentAttack` + + """ + + @generator_decorator + def as_generator(self, a, + binary_search=True, + epsilon=0.3, + stepsize=0.05, + iterations=10, + random_start=False, + return_early=True): + + """Simple iterative gradient-based attack known as + Basic Iterative Method, Projected Gradient Descent or FGSM^k. + + Parameters + ---------- + inputs : `numpy.ndarray` + Batch of inputs with shape as expected by the underlying model. + labels : `numpy.ndarray` + Class labels of the inputs as a vector of integers in [0, number of classes). + unpack : bool + If true, returns the adversarial inputs as an array, otherwise returns Adversarial objects. + binary_search : bool or int + Whether to perform a binary search over epsilon and stepsize, + keeping their ratio constant and using their values to start + the search. If False, hyperparameters are not optimized. + Can also be an integer, specifying the number of binary + search steps (default 20). + epsilon : float + Limit on the perturbation size; if binary_search is True, + this value is only for initialization and automatically + adapted. + stepsize : float + Step size for gradient descent; if binary_search is True, + this value is only for initialization and automatically + adapted. + iterations : int + Number of iterations for each gradient descent run. + random_start : bool + Start the attack from a random point rather than from the + original input. + return_early : bool + Whether an individual gradient descent run should stop as + soon as an adversarial is found. + """ + + assert epsilon > 0 + + yield from self._run(a, binary_search, + epsilon, stepsize, iterations, + random_start, return_early) + + +LinfinityBasicIterativeAttack.__call__.__doc__ = LinfinityBasicIterativeAttack.as_generator.__doc__ + + +BasicIterativeMethod = LinfinityBasicIterativeAttack +BIM = BasicIterativeMethod + + +class L1BasicIterativeAttack( + L1GradientMixin, + L1ClippingMixin, + L1DistanceCheckMixin, + IterativeProjectedGradientBaseAttack): + + """Modified version of the Basic Iterative Method + that minimizes the L1 distance. + + .. seealso:: :class:`LinfinityBasicIterativeAttack` + + """ + + @generator_decorator + def as_generator(self, a, + binary_search=True, + epsilon=0.3, + stepsize=0.05, + iterations=10, + random_start=False, + return_early=True): + + """Simple iterative gradient-based attack known as + Basic Iterative Method, Projected Gradient Descent or FGSM^k. + + Parameters + ---------- + inputs : `numpy.ndarray` + Batch of inputs with shape as expected by the underlying model. + labels : `numpy.ndarray` + Class labels of the inputs as a vector of integers in [0, number of classes). + unpack : bool + If true, returns the adversarial inputs as an array, otherwise returns Adversarial objects. + binary_search : bool or int + Whether to perform a binary search over epsilon and stepsize, + keeping their ratio constant and using their values to start + the search. If False, hyperparameters are not optimized. + Can also be an integer, specifying the number of binary + search steps (default 20). + epsilon : float + Limit on the perturbation size; if binary_search is True, + this value is only for initialization and automatically + adapted. + stepsize : float + Step size for gradient descent; if binary_search is True, + this value is only for initialization and automatically + adapted. + iterations : int + Number of iterations for each gradient descent run. + random_start : bool + Start the attack from a random point rather than from the + original input. + return_early : bool + Whether an individual gradient descent run should stop as + soon as an adversarial is found. + """ + + assert epsilon > 0 + + yield from self._run(a, binary_search, + epsilon, stepsize, iterations, + random_start, return_early) + + +L1BasicIterativeAttack.__call__.__doc__ = L1BasicIterativeAttack.as_generator.__doc__ + + +class L2BasicIterativeAttack( + L2GradientMixin, + L2ClippingMixin, + L2DistanceCheckMixin, + IterativeProjectedGradientBaseAttack): + + """Modified version of the Basic Iterative Method + that minimizes the L2 distance. + + .. seealso:: :class:`LinfinityBasicIterativeAttack` + + """ + + @generator_decorator + def as_generator(self, a, + binary_search=True, + epsilon=0.3, + stepsize=0.05, + iterations=10, + random_start=False, + return_early=True): + + """Simple iterative gradient-based attack known as + Basic Iterative Method, Projected Gradient Descent or FGSM^k. + + Parameters + ---------- + inputs : `numpy.ndarray` + Batch of inputs with shape as expected by the underlying model. + labels : `numpy.ndarray` + Class labels of the inputs as a vector of integers in [0, number of classes). + unpack : bool + If true, returns the adversarial inputs as an array, otherwise returns Adversarial objects. + binary_search : bool or int + Whether to perform a binary search over epsilon and stepsize, + keeping their ratio constant and using their values to start + the search. If False, hyperparameters are not optimized. + Can also be an integer, specifying the number of binary + search steps (default 20). + epsilon : float + Limit on the perturbation size; if binary_search is True, + this value is only for initialization and automatically + adapted. + stepsize : float + Step size for gradient descent; if binary_search is True, + this value is only for initialization and automatically + adapted. + iterations : int + Number of iterations for each gradient descent run. + random_start : bool + Start the attack from a random point rather than from the + original input. + return_early : bool + Whether an individual gradient descent run should stop as + soon as an adversarial is found. + """ + + assert epsilon > 0 + + yield from self._run(a, binary_search, + epsilon, stepsize, iterations, + random_start, return_early) + + +L2BasicIterativeAttack.__call__.__doc__ = L2BasicIterativeAttack.as_generator.__doc__ + + +class ProjectedGradientDescentAttack( + LinfinityGradientMixin, + LinfinityClippingMixin, + LinfinityDistanceCheckMixin, + IterativeProjectedGradientBaseAttack): + + """The Projected Gradient Descent Attack + introduced in [1]_ without random start. + + When used without a random start, this attack + is also known as Basic Iterative Method (BIM) + or FGSM^k. + + References + ---------- + .. [1] Aleksander Madry, Aleksandar Makelov, Ludwig Schmidt, + Dimitris Tsipras, Adrian Vladu, "Towards Deep Learning + Models Resistant to Adversarial Attacks", + https://arxiv.org/abs/1706.06083 + + .. seealso:: + + :class:`LinfinityBasicIterativeAttack` and + :class:`RandomStartProjectedGradientDescentAttack` + + """ + + @generator_decorator + def as_generator(self, a, + binary_search=True, + epsilon=0.3, + stepsize=0.01, + iterations=40, + random_start=False, + return_early=True): + + """Simple iterative gradient-based attack known as + Basic Iterative Method, Projected Gradient Descent or FGSM^k. + + Parameters + ---------- + inputs : `numpy.ndarray` + Batch of inputs with shape as expected by the underlying model. + labels : `numpy.ndarray` + Class labels of the inputs as a vector of integers in [0, number of classes). + unpack : bool + If true, returns the adversarial inputs as an array, otherwise returns Adversarial objects. + binary_search : bool or int + Whether to perform a binary search over epsilon and stepsize, + keeping their ratio constant and using their values to start + the search. If False, hyperparameters are not optimized. + Can also be an integer, specifying the number of binary + search steps (default 20). + epsilon : float + Limit on the perturbation size; if binary_search is True, + this value is only for initialization and automatically + adapted. + stepsize : float + Step size for gradient descent; if binary_search is True, + this value is only for initialization and automatically + adapted. + iterations : int + Number of iterations for each gradient descent run. + random_start : bool + Start the attack from a random point rather than from the + original input. + return_early : bool + Whether an individual gradient descent run should stop as + soon as an adversarial is found. + """ + + assert epsilon > 0 + + yield from self._run(a, binary_search, + epsilon, stepsize, iterations, + random_start, return_early) + + +ProjectedGradientDescentAttack.__call__.__doc__ = ProjectedGradientDescentAttack.as_generator.__doc__ + + +ProjectedGradientDescent = ProjectedGradientDescentAttack +PGD = ProjectedGradientDescent + + +class RandomStartProjectedGradientDescentAttack( + LinfinityGradientMixin, + LinfinityClippingMixin, + LinfinityDistanceCheckMixin, + IterativeProjectedGradientBaseAttack): + + """The Projected Gradient Descent Attack + introduced in [1]_ with random start. + + References + ---------- + .. [1] Aleksander Madry, Aleksandar Makelov, Ludwig Schmidt, + Dimitris Tsipras, Adrian Vladu, "Towards Deep Learning + Models Resistant to Adversarial Attacks", + https://arxiv.org/abs/1706.06083 + + .. seealso:: :class:`ProjectedGradientDescentAttack` + + """ + + @generator_decorator + def as_generator(self, a, + binary_search=True, + epsilon=0.3, + stepsize=0.01, + iterations=40, + random_start=True, + return_early=True): + + """Simple iterative gradient-based attack known as + Basic Iterative Method, Projected Gradient Descent or FGSM^k. + + Parameters + ---------- + inputs : `numpy.ndarray` + Batch of inputs with shape as expected by the underlying model. + labels : `numpy.ndarray` + Class labels of the inputs as a vector of integers in [0, number of classes). + unpack : bool + If true, returns the adversarial inputs as an array, otherwise returns Adversarial objects. + binary_search : bool or int + Whether to perform a binary search over epsilon and stepsize, + keeping their ratio constant and using their values to start + the search. If False, hyperparameters are not optimized. + Can also be an integer, specifying the number of binary + search steps (default 20). + epsilon : float + Limit on the perturbation size; if binary_search is True, + this value is only for initialization and automatically + adapted. + stepsize : float + Step size for gradient descent; if binary_search is True, + this value is only for initialization and automatically + adapted. + iterations : int + Number of iterations for each gradient descent run. + random_start : bool + Start the attack from a random point rather than from the + original input. + return_early : bool + Whether an individual gradient descent run should stop as + soon as an adversarial is found. + """ + + assert epsilon > 0 + + yield from self._run(a, binary_search, + epsilon, stepsize, iterations, + random_start, return_early) + + +RandomStartProjectedGradientDescentAttack.__call__.__doc__ = \ + RandomStartProjectedGradientDescentAttack.as_generator.__doc__ + + +RandomProjectedGradientDescent = RandomStartProjectedGradientDescentAttack +RandomPGD = RandomProjectedGradientDescent + + +class MomentumIterativeAttack( + LinfinityClippingMixin, + LinfinityDistanceCheckMixin, + IterativeProjectedGradientBaseAttack): + + """The Momentum Iterative Method attack + introduced in [1]_. It's like the Basic + Iterative Method or Projected Gradient + Descent except that it uses momentum. + + References + ---------- + .. [1] Yinpeng Dong, Fangzhou Liao, Tianyu Pang, Hang Su, + Jun Zhu, Xiaolin Hu, Jianguo Li, "Boosting Adversarial + Attacks with Momentum", + https://arxiv.org/abs/1710.06081 + + """ + + def _gradient(self, a, x, class_, strict=True): + # get current gradient + gradient = yield from a.gradient_one(x, class_, strict=strict) + gradient = gradient / max(1e-12, np.mean(np.abs(gradient))) + + # combine with history of gradient as new history + self._momentum_history = \ + self._decay_factor * self._momentum_history + gradient + + # use history + gradient = self._momentum_history + gradient = np.sign(gradient) + min_, max_ = a.bounds() + gradient = (max_ - min_) * gradient + return gradient + + def _run_one(self, *args, **kwargs): + # reset momentum history every time we restart + # gradient descent + self._momentum_history = 0 + success = yield from super(MomentumIterativeAttack, self)._run_one( + *args, **kwargs) + return success + + @generator_decorator + def as_generator(self, a, + binary_search=True, + epsilon=0.3, + stepsize=0.06, + iterations=10, + decay_factor=1.0, + random_start=False, + return_early=True): + + """Momentum-based iterative gradient attack known as + Momentum Iterative Method. + + Parameters + ---------- + inputs : `numpy.ndarray` + Batch of inputs with shape as expected by the underlying model. + labels : `numpy.ndarray` + Class labels of the inputs as a vector of integers in [0, number of classes). + unpack : bool + If true, returns the adversarial inputs as an array, otherwise returns Adversarial objects. + binary_search : bool + Whether to perform a binary search over epsilon and stepsize, + keeping their ratio constant and using their values to start + the search. If False, hyperparameters are not optimized. + Can also be an integer, specifying the number of binary + search steps (default 20). + epsilon : float + Limit on the perturbation size; if binary_search is True, + this value is only for initialization and automatically + adapted. + stepsize : float + Step size for gradient descent; if binary_search is True, + this value is only for initialization and automatically + adapted. + iterations : int + Number of iterations for each gradient descent run. + decay_factor : float + Decay factor used by the momentum term. + random_start : bool + Start the attack from a random point rather than from the + original input. + return_early : bool + Whether an individual gradient descent run should stop as + soon as an adversarial is found. + """ + + assert epsilon > 0 + + self._decay_factor = decay_factor + + yield from self._run(a, binary_search, + epsilon, stepsize, iterations, + random_start, return_early) + + +MomentumIterativeAttack.__call__.__doc__ = MomentumIterativeAttack.as_generator.__doc__ + + +MomentumIterativeMethod = MomentumIterativeAttack diff --git a/foolbox/batching.py b/foolbox/batching.py new file mode 100644 index 00000000..c3aefa38 --- /dev/null +++ b/foolbox/batching.py @@ -0,0 +1,108 @@ +import logging +import numpy as np +import itertools +from .distances import MSE +from .yielding_adversarial import YieldingAdversarial + + +def run_sequential(create_attack_fn, model, criterion, inputs, labels, + distance=MSE, threshold=None, verbose=False, **kwargs): + advs = [YieldingAdversarial(model, criterion, x, label, + distance=distance, threshold=threshold, verbose=verbose) + for x, label in zip(inputs, labels)] + attacks = [create_attack_fn().as_generator(adv, **kwargs) for adv in advs] + + supported_methods = { + 'forward_one': model.forward_one, + 'gradient_one': model.gradient_one, + 'backward_one': model.backward_one, + 'forward_and_gradient_one': model.forward_and_gradient_one, + } + + for i, attack in enumerate(attacks): + result = None + while True: + try: + x = attack.send(result) + except StopIteration: + break + method, args = x[0], x[1:] + method = supported_methods[method] + result = method(*args) + assert result is not None + logging.info('{} of {} attacks completed'.format(i + 1, len(advs))) + return advs + + +def run_parallel(create_attack_fn, model, criterion, inputs, labels, + distance=MSE, threshold=None, verbose=False, **kwargs): + advs = [YieldingAdversarial(model, criterion, x, label, + distance=distance, threshold=threshold, verbose=verbose) + for x, label in zip(inputs, labels)] + attacks = [create_attack_fn().as_generator(adv, **kwargs) for adv in advs] + + predictions = [None for _ in attacks] + gradients = [] + backwards = [] + results = itertools.chain(predictions, gradients, backwards) + + while True: + attacks_requesting_predictions = [] + predictions_args = [] + attacks_requesting_gradients = [] + gradients_args = [] + attacks_requesting_backwards = [] + backwards_args = [] + for attack, result in zip(attacks, results): + try: + x = attack.send(result) + except StopIteration: + continue + method, args = x[0], x[1:] + if method == 'forward_one': + attacks_requesting_predictions.append(attack) + predictions_args.append(args) + elif method == 'gradient_one': + attacks_requesting_gradients.append(attack) + gradients_args.append(args) + elif method == 'backward_one': + attacks_requesting_backwards.append(attack) + backwards_args.append(args) + elif method == 'forward_and_gradient_one': + raise NotImplementedError('batching support for forward_and_gradient_one' + ' not yet implemented; please open an issue') + else: + assert False + N_active_attacks = len(attacks_requesting_predictions) \ + + len(attacks_requesting_gradients) \ + + len(attacks_requesting_backwards) + if N_active_attacks < len(predictions) + len(gradients) + len(backwards): # noqa: E501 + # an attack completed in this iteration + logging.info('{} of {} attacks completed'.format(len(advs) - N_active_attacks, len(advs))) # noqa: E501 + if N_active_attacks == 0: + break + + if len(attacks_requesting_predictions) > 0: + logging.debug('calling forward with', len(attacks_requesting_predictions)) # noqa: E501 + predictions_args = map(np.stack, zip(*predictions_args)) + predictions = model.forward(*predictions_args) + else: + predictions = [] + + if len(attacks_requesting_gradients) > 0: + logging.debug('calling gradient with', len(attacks_requesting_gradients)) # noqa: E501 + gradients_args = map(np.stack, zip(*gradients_args)) + gradients = model.gradient(*gradients_args) + else: + gradients = [] + + if len(attacks_requesting_backwards) > 0: + logging.debug('calling backward with', len(attacks_requesting_backwards)) # noqa: E501 + backwards_args = map(np.stack, zip(*backwards_args)) + backwards = model.backward(*backwards_args) + else: + backwards = [] + + attacks = itertools.chain(attacks_requesting_predictions, attacks_requesting_gradients, attacks_requesting_backwards) # noqa: E501 + results = itertools.chain(predictions, gradients, backwards) + return advs diff --git a/foolbox/models/base.py b/foolbox/models/base.py index 77b97343..5e5a56b5 100644 --- a/foolbox/models/base.py +++ b/foolbox/models/base.py @@ -302,4 +302,4 @@ def forward_and_gradient_one(self, x, label): :meth:`gradient_one` """ - return self.forward_one(x), self.gradient_one(x, label) + return self.forward_one(x), self.gradient_one(x, label) # pragma: no cover diff --git a/foolbox/tests/attacks/__init__.py b/foolbox/tests/attacks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/foolbox/tests/test_attacks_adef_attack.py b/foolbox/tests/attacks/test_attacks_adef_attack.py similarity index 100% rename from foolbox/tests/test_attacks_adef_attack.py rename to foolbox/tests/attacks/test_attacks_adef_attack.py diff --git a/foolbox/tests/test_attacks_approx_lbfgs.py b/foolbox/tests/attacks/test_attacks_approx_lbfgs.py similarity index 100% rename from foolbox/tests/test_attacks_approx_lbfgs.py rename to foolbox/tests/attacks/test_attacks_approx_lbfgs.py diff --git a/foolbox/tests/test_attacks_bapp.py b/foolbox/tests/attacks/test_attacks_bapp.py similarity index 100% rename from foolbox/tests/test_attacks_bapp.py rename to foolbox/tests/attacks/test_attacks_bapp.py diff --git a/foolbox/tests/test_attacks_binarization.py b/foolbox/tests/attacks/test_attacks_binarization.py similarity index 100% rename from foolbox/tests/test_attacks_binarization.py rename to foolbox/tests/attacks/test_attacks_binarization.py diff --git a/foolbox/tests/test_attacks_blur.py b/foolbox/tests/attacks/test_attacks_blur.py similarity index 100% rename from foolbox/tests/test_attacks_blur.py rename to foolbox/tests/attacks/test_attacks_blur.py diff --git a/foolbox/tests/test_attacks_boundary.py b/foolbox/tests/attacks/test_attacks_boundary.py similarity index 100% rename from foolbox/tests/test_attacks_boundary.py rename to foolbox/tests/attacks/test_attacks_boundary.py diff --git a/foolbox/tests/test_attacks_carlini_wagner.py b/foolbox/tests/attacks/test_attacks_carlini_wagner.py similarity index 100% rename from foolbox/tests/test_attacks_carlini_wagner.py rename to foolbox/tests/attacks/test_attacks_carlini_wagner.py diff --git a/foolbox/tests/test_attacks_contrast.py b/foolbox/tests/attacks/test_attacks_contrast.py similarity index 100% rename from foolbox/tests/test_attacks_contrast.py rename to foolbox/tests/attacks/test_attacks_contrast.py diff --git a/foolbox/tests/test_attacks_decoupled_direction_norm.py b/foolbox/tests/attacks/test_attacks_decoupled_direction_norm.py similarity index 100% rename from foolbox/tests/test_attacks_decoupled_direction_norm.py rename to foolbox/tests/attacks/test_attacks_decoupled_direction_norm.py diff --git a/foolbox/tests/test_attacks_deepfool.py b/foolbox/tests/attacks/test_attacks_deepfool.py similarity index 100% rename from foolbox/tests/test_attacks_deepfool.py rename to foolbox/tests/attacks/test_attacks_deepfool.py diff --git a/foolbox/tests/test_attacks_gradient.py b/foolbox/tests/attacks/test_attacks_gradient.py similarity index 100% rename from foolbox/tests/test_attacks_gradient.py rename to foolbox/tests/attacks/test_attacks_gradient.py diff --git a/foolbox/tests/test_attacks_gradient_sign.py b/foolbox/tests/attacks/test_attacks_gradient_sign.py similarity index 100% rename from foolbox/tests/test_attacks_gradient_sign.py rename to foolbox/tests/attacks/test_attacks_gradient_sign.py diff --git a/foolbox/tests/test_attacks_iterative_gradient.py b/foolbox/tests/attacks/test_attacks_iterative_gradient.py similarity index 100% rename from foolbox/tests/test_attacks_iterative_gradient.py rename to foolbox/tests/attacks/test_attacks_iterative_gradient.py diff --git a/foolbox/tests/test_attacks_iterative_gradient_sign.py b/foolbox/tests/attacks/test_attacks_iterative_gradient_sign.py similarity index 100% rename from foolbox/tests/test_attacks_iterative_gradient_sign.py rename to foolbox/tests/attacks/test_attacks_iterative_gradient_sign.py diff --git a/foolbox/tests/test_attacks_iterative_projected_gradient.py b/foolbox/tests/attacks/test_attacks_iterative_projected_gradient.py similarity index 100% rename from foolbox/tests/test_attacks_iterative_projected_gradient.py rename to foolbox/tests/attacks/test_attacks_iterative_projected_gradient.py diff --git a/foolbox/tests/test_attacks_lbfgs.py b/foolbox/tests/attacks/test_attacks_lbfgs.py similarity index 100% rename from foolbox/tests/test_attacks_lbfgs.py rename to foolbox/tests/attacks/test_attacks_lbfgs.py diff --git a/foolbox/tests/test_attacks_localsearch.py b/foolbox/tests/attacks/test_attacks_localsearch.py similarity index 100% rename from foolbox/tests/test_attacks_localsearch.py rename to foolbox/tests/attacks/test_attacks_localsearch.py diff --git a/foolbox/tests/test_attacks_newtonfool.py b/foolbox/tests/attacks/test_attacks_newtonfool.py similarity index 100% rename from foolbox/tests/test_attacks_newtonfool.py rename to foolbox/tests/attacks/test_attacks_newtonfool.py diff --git a/foolbox/tests/test_attacks_noise.py b/foolbox/tests/attacks/test_attacks_noise.py similarity index 100% rename from foolbox/tests/test_attacks_noise.py rename to foolbox/tests/attacks/test_attacks_noise.py diff --git a/foolbox/tests/test_attacks_pointwise.py b/foolbox/tests/attacks/test_attacks_pointwise.py similarity index 100% rename from foolbox/tests/test_attacks_pointwise.py rename to foolbox/tests/attacks/test_attacks_pointwise.py diff --git a/foolbox/tests/test_attacks_precomputed.py b/foolbox/tests/attacks/test_attacks_precomputed.py similarity index 100% rename from foolbox/tests/test_attacks_precomputed.py rename to foolbox/tests/attacks/test_attacks_precomputed.py diff --git a/foolbox/tests/test_attacks_saliency.py b/foolbox/tests/attacks/test_attacks_saliency.py similarity index 100% rename from foolbox/tests/test_attacks_saliency.py rename to foolbox/tests/attacks/test_attacks_saliency.py diff --git a/foolbox/tests/test_attacks_singlepixel.py b/foolbox/tests/attacks/test_attacks_singlepixel.py similarity index 100% rename from foolbox/tests/test_attacks_singlepixel.py rename to foolbox/tests/attacks/test_attacks_singlepixel.py diff --git a/foolbox/tests/test_attacks_slsqp.py b/foolbox/tests/attacks/test_attacks_slsqp.py similarity index 100% rename from foolbox/tests/test_attacks_slsqp.py rename to foolbox/tests/attacks/test_attacks_slsqp.py diff --git a/foolbox/tests/test_attacks_sparsefool.py b/foolbox/tests/attacks/test_attacks_sparsefool.py similarity index 100% rename from foolbox/tests/test_attacks_sparsefool.py rename to foolbox/tests/attacks/test_attacks_sparsefool.py diff --git a/foolbox/tests/test_attacks_spatial.py b/foolbox/tests/attacks/test_attacks_spatial.py similarity index 100% rename from foolbox/tests/test_attacks_spatial.py rename to foolbox/tests/attacks/test_attacks_spatial.py diff --git a/foolbox/tests/batch_attacks/__init__.py b/foolbox/tests/batch_attacks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/foolbox/tests/batch_attacks/test_batch_carlini_wagner.py b/foolbox/tests/batch_attacks/test_batch_carlini_wagner.py new file mode 100644 index 00000000..a5ab8cd8 --- /dev/null +++ b/foolbox/tests/batch_attacks/test_batch_carlini_wagner.py @@ -0,0 +1,35 @@ +import numpy as np + +from foolbox.batch_attacks import CarliniWagnerL2Attack as Attack + + +def test_untargeted_attack(bn_model, bn_criterion, bn_images, bn_labels): + attack = Attack(bn_model, bn_criterion) + advs = attack(bn_images, bn_labels, unpack=False, max_iterations=100) + for adv in advs: + assert adv.perturbed is not None + assert adv.distance.value < np.inf + + +def test_targeted_attack(bn_model, bn_targeted_criterion, bn_images, bn_labels): + attack = Attack(bn_model, bn_targeted_criterion) + advs = attack(bn_images, bn_labels, unpack=False, max_iterations=100, binary_search_steps=20) + for adv in advs: + assert adv.perturbed is not None + assert adv.distance.value < np.inf + + +def test_attack_impossible(bn_model, bn_impossible_criterion, bn_images, bn_labels): + attack = Attack(bn_model, bn_impossible_criterion) + advs = attack(bn_images, bn_labels, unpack=False, max_iterations=100, binary_search_steps=20) + for adv in advs: + assert adv.perturbed is None + assert adv.distance.value == np.inf + + +def test_attack_gl(gl_bn_model, bn_criterion, bn_images, bn_labels): + attack = Attack(gl_bn_model, bn_criterion) + advs = attack(bn_images, bn_labels, unpack=False, max_iterations=100) + for adv in advs: + assert adv.perturbed is None + assert adv.distance.value == np.inf diff --git a/foolbox/tests/batch_attacks/test_batch_gradient.py b/foolbox/tests/batch_attacks/test_batch_gradient.py new file mode 100644 index 00000000..40f09237 --- /dev/null +++ b/foolbox/tests/batch_attacks/test_batch_gradient.py @@ -0,0 +1,46 @@ +import pytest +import numpy as np + +from foolbox.gradient_estimators import CoordinateWiseGradientEstimator +from foolbox.gradient_estimators import EvolutionaryStrategiesGradientEstimator + +from foolbox.models import ModelWithEstimatedGradients + +from foolbox.batch_attacks import GradientAttack as Attack + + +def test_untargeted_attack(bn_model, bn_criterion, bn_images, bn_labels): + attack = Attack(bn_model, bn_criterion) + advs = attack(bn_images, bn_labels, unpack=False) + for adv in advs: + assert adv.perturbed is not None + assert adv.distance.value < np.inf + + +def test_attack_eps(bn_model, bn_criterion, bn_images, bn_labels): + attack = Attack(bn_model, bn_criterion) + advs = attack(bn_images, bn_labels, unpack=False, epsilons=np.linspace(0., 1., 100)[1:]) + for adv in advs: + assert adv.perturbed is not None + assert adv.distance.value < np.inf + + +def test_attack_gl(gl_bn_model, bn_criterion, bn_images, bn_labels): + attack = Attack(gl_bn_model, bn_criterion) + advs = attack(bn_images, bn_labels, unpack=False) + for adv in advs: + assert adv.perturbed is None + assert adv.distance.value == np.inf + + +@pytest.fixture(params=[CoordinateWiseGradientEstimator, + EvolutionaryStrategiesGradientEstimator]) +def test_attack_eg(request, bn_model, bn_criterion, bn_images, bn_labels): + GradientEstimator = request.param + gradient_estimator = GradientEstimator(epsilon=0.01) + model = ModelWithEstimatedGradients(bn_model, gradient_estimator) + attack = Attack(model, bn_criterion) + advs = attack(bn_images, bn_labels, unpack=False) + for adv in advs: + assert adv.perturbed is not None + assert adv.distance.value < np.inf diff --git a/foolbox/tests/batch_attacks/test_batch_gradient_sign.py b/foolbox/tests/batch_attacks/test_batch_gradient_sign.py new file mode 100644 index 00000000..bf803bb3 --- /dev/null +++ b/foolbox/tests/batch_attacks/test_batch_gradient_sign.py @@ -0,0 +1,27 @@ +import numpy as np + +from foolbox.batch_attacks import GradientSignAttack as Attack + + +def test_untargeted_attack(bn_model, bn_criterion, bn_images, bn_labels): + attack = Attack(bn_model, bn_criterion) + advs = attack(bn_images, bn_labels, unpack=False) + for adv in advs: + assert adv.perturbed is not None + assert adv.distance.value < np.inf + + +def test_attack_eps(bn_model, bn_criterion, bn_images, bn_labels): + attack = Attack(bn_model, bn_criterion) + advs = attack(bn_images, bn_labels, unpack=False, epsilons=np.linspace(0., 1., 100)[1:]) + for adv in advs: + assert adv.perturbed is not None + assert adv.distance.value < np.inf + + +def test_attack_gl(gl_bn_model, bn_criterion, bn_images, bn_labels): + attack = Attack(gl_bn_model, bn_criterion) + advs = attack(bn_images, bn_labels, unpack=False) + for adv in advs: + assert adv.perturbed is None + assert adv.distance.value == np.inf diff --git a/foolbox/tests/batch_attacks/test_batch_iterative_projected_gradient.py b/foolbox/tests/batch_attacks/test_batch_iterative_projected_gradient.py new file mode 100644 index 00000000..e9f0c236 --- /dev/null +++ b/foolbox/tests/batch_attacks/test_batch_iterative_projected_gradient.py @@ -0,0 +1,83 @@ +import pytest +import numpy as np + +from foolbox.batch_attacks import LinfinityBasicIterativeAttack +from foolbox.batch_attacks import L1BasicIterativeAttack +from foolbox.batch_attacks import L2BasicIterativeAttack +from foolbox.batch_attacks import ProjectedGradientDescentAttack +from foolbox.batch_attacks import RandomStartProjectedGradientDescentAttack +from foolbox.batch_attacks import MomentumIterativeAttack + +from foolbox.distances import Linfinity +from foolbox.distances import MAE + +Attacks = [ + LinfinityBasicIterativeAttack, + L1BasicIterativeAttack, + L2BasicIterativeAttack, + ProjectedGradientDescentAttack, + RandomStartProjectedGradientDescentAttack, + MomentumIterativeAttack, +] + + +def test_attack_no_binary_search_and_no_return_early(bn_model, bn_criterion, bn_images, bn_labels): + attack = LinfinityBasicIterativeAttack(bn_model, bn_criterion, distance=Linfinity) + advs = attack(bn_images, bn_labels, unpack=False, binary_search=False, return_early=False) + for adv in advs: + assert adv.perturbed is not None + assert adv.distance.value < np.inf + + +@pytest.mark.parametrize('Attack', Attacks) +def test_attack_linf(Attack, bn_model, bn_criterion, bn_images, bn_labels): + attack = Attack(bn_model, bn_criterion) + advs = attack(bn_images, bn_labels, unpack=False, binary_search=10) + for adv in advs: + assert adv.perturbed is not None + assert adv.distance.value < np.inf + + +@pytest.mark.parametrize('Attack', Attacks) +def test_attack_l2(Attack, bn_model, bn_criterion, bn_images, bn_labels): + attack = Attack(bn_model, bn_criterion) + advs = attack(bn_images, bn_labels, unpack=False) + for adv in advs: + assert adv.perturbed is not None + assert adv.distance.value < np.inf + + +@pytest.mark.parametrize('Attack', Attacks) +def test_attack_l1(Attack, bn_model, bn_criterion, bn_images, bn_labels): + attack = Attack(bn_model, bn_criterion, distance=MAE) + advs = attack(bn_images, bn_labels, unpack=False) + for adv in advs: + assert adv.perturbed is not None + assert adv.distance.value < np.inf + + +@pytest.mark.parametrize('Attack', Attacks) +def test_targeted_attack(Attack, bn_model, bn_targeted_criterion, bn_images, bn_labels): + attack = Attack(bn_model, bn_targeted_criterion) + advs = attack(bn_images, bn_labels, unpack=False) + for adv in advs: + assert adv.perturbed is not None + assert adv.distance.value < np.inf + + +@pytest.mark.parametrize('Attack', Attacks) +def test_attack_gl(Attack, gl_bn_model, bn_criterion, bn_images, bn_labels): + attack = Attack(gl_bn_model, bn_criterion) + advs = attack(bn_images, bn_labels, unpack=False) + for adv in advs: + assert adv.perturbed is None + assert adv.distance.value == np.inf + + +@pytest.mark.parametrize('Attack', Attacks) +def test_attack_impossible(Attack, bn_model, bn_impossible_criterion, bn_images, bn_labels): + attack = Attack(bn_model, bn_impossible_criterion) + advs = attack(bn_images, bn_labels, unpack=False) + for adv in advs: + assert adv.perturbed is None + assert adv.distance.value == np.inf diff --git a/foolbox/tests/conftest.py b/foolbox/tests/conftest.py index e517c788..ee46bf04 100644 --- a/foolbox/tests/conftest.py +++ b/foolbox/tests/conftest.py @@ -229,6 +229,13 @@ def bn_image(): return image +@pytest.fixture +def bn_images(): + np.random.seed(22) + image = np.random.uniform(size=(7, 5, 5, 10)).astype(np.float32) + return image + + @pytest.fixture def bn_image_pytorch(): np.random.seed(22) @@ -245,6 +252,14 @@ def bn_label(bn_image): return label +@pytest.fixture +def bn_labels(bn_images): + images = bn_images + mean = np.mean(images, axis=(1, 2)) + labels = np.argmax(mean, axis=-1) + return labels + + @pytest.fixture def bn_label_pytorch(bn_image_pytorch): image = bn_image_pytorch diff --git a/foolbox/tests/models/__init__.py b/foolbox/tests/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/foolbox/tests/test_models_caffe.py b/foolbox/tests/models/test_models_caffe.py similarity index 100% rename from foolbox/tests/test_models_caffe.py rename to foolbox/tests/models/test_models_caffe.py diff --git a/foolbox/tests/test_models_keras.py b/foolbox/tests/models/test_models_keras.py similarity index 100% rename from foolbox/tests/test_models_keras.py rename to foolbox/tests/models/test_models_keras.py diff --git a/foolbox/tests/test_models_lasagne.py b/foolbox/tests/models/test_models_lasagne.py similarity index 100% rename from foolbox/tests/test_models_lasagne.py rename to foolbox/tests/models/test_models_lasagne.py diff --git a/foolbox/tests/test_models_mxnet.py b/foolbox/tests/models/test_models_mxnet.py similarity index 100% rename from foolbox/tests/test_models_mxnet.py rename to foolbox/tests/models/test_models_mxnet.py diff --git a/foolbox/tests/test_models_mxnet_gluon.py b/foolbox/tests/models/test_models_mxnet_gluon.py similarity index 100% rename from foolbox/tests/test_models_mxnet_gluon.py rename to foolbox/tests/models/test_models_mxnet_gluon.py diff --git a/foolbox/tests/test_models_pytorch.py b/foolbox/tests/models/test_models_pytorch.py similarity index 100% rename from foolbox/tests/test_models_pytorch.py rename to foolbox/tests/models/test_models_pytorch.py diff --git a/foolbox/tests/test_models_tensorflow.py b/foolbox/tests/models/test_models_tensorflow.py similarity index 100% rename from foolbox/tests/test_models_tensorflow.py rename to foolbox/tests/models/test_models_tensorflow.py diff --git a/foolbox/tests/test_models_tensorflow_eager.py b/foolbox/tests/models/test_models_tensorflow_eager.py similarity index 100% rename from foolbox/tests/test_models_tensorflow_eager.py rename to foolbox/tests/models/test_models_tensorflow_eager.py diff --git a/foolbox/tests/test_models_theano.py b/foolbox/tests/models/test_models_theano.py similarity index 100% rename from foolbox/tests/test_models_theano.py rename to foolbox/tests/models/test_models_theano.py diff --git a/foolbox/tests/run_caffe_test.sh b/foolbox/tests/run_caffe_test.sh index 1f15fd62..d2dd2678 100644 --- a/foolbox/tests/run_caffe_test.sh +++ b/foolbox/tests/run_caffe_test.sh @@ -18,4 +18,4 @@ EOF # clear cache for importing mock modules without conflicts find . -type d -name __pycache__ -o \( -type f -name '*.py[co]' \) -print | xargs rm -rf -PYTHONPATH="/mock:${PYTHONPATH}" pytest --cov-append foolbox/tests/test_models_caffe.py +PYTHONPATH="/mock:${PYTHONPATH}" pytest --cov-append foolbox/tests/models/test_models_caffe.py diff --git a/foolbox/tests/test_batching.py b/foolbox/tests/test_batching.py new file mode 100644 index 00000000..d7b03c0b --- /dev/null +++ b/foolbox/tests/test_batching.py @@ -0,0 +1,26 @@ +import numpy as np + +from foolbox.batch_attacks import GradientAttack as Attack + +from foolbox import run_parallel +from foolbox import run_sequential + + +def test_run_parallel(bn_model, bn_criterion, bn_images, bn_labels): + attack = Attack(bn_model, bn_criterion) + advs1 = attack(bn_images, bn_labels) + + advs2 = run_parallel(Attack, bn_model, bn_criterion, bn_images, bn_labels) + advs2 = np.stack([a.perturbed for a in advs2]) + + assert np.all(advs1 == advs2) + + +def test_run_sequential(bn_model, bn_criterion, bn_images, bn_labels): + attack = Attack(bn_model, bn_criterion) + advs1 = attack(bn_images, bn_labels) + + advs2 = run_sequential(Attack, bn_model, bn_criterion, bn_images, bn_labels) + advs2 = np.stack([a.perturbed for a in advs2]) + + assert np.all(advs1 == advs2) diff --git a/foolbox/yielding_adversarial.py b/foolbox/yielding_adversarial.py new file mode 100644 index 00000000..d2292351 --- /dev/null +++ b/foolbox/yielding_adversarial.py @@ -0,0 +1,205 @@ +""" +Provides a class that represents an adversarial example. + +""" + +import numpy as np + +from .adversarial import Adversarial +from .adversarial import StopAttack + + +class YieldingAdversarial(Adversarial): + def _check_unperturbed(self): + try: + # for now, we use the non-yielding implementation in the super-class + # TODO: add support for batching this first call as well + super(YieldingAdversarial, self).forward_one(self._Adversarial__unperturbed) + except StopAttack: + # if a threshold is specified and the unperturbed input is + # misclassified, this can already cause a StopAttack + # exception + assert self.distance.value == 0. + + def forward_one(self, x, strict=True, return_details=False): + """Interface to model.forward_one for attacks. + + Parameters + ---------- + x : `numpy.ndarray` + Single input with shape as expected by the model + (without the batch dimension). + strict : bool + Controls if the bounds for the pixel values should be checked. + + """ + in_bounds = self.in_bounds(x) + assert not strict or in_bounds + + self._total_prediction_calls += 1 + predictions = yield ('forward_one', x) + is_adversarial, is_best, distance = self._Adversarial__is_adversarial( + x, predictions, in_bounds) + + assert predictions.ndim == 1 + if return_details: + return predictions, is_adversarial, is_best, distance + else: + return predictions, is_adversarial + + def forward(self, inputs, greedy=False, strict=True, return_details=False): + """Interface to model.forward for attacks. + + Parameters + ---------- + inputs : `numpy.ndarray` + Batch of inputs with shape as expected by the model. + greedy : bool + Whether the first adversarial should be returned. + strict : bool + Controls if the bounds for the pixel values should be checked. + + """ + if strict: + in_bounds = self.in_bounds(inputs) + assert in_bounds + + self._total_prediction_calls += len(inputs) + predictions = yield ('forward', inputs) + + assert predictions.ndim == 2 + assert predictions.shape[0] == inputs.shape[0] + + if return_details: + assert greedy + + adversarials = [] + for i in range(len(predictions)): + if strict: + in_bounds_i = True + else: + in_bounds_i = self.in_bounds(inputs[i]) + is_adversarial, is_best, distance = self._Adversarial__is_adversarial( + inputs[i], predictions[i], in_bounds_i) + if is_adversarial and greedy: + if return_details: + return predictions, is_adversarial, i, is_best, distance + else: + return predictions, is_adversarial, i + adversarials.append(is_adversarial) + + if greedy: # pragma: no cover + # no adversarial found + if return_details: + return predictions, False, None, False, None + else: + return predictions, False, None + + is_adversarial = np.array(adversarials) + assert is_adversarial.ndim == 1 + assert is_adversarial.shape[0] == inputs.shape[0] + return predictions, is_adversarial + + def gradient_one(self, x=None, label=None, strict=True): + """Interface to model.gradient_one for attacks. + + Parameters + ---------- + x : `numpy.ndarray` + Single input with shape as expected by the model + (without the batch dimension). + Defaults to the original input. + label : int + Label used to calculate the loss that is differentiated. + Defaults to the original label. + strict : bool + Controls if the bounds for the pixel values should be checked. + + """ + assert self.has_gradient() + + if x is None: + x = self._Adversarial__unperturbed + if label is None: + label = self._Adversarial__original_class + + assert not strict or self.in_bounds(x) + + self._total_gradient_calls += 1 + gradient = yield ('gradient_one', x, label) + + assert gradient.shape == x.shape + return gradient + + def forward_and_gradient_one(self, x=None, label=None, strict=True, return_details=False): + """Interface to model.forward_and_gradient_one for attacks. + + Parameters + ---------- + x : `numpy.ndarray` + Single input with shape as expected by the model + (without the batch dimension). + Defaults to the original input. + label : int + Label used to calculate the loss that is differentiated. + Defaults to the original label. + strict : bool + Controls if the bounds for the pixel values should be checked. + + """ + assert self.has_gradient() + + if x is None: + x = self._Adversarial__unperturbed + if label is None: + label = self._Adversarial__original_class + + in_bounds = self.in_bounds(x) + assert not strict or in_bounds + + self._total_prediction_calls += 1 + self._total_gradient_calls += 1 + predictions, gradient = yield ('forward_and_gradient_one', x, label) + is_adversarial, is_best, distance = self._Adversarial__is_adversarial(x, predictions, in_bounds) + + assert predictions.ndim == 1 + assert gradient.shape == x.shape + if return_details: + return predictions, gradient, is_adversarial, is_best, distance + else: + return predictions, gradient, is_adversarial + + def backward_one(self, gradient, x=None, strict=True): + """Interface to model.backward_one for attacks. + + Parameters + ---------- + gradient : `numpy.ndarray` + Gradient of some loss w.r.t. the logits. + x : `numpy.ndarray` + Single input with shape as expected by the model + (without the batch dimension). + + Returns + ------- + gradient : `numpy.ndarray` + The gradient w.r.t the input. + + See Also + -------- + :meth:`gradient` + + """ + assert self.has_gradient() + assert gradient.ndim == 1 + + if x is None: + x = self._Adversarial__unperturbed + + assert not strict or self.in_bounds(x) + + self._total_gradient_calls += 1 + gradient = yield ('backward_one', gradient, x) + + assert gradient.shape == x.shape + return gradient diff --git a/requirements-dev.txt b/requirements-dev.txt index 2c57f302..af6b64a6 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -3,7 +3,7 @@ sphinx>=1.6.2 sphinx-autobuild>=0.6.0 sphinx_rtd_theme>=0.2.4 twine>=1.9.1 -pytest>=3.6.0 +pytest>=4.5.0 pytest-cov>=2.5.1 flake8>=3.3.0 python-coveralls>=2.9.1