Skip to content

Commit

Permalink
[dITjfMpu] apoc.coll.indexOf unexpectedly treats collections differen…
Browse files Browse the repository at this point in the history
…tly than the same hardcoded list (neo4j/apoc#422) (#3600)

* [dITjfMpu] apoc.coll.indexOf unexpectedly treats collections differently than the same hardcoded list

* [dITjfMpu] various changes

* [dITjfMpu] small review changes
  • Loading branch information
vga91 authored May 31, 2023
1 parent 4a92bc1 commit 1feeee3
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 12 deletions.
32 changes: 20 additions & 12 deletions core/src/main/java/apoc/coll/Coll.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package apoc.coll;

import apoc.result.ListResult;
import apoc.util.Util;
import com.google.common.util.concurrent.AtomicDouble;
import org.apache.commons.lang3.mutable.MutableInt;
import org.apache.commons.math3.stat.descriptive.moment.StandardDeviation;
Expand All @@ -35,6 +36,7 @@
import org.neo4j.procedure.Name;
import org.neo4j.procedure.Procedure;
import org.neo4j.procedure.UserFunction;
import org.neo4j.values.AnyValue;

import java.lang.reflect.Array;
import java.text.Collator;
Expand All @@ -59,6 +61,8 @@
import java.util.stream.IntStream;
import java.util.stream.Stream;

import static apoc.util.Util.containsValueEquals;
import static apoc.util.Util.toAnyValues;
import static java.util.Arrays.asList;

public class Coll {
Expand Down Expand Up @@ -391,12 +395,12 @@ public Stream<ListResult> split(@Name("values") List<Object> list, @Name("value"
if (list==null || list.isEmpty()) return Stream.empty();
List<Object> l = new ArrayList<>(list);
List<List<Object>> result = new ArrayList<>(10);
int idx = l.indexOf(value);
int idx = Util.indexOf(l, value);
while (idx != -1) {
List<Object> subList = l.subList(0, idx);
if (!subList.isEmpty()) result.add(subList);
l = l.subList(idx+1,l.size());
idx = l.indexOf(value);
idx = Util.indexOf(l, value);
}
if (!l.isEmpty()) result.add(l);
return result.stream().map(ListResult::new);
Expand Down Expand Up @@ -473,14 +477,17 @@ public List<Object> remove(@Name("coll") List<Object> coll, @Name("index") long
public long indexOf(@Name("coll") List<Object> coll, @Name("value") Object value) {
// return reduce(res=[0,-1], x in $list | CASE WHEN x=$value AND res[1]=-1 THEN [res[0], res[0]+1] ELSE [res[0]+1, res[1]] END)[1] as value
if (coll == null || coll.isEmpty()) return -1;
return new ArrayList<>(coll).indexOf(value);
return Util.indexOf(coll, value);
}

@UserFunction
@Description("apoc.coll.containsAll(coll, values) optimized contains-all operation (using a HashSet) (returns single row or not)")
public boolean containsAll(@Name("coll") List<Object> coll, @Name("values") List<Object> values) {
if (coll == null || coll.isEmpty() || values == null) return false;
return new HashSet<>(coll).containsAll(values);
Set<Object> objects = new HashSet<>(coll);

return values.stream()
.allMatch( i -> containsValueEquals(objects, i));
}

@UserFunction
Expand Down Expand Up @@ -524,7 +531,8 @@ public boolean isEqualCollection(@Name("coll") List<Object> first, @Name("values
@Description("apoc.coll.toSet([list]) returns a unique list backed by a set")
public List<Object> toSet(@Name("values") List<Object> list) {
if (list == null) return null;
return new SetBackedList(new LinkedHashSet(list));
List<AnyValue> anyValues = toAnyValues(list);
return new SetBackedList(new LinkedHashSet(anyValues));
}

@UserFunction
Expand Down Expand Up @@ -604,17 +612,17 @@ public List<Object> union(@Name("first") List<Object> first, @Name("second") Lis
@UserFunction
@Description("apoc.coll.subtract(first, second) - returns unique set of first list with all elements of second list removed")
public List<Object> subtract(@Name("first") List<Object> first, @Name("second") List<Object> second) {
if (first == null) return null;
Set<Object> set = new HashSet<>(first);
if (second!=null) set.removeAll(second);
return new SetBackedList(set);
if (first == null) return null;
List<Object> list = new ArrayList<>(toAnyValues(first));
if (second!=null) list.removeAll(toAnyValues(second));
return list;
}
@UserFunction
@Description("apoc.coll.removeAll(first, second) - returns first list with all elements of second list removed")
public List<Object> removeAll(@Name("first") List<Object> first, @Name("second") List<Object> second) {
if (first == null) return null;
List<Object> list = new ArrayList<>(first);
if (second!=null) list.removeAll(second);
if (first == null) return null;
List<Object> list = new ArrayList<>(toAnyValues(first));
if (second!=null) list.removeAll(toAnyValues(second));
return list;
}

Expand Down
28 changes: 28 additions & 0 deletions core/src/main/java/apoc/util/Util.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import apoc.export.util.ExportConfig;
import apoc.result.VirtualNode;
import apoc.result.VirtualRelationship;
import org.apache.commons.collections4.ListUtils;
import org.apache.commons.compress.archivers.ArchiveEntry;
import org.apache.commons.compress.archivers.ArchiveInputStream;
import org.apache.commons.io.IOUtils;
Expand All @@ -50,8 +51,10 @@
import org.neo4j.logging.Log;
import org.neo4j.graphdb.ExecutionPlanDescription;
import org.neo4j.graphdb.Result;
import org.neo4j.kernel.impl.util.ValueUtils;
import org.neo4j.logging.NullLog;
import org.neo4j.procedure.Mode;
import org.neo4j.values.AnyValue;
import org.neo4j.procedure.TerminationGuard;
import org.neo4j.values.storable.CoordinateReferenceSystem;
import org.neo4j.values.storable.PointValue;
Expand Down Expand Up @@ -1202,4 +1205,29 @@ public static boolean isWindows() {
.toLowerCase()
.contains("win");
}

public static <T> boolean valueEquals(T one, T other) {
if (one == null || other == null) {
return false;
}
return ValueUtils.of(one)
.equals(ValueUtils.of(other));
}

public static boolean containsValueEquals(Collection<Object> collection, Object value) {
return collection.stream()
.anyMatch(i -> Util.valueEquals(value, i));
}

public static <T> List<AnyValue> toAnyValues(List<T> list) {
return list.stream()
.map(ValueUtils::of)
.collect(Collectors.toList());
}

public static int indexOf(List<Object> list, Object value) {
return ListUtils.indexOf(list,
(i) -> Util.valueEquals(i, value)
);
}
}
103 changes: 103 additions & 0 deletions core/src/test/java/apoc/coll/CollTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import apoc.convert.Json;
import apoc.util.TestUtil;
import org.junit.After;
import org.junit.BeforeClass;
import org.junit.ClassRule;
import org.junit.Test;
Expand All @@ -39,6 +40,19 @@
import static org.neo4j.internal.helpers.collection.Iterables.asSet;

public class CollTest {
// query that procedures a list,
// with both entity types, via collect(..), and hardcoded one
private static final String QUERY_WITH_MIXED_TYPES = "MATCH (n:Test) " +
"WITH n ORDER BY n.a " +
"WITH COLLECT({something: n.something}) + { something: [] } + {something: 'alpha'} + {something: [1,2,3]} AS collection\n";

private static final String QUERY_WITH_ARRAY = "CREATE (:Test {a: 1, something: 'alpha' }), " +
"(:Test { a: 2, something: [] }), " +
"(:Test { a: 3, something: 'beta' })," +
"(:Test { a: 4, something: [1,2,3] })";

public static final Map<String, String> MAP_WITH_ALPHA = Map.of("something", "alpha");
public static final Map<String, String> MAP_WITH_BETA = Map.of("something", "beta");

@ClassRule
public static DbmsRule db = new ImpermanentDbmsRule();
Expand All @@ -48,6 +62,11 @@ public class CollTest {
TestUtil.registerProcedure(db, Coll.class, Json.class);
}

@After
public void after() {
db.executeTransactionally("MATCH (n) DETACH DELETE n");
}

@Test
public void testRunningTotal() throws Exception {
testCall(db, "RETURN apoc.coll.runningTotal([1,2,3,4,5.5,1]) as value",
Expand Down Expand Up @@ -197,6 +216,14 @@ public void testIndexOf() throws Exception {
testCall(db, "RETURN apoc.coll.indexOf([1,2,3],null) AS value", r -> assertEquals(-1L, r.get("value")));
}

@Test
public void testIndexOfWithCollections() {
db.executeTransactionally(QUERY_WITH_ARRAY);
testCall(db, QUERY_WITH_MIXED_TYPES + "RETURN apoc.coll.indexOf(collection, { something: [] }) AS index",
r -> assertEquals(1L, r.get("index"))
);
}

@Test
public void testSplit() throws Exception {
testResult(db, "CALL apoc.coll.split([1,2,3,2,4,5],2)", r -> {
Expand Down Expand Up @@ -224,6 +251,63 @@ public void testSplit() throws Exception {
});
}

@Test
public void testSplitOfWithBothHardcodedAndEntityTypes() {
db.executeTransactionally(QUERY_WITH_ARRAY);
testResult(db, QUERY_WITH_MIXED_TYPES + "CALL apoc.coll.split(collection, { something: [] }) YIELD value RETURN value",
r -> {
Map<String, Object> row = r.next();
List<Map<String, Object>> value = (List<Map<String, Object>>) row.get("value");
assertEquals(List.of(MAP_WITH_ALPHA), value);
row = r.next();
value = (List<Map<String, Object>>) row.get("value");
assertEquals(2, value.size());
assertEquals(MAP_WITH_BETA, value.get(0));
// in this case the `[1,2,3]` in `{ something: [1,2,3] }` is an array
assertMapWithNumericArray(value.get(1));

row = r.next();
value = (List<Map<String, Object>>) row.get("value");

assertEquals(2, value.size());
assertEquals(MAP_WITH_ALPHA, value.get(0));
// in this case the `[1,2,3]` in `{ something: [1,2,3] }` is an ArrayList
assertEquals(Map.of("something", List.of(1L,2L,3L)), value.get(1));

assertFalse(r.hasNext());
});
}

@Test
public void testRemoveWithBothHardcodedAndEntityTypes() {
db.executeTransactionally(QUERY_WITH_ARRAY);
testCall(db, QUERY_WITH_MIXED_TYPES + "RETURN apoc.coll.removeAll(collection, [{ something: [] }, { something: 'alpha' }]) AS value",
row -> {
List<Map<String, Object>> value = (List<Map<String, Object>>) row.get("value");
assertEquals(3, value.size());
assertEquals(MAP_WITH_BETA, value.get(0));
// in this case the `[1,2,3]` in `{ something: [1,2,3] }` is an array
assertMapWithNumericArray(value.get(1));
// in this case the `[1,2,3]` in `{ something: [1,2,3] }` is an ArrayList
assertEquals(Map.of("something", List.of(1L,2L,3L)), value.get(2));
});
}

@Test
public void testCollToSetWithBothHardcodedAndEntityTypes() {
db.executeTransactionally(QUERY_WITH_ARRAY);

testCall(db, QUERY_WITH_MIXED_TYPES + "RETURN apoc.coll.toSet(collection) AS value",
row -> {
List<Map<String, Object>> value = (List<Map<String, Object>>) row.get("value");
assertEquals(4, value.size());
assertEquals(MAP_WITH_ALPHA, value.get(0));
assertMapWithEmptyArray(value.get(1));
assertEquals(MAP_WITH_BETA, value.get(2));
assertMapWithNumericArray(value.get(3));
});
}

@Test
public void testSet() throws Exception {
testCall(db, "RETURN apoc.coll.set(null,0,4) AS value", r -> assertNull(r.get("value")));
Expand Down Expand Up @@ -284,6 +368,25 @@ public void testContainsAll() throws Exception {
testCall(db, "RETURN apoc.coll.containsAll([1,2,3],null) AS value", (res) -> assertEquals(false, res.get("value")));
}

@Test
public void testContainsAllOfWithCollections() {
db.executeTransactionally(QUERY_WITH_ARRAY);

testCall(db, QUERY_WITH_MIXED_TYPES + "RETURN apoc.coll.containsAll(collection, [{ something: [] }]) AS value",
row -> assertTrue( (boolean) row.get("value") )
);
}

private static void assertMapWithEmptyArray(Map map) {
assertEquals(1, map.size());
assertArrayEquals(new String[]{}, (String[]) map.get("something"));
}

private static void assertMapWithNumericArray(Map map) {
assertEquals(1, map.size());
assertArrayEquals(new long[] {1,2,3}, (long[]) map.get("something"));
}

@Test
public void testContainsAllSorted() throws Exception {
testCall(db, "RETURN apoc.coll.containsAllSorted([1,2,3],[1,2]) AS value", (res) -> assertEquals(true, res.get("value")));
Expand Down

0 comments on commit 1feeee3

Please sign in to comment.