From 9bd67de49e4bc2261fe239ad677dded7ba17dde5 Mon Sep 17 00:00:00 2001 From: Sanyam Kapoor <1sanyamkapoor@gmail.com> Date: Mon, 6 Nov 2017 03:40:46 -0500 Subject: [PATCH 01/11] Add basic Omniglot dataset loader --- .gitignore | 3 +- torchvision/datasets/__init__.py | 4 +- torchvision/datasets/omniglot.py | 123 +++++++++++++++++++++++++++++++ torchvision/datasets/utils.py | 46 ++++++++++++ 4 files changed, 174 insertions(+), 2 deletions(-) create mode 100644 torchvision/datasets/omniglot.py diff --git a/.gitignore b/.gitignore index c02a6ab80e3..19d64575c3e 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,5 @@ torchvision.egg-info/ */**/*.pyc */**/*~ *~ -docs/build \ No newline at end of file +docs/build +.idea/ diff --git a/torchvision/datasets/__init__.py b/torchvision/datasets/__init__.py index 9fab55190cc..64e1ac5a44b 100644 --- a/torchvision/datasets/__init__.py +++ b/torchvision/datasets/__init__.py @@ -8,9 +8,11 @@ from .phototour import PhotoTour from .fakedata import FakeData from .semeion import SEMEION +from .omniglot import Omniglot __all__ = ('LSUN', 'LSUNClass', 'ImageFolder', 'FakeData', 'CocoCaptions', 'CocoDetection', 'CIFAR10', 'CIFAR100', 'FashionMNIST', - 'MNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION') + 'MNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION', + 'Omniglot') diff --git a/torchvision/datasets/omniglot.py b/torchvision/datasets/omniglot.py new file mode 100644 index 00000000000..29ec6ab5681 --- /dev/null +++ b/torchvision/datasets/omniglot.py @@ -0,0 +1,123 @@ +from __future__ import print_function +from PIL import Image +from functools import reduce +import os +import torch.utils.data as data +from .utils import download_url, check_integrity, list_dir, list_files + + +class Omniglot(data.Dataset): + """`Omniglot `_ Dataset. + Args: + root (string): Root directory of dataset where directory + ``omniglot-py`` exists. + background (bool, optional): If True, creates dataset from the "background" set, otherwise + creates from the "evaluation" set. This terminology is defined by the authors. + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + download (bool, optional): If true, downloads the dataset zip files from the internet and + puts it in root directory. If the zip files are already downloaded, they are not + downloaded again. + force_extract (bool, optional): If true, extracts the downloaded zip file irrespective + of the existence of an extracted folder with the same name + """ + folder = 'omniglot-py' + download_url_prefix = 'https://github.com/brendenlake/omniglot/raw/master/python' + zips_md5 = [ + ['images_background', '68d2efa1b9178cc56df9314c21c6e718'], + ['images_evaluation', '6b91aef0f799c5bb55b94e3f2daec811'], + # Kept for provisional purposes + ['images_background_small1', 'e704a628b5459e08445c13499850abc4'], + ['images_background_small2', 'b75a71a51d3b13f821f212756fe481fd'], + ] + + def __init__(self, root, background=True, + transform=None, target_transform=None, + download=False, + force_extract=False): + self.root = os.path.join(os.path.expanduser(root), self.folder) + self.background = background + self.transform = transform + self.target_transform = target_transform + + if download: + self.download(force_extract) + + if not self._check_integrity(): + raise RuntimeError('Dataset not found or corrupted.' + + ' You can use download=True to download it') + + self.target_folder = os.path.join(self.root, self._get_target_folder()) + self.alphabets_ = list_dir(self.target_folder) + self.characters_ = list(reduce(lambda x, y: x + y, + [ + [ + os.path.join(alphabet, character) + for character in + list_dir(os.path.join(self.target_folder, alphabet)) + ] + for alphabet in self.alphabets_ + ] + )) + self.character_images_ = [ + [ + tuple([image, idx]) + for image in list_files(os.path.join(self.target_folder, character), '.png') + ] + for idx, character in enumerate(self.characters_) + ] + self.flat_character_images_ = list(reduce(lambda x, y: x + y, self.character_images_)) + + def __len__(self): + return len(self.flat_character_images_) + + def __getitem__(self, index): + """ + Args: + index (int): Index + + Returns: + tuple: (image, target) where target is index of the target character class. + """ + image_name, character_class = self.flat_character_images_[index] + image_path = os.path.join(self.target_folder, self.characters_[character_class], image_name) + image = Image.open(image_path, mode='r').convert('L') + + if self.transform: + image = self.transform(image) + + if self.target_transform: + character_class = self.target_transform(character_class) + + return image, character_class + + def _check_integrity(self): + for fzip in self.zips_md5: + filename, md5 = fzip[0] + '.zip', fzip[1] + fpath = os.path.join(self.root, filename) + if not check_integrity(fpath, md5): + return False + return True + + def download(self, force_extract=False): + import zipfile + + if self._check_integrity(): + print('Files already downloaded and verified') + return + + for fzip in self.zips_md5: + filename, md5 = fzip[0], fzip[1] + zip_filename = filename + '.zip' + url = self.download_url_prefix + '/' + zip_filename + download_url(url, self.root, zip_filename, md5) + + if not os.path.isdir(os.path.join(self.root, filename)) or force_extract is True: + print('Extracting downloaded file: ' + os.path.join(self.root, zip_filename)) + with zipfile.ZipFile(os.path.join(self.root, zip_filename), 'r') as zip_file: + zip_file.extractall(self.root) + + def _get_target_folder(self): + return 'images_background' if self.background is True else 'images_evaluation' diff --git a/torchvision/datasets/utils.py b/torchvision/datasets/utils.py index 466be647252..9fa3b0b8c9b 100644 --- a/torchvision/datasets/utils.py +++ b/torchvision/datasets/utils.py @@ -45,3 +45,49 @@ def download_url(url, root, filename, md5): print('Failed download. Trying https -> http instead.' ' Downloading ' + url + ' to ' + fpath) urllib.request.urlretrieve(url, fpath) + + +def list_dir(root, prefix=False): + """List all directories at a given root + + Args: + root (str): Path to directory whose folders need to be listed + prefix (bool, optional): If true, prepends the path to each result, otherwise + only returns the name of the directories found + """ + root = os.path.expanduser(root) + directories = list( + filter( + lambda p: os.path.isdir(os.path.join(root, p)), + os.listdir(root) + ) + ) + + if prefix is True: + directories = [os.path.join(root, d) for d in directories] + + return directories + + +def list_files(root, suffix, prefix=False): + """List all files ending with a suffix at a given root + + Args: + root (str): Path to directory whose folders need to be listed + suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png'). + It uses the Python "str.endswith" method and is passed directly + prefix (bool, optional): If true, prepends the path to each result, otherwise + only returns the name of the files found + """ + root = os.path.expanduser(root) + files = list( + filter( + lambda p: os.path.isfile(os.path.join(root, p)) and p.endswith(suffix), + os.listdir(root) + ) + ) + + if prefix is True: + files = [os.path.join(root, d) for d in files] + + return files From 019b16d38ab6eb0d84efe6f8923afd3ca7df0ac9 Mon Sep 17 00:00:00 2001 From: Sanyam Kapoor <1sanyamkapoor@gmail.com> Date: Mon, 6 Nov 2017 03:50:42 -0500 Subject: [PATCH 02/11] Remove unused import --- torchvision/datasets/cifar.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchvision/datasets/cifar.py b/torchvision/datasets/cifar.py index c3b7b8ef4f2..d39bb8902d7 100644 --- a/torchvision/datasets/cifar.py +++ b/torchvision/datasets/cifar.py @@ -2,7 +2,6 @@ from PIL import Image import os import os.path -import errno import numpy as np import sys if sys.version_info[0] == 2: From 29b27cfbe9b2be4ca6c105014375afbd6e7dcdc4 Mon Sep 17 00:00:00 2001 From: Sanyam Kapoor <1sanyamkapoor@gmail.com> Date: Mon, 6 Nov 2017 04:16:22 -0500 Subject: [PATCH 03/11] Add Omniglot random pair to sample pair of characters --- torchvision/datasets/omniglot.py | 76 ++++++++++++++++++++++++++------ 1 file changed, 63 insertions(+), 13 deletions(-) diff --git a/torchvision/datasets/omniglot.py b/torchvision/datasets/omniglot.py index 29ec6ab5681..f0bbfb9c20d 100644 --- a/torchvision/datasets/omniglot.py +++ b/torchvision/datasets/omniglot.py @@ -2,6 +2,7 @@ from PIL import Image from functools import reduce import os +import random import torch.utils.data as data from .utils import download_url, check_integrity, list_dir, list_files @@ -28,9 +29,9 @@ class Omniglot(data.Dataset): zips_md5 = [ ['images_background', '68d2efa1b9178cc56df9314c21c6e718'], ['images_evaluation', '6b91aef0f799c5bb55b94e3f2daec811'], - # Kept for provisional purposes - ['images_background_small1', 'e704a628b5459e08445c13499850abc4'], - ['images_background_small2', 'b75a71a51d3b13f821f212756fe481fd'], + # Provision in future + # ['images_background_small1', 'e704a628b5459e08445c13499850abc4'], + # ['images_background_small2', 'b75a71a51d3b13f821f212756fe481fd'], ] def __init__(self, root, background=True, @@ -51,16 +52,18 @@ def __init__(self, root, background=True, self.target_folder = os.path.join(self.root, self._get_target_folder()) self.alphabets_ = list_dir(self.target_folder) - self.characters_ = list(reduce(lambda x, y: x + y, - [ - [ - os.path.join(alphabet, character) - for character in - list_dir(os.path.join(self.target_folder, alphabet)) - ] - for alphabet in self.alphabets_ - ] - )) + self.characters_ = list( + reduce( + lambda x, y: x + y, + [ + [ + os.path.join(alphabet, character) + for character in list_dir(os.path.join(self.target_folder, alphabet)) + ] + for alphabet in self.alphabets_ + ] + ) + ) self.character_images_ = [ [ tuple([image, idx]) @@ -121,3 +124,50 @@ def download(self, force_extract=False): def _get_target_folder(self): return 'images_background' if self.background is True else 'images_evaluation' + + +class OmniglotRandomPair(Omniglot): + """`OmniglotRandomPair `_ Dataset. + + This is a subclass of the Omniglot dataset. This instead it returns + a randomized pair of images with similarity label (0 or 1) + """ + def __init__(self, *args, **kwargs): + super(self.__class__, self).__init__(*args, **kwargs) + + def __getitem__(self, index): + """ + Args: + index (int): Index + + Returns: + tuple: ((image0, image1), is_match) a random pair of images from the Omniglot characters + with corresponding label 1 if it is matching pair and 0 otherwise + """ + + character_classes = [random.randint(0, len(self.characters_) - 1) for _ in range(2)] + + # Choose to return a matching/non-matching pair with probability 1/2 + is_match = random.randint(0, 1) + if is_match == 1: + character_classes = [character_classes[0], character_classes[0]] + else: + while character_classes[0] == character_classes[1]: + character_classes[1] = random.randint(0, len(self.characters_) - 1) + + image_names = [random.choice(self.character_images_[cls]) for cls in character_classes] + + image_paths = [ + os.path.join(self.target_folder, self.characters_[character_classes[idx]], image_name[0]) + for idx, image_name in enumerate(image_names) + ] + + images = [Image.open(image_path, mode='r').convert('L') for image_path in image_paths] + + if self.transform is not None: + images = [self.transform(image) for image in images] + + if self.target_transform is not None: + is_match = self.target_transform(is_match) + + return images, is_match From 0e2e37b343a2c0f8eb090902c74d4d223d5de2e4 Mon Sep 17 00:00:00 2001 From: Sanyam Kapoor <1sanyamkapoor@gmail.com> Date: Mon, 6 Nov 2017 18:33:53 -0500 Subject: [PATCH 04/11] Precompute random set of pairs, deterministic after object instantiation --- torchvision/datasets/omniglot.py | 74 +++++++++++++++++++++++--------- 1 file changed, 54 insertions(+), 20 deletions(-) diff --git a/torchvision/datasets/omniglot.py b/torchvision/datasets/omniglot.py index f0bbfb9c20d..f98a25a583c 100644 --- a/torchvision/datasets/omniglot.py +++ b/torchvision/datasets/omniglot.py @@ -131,38 +131,37 @@ class OmniglotRandomPair(Omniglot): This is a subclass of the Omniglot dataset. This instead it returns a randomized pair of images with similarity label (0 or 1) + + Args: + pair_count (int, optional): The total number of image pairs to generate. Defaults to + 10000 """ - def __init__(self, *args, **kwargs): + def __init__(self, *args, pair_count=10000, **kwargs): super(self.__class__, self).__init__(*args, **kwargs) + self.pair_count = pair_count + self._precompute_pairs() + + def __len__(self): + return len(self.pairs_list) + def __getitem__(self, index): """ Args: index (int): Index Returns: - tuple: ((image0, image1), is_match) a random pair of images from the Omniglot characters + tuple: (image0, image1, is_match) a random pair of images from the Omniglot characters with corresponding label 1 if it is matching pair and 0 otherwise """ - character_classes = [random.randint(0, len(self.characters_) - 1) for _ in range(2)] - - # Choose to return a matching/non-matching pair with probability 1/2 - is_match = random.randint(0, 1) - if is_match == 1: - character_classes = [character_classes[0], character_classes[0]] - else: - while character_classes[0] == character_classes[1]: - character_classes[1] = random.randint(0, len(self.characters_) - 1) - - image_names = [random.choice(self.character_images_[cls]) for cls in character_classes] - - image_paths = [ - os.path.join(self.target_folder, self.characters_[character_classes[idx]], image_name[0]) - for idx, image_name in enumerate(image_names) + target_pair, is_match = self.pairs_list[index] + target_image_names = [self.character_images_[i][j] for i, j in target_pair] + target_image_paths = [ + os.path.join(self.target_folder, self.characters_[cid], name) + for name, cid in target_image_names ] - - images = [Image.open(image_path, mode='r').convert('L') for image_path in image_paths] + images = [Image.open(path, mode='r').convert('L') for path in target_image_paths] if self.transform is not None: images = [self.transform(image) for image in images] @@ -170,4 +169,39 @@ def __getitem__(self, index): if self.target_transform is not None: is_match = self.target_transform(is_match) - return images, is_match + return images[0], images[1], is_match + + def _precompute_pairs(self): + """A utility wrapper to randomly generate pairs of images + + Args: + + Returns: + list(tuple((cid0, id0), (cid1, id1), is_match)), a list of 3-tuples where the first two + items of the tuple contains a character id and corresponding randomly chose image id + and the last item is 1 or 0 based on whether the image pair is from the same character + or not respectively + """ + is_match = [random.randint(0, 1) for _ in range(self.pair_count)] + + cid0_list = [random.randint(0, len(self.characters_) - 1) for _ in range(self.pair_count)] + c0_list = [random.randint(0, len(self.character_images_[cid]) - 1) for cid in cid0_list] + + cid1_list = [ + cid0_list[idx] if is_match[idx] == 1 else self._generate_pair(cid0_list[idx]) + for idx in range(self.pair_count) + ] + c1_list = [random.randint(0, len(self.character_images_[cid]) - 1) for cid in cid1_list] + + self.pairs_list = [ + (((cid0_list[idx], c0_list[idx]), (cid1_list[idx], c1_list[idx])), is_match[idx]) + for idx in range(self.pair_count) + ] + + def _generate_pair(self, character_id): + pair_id = random.randint(0, len(self.characters_) - 1) + while pair_id == character_id: + pair_id = random.randint(0, len(self.characters_) - 1) + return pair_id + + From dbc796ae21e17d5ca425ef987cdd6529ab668686 Mon Sep 17 00:00:00 2001 From: Sanyam Kapoor <1sanyamkapoor@gmail.com> Date: Mon, 6 Nov 2017 18:36:36 -0500 Subject: [PATCH 05/11] Export OmniglotRandomPair via the datasets module interfact --- torchvision/datasets/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/datasets/__init__.py b/torchvision/datasets/__init__.py index 64e1ac5a44b..2ae4ce0d281 100644 --- a/torchvision/datasets/__init__.py +++ b/torchvision/datasets/__init__.py @@ -8,11 +8,11 @@ from .phototour import PhotoTour from .fakedata import FakeData from .semeion import SEMEION -from .omniglot import Omniglot +from .omniglot import Omniglot, OmniglotRandomPair __all__ = ('LSUN', 'LSUNClass', 'ImageFolder', 'FakeData', 'CocoCaptions', 'CocoDetection', 'CIFAR10', 'CIFAR100', 'FashionMNIST', 'MNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION', - 'Omniglot') + 'Omniglot', 'OmniglotRandomPair') From 5ccc7ae0195848ebe32c29334316174b419377e8 Mon Sep 17 00:00:00 2001 From: Sanyam Kapoor <1sanyamkapoor@gmail.com> Date: Wed, 8 Nov 2017 12:11:24 -0500 Subject: [PATCH 06/11] Fix naming convention, use sum instead of reduce --- .gitignore | 1 - torchvision/datasets/omniglot.py | 49 ++++++++++++++------------------ 2 files changed, 22 insertions(+), 28 deletions(-) diff --git a/.gitignore b/.gitignore index 19d64575c3e..2abd33a8556 100644 --- a/.gitignore +++ b/.gitignore @@ -6,4 +6,3 @@ torchvision.egg-info/ */**/*~ *~ docs/build -.idea/ diff --git a/torchvision/datasets/omniglot.py b/torchvision/datasets/omniglot.py index f98a25a583c..b149c41b7e8 100644 --- a/torchvision/datasets/omniglot.py +++ b/torchvision/datasets/omniglot.py @@ -29,9 +29,6 @@ class Omniglot(data.Dataset): zips_md5 = [ ['images_background', '68d2efa1b9178cc56df9314c21c6e718'], ['images_evaluation', '6b91aef0f799c5bb55b94e3f2daec811'], - # Provision in future - # ['images_background_small1', 'e704a628b5459e08445c13499850abc4'], - # ['images_background_small2', 'b75a71a51d3b13f821f212756fe481fd'], ] def __init__(self, root, background=True, @@ -51,30 +48,28 @@ def __init__(self, root, background=True, ' You can use download=True to download it') self.target_folder = os.path.join(self.root, self._get_target_folder()) - self.alphabets_ = list_dir(self.target_folder) - self.characters_ = list( - reduce( - lambda x, y: x + y, + self._alphabets = list_dir(self.target_folder) + self._characters = sum( + [ [ - [ - os.path.join(alphabet, character) - for character in list_dir(os.path.join(self.target_folder, alphabet)) - ] - for alphabet in self.alphabets_ + os.path.join(alphabet, character) + for character in list_dir(os.path.join(self.target_folder, alphabet)) ] - ) + for alphabet in self._alphabets + ], + [] ) - self.character_images_ = [ + self._character_images = [ [ - tuple([image, idx]) + (image, idx) for image in list_files(os.path.join(self.target_folder, character), '.png') ] - for idx, character in enumerate(self.characters_) + for idx, character in enumerate(self._characters) ] - self.flat_character_images_ = list(reduce(lambda x, y: x + y, self.character_images_)) + self._flat_character_images = sum(self._character_images, []) def __len__(self): - return len(self.flat_character_images_) + return len(self._flat_character_images) def __getitem__(self, index): """ @@ -84,8 +79,8 @@ def __getitem__(self, index): Returns: tuple: (image, target) where target is index of the target character class. """ - image_name, character_class = self.flat_character_images_[index] - image_path = os.path.join(self.target_folder, self.characters_[character_class], image_name) + image_name, character_class = self._flat_character_images[index] + image_path = os.path.join(self.target_folder, self._characters[character_class], image_name) image = Image.open(image_path, mode='r').convert('L') if self.transform: @@ -156,9 +151,9 @@ def __getitem__(self, index): """ target_pair, is_match = self.pairs_list[index] - target_image_names = [self.character_images_[i][j] for i, j in target_pair] + target_image_names = [self._character_images[i][j] for i, j in target_pair] target_image_paths = [ - os.path.join(self.target_folder, self.characters_[cid], name) + os.path.join(self.target_folder, self._characters[cid], name) for name, cid in target_image_names ] images = [Image.open(path, mode='r').convert('L') for path in target_image_paths] @@ -184,14 +179,14 @@ def _precompute_pairs(self): """ is_match = [random.randint(0, 1) for _ in range(self.pair_count)] - cid0_list = [random.randint(0, len(self.characters_) - 1) for _ in range(self.pair_count)] - c0_list = [random.randint(0, len(self.character_images_[cid]) - 1) for cid in cid0_list] + cid0_list = [random.randint(0, len(self._characters) - 1) for _ in range(self.pair_count)] + c0_list = [random.randint(0, len(self._character_images[cid]) - 1) for cid in cid0_list] cid1_list = [ cid0_list[idx] if is_match[idx] == 1 else self._generate_pair(cid0_list[idx]) for idx in range(self.pair_count) ] - c1_list = [random.randint(0, len(self.character_images_[cid]) - 1) for cid in cid1_list] + c1_list = [random.randint(0, len(self._character_images[cid]) - 1) for cid in cid1_list] self.pairs_list = [ (((cid0_list[idx], c0_list[idx]), (cid1_list[idx], c1_list[idx])), is_match[idx]) @@ -199,9 +194,9 @@ def _precompute_pairs(self): ] def _generate_pair(self, character_id): - pair_id = random.randint(0, len(self.characters_) - 1) + pair_id = random.randint(0, len(self._characters) - 1) while pair_id == character_id: - pair_id = random.randint(0, len(self.characters_) - 1) + pair_id = random.randint(0, len(self._characters) - 1) return pair_id From 594247d3ec76e931a0f8b42217e6e6d92acbedec Mon Sep 17 00:00:00 2001 From: Sanyam Kapoor <1sanyamkapoor@gmail.com> Date: Wed, 8 Nov 2017 12:43:52 -0500 Subject: [PATCH 07/11] Fix downloading to not download everything, fix Python2 syntax --- torchvision/datasets/omniglot.py | 51 ++++++++++++++------------------ 1 file changed, 23 insertions(+), 28 deletions(-) diff --git a/torchvision/datasets/omniglot.py b/torchvision/datasets/omniglot.py index b149c41b7e8..531ef1b9ff7 100644 --- a/torchvision/datasets/omniglot.py +++ b/torchvision/datasets/omniglot.py @@ -1,6 +1,5 @@ from __future__ import print_function from PIL import Image -from functools import reduce import os import random import torch.utils.data as data @@ -21,27 +20,24 @@ class Omniglot(data.Dataset): download (bool, optional): If true, downloads the dataset zip files from the internet and puts it in root directory. If the zip files are already downloaded, they are not downloaded again. - force_extract (bool, optional): If true, extracts the downloaded zip file irrespective - of the existence of an extracted folder with the same name """ folder = 'omniglot-py' download_url_prefix = 'https://github.com/brendenlake/omniglot/raw/master/python' - zips_md5 = [ - ['images_background', '68d2efa1b9178cc56df9314c21c6e718'], - ['images_evaluation', '6b91aef0f799c5bb55b94e3f2daec811'], - ] + zips_md5 = { + 'images_background': '68d2efa1b9178cc56df9314c21c6e718', + 'images_evaluation': '6b91aef0f799c5bb55b94e3f2daec811' + } def __init__(self, root, background=True, transform=None, target_transform=None, - download=False, - force_extract=False): + download=False): self.root = os.path.join(os.path.expanduser(root), self.folder) self.background = background self.transform = transform self.target_transform = target_transform if download: - self.download(force_extract) + self.download() if not self._check_integrity(): raise RuntimeError('Dataset not found or corrupted.' + @@ -92,30 +88,25 @@ def __getitem__(self, index): return image, character_class def _check_integrity(self): - for fzip in self.zips_md5: - filename, md5 = fzip[0] + '.zip', fzip[1] - fpath = os.path.join(self.root, filename) - if not check_integrity(fpath, md5): - return False + zip_filename = self._get_target_folder() + if not check_integrity(os.path.join(self.root, zip_filename + '.zip'), self.zips_md5[zip_filename]): + return False return True - def download(self, force_extract=False): + def download(self): import zipfile if self._check_integrity(): print('Files already downloaded and verified') return - for fzip in self.zips_md5: - filename, md5 = fzip[0], fzip[1] - zip_filename = filename + '.zip' - url = self.download_url_prefix + '/' + zip_filename - download_url(url, self.root, zip_filename, md5) - - if not os.path.isdir(os.path.join(self.root, filename)) or force_extract is True: - print('Extracting downloaded file: ' + os.path.join(self.root, zip_filename)) - with zipfile.ZipFile(os.path.join(self.root, zip_filename), 'r') as zip_file: - zip_file.extractall(self.root) + filename = self._get_target_folder() + zip_filename = filename + '.zip' + url = self.download_url_prefix + '/' + zip_filename + download_url(url, self.root, zip_filename, self.zips_md5[filename]) + print('Extracting downloaded file: ' + os.path.join(self.root, zip_filename)) + with zipfile.ZipFile(os.path.join(self.root, zip_filename), 'r') as zip_file: + zip_file.extractall(self.root) def _get_target_folder(self): return 'images_background' if self.background is True else 'images_evaluation' @@ -131,8 +122,12 @@ class OmniglotRandomPair(Omniglot): pair_count (int, optional): The total number of image pairs to generate. Defaults to 10000 """ - def __init__(self, *args, pair_count=10000, **kwargs): - super(self.__class__, self).__init__(*args, **kwargs) + def __init__(self, root, pair_count=10000, background=True, + transform=None, target_transform=None, + download=False): + super(self.__class__, self).__init__(root, background=background, + transform=transform, target_transform=target_transform, + download=download) self.pair_count = pair_count self._precompute_pairs() From b2e0b185113d2a4e095eb8b1519b733f4a980063 Mon Sep 17 00:00:00 2001 From: Sanyam Kapoor <1sanyamkapoor@gmail.com> Date: Wed, 8 Nov 2017 13:46:26 -0500 Subject: [PATCH 08/11] Fix end line lint --- torchvision/datasets/omniglot.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/torchvision/datasets/omniglot.py b/torchvision/datasets/omniglot.py index 531ef1b9ff7..033ec322a85 100644 --- a/torchvision/datasets/omniglot.py +++ b/torchvision/datasets/omniglot.py @@ -193,5 +193,3 @@ def _generate_pair(self, character_id): while pair_id == character_id: pair_id = random.randint(0, len(self._characters) - 1) return pair_id - - From 7985f476de720314e1bf50a09ea8710e6750e7a0 Mon Sep 17 00:00:00 2001 From: Sanyam Kapoor <1sanyamkapoor@gmail.com> Date: Thu, 9 Nov 2017 00:20:34 -0500 Subject: [PATCH 09/11] Add random_seed, syntax fixes --- torchvision/datasets/omniglot.py | 51 ++++++++++++++------------------ 1 file changed, 22 insertions(+), 29 deletions(-) diff --git a/torchvision/datasets/omniglot.py b/torchvision/datasets/omniglot.py index 033ec322a85..3060c4d9a77 100644 --- a/torchvision/datasets/omniglot.py +++ b/torchvision/datasets/omniglot.py @@ -1,5 +1,6 @@ from __future__ import print_function from PIL import Image +from os.path import join import os import random import torch.utils.data as data @@ -31,7 +32,7 @@ class Omniglot(data.Dataset): def __init__(self, root, background=True, transform=None, target_transform=None, download=False): - self.root = os.path.join(os.path.expanduser(root), self.folder) + self.root = join(os.path.expanduser(root), self.folder) self.background = background self.transform = transform self.target_transform = target_transform @@ -43,25 +44,12 @@ def __init__(self, root, background=True, raise RuntimeError('Dataset not found or corrupted.' + ' You can use download=True to download it') - self.target_folder = os.path.join(self.root, self._get_target_folder()) + self.target_folder = join(self.root, self._get_target_folder()) self._alphabets = list_dir(self.target_folder) - self._characters = sum( - [ - [ - os.path.join(alphabet, character) - for character in list_dir(os.path.join(self.target_folder, alphabet)) - ] - for alphabet in self._alphabets - ], - [] - ) - self._character_images = [ - [ - (image, idx) - for image in list_files(os.path.join(self.target_folder, character), '.png') - ] - for idx, character in enumerate(self._characters) - ] + self._characters = sum([[join(a, c) for c in list_dir(join(self.target_folder, a))] + for a in self._alphabets], []) + self._character_images = [[(image, idx) for image in list_files(join(self.target_folder, character), '.png')] + for idx, character in enumerate(self._characters)] self._flat_character_images = sum(self._character_images, []) def __len__(self): @@ -76,7 +64,7 @@ def __getitem__(self, index): tuple: (image, target) where target is index of the target character class. """ image_name, character_class = self._flat_character_images[index] - image_path = os.path.join(self.target_folder, self._characters[character_class], image_name) + image_path = join(self.target_folder, self._characters[character_class], image_name) image = Image.open(image_path, mode='r').convert('L') if self.transform: @@ -89,7 +77,7 @@ def __getitem__(self, index): def _check_integrity(self): zip_filename = self._get_target_folder() - if not check_integrity(os.path.join(self.root, zip_filename + '.zip'), self.zips_md5[zip_filename]): + if not check_integrity(join(self.root, zip_filename + '.zip'), self.zips_md5[zip_filename]): return False return True @@ -104,12 +92,12 @@ def download(self): zip_filename = filename + '.zip' url = self.download_url_prefix + '/' + zip_filename download_url(url, self.root, zip_filename, self.zips_md5[filename]) - print('Extracting downloaded file: ' + os.path.join(self.root, zip_filename)) - with zipfile.ZipFile(os.path.join(self.root, zip_filename), 'r') as zip_file: + print('Extracting downloaded file: ' + join(self.root, zip_filename)) + with zipfile.ZipFile(join(self.root, zip_filename), 'r') as zip_file: zip_file.extractall(self.root) def _get_target_folder(self): - return 'images_background' if self.background is True else 'images_evaluation' + return 'images_background' if self.background else 'images_evaluation' class OmniglotRandomPair(Omniglot): @@ -121,14 +109,17 @@ class OmniglotRandomPair(Omniglot): Args: pair_count (int, optional): The total number of image pairs to generate. Defaults to 10000 + random_seed (int, optional): The value to pass to "random.seed" to allow reproducibility + of randomized pair generation """ def __init__(self, root, pair_count=10000, background=True, transform=None, target_transform=None, - download=False): - super(self.__class__, self).__init__(root, background=background, - transform=transform, target_transform=target_transform, - download=download) + download=False, random_seed=None): + super(OmniglotRandomPair, self).__init__(root, background=background, + transform=transform, target_transform=target_transform, + download=download) + self.random_seed = random_seed self.pair_count = pair_count self._precompute_pairs() @@ -148,7 +139,7 @@ def __getitem__(self, index): target_pair, is_match = self.pairs_list[index] target_image_names = [self._character_images[i][j] for i, j in target_pair] target_image_paths = [ - os.path.join(self.target_folder, self._characters[cid], name) + join(self.target_folder, self._characters[cid], name) for name, cid in target_image_names ] images = [Image.open(path, mode='r').convert('L') for path in target_image_paths] @@ -172,6 +163,8 @@ def _precompute_pairs(self): and the last item is 1 or 0 based on whether the image pair is from the same character or not respectively """ + random.seed(self.random_seed) + is_match = [random.randint(0, 1) for _ in range(self.pair_count)] cid0_list = [random.randint(0, len(self._characters) - 1) for _ in range(self.pair_count)] From f0d4664b7655f1fdf3e14b3b41078c7393eb0871 Mon Sep 17 00:00:00 2001 From: Sanyam Kapoor <1sanyamkapoor@gmail.com> Date: Fri, 17 Nov 2017 11:50:51 -0500 Subject: [PATCH 10/11] Remove randomized pair, take up as a separate generic wrapper --- torchvision/datasets/__init__.py | 4 +- torchvision/datasets/omniglot.py | 89 -------------------------------- 2 files changed, 2 insertions(+), 91 deletions(-) diff --git a/torchvision/datasets/__init__.py b/torchvision/datasets/__init__.py index 2ae4ce0d281..64e1ac5a44b 100644 --- a/torchvision/datasets/__init__.py +++ b/torchvision/datasets/__init__.py @@ -8,11 +8,11 @@ from .phototour import PhotoTour from .fakedata import FakeData from .semeion import SEMEION -from .omniglot import Omniglot, OmniglotRandomPair +from .omniglot import Omniglot __all__ = ('LSUN', 'LSUNClass', 'ImageFolder', 'FakeData', 'CocoCaptions', 'CocoDetection', 'CIFAR10', 'CIFAR100', 'FashionMNIST', 'MNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION', - 'Omniglot', 'OmniglotRandomPair') + 'Omniglot') diff --git a/torchvision/datasets/omniglot.py b/torchvision/datasets/omniglot.py index 3060c4d9a77..6fff770b165 100644 --- a/torchvision/datasets/omniglot.py +++ b/torchvision/datasets/omniglot.py @@ -2,7 +2,6 @@ from PIL import Image from os.path import join import os -import random import torch.utils.data as data from .utils import download_url, check_integrity, list_dir, list_files @@ -98,91 +97,3 @@ def download(self): def _get_target_folder(self): return 'images_background' if self.background else 'images_evaluation' - - -class OmniglotRandomPair(Omniglot): - """`OmniglotRandomPair `_ Dataset. - - This is a subclass of the Omniglot dataset. This instead it returns - a randomized pair of images with similarity label (0 or 1) - - Args: - pair_count (int, optional): The total number of image pairs to generate. Defaults to - 10000 - random_seed (int, optional): The value to pass to "random.seed" to allow reproducibility - of randomized pair generation - """ - def __init__(self, root, pair_count=10000, background=True, - transform=None, target_transform=None, - download=False, random_seed=None): - super(OmniglotRandomPair, self).__init__(root, background=background, - transform=transform, target_transform=target_transform, - download=download) - - self.random_seed = random_seed - self.pair_count = pair_count - self._precompute_pairs() - - def __len__(self): - return len(self.pairs_list) - - def __getitem__(self, index): - """ - Args: - index (int): Index - - Returns: - tuple: (image0, image1, is_match) a random pair of images from the Omniglot characters - with corresponding label 1 if it is matching pair and 0 otherwise - """ - - target_pair, is_match = self.pairs_list[index] - target_image_names = [self._character_images[i][j] for i, j in target_pair] - target_image_paths = [ - join(self.target_folder, self._characters[cid], name) - for name, cid in target_image_names - ] - images = [Image.open(path, mode='r').convert('L') for path in target_image_paths] - - if self.transform is not None: - images = [self.transform(image) for image in images] - - if self.target_transform is not None: - is_match = self.target_transform(is_match) - - return images[0], images[1], is_match - - def _precompute_pairs(self): - """A utility wrapper to randomly generate pairs of images - - Args: - - Returns: - list(tuple((cid0, id0), (cid1, id1), is_match)), a list of 3-tuples where the first two - items of the tuple contains a character id and corresponding randomly chose image id - and the last item is 1 or 0 based on whether the image pair is from the same character - or not respectively - """ - random.seed(self.random_seed) - - is_match = [random.randint(0, 1) for _ in range(self.pair_count)] - - cid0_list = [random.randint(0, len(self._characters) - 1) for _ in range(self.pair_count)] - c0_list = [random.randint(0, len(self._character_images[cid]) - 1) for cid in cid0_list] - - cid1_list = [ - cid0_list[idx] if is_match[idx] == 1 else self._generate_pair(cid0_list[idx]) - for idx in range(self.pair_count) - ] - c1_list = [random.randint(0, len(self._character_images[cid]) - 1) for cid in cid1_list] - - self.pairs_list = [ - (((cid0_list[idx], c0_list[idx]), (cid1_list[idx], c1_list[idx])), is_match[idx]) - for idx in range(self.pair_count) - ] - - def _generate_pair(self, character_id): - pair_id = random.randint(0, len(self._characters) - 1) - while pair_id == character_id: - pair_id = random.randint(0, len(self._characters) - 1) - return pair_id From 5a93b5f3f498921dfe8c518407d4b472d22e8e45 Mon Sep 17 00:00:00 2001 From: Sanyam Kapoor <1sanyamkapoor@gmail.com> Date: Fri, 22 Dec 2017 17:31:16 -0500 Subject: [PATCH 11/11] Fix master conflict --- torchvision/datasets/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/datasets/__init__.py b/torchvision/datasets/__init__.py index 64e1ac5a44b..4e6dc40ec06 100644 --- a/torchvision/datasets/__init__.py +++ b/torchvision/datasets/__init__.py @@ -13,6 +13,6 @@ __all__ = ('LSUN', 'LSUNClass', 'ImageFolder', 'FakeData', 'CocoCaptions', 'CocoDetection', - 'CIFAR10', 'CIFAR100', 'FashionMNIST', + 'CIFAR10', 'CIFAR100', 'EMNIST', 'FashionMNIST', 'MNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION', 'Omniglot')