Skip to content

Commit

Permalink
Add rescore (#186)
Browse files Browse the repository at this point in the history
* Add rescore

* Add rescore

* Call rescore from rerank

* Sort in reverse order

* Address comments
  • Loading branch information
saileshnankani authored May 5, 2021
1 parent 0dcabff commit 95b3da7
Show file tree
Hide file tree
Showing 7 changed files with 26 additions and 27 deletions.
12 changes: 6 additions & 6 deletions pygaggle/model/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,8 @@ def evaluate(self,
examples: List[RelevanceExample]) -> List[MetricAccumulator]:
metrics = [cls() for cls in self.metrics]
for example in tqdm(examples, disable=not self.use_tqdm):
scores = [x.score for x in self.reranker.rerank(example.query,
example.documents)]
scores = [x.score for x in self.reranker.rescore(example.query,
example.documents)]
if self.writer is not None:
self.writer.write(scores, example)
for metric in metrics:
Expand All @@ -178,7 +178,7 @@ def evaluate_by_segments(self,
segment_processor = SegmentProcessor()
for example in tqdm(examples, disable=not self.use_tqdm):
segment_group = segment_processor.segment(example.documents, seg_size, stride)
segment_group.segments = self.reranker.rerank(example.query, segment_group.segments)
segment_group.segments = self.reranker.rescore(example.query, segment_group.segments)
doc_scores = [x.score for x in segment_processor.aggregate(example.documents,
segment_group,
aggregate_method)]
Expand Down Expand Up @@ -210,12 +210,12 @@ def evaluate(self,
mono_texts = []
scores = []
for ct, example in tqdm(enumerate(examples), total=len(examples), disable=not self.use_tqdm):
mono_out = self.mono_reranker.rerank(example.query, example.documents)
mono_out = self.mono_reranker.rescore(example.query, example.documents)
mono_texts.append(sorted(enumerate(mono_out), key=lambda x: x[1].score, reverse=True)[:self.mono_hits])
scores.append(np.array([x.score for x in mono_out]))
for ct, texts in tqdm(enumerate(mono_texts), total=len(mono_texts), disable=not self.use_tqdm):
duo_in = list(map(lambda x: x[1], texts))
duo_scores = [x.score for x in self.duo_reranker.rerank(examples[ct].query, duo_in)]
duo_scores = [x.score for x in self.duo_reranker.rescore(examples[ct].query, duo_in)]

scores[ct][list(map(lambda x: x[0], texts))] = duo_scores
if self.writer is not None:
Expand All @@ -233,7 +233,7 @@ def evaluate_by_segments(self,
segment_processor = SegmentProcessor()
for example in tqdm(examples, disable=not self.use_tqdm):
segment_group = segment_processor.segment(example.documents, seg_size, stride)
segment_group.segments = self.reranker.rerank(example.query, segment_group.segments)
segment_group.segments = self.reranker.rescore(example.query, segment_group.segments)
doc_scores = [x.score for x in segment_processor.aggregate(example.documents,
segment_group,
aggregate_method)]
Expand Down
10 changes: 9 additions & 1 deletion pygaggle/rerank/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List, Union, Optional, Mapping, Any
from copy import deepcopy
import abc

from pyserini.search import JSimpleSearcherResult
Expand All @@ -21,6 +22,7 @@ class Query:
id : Optional[str]
The query id.
"""

def __init__(self, text: str, id: Optional[str] = None):
self.text = text
self.id = id
Expand Down Expand Up @@ -63,8 +65,14 @@ class Reranker:
A reranker takes a list texts and returns a list of texts non-destructively
(i.e., does not alter the original input list of texts).
"""
@abc.abstractmethod

def rerank(self, query: Query, texts: List[Text]) -> List[Text]:
"""Sorts a list of texts
"""
return sorted(self.rescore(query, texts), key=lambda x: x.score, reverse=True)

@abc.abstractmethod
def rescore(self, query: Query, texts: List[Text]) -> List[Text]:
"""Reranks a list of texts with respect to a query.
Parameters
Expand Down
4 changes: 2 additions & 2 deletions pygaggle/rerank/bm25.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(self,
self.use_corpus_estimator = True
self.index_utils = IndexReader(index_path)

def rerank(self, query: Query, texts: List[Text]) -> List[Text]:
def rescore(self, query: Query, texts: List[Text]) -> List[Text]:
query_words = self.analyzer.analyze(query.text)
sentences = list(map(self.analyzer.analyze, (t.text for t in texts)))

Expand All @@ -45,7 +45,7 @@ def rerank(self, query: Query, texts: List[Text]) -> List[Text]:
if self.use_corpus_estimator:
idfs = {w:
self.index_utils.compute_bm25_term_weight(
text.metadata['docid'], w) for w in tf}
text.metadata['docid'], w) for w in tf}
score = sum(idfs[w] * tf[w] * (self.k1 + 1) /
(tf[w] + self.k1 * (1 - self.b + self.b *
(d_len / mean_len))) for w in tf)
Expand Down
2 changes: 1 addition & 1 deletion pygaggle/rerank/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@


class IdentityReranker(Reranker):
def rerank(self, query: Query, texts: List[Text]) -> List[Text]:
def rescore(self, query: Query, texts: List[Text]) -> List[Text]:
return texts
2 changes: 1 addition & 1 deletion pygaggle/rerank/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class RandomReranker(Reranker):
def __init__(self, seed: int = 0):
self.rand = random.Random(seed)

def rerank(self, query: Query, texts: List[Text]) -> List[Text]:
def rescore(self, query: Query, texts: List[Text]) -> List[Text]:
texts = deepcopy(texts)
for text in texts:
text.score = self.rand.random()
Expand Down
21 changes: 6 additions & 15 deletions pygaggle/rerank/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def get_tokenizer(pretrained_model_name_or_path: str = 't5-base',
batch_size=batch_size
)

def rerank(self, query: Query, texts: List[Text]) -> List[Text]:
def rescore(self, query: Query, texts: List[Text]) -> List[Text]:
texts = deepcopy(texts)
batch_input = QueryDocumentBatch(query=query, documents=texts)
for batch in self.tokenizer.traverse_query_document(batch_input):
Expand All @@ -72,7 +72,6 @@ def rerank(self, query: Query, texts: List[Text]) -> List[Text]:
for doc, score in zip(batch.documents, batch_log_probs):
doc.score = score

texts.sort(key=lambda x: x.score, reverse=True)
return texts


Expand Down Expand Up @@ -100,7 +99,7 @@ def get_tokenizer(pretrained_model_name_or_path: str = 't5-base',
batch_size=batch_size
)

def rerank(self, query: Query, texts: List[Text]) -> List[Text]:
def rescore(self, query: Query, texts: List[Text]) -> List[Text]:
texts = deepcopy(texts)
doc_pairs = list(permutations(texts, 2))
scores = defaultdict(float)
Expand All @@ -125,8 +124,6 @@ def rerank(self, query: Query, texts: List[Text]) -> List[Text]:
for text in texts:
text.score = scores[text.metadata['docid']]

texts.sort(key=lambda x: x.score, reverse=True)

return texts


Expand Down Expand Up @@ -155,7 +152,7 @@ def __init__(self,
self.argmax_only = argmax_only

@torch.no_grad()
def rerank(self, query: Query, texts: List[Text]) -> List[Text]:
def rescore(self, query: Query, texts: List[Text]) -> List[Text]:
encoded_query = self.encoder.encode_single(query)
encoded_documents = self.encoder.encode(texts)
texts = deepcopy(texts)
Expand All @@ -174,8 +171,6 @@ def rerank(self, query: Query, texts: List[Text]) -> List[Text]:
if text.score != max_score:
text.score = max_score - 10000

texts.sort(key=lambda x: x.score, reverse=True)

return texts


Expand All @@ -201,7 +196,7 @@ def get_tokenizer(pretrained_model_name_or_path: str = 'bert-large-uncased',
return AutoTokenizer.from_pretrained(pretrained_model_name_or_path, use_fast=False, *args, **kwargs)

@torch.no_grad()
def rerank(self, query: Query, texts: List[Text]) -> List[Text]:
def rescore(self, query: Query, texts: List[Text]) -> List[Text]:
texts = deepcopy(texts)
for text in texts:
ret = self.tokenizer.encode_plus(query.text,
Expand All @@ -215,12 +210,10 @@ def rerank(self, query: Query, texts: List[Text]) -> List[Text]:
output, = self.model(input_ids, token_type_ids=tt_ids, return_dict=False)
if output.size(1) > 1:
text.score = torch.nn.functional.log_softmax(
output, 1)[0, -1].item()
output, 1)[0, -1].item()
else:
text.score = output.item()

texts.sort(key=lambda x: x.score, reverse=True)

return texts


Expand All @@ -231,7 +224,7 @@ def __init__(self, model: PreTrainedModel, tokenizer: PreTrainedTokenizer):
self.device = next(model.parameters()).device

@torch.no_grad()
def rerank(self, query: Query, texts: List[Text]) -> List[Text]:
def rescore(self, query: Query, texts: List[Text]) -> List[Text]:
texts = deepcopy(texts)
for text in texts:
ret = self.tokenizer.encode_plus(query.text,
Expand All @@ -253,6 +246,4 @@ def rerank(self, query: Query, texts: List[Text]) -> List[Text]:
emax_val, emax_idx = end_scores.max(0)
text.score = max(smax_val.item(), emax_val.item())

texts.sort(key=lambda x: x.score, reverse=True)

return texts
2 changes: 1 addition & 1 deletion tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_basic(self):
identity_reranker = IdentityReranker()
self.assertTrue(isinstance(identity_reranker, Reranker))

output = identity_reranker.rerank(query, texts)
output = identity_reranker.rescore(query, texts)

# Check that reranked output is indeed the same as the input
for i in range(0, len(hits)):
Expand Down

0 comments on commit 95b3da7

Please sign in to comment.