Skip to content

Commit

Permalink
debug backend API for TAB 'search' (infiniflow#2389)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?
infiniflow#2247

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
  • Loading branch information
KevinHuSh authored Sep 12, 2024
1 parent acb7d25 commit d152c37
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 20 deletions.
2 changes: 1 addition & 1 deletion api/apps/chunk_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def retrieval_test():
kb_id = req["kb_id"]
if isinstance(kb_id, str): kb_id = [kb_id]
doc_ids = req.get("doc_ids", [])
similarity_threshold = float(req.get("similarity_threshold", 0.2))
similarity_threshold = float(req.get("similarity_threshold", 0.0))
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
top = int(req.get("top_k", 1024))

Expand Down
4 changes: 3 additions & 1 deletion api/apps/conversation_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
#
import json
import re
import traceback
from copy import deepcopy

from api.db.services.user_service import UserTenantService
from flask import request, Response
from flask_login import login_required, current_user
Expand Down Expand Up @@ -333,6 +333,8 @@ def mindmap():
0.3, 0.3, aggs=False)
mindmap = MindMapExtractor(chat_mdl)
mind_map = mindmap([c["content_with_weight"] for c in ranks["chunks"]]).output
if "error" in mind_map:
return server_error_response(Exception(mind_map["error"]))
return get_json_result(data=mind_map)


Expand Down
3 changes: 1 addition & 2 deletions api/db/services/dialog_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def decorate_answer(answer):
for ans in chat_mdl.chat_streamly(prompt, msg[1:], gen_conf):
answer = ans
delta_ans = ans[len(last_ans):]
if num_tokens_from_string(delta_ans) < 12:
if num_tokens_from_string(delta_ans) < 16:
continue
last_ans = answer
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
Expand Down Expand Up @@ -404,7 +404,6 @@ def rewrite(tenant_id, llm_id, question):


def tts(tts_mdl, text):
return
if not tts_mdl or not text: return
bin = b""
for chunk in tts_mdl.tts(text):
Expand Down
2 changes: 1 addition & 1 deletion graphrag/mind_map_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def __call__(
res.append(_.result())

if not res:
return MindMapResult(output={"root":{}})
return MindMapResult(output={"id": "root", "children": []})

merge_json = reduce(self._merge, res)
if len(merge_json.keys()) > 1:
Expand Down
2 changes: 1 addition & 1 deletion rag/llm/embedding_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#
import re
from typing import Optional
import threading
import threading
import requests
from huggingface_hub import snapshot_download
from openai.lib.azure import AzureOpenAI
Expand Down
37 changes: 23 additions & 14 deletions rag/nlp/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,8 @@ def trans2floats(txt):
def insert_citations(self, answer, chunks, chunk_v,
embd_mdl, tkweight=0.1, vtweight=0.9):
assert len(chunks) == len(chunk_v)
if not chunks:
return answer, set([])
pieces = re.split(r"(```)", answer)
if len(pieces) >= 3:
i = 0
Expand Down Expand Up @@ -263,7 +265,7 @@ def insert_citations(self, answer, chunks, chunk_v,

ans_v, _ = embd_mdl.encode(pieces_)
assert len(ans_v[0]) == len(chunk_v[0]), "The dimension of query and chunk do not match: {} vs. {}".format(
len(ans_v[0]), len(chunk_v[0]))
len(ans_v[0]), len(chunk_v[0]))

chunks_tks = [rag_tokenizer.tokenize(self.qryr.rmWWW(ck)).split(" ")
for ck in chunks]
Expand Down Expand Up @@ -360,29 +362,33 @@ def retrieval(self, question, embd_mdl, tenant_id, kb_ids, page, page_size, simi
ranks = {"total": 0, "chunks": [], "doc_aggs": {}}
if not question:
return ranks
req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": page_size,
RERANK_PAGE_LIMIT = 3
req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "size": page_size*RERANK_PAGE_LIMIT,
"question": question, "vector": True, "topk": top,
"similarity": similarity_threshold,
"available_int": 1}
if page > RERANK_PAGE_LIMIT:
req["page"] = page
req["size"] = page_size
sres = self.search(req, index_name(tenant_id), embd_mdl, highlight)
ranks["total"] = sres.total

if rerank_mdl:
sim, tsim, vsim = self.rerank_by_model(rerank_mdl,
sres, question, 1 - vector_similarity_weight, vector_similarity_weight)
if page <= RERANK_PAGE_LIMIT:
if rerank_mdl:
sim, tsim, vsim = self.rerank_by_model(rerank_mdl,
sres, question, 1 - vector_similarity_weight, vector_similarity_weight)
else:
sim, tsim, vsim = self.rerank(
sres, question, 1 - vector_similarity_weight, vector_similarity_weight)
idx = np.argsort(sim * -1)[(page-1)*page_size:page*page_size]
else:
sim, tsim, vsim = self.rerank(
sres, question, 1 - vector_similarity_weight, vector_similarity_weight)
idx = np.argsort(sim * -1)
sim = tsim = vsim = [1]*len(sres.ids)
idx = list(range(len(sres.ids)))

dim = len(sres.query_vector)
start_idx = (page - 1) * page_size
for i in idx:
if sim[i] < similarity_threshold:
break
ranks["total"] += 1
start_idx -= 1
if start_idx >= 0:
continue
if len(ranks["chunks"]) >= page_size:
if aggs:
continue
Expand All @@ -406,7 +412,10 @@ def retrieval(self, question, embd_mdl, tenant_id, kb_ids, page, page_size, simi
"positions": sres.field[id].get("position_int", "").split("\t")
}
if highlight:
d["highlight"] = rmSpace(sres.highlight[id])
if id in sres.highlight:
d["highlight"] = rmSpace(sres.highlight[id])
else:
d["highlight"] = d["content_with_weight"]
if len(d["positions"]) % 5 == 0:
poss = []
for i in range(0, len(d["positions"]), 5):
Expand Down

0 comments on commit d152c37

Please sign in to comment.