From 07bd3c50ead83ec934fcc9eaf425b7dca44add79 Mon Sep 17 00:00:00 2001 From: Timo Moeller Date: Thu, 12 Aug 2021 14:31:48 +0200 Subject: [PATCH] Add new QA eval metric: Semantic Answer Similarity (SAS) (#1338) * init * Add type annotation * Add test case, fix mypy * Add german model to docstring Co-authored-by: Malte Pietsch --- haystack/eval.py | 286 ++++++++++++------------------ test/test_eval.py | 9 +- tutorials/Tutorial5_Evaluation.py | 2 +- 3 files changed, 127 insertions(+), 170 deletions(-) diff --git a/haystack/eval.py b/haystack/eval.py index 42dc65684e..7fe9e3dedb 100644 --- a/haystack/eval.py +++ b/haystack/eval.py @@ -1,5 +1,9 @@ from typing import List, Tuple, Dict, Any, Optional import logging +from transformers import AutoConfig +from sentence_transformers import SentenceTransformer, CrossEncoder +from sklearn.metrics.pairwise import cosine_similarity +import numpy as np from haystack import MultiLabel, Label @@ -148,18 +152,34 @@ class EvalAnswers: open vs closed domain eval (https://haystack.deepset.ai/docs/latest/tutorial5md). """ - def __init__(self, skip_incorrect_retrieval: bool=True, open_domain: bool=True, debug: bool=False): + def __init__(self, + skip_incorrect_retrieval: bool = True, + open_domain: bool = True, + sas_model: str = None, + debug: bool = False, + ): """ :param skip_incorrect_retrieval: When set to True, this eval will ignore the cases where the retriever returned no correct documents :param open_domain: When True, extracted answers are evaluated purely on string similarity rather than the position of the extracted answer + :param sas_model: Name or path of "Semantic Answer Similarity (SAS) model". When set, the model will be used to calculate similarity between predictions and labels and generate the SAS metric. + The SAS metric correlates better with human judgement of correct answers as it does not rely on string overlaps. + Example: Prediction = "30%", Label = "thirty percent", EM and F1 would be overly pessimistic with both being 0, while SAS paints a more realistic picture. + Models: + - You can use Bi Encoders (sentence transformers) or cross encoders trained on Semantic Textual Similarity (STS) data. + Not all cross encoders can be used because of different return types. + If you use custom cross encoders please make sure they work with sentence_transformers.CrossEncoder class + - Good default for multiple languages: "sentence-transformers/paraphrase-multilingual-mpnet-base-v2" + - Large, powerful, but slow model for English only: "cross-encoder/stsb-roberta-large" + - Large model for German only: "deepset/gbert-large-sts" :param debug: When True, a record of each sample and its evaluation will be stored in EvalAnswers.log """ self.outgoing_edges = 1 - self.init_counts() self.log: List = [] self.debug = debug self.skip_incorrect_retrieval = skip_incorrect_retrieval self.open_domain = open_domain + self.sas_model = sas_model + self.init_counts() def init_counts(self): self.query_count = 0 @@ -176,6 +196,11 @@ def init_counts(self): self.top_k_em = 0.0 self.top_1_f1 = 0.0 self.top_k_f1 = 0.0 + if self.sas_model is not None: + self.top_1_sas_sum = 0 + self.top_k_sas_sum = 0 + self.top_1_sas = 0.0 + self.top_k_sas = 0.0 def run(self, labels, answers, **kwargs): """Run this node on one sample and its labels""" @@ -201,12 +226,27 @@ def run(self, labels, answers, **kwargs): self.has_answer_count += 1 predictions = [p for p in predictions if p["answer"]] top_1_em, top_1_f1, top_k_em, top_k_f1 = self.evaluate_extraction(multi_labels, predictions) + + # Compute Semantic Answer Similarity if model is supplied + if self.sas_model is not None: + # sas works on batches, so we pack the labels into a list of lists, and unpack the return values as well + gold_labels = [multi_labels.multiple_answers] + predictions_list = [[p["answer"] for p in predictions]] + top_1_sas, top_k_sas = semantic_answer_similarity( + predictions=predictions_list, + gold_labels=gold_labels, + sas_model_name_or_path=self.sas_model) + self.top_1_sas_sum += top_1_sas[0] + self.top_k_sas_sum += top_k_sas[0] + if self.debug: self.log.append({"predictions": predictions, "gold_labels": multi_labels, "top_k_f1": top_k_f1, "top_k_em": top_k_em }) + if self.sas_model: + self.log[-1].update({"top_k_sas":top_k_sas}) self.top_1_em_count += top_1_em self.top_1_f1_sum += top_1_f1 @@ -233,6 +273,9 @@ def update_has_answer_metrics(self): self.top_k_em = self.top_k_em_count / self.has_answer_count self.top_1_f1 = self.top_1_f1_sum / self.has_answer_count self.top_k_f1 = self.top_k_f1_sum / self.has_answer_count + if self.sas_model is not None: + self.top_1_sas = self.top_1_sas_sum / self.has_answer_count + self.top_k_sas = self.top_k_sas_sum / self.has_answer_count def update_no_answer_metrics(self): self.top_1_no_answer = self.top_1_no_answer_count / self.no_answer_count @@ -248,6 +291,9 @@ def print(self, mode): print(f"top k EM: {self.top_k_em:.4f}") print(f"top 1 F1: {self.top_1_f1:.4f}") print(f"top k F1: {self.top_k_f1:.4f}") + if self.sas_model is not None: + print(f"top 1 SAS: {self.top_1_sas:.4f}") + print(f"top k SAS: {self.top_k_sas:.4f}") if self.no_answer_count: print() print(f"no_answer queries: {self.no_answer_count}") @@ -266,11 +312,17 @@ def print(self, mode): print(f"top k EM: {pipeline_top_k_em:.4f}") print(f"top 1 F1: {pipeline_top_1_f1:.4f}") print(f"top k F1: {pipeline_top_k_f1:.4f}") + if self.sas_model is not None: + pipeline_top_1_sas = (self.top_1_sas_sum + self.top_1_no_answer_count) / self.query_count + pipeline_top_k_sas = (self.top_k_sas_sum + self.no_answer_count) / self.query_count + print(f"top 1 SAS: {pipeline_top_1_sas:.4f}") + print(f"top k SAS: {pipeline_top_k_sas:.4f}") if self.no_answer_count: print( "(top k results are likely inflated since the Reader always returns a no_answer prediction in its top k)" ) + def get_label(labels, node_id): if type(labels) in [Label, MultiLabel]: ret = labels @@ -279,6 +331,7 @@ def get_label(labels, node_id): ret = labels[node_id] return ret + def calculate_em_str_multi(gold_labels, prediction): for gold_label in gold_labels: result = calculate_em_str(gold_label, prediction) @@ -295,176 +348,72 @@ def calculate_f1_str_multi(gold_labels, prediction): return max(results) -def calculate_reader_metrics(metric_counts: Dict[str, float], correct_retrievals: int): - number_of_has_answer = correct_retrievals - metric_counts["number_of_no_answer"] - - metrics = { - "reader_top1_accuracy" : metric_counts["correct_readings_top1"] / correct_retrievals, - "reader_top1_accuracy_has_answer" : metric_counts["correct_readings_top1_has_answer"] / number_of_has_answer, - "reader_topk_accuracy" : metric_counts["correct_readings_topk"] / correct_retrievals, - "reader_topk_accuracy_has_answer" : metric_counts["correct_readings_topk_has_answer"] / number_of_has_answer, - "reader_top1_em" : metric_counts["exact_matches_top1"] / correct_retrievals, - "reader_top1_em_has_answer" : metric_counts["exact_matches_top1_has_answer"] / number_of_has_answer, - "reader_topk_em" : metric_counts["exact_matches_topk"] / correct_retrievals, - "reader_topk_em_has_answer" : metric_counts["exact_matches_topk_has_answer"] / number_of_has_answer, - "reader_top1_f1" : metric_counts["summed_f1_top1"] / correct_retrievals, - "reader_top1_f1_has_answer" : metric_counts["summed_f1_top1_has_answer"] / number_of_has_answer, - "reader_topk_f1" : metric_counts["summed_f1_topk"] / correct_retrievals, - "reader_topk_f1_has_answer" : metric_counts["summed_f1_topk_has_answer"] / number_of_has_answer, - } - - if metric_counts["number_of_no_answer"]: - metrics["reader_top1_no_answer_accuracy"] = metric_counts["correct_no_answers_top1"] / metric_counts[ - "number_of_no_answer"] - metrics["reader_topk_no_answer_accuracy"] = metric_counts["correct_no_answers_topk"] / metric_counts[ - "number_of_no_answer"] - else: - metrics["reader_top1_no_answer_accuracy"] = None # type: ignore - metrics["reader_topk_no_answer_accuracy"] = None # type: ignore - - return metrics - - -def calculate_average_precision_and_reciprocal_rank(questions_with_docs: List[dict]): - questions_with_correct_doc = [] - summed_avg_precision_retriever = 0.0 - summed_reciprocal_rank_retriever = 0.0 - - for question in questions_with_docs: - number_relevant_docs = len(set(question["question"].multiple_document_ids)) - found_relevant_doc = False - relevant_docs_found = 0 - current_avg_precision = 0.0 - for doc_idx, doc in enumerate(question["docs"]): - # check if correct doc among retrieved docs - if doc.id in question["question"].multiple_document_ids: - if not found_relevant_doc: - summed_reciprocal_rank_retriever += 1 / (doc_idx + 1) - relevant_docs_found += 1 - found_relevant_doc = True - current_avg_precision += relevant_docs_found / (doc_idx + 1) - if relevant_docs_found == number_relevant_docs: - break - if found_relevant_doc: - all_relevant_docs = len(set(question["question"].multiple_document_ids)) - summed_avg_precision_retriever += current_avg_precision / all_relevant_docs - - if found_relevant_doc: - questions_with_correct_doc.append({ - "question": question["question"], - "docs": question["docs"] - }) - - return questions_with_correct_doc, summed_avg_precision_retriever, summed_reciprocal_rank_retriever - - -def eval_counts_reader(question: MultiLabel, predicted_answers: Dict[str, Any], metric_counts: Dict[str, float]): - # Calculates evaluation metrics for one question and adds results to counter. - # check if question is answerable - if not question.no_answer: - found_answer = False - found_em = False - best_f1 = 0 - for answer_idx, answer in enumerate(predicted_answers["answers"]): - if answer["document_id"] in question.multiple_document_ids: - gold_spans = [{"offset_start": question.multiple_offset_start_in_docs[i], - "offset_end": question.multiple_offset_start_in_docs[i] + len(question.multiple_answers[i]), - "doc_id": question.multiple_document_ids[i]} for i in range(len(question.multiple_answers))] # type: ignore - predicted_span = {"offset_start": answer["offset_start_in_doc"], - "offset_end": answer["offset_end_in_doc"], - "doc_id": answer["document_id"]} - best_f1_in_gold_spans = 0 - for gold_span in gold_spans: - if gold_span["doc_id"] == predicted_span["doc_id"]: - # check if overlap between gold answer and predicted answer - if not found_answer: - metric_counts, found_answer = _count_overlap(gold_span, predicted_span, metric_counts, answer_idx) # type: ignore - - # check for exact match - if not found_em: - metric_counts, found_em = _count_exact_match(gold_span, predicted_span, metric_counts, answer_idx) # type: ignore - - # calculate f1 - current_f1 = _calculate_f1(gold_span, predicted_span) # type: ignore - if current_f1 > best_f1_in_gold_spans: - best_f1_in_gold_spans = current_f1 - # top-1 f1 - if answer_idx == 0: - metric_counts["summed_f1_top1"] += best_f1_in_gold_spans - metric_counts["summed_f1_top1_has_answer"] += best_f1_in_gold_spans - if best_f1_in_gold_spans > best_f1: - best_f1 = best_f1_in_gold_spans - - if found_em: - break - # top-k answers: use best f1-score - metric_counts["summed_f1_topk"] += best_f1 - metric_counts["summed_f1_topk_has_answer"] += best_f1 - - # question not answerable - else: - metric_counts["number_of_no_answer"] += 1 - metric_counts = _count_no_answer(predicted_answers["answers"], metric_counts) +def semantic_answer_similarity(predictions: List[List[str]], + gold_labels: List[List[str]], + sas_model_name_or_path: str = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2" + ) -> Tuple[List[float],List[float]]: + """ + Computes Transformer-based similarity of predicted answer to gold labels to derive a more meaningful metric than EM or F1. + Returns per QA pair a) the similarity of the most likely prediction (top 1) to all available gold labels + b) the highest similarity of all predictions to gold labels - return metric_counts + :param predictions: Predicted answers as list of multiple preds per question + :param gold_labels: Labels as list of multiple possible answers per question + :param sas_model_name_or_path: SentenceTransformers semantic textual similarity model, should be path or string + pointing to downloadable models. -def eval_counts_reader_batch(pred: Dict[str, Any], metric_counts: Dict[str, float]): - # Calculates evaluation metrics for one question and adds results to counter. - - # check if question is answerable - if not pred["label"].no_answer: - found_answer = False - found_em = False - best_f1 = 0 - for answer_idx, answer in enumerate(pred["answers"]): - # check if correct document: - if answer["document_id"] in pred["label"].multiple_document_ids: - gold_spans = [{"offset_start": pred["label"].multiple_offset_start_in_docs[i], - "offset_end": pred["label"].multiple_offset_start_in_docs[i] + len(pred["label"].multiple_answers[i]), - "doc_id": pred["label"].multiple_document_ids[i]} - for i in range(len(pred["label"].multiple_answers))] # type: ignore - predicted_span = {"offset_start": answer["offset_start_in_doc"], - "offset_end": answer["offset_end_in_doc"], - "doc_id": answer["document_id"]} - - best_f1_in_gold_spans = 0 - for gold_span in gold_spans: - if gold_span["doc_id"] == predicted_span["doc_id"]: - # check if overlap between gold answer and predicted answer - if not found_answer: - metric_counts, found_answer = _count_overlap( - gold_span, predicted_span, metric_counts, answer_idx - ) - # check for exact match - if not found_em: - metric_counts, found_em = _count_exact_match( - gold_span, predicted_span, metric_counts, answer_idx - ) - # calculate f1 - current_f1 = _calculate_f1(gold_span, predicted_span) - if current_f1 > best_f1_in_gold_spans: - best_f1_in_gold_spans = current_f1 - # top-1 f1 - if answer_idx == 0: - metric_counts["summed_f1_top1"] += best_f1_in_gold_spans - metric_counts["summed_f1_top1_has_answer"] += best_f1_in_gold_spans - if best_f1_in_gold_spans > best_f1: - best_f1 = best_f1_in_gold_spans - - if found_em: - break - - # top-k answers: use best f1-score - metric_counts["summed_f1_topk"] += best_f1 - metric_counts["summed_f1_topk_has_answer"] += best_f1 - - # question not answerable + :return top_1_sas, top_k_sas + """ + assert len(predictions) == len(gold_labels) + + config = AutoConfig.from_pretrained(sas_model_name_or_path) + cross_encoder_used = False + if config.architectures is not None: + cross_encoder_used = any([arch.endswith('ForSequenceClassification') for arch in config.architectures]) + + # Compute similarities + top_1_sas = [] + top_k_sas = [] + + # Based on Modelstring we can load either Bi-Encoders or Cross Encoders. + # Similarity computation changes for both approaches + if cross_encoder_used: + model = CrossEncoder(sas_model_name_or_path) + for preds, labels in zip (predictions,gold_labels): + # TODO add efficient batch mode: put all texts and labels into grid and extract scores afterwards + grid = [] + for p in preds: + for l in labels: + grid.append((p,l)) + scores = model.predict(grid) + top_1_sas.append(np.max(scores[:len(labels)])) + top_k_sas.append(np.max(scores)) else: - metric_counts["number_of_no_answer"] += 1 - metric_counts = _count_no_answer(pred["answers"], metric_counts) - - return metric_counts + # For Bi-encoders we can flatten predictions and labels into one list + model = SentenceTransformer(sas_model_name_or_path) + lengths: List[Tuple[int,int]] = [] + all_texts: List[str] = [] + for p, l in zip(predictions, gold_labels): # type: ignore + # TODO potentially exclude (near) exact matches from computations + all_texts.extend(p) + all_texts.extend(l) + lengths.append((len(p), len(l))) + # then compute embeddings + embeddings = model.encode(all_texts) + + # then select which embeddings will be used for similarity computations + current_position = 0 + for i, (len_p, len_l) in enumerate(lengths): + pred_embeddings = embeddings[current_position:current_position + len_p, :] + current_position += len_p + label_embeddings = embeddings[current_position:current_position + len_l, :] + current_position += len_l + sims = cosine_similarity(pred_embeddings, label_embeddings) + top_1_sas.append(np.max(sims[0, :])) + top_k_sas.append(np.max(sims)) + + return top_1_sas, top_k_sas def _count_overlap( @@ -554,3 +503,4 @@ def _count_no_answer(answers: List[dict], metric_counts: Dict[str, float]): break return metric_counts + diff --git a/test/test_eval.py b/test/test_eval.py index 1842c0ae7e..287fd33a9e 100644 --- a/test/test_eval.py +++ b/test/test_eval.py @@ -106,7 +106,9 @@ def test_eval_pipeline(document_store: BaseDocumentStore, reader, retriever): labels = document_store.get_all_labels_aggregated(index="haystack_test_feedback") eval_retriever = EvalDocuments() - eval_reader = EvalAnswers() + eval_reader = EvalAnswers(sas_model="sentence-transformers/paraphrase-MiniLM-L3-v2",debug=True) + eval_reader_cross = EvalAnswers(sas_model="cross-encoder/stsb-TinyBERT-L-4",debug=True) + eval_reader_vanila = EvalAnswers() assert document_store.get_document_count(index="haystack_test_eval_document") == 2 p = Pipeline() @@ -114,6 +116,8 @@ def test_eval_pipeline(document_store: BaseDocumentStore, reader, retriever): p.add_node(component=eval_retriever, name="EvalDocuments", inputs=["ESRetriever"]) p.add_node(component=reader, name="QAReader", inputs=["EvalDocuments"]) p.add_node(component=eval_reader, name="EvalAnswers", inputs=["QAReader"]) + p.add_node(component=eval_reader_cross, name="EvalAnswers_cross", inputs=["QAReader"]) + p.add_node(component=eval_reader_vanila, name="EvalAnswers_vanilla", inputs=["QAReader"]) for l in labels: res = p.run( query=l.question, @@ -125,6 +129,9 @@ def test_eval_pipeline(document_store: BaseDocumentStore, reader, retriever): assert eval_retriever.recall == 1.0 assert round(eval_reader.top_k_f1, 4) == 0.8333 assert eval_reader.top_k_em == 0.5 + assert round(eval_reader.top_k_sas, 3) == 0.800 + assert round(eval_reader_cross.top_k_sas, 3) == 0.671 + assert eval_reader.top_k_em == eval_reader_vanila.top_k_em @pytest.mark.elasticsearch def test_eval_data_split_word(document_store): diff --git a/tutorials/Tutorial5_Evaluation.py b/tutorials/Tutorial5_Evaluation.py index 1c8f2dcc67..3af9f93fe3 100644 --- a/tutorials/Tutorial5_Evaluation.py +++ b/tutorials/Tutorial5_Evaluation.py @@ -99,7 +99,7 @@ def tutorial5_evaluation(): # Here we initialize the nodes that perform evaluation eval_retriever = EvalDocuments() - eval_reader = EvalAnswers() + eval_reader = EvalAnswers(sas_model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2") ## Evaluate Retriever on its own in closed domain fashion