Skip to content

Commit

Permalink
Add description for Dataset objects (#384)
Browse files Browse the repository at this point in the history
* add __repr__ for datasets

* fix lint
  • Loading branch information
vishwakftw authored and soumith committed Jan 4, 2018
1 parent a8071d5 commit 7044049
Show file tree
Hide file tree
Showing 10 changed files with 103 additions and 1 deletion.
12 changes: 12 additions & 0 deletions torchvision/datasets/cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,18 @@ def download(self):
tar.close()
os.chdir(cwd)

def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
tmp = 'train' if self.train is True else 'test'
fmt_str += ' Split: {}\n'.format(tmp)
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str


class CIFAR100(CIFAR10):
"""`CIFAR100 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
Expand Down
10 changes: 10 additions & 0 deletions torchvision/datasets/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,13 @@ def __getitem__(self, index):

def __len__(self):
return len(self.ids)

def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
9 changes: 9 additions & 0 deletions torchvision/datasets/fakedata.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,12 @@ def __getitem__(self, index):

def __len__(self):
return self.size

def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
10 changes: 10 additions & 0 deletions torchvision/datasets/folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,13 @@ def __getitem__(self, index):

def __len__(self):
return len(self.imgs)

def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
9 changes: 8 additions & 1 deletion torchvision/datasets/lsun.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,4 +143,11 @@ def __len__(self):
return self.length

def __repr__(self):
return self.__class__.__name__ + ' (' + self.db_path + ')'
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
12 changes: 12 additions & 0 deletions torchvision/datasets/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,18 @@ def download(self):

print('Done!')

def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
tmp = 'train' if self.train is True else 'test'
fmt_str += ' Split: {}\n'.format(tmp)
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str


class FashionMNIST(MNIST):
"""`Fashion-MNIST <https://github.com/zalandoresearch/fashion-mnist>`_ Dataset.
Expand Down
10 changes: 10 additions & 0 deletions torchvision/datasets/phototour.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,16 @@ def download(self):
with open(self.data_file, 'wb') as f:
torch.save(dataset, f)

def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
tmp = 'train' if self.train is True else 'test'
fmt_str += ' Split: {}\n'.format(tmp)
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str


def read_image_file(data_dir, image_ext, n):
"""Return a Tensor containing the patches
Expand Down
10 changes: 10 additions & 0 deletions torchvision/datasets/semeion.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,13 @@ def download(self):

root = self.root
download_url(self.url, root, self.filename, self.md5_checksum)

def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
11 changes: 11 additions & 0 deletions torchvision/datasets/stl10.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,14 @@ def __loadfile(self, data_file, labels_file=None):
images = np.transpose(images, (0, 1, 3, 2))

return images, labels

def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Split: {}\n'.format(self.split)
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
11 changes: 11 additions & 0 deletions torchvision/datasets/svhn.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,14 @@ def _check_integrity(self):
def download(self):
md5 = self.split_list[self.split][2]
download_url(self.url, self.root, self.filename, md5)

def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Split: {}\n'.format(self.split)
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str

0 comments on commit 7044049

Please sign in to comment.