Skip to content

Commit

Permalink
Use the neo4j_database parameter everywhere (#216)
Browse files Browse the repository at this point in the history
* Use self.neo4j_database for all queries in Neo4jWriter

* Make sure all execute_query can be run against a custom database

* Update CHANGELOG

* Update docstring + update examples not to use undocumented feature for neo4j driver

* Expose neo4j_database in SimpleKGBuilder

* Update CHANGELOG

* Simplify changelog
  • Loading branch information
stellasia authored Nov 22, 2024
1 parent 8285d56 commit 4d49525
Show file tree
Hide file tree
Showing 29 changed files with 158 additions and 59 deletions.
11 changes: 9 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,15 @@
## Next

### Added
- Introduced optional lexical graph configuration for SimpleKGPipeline, enhancing flexibility in customizing node labels and relationship types in the lexical graph.
- Ability to provide description and list of properties for entities and relations in the SimpleKGPipeline constructor.
- Introduced optional lexical graph configuration for `SimpleKGPipeline`, enhancing flexibility in customizing node labels and relationship types in the lexical graph.
- Introduced optional `neo4j_database` parameter for `SimpleKGPipeline`, `Neo4jChunkReader`and `Text2CypherRetriever`.
- Ability to provide description and list of properties for entities and relations in the `SimpleKGPipeline` constructor.

### Fixed
- `neo4j_database` parameter is now used for all queries in the `Neo4jWriter`.

### Changed
- Updated all examples to use `neo4j_database` parameter instead of an undocumented neo4j driver constructor.

## 1.2.0

Expand Down
3 changes: 2 additions & 1 deletion examples/build_graph/simple_kg_builder_from_pdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ async def define_and_run_pipeline(
entities=ENTITIES,
relations=RELATIONS,
potential_schema=POTENTIAL_SCHEMA,
neo4j_database=DATABASE,
)
return await kg_builder.run_async(file_path=str(file_path))

Expand All @@ -62,7 +63,7 @@ async def main() -> PipelineResult:
"response_format": {"type": "json_object"},
},
)
with neo4j.GraphDatabase.driver(URI, auth=AUTH, database=DATABASE) as driver:
with neo4j.GraphDatabase.driver(URI, auth=AUTH) as driver:
res = await define_and_run_pipeline(driver, llm)
await llm.async_client.close()
return res
Expand Down
5 changes: 3 additions & 2 deletions examples/build_graph/simple_kg_builder_from_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
# Neo4j db infos
URI = "neo4j://localhost:7687"
AUTH = ("neo4j", "password")
DATABASE = "neo4j"
DATABASE = "newdb"

