Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Omniglot Dataset #323

Merged
merged 13 commits into from
Jan 28, 2018
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ torchvision.egg-info/
*/**/*.pyc
*/**/*~
*~
docs/build
docs/build
4 changes: 3 additions & 1 deletion torchvision/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
from .phototour import PhotoTour
from .fakedata import FakeData
from .semeion import SEMEION
from .omniglot import Omniglot, OmniglotRandomPair

__all__ = ('LSUN', 'LSUNClass',
'ImageFolder', 'FakeData',
'CocoCaptions', 'CocoDetection',
'CIFAR10', 'CIFAR100', 'FashionMNIST',
'MNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION')
'MNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION',
'Omniglot', 'OmniglotRandomPair')
1 change: 0 additions & 1 deletion torchvision/datasets/cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from PIL import Image
import os
import os.path
import errno
import numpy as np
import sys
if sys.version_info[0] == 2:
Expand Down
195 changes: 195 additions & 0 deletions torchvision/datasets/omniglot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
from __future__ import print_function
from PIL import Image
import os
import random
import torch.utils.data as data
from .utils import download_url, check_integrity, list_dir, list_files


class Omniglot(data.Dataset):
"""`Omniglot <https://github.com/brendenlake/omniglot>`_ Dataset.
Args:
root (string): Root directory of dataset where directory
``omniglot-py`` exists.
background (bool, optional): If True, creates dataset from the "background" set, otherwise
creates from the "evaluation" set. This terminology is defined by the authors.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If true, downloads the dataset zip files from the internet and
puts it in root directory. If the zip files are already downloaded, they are not
downloaded again.
"""
folder = 'omniglot-py'
download_url_prefix = 'https://github.com/brendenlake/omniglot/raw/master/python'
zips_md5 = {
'images_background': '68d2efa1b9178cc56df9314c21c6e718',
'images_evaluation': '6b91aef0f799c5bb55b94e3f2daec811'
}

def __init__(self, root, background=True,
transform=None, target_transform=None,
download=False):
self.root = os.path.join(os.path.expanduser(root), self.folder)
self.background = background
self.transform = transform
self.target_transform = target_transform

if download:
self.download()

if not self._check_integrity():
raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it')

self.target_folder = os.path.join(self.root, self._get_target_folder())
self._alphabets = list_dir(self.target_folder)
self._characters = sum(

This comment was marked as off-topic.

[
[
os.path.join(alphabet, character)
for character in list_dir(os.path.join(self.target_folder, alphabet))
]
for alphabet in self._alphabets
],
[]
)
self._character_images = [

This comment was marked as off-topic.

This comment was marked as off-topic.

[
(image, idx)
for image in list_files(os.path.join(self.target_folder, character), '.png')
]
for idx, character in enumerate(self._characters)
]
self._flat_character_images = sum(self._character_images, [])

def __len__(self):
return len(self._flat_character_images)

def __getitem__(self, index):
"""
Args:
index (int): Index

Returns:
tuple: (image, target) where target is index of the target character class.
"""
image_name, character_class = self._flat_character_images[index]
image_path = os.path.join(self.target_folder, self._characters[character_class], image_name)
image = Image.open(image_path, mode='r').convert('L')

if self.transform:
image = self.transform(image)

if self.target_transform:
character_class = self.target_transform(character_class)

return image, character_class

def _check_integrity(self):
zip_filename = self._get_target_folder()
if not check_integrity(os.path.join(self.root, zip_filename + '.zip'), self.zips_md5[zip_filename]):
return False
return True

def download(self):
import zipfile

if self._check_integrity():
print('Files already downloaded and verified')
return

filename = self._get_target_folder()
zip_filename = filename + '.zip'
url = self.download_url_prefix + '/' + zip_filename
download_url(url, self.root, zip_filename, self.zips_md5[filename])
print('Extracting downloaded file: ' + os.path.join(self.root, zip_filename))
with zipfile.ZipFile(os.path.join(self.root, zip_filename), 'r') as zip_file:
zip_file.extractall(self.root)

def _get_target_folder(self):
return 'images_background' if self.background is True else 'images_evaluation'

