Skip to content

Commit

Permalink
upgrade milvus parallel embed usage (#408)
Browse files Browse the repository at this point in the history
  • Loading branch information
dorren002 authored Jan 10, 2025
1 parent a3d0685 commit 5690af5
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 12 deletions.
4 changes: 2 additions & 2 deletions docs/en/Best Practice/rag.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,15 @@ docs = Document()

# (1)
docs.create_node_group(name='block',
transform=lambda d: '\n'.split(d))
transform=lambda d: d.split('\n'))

# (2)
docs.create_node_group(name='doc-summary',
transform=lambda d: summary_llm(d))

# (3)
docs.create_node_group(name='sentence',
transform=lambda b: ''.split(b),
transform=lambda b: b.split(''),
parent='block')

# (4)
Expand Down
4 changes: 2 additions & 2 deletions docs/zh/Best Practice/rag.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,15 @@ docs = Document()

# (1)
docs.create_node_group(name='block',
transform=lambda d: '\n'.split(d))
transform=lambda d: d.split('\n'))

# (2)
docs.create_node_group(name='doc-summary',
transform=lambda d: summary_llm(d))

# (3)
docs.create_node_group(name='sentence',
transform=lambda b: ''.split(b),
transform=lambda b: b.split(''),
parent='block')

# (4)
Expand Down
2 changes: 1 addition & 1 deletion lazyllm/tools/rag/data_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,6 @@ def load_data(self, input_files: Optional[List[str]] = None, metadates: Optional
nodes.append(doc)
if not nodes:
LOG.warning(
f"No nodes load from path {self.input_files}, please check your data path."
f"No nodes load from path {input_files}, please check your data path."
)
return nodes
14 changes: 9 additions & 5 deletions lazyllm/tools/rag/milvus_store.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
from collections import defaultdict
from typing import Dict, List, Optional, Union, Callable, Set
from lazyllm.thirdparty import pymilvus
from .doc_node import DocNode
Expand All @@ -12,6 +13,8 @@
from .data_type import DataType
from lazyllm.common import override, obj2str, str2obj

MILVUS_UPSERT_BATCH_SIZE = 500

class MilvusStore(StoreBase):
# we define these variables as members so that pymilvus is not imported until MilvusStore is instantiated.
def _def_constants(self) -> None:
Expand Down Expand Up @@ -156,13 +159,14 @@ def __init__(self, group_embed_keys: Dict[str, Set[str]], embed: Dict[str, Calla

@override
def update_nodes(self, nodes: List[DocNode]) -> None:
parallel_do_embedding(self._embed, [], nodes, self._group_embed_keys)
group_embed_dict = defaultdict(list)
for node in nodes:
embed_keys = self._group_embed_keys.get(node._group)
if embed_keys:
parallel_do_embedding(self._embed, embed_keys, [node])
data = self._serialize_node_partial(node)
self._client.upsert(collection_name=node._group, data=[data])

group_embed_dict[node._group].append(data)
for group_name, data in group_embed_dict.items():
for i in range(0, MILVUS_UPSERT_BATCH_SIZE, len(data)):
self._client.upsert(collection_name=group_name, data=data[i:i + MILVUS_UPSERT_BATCH_SIZE])
self._map_store.update_nodes(nodes)

@override
Expand Down
2 changes: 1 addition & 1 deletion lazyllm/tools/rag/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ def transform(self, node: DocNode, **kwargs) -> List[Union[str, DocNode]]:
You should not have any unnecessary output. Lets begin:
"""),
cn=dict(summary="""
zh=dict(summary="""
## 角色:文本摘要
你是一个文本摘要引擎,负责分析用户输入的文本,并根据请求任务提供简洁的摘要。
Expand Down
6 changes: 5 additions & 1 deletion lazyllm/tools/rag/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,11 +681,15 @@ def save_files_in_threads(

# returns a list of modified nodes
def parallel_do_embedding(embed: Dict[str, Callable], embed_keys: Optional[Union[List[str], Set[str]]],
nodes: List[DocNode]) -> List[DocNode]:
nodes: List[DocNode], group_embed_keys: Dict[str, List[str]] = None) -> List[DocNode]:
modified_nodes = []
with ThreadPoolExecutor(config["max_embedding_workers"]) as executor:
futures = []
for node in nodes:
if group_embed_keys:
embed_keys = group_embed_keys.get(node._group)
if not embed_keys:
continue
miss_keys = node.has_missing_embedding(embed_keys)
if not miss_keys:
continue
Expand Down

0 comments on commit 5690af5

Please sign in to comment.