Skip to content

Commit

Permalink
Omniglot Dataset (#323)
Browse files Browse the repository at this point in the history
* Add basic Omniglot dataset loader

* Remove unused import

* Add Omniglot random pair to sample pair of characters

* Precompute random set of pairs, deterministic after object instantiation

* Export OmniglotRandomPair via the datasets module interfact

* Fix naming convention, use sum instead of reduce

* Fix downloading to not download everything, fix Python2 syntax

* Fix end line lint

* Add random_seed, syntax fixes

* Remove randomized pair, take up as a separate generic wrapper

* Fix master conflict
  • Loading branch information
activatedgeek authored and fmassa committed Jan 28, 2018
1 parent 7044049 commit dac9efa
Show file tree
Hide file tree
Showing 5 changed files with 149 additions and 3 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ torchvision.egg-info/
*/**/*.pyc
*/**/*~
*~
docs/build
docs/build
4 changes: 3 additions & 1 deletion torchvision/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
from .phototour import PhotoTour
from .fakedata import FakeData
from .semeion import SEMEION
from .omniglot import Omniglot

__all__ = ('LSUN', 'LSUNClass',
'ImageFolder', 'FakeData',
'CocoCaptions', 'CocoDetection',
'CIFAR10', 'CIFAR100', 'EMNIST', 'FashionMNIST',
'MNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION')
'MNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION',
'Omniglot')
1 change: 0 additions & 1 deletion torchvision/datasets/cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from PIL import Image
import os
import os.path
import errno
import numpy as np
import sys
if sys.version_info[0] == 2:
Expand Down
99 changes: 99 additions & 0 deletions torchvision/datasets/omniglot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from __future__ import print_function
from PIL import Image
from os.path import join
import os
import torch.utils.data as data
from .utils import download_url, check_integrity, list_dir, list_files


class Omniglot(data.Dataset):
"""`Omniglot <https://github.com/brendenlake/omniglot>`_ Dataset.
Args:
root (string): Root directory of dataset where directory
``omniglot-py`` exists.
background (bool, optional): If True, creates dataset from the "background" set, otherwise
creates from the "evaluation" set. This terminology is defined by the authors.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If true, downloads the dataset zip files from the internet and
puts it in root directory. If the zip files are already downloaded, they are not
downloaded again.
"""
folder = 'omniglot-py'
download_url_prefix = 'https://github.com/brendenlake/omniglot/raw/master/python'
zips_md5 = {
'images_background': '68d2efa1b9178cc56df9314c21c6e718',
'images_evaluation': '6b91aef0f799c5bb55b94e3f2daec811'
}

def __init__(self, root, background=True,
transform=None, target_transform=None,
download=False):
self.root = join(os.path.expanduser(root), self.folder)
self.background = background
self.transform = transform
self.target_transform = target_transform

if download:
self.download()

if not self._check_integrity():
raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it')

self.target_folder = join(self.root, self._get_target_folder())
self._alphabets = list_dir(self.target_folder)
self._characters = sum([[join(a, c) for c in list_dir(join(self.target_folder, a))]
for a in self._alphabets], [])
self._character_images = [[(image, idx) for image in list_files(join(self.target_folder, character), '.png')]
for idx, character in enumerate(self._characters)]
self._flat_character_images = sum(self._character_images, [])

def __len__(self):
return len(self._flat_character_images)

def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target character class.
"""
image_name, character_class = self._flat_character_images[index]
image_path = join(self.target_folder, self._characters[character_class], image_name)
image = Image.open(image_path, mode='r').convert('L')

if self.transform:
image = self.transform(image)

if self.target_transform:
character_class = self.target_transform(character_class)

return image, character_class

def _check_integrity(self):
zip_filename = self._get_target_folder()
if not check_integrity(join(self.root, zip_filename + '.zip'), self.zips_md5[zip_filename]):
return False
return True

def download(self):
import zipfile

if self._check_integrity():
print('Files already downloaded and verified')
return

filename = self._get_target_folder()
zip_filename = filename + '.zip'
url = self.download_url_prefix + '/' + zip_filename
download_url(url, self.root, zip_filename, self.zips_md5[filename])
print('Extracting downloaded file: ' + join(self.root, zip_filename))
with zipfile.ZipFile(join(self.root, zip_filename), 'r') as zip_file:
zip_file.extractall(self.root)

def _get_target_folder(self):
return 'images_background' if self.background else 'images_evaluation'
46 changes: 46 additions & 0 deletions torchvision/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,49 @@ def download_url(url, root, filename, md5):
print('Failed download. Trying https -> http instead.'
' Downloading ' + url + ' to ' + fpath)
urllib.request.urlretrieve(url, fpath)


def list_dir(root, prefix=False):
"""List all directories at a given root
Args:
root (str): Path to directory whose folders need to be listed
prefix (bool, optional): If true, prepends the path to each result, otherwise
only returns the name of the directories found
"""
root = os.path.expanduser(root)
directories = list(
filter(
lambda p: os.path.isdir(os.path.join(root, p)),
os.listdir(root)
)
)

if prefix is True:
directories = [os.path.join(root, d) for d in directories]

return directories


def list_files(root, suffix, prefix=False):
"""List all files ending with a suffix at a given root
Args:
root (str): Path to directory whose folders need to be listed
suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png').
It uses the Python "str.endswith" method and is passed directly
prefix (bool, optional): If true, prepends the path to each result, otherwise
only returns the name of the files found
"""
root = os.path.expanduser(root)
files = list(
filter(
lambda p: os.path.isfile(os.path.join(root, p)) and p.endswith(suffix),
os.listdir(root)
)
)

if prefix is True:
files = [os.path.join(root, d) for d in files]

return files

0 comments on commit dac9efa

Please sign in to comment.