-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathdataloaders.py
102 lines (87 loc) · 3.83 KB
/
dataloaders.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from torch.utils.data import random_split, DataLoader, Dataset
import torch
from torchvision.datasets import ImageFolder
from torchvision import transforms
from pytorch_lightning import LightningDataModule
from torchvision.transforms import ToTensor
import sys
def get_dataloader(args):
if args.dataset == 'CLIC2020':
return CLIC(root=f'{args.data_root}/CLIC/2020')
elif args.dataset == 'CLIC2021':
return CLIC(root=f'{args.data_root}/CLIC/2021')
elif args.dataset == 'Kodak':
return Kodak(root=f'{args.data_root}/Kodak')
elif args.dataset == 'DIV2K':
return DIV2K(root=f'{args.data_root}/DIV2K')
else:
print("Invalid dataset")
sys.exit(0)
class CLIC(LightningDataModule):
def __init__(self, root, batch_size=1):
super().__init__()
self.root = root
self.batch_size = batch_size
# self.train=train
transform = transforms.Compose(
[transforms.ToTensor()]
)
self.train_dset = ImageFolder(root=self.root + '/train', transform=transform)
self.val_dset = ImageFolder(root=self.root + '/valid', transform=transform)
self.test_dset = ImageFolder(root=self.root + '/test', transform=transform)
def train_dataloader(self):
loader = DataLoader(self.train_dset, batch_size=self.batch_size, shuffle=True, num_workers=4, pin_memory=True, persistent_workers=True, drop_last=True)
return loader
def val_dataloader(self):
loader = DataLoader(self.val_dset, batch_size=self.batch_size, num_workers=4, pin_memory=True, persistent_workers=True, drop_last=True)
return loader
def test_dataloader(self):
loader = DataLoader(self.test_dset, batch_size=self.batch_size, num_workers=4, pin_memory=True, persistent_workers=True, drop_last=True)
return loader
class DIV2K(LightningDataModule):
def __init__(self, root, batch_size=1):
super().__init__()
self.root = root
self.batch_size = batch_size
# self.train=train
transform = transforms.Compose(
[transforms.ToTensor()]
)
self.train_dset = ImageFolder(root=self.root + '/train', transform=transform)
self.test_dset = ImageFolder(root=self.root + '/val', transform=transform)
def train_dataloader(self):
loader = DataLoader(self.train_dset, batch_size=self.batch_size, shuffle=True, num_workers=4, pin_memory=True, persistent_workers=True, drop_last=True)
return loader
def test_dataloader(self):
loader = DataLoader(self.test_dset, batch_size=self.batch_size, num_workers=4, pin_memory=True, persistent_workers=True, drop_last=True)
return loader
class Kodak(LightningDataModule):
def __init__(self, root, batch_size=1):
super().__init__()
self.root = root
self.batch_size = batch_size
# self.train=train
transform = transforms.Compose(
[transforms.ToTensor()]
)
self.test_dset = ImageFolder(root=self.root, transform=transform)
def test_dataloader(self):
loader = DataLoader(self.test_dset, batch_size=self.batch_size, num_workers=4, pin_memory=True, persistent_workers=True, drop_last=True)
return loader
# class ImageDataset(Dataset):
# def __init__(self, image_paths, transform=None):
# self.image_paths = image_paths
# self.transform = transform
# def get_class_label(self, image_name):
# # your method here
# y = ...
# return y
# def __getitem__(self, index):
# image_path = self.image_paths[index]
# x = Image.open(image_path)
# y = self.get_class_label(image_path.split('/')[-1])
# if self.transform is not None:
# x = self.transform(x)
# return x, y
# def __len__(self):
# return len(self.image_paths)