Skip to content

Commit

Permalink
support one document to multi docnode (#413)
Browse files Browse the repository at this point in the history
  • Loading branch information
wzh1994 authored Jan 7, 2025
1 parent e327a3f commit a3d0685
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 27 deletions.
29 changes: 15 additions & 14 deletions lazyllm/tools/rag/dataReader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from .readers import (ReaderBase, PDFReader, DocxReader, HWPReader, PPTXReader, ImageReader, IPYNBReader,
EpubReader, MarkdownReader, MboxReader, PandasCSVReader, PandasExcelReader, VideoAudioReader,
get_default_fs, is_default_fs)
from .global_metadata import (RAG_DOC_FILE_NAME, RAG_DOC_FILE_TYPE, RAG_DOC_FILE_SIZE,
RAG_DOC_CREATION_DATE, RAG_DOC_LAST_MODIFIED_DATE, RAG_DOC_LAST_ACCESSED_DATE)

def _file_timestamp_format(timestamp: float, include_time: bool = False) -> Optional[str]:
try:
Expand All @@ -43,13 +45,12 @@ def __call__(self, file_path: str) -> Dict:
last_modified_date = _file_timestamp_format(stat_result.get("mtime"))
last_accessed_date = _file_timestamp_format(stat_result.get("atime"))
default_meta = {
"file_path": file_path,
"file_name": file_name,
"file_type": mimetypes.guess_type(file_path)[0],
"file_size": stat_result.get("size"),
"creation_date": creation_date,
"last_modified_date": last_modified_date,
"last_accessed_date": last_accessed_date,
RAG_DOC_FILE_NAME: file_name,
RAG_DOC_FILE_TYPE: mimetypes.guess_type(file_path)[0],
RAG_DOC_FILE_SIZE: stat_result.get("size"),
RAG_DOC_CREATION_DATE: creation_date,
RAG_DOC_LAST_MODIFIED_DATE: last_modified_date,
RAG_DOC_LAST_ACCESSED_DATE: last_accessed_date,
}

return {meta_key: meta_value for meta_key, meta_value in default_meta.items() if meta_value is not None}
Expand Down Expand Up @@ -83,7 +84,7 @@ def __init__(self, input_dir: Optional[str] = None, input_files: Optional[List]
encoding: str = "utf-8", filename_as_id: bool = False, required_exts: Optional[List[str]] = None,
file_extractor: Optional[Dict[str, Callable]] = None, fs: Optional[AbstractFileSystem] = None,
metadata_genf: Optional[Callable[[str], Dict]] = None, num_files_limit: Optional[int] = None,
return_trace: bool = False) -> None:
return_trace: bool = False, metadatas: Optional[Dict] = None) -> None:
super().__init__(return_trace=return_trace)

if (not input_dir and not input_files) or (input_dir and input_files):
Expand All @@ -98,6 +99,7 @@ def __init__(self, input_dir: Optional[str] = None, input_files: Optional[List]
self._required_exts = required_exts
self._num_files_limit = num_files_limit
self._Path = Path if is_default_fs(self._fs) else PurePosixPath
self._metadatas = metadatas

if input_files:
self._input_files = []
Expand Down Expand Up @@ -191,12 +193,11 @@ def _exclude_metadata(self, documents: List[DocNode]) -> List[DocNode]:
@staticmethod
def load_file(input_file: Path, metadata_genf: Callable[[str], Dict], file_extractor: Dict[str, Callable],
filename_as_id: bool = False, encoding: str = "utf-8", pathm: PurePath = Path,
fs: Optional[AbstractFileSystem] = None) -> List[DocNode]:
metadata: Optional[dict] = None
fs: Optional[AbstractFileSystem] = None, metadata: Optional[Dict] = None) -> List[DocNode]:
metadata: dict = metadata or {}
documents: List[DocNode] = []

if metadata_genf is not None: metadata = metadata_genf(str(input_file))

if metadata_genf is not None: metadata.update(metadata_genf(str(input_file)))
file_reader_patterns = list(file_extractor.keys())