This comment was marked as off-topic.



class OmniglotRandomPair(Omniglot):

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

"""`OmniglotRandomPair <https://github.com/brendenlake/omniglot>`_ Dataset.

This is a subclass of the Omniglot dataset. This instead it returns
a randomized pair of images with similarity label (0 or 1)

Args:
pair_count (int, optional): The total number of image pairs to generate. Defaults to
10000
"""
def __init__(self, root, pair_count=10000, background=True,
transform=None, target_transform=None,
download=False):
super(self.__class__, self).__init__(root, background=background,

This comment was marked as off-topic.

This comment was marked as off-topic.

transform=transform, target_transform=target_transform,
download=download)

self.pair_count = pair_count
self._precompute_pairs()

def __len__(self):
return len(self.pairs_list)

def __getitem__(self, index):
"""
Args:
index (int): Index

Returns:
tuple: (image0, image1, is_match) a random pair of images from the Omniglot characters
with corresponding label 1 if it is matching pair and 0 otherwise
"""

target_pair, is_match = self.pairs_list[index]
target_image_names = [self._character_images[i][j] for i, j in target_pair]
target_image_paths = [
os.path.join(self.target_folder, self._characters[cid], name)
for name, cid in target_image_names
]
images = [Image.open(path, mode='r').convert('L') for path in target_image_paths]

if self.transform is not None:
images = [self.transform(image) for image in images]

if self.target_transform is not None:
is_match = self.target_transform(is_match)

return images[0], images[1], is_match

def _precompute_pairs(self):
"""A utility wrapper to randomly generate pairs of images

Args:

Returns:
list(tuple((cid0, id0), (cid1, id1), is_match)), a list of 3-tuples where the first two
items of the tuple contains a character id and corresponding randomly chose image id
and the last item is 1 or 0 based on whether the image pair is from the same character
or not respectively
"""
is_match = [random.randint(0, 1) for _ in range(self.pair_count)]

cid0_list = [random.randint(0, len(self._characters) - 1) for _ in range(self.pair_count)]
c0_list = [random.randint(0, len(self._character_images[cid]) - 1) for cid in cid0_list]

cid1_list = [
cid0_list[idx] if is_match[idx] == 1 else self._generate_pair(cid0_list[idx])
for idx in range(self.pair_count)
]
c1_list = [random.randint(0, len(self._character_images[cid]) - 1) for cid in cid1_list]

self.pairs_list = [
(((cid0_list[idx], c0_list[idx]), (cid1_list[idx], c1_list[idx])), is_match[idx])
for idx in range(self.pair_count)
]

def _generate_pair(self, character_id):
pair_id = random.randint(0, len(self._characters) - 1)
while pair_id == character_id:
pair_id = random.randint(0, len(self._characters) - 1)
return pair_id
46 changes: 46 additions & 0 deletions torchvision/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,49 @@ def download_url(url, root, filename, md5):
print('Failed download. Trying https -> http instead.'
' Downloading ' + url + ' to ' + fpath)
urllib.request.urlretrieve(url, fpath)


def list_dir(root, prefix=False):
"""List all directories at a given root

Args:
root (str): Path to directory whose folders need to be listed
prefix (bool, optional): If true, prepends the path to each result, otherwise
only returns the name of the directories found
"""
root = os.path.expanduser(root)
directories = list(
filter(
lambda p: os.path.isdir(os.path.join(root, p)),
os.listdir(root)
)
)

if prefix is True:
directories = [os.path.join(root, d) for d in directories]

return directories


def list_files(root, suffix, prefix=False):
"""List all files ending with a suffix at a given root

Args:
root (str): Path to directory whose folders need to be listed
suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png').
It uses the Python "str.endswith" method and is passed directly
prefix (bool, optional): If true, prepends the path to each result, otherwise
only returns the name of the files found
"""
root = os.path.expanduser(root)
files = list(
filter(
lambda p: os.path.isfile(os.path.join(root, p)) and p.endswith(suffix),
os.listdir(root)
)
)

if prefix is True:
files = [os.path.join(root, d) for d in files]

return files