Skip to content

Commit

Permalink
changes review
Browse files Browse the repository at this point in the history
  • Loading branch information
vga91 committed Feb 5, 2025
1 parent c4bf9c1 commit ba2e72f
Show file tree
Hide file tree
Showing 12 changed files with 315 additions and 114 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -526,25 +526,25 @@ CALL apoc.load.jdbcUpdate('jdbc:derby:derbyDB','UPDATE PERSON SET NAME = ? WHERE
CALL apoc.load.jdbc('jdbc:derby:derbyDB', 'PERSON',[],{credentials:{user:'apoc',password:'Ap0c!#Db'}})
----

== Load JDBC - Analytics
== JDBC Analytics

You can use the `apoc.jdbc.analytics(<cypherQuery>, <jdbcUrl>, <sqlQueryOverTemporaryTable>, <paramsList>, $config)`
to create a temporary table starting from a Cypher query
and delegate complex analytics to the database defined JDBC URL.
and delegate complex analytics to the database defined by JDBC URL.

Please note that the returning SQL column names have to be consistent with the one provided by the Cypher query.

In addition to the configurations of the `apoc.load.jdbc` procedure, the `apoc.jdbc.analytics` provides the following ones:

[cols="1m,2,1"]
[opts=header, cols="1m,2,1"]
|===
| name | description | default value
| tableName | the temporary table name | neo4j_tmp_table
| provider | the SQL provider, to handle data type based on it, possible values are "POSTGRES", "MYSQL" and "DEFAULT" | "DEFAULT"
| provider | the SQL provider, to handle data type based on it, possible values are "POSTGRES", "MYSQL" and "DUCKDB" | "DUCKDB"
|===


It is possible to specify a provider in the config parameters.
The default value is "DUCKDB".

You can reproduce the following queries using some nodes:

Expand All @@ -568,7 +568,6 @@ Fields of the SQL query should be consistent with the Cypher query.
For detailed information go to https://duckdb.org/docs/sql/functions/window_functions.html#rank



[source,cypher]
----
CALL apoc.jdbc.analytics(
Expand Down Expand Up @@ -626,6 +625,31 @@ CALL apoc.jdbc.analytics(
)
----

In DuckDB, we can also use an in-memory instance using the `jdbc:duckdb:` URL:

[source,cypher]
----
CALL apoc.jdbc.analytics(
"MATCH (n:City) RETURN n.country AS country, n.name AS name, n.year AS year, n.population AS population",
'jdbc:duckdb:',
"PIVOT 'neo4j_tmp_table'
ON year
USING sum(population)
ORDER by name"
)
----

[source,cypher]
----
CALL apoc.jdbc.analytics(
"MATCH (n:City) RETURN n.country AS country, n.name AS name, n.year AS year, n.population AS population",
$url,
"PIVOT 'neo4j_tmp_table'
ON year
USING sum(population)
ORDER by name"
)
----

=== MySQL

Expand Down
23 changes: 18 additions & 5 deletions extended-it/src/test/java/apoc/load/MySQLJdbcTest.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
package apoc.load;

import apoc.load.jdbc.AbstractJdbcTest;
import apoc.load.jdbc.Analytics;
import apoc.load.jdbc.Jdbc;
import apoc.util.s3.MySQLContainerExtension;
import apoc.util.TestUtil;
import apoc.util.Util;
Expand All @@ -20,7 +23,7 @@
import java.time.ZonedDateTime;
import java.util.Map;

import static apoc.load.Analytics.PROVIDER_CONF_KEY;
import static apoc.load.jdbc.Analytics.PROVIDER_CONF_KEY;
import static apoc.util.MapUtil.map;
import static apoc.util.TestUtil.testCall;
import static apoc.util.TestUtil.testResult;
Expand Down Expand Up @@ -68,21 +71,26 @@ public void testIssue3496() {

@Test
public void testLoadJdbcAnalytics() {
String cypher = "MATCH (n:City) RETURN n.country AS country, n.name AS name, n.year AS year, n.population AS population";

String sql = """
SELECT
country,
name,
year,
date,
time,
datetime,
localtime,
localdatetime,
duration,
population,
RANK() OVER (PARTITION BY country ORDER BY year DESC) AS 'rank'
FROM %s
ORDER BY country, name;
""".formatted(Analytics.TABLE_NAME_DEFAULT_CONF_KEY);
testResult(db, "CALL apoc.jdbc.analytics($queryCypher, $url, $sql, [], $config)",
map(
"queryCypher", cypher,
"queryCypher", MATCH_SQL_ANALYTICS,
"sql", sql,
"url", mysql.getJdbcUrl(),
"config", map(PROVIDER_CONF_KEY, Analytics.Provider.MYSQL.name())
Expand All @@ -92,13 +100,18 @@ public void testLoadJdbcAnalytics() {

@Test
public void testLoadJdbcAnalyticsWindow() {
String cypher = "MATCH (n:City) RETURN n.country AS country, n.name AS name, n.year AS year, n.population AS population";

String sql = """
SELECT
country,
name,
year,
date,
time,
datetime,
localtime,
localdatetime,
duration,
population,
ROW_NUMBER() OVER (PARTITION BY country ORDER BY year DESC) AS 'rank'
FROM %s
Expand All @@ -107,7 +120,7 @@ public void testLoadJdbcAnalyticsWindow() {

testResult(db, "CALL apoc.jdbc.analytics($queryCypher, $url, $sql, [], $config)",
map(
"queryCypher", cypher,
"queryCypher", MATCH_SQL_ANALYTICS,
"sql", sql,
"url", mysql.getJdbcUrl(),
"config", map(PROVIDER_CONF_KEY, Analytics.Provider.MYSQL.name())
Expand Down
45 changes: 29 additions & 16 deletions extended-it/src/test/java/apoc/load/PostgresJdbcTest.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
package apoc.load;

import apoc.load.jdbc.AbstractJdbcTest;
import apoc.load.jdbc.Analytics;
import apoc.load.jdbc.Jdbc;
import apoc.periodic.Periodic;
import apoc.text.Strings;
import apoc.util.TestUtil;
Expand All @@ -20,7 +23,7 @@
import java.util.Map;
import java.util.concurrent.TimeUnit;

import static apoc.load.Analytics.PROVIDER_CONF_KEY;
import static apoc.load.jdbc.Analytics.PROVIDER_CONF_KEY;
import static apoc.util.MapUtil.map;
import static apoc.util.TestUtil.testCall;
import static apoc.util.TestUtil.testResult;
Expand All @@ -44,7 +47,7 @@ public class PostgresJdbcTest extends AbstractJdbcTest {
public static void setUp() throws Exception {
postgress = new PostgreSQLContainer().withInitScript("init_postgres.sql");
postgress.start();
TestUtil.registerProcedure(db,Jdbc.class, Periodic.class, Strings.class, Analytics.class);
TestUtil.registerProcedure(db, Jdbc.class, Periodic.class, Strings.class, Analytics.class);
db.executeTransactionally("CALL apoc.load.driver('org.postgresql.Driver')");

String movies = Util.readResourceFile(ANALYTICS_CYPHER_FILE);
Expand Down Expand Up @@ -145,22 +148,27 @@ public void testIssue4141PeriodicIterateWithJdbc() throws Exception {

@Test
public void testLoadJdbcAnalytics() {
String cypher = "MATCH (n:City) RETURN n.country AS country, n.name AS name, n.year AS year, n.population AS population";

String sql = """
SELECT
country,
name,
year,
population,
RANK() OVER (PARTITION BY country ORDER BY year DESC) rank
FROM %s
ORDER BY rank, country, name;
""".formatted(Analytics.TABLE_NAME_DEFAULT_CONF_KEY);
SELECT
country,
name,
year,
date,
time,
datetime,
localtime,
localdatetime,
duration,
population,
RANK() OVER (PARTITION BY country ORDER BY year DESC) rank
FROM %s
ORDER BY rank, country, name;
""".formatted(Analytics.TABLE_NAME_DEFAULT_CONF_KEY);

testResult(db, "CALL apoc.jdbc.analytics($queryCypher, $url, $sql, [], $config)",
map(
"queryCypher", cypher,
"queryCypher", MATCH_SQL_ANALYTICS,
"sql", sql,
"url", getUrl(postgress),
"config", map(PROVIDER_CONF_KEY, Analytics.Provider.POSTGRES.name())
Expand All @@ -170,13 +178,18 @@ public void testLoadJdbcAnalytics() {

@Test
public void testLoadJdbcAnalyticsWindow() {
String cypher = "MATCH (n:City) RETURN n.country AS country, n.name AS name, n.year AS year, n.population AS population";

String sql = """
SELECT
country,
name,
year,
date,
time,
datetime,
localtime,
localdatetime,
duration,
population,
ROW_NUMBER() OVER (PARTITION BY country ORDER BY year DESC) rank
FROM %s
Expand All @@ -185,10 +198,10 @@ public void testLoadJdbcAnalyticsWindow() {

testResult(db, "CALL apoc.jdbc.analytics($queryCypher, $url, $sql, [], $config)",
map(
"queryCypher", cypher,
"queryCypher", MATCH_SQL_ANALYTICS,
"sql", sql,
"url", getUrl(postgress),
"config", map(PROVIDER_CONF_KEY, Analytics.Provider.MYSQL.name())
"config", map(PROVIDER_CONF_KEY, Analytics.Provider.POSTGRES.name())
),
r -> commonAnalyticsAssertions(r, 2L, 4L, 6L));
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package apoc.load;
package apoc.load.jdbc;

import apoc.Extended;
import apoc.load.util.LoadJdbcConfig;
Expand All @@ -13,28 +13,34 @@
import org.neo4j.procedure.Procedure;

import java.sql.Connection;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.time.*;
import java.time.format.DateTimeFormatter;
import java.util.*;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static apoc.load.Jdbc.executeQuery;
import static apoc.load.Jdbc.executeUpdate;
import static apoc.load.util.JdbcUtil.getConnection;
import static apoc.load.util.JdbcUtil.getUrlOrKey;
import static apoc.load.jdbc.Jdbc.executeQuery;
import static apoc.load.jdbc.Jdbc.executeUpdate;
import static apoc.load.util.JdbcUtil.*;

@Extended
public class Analytics {
public static final String PROVIDER_CONF_KEY = "provider";
public static final String TABLE_NAME_CONF_KEY = "tableName";
public static final String TABLE_NAME_DEFAULT_CONF_KEY = "neo4j_tmp_table";

enum Provider {
DEFAULT,
POSTGRES,
MYSQL
public enum Provider {
DUCKDB(DUCK_TYPE_MAP, "\"%s\" %s"),
POSTGRES(POSTGRES_TYPE_MAP, "\"%s\" %s"),
MYSQL(MYSQL_TYPE_MAP, "`%s` %s");

public final Map<Class<?>, String> typeMap;
public final String tableTypeTemplate;
Provider(Map<Class<?>, String> typeMap, String tableTypeTemplate) {
this.typeMap = typeMap;
this.tableTypeTemplate = tableTypeTemplate;
}
}

@Context
Expand All @@ -55,7 +61,7 @@ public Stream<RowResult> aggregate(
@Name(value = "params", defaultValue = "[]") List<Object> params,
@Name(value = "config",defaultValue = "{}") Map<String, Object> config) {
AtomicReference<String> createTable = new AtomicReference<>();
final Provider provider = Provider.valueOf((String) config.getOrDefault(PROVIDER_CONF_KEY, Provider.DEFAULT.name()));
final Provider provider = Provider.valueOf((String) config.getOrDefault(PROVIDER_CONF_KEY, Provider.DUCKDB.name()));
final String tableName = (String) config.getOrDefault(TABLE_NAME_CONF_KEY, TABLE_NAME_DEFAULT_CONF_KEY);

AtomicReference<String> columns = new AtomicReference<>();
Expand All @@ -74,7 +80,6 @@ public Stream<RowResult> aggregate(

// convert Neo4j row result to SQL row
final String row = getStreamSortedByKey(map)
// .map(e -> addFieldToTempTable(e, sqlTypesForTempTable, provider))
.map(Map.Entry::getValue)
.map(Analytics::formatSqlValue)
.collect(Collectors.joining(","));
Expand Down Expand Up @@ -130,7 +135,10 @@ public Stream<RowResult> aggregate(
*/
private String getTempTableClause(Map<String, Object> map, Provider provider, String tableName) {
String sqlFields = getStreamSortedByKey(map)
.map(e -> e.getKey() + " " + mapSqlType(provider, e.getValue()))
.map(e -> {
String type = mapSqlType(provider, e.getValue());
return provider.tableTypeTemplate.formatted(e.getKey(), type);
})
.collect(Collectors.joining(","));

return "CREATE TEMPORARY TABLE %s (%s)".formatted(tableName, sqlFields);
Expand All @@ -142,23 +150,22 @@ private static Stream<Map.Entry<String, Object>> getStreamSortedByKey(Map<String
.sorted(Map.Entry.comparingByKey());
}


private static String formatSqlValue(Object x) {
final String stringValue = x.toString();
if (x instanceof Number) return stringValue;
private static String formatSqlValue(Object val) {
String stringValue = val.toString();
if (val instanceof Number) {
return stringValue;
}
if (val instanceof ZonedDateTime zonedDateTime) {
stringValue = toSqlCompatibleDateTime(zonedDateTime);
}
if (val instanceof OffsetTime zonedDateTime) {
stringValue = toSqlCompatibleTimeFormat(zonedDateTime);
}
return String.format("'%s'", stringValue.replace("'", "''"));
}

private String mapSqlType(Provider provider, Object value) {
return switch (provider) {
case MYSQL, POSTGRES -> {
if (value instanceof Number) yield "INTEGER";
else yield "VARCHAR(1000)";
}
default -> {
if (value instanceof Number) yield "INTEGER";
else yield "VARCHAR";
}
};
Class<?> clazz = value.getClass();
return provider.typeMap.getOrDefault(clazz, VARCHAR_TYPE);
}
}
Loading

0 comments on commit ba2e72f

Please sign in to comment.