Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Omniglot Dataset #323

Merged
merged 13 commits into from
Jan 28, 2018
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ torchvision.egg-info/
*/**/*.pyc
*/**/*~
*~
docs/build
docs/build
4 changes: 3 additions & 1 deletion torchvision/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
from .phototour import PhotoTour
from .fakedata import FakeData
from .semeion import SEMEION
from .omniglot import Omniglot

__all__ = ('LSUN', 'LSUNClass',
'ImageFolder', 'FakeData',
'CocoCaptions', 'CocoDetection',
'CIFAR10', 'CIFAR100', 'EMNIST', 'FashionMNIST',
'MNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION')
'MNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION',
'Omniglot')
1 change: 0 additions & 1 deletion torchvision/datasets/cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from PIL import Image
import os
import os.path
import errno
import numpy as np
import sys
if sys.version_info[0] == 2:
Expand Down
99 changes: 99 additions & 0 deletions torchvision/datasets/omniglot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from __future__ import print_function
from PIL import Image
from os.path import join
import os
import torch.utils.data as data
from .utils import download_url, check_integrity, list_dir, list_files


class Omniglot(data.Dataset):
"""`Omniglot <https://github.com/brendenlake/omniglot>`_ Dataset.
Args:
root (string): Root directory of dataset where directory
``omniglot-py`` exists.
background (bool, optional): If True, creates dataset from the "background" set, otherwise
creates from the "evaluation" set. This terminology is defined by the authors.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If true, downloads the dataset zip files from the internet and
puts it in root directory. If the zip files are already downloaded, they are not
downloaded again.
"""
folder = 'omniglot-py'
download_url_prefix = 'https://github.com/brendenlake/omniglot/raw/master/python'
zips_md5 = {
'images_background': '68d2efa1b9178cc56df9314c21c6e718',
'images_evaluation': '6b91aef0f799c5bb55b94e3f2daec811'
}

def __init__(self, root, background=True,
transform=None, target_transform=None,
download=False):
self.root = join(os.path.expanduser(root), self.folder)
self.background = background
self.transform = transform
self.target_transform = target_transform

if download:
self.download()

if not self._check_integrity():
raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it')

self.target_folder = join(self.root, self._get_target_folder())
self._alphabets = list_dir(self.target_folder)
self._characters = sum([[join(a, c) for c in list_dir(join(self.target_folder, a))]
for a in self._alphabets], [])
self._character_images = [[(image, idx) for image in list_files(join(self.target_folder, character), '.png')]
for idx, character in enumerate(self._characters)]
self._flat_character_images = sum(self._character_images, [])

def __len__(self):
return len(self._flat_character_images)

def __getitem__(self, index):
"""
Args:
index (int): Index

Returns:
tuple: (image, target) where target is index of the target character class.
"""
image_name, character_class = self._flat_character_images[index]
image_path = join(self.target_folder, self._characters[character_class], image_name)
image = Image.open(image_path, mode='r').convert('L')

if self.transform:
image = self.transform(image)

if self.target_transform:
character_class = self.target_transform(character_class)

return image, character_class

def _check_integrity(self):
zip_filename = self._get_target_folder()
if not check_integrity(join(self.root, zip_filename + '.zip'), self.zips_md5[zip_filename]):
return False
return True

def download(self):
import zipfile

if self._check_integrity():
print('Files already downloaded and verified')
return

filename = self._get_target_folder()
zip_filename = filename + '.zip'
url = self.download_url_prefix + '/' + zip_filename
download_url(url, self.root, zip_filename, self.zips_md5[filename])
print('Extracting downloaded file: ' + join(self.root, zip_filename))
with zipfile.ZipFile(join(self.root, zip_filename), 'r') as zip_file:
zip_file.extractall(self.root)

def _get_target_folder(self):
return 'images_background' if self.background else 'images_evaluation'
46 changes: 46 additions & 0 deletions torchvision/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,49 @@ def download_url(url, root, filename, md5):
print('Failed download. Trying https -> http instead.'
' Downloading ' + url + ' to ' + fpath)
urllib.request.urlretrieve(url, fpath)


def list_dir(root, prefix=False):
"""List all directories at a given root

Args:
root (str): Path to directory whose folders need to be listed
prefix (bool, optional): If true, prepends the path to each result, otherwise
only returns the name of the directories found
"""
root = os.path.expanduser(root)
directories = list(
filter(
lambda p: os.path.isdir(os.path.join(root, p)),
os.listdir(root)
)
)

if prefix is True:
directories = [os.path.join(root, d) for d in directories]

return directories


def list_files(root, suffix, prefix=False):
"""List all files ending with a suffix at a given root

Args:
root (str): Path to directory whose folders need to be listed
suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png').
It uses the Python "str.endswith" method and is passed directly
prefix (bool, optional): If true, prepends the path to each result, otherwise
only returns the name of the files found
"""
root = os.path.expanduser(root)
files = list(
filter(
lambda p: os.path.isfile(os.path.join(root, p)) and p.endswith(suffix),
os.listdir(root)
)
)

if prefix is True:
files = [os.path.join(root, d) for d in files]

return files