From 149760d41aab24907357b7516ef15c294c4757ac Mon Sep 17 00:00:00 2001 From: Ludovic Denoyer Date: Wed, 25 Jan 2017 17:05:13 +0100 Subject: [PATCH 1/9] Add Omniglot dataset loader --- torchvision/datasets/__init__.py | 3 +- torchvision/datasets/omniglot.py | 104 +++++++++++++++++++++++++++++++ 2 files changed, 106 insertions(+), 1 deletion(-) create mode 100644 torchvision/datasets/omniglot.py diff --git a/torchvision/datasets/__init__.py b/torchvision/datasets/__init__.py index e9c4b0e7184..56622b679ca 100644 --- a/torchvision/datasets/__init__.py +++ b/torchvision/datasets/__init__.py @@ -3,9 +3,10 @@ from .coco import CocoCaptions, CocoDetection from .cifar import CIFAR10, CIFAR100 from .mnist import MNIST +from .omniglot import OMNIGLOT __all__ = ('LSUN', 'LSUNClass', 'ImageFolder', 'CocoCaptions', 'CocoDetection', 'CIFAR10', 'CIFAR100', - 'MNIST') + 'MNIST','OMNIGLOT') diff --git a/torchvision/datasets/omniglot.py b/torchvision/datasets/omniglot.py new file mode 100644 index 00000000000..cb7ccdce104 --- /dev/null +++ b/torchvision/datasets/omniglot.py @@ -0,0 +1,104 @@ +from __future__ import print_function +import torch.utils.data as data +from PIL import Image +import os +import os.path +import errno +import torch +import json +import codecs +import numpy as np +from PIL import Image + +class OMNIGLOT(data.Dataset): + urls = [ + 'https://github.com/brendenlake/omniglot/raw/master/python/images_background.zip', + 'https://github.com/brendenlake/omniglot/raw/master/python/images_evaluation.zip' + ] + raw_folder = 'raw' + processed_folder = 'processed' + training_file = 'training.pt' + test_file = 'test.pt' + + def __init__(self, root, train=True, transform=None, target_transform=None, download=False): + self.root = root + self.transform = transform + self.target_transform = target_transform + if download: + self.download() + + if not self._check_exists(): + raise RuntimeError('Dataset not found.' + + ' You can use download=True to download it') + + self.all_items=find_classes(os.path.join(self.root, self.processed_folder)) + self.idx_classes=index_classes(self.all_items) + + def __getitem__(self, index): + filename=self.all_items[index][0] + path=self.all_items[index][2]+"/"+filename + img=Image.open(path).convert('RGB') + target=self.idx_classes[self.all_items[index][1]] + if self.transform is not None: + img = self.transform(img) + if self.target_transform is not None: + target = self.target_transform(target) + + return img,target + + def __len__(self): + return len(self.all_items) + + def _check_exists(self): + return os.path.exists(os.path.join(self.root, self.processed_folder, "images_evaluation")) and \ + os.path.exists(os.path.join(self.root, self.processed_folder, "images_background")) + + def download(self): + from six.moves import urllib + import zipfile + + if self._check_exists(): + return + + # download files + try: + os.makedirs(os.path.join(self.root, self.raw_folder)) + os.makedirs(os.path.join(self.root, self.processed_folder)) + except OSError as e: + if e.errno == errno.EEXIST: + pass + else: + raise + + for url in self.urls: + print('== Downloading ' + url) + data = urllib.request.urlopen(url) + filename = url.rpartition('/')[2] + file_path = os.path.join(self.root, self.raw_folder, filename) + with open(file_path, 'wb') as f: + f.write(data.read()) + file_processed = os.path.join(self.root, self.processed_folder) + print("== Unzip from "+file_path+" to "+file_processed) + zip_ref = zipfile.ZipFile(file_path, 'r') + zip_ref.extractall(file_processed) + zip_ref.close() + print("Download finished.") + +def find_classes(root_dir): + retour=[] + for (root,dirs,files) in os.walk(root_dir): + for f in files: + if (f.endswith("png")): + r=root.split('/') + lr=len(r) + retour.append((f,r[lr-2]+"/"+r[lr-1],root)) + print("Found %d items "%len(retour)) + return retour + +def index_classes(items): + idx={} + for i in items: + if (not i[1] in idx): + idx[i[1]]=len(idx) + print("Found %d classes"% len(idx)) + return idx \ No newline at end of file From 206c74ec9c5bc91124cb7b6fc17fa35ee4b2c371 Mon Sep 17 00:00:00 2001 From: Ludovic Denoyer Date: Wed, 25 Jan 2017 17:11:57 +0100 Subject: [PATCH 2/9] Add OMNIGLOT documentation --- README.rst | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/README.rst b/README.rst index 3da007a5d2e..5e7030c10f0 100644 --- a/README.rst +++ b/README.rst @@ -46,7 +46,7 @@ The following dataset loaders are available: - `ImageFolder <#imagefolder>`__ - `Imagenet-12 <#imagenet-12>`__ - `CIFAR10 and CIFAR100 <#cifar>`__ - +- `OMNIGLOT`__ Datasets have the API: - ``__getitem__`` - ``__len__`` They all subclass from ``torch.utils.data.Dataset`` Hence, they can all be multi-threaded (python multiprocessing) using standard torch.utils.data.DataLoader. @@ -187,6 +187,13 @@ here `__. +OMNIGLOT +~~~~~~~~ + +dset.OMNIGLOT(root_dir, [transform=None, target_transform=None])` + +From: `Lake, B. M., Salakhutdinov, R., and Tenenbaum, J. B. (2015). Human-level concept learning through probabilistic program induction. Science, 350(6266), 1332-1338.` + Models ====== From a2954d2d128b5e2a3bcc6210bf8029d037448a48 Mon Sep 17 00:00:00 2001 From: Ludovic Denoyer Date: Wed, 25 Jan 2017 17:13:52 +0100 Subject: [PATCH 3/9] Add OMNIGLOT documentation --- README.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.rst b/README.rst index 5e7030c10f0..6650ef72811 100644 --- a/README.rst +++ b/README.rst @@ -46,7 +46,7 @@ The following dataset loaders are available: - `ImageFolder <#imagefolder>`__ - `Imagenet-12 <#imagenet-12>`__ - `CIFAR10 and CIFAR100 <#cifar>`__ -- `OMNIGLOT`__ +- `OMNIGLOT <#omniglot>`__ Datasets have the API: - ``__getitem__`` - ``__len__`` They all subclass from ``torch.utils.data.Dataset`` Hence, they can all be multi-threaded (python multiprocessing) using standard torch.utils.data.DataLoader. From a6f165a691e4d2677df79dc2e257a4703d6850bf Mon Sep 17 00:00:00 2001 From: Ludovic Denoyer Date: Wed, 25 Jan 2017 17:15:04 +0100 Subject: [PATCH 4/9] Add OMNIGLOT documentation --- README.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.rst b/README.rst index 6650ef72811..56cf988a318 100644 --- a/README.rst +++ b/README.rst @@ -46,7 +46,7 @@ The following dataset loaders are available: - `ImageFolder <#imagefolder>`__ - `Imagenet-12 <#imagenet-12>`__ - `CIFAR10 and CIFAR100 <#cifar>`__ -- `OMNIGLOT <#omniglot>`__ +- OMNIGLOT Datasets have the API: - ``__getitem__`` - ``__len__`` They all subclass from ``torch.utils.data.Dataset`` Hence, they can all be multi-threaded (python multiprocessing) using standard torch.utils.data.DataLoader. @@ -190,7 +190,7 @@ example Date: Wed, 25 Jan 2017 17:16:07 +0100 Subject: [PATCH 5/9] Add OMNIGLOT documentation --- README.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.rst b/README.rst index 56cf988a318..ef60fa9c98d 100644 --- a/README.rst +++ b/README.rst @@ -190,7 +190,7 @@ example Date: Wed, 25 Jan 2017 17:18:16 +0100 Subject: [PATCH 6/9] Add OMNIGLOT documentation --- README.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.rst b/README.rst index ef60fa9c98d..5934fd0cb77 100644 --- a/README.rst +++ b/README.rst @@ -192,6 +192,9 @@ OMNIGLOT `dset.OMNIGLOT(root_dir, [transform=None, target_transform=None])` +The dataset is composed of pairs: `(Image,Category idx)`. Each category corresponds to one character in one alphabet. Matching between classes indexes and real classes can be accessed through: `dataset.idx_classes` + + From: `Lake, B. M., Salakhutdinov, R., and Tenenbaum, J. B. (2015). Human-level concept learning through probabilistic program induction. Science, 350(6266), 1332-1338.` Models From e472eaa4aaa359a93cf8a1a83850b95bd527ea85 Mon Sep 17 00:00:00 2001 From: "ludovic.denoyer" Date: Wed, 25 Jan 2017 22:13:21 +0100 Subject: [PATCH 7/9] omniglot fixes --- .idea/vcs.xml | 6 ++++++ torchvision/datasets/omniglot.py | 21 ++++++++++++++++++--- 2 files changed, 24 insertions(+), 3 deletions(-) create mode 100644 .idea/vcs.xml diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 00000000000..94a25f7f4cb --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/torchvision/datasets/omniglot.py b/torchvision/datasets/omniglot.py index cb7ccdce104..145884b2a77 100644 --- a/torchvision/datasets/omniglot.py +++ b/torchvision/datasets/omniglot.py @@ -20,10 +20,20 @@ class OMNIGLOT(data.Dataset): training_file = 'training.pt' test_file = 'test.pt' - def __init__(self, root, train=True, transform=None, target_transform=None, download=False): + ''' + Args: + + - root: the directory where the dataset will be stored + - transform: how to transform the input + - target_transform: how to transform the target + - download: need to download the dataset + - input_is_filename: if True, the returned data is (filename,target), it is a pair (PIL.Image,target) elsewhere + ''' + def __init__(self, root, transform=None, target_transform=None, download=False,input_is_filename=False): self.root = root self.transform = transform self.target_transform = target_transform + self.input_is_filename=input_is_filename if download: self.download() @@ -36,8 +46,13 @@ def __init__(self, root, train=True, transform=None, target_transform=None, down def __getitem__(self, index): filename=self.all_items[index][0] - path=self.all_items[index][2]+"/"+filename - img=Image.open(path).convert('RGB') + path=str.join('/',[self.all_items[index][2],filename]) + + if (not self.input_is_filename): + img=Image.open(path).convert('RGB') + else: + img=path + target=self.idx_classes[self.all_items[index][1]] if self.transform is not None: img = self.transform(img) From 6d4e78f50e9d69d17c80661d0d398ca502c29e5f Mon Sep 17 00:00:00 2001 From: "ludovic.denoyer" Date: Wed, 25 Jan 2017 22:15:42 +0100 Subject: [PATCH 8/9] omniglot fixes --- .idea/vcs.xml | 6 ------ 1 file changed, 6 deletions(-) delete mode 100644 .idea/vcs.xml diff --git a/.idea/vcs.xml b/.idea/vcs.xml deleted file mode 100644 index 94a25f7f4cb..00000000000 --- a/.idea/vcs.xml +++ /dev/null @@ -1,6 +0,0 @@ - - - - - - \ No newline at end of file From b0cf26776c9bbbec1a8d8702b76a597d4deac200 Mon Sep 17 00:00:00 2001 From: "ludovic.denoyer" Date: Wed, 25 Jan 2017 22:33:03 +0100 Subject: [PATCH 9/9] Omniglot is simplified. Add the FilenameToPILImage transformer --- README.rst | 4 ++-- test/test_omniglot.py | 13 +++++++++++++ torchvision/datasets/omniglot.py | 18 +++++++----------- torchvision/transforms.py | 7 +++++++ 4 files changed, 29 insertions(+), 13 deletions(-) create mode 100644 test/test_omniglot.py diff --git a/README.rst b/README.rst index 5934fd0cb77..9193e8f6827 100644 --- a/README.rst +++ b/README.rst @@ -192,8 +192,8 @@ OMNIGLOT `dset.OMNIGLOT(root_dir, [transform=None, target_transform=None])` -The dataset is composed of pairs: `(Image,Category idx)`. Each category corresponds to one character in one alphabet. Matching between classes indexes and real classes can be accessed through: `dataset.idx_classes` - +The dataset is composed of pairs: ``(Filename,Category idx)``. Each caty"egory corresponds to one character in one alphabet. Matching between classes indexes and real classes can be accessed through: `dataset.idx_classes` +The dataset can be used with ``transform=transforms.FilenameToPILImage`` to obtain pairs of (PIL Image,Category_idx) From: `Lake, B. M., Salakhutdinov, R., and Tenenbaum, J. B. (2015). Human-level concept learning through probabilistic program induction. Science, 350(6266), 1332-1338.` diff --git a/test/test_omniglot.py b/test/test_omniglot.py new file mode 100644 index 00000000000..82e2e6a4bf7 --- /dev/null +++ b/test/test_omniglot.py @@ -0,0 +1,13 @@ +import torch +import torchvision.datasets as dset +import torchvision.transforms as transforms + +print('Omniglot') +a = dset.OMNIGLOT("../data", download=True,transform=transforms.Compose([transforms.FilenameToPILImage(),transforms.ToTensor()])) + +print(a.idx_classes) +print(a[3]) +# print('\n\nCifar 100') +# a = dset.CIFAR100(root="abc/def/ghi", download=True) + +# print(a[3]) diff --git a/torchvision/datasets/omniglot.py b/torchvision/datasets/omniglot.py index 145884b2a77..e6999cea832 100644 --- a/torchvision/datasets/omniglot.py +++ b/torchvision/datasets/omniglot.py @@ -21,19 +21,20 @@ class OMNIGLOT(data.Dataset): test_file = 'test.pt' ''' + The items are (filename,category). The index of all the categories can be found in self.idx_classes + Args: - root: the directory where the dataset will be stored - transform: how to transform the input - target_transform: how to transform the target - download: need to download the dataset - - input_is_filename: if True, the returned data is (filename,target), it is a pair (PIL.Image,target) elsewhere ''' - def __init__(self, root, transform=None, target_transform=None, download=False,input_is_filename=False): + def __init__(self, root, transform=None, target_transform=None, download=False): self.root = root self.transform = transform self.target_transform = target_transform - self.input_is_filename=input_is_filename + if download: self.download() @@ -46,12 +47,7 @@ def __init__(self, root, transform=None, target_transform=None, download=False,i def __getitem__(self, index): filename=self.all_items[index][0] - path=str.join('/',[self.all_items[index][2],filename]) - - if (not self.input_is_filename): - img=Image.open(path).convert('RGB') - else: - img=path + img=str.join('/',[self.all_items[index][2],filename]) target=self.idx_classes[self.all_items[index][1]] if self.transform is not None: @@ -107,7 +103,7 @@ def find_classes(root_dir): r=root.split('/') lr=len(r) retour.append((f,r[lr-2]+"/"+r[lr-1],root)) - print("Found %d items "%len(retour)) + print("== Found %d items "%len(retour)) return retour def index_classes(items): @@ -115,5 +111,5 @@ def index_classes(items): for i in items: if (not i[1] in idx): idx[i[1]]=len(idx) - print("Found %d classes"% len(idx)) + print("== Found %d classes"% len(idx)) return idx \ No newline at end of file diff --git a/torchvision/transforms.py b/torchvision/transforms.py index 68ce23a1b1a..62af39903c9 100644 --- a/torchvision/transforms.py +++ b/torchvision/transforms.py @@ -27,6 +27,13 @@ def __call__(self, img): img = t(img) return img +class FilenameToPILImage(object): + """ + Load a PIL RGB Image from a filename. + """ + def __call__(self,filename): + img=Image.open(filename).convert('RGB') + return img class ToTensor(object): """Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range