Skip to content

Commit

Permalink
Add Cypher Version handling to APOC inner queries
Browse files Browse the repository at this point in the history
  • Loading branch information
gem-neo4j committed Feb 5, 2025
1 parent 46c8c9e commit a1736e0
Show file tree
Hide file tree
Showing 25 changed files with 376 additions and 148 deletions.
11 changes: 8 additions & 3 deletions common/src/main/java/apoc/cypher/CypherUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,24 @@
import static java.util.stream.Collectors.toList;

import apoc.result.CypherStatementMapResult;
import apoc.util.Util;
import java.util.Collection;
import java.util.Collections;
import java.util.Map;
import java.util.stream.Stream;
import org.neo4j.graphdb.Transaction;
import org.neo4j.internal.kernel.api.procs.ProcedureCallContext;
import org.neo4j.procedure.Name;

public class CypherUtils {
public static Stream<CypherStatementMapResult> runCypherQuery(
Transaction tx, @Name("cypher") String statement, @Name("params") Map<String, Object> params) {
Transaction tx,
@Name("cypher") String statement,
@Name("params") Map<String, Object> params,
ProcedureCallContext procedureCallContext) {
if (params == null) params = Collections.emptyMap();
return tx.execute(withParamMapping(statement, params.keySet()), params).stream()
.map(CypherStatementMapResult::new);
String query = Util.prefixQueryWithCheck(procedureCallContext, withParamMapping(statement, params.keySet()));
return tx.execute(query, params).stream().map(CypherStatementMapResult::new);
}

public static String withParamMapping(String fragment, Collection<String> keys) {
Expand Down
108 changes: 108 additions & 0 deletions common/src/main/java/apoc/util/Util.java
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@
import org.neo4j.graphdb.schema.IndexType;
import org.neo4j.graphdb.security.URLAccessChecker;
import org.neo4j.graphdb.security.URLAccessValidationError;
import org.neo4j.internal.kernel.api.procs.ProcedureCallContext;
import org.neo4j.internal.schema.ConstraintDescriptor;
import org.neo4j.kernel.api.QueryLanguage;
import org.neo4j.kernel.impl.coreapi.InternalTransaction;
Expand Down Expand Up @@ -1363,6 +1364,113 @@ public static <T> T withBackOffRetries(Supplier<T> func, long initialTimeout, lo
return result;
}

/**
* A helper function to give the Pre-parser text that can be appended to a query to force it to be
* of a certain Cypher Version
* @param procedureCallContext the injectable context from the procedure framework
* @return Cypher Pre-parser for Cypher Version setting
*/
public static String getCypherVersionPrefix(ProcedureCallContext procedureCallContext) {
return procedureCallContext.calledwithQueryLanguage().equals(QueryLanguage.CYPHER_5)
? "CYPHER 5 "
: "CYPHER 25 ";
}

public static String getCypherVersionString(ProcedureCallContext procedureCallContext) {
return procedureCallContext.calledwithQueryLanguage().equals(QueryLanguage.CYPHER_5) ? "5" : "25";
}

private static final Pattern CYPHER_VERSION_PATTERN =
Pattern.compile("^(CYPHER)(?:\\s+(\\d+))?", Pattern.CASE_INSENSITIVE);
public static final Pattern PLANNER_PATTERN =
Pattern.compile("\\bplanner\\s*=\\s*[^\\s]*", Pattern.CASE_INSENSITIVE);
public static final Pattern RUNTIME_PATTERN = Pattern.compile("\\bruntime\\s*=", Pattern.CASE_INSENSITIVE);
public static final String CYPHER_RUNTIME_SLOTTED = " runtime=slotted ";

public static String prefixQueryWithCheck(ProcedureCallContext procedureCallContext, String query) {
return prefixQueryWithCheck(getCypherVersionString(procedureCallContext), query);
}

/**
* A helper function to prefix a query with a Cypher Version; if it is not already prefixed.
* @param cypherVersion a string of the version number
* @return The prefixed query
*/
public static String prefixQueryWithCheck(String cypherVersion, String query) {
List<String> cypherPrefix = extractCypherPrefix(query, cypherVersion);
if (Objects.equals(cypherPrefix.getFirst(), "")) {
return cypherPrefix.get(1) + " " + query;
}
return query.replaceFirst("(?i)" + cypherPrefix.getFirst(), cypherPrefix.get(1) + " ");
}

// Extract the Cypher prefix, add version if missing
// This will return a list of 2 items, the first being what was extracted (empty string if nothing)
// The second being what the prefix should be.
public static List<String> extractCypherPrefix(String input, String cypherVersion) {
String cypherVersionPrefix = "CYPHER " + cypherVersion;
Matcher matcher = CYPHER_VERSION_PATTERN.matcher(input);
if (matcher.find()) {
String cypher = matcher.group(1); // Always "CYPHER"
String version = matcher.group(2); // Optional version
return List.of(
version != null ? cypher + " " + version : cypher,
version != null ? cypher + " " + version : cypherVersionPrefix);
}
return List.of("", cypherVersionPrefix); // No prefix was found
}

/**
* A helper function prefix a query with a Cypher Version, it will not check if the query is already prefixed or not.
* To do that call: prefixQueryWithCheck
* @param procedureCallContext the injectable context from the procedure framework
* @return The prefixed query
*/
public static String prefixQuery(ProcedureCallContext procedureCallContext, String query) {
return procedureCallContext.calledwithQueryLanguage().equals(QueryLanguage.CYPHER_5)
? "CYPHER 5 "
: "CYPHER 25 " + query;
}

public static String prefixQuery(String cypherVersion, String query) {
return "CYPHER " + cypherVersion + " " + query;
}

public static String slottedRuntime(String cypherIterate, String cypherVersion) {
if (RUNTIME_PATTERN.matcher(cypherIterate).find()) {
return cypherIterate;
}

return prependQueryOption(cypherIterate, CYPHER_RUNTIME_SLOTTED, cypherVersion);
}

public enum Planner {
DEFAULT,
COST,
IDP,
DP
}

public static String applyPlanner(String query, Planner planner, String cypherVersion) {
if (planner.equals(Planner.DEFAULT)) {
return Util.prefixQueryWithCheck(cypherVersion, query);
}
Matcher matcher = PLANNER_PATTERN.matcher(query);
String cypherPlanner = String.format(" planner=%s ", planner.name().toLowerCase());
if (matcher.find()) {
return Util.prefixQueryWithCheck(cypherVersion, matcher.replaceFirst(cypherPlanner));
}
return prependQueryOption(query, cypherPlanner, cypherVersion);
}

private static String prependQueryOption(String query, String cypherOption, String cypherVersion) {
List<String> cypherPrefix = Util.extractCypherPrefix(query, cypherVersion);
if (Objects.equals(cypherPrefix.getFirst(), "")) {
return cypherPrefix.get(1) + cypherOption + query;
}
return query.replaceFirst("(?i)" + cypherPrefix.getFirst(), cypherPrefix.get(1) + cypherOption);
}

// Get the current supported query language versions, if this list changes
// this function will error, on error please update!
public static List<String> getCypherVersions() {
Expand Down
8 changes: 6 additions & 2 deletions core/src/main/java/apoc/atomic/Atomic.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.neo4j.graphdb.Entity;
import org.neo4j.graphdb.NotFoundException;
import org.neo4j.graphdb.Transaction;
import org.neo4j.internal.kernel.api.procs.ProcedureCallContext;
import org.neo4j.procedure.*;

/**
Expand All @@ -45,6 +46,9 @@ public class Atomic {
@Context
public Transaction tx;

@Context
public ProcedureCallContext procedureCallContext;

/**
* increment a property's value
*/
Expand Down Expand Up @@ -281,10 +285,10 @@ public Stream<AtomicResults> update(
executionContext,
(context) -> {
oldValue[0] = entity.getProperty(property);
String statement = "WITH $container as n with n set n." + Util.sanitize(property, true) + "="
String statement = "WITH $container AS n WITH n SET n." + Util.sanitize(property, true) + "="
+ operation + ";";
Map<String, Object> properties = MapUtil.map("container", entity);
return context.tx.execute(statement, properties);
return context.tx.execute(Util.prefixQuery(procedureCallContext, statement), properties);
},
retryAttempts);

Expand Down
12 changes: 10 additions & 2 deletions core/src/main/java/apoc/coll/Coll.java
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
import org.neo4j.graphdb.Relationship;
import org.neo4j.graphdb.Result;
import org.neo4j.graphdb.Transaction;
import org.neo4j.internal.kernel.api.procs.ProcedureCallContext;
import org.neo4j.kernel.impl.util.ValueUtils;
import org.neo4j.procedure.Context;
import org.neo4j.procedure.Description;
Expand All @@ -69,6 +70,9 @@ public class Coll {
@Context
public Transaction tx;

@Context
public ProcedureCallContext procedureCallContext;

@UserFunction("apoc.coll.stdev")
@Description("Returns sample or population standard deviation with `isBiasCorrected` true or false respectively.")
public Number stdev(
Expand Down Expand Up @@ -177,8 +181,10 @@ public Object min(@Name(value = "values", description = "The list to find the mi
if (list == null || list.isEmpty()) return null;
if (list.size() == 1) return list.get(0);

var preparser = "CYPHER " + Util.getCypherVersionString(procedureCallContext) + " runtime=slotted ";
try (Result result = tx.execute(
"cypher runtime=slotted return reduce(res=null, x in $list | CASE WHEN res IS NULL OR x<res THEN x ELSE res END) as value",
preparser
+ "return reduce(res=null, x in $list | CASE WHEN res IS NULL OR x<res THEN x ELSE res END) as value",
Collections.singletonMap("list", list))) {
return result.next().get("value");
}
Expand All @@ -190,8 +196,10 @@ public Object min(@Name(value = "values", description = "The list to find the mi
public Object max(@Name(value = "values", description = "The list to find the maximum in.") List<Object> list) {
if (list == null || list.isEmpty()) return null;
if (list.size() == 1) return list.get(0);
var preparser = "CYPHER " + Util.getCypherVersionString(procedureCallContext) + " runtime=slotted ";
try (Result result = tx.execute(
"cypher runtime=slotted return reduce(res=null, x in $list | CASE WHEN res IS NULL OR res<x THEN x ELSE res END) as value",
preparser
+ "RETURN reduce(res=null, x in $list | CASE WHEN res IS NULL OR res<x THEN x ELSE res END) AS value",
Collections.singletonMap("list", list))) {
return result.next().get("value");
}
Expand Down
28 changes: 18 additions & 10 deletions core/src/main/java/apoc/cypher/Cypher.java
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import org.neo4j.graphdb.Result;
import org.neo4j.graphdb.Transaction;
import org.neo4j.graphdb.security.AuthorizationViolationException;
import org.neo4j.internal.kernel.api.procs.ProcedureCallContext;
import org.neo4j.procedure.Context;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Mode;
Expand All @@ -73,14 +74,17 @@ public class Cypher {
@Context
public Pools pools;

@Context
public ProcedureCallContext procedureCallContext;

@NotThreadSafe
@Procedure("apoc.cypher.run")
@Description("Runs a dynamically constructed read-only statement with the given parameters.")
public Stream<CypherStatementMapResult> run(
@Name(value = "statement", description = "The Cypher statement to run.") String statement,
@Name(value = "params", description = "The parameters for the given Cypher statement.")
Map<String, Object> params) {
return runCypherQuery(tx, statement, params);
return runCypherQuery(tx, statement, params, procedureCallContext);
}

@Procedure(name = "apoc.cypher.runMany", mode = WRITE)
Expand All @@ -107,7 +111,8 @@ private Stream<Cypher.RowResult> streamInNewTx(String cypher, Map<String, Object
// At this point you may have questions like;
// - "Why do we execute this statement in a new transaction?"
// My guess is as good as yours. This is the way of the apoc. Safe travels.
final var results = new RunManyResultSpliterator(innerTx.execute(cypher, params), stats);
final var results = new RunManyResultSpliterator(
innerTx.execute(Util.prefixQueryWithCheck(procedureCallContext, cypher), params), stats);
return StreamSupport.stream(results, false).onClose(results::close).onClose(innerTx::commit);
} catch (AuthorizationViolationException accessModeException) {
// We meet again, few people make it this far into this world!
Expand Down Expand Up @@ -188,7 +193,7 @@ public Stream<CypherStatementMapResult> doIt(
@Name(value = "statement", description = "The Cypher statement to run.") String statement,
@Name(value = "params", description = "The parameters for the given Cypher statement.")
Map<String, Object> params) {
return runCypherQuery(tx, statement, params);
return runCypherQuery(tx, statement, params, procedureCallContext);
}

@Procedure(name = "apoc.cypher.runWrite", mode = WRITE)
Expand All @@ -206,7 +211,7 @@ public Stream<CypherStatementMapResult> runSchema(
@Name(value = "statement", description = "The Cypher schema statement to run.") String statement,
@Name(value = "params", description = "The parameters for the given Cypher statement.")
Map<String, Object> params) {
return runCypherQuery(tx, statement, params);
return runCypherQuery(tx, statement, params, procedureCallContext);
}

@NotThreadSafe
Expand All @@ -231,8 +236,9 @@ public Stream<CaseMapResult> when(
if (targetQuery.isEmpty()) {
return Stream.of(new CaseMapResult(Collections.emptyMap()));
} else {
return tx.execute(withParamMapping(targetQuery, params.keySet()), params).stream()
.map(CaseMapResult::new);
String query =
Util.prefixQueryWithCheck(procedureCallContext, withParamMapping(targetQuery, params.keySet()));
return tx.execute(query, params).stream().map(CaseMapResult::new);
}
}

Expand Down Expand Up @@ -294,16 +300,18 @@ public Stream<CaseMapResult> whenCase(
String ifQuery = (String) caseItr.next();

if (condition) {
return tx.execute(withParamMapping(ifQuery, params.keySet()), params).stream()
.map(CaseMapResult::new);
String query =
Util.prefixQueryWithCheck(procedureCallContext, withParamMapping(ifQuery, params.keySet()));
return tx.execute(query, params).stream().map(CaseMapResult::new);
}
}

if (elseQuery.isEmpty()) {
return Stream.of(new CaseMapResult(Collections.emptyMap()));
} else {
return tx.execute(withParamMapping(elseQuery, params.keySet()), params).stream()
.map(CaseMapResult::new);
String query =
Util.prefixQueryWithCheck(procedureCallContext, withParamMapping(elseQuery, params.keySet()));
return tx.execute(query, params).stream().map(CaseMapResult::new);
}
}

Expand Down
7 changes: 6 additions & 1 deletion core/src/main/java/apoc/cypher/CypherFunctions.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@

import static apoc.cypher.CypherUtils.withParamMapping;

import apoc.util.Util;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.neo4j.graphdb.ResourceIterator;
import org.neo4j.graphdb.Result;
import org.neo4j.graphdb.Transaction;
import org.neo4j.internal.kernel.api.procs.ProcedureCallContext;
import org.neo4j.procedure.Context;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Name;
Expand All @@ -40,10 +42,13 @@ public class CypherFunctions {
@Context
public Transaction tx;

@Context
public ProcedureCallContext procedureCallContext;

public Object runFirstColumn(String statement, Map<String, Object> params, boolean expectMultipleValues) {
if (params == null) params = Collections.emptyMap();
String resolvedStatement = withParamMapping(statement, params.keySet());
if (!resolvedStatement.contains(" runtime")) resolvedStatement = "cypher runtime=slotted " + resolvedStatement;
resolvedStatement = Util.slottedRuntime(resolvedStatement, Util.getCypherVersionString(procedureCallContext));
try (Result result = tx.execute(resolvedStatement, params)) {

String firstColumn = result.columns().get(0);
Expand Down
8 changes: 7 additions & 1 deletion core/src/main/java/apoc/cypher/Timeboxed.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.neo4j.graphdb.Result;
import org.neo4j.graphdb.Transaction;
import org.neo4j.graphdb.TransactionTerminatedException;
import org.neo4j.internal.kernel.api.procs.ProcedureCallContext;
import org.neo4j.kernel.api.QueryLanguage;
import org.neo4j.kernel.api.procedure.QueryLanguageScope;
import org.neo4j.logging.Log;
Expand Down Expand Up @@ -63,6 +64,9 @@ public class Timeboxed {
@Context
public TerminationGuard terminationGuard;

@Context
public ProcedureCallContext procedureCallContext;

private static final Map<String, Object> POISON = Collections.singletonMap("__magic", "POISON");

@NotThreadSafe
Expand Down Expand Up @@ -105,7 +109,9 @@ public Stream<CypherStatementMapResult> runTimeboxed(
pools.getDefaultExecutorService().submit(() -> {
try (Transaction innerTx = db.beginTx()) {
txAtomic.set(innerTx);
Result result = innerTx.execute(cypher, params == null ? Collections.EMPTY_MAP : params);
Result result = innerTx.execute(
Util.prefixQueryWithCheck(procedureCallContext, cypher),
params == null ? Collections.EMPTY_MAP : params);
while (result.hasNext()) {
if (Util.transactionIsTerminated(terminationGuard)) {
txAtomic.get().close();
Expand Down
6 changes: 5 additions & 1 deletion core/src/main/java/apoc/example/Examples.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.neo4j.graphdb.QueryStatistics;
import org.neo4j.graphdb.Result;
import org.neo4j.graphdb.Transaction;
import org.neo4j.internal.kernel.api.procs.ProcedureCallContext;
import org.neo4j.procedure.Context;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Mode;
Expand All @@ -34,6 +35,9 @@ public class Examples {
@Context
public Transaction tx;

@Context
public ProcedureCallContext procedureCallContext;

public static class ExamplesProgressInfo {
@Description("The name of the file containing the movies example.")
public final String file;
Expand Down Expand Up @@ -89,7 +93,7 @@ public ExamplesProgressInfo(long nodes, long relationships, long properties, lon
public Stream<ExamplesProgressInfo> movies() {
long start = System.currentTimeMillis();
String file = "movies.cypher";
Result result = tx.execute(Util.readResourceFile(file));
Result result = tx.execute(Util.prefixQuery(procedureCallContext, Util.readResourceFile(file)));
QueryStatistics stats = result.getQueryStatistics();
ExamplesProgressInfo progress = new ExamplesProgressInfo(
stats.getNodesCreated(),
Expand Down
Loading

0 comments on commit a1736e0

Please sign in to comment.