Skip to content

Commit

Permalink
[NOID] Fixes #2182: added apoc.agg.rollup procedure (#4064)
Browse files Browse the repository at this point in the history
* Fixes #2182: added apoc.agg.rollup procedure

* updated extended.txt
  • Loading branch information
vga91 committed Jan 22, 2025
1 parent 9ad5cd5 commit c80a35b
Show file tree
Hide file tree
Showing 12 changed files with 1,249 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ apoc.agg.multiStats(value :: NODE | RELATIONSHIP, keys :: LIST OF STRING) :: (MA
|===


[[usage-apoc.data.email]]
[[usage-apoc.agg.multiStats]]
== Usage Examples

Given this dataset:
Expand Down

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions docs/asciidoc/modules/ROOT/pages/overview/apoc.agg/index.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -81,5 +81,13 @@ Returns index of the `element` that match the given `predicate`
apoc.agg.multiStats(nodeOrRel, keys) - Return a multi-dimensional aggregation
|label:function[]
|label:apoc-full[]

|xref::overview/apoc.agg/apoc.agg.rollup.adoc[apoc.agg.rollup icon:book[]]

apoc.agg.rollup(<ANY>, [groupKeys], [aggKeys])

Emulate an Oracle/Mysql rollup command: `ROLLUP groupKeys, SUM(aggKey1), AVG(aggKey1), COUNT(aggKey1), SUM(aggKey2), AVG(aggKey2), ... `
|label:function[]
|label:apoc-full[]
|===

Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,13 @@ apoc.bitwise.op(60,'\|',13) bitwise operations a & b, a \| b, a ^ b, ~a, a >> b,

apoc.agg.multiStats(nodeOrRel, keys) - Return a multi-dimensional aggregation
|label:procedure[]

|xref::overview/apoc.agg/apoc.agg.rollup.adoc[apoc.agg.rollup icon:book[]]

apoc.agg.rollup(<ANY>, [groupKeys], [aggKeys])

Emulate an Oracle/Mysql rollup command: `ROLLUP groupKeys, SUM(aggKey1), AVG(aggKey1), COUNT(aggKey1), SUM(aggKey2), AVG(aggKey2), ... `
|label:procedure[]
|===


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ This file is generated by DocsTest, so don't change it!
*** xref::overview/apoc.atomic/apoc.atomic.update.adoc[]
** xref::overview/apoc.bitwise/index.adoc[]
*** xref::overview/apoc.bitwise/apoc.bitwise.op.adoc[]
*** xref::overview/apoc.agg/apoc.agg.rollup.adoc[]
** xref::overview/apoc.bolt/index.adoc[]
*** xref::overview/apoc.bolt/apoc.bolt.execute.adoc[]
*** xref::overview/apoc.bolt/apoc.bolt.load.adoc[]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
The procedure support the following properties in the APOC configuration file (`apoc.conf`):

.Config parameters
[opts=header, cols="1,1,1,3"]
|===
| name | type | default | description
| cube | boolean | false| to emulate the https://docs.oracle.com/cd/F49540_01/DOC/server.815/a68003/rollup_c.htm#32311[CUBE] clause,
instead of the https://docs.oracle.com/cd/F49540_01/DOC/server.815/a68003/rollup_c.htm#32084[ROLLUP] one.
|===
53 changes: 53 additions & 0 deletions full/src/main/java/apoc/agg/AggregationUtil.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package apoc.agg;

import java.util.Map;

public class AggregationUtil {

public static void updateAggregationValues(Map<String, Number> partialResult, Object property, String countKey, String sumKey, String avgKey) {
Number count = updateCountValue(partialResult, countKey);

updateSumAndAvgValues(partialResult, property, count.doubleValue(), sumKey, avgKey);
}

private static Number updateCountValue(Map<String, Number> partialResult, String countKey) {
Number count = partialResult.compute(countKey,
((subKey, subVal) -> {
return subVal == null ? 1 : subVal.longValue() + 1;
}));
return count;
}

private static void updateSumAndAvgValues(Map<String, Number> partialResult, Object property, double count, String sumKey, String avgKey) {
if (!(property instanceof Number)) {
return;
}

Number numberProp = (Number) property;

Number sum = partialResult.compute(sumKey,
((subKey, subVal) -> {
if (subVal == null) {
if (numberProp instanceof Long) {
return numberProp;
}
return numberProp.doubleValue();
}
if (subVal instanceof Long
&& numberProp instanceof Long) {
Long long2 = (Long) numberProp;
Long long1 = (Long) subVal;
return long1 + long2;
}
return subVal.doubleValue() + numberProp.doubleValue();
}));

partialResult.compute(avgKey, ((subKey, subVal) -> {
if (subVal == null) {
return numberProp.doubleValue();
}
return sum.doubleValue() / count;
})
);
}
}
26 changes: 7 additions & 19 deletions full/src/main/java/apoc/agg/MultiStats.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.neo4j.procedure.UserAggregationResult;
import org.neo4j.procedure.UserAggregationUpdate;

import static apoc.agg.AggregationUtil.updateAggregationValues;
@Extended
public class MultiStats {

Expand All @@ -38,26 +39,13 @@ public void aggregate(@Name("value") Object value, @Name(value = "keys") List<St

map.compute(property.toString(), (propKey, propVal) -> {
Map<String, Number> propMap = Objects.requireNonNullElseGet(propVal, HashMap::new);
Number count = propMap.compute(
"count", ((subKey, subVal) -> subVal == null ? 1 : subVal.longValue() + 1));
if (property instanceof Number) {
Number numberProp = (Number) property;
Number sum = propMap.compute("sum", ((subKey, subVal) -> {
if (subVal == null) return numberProp;
if (subVal instanceof Long && numberProp instanceof Long) {
Long long2 = (Long) numberProp;
Long long1 = (Long) subVal;
return long1 + long2;
}
return subVal.doubleValue() + numberProp.doubleValue();
}));

propMap.compute(
"avg",
((subKey, subVal) -> subVal == null
? numberProp.doubleValue()
: sum.doubleValue() / count.doubleValue()));
}
String countKey = "count";
String sumKey = "sum";
String avgKey = "avg";

updateAggregationValues(propMap, property, countKey, sumKey, avgKey);

return propMap;
});
return map;
Expand Down
195 changes: 195 additions & 0 deletions full/src/main/java/apoc/agg/Rollup.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
package apoc.agg;

import apoc.Extended;
import apoc.util.Util;
import org.apache.commons.collections4.ListUtils;
import org.neo4j.graphdb.Entity;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Name;
import org.neo4j.procedure.UserAggregationFunction;
import org.neo4j.procedure.UserAggregationResult;
import org.neo4j.procedure.UserAggregationUpdate;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import static apoc.agg.AggregationUtil.updateAggregationValues;


@Extended
public class Rollup {
public static final String NULL_ROLLUP = "[NULL]";

@UserAggregationFunction("apoc.agg.rollup")
@Description("apoc.agg.rollup(<ANY>, [groupKeys], [aggKeys])" +
"\n Emulate an Oracle/Mysql rollup command: `ROLLUP groupKeys, SUM(aggKey1), AVG(aggKey1), COUNT(aggKey1), SUM(aggKey2), AVG(aggKey2), ... `")
public RollupFunction rollup() {
return new RollupFunction();
}

public static class RollupFunction {
// Function to generate all combinations of a list with "TEST" as a placeholder
public static <T> List<List<T>> generateCombinationsWithPlaceholder(List<T> elements) {
List<List<T>> result = new ArrayList<>();
generateCombinationsWithPlaceholder(elements, 0, new ArrayList<>(), result);
return result;
}

// Helper function for generating combinations recursively
private static <T> void generateCombinationsWithPlaceholder(List<T> elements, int index, List<T> current, List<List<T>> result) {
if (index == elements.size()) {
result.add(new ArrayList<>(current));
return;
}

current.add(elements.get(index));
generateCombinationsWithPlaceholder(elements, index + 1, current, result);
current.remove(current.size() - 1);

// Add "NULL" as a combination placeholder
current.add((T) NULL_ROLLUP);
generateCombinationsWithPlaceholder(elements, index + 1, current, result);
current.remove(current.size() - 1);
}

private final Map<String, Object> result = new HashMap<>();

private final Map<List<Object>, Map<String, Number>> rolledUpData = new HashMap<>();
private List<String> groupKeysRes = null;

@UserAggregationUpdate
public void aggregate(
@Name("value") Object value,
@Name(value = "groupKeys") List<String> groupKeys,
@Name(value = "aggKeys") List<String> aggKeys,
@Name(value = "config", defaultValue = "{}") Map<String, Object> config) {

boolean cube = Util.toBoolean(config.get("cube"));

Entity entity = (Entity) value;

if (groupKeys.isEmpty()) {
return;
}
groupKeysRes = groupKeys;

/*
if true:
emulate the CUBE command: https://docs.oracle.com/cd/F49540_01/DOC/server.815/a68003/rollup_c.htm#32311
else:
emulate the ROLLUP command: https://docs.oracle.com/cd/F49540_01/DOC/server.815/a68003/rollup_c.htm#32084
*/
if (cube) {
List<List<String>> groupingSets = generateCombinationsWithPlaceholder(groupKeys);

for (List<String> groupKey : groupingSets) {
List<Object> partialKey = new ArrayList<>();
for (String column : groupKey) {
partialKey.add(((Entity) value).getProperty(column, NULL_ROLLUP));
}
if (!rolledUpData.containsKey(partialKey)) {
rolledUpData.put(partialKey, new HashMap<>());
}
rollupAggregationProperties(aggKeys, entity, partialKey);
}

return;
}

List<Object> groupKey = groupKeys.stream()
.map(i -> entity.getProperty(i, null))
.collect(Collectors.toList());

for (int i = 0; i <= groupKey.size(); i++) {
// add NULL_ROLLUP to remaining elements,
// e.g. `[<firstGroupKey>, `NULL_ROLLUP`, `NULL_ROLLUP`]`
List<Object> partialKey = ListUtils.union(groupKey.subList(0, i), Collections.nCopies(groupKey.size() - i, NULL_ROLLUP));
if (!rolledUpData.containsKey(partialKey)) {
rolledUpData.put(partialKey, new HashMap<>());
}
rollupAggregationProperties(aggKeys, entity, partialKey);
}
}

private void rollupAggregationProperties(List<String> aggKeys, Entity entity, List<Object> partialKey) {
Map<String, Number> partialResult = rolledUpData.get(partialKey);
for(var aggKey: aggKeys) {
if (!entity.hasProperty(aggKey)) {
continue;
}

Object property = entity.getProperty(aggKey);

String countKey = String.format("COUNT(%s)", aggKey);
String sumKey = String.format("SUM(%s)", aggKey);
String avgKey = String.format("AVG(%s)", aggKey);

updateAggregationValues(partialResult, property, countKey, sumKey, avgKey);
}
}

/**
* Transform a Map.of(ListGroupKeys, MapOfAggResults) in a List of Map.of(AggResult + ListGroupKeyToMap)
*/
@UserAggregationResult
public Object result() {
List<HashMap<String, Object>> list = rolledUpData.entrySet().stream()
.map(e -> {
HashMap<String, Object> map = new HashMap<>();
for (int i = 0; i < groupKeysRes.size(); i++) {
map.put(groupKeysRes.get(i), e.getKey().get(i));
}
map.putAll(e.getValue());
return map;
})
.sorted((m1, m2) -> {
for (String key : groupKeysRes) {
Object value1 = m1.get(key);
Object value2 = m2.get(key);
int cmp = compareValues(value1, value2);
if (cmp != 0) {
return cmp;
}
}
return 0;
})
.collect(Collectors.toList());

return list;
}

/**
* We use this instead of e.g. apoc.coll.sortMulti
* since we have to handle the NULL_ROLLUP values as well
*/
private static int compareValues(Object value1, Object value2) {
if (value1 == null && value2 == null) {
return 0;
} else if (value1 == null) {
return 1;
} else if (value2 == null) {
return -1;
} else if (NULL_ROLLUP.equals(value1) && NULL_ROLLUP.equals(value2)) {
return 0;
} else if (NULL_ROLLUP.equals(value1)) {
return 1;
} else if (NULL_ROLLUP.equals(value2)) {
return -1;
} else if (value1 instanceof Comparable && value2 instanceof Comparable) {
try {
return ((Comparable<Object>) value1).compareTo(value2);
} catch (Exception e) {
// e.g. different data types, like int and strings
return 0;
}

} else {
return 0;
}
}
}
}
1 change: 1 addition & 0 deletions full/src/main/resources/extended.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
apoc.agg.position
apoc.agg.row
apoc.agg.multiStats
apoc.agg.rollup
apoc.algo.aStarWithPoint
apoc.algo.travellingSalesman
apoc.bolt.execute
Expand Down
Loading

0 comments on commit c80a35b

Please sign in to comment.