-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathdataset.py
121 lines (112 loc) · 4.67 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import datasets
import os
from PIL import Image
import json
import jsonlines
class ImagesConfig(datasets.BuilderConfig):
def __init__(self, **kwargs):
super(ImagesConfig, self).__init__(**kwargs)
class Images(datasets.GeneratorBasedBuilder):
def __init__(self, **kwargs):
super(Images, self).__init__(**kwargs)
def _split_generators(self, dl_manager: datasets.DownloadManager):
with jsonlines.open(os.path.join(self.config.data_dir, "meta_data.jsonl"), "r") as meta_data:
data = []
if (
self.config.name == "similar_pairs"
or self.config.name =="reference_only_for_dwpose"
):
for obj in meta_data:
reference_image_path=obj['image_path']
if not os.path.exists(os.path.join(self.config.data_dir,reference_image_path)):
print(reference_image_path+" not exists")
for target_image, similarity in obj["similar_images"]:
if not os.path.exists(os.path.join(self.config.data_dir,target_image)):
print(target_image+" not exists")
continue
data.append(
(
reference_image_path,
target_image
)
)
print("data size:", len(data))
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={"split": datasets.Split.TRAIN, "data": data},
)
]
BUILDER_CONFIGS = [
ImagesConfig(
name="similar_pairs",
description="simliar pair dataset,item is a pair of similar images",
),
ImagesConfig(
name="image_prompt_pairs",
description="image prompt pairs",
),
ImagesConfig(
name="reference_only_for_dwpose",
)
]
def _info(self):
if self.config.name == "similar_pairs":
return datasets.DatasetInfo(
features=datasets.Features(
{
"reference_image": datasets.features.Image(),
"reference_image_path": datasets.Value("string"),
"target_image": datasets.features.Image(),
"target_image_path": datasets.Value("string"),
"similarity": datasets.Value("float32"),
}
)
)
elif self.config.name == "image_prompt_pairs":
return datasets.DatasetInfo(
features=datasets.Features(
{
"image": datasets.features.Image(),
"image_path": datasets.features.Value("string"),
"prompt": datasets.Value("string"),
}
)
)
elif self.config.name == "reference_only_for_dwpose":
return datasets.DatasetInfo(
features=datasets.Features(
{
"reference_image": datasets.features.Image(),
"target_image": datasets.features.Image(),
"blueprint_image": datasets.features.Image()
}
)
)
def _generate_examples(self, split, data):
if self.config.name == "similar_pairs":
for image1_path, image2_path, similarity in data:
yield image1_path + ":" + image2_path, {
"image1": Image.open(
os.path.join(self.config.data_dir, image1_path)
),
"image1_path": image1_path,
"image2": Image.open(
os.path.join(self.config.data_dir, image2_path)
),
"image2_path": image2_path,
"similarity": similarity,
}
elif self.config.name=="reference_only_for_dwpose":
for image1_path, image2_path in data:
yield image1_path + ":" + image2_path, {
"reference_image": Image.open(
os.path.join(self.config.data_dir, image1_path)
),
"target_image": Image.open(
os.path.join(self.config.data_dir, image2_path)
),
"blueprint_image":Image.open(
os.path.join(self.config.data_dir,image2_path.replace('data','dwpose'))
),
}