diff --git a/autogen/agentchat/contrib/rag/mongodb_query_engine.py b/autogen/agentchat/contrib/rag/mongodb_query_engine.py new file mode 100644 index 0000000000..1b34a4a56e --- /dev/null +++ b/autogen/agentchat/contrib/rag/mongodb_query_engine.py @@ -0,0 +1,242 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 + +import logging +from pathlib import Path +from typing import Any, Callable, List, Optional, Union + +from autogen.agentchat.contrib.vectordb.base import VectorDBFactory +from autogen.agentchat.contrib.vectordb.mongodb import MongoDBAtlasVectorDB +from autogen.import_utils import optional_import_block, require_optional_import + +with optional_import_block(): + from llama_index.core import SimpleDirectoryReader, StorageContext, VectorStoreIndex + from llama_index.llms.langchain.base import LLM + from llama_index.vector_stores.mongodb import MongoDBAtlasVectorSearch + from pymongo import MongoClient + +DEFAULT_COLLECTION_NAME = "docling-parsed-docs" + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +@require_optional_import(["pymongo", "llama_index"], "rag") +class MongoDBQueryEngine: + """ + A query engine backed by MongoDB Atlas that supports document insertion and querying. + + This engine initializes a vector database, builds an index from input documents, + and allows querying using the chat engine interface. + + Attributes: + vector_db (MongoDBAtlasVectorDB): The MongoDB vector database instance. + vector_search_engine (MongoDBAtlasVectorSearch): The vector search engine. + storage_context (StorageContext): The storage context for the vector store. + indexer (Optional[VectorStoreIndex]): The index built from the documents. + """ + + def __init__( # type: ignore[no-any-unimported] + self, + connection_string: str = "", + database_name: str = "vector_db", + embedding_function: Optional[Callable[..., Any]] = None, + collection_name: str = DEFAULT_COLLECTION_NAME, + index_name: str = "vector_index", + llm: Union[str, LLM] = "gpt-4o", + ): + """ + Initialize the MongoDBQueryEngine. + + Note: The actual connection and creation of the vector database is deferred to + connect_db (to use an existing collection) or init_db (to create a new collection). + """ + self.connection_string = connection_string + self.database_name = database_name + self.embedding_function = embedding_function + self.collection_name = collection_name + self.index_name = index_name + + # These will be initialized later. + self.vector_db: Optional[MongoDBAtlasVectorDB] = None + self.vector_search_engine = None + self.storage_context = None + self.index: Optional[VectorStoreIndex] = None # type: ignore[no-any-unimported] + + self.llm = llm + + def _setup_vector_db(self, overwrite: bool) -> None: + """ + Helper method to create the vector database, vector search engine, and storage context. + + Args: + overwrite (bool): If True, create a new collection (overwriting if exists). + If False, use an existing collection. + """ + # Pass the overwrite flag to the factory if supported. + self.vector_db: MongoDBAtlasVectorDB = VectorDBFactory.create_vector_db( # type: ignore[assignment, no-redef] + db_type="mongodb", + connection_string=self.connection_string, + database_name=self.database_name, + index_name=self.index_name, + embedding_function=self.embedding_function, + collection_name=self.collection_name, + overwrite=overwrite, # new parameter to control creation behavior + ) + self.vector_search_engine = MongoDBAtlasVectorSearch( + mongodb_client=self.vector_db.client, # type: ignore[union-attr] + db_name=self.database_name, + collection_name=self.collection_name, + ) + self.storage_context = StorageContext.from_defaults(vector_store=self.vector_search_engine) + self.index = VectorStoreIndex.from_vector_store(self.vector_search_engine, storage_context=self.storage_context) + + def connect_db(self, overwrite: bool = False, *args: Any, **kwargs: Any) -> bool: + """ + Connect to the MongoDB database by issuing a ping using an existing collection. + This method first checks if the target database and collection exist. + - If not, it raises an error instructing the user to run init_db. + - If the collection exists and overwrite is True, it reinitializes the database. + - Otherwise, it uses the existing collection. + + Returns: + bool: True if the connection is successful; False otherwise. + """ + try: + # Check if the target collection exists. + client = MongoClient(self.connection_string) + db = client[self.database_name] + if self.collection_name not in db.list_collection_names(): + raise ValueError( + f"Collection '{self.collection_name}' not found in database '{self.database_name}'. " + "Please run init_db to create a new collection." + ) + # Reinitialize if the caller requested overwrite. + if overwrite: + logger.info("Overwriting existing collection as requested.") + self._setup_vector_db(overwrite=True) + else: + self._setup_vector_db(overwrite=False) + self.vector_db.client.admin.command("ping") # type: ignore[union-attr] + logger.info("Connected to MongoDB successfully.") + return True + except Exception as error: + logger.error("Failed to connect to MongoDB: %s", error) + return False + + def init_db( + self, + new_doc_dir: Optional[Union[str, Path]] = None, + new_doc_paths: Optional[List[Union[str, Path]]] = None, + *args: Any, + **kwargs: Any, + ) -> bool: + """ + Initialize the database by loading documents from the given directory or file paths, + then building an index. This method is intended for first-time creation of the database, + so it expects that the collection does not already exist (i.e. overwrite is False). + + Args: + new_doc_dir (Optional[Union[str, Path]]): Directory containing input documents. + new_doc_paths (Optional[List[Union[str, Path]]]): List of document paths or URLs. + + Returns: + bool: True if initialization is successful; False otherwise. + """ + try: + # Check if the collection already exists. + client = MongoClient(self.connection_string) + db = client[self.database_name] + if self.collection_name in db.list_collection_names(): + raise ValueError( + f"Collection '{self.collection_name}' already exists in database '{self.database_name}'. " + "Use connect_db with overwrite=True to reinitialize it." + ) + # Set up the database without overwriting. + self._setup_vector_db(overwrite=False) + self.vector_db.client.admin.command("ping") # type: ignore[union-attr] + # Gather document paths. + document_list: List[Union[str, Path]] = [] + if new_doc_dir: + document_list.extend(Path(new_doc_dir).glob("**/*")) + if new_doc_paths: + document_list.extend(new_doc_paths) + + if not document_list: + logger.warning("No input documents provided to initialize the database.") + return False + + documents = SimpleDirectoryReader(input_files=document_list).load_data() + self.index = VectorStoreIndex.from_documents(documents, storage_context=self.storage_context) + logger.info("Database initialized with %d documents.", len(documents)) + return True + except Exception as e: + logger.error("Failed to initialize the database: %s", e) + return False + + def add_records( + self, + new_doc_dir: Optional[Union[str, Path]] = None, + new_doc_paths_or_urls: Optional[Union[List[Union[str, Path]], Union[str, Path]]] = None, + *args: Any, + **kwargs: Any, + ) -> None: + """ + Load, parse, and insert documents into the index. + + This method uses a SentenceSplitter to break documents into chunks before insertion. + + Args: + new_doc_dir (Optional[Union[str, Path]]): Directory containing input documents. + new_doc_paths_or_urls (Optional[Union[List[Union[str, Path]], Union[str, Path]]]): + List of document paths or a single document path/URL. + """ + # Collect document paths. + document_list: List[Union[str, Path]] = [] + if new_doc_dir: + document_list.extend(Path(new_doc_dir).glob("**/*")) + if new_doc_paths_or_urls: + if isinstance(new_doc_paths_or_urls, (list, tuple)): + document_list.extend(new_doc_paths_or_urls) + else: + document_list.append(new_doc_paths_or_urls) + + if not document_list: + logger.warning("No documents found for adding records.") + return + + try: + raw_documents = SimpleDirectoryReader(input_files=document_list).load_data() + except Exception as e: + logger.error("Error loading documents: %s", e) + return + + if not raw_documents: + logger.warning("No document chunks created for insertion.") + return + + try: + for doc in raw_documents: + self.index.insert(doc) # type: ignore[union-attr] + logger.info("Inserted %d document chunks successfully.", len(raw_documents)) + except Exception as e: + logger.error("Error inserting documents into the index: %s", e) + + def query(self, question: str, *args: Any, **kwargs: Any) -> Any: # type: ignore[no-any-unimported, type-arg] + """ + Query the index using the given question. + + Args: + question (str): The query string. + llm (Union[str, LLM, BaseLanguageModel]): The language model to use. + + Returns: + Any: The response from the chat engine, or None if an error occurs. + """ + try: + response = self.index.as_query_engine(llm=self.llm).query(question) # type: ignore[union-attr] + return response + except Exception as e: + logger.error("Query failed: %s", e) + return None diff --git a/notebook/docling_md_query_engine.ipynb b/notebook/docling_md_query_engine.ipynb index c40bca9f42..1068784de0 100644 --- a/notebook/docling_md_query_engine.ipynb +++ b/notebook/docling_md_query_engine.ipynb @@ -21,6 +21,16 @@ "%pip install llama-index==0.12.16" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install sentence-transformers\n", + "%pip install llama-index-llms-langchain" + ] + }, { "cell_type": "code", "execution_count": null, @@ -157,11 +167,112 @@ "answer = query_engine.query(question)\n", "print(answer)" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Docling MD Query Engine MongoDB" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "from chromadb.utils import embedding_functions\n", + "from langchain_openai import ChatOpenAI\n", + "\n", + "os.environ[\"OPENAI_API_KEY\"] = \"\"\n", + "openai_ef = embedding_functions.OpenAIEmbeddingFunction(\n", + " api_key=\"\",\n", + " model_name=\"text-embedding-ada-002\",\n", + ")\n", + "\n", + "llm = ChatOpenAI()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "input_dir = \"/root/ag2/test/agents/experimental/document_agent/pdf_parsed/\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from autogen.agentchat.contrib.rag.mongodb_query_engine import MongoDBQueryEngine\n", + "\n", + "query_engine = MongoDBQueryEngine(\n", + " connection_string=\"\",\n", + " embedding_function=openai_ef,\n", + " database_name=\"vector_db_1\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "query_engine.connect_db()\n", + "# first time run will return error and tell you to run init_db first\n", + "# from the second time when you run this cell, it will work" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# nvidia_10k_2024.md\n", + "query_engine.init_db(new_doc_paths=[input_dir + \"Toast_financial_report.md\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "question = \"What is the trading symbol for Toast\"\n", + "answer = query_engine.query(question, llm)\n", + "print(answer)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "query_engine.add_records(new_doc_paths_or_urls=[input_dir + \"nvidia_10k_2024.md\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(query_engine.query(\"How much money did Nvidia spend in research and development\", llm))" + ] } ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "ag2", "language": "python", "name": "python3" }, @@ -175,7 +286,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.21" + "version": "3.11.11" } }, "nbformat": 4, diff --git a/notebook/mongodb_query_engine.ipynb b/notebook/mongodb_query_engine.ipynb new file mode 100644 index 0000000000..a37c4a2ddd --- /dev/null +++ b/notebook/mongodb_query_engine.ipynb @@ -0,0 +1,203 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# MongoDB Query Engine Tutorial\n", + "\n", + "This notebook demonstrates the use of the `MongoDBQueryEngine` for retrieval-augmented question answering over documents using MongoDB. It shows how to set up the engine with MongoDB and simple text parser using LlamaIndex parsed Markdown files, and execute natural language queries against the indexed data. \n", + "\n", + "The `MongoDBQueryEngine` integrates cloud MongoDB Atlas but also MongoDB localhost vector storage with LlamaIndex for efficient document retrieval." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install llama-index==0.12.16\n", + "%pip install llama-index-llms-langchain" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Before calling the `MongoDBQueryEngine`, we will create and import our own embedding function. You can replace this with any embedding function built-in from Langchain or Llama-Index or your custom build" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "from chromadb.utils import embedding_functions\n", + "from llama_index.llms.openai import OpenAI\n", + "\n", + "os.environ[\"OPENAI_API_KEY\"] = \"\"\n", + "openai_ef = embedding_functions.OpenAIEmbeddingFunction(\n", + " api_key=\"\",\n", + " model_name=\"text-embedding-ada-002\",\n", + ")\n", + "\n", + "llm = OpenAI(model=\"gpt-4o\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We will define the folder path that contains our document that we want to input. In this case, I will use experimental document from ag2 test folder" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "input_dir = \"/ag2/test/agents/experimental/document_agent/pdf_parsed/\"\n", + "# you might need to change the folder path based on which folder you want to input" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Initialize the `MongoDBQueryEngine` with:\n", + "- MongoDB Atlas Connection String (you can also provide your localhost MongoDB Connection String as well)\n", + "- Pre-defined embedding function\n", + "- Database Name\n", + "- Pre-defined LLM Model (Optional)\n", + "\n", + "Beside that, you can also customize your `collection_name` and `index_name`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from autogen.agentchat.contrib.rag.mongodb_query_engine import MongoDBQueryEngine\n", + "\n", + "query_engine = MongoDBQueryEngine(\n", + " connection_string=\"\", embedding_function=openai_ef, database_name=\"vector_db\", llm=llm\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Because this is our first time running this and we do not have any database yet so we will call the `init_db` function along with a list of input document or folder path that we want.\n", + "\n", + "You can also use this `init_db` to overwrite and re-create your database again as well, simply use the arg `overwrite = True`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# nvidia_10k_2024.md\n", + "query_engine.init_db(new_doc_paths=[input_dir + \"Toast_financial_report.md\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Using `query` to answer user input question" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "question = \"What is the trading symbol for Toast\"\n", + "answer = query_engine.query(question)\n", + "print(answer)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To add-on new document, we can use `add_records`, this could take new document path or new document dir as input" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "query_engine.add_records(new_doc_paths_or_urls=[input_dir + \"nvidia_10k_2024.md\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(query_engine.query(\"How much money did Nvidia spend in research and development\"))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In case that you already have a MongoDB Vector Database and you just want to connect to it without having to initialize it again, you can call the `connect_db`. You can also overwrite and re-setup your connected Vector Database by using the arg `overwrite=True`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "query_engine.connect_db()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(query_engine.query(\"How much money did Nvidia spend in research and development\"))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "ag2", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/pyproject.toml b/pyproject.toml index 9b07b10f62..5cde60026b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -139,6 +139,8 @@ rag = [ "chromadb>=0.5,<1", "llama-index>=0.12,<1", "llama-index-vector-stores-chroma==0.4.1", + "llama-index-vector-stores-mongodb==0.6.0", + "llama-index-llms-langchain==0.6.0" ] diff --git a/test/agentchat/contrib/rag/test_mongodb_query_engine.py b/test/agentchat/contrib/rag/test_mongodb_query_engine.py new file mode 100644 index 0000000000..fabce2e500 --- /dev/null +++ b/test/agentchat/contrib/rag/test_mongodb_query_engine.py @@ -0,0 +1,239 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 +# +# Portions derived from https://github.com/microsoft/autogen are under the MIT License. +# SPDX-License-Identifier: MIT +# !/usr/bin/env python3 -m pytest + +import os + +import pytest +from llama_index.core import SimpleDirectoryReader, VectorStoreIndex + +# Import the class and constant from your module. +from autogen.agentchat.contrib.rag.mongodb_query_engine import ( + DEFAULT_COLLECTION_NAME, + MongoDBQueryEngine, +) +from autogen.agentchat.contrib.vectordb.base import VectorDBFactory +from autogen.import_utils import skip_on_missing_imports + +# ----- Fake classes for simulating MongoDB behavior ----- # + + +@pytest.mark.openai +@skip_on_missing_imports(["pymongo", "openai", "llama_index"], "rag") +class FakeDBExists: + def list_collection_names(self): + return [DEFAULT_COLLECTION_NAME] + + def __getitem__(self, key): + return {} # Return a dummy collection + + +class FakeDBNoCollection: + def list_collection_names(self): + return [] + + def __getitem__(self, key): + return {} # Return a dummy collection + + +class FakeMongoClientExists: + def __init__(self, connection_string): + self.connection_string = connection_string + self.admin = self + + def command(self, cmd): + if cmd == "ping": + return {"ok": 1} + raise Exception("Ping failed") + + def __getitem__(self, name): + return FakeDBExists() + + +class FakeMongoClientNoCollection: + def __init__(self, connection_string): + self.connection_string = connection_string + self.admin = self + + def command(self, cmd): + if cmd == "ping": + return {"ok": 1} + raise Exception("Ping failed") + + def __getitem__(self, name): + return FakeDBNoCollection() + + +# ----- Fake vector DB and index implementations ----- # + + +class FakeVectorDB: + def __init__(self, client): + self.client = client + + +class FakeIndex: + def __init__(self, docs=None): + self.docs = docs or [] + + def as_chat_engine(self, llm): + # Return a fake chat engine with a query method. + class FakeChatEngine: + def query(self, question): + return f"Answer to {question}" + + return FakeChatEngine() + + def insert(self, doc): + self.docs.append(doc) + + +# Fake MongoDBAtlasVectorSearch so that no real collection creation is attempted. +class FakeMongoDBAtlasVectorSearch: + def __init__(self, mongodb_client, db_name, collection_name): + self.client = mongodb_client # so that admin.command("ping") works. + self.db_name = db_name + self.collection_name = collection_name + self.stores_text = True # Added attribute to mimic real behavior. + + +# Fake create_vector_db function to be used in place of the factory. +def fake_create_vector_db( + db_type, connection_string, database_name, index_name, embedding_function, collection_name, overwrite +): + # Choose a fake MongoClient based on the connection string. + if "exists" in connection_string: + client = FakeMongoClientExists(connection_string) + else: + client = FakeMongoClientNoCollection(connection_string) + return FakeVectorDB(client) + + +# ----- Pytest tests ----- # + + +def test_connect_db_no_collection(monkeypatch): + """ + Test connect_db when the target collection does not exist. + It should catch the error and return False. + """ + monkeypatch.setattr("autogen.agentchat.contrib.rag.mongodb_query_engine.MongoClient", FakeMongoClientNoCollection) + monkeypatch.setattr(VectorDBFactory, "create_vector_db", fake_create_vector_db) + engine = MongoDBQueryEngine( + connection_string="dummy_no_collection", database_name="vector_db", collection_name=DEFAULT_COLLECTION_NAME + ) + result = engine.connect_db() + assert result is False + + +def test_connect_db_existing(monkeypatch): + """ + Test connect_db when the collection exists. + It should succeed and return True. + """ + monkeypatch.setattr("autogen.agentchat.contrib.rag.mongodb_query_engine.MongoClient", FakeMongoClientExists) + # Override MongoDBAtlasVectorSearch with our fake. + monkeypatch.setattr( + "autogen.agentchat.contrib.rag.mongodb_query_engine.MongoDBAtlasVectorSearch", FakeMongoDBAtlasVectorSearch + ) + monkeypatch.setattr(VectorDBFactory, "create_vector_db", fake_create_vector_db) + # Override from_vector_store to accept keyword arguments. + monkeypatch.setattr(VectorStoreIndex, "from_vector_store", lambda vs, **kwargs: FakeIndex()) + engine = MongoDBQueryEngine( + connection_string="dummy_exists", database_name="vector_db", collection_name=DEFAULT_COLLECTION_NAME + ) + result = engine.connect_db() + assert result is True + + +def test_init_db_existing_collection(monkeypatch): + """ + Test init_db when the collection already exists. + It should raise an error internally and return False. + """ + monkeypatch.setattr("autogen.agentchat.contrib.rag.mongodb_query_engine.MongoClient", FakeMongoClientExists) + engine = MongoDBQueryEngine( + connection_string="dummy_exists", database_name="vector_db", collection_name=DEFAULT_COLLECTION_NAME + ) + # Use a dummy document name instead of an absolute path. + result = engine.init_db(new_doc_paths=["dummy_doc.md"]) + # Since the collection exists, init_db should return False. + assert result is False + + +def test_init_db_no_documents(monkeypatch): + """ + Test init_db when no documents are provided. + It should log a warning and return False. + """ + monkeypatch.setattr("autogen.agentchat.contrib.rag.mongodb_query_engine.MongoClient", FakeMongoClientNoCollection) + monkeypatch.setattr(VectorDBFactory, "create_vector_db", fake_create_vector_db) + # Override load_data to return an empty list. + monkeypatch.setattr(SimpleDirectoryReader, "load_data", lambda self: []) + engine = MongoDBQueryEngine( + connection_string="dummy_no_collection", database_name="vector_db", collection_name=DEFAULT_COLLECTION_NAME + ) + result = engine.init_db(new_doc_paths=[]) + assert result is False + + +def test_init_db_success(monkeypatch): + """ + Test successful initialization of the database. + It should load documents and build the index. + """ + monkeypatch.setattr("autogen.agentchat.contrib.rag.mongodb_query_engine.MongoClient", FakeMongoClientNoCollection) + monkeypatch.setattr( + "autogen.agentchat.contrib.rag.mongodb_query_engine.MongoDBAtlasVectorSearch", FakeMongoDBAtlasVectorSearch + ) + monkeypatch.setattr(VectorDBFactory, "create_vector_db", fake_create_vector_db) + # Simulate document loading returning two dummy docs. + monkeypatch.setattr(SimpleDirectoryReader, "load_data", lambda self: ["doc1", "doc2"]) + # Override from_documents to return a FakeIndex containing the docs. + monkeypatch.setattr(VectorStoreIndex, "from_documents", lambda docs, **kwargs: FakeIndex(docs)) + engine = MongoDBQueryEngine( + connection_string="dummy_no_collection", database_name="vector_db", collection_name=DEFAULT_COLLECTION_NAME + ) + # Use a dummy document name. + result = engine.init_db(new_doc_paths=["dummy_doc.md"]) + assert result is True + # Our fake loader returns ["doc1", "doc2"] + assert engine.index.docs == ["doc1", "doc2"] + + +def test_add_records(monkeypatch): + """ + Test that add_records loads documents and inserts them into the index. + """ + fake_index = FakeIndex() + engine = MongoDBQueryEngine( + connection_string="dummy", database_name="vector_db", collection_name=DEFAULT_COLLECTION_NAME + ) + engine.index = fake_index + # Override __init__ of SimpleDirectoryReader to bypass file existence checks. + monkeypatch.setattr( + SimpleDirectoryReader, "__init__", lambda self, input_files: setattr(self, "input_files", input_files) + ) + # Override load_data to return dummy records. + monkeypatch.setattr(SimpleDirectoryReader, "load_data", lambda self: ["record1", "record2"]) + # Force os.path.exists to always return True so that file existence checks pass. + monkeypatch.setattr(os.path, "exists", lambda path: True) + engine.add_records(new_doc_paths_or_urls=["dummy_path"]) + assert fake_index.docs == ["record1", "record2"] + + +def test_query(monkeypatch): + """ + Test that query returns the expected response from the fake chat engine. + """ + fake_index = FakeIndex() + engine = MongoDBQueryEngine( + connection_string="dummy", database_name="vector_db", collection_name=DEFAULT_COLLECTION_NAME + ) + engine.index = fake_index + answer = engine.query("What is testing?", llm="dummy_llm") + assert answer == "Answer to What is testing?"