-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmake_dataset.py
114 lines (83 loc) · 3.94 KB
/
make_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import torchvision.transforms as trn
import torchvision.datasets as dset
import svhn_loader as svhn
# *** update this before running on your machine ***
# cifar10_path = '/nobackup/my_xfdu/cifarpy'
# cifar100_path = '/nobackup/my_xfdu/cifar-100-python'
# svhn_path = '/nobackup/my_xfdu/svhn/'
# lsun_c_path = '/nobackup/my_xfdu/LSUN_C'
# lsun_r_path = '/nobackup/my_xfdu/LSUN_resize'
# isun_path = '/nobackup/my_xfdu/iSUN'
# dtd_path = '/nobackup/my_xfdu/dtd/images'
# places_path = '/nobackup/my_xfdu/places365/'
# tinyimages_300k_path = '/nobackup/my_xfdu/300K_random_images.npy'
cifar10_path = '/nobackup-slow/dataset/my_xfdu/cifarpy/'
cifar100_path = '/nobackup-slow/dataset/my_xfdu/cifarpy/'
svhn_path = '/nobackup-slow/dataset/svhn/'
lsun_c_path = '/nobackup-slow/dataset/LSUN_C'
lsun_r_path = '/nobackup-slow/dataset/LSUN_resize'
isun_path = '/nobackup-slow/dataset/iSUN'
dtd_path = '/nobackup-slow/dataset/dtd/images'
places_path = '/nobackup-slow/dataset/my_xfdu/places365/'
tinyimages_300k_path = '/nobackup-slow/dataset/my_xfdu/300K_random_images.npy'
def load_CIFAR(dataset, classes=[]):
mean = [x / 255 for x in [125.3, 123.0, 113.9]]
std = [x / 255 for x in [63.0, 62.1, 66.7]]
# train_transform = trn.Compose([trn.RandomHorizontalFlip(), trn.RandomCrop(32, padding=4),
# trn.ToTensor(), trn.Normalize(mean, std)])
train_transform = trn.Compose([trn.ToTensor(), trn.Normalize(mean, std)])
test_transform = trn.Compose([trn.ToTensor(), trn.Normalize(mean, std)])
if dataset in ['cifar10']:
print('loading CIFAR-10')
train_data = dset.CIFAR10(
cifar10_path, train=True, transform=train_transform, download=True)
test_data = dset.CIFAR10(
cifar10_path, train=False, transform=test_transform, download=True)
elif dataset in ['cifar100']:
print('loading CIFAR-100')
train_data = dset.CIFAR100(
cifar100_path, train=True, transform=train_transform, download=True)
test_data = dset.CIFAR100(
cifar100_path, train=False, transform=test_transform, download=True)
return train_data, test_data
def load_SVHN(transform, include_extra=False):
print('loading SVHN')
if not include_extra:
train_data = svhn.SVHN(root=svhn_path, split="train",
transform=transform)
else:
train_data = svhn.SVHN(root=svhn_path, split="train_and_extra",
transform=transform)
test_data = svhn.SVHN(root=svhn_path, split="test",
transform=transform)
train_data.targets = train_data.targets.astype('int64')
test_data.targets = test_data.targets.astype('int64')
return train_data, test_data
def load_ood_dataset(dataset, transform):
if dataset == 'lsun_c':
print('loading LSUN_C')
out_data = dset.ImageFolder(root=lsun_c_path,
transform=transform)
if dataset == 'lsun_r':
print('loading LSUN_R')
out_data = dset.ImageFolder(root=lsun_r_path,
transform=transform)
if dataset == 'isun':
print('loading iSUN')
out_data = dset.ImageFolder(root=isun_path,
transform=transform)
if dataset == 'dtd':
print('loading DTD')
out_data = dset.ImageFolder(root=dtd_path,
transform=transform)
if dataset == 'places':
print('loading Places365')
out_data = dset.ImageFolder(root=places_path,
transform=transform)
import numpy as np
import torch
idx = np.array(range(len(out_data)))
rng = np.random.choice(idx, 10000)
idx = idx[rng]
out_data = torch.utils.data.Subset(out_data, idx)
return out_data