In a previous guide, we introduced a conceptual framework for a representational approach to information retrieval that integrates dense and sparse representations into the same underlying (bi-encoder) architecture. This guide offers a deeper dive that connects the high-level concepts with the actual code implementation.
If you're a Waterloo student traversing the onboarding path (which starts here), make sure you've first done the previous step, reproducing a dense retrieval baseline for NFCorpus. In general, don't try to rush through this guide by just blindly copying and pasting commands into a shell; that's what I call cargo culting. Instead, really try to understand what's going on.
Following the onboarding path, this lesson does not introduce any new concepts. Rather, the focus is to solidify previously introduced concepts and to connect the bi-encoder architecture to implementations in Pyserini. Informally, we're "peeling back the covers".
Learning outcomes for this guide, building on previous steps in the onboarding path, are divided into three parts. With respect to dense retrieval models:
- Be able to materialize and inspect dense vectors stored in Faiss.
- Be able to encode documents and queries with the BGE-base model and manipulate the resulting vector representations.
- Be able to compute query-document scores (i.e., retrieval scores) "by hand" for dense retrieval, by directly manipulating the vectors.
- Be able to perform retrieval "by hand" given a query, by directly manipulating the document vectors stored in the index.
With respect to sparse (i.e., bag-of-words) retrieval models:
- Be able to materialize and inspect BM25 document vectors from a Lucene inverted index.
- Be able to compute query-document scores (i.e., retrieval scores) "by hand" for bag-of-words retrieval, by directly manipulating the vectors.
- Be able to perform retrieval "by hand" given a query, by directly manipulating the document vectors materialized from the inverted index.
And putting the two together:
- Understand how dense retrieval and sparse (bag-of-words) retrieval are different realizations of the same bi-encoder architecture.
- Be able to connect key concepts in the bi-encoder architecture to Pyserini implementations.
- Be able to "trace" retrieval with dense and sparse representations through the encoding and top-k retrieval phases.
As a recap from here, this is the "core retrieval" problem that we're trying to solve:
Given an information need expressed as a query q, the text retrieval task is to return a ranked list of k texts {d1, d2 ... dk} from an arbitrarily large but finite collection of texts C = {di} that maximizes a metric of interest, for example, nDCG, AP, etc.
And this is the bi-encoder architecture for tackling the above challenge:
It's all about representations! BM25 generates bag-of-words sparse lexical vectors where the terms are assigned BM25 weights in an unsupervised manner. Contriever and BGE-base, which are examples of dense retrieval models, use transformer-based encoders, trained on large amounts of supervised data, that generate dense vectors.
Let's start by first peeking inside the Faiss index we built:
import faiss
index = faiss.read_index('indexes/nfcorpus.bge-base-en-v1.5/index')
num_vectors = index.ntotal
Try it with Contriever:
import faiss
index_c = faiss.read_index('indexes/faiss.nfcorpus.contriever-msmacro/index')
num_vectors_c = index_c.ntotal
We see, from num_vectors
, that there are 3633 vectors in this index.
That's a vector (or alternatively, embedding) for each document.
We can print out first 10 vectors:
for i in range(10):
vector = index.reconstruct(i)
print(f"Vector {i}: {vector}")
Contriever:
for i in range(10):
vector_c = index_c.reconstruct(i)
print(f"Vector {i}: {vector_c}")
Pyserini stores the docid
corresponding to each vector separately.
In the code snippet below, we load in the mapping data and then look up the vector corresponding to MED-4555
.
docids = []
with open('indexes/nfcorpus.bge-base-en-v1.5/docid', 'r') as fin:
docids = [line.rstrip() for line in fin.readlines()]
v1 = index.reconstruct(docids.index('MED-4555'))
Contriever:
docids_c = []
with open('indexes/faiss.nfcorpus.contriever-msmacro/docid', 'r') as fin:
docids_c = [line.rstrip() for line in fin.readlines()]
v1_c = index_c.reconstruct(docids_c.index('MED-4555'))
So, v1
now holds the dense vector representation (i.e., embedding) of document MED-4555
.
Now, where did this vector come from?
Well, it's the output of the encoder.
Let's verify this by first encoding the contents of the document, which is in doc_text
:
# This is the string contents of doc MED-4555
doc_text = 'Analysis of risk factors for abdominal aortic aneurysm in a cohort of more than 3 million individuals. BACKGROUND: Abdominal aortic aneurysm (AAA) disease is an insidious condition with an 85% chance of death after rupture. Ultrasound screening can reduce mortality, but its use is advocated only for a limited subset of the population at risk. METHODS: We used data from a retrospective cohort of 3.1 million patients who completed a medical and lifestyle questionnaire and were evaluated by ultrasound imaging for the presence of AAA by Life Line Screening in 2003 to 2008. Risk factors associated with AAA were identified using multivariable logistic regression analysis. RESULTS: We observed a positive association with increasing years of smoking and cigarettes smoked and a negative association with smoking cessation. Excess weight was associated with increased risk, whereas exercise and consumption of nuts, vegetables, and fruits were associated with reduced risk. Blacks, Hispanics, and Asians had lower risk of AAA than whites and Native Americans. Well-known risk factors were reaffirmed, including male gender, age, family history, and cardiovascular disease. A predictive scoring system was created that identifies aneurysms more efficiently than current criteria and includes women, nonsmokers, and individuals aged <65 years. Using this model on national statistics of risk factors prevalence, we estimated 1.1 million AAAs in the United States, of which 569,000 are among women, nonsmokers, and individuals aged <65 years. CONCLUSIONS: Smoking cessation and a healthy lifestyle are associated with lower risk of AAA. We estimated that about half of the patients with AAA disease are not eligible for screening under current guidelines. We have created a high-yield screening algorithm that expands the target population for screening by including at-risk individuals not identified with existing screening criteria.'
from pyserini.encode import AutoDocumentEncoder
encoder = AutoDocumentEncoder('BAAI/bge-base-en-v1.5', device='cpu', pooling='mean', l2_norm=True)
v2 = encoder.encode(doc_text)
Contriever:
from pyserini.encode import AutoDocumentEncoder
encoder_c = AutoDocumentEncoder('facebook/contriever-msmarco', device='cpu', pooling='mean')
v2_c = encoder_c.encode(doc_text)
Minor detail here: the encoder is designed to work on batches of input, so the actual vector representation is v2[0]
.
We can verify that the vector we generated using the encoder is identical to the vector that is stored in the index by computing the L2 norm (which should be almost zero):
import numpy as np
np.linalg.norm(v2[0] - v1)
Contriever:
import numpy as np
np.linalg.norm(v2_c[0] - v1_c)
Let's push this further and work through a query.
Consider the query "How to Help Prevent Abdominal Aortic Aneurysms", which is PLAIN-3074
.
We can perform interactive retrieval as follows:
from pyserini.search.faiss import FaissSearcher
from pyserini.encode import AutoQueryEncoder
encoder = AutoQueryEncoder('BAAI/bge-base-en-v1.5', device='cpu', pooling='mean', l2_norm=True)
searcher = FaissSearcher('indexes/nfcorpus.bge-base-en-v1.5', encoder)
hits = searcher.search('How to Help Prevent Abdominal Aortic Aneurysms')
for i in range(0, 10):
print(f'{i+1:2} {hits[i].docid:7} {hits[i].score:.6f}')
And the result will be:
1 MED-4555 0.791379
2 MED-4560 0.710725
3 MED-4421 0.688938
4 MED-4993 0.686238
5 MED-4424 0.686214
6 MED-1663 0.682199
7 MED-3436 0.680585
8 MED-2750 0.677033
9 MED-4324 0.675772
10 MED-2939 0.674646
Contriever:
from pyserini.search.faiss import FaissSearcher
from pyserini.encode import AutoQueryEncoder
encoder_c = AutoQueryEncoder('facebook/contriever-msmarco', device='cpu', pooling='mean')
searcher_c = FaissSearcher('indexes/faiss.nfcorpus.contriever-msmacro', encoder_c)
hits_c = searcher_c.search('How to Help Prevent Abdominal Aortic Aneurysms')
for i in range(0, 10):
print(f'{i+1:2} {hits_c[i].docid:7} {hits_c[i].score:.6f}')
And the result will be:
1 MED-4555 1.472201
2 MED-3180 1.125014
3 MED-1309 1.067153
4 MED-2224 1.059536
5 MED-4423 1.038440
6 MED-4887 1.032622
7 MED-2530 1.020758
8 MED-2372 1.016142
9 MED-1006 1.013599
10 MED-2587 1.010811
Let's go ahead and encode the query, producing the query vector q_vec
:
from pyserini.encode import AutoQueryEncoder
q_encoder = AutoQueryEncoder('BAAI/bge-base-en-v1.5', device='cpu', pooling='mean', l2_norm=True)
q_vec = q_encoder.encode('How to Help Prevent Abdominal Aortic Aneurysms')
Then, we compute the dot product between the query vector q_vec
and the document vector v1
(which is the representation of document MED-4555
generated by the document encoder):
np.dot(q_vec, v1)
We should arrive at the same score as above (0.7913785
).
In other words, the query-document score (i.e., the relevance score of the document with respect to the query) is exactly the dot product of the two vector representations.
This is as expected!
Contriever:
from pyserini.encode import AutoQueryEncoder
q_encoder_c = AutoQueryEncoder('facebook/contriever-msmarco', device='cpu', pooling='mean')
q_vec_c = q_encoder_c.encode('How to Help Prevent Abdominal Aortic Aneurysms')
Then, we compute the dot product between the query vector q_vec_c
and the document vector v1_c
(which is the representation of document MED-4555
generated by the document encoder):
np.dot(q_vec_c, v1_c)
We should arrive at the same score as above (1.472201
).
We can take this a step further and manually perform retrieval by computing the dot product between the query vector and all document vectors. The corpus is small enough that this is practical:
from tqdm import tqdm
scores = []
# Iterate through all document vectors and compute dot product.
for i in tqdm(range(num_vectors)):
vector = index.reconstruct(i)
score = np.dot(q_vec, vector)
scores.append([docids[i], score])
# Sort by score descending.
scores.sort(key=lambda x: -x[1])
for s in scores[:10]:
print(f'{s[0]} {s[1]:.6f}')
In a bit more detail, we iterate through all document vectors in the index, compute its dot product with the query vector, and append the results in scores
.
After going through the entire corpus in this manner, we sort the results and print out the top-10.
This sorting operation corresponds to top-k retrieval.
We can see that the output is the same as search with FaissSearcher
above.
This is exactly as expected.
Contriever:
from tqdm import tqdm
scores_c = []
# Iterate through all document vectors and compute dot product.
for i in tqdm(range(num_vectors_c)):
vector_c = index_c.reconstruct(i)
score_c = np.dot(q_vec_c, vector_c)
scores_c.append([docids_c[i], score_c])
# Sort by score descending.
scores_c.sort(key=lambda x: -x[1])
for s in scores_c[:10]:
print(f'{s[0]} {s[1]:.6f}')
Again, the output is the same as search with FaissSearcher
above.
Now, we're going to basically do the same thing, but with BM25. The point here is to illustrate how dense and sparse retrieval are conceptually identical — they're both instantiations of the bi-encoder architecture. The primary difference is the encoder representation, i.e., the vectors that the encoders generate.
We have to start with a bit of data munging, since the Lucene indexer expects the documents in a slightly different format. Start by creating a new sub-directory:
mkdir collections/nfcorpus/pyserini-corpus
Now run the following Python script to munge the data into the right format:
import json
with open('collections/nfcorpus/pyserini-corpus/corpus.jsonl', 'w') as out:
with open('collections/nfcorpus/corpus.jsonl', 'r') as f:
for line in f:
l = json.loads(line)
s = json.dumps({'id': l['_id'], 'contents': l['title'] + ' ' + l['text']})
out.write(s + '\n')
We can now index these documents as a JsonCollection
using Pyserini:
python -m pyserini.index.lucene \
--collection JsonCollection \
--input collections/nfcorpus/pyserini-corpus/ \
--index indexes/lucene.nfcorpus \
--generator DefaultLuceneDocumentGenerator \
--storePositions --storeDocvectors --storeRaw
Perform retrieval:
python -m pyserini.search.lucene \
--index indexes/lucene.nfcorpus \
--topics collections/nfcorpus/queries.tsv \
--output runs/run.beir-bm25.nfcorpus.txt \
--hits 1000 --bm25 \
--threads 4 --batch-size 16
And evaluate the retrieval run:
python -m pyserini.eval.trec_eval \
-c -m ndcg_cut.10 collections/nfcorpus/qrels/test.qrels \
runs/run.beir-bm25.nfcorpus.txt
The expected results are:
Results:
ndcg_cut_10 all 0.3218
We can also perform retrieval interactively:
from pyserini.search.lucene import LuceneSearcher
searcher = LuceneSearcher('indexes/lucene.nfcorpus')
hits = searcher.search('How to Help Prevent Abdominal Aortic Aneurysms')
for i in range(0, 10):
print(f'{i+1:2} {hits[i].docid:7} {hits[i].score:.4f}')
The results should be as follows:
1 MED-4555 11.9305
2 MED-4423 8.4771
3 MED-3180 7.1896
4 MED-2718 6.0102
5 MED-1309 5.8181
6 MED-4424 5.7448
7 MED-1705 5.6101
8 MED-4902 5.3639
9 MED-1009 5.2533
10 MED-1512 5.2068
So far, none of this is new: We did exactly the same thing for the MS MARCO passage ranking test collection, but now we're doing it for NFCorpus.
Next, let's generate the BM25 document vector for doc MED-4555
, the same document we examined above.
from pyserini.index.lucene import LuceneIndexReader
import json
index_reader = LuceneIndexReader('indexes/lucene.nfcorpus')
tf = index_reader.get_document_vector('MED-4555')
bm25_weights = \
{term: index_reader.compute_bm25_term_weight('MED-4555', term, analyzer=None) \
for term in tf.keys()}
print(json.dumps(bm25_weights, indent=4, sort_keys=True))
The variable bm25_weights
is a Python dictionary holding the BM25 weights for the document.
We're going to now perform retrieval "by hand" with BM25, similar to what we did above with the dense retrieval model. Let's start by encoding the query, which is a multi-hot vector where the non-zero items correspond to the query terms:
from pyserini.analysis import Analyzer, get_lucene_analyzer
analyzer = Analyzer(get_lucene_analyzer())
query_tokens = analyzer.analyze('How to Help Prevent Abdominal Aortic Aneurysms')
multihot_query_weights = {k: 1 for k in query_tokens}
The variable multihot_query_weights
is a Python dictionary where the keys correspond to the query tokens, each with a value of one.
Now let's compute the dot product of the two vectors.
sum({term: bm25_weights[term] \
for term in bm25_weights.keys() & \
multihot_query_weights.keys()}.values())
The dot product is 11.9305
.
Again, this isn't anything new. We did all of this in the conceptual framework guide with MS MARCO passage; we're just now doing it on NFCorpus.
The above expression for computing a dot product — let's wrap in a Python function, and then verify it gives the same output:
def dot(q_weights, d_weights):
return sum({term: d_weights[term] \
for term in d_weights.keys() & \
q_weights.keys()}.values())
dot(multihot_query_weights, bm25_weights)
With this setup, we can now perform end-to-end retrieval for a query "by hand", by directly manipulating the index structures:
from pyserini.search.lucene import LuceneSearcher
from pyserini.index.lucene import LuceneIndexReader
from tqdm import tqdm
searcher = LuceneSearcher('indexes/lucene.nfcorpus')
index_reader = LuceneIndexReader('indexes/lucene.nfcorpus')
scores = []
# Iterate through all docids in the index.
for i in tqdm(range(0, searcher.num_docs)):
docid = searcher.doc(i).get('id')
# Reconstruct the BM25 document vector.
tf = index_reader.get_document_vector(docid)
bm25_weights = \
{term: index_reader.compute_bm25_term_weight(docid, term, analyzer=None) \
for term in tf.keys()}
# Compute and retain the query-document score.
score = dot(multihot_query_weights, bm25_weights)
scores.append([docid, score])
# Sort by score descending.
scores.sort(key=lambda x: -x[1])
for s in scores[:10]:
print(f'{s[0]} {s[1]:.4f}')
The code snippet above should be self-explanatory. We iterate through all documents, reconstruct the BM25 document vectors (as weights in a Python dictionary), compute the dot product with the query vector, and retain the scores. Once we've gone through all documents in the corpus in this manner, we sort the scores and print out the top-k.
The output should match the results from LuceneSearcher
above.
To recap, what's the point for this exercise?
- We see that dense retrieval and sparse retrieval are both instantiations of a bi-encoder architecture. The only difference is the output of the encoder representations.
- For both a dense index (Faiss) and a sparse index (Lucene), you now know how to reconstruct the document vector representations.
- For both a dense retrieval model and a sparse retrieval model, you now know how to encode a query into a query vector.
- For both a dense retrieval model and a sparse retrieval model, you know how to compute query-document scores: they're just dot products.
- Finally, for both a dense retrieval model and a sparse retrieval model, you can perform retrieval "by hand". This can be accomplished by iterating through all document vectors in the index and computing its dot product with the query vector in a brute force manner. By sorting the scores, you're performing top-k retrieval, which gives exactly the same output as
FaissSearcher
andLuceneSearcher
(although not as efficient).
Okay, that's it for this lesson.
Before you move on, however, add an entry in the "Reproduction Log" at the bottom of this page, following the same format: use yyyy-mm-dd
, make sure you're using a commit id that's on the main trunk of Pyserini, and use its 7-hexadecimal prefix for the link anchor text.
Reproduction Log*
- Results reproduced by @sahel-sh on 2023-08-07 (commit
9dab30f
) - Results reproduced by @Andrwyl on 2023-08-26 (commit
d9da49e
) - Results reproduced by @yilinjz on 2023-08-30 (commit
42b3549
) - Results reproduced by @UShivani3 on 2023-09-02 (commit
42b3549
) - Results reproduced by @Edward-J-Xu on 2023-09-05 (commit
8063322
) - Results reproduced by @mchlp on 2023-09-09 (commit
d8dc5b3
) - Results reproduced by @lucedes27 on 2023-09-10 (commit
54014af
) - Results reproduced by @MojTabaa4 on 2023-09-14 (commit
d4a829d
) - Results reproduced by @Kshama on 2023-09-24 (commit
7d18f4b
) - Results reproduced by @MelvinMo on 2023-09-24 (commit
7d18f4b
) - Results reproduced by @ksunisth on 2023-09-28 (commit
142c774
) - Results reproduced by @maizerrr on 2023-10-01 (commit
bdb9504
) - Results reproduced by @Mofetoluwa on 2023-10-02 (commit
88f1f5b
) - Results reproduced by @Stefan824 on 2023-10-04 (commit
4f3da10
) - Results reproduced by @shayanbali on 2023-10-16 (commit
f1d623c
) - Results reproduced by @gituserbs on 2023-10-19 (commit
e0a0d35
) - Results reproduced by @shakibaam on 2023-11-04 (commit
01889cc
) - Results reproduced by @gitHubAndyLee2020 on 2023-11-05 (commit
01889cc
) - Results reproduced by @Melissa1412 on 2023-11-05 (commit
acd969f
) - Results reproduced by @oscarbelda86 on 2023-11-13 (commit
086e16b
) - Results reproduced by @salinaria on 2023-11-14 (commit
086e16b
) - Results reproduced by @aliranjbari on 2023-11-15 (commit
b931e52
) - Results reproduced by @Seun-Ajayi on 2023-11-21 (commit
5d63bc5
) - Results reproduced by @AndreSlavescu on 2023-11-28 (commit
1219cdb
) - Results reproduced by @tudou0002 on 2023-11-28 (commit
723e06c
) - Results reproduced by @alimt1992 on 2023-11-29 (commit
e6700f6
) - Results reproduced by @golnooshasefi on 2023-11-29 (commit
1219cdb
) - Results reproduced by @sueszli on 2023-12-01 (commit
170e271
) - Results reproduced by @kdricci on 2023-12-01 (commit
a2049c4
) - Results reproduced by @ljk423 on 2023-12-04 (commit
35002ad
) - Results reproduced by @saharsamr on 2023-12-14 (commit
039c137
) - Results reproduced by @Panizghi on 2023-12-17 (commit
0f5db95
) - Results reproduced by @AreelKhan on 2023-12-22 (commit
f75adca
) - Results reproduced by @wu-ming233 on 2023-12-31 (commit
38a571f
) - Results reproduced by @Yuan-Hou on 2024-01-02 (commit
38a571f
) - Results reproduced by @himasheth on 2024-01-10 (commit
a6ed27e
) - Results reproduced by @Tanngent on 2024-01-13 (commit
57a00cf
) - Results reproduced by @BeginningGradeMaker on 2024-01-15 (commit
d4ea011
) - Results reproduced by @ia03 on 2024-01-18 (commit
05ee8ef
) - Results reproduced by @AlexStan0 on 2024-01-20 (commit
833ee19
) - Results reproduced by @charlie-liuu on 2024-01-23 (commit
87a120e
) - Results reproduced by @dannychn11 on 2024-01-28 (commit
2f7702f
) - Results reproduced by @ru5h16h on 2024-02-20 (commit
758eaaa
) - Results reproduced by @ASChampOmega on 2024-02-23 (commit
442e7e1
) - Results reproduced by @16BitNarwhal on 2024-02-26 (commit
19fcd3b
) - Results reproduced by @HaeriAmin on 2024-02-27 (commit
19fcd3b
) - Results reproduced by @17Melissa on 2024-03-03 (commit
a9f295f
) - Results reproduced by @devesh-002 on 2024-03-05 (commit
84c6742
) - Results reproduced by @chloeqxq on 2024-03-07 (commit
19fcd3b
) - Results reproduced by @xpbowler on 2024-03-11 (commit
19fcd3b
) - Results reproduced by @jodyz0203 on 2024-03-12 (commit
280e009
) - Results reproduced by @kxwtan on 2024-03-12 (commit
2bb342a
) - Results reproduced by @syedhuq28 on 2024-03-28 (commit
2bb342a
) - Results reproduced by @khufia on 2024-03-29 (commit
2bb342a
) - Results reproduced by @Lindaaa8 on 2024-04-02 (commit
7dda9f3
) - Results reproduced by @th13nd4n0 on 2024-04-05 (commit
df3bc6c
) - Results reproduced by @a68lin on 2024-04-12 (commit
7dda9f3
) - Results reproduced by @DanielKohn1208 on 2024-04-22 (commit
184a212
) - Results reproduced by @emadahmed19 on 2024-04-28 (commit
9db2584
) - Results reproduced by @CheranMahalingam on 2024-05-05 (commit
f817186
) - Results reproduced by @billycz8 on 2024-05-08 (commit
c945c50
) - Results reproduced by @KenWuqianhao on 2024-05-11 (commit
c945c50
) - Results reproduced by @hrouzegar on 2024-05-13 (commit
bf68fc5
) - Results reproduced by @Yuv-sue1005 on 2024-05-15 (commit '9df4015')
- Results reproduced by @RohanNankani on 2024-05-17 (commit a91ef1d)
- Results reproduced by @IR3KT4FUNZ on 2024-05-26 (commit
a6f4d6
) - Results reproduced by @bilet-13 on 2024-06-01 (commit
b0c53f3
) - Results reproduced by @SeanSong25 on 2024-06-05 (commit
b7e1da3
) - Results reproduced by @alireza-taban on 2024-06-11 (commit
d814290
) - Results reproduced by @hosnahoseini on 2024-06-18 (commit
49d8c43
) - Results reproduced by @FaizanFaisal25 on 2024-07-07 (commit
3b9d541
) - Results reproduced by @Feng-12138 on 2024-07-11(commit
3b9d541
) - Results reproduced by @XKTZ on 2024-07-13 (commit
544046e
) - Results reproduced by @MehrnazSadeghieh on 2024-07-19 (commit
26a2538
) - Results reproduced by @alireza-nasirian on 2024-07-19 (commit
544046e
) - Results reproduced by @MariaPonomarenko38 on 2024-07-19 (commit
d4509dc
) - Results reproduced by @valamuri2020 on 2024-08-02 (commit
3f81997
) - Results reproduced by @daisyyedda on 2024-08-06 (commit
d814290
) - Results reproduced by @emily-emily on 2024-08-16 (commit
1bbf7a7
) - Results reproduced by @nicoella on 2024-08-20 (commit
e65dd95
) - Results reproduced by @natek-1 on 2024-08-19 ( commit
e65dd95
) - Results reproduced by @setarehbabajani on 2024-09-01 (commit
0dd5fa7
) - Results reproduced by @anshulsc on 2024-09-07 (commit
2e4fa5d
) - Results reproduced by @r-aya on 2024-09-08 (commit
2e4fa5d
) - Results reproduced by @Amirkia1998 on 2024-09-20 (commit
83537a3
) - Results reproduced by @pjyi2147 on 2024-09-20 (commit
f511655
) - Results reproduced by @krishh-p on 2024-09-21 (commit
f511655
) - Results reproduced by @andrewxucs on 2024-09-22 (commit
dd57b7d
) - Results reproduced by @Hossein-Molaeian on 2024-09-22 (commit
bc13901
) - Results reproduced by @AhmedEssam19 on 2024-09-30 (commit
07f04d4
) - Results reproduced by @sisixili on 2024-10-01 (commit
07f04d4
) - Results reproduced by @alirezaJvh on 2024-10-05 (commit
3f76099
) - Results reproduced by @Raghav0005 on 2024-10-09 (commit
7ed8369
) - Results reproduced by @Pxlin-09 on 2024-10-26 (commit
af2d3c5
) - Results reproduced by @Samantha-Zhan on 2024-11-17 (commit
a95b0e0
) - Results reproduced by @Divyajyoti02 on 2024-11-24 (commit
f6f8ecc
) - Results reproduced by @b8zhong on 2024-11-24 (commit
778968f
) - Results reproduced by @vincent-4 on 2024-11-28 (commit
576fdaf
) - Results reproduced by @ShreyasP20 on 2024-11-28 (commit
576fdaf
) - Results reproduced by @nihalmenon on 2024-12-01 (commit
94492de
) - Results reproduced by @zdann15 on 2024-12-04 (commit
5e66e98
) - Results reproduced by @sherloc512 on 2024-12-05 (commit
5e66e98
) - Results reproduced by @Alireza-Zwolf on 2024-12-18 (commit
6cc23d5
) - Results reproduced by @Linsen-gao-457 on 2024-12-20 (commit
10606f0
) - Results reproduced by @robro612 on 2025-01-05 (commit
9268591
) - Results reproduced by @nourj98 on 2025-01-07 (commit
6ac07cc
) - Results reproduced by @mithildamani256 on 2025-01-13 (commit
ad41512
) - Results reproduced by @ezafar on 2025-01-15 (commit
e1a3386
) - Results reproduced by @ErfanSadraiye on 2025-01-16 (commit
cb14c93
)