Skip to content

Commit

Permalink
typing fixed for RAG and Graph RAG
Browse files Browse the repository at this point in the history
  • Loading branch information
davorrunje committed Feb 20, 2025
1 parent 78cc764 commit 0ed53a3
Show file tree
Hide file tree
Showing 18 changed files with 154 additions and 113 deletions.
6 changes: 3 additions & 3 deletions .secrets.baseline
Original file line number Diff line number Diff line change
Expand Up @@ -773,7 +773,7 @@
"filename": "test/agentchat/contrib/graph_rag/test_native_neo4j_graph_rag.py",
"hashed_secret": "5baa61e4c9b93f3f0682250b6cf8331b7ee68fd8",
"is_verified": false,
"line_number": 78,
"line_number": 75,
"is_secret": false
}
],
Expand All @@ -783,7 +783,7 @@
"filename": "test/agentchat/contrib/graph_rag/test_neo4j_graph_rag.py",
"hashed_secret": "5baa61e4c9b93f3f0682250b6cf8331b7ee68fd8",
"is_verified": false,
"line_number": 38,
"line_number": 35,
"is_secret": false
}
],
Expand Down Expand Up @@ -1616,5 +1616,5 @@
}
]
},
"generated_at": "2025-02-19T11:06:40Z"
"generated_at": "2025-02-20T10:09:56Z"
}
6 changes: 5 additions & 1 deletion autogen/agentchat/contrib/graph_rag/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,8 @@
#
# SPDX-License-Identifier: Apache-2.0

__all__: list[str] = []
from .document import Document, DocumentType
from .graph_query_engine import GraphQueryEngine, GraphStoreQueryResult
from .graph_rag_capability import GraphRagCapability

__all__ = ["Document", "DocumentType", "GraphQueryEngine", "GraphRagCapability", "GraphStoreQueryResult"]
10 changes: 6 additions & 4 deletions autogen/agentchat/contrib/graph_rag/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
from dataclasses import dataclass
from dataclasses import dataclass, field
from enum import Enum, auto
from typing import Optional
from typing import Any, Optional

__all__ = ["Document", "DocumentType"]


class DocumentType(Enum):
Expand All @@ -23,5 +25,5 @@ class Document:
"""A wrapper of graph store query results."""

doctype: DocumentType
data: Optional[object] = None
path_or_url: Optional[str] = ""
data: Optional[Any] = None
path_or_url: Optional[str] = field(default_factory=lambda: "")
20 changes: 10 additions & 10 deletions autogen/agentchat/contrib/graph_rag/falkor_graph_query_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import os
import warnings
from typing import Optional
from typing import Any, Optional

from ....import_utils import optional_import_block, require_optional_import
from .document import Document
Expand All @@ -23,7 +23,7 @@
class FalkorGraphQueryEngine:
"""This is a wrapper for FalkorDB KnowledgeGraph."""

