Skip to content

Commit

Permalink
Fix Pyserini compatibility issues (#71)
Browse files Browse the repository at this point in the history
  • Loading branch information
justinborromeo authored Sep 8, 2020
1 parent ae2dfc5 commit 96a7e8d
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 8 deletions.
5 changes: 3 additions & 2 deletions docs/experiments-CovidQA.md
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,10 @@ mrr 0.37988285486956513
mrr@10 0.3671336788683727
```

It takes about 17 minutes to re-rank this subset on CovidQA using a P100.
It takes about 17 minutes to re-rank this subset on CovidQA using a P100. It is worth noting again that you might need to modify the batch size to best fit the GPU at hand (--batch-size={BATCH_SIZE}).

If you were able to replicate these results, please submit a PR adding to the replication log!


## Replication Log
## Replication Log

2 changes: 1 addition & 1 deletion pygaggle/data/relevance.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def unfold(entries):

class MsMarcoPassageLoader:
def __init__(self, index_path: str):
self.searcher = pysearch.SimpleSearcher(index_path)
self.searcher = SimpleSearcher(index_path)

def load_passage(self, id: str) -> MsMarcoPassage:
try:
Expand Down
2 changes: 1 addition & 1 deletion pygaggle/data/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def segment(self, documents: List[Text], seg_size: int, stride: int) -> SegmentG
sentences = [sent.string.strip() for sent in doc.sents]
for i in range(0, len(sentences), stride):
segment_text = ' '.join(sentences[i:i + seg_size])
segmented_doc.append(Text(segment_text, dict(docid=document.raw["docid"])))
segmented_doc.append(Text(segment_text, dict(docid=document.metadata["docid"])))
if i + seg_size >= len(sentences):
end_idx += i/stride + 1
doc_end_indexes.append(int(end_idx))
Expand Down
8 changes: 4 additions & 4 deletions pygaggle/rerank/bm25.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from typing import List
import math

from pyserini.analysis.pyanalysis import get_lucene_analyzer, Analyzer
from pyserini.index.pyutils import IndexReaderUtils
from pyserini.analysis import get_lucene_analyzer, Analyzer
from pyserini.index import IndexReader
import numpy as np

from .base import Reranker, Query, Text
Expand All @@ -24,7 +24,7 @@ def __init__(self,
self.analyzer = Analyzer(get_lucene_analyzer())
if index_path:
self.use_corpus_estimator = True
self.index_utils = IndexReaderUtils(index_path)
self.index_utils = IndexReader(index_path)

def rerank(self, query: Query, texts: List[Text]) -> List[Text]:
query_words = self.analyzer.analyze(query.text)
Expand All @@ -45,7 +45,7 @@ def rerank(self, query: Query, texts: List[Text]) -> List[Text]:
if self.use_corpus_estimator:
idfs = {w:
self.index_utils.compute_bm25_term_weight(
text.raw['docid'], w) for w in tf}
text.metadata['docid'], w) for w in tf}
score = sum(idfs[w] * tf[w] * (self.k1 + 1) /
(tf[w] + self.k1 * (1 - self.b + self.b *
(d_len / mean_len))) for w in tf)
Expand Down

0 comments on commit 96a7e8d

Please sign in to comment.