-
Notifications
You must be signed in to change notification settings - Fork 23
/
Copy pathutils.py
211 lines (184 loc) · 7.63 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
import re
import string
import numpy as np
from typing import List, Dict, Union
import statistics
from collections import defaultdict
from rouge_score import rouge_scorer
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from bert_score import score as bert_score
import nltk
from nltk.translate.meteor_score import meteor_score
from sentence_transformers import SentenceTransformer
import logging
from dataclasses import dataclass
from pathlib import Path
from openai import OpenAI
from load_dataset import load_locomo_dataset, QA, Turn, Session, Conversation
from sentence_transformers.util import pytorch_cos_sim
# Download required NLTK data
try:
nltk.download('punkt', quiet=True)
nltk.download('wordnet', quiet=True)
except Exception as e:
print(f"Error downloading NLTK data: {e}")
# Initialize SentenceTransformer model (this will be reused)
try:
sentence_model = SentenceTransformer('all-MiniLM-L6-v2')
except Exception as e:
print(f"Warning: Could not load SentenceTransformer model: {e}")
sentence_model = None
def simple_tokenize(text):
"""Simple tokenization function."""
# Convert to string if not already
text = str(text)
return text.lower().replace('.', ' ').replace(',', ' ').replace('!', ' ').replace('?', ' ').split()
def calculate_rouge_scores(prediction: str, reference: str) -> Dict[str, float]:
"""Calculate ROUGE scores for prediction against reference."""
scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
scores = scorer.score(reference, prediction)
return {
'rouge1_f': scores['rouge1'].fmeasure,
'rouge2_f': scores['rouge2'].fmeasure,
'rougeL_f': scores['rougeL'].fmeasure
}
def calculate_bleu_scores(prediction: str, reference: str) -> Dict[str, float]:
"""Calculate BLEU scores with different n-gram settings."""
pred_tokens = nltk.word_tokenize(prediction.lower())
ref_tokens = [nltk.word_tokenize(reference.lower())]
weights_list = [(1, 0, 0, 0), (0.5, 0.5, 0, 0), (0.33, 0.33, 0.33, 0), (0.25, 0.25, 0.25, 0.25)]
smooth = SmoothingFunction().method1
scores = {}
for n, weights in enumerate(weights_list, start=1):
try:
score = sentence_bleu(ref_tokens, pred_tokens, weights=weights, smoothing_function=smooth)
except Exception:
score = 0.0
scores[f'bleu{n}'] = score
return scores
def calculate_bert_scores(prediction: str, reference: str) -> Dict[str, float]:
"""Calculate BERTScore for semantic similarity."""
try:
P, R, F1 = bert_score([prediction], [reference], lang='en', verbose=False)
return {
'bert_precision': P.item(),
'bert_recall': R.item(),
'bert_f1': F1.item()
}
except Exception as e:
print(f"Error calculating BERTScore: {e}")
return {
'bert_precision': 0.0,
'bert_recall': 0.0,
'bert_f1': 0.0
}
def calculate_meteor_score(prediction: str, reference: str) -> float:
"""Calculate METEOR score for the prediction."""
try:
return meteor_score([reference.split()], prediction.split())
except Exception as e:
print(f"Error calculating METEOR score: {e}")
return 0.0
def calculate_sentence_similarity(prediction: str, reference: str) -> float:
"""Calculate sentence embedding similarity using SentenceBERT."""
if sentence_model is None:
return 0.0
try:
# Encode sentences
embedding1 = sentence_model.encode([prediction], convert_to_tensor=True)
embedding2 = sentence_model.encode([reference], convert_to_tensor=True)
# Calculate cosine similarity
similarity = pytorch_cos_sim(embedding1, embedding2).item()
return float(similarity)
except Exception as e:
print(f"Error calculating sentence similarity: {e}")
return 0.0
def calculate_metrics(prediction: str, reference: str) -> Dict[str, float]:
"""Calculate comprehensive evaluation metrics for a prediction."""
# Handle empty or None values
if not prediction or not reference:
return {
"exact_match": 0,
"f1": 0.0,
"rouge1_f": 0.0,
"rouge2_f": 0.0,
"rougeL_f": 0.0,
"bleu1": 0.0,
"bleu2": 0.0,
"bleu3": 0.0,
"bleu4": 0.0,
"bert_f1": 0.0,
"meteor": 0.0,
"sbert_similarity": 0.0
}
# Convert to strings if they're not already
prediction = str(prediction).strip()
reference = str(reference).strip()
# Calculate exact match
exact_match = int(prediction.lower() == reference.lower())
# Calculate token-based F1 score
pred_tokens = set(simple_tokenize(prediction))
ref_tokens = set(simple_tokenize(reference))
common_tokens = pred_tokens & ref_tokens
if not pred_tokens or not ref_tokens:
f1 = 0.0
else:
precision = len(common_tokens) / len(pred_tokens)
recall = len(common_tokens) / len(ref_tokens)
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
# Calculate all scores
rouge_scores = calculate_rouge_scores(prediction, reference)
bleu_scores = calculate_bleu_scores(prediction, reference)
bert_scores = calculate_bert_scores(prediction, reference)
meteor = calculate_meteor_score(prediction, reference)
sbert_similarity = calculate_sentence_similarity(prediction, reference)
# Combine all metrics
metrics = {
"exact_match": exact_match,
"f1": f1,
**rouge_scores,
**bleu_scores,
**bert_scores,
"meteor": meteor,
"sbert_similarity": sbert_similarity
}
return metrics
def aggregate_metrics(all_metrics: List[Dict[str, float]], all_categories: List[int]) -> Dict[str, Dict[str, Union[float, Dict[str, float]]]]:
"""Calculate aggregate statistics for all metrics, split by category."""
if not all_metrics:
return {}
# Initialize aggregates for overall and per-category metrics
aggregates = defaultdict(list)
category_aggregates = defaultdict(lambda: defaultdict(list))
# Collect all values for each metric, both overall and per category
for metrics, category in zip(all_metrics, all_categories):
for metric_name, value in metrics.items():
aggregates[metric_name].append(value)
category_aggregates[category][metric_name].append(value)
# Calculate statistics for overall metrics
results = {
"overall": {}
}
for metric_name, values in aggregates.items():
results["overall"][metric_name] = {
'mean': statistics.mean(values),
'std': statistics.stdev(values) if len(values) > 1 else 0.0,
'median': statistics.median(values),
'min': min(values),
'max': max(values),
'count': len(values)
}
# Calculate statistics for each category
for category in sorted(category_aggregates.keys()):
results[f"category_{category}"] = {}
for metric_name, values in category_aggregates[category].items():
if values: # Only calculate if we have values for this category
results[f"category_{category}"][metric_name] = {
'mean': statistics.mean(values),
'std': statistics.stdev(values) if len(values) > 1 else 0.0,
'median': statistics.median(values),
'min': min(values),
'max': max(values),
'count': len(values)
}
return results