This repository was archived by the owner on Feb 21, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathInitDatabase.py
272 lines (233 loc) · 9 KB
/
InitDatabase.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
import json
import os
import shutil
import sys
from datetime import datetime
import yaml
from langchain.retrievers import ParentDocumentRetriever
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_core.documents import Document
from langchain_milvus import Milvus
from loguru import logger
from tqdm import tqdm
from llm.EmbeddingCore import BgeM3Embeddings
from utils.Decorator import timer
from utils.entities.UserProfile import User, UserGroup
logger.remove()
handler_id = logger.add(sys.stderr, level="INFO")
logger.add('log/init_database.log')
def init_retriever() -> ParentDocumentRetriever:
logger.info('start building vector database...')
milvus_cfg = config.milvus_config
collection_name = milvus_cfg.get_collection().collection_name
embed_cfg = config.embedding_config
embedding = BgeM3Embeddings(
model_name=embed_cfg.model,
model_kwargs={
'device': 'cuda',
'normalize_embeddings': embed_cfg.normalize,
'use_fp16': embed_cfg.fp16
},
local_load=embed_cfg.save_local,
local_path=embed_cfg.local_path
)
logger.info(f'load collection [{collection_name}], using model {embed_cfg.model}')
if args.drop_old:
doc_store = SqliteDocStore(
connection_string=config.get_sqlite_path(collection_name),
drop_old=True
)
else:
doc_store = SqliteDocStore(
connection_string=config.get_sqlite_path(collection_name)
)
vector_db = Milvus(
embedding,
collection_name=collection_name,
connection_args=milvus_cfg.get_conn_args(),
index_params=milvus_cfg.get_collection().index_param,
drop_old=True,
auto_id=True,
enable_dynamic_field=True,
)
init_doc = Document(page_content=f'This is a collection about {collection_name}',
metadata={
'title': 'About this collection',
'section': 'Abstract',
'author': 'administrator',
'year': datetime.now().year,
'type': -1,
'keywords': 'collection',
'doi': ''
})
init_ids = vector_db.add_documents([init_doc])
vector_db.delete(init_ids)
logger.info('done')
parent_splitter = RecursiveCharacterTextSplitter(
chunk_size=450,
chunk_overlap=0,
separators=['\n\n', '\n'],
keep_separator=False
)
if milvus_cfg.get_collection().language == 'en':
child_splitter = RecursiveCharacterTextSplitter(
chunk_size=100,
chunk_overlap=0,
separators=['.', '\n\n', '\n'],
keep_separator=False
)
elif milvus_cfg.get_collection().language == 'zh':
child_splitter = RecursiveCharacterTextSplitter(
chunk_size=100,
chunk_overlap=0,
separators=['。', '?', '\n\n', '\n'],
keep_separator=False
)
else:
raise Exception(f'error language {milvus_cfg.get_collection().language}')
retriever = ParentDocumentRetriever(
vectorstore=vector_db,
docstore=doc_store,
child_splitter=child_splitter,
parent_splitter=parent_splitter,
)
return retriever
@timer
def load_md(base_path: str) -> None:
"""
加载markdown文件到检索器中。
:param base_path: 基础路径,包含年份子目录,每个子目录下包含markdown和xml文件。
:return: 无返回值
"""
# 初始化检索器,并添加初始文档
retriever = init_retriever()
now_collection = config.milvus_config.get_collection().collection_name
logger.info('start loading file...')
# 遍历基础路径下的所有文件和子目录
for root, dirs, files in os.walk(base_path):
# 跳过空目录
if len(files) == 0:
continue
# 提取年份信息
year = os.path.basename(root)
for _file in tqdm(files, total=len(files), desc=f'load file in ({year})'):
# 加载并处理markdown文件
file_path = os.path.join(config.get_md_path(now_collection), year, _file)
# 分割markdown文本为多个文档
md_docs, reference_data = load_from_md(file_path)
# 尝试将分割得到的文档添加到检索器
try:
retriever.add_documents(md_docs)
with ReferenceStore(config.get_reference_path()) as ref_store:
ref_store.add_reference(reference_data)
except Exception as e:
logger.error(f'loading <{_file}> ({year}) fail')
logger.error(e)
logger.info(f'done')
def create_userdb():
connect_str = config.get_user_db()
os.makedirs(os.path.dirname(connect_str), exist_ok=True)
with ProfileStore(
connection_string=connect_str
) as profile_store:
profile_store.init_tables()
init_username = config.yml['user_login_config']['admin_user']['username']
init_password = config.yml['user_login_config']['admin_user']['password']
admin = User(
name=init_username,
password=init_password,
user_group=UserGroup.ADMIN.value,
)
profile_store.create_user(admin)
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='init vector database')
parser.add_argument(
'--collection',
'-C',
nargs='?',
const=-1,
type=int,
help='Initialize a specific collection, starting from 0.'
)
parser.add_argument(
'--auto_create',
'-A',
action='store_true',
help='Automatic database initialization based on directory structure'
)
parser.add_argument(
'--force',
'-F',
action='store_true',
help='Force override of existing configurations'
)
parser.add_argument(
'--drop_old',
'-D',
action='store_true',
help='Whether to delete the original reference database'
)
parser.add_argument(
'--user',
'-U',
action='store_true',
help='Initialize user-related databases, '
'including creating SQLite database files that hold user information and initializing administrator user accounts.'
)
args = parser.parse_args()
if args.auto_create:
yml_path = 'config.yml'
if not os.path.exists(yml_path):
logger.info('config dose not exits')
shutil.copy('config.example.yml', yml_path)
with open(file=yml_path, mode='r', encoding='utf-8') as file:
yml = yaml.load(file, Loader=yaml.FullLoader)
DATA_ROOT = yml['paper_directory']['data_root']
cfg_path = os.path.join(DATA_ROOT, 'collections.json')
if not args.force and os.path.exists(cfg_path):
logger.info('config file exists, use existing config file')
else:
cols = [
{
"collection_name": collection,
"language": 'en',
"title": collection,
"description": f'This is a collection about {collection}',
"index_param": {
"metric_type": 'L2',
"index_type": 'HNSW',
"params": {"M": 8, "efConstruction": 64},
},
"visitor_visible": True,
}
for collection in os.listdir(DATA_ROOT)
if os.path.isdir(os.path.join(DATA_ROOT, collection))
]
json.dump({"collections": cols}, open(cfg_path, 'w', encoding='utf-8'))
logger.info(f'auto create config file {cfg_path}')
from Config import Config
config = Config()
from storage.SqliteStore import SqliteDocStore, ReferenceStore, ProfileStore
from utils.MarkdownPraser import load_from_md
if args.drop_old:
with ReferenceStore(config.get_reference_path()) as _store:
_store.drop_old()
logger.info('drop old database.')
if args.collection is not None:
if args.collection == -1:
for i in range(len(config.milvus_config.collections)):
logger.info(f'Start init collection {i}')
config.set_collection(i)
load_md(config.get_md_path(config.milvus_config.get_collection().collection_name))
else:
if args.collection >= len(config.milvus_config.collections) or args.collection < -1:
logger.error(f'collection index {args.collection} out of range')
exit(1)
else:
config.set_collection(args.collection)
logger.info(f'Only init collection {args.collection}')
load_md(config.get_md_path(config.milvus_config.get_collection().collection_name))
if args.user:
logger.info('Create admin profile...')
create_userdb()