-
Notifications
You must be signed in to change notification settings - Fork 23
/
Copy pathload_dataset.py
261 lines (224 loc) · 8.85 KB
/
load_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
import json
from typing import Dict, List, Optional, Union
import os
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
@dataclass
class QA:
question: str
answer: Optional[str]
evidence: List[str]
category: Optional[int] = None
adversarial_answer: Optional[str] = None
@property
def final_answer(self) -> Optional[str]:
"""Get the appropriate answer based on category."""
if self.category == 5:
return self.adversarial_answer
return self.answer
@dataclass
class Turn:
speaker: str
dia_id: str
text: str
@dataclass
class Session:
session_id: int
date_time: str
turns: List[Turn]
@dataclass
class Conversation:
speaker_a: str
speaker_b: str
sessions: Dict[int, Session]
@dataclass
class EventSummary:
events: Dict[str, Dict[str, List[str]]] # session -> speaker -> events
@dataclass
class Observation:
observations: Dict[str, Dict[str, List[List[str]]]] # session -> speaker -> [observation, evidence]
@dataclass
class LoCoMoSample:
"""A single sample from the LoComo dataset"""
sample_id: str
qa: List[QA]
conversation: Conversation
event_summary: EventSummary
observation: Observation
session_summary: Dict[str, str]
def parse_session(session_data: List[dict], session_id: int, date_time: str) -> Session:
"""Parse a single session's data, including turns with images by using their captions."""
turns = []
for turn in session_data:
# For turns with images, combine caption and text
text = turn.get("text", "")
if "img_url" in turn and "blip_caption" in turn:
caption_text = f"[Image: {turn['blip_caption']}]"
if text:
text = f"{caption_text} {text}"
else:
text = caption_text
turns.append(Turn(
speaker=turn["speaker"],
dia_id=turn["dia_id"],
text=text
))
return Session(session_id=session_id, date_time=date_time, turns=turns)
def parse_conversation(conv_data: dict) -> Conversation:
"""Parse conversation data."""
sessions = {}
for key, value in conv_data.items():
if key.startswith("session_") and isinstance(value, list):
session_id = int(key.split("_")[1])
date_time = conv_data.get(f"{key}_date_time")
if date_time:
session = parse_session(value, session_id, date_time)
# Only add sessions that have turns after filtering
if session.turns:
sessions[session_id] = session
return Conversation(
speaker_a=conv_data["speaker_a"],
speaker_b=conv_data["speaker_b"],
sessions=sessions
)
def load_locomo_dataset(file_path: Union[str, Path]) -> List[LoCoMoSample]:
"""
Load the LoComo dataset from a JSON file, including image-based content by using captions.
Args:
file_path: Path to the JSON file containing the dataset
Returns:
List of LoCoMoSample objects containing the parsed data
"""
if isinstance(file_path, str):
file_path = Path(file_path)
if not file_path.exists():
raise FileNotFoundError(f"Dataset file not found at {file_path}")
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
samples = []
total_qa = 0
total_image_qa = 0
qa_counts_per_sample = []
for sample_idx, sample in enumerate(data):
try:
# Parse QA data
qa_list = []
sample_qa_count = 0
sample_image_qa_count = 0
for qa_idx, qa in enumerate(sample["qa"]):
try:
# Check if QA has image evidence
has_image_evidence = False
for evidence_id in qa.get("evidence", []):
if ":" not in evidence_id:
continue
turn_id = evidence_id.split(":")[1]
for session in sample["conversation"].values():
if isinstance(session, list):
for turn in session:
if turn.get("dia_id", "").endswith(turn_id):
if "img_url" in turn or "blip_caption" in turn:
has_image_evidence = True
break
if has_image_evidence:
sample_image_qa_count += 1
qa_obj = QA(
question=qa["question"],
answer=qa.get("answer"),
evidence=qa.get("evidence", []),
category=qa.get("category"),
adversarial_answer=qa.get("adversarial_answer")
)
qa_list.append(qa_obj)
sample_qa_count += 1
except KeyError as e:
print(f"Error in sample {sample_idx}, QA pair {qa_idx}:")
print(f"QA data: {qa}")
raise e
except Exception as e:
print(f"Unexpected error in sample {sample_idx}, QA pair {qa_idx}:")
print(f"QA data: {qa}")
raise e
# Parse conversation
conversation = parse_conversation(sample["conversation"])
# Parse event summary
event_summary = EventSummary(events=sample["event_summary"])
# Parse observation
observation = Observation(observations=sample["observation"])
# Get session summary
session_summary = sample.get("session_summary", {})
# Create sample object
sample_obj = LoCoMoSample(
sample_id=str(sample_idx),
qa=qa_list,
conversation=conversation,
event_summary=event_summary,
observation=observation,
session_summary=session_summary
)
samples.append(sample_obj)
total_qa += sample_qa_count
total_image_qa += sample_image_qa_count
qa_counts_per_sample.append(sample_qa_count)
# Print statistics for this sample
print(f"\nSample {sample_idx}:")
print(f" Total QAs: {sample_qa_count}")
print(f" QAs with image evidence: {sample_image_qa_count}")
except Exception as e:
print(f"Error processing sample {sample_idx}:")
print(str(e))
raise e
# Print overall statistics
print("\nOverall Statistics:")
print(f"Total QAs: {total_qa}")
print(f"Total QAs with image evidence: {total_image_qa}")
print(f"Average QAs per sample: {total_qa / len(samples):.2f}")
print(f"Min QAs in a sample: {min(qa_counts_per_sample)}")
print(f"Max QAs in a sample: {max(qa_counts_per_sample)}")
return samples
def get_dataset_statistics(samples: List[LoCoMoSample]) -> Dict:
"""
Get basic statistics about the text-only dataset.
Args:
samples: List of LoCoMoSample objects
Returns:
Dictionary containing various statistics about the dataset
"""
stats = {
"num_samples": len(samples),
"total_qa_pairs": sum(len(sample.qa) for sample in samples),
"total_sessions": sum(len(sample.conversation.sessions) for sample in samples),
"total_turns": sum(
sum(len(session.turns) for session in sample.conversation.sessions.values())
for sample in samples
),
"qa_with_adversarial": sum(
sum(1 for qa in sample.qa if qa.adversarial_answer is not None)
for sample in samples
)
}
return stats
if __name__ == "__main__":
# Example usage
dataset_path = Path(__file__).parent / "data" / "locomo10.json"
try:
print(f"Loading dataset from: {dataset_path}")
samples = load_locomo_dataset(dataset_path)
for sample_idx, sample in enumerate(samples):
print(f"\nSample {sample_idx}:")
for _,turns in sample.conversation.sessions.items():
for turn in turns.turns:
print(turn)
break
# stats = get_dataset_statistics(samples)
# print("\nDataset Statistics (Text-only content):")
# for key, value in stats.items():
# print(f"{key}: {value}")
# print(len(samples))
# for sample in samples:
# print(sample)
# break
except Exception as e:
print(f"Error loading dataset: {e}")
raise