From 67147212a4e69521c50cc7aedfe29c32dc35acc6 Mon Sep 17 00:00:00 2001 From: sitloboi2012 Date: Fri, 21 Feb 2025 22:22:30 +0700 Subject: [PATCH 01/12] initial setup for mongodb query engine and notebook usage Signed-off-by: sitloboi2012 --- autogen/agentchat/contrib/rag/mongodb.py | 233 +++++++++++++++++++++++ notebook/docling_md_query_engine.ipynb | 107 ++++++++++- pyproject.toml | 1 + 3 files changed, 339 insertions(+), 2 deletions(-) create mode 100644 autogen/agentchat/contrib/rag/mongodb.py diff --git a/autogen/agentchat/contrib/rag/mongodb.py b/autogen/agentchat/contrib/rag/mongodb.py new file mode 100644 index 0000000000..e7c86205fd --- /dev/null +++ b/autogen/agentchat/contrib/rag/mongodb.py @@ -0,0 +1,233 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 + +import logging +from pathlib import Path +from typing import Any, List, Optional, Union + +from autogen.agentchat.contrib.rag.query_engine import VectorDbQueryEngine +from autogen.agentchat.contrib.vectordb.base import VectorDBFactory +from autogen.import_utils import optional_import_block, require_optional_import + +with optional_import_block(): + from llama_index.core.llms import LLM + from llama_index.llms.openai import OpenAI + from llama_index.vector_stores.mongodb import MongoDBAtlasVectorSearch + from pymongo import MongoClient + + +DEFAULT_COLLECTION_NAME = "docling-parsed-docs" + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +@require_optional_import(["pymongo", "llama_index"], "rag") +class MongoDBQueryEngine(VectorDbQueryEngine): + """ + MongoDBQueryEngine is a production-ready implementation of the VectorDbQueryEngine for MongoDB. + + This engine leverages the VectorDBFactory to instantiate a MongoDB vector database + (MongoDBAtlasVectorDB) and wraps its collection with a LlamaIndex vector store + (MongoDBAtlasVectorSearch) to enable document indexing and retrieval. + + Conceptually, it mirrors the approach used in DoclingMdQueryEngine for ChromaDB. + It provides methods to: + - Connect to the database. + - Initialize (or reinitialize) the collection with documents. + - Add new documents. + - Execute natural language queries against the vector index. + """ + + def __init__( # type: ignore[misc, no-any-unimported, return] + self, + connection_string: str, + db_name: str = "vector_db", + collection_name: str = "default_collection", + vector_index_name: str = "vector_index", + embedding_function: Optional[Any] = None, + llm: Optional[LLM] = None, + **kwargs: Any, + ) -> Any | None: + """ + Initializes the MongoDBQueryEngine. + + Args: + connection_string: MongoDB connection string. + db_name: Name of the MongoDB database. + collection_name: Name of the collection to use (default is DEFAULT_COLLECTION_NAME). + vector_index_name: Name of the vector search index. + embedding_function: Function to compute embeddings (if needed by the underlying vector DB). + llm: LLM for query processing (default uses OpenAI's GPT-4 variant). + **kwargs: Additional keyword arguments. + """ + self.connection_string = connection_string + self.db_name = db_name + self.collection_name = collection_name or DEFAULT_COLLECTION_NAME + self.vector_index_name = vector_index_name + + # Set up the LLM; if not provided, use a default OpenAI model. + self.llm = llm or OpenAI(model="gpt-4o", temperature=0.0) # type: ignore + + # Create a MongoDB client for use by the vector search wrapper. + self.mongodb_client = MongoClient(connection_string) + + # Initialize the LlamaIndex-style vector store wrapper for advanced query pipelines. + self.vector_search = MongoDBAtlasVectorSearch( + mongodb_client=self.mongodb_client, + db_name=db_name, + collection_name=collection_name, + vector_index_name=vector_index_name, + **kwargs, + ) + + # Create the full vector database instance via the VectorDBFactory. + self.vector_db = VectorDBFactory.create_vector_db( + "mongodb", + connection_string=connection_string, + database_name=db_name, + collection_name=collection_name, + index_name=vector_index_name, + embedding_function=embedding_function, + **kwargs, + ) + + def connect_db(self, *args: Any, **kwargs: Any) -> bool: + """ + Connect to the MongoDB database by issuing a ping. + + Returns: + True if the connection is successful; False otherwise. + """ + try: + self.mongodb_client.admin.command("ping") + return True + except Exception as error: + logger.error("Failed to connect to MongoDB: %s", error) + return False + + def init_db( + self, + new_doc_dir: Optional[Union[Path, str]] = None, + new_doc_paths: Optional[List[Union[Path, str]]] = None, + overwrite: bool = True, + *args: Any, + **kwargs: Any, + ) -> bool: + """ + Initialize the database with documents. + + This method: + 1. Connects to MongoDB. + 2. Creates (or overwrites) the target collection via the vector DB interface. + 3. Loads documents from a directory and/or file paths. + 4. Inserts the documents into the collection. + 5. Creates the vector search index. + + Args: + new_doc_dir: Directory containing document files. + new_doc_paths: List of file paths to individual documents. + overwrite: If True, the existing collection is overwritten. + *args, **kwargs: Additional arguments. + + Returns: + True if initialization is successful; False otherwise. + """ + if not self.connect_db(): + return False + + try: + self.vector_db.create_collection( + collection_name=self.collection_name, overwrite=overwrite, get_or_create=True + ) + except Exception as e: + logger.error("Error creating collection: %s", e) + return False + + # Load documents from file paths and/or a directory. + docs = [] + if new_doc_paths: + for i, doc_path in enumerate(new_doc_paths): + path_obj = Path(doc_path) + if path_obj.is_file(): + content = path_obj.read_text(encoding="utf-8") + docs.append({ + "id": f"doc_{i}", + "content": content, + "metadata": None, + "embedding": None, # Embeddings will be computed internally. + }) + if new_doc_dir: + dir_path = Path(new_doc_dir) + for i, file in enumerate(dir_path.glob("*")): + if file.is_file(): + content = file.read_text(encoding="utf-8") + docs.append({"id": f"doc_dir_{i}", "content": content, "metadata": None, "embedding": None}) + + if docs: + self.vector_db.insert_docs(docs, collection_name=self.collection_name) # type: ignore[arg-type] + + # Create the vector search index. + try: + dimensions = getattr(self.vector_db, "dimensions", 1536) + self.vector_search.create_vector_search_index(dimensions=dimensions, path="embedding", similarity="cosine") + except Exception as e: + logger.error("Error creating vector search index: %s", e) + + return True + + def add_records( + self, + new_doc_dir: Optional[Union[Path, str]] = None, + new_doc_paths_or_urls: Optional[List[Union[Path, str]]] = None, + *args: Any, + **kwargs: Any, + ) -> bool: + """ + Add new documents to the existing collection. + + Loads documents from the specified directory and/or file paths and inserts them into the vector DB. + + Returns: + True if records were added successfully; False otherwise. + """ + docs = [] + if new_doc_paths_or_urls: + for i, doc_path in enumerate(new_doc_paths_or_urls): + path_obj = Path(doc_path) + if path_obj.is_file(): + content = path_obj.read_text(encoding="utf-8") + docs.append({"id": f"new_doc_{i}", "content": content, "metadata": None, "embedding": None}) + if new_doc_dir: + dir_path = Path(new_doc_dir) + for i, file in enumerate(dir_path.glob("*")): + if file.is_file(): + content = file.read_text(encoding="utf-8") + docs.append({"id": f"new_doc_dir_{i}", "content": content, "metadata": None, "embedding": None}) + if docs: + self.vector_db.insert_docs(docs, collection_name=self.collection_name) # type: ignore[arg-type] + return True + return False + + def query(self, question: str, *args: Any, **kwargs: Any) -> str: + """ + Execute a natural language query against the vector database. + + Converts the query string into a vector search, retrieves the most relevant document, + and returns its content. + + Args: + question: The natural language query. + *args, **kwargs: Additional query parameters (e.g., n_results, distance_threshold). + + Returns: + The content of the top matching document, or an empty string if no match is found. + """ + results = self.vector_db.retrieve_docs( + queries=[question], collection_name=self.collection_name, n_results=1, **kwargs + ) + if results and results[0]: + best_doc, _ = results[0][0] + return best_doc.get("content", "") + return "" diff --git a/notebook/docling_md_query_engine.ipynb b/notebook/docling_md_query_engine.ipynb index c40bca9f42..53b9d6e984 100644 --- a/notebook/docling_md_query_engine.ipynb +++ b/notebook/docling_md_query_engine.ipynb @@ -21,6 +21,15 @@ "%pip install llama-index==0.12.16" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install sentence-transformers" + ] + }, { "cell_type": "code", "execution_count": null, @@ -157,11 +166,105 @@ "answer = query_engine.query(question)\n", "print(answer)" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Docling MD Query Engine MongoDB" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from pathlib import Path\n", + "\n", + "from autogen.agentchat.contrib.rag.mongodb import MongoDBQueryEngine\n", + "\n", + "\n", + "def main():\n", + " # Set up the MongoDB connection string (adjust if necessary)\n", + " connection_string = \"\"\n", + "\n", + " # Create an instance of MongoDBQueryEngine\n", + " engine = MongoDBQueryEngine(\n", + " connection_string=connection_string,\n", + " db_name=\"my_vector_db\",\n", + " collection_name=\"my_collection\",\n", + " vector_index_name=\"vector_index\",\n", + " )\n", + "\n", + " # Test the database connection\n", + " if engine.connect_db():\n", + " print(\"Connected to MongoDB successfully.\")\n", + " else:\n", + " print(\"Failed to connect to MongoDB.\")\n", + " return\n", + "\n", + " # -------------------------------\n", + " # Initialize the database with documents\n", + " # -------------------------------\n", + " # Create a sample directory with text files to ingest\n", + " sample_dir = Path(\"sample_docs\")\n", + " sample_dir.mkdir(exist_ok=True)\n", + "\n", + " sample_texts = [\n", + " \"This is a sample document about MongoDB vector search.\",\n", + " \"Another document discussing RAG and vector databases.\",\n", + " ]\n", + "\n", + " for i, text in enumerate(sample_texts):\n", + " file_path = sample_dir / f\"doc_{i}.txt\"\n", + " with open(file_path, \"w\", encoding=\"utf-8\") as f:\n", + " f.write(text)\n", + "\n", + " # Initialize the database (creates/overwrites collection, ingests docs, and builds index)\n", + " if engine.init_db(new_doc_dir=str(sample_dir), overwrite=True):\n", + " print(\"Database initialized with sample documents.\")\n", + " else:\n", + " print(\"Database initialization failed.\")\n", + "\n", + " # -------------------------------\n", + " # Add additional records\n", + " # -------------------------------\n", + " add_dir = Path(\"additional_docs\")\n", + " add_dir.mkdir(exist_ok=True)\n", + "\n", + " additional_texts = [\n", + " \"Additional record about MongoDBAtlasVectorSearch integration.\",\n", + " \"Additional record on document ingestion for RAG.\",\n", + " ]\n", + "\n", + " for i, text in enumerate(additional_texts):\n", + " file_path = add_dir / f\"additional_doc_{i}.txt\"\n", + " with open(file_path, \"w\", encoding=\"utf-8\") as f:\n", + " f.write(text)\n", + "\n", + " if engine.add_records(new_doc_dir=str(add_dir)):\n", + " print(\"Additional records added.\")\n", + " else:\n", + " print(\"No additional records were added.\")\n", + "\n", + " # -------------------------------\n", + " # Execute a query\n", + " # -------------------------------\n", + " query_text = \"Explain MongoDB vector search.\"\n", + " result = engine.query(query_text)\n", + " print(\"Query result:\", result)\n", + "\n", + "\n", + "if __name__ == \"__main__\":\n", + " main()" + ] } ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "ag2", "language": "python", "name": "python3" }, @@ -175,7 +278,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.21" + "version": "3.11.11" } }, "nbformat": 4, diff --git a/pyproject.toml b/pyproject.toml index 9b07b10f62..0044a5e33e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -139,6 +139,7 @@ rag = [ "chromadb>=0.5,<1", "llama-index>=0.12,<1", "llama-index-vector-stores-chroma==0.4.1", + "llama-index-vector-stores-mongodb==0.6.0", ] From 463614ee0e9b5877a7c319d7586fab8ce7cede90 Mon Sep 17 00:00:00 2001 From: sitloboi2012 Date: Sat, 22 Feb 2025 14:52:22 +0700 Subject: [PATCH 02/12] update mongodb query engine class again Signed-off-by: sitloboi2012 --- autogen/agentchat/contrib/rag/mongodb.py | 261 +++++++---------------- notebook/docling_md_query_engine.ipynb | 139 ++++++------ 2 files changed, 138 insertions(+), 262 deletions(-) diff --git a/autogen/agentchat/contrib/rag/mongodb.py b/autogen/agentchat/contrib/rag/mongodb.py index e7c86205fd..b10fa6564e 100644 --- a/autogen/agentchat/contrib/rag/mongodb.py +++ b/autogen/agentchat/contrib/rag/mongodb.py @@ -4,17 +4,16 @@ import logging from pathlib import Path -from typing import Any, List, Optional, Union +from typing import Any, Callable, Optional from autogen.agentchat.contrib.rag.query_engine import VectorDbQueryEngine -from autogen.agentchat.contrib.vectordb.base import VectorDBFactory +from autogen.agentchat.contrib.vectordb.base import Document, VectorDBFactory +from autogen.agentchat.contrib.vectordb.mongodb import MongoDBAtlasVectorDB from autogen.import_utils import optional_import_block, require_optional_import with optional_import_block(): - from llama_index.core.llms import LLM - from llama_index.llms.openai import OpenAI + from llama_index.core import SimpleDirectoryReader, StorageContext, VectorStoreIndex from llama_index.vector_stores.mongodb import MongoDBAtlasVectorSearch - from pymongo import MongoClient DEFAULT_COLLECTION_NAME = "docling-parsed-docs" @@ -25,209 +24,97 @@ @require_optional_import(["pymongo", "llama_index"], "rag") class MongoDBQueryEngine(VectorDbQueryEngine): - """ - MongoDBQueryEngine is a production-ready implementation of the VectorDbQueryEngine for MongoDB. - - This engine leverages the VectorDBFactory to instantiate a MongoDB vector database - (MongoDBAtlasVectorDB) and wraps its collection with a LlamaIndex vector store - (MongoDBAtlasVectorSearch) to enable document indexing and retrieval. - - Conceptually, it mirrors the approach used in DoclingMdQueryEngine for ChromaDB. - It provides methods to: - - Connect to the database. - - Initialize (or reinitialize) the collection with documents. - - Add new documents. - - Execute natural language queries against the vector index. - """ - - def __init__( # type: ignore[misc, no-any-unimported, return] + def __init__( self, - connection_string: str, - db_name: str = "vector_db", - collection_name: str = "default_collection", - vector_index_name: str = "vector_index", - embedding_function: Optional[Any] = None, - llm: Optional[LLM] = None, - **kwargs: Any, - ) -> Any | None: - """ - Initializes the MongoDBQueryEngine. - - Args: - connection_string: MongoDB connection string. - db_name: Name of the MongoDB database. - collection_name: Name of the collection to use (default is DEFAULT_COLLECTION_NAME). - vector_index_name: Name of the vector search index. - embedding_function: Function to compute embeddings (if needed by the underlying vector DB). - llm: LLM for query processing (default uses OpenAI's GPT-4 variant). - **kwargs: Additional keyword arguments. - """ - self.connection_string = connection_string - self.db_name = db_name - self.collection_name = collection_name or DEFAULT_COLLECTION_NAME - self.vector_index_name = vector_index_name - - # Set up the LLM; if not provided, use a default OpenAI model. - self.llm = llm or OpenAI(model="gpt-4o", temperature=0.0) # type: ignore - - # Create a MongoDB client for use by the vector search wrapper. - self.mongodb_client = MongoClient(connection_string) - - # Initialize the LlamaIndex-style vector store wrapper for advanced query pipelines. - self.vector_search = MongoDBAtlasVectorSearch( - mongodb_client=self.mongodb_client, - db_name=db_name, - collection_name=collection_name, - vector_index_name=vector_index_name, - **kwargs, - ) - - # Create the full vector database instance via the VectorDBFactory. - self.vector_db = VectorDBFactory.create_vector_db( - "mongodb", + connection_string: str = "", + database_name: str = "vector_db", + embedding_function: Optional[Callable[..., Any]] = None, + collection_name: str = DEFAULT_COLLECTION_NAME, + index_name: str = "vector_index", + ): + super().__init__() + + self.vector_db: MongoDBAtlasVectorDB = VectorDBFactory.create_vector_db( # type: ignore[assignment] + db_type="mongodb", connection_string=connection_string, - database_name=db_name, - collection_name=collection_name, - index_name=vector_index_name, + database_name=database_name, + index_name=index_name, embedding_function=embedding_function, - **kwargs, + collection_name=collection_name, ) + self.vector_search_engine = MongoDBAtlasVectorSearch( + mongodb_client=self.vector_db.client, db_name=database_name, collection_name=collection_name + ) + self.storage_context = StorageContext.from_defaults(vector_store=self.vector_search_engine) + self.indexer = None - def connect_db(self, *args: Any, **kwargs: Any) -> bool: - """ - Connect to the MongoDB database by issuing a ping. - - Returns: - True if the connection is successful; False otherwise. - """ - try: - self.mongodb_client.admin.command("ping") - return True - except Exception as error: - logger.error("Failed to connect to MongoDB: %s", error) - return False - - def init_db( + def init_db( # type: ignore[no-untyped-def] self, - new_doc_dir: Optional[Union[Path, str]] = None, - new_doc_paths: Optional[List[Union[Path, str]]] = None, - overwrite: bool = True, - *args: Any, - **kwargs: Any, - ) -> bool: - """ - Initialize the database with documents. - - This method: - 1. Connects to MongoDB. - 2. Creates (or overwrites) the target collection via the vector DB interface. - 3. Loads documents from a directory and/or file paths. - 4. Inserts the documents into the collection. - 5. Creates the vector search index. + new_doc_dir: Optional[str | Path] = None, + new_doc_paths: Optional[list[str | Path]] = None, + *args, + **kwargs, + ) -> bool: # type: ignore[no-untyped-def] + """Initialize the database with the input documents or records. + + This method initializes database with the input documents or records. + Usually, it takes the following steps, + 1. connecting to a database. + 2. insert records + 3. build indexes etc. Args: - new_doc_dir: Directory containing document files. - new_doc_paths: List of file paths to individual documents. - overwrite: If True, the existing collection is overwritten. - *args, **kwargs: Additional arguments. + new_doc_dir: a dir of input documents that are used to create the records in database. + new_doc_paths: + a list of input documents that are used to create the records in database. + a document can be a path to a file or a url. + *args: Any additional arguments + **kwargs: Any additional keyword arguments Returns: - True if initialization is successful; False otherwise. + bool: True if initialization is successful, False otherwise """ if not self.connect_db(): return False - try: - self.vector_db.create_collection( - collection_name=self.collection_name, overwrite=overwrite, get_or_create=True - ) - except Exception as e: - logger.error("Error creating collection: %s", e) - return False - - # Load documents from file paths and/or a directory. - docs = [] - if new_doc_paths: - for i, doc_path in enumerate(new_doc_paths): - path_obj = Path(doc_path) - if path_obj.is_file(): - content = path_obj.read_text(encoding="utf-8") - docs.append({ - "id": f"doc_{i}", - "content": content, - "metadata": None, - "embedding": None, # Embeddings will be computed internally. - }) - if new_doc_dir: - dir_path = Path(new_doc_dir) - for i, file in enumerate(dir_path.glob("*")): - if file.is_file(): - content = file.read_text(encoding="utf-8") - docs.append({"id": f"doc_dir_{i}", "content": content, "metadata": None, "embedding": None}) - - if docs: - self.vector_db.insert_docs(docs, collection_name=self.collection_name) # type: ignore[arg-type] - - # Create the vector search index. - try: - dimensions = getattr(self.vector_db, "dimensions", 1536) - self.vector_search.create_vector_search_index(dimensions=dimensions, path="embedding", similarity="cosine") - except Exception as e: - logger.error("Error creating vector search index: %s", e) - + self.add_records(new_doc_dir, new_doc_paths) # type: ignore[no-untyped-call] + self.indexer = VectorStoreIndex.from_vector_store( + self.vector_search_engine, storage_context=self.storage_context + ) return True - def add_records( - self, - new_doc_dir: Optional[Union[Path, str]] = None, - new_doc_paths_or_urls: Optional[List[Union[Path, str]]] = None, - *args: Any, - **kwargs: Any, - ) -> bool: + def connect_db(self, *args, **kwargs) -> bool: # type: ignore[no-untyped-def] """ - Add new documents to the existing collection. - - Loads documents from the specified directory and/or file paths and inserts them into the vector DB. + Connect to the MongoDB database by issuing a ping. Returns: - True if records were added successfully; False otherwise. + True if the connection is successful; False otherwise. """ - docs = [] - if new_doc_paths_or_urls: - for i, doc_path in enumerate(new_doc_paths_or_urls): - path_obj = Path(doc_path) - if path_obj.is_file(): - content = path_obj.read_text(encoding="utf-8") - docs.append({"id": f"new_doc_{i}", "content": content, "metadata": None, "embedding": None}) - if new_doc_dir: - dir_path = Path(new_doc_dir) - for i, file in enumerate(dir_path.glob("*")): - if file.is_file(): - content = file.read_text(encoding="utf-8") - docs.append({"id": f"new_doc_dir_{i}", "content": content, "metadata": None, "embedding": None}) - if docs: - self.vector_db.insert_docs(docs, collection_name=self.collection_name) # type: ignore[arg-type] + try: + self.vector_db.client.admin.command("ping") return True - return False - - def query(self, question: str, *args: Any, **kwargs: Any) -> str: - """ - Execute a natural language query against the vector database. + except Exception as error: + logger.error("Failed to connect to MongoDB: %s", error) + return False - Converts the query string into a vector search, retrieves the most relevant document, - and returns its content. + def add_records(self, new_doc_dir=None, new_doc_paths_or_urls=None, *args, **kwargs): # type: ignore[no-untyped-def] + document_list = [] # type: ignore[var-annotated] + if new_doc_dir: + document_list.extend(Path(new_doc_dir).glob("**/*")) + if new_doc_paths_or_urls: + document_list.append(new_doc_paths_or_urls) + + document_reader = SimpleDirectoryReader(input_files=document_list).load_data() + self.vector_db.insert_docs([ + Document( # type: ignore[typeddict-item, typeddict-unknown-key] + id=document.id_, + content=document.text, + metadata=document.metadata, + ) + for document in document_reader + ]) - Args: - question: The natural language query. - *args, **kwargs: Additional query parameters (e.g., n_results, distance_threshold). + def query(self, question, *args, **kwargs): # type: ignore[no-untyped-def] + response = self.indexer.as_chat_engine().query(question) # type: ignore[attr-defined] - Returns: - The content of the top matching document, or an empty string if no match is found. - """ - results = self.vector_db.retrieve_docs( - queries=[question], collection_name=self.collection_name, n_results=1, **kwargs - ) - if results and results[0]: - best_doc, _ = results[0][0] - return best_doc.get("content", "") - return "" + return response diff --git a/notebook/docling_md_query_engine.ipynb b/notebook/docling_md_query_engine.ipynb index 53b9d6e984..6daf3d6a3d 100644 --- a/notebook/docling_md_query_engine.ipynb +++ b/notebook/docling_md_query_engine.ipynb @@ -174,6 +174,17 @@ "# Docling MD Query Engine MongoDB" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from chromadb.utils import embedding_functions\n", + "\n", + "openai_ef = embedding_functions.OpenAIEmbeddingFunction(api_key=\"\", model_name=\"text-embedding-ada-002\")" + ] + }, { "cell_type": "code", "execution_count": null, @@ -181,85 +192,63 @@ "outputs": [], "source": [ "import os\n", - "from pathlib import Path\n", "\n", + "os.environ[\"OPENAI_API_KEY\"] = \"\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "filename = \"/root/ag2/test/agents/experimental/document_agent/pdf_parsed/nvidia_10k_2024.md\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ "from autogen.agentchat.contrib.rag.mongodb import MongoDBQueryEngine\n", "\n", - "\n", - "def main():\n", - " # Set up the MongoDB connection string (adjust if necessary)\n", - " connection_string = \"\"\n", - "\n", - " # Create an instance of MongoDBQueryEngine\n", - " engine = MongoDBQueryEngine(\n", - " connection_string=connection_string,\n", - " db_name=\"my_vector_db\",\n", - " collection_name=\"my_collection\",\n", - " vector_index_name=\"vector_index\",\n", - " )\n", - "\n", - " # Test the database connection\n", - " if engine.connect_db():\n", - " print(\"Connected to MongoDB successfully.\")\n", - " else:\n", - " print(\"Failed to connect to MongoDB.\")\n", - " return\n", - "\n", - " # -------------------------------\n", - " # Initialize the database with documents\n", - " # -------------------------------\n", - " # Create a sample directory with text files to ingest\n", - " sample_dir = Path(\"sample_docs\")\n", - " sample_dir.mkdir(exist_ok=True)\n", - "\n", - " sample_texts = [\n", - " \"This is a sample document about MongoDB vector search.\",\n", - " \"Another document discussing RAG and vector databases.\",\n", - " ]\n", - "\n", - " for i, text in enumerate(sample_texts):\n", - " file_path = sample_dir / f\"doc_{i}.txt\"\n", - " with open(file_path, \"w\", encoding=\"utf-8\") as f:\n", - " f.write(text)\n", - "\n", - " # Initialize the database (creates/overwrites collection, ingests docs, and builds index)\n", - " if engine.init_db(new_doc_dir=str(sample_dir), overwrite=True):\n", - " print(\"Database initialized with sample documents.\")\n", - " else:\n", - " print(\"Database initialization failed.\")\n", - "\n", - " # -------------------------------\n", - " # Add additional records\n", - " # -------------------------------\n", - " add_dir = Path(\"additional_docs\")\n", - " add_dir.mkdir(exist_ok=True)\n", - "\n", - " additional_texts = [\n", - " \"Additional record about MongoDBAtlasVectorSearch integration.\",\n", - " \"Additional record on document ingestion for RAG.\",\n", - " ]\n", - "\n", - " for i, text in enumerate(additional_texts):\n", - " file_path = add_dir / f\"additional_doc_{i}.txt\"\n", - " with open(file_path, \"w\", encoding=\"utf-8\") as f:\n", - " f.write(text)\n", - "\n", - " if engine.add_records(new_doc_dir=str(add_dir)):\n", - " print(\"Additional records added.\")\n", - " else:\n", - " print(\"No additional records were added.\")\n", - "\n", - " # -------------------------------\n", - " # Execute a query\n", - " # -------------------------------\n", - " query_text = \"Explain MongoDB vector search.\"\n", - " result = engine.query(query_text)\n", - " print(\"Query result:\", result)\n", - "\n", - "\n", - "if __name__ == \"__main__\":\n", - " main()" + "query_engine = MongoDBQueryEngine(connection_string=\"\", embedding_function=openai_ef, database_name=\"vector_db_4\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "query_engine.init_db(new_doc_paths=filename)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(query_engine.query(\"How much money did Nvidia spend in research and development\"))" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "query_engine.add_records(new_doc_paths_or_urls=filename)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { From 0adeb5354617e94a248f2b3468b1c789e0a4acb0 Mon Sep 17 00:00:00 2001 From: sitloboi2012 Date: Sat, 22 Feb 2025 19:58:43 +0700 Subject: [PATCH 03/12] update mongodb query engine to use docling Signed-off-by: sitloboi2012 --- autogen/agentchat/contrib/rag/mongodb.py | 48 ++++++++++++++++++------ pyproject.toml | 2 + 2 files changed, 38 insertions(+), 12 deletions(-) diff --git a/autogen/agentchat/contrib/rag/mongodb.py b/autogen/agentchat/contrib/rag/mongodb.py index b10fa6564e..42fd5c9dda 100644 --- a/autogen/agentchat/contrib/rag/mongodb.py +++ b/autogen/agentchat/contrib/rag/mongodb.py @@ -7,12 +7,14 @@ from typing import Any, Callable, Optional from autogen.agentchat.contrib.rag.query_engine import VectorDbQueryEngine -from autogen.agentchat.contrib.vectordb.base import Document, VectorDBFactory +from autogen.agentchat.contrib.vectordb.base import VectorDBFactory from autogen.agentchat.contrib.vectordb.mongodb import MongoDBAtlasVectorDB from autogen.import_utils import optional_import_block, require_optional_import with optional_import_block(): - from llama_index.core import SimpleDirectoryReader, StorageContext, VectorStoreIndex + from llama_index.core import StorageContext, VectorStoreIndex + from llama_index.node_parser.docling import DoclingNodeParser + from llama_index.readers.docling import DoclingReader from llama_index.vector_stores.mongodb import MongoDBAtlasVectorSearch @@ -77,7 +79,9 @@ def init_db( # type: ignore[no-untyped-def] if not self.connect_db(): return False - self.add_records(new_doc_dir, new_doc_paths) # type: ignore[no-untyped-call] + if new_doc_dir or new_doc_paths: + self.add_records(new_doc_dir, new_doc_paths) # type: ignore[no-untyped-call] + self.indexer = VectorStoreIndex.from_vector_store( self.vector_search_engine, storage_context=self.storage_context ) @@ -104,15 +108,35 @@ def add_records(self, new_doc_dir=None, new_doc_paths_or_urls=None, *args, **kwa if new_doc_paths_or_urls: document_list.append(new_doc_paths_or_urls) - document_reader = SimpleDirectoryReader(input_files=document_list).load_data() - self.vector_db.insert_docs([ - Document( # type: ignore[typeddict-item, typeddict-unknown-key] - id=document.id_, - content=document.text, - metadata=document.metadata, - ) - for document in document_reader - ]) + reader = DoclingReader(export_type=DoclingReader.ExportType.JSON) + node_parser = DoclingNodeParser() + + documents = reader.load_data(document_list) + parser_nodes = node_parser._parse_nodes(documents) + + # print("Parser data: ", parser) + + # # document_reader = SimpleDirectoryReader(input_files=document_list).load_data() + # self.vector_db.insert_docs([ + # Document( # type: ignore[typeddict-item, typeddict-unknown-key] + # id=document.id_, + # content=document.get_content(), + # metadata=document.metadata, + # ) + # for document in parser + # ]) + + docs_to_insert = [] + for node in parser_nodes: + doc_dict = { + "id": node.id_, # Ensure the key is 'id' + "content": node.get_content(), + "metadata": node.metadata, + } + docs_to_insert.append(doc_dict) + + # Insert documents into vector DB. + self.vector_db.insert_docs(docs_to_insert) # type: ignore[arg-type] def query(self, question, *args, **kwargs): # type: ignore[no-untyped-def] response = self.indexer.as_chat_engine().query(question) # type: ignore[attr-defined] diff --git a/pyproject.toml b/pyproject.toml index 0044a5e33e..432eb200dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -140,6 +140,8 @@ rag = [ "llama-index>=0.12,<1", "llama-index-vector-stores-chroma==0.4.1", "llama-index-vector-stores-mongodb==0.6.0", + "llama-index-node-parser-docling==0.3.1", + "llama-index-readers-docling==0.3.1" ] From 37356a35f91edbe132fd56c60d0d20e0309ef26a Mon Sep 17 00:00:00 2001 From: sitloboi2012 Date: Sat, 22 Feb 2025 21:56:27 +0700 Subject: [PATCH 04/12] update and finalize the mongodb query engine with documentation Signed-off-by: sitloboi2012 --- autogen/agentchat/contrib/rag/mongodb.py | 229 +++++++++++++++-------- notebook/docling_md_query_engine.ipynb | 37 ++-- 2 files changed, 170 insertions(+), 96 deletions(-) diff --git a/autogen/agentchat/contrib/rag/mongodb.py b/autogen/agentchat/contrib/rag/mongodb.py index 42fd5c9dda..878dec0c27 100644 --- a/autogen/agentchat/contrib/rag/mongodb.py +++ b/autogen/agentchat/contrib/rag/mongodb.py @@ -4,7 +4,7 @@ import logging from pathlib import Path -from typing import Any, Callable, Optional +from typing import Any, Callable, List, Optional, Union from autogen.agentchat.contrib.rag.query_engine import VectorDbQueryEngine from autogen.agentchat.contrib.vectordb.base import VectorDBFactory @@ -12,12 +12,10 @@ from autogen.import_utils import optional_import_block, require_optional_import with optional_import_block(): - from llama_index.core import StorageContext, VectorStoreIndex - from llama_index.node_parser.docling import DoclingNodeParser - from llama_index.readers.docling import DoclingReader + from llama_index.core import Document, SimpleDirectoryReader, StorageContext, VectorStoreIndex + from llama_index.core.node_parser import SentenceSplitter from llama_index.vector_stores.mongodb import MongoDBAtlasVectorSearch - DEFAULT_COLLECTION_NAME = "docling-parsed-docs" logging.basicConfig(level=logging.INFO) @@ -26,6 +24,19 @@ @require_optional_import(["pymongo", "llama_index"], "rag") class MongoDBQueryEngine(VectorDbQueryEngine): + """ + A query engine backed by MongoDB Atlas that supports document insertion and querying. + + This engine initializes a vector database, builds an index from input documents, + and allows querying using the chat engine interface. + + Attributes: + vector_db (MongoDBAtlasVectorDB): The MongoDB vector database instance. + vector_search_engine (MongoDBAtlasVectorSearch): The vector search engine. + storage_context (StorageContext): The storage context for the vector store. + indexer (Optional[VectorStoreIndex]): The index built from the documents. + """ + def __init__( self, connection_string: str = "", @@ -34,8 +45,17 @@ def __init__( collection_name: str = DEFAULT_COLLECTION_NAME, index_name: str = "vector_index", ): - super().__init__() + """ + Initialize the MongoDBQueryEngine. + Args: + connection_string (str): The MongoDB connection string. + database_name (str): The name of the database to use. + embedding_function (Optional[Callable[..., Any]]): The function to compute embeddings. + collection_name (str): The name of the collection to use. + index_name (str): The name of the vector index. + """ + super().__init__() self.vector_db: MongoDBAtlasVectorDB = VectorDBFactory.create_vector_db( # type: ignore[assignment] db_type="mongodb", connection_string=connection_string, @@ -45,100 +65,151 @@ def __init__( collection_name=collection_name, ) self.vector_search_engine = MongoDBAtlasVectorSearch( - mongodb_client=self.vector_db.client, db_name=database_name, collection_name=collection_name + mongodb_client=self.vector_db.client, + db_name=database_name, + collection_name=collection_name, ) self.storage_context = StorageContext.from_defaults(vector_store=self.vector_search_engine) - self.indexer = None + self.indexer: Optional[VectorStoreIndex] = None # type: ignore[no-any-unimported] + + def connect_db(self, *args, **kwargs) -> bool: # type: ignore[no-untyped-def] + """ + Connect to the MongoDB database by issuing a ping. + + Returns: + bool: True if the connection is successful; False otherwise. + """ + try: + self.vector_db.client.admin.command("ping") + logger.info("Connected to MongoDB successfully.") + return True + except Exception as error: + logger.error("Failed to connect to MongoDB: %s", error) + return False def init_db( # type: ignore[no-untyped-def] self, - new_doc_dir: Optional[str | Path] = None, - new_doc_paths: Optional[list[str | Path]] = None, + new_doc_dir: Optional[Union[str, Path]] = None, + new_doc_paths: Optional[List[Union[str, Path]]] = None, *args, **kwargs, - ) -> bool: # type: ignore[no-untyped-def] - """Initialize the database with the input documents or records. - - This method initializes database with the input documents or records. - Usually, it takes the following steps, - 1. connecting to a database. - 2. insert records - 3. build indexes etc. + ) -> bool: + """ + Initialize the database by loading documents from the given directory or file paths, + then building an index. Args: - new_doc_dir: a dir of input documents that are used to create the records in database. - new_doc_paths: - a list of input documents that are used to create the records in database. - a document can be a path to a file or a url. - *args: Any additional arguments - **kwargs: Any additional keyword arguments + new_doc_dir (Optional[Union[str, Path]]): Directory containing input documents. + new_doc_paths (Optional[List[Union[str, Path]]]): List of document paths or URLs. Returns: - bool: True if initialization is successful, False otherwise + bool: True if initialization is successful; False otherwise. """ if not self.connect_db(): return False - if new_doc_dir or new_doc_paths: - self.add_records(new_doc_dir, new_doc_paths) # type: ignore[no-untyped-call] - - self.indexer = VectorStoreIndex.from_vector_store( - self.vector_search_engine, storage_context=self.storage_context - ) - return True + # Gather document paths. + document_list: List[Union[str, Path]] = [] + if new_doc_dir: + document_list.extend(Path(new_doc_dir).glob("**/*")) + if new_doc_paths: + document_list.extend(new_doc_paths) - def connect_db(self, *args, **kwargs) -> bool: # type: ignore[no-untyped-def] - """ - Connect to the MongoDB database by issuing a ping. + if not document_list: + logger.warning("No input documents provided to initialize the database.") + return False - Returns: - True if the connection is successful; False otherwise. - """ try: - self.vector_db.client.admin.command("ping") + documents = SimpleDirectoryReader(input_files=document_list).load_data() + self.indexer = VectorStoreIndex.from_documents(documents, storage_context=self.storage_context) + logger.info("Database initialized with %d documents.", len(documents)) return True - except Exception as error: - logger.error("Failed to connect to MongoDB: %s", error) + except Exception as e: + logger.error("Failed to initialize the database: %s", e) return False - def add_records(self, new_doc_dir=None, new_doc_paths_or_urls=None, *args, **kwargs): # type: ignore[no-untyped-def] - document_list = [] # type: ignore[var-annotated] + def add_records( # type: ignore[no-untyped-def, override] + self, + new_doc_dir: Optional[Union[str, Path]] = None, + new_doc_paths_or_urls: Optional[Union[List[Union[str, Path]], Union[str, Path]]] = None, + chunk_size: int = 500, + chunk_overlap: int = 300, + *args, + **kwargs, + ) -> None: + """ + Load, parse, and insert documents into the index. + + This method uses a SentenceSplitter to break documents into chunks before insertion. + + Args: + new_doc_dir (Optional[Union[str, Path]]): Directory containing input documents. + new_doc_paths_or_urls (Optional[Union[List[Union[str, Path]], Union[str, Path]]]): + List of document paths or a single document path/URL. + """ + # Collect document paths. + document_list: List[Union[str, Path]] = [] if new_doc_dir: document_list.extend(Path(new_doc_dir).glob("**/*")) if new_doc_paths_or_urls: - document_list.append(new_doc_paths_or_urls) - - reader = DoclingReader(export_type=DoclingReader.ExportType.JSON) - node_parser = DoclingNodeParser() - - documents = reader.load_data(document_list) - parser_nodes = node_parser._parse_nodes(documents) - - # print("Parser data: ", parser) - - # # document_reader = SimpleDirectoryReader(input_files=document_list).load_data() - # self.vector_db.insert_docs([ - # Document( # type: ignore[typeddict-item, typeddict-unknown-key] - # id=document.id_, - # content=document.get_content(), - # metadata=document.metadata, - # ) - # for document in parser - # ]) - - docs_to_insert = [] - for node in parser_nodes: - doc_dict = { - "id": node.id_, # Ensure the key is 'id' - "content": node.get_content(), - "metadata": node.metadata, - } - docs_to_insert.append(doc_dict) - - # Insert documents into vector DB. - self.vector_db.insert_docs(docs_to_insert) # type: ignore[arg-type] - - def query(self, question, *args, **kwargs): # type: ignore[no-untyped-def] - response = self.indexer.as_chat_engine().query(question) # type: ignore[attr-defined] - - return response + if isinstance(new_doc_paths_or_urls, (list, tuple)): + document_list.extend(new_doc_paths_or_urls) + else: + document_list.append(new_doc_paths_or_urls) + + if not document_list: + logger.warning("No documents found for adding records.") + return + + try: + raw_documents = SimpleDirectoryReader(input_files=document_list).load_data() + except Exception as e: + logger.error("Error loading documents: %s", e) + return + + node_parser = SentenceSplitter( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + ) + doc_chunks: List[Document] = [] # type: ignore[no-any-unimported] + + for document in raw_documents: + try: + # Split the document text into chunks. + chunks = node_parser.split_text(document.text) + metadata = document.metadata + doc_id = document.id_ # Ensure this is not a tuple. + for chunk in chunks: + new_doc = Document(text=chunk, id_=doc_id, metadata=metadata) + doc_chunks.append(new_doc) + except Exception as e: + logger.error("Error parsing document %s: %s", document.id_, e) + + if not doc_chunks: + logger.warning("No document chunks created for insertion.") + return + + # Insert document chunks using the indexer. + try: + for doc in doc_chunks: + self.indexer.insert(doc) # type: ignore[union-attr] + logger.info("Inserted %d document chunks successfully.", len(doc_chunks)) + except Exception as e: + logger.error("Error inserting documents into the index: %s", e) + + def query(self, question: str, *args, **kwargs) -> Any: # type: ignore[no-untyped-def] + """ + Query the index using the given question. + + Args: + question (str): The query string. + + Returns: + Any: The response from the chat engine, or None if an error occurs. + """ + try: + response = self.indexer.as_chat_engine().query(question) # type: ignore[union-attr] + return response + except Exception as e: + logger.error("Query failed: %s", e) + return None diff --git a/notebook/docling_md_query_engine.ipynb b/notebook/docling_md_query_engine.ipynb index 6daf3d6a3d..2bfbe59ec0 100644 --- a/notebook/docling_md_query_engine.ipynb +++ b/notebook/docling_md_query_engine.ipynb @@ -174,17 +174,6 @@ "# Docling MD Query Engine MongoDB" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from chromadb.utils import embedding_functions\n", - "\n", - "openai_ef = embedding_functions.OpenAIEmbeddingFunction(api_key=\"\", model_name=\"text-embedding-ada-002\")" - ] - }, { "cell_type": "code", "execution_count": null, @@ -193,7 +182,13 @@ "source": [ "import os\n", "\n", - "os.environ[\"OPENAI_API_KEY\"] = \"\"" + "from chromadb.utils import embedding_functions\n", + "\n", + "os.environ[\"OPENAI_API_KEY\"] = \"\"\n", + "openai_ef = embedding_functions.OpenAIEmbeddingFunction(\n", + " api_key=\"\",\n", + " model_name=\"text-embedding-ada-002\",\n", + ")" ] }, { @@ -202,7 +197,7 @@ "metadata": {}, "outputs": [], "source": [ - "filename = \"/root/ag2/test/agents/experimental/document_agent/pdf_parsed/nvidia_10k_2024.md\"" + "input_dir = \"/root/ag2/test/agents/experimental/document_agent/pdf_parsed/\"" ] }, { @@ -213,7 +208,11 @@ "source": [ "from autogen.agentchat.contrib.rag.mongodb import MongoDBQueryEngine\n", "\n", - "query_engine = MongoDBQueryEngine(connection_string=\"\", embedding_function=openai_ef, database_name=\"vector_db_4\")" + "query_engine = MongoDBQueryEngine(\n", + " connection_string=\"\",\n", + " embedding_function=openai_ef,\n", + " database_name=\"vector_db_1\",\n", + ")" ] }, { @@ -222,7 +221,7 @@ "metadata": {}, "outputs": [], "source": [ - "query_engine.init_db(new_doc_paths=filename)" + "query_engine.init_db(new_doc_paths=[input_dir + \"nvidia_10k_2024.md\"])" ] }, { @@ -240,7 +239,7 @@ "metadata": {}, "outputs": [], "source": [ - "query_engine.add_records(new_doc_paths_or_urls=filename)" + "query_engine.add_records(new_doc_paths_or_urls=[input_dir + \"Toast_financial_report.md\"])" ] }, { @@ -248,7 +247,11 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "question = \"What is the trading symbol for Toast\"\n", + "answer = query_engine.query(question)\n", + "print(answer)" + ] } ], "metadata": { From d5c928d877070d4954a996147cb9f490d256ae13 Mon Sep 17 00:00:00 2001 From: sitloboi2012 Date: Sun, 23 Feb 2025 23:36:29 +0700 Subject: [PATCH 05/12] refactor the add_records again to simplify the solution, update based on comment Signed-off-by: sitloboi2012 --- autogen/agentchat/contrib/rag/mongodb.py | 51 +++++++----------------- notebook/docling_md_query_engine.ipynb | 7 ++-- pyproject.toml | 4 +- 3 files changed, 19 insertions(+), 43 deletions(-) diff --git a/autogen/agentchat/contrib/rag/mongodb.py b/autogen/agentchat/contrib/rag/mongodb.py index 878dec0c27..04846ec8c2 100644 --- a/autogen/agentchat/contrib/rag/mongodb.py +++ b/autogen/agentchat/contrib/rag/mongodb.py @@ -6,14 +6,12 @@ from pathlib import Path from typing import Any, Callable, List, Optional, Union -from autogen.agentchat.contrib.rag.query_engine import VectorDbQueryEngine from autogen.agentchat.contrib.vectordb.base import VectorDBFactory from autogen.agentchat.contrib.vectordb.mongodb import MongoDBAtlasVectorDB from autogen.import_utils import optional_import_block, require_optional_import with optional_import_block(): - from llama_index.core import Document, SimpleDirectoryReader, StorageContext, VectorStoreIndex - from llama_index.core.node_parser import SentenceSplitter + from llama_index.core import SimpleDirectoryReader, StorageContext, VectorStoreIndex from llama_index.vector_stores.mongodb import MongoDBAtlasVectorSearch DEFAULT_COLLECTION_NAME = "docling-parsed-docs" @@ -23,7 +21,7 @@ @require_optional_import(["pymongo", "llama_index"], "rag") -class MongoDBQueryEngine(VectorDbQueryEngine): +class MongoDBQueryEngine: """ A query engine backed by MongoDB Atlas that supports document insertion and querying. @@ -70,9 +68,9 @@ def __init__( collection_name=collection_name, ) self.storage_context = StorageContext.from_defaults(vector_store=self.vector_search_engine) - self.indexer: Optional[VectorStoreIndex] = None # type: ignore[no-any-unimported] + self.index: Optional[VectorStoreIndex] = None # type: ignore[no-any-unimported] - def connect_db(self, *args, **kwargs) -> bool: # type: ignore[no-untyped-def] + def connect_db(self, *args: Any, **kwargs: Any) -> bool: """ Connect to the MongoDB database by issuing a ping. @@ -87,12 +85,12 @@ def connect_db(self, *args, **kwargs) -> bool: # type: ignore[no-untyped-def] logger.error("Failed to connect to MongoDB: %s", error) return False - def init_db( # type: ignore[no-untyped-def] + def init_db( self, new_doc_dir: Optional[Union[str, Path]] = None, new_doc_paths: Optional[List[Union[str, Path]]] = None, - *args, - **kwargs, + *args: Any, + **kwargs: Any, ) -> bool: """ Initialize the database by loading documents from the given directory or file paths, @@ -128,14 +126,12 @@ def init_db( # type: ignore[no-untyped-def] logger.error("Failed to initialize the database: %s", e) return False - def add_records( # type: ignore[no-untyped-def, override] + def add_records( self, new_doc_dir: Optional[Union[str, Path]] = None, new_doc_paths_or_urls: Optional[Union[List[Union[str, Path]], Union[str, Path]]] = None, - chunk_size: int = 500, - chunk_overlap: int = 300, - *args, - **kwargs, + *args: Any, + **kwargs: Any, ) -> None: """ Load, parse, and insert documents into the index. @@ -167,37 +163,18 @@ def add_records( # type: ignore[no-untyped-def, override] logger.error("Error loading documents: %s", e) return - node_parser = SentenceSplitter( - chunk_size=chunk_size, - chunk_overlap=chunk_overlap, - ) - doc_chunks: List[Document] = [] # type: ignore[no-any-unimported] - - for document in raw_documents: - try: - # Split the document text into chunks. - chunks = node_parser.split_text(document.text) - metadata = document.metadata - doc_id = document.id_ # Ensure this is not a tuple. - for chunk in chunks: - new_doc = Document(text=chunk, id_=doc_id, metadata=metadata) - doc_chunks.append(new_doc) - except Exception as e: - logger.error("Error parsing document %s: %s", document.id_, e) - - if not doc_chunks: + if not raw_documents: logger.warning("No document chunks created for insertion.") return - # Insert document chunks using the indexer. try: - for doc in doc_chunks: + for doc in raw_documents: self.indexer.insert(doc) # type: ignore[union-attr] - logger.info("Inserted %d document chunks successfully.", len(doc_chunks)) + logger.info("Inserted %d document chunks successfully.", len(raw_documents)) except Exception as e: logger.error("Error inserting documents into the index: %s", e) - def query(self, question: str, *args, **kwargs) -> Any: # type: ignore[no-untyped-def] + def query(self, question: str, *args: Any, **kwargs: Any) -> Any: """ Query the index using the given question. diff --git a/notebook/docling_md_query_engine.ipynb b/notebook/docling_md_query_engine.ipynb index 2bfbe59ec0..efeaf577af 100644 --- a/notebook/docling_md_query_engine.ipynb +++ b/notebook/docling_md_query_engine.ipynb @@ -211,7 +211,7 @@ "query_engine = MongoDBQueryEngine(\n", " connection_string=\"\",\n", " embedding_function=openai_ef,\n", - " database_name=\"vector_db_1\",\n", + " database_name=\"vector_db_2\",\n", ")" ] }, @@ -221,7 +221,8 @@ "metadata": {}, "outputs": [], "source": [ - "query_engine.init_db(new_doc_paths=[input_dir + \"nvidia_10k_2024.md\"])" + "# nvidia_10k_2024.md\n", + "query_engine.init_db(new_doc_paths=[input_dir + \"Toast_financial_report.md\"])" ] }, { @@ -239,7 +240,7 @@ "metadata": {}, "outputs": [], "source": [ - "query_engine.add_records(new_doc_paths_or_urls=[input_dir + \"Toast_financial_report.md\"])" + "query_engine.add_records(new_doc_paths_or_urls=[input_dir + \"nvidia_10k_2024.md\"])" ] }, { diff --git a/pyproject.toml b/pyproject.toml index 432eb200dc..2843f39f9c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -139,9 +139,7 @@ rag = [ "chromadb>=0.5,<1", "llama-index>=0.12,<1", "llama-index-vector-stores-chroma==0.4.1", - "llama-index-vector-stores-mongodb==0.6.0", - "llama-index-node-parser-docling==0.3.1", - "llama-index-readers-docling==0.3.1" + "llama-index-vector-stores-mongodb==0.6.0" ] From b8747d410438ee59c9c9588427d97f2e73bdd86f Mon Sep 17 00:00:00 2001 From: sitloboi2012 Date: Mon, 24 Feb 2025 00:20:01 +0700 Subject: [PATCH 06/12] add on test case for mongodb query engine WIP Signed-off-by: sitloboi2012 --- .../contrib/rag/test_mongodb_query_engine.py | 156 ++++++++++++++++++ 1 file changed, 156 insertions(+) create mode 100644 test/agentchat/contrib/rag/test_mongodb_query_engine.py diff --git a/test/agentchat/contrib/rag/test_mongodb_query_engine.py b/test/agentchat/contrib/rag/test_mongodb_query_engine.py new file mode 100644 index 0000000000..580560d0d5 --- /dev/null +++ b/test/agentchat/contrib/rag/test_mongodb_query_engine.py @@ -0,0 +1,156 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 +# +# Portions derived from https://github.com/microsoft/autogen are under the MIT License. +# SPDX-License-Identifier: MIT +# !/usr/bin/env python3 -m pytest + +import unittest +from pathlib import Path +from unittest.mock import MagicMock, patch + +# Replace 'your_module' with the actual module name where MongoDBQueryEngine is defined. +from autogen.agentchat.contrib.rag.mongodb import MongoDBQueryEngine + + +# A dummy client to simulate MongoDB's behavior. +class DummyCollection: + def __init__(self): + # Some methods might be called on the collection; add those as needed. + self.list_collection_names = MagicMock(return_value=[]) + + +class DummyDatabase: + def __init__(self): + # Instead of returning a new DummyCollection every time, you can use one or simulate a dict. + self._dummy_collection = DummyCollection() + + def __getitem__(self, key): + # Return a dummy collection when a collection is accessed. + return self._dummy_collection + + def list_collection_names(self): + # Provide a dummy implementation for list_collection_names. + return [] + + +class DummyClient: + def __init__(self, should_fail=False): + self.should_fail = should_fail + self.admin = MagicMock() + + def command(self, command): + if self.should_fail: + raise Exception("Ping failed") + return "ok" + + def __getitem__(self, key): + # Return a DummyDatabase when a database is accessed. + return DummyDatabase() + + +# A dummy chat engine to simulate query responses. +class DummyChatEngine: + def query(self, question): + return f"Response to: {question}" + + +class TestMongoDBQueryEngine(unittest.TestCase): + def setUp(self): + # Patch the vector database factory and search engine constructors. + patcher_factory = patch("autogen.agentchat.contrib.vectordb.base.VectorDBFactory.create_vector_db") + self.mock_create_vector_db = patcher_factory.start() + self.addCleanup(patcher_factory.stop) + + patcher_search = patch("llama_index.vector_stores.mongodb.MongoDBAtlasVectorSearch") + self.mock_mongo_search = patcher_search.start() + self.addCleanup(patcher_search.stop) + + # Set up a dummy MongoDB client. + self.dummy_client = DummyClient() + dummy_vector_db = MagicMock() + dummy_vector_db.client = self.dummy_client + self.mock_create_vector_db.return_value = dummy_vector_db + + # Instantiate the engine with dummy parameters. + self.engine = MongoDBQueryEngine(connection_string="dummy", database_name="test_db") + # Pre-assign a dummy indexer (used in query and add_records). + self.engine.indexer = MagicMock() + + def test_connect_db_success(self): + # Simulate a successful ping. + self.dummy_client.should_fail = False + result = self.engine.connect_db() + self.assertTrue(result) + # Verify that the ping command was called. + self.dummy_client.admin.command.assert_called_with("ping") + + def test_connect_db_failure(self): + # Simulate a failure during the ping. + self.dummy_client.should_fail = True + self.dummy_client.admin.command.side_effect = Exception("Ping failed") + result = self.engine.connect_db() + self.assertFalse(result) + + @patch("llama_index.core.SimpleDirectoryReader") + @patch("your_module.VectorStoreIndex.from_documents") + def test_init_db_with_documents(self, mock_from_documents, mock_simple_dir_reader): + # Create dummy documents. + dummy_documents = ["doc1", "doc2"] + reader_instance = MagicMock() + reader_instance.load_data.return_value = dummy_documents + mock_simple_dir_reader.return_value = reader_instance + + # Patch Path.glob to return a list with at least one file. + with patch.object(Path, "glob", return_value=[Path("dummy.txt")]): + result = self.engine.init_db(new_doc_dir="dummy_dir") + self.assertTrue(result) + # Ensure that the index was built with the dummy documents. + mock_from_documents.assert_called_with(dummy_documents, storage_context=self.engine.storage_context) + + def test_init_db_no_documents(self): + # Without any document directory or file paths, the method should return False. + result = self.engine.init_db() + self.assertFalse(result) + + @patch("your_module.SimpleDirectoryReader") + def test_add_records_no_documents(self, mock_simple_dir_reader): + # When no document paths or directory are provided, expect a warning and no processing. + with self.assertLogs(level="WARNING") as cm: + self.engine.add_records() + self.assertTrue(any("No documents found for adding records." in msg for msg in cm.output)) + + @patch("your_module.SimpleDirectoryReader") + def test_add_records_with_documents(self, mock_simple_dir_reader): + # Create dummy documents to be loaded. + dummy_documents = ["doc1", "doc2", "doc3"] + reader_instance = MagicMock() + reader_instance.load_data.return_value = dummy_documents + mock_simple_dir_reader.return_value = reader_instance + + # Ensure that the indexer's insert method is callable. + self.engine.indexer.insert = MagicMock() + + # Provide a list of dummy document paths. + self.engine.add_records(new_doc_paths_or_urls=["dummy1.txt", "dummy2.txt"]) + # Verify that insert was called for each dummy document. + self.assertEqual(self.engine.indexer.insert.call_count, len(dummy_documents)) + + def test_query_success(self): + # Simulate a working chat engine. + dummy_chat_engine = DummyChatEngine() + self.engine.indexer.as_chat_engine.return_value = dummy_chat_engine + + response = self.engine.query("Test question") + self.assertEqual(response, "Response to: Test question") + + def test_query_failure(self): + # Simulate failure by having as_chat_engine raise an exception. + self.engine.indexer.as_chat_engine.side_effect = Exception("Query failed") + response = self.engine.query("Test question") + self.assertIsNone(response) + + +if __name__ == "__main__": + unittest.main() From dcb2713639ab0fb6e918cf506bfdee6f14b265c0 Mon Sep 17 00:00:00 2001 From: sitloboi2012 Date: Mon, 24 Feb 2025 00:21:33 +0700 Subject: [PATCH 07/12] add on test case for mongodb query engine WIP Signed-off-by: sitloboi2012 --- .../contrib/rag/test_mongodb_query_engine.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/test/agentchat/contrib/rag/test_mongodb_query_engine.py b/test/agentchat/contrib/rag/test_mongodb_query_engine.py index 580560d0d5..085a4457bf 100644 --- a/test/agentchat/contrib/rag/test_mongodb_query_engine.py +++ b/test/agentchat/contrib/rag/test_mongodb_query_engine.py @@ -17,23 +17,25 @@ # A dummy client to simulate MongoDB's behavior. class DummyCollection: def __init__(self): - # Some methods might be called on the collection; add those as needed. self.list_collection_names = MagicMock(return_value=[]) class DummyDatabase: def __init__(self): - # Instead of returning a new DummyCollection every time, you can use one or simulate a dict. self._dummy_collection = DummyCollection() def __getitem__(self, key): - # Return a dummy collection when a collection is accessed. + # When accessing a collection via subscripting, return the dummy collection. return self._dummy_collection def list_collection_names(self): - # Provide a dummy implementation for list_collection_names. + # Return an empty list or a list of dummy collection names. return [] + def create_collection(self, name, *args, **kwargs): + # Simulate creation of a new collection by returning a dummy collection. + return DummyCollection() + class DummyClient: def __init__(self, should_fail=False): @@ -46,7 +48,7 @@ def command(self, command): return "ok" def __getitem__(self, key): - # Return a DummyDatabase when a database is accessed. + # Return a dummy database when subscripting the client. return DummyDatabase() From 0b896248da59475bd51898e979622aa5f6e81de8 Mon Sep 17 00:00:00 2001 From: sitloboi2012 Date: Mon, 24 Feb 2025 22:02:32 +0700 Subject: [PATCH 08/12] rename the file to mongodb_query_engine.py, rework the init_db and connect_db to align with reviewer's comment and expectation, update the jupyter notebook again to align with the new mongodb_query_engine code, add on llm arg into the query function Signed-off-by: sitloboi2012 --- .../{mongodb.py => mongodb_query_engine.py} | 124 ++++++++++++------ notebook/docling_md_query_engine.ipynb | 31 +++-- .../contrib/rag/test_mongodb_query_engine.py | 2 +- 3 files changed, 110 insertions(+), 47 deletions(-) rename autogen/agentchat/contrib/rag/{mongodb.py => mongodb_query_engine.py} (54%) diff --git a/autogen/agentchat/contrib/rag/mongodb.py b/autogen/agentchat/contrib/rag/mongodb_query_engine.py similarity index 54% rename from autogen/agentchat/contrib/rag/mongodb.py rename to autogen/agentchat/contrib/rag/mongodb_query_engine.py index 04846ec8c2..98810e0836 100644 --- a/autogen/agentchat/contrib/rag/mongodb.py +++ b/autogen/agentchat/contrib/rag/mongodb_query_engine.py @@ -11,8 +11,11 @@ from autogen.import_utils import optional_import_block, require_optional_import with optional_import_block(): + from langchain.base_language import BaseLanguageModel from llama_index.core import SimpleDirectoryReader, StorageContext, VectorStoreIndex + from llama_index.core.llms.llm import LLM from llama_index.vector_stores.mongodb import MongoDBAtlasVectorSearch + from pymongo import MongoClient DEFAULT_COLLECTION_NAME = "docling-parsed-docs" @@ -46,39 +49,74 @@ def __init__( """ Initialize the MongoDBQueryEngine. + Note: The actual connection and creation of the vector database is deferred to + connect_db (to use an existing collection) or init_db (to create a new collection). + """ + self.connection_string = connection_string + self.database_name = database_name + self.embedding_function = embedding_function + self.collection_name = collection_name + self.index_name = index_name + + # These will be initialized later. + self.vector_db: Optional[MongoDBAtlasVectorDB] = None + self.vector_search_engine = None + self.storage_context = None + self.index: Optional[VectorStoreIndex] = None # type: ignore[no-any-unimported] + + def _setup_vector_db(self, overwrite: bool) -> None: + """ + Helper method to create the vector database, vector search engine, and storage context. + Args: - connection_string (str): The MongoDB connection string. - database_name (str): The name of the database to use. - embedding_function (Optional[Callable[..., Any]]): The function to compute embeddings. - collection_name (str): The name of the collection to use. - index_name (str): The name of the vector index. + overwrite (bool): If True, create a new collection (overwriting if exists). + If False, use an existing collection. """ - super().__init__() - self.vector_db: MongoDBAtlasVectorDB = VectorDBFactory.create_vector_db( # type: ignore[assignment] + # Pass the overwrite flag to the factory if supported. + self.vector_db: MongoDBAtlasVectorDB = VectorDBFactory.create_vector_db( # type: ignore[assignment, no-redef] db_type="mongodb", - connection_string=connection_string, - database_name=database_name, - index_name=index_name, - embedding_function=embedding_function, - collection_name=collection_name, + connection_string=self.connection_string, + database_name=self.database_name, + index_name=self.index_name, + embedding_function=self.embedding_function, + collection_name=self.collection_name, + overwrite=overwrite, # new parameter to control creation behavior ) self.vector_search_engine = MongoDBAtlasVectorSearch( - mongodb_client=self.vector_db.client, - db_name=database_name, - collection_name=collection_name, + mongodb_client=self.vector_db.client, # type: ignore[union-attr] + db_name=self.database_name, + collection_name=self.collection_name, ) self.storage_context = StorageContext.from_defaults(vector_store=self.vector_search_engine) - self.index: Optional[VectorStoreIndex] = None # type: ignore[no-any-unimported] + self.index = VectorStoreIndex.from_vector_store(self.vector_search_engine, storage_context=self.storage_context) - def connect_db(self, *args: Any, **kwargs: Any) -> bool: + def connect_db(self, overwrite: bool = False, *args: Any, **kwargs: Any) -> bool: """ - Connect to the MongoDB database by issuing a ping. + Connect to the MongoDB database by issuing a ping using an existing collection. + This method first checks if the target database and collection exist. + - If not, it raises an error instructing the user to run init_db. + - If the collection exists and overwrite is True, it reinitializes the database. + - Otherwise, it uses the existing collection. Returns: bool: True if the connection is successful; False otherwise. """ try: - self.vector_db.client.admin.command("ping") + # Check if the target collection exists. + client = MongoClient(self.connection_string) + db = client[self.database_name] + if self.collection_name not in db.list_collection_names(): + raise ValueError( + f"Collection '{self.collection_name}' not found in database '{self.database_name}'. " + "Please run init_db to create a new collection." + ) + # Reinitialize if the caller requested overwrite. + if overwrite: + logger.info("Overwriting existing collection as requested.") + self._setup_vector_db(overwrite=True) + else: + self._setup_vector_db(overwrite=False) + self.vector_db.client.admin.command("ping") # type: ignore[union-attr] logger.info("Connected to MongoDB successfully.") return True except Exception as error: @@ -94,7 +132,8 @@ def init_db( ) -> bool: """ Initialize the database by loading documents from the given directory or file paths, - then building an index. + then building an index. This method is intended for first-time creation of the database, + so it expects that the collection does not already exist (i.e. overwrite is False). Args: new_doc_dir (Optional[Union[str, Path]]): Directory containing input documents. @@ -103,23 +142,31 @@ def init_db( Returns: bool: True if initialization is successful; False otherwise. """ - if not self.connect_db(): - return False - - # Gather document paths. - document_list: List[Union[str, Path]] = [] - if new_doc_dir: - document_list.extend(Path(new_doc_dir).glob("**/*")) - if new_doc_paths: - document_list.extend(new_doc_paths) - - if not document_list: - logger.warning("No input documents provided to initialize the database.") - return False - try: + # Check if the collection already exists. + client = MongoClient(self.connection_string) + db = client[self.database_name] + if self.collection_name in db.list_collection_names(): + raise ValueError( + f"Collection '{self.collection_name}' already exists in database '{self.database_name}'. " + "Use connect_db with overwrite=True to reinitialize it." + ) + # Set up the database without overwriting. + self._setup_vector_db(overwrite=False) + self.vector_db.client.admin.command("ping") # type: ignore[union-attr] + # Gather document paths. + document_list: List[Union[str, Path]] = [] + if new_doc_dir: + document_list.extend(Path(new_doc_dir).glob("**/*")) + if new_doc_paths: + document_list.extend(new_doc_paths) + + if not document_list: + logger.warning("No input documents provided to initialize the database.") + return False + documents = SimpleDirectoryReader(input_files=document_list).load_data() - self.indexer = VectorStoreIndex.from_documents(documents, storage_context=self.storage_context) + self.index = VectorStoreIndex.from_documents(documents, storage_context=self.storage_context) logger.info("Database initialized with %d documents.", len(documents)) return True except Exception as e: @@ -169,23 +216,24 @@ def add_records( try: for doc in raw_documents: - self.indexer.insert(doc) # type: ignore[union-attr] + self.index.insert(doc) # type: ignore[union-attr] logger.info("Inserted %d document chunks successfully.", len(raw_documents)) except Exception as e: logger.error("Error inserting documents into the index: %s", e) - def query(self, question: str, *args: Any, **kwargs: Any) -> Any: + def query(self, question: str, llm: Union[str, LLM, "BaseLanguageModel"], *args: Any, **kwargs: Any) -> Any: # type: ignore[no-any-unimported, type-arg] """ Query the index using the given question. Args: question (str): The query string. + llm (Union[str, LLM, BaseLanguageModel]): The language model to use. Returns: Any: The response from the chat engine, or None if an error occurs. """ try: - response = self.indexer.as_chat_engine().query(question) # type: ignore[union-attr] + response = self.index.as_chat_engine(llm=llm).query(question) # type: ignore[union-attr] return response except Exception as e: logger.error("Query failed: %s", e) diff --git a/notebook/docling_md_query_engine.ipynb b/notebook/docling_md_query_engine.ipynb index efeaf577af..1068784de0 100644 --- a/notebook/docling_md_query_engine.ipynb +++ b/notebook/docling_md_query_engine.ipynb @@ -27,7 +27,8 @@ "metadata": {}, "outputs": [], "source": [ - "%pip install sentence-transformers" + "%pip install sentence-transformers\n", + "%pip install llama-index-llms-langchain" ] }, { @@ -183,12 +184,15 @@ "import os\n", "\n", "from chromadb.utils import embedding_functions\n", + "from langchain_openai import ChatOpenAI\n", "\n", "os.environ[\"OPENAI_API_KEY\"] = \"\"\n", "openai_ef = embedding_functions.OpenAIEmbeddingFunction(\n", " api_key=\"\",\n", " model_name=\"text-embedding-ada-002\",\n", - ")" + ")\n", + "\n", + "llm = ChatOpenAI()" ] }, { @@ -206,15 +210,26 @@ "metadata": {}, "outputs": [], "source": [ - "from autogen.agentchat.contrib.rag.mongodb import MongoDBQueryEngine\n", + "from autogen.agentchat.contrib.rag.mongodb_query_engine import MongoDBQueryEngine\n", "\n", "query_engine = MongoDBQueryEngine(\n", " connection_string=\"\",\n", " embedding_function=openai_ef,\n", - " database_name=\"vector_db_2\",\n", + " database_name=\"vector_db_1\",\n", ")" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "query_engine.connect_db()\n", + "# first time run will return error and tell you to run init_db first\n", + "# from the second time when you run this cell, it will work" + ] + }, { "cell_type": "code", "execution_count": null, @@ -231,7 +246,9 @@ "metadata": {}, "outputs": [], "source": [ - "print(query_engine.query(\"How much money did Nvidia spend in research and development\"))" + "question = \"What is the trading symbol for Toast\"\n", + "answer = query_engine.query(question, llm)\n", + "print(answer)" ] }, { @@ -249,9 +266,7 @@ "metadata": {}, "outputs": [], "source": [ - "question = \"What is the trading symbol for Toast\"\n", - "answer = query_engine.query(question)\n", - "print(answer)" + "print(query_engine.query(\"How much money did Nvidia spend in research and development\", llm))" ] } ], diff --git a/test/agentchat/contrib/rag/test_mongodb_query_engine.py b/test/agentchat/contrib/rag/test_mongodb_query_engine.py index 085a4457bf..8a42fd932e 100644 --- a/test/agentchat/contrib/rag/test_mongodb_query_engine.py +++ b/test/agentchat/contrib/rag/test_mongodb_query_engine.py @@ -11,7 +11,7 @@ from unittest.mock import MagicMock, patch # Replace 'your_module' with the actual module name where MongoDBQueryEngine is defined. -from autogen.agentchat.contrib.rag.mongodb import MongoDBQueryEngine +from autogen.agentchat.contrib.rag.mongodb_query_engine import MongoDBQueryEngine # A dummy client to simulate MongoDB's behavior. From 8378e77bc2256499bbb2659bbdf9ed4d21750313 Mon Sep 17 00:00:00 2001 From: sitloboi2012 Date: Mon, 24 Feb 2025 22:05:27 +0700 Subject: [PATCH 09/12] remove the LLM def in query function Signed-off-by: sitloboi2012 --- autogen/agentchat/contrib/rag/mongodb_query_engine.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/autogen/agentchat/contrib/rag/mongodb_query_engine.py b/autogen/agentchat/contrib/rag/mongodb_query_engine.py index 98810e0836..e65d935d78 100644 --- a/autogen/agentchat/contrib/rag/mongodb_query_engine.py +++ b/autogen/agentchat/contrib/rag/mongodb_query_engine.py @@ -13,7 +13,6 @@ with optional_import_block(): from langchain.base_language import BaseLanguageModel from llama_index.core import SimpleDirectoryReader, StorageContext, VectorStoreIndex - from llama_index.core.llms.llm import LLM from llama_index.vector_stores.mongodb import MongoDBAtlasVectorSearch from pymongo import MongoClient @@ -221,7 +220,7 @@ def add_records( except Exception as e: logger.error("Error inserting documents into the index: %s", e) - def query(self, question: str, llm: Union[str, LLM, "BaseLanguageModel"], *args: Any, **kwargs: Any) -> Any: # type: ignore[no-any-unimported, type-arg] + def query(self, question: str, llm: Union[str, "BaseLanguageModel"], *args: Any, **kwargs: Any) -> Any: # type: ignore[no-any-unimported, type-arg] """ Query the index using the given question. From 445e7d1f927ac2d16f67fc73e39d49650da08389 Mon Sep 17 00:00:00 2001 From: sitloboi2012 Date: Mon, 24 Feb 2025 22:43:38 +0700 Subject: [PATCH 10/12] update test case for mongodb query engine Signed-off-by: sitloboi2012 --- .../contrib/rag/test_mongodb_query_engine.py | 351 +++++++++++------- 1 file changed, 216 insertions(+), 135 deletions(-) diff --git a/test/agentchat/contrib/rag/test_mongodb_query_engine.py b/test/agentchat/contrib/rag/test_mongodb_query_engine.py index 8a42fd932e..fabce2e500 100644 --- a/test/agentchat/contrib/rag/test_mongodb_query_engine.py +++ b/test/agentchat/contrib/rag/test_mongodb_query_engine.py @@ -6,153 +6,234 @@ # SPDX-License-Identifier: MIT # !/usr/bin/env python3 -m pytest -import unittest -from pathlib import Path -from unittest.mock import MagicMock, patch +import os -# Replace 'your_module' with the actual module name where MongoDBQueryEngine is defined. -from autogen.agentchat.contrib.rag.mongodb_query_engine import MongoDBQueryEngine +import pytest +from llama_index.core import SimpleDirectoryReader, VectorStoreIndex +# Import the class and constant from your module. +from autogen.agentchat.contrib.rag.mongodb_query_engine import ( + DEFAULT_COLLECTION_NAME, + MongoDBQueryEngine, +) +from autogen.agentchat.contrib.vectordb.base import VectorDBFactory +from autogen.import_utils import skip_on_missing_imports -# A dummy client to simulate MongoDB's behavior. -class DummyCollection: - def __init__(self): - self.list_collection_names = MagicMock(return_value=[]) +# ----- Fake classes for simulating MongoDB behavior ----- # -class DummyDatabase: - def __init__(self): - self._dummy_collection = DummyCollection() +@pytest.mark.openai +@skip_on_missing_imports(["pymongo", "openai", "llama_index"], "rag") +class FakeDBExists: + def list_collection_names(self): + return [DEFAULT_COLLECTION_NAME] def __getitem__(self, key): - # When accessing a collection via subscripting, return the dummy collection. - return self._dummy_collection + return {} # Return a dummy collection + +class FakeDBNoCollection: def list_collection_names(self): - # Return an empty list or a list of dummy collection names. return [] - def create_collection(self, name, *args, **kwargs): - # Simulate creation of a new collection by returning a dummy collection. - return DummyCollection() + def __getitem__(self, key): + return {} # Return a dummy collection + +class FakeMongoClientExists: + def __init__(self, connection_string): + self.connection_string = connection_string + self.admin = self -class DummyClient: - def __init__(self, should_fail=False): - self.should_fail = should_fail - self.admin = MagicMock() + def command(self, cmd): + if cmd == "ping": + return {"ok": 1} + raise Exception("Ping failed") - def command(self, command): - if self.should_fail: - raise Exception("Ping failed") - return "ok" + def __getitem__(self, name): + return FakeDBExists() - def __getitem__(self, key): - # Return a dummy database when subscripting the client. - return DummyDatabase() - - -# A dummy chat engine to simulate query responses. -class DummyChatEngine: - def query(self, question): - return f"Response to: {question}" - - -class TestMongoDBQueryEngine(unittest.TestCase): - def setUp(self): - # Patch the vector database factory and search engine constructors. - patcher_factory = patch("autogen.agentchat.contrib.vectordb.base.VectorDBFactory.create_vector_db") - self.mock_create_vector_db = patcher_factory.start() - self.addCleanup(patcher_factory.stop) - - patcher_search = patch("llama_index.vector_stores.mongodb.MongoDBAtlasVectorSearch") - self.mock_mongo_search = patcher_search.start() - self.addCleanup(patcher_search.stop) - - # Set up a dummy MongoDB client. - self.dummy_client = DummyClient() - dummy_vector_db = MagicMock() - dummy_vector_db.client = self.dummy_client - self.mock_create_vector_db.return_value = dummy_vector_db - - # Instantiate the engine with dummy parameters. - self.engine = MongoDBQueryEngine(connection_string="dummy", database_name="test_db") - # Pre-assign a dummy indexer (used in query and add_records). - self.engine.indexer = MagicMock() - - def test_connect_db_success(self): - # Simulate a successful ping. - self.dummy_client.should_fail = False - result = self.engine.connect_db() - self.assertTrue(result) - # Verify that the ping command was called. - self.dummy_client.admin.command.assert_called_with("ping") - - def test_connect_db_failure(self): - # Simulate a failure during the ping. - self.dummy_client.should_fail = True - self.dummy_client.admin.command.side_effect = Exception("Ping failed") - result = self.engine.connect_db() - self.assertFalse(result) - - @patch("llama_index.core.SimpleDirectoryReader") - @patch("your_module.VectorStoreIndex.from_documents") - def test_init_db_with_documents(self, mock_from_documents, mock_simple_dir_reader): - # Create dummy documents. - dummy_documents = ["doc1", "doc2"] - reader_instance = MagicMock() - reader_instance.load_data.return_value = dummy_documents - mock_simple_dir_reader.return_value = reader_instance - - # Patch Path.glob to return a list with at least one file. - with patch.object(Path, "glob", return_value=[Path("dummy.txt")]): - result = self.engine.init_db(new_doc_dir="dummy_dir") - self.assertTrue(result) - # Ensure that the index was built with the dummy documents. - mock_from_documents.assert_called_with(dummy_documents, storage_context=self.engine.storage_context) - - def test_init_db_no_documents(self): - # Without any document directory or file paths, the method should return False. - result = self.engine.init_db() - self.assertFalse(result) - - @patch("your_module.SimpleDirectoryReader") - def test_add_records_no_documents(self, mock_simple_dir_reader): - # When no document paths or directory are provided, expect a warning and no processing. - with self.assertLogs(level="WARNING") as cm: - self.engine.add_records() - self.assertTrue(any("No documents found for adding records." in msg for msg in cm.output)) - - @patch("your_module.SimpleDirectoryReader") - def test_add_records_with_documents(self, mock_simple_dir_reader): - # Create dummy documents to be loaded. - dummy_documents = ["doc1", "doc2", "doc3"] - reader_instance = MagicMock() - reader_instance.load_data.return_value = dummy_documents - mock_simple_dir_reader.return_value = reader_instance - - # Ensure that the indexer's insert method is callable. - self.engine.indexer.insert = MagicMock() - - # Provide a list of dummy document paths. - self.engine.add_records(new_doc_paths_or_urls=["dummy1.txt", "dummy2.txt"]) - # Verify that insert was called for each dummy document. - self.assertEqual(self.engine.indexer.insert.call_count, len(dummy_documents)) - - def test_query_success(self): - # Simulate a working chat engine. - dummy_chat_engine = DummyChatEngine() - self.engine.indexer.as_chat_engine.return_value = dummy_chat_engine - - response = self.engine.query("Test question") - self.assertEqual(response, "Response to: Test question") - - def test_query_failure(self): - # Simulate failure by having as_chat_engine raise an exception. - self.engine.indexer.as_chat_engine.side_effect = Exception("Query failed") - response = self.engine.query("Test question") - self.assertIsNone(response) - - -if __name__ == "__main__": - unittest.main() + +class FakeMongoClientNoCollection: + def __init__(self, connection_string): + self.connection_string = connection_string + self.admin = self + + def command(self, cmd): + if cmd == "ping": + return {"ok": 1} + raise Exception("Ping failed") + + def __getitem__(self, name): + return FakeDBNoCollection() + + +# ----- Fake vector DB and index implementations ----- # + + +class FakeVectorDB: + def __init__(self, client): + self.client = client + + +class FakeIndex: + def __init__(self, docs=None): + self.docs = docs or [] + + def as_chat_engine(self, llm): + # Return a fake chat engine with a query method. + class FakeChatEngine: + def query(self, question): + return f"Answer to {question}" + + return FakeChatEngine() + + def insert(self, doc): + self.docs.append(doc) + + +# Fake MongoDBAtlasVectorSearch so that no real collection creation is attempted. +class FakeMongoDBAtlasVectorSearch: + def __init__(self, mongodb_client, db_name, collection_name): + self.client = mongodb_client # so that admin.command("ping") works. + self.db_name = db_name + self.collection_name = collection_name + self.stores_text = True # Added attribute to mimic real behavior. + + +# Fake create_vector_db function to be used in place of the factory. +def fake_create_vector_db( + db_type, connection_string, database_name, index_name, embedding_function, collection_name, overwrite +): + # Choose a fake MongoClient based on the connection string. + if "exists" in connection_string: + client = FakeMongoClientExists(connection_string) + else: + client = FakeMongoClientNoCollection(connection_string) + return FakeVectorDB(client) + + +# ----- Pytest tests ----- # + + +def test_connect_db_no_collection(monkeypatch): + """ + Test connect_db when the target collection does not exist. + It should catch the error and return False. + """ + monkeypatch.setattr("autogen.agentchat.contrib.rag.mongodb_query_engine.MongoClient", FakeMongoClientNoCollection) + monkeypatch.setattr(VectorDBFactory, "create_vector_db", fake_create_vector_db) + engine = MongoDBQueryEngine( + connection_string="dummy_no_collection", database_name="vector_db", collection_name=DEFAULT_COLLECTION_NAME + ) + result = engine.connect_db() + assert result is False + + +def test_connect_db_existing(monkeypatch): + """ + Test connect_db when the collection exists. + It should succeed and return True. + """ + monkeypatch.setattr("autogen.agentchat.contrib.rag.mongodb_query_engine.MongoClient", FakeMongoClientExists) + # Override MongoDBAtlasVectorSearch with our fake. + monkeypatch.setattr( + "autogen.agentchat.contrib.rag.mongodb_query_engine.MongoDBAtlasVectorSearch", FakeMongoDBAtlasVectorSearch + ) + monkeypatch.setattr(VectorDBFactory, "create_vector_db", fake_create_vector_db) + # Override from_vector_store to accept keyword arguments. + monkeypatch.setattr(VectorStoreIndex, "from_vector_store", lambda vs, **kwargs: FakeIndex()) + engine = MongoDBQueryEngine( + connection_string="dummy_exists", database_name="vector_db", collection_name=DEFAULT_COLLECTION_NAME + ) + result = engine.connect_db() + assert result is True + + +def test_init_db_existing_collection(monkeypatch): + """ + Test init_db when the collection already exists. + It should raise an error internally and return False. + """ + monkeypatch.setattr("autogen.agentchat.contrib.rag.mongodb_query_engine.MongoClient", FakeMongoClientExists) + engine = MongoDBQueryEngine( + connection_string="dummy_exists", database_name="vector_db", collection_name=DEFAULT_COLLECTION_NAME + ) + # Use a dummy document name instead of an absolute path. + result = engine.init_db(new_doc_paths=["dummy_doc.md"]) + # Since the collection exists, init_db should return False. + assert result is False + + +def test_init_db_no_documents(monkeypatch): + """ + Test init_db when no documents are provided. + It should log a warning and return False. + """ + monkeypatch.setattr("autogen.agentchat.contrib.rag.mongodb_query_engine.MongoClient", FakeMongoClientNoCollection) + monkeypatch.setattr(VectorDBFactory, "create_vector_db", fake_create_vector_db) + # Override load_data to return an empty list. + monkeypatch.setattr(SimpleDirectoryReader, "load_data", lambda self: []) + engine = MongoDBQueryEngine( + connection_string="dummy_no_collection", database_name="vector_db", collection_name=DEFAULT_COLLECTION_NAME + ) + result = engine.init_db(new_doc_paths=[]) + assert result is False + + +def test_init_db_success(monkeypatch): + """ + Test successful initialization of the database. + It should load documents and build the index. + """ + monkeypatch.setattr("autogen.agentchat.contrib.rag.mongodb_query_engine.MongoClient", FakeMongoClientNoCollection) + monkeypatch.setattr( + "autogen.agentchat.contrib.rag.mongodb_query_engine.MongoDBAtlasVectorSearch", FakeMongoDBAtlasVectorSearch + ) + monkeypatch.setattr(VectorDBFactory, "create_vector_db", fake_create_vector_db) + # Simulate document loading returning two dummy docs. + monkeypatch.setattr(SimpleDirectoryReader, "load_data", lambda self: ["doc1", "doc2"]) + # Override from_documents to return a FakeIndex containing the docs. + monkeypatch.setattr(VectorStoreIndex, "from_documents", lambda docs, **kwargs: FakeIndex(docs)) + engine = MongoDBQueryEngine( + connection_string="dummy_no_collection", database_name="vector_db", collection_name=DEFAULT_COLLECTION_NAME + ) + # Use a dummy document name. + result = engine.init_db(new_doc_paths=["dummy_doc.md"]) + assert result is True + # Our fake loader returns ["doc1", "doc2"] + assert engine.index.docs == ["doc1", "doc2"] + + +def test_add_records(monkeypatch): + """ + Test that add_records loads documents and inserts them into the index. + """ + fake_index = FakeIndex() + engine = MongoDBQueryEngine( + connection_string="dummy", database_name="vector_db", collection_name=DEFAULT_COLLECTION_NAME + ) + engine.index = fake_index + # Override __init__ of SimpleDirectoryReader to bypass file existence checks. + monkeypatch.setattr( + SimpleDirectoryReader, "__init__", lambda self, input_files: setattr(self, "input_files", input_files) + ) + # Override load_data to return dummy records. + monkeypatch.setattr(SimpleDirectoryReader, "load_data", lambda self: ["record1", "record2"]) + # Force os.path.exists to always return True so that file existence checks pass. + monkeypatch.setattr(os.path, "exists", lambda path: True) + engine.add_records(new_doc_paths_or_urls=["dummy_path"]) + assert fake_index.docs == ["record1", "record2"] + + +def test_query(monkeypatch): + """ + Test that query returns the expected response from the fake chat engine. + """ + fake_index = FakeIndex() + engine = MongoDBQueryEngine( + connection_string="dummy", database_name="vector_db", collection_name=DEFAULT_COLLECTION_NAME + ) + engine.index = fake_index + answer = engine.query("What is testing?", llm="dummy_llm") + assert answer == "Answer to What is testing?" From 2e581c9667b90e35f6d623968ad3d5dc9b1a167a Mon Sep 17 00:00:00 2001 From: sitloboi2012 Date: Tue, 25 Feb 2025 22:12:40 +0700 Subject: [PATCH 11/12] update llm into __init__, update pyproject.toml to include llama-index-llms-langchain, update notebook for instruction usage Signed-off-by: sitloboi2012 --- .../contrib/rag/mongodb_query_engine.py | 11 +- notebook/mongodb_query_engine.ipynb | 203 ++++++++++++++++++ pyproject.toml | 3 +- 3 files changed, 212 insertions(+), 5 deletions(-) create mode 100644 notebook/mongodb_query_engine.ipynb diff --git a/autogen/agentchat/contrib/rag/mongodb_query_engine.py b/autogen/agentchat/contrib/rag/mongodb_query_engine.py index e65d935d78..70601dd2ff 100644 --- a/autogen/agentchat/contrib/rag/mongodb_query_engine.py +++ b/autogen/agentchat/contrib/rag/mongodb_query_engine.py @@ -11,8 +11,8 @@ from autogen.import_utils import optional_import_block, require_optional_import with optional_import_block(): - from langchain.base_language import BaseLanguageModel from llama_index.core import SimpleDirectoryReader, StorageContext, VectorStoreIndex + from llama_index.llms.langchain.base import LLM from llama_index.vector_stores.mongodb import MongoDBAtlasVectorSearch from pymongo import MongoClient @@ -37,13 +37,14 @@ class MongoDBQueryEngine: indexer (Optional[VectorStoreIndex]): The index built from the documents. """ - def __init__( + def __init__( # type: ignore[no-any-unimported] self, connection_string: str = "", database_name: str = "vector_db", embedding_function: Optional[Callable[..., Any]] = None, collection_name: str = DEFAULT_COLLECTION_NAME, index_name: str = "vector_index", + llm: Union[str, LLM] = "gpt-4o", ): """ Initialize the MongoDBQueryEngine. @@ -63,6 +64,8 @@ def __init__( self.storage_context = None self.index: Optional[VectorStoreIndex] = None # type: ignore[no-any-unimported] + self.llm = llm + def _setup_vector_db(self, overwrite: bool) -> None: """ Helper method to create the vector database, vector search engine, and storage context. @@ -220,7 +223,7 @@ def add_records( except Exception as e: logger.error("Error inserting documents into the index: %s", e) - def query(self, question: str, llm: Union[str, "BaseLanguageModel"], *args: Any, **kwargs: Any) -> Any: # type: ignore[no-any-unimported, type-arg] + def query(self, question: str, *args: Any, **kwargs: Any) -> Any: # type: ignore[no-any-unimported, type-arg] """ Query the index using the given question. @@ -232,7 +235,7 @@ def query(self, question: str, llm: Union[str, "BaseLanguageModel"], *args: Any, Any: The response from the chat engine, or None if an error occurs. """ try: - response = self.index.as_chat_engine(llm=llm).query(question) # type: ignore[union-attr] + response = self.index.as_chat_engine(llm=self.llm).query(question) # type: ignore[union-attr] return response except Exception as e: logger.error("Query failed: %s", e) diff --git a/notebook/mongodb_query_engine.ipynb b/notebook/mongodb_query_engine.ipynb new file mode 100644 index 0000000000..b27829b3db --- /dev/null +++ b/notebook/mongodb_query_engine.ipynb @@ -0,0 +1,203 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# MongoDB Query Engine Tutorial\n", + "\n", + "This notebook demonstrates the use of the `MongoDBQueryEngine` for retrieval-augmented question answering over documents using MongoDB. It shows how to set up the engine with MongoDB and simple text parser using LlamaIndex parsed Markdown files, and execute natural language queries against the indexed data. \n", + "\n", + "The `MongoDBQueryEngine` integrates cloud MongoDB Atlas but also MongoDB localhost vector storage with LlamaIndex for efficient document retrieval." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install llama-index==0.12.16\n", + "%pip install llama-index-llms-langchain" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Before calling the `MongoDBQueryEngine`, we will create and import our own embedding function. You can replace this with any embedding function built-in from Langchain or Llama-Index or your custom build" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "from chromadb.utils import embedding_functions\n", + "from langchain_openai import ChatOpenAI\n", + "\n", + "os.environ[\"OPENAI_API_KEY\"] = \"\"\n", + "openai_ef = embedding_functions.OpenAIEmbeddingFunction(\n", + " api_key=\"\",\n", + " model_name=\"text-embedding-ada-002\",\n", + ")\n", + "\n", + "llm = ChatOpenAI()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We will define the folder path that contains our document that we want to input. In this case, I will use experimental document from ag2 test folder" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "input_dir = \"/ag2/test/agents/experimental/document_agent/pdf_parsed/\"\n", + "# you might need to change the folder path based on which folder you want to input" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Initialize the `MongoDBQueryEngine` with:\n", + "- MongoDB Atlas Connection String (you can also provide your localhost MongoDB Connection String as well)\n", + "- Pre-defined embedding function\n", + "- Database Name\n", + "- Pre-defined LLM Model (Optional)\n", + "\n", + "Beside that, you can also customize your `collection_name` and `index_name`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from autogen.agentchat.contrib.rag.mongodb_query_engine import MongoDBQueryEngine\n", + "\n", + "query_engine = MongoDBQueryEngine(\n", + " connection_string=\"\", embedding_function=openai_ef, database_name=\"vector_db\", llm=llm\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Because this is our first time running this and we do not have any database yet so we will call the `init_db` function along with a list of input document or folder path that we want.\n", + "\n", + "You can also use this `init_db` to overwrite and re-create your database again as well, simply use the arg `overwrite = True`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# nvidia_10k_2024.md\n", + "query_engine.init_db(new_doc_paths=[input_dir + \"Toast_financial_report.md\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Using `query` to answer user input question" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "question = \"What is the trading symbol for Toast\"\n", + "answer = query_engine.query(question)\n", + "print(answer)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To add-on new document, we can use `add_records`, this could take new document path or new document dir as input" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "query_engine.add_records(new_doc_paths_or_urls=[input_dir + \"nvidia_10k_2024.md\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(query_engine.query(\"How much money did Nvidia spend in research and development\", llm))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In case that you already have a MongoDB Vector Database and you just want to connect to it without having to initialize it again, you can call the `connect_db`. You can also overwrite and re-setup your connected Vector Database by using the arg `overwrite=True`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "query_engine.connect_db()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(query_engine.query(\"How much money did Nvidia spend in research and development\", llm))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "ag2", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/pyproject.toml b/pyproject.toml index 2843f39f9c..5cde60026b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -139,7 +139,8 @@ rag = [ "chromadb>=0.5,<1", "llama-index>=0.12,<1", "llama-index-vector-stores-chroma==0.4.1", - "llama-index-vector-stores-mongodb==0.6.0" + "llama-index-vector-stores-mongodb==0.6.0", + "llama-index-llms-langchain==0.6.0" ] From 9397d24fbea1ce91477dd1deb7edafa6ccb4565a Mon Sep 17 00:00:00 2001 From: sitloboi2012 Date: Tue, 25 Feb 2025 22:39:45 +0700 Subject: [PATCH 12/12] replace the as_chat_engine to as_query_engine in MongoDBQueryEngine, update notebook Signed-off-by: sitloboi2012 --- .../agentchat/contrib/rag/mongodb_query_engine.py | 2 +- notebook/mongodb_query_engine.ipynb | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/autogen/agentchat/contrib/rag/mongodb_query_engine.py b/autogen/agentchat/contrib/rag/mongodb_query_engine.py index 70601dd2ff..1b34a4a56e 100644 --- a/autogen/agentchat/contrib/rag/mongodb_query_engine.py +++ b/autogen/agentchat/contrib/rag/mongodb_query_engine.py @@ -235,7 +235,7 @@ def query(self, question: str, *args: Any, **kwargs: Any) -> Any: # type: ignor Any: The response from the chat engine, or None if an error occurs. """ try: - response = self.index.as_chat_engine(llm=self.llm).query(question) # type: ignore[union-attr] + response = self.index.as_query_engine(llm=self.llm).query(question) # type: ignore[union-attr] return response except Exception as e: logger.error("Query failed: %s", e) diff --git a/notebook/mongodb_query_engine.ipynb b/notebook/mongodb_query_engine.ipynb index b27829b3db..a37c4a2ddd 100644 --- a/notebook/mongodb_query_engine.ipynb +++ b/notebook/mongodb_query_engine.ipynb @@ -37,15 +37,15 @@ "import os\n", "\n", "from chromadb.utils import embedding_functions\n", - "from langchain_openai import ChatOpenAI\n", + "from llama_index.llms.openai import OpenAI\n", "\n", - "os.environ[\"OPENAI_API_KEY\"] = \"\"\n", + "os.environ[\"OPENAI_API_KEY\"] = \"\"\n", "openai_ef = embedding_functions.OpenAIEmbeddingFunction(\n", - " api_key=\"\",\n", + " api_key=\"\",\n", " model_name=\"text-embedding-ada-002\",\n", ")\n", "\n", - "llm = ChatOpenAI()" + "llm = OpenAI(model=\"gpt-4o\")" ] }, { @@ -87,7 +87,7 @@ "from autogen.agentchat.contrib.rag.mongodb_query_engine import MongoDBQueryEngine\n", "\n", "query_engine = MongoDBQueryEngine(\n", - " connection_string=\"\", embedding_function=openai_ef, database_name=\"vector_db\", llm=llm\n", + " connection_string=\"\", embedding_function=openai_ef, database_name=\"vector_db\", llm=llm\n", ")" ] }, @@ -150,7 +150,7 @@ "metadata": {}, "outputs": [], "source": [ - "print(query_engine.query(\"How much money did Nvidia spend in research and development\", llm))" + "print(query_engine.query(\"How much money did Nvidia spend in research and development\"))" ] }, { @@ -175,7 +175,7 @@ "metadata": {}, "outputs": [], "source": [ - "print(query_engine.query(\"How much money did Nvidia spend in research and development\", llm))" + "print(query_engine.query(\"How much money did Nvidia spend in research and development\"))" ] } ],