Skip to content

Commit

Permalink
[dITjfMpu] apoc.coll.indexOf unexpectedly treats collections differen…
Browse files Browse the repository at this point in the history
…tly than the same hardcoded list (#422)

* [dITjfMpu] apoc.coll.indexOf unexpectedly treats collections differently than the same hardcoded list

* [dITjfMpu] various changes

* [dITjfMpu] small review changes
  • Loading branch information
vga91 authored May 29, 2023
1 parent 9ecb4c5 commit 2c1f37c
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 8 deletions.
28 changes: 28 additions & 0 deletions common/src/main/java/apoc/util/Util.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import apoc.result.VirtualNode;
import apoc.result.VirtualRelationship;
import apoc.util.collection.Iterators;
import org.apache.commons.collections4.ListUtils;
import org.apache.commons.compress.archivers.ArchiveEntry;
import org.apache.commons.compress.archivers.ArchiveInputStream;
import org.apache.commons.io.IOUtils;
Expand All @@ -36,8 +37,10 @@
import org.neo4j.graphdb.schema.ConstraintType;
import org.neo4j.internal.schema.ConstraintDescriptor;
import org.neo4j.kernel.impl.coreapi.InternalTransaction;
import org.neo4j.kernel.impl.util.ValueUtils;
import org.neo4j.logging.NullLog;
import org.neo4j.procedure.Mode;
import org.neo4j.values.AnyValue;
import org.neo4j.values.storable.CoordinateReferenceSystem;
import org.neo4j.values.storable.PointValue;
import org.neo4j.values.storable.Values;
Expand Down Expand Up @@ -1101,4 +1104,29 @@ public static boolean isWindows() {
.toLowerCase()
.contains("win");
}

public static <T> boolean valueEquals(T one, T other) {
if (one == null || other == null) {
return false;
}
return ValueUtils.of(one)
.equals(ValueUtils.of(other));
}

public static boolean containsValueEquals(Collection<Object> collection, Object value) {
return collection.stream()
.anyMatch(i -> Util.valueEquals(value, i));
}

public static <T> List<AnyValue> toAnyValues(List<T> list) {
return list.stream()
.map(ValueUtils::of)
.collect(Collectors.toList());
}

public static int indexOf(List<Object> list, Object value) {
return ListUtils.indexOf(list,
(i) -> Util.valueEquals(i, value)
);
}
}
24 changes: 16 additions & 8 deletions core/src/main/java/apoc/coll/Coll.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package apoc.coll;

import apoc.result.ListResult;
import apoc.util.Util;
import com.google.common.util.concurrent.AtomicDouble;
import org.apache.commons.lang3.mutable.MutableInt;
import org.apache.commons.lang3.tuple.Pair;
Expand All @@ -34,6 +35,7 @@
import org.neo4j.procedure.Name;
import org.neo4j.procedure.Procedure;
import org.neo4j.procedure.UserFunction;
import org.neo4j.values.AnyValue;

import java.lang.reflect.Array;
import java.text.Collator;
Expand All @@ -58,6 +60,8 @@
import java.util.stream.IntStream;
import java.util.stream.Stream;

import static apoc.util.Util.containsValueEquals;
import static apoc.util.Util.toAnyValues;
import static java.util.Arrays.asList;

public class Coll {
Expand Down Expand Up @@ -389,12 +393,12 @@ public Stream<ListResult> split(@Name("coll") List<Object> list, @Name("value")
if (list==null || list.isEmpty()) return Stream.empty();
List<Object> l = new ArrayList<>(list);
List<List<Object>> result = new ArrayList<>(10);
int idx = l.indexOf(value);
int idx = Util.indexOf(l, value);
while (idx != -1) {
List<Object> subList = l.subList(0, idx);
if (!subList.isEmpty()) result.add(subList);
l = l.subList(idx+1,l.size());
idx = l.indexOf(value);
idx = Util.indexOf(l, value);
}
if (!l.isEmpty()) result.add(l);
return result.stream().map(ListResult::new);
Expand Down Expand Up @@ -471,14 +475,17 @@ public List<Object> remove(@Name("coll") List<Object> coll, @Name("index") long
public long indexOf(@Name("coll") List<Object> coll, @Name("value") Object value) {
// return reduce(res=[0,-1], x in $list | CASE WHEN x=$value AND res[1]=-1 THEN [res[0], res[0]+1] ELSE [res[0]+1, res[1]] END)[1] as value
if (coll == null || coll.isEmpty()) return -1;
return new ArrayList<>(coll).indexOf(value);
return Util.indexOf(coll, value);
}

@UserFunction("apoc.coll.containsAll")
@Description("Returns whether or not all of the given values exist in the given collection (using a HashSet).")
public boolean containsAll(@Name("coll1") List<Object> coll, @Name("coll2") List<Object> values) {
if (coll == null || coll.isEmpty() || values == null) return false;
return new HashSet<>(coll).containsAll(values);
Set<Object> objects = new HashSet<>(coll);

return values.stream()
.allMatch( i -> containsValueEquals(objects, i));
}

@UserFunction("apoc.coll.containsSorted")
Expand Down Expand Up @@ -522,7 +529,8 @@ public boolean isEqualCollection(@Name("coll") List<Object> first, @Name("values
@Description("Returns a unique list from the given list.")
public List<Object> toSet(@Name("coll") List<Object> list) {
if (list == null) return null;
return new SetBackedList(new LinkedHashSet(list));
List<AnyValue> anyValues = toAnyValues(list);
return new SetBackedList(new LinkedHashSet(anyValues));
}

@UserFunction("apoc.coll.sumLongs")
Expand Down Expand Up @@ -602,9 +610,9 @@ public List<Object> union(@Name("list1") List<Object> first, @Name("list2") List
@UserFunction("apoc.coll.removeAll")
@Description("Returns the first list with all elements of the second list removed.")
public List<Object> removeAll(@Name("list1") List<Object> first, @Name("list2") List<Object> second) {
if (first == null) return null;
List<Object> list = new ArrayList<>(first);
if (second!=null) list.removeAll(second);
if (first == null) return null;
List<Object> list = new ArrayList<>(toAnyValues(first));
if (second!=null) list.removeAll(toAnyValues(second));
return list;
}
@UserFunction("apoc.coll.subtract")
Expand Down
106 changes: 106 additions & 0 deletions core/src/test/java/apoc/coll/CollTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import apoc.convert.Json;
import apoc.util.TestUtil;
import org.junit.After;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.ClassRule;
Expand All @@ -40,6 +41,22 @@
import static org.junit.Assert.*;

public class CollTest {
// query that procedures a list,
// with both entity types, via collect(..), and hardcoded one
private static final String QUERY_WITH_MIXED_TYPES = """
WITH COLLECT {
MATCH (n:Test)
RETURN {something: n.something}
ORDER BY n.a} + { something: [] } + {something: 'alpha'} + {something: [1,2,3]} AS collection
""";

private static final String QUERY_WITH_ARRAY = "CREATE (:Test {a: 1, something: 'alpha' }), " +
"(:Test { a: 2, something: [] }), " +
"(:Test { a: 3, something: 'beta' })," +
"(:Test { a: 4, something: [1,2,3] })";

public static final Map<String, String> MAP_WITH_ALPHA = Map.of("something", "alpha");
public static final Map<String, String> MAP_WITH_BETA = Map.of("something", "beta");

@ClassRule
public static DbmsRule db = new ImpermanentDbmsRule();
Expand All @@ -54,6 +71,11 @@ public static void teardown() {
db.shutdown();
}

@After
public void after() {
db.executeTransactionally("MATCH (n) DETACH DELETE n");
}

@Test
public void testRunningTotal() {
testCall(db, "RETURN apoc.coll.runningTotal([1,2,3,4,5.5,1]) as value",
Expand Down Expand Up @@ -203,6 +225,14 @@ public void testIndexOf() {
testCall(db, "RETURN apoc.coll.indexOf([1,2,3],null) AS value", r -> assertEquals(-1L, r.get("value")));
}

@Test
public void testIndexOfWithCollections() {
db.executeTransactionally(QUERY_WITH_ARRAY);
testCall(db, QUERY_WITH_MIXED_TYPES + "RETURN apoc.coll.indexOf(collection, { something: [] }) AS index",
r -> assertEquals(1L, r.get("index"))
);
}

@Test
public void testSplit() {
testResult(db, "CALL apoc.coll.split([1,2,3,2,4,5],2)", r -> {
Expand Down Expand Up @@ -230,6 +260,63 @@ public void testSplit() {
});
}

@Test
public void testSplitOfWithBothHardcodedAndEntityTypes() {
db.executeTransactionally(QUERY_WITH_ARRAY);
testResult(db, QUERY_WITH_MIXED_TYPES + "CALL apoc.coll.split(collection, { something: [] }) YIELD value RETURN value",
r -> {
Map<String, Object> row = r.next();
List<Map<String, Object>> value = (List<Map<String, Object>>) row.get("value");
assertEquals(List.of(MAP_WITH_ALPHA), value);
row = r.next();
value = (List<Map<String, Object>>) row.get("value");
assertEquals(2, value.size());
assertEquals(MAP_WITH_BETA, value.get(0));
// in this case the `[1,2,3]` in `{ something: [1,2,3] }` is an array
assertMapWithNumericArray(value.get(1));

row = r.next();
value = (List<Map<String, Object>>) row.get("value");

assertEquals(2, value.size());
assertEquals(MAP_WITH_ALPHA, value.get(0));
// in this case the `[1,2,3]` in `{ something: [1,2,3] }` is an ArrayList
assertEquals(Map.of("something", List.of(1L,2L,3L)), value.get(1));

assertFalse(r.hasNext());
});
}

@Test
public void testRemoveWithBothHardcodedAndEntityTypes() {
db.executeTransactionally(QUERY_WITH_ARRAY);
testCall(db, QUERY_WITH_MIXED_TYPES + "RETURN apoc.coll.removeAll(collection, [{ something: [] }, { something: 'alpha' }]) AS value",
row -> {
List<Map<String, Object>> value = (List<Map<String, Object>>) row.get("value");
assertEquals(3, value.size());
assertEquals(MAP_WITH_BETA, value.get(0));
// in this case the `[1,2,3]` in `{ something: [1,2,3] }` is an array
assertMapWithNumericArray(value.get(1));
// in this case the `[1,2,3]` in `{ something: [1,2,3] }` is an ArrayList
assertEquals(Map.of("something", List.of(1L,2L,3L)), value.get(2));
});
}

@Test
public void testCollToSetWithBothHardcodedAndEntityTypes() {
db.executeTransactionally(QUERY_WITH_ARRAY);

testCall(db, QUERY_WITH_MIXED_TYPES + "RETURN apoc.coll.toSet(collection) AS value",
row -> {
List<Map<String, Object>> value = (List<Map<String, Object>>) row.get("value");
assertEquals(4, value.size());
assertEquals(MAP_WITH_ALPHA, value.get(0));
assertMapWithEmptyArray(value.get(1));
assertEquals(MAP_WITH_BETA, value.get(2));
assertMapWithNumericArray(value.get(3));
});
}

@Test
public void testSet() {
testCall(db, "RETURN apoc.coll.set(null,0,4) AS value", r -> assertNull(r.get("value")));
Expand Down Expand Up @@ -290,6 +377,25 @@ public void testContainsAll() {
testCall(db, "RETURN apoc.coll.containsAll([1,2,3],null) AS value", (res) -> assertEquals(false, res.get("value")));
}

@Test
public void testContainsAllOfWithCollections() {
db.executeTransactionally(QUERY_WITH_ARRAY);

testCall(db, QUERY_WITH_MIXED_TYPES + "RETURN apoc.coll.containsAll(collection, [{ something: [] }]) AS value",
row -> assertTrue( (boolean) row.get("value") )
);
}

private static void assertMapWithEmptyArray(Map map) {
assertEquals(1, map.size());
assertArrayEquals(new String[]{}, (String[]) map.get("something"));
}

private static void assertMapWithNumericArray(Map map) {
assertEquals(1, map.size());
assertArrayEquals(new long[] {1,2,3}, (long[]) map.get("something"));
}

@Test
public void testContainsAllSorted() {
testCall(db, "RETURN apoc.coll.containsAllSorted([1,2,3],[1,2]) AS value", (res) -> assertEquals(true, res.get("value")));
Expand Down

0 comments on commit 2c1f37c

Please sign in to comment.