-
Notifications
You must be signed in to change notification settings - Fork 7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add basic Omniglot dataset loader * Remove unused import * Add Omniglot random pair to sample pair of characters * Precompute random set of pairs, deterministic after object instantiation * Export OmniglotRandomPair via the datasets module interfact * Fix naming convention, use sum instead of reduce * Fix downloading to not download everything, fix Python2 syntax * Fix end line lint * Add random_seed, syntax fixes * Remove randomized pair, take up as a separate generic wrapper * Fix master conflict
- Loading branch information
1 parent
7044049
commit dac9efa
Showing
5 changed files
with
149 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,4 +5,4 @@ torchvision.egg-info/ | |
*/**/*.pyc | ||
*/**/*~ | ||
*~ | ||
docs/build | ||
docs/build |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
from __future__ import print_function | ||
from PIL import Image | ||
from os.path import join | ||
import os | ||
import torch.utils.data as data | ||
from .utils import download_url, check_integrity, list_dir, list_files | ||
|
||
|
||
class Omniglot(data.Dataset): | ||
"""`Omniglot <https://github.com/brendenlake/omniglot>`_ Dataset. | ||
Args: | ||
root (string): Root directory of dataset where directory | ||
``omniglot-py`` exists. | ||
background (bool, optional): If True, creates dataset from the "background" set, otherwise | ||
creates from the "evaluation" set. This terminology is defined by the authors. | ||
transform (callable, optional): A function/transform that takes in an PIL image | ||
and returns a transformed version. E.g, ``transforms.RandomCrop`` | ||
target_transform (callable, optional): A function/transform that takes in the | ||
target and transforms it. | ||
download (bool, optional): If true, downloads the dataset zip files from the internet and | ||
puts it in root directory. If the zip files are already downloaded, they are not | ||
downloaded again. | ||
""" | ||
folder = 'omniglot-py' | ||
download_url_prefix = 'https://github.com/brendenlake/omniglot/raw/master/python' | ||
zips_md5 = { | ||
'images_background': '68d2efa1b9178cc56df9314c21c6e718', | ||
'images_evaluation': '6b91aef0f799c5bb55b94e3f2daec811' | ||
} | ||
|
||
def __init__(self, root, background=True, | ||
transform=None, target_transform=None, | ||
download=False): | ||
self.root = join(os.path.expanduser(root), self.folder) | ||
self.background = background | ||
self.transform = transform | ||
self.target_transform = target_transform | ||
|
||
if download: | ||
self.download() | ||
|
||
if not self._check_integrity(): | ||
raise RuntimeError('Dataset not found or corrupted.' + | ||
' You can use download=True to download it') | ||
|
||
self.target_folder = join(self.root, self._get_target_folder()) | ||
self._alphabets = list_dir(self.target_folder) | ||
self._characters = sum([[join(a, c) for c in list_dir(join(self.target_folder, a))] | ||
for a in self._alphabets], []) | ||
self._character_images = [[(image, idx) for image in list_files(join(self.target_folder, character), '.png')] | ||
for idx, character in enumerate(self._characters)] | ||
self._flat_character_images = sum(self._character_images, []) | ||
|
||
def __len__(self): | ||
return len(self._flat_character_images) | ||
|
||
def __getitem__(self, index): | ||
""" | ||
Args: | ||
index (int): Index | ||
Returns: | ||
tuple: (image, target) where target is index of the target character class. | ||
""" | ||
image_name, character_class = self._flat_character_images[index] | ||
image_path = join(self.target_folder, self._characters[character_class], image_name) | ||
image = Image.open(image_path, mode='r').convert('L') | ||
|
||
if self.transform: | ||
image = self.transform(image) | ||
|
||
if self.target_transform: | ||
character_class = self.target_transform(character_class) | ||
|
||
return image, character_class | ||
|
||
def _check_integrity(self): | ||
zip_filename = self._get_target_folder() | ||
if not check_integrity(join(self.root, zip_filename + '.zip'), self.zips_md5[zip_filename]): | ||
return False | ||
return True | ||
|
||
def download(self): | ||
import zipfile | ||
|
||
if self._check_integrity(): | ||
print('Files already downloaded and verified') | ||
return | ||
|
||
filename = self._get_target_folder() | ||
zip_filename = filename + '.zip' | ||
url = self.download_url_prefix + '/' + zip_filename | ||
download_url(url, self.root, zip_filename, self.zips_md5[filename]) | ||
print('Extracting downloaded file: ' + join(self.root, zip_filename)) | ||
with zipfile.ZipFile(join(self.root, zip_filename), 'r') as zip_file: | ||
zip_file.extractall(self.root) | ||
|
||
def _get_target_folder(self): | ||
return 'images_background' if self.background else 'images_evaluation' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters