Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dITjfMpu] apoc.coll.indexOf unexpectedly treats collections differently than the same hardcoded list (neo4j/apoc#422) #3600

Merged
merged 1 commit into from
May 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 20 additions & 12 deletions core/src/main/java/apoc/coll/Coll.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package apoc.coll;

import apoc.result.ListResult;
import apoc.util.Util;
import com.google.common.util.concurrent.AtomicDouble;
import org.apache.commons.lang3.mutable.MutableInt;
import org.apache.commons.math3.stat.descriptive.moment.StandardDeviation;
Expand All @@ -35,6 +36,7 @@
import org.neo4j.procedure.Name;
import org.neo4j.procedure.Procedure;
import org.neo4j.procedure.UserFunction;
import org.neo4j.values.AnyValue;

import java.lang.reflect.Array;
import java.text.Collator;
Expand All @@ -59,6 +61,8 @@
import java.util.stream.IntStream;
import java.util.stream.Stream;

import static apoc.util.Util.containsValueEquals;
import static apoc.util.Util.toAnyValues;
import static java.util.Arrays.asList;

public class Coll {
Expand Down Expand Up @@ -391,12 +395,12 @@ public Stream<ListResult> split(@Name("values") List<Object> list, @Name("value"
if (list==null || list.isEmpty()) return Stream.empty();
List<Object> l = new ArrayList<>(list);
List<List<Object>> result = new ArrayList<>(10);
int idx = l.indexOf(value);
int idx = Util.indexOf(l, value);
while (idx != -1) {
List<Object> subList = l.subList(0, idx);
if (!subList.isEmpty()) result.add(subList);
l = l.subList(idx+1,l.size());
idx = l.indexOf(value);
idx = Util.indexOf(l, value);
}
if (!l.isEmpty()) result.add(l);
return result.stream().map(ListResult::new);
Expand Down Expand Up @@ -473,14 +477,17 @@ public List<Object> remove(@Name("coll") List<Object> coll, @Name("index") long
public long indexOf(@Name("coll") List<Object> coll, @Name("value") Object value) {
// return reduce(res=[0,-1], x in $list | CASE WHEN x=$value AND res[1]=-1 THEN [res[0], res[0]+1] ELSE [res[0]+1, res[1]] END)[1] as value
if (coll == null || coll.isEmpty()) return -1;
return new ArrayList<>(coll).indexOf(value);
return Util.indexOf(coll, value);
}

@UserFunction
@Description("apoc.coll.containsAll(coll, values) optimized contains-all operation (using a HashSet) (returns single row or not)")
public boolean containsAll(@Name("coll") List<Object> coll, @Name("values") List<Object> values) {
if (coll == null || coll.isEmpty() || values == null) return false;
return new HashSet<>(coll).containsAll(values);
Set<Object> objects = new HashSet<>(coll);

return values.stream()
.allMatch( i -> containsValueEquals(objects, i));
}

@UserFunction
Expand Down Expand Up @@ -524,7 +531,8 @@ public boolean isEqualCollection(@Name("coll") List<Object> first, @Name("values
@Description("apoc.coll.toSet([list]) returns a unique list backed by a set")
public List<Object> toSet(@Name("values") List<Object> list) {
if (list == null) return null;
return new SetBackedList(new LinkedHashSet(list));
List<AnyValue> anyValues = toAnyValues(list);
return new SetBackedList(new LinkedHashSet(anyValues));
}

@UserFunction
Expand Down Expand Up @@ -604,17 +612,17 @@ public List<Object> union(@Name("first") List<Object> first, @Name("second") Lis
@UserFunction
@Description("apoc.coll.subtract(first, second) - returns unique set of first list with all elements of second list removed")
public List<Object> subtract(@Name("first") List<Object> first, @Name("second") List<Object> second) {
if (first == null) return null;
Set<Object> set = new HashSet<>(first);
if (second!=null) set.removeAll(second);
return new SetBackedList(set);
if (first == null) return null;
List<Object> list = new ArrayList<>(toAnyValues(first));
if (second!=null) list.removeAll(toAnyValues(second));
return list;
}
@UserFunction
@Description("apoc.coll.removeAll(first, second) - returns first list with all elements of second list removed")
public List<Object> removeAll(@Name("first") List<Object> first, @Name("second") List<Object> second) {
if (first == null) return null;
List<Object> list = new ArrayList<>(first);
if (second!=null) list.removeAll(second);
if (first == null) return null;
List<Object> list = new ArrayList<>(toAnyValues(first));
if (second!=null) list.removeAll(toAnyValues(second));
return list;
}

Expand Down
28 changes: 28 additions & 0 deletions core/src/main/java/apoc/util/Util.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import apoc.export.util.ExportConfig;
import apoc.result.VirtualNode;
import apoc.result.VirtualRelationship;
import org.apache.commons.collections4.ListUtils;
import org.apache.commons.compress.archivers.ArchiveEntry;
import org.apache.commons.compress.archivers.ArchiveInputStream;
import org.apache.commons.io.IOUtils;
Expand All @@ -50,8 +51,10 @@
import org.neo4j.logging.Log;
import org.neo4j.graphdb.ExecutionPlanDescription;
import org.neo4j.graphdb.Result;
import org.neo4j.kernel.impl.util.ValueUtils;
import org.neo4j.logging.NullLog;
import org.neo4j.procedure.Mode;
import org.neo4j.values.AnyValue;
import org.neo4j.procedure.TerminationGuard;
import org.neo4j.values.storable.CoordinateReferenceSystem;
import org.neo4j.values.storable.PointValue;
Expand Down Expand Up @@ -1202,4 +1205,29 @@ public static boolean isWindows() {
.toLowerCase()
.contains("win");
}

public static <T> boolean valueEquals(T one, T other) {
if (one == null || other == null) {
return false;
}
return ValueUtils.of(one)
.equals(ValueUtils.of(other));
}

public static boolean containsValueEquals(Collection<Object> collection, Object value) {
return collection.stream()
.anyMatch(i -> Util.valueEquals(value, i));
}

public static <T> List<AnyValue> toAnyValues(List<T> list) {
return list.stream()
.map(ValueUtils::of)
.collect(Collectors.toList());
}

public static int indexOf(List<Object> list, Object value) {
return ListUtils.indexOf(list,
(i) -> Util.valueEquals(i, value)
);
}
}
103 changes: 103 additions & 0 deletions core/src/test/java/apoc/coll/CollTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import apoc.convert.Json;
import apoc.util.TestUtil;
import org.junit.After;
import org.junit.BeforeClass;
import org.junit.ClassRule;
import org.junit.Test;
Expand All @@ -39,6 +40,19 @@
import static org.neo4j.internal.helpers.collection.Iterables.asSet;

public class CollTest {
// query that procedures a list,
// with both entity types, via collect(..), and hardcoded one
private static final String QUERY_WITH_MIXED_TYPES = "MATCH (n:Test) " +
"WITH n ORDER BY n.a " +
"WITH COLLECT({something: n.something}) + { something: [] } + {something: 'alpha'} + {something: [1,2,3]} AS collection\n";

private static final String QUERY_WITH_ARRAY = "CREATE (:Test {a: 1, something: 'alpha' }), " +
"(:Test { a: 2, something: [] }), " +
"(:Test { a: 3, something: 'beta' })," +
"(:Test { a: 4, something: [1,2,3] })";

public static final Map<String, String> MAP_WITH_ALPHA = Map.of("something", "alpha");
public static final Map<String, String> MAP_WITH_BETA = Map.of("something", "beta");

@ClassRule
public static DbmsRule db = new ImpermanentDbmsRule();
Expand All @@ -48,6 +62,11 @@ public class CollTest {
TestUtil.registerProcedure(db, Coll.class, Json.class);
}

@After
public void after() {
db.executeTransactionally("MATCH (n) DETACH DELETE n");
}

@Test
public void testRunningTotal() throws Exception {
testCall(db, "RETURN apoc.coll.runningTotal([1,2,3,4,5.5,1]) as value",
Expand Down Expand Up @@ -197,6 +216,14 @@ public void testIndexOf() throws Exception {
testCall(db, "RETURN apoc.coll.indexOf([1,2,3],null) AS value", r -> assertEquals(-1L, r.get("value")));
}

@Test
public void testIndexOfWithCollections() {
db.executeTransactionally(QUERY_WITH_ARRAY);
testCall(db, QUERY_WITH_MIXED_TYPES + "RETURN apoc.coll.indexOf(collection, { something: [] }) AS index",
r -> assertEquals(1L, r.get("index"))
);
}

@Test
public void testSplit() throws Exception {
testResult(db, "CALL apoc.coll.split([1,2,3,2,4,5],2)", r -> {
Expand Down Expand Up @@ -224,6 +251,63 @@ public void testSplit() throws Exception {
});
}

@Test
public void testSplitOfWithBothHardcodedAndEntityTypes() {
db.executeTransactionally(QUERY_WITH_ARRAY);
testResult(db, QUERY_WITH_MIXED_TYPES + "CALL apoc.coll.split(collection, { something: [] }) YIELD value RETURN value",
r -> {
Map<String, Object> row = r.next();
List<Map<String, Object>> value = (List<Map<String, Object>>) row.get("value");
assertEquals(List.of(MAP_WITH_ALPHA), value);
row = r.next();
value = (List<Map<String, Object>>) row.get("value");
assertEquals(2, value.size());
assertEquals(MAP_WITH_BETA, value.get(0));
// in this case the `[1,2,3]` in `{ something: [1,2,3] }` is an array
assertMapWithNumericArray(value.get(1));

row = r.next();
value = (List<Map<String, Object>>) row.get("value");

assertEquals(2, value.size());
assertEquals(MAP_WITH_ALPHA, value.get(0));
// in this case the `[1,2,3]` in `{ something: [1,2,3] }` is an ArrayList
assertEquals(Map.of("something", List.of(1L,2L,3L)), value.get(1));

assertFalse(r.hasNext());
});
}

@Test
public void testRemoveWithBothHardcodedAndEntityTypes() {
db.executeTransactionally(QUERY_WITH_ARRAY);
testCall(db, QUERY_WITH_MIXED_TYPES + "RETURN apoc.coll.removeAll(collection, [{ something: [] }, { something: 'alpha' }]) AS value",
row -> {
List<Map<String, Object>> value = (List<Map<String, Object>>) row.get("value");
assertEquals(3, value.size());
assertEquals(MAP_WITH_BETA, value.get(0));
// in this case the `[1,2,3]` in `{ something: [1,2,3] }` is an array
assertMapWithNumericArray(value.get(1));
// in this case the `[1,2,3]` in `{ something: [1,2,3] }` is an ArrayList
assertEquals(Map.of("something", List.of(1L,2L,3L)), value.get(2));
});
}

@Test
public void testCollToSetWithBothHardcodedAndEntityTypes() {
db.executeTransactionally(QUERY_WITH_ARRAY);

testCall(db, QUERY_WITH_MIXED_TYPES + "RETURN apoc.coll.toSet(collection) AS value",
row -> {
List<Map<String, Object>> value = (List<Map<String, Object>>) row.get("value");
assertEquals(4, value.size());
assertEquals(MAP_WITH_ALPHA, value.get(0));
assertMapWithEmptyArray(value.get(1));
assertEquals(MAP_WITH_BETA, value.get(2));
assertMapWithNumericArray(value.get(3));
});
}

@Test
public void testSet() throws Exception {
testCall(db, "RETURN apoc.coll.set(null,0,4) AS value", r -> assertNull(r.get("value")));
Expand Down Expand Up @@ -284,6 +368,25 @@ public void testContainsAll() throws Exception {
testCall(db, "RETURN apoc.coll.containsAll([1,2,3],null) AS value", (res) -> assertEquals(false, res.get("value")));
}

@Test
public void testContainsAllOfWithCollections() {
db.executeTransactionally(QUERY_WITH_ARRAY);

testCall(db, QUERY_WITH_MIXED_TYPES + "RETURN apoc.coll.containsAll(collection, [{ something: [] }]) AS value",
row -> assertTrue( (boolean) row.get("value") )
);
}

private static void assertMapWithEmptyArray(Map map) {
assertEquals(1, map.size());
assertArrayEquals(new String[]{}, (String[]) map.get("something"));
}

private static void assertMapWithNumericArray(Map map) {
assertEquals(1, map.size());
assertArrayEquals(new long[] {1,2,3}, (long[]) map.get("something"));
}

@Test
public void testContainsAllSorted() throws Exception {
testCall(db, "RETURN apoc.coll.containsAllSorted([1,2,3],[1,2]) AS value", (res) -> assertEquals(true, res.get("value")));
Expand Down