From 78cc764a1b70eb4ae202544773962785adc6dd2a Mon Sep 17 00:00:00 2001 From: AgentGenie Date: Wed, 19 Feb 2025 21:50:30 -0800 Subject: [PATCH] Refactor query engine interface --- .../contrib/graph_rag/graph_query_engine.py | 3 +- .../graph_rag/neo4j_graph_query_engine.py | 4 +- .../neo4j_native_graph_query_engine.py | 4 +- autogen/agentchat/contrib/rag/__init__.py | 9 +++ autogen/agentchat/contrib/rag/query_engine.py | 58 +++++++++++++++++++ 5 files changed, 73 insertions(+), 5 deletions(-) create mode 100644 autogen/agentchat/contrib/rag/__init__.py create mode 100644 autogen/agentchat/contrib/rag/query_engine.py diff --git a/autogen/agentchat/contrib/graph_rag/graph_query_engine.py b/autogen/agentchat/contrib/graph_rag/graph_query_engine.py index 0432459b42..b46930681f 100644 --- a/autogen/agentchat/contrib/graph_rag/graph_query_engine.py +++ b/autogen/agentchat/contrib/graph_rag/graph_query_engine.py @@ -5,7 +5,7 @@ # Portions derived from https://github.com/microsoft/autogen are under the MIT License. # SPDX-License-Identifier: MIT from dataclasses import dataclass, field -from typing import Optional, Protocol +from typing import Optional, Protocol, runtime_checkable from .document import Document @@ -22,6 +22,7 @@ class GraphStoreQueryResult: results: list = field(default_factory=list) +@runtime_checkable class GraphQueryEngine(Protocol): """An abstract base class that represents a graph query engine on top of a underlying graph database. diff --git a/autogen/agentchat/contrib/graph_rag/neo4j_graph_query_engine.py b/autogen/agentchat/contrib/graph_rag/neo4j_graph_query_engine.py index d01cb7c286..a7d16c3ff4 100644 --- a/autogen/agentchat/contrib/graph_rag/neo4j_graph_query_engine.py +++ b/autogen/agentchat/contrib/graph_rag/neo4j_graph_query_engine.py @@ -12,7 +12,7 @@ from ....import_utils import optional_import_block, require_optional_import from .document import Document, DocumentType -from .graph_query_engine import GraphQueryEngine, GraphStoreQueryResult +from .graph_query_engine import GraphStoreQueryResult with optional_import_block(): from llama_index.core import PropertyGraphIndex, SimpleDirectoryReader @@ -31,7 +31,7 @@ @require_optional_import("llama_index", "neo4j") -class Neo4jGraphQueryEngine(GraphQueryEngine): +class Neo4jGraphQueryEngine: """This class serves as a wrapper for a property graph query engine backed by LlamaIndex and Neo4j, facilitating the creating, connecting, updating, and querying of LlamaIndex property graphs. diff --git a/autogen/agentchat/contrib/graph_rag/neo4j_native_graph_query_engine.py b/autogen/agentchat/contrib/graph_rag/neo4j_native_graph_query_engine.py index 909da92c7b..446a68bd61 100644 --- a/autogen/agentchat/contrib/graph_rag/neo4j_native_graph_query_engine.py +++ b/autogen/agentchat/contrib/graph_rag/neo4j_native_graph_query_engine.py @@ -8,7 +8,7 @@ from ....import_utils import optional_import_block, require_optional_import from .document import Document, DocumentType -from .graph_query_engine import GraphQueryEngine, GraphStoreQueryResult +from .graph_query_engine import GraphStoreQueryResult with optional_import_block(): from neo4j import GraphDatabase @@ -26,7 +26,7 @@ @require_optional_import(["neo4j", "neo4j_graphrag"], "neo4j") -class Neo4jNativeGraphQueryEngine(GraphQueryEngine): +class Neo4jNativeGraphQueryEngine: """A graph query engine implemented using the Neo4j GraphRAG SDK. Provides functionality to initialize a knowledge graph, create a vector index, and query the graph using Neo4j and LLM. diff --git a/autogen/agentchat/contrib/rag/__init__.py b/autogen/agentchat/contrib/rag/__init__.py new file mode 100644 index 0000000000..1495f4ec7f --- /dev/null +++ b/autogen/agentchat/contrib/rag/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 + +from .query_engine import VectorDbQueryEngine + +__all__ = [ + "VectorDbQueryEngine", +] diff --git a/autogen/agentchat/contrib/rag/query_engine.py b/autogen/agentchat/contrib/rag/query_engine.py new file mode 100644 index 0000000000..c1dad41b12 --- /dev/null +++ b/autogen/agentchat/contrib/rag/query_engine.py @@ -0,0 +1,58 @@ +# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors +# +# SPDX-License-Identifier: Apache-2.0 + +from pathlib import Path +from typing import Optional, Protocol, Union, runtime_checkable + + +@runtime_checkable +class VectorDbQueryEngine(Protocol): + """An abstract base class that represents aquery engine on top of a underlying vector database. + + This interface defines the basic methods for RAG. + """ + + def init_db( + self, + new_doc_dir: Optional[Union[Path, str]] = None, + new_doc_paths: Optional[list[Union[Path, str]]] = None, + /, + *args, + **kwargs, + ) -> bool: + """This method initializes database with the input documents or records. + Usually, it takes the following steps, + 1. connecting to a database. + 2. insert records + 3. build indexes etc. + + Args: + new_doc_dir: a dir of input documents that are used to create the records in database. + new_doc_paths: + a list of input documents that are used to create the records in database. + a document can be a path to a file or a url. + + Returns: + bool: True if initialization is successful, False otherwise + """ + ... + + def add_records( + self, + new_doc_dir: Optional[Union[Path, str]] = None, + new_doc_paths_or_urls: Optional[list[Union[Path, str]]] = None, + /, + *args, + **kwargs, + ) -> bool: + """Add new documents to the underlying database and add to the index.""" + ... + + def connect_db(self, *args, **kwargs) -> bool: + """This method connects to the database.""" + ... + + def query(self, question: str, /, *args, **kwargs) -> str: + """This method transform a string format question into database query and return the result.""" + ...