Skip to content

Commit

Permalink
Refactors Neo4j version checking code (#255)
Browse files Browse the repository at this point in the history
* Refactors Neo4j version checking code into a set of utility functions

* Adds the copyright header to new files

* Updated base retriever

* Fixed hybrid retriever unit tests

* Updated Text2Cypher unit tests

* Updated vector retriever tests

* Added the ability to return the edition of the db in get_version

* Updated kg writer
  • Loading branch information
alexthomas93 authored Jan 23, 2025
1 parent 3b7ded4 commit e0b5e86
Show file tree
Hide file tree
Showing 12 changed files with 426 additions and 256 deletions.
30 changes: 6 additions & 24 deletions src/neo4j_graphrag/experimental/components/kg_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@
UPSERT_RELATIONSHIP_QUERY,
UPSERT_RELATIONSHIP_QUERY_VARIABLE_SCOPE_CLAUSE,
)
from neo4j_graphrag.utils.version_utils import (
get_version,
is_version_5_23_or_above,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -116,7 +120,8 @@ def __init__(
self.driver = driver
self.neo4j_database = neo4j_database
self.batch_size = batch_size
self.is_version_5_23_or_above = self._check_if_version_5_23_or_above()
version_tuple, _, _ = get_version(self.driver, self.neo4j_database)
self.is_version_5_23_or_above = is_version_5_23_or_above(version_tuple)

def _db_setup(self) -> None:
# create index on __KGBuilder__.id
Expand Down Expand Up @@ -162,29 +167,6 @@ def _upsert_nodes(
database_=self.neo4j_database,
)

def _get_version(self) -> tuple[int, ...]:
records, _, _ = self.driver.execute_query(
"CALL dbms.components()", database_=self.neo4j_database
)
version = records[0]["versions"][0]
# Drop everything after the '-' first
version_main, *_ = version.split("-")
# Convert each number between '.' into int
version_tuple = tuple(map(int, version_main.split(".")))
# If no patch version, consider it's 0
if len(version_tuple) < 3:
version_tuple = (*version_tuple, 0)
return version_tuple

def _check_if_version_5_23_or_above(self) -> bool:
"""
Check if the connected Neo4j database version supports the required features.
Sets a flag if the connected Neo4j version is 5.23 or above.
"""
version_tuple = self._get_version()
return version_tuple >= (5, 23, 0)

def _upsert_relationships(self, rels: list[Neo4jRelationship]) -> None:
"""Upserts a single relationship into the Neo4j database.
Expand Down
61 changes: 14 additions & 47 deletions src/neo4j_graphrag/retrievers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@

from neo4j_graphrag.exceptions import Neo4jVersionError
from neo4j_graphrag.types import RawSearchResult, RetrieverResult, RetrieverResultItem
from neo4j_graphrag.utils.version_utils import (
get_version,
has_metadata_filtering_support,
has_vector_index_support,
is_version_5_23_or_above,
)

T = ParamSpec("T")
P = TypeVar("P")
Expand Down Expand Up @@ -85,53 +91,14 @@ def __init__(self, driver: neo4j.Driver, neo4j_database: Optional[str] = None):
self.driver = driver
self.neo4j_database = neo4j_database
if self.VERIFY_NEO4J_VERSION:
self._verify_version()

def _get_version(self) -> tuple[tuple[int, ...], bool]:
records, _, _ = self.driver.execute_query(
"CALL dbms.components()",
database_=self.neo4j_database,
routing_=neo4j.RoutingControl.READ,
)
version = records[0]["versions"][0]
# drop everything after the '-' first
version_main, *_ = version.split("-")
# convert each number between '.' into int
version_tuple = tuple(map(int, version_main.split(".")))
# if no patch version, consider it's 0
if len(version_tuple) < 3:
version_tuple = (*version_tuple, 0)
return version_tuple, "aura" in version

def _check_if_version_5_23_or_above(self, version_tuple: tuple[int, ...]) -> bool:
"""
Check if the connected Neo4j database version supports the required features.
Sets a flag if the connected Neo4j version is 5.23 or above.
"""
return version_tuple >= (5, 23, 0)

def _verify_version(self) -> None:
"""
Check if the connected Neo4j database version supports vector indexing.
Queries the Neo4j database to retrieve its version and compares it
against a target version (5.18.1) that is known to support vector
indexing. Raises a Neo4jMinVersionError if the connected Neo4j version is
not supported.
"""
version_tuple, is_aura = self._get_version()
self.neo4j_version_is_5_23_or_above = self._check_if_version_5_23_or_above(
version_tuple
)

if is_aura:
target_version = (5, 18, 0)
else:
target_version = (5, 18, 1)

if version_tuple < target_version:
raise Neo4jVersionError()
version_tuple, is_aura, _ = get_version(self.driver, self.neo4j_database)
self.neo4j_version_is_5_23_or_above = is_version_5_23_or_above(
version_tuple
)
if not has_vector_index_support(
version_tuple
) or not has_metadata_filtering_support(version_tuple, is_aura):
raise Neo4jVersionError()

def _fetch_index_infos(self, vector_index_name: str) -> None:
"""Fetch the node label and embedding property from the index definition
Expand Down
101 changes: 101 additions & 0 deletions src/neo4j_graphrag/utils/version_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [https://neo4j.com]
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# https://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional

import neo4j


def get_version(
driver: neo4j.Driver, database: Optional[str] = None
) -> tuple[tuple[int, ...], bool, bool]:
"""
Retrieves the Neo4j database version and checks if it is running on the Aura platform.
Args:
driver (neo4j.Driver): Neo4j Python driver instance to execute the query.
database (str, optional): The name of the Neo4j database to query. Defaults to None.
Returns:
tuple[tuple[int, ...], bool]:
- A tuple of integers representing the database version (major, minor, patch) or
(year, month, patch) for later versions.
- A boolean indicating whether the database is hosted on the Aura platform.
- A boolean indicating whether the database is running the enterprise edition.
"""
records, _, _ = driver.execute_query(
"CALL dbms.components()",
database_=database,
routing_=neo4j.RoutingControl.READ,
)
version = records[0]["versions"][0]
edition = records[0]["edition"]
# drop everything after the '-' first
version_main, *_ = version.split("-")
# convert each number between '.' into int
version_tuple = tuple(map(int, version_main.split(".")))
# if no patch version, consider it's 0
if len(version_tuple) < 3:
version_tuple = (*version_tuple, 0)
return version_tuple, "aura" in version, edition == "enterprise"


def is_version_5_23_or_above(version_tuple: tuple[int, ...]) -> bool:
"""
Determines if the Neo4j database version is 5.23 or above.
Args:
version_tuple (tuple[int, ...]): A tuple of integers representing the database version
(major, minor, patch) or (year, month, patch) for later versions.
Returns:
bool: True if the version is 5.23.0 or above, False otherwise.
"""
return version_tuple >= (5, 23, 0)


def has_vector_index_support(version_tuple: tuple[int, ...]) -> bool:
"""
Checks if a Neo4j database supports vector indexing based on its version and platform.
Args:
version_tuple (neo4j.Driver): A tuple of integers representing the database version (major, minor, patch) or
(year, month, patch) for later versions.
Returns:
bool: True if the connected Neo4j database version supports vector indexing, False otherwise.
"""
return version_tuple >= (5, 11, 0)


def has_metadata_filtering_support(
version_tuple: tuple[int, ...], is_aura: bool
) -> bool:
"""
Checks if a Neo4j database supports vector index metadata filtering based on its version and platform.
Args:
version_tuple (neo4j.Driver): A tuple of integers representing the database version (major, minor, patch) or
(year, month, patch) for later versions.
is_aura (bool): A boolean indicating whether the database is hosted on the Aura platform.
Returns:
bool: True if the connected Neo4j database version supports vector index metadata filtering , False otherwise.
"""
if is_aura:
target_version = (5, 18, 0)
else:
target_version = (5, 18, 1)

return version_tuple >= target_version
28 changes: 13 additions & 15 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,37 +51,35 @@ def retriever_mock() -> MagicMock:


@pytest.fixture(scope="function")
@patch("neo4j_graphrag.retrievers.VectorRetriever._verify_version")
def vector_retriever(
_verify_version_mock: MagicMock, driver: MagicMock
) -> VectorRetriever:
@patch("neo4j_graphrag.retrievers.base.get_version")
def vector_retriever(mock_get_version: MagicMock, driver: MagicMock) -> VectorRetriever:
mock_get_version.return_value = ((5, 23, 0), False, False)
return VectorRetriever(driver, "my-index")


@pytest.fixture(scope="function")
@patch("neo4j_graphrag.retrievers.VectorCypherRetriever._verify_version")
@patch("neo4j_graphrag.retrievers.base.get_version")
def vector_cypher_retriever(
_verify_version_mock: MagicMock, driver: MagicMock
mock_get_version: MagicMock, driver: MagicMock
) -> VectorCypherRetriever:
retrieval_query = """
RETURN node.id AS node_id, node.text AS text, score
"""
mock_get_version.return_value = ((5, 23, 0), False, False)
retrieval_query = "RETURN node.id AS node_id, node.text AS text, score"
return VectorCypherRetriever(driver, "my-index", retrieval_query)


@pytest.fixture(scope="function")
@patch("neo4j_graphrag.retrievers.HybridRetriever._verify_version")
def hybrid_retriever(
_verify_version_mock: MagicMock, driver: MagicMock
) -> HybridRetriever:
@patch("neo4j_graphrag.retrievers.base.get_version")
def hybrid_retriever(mock_get_version: MagicMock, driver: MagicMock) -> HybridRetriever:
mock_get_version.return_value = ((5, 23, 0), False, False)
return HybridRetriever(driver, "my-index", "my-fulltext-index")


@pytest.fixture(scope="function")
@patch("neo4j_graphrag.retrievers.Text2CypherRetriever._verify_version")
@patch("neo4j_graphrag.retrievers.base.get_version")
def t2c_retriever(
_verify_version_mock: MagicMock, driver: MagicMock, llm: MagicMock
mock_get_version: MagicMock, driver: MagicMock, llm: MagicMock
) -> Text2CypherRetriever:
mock_get_version.return_value = ((5, 23, 0), False, False)
return Text2CypherRetriever(driver, llm)


Expand Down
Loading

0 comments on commit e0b5e86

Please sign in to comment.