Skip to content

Commit

Permalink
update and finalize the mongodb query engine with documentation
Browse files Browse the repository at this point in the history
Signed-off-by: sitloboi2012 <[email protected]>
  • Loading branch information
sitloboi2012 committed Feb 23, 2025
1 parent e2f90bd commit 6b69f27
Show file tree
Hide file tree
Showing 2 changed files with 170 additions and 96 deletions.
229 changes: 150 additions & 79 deletions autogen/agentchat/contrib/rag/mongodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,18 @@

import logging
from pathlib import Path
from typing import Any, Callable, Optional
from typing import Any, Callable, List, Optional, Union

from autogen.agentchat.contrib.rag.query_engine import VectorDbQueryEngine
from autogen.agentchat.contrib.vectordb.base import VectorDBFactory
from autogen.agentchat.contrib.vectordb.mongodb import MongoDBAtlasVectorDB
from autogen.import_utils import optional_import_block, require_optional_import

with optional_import_block():
from llama_index.core import StorageContext, VectorStoreIndex
from llama_index.node_parser.docling import DoclingNodeParser
from llama_index.readers.docling import DoclingReader
from llama_index.core import Document, SimpleDirectoryReader, StorageContext, VectorStoreIndex
from llama_index.core.node_parser import SentenceSplitter
from llama_index.vector_stores.mongodb import MongoDBAtlasVectorSearch


DEFAULT_COLLECTION_NAME = "docling-parsed-docs"

logging.basicConfig(level=logging.INFO)
Expand All @@ -26,6 +24,19 @@

@require_optional_import(["pymongo", "llama_index"], "rag")
class MongoDBQueryEngine(VectorDbQueryEngine):
"""
A query engine backed by MongoDB Atlas that supports document insertion and querying.
This engine initializes a vector database, builds an index from input documents,
and allows querying using the chat engine interface.
Attributes:
vector_db (MongoDBAtlasVectorDB): The MongoDB vector database instance.
vector_search_engine (MongoDBAtlasVectorSearch): The vector search engine.
storage_context (StorageContext): The storage context for the vector store.
indexer (Optional[VectorStoreIndex]): The index built from the documents.
"""