for pattern in file_reader_patterns:
Expand All @@ -220,7 +221,7 @@ def load_file(input_file: Path, metadata_genf: Callable[[str], Dict], file_extra
with fs.open(input_file, encoding=encoding) as f:
data = f.read().decode(encoding)

doc = DocNode(text=data, metadata=metadata or {})
doc = DocNode(text=data, global_metadata=metadata or {})
doc.docpath = str(input_file)
if filename_as_id: doc._uid = str(input_file)
documents.append(doc)
Expand All @@ -245,7 +246,7 @@ def _load_data(self, show_progress: bool = False, num_workers: Optional[int] = N
results = p.starmap(SimpleDirectoryReader.load_file,
zip(process_file, repeat(self._metadata_genf), repeat(file_readers),
repeat(self._filename_as_id), repeat(self._encoding), repeat(self._Path),
repeat(self._fs)))
repeat(self._fs), self._metadatas or repeat(None)))
documents = reduce(lambda x, y: x + y, results)
else:
if show_progress:
Expand Down
2 changes: 1 addition & 1 deletion lazyllm/tools/rag/data_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def __init__(self, input_files: Optional[List[str]], local_readers: Optional[Dic
self._local_readers = local_readers
self._global_readers = global_readers

def load_data(self, input_files: Optional[List[str]] = None) -> List[DocNode]:
def load_data(self, input_files: Optional[List[str]] = None, metadates: Optional[Dict] = None) -> List[DocNode]:
input_files = input_files or self._input_files
file_readers = self._local_readers.copy()
for key, func in self._global_readers.items():
Expand Down
10 changes: 5 additions & 5 deletions lazyllm/tools/rag/doc_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,11 @@ def _lazy_init(self) -> None:
if not self.store.is_group_active(LAZY_ROOT_NAME):
ids, paths, metadatas = self._list_files()
if paths:
root_nodes = self._reader.load_data(paths)
for idx, node in enumerate(root_nodes):
node.global_metadata.update(metadatas[idx].copy() if metadatas else {})
node.global_metadata[RAG_DOC_ID] = ids[idx] if ids else gen_docid(paths[idx])
node.global_metadata[RAG_DOC_PATH] = paths[idx]
if not metadatas: metadatas = [{}] * len(paths)
for idx, meta in enumerate(metadatas):
meta[RAG_DOC_ID] = ids[idx] if ids else gen_docid(paths[idx])
meta[RAG_DOC_PATH] = paths[idx]
root_nodes = self._reader.load_data(paths, metadatas)
self.store.update_nodes(root_nodes)
if self._dlm:
self._dlm.update_kb_group(cond_file_ids=ids, cond_group=self._kb_group_name,
Expand Down
2 changes: 2 additions & 0 deletions lazyllm/tools/rag/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def __init__(self, dataset_path: str, embed: Optional[Union[Callable, Dict[str,
defatult_path = os.path.join(lazyllm.config["data_path"], dataset_path)
if os.path.exists(defatult_path):
dataset_path = defatult_path
else:
dataset_path = os.path.join(os.getcwd(), dataset_path)
self._launcher: Launcher = launcher if launcher else lazyllm.launchers.remote(sync=False)
self._dataset_path = dataset_path
self._embed = embed if isinstance(embed, dict) else {EMBED_DEFAULT_KEY: embed} if embed else {}
Expand Down
2 changes: 1 addition & 1 deletion lazyllm/tools/rag/readers/pdfReader.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
RETRY_TIMES = 3

class PDFReader(LazyLLMReaderBase):
def __init__(self, return_full_document: bool = True, return_trace: bool = True) -> None:
def __init__(self, return_full_document: bool = False, return_trace: bool = True) -> None:
super().__init__(return_trace=return_trace)
self._return_full_document = return_full_document

Expand Down
10 changes: 5 additions & 5 deletions tests/basic_tests/test_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,16 +147,16 @@ def test_multi_embedding_with_document(self):

document2 = Document(dataset_path="rag_master", embed={"m1": self.embed_model1, "m2": self.embed_model2})
document2.create_node_group(name="sentences", transform=SentenceSplitter, chunk_size=1024, chunk_overlap=100)
retriever2 = Retriever(document2, group_name="sentences", similarity="cosine", topk=10)
retriever2 = Retriever(document2, group_name="sentences", similarity="cosine", topk=3)
nodes2 = retriever2("何为天道?")
assert len(nodes2) == 11
assert len(nodes2) >= 3

document3 = Document(dataset_path="rag_master", embed={"m1": self.embed_model1, "m2": self.embed_model2})
document3.create_node_group(name="sentences", transform=SentenceSplitter, chunk_size=1024, chunk_overlap=100)
retriever3 = Retriever(document3, group_name="sentences", similarity="cosine",
similarity_cut_off={"m1": 0.5, "m2": 0.55}, topk=10)
nodes3 = retriever3("何为天道?")
assert len(nodes3) == 3
similarity_cut_off={"m1": 0.5, "m2": 0.55}, topk=3, output_format='content', join=True)
nodes3_text = retriever3("何为天道?")
assert '观天之道' in nodes3_text or '天命之谓性' in nodes3_text

def test_doc_web_module(self):
import time
Expand Down
2 changes: 1 addition & 1 deletion tests/basic_tests/test_rag_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def test_reader_file(self):
docs = []
for doc in reader():
docs.append(doc)
assert len(docs) == 2
assert len(docs) == 3

def test_reader_dir(self):
input_dir = self.datasets
Expand Down

0 comments on commit a3d0685

Please sign in to comment.