forked from stanford-futuredata/ColBERT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathindexer.py
94 lines (67 loc) · 3.35 KB
/
indexer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import os
import time
import torch.multiprocessing as mp
from colbert.infra.run import Run
from colbert.infra.config import ColBERTConfig, RunConfig
from colbert.infra.launcher import Launcher
from colbert.utils.utils import create_directory, print_message
from colbert.indexing.collection_indexer import encode
class Indexer:
def __init__(self, checkpoint, config=None, verbose: int = 3):
"""
Use Run().context() to choose the run's configuration. They are NOT extracted from `config`.
"""
self.index_path = None
self.verbose = verbose
self.checkpoint = checkpoint
self.checkpoint_config = ColBERTConfig.load_from_checkpoint(checkpoint)
self.config = ColBERTConfig.from_existing(self.checkpoint_config, config, Run().config)
self.configure(checkpoint=checkpoint)
def configure(self, **kw_args):
self.config.configure(**kw_args)
def get_index(self):
return self.index_path
def erase(self, force_silent: bool = False):
assert self.index_path is not None
directory = self.index_path
deleted = []
for filename in sorted(os.listdir(directory)):
filename = os.path.join(directory, filename)
delete = filename.endswith(".json")
delete = delete and ('metadata' in filename or 'doclen' in filename or 'plan' in filename)
delete = delete or filename.endswith(".pt")
if delete:
deleted.append(filename)
if len(deleted):
for filename in deleted:
os.remove(filename)
return deleted
def index(self, name, collection, overwrite=False):
assert overwrite in [True, False, 'reuse', 'resume', "force_silent_overwrite"]
self.configure(collection=collection, index_name=name, resume=overwrite=='resume')
# Note: The bsize value set here is ignored internally. Users are encouraged
# to supply their own batch size for indexing by using the index_bsize parameter in the ColBERTConfig.
self.configure(bsize=64, partitions=None)
self.index_path = self.config.index_path_
index_does_not_exist = (not os.path.exists(self.config.index_path_))
assert (overwrite in [True, 'reuse', 'resume', "force_silent_overwrite"]) or index_does_not_exist, self.config.index_path_
create_directory(self.config.index_path_)
if overwrite == 'force_silent_overwrite':
self.erase(force_silent=True)
elif overwrite is True:
self.erase()
if index_does_not_exist or overwrite != 'reuse':
self.__launch(collection)
return self.index_path
def __launch(self, collection):
launcher = Launcher(encode)
if self.config.nranks == 1 and self.config.avoid_fork_if_possible:
shared_queues = []
shared_lists = []
launcher.launch_without_fork(self.config, collection, shared_lists, shared_queues, self.verbose)
return
manager = mp.Manager()
shared_lists = [manager.list() for _ in range(self.config.nranks)]
shared_queues = [manager.Queue(maxsize=1) for _ in range(self.config.nranks)]
# Encodes collection into index using the CollectionIndexer class
launcher.launch(self.config, collection, shared_lists, shared_queues, self.verbose)