def __init__(
self,
connection_string: str = "",
Expand All @@ -34,8 +45,17 @@ def __init__(
collection_name: str = DEFAULT_COLLECTION_NAME,
index_name: str = "vector_index",
):
super().__init__()
"""
Initialize the MongoDBQueryEngine.
Args:
connection_string (str): The MongoDB connection string.
database_name (str): The name of the database to use.
embedding_function (Optional[Callable[..., Any]]): The function to compute embeddings.
collection_name (str): The name of the collection to use.
index_name (str): The name of the vector index.
"""
super().__init__()
self.vector_db: MongoDBAtlasVectorDB = VectorDBFactory.create_vector_db( # type: ignore[assignment]
db_type="mongodb",
connection_string=connection_string,
Expand All @@ -45,100 +65,151 @@ def __init__(
collection_name=collection_name,
)
self.vector_search_engine = MongoDBAtlasVectorSearch(
mongodb_client=self.vector_db.client, db_name=database_name, collection_name=collection_name
mongodb_client=self.vector_db.client,
db_name=database_name,
collection_name=collection_name,
)
self.storage_context = StorageContext.from_defaults(vector_store=self.vector_search_engine)
self.indexer = None
self.indexer: Optional[VectorStoreIndex] = None # type: ignore[no-any-unimported]

def connect_db(self, *args, **kwargs) -> bool: # type: ignore[no-untyped-def]
"""
Connect to the MongoDB database by issuing a ping.
Returns:
bool: True if the connection is successful; False otherwise.
"""
try:
self.vector_db.client.admin.command("ping")
logger.info("Connected to MongoDB successfully.")
return True
except Exception as error:
logger.error("Failed to connect to MongoDB: %s", error)
return False

def init_db( # type: ignore[no-untyped-def]
self,
new_doc_dir: Optional[str | Path] = None,
new_doc_paths: Optional[list[str | Path]] = None,
new_doc_dir: Optional[Union[str, Path]] = None,
new_doc_paths: Optional[List[Union[str, Path]]] = None,
*args,
**kwargs,
) -> bool: # type: ignore[no-untyped-def]
"""Initialize the database with the input documents or records.
This method initializes database with the input documents or records.
Usually, it takes the following steps,
1. connecting to a database.
2. insert records
3. build indexes etc.
) -> bool:
"""
Initialize the database by loading documents from the given directory or file paths,
then building an index.
Args:
new_doc_dir: a dir of input documents that are used to create the records in database.
new_doc_paths:
a list of input documents that are used to create the records in database.
a document can be a path to a file or a url.
*args: Any additional arguments
**kwargs: Any additional keyword arguments
new_doc_dir (Optional[Union[str, Path]]): Directory containing input documents.
new_doc_paths (Optional[List[Union[str, Path]]]): List of document paths or URLs.
Returns:
bool: True if initialization is successful, False otherwise
bool: True if initialization is successful; False otherwise.
"""
if not self.connect_db():
return False

if new_doc_dir or new_doc_paths:
self.add_records(new_doc_dir, new_doc_paths) # type: ignore[no-untyped-call]

self.indexer = VectorStoreIndex.from_vector_store(
self.vector_search_engine, storage_context=self.storage_context
)
return True
# Gather document paths.
document_list: List[Union[str, Path]] = []
if new_doc_dir:
document_list.extend(Path(new_doc_dir).glob("**/*"))
if new_doc_paths:
document_list.extend(new_doc_paths)

def connect_db(self, *args, **kwargs) -> bool: # type: ignore[no-untyped-def]
"""
Connect to the MongoDB database by issuing a ping.
if not document_list:
logger.warning("No input documents provided to initialize the database.")
return False

Returns:
True if the connection is successful; False otherwise.
"""
try:
self.vector_db.client.admin.command("ping")
documents = SimpleDirectoryReader(input_files=document_list).load_data()
self.indexer = VectorStoreIndex.from_documents(documents, storage_context=self.storage_context)
logger.info("Database initialized with %d documents.", len(documents))
return True
except Exception as error:
logger.error("Failed to connect to MongoDB: %s", error)
except Exception as e:
logger.error("Failed to initialize the database: %s", e)
return False

def add_records(self, new_doc_dir=None, new_doc_paths_or_urls=None, *args, **kwargs): # type: ignore[no-untyped-def]
document_list = [] # type: ignore[var-annotated]
def add_records( # type: ignore[no-untyped-def, override]
self,
new_doc_dir: Optional[Union[str, Path]] = None,
new_doc_paths_or_urls: Optional[Union[List[Union[str, Path]], Union[str, Path]]] = None,
chunk_size: int = 500,
chunk_overlap: int = 300,
*args,
**kwargs,
) -> None:
"""
Load, parse, and insert documents into the index.
This method uses a SentenceSplitter to break documents into chunks before insertion.
Args:
new_doc_dir (Optional[Union[str, Path]]): Directory containing input documents.
new_doc_paths_or_urls (Optional[Union[List[Union[str, Path]], Union[str, Path]]]):
List of document paths or a single document path/URL.
"""
# Collect document paths.
document_list: List[Union[str, Path]] = []
if new_doc_dir:
document_list.extend(Path(new_doc_dir).glob("**/*"))
if new_doc_paths_or_urls:
document_list.append(new_doc_paths_or_urls)

reader = DoclingReader(export_type=DoclingReader.ExportType.JSON)
node_parser = DoclingNodeParser()

documents = reader.load_data(document_list)
parser_nodes = node_parser._parse_nodes(documents)

# print("Parser data: ", parser)

# # document_reader = SimpleDirectoryReader(input_files=document_list).load_data()
# self.vector_db.insert_docs([
# Document( # type: ignore[typeddict-item, typeddict-unknown-key]
# id=document.id_,
# content=document.get_content(),
# metadata=document.metadata,
# )
# for document in parser
# ])

docs_to_insert = []
for node in parser_nodes:
doc_dict = {
"id": node.id_, # Ensure the key is 'id'
"content": node.get_content(),
"metadata": node.metadata,
}
docs_to_insert.append(doc_dict)

# Insert documents into vector DB.
self.vector_db.insert_docs(docs_to_insert) # type: ignore[arg-type]

def query(self, question, *args, **kwargs): # type: ignore[no-untyped-def]
response = self.indexer.as_chat_engine().query(question) # type: ignore[attr-defined]

return response
if isinstance(new_doc_paths_or_urls, (list, tuple)):
document_list.extend(new_doc_paths_or_urls)
else:
document_list.append(new_doc_paths_or_urls)

if not document_list:
logger.warning("No documents found for adding records.")
return

try:
raw_documents = SimpleDirectoryReader(input_files=document_list).load_data()
except Exception as e:
logger.error("Error loading documents: %s", e)
return

node_parser = SentenceSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
doc_chunks: List[Document] = [] # type: ignore[no-any-unimported]

for document in raw_documents:
try:
# Split the document text into chunks.
chunks = node_parser.split_text(document.text)
metadata = document.metadata
doc_id = document.id_ # Ensure this is not a tuple.
for chunk in chunks:
new_doc = Document(text=chunk, id_=doc_id, metadata=metadata)
doc_chunks.append(new_doc)
except Exception as e:
logger.error("Error parsing document %s: %s", document.id_, e)

if not doc_chunks:
logger.warning("No document chunks created for insertion.")
return

# Insert document chunks using the indexer.
try:
for doc in doc_chunks:
self.indexer.insert(doc) # type: ignore[union-attr]
logger.info("Inserted %d document chunks successfully.", len(doc_chunks))
except Exception as e:
logger.error("Error inserting documents into the index: %s", e)

def query(self, question: str, *args, **kwargs) -> Any: # type: ignore[no-untyped-def]
"""
Query the index using the given question.
Args:
question (str): The query string.
Returns:
Any: The response from the chat engine, or None if an error occurs.
"""
try:
response = self.indexer.as_chat_engine().query(question) # type: ignore[union-attr]
return response
except Exception as e:
logger.error("Query failed: %s", e)
return None
37 changes: 20 additions & 17 deletions notebook/docling_md_query_engine.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -174,17 +174,6 @@
"# Docling MD Query Engine MongoDB"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from chromadb.utils import embedding_functions\n",
"\n",
"openai_ef = embedding_functions.OpenAIEmbeddingFunction(api_key=\"\", model_name=\"text-embedding-ada-002\")"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -193,7 +182,13 @@
"source": [
"import os\n",
"\n",
"os.environ[\"OPENAI_API_KEY\"] = \"\""
"from chromadb.utils import embedding_functions\n",
"\n",
"os.environ[\"OPENAI_API_KEY\"] = \"\"\n",
"openai_ef = embedding_functions.OpenAIEmbeddingFunction(\n",
" api_key=\"\",\n",
" model_name=\"text-embedding-ada-002\",\n",
")"
]
},
{
Expand All @@ -202,7 +197,7 @@
"metadata": {},
"outputs": [],
"source": [
"filename = \"/root/ag2/test/agents/experimental/document_agent/pdf_parsed/nvidia_10k_2024.md\""
"input_dir = \"/root/ag2/test/agents/experimental/document_agent/pdf_parsed/\""
]
},
{
Expand All @@ -213,7 +208,11 @@
"source": [
"from autogen.agentchat.contrib.rag.mongodb import MongoDBQueryEngine\n",
"\n",
"query_engine = MongoDBQueryEngine(connection_string=\"\", embedding_function=openai_ef, database_name=\"vector_db_4\")"
"query_engine = MongoDBQueryEngine(\n",
" connection_string=\"\",\n",
" embedding_function=openai_ef,\n",
" database_name=\"vector_db_1\",\n",
")"
]
},
{
Expand All @@ -222,7 +221,7 @@
"metadata": {},
"outputs": [],
"source": [
"query_engine.init_db(new_doc_paths=filename)"
"query_engine.init_db(new_doc_paths=[input_dir + \"nvidia_10k_2024.md\"])"
]
},
{
Expand All @@ -240,15 +239,19 @@
"metadata": {},
"outputs": [],
"source": [
"query_engine.add_records(new_doc_paths_or_urls=filename)"
"query_engine.add_records(new_doc_paths_or_urls=[input_dir + \"Toast_financial_report.md\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
"source": [
"question = \"What is the trading symbol for Toast\"\n",
"answer = query_engine.query(question)\n",
"print(answer)"
]
}
],
"metadata": {
Expand Down

0 comments on commit 6b69f27

Please sign in to comment.