def __init__(
def __init__( # type: ignore[no-any-unimported]
self,
name: str,
host: str = "127.0.0.1",
Expand Down Expand Up @@ -57,10 +57,10 @@ def __init__(
self.model = model or OpenAiGenerativeModel("gpt-4o")
self.model_config = KnowledgeGraphModelConfig.with_model(model)
self.ontology = ontology
self.knowledge_graph = None
self.knowledge_graph: Optional["KnowledgeGraph"] = None # type: ignore[no-any-unimported]
self.falkordb = FalkorDB(host=self.host, port=self.port, username=self.username, password=self.password)

def connect_db(self):
def connect_db(self) -> None:
"""Connect to an existing knowledge graph."""
if self.name in self.falkordb.list_graphs():
try:
Expand All @@ -86,11 +86,11 @@ def connect_db(self):
else:
raise ValueError(f"Knowledge graph '{self.name}' does not exist")

def init_db(self, input_doc: list[Document]):
def init_db(self, input_doc: list[Document]) -> None:
"""Build the knowledge graph with input documents."""
sources = []
for doc in input_doc:
if os.path.exists(doc.path_or_url):
if doc.path_or_url and os.path.exists(doc.path_or_url):
sources.append(Source(doc.path_or_url))

if sources:
Expand Down Expand Up @@ -123,7 +123,7 @@ def init_db(self, input_doc: list[Document]):
def add_records(self, new_records: list[Document]) -> bool:
raise NotImplementedError("This method is not supported by FalkorDB SDK yet.")

def query(self, question: str, n_results: int = 1, **kwargs) -> GraphStoreQueryResult:
def query(self, question: str, n_results: int = 1, **kwargs: Any) -> GraphStoreQueryResult:
"""Query the knowledge graph with a question and optional message history.
Args:
Expand Down Expand Up @@ -153,17 +153,17 @@ def delete(self) -> bool:
self.falkordb.select_graph(self.ontology_table_name).delete()
return True

def __get_ontology_storage_graph(self) -> "Graph":
def __get_ontology_storage_graph(self) -> "Graph": # type: ignore[no-any-unimported]
return self.falkordb.select_graph(self.ontology_table_name)

def _save_ontology_to_db(self, ontology: "Ontology"):
def _save_ontology_to_db(self, ontology: "Ontology") -> None: # type: ignore[no-any-unimported]
"""Save graph ontology to a separate table with {graph_name}_ontology"""
if self.ontology_table_name in self.falkordb.list_graphs():
raise ValueError(f"Knowledge graph {self.name} is already created.")
graph = self.__get_ontology_storage_graph()
ontology.save_to_graph(graph)

def _load_ontology_from_db(self) -> "Ontology":
def _load_ontology_from_db(self) -> "Ontology": # type: ignore[no-any-unimported]
if self.ontology_table_name not in self.falkordb.list_graphs():
raise ValueError(f"Knowledge graph {self.name} has not been created.")
graph = self.__get_ontology_storage_graph()
Expand Down
18 changes: 12 additions & 6 deletions autogen/agentchat/contrib/graph_rag/falkor_graph_rag_capability.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,18 @@ def __init__(self, query_engine: FalkorGraphQueryEngine):
"""Initialize GraphRAG capability with a graph query engine"""
self.query_engine = query_engine

def add_to_agent(self, agent: UserProxyAgent):
def add_to_agent(self, agent: ConversableAgent) -> None:
"""Add FalkorDB GraphRAG capability to a UserProxyAgent.
Args:
agent: The UserProxyAgent instance to add the capability to.
The restriction to a UserProxyAgent to make sure the returned message does not contain information retrieved from the graph DB instead of any LLMs.
"""
if not isinstance(agent, UserProxyAgent):
raise Exception("FalkorDB GraphRAG capability can only be added to a UserProxyAgent.")

self.graph_rag_agent = agent

# Validate the agent config
Expand Down Expand Up @@ -62,7 +70,8 @@ def _reply_using_falkordb_query(
Returns:
A tuple containing a boolean indicating success and the assistant's reply.
"""
question = self._messages_summary(messages, recipient.system_message)
# todo: fix typing, this is not correct
question = self._messages_summary(messages, recipient.system_message) # type: ignore[arg-type]
result: GraphStoreQueryResult = self.query_engine.query(question)

return True, result.answer if result.answer else "I'm sorry, I don't have an answer for that."
Expand All @@ -77,10 +86,7 @@ def _messages_summary(self, messages: Union[dict[str, Any], str], system_message
<content>
"""
if isinstance(messages, str):
if system_message:
summary = f"IMPORTANT: {system_message}\nContext:\n\n{messages}"
else:
return messages
return (f"IMPORTANT: {system_message}\n" if system_message else "") + f"Context:\n\n{messages}"

elif isinstance(messages, list):
summary = ""
Expand Down
12 changes: 7 additions & 5 deletions autogen/agentchat/contrib/graph_rag/graph_query_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
from dataclasses import dataclass, field
from typing import Optional, Protocol, runtime_checkable
from typing import Any, Optional, Protocol, runtime_checkable

from .document import Document

__all__ = ["GraphQueryEngine", "GraphStoreQueryResult"]


@dataclass
class GraphStoreQueryResult:
Expand All @@ -19,7 +21,7 @@ class GraphStoreQueryResult:
"""

answer: Optional[str] = None
results: list = field(default_factory=list)
results: list[Any] = field(default_factory=list)


@runtime_checkable
Expand All @@ -29,7 +31,7 @@ class GraphQueryEngine(Protocol):
This interface defines the basic methods for graph-based RAG.
"""

def init_db(self, input_doc: Optional[list[Document]] = None):
def init_db(self, input_doc: Optional[list[Document]] = None) -> None:
"""This method initializes graph database with the input documents or records.
Usually, it takes the following steps,
1. connecting to a graph database.
Expand All @@ -42,10 +44,10 @@ def init_db(self, input_doc: Optional[list[Document]] = None):
"""
pass

def add_records(self, new_records: list) -> bool:
def add_records(self, new_records: list[Any]) -> bool:
"""Add new records to the underlying database and add to the graph if required."""
pass

def query(self, question: str, n_results: int = 1, **kwargs) -> GraphStoreQueryResult:
def query(self, question: str, n_results: int = 1, **kwarg: Any) -> GraphStoreQueryResult:
"""This method transform a string format question into database query and return the result."""
pass
6 changes: 4 additions & 2 deletions autogen/agentchat/contrib/graph_rag/graph_rag_capability.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from ..capabilities.agent_capability import AgentCapability
from .graph_query_engine import GraphQueryEngine

__all__ = ["GraphRagCapability"]


class GraphRagCapability(AgentCapability):
"""A graph-based RAG capability uses a graph query engine to give a conversable agent the graph-based RAG ability.
Expand Down Expand Up @@ -52,10 +54,10 @@ class GraphRagCapability(AgentCapability):
```
"""

def __init__(self, query_engine: GraphQueryEngine):
def __init__(self, query_engine: GraphQueryEngine) -> None:
"""Initialize graph-based RAG capability with a graph query engine"""
...

def add_to_agent(self, agent: ConversableAgent):
def add_to_agent(self, agent: ConversableAgent) -> None:
"""Add the capability to an agent"""
...
28 changes: 15 additions & 13 deletions autogen/agentchat/contrib/graph_rag/neo4j_graph_query_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0
import os
import sys
from typing import Optional, Union
from typing import Any, Optional, Union

if sys.version_info >= (3, 10):
from typing import TypeAlias
Expand All @@ -17,6 +17,7 @@
with optional_import_block():
from llama_index.core import PropertyGraphIndex, SimpleDirectoryReader
from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.core.chat_engine.types import ChatMode
from llama_index.core.indices.property_graph import (
DynamicLLMPathExtractor,
SchemaLLMPathExtractor,
Expand All @@ -25,6 +26,7 @@
from llama_index.core.llms import LLM
from llama_index.core.readers.json import JSONReader
from llama_index.core.schema import Document as LlamaDocument
from llama_index.core.schema import TransformComponent
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.graph_stores.neo4j import Neo4jPropertyGraphStore
from llama_index.llms.openai import OpenAI
Expand All @@ -51,7 +53,7 @@ class Neo4jGraphQueryEngine:
For usage, please refer to example notebook/agentchat_graph_rag_neo4j.ipynb
"""

def __init__(
def __init__( # type: ignore[no-any-unimported]
self,
host: str = "bolt://localhost",
port: int = 7687,
Expand Down Expand Up @@ -94,7 +96,7 @@ def __init__(
self.schema = schema
self.strict = strict

def init_db(self, input_doc: Optional[list[Document]] = None):
def init_db(self, input_doc: Optional[list[Document]] = None) -> None:
"""Build the knowledge graph with input documents."""
self.documents = self._load_doc(input_doc if input_doc is not None else [])

Expand All @@ -120,7 +122,7 @@ def init_db(self, input_doc: Optional[list[Document]] = None):
show_progress=True,
)

def connect_db(self):
def connect_db(self) -> None:
"""Connect to an existing knowledge graph database."""
self.graph_store = Neo4jPropertyGraphStore(
username=self.username,
Expand All @@ -139,7 +141,7 @@ def connect_db(self):
show_progress=True,
)

def add_records(self, new_records: list) -> bool:
def add_records(self, new_records: list[Document]) -> bool:
"""Add new records to the knowledge graph. Must be local files.
Args:
Expand All @@ -166,7 +168,7 @@ def add_records(self, new_records: list) -> bool:
print(f"Error adding records: {e}")
return False

def query(self, question: str, n_results: int = 1, **kwargs) -> GraphStoreQueryResult:
def query(self, question: str, n_results: int = 1, **kwargs: Any) -> GraphStoreQueryResult:
"""Query the property graph with a question using LlamaIndex chat engine.
We use the condense_plus_context chat mode
which condenses the conversation history and the user query into a standalone question,
Expand All @@ -185,7 +187,7 @@ def query(self, question: str, n_results: int = 1, **kwargs) -> GraphStoreQueryR

# Initialize chat engine if not already initialized
if not hasattr(self, "chat_engine"):
self.chat_engine = self.index.as_chat_engine(chat_mode="condense_plus_context", llm=self.llm)
self.chat_engine = self.index.as_chat_engine(chat_mode=ChatMode.CONDENSE_PLUS_CONTEXT, llm=self.llm)

response = self.chat_engine.chat(question)
return GraphStoreQueryResult(answer=str(response))
Expand All @@ -197,7 +199,7 @@ def _clear(self) -> None:
with self.graph_store._driver.session() as session:
session.run("MATCH (n) DETACH DELETE n;")

def _load_doc(self, input_doc: list[Document]) -> list["LlamaDocument"]:
def _load_doc(self, input_doc: list[Document]) -> list["LlamaDocument"]: # type: ignore[no-any-unimported]
"""Load documents from the input files. Currently support the following file types:
.csv - comma-separated values
.docx - Microsoft Word
Expand All @@ -214,7 +216,7 @@ def _load_doc(self, input_doc: list[Document]) -> list["LlamaDocument"]:
.json JSON files
"""
for doc in input_doc:
if not os.path.exists(doc.path_or_url):
if not os.path.exists(doc.path_or_url): # type: ignore[arg-type]
raise ValueError(f"Document file not found: {doc.path_or_url}")

common_type_input_files = []
Expand All @@ -228,11 +230,11 @@ def _load_doc(self, input_doc: list[Document]) -> list["LlamaDocument"]:
if common_type_input_files:
loaded_documents.extend(SimpleDirectoryReader(input_files=common_type_input_files).load_data())
for json_file in json_type_input_files:
loaded_documents.extend(JSONReader().load_data(input_file=json_file))
loaded_documents.extend(JSONReader().load_data(input_file=json_file)) # type: ignore[arg-type]

return loaded_documents

def _create_kg_extractors(self):
def _create_kg_extractors(self) -> list["TransformComponent"]: # type: ignore[no-any-unimported]
"""If strict is True,
extract paths following a strict schema of allowed relationships for each entity.
Expand All @@ -242,13 +244,13 @@ def _create_kg_extractors(self):
# To add more extractors, please refer to https://docs.llamaindex.ai/en/latest/module_guides/indexing/lpg_index_guide/#construction
"""
#
kg_extractors = [
kg_extractors: list["TransformComponent"] = [ # type: ignore[no-any-unimported]
SchemaLLMPathExtractor(
llm=self.llm,
possible_entities=self.entities,
possible_relations=self.relations,
kg_validation_schema=self.schema,
strict=self.strict,
strict=self.strict if self.strict else False,
),
]

Expand Down
Loading

0 comments on commit 0ed53a3

Please sign in to comment.