Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use scalars for temporal types #235

Merged
merged 1 commit into from
Jun 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions core/src/main/kotlin/org/neo4j/graphql/GraphQLExtensions.kt
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,12 @@ fun GraphQLType.requiredName(): String = requireNotNull(name()) { "name is requi

fun GraphQLType.isList() = this is GraphQLList || (this is GraphQLNonNull && this.wrappedType is GraphQLList)
fun GraphQLType.isNeo4jType() = this.innerName().startsWith("_Neo4j")
fun GraphQLType.isNeo4jTemporalType() = NEO4j_TEMPORAL_TYPES.contains(this.innerName())

fun GraphQLType.isNeo4jSpatialType() = this.innerName().startsWith("_Neo4jPoint")
fun TypeDefinition<*>.isNeo4jSpatialType() = this.name.startsWith("_Neo4jPoint")

fun GraphQLFieldDefinition.isNeo4jType(): Boolean = this.type.isNeo4jType()
fun GraphQLFieldDefinition.isNeo4jTemporalType(): Boolean = this.type.isNeo4jTemporalType()

fun GraphQLFieldDefinition.isRelationship() = !type.isNeo4jType() && this.type.inner().let { it is GraphQLFieldsContainer }

Expand Down Expand Up @@ -184,7 +185,7 @@ fun Value<*>.toJavaValue(): Any? = when (this) {

fun GraphQLFieldDefinition.isID() = this.type.inner() == Scalars.GraphQLID
fun GraphQLFieldDefinition.isNativeId() = this.name == ProjectionBase.NATIVE_ID
fun GraphQLFieldDefinition.isIgnored() = getDirective(DirectiveConstants.IGNORE) != null
fun GraphQLFieldDefinition.isIgnored() = getDirective(DirectiveConstants.IGNORE) != null
fun FieldDefinition.isIgnored(): Boolean = hasDirective(DirectiveConstants.IGNORE)

fun GraphQLFieldsContainer.getIdField() = this.getRelevantFieldDefinitions().find { it.isID() }
Expand Down
43 changes: 36 additions & 7 deletions core/src/main/kotlin/org/neo4j/graphql/Neo4jTypes.kt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,16 @@ data class TypeDefinition(
val inputDefinition: String = typeDefinition + "Input"
)

class Neo4jTemporalConverter(name: String) : Neo4jSimpleConverter(name) {
override fun projectField(variable: SymbolicName, field: Field, name: String): Any {
return Cypher.call("toString").withArgs(variable.property(field.name)).asFunction()
}

override fun createCondition(property: Property, parameter: Parameter<Any>, conditionCreator: (Expression, Expression) -> Condition): Condition {
return conditionCreator(property, toExpression(parameter))
}
}

class Neo4jTimeConverter(name: String) : Neo4jConverter(name) {

override fun createCondition(
Expand Down Expand Up @@ -63,23 +73,30 @@ class Neo4jPointConverter(name: String) : Neo4jConverter(name) {
}

open class Neo4jConverter(
val name: String,
name: String,
val prefixedName: String = "_Neo4j$name",
val typeDefinition: TypeDefinition = TypeDefinition(name, prefixedName)
) {
) : Neo4jSimpleConverter(name) {
}

open class Neo4jSimpleConverter(val name: String) {
protected fun toExpression(parameter: Expression): Expression {
return Cypher.call(name.toLowerCase()).withArgs(parameter).asFunction()
}

open fun createCondition(
property: Property,
parameter: Parameter<Any>,
conditionCreator: (Expression, Expression) -> Condition
): Condition = conditionCreator(property, parameter)

open fun createCondition(
objectField: ObjectField,
field: GraphQLFieldDefinition,
parameter: Parameter<Any>,
conditionCreator: (Expression, Expression) -> Condition,
propertyContainer: PropertyContainer
): Condition = conditionCreator(propertyContainer.property(field.name, objectField.name), parameter)

): Condition = createCondition(propertyContainer.property(field.name, objectField.name), parameter, conditionCreator)

open fun projectField(variable: SymbolicName, field: Field, name: String): Any = variable.property(field.name, name)

Expand All @@ -89,10 +106,10 @@ open class Neo4jConverter(
}
}

fun getNeo4jTypeConverter(field: GraphQLFieldDefinition): Neo4jConverter = getNeo4jTypeConverter(field.type.innerName())
fun getNeo4jTypeConverter(field: GraphQLFieldDefinition): Neo4jSimpleConverter = getNeo4jTypeConverter(field.type.innerName())

fun getNeo4jTypeConverter(name: String): Neo4jConverter =
neo4jConverter[name] ?: throw RuntimeException("Type $name not found")
private fun getNeo4jTypeConverter(name: String): Neo4jSimpleConverter =
neo4jConverter[name] ?: neo4jScalarConverter[name] ?: throw RuntimeException("Type $name not found")

private val neo4jConverter = listOf(
Neo4jTimeConverter("LocalTime"),
Expand All @@ -105,4 +122,16 @@ private val neo4jConverter = listOf(
.map { it.prefixedName to it }
.toMap()

private val neo4jScalarConverter = listOf(
Neo4jTemporalConverter("LocalTime"),
Neo4jTemporalConverter("Date"),
Neo4jTemporalConverter("DateTime"),
Neo4jTemporalConverter("Time"),
Neo4jTemporalConverter("LocalDateTime")
)
.map { it.name to it }
.toMap()

val NEO4j_TEMPORAL_TYPES = neo4jScalarConverter.keys

val neo4jTypeDefinitions = neo4jConverter.values.map { it.typeDefinition }
102 changes: 50 additions & 52 deletions core/src/main/kotlin/org/neo4j/graphql/Predicates.kt
Original file line number Diff line number Diff line change
Expand Up @@ -14,41 +14,53 @@ typealias CypherDSL = org.neo4j.cypherdsl.core.Cypher

enum class FieldOperator(
val suffix: String,
val op: String,
private val conditionCreator: (Expression, Expression) -> Condition,
val not: Boolean = false,
val requireParam: Boolean = true,
val distance: Boolean = false
val distance: Boolean = false,
val list: Boolean = false
) {
EQ("", "=", { lhs, rhs -> lhs.isEqualTo(rhs) }),
IS_NULL("", "", { lhs, _ -> lhs.isNull }, requireParam = false),
IS_NOT_NULL("_not", "", { lhs, _ -> lhs.isNotNull }, true, requireParam = false),
NEQ("_not", "=", { lhs, rhs -> lhs.isEqualTo(rhs).not() }, true),
GTE("_gte", ">=", { lhs, rhs -> lhs.gte(rhs) }),
GT("_gt", ">", { lhs, rhs -> lhs.gt(rhs) }),
LTE("_lte", "<=", { lhs, rhs -> lhs.lte(rhs) }),
LT("_lt", "<", { lhs, rhs -> lhs.lt(rhs) }),

NIN("_not_in", "IN", { lhs, rhs -> lhs.`in`(rhs).not() }, true),
IN("_in", "IN", { lhs, rhs -> lhs.`in`(rhs) }),
NC("_not_contains", "CONTAINS", { lhs, rhs -> lhs.contains(rhs).not() }, true),
NSW("_not_starts_with", "STARTS WITH", { lhs, rhs -> lhs.startsWith(rhs).not() }, true),
NEW("_not_ends_with", "ENDS WITH", { lhs, rhs -> lhs.endsWith(rhs).not() }, true),
C("_contains", "CONTAINS", { lhs, rhs -> lhs.contains(rhs) }),
SW("_starts_with", "STARTS WITH", { lhs, rhs -> lhs.startsWith(rhs) }),
EW("_ends_with", "ENDS WITH", { lhs, rhs -> lhs.endsWith(rhs) }),
MATCHES("_matches", "=~", { lhs, rhs -> lhs.matches(rhs) }),


DISTANCE(NEO4j_POINT_DISTANCE_FILTER_SUFFIX, "=", { lhs, rhs -> lhs.isEqualTo(rhs) }, distance = true),
DISTANCE_LT(NEO4j_POINT_DISTANCE_FILTER_SUFFIX + "_lt", "<", { lhs, rhs -> lhs.lt(rhs) }, distance = true),
DISTANCE_LTE(NEO4j_POINT_DISTANCE_FILTER_SUFFIX + "_lte", "<=", { lhs, rhs -> lhs.lte(rhs) }, distance = true),
DISTANCE_GT(NEO4j_POINT_DISTANCE_FILTER_SUFFIX + "_gt", ">", { lhs, rhs -> lhs.gt(rhs) }, distance = true),
DISTANCE_GTE(NEO4j_POINT_DISTANCE_FILTER_SUFFIX + "_gte", ">=", { lhs, rhs -> lhs.gte(rhs) }, distance = true);

val list = op == "IN"

fun resolveCondition(variablePrefix: String, queriedField: String, propertyContainer: PropertyContainer, field: GraphQLFieldDefinition?, value: Any, suffix: String? = null): List<Condition> {
EQ("", { lhs, rhs -> lhs.isEqualTo(rhs) }),
IS_NULL("", { lhs, _ -> lhs.isNull }, requireParam = false),
IS_NOT_NULL("_not", { lhs, _ -> lhs.isNotNull }, true, requireParam = false),
NEQ("_not", { lhs, rhs -> lhs.isEqualTo(rhs).not() }, not = true),
GTE("_gte", { lhs, rhs -> lhs.gte(rhs) }),
GT("_gt", { lhs, rhs -> lhs.gt(rhs) }),
LTE("_lte", { lhs, rhs -> lhs.lte(rhs) }),
LT("_lt", { lhs, rhs -> lhs.lt(rhs) }),

NIN("_not_in", { lhs, rhs -> lhs.`in`(rhs).not() }, not = true, list = true),
IN("_in", { lhs, rhs -> lhs.`in`(rhs) }, list = true),
NC("_not_contains", { lhs, rhs -> lhs.contains(rhs).not() }, not = true),
NSW("_not_starts_with", { lhs, rhs -> lhs.startsWith(rhs).not() }, not = true),
NEW("_not_ends_with", { lhs, rhs -> lhs.endsWith(rhs).not() }, not = true),
C("_contains", { lhs, rhs -> lhs.contains(rhs) }),
SW("_starts_with", { lhs, rhs -> lhs.startsWith(rhs) }),
EW("_ends_with", { lhs, rhs -> lhs.endsWith(rhs) }),
MATCHES("_matches", { lhs, rhs -> lhs.matches(rhs) }),


DISTANCE(NEO4j_POINT_DISTANCE_FILTER_SUFFIX, { lhs, rhs -> lhs.isEqualTo(rhs) }, distance = true),
DISTANCE_LT(NEO4j_POINT_DISTANCE_FILTER_SUFFIX + "_lt", { lhs, rhs -> lhs.lt(rhs) }, distance = true),
DISTANCE_LTE(NEO4j_POINT_DISTANCE_FILTER_SUFFIX + "_lte", { lhs, rhs -> lhs.lte(rhs) }, distance = true),
DISTANCE_GT(NEO4j_POINT_DISTANCE_FILTER_SUFFIX + "_gt", { lhs, rhs -> lhs.gt(rhs) }, distance = true),
DISTANCE_GTE(NEO4j_POINT_DISTANCE_FILTER_SUFFIX + "_gte", { lhs, rhs -> lhs.gte(rhs) }, distance = true);

fun resolveCondition(
variablePrefix: String,
queriedField: String,
propertyContainer: PropertyContainer,
field: GraphQLFieldDefinition?,
value: Any,
schemaConfig: SchemaConfig,
suffix: String? = null
): List<Condition> {
if (schemaConfig.useTemporalScalars && field?.type?.isNeo4jTemporalType() == true) {
val neo4jTypeConverter = getNeo4jTypeConverter(field)
val parameter = queryParameter(value, variablePrefix, queriedField, null, suffix)
.withValue(value.toJavaValue())
return listOf(neo4jTypeConverter.createCondition(propertyContainer.property(field.name), parameter, conditionCreator))
}
return if (field?.type?.isNeo4jType() == true && value is ObjectValue) {
resolveNeo4jTypeConditions(variablePrefix, queriedField, propertyContainer, field, value, suffix)
} else if (field?.isNativeId() == true) {
Expand Down Expand Up @@ -96,20 +108,6 @@ enum class FieldOperator(

companion object {

fun resolve(queriedField: String, field: GraphQLFieldDefinition, value: Any?): FieldOperator? {
val fieldName = field.name
if (value == null) {
return listOf(IS_NULL, IS_NOT_NULL).find { queriedField == fieldName + it.suffix }
}
val ops = enumValues<FieldOperator>().filterNot { it == IS_NULL || it == IS_NOT_NULL }
return ops.find { queriedField == fieldName + it.suffix }
?: if (field.type.isNeo4jSpatialType()) {
ops.find { queriedField == fieldName + NEO4j_POINT_DISTANCE_FILTER_SUFFIX + it.suffix }
} else {
null
}
}

fun forType(type: TypeDefinition<*>, isNeo4jType: Boolean): List<FieldOperator> =
when {
type.name == TypeBoolean.name -> listOf(EQ, NEQ)
Expand All @@ -128,17 +126,17 @@ enum class FieldOperator(
fun fieldName(fieldName: String) = fieldName + suffix
}

enum class RelationOperator(val suffix: String, val op: String) {
SOME("_some", "ANY"),
enum class RelationOperator(val suffix: String) {
SOME("_some"),

EVERY("_every", "ALL"),
EVERY("_every"),

SINGLE("_single", "SINGLE"),
NONE("_none", "NONE"),
SINGLE("_single"),
NONE("_none"),

// `eq` if queried with an object, `not exists` if queried with null
EQ_OR_NOT_EXISTS("", ""),
NOT("_not", "");
EQ_OR_NOT_EXISTS(""),
NOT("_not");

fun fieldName(fieldName: String) = fieldName + suffix

Expand Down
5 changes: 5 additions & 0 deletions core/src/main/kotlin/org/neo4j/graphql/SchemaConfig.kt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ data class SchemaConfig @JvmOverloads constructor(
* additionally the separated filter arguments will no longer be generated.
*/
val useWhereFilter: Boolean = false,

/**
* if enabled the `Date`, `Time`, `LocalTime`, `DateTime` and `LocalDateTime` are used as scalars
*/
val useTemporalScalars: Boolean = false,
) {
data class CRUDConfig(val enabled: Boolean = true, val exclude: List<String> = emptyList())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ abstract class BaseDataFetcherForContainer(schemaConfig: SchemaConfig) : BaseDat
val dynamicPrefix = field.dynamicPrefix()
propertyFields[field.name] = when {
dynamicPrefix != null -> dynamicPrefixCallback(field, dynamicPrefix)
field.isNeo4jType() -> neo4jTypeCallback(field)
field.isNeo4jType() || (schemaConfig.useTemporalScalars && field.isNeo4jTemporalType()) -> neo4jTypeCallback(field)
else -> defaultCallback(field)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class OptimizedFilterHandler(val type: GraphQLFieldsContainer, schemaConfig: Sch
* @param value the value passed to the graphQL field
* @param parentPassThroughWiths all the nodes, required to be passed through via WITH
*/
class NestingLevelHandler(
inner class NestingLevelHandler(
private val parsedQuery: ParsedQuery,
private val useDistinct: Boolean,
private val current: PropertyContainer,
Expand Down Expand Up @@ -113,7 +113,7 @@ class OptimizedFilterHandler(val type: GraphQLFieldsContainer, schemaConfig: Sch
}

// WHERE MATCH all predicates for current
val condition = parsedQuery.getFieldConditions(current, variablePrefix, "")
val condition = parsedQuery.getFieldConditions(current, variablePrefix, "", schemaConfig)
val matchQueryWithWhere = matchQueryWithoutWhere.where(condition)

return if (additionalConditions != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ open class ProjectionBase(
type: GraphQLFieldsContainer,
variables: Map<String, Any>
): Condition {
var result = parsedQuery.getFieldConditions(propertyContainer, variablePrefix, variableSuffix)
var result = parsedQuery.getFieldConditions(propertyContainer, variablePrefix, variableSuffix, schemaConfig)

for (predicate in parsedQuery.relationPredicates) {
val objectField = predicate.queryField
Expand Down Expand Up @@ -269,6 +269,9 @@ open class ProjectionBase(
}

} else when {
schemaConfig.useTemporalScalars && fieldDefinition.isNeo4jTemporalType() -> {
projections += getNeo4jTypeConverter(fieldDefinition).projectField(variable, field, "")
}
isObjectField -> {
if (fieldDefinition.isNeo4jType()) {
if (propertiesToSkipDeepProjection.contains(fieldDefinition.name)) {
Expand Down
7 changes: 4 additions & 3 deletions core/src/main/kotlin/org/neo4j/graphql/parser/QueryParser.kt
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ class ParsedQuery(
val and: List<Value<*>>? = null
) {

fun getFieldConditions(propertyContainer: PropertyContainer, variablePrefix: String, variableSuffix: String): Condition =
fun getFieldConditions(propertyContainer: PropertyContainer, variablePrefix: String, variableSuffix: String, schemaConfig: SchemaConfig): Condition =
fieldPredicates
.flatMap { it.createCondition(propertyContainer, variablePrefix, variableSuffix) }
.flatMap { it.createCondition(propertyContainer, variablePrefix, variableSuffix, schemaConfig) }
.reduceOrNull { result, condition -> result.and(condition) }
?: Conditions.noCondition()
}
Expand All @@ -43,13 +43,14 @@ class FieldPredicate(
index: Int
) : Predicate<FieldOperator>(op, queryField, normalizeName(fieldDefinition.name, op.suffix.toCamelCase()), index) {

fun createCondition(propertyContainer: PropertyContainer, variablePrefix: String, variableSuffix: String) =
fun createCondition(propertyContainer: PropertyContainer, variablePrefix: String, variableSuffix: String, schemaConfig: SchemaConfig) =
op.resolveCondition(
variablePrefix,
normalizedName,
propertyContainer,
fieldDefinition,
queryField.value,
schemaConfig,
variableSuffix
)

Expand Down
6 changes: 6 additions & 0 deletions core/src/main/resources/neo4j_types.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,9 @@ type _Neo4jPoint{
# * `9157`: represents CRS `cartesian-3d`
srid: Int
}

scalar Date
scalar Time
scalar LocalTime
scalar DateTime
scalar LocalDateTime
Original file line number Diff line number Diff line change
@@ -1,24 +1,20 @@
package org.neo4j.graphql.utils

import graphql.language.InterfaceTypeDefinition
import graphql.language.ScalarTypeDefinition
import graphql.schema.GraphQLScalarType
import graphql.schema.GraphQLSchema
import graphql.schema.GraphQLType
import graphql.schema.diff.DiffSet
import graphql.schema.diff.SchemaDiff
import graphql.schema.diff.reporting.CapturingReporter
import graphql.schema.idl.RuntimeWiring
import graphql.schema.idl.SchemaGenerator
import graphql.schema.idl.SchemaParser
import graphql.schema.idl.SchemaPrinter
import graphql.schema.idl.*
import org.assertj.core.api.Assertions.assertThat
import org.junit.jupiter.api.Assertions
import org.junit.jupiter.api.Assumptions
import org.junit.jupiter.api.DynamicNode
import org.junit.jupiter.api.DynamicTest
import org.neo4j.graphql.DynamicProperties
import org.neo4j.graphql.SchemaBuilder
import org.neo4j.graphql.SchemaConfig
import org.neo4j.graphql.requiredName
import org.neo4j.graphql.*
import org.opentest4j.AssertionFailedError
import java.util.*
import java.util.regex.Pattern
Expand Down Expand Up @@ -51,9 +47,18 @@ class GraphQLSchemaTestSuite(fileName: String) : AsciiDocTestSuite(
reg
.getTypes(InterfaceTypeDefinition::class.java)
.forEach { typeDefinition -> runtimeWiring.type(typeDefinition.name) { it.typeResolver { null } } }
expectedSchema = schemaGenerator.makeExecutableSchema(reg, runtimeWiring
.scalar(DynamicProperties.INSTANCE)
.build())
reg
.scalars()
.filterNot { entry -> ScalarInfo.GRAPHQL_SPECIFICATION_SCALARS_DEFINITIONS.containsKey(entry.key) }
.forEach { (name, definition) ->
runtimeWiring.scalar(GraphQLScalarType.newScalar()
.name(name)
.definition(definition)
.coercing(NoOpCoercing)
.build()
)
}
expectedSchema = schemaGenerator.makeExecutableSchema(reg, runtimeWiring.build())

diff(expectedSchema, augmentedSchema)
diff(augmentedSchema, expectedSchema)
Expand Down
Loading