diff --git a/autogen/agentchat/contrib/rag/mongodb.py b/autogen/agentchat/contrib/rag/mongodb.py index 42fd5c9dda..878dec0c27 100644 --- a/autogen/agentchat/contrib/rag/mongodb.py +++ b/autogen/agentchat/contrib/rag/mongodb.py @@ -4,7 +4,7 @@ import logging from pathlib import Path -from typing import Any, Callable, Optional +from typing import Any, Callable, List, Optional, Union from autogen.agentchat.contrib.rag.query_engine import VectorDbQueryEngine from autogen.agentchat.contrib.vectordb.base import VectorDBFactory @@ -12,12 +12,10 @@ from autogen.import_utils import optional_import_block, require_optional_import with optional_import_block(): - from llama_index.core import StorageContext, VectorStoreIndex - from llama_index.node_parser.docling import DoclingNodeParser - from llama_index.readers.docling import DoclingReader + from llama_index.core import Document, SimpleDirectoryReader, StorageContext, VectorStoreIndex + from llama_index.core.node_parser import SentenceSplitter from llama_index.vector_stores.mongodb import MongoDBAtlasVectorSearch - DEFAULT_COLLECTION_NAME = "docling-parsed-docs" logging.basicConfig(level=logging.INFO) @@ -26,6 +24,19 @@ @require_optional_import(["pymongo", "llama_index"], "rag") class MongoDBQueryEngine(VectorDbQueryEngine): + """ + A query engine backed by MongoDB Atlas that supports document insertion and querying. + + This engine initializes a vector database, builds an index from input documents, + and allows querying using the chat engine interface. + + Attributes: + vector_db (MongoDBAtlasVectorDB): The MongoDB vector database instance. + vector_search_engine (MongoDBAtlasVectorSearch): The vector search engine. + storage_context (StorageContext): The storage context for the vector store. + indexer (Optional[VectorStoreIndex]): The index built from the documents. + """ + def __init__( self, connection_string: str = "", @@ -34,8 +45,17 @@ def __init__( collection_name: str = DEFAULT_COLLECTION_NAME, index_name: str = "vector_index", ): - super().__init__() + """ + Initialize the MongoDBQueryEngine. + Args: + connection_string (str): The MongoDB connection string. + database_name (str): The name of the database to use. + embedding_function (Optional[Callable[..., Any]]): The function to compute embeddings. + collection_name (str): The name of the collection to use. + index_name (str): The name of the vector index. + """ + super().__init__() self.vector_db: MongoDBAtlasVectorDB = VectorDBFactory.create_vector_db( # type: ignore[assignment] db_type="mongodb", connection_string=connection_string, @@ -45,100 +65,151 @@ def __init__( collection_name=collection_name, ) self.vector_search_engine = MongoDBAtlasVectorSearch( - mongodb_client=self.vector_db.client, db_name=database_name, collection_name=collection_name + mongodb_client=self.vector_db.client, + db_name=database_name, + collection_name=collection_name, ) self.storage_context = StorageContext.from_defaults(vector_store=self.vector_search_engine) - self.indexer = None + self.indexer: Optional[VectorStoreIndex] = None # type: ignore[no-any-unimported] + + def connect_db(self, *args, **kwargs) -> bool: # type: ignore[no-untyped-def] + """ + Connect to the MongoDB database by issuing a ping. + + Returns: + bool: True if the connection is successful; False otherwise. + """ + try: + self.vector_db.client.admin.command("ping") + logger.info("Connected to MongoDB successfully.") + return True + except Exception as error: + logger.error("Failed to connect to MongoDB: %s", error) + return False def init_db( # type: ignore[no-untyped-def] self, - new_doc_dir: Optional[str | Path] = None, - new_doc_paths: Optional[list[str | Path]] = None, + new_doc_dir: Optional[Union[str, Path]] = None, + new_doc_paths: Optional[List[Union[str, Path]]] = None, *args, **kwargs, - ) -> bool: # type: ignore[no-untyped-def] - """Initialize the database with the input documents or records. - - This method initializes database with the input documents or records. - Usually, it takes the following steps, - 1. connecting to a database. - 2. insert records - 3. build indexes etc. + ) -> bool: + """ + Initialize the database by loading documents from the given directory or file paths, + then building an index. Args: - new_doc_dir: a dir of input documents that are used to create the records in database. - new_doc_paths: - a list of input documents that are used to create the records in database. - a document can be a path to a file or a url. - *args: Any additional arguments - **kwargs: Any additional keyword arguments + new_doc_dir (Optional[Union[str, Path]]): Directory containing input documents. + new_doc_paths (Optional[List[Union[str, Path]]]): List of document paths or URLs. Returns: - bool: True if initialization is successful, False otherwise + bool: True if initialization is successful; False otherwise. """ if not self.connect_db(): return False - if new_doc_dir or new_doc_paths: - self.add_records(new_doc_dir, new_doc_paths) # type: ignore[no-untyped-call] - - self.indexer = VectorStoreIndex.from_vector_store( - self.vector_search_engine, storage_context=self.storage_context - ) - return True + # Gather document paths. + document_list: List[Union[str, Path]] = [] + if new_doc_dir: + document_list.extend(Path(new_doc_dir).glob("**/*")) + if new_doc_paths: + document_list.extend(new_doc_paths) - def connect_db(self, *args, **kwargs) -> bool: # type: ignore[no-untyped-def] - """ - Connect to the MongoDB database by issuing a ping. + if not document_list: + logger.warning("No input documents provided to initialize the database.") + return False - Returns: - True if the connection is successful; False otherwise. - """ try: - self.vector_db.client.admin.command("ping") + documents = SimpleDirectoryReader(input_files=document_list).load_data() + self.indexer = VectorStoreIndex.from_documents(documents, storage_context=self.storage_context) + logger.info("Database initialized with %d documents.", len(documents)) return True - except Exception as error: - logger.error("Failed to connect to MongoDB: %s", error) + except Exception as e: + logger.error("Failed to initialize the database: %s", e) return False - def add_records(self, new_doc_dir=None, new_doc_paths_or_urls=None, *args, **kwargs): # type: ignore[no-untyped-def] - document_list = [] # type: ignore[var-annotated] + def add_records( # type: ignore[no-untyped-def, override] + self, + new_doc_dir: Optional[Union[str, Path]] = None, + new_doc_paths_or_urls: Optional[Union[List[Union[str, Path]], Union[str, Path]]] = None, + chunk_size: int = 500, + chunk_overlap: int = 300, + *args, + **kwargs, + ) -> None: + """ + Load, parse, and insert documents into the index. + + This method uses a SentenceSplitter to break documents into chunks before insertion. + + Args: + new_doc_dir (Optional[Union[str, Path]]): Directory containing input documents. + new_doc_paths_or_urls (Optional[Union[List[Union[str, Path]], Union[str, Path]]]): + List of document paths or a single document path/URL. + """ + # Collect document paths. + document_list: List[Union[str, Path]] = [] if new_doc_dir: document_list.extend(Path(new_doc_dir).glob("**/*")) if new_doc_paths_or_urls: - document_list.append(new_doc_paths_or_urls) - - reader = DoclingReader(export_type=DoclingReader.ExportType.JSON) - node_parser = DoclingNodeParser() - - documents = reader.load_data(document_list) - parser_nodes = node_parser._parse_nodes(documents) - - # print("Parser data: ", parser) - - # # document_reader = SimpleDirectoryReader(input_files=document_list).load_data() - # self.vector_db.insert_docs([ - # Document( # type: ignore[typeddict-item, typeddict-unknown-key] - # id=document.id_, - # content=document.get_content(), - # metadata=document.metadata, - # ) - # for document in parser - # ]) - - docs_to_insert = [] - for node in parser_nodes: - doc_dict = { - "id": node.id_, # Ensure the key is 'id' - "content": node.get_content(), - "metadata": node.metadata, - } - docs_to_insert.append(doc_dict) - - # Insert documents into vector DB. - self.vector_db.insert_docs(docs_to_insert) # type: ignore[arg-type] - - def query(self, question, *args, **kwargs): # type: ignore[no-untyped-def] - response = self.indexer.as_chat_engine().query(question) # type: ignore[attr-defined] - - return response + if isinstance(new_doc_paths_or_urls, (list, tuple)): + document_list.extend(new_doc_paths_or_urls) + else: + document_list.append(new_doc_paths_or_urls) + + if not document_list: + logger.warning("No documents found for adding records.") + return + + try: + raw_documents = SimpleDirectoryReader(input_files=document_list).load_data() + except Exception as e: + logger.error("Error loading documents: %s", e) + return + + node_parser = SentenceSplitter( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + ) + doc_chunks: List[Document] = [] # type: ignore[no-any-unimported] + + for document in raw_documents: + try: + # Split the document text into chunks. + chunks = node_parser.split_text(document.text) + metadata = document.metadata + doc_id = document.id_ # Ensure this is not a tuple. + for chunk in chunks: + new_doc = Document(text=chunk, id_=doc_id, metadata=metadata) + doc_chunks.append(new_doc) + except Exception as e: + logger.error("Error parsing document %s: %s", document.id_, e) + + if not doc_chunks: + logger.warning("No document chunks created for insertion.") + return + + # Insert document chunks using the indexer. + try: + for doc in doc_chunks: + self.indexer.insert(doc) # type: ignore[union-attr] + logger.info("Inserted %d document chunks successfully.", len(doc_chunks)) + except Exception as e: + logger.error("Error inserting documents into the index: %s", e) + + def query(self, question: str, *args, **kwargs) -> Any: # type: ignore[no-untyped-def] + """ + Query the index using the given question. + + Args: + question (str): The query string. + + Returns: + Any: The response from the chat engine, or None if an error occurs. + """ + try: + response = self.indexer.as_chat_engine().query(question) # type: ignore[union-attr] + return response + except Exception as e: + logger.error("Query failed: %s", e) + return None diff --git a/notebook/docling_md_query_engine.ipynb b/notebook/docling_md_query_engine.ipynb index 6daf3d6a3d..2bfbe59ec0 100644 --- a/notebook/docling_md_query_engine.ipynb +++ b/notebook/docling_md_query_engine.ipynb @@ -174,17 +174,6 @@ "# Docling MD Query Engine MongoDB" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from chromadb.utils import embedding_functions\n", - "\n", - "openai_ef = embedding_functions.OpenAIEmbeddingFunction(api_key=\"\", model_name=\"text-embedding-ada-002\")" - ] - }, { "cell_type": "code", "execution_count": null, @@ -193,7 +182,13 @@ "source": [ "import os\n", "\n", - "os.environ[\"OPENAI_API_KEY\"] = \"\"" + "from chromadb.utils import embedding_functions\n", + "\n", + "os.environ[\"OPENAI_API_KEY\"] = \"\"\n", + "openai_ef = embedding_functions.OpenAIEmbeddingFunction(\n", + " api_key=\"\",\n", + " model_name=\"text-embedding-ada-002\",\n", + ")" ] }, { @@ -202,7 +197,7 @@ "metadata": {}, "outputs": [], "source": [ - "filename = \"/root/ag2/test/agents/experimental/document_agent/pdf_parsed/nvidia_10k_2024.md\"" + "input_dir = \"/root/ag2/test/agents/experimental/document_agent/pdf_parsed/\"" ] }, { @@ -213,7 +208,11 @@ "source": [ "from autogen.agentchat.contrib.rag.mongodb import MongoDBQueryEngine\n", "\n", - "query_engine = MongoDBQueryEngine(connection_string=\"\", embedding_function=openai_ef, database_name=\"vector_db_4\")" + "query_engine = MongoDBQueryEngine(\n", + " connection_string=\"\",\n", + " embedding_function=openai_ef,\n", + " database_name=\"vector_db_1\",\n", + ")" ] }, { @@ -222,7 +221,7 @@ "metadata": {}, "outputs": [], "source": [ - "query_engine.init_db(new_doc_paths=filename)" + "query_engine.init_db(new_doc_paths=[input_dir + \"nvidia_10k_2024.md\"])" ] }, { @@ -240,7 +239,7 @@ "metadata": {}, "outputs": [], "source": [ - "query_engine.add_records(new_doc_paths_or_urls=filename)" + "query_engine.add_records(new_doc_paths_or_urls=[input_dir + \"Toast_financial_report.md\"])" ] }, { @@ -248,7 +247,11 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "question = \"What is the trading symbol for Toast\"\n", + "answer = query_engine.query(question)\n", + "print(answer)" + ] } ], "metadata": {