diff --git a/src/main/java/apoc/nodes/Grouping.java b/src/main/java/apoc/nodes/Grouping.java index 95c5c7439c..122ba3aa4c 100644 --- a/src/main/java/apoc/nodes/Grouping.java +++ b/src/main/java/apoc/nodes/Grouping.java @@ -30,51 +30,28 @@ public class Grouping { private static final int BATCHSIZE = 10000; + + private static final String ASTERISK = "*"; + @Context public GraphDatabaseAPI db; @Context public Log log; - static class Key { - private final int hash; - private final String label; - private final Map values; - - public Key(String label, Map values) { - this.label = label; - this.values = values; - hash = 31 * label.hashCode() + values.hashCode(); - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - - Key key = (Key) o; - return label.equals(key.label) && values.equals(key.values); - } - - @Override - public int hashCode() { - return hash; - } - } - @Procedure @Description("Group all nodes and their relationships by given keys, create virtual nodes and relationships for the summary information, you can provide an aggregations map [{kids:'sum',age:['min','max','avg'],gender:'collect'},{`*`,'count'}]") public Stream group(@Name("labels") List labels, @Name("groupByProperties") List groupByProperties, - @Name(value = "aggregations",defaultValue = "[{\"*\":\"count\"},{\"*\":\"count\"}]") List> aggregations) { + @Name(value = "aggregations", defaultValue = "[{\"*\":\"count\"},{\"*\":\"count\"}]") List> aggregations) { String[] keys = groupByProperties.toArray(new String[groupByProperties.size()]); - Map> nodeAggNames = (aggregations.size()>0) ? toStringListMap(aggregations.get(0)) : emptyMap(); - String[] nodeAggKeys = keyArray(nodeAggNames, "*"); + Map> nodeAggNames = (aggregations.size() > 0) ? toStringListMap(aggregations.get(0)) : emptyMap(); + String[] nodeAggKeys = keyArray(nodeAggNames, ASTERISK); - Map> relAggNames = (aggregations.size()>1) ? toStringListMap(aggregations.get(1)) : emptyMap(); - String[] relAggKeys = keyArray(relAggNames, "*");; + Map> relAggNames = (aggregations.size() > 1) ? toStringListMap(aggregations.get(1)) : emptyMap(); + String[] relAggKeys = keyArray(relAggNames, ASTERISK); - Map> grouped = new ConcurrentHashMap<>(); - Map virtual = new ConcurrentHashMap<>(); - Map virtualRels = new ConcurrentHashMap<>(); + Map> grouped = new ConcurrentHashMap<>(); + Map virtualNodes = new ConcurrentHashMap<>(); + Map virtualRels = new ConcurrentHashMap<>(); List futures = new ArrayList<>(1000); @@ -83,27 +60,32 @@ public Stream group(@Name("labels") List labels, @Name("gro Label label = Label.label(labelName); Label[] singleLabel = {label}; - try (ResourceIterator nodes = db.findNodes(label)) { + try (ResourceIterator nodes = (labelName.equals(ASTERISK)) ? db.getAllNodes().iterator() : db.findNodes(label)) { while (nodes.hasNext()) { List batch = Util.take(nodes, BATCHSIZE); futures.add(Util.inTxFuture(pool, db, () -> { try { for (Node node : batch) { - Key key = keyFor(node, labelName, keys); - grouped.compute(key, (k, v) -> {if (v == null) v = new HashSet<>(); v.add(node); return v;}); - virtual.compute(key, (k, v) -> { - if (v == null) { - v = new VirtualNode(singleLabel,node.getProperties(keys),db); - } - Node vn = v; - if (!nodeAggNames.isEmpty()) { - aggregate(vn, nodeAggNames, nodeAggKeys.length > 0 ? node.getProperties(nodeAggKeys) : Collections.emptyMap()); - } - return vn;} + NodeKey key = keyFor(node, labelName, keys); + grouped.compute(key, (k, v) -> { + if (v == null) v = new HashSet<>(); + v.add(node); + return v; + }); + virtualNodes.compute(key, (k, v) -> { + if (v == null) { + v = new VirtualNode(singleLabel, propertiesFor(node, keys), db); + } + Node vn = v; + if (!nodeAggNames.isEmpty()) { + aggregate(vn, nodeAggNames, nodeAggKeys.length > 0 ? node.getProperties(nodeAggKeys) : Collections.emptyMap()); + } + return vn; + } ); } - } catch(Exception e) { - log.debug("Error grouping nodes",e); + } catch (Exception e) { + log.debug("Error grouping nodes", e); } return null; })); @@ -113,27 +95,27 @@ public Stream group(@Name("labels") List labels, @Name("gro } Util.waitForFutures(futures); futures.clear(); - Iterator>> entries = grouped.entrySet().iterator(); + Iterator>> entries = grouped.entrySet().iterator(); int size = 0; - List>> batch = new ArrayList<>(); + List>> batch = new ArrayList<>(); while (entries.hasNext()) { - Map.Entry> outerEntry = entries.next(); + Map.Entry> outerEntry = entries.next(); batch.add(outerEntry); size += outerEntry.getValue().size(); if (size > BATCHSIZE || !entries.hasNext()) { - ArrayList>> submitted = new ArrayList<>(batch); + ArrayList>> submitted = new ArrayList<>(batch); batch.clear(); size = 0; futures.add(Util.inTxFuture(pool, db, () -> { try { - for (Map.Entry> entry : submitted) { + for (Map.Entry> entry : submitted) { for (Node node : entry.getValue()) { - Key startKey = entry.getKey(); - Node v1 = virtual.get(startKey); + NodeKey startKey = entry.getKey(); + Node v1 = virtualNodes.get(startKey); for (Relationship rel : node.getRelationships(Direction.OUTGOING)) { Node endNode = rel.getEndNode(); - for (Key endKey : keysFor(endNode, labels, keys)) { - Node v2 = virtual.get(endKey); + for (NodeKey endKey : keysFor(endNode, labels, keys)) { + Node v2 = virtualNodes.get(endKey); if (v2 == null) continue; virtualRels.compute(new RelKey(startKey, endKey, rel), (rk, vRel) -> { if (vRel == null) vRel = v1.createRelationshipTo(v2, rel.getType()); @@ -146,8 +128,8 @@ public Stream group(@Name("labels") List labels, @Name("gro } } } - } catch(Exception e) { - log.debug("Error grouping relationships",e); + } catch (Exception e) { + log.debug("Error grouping relationships", e); } return null; })); @@ -155,22 +137,22 @@ public Stream group(@Name("labels") List labels, @Name("gro } } Util.waitForFutures(futures); - return fixAggregates(virtual.values()).stream().map( n -> new GraphResult(singletonList(n), fixAggregates(Iterables.asList(n.getRelationships())))); + return fixAggregates(virtualNodes.values()).stream().map(n -> new GraphResult(singletonList(n), fixAggregates(Iterables.asList(n.getRelationships())))); } - public Map> toStringListMap(Map input) { + private Map> toStringListMap(Map input) { Map> nodeAggNames = new LinkedHashMap<>(input.size()); input.forEach((k, v) -> nodeAggNames.put(k, v instanceof List ? ((List) v).stream().map(Object::toString).collect(Collectors.toList()) : singletonList(v.toString()))); return nodeAggNames; } - public String[] keyArray(Map map, String... removeKeys) { + private String[] keyArray(Map map, String... removeKeys) { List keys = new ArrayList<>(map.keySet()); for (String key : removeKeys) keys.remove(key); return keys.toArray(new String[keys.size()]); } - private ,T extends PropertyContainer> C fixAggregates(C pcs) { + private , T extends PropertyContainer> C fixAggregates(C pcs) { for (PropertyContainer pc : pcs) { pc.getAllProperties().entrySet().forEach((entry) -> { Object v = entry.getValue(); @@ -181,18 +163,18 @@ private ,T extends PropertyContainer> C fixAggregates(C } if (k.matches("^avg_.+") && v instanceof double[]) { double[] values = (double[]) v; - entry.setValue(values[1] == 0 ? 0 : values[0]/ values[1]); + entry.setValue(values[1] == 0 ? 0 : values[0] / values[1]); } if (k.matches("^collect_.+") && v instanceof Collection) { - entry.setValue(((Collection)v).toArray()); + entry.setValue(((Collection) v).toArray()); } }); } return pcs; } - public void aggregate(PropertyContainer pc, Map> aggregations, Map properties) { - aggregations.forEach((k2,aggNames) -> { + private void aggregate(PropertyContainer pc, Map> aggregations, Map properties) { + aggregations.forEach((k2, aggNames) -> { for (String aggName : aggNames) { String key = aggName + "_" + k2; if ("count_*".equals(key)) { @@ -232,29 +214,98 @@ public void aggregate(PropertyContainer pc, Map> aggregatio }); } - public Key keyFor(Node node, String label, String[] keys) { - Map props = node.getProperties(keys); - return new Key(label, props); + /** + * Returns the properties for the given node according to the specified keys. If a node does not have a property + * assigned to given key, the value is set to {@code null}. + * + * @param node node + * @param keys property keys + * @return node properties for keys + */ + private Map propertiesFor(Node node, String[] keys) { + Map props = new HashMap<>(keys.length); + + for (String key : keys) { + props.put(key, node.getProperty(key, null)); + } + + return props; } - public Collection keysFor(Node node, List labels, String[] keys) { - Map props = node.getProperties(keys); - List result=new ArrayList<>(labels.size()); - for (Label label : node.getLabels()) { - if (labels.contains(label.name())) { - result.add(new Key(label.name(), props)); + /** + * Creates a grouping key for the given node using its label and grouping properties. + * + * @param node node + * @param label node label + * @param keys property keys + * @return grouping key + */ + private NodeKey keyFor(Node node, String label, String[] keys) { + return new NodeKey(label, propertiesFor(node, keys)); + } + + /** + * Creates a grouping key for each specified label. + * + * @param node node + * @param labels node labels + * @param keys property keys + * @return grouping keys + */ + private Collection keysFor(Node node, List labels, String[] keys) { + Map props = propertiesFor(node, keys); + List result = new ArrayList<>(labels.size()); + if (labels.contains(ASTERISK)) { + result.add(new NodeKey(ASTERISK, props)); + } else { + for (Label label : node.getLabels()) { + if (labels.contains(label.name())) { + result.add(new NodeKey(label.name(), props)); + } } } return result; } + /** + * Represents a grouping key for nodes. + */ + static class NodeKey { + private final int hash; + private final String label; + private final Map values; + + NodeKey(String label, Map values) { + this.label = label; + this.values = values; + hash = 31 * label.hashCode() + values.hashCode(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + NodeKey key = (NodeKey) o; + return label.equals(key.label) && values.equals(key.values); + } + + @Override + public int hashCode() { + return hash; + } + } + + /** + * Represents a grouping key for relationships. + */ private static class RelKey { private final int hash; - private final Key startKey; - private final Key endKey; + private final NodeKey startKey; + private final NodeKey endKey; private final String type; - public RelKey(Key startKey, Key endKey, Relationship rel) { + RelKey(NodeKey startKey, NodeKey endKey, Relationship rel) { this.startKey = startKey; this.endKey = endKey; this.type = rel.getType().name(); diff --git a/src/test/java/apoc/nodes/GroupingTest.java b/src/test/java/apoc/nodes/GroupingTest.java index 12c6cd5bc7..f38fdf6a06 100644 --- a/src/test/java/apoc/nodes/GroupingTest.java +++ b/src/test/java/apoc/nodes/GroupingTest.java @@ -1,8 +1,9 @@ package apoc.nodes; -import apoc.coll.Coll; import apoc.util.TestUtil; -import org.junit.*; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; import org.neo4j.graphdb.GraphDatabaseService; import org.neo4j.graphdb.Node; import org.neo4j.graphdb.Relationship; @@ -11,11 +12,10 @@ import java.util.List; import java.util.Map; -import static apoc.util.TestUtil.testCall; import static apoc.util.TestUtil.testResult; import static apoc.util.Util.map; -import static java.util.Arrays.asList; -import static org.junit.Assert.*; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; /** * @author mh @@ -28,9 +28,19 @@ public class GroupingTest { public static void setUp() throws Exception { db = new TestGraphDatabaseFactory().newImpermanentDatabase(); TestUtil.registerProcedure(db, Grouping.class); - db.execute("CREATE (a:Person {name:'Alice',female:true,age:32,kids:1})-[:KNOWS]->(b:Person {name:'Bob', female:false, age:42,kids:3})<-[:KNOWS]-(c:Person {name:'Cath',female:true,age:28,kids:2})") - .close(); + db.execute("CREATE " + + "(alice:Person {name:'Alice', gender:'female', age:32, kids:1})," + + "(bob:Person {name:'Bob', gender:'male', age:42, kids:3})," + + "(eve:Person {name:'Eve', gender:'female', age:28, kids:2})," + + "(graphs:Forum {name:'Graphs', members:23})," + + "(dbs:Forum {name:'Databases', members:42})," + + "(alice)-[:KNOWS {since:2017}]->(bob)," + + "(eve)-[:KNOWS {since:2018}]->(bob)," + + "(alice)-[:MEMBER_OF]->(graphs)," + + "(alice)-[:MEMBER_OF]->(dbs)," + + "(bob)-[:MEMBER_OF]->(dbs)," + + "(eve)-[:MEMBER_OF]->(graphs)"); } @AfterClass @@ -38,33 +48,95 @@ public static void tearDown() { db.shutdown(); } + @Test + public void testGroupAllNodes() throws Exception { + Map female = map("gender", "female", "count_*", 2L, "min_age", 28L); + Map male = map("gender", "male", "count_*", 1L, "min_age", 42L); + Map other = map("gender", null, "count_*", 2L); + + testResult(db, "CALL apoc.nodes.group(" + + "['*'],['gender'],[" + + "{`*`:'count', age:'min'}," + + "{`*`:'count'}" + + "])", + (result) -> { + assertTrue(result.hasNext()); + + String[] keys = {"count_*", "gender", "min_age"}; + while (result.hasNext()) { + Map row = result.next(); + List nodes = (List) row.get("nodes"); + + assertEquals(1, nodes.size()); + Node node = nodes.get(0); + Object value = node.getProperty("gender"); + + List rels = (List) row.get("relationships"); + if (value == null) { + assertEquals(other, node.getProperties(keys)); + assertEquals(0L, rels.size()); + } else if (value.equals("female")) { + assertEquals(female, node.getProperties(keys)); + assertEquals(2L, rels.size()); + Relationship rel = rels.get(0); + Object count = rel.getProperty("count_*"); + if (count.equals(3L)) { // MEMBER_OF + assertEquals(other, rel.getEndNode().getProperties(keys)); + } else if (count.equals(2L)) { // KNOWS + assertEquals(male, rel.getEndNode().getProperties(keys)); + } else { + assertTrue("Unexpected count value: " + count, false); + } + } else if (value.equals("male")) { + assertEquals(male, node.getProperties(keys)); + assertEquals(1L, rels.size()); + Relationship rel = rels.get(0); // MEMBER_OF + assertEquals(1L, rel.getProperty("count_*")); + assertEquals(other, rel.getEndNode().getProperties(keys)); + } else { + assertTrue("Unexpected value: " + value, false); + } + } + }); + } + @Test public void testGroupNode() throws Exception { - Map female = map("female", true, "count_*", 2L, "sum_kids", 3L, "min_age", 28L, "max_age", 32L, "avg_age", 30D); - Map male = map("female", false, "count_*", 1L, "sum_kids", 3L, "min_age", 42L, "max_age", 42L, "avg_age", 42D); - testResult(db, "CALL apoc.nodes.group(['Person'],['female'],[{`*`:'count',kids:'sum',age:['min','max','avg'],gender:'collect'},{`*`:'count'}])", + Map female = map("gender", "female", "count_*", 2L, "sum_kids", 3L, "min_age", 28L, "max_age", 32L, "avg_age", 30D); + Map male = map("gender", "male", "count_*", 1L, "sum_kids", 3L, "min_age", 42L, "max_age", 42L, "avg_age", 42D); + testResult(db, "CALL apoc.nodes.group(" + + "['Person'],['gender'],[" + + "{`*`:'count', kids:'sum', age:['min', 'max', 'avg'], gender:'collect'}," + + "{`*`:'count', since:['min', 'max']}" + + "])", (result) -> { assertTrue(result.hasNext()); Map row = result.next(); List nodes = (List) row.get("nodes"); assertEquals(1,nodes.size()); Node node = nodes.get(0); - String[] keys = {"count_*", "female", "sum_kids", "min_age", "max_age", "avg_age"}; - assertEquals(node.getProperty("female").equals(true) ? female : male, node.getProperties(keys)); + String[] keys = {"count_*", "gender", "sum_kids", "min_age", "max_age", "avg_age"}; + assertEquals(node.getProperty("gender").equals("female") ? + female : male, node.getProperties(keys)); List rels = (List) row.get("relationships"); assertEquals(1,rels.size()); Relationship rel = rels.get(0); assertEquals(2L,rel.getProperty("count_*")); + assertEquals(2017L, rel.getProperty("min_since")); + assertEquals(2018L, rel.getProperty("max_since")); assertEquals("KNOWS",rel.getType().name()); node = rel.getOtherNode(node); - assertEquals(node.getProperty("female").equals(true) ? female : male, node.getProperties(keys)); + assertEquals(node.getProperty("gender").equals("female") ? + female : male, node.getProperties(keys)); assertTrue(result.hasNext()); row = result.next(); + System.out.println(row); nodes = (List) row.get("nodes"); assertEquals(1,nodes.size()); node = nodes.get(0); - assertEquals(node.getProperty("female").equals(true) ? female : male, node.getProperties(keys)); + assertEquals(node.getProperty("gender").equals("female") ? + female : male, node.getProperties(keys)); rels = (List) row.get("relationships"); assertEquals(0,rels.size());