# Text to process
TEXT = """The son of Duke Leto Atreides and the Lady Jessica, Paul is the heir of House Atreides,
Expand Down Expand Up @@ -67,6 +67,7 @@ async def define_and_run_pipeline(
relations=RELATIONS,
potential_schema=POTENTIAL_SCHEMA,
from_pdf=False,
neo4j_database=DATABASE,
)
return await kg_builder.run_async(text=TEXT)

Expand All @@ -79,7 +80,7 @@ async def main() -> PipelineResult:
"response_format": {"type": "json_object"},
},
)
with neo4j.GraphDatabase.driver(URI, auth=AUTH, database=DATABASE) as driver:
with neo4j.GraphDatabase.driver(URI, auth=AUTH) as driver:
res = await define_and_run_pipeline(driver, llm)
await llm.async_client.close()
return res
Expand Down
2 changes: 1 addition & 1 deletion examples/customize/answer/custom_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
driver = neo4j.GraphDatabase.driver(
URI,
auth=AUTH,
database=DATABASE,
)

embedder = OpenAIEmbeddings()
Expand All @@ -33,6 +32,7 @@
index_name=INDEX,
retrieval_query="WITH node, score RETURN node.title as title, node.plot as plot",
embedder=embedder,
neo4j_database=DATABASE,
)

llm = OpenAILLM(model_name="gpt-4o", model_params={"temperature": 0})
Expand Down
2 changes: 1 addition & 1 deletion examples/customize/answer/langchain_compatiblity.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
driver = neo4j.GraphDatabase.driver(
URI,
auth=AUTH,
database=DATABASE,
)

embedder = OpenAIEmbeddings(model="text-embedding-ada-002")
Expand All @@ -31,6 +30,7 @@
index_name=INDEX,
retrieval_query="WITH node, score RETURN node.title as title, node.plot as plot",
embedder=embedder, # type: ignore[arg-type, unused-ignore]
neo4j_database=DATABASE,
)

llm = ChatOpenAI(model="gpt-4o", temperature=0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def my_result_formatter(record: neo4j.Record) -> RetrieverResultItem:
)


with neo4j.GraphDatabase.driver(URI, auth=AUTH, database=DATABASE) as driver:
with neo4j.GraphDatabase.driver(URI, auth=AUTH) as driver:
# Initialize the retriever
retriever = VectorCypherRetriever(
driver=driver,
Expand All @@ -48,7 +48,7 @@ def my_result_formatter(record: neo4j.Record) -> RetrieverResultItem:
retrieval_query=RETRIEVAL_QUERY,
result_formatter=my_result_formatter,
# optionally, set neo4j database
# neo4j_database="neo4j",
neo4j_database=DATABASE,
)

# Perform the similarity search for a text query
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@


# Connect to Neo4j database
driver = neo4j.GraphDatabase.driver(URI, auth=AUTH, database=DATABASE)
driver = neo4j.GraphDatabase.driver(URI, auth=AUTH)


query_text = "Find a movie about astronauts"
Expand All @@ -52,6 +52,7 @@
index_name=INDEX_NAME,
embedder=OpenAIEmbeddings(),
return_properties=["title", "plot"],
neo4j_database=DATABASE,
)
print(retriever.search(query_text=query_text, top_k=top_k_results))
print()
Expand Down
2 changes: 1 addition & 1 deletion examples/question_answering/graphrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def formatter(record: neo4j.Record) -> RetrieverResultItem:
driver = neo4j.GraphDatabase.driver(
URI,
auth=AUTH,
database=DATABASE,
)

embedder = OpenAIEmbeddings()
Expand All @@ -46,6 +45,7 @@ def formatter(record: neo4j.Record) -> RetrieverResultItem:
retrieval_query="with node, score return node.title as title, node.plot as plot",
result_formatter=formatter,
embedder=embedder,
neo4j_database=DATABASE,
)

llm = OpenAILLM(model_name="gpt-4o", model_params={"temperature": 0})
Expand Down
4 changes: 2 additions & 2 deletions examples/retrieve/hybrid_cypher_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
# the name of all actors starring in that movie
RETRIEVAL_QUERY = " MATCH (node)<-[:ACTED_IN]-(p:Person) RETURN node.title as movieTitle, node.plot as moviePlot, collect(p.name) as actors, score as similarityScore"

with neo4j.GraphDatabase.driver(URI, auth=AUTH, database=DATABASE) as driver:
with neo4j.GraphDatabase.driver(URI, auth=AUTH) as driver:
# Initialize the retriever
retriever = HybridCypherRetriever(
driver=driver,
Expand All @@ -37,7 +37,7 @@
# (see corresponding example in 'customize' directory)
# result_formatter=None,
# optionally, set neo4j database
# neo4j_database="neo4j",
neo4j_database=DATABASE,
)

# Perform the similarity search for a text query
Expand Down
4 changes: 2 additions & 2 deletions examples/retrieve/hybrid_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
FULLTEXT_INDEX_NAME = "movieFulltext"


with neo4j.GraphDatabase.driver(URI, auth=AUTH, database=DATABASE) as driver:
with neo4j.GraphDatabase.driver(URI, auth=AUTH) as driver:
# Initialize the retriever
retriever = HybridRetriever(
driver=driver,
Expand All @@ -31,7 +31,7 @@
# (see corresponding example in 'customize' directory)
# result_formatter=None,
# optionally, set neo4j database
# neo4j_database="neo4j",
neo4j_database=DATABASE,
)

# Perform the similarity search for a text query
Expand Down
4 changes: 2 additions & 2 deletions examples/retrieve/similarity_search_for_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
INDEX_NAME = "moviePlotsEmbedding"


with neo4j.GraphDatabase.driver(URI, auth=AUTH, database=DATABASE) as driver:
with neo4j.GraphDatabase.driver(URI, auth=AUTH) as driver:
# Initialize the retriever
retriever = VectorRetriever(
driver=driver,
Expand All @@ -29,7 +29,7 @@
# (see corresponding example in 'customize' directory)
# result_formatter=None,
# optionally, set neo4j database
# neo4j_database="neo4j",
neo4j_database=DATABASE,
)

# Perform the similarity search for a text query
Expand Down
3 changes: 2 additions & 1 deletion examples/retrieve/similarity_search_for_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
INDEX_NAME = "moviePlotsEmbedding"


with neo4j.GraphDatabase.driver(URI, auth=AUTH, database=DATABASE) as driver:
with neo4j.GraphDatabase.driver(URI, auth=AUTH) as driver:
# Initialize the retriever
retriever = VectorRetriever(
driver=driver,
index_name=INDEX_NAME,
neo4j_database=DATABASE,
)

# Perform the similarity search for a vector query
Expand Down
1 change: 1 addition & 0 deletions examples/retrieve/text2cypher_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
# optionally, you can also provide your own prompt
# for the text2Cypher generation step
# custom_prompt="",
neo4j_database=DATABASE,
)

# Generate a Cypher query using the LLM, send it to the Neo4j database, and return the results
Expand Down
4 changes: 2 additions & 2 deletions examples/retrieve/vector_cypher_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
# the name of all actors starring in that movie
RETRIEVAL_QUERY = " MATCH (node)<-[:ACTED_IN]-(p:Person) RETURN node.title as movieTitle, node.plot as moviePlot, collect(p.name) as actors, score as similarityScore"

with neo4j.GraphDatabase.driver(URI, auth=AUTH, database=DATABASE) as driver:
with neo4j.GraphDatabase.driver(URI, auth=AUTH) as driver:
# Initialize the retriever
retriever = VectorCypherRetriever(
driver=driver,
Expand All @@ -34,7 +34,7 @@
# (see corresponding example in 'customize' directory)
# result_formatter=None,
# optionally, set neo4j database
# neo4j_database="neo4j",
neo4j_database=DATABASE,
)

# Perform the similarity search for a text query
Expand Down
31 changes: 22 additions & 9 deletions src/neo4j_graphrag/experimental/components/kg_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class Neo4jWriter(KGWriter):
Args:
driver (neo4j.driver): The Neo4j driver to connect to the database.
neo4j_database (Optional[str]): The name of the Neo4j database to write to. Defaults to 'neo4j' if not provided.
neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to the server's default database ("neo4j" by default) (`see reference to documentation <https://neo4j.com/docs/operations-manual/current/database-administration/#manage-databases-default>`_).
batch_size (int): The number of nodes or relationships to write to the database in a batch. Defaults to 1000.
Example:
Expand All @@ -99,7 +99,7 @@ class Neo4jWriter(KGWriter):
AUTH = ("neo4j", "password")
DATABASE = "neo4j"
driver = GraphDatabase.driver(URI, auth=AUTH, database=DATABASE)
driver = GraphDatabase.driver(URI, auth=AUTH)
writer = Neo4jWriter(driver=driver, neo4j_database=DATABASE)
pipeline = Pipeline()
Expand All @@ -119,10 +119,11 @@ def __init__(
self.is_version_5_23_or_above = self._check_if_version_5_23_or_above()

def _db_setup(self) -> None:
# create index on __Entity__.id
# create index on __KGBuilder__.id
# used when creating the relationships
self.driver.execute_query(
"CREATE INDEX __entity__id IF NOT EXISTS FOR (n:__KGBuilder__) ON (n.id)"
"CREATE INDEX __entity__id IF NOT EXISTS FOR (n:__KGBuilder__) ON (n.id)",
database_=self.neo4j_database,
)

@staticmethod
Expand Down Expand Up @@ -150,10 +151,16 @@ def _upsert_nodes(
parameters = {"rows": self._nodes_to_rows(nodes, lexical_graph_config)}
if self.is_version_5_23_or_above:
self.driver.execute_query(
UPSERT_NODE_QUERY_VARIABLE_SCOPE_CLAUSE, parameters_=parameters
UPSERT_NODE_QUERY_VARIABLE_SCOPE_CLAUSE,
parameters_=parameters,
database_=self.neo4j_database,
)
else:
self.driver.execute_query(UPSERT_NODE_QUERY, parameters_=parameters)
self.driver.execute_query(
UPSERT_NODE_QUERY,
parameters_=parameters,
database_=self.neo4j_database,
)

def _get_version(self) -> tuple[int, ...]:
records, _, _ = self.driver.execute_query(
Expand Down Expand Up @@ -187,10 +194,16 @@ def _upsert_relationships(self, rels: list[Neo4jRelationship]) -> None:
parameters = {"rows": [rel.model_dump() for rel in rels]}
if self.is_version_5_23_or_above:
self.driver.execute_query(
UPSERT_RELATIONSHIP_QUERY_VARIABLE_SCOPE_CLAUSE, parameters_=parameters
UPSERT_RELATIONSHIP_QUERY_VARIABLE_SCOPE_CLAUSE,
parameters_=parameters,
database_=self.neo4j_database,
)
else:
self.driver.execute_query(UPSERT_RELATIONSHIP_QUERY, parameters_=parameters)
self.driver.execute_query(
UPSERT_RELATIONSHIP_QUERY,
parameters_=parameters,
database_=self.neo4j_database,
)

@validate_call
async def run(
Expand All @@ -202,7 +215,7 @@ async def run(
Args:
graph (Neo4jGraph): The knowledge graph to upsert into the database.
lexical_graph_config (LexicalGraphConfig):
lexical_graph_config (LexicalGraphConfig): Node labels and relationship types for the lexical graph.
"""
try:
self._db_setup()
Expand Down
38 changes: 37 additions & 1 deletion src/neo4j_graphrag/experimental/components/neo4j_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# limitations under the License.
from __future__ import annotations

from typing import Optional

import neo4j
from pydantic import validate_call

Expand All @@ -26,13 +28,39 @@


class Neo4jChunkReader(Component):
"""Reads text chunks from a Neo4j database.
Args:
driver (neo4j.driver): The Neo4j driver to connect to the database.
fetch_embeddings (bool): If True, the embedding property is also returned. Default to False.
neo4j_database (Optional[str]): The name of the Neo4j database. If not provided, this defaults to the server's default database ("neo4j" by default) (`see reference to documentation <https://neo4j.com/docs/operations-manual/current/database-administration/#manage-databases-default>`_).
Example:
.. code-block:: python
from neo4j import GraphDatabase
from neo4j_graphrag.experimental.components.neo4j_reader import Neo4jChunkReader
URI = "neo4j://localhost:7687"
AUTH = ("neo4j", "password")
DATABASE = "neo4j"
driver = GraphDatabase.driver(URI, auth=AUTH)
reader = Neo4jChunkReader(driver=driver, neo4j_database=DATABASE)
await reader.run()
"""

def __init__(
self,
driver: neo4j.Driver,
fetch_embeddings: bool = False,
neo4j_database: Optional[str] = None,
):
self.driver = driver
self.fetch_embeddings = fetch_embeddings
self.neo4j_database = neo4j_database

def _get_query(
self,
Expand All @@ -56,12 +84,20 @@ async def run(
self,
lexical_graph_config: LexicalGraphConfig = LexicalGraphConfig(),
) -> TextChunks:
"""Reads text chunks from a Neo4j database.
Args:
lexical_graph_config (LexicalGraphConfig): Node labels and relationship types for the lexical graph.
"""
query = self._get_query(
lexical_graph_config.chunk_node_label,
lexical_graph_config.chunk_index_property,
lexical_graph_config.chunk_embedding_property,
)
result, _, _ = self.driver.execute_query(query)
result, _, _ = self.driver.execute_query(
query,
database_=self.neo4j_database,
)
chunks = []
for record in result:
chunk = record.get("chunk")
Expand Down
Loading

0 comments on commit 4d49525

Please sign in to comment.