From 1feeee30f3db70bd7c5f210c8729e55a743bc6ce Mon Sep 17 00:00:00 2001 From: Giuseppe Villani Date: Wed, 31 May 2023 09:13:41 +0200 Subject: [PATCH] [dITjfMpu] apoc.coll.indexOf unexpectedly treats collections differently than the same hardcoded list (neo4j/apoc#422) (#3600) * [dITjfMpu] apoc.coll.indexOf unexpectedly treats collections differently than the same hardcoded list * [dITjfMpu] various changes * [dITjfMpu] small review changes --- core/src/main/java/apoc/coll/Coll.java | 32 ++++--- core/src/main/java/apoc/util/Util.java | 28 ++++++ core/src/test/java/apoc/coll/CollTest.java | 103 +++++++++++++++++++++ 3 files changed, 151 insertions(+), 12 deletions(-) diff --git a/core/src/main/java/apoc/coll/Coll.java b/core/src/main/java/apoc/coll/Coll.java index e6c0f9e78b..d81e145428 100644 --- a/core/src/main/java/apoc/coll/Coll.java +++ b/core/src/main/java/apoc/coll/Coll.java @@ -19,6 +19,7 @@ package apoc.coll; import apoc.result.ListResult; +import apoc.util.Util; import com.google.common.util.concurrent.AtomicDouble; import org.apache.commons.lang3.mutable.MutableInt; import org.apache.commons.math3.stat.descriptive.moment.StandardDeviation; @@ -35,6 +36,7 @@ import org.neo4j.procedure.Name; import org.neo4j.procedure.Procedure; import org.neo4j.procedure.UserFunction; +import org.neo4j.values.AnyValue; import java.lang.reflect.Array; import java.text.Collator; @@ -59,6 +61,8 @@ import java.util.stream.IntStream; import java.util.stream.Stream; +import static apoc.util.Util.containsValueEquals; +import static apoc.util.Util.toAnyValues; import static java.util.Arrays.asList; public class Coll { @@ -391,12 +395,12 @@ public Stream split(@Name("values") List list, @Name("value" if (list==null || list.isEmpty()) return Stream.empty(); List l = new ArrayList<>(list); List> result = new ArrayList<>(10); - int idx = l.indexOf(value); + int idx = Util.indexOf(l, value); while (idx != -1) { List subList = l.subList(0, idx); if (!subList.isEmpty()) result.add(subList); l = l.subList(idx+1,l.size()); - idx = l.indexOf(value); + idx = Util.indexOf(l, value); } if (!l.isEmpty()) result.add(l); return result.stream().map(ListResult::new); @@ -473,14 +477,17 @@ public List remove(@Name("coll") List coll, @Name("index") long public long indexOf(@Name("coll") List coll, @Name("value") Object value) { // return reduce(res=[0,-1], x in $list | CASE WHEN x=$value AND res[1]=-1 THEN [res[0], res[0]+1] ELSE [res[0]+1, res[1]] END)[1] as value if (coll == null || coll.isEmpty()) return -1; - return new ArrayList<>(coll).indexOf(value); + return Util.indexOf(coll, value); } @UserFunction @Description("apoc.coll.containsAll(coll, values) optimized contains-all operation (using a HashSet) (returns single row or not)") public boolean containsAll(@Name("coll") List coll, @Name("values") List values) { if (coll == null || coll.isEmpty() || values == null) return false; - return new HashSet<>(coll).containsAll(values); + Set objects = new HashSet<>(coll); + + return values.stream() + .allMatch( i -> containsValueEquals(objects, i)); } @UserFunction @@ -524,7 +531,8 @@ public boolean isEqualCollection(@Name("coll") List first, @Name("values @Description("apoc.coll.toSet([list]) returns a unique list backed by a set") public List toSet(@Name("values") List list) { if (list == null) return null; - return new SetBackedList(new LinkedHashSet(list)); + List anyValues = toAnyValues(list); + return new SetBackedList(new LinkedHashSet(anyValues)); } @UserFunction @@ -604,17 +612,17 @@ public List union(@Name("first") List first, @Name("second") Lis @UserFunction @Description("apoc.coll.subtract(first, second) - returns unique set of first list with all elements of second list removed") public List subtract(@Name("first") List first, @Name("second") List second) { - if (first == null) return null; - Set set = new HashSet<>(first); - if (second!=null) set.removeAll(second); - return new SetBackedList(set); + if (first == null) return null; + List list = new ArrayList<>(toAnyValues(first)); + if (second!=null) list.removeAll(toAnyValues(second)); + return list; } @UserFunction @Description("apoc.coll.removeAll(first, second) - returns first list with all elements of second list removed") public List removeAll(@Name("first") List first, @Name("second") List second) { - if (first == null) return null; - List list = new ArrayList<>(first); - if (second!=null) list.removeAll(second); + if (first == null) return null; + List list = new ArrayList<>(toAnyValues(first)); + if (second!=null) list.removeAll(toAnyValues(second)); return list; } diff --git a/core/src/main/java/apoc/util/Util.java b/core/src/main/java/apoc/util/Util.java index b5e353fbe2..f945c90cff 100644 --- a/core/src/main/java/apoc/util/Util.java +++ b/core/src/main/java/apoc/util/Util.java @@ -25,6 +25,7 @@ import apoc.export.util.ExportConfig; import apoc.result.VirtualNode; import apoc.result.VirtualRelationship; +import org.apache.commons.collections4.ListUtils; import org.apache.commons.compress.archivers.ArchiveEntry; import org.apache.commons.compress.archivers.ArchiveInputStream; import org.apache.commons.io.IOUtils; @@ -50,8 +51,10 @@ import org.neo4j.logging.Log; import org.neo4j.graphdb.ExecutionPlanDescription; import org.neo4j.graphdb.Result; +import org.neo4j.kernel.impl.util.ValueUtils; import org.neo4j.logging.NullLog; import org.neo4j.procedure.Mode; +import org.neo4j.values.AnyValue; import org.neo4j.procedure.TerminationGuard; import org.neo4j.values.storable.CoordinateReferenceSystem; import org.neo4j.values.storable.PointValue; @@ -1202,4 +1205,29 @@ public static boolean isWindows() { .toLowerCase() .contains("win"); } + + public static boolean valueEquals(T one, T other) { + if (one == null || other == null) { + return false; + } + return ValueUtils.of(one) + .equals(ValueUtils.of(other)); + } + + public static boolean containsValueEquals(Collection collection, Object value) { + return collection.stream() + .anyMatch(i -> Util.valueEquals(value, i)); + } + + public static List toAnyValues(List list) { + return list.stream() + .map(ValueUtils::of) + .collect(Collectors.toList()); + } + + public static int indexOf(List list, Object value) { + return ListUtils.indexOf(list, + (i) -> Util.valueEquals(i, value) + ); + } } diff --git a/core/src/test/java/apoc/coll/CollTest.java b/core/src/test/java/apoc/coll/CollTest.java index 4fb8daccc4..1b65ec3442 100644 --- a/core/src/test/java/apoc/coll/CollTest.java +++ b/core/src/test/java/apoc/coll/CollTest.java @@ -20,6 +20,7 @@ import apoc.convert.Json; import apoc.util.TestUtil; +import org.junit.After; import org.junit.BeforeClass; import org.junit.ClassRule; import org.junit.Test; @@ -39,6 +40,19 @@ import static org.neo4j.internal.helpers.collection.Iterables.asSet; public class CollTest { + // query that procedures a list, + // with both entity types, via collect(..), and hardcoded one + private static final String QUERY_WITH_MIXED_TYPES = "MATCH (n:Test) " + + "WITH n ORDER BY n.a " + + "WITH COLLECT({something: n.something}) + { something: [] } + {something: 'alpha'} + {something: [1,2,3]} AS collection\n"; + + private static final String QUERY_WITH_ARRAY = "CREATE (:Test {a: 1, something: 'alpha' }), " + + "(:Test { a: 2, something: [] }), " + + "(:Test { a: 3, something: 'beta' })," + + "(:Test { a: 4, something: [1,2,3] })"; + + public static final Map MAP_WITH_ALPHA = Map.of("something", "alpha"); + public static final Map MAP_WITH_BETA = Map.of("something", "beta"); @ClassRule public static DbmsRule db = new ImpermanentDbmsRule(); @@ -48,6 +62,11 @@ public class CollTest { TestUtil.registerProcedure(db, Coll.class, Json.class); } + @After + public void after() { + db.executeTransactionally("MATCH (n) DETACH DELETE n"); + } + @Test public void testRunningTotal() throws Exception { testCall(db, "RETURN apoc.coll.runningTotal([1,2,3,4,5.5,1]) as value", @@ -197,6 +216,14 @@ public void testIndexOf() throws Exception { testCall(db, "RETURN apoc.coll.indexOf([1,2,3],null) AS value", r -> assertEquals(-1L, r.get("value"))); } + @Test + public void testIndexOfWithCollections() { + db.executeTransactionally(QUERY_WITH_ARRAY); + testCall(db, QUERY_WITH_MIXED_TYPES + "RETURN apoc.coll.indexOf(collection, { something: [] }) AS index", + r -> assertEquals(1L, r.get("index")) + ); + } + @Test public void testSplit() throws Exception { testResult(db, "CALL apoc.coll.split([1,2,3,2,4,5],2)", r -> { @@ -224,6 +251,63 @@ public void testSplit() throws Exception { }); } + @Test + public void testSplitOfWithBothHardcodedAndEntityTypes() { + db.executeTransactionally(QUERY_WITH_ARRAY); + testResult(db, QUERY_WITH_MIXED_TYPES + "CALL apoc.coll.split(collection, { something: [] }) YIELD value RETURN value", + r -> { + Map row = r.next(); + List> value = (List>) row.get("value"); + assertEquals(List.of(MAP_WITH_ALPHA), value); + row = r.next(); + value = (List>) row.get("value"); + assertEquals(2, value.size()); + assertEquals(MAP_WITH_BETA, value.get(0)); + // in this case the `[1,2,3]` in `{ something: [1,2,3] }` is an array + assertMapWithNumericArray(value.get(1)); + + row = r.next(); + value = (List>) row.get("value"); + + assertEquals(2, value.size()); + assertEquals(MAP_WITH_ALPHA, value.get(0)); + // in this case the `[1,2,3]` in `{ something: [1,2,3] }` is an ArrayList + assertEquals(Map.of("something", List.of(1L,2L,3L)), value.get(1)); + + assertFalse(r.hasNext()); + }); + } + + @Test + public void testRemoveWithBothHardcodedAndEntityTypes() { + db.executeTransactionally(QUERY_WITH_ARRAY); + testCall(db, QUERY_WITH_MIXED_TYPES + "RETURN apoc.coll.removeAll(collection, [{ something: [] }, { something: 'alpha' }]) AS value", + row -> { + List> value = (List>) row.get("value"); + assertEquals(3, value.size()); + assertEquals(MAP_WITH_BETA, value.get(0)); + // in this case the `[1,2,3]` in `{ something: [1,2,3] }` is an array + assertMapWithNumericArray(value.get(1)); + // in this case the `[1,2,3]` in `{ something: [1,2,3] }` is an ArrayList + assertEquals(Map.of("something", List.of(1L,2L,3L)), value.get(2)); + }); + } + + @Test + public void testCollToSetWithBothHardcodedAndEntityTypes() { + db.executeTransactionally(QUERY_WITH_ARRAY); + + testCall(db, QUERY_WITH_MIXED_TYPES + "RETURN apoc.coll.toSet(collection) AS value", + row -> { + List> value = (List>) row.get("value"); + assertEquals(4, value.size()); + assertEquals(MAP_WITH_ALPHA, value.get(0)); + assertMapWithEmptyArray(value.get(1)); + assertEquals(MAP_WITH_BETA, value.get(2)); + assertMapWithNumericArray(value.get(3)); + }); + } + @Test public void testSet() throws Exception { testCall(db, "RETURN apoc.coll.set(null,0,4) AS value", r -> assertNull(r.get("value"))); @@ -284,6 +368,25 @@ public void testContainsAll() throws Exception { testCall(db, "RETURN apoc.coll.containsAll([1,2,3],null) AS value", (res) -> assertEquals(false, res.get("value"))); } + @Test + public void testContainsAllOfWithCollections() { + db.executeTransactionally(QUERY_WITH_ARRAY); + + testCall(db, QUERY_WITH_MIXED_TYPES + "RETURN apoc.coll.containsAll(collection, [{ something: [] }]) AS value", + row -> assertTrue( (boolean) row.get("value") ) + ); + } + + private static void assertMapWithEmptyArray(Map map) { + assertEquals(1, map.size()); + assertArrayEquals(new String[]{}, (String[]) map.get("something")); + } + + private static void assertMapWithNumericArray(Map map) { + assertEquals(1, map.size()); + assertArrayEquals(new long[] {1,2,3}, (long[]) map.get("something")); + } + @Test public void testContainsAllSorted() throws Exception { testCall(db, "RETURN apoc.coll.containsAllSorted([1,2,3],[1,2]) AS value", (res) -> assertEquals(true, res.get("value")));