diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000..a2d32f1e --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "foolbox/tests/data/model_repo"] + path = foolbox/tests/data/model_repo + url = https://github.com/bveliqi/foolbox-zoo-dummy.git diff --git a/docs/index.rst b/docs/index.rst index 91d90d15..266254b3 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -35,6 +35,7 @@ You might want to have a look at our recently announced `Robust Vision Benchmark user/tutorial user/examples user/adversarial + user/zoo user/development user/faq @@ -44,6 +45,7 @@ You might want to have a look at our recently announced `Robust Vision Benchmark modules/models modules/criteria + modules/zoo modules/distances modules/attacks modules/adversarial diff --git a/docs/modules/zoo.rst b/docs/modules/zoo.rst new file mode 100644 index 00000000..ad4411f4 --- /dev/null +++ b/docs/modules/zoo.rst @@ -0,0 +1,17 @@ +:mod:`foolbox.zoo` +================================= + +.. automodule:: foolbox.zoo + + +Get Model +---------------- + +.. autofunction:: get_model + + + +Fetch Weights +---------------- + +.. autofunction:: fetch_weights diff --git a/docs/user/zoo.rst b/docs/user/zoo.rst new file mode 100644 index 00000000..9185a55d --- /dev/null +++ b/docs/user/zoo.rst @@ -0,0 +1,26 @@ +========= +Model Zoo +========= + +This tutorial will show you how the model zoo can be used to run your attack against a robust model. + +Downloading a model +=================== + +For this tutorial, we will download the `Madry et al. CIFAR10 challenge` robust model implemented in `TensorFlow` +and run a `FGSM (GradienSignAttack)` against it. + +.. code-block:: python3 + + from foolbox import zoo + + # download the model + model = zoo.get_model(url="https://github.com/bethgelab/cifar10_challenge.git") + + # read image and label + image = ... + label = ... + + # apply attack on source image + attack = foolbox.attacks.FGSM(model) + adversarial = attack(image[:,:,::-1], label) diff --git a/foolbox/tests/data/model_repo b/foolbox/tests/data/model_repo new file mode 160000 index 00000000..e1932bbc --- /dev/null +++ b/foolbox/tests/data/model_repo @@ -0,0 +1 @@ +Subproject commit e1932bbc04ed82a3b7593e3352d8931a7ab544e1 diff --git a/foolbox/tests/test_fetch_weights.py b/foolbox/tests/test_fetch_weights.py new file mode 100644 index 00000000..5b0fb057 --- /dev/null +++ b/foolbox/tests/test_fetch_weights.py @@ -0,0 +1,93 @@ +from foolbox.zoo import fetch_weights +from foolbox.zoo.common import path_exists, home_directory_path, sha256_hash +from foolbox.zoo.weights_fetcher import FOLDER + +import os +import pytest +import shutil + +import responses +import io +import zipfile + + +@responses.activate +def test_fetch_weights_unzipped(): + weights_uri = 'http://localhost:8080/weights.zip' + raw_body = _random_body(zipped=False) + + # mock server + responses.add(responses.GET, weights_uri, + body=raw_body, status=200, stream=True) + + expected_path = _expected_path(weights_uri) + + if path_exists(expected_path): + shutil.rmtree(expected_path) # make sure path does not exist already + + file_path = fetch_weights(weights_uri) + + exists_locally = path_exists(expected_path) + assert exists_locally + assert expected_path in file_path + + +@responses.activate +def test_fetch_weights_zipped(): + weights_uri = 'http://localhost:8080/weights.zip' + + # mock server + raw_body = _random_body(zipped=True) + responses.add(responses.GET, weights_uri, + body=raw_body, status=200, stream=True, + content_type='application/zip', + headers={'Accept-Encoding': 'gzip, deflate'}) + + expected_path = _expected_path(weights_uri) + + if path_exists(expected_path): + shutil.rmtree(expected_path) # make sure path does not exist already + + file_path = fetch_weights(weights_uri, unzip=True) + + exists_locally = path_exists(expected_path) + assert exists_locally + assert expected_path in file_path + + +@responses.activate +def test_fetch_weights_returns_404(): + weights_uri = 'http://down:8080/weights.zip' + + # mock server + responses.add(responses.GET, weights_uri, status=404) + + expected_path = _expected_path(weights_uri) + + if path_exists(expected_path): + shutil.rmtree(expected_path) # make sure path does not exist already + + with pytest.raises(RuntimeError): + fetch_weights(weights_uri, unzip=False) + + +def test_no_uri_given(): + assert fetch_weights(None) is None + + +def _random_body(zipped=False): + if zipped: + data = io.BytesIO() + with zipfile.ZipFile(data, mode='w') as z: + z.writestr('test.txt', 'no real weights in here :)') + data.seek(0) + return data.getvalue() + else: + raw_body = os.urandom(1024) + return raw_body + + +def _expected_path(weights_uri): + hash_digest = sha256_hash(weights_uri) + local_path = home_directory_path(FOLDER, hash_digest) + return local_path diff --git a/foolbox/tests/test_git_cloner.py b/foolbox/tests/test_git_cloner.py new file mode 100644 index 00000000..401dd05f --- /dev/null +++ b/foolbox/tests/test_git_cloner.py @@ -0,0 +1,32 @@ +from foolbox.zoo import git_cloner +import os +import hashlib +import pytest +from foolbox.zoo.git_cloner import GitCloneError + + +def test_git_clone(): + # given + git_uri = "https://github.com/bethgelab/convex_adversarial.git" + expected_path = _expected_path(git_uri) + + # when + path = git_cloner.clone(git_uri) + + # then + assert path == expected_path + + +def test_wrong_git_uri(): + git_uri = "git@github.com:bethgelab/non-existing-repo.git" + with pytest.raises(GitCloneError): + git_cloner.clone(git_uri) + + +def _expected_path(git_uri): + home = os.path.expanduser('~') + m = hashlib.sha256() + m.update(git_uri.encode()) + hash = m.hexdigest() + expected_path = os.path.join(home, '.foolbox_zoo', hash) + return expected_path diff --git a/foolbox/tests/test_model_zoo.py b/foolbox/tests/test_model_zoo.py new file mode 100644 index 00000000..84629f57 --- /dev/null +++ b/foolbox/tests/test_model_zoo.py @@ -0,0 +1,52 @@ +from foolbox import zoo +import numpy as np +import foolbox +import sys +import pytest +from foolbox.zoo.model_loader import ModelLoader +from os.path import join, dirname + + +@pytest.fixture(autouse=True) +def unload_foolbox_model_module(): + # reload foolbox_model from scratch for every run + # to ensure atomic tests without side effects + module_names = ['foolbox_model', 'model'] + for module_name in module_names: + if module_name in sys.modules: + del sys.modules[module_name] + + +test_data = [ + # private repo won't work on travis + # ('https://github.com/bethgelab/AnalysisBySynthesis.git', (1, 28, 28)), + # ('https://github.com/bethgelab/convex_adversarial.git', (1, 28, 28)), + # ('https://github.com/bethgelab/mnist_challenge.git', 784) + (join('file://', dirname(__file__), 'data/model_repo'), (3, 224, 224)) +] + + +@pytest.mark.parametrize("url, dim", test_data) +def test_loading_model(url, dim): + # download model + model = zoo.get_model(url) + + # create a dummy image + x = np.zeros(dim, dtype=np.float32) + x[:] = np.random.randn(*x.shape) + + # run the model + logits = model.predictions(x) + probabilities = foolbox.utils.softmax(logits) + predicted_class = np.argmax(logits) + + # sanity check + assert predicted_class >= 0 + assert np.sum(probabilities) >= 0.9999 + + # TODO: delete fmodel + + +def test_non_default_module_throws_error(): + with pytest.raises(RuntimeError): + ModelLoader.get(key='other') diff --git a/foolbox/zoo/__init__.py b/foolbox/zoo/__init__.py new file mode 100644 index 00000000..b4c2d028 --- /dev/null +++ b/foolbox/zoo/__init__.py @@ -0,0 +1,2 @@ +from .zoo import get_model # noqa: F401 +from .weights_fetcher import fetch_weights # noqa: F401 diff --git a/foolbox/zoo/common.py b/foolbox/zoo/common.py new file mode 100644 index 00000000..72644d33 --- /dev/null +++ b/foolbox/zoo/common.py @@ -0,0 +1,18 @@ +import hashlib +import os + + +def sha256_hash(git_uri): + m = hashlib.sha256() + m.update(git_uri.encode()) + return m.hexdigest() + + +def home_directory_path(folder, hash_digest): + # does this work on all operating systems? + home = os.path.expanduser('~') + return os.path.join(home, folder, hash_digest) + + +def path_exists(local_path): + return os.path.exists(local_path) diff --git a/foolbox/zoo/git_cloner.py b/foolbox/zoo/git_cloner.py new file mode 100644 index 00000000..efb203d2 --- /dev/null +++ b/foolbox/zoo/git_cloner.py @@ -0,0 +1,39 @@ +from git import Repo +import logging +from .common import sha256_hash, home_directory_path, path_exists + +FOLDER = '.foolbox_zoo' + + +class GitCloneError(RuntimeError): + pass + + +def clone(git_uri): + """ + Clone a remote git repository to a local path. + + :param git_uri: the URI to the git repository to be cloned + :return: the generated local path where the repository has been cloned to + """ + hash_digest = sha256_hash(git_uri) + local_path = home_directory_path(FOLDER, hash_digest) + exists_locally = path_exists(local_path) + + if not exists_locally: + _clone_repo(git_uri, local_path) + else: + logging.info( # pragma: no cover + "Git repository already exists locally.") # pragma: no cover + + return local_path + + +def _clone_repo(git_uri, local_path): + logging.info("Cloning repo %s to %s", git_uri, local_path) + try: + Repo.clone_from(git_uri, local_path) + except Exception as e: + logging.exception("Failed to clone repository", e) + raise GitCloneError("Failed to clone repository") + logging.info("Cloned repo successfully.") diff --git a/foolbox/zoo/model_loader.py b/foolbox/zoo/model_loader.py new file mode 100644 index 00000000..38256ee9 --- /dev/null +++ b/foolbox/zoo/model_loader.py @@ -0,0 +1,45 @@ +import sys +import importlib + +import abc +abstractmethod = abc.abstractmethod +if sys.version_info >= (3, 4): + ABC = abc.ABC +else: # pragma: no cover + ABC = abc.ABCMeta('ABC', (), {}) + + +class ModelLoader(ABC): + + @abstractmethod + def load(self, path): + """ + Load a model from a local path, to which a git repository + has been previously cloned to. + + :param path: the path to the local repository containing the code + :return: a foolbox-wrapped model + """ + pass # pragma: no cover + + @staticmethod + def get(key='default'): + if key is 'default': + return DefaultLoader() + else: + raise RuntimeError("No model loader for: %s".format(key)) + + @staticmethod + def _import_module(path, module_name='foolbox_model'): + sys.path.insert(0, path) + module = importlib.import_module(module_name) + print('imported module: {}'.format(module)) + return module + + +class DefaultLoader(ModelLoader): + + def load(self, path, module_name='foolbox_model'): + module = ModelLoader._import_module(path, module_name) + model = module.create() + return model diff --git a/foolbox/zoo/weights_fetcher.py b/foolbox/zoo/weights_fetcher.py new file mode 100644 index 00000000..9d33278e --- /dev/null +++ b/foolbox/zoo/weights_fetcher.py @@ -0,0 +1,94 @@ +import requests +import shutil +import zipfile +import os +import logging + +from .common import sha256_hash, home_directory_path, path_exists + +FOLDER = '.foolbox_zoo/weights' + + +def fetch_weights(weights_uri, unzip=False): + """ + + Provides utilities to download and extract packages + containing model weights when creating foolbox-zoo compatible + repositories, if the weights are not part of the repository itself. + + Examples + -------- + + Download and unzip weights: + + >>> from foolbox import zoo + >>> url = 'https://github.com/MadryLab/mnist_challenge_models/raw/master/secret.zip' # noqa F501 + >>> weights_path = zoo.fetch_weights(url, unzip=True) + + :param weights_uri: the URI to fetch the weights from + :param unzip: should be `True` if the file to be downloaded is + a zipped package + :return: local path where the weights have been downloaded + and potentially unzipped to + """ + if weights_uri is None: + logging.info("No weights to be fetched for this model.") + return + + hash_digest = sha256_hash(weights_uri) + local_path = home_directory_path(FOLDER, hash_digest) + exists_locally = path_exists(local_path) + + filename = _filename_from_uri(weights_uri) + file_path = os.path.join(local_path, filename) + + if exists_locally: + logging.info("Weights already stored locally.") # pragma: no cover + else: + _download(file_path, weights_uri, local_path) + + if unzip: + file_path = _extract(local_path, filename) + + return file_path + + +def _filename_from_uri(url): + # get last part of the URI, i.e. file-name + filename = url.split('/')[-1] + # remove query params if exist + filename = filename.split('?')[0] + return filename + + +def _download(file_path, url, directory): + logging.info("Downloading weights: %s to %s", url, file_path) + if not os.path.exists(directory): + os.makedirs(directory) + # first check ETag or If-Modified-Since header or similar + # to check whether updated weights are available? + r = requests.get(url, stream=True) + if r.status_code == 200: + with open(file_path, 'wb') as f: + r.raw.decode_content = True + shutil.copyfileobj(r.raw, f) + else: + raise RuntimeError("Failed to fetch weights from %s", url) + + +def _extract(directory, filename): + file_path = os.path.join(directory, filename) + extracted_folder = filename.rsplit('.', 1)[0] + extracted_folder = os.path.join(directory, extracted_folder) + + if not os.path.exists(extracted_folder): + logging.info("Extracting weights package to %s", extracted_folder) + os.makedirs(extracted_folder) + zip_ref = zipfile.ZipFile(file_path, 'r') + zip_ref.extractall(extracted_folder) + zip_ref.close() + else: + logging.info("Extraced folder already exists: %s", + extracted_folder) # pragma: no cover + + return extracted_folder diff --git a/foolbox/zoo/zoo.py b/foolbox/zoo/zoo.py new file mode 100644 index 00000000..9949c9bf --- /dev/null +++ b/foolbox/zoo/zoo.py @@ -0,0 +1,37 @@ +from .git_cloner import clone +from .model_loader import ModelLoader + + +def get_model(url): + """ + + Provides utilities to download foolbox-compatible robust models + to easily test attacks against them by simply providing a git-URL. + + Examples + -------- + + Instantiate a model: + + >>> from foolbox import zoo + >>> url = "https://github.com/bveliqi/foolbox-zoo-dummy.git" + >>> model = zoo.get_model(url) # doctest: +SKIP + + Only works with a foolbox-zoo compatible repository. + I.e. models need to have a `foolbox_model.py` file + with a `create()`-function, which returns a foolbox-wrapped model. + + Example repositories: + + - https://github.com/bethgelab/mnist_challenge + - https://github.com/bethgelab/cifar10_challenge + - https://github.com/bethgelab/convex_adversarial + + :param url: URL to the git repository + :return: a foolbox-wrapped model instance + """ + repo_path = clone(url) + loader = ModelLoader.get() + model = loader.load(repo_path) + + return model diff --git a/requirements-dev.txt b/requirements-dev.txt index 8020b85d..a78b05b2 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -9,3 +9,4 @@ flake8 >= 3.3.0 python-coveralls >= 2.9.1 pillow >= 4.1.1 randomgen >= 1.14.4 +responses >= 0.9.0 \ No newline at end of file diff --git a/setup.py b/setup.py index b6d88c19..4fa7873f 100644 --- a/setup.py +++ b/setup.py @@ -21,6 +21,8 @@ 'numpy', 'scipy', 'setuptools', + 'requests', + 'GitPython' ] tests_require = [