Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added the possibility to return iterators from the match_* functions #26

Merged
merged 2 commits into from
Aug 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 1 addition & 77 deletions data2neo/neo4j/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import List, Union

from .graph_elements import Node, Relationship, Subgraph, Attribute
from .cypher import cypher_join, _match_clause, encode_value, encode_key
from .matching import match_nodes, match_relationships

def create(graph: Subgraph, session: Session):
"""
Expand Down Expand Up @@ -47,79 +47,3 @@ def pull(graph: Subgraph, session: Session):
"""
session.execute_read(graph.__db_pull__)


def match_nodes(session: Session, *labels: List[str], **properties: dict):
"""
Matches nodes in the database.

Args:
labels (List[str]): The labels to match.
session (Session): The `session <https://neo4j.com/docs/api/python-driver/current/api.html#session>`_ to use.
properties (dict): The properties to match.
"""
flat_params = [tuple(labels),]
data = []
for k, v in properties.items():
data.append(v)
flat_params.append(k)

if len(data) > 1:
data = [data]

unwind = "UNWIND $data as r" if len(data) > 0 else ""
clause = cypher_join(unwind, _match_clause('n', tuple(flat_params), "r"), "RETURN n, LABELS(n), ID(n)", data=data)

records = session.run(*clause).data()
# Convert to Node
out = []
for record in records:
node = Node.from_dict(record['LABELS(n)'], record['n'], identity=record['ID(n)'])
out.append(node)
return out


def match_relationships(session: Session, from_node: Node =None, to_node:Node =None, rel_type: str =None, **properties: dict):
"""
Matches relationships in the database.

Args:
session (Session): The `session <https://neo4j.com/docs/api/python-driver/current/api.html#session>`_ to use.
from_node (Node): The node to match the relationship from (Default: None)
to_node (Node): The node to match the relationship to (Default: None)
rel_type (str): The type of the relationship to match (Default: None)
properties (dict): The properties to match.
"""
if from_node is not None:
assert from_node.identity is not None, "from_node must have an identity"

if to_node is not None:
assert to_node.identity is not None, "to_node must have an identity"

params = ""
for k, v in properties.items():
if params != "":
params += ", "
params += f"{encode_key(k)}: {encode_value(v)}"

clauses = []
if from_node is not None:
clauses.append(f"ID(from_node) = {from_node.identity}")
if to_node is not None:
clauses.append(f"ID(to_node) = {to_node.identity}")
if rel_type is not None:
clauses.append(f"type(r) = {encode_value(rel_type)}")

clause = cypher_join(
f"MATCH (from_node)-[r {{{params}}}]->(to_node)",
"WHERE" if len(clauses) > 0 else "",
" AND ".join(clauses),
"RETURN PROPERTIES(r), TYPE(r), ID(r), from_node, LABELS(from_node), ID(from_node), to_node, LABELS(to_node), ID(to_node)"
)
records = session.run(*clause).data()
out = []
for record in records:
fn = Node.from_dict(record['LABELS(from_node)'], record['from_node'], identity=record['ID(from_node)']) if from_node is None else from_node
tn = Node.from_dict(record['LABELS(to_node)'], record['to_node'], identity=record['ID(to_node)']) if to_node is None else to_node
rel = Relationship.from_dict(fn, tn, record['TYPE(r)'], record['PROPERTIES(r)'], identity=record['ID(r)'])
out.append(rel)
return out
120 changes: 120 additions & 0 deletions data2neo/neo4j/matching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from neo4j import Session
from typing import List, Union

from .graph_elements import Node, Relationship, Subgraph, Attribute
from .cypher import cypher_join, _match_clause, encode_value, encode_key
from abc import ABC, abstractmethod

class ResultIterator(ABC):
def __init__(self, count, match):
self._count = count
self._match = match

def __len__(self):
return self._count

@abstractmethod
def __iter__(self):
pass

class NodeIterator(ResultIterator):
def __iter__(self):
for record in self._match:
node = Node.from_dict(record['LABELS(n)'], record['n'], identity=record['ID(n)'])
yield node

class RelationshipIterator(ResultIterator):
def __iter__(self):
for record in self._match:
fn = Node.from_dict(record['LABELS(from_node)'], record['from_node'], identity=record['ID(from_node)'])
tn = Node.from_dict(record['LABELS(to_node)'], record['to_node'], identity=record['ID(to_node)'])
rel = Relationship.from_dict(fn, tn, record['TYPE(r)'], record['PROPERTIES(r)'], identity=record['ID(r)'])
yield rel

def match_nodes(session: Session, *labels: List[str], return_iterator=False, **properties: dict):
"""
Matches nodes in the database.

Args:
session (Session): The `session <https://neo4j.com/docs/api/python-driver/current/api.html#session>`_ to use.
labels (List[str]): The labels to match.
return_iterator (bool): Whether to return an iterator or a list (Default: False)
properties (dict): The properties to match.
"""
flat_params = [tuple(labels),]
data = []
for k, v in properties.items():
data.append(v)
flat_params.append(k)

if len(data) > 1:
data = [data]

unwind = "UNWIND $data as r" if len(data) > 0 else ""


clause = cypher_join(unwind, _match_clause('n', tuple(flat_params), "r"), "RETURN n, LABELS(n), ID(n)", data=data)
count_clause = cypher_join(unwind, _match_clause('n', tuple(flat_params), "r"), "RETURN count(n)", data=data)

count = session.run(*count_clause).single().value()

match = session.run(*clause)
iterator = NodeIterator(count, match)

if return_iterator:
return iterator
else:
return list(iterator)

def match_relationships(session: Session, from_node: Node =None, to_node:Node =None, rel_type: str =None, return_iterator=False, **properties: dict):
"""
Matches relationships in the database.

Args:
session (Session): The `session <https://neo4j.com/docs/api/python-driver/current/api.html#session>`_ to use.
from_node (Node): The node to match the relationship from (Default: None)
to_node (Node): The node to match the relationship to (Default: None)
rel_type (str): The type of the relationship to match (Default: None)
return_iterator (bool): Whether to return an iterator or a list (Default: False)
properties (dict): The properties to match.
"""
if from_node is not None:
assert from_node.identity is not None, "from_node must have an identity"

if to_node is not None:
assert to_node.identity is not None, "to_node must have an identity"

params = ""
for k, v in properties.items():
if params != "":
params += ", "
params += f"{encode_key(k)}: {encode_value(v)}"

clauses = []
if from_node is not None:
clauses.append(f"ID(from_node) = {from_node.identity}")
if to_node is not None:
clauses.append(f"ID(to_node) = {to_node.identity}")
if rel_type is not None:
clauses.append(f"type(r) = {encode_value(rel_type)}")

clause = cypher_join(
f"MATCH (from_node)-[r {{{params}}}]->(to_node)",
"WHERE" if len(clauses) > 0 else "",
" AND ".join(clauses),
"RETURN PROPERTIES(r), TYPE(r), ID(r), from_node, LABELS(from_node), ID(from_node), to_node, LABELS(to_node), ID(to_node)"
)
count_clause = cypher_join(
f"MATCH (from_node)-[r {{{params}}}]->(to_node)",
"WHERE" if len(clauses) > 0 else "",
" AND ".join(clauses),
"RETURN count(r)"
)
count = session.run(*count_clause).single().value()

match = session.run(*clause)

if return_iterator:
return RelationshipIterator(count, match)
else:
return list(RelationshipIterator(count, match))
76 changes: 76 additions & 0 deletions tests/unit/neo4j/test_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,43 @@ def test_match_nodes(session):
nodes = match_nodes(session, name="test1", anotherattr="test")
assert(len(nodes) == 1)
assert(check_node(nodes, 1))

def test_match_nodes_with_iterator(session):
# match by single label
nodes = match_nodes(session, "test", return_iterator=True)
assert(len(nodes) == 2)
nodes = list(nodes)
assert(len(nodes) == 2)
assert(check_node(nodes, 1))
assert(check_node(nodes, 2))

# match by multiple labels
nodes = match_nodes(session, "test", "second", return_iterator=True)
assert(len(nodes) == 1)
nodes = list(nodes)
assert(len(nodes) == 1)
assert(check_node(nodes, 1))

# match by properties with no label
nodes = match_nodes(session, name="test3", return_iterator=True)
assert(len(nodes) == 1)
nodes = list(nodes)
assert(len(nodes) == 1)
assert(check_node(nodes, 3))

# match by properties with label
nodes = match_nodes(session, "test", name="test1", return_iterator=True)
assert(len(nodes) == 1)
nodes = list(nodes)
assert(len(nodes) == 1)
assert(check_node(nodes, 1))

# match by two properties
nodes = match_nodes(session, name="test1", anotherattr="test", return_iterator=True)
assert(len(nodes) == 1)
nodes = list(nodes)
assert(len(nodes) == 1)
assert(check_node(nodes, 1))

def test_match_relationships(session):
# match by type
Expand Down Expand Up @@ -109,3 +146,42 @@ def test_match_relationships(session):
assert(len(rels) == 1)
assert(check_rel(rels, 1))

def test_match_relationships_with_iterator(session):
# match by type
rels = match_relationships(session, rel_type="to", return_iterator=True)
assert(len(rels) == 2)
rels = list(rels)
assert(len(rels) == 2)
assert(check_rel(rels, 1))
assert(check_rel(rels, 2))

# match by properties
rels = match_relationships(session, rel_type="to", id=1, return_iterator=True)
assert(len(rels) == 1)
rels = list(rels)
assert(len(rels) == 1)
assert(check_rel(rels, 1))

# match by multiple properties
rels = match_relationships(session, rel_type="to", id=2, anotherattr="test", return_iterator=True)
assert(len(rels) == 1)
rels = list(rels)
assert(len(rels) == 1)
assert(check_rel(rels,2))

# match by from node
n1 = match_nodes(session, "test", id=1)[0]
rels = match_relationships(session, from_node=n1, return_iterator=True)
assert(len(rels) == 2)
rels = list(rels)
assert(len(rels) == 2)
assert(check_rel(rels, 1))
assert(check_rel(rels, 2))

# match by to node
n2 = match_nodes(session, "test", id=2)[0]
rels = match_relationships(session, to_node=n2, return_iterator=True)
assert(len(rels) == 1)
rels = list(rels)
assert(len(rels) == 1)
assert(check_rel(rels, 1))
Loading