From 00611ba39dcd2a23305226fe0a84eaeb47dde0a4 Mon Sep 17 00:00:00 2001 From: yliuuuu <107505258+yliuuuu@users.noreply.github.com> Date: Tue, 9 Apr 2024 15:17:42 -0700 Subject: [PATCH] PartiQL Eval - Planner Mode (#1385) --- .../kotlin/partiql.conventions.gradle.kts | 1 + .../operator/rel/RelJoinNestedLoop.kt | 2 +- .../eval/internal/operator/rel/RelOffset.kt | 3 +- .../src/main/resources/partiql_plan.ion | 6 + .../org/partiql/planner/PartiQLPlanner.kt | 1 - .../partiql/planner/PartiQLPlannerBuilder.kt | 14 +- .../org/partiql/planner/internal/Env.kt | 17 +- .../{ => internal}/PartiQLPlannerDefault.kt | 10 +- .../partiql/planner/internal/PlannerFlag.kt | 20 + .../PlanningProblemDetails.kt} | 87 +- .../planner/internal/ProblemGenerator.kt | 88 + .../org/partiql/planner/internal/ir/Nodes.kt | 31 +- .../internal/transforms/PlanTransform.kt | 75 +- .../planner/internal/typer/PlanTyper.kt | 1606 ++++++++--------- .../main/resources/partiql_plan_internal.ion | 26 +- .../kotlin/org/partiql/planner/PlanTest.kt | 55 +- .../planner/PlannerErrorReportingTests.kt | 458 +++++ .../internal/typer/PartiQLTyperTestBase.kt | 4 +- .../planner/internal/typer/PlanTyperTest.kt | 5 +- .../internal/typer/PlanTyperTestsPorted.kt | 330 ++-- .../planner/util/PlanNodeEquivalentVisitor.kt | 16 +- .../test/resources/outputs/basics/select.sql | 34 +- 22 files changed, 1731 insertions(+), 1158 deletions(-) rename partiql-planner/src/main/kotlin/org/partiql/planner/{ => internal}/PartiQLPlannerDefault.kt (82%) create mode 100644 partiql-planner/src/main/kotlin/org/partiql/planner/internal/PlannerFlag.kt rename partiql-planner/src/main/kotlin/org/partiql/planner/{Errors.kt => internal/PlanningProblemDetails.kt} (64%) create mode 100644 partiql-planner/src/main/kotlin/org/partiql/planner/internal/ProblemGenerator.kt create mode 100644 partiql-planner/src/test/kotlin/org/partiql/planner/PlannerErrorReportingTests.kt diff --git a/buildSrc/src/main/kotlin/partiql.conventions.gradle.kts b/buildSrc/src/main/kotlin/partiql.conventions.gradle.kts index 9d07abba0b..b584e08693 100644 --- a/buildSrc/src/main/kotlin/partiql.conventions.gradle.kts +++ b/buildSrc/src/main/kotlin/partiql.conventions.gradle.kts @@ -50,6 +50,7 @@ java { tasks.test { useJUnitPlatform() // Enable JUnit5 jvmArgs.addAll(listOf("-Duser.language=en", "-Duser.country=US")) + jvmArgs.add("-Djunit.jupiter.execution.timeout.mode=disabled_on_debug") maxHeapSize = "4g" testLogging { events.add(TestLogEvent.FAILED) diff --git a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rel/RelJoinNestedLoop.kt b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rel/RelJoinNestedLoop.kt index 685ada0640..ed761c10d2 100644 --- a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rel/RelJoinNestedLoop.kt +++ b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rel/RelJoinNestedLoop.kt @@ -63,7 +63,7 @@ internal abstract class RelJoinNestedLoop : RelPeeking() { toReturn = join(result.isTrue(), lhsRecord!!, rhsRecord) } // Move the pointer to the next row for the RHS - if (toReturn == null) rhsRecord = rhs.next() + if (toReturn == null) rhsRecord = if (rhs.hasNext()) rhs.next() else null } while (toReturn == null) return toReturn diff --git a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rel/RelOffset.kt b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rel/RelOffset.kt index 2b02521cf8..f9bbef6e5c 100644 --- a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rel/RelOffset.kt +++ b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rel/RelOffset.kt @@ -33,11 +33,12 @@ internal class RelOffset( override fun hasNext(): Boolean { if (!init) { - for (record in input) { + while (input.hasNext()) { if (_seen >= _offset) { break } _seen = _seen.add(BigInteger.ONE) + input.next() } init = true } diff --git a/partiql-plan/src/main/resources/partiql_plan.ion b/partiql-plan/src/main/resources/partiql_plan.ion index 3f6f3fade5..2d8d9351b8 100644 --- a/partiql-plan/src/main/resources/partiql_plan.ion +++ b/partiql-plan/src/main/resources/partiql_plan.ion @@ -195,6 +195,12 @@ rex::{ err::{ message: string, + causes: list::['.rex.op'] + }, + + missing::{ + message: string, + causes: list::['.rex.op'] }, ], } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLPlanner.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLPlanner.kt index 14ecc0eb4f..9e5b729307 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLPlanner.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLPlanner.kt @@ -50,7 +50,6 @@ public interface PartiQLPlanner { public val catalogs: Map = emptyMap(), public val instant: Instant = Instant.now(), ) - public companion object { @JvmStatic diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLPlannerBuilder.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLPlannerBuilder.kt index 2b51cddfd9..7af3c275fc 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLPlannerBuilder.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLPlannerBuilder.kt @@ -1,5 +1,7 @@ package org.partiql.planner +import org.partiql.planner.internal.PartiQLPlannerDefault +import org.partiql.planner.internal.PlannerFlag import org.partiql.spi.connector.ConnectorMetadata /** @@ -7,11 +9,14 @@ import org.partiql.spi.connector.ConnectorMetadata * * Usage: * PartiQLPlanner.builder() + * .signalMode() * .addPass(myPass) * .build() */ public class PartiQLPlannerBuilder { + private val flags: MutableSet = mutableSetOf() + private val passes: MutableList = mutableListOf() /** @@ -19,7 +24,7 @@ public class PartiQLPlannerBuilder { * * @return */ - public fun build(): PartiQLPlanner = PartiQLPlannerDefault(passes) + public fun build(): PartiQLPlanner = PartiQLPlannerDefault(passes, flags) /** * Java style method for adding a planner pass to this planner builder. @@ -41,6 +46,13 @@ public class PartiQLPlannerBuilder { this.passes.addAll(passes) } + /** + * Java style method for setting the planner to signal mode + */ + public fun signalMode(): PartiQLPlannerBuilder = this.apply { + this.flags.add(PlannerFlag.SIGNAL_MODE) + } + /** * Java style method for assigning a Catalog name to [ConnectorMetadata]. * diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/Env.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/Env.kt index fcf385d995..664b321122 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/Env.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/Env.kt @@ -93,9 +93,22 @@ internal class Env(private val session: PartiQLPlanner.Session) { // Invoke FnResolver to determine if we made a match val variants = item.handle.entity.getVariants() val match = FnResolver.resolve(variants, args.map { it.type }) + // If Type mismatch, then we return a missingOp whose trace is all possible candidates. if (match == null) { - // unable to make a match, consider returning helpful error messages given the item.variants. - return null + val candidates = variants.map { fnSignature -> + rexOpCallDynamicCandidate( + fn = refFn( + item.catalog, + path = item.handle.path.steps, + signature = fnSignature + ), + coercions = emptyList() + ) + } + return ProblemGenerator.missingRex( + rexOpCallDynamic(args, candidates, false), + ProblemGenerator.incompatibleTypesForOp(args.map { it.type }, path.normalized.joinToString(".")) + ) } return when (match) { is FnMatch.Dynamic -> { diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLPlannerDefault.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/PartiQLPlannerDefault.kt similarity index 82% rename from partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLPlannerDefault.kt rename to partiql-planner/src/main/kotlin/org/partiql/planner/internal/PartiQLPlannerDefault.kt index e9d8f777c6..a9815a8bdb 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLPlannerDefault.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/PartiQLPlannerDefault.kt @@ -1,9 +1,10 @@ -package org.partiql.planner +package org.partiql.planner.internal import org.partiql.ast.Statement import org.partiql.ast.normalize.normalize import org.partiql.errors.ProblemCallback -import org.partiql.planner.internal.Env +import org.partiql.planner.PartiQLPlanner +import org.partiql.planner.PartiQLPlannerPass import org.partiql.planner.internal.transforms.AstToPlan import org.partiql.planner.internal.transforms.PlanTransform import org.partiql.planner.internal.typer.PlanTyper @@ -13,6 +14,7 @@ import org.partiql.planner.internal.typer.PlanTyper */ internal class PartiQLPlannerDefault( private val passes: List, + private val flags: Set ) : PartiQLPlanner { override fun plan( @@ -31,12 +33,12 @@ internal class PartiQLPlannerDefault( val root = AstToPlan.apply(ast, env) // 3. Resolve variables - val typer = PlanTyper(env, onProblem) + val typer = PlanTyper(env) val typed = typer.resolve(root) val internal = org.partiql.planner.internal.ir.PartiQLPlan(typed) // 4. Assert plan has been resolved — translating to public API - var plan = PlanTransform.transform(internal, onProblem) + var plan = PlanTransform(flags).transform(internal, onProblem) // 5. Apply all passes for (pass in passes) { diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/PlannerFlag.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/PlannerFlag.kt new file mode 100644 index 0000000000..bd91be0535 --- /dev/null +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/PlannerFlag.kt @@ -0,0 +1,20 @@ +package org.partiql.planner.internal + +internal enum class PlannerFlag { + /** + * Determine the planner behavior upon encounter an operation that always returns MISSING. + * + * If this flag is included: + * + * The problematic operation will be tracked in problem callback as a error. + * + * The result plan will turn the problematic operation into an error node. + * + * Otherwise: + * + * The problematic operation will be tracked in problem callback as a missing. + * + * The result plan will turn the problematic operation into a missing node. + */ + SIGNAL_MODE +} diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/Errors.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/PlanningProblemDetails.kt similarity index 64% rename from partiql-planner/src/main/kotlin/org/partiql/planner/Errors.kt rename to partiql-planner/src/main/kotlin/org/partiql/planner/internal/PlanningProblemDetails.kt index 8d388a17b7..e05ef49e11 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/Errors.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/PlanningProblemDetails.kt @@ -1,4 +1,4 @@ -package org.partiql.planner +package org.partiql.planner.internal import org.partiql.errors.ProblemDetails import org.partiql.errors.ProblemSeverity @@ -13,21 +13,45 @@ import org.partiql.types.StaticType * This information can be used to generate end-user readable error messages and is also easy to assert * equivalence in unit tests. */ -public sealed class PlanningProblemDetails( +internal open class PlanningProblemDetails( override val severity: ProblemSeverity, - public val messageFormatter: () -> String, + val messageFormatter: () -> String, ) : ProblemDetails { + companion object { + private fun quotationHint(caseSensitive: Boolean) = + if (caseSensitive) { + // Individuals that are new to SQL often try to use double quotes for string literals. + // Let's help them out a bit. + " Hint: did you intend to use single-quotes (') here? Remember that double-quotes (\") denote " + + "quoted identifiers and single-quotes denote strings." + } else { + "" + } + + private fun Identifier.sql(): String = when (this) { + is Identifier.Qualified -> this.sql() + is Identifier.Symbol -> this.sql() + } + + private fun Identifier.Qualified.sql(): String = root.sql() + "." + steps.joinToString(".") { it.sql() } + + private fun Identifier.Symbol.sql(): String = when (caseSensitivity) { + Identifier.CaseSensitivity.SENSITIVE -> "\"$symbol\"" + Identifier.CaseSensitivity.INSENSITIVE -> symbol + } + } + override fun toString(): String = message override val message: String get() = messageFormatter() - public data class ParseError(val parseErrorMessage: String) : + data class ParseError(val parseErrorMessage: String) : PlanningProblemDetails(ProblemSeverity.ERROR, { parseErrorMessage }) - public data class CompileError(val errorMessage: String) : + data class CompileError(val errorMessage: String) : PlanningProblemDetails(ProblemSeverity.ERROR, { errorMessage }) - public data class UndefinedVariable(val id: BindingPath) : + data class UndefinedVariable(val id: BindingPath) : PlanningProblemDetails( ProblemSeverity.ERROR, { @@ -37,7 +61,7 @@ public sealed class PlanningProblemDetails( } ) - public data class UndefinedDmlTarget(val variableName: String, val caseSensitive: Boolean) : + data class UndefinedDmlTarget(val variableName: String, val caseSensitive: Boolean) : PlanningProblemDetails( ProblemSeverity.ERROR, { @@ -47,25 +71,25 @@ public sealed class PlanningProblemDetails( } ) - public data class VariablePreviouslyDefined(val variableName: String) : + data class VariablePreviouslyDefined(val variableName: String) : PlanningProblemDetails( ProblemSeverity.ERROR, { "The variable '$variableName' was previously defined." } ) - public data class UnimplementedFeature(val featureName: String) : + data class UnimplementedFeature(val featureName: String) : PlanningProblemDetails( ProblemSeverity.ERROR, { "The syntax at this location is valid but utilizes unimplemented PartiQL feature '$featureName'" } ) - public object InvalidDmlTarget : + object InvalidDmlTarget : PlanningProblemDetails( ProblemSeverity.ERROR, { "Expression is not a valid DML target. Hint: only table names are allowed here." } ) - public object InsertValueDisallowed : + object InsertValueDisallowed : PlanningProblemDetails( ProblemSeverity.ERROR, { @@ -74,7 +98,7 @@ public sealed class PlanningProblemDetails( } ) - public object InsertValuesDisallowed : + object InsertValuesDisallowed : PlanningProblemDetails( ProblemSeverity.ERROR, { @@ -83,14 +107,14 @@ public sealed class PlanningProblemDetails( } ) - public data class UnexpectedType( + data class UnexpectedType( val actualType: StaticType, val expectedTypes: Set, ) : PlanningProblemDetails(ProblemSeverity.ERROR, { "Unexpected type $actualType, expected one of ${expectedTypes.joinToString()}" }) - public data class UnknownFunction( + data class UnknownFunction( val identifier: String, val args: List, ) : PlanningProblemDetails(ProblemSeverity.ERROR, { @@ -98,12 +122,17 @@ public sealed class PlanningProblemDetails( "Unknown function `$identifier($types)" }) - public object ExpressionAlwaysReturnsNullOrMissing : PlanningProblemDetails( + data class ExpressionAlwaysReturnsMissing(val reason: String? = null) : PlanningProblemDetails( + severity = ProblemSeverity.ERROR, + messageFormatter = { "Expression always returns null or missing: caused by $reason" } + ) + + object ExpressionAlwaysReturnsNullOrMissing : PlanningProblemDetails( severity = ProblemSeverity.ERROR, messageFormatter = { "Expression always returns null or missing." } ) - public data class InvalidArgumentTypeForFunction( + data class InvalidArgumentTypeForFunction( val functionName: String, val expectedType: StaticType, val actualType: StaticType, @@ -113,7 +142,7 @@ public sealed class PlanningProblemDetails( messageFormatter = { "Invalid argument type for $functionName. Expected $expectedType but got $actualType" } ) - public data class IncompatibleTypesForOp( + data class IncompatibleTypesForOp( val actualTypes: List, val operator: String, ) : @@ -122,31 +151,9 @@ public sealed class PlanningProblemDetails( messageFormatter = { "${actualTypes.joinToString()} is/are incompatible data types for the '$operator' operator." } ) - public data class UnresolvedExcludeExprRoot(val root: String) : + data class UnresolvedExcludeExprRoot(val root: String) : PlanningProblemDetails( ProblemSeverity.ERROR, { "Exclude expression given an unresolvable root '$root'" } ) } - -private fun quotationHint(caseSensitive: Boolean) = - if (caseSensitive) { - // Individuals that are new to SQL often try to use double quotes for string literals. - // Let's help them out a bit. - " Hint: did you intend to use single-quotes (') here? Remember that double-quotes (\") denote " + - "quoted identifiers and single-quotes denote strings." - } else { - "" - } - -private fun Identifier.sql(): String = when (this) { - is Identifier.Qualified -> this.sql() - is Identifier.Symbol -> this.sql() -} - -private fun Identifier.Qualified.sql(): String = root.sql() + "." + steps.joinToString(".") { it.sql() } - -private fun Identifier.Symbol.sql(): String = when (caseSensitivity) { - Identifier.CaseSensitivity.SENSITIVE -> "\"$symbol\"" - Identifier.CaseSensitivity.INSENSITIVE -> symbol -} diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ProblemGenerator.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ProblemGenerator.kt new file mode 100644 index 0000000000..92f137a557 --- /dev/null +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ProblemGenerator.kt @@ -0,0 +1,88 @@ +package org.partiql.planner.internal + +import org.partiql.errors.Problem +import org.partiql.errors.ProblemDetails +import org.partiql.errors.ProblemLocation +import org.partiql.errors.ProblemSeverity +import org.partiql.errors.UNKNOWN_PROBLEM_LOCATION +import org.partiql.planner.internal.ir.Rex +import org.partiql.planner.internal.ir.rex +import org.partiql.planner.internal.ir.rexOpErr +import org.partiql.planner.internal.ir.rexOpMissing +import org.partiql.spi.BindingPath +import org.partiql.types.StaticType +import org.partiql.planner.internal.ir.Identifier as InternalIdentifier + +/** + * Used to report problems during planning phase. + */ +internal object ProblemGenerator { + fun problem(problemLocation: ProblemLocation, problemDetails: ProblemDetails): Problem = Problem( + problemLocation, + problemDetails + ) + + fun asWarning(problem: Problem): Problem { + val details = problem.details as PlanningProblemDetails + return if (details.severity == ProblemSeverity.WARNING) problem + else Problem( + problem.sourceLocation, + PlanningProblemDetails(ProblemSeverity.WARNING, details.messageFormatter) + ) + } + fun asError(problem: Problem): Problem { + val details = problem.details as PlanningProblemDetails + return if (details.severity == ProblemSeverity.ERROR) problem + else Problem( + problem.sourceLocation, + PlanningProblemDetails(ProblemSeverity.ERROR, details.messageFormatter) + ) + } + + fun missingRex(causes: List, problem: Problem): Rex = + rex(StaticType.MISSING, rexOpMissing(problem, causes)) + + fun missingRex(causes: Rex.Op, problem: Problem): Rex = + rex(StaticType.MISSING, rexOpMissing(problem, listOf(causes))) + + fun errorRex(causes: List, problem: Problem): Rex = + rex(StaticType.ANY, rexOpErr(problem, causes)) + + fun errorRex(trace: Rex.Op, problem: Problem): Rex = + rex(StaticType.ANY, rexOpErr(problem, listOf(trace))) + + private fun InternalIdentifier.debug(): String = when (this) { + is InternalIdentifier.Qualified -> (listOf(root.debug()) + steps.map { it.debug() }).joinToString(".") + is InternalIdentifier.Symbol -> when (caseSensitivity) { + InternalIdentifier.CaseSensitivity.SENSITIVE -> "\"$symbol\"" + InternalIdentifier.CaseSensitivity.INSENSITIVE -> symbol + } + } + + fun undefinedFunction(identifier: InternalIdentifier, args: List, location: ProblemLocation = UNKNOWN_PROBLEM_LOCATION): Problem = + problem(location, PlanningProblemDetails.UnknownFunction(identifier.debug(), args)) + + fun undefinedFunction(identifier: String, args: List, location: ProblemLocation = UNKNOWN_PROBLEM_LOCATION): Problem = + problem(location, PlanningProblemDetails.UnknownFunction(identifier, args)) + + fun undefinedVariable(id: BindingPath, location: ProblemLocation = UNKNOWN_PROBLEM_LOCATION): Problem = + problem(location, PlanningProblemDetails.UndefinedVariable(id)) + + fun incompatibleTypesForOp(actualTypes: List, operator: String, location: ProblemLocation = UNKNOWN_PROBLEM_LOCATION): Problem = + problem(location, PlanningProblemDetails.IncompatibleTypesForOp(actualTypes, operator)) + + fun unresolvedExcludedExprRoot(root: InternalIdentifier, location: ProblemLocation = UNKNOWN_PROBLEM_LOCATION): Problem = + problem(location, PlanningProblemDetails.UnresolvedExcludeExprRoot(root.debug())) + + fun unresolvedExcludedExprRoot(root: String, location: ProblemLocation = UNKNOWN_PROBLEM_LOCATION): Problem = + problem(location, PlanningProblemDetails.UnresolvedExcludeExprRoot(root)) + + fun expressionAlwaysReturnsMissing(reason: String? = null, location: ProblemLocation = UNKNOWN_PROBLEM_LOCATION): Problem = + problem(location, PlanningProblemDetails.ExpressionAlwaysReturnsMissing(reason)) + + fun unexpectedType(actualType: StaticType, expectedTypes: Set, location: ProblemLocation = UNKNOWN_PROBLEM_LOCATION): Problem = + problem(location, PlanningProblemDetails.UnexpectedType(actualType, expectedTypes)) + + fun compilerError(message: String, location: ProblemLocation = UNKNOWN_PROBLEM_LOCATION): Problem = + problem(location, PlanningProblemDetails.CompileError(message)) +} diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ir/Nodes.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ir/Nodes.kt index 3b479a787e..691572a1dd 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ir/Nodes.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ir/Nodes.kt @@ -5,6 +5,7 @@ package org.partiql.planner.`internal`.ir +import org.partiql.errors.Problem import org.partiql.planner.internal.ir.builder.IdentifierQualifiedBuilder import org.partiql.planner.internal.ir.builder.IdentifierSymbolBuilder import org.partiql.planner.internal.ir.builder.PartiQlPlanBuilder @@ -274,6 +275,7 @@ internal data class Rex( is Select -> visitor.visitRexOpSelect(this, ctx) is TupleUnion -> visitor.visitRexOpTupleUnion(this, ctx) is Err -> visitor.visitRexOpErr(this, ctx) + is Missing -> visitor.visitRexOpMissing(this, ctx) } internal data class Lit( @@ -736,9 +738,14 @@ internal data class Rex( } internal data class Err( - @JvmField internal val message: String, + @JvmField internal val problem: Problem, + @JvmField internal val causes: List, ) : Op() { - public override val children: List = emptyList() + public override val children: List by lazy { + val kids = mutableListOf() + kids.addAll(causes) + kids.filterNotNull() + } public override fun accept(visitor: PlanVisitor, ctx: C): R = visitor.visitRexOpErr(this, ctx) @@ -747,6 +754,24 @@ internal data class Rex( internal fun builder(): RexOpErrBuilder = RexOpErrBuilder() } } + + internal data class Missing( + @JvmField internal val problem: Problem, + @JvmField internal val causes: List, + ) : Op() { + public override val children: List by lazy { + val kids = mutableListOf() + kids.addAll(causes) + kids.filterNotNull() + } + + public override fun accept(visitor: PlanVisitor, ctx: C): R = visitor.visitRexOpMissing(this, ctx) + + internal companion object { + @JvmStatic + internal fun builder(): RexOpErrBuilder = RexOpErrBuilder() + } + } } internal companion object { @@ -1186,7 +1211,7 @@ internal data class Rel( visitor.visitRelOpExclude(this, ctx) internal data class Path( - @JvmField internal val root: Rex.Op.Var, + @JvmField internal val root: Rex.Op, @JvmField internal val steps: List, ) : PlanNode() { public override val children: List by lazy { diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/PlanTransform.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/PlanTransform.kt index e7102da68f..b3f3179ebe 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/PlanTransform.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/PlanTransform.kt @@ -4,6 +4,9 @@ import org.partiql.errors.ProblemCallback import org.partiql.plan.PlanNode import org.partiql.plan.partiQLPlan import org.partiql.plan.rexOpCast +import org.partiql.plan.rexOpErr +import org.partiql.planner.internal.PlannerFlag +import org.partiql.planner.internal.ProblemGenerator import org.partiql.planner.internal.ir.Identifier import org.partiql.planner.internal.ir.PartiQLPlan import org.partiql.planner.internal.ir.Ref @@ -21,11 +24,14 @@ import org.partiql.value.PartiQLValueExperimental * * Ideally this class becomes very small as the internal IR will be a thin wrapper over the public API. */ -internal object PlanTransform { +internal class PlanTransform( + flags: Set +) { + private val signalMode = flags.contains(PlannerFlag.SIGNAL_MODE) fun transform(node: PartiQLPlan, onProblem: ProblemCallback): org.partiql.plan.PartiQLPlan { val symbols = Symbols.empty() - val visitor = Visitor(symbols, onProblem) + val visitor = Visitor(symbols, signalMode, onProblem) val statement = visitor.visitStatement(node.statement, Unit) return partiQLPlan( catalogs = symbols.build(), @@ -35,6 +41,7 @@ internal object PlanTransform { private class Visitor( private val symbols: Symbols, + private val signalMode: Boolean, private val onProblem: ProblemCallback, ) : PlanBaseVisitor() { @@ -116,7 +123,7 @@ internal object PlanTransform { super.visitRexOpVar(node, ctx) as org.partiql.plan.Rex.Op override fun visitRexOpVarUnresolved(node: Rex.Op.Var.Unresolved, ctx: Unit) = - org.partiql.plan.Rex.Op.Err("Unresolved variable $node") + error("The Internal Plan Node Rex.Op.Var.Unresolved should be converted to an MISSING Node during type resolution if resolution failed") override fun visitRexOpVarGlobal(node: Rex.Op.Var.Global, ctx: Unit) = org.partiql.plan.Rex.Op.Global( ref = visitRef(node.ref, ctx) @@ -163,7 +170,7 @@ internal object PlanTransform { } override fun visitRexOpCallUnresolved(node: Rex.Op.Call.Unresolved, ctx: Unit): PlanNode { - error("Unresolved function ${node.identifier}") + error("The Internal Node Rex.Op.Call.Unresolved should be converted to an Err Node during type resolution if resolution failed") } override fun visitRexOpCallStatic(node: Rex.Op.Call.Static, ctx: Unit): org.partiql.plan.Rex.Op { @@ -232,7 +239,28 @@ internal object PlanTransform { override fun visitRexOpTupleUnion(node: Rex.Op.TupleUnion, ctx: Unit) = org.partiql.plan.Rex.Op.TupleUnion(args = node.args.map { visitRex(it, ctx) }) - override fun visitRexOpErr(node: Rex.Op.Err, ctx: Unit) = org.partiql.plan.Rex.Op.Err(node.message) + override fun visitRexOpErr(node: Rex.Op.Err, ctx: Unit): PlanNode { + // track the error in call back + val trace = node.causes.map { visitRexOp(it, ctx) } + onProblem(ProblemGenerator.asError(node.problem)) + return org.partiql.plan.Rex.Op.Err(node.problem.toString(), trace) + } + + @OptIn(PartiQLValueExperimental::class) + override fun visitRexOpMissing(node: Rex.Op.Missing, ctx: Unit): PlanNode { + // gather problem from subtree. + val trace = node.causes.map { visitRexOp(it, ctx) } + return when (signalMode) { + true -> { + onProblem.invoke(ProblemGenerator.asError(node.problem)) + rexOpErr(node.problem.toString(), trace) + } + false -> { + onProblem.invoke(ProblemGenerator.asWarning(node.problem)) + org.partiql.plan.rexOpMissing(node.problem.toString(), trace) + } + } + } // RELATION OPERATORS @@ -361,20 +389,31 @@ internal object PlanTransform { override fun visitRelOpExclude(node: Rel.Op.Exclude, ctx: Unit) = org.partiql.plan.Rel.Op.Exclude( input = visitRel(node.input, ctx), - paths = node.paths.map { visitRelOpExcludePath(it, ctx) }, - ) - - override fun visitRelOpExcludePath(node: Rel.Op.Exclude.Path, ctx: Unit): org.partiql.plan.Rel.Op.Exclude.Path { - val root = when (node.root) { - is Rex.Op.Var.Unresolved -> org.partiql.plan.Rex.Op.Var(-1, -1) // unresolved in `PlanTyper` results in error - is Rex.Op.Var.Local -> visitRexOpVarLocal(node.root, ctx) - is Rex.Op.Var.Global -> error("EXCLUDE only disallows values coming from the input record.") + paths = node.paths.mapNotNull { + val root = when (val root = it.root) { + is Rex.Op.Var.Unresolved -> error("EXCLUDE expression has an unresolvable root") // unresolved in `PlanTyper` results in error + is Rex.Op.Var.Local -> visitRexOpVarLocal(root, ctx) + is Rex.Op.Var.Global -> error("EXCLUDE only disallows values coming from the input record.") + is Rex.Op.Err -> { + // trace error + visitRexOpErr(root, ctx) + // this is: an erroneous exclude path is removed for continuation + return@mapNotNull null + } + is Rex.Op.Missing -> { + // trace missing + visitRexOpMissing(root, ctx) + // this is: an exclude path that always returns missing is removed for continuation + return@mapNotNull null + } + else -> error("Should be converted to an error node") + } + org.partiql.plan.Rel.Op.Exclude.Path( + root = root, + steps = it.steps.map { visitRelOpExcludeStep(it, ctx) }, + ) } - return org.partiql.plan.Rel.Op.Exclude.Path( - root = root, - steps = node.steps.map { visitRelOpExcludeStep(it, ctx) }, - ) - } + ) override fun visitRelOpExcludeStep(node: Rel.Op.Exclude.Step, ctx: Unit): org.partiql.plan.Rel.Op.Exclude.Step { return org.partiql.plan.Rel.Op.Exclude.Step( diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt index 9dbf331d94..29043a72ae 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt @@ -16,11 +16,8 @@ package org.partiql.planner.internal.typer -import org.partiql.errors.Problem -import org.partiql.errors.ProblemCallback -import org.partiql.errors.UNKNOWN_PROBLEM_LOCATION -import org.partiql.planner.PlanningProblemDetails import org.partiql.planner.internal.Env +import org.partiql.planner.internal.ProblemGenerator import org.partiql.planner.internal.exclude.ExcludeRepr import org.partiql.planner.internal.ir.Identifier import org.partiql.planner.internal.ir.PlanNode @@ -47,9 +44,9 @@ import org.partiql.planner.internal.ir.relOpSort import org.partiql.planner.internal.ir.relOpUnpivot import org.partiql.planner.internal.ir.relType import org.partiql.planner.internal.ir.rex +import org.partiql.planner.internal.ir.rexOpCase import org.partiql.planner.internal.ir.rexOpCaseBranch import org.partiql.planner.internal.ir.rexOpCollection -import org.partiql.planner.internal.ir.rexOpErr import org.partiql.planner.internal.ir.rexOpLit import org.partiql.planner.internal.ir.rexOpPathIndex import org.partiql.planner.internal.ir.rexOpPathKey @@ -91,7 +88,6 @@ import org.partiql.value.BoolValue import org.partiql.value.PartiQLValueExperimental import org.partiql.value.TextValue import org.partiql.value.boolValue -import org.partiql.value.missingValue import org.partiql.value.stringValue /** @@ -102,8 +98,7 @@ import org.partiql.value.stringValue */ @OptIn(PartiQLValueExperimental::class) internal class PlanTyper( - private val env: Env, - private val onProblem: ProblemCallback, + private val env: Env ) { /** @@ -247,13 +242,21 @@ internal class PlanTyper( // type limit expression using outer scope with global resolution // TODO: Assert expression doesn't contain locals or upvalues. val limit = node.limit.type(input.type.schema, outer, Scope.GLOBAL) - // check types - assertAsInt(limit.type) // compute output schema val type = input.type - // rewrite - val op = relOpLimit(input, limit) - return rel(type, op) + // check types + return if (!assertAsInt(limit.type)) + rel( + type, + relOpLimit( + input, + ProblemGenerator.missingRex( + listOf(limit.op), + ProblemGenerator.unexpectedType(limit.type, setOf(StaticType.INT)) + ) + ) + ) + else rel(type, relOpLimit(input, limit)) } override fun visitRelOpOffset(node: Rel.Op.Offset, ctx: Rel.Type?): Rel { @@ -262,13 +265,21 @@ internal class PlanTyper( // type offset expression using outer scope with global resolution // TODO: Assert expression doesn't contain locals or upvalues. val offset = node.offset.type(input.type.schema, outer, Scope.GLOBAL) - // check types - assertAsInt(offset.type) // compute output schema val type = input.type - // rewrite - val op = relOpOffset(input, offset) - return rel(type, op) + // check types + return if (!assertAsInt(offset.type)) + rel( + type, + relOpOffset( + input, + ProblemGenerator.missingRex( + listOf(offset.op), + ProblemGenerator.unexpectedType(offset.type, setOf(StaticType.INT)) + ) + ) + ) + else rel(type, relOpOffset(input, offset)) } override fun visitRelOpProject(node: Rel.Op.Project, ctx: Rel.Type?): Rel { @@ -340,6 +351,8 @@ internal class PlanTyper( * some other semantic pass * - currently does not give an error */ + + // TODO: better error reporting with exclude. override fun visitRelOpExclude(node: Rel.Op.Exclude, ctx: Rel.Type?): Rel { // compute input schema val input = visitRel(node.input, ctx) @@ -360,17 +373,21 @@ internal class PlanTyper( val path = root.identifier.toBindingPath() val resolved = locals.resolve(path) if (resolved == null) { - handleUnresolvedExcludeRoot(root.identifier) - root + ProblemGenerator.missingRex( + emptyList(), + ProblemGenerator.unresolvedExcludedExprRoot(root.identifier) + ).op } else { // root of exclude is always a symbol resolved.op as Rex.Op.Var } } is Rex.Op.Var.Local, is Rex.Op.Var.Global -> root + else -> error("Expect exclude path root to be Rex.Op.Var") } relOpExcludePath(resolvedRoot, path.steps) } + val subsumedPaths = newPaths .groupBy(keySelector = { it.root }, valueTransform = { it.steps }) // combine exclude paths with the same resolved root before subsumption .map { (root, allSteps) -> @@ -382,6 +399,9 @@ internal class PlanTyper( } override fun visitRelOpAggregate(node: Rel.Op.Aggregate, ctx: Rel.Type?): Rel { + // TODO: Do we need to report aggregation call always returns MISSING? + // Currently aggregation is part of the rel op + // The rel op should produce a set of binding tuple, which missing should be allowed. // compute input schema val input = visitRel(node.input, ctx) @@ -454,10 +474,8 @@ internal class PlanTyper( Scope.LOCAL -> locals.resolve(path) ?: env.resolveObj(path) Scope.GLOBAL -> env.resolveObj(path) ?: locals.resolve(path) } - if (resolvedVar == null) { - handleUndefinedVariable(node.identifier) - return rexErr("Undefined variable `${node.identifier.debug()}`") - } + // Trace for unresolved var is empty for now. + ?: return ProblemGenerator.missingRex(emptyList(), ProblemGenerator.undefinedVariable(path)) return visitRex(resolvedVar, null) } @@ -466,17 +484,21 @@ internal class PlanTyper( override fun visitRexOpPathIndex(node: Rex.Op.Path.Index, ctx: StaticType?): Rex { val root = visitRex(node.root, node.root.type) val key = visitRex(node.key, node.key.type) - if (key.type !is IntType) { - handleAlwaysMissing() - return rex(MISSING, rexOpErr("Collections must be indexed with integers, found ${key.type}")) - } val elementTypes = root.type.allTypes.map { type -> - val rootType = type as? CollectionType ?: return@map MISSING - if (rootType !is ListType && rootType !is SexpType) { + if (type !is ListType && type !is SexpType) { return@map MISSING } - rootType.elementType + (type as CollectionType).elementType }.toSet() + + // TODO: For now we just log a single error. + // Ideally we can log more detailed information such as key not integer, etc. + if (elementTypes.all { it is MissingType } || key.type !is IntType) { + return ProblemGenerator.missingRex( + rexOpPathIndex(root, key), + ProblemGenerator.expressionAlwaysReturnsMissing("Path Navigation always returns MISSING") + ) + } val finalType = unionOf(elementTypes).flatten() return rex(finalType.swallowAny(), rexOpPathIndex(root, key)) } @@ -493,10 +515,6 @@ internal class PlanTyper( else -> MISSING } } - if (toAddTypes.size == key.type.allTypes.size && toAddTypes.all { it is MissingType }) { - handleAlwaysMissing() - return rex(MISSING, rexOpErr("Expected string but found: ${key.type}")) - } val pathTypes = root.type.allTypes.map { type -> val struct = type as? StructType ?: return@map MISSING @@ -507,7 +525,7 @@ internal class PlanTyper( val id = identifierSymbol(lit.string!!, Identifier.CaseSensitivity.SENSITIVE) inferStructLookup(struct, id).first } else { - error("Expected text literal, but got $lit") + return@map MISSING } } else { // cannot infer type of non-literal path step because we don't know its value @@ -515,6 +533,15 @@ internal class PlanTyper( ANY } }.toSet() + + // TODO: For now, we just log a single error. + // Ideally we can add more details such as key is not an text, root is not a struct, etc. + if (pathTypes.all { it == MISSING } || + (toAddTypes.size == key.type.allTypes.size && toAddTypes.all { it is MissingType }) // key value check + ) return ProblemGenerator.missingRex( + rexOpPathKey(root, key), + ProblemGenerator.expressionAlwaysReturnsMissing("Path Navigation failed") + ) val finalType = unionOf(pathTypes + toAddTypes).flatten() return rex(finalType.swallowAny(), rexOpPathKey(root, key)) } @@ -522,8 +549,8 @@ internal class PlanTyper( override fun visitRexOpPathSymbol(node: Rex.Op.Path.Symbol, ctx: StaticType?): Rex { val root = visitRex(node.root, node.root.type) - val paths = root.type.allTypes.map { type -> - val struct = type as? StructType ?: return@map rex(MISSING, rexOpLit(missingValue())) + val paths = root.type.allTypes.mapNotNull { type -> + val struct = type as? StructType ?: return@mapNotNull null val (pathType, replacementId) = inferStructLookup( struct, identifierSymbol(node.key, Identifier.CaseSensitivity.INSENSITIVE) ) @@ -534,7 +561,16 @@ internal class PlanTyper( ) } } + + if (paths.isEmpty()) return ProblemGenerator.missingRex( + rexOpPathSymbol(root, node.key), + ProblemGenerator.expressionAlwaysReturnsMissing("Path Navigation failed - Expect Root to be of type Struct but is ${root.type}") + ) val type = unionOf(paths.map { it.type }.toSet()).flatten() + if (type is MissingType) return ProblemGenerator.missingRex( + rexOpPathSymbol(root, node.key), + ProblemGenerator.expressionAlwaysReturnsMissing("Path Navigation always returns MISSING") + ) // replace step only if all are disambiguated val firstPathOp = paths.first().op @@ -559,22 +595,12 @@ internal class PlanTyper( private fun rexString(str: String) = rex(STRING, rexOpLit(stringValue(str))) - override fun visitRexOpPath(node: Rex.Op.Path, ctx: StaticType?): Rex { - val path = super.visitRexOpPath(node, ctx) as Rex - if (path.type == MISSING) { - handleAlwaysMissing() - return rexErr("Path always returns missing: ${node.debug()}") - } - return path - } - override fun visitRexOpCastUnresolved(node: Rex.Op.Cast.Unresolved, ctx: StaticType?): Rex { val arg = visitRex(node.arg, null) - val cast = env.resolveCast(arg, node.target) - if (cast == null) { - handleUnknownCast(node) - return rexErr("Invalid CAST operator") - } + val cast = env.resolveCast(arg, node.target) ?: return ProblemGenerator.errorRex( + node.copy(node.target, arg), + ProblemGenerator.undefinedFunction("CAST( AS ${node.target})", listOf(arg.type)) + ) return visitRexOpCastResolved(cast, null) } @@ -592,902 +618,812 @@ internal class PlanTyper( override fun visitRexOpCallUnresolved(node: Rex.Op.Call.Unresolved, ctx: StaticType?): Rex { // Type the arguments - val args = node.args.map { - val arg = visitRex(it, null) - if (arg.op is Rex.Op.Err) { - // don't attempt to resolve a function which has erroneous arguments. - return arg - } - arg - } + val args = node.args.map { visitRex(it, null) } + // Attempt to resolve in the environment val path = node.identifier.toBindingPath() val rex = env.resolveFn(path, args) - if (rex == null) { - handleUnknownFunction(node, args) - val name = node.identifier.debug() - val types = args.joinToString { "<${it.type}>" } - return rexErr("Unable to resolve function $name($types)") + ?: return ProblemGenerator.errorRex( + args.map { it.op }, ProblemGenerator.undefinedFunction(node.identifier, args.map { it.type }) + ) + // Pass off to Rex.Op.Call.Static or Rex.Op.Call.Dynamic for typing. + return visitRex(rex, null) } - // Pass off to Rex.Op.Call.Static or Rex.Op.Call.Dynamic for typing. - return visitRex(rex, null) - } - /** - * Resolve and type scalar function calls. - * - * @param node - * @param ctx - * @return - */ - @OptIn(FnExperimental::class) - override fun visitRexOpCallStatic(node: Rex.Op.Call.Static, ctx: StaticType?): Rex { - // Apply the coercions as explicit casts - val args: List = node.args.map { - // Propagate MISSING argument. - if (it.type == MissingType && node.fn.signature.isMissingCall) { - handleAlwaysMissing() - return rex(MISSING, node) - } - // Type the coercions - when (val op = it.op) { - is Rex.Op.Cast.Resolved -> visitRexOpCastResolved(op, null) - else -> it + /** + * Resolve and type scalar function calls. + * + * @param node + * @param ctx + * @return + */ + @OptIn(FnExperimental::class) + override fun visitRexOpCallStatic(node: Rex.Op.Call.Static, ctx: StaticType?): Rex { + // Apply the coercions as explicit casts + val args: List = node.args.map { + // Type the coercions + when (val op = it.op) { + is Rex.Op.Cast.Resolved -> visitRexOpCastResolved(op, null) + else -> it + } } + // Infer fn return type + val type = inferFnType(node.fn.signature, args) + if (type is MissingType) + return ProblemGenerator.missingRex(node, ProblemGenerator.expressionAlwaysReturnsMissing("function always returns missing")) + return rex(type, node) } - // Infer fn return type - val type = inferFnType(node.fn.signature, args) - return rex(type, node) - } - /** - * Typing of a dynamic function call. - * - * isMissable TRUE when the argument permutations may not definitively invoke one of the candidates. - * You can think of [isMissable] as being the same as "not exhaustive". For example, if we have ABS(INT | STRING), then - * this function call [isMissable] because there isn't an `ABS(STRING)` function signature AKA we haven't exhausted - * all the arguments. On the other hand, take an "exhaustive" scenario: ABS(INT | DEC). In this case, [isMissable] - * is false because we have functions for each potential argument AKA we have exhausted the arguments. - * - * - * @param node - * @param ctx - * @return - */ - @OptIn(FnExperimental::class) - override fun visitRexOpCallDynamic(node: Rex.Op.Call.Dynamic, ctx: StaticType?): Rex { - var isMissingCall = false - val types = node.candidates.map { candidate -> - isMissingCall = isMissingCall || candidate.fn.signature.isMissingCall - inferFnType(candidate.fn.signature, node.args) - }.toMutableSet() - - // We had a branch (arg type permutation) without a candidate. - if (!node.exhaustive) { - types.add(MISSING) - } + /** + * Typing of a dynamic function call. + * + * isMissable TRUE when the argument permutations may not definitively invoke one of the candidates. + * You can think of [isMissable] as being the same as "not exhaustive". For example, if we have ABS(INT | STRING), then + * this function call [isMissable] because there isn't an `ABS(STRING)` function signature AKA we haven't exhausted + * all the arguments. On the other hand, take an "exhaustive" scenario: ABS(INT | DEC). In this case, [isMissable] + * is false because we have functions for each potential argument AKA we have exhausted the arguments. + * + * + * @param node + * @param ctx + * @return + */ + @OptIn(FnExperimental::class) + override fun visitRexOpCallDynamic(node: Rex.Op.Call.Dynamic, ctx: StaticType?): Rex { + var isMissingCall = false + val types = node.candidates.map { candidate -> + isMissingCall = isMissingCall || candidate.fn.signature.isMissingCall + inferFnType(candidate.fn.signature, node.args) + }.toMutableSet() + + // We had a branch (arg type permutation) without a candidate. + if (!node.exhaustive) { + types.add(MISSING) + } - return rex(type = unionOf(types).flatten(), op = node) - } + return rex(type = unionOf(types).flatten(), op = node) + } - override fun visitRexOpCase(node: Rex.Op.Case, ctx: StaticType?): Rex { - // Type branches and prune branches known to never execute - val newBranches = node.branches.map { visitRexOpCaseBranch(it, it.rex.type) } - .filterNot { isLiteralBool(it.condition, false) } - - newBranches.forEach { branch -> - if (canBeBoolean(branch.condition.type).not()) { - onProblem.invoke( - Problem( - UNKNOWN_PROBLEM_LOCATION, - PlanningProblemDetails.IncompatibleTypesForOp(branch.condition.type.allTypes, "CASE_WHEN") - ) + override fun visitRexOpCase(node: Rex.Op.Case, ctx: StaticType?): Rex { + // Type branches and prune branches known to never execute + val newBranches = node.branches.map { visitRexOpCaseBranch(it, it.rex.type) } + .filterNot { isLiteralBool(it.condition, false) } + + val default = visitRex(node.default, node.default.type) + + // Calculate final expression (short-circuit to first branch if the condition is always TRUE). + val resultTypes = ( + newBranches + // do not add to output type if the condition is missing or error, it can never be reached. + .filterNot { it.condition.op is Rex.Op.Missing || it.condition.op is Rex.Op.Err } + .map { it.rex } + .map { it.type } + listOf(default.type) + ).toSet() + if (resultTypes.all { it is MissingType }) { + return ProblemGenerator.missingRex( + rexOpCase(newBranches, default), + ProblemGenerator.expressionAlwaysReturnsMissing("Case expression always returns missing") ) } - } - val default = visitRex(node.default, node.default.type) - - // Calculate final expression (short-circuit to first branch if the condition is always TRUE). - val resultTypes = newBranches.map { it.rex }.map { it.type } + listOf(default.type) - return when (newBranches.size) { - 0 -> default - else -> when (isLiteralBool(newBranches[0].condition, true)) { - true -> newBranches[0].rex - false -> rex( - type = unionOf(resultTypes.toSet()).flatten(), - node.copy(branches = newBranches, default = default) - ) + return when (newBranches.size) { + 0 -> default + else -> when (isLiteralBool(newBranches[0].condition, true)) { + true -> newBranches[0].rex + false -> rex( + type = unionOf(resultTypes).flatten(), + node.copy(branches = newBranches, default = default) + ) + } } } - } - /** - * In this context, Boolean means PartiQLValueType Bool, which can be nullable. - * Hence, we permit Static Type BOOL, Static Type NULL, Static Type Missing here. - */ - private fun canBeBoolean(type: StaticType): Boolean { - return type.flatten().allTypes.any { - // TODO: This is a quick fix to unblock the typing or case expression. - // We need to model the truth value better in typer. - it is BoolType || it is NullType || it is MissingType + /** + * In this context, Boolean means PartiQLValueType Bool, which can be nullable. + * Hence, we permit Static Type BOOL, Static Type NULL, Static Type Missing here. + */ + private fun canBeBoolean(type: StaticType): Boolean { + return type.flatten().allTypes.any { + // TODO: This is a quick fix to unblock the typing or case expression. + // We need to model the truth value better in typer. + it is BoolType || it is NullType || it is MissingType + } } - } - @OptIn(PartiQLValueExperimental::class) - private fun isLiteralBool(rex: Rex, bool: Boolean): Boolean { - val op = rex.op as? Rex.Op.Lit ?: return false - val value = op.value as? BoolValue ?: return false - return value.value == bool - } + @OptIn(PartiQLValueExperimental::class) + private fun isLiteralBool(rex: Rex, bool: Boolean): Boolean { + val op = rex.op as? Rex.Op.Lit ?: return false + val value = op.value as? BoolValue ?: return false + return value.value == bool + } - /** - * We need special handling for: - * ``` - * CASE - * WHEN a IS STRUCT THEN a - * ELSE { 'a': a } - * END - * ``` - * When we type the above, if we know that `a` can be many different types (one of them being a struct), - * then when we see the top-level `a IS STRUCT`, then we can assume that the `a` on the RHS is definitely a - * struct. We handle this by using [foldCaseBranch]. - */ - override fun visitRexOpCaseBranch(node: Rex.Op.Case.Branch, ctx: StaticType?): Rex.Op.Case.Branch { - val visitedCondition = visitRex(node.condition, node.condition.type) - val visitedReturn = visitRex(node.rex, node.rex.type) - return foldCaseBranch(visitedCondition, visitedReturn) - } + /** + * We need special handling for: + * ``` + * CASE + * WHEN a IS STRUCT THEN a + * ELSE { 'a': a } + * END + * ``` + * When we type the above, if we know that `a` can be many different types (one of them being a struct), + * then when we see the top-level `a IS STRUCT`, then we can assume that the `a` on the RHS is definitely a + * struct. We handle this by using [foldCaseBranch]. + */ + override fun visitRexOpCaseBranch(node: Rex.Op.Case.Branch, ctx: StaticType?): Rex.Op.Case.Branch { + val visitedCondition = visitRex(node.condition, node.condition.type) + val visitedReturn = visitRex(node.rex, node.rex.type) + return foldCaseBranch(visitedCondition, visitedReturn) + } - /** - * This takes in a branch condition and its result expression. - * - * 1. If the condition is a type check T (ie ` IS T`), then this function will be typed as T. - * 2. If a branch condition is known to be false, it will be removed. - * - * TODO: Currently, this only folds type checking for STRUCTs. We need to add support for all other types. - * - * TODO: I added a check for [Rex.Op.Var.Outer] as it seemed odd to replace a general expression like: - * `WHEN { 'a': { 'b': 1} }.a IS STRUCT THEN { 'a': { 'b': 1} }.a.b`. We can discuss this later, but I'm - * currently limiting the scope of this intentionally. - */ - @OptIn(FnExperimental::class) - private fun foldCaseBranch(condition: Rex, result: Rex): Rex.Op.Case.Branch { - return when (val call = condition.op) { - is Rex.Op.Call.Dynamic -> { - val rex = call.candidates.map { candidate -> - val fn = candidate.fn + /** + * This takes in a branch condition and its result expression. + * + * 1. If the condition is a type check T (ie ` IS T`), then this function will be typed as T. + * 2. If a branch condition is known to be false, it will be removed. + * + * TODO: Currently, this only folds type checking for STRUCTs. We need to add support for all other types. + * + * TODO: I added a check for [Rex.Op.Var.Outer] as it seemed odd to replace a general expression like: + * `WHEN { 'a': { 'b': 1} }.a IS STRUCT THEN { 'a': { 'b': 1} }.a.b`. We can discuss this later, but I'm + * currently limiting the scope of this intentionally. + */ + @OptIn(FnExperimental::class) + private fun foldCaseBranch(condition: Rex, result: Rex): Rex.Op.Case.Branch { + return when (val call = condition.op) { + is Rex.Op.Call.Dynamic -> { + val rex = call.candidates.map { candidate -> + val fn = candidate.fn + if (fn.signature.name.equals("is_struct", ignoreCase = true).not()) { + return rexOpCaseBranch(condition, result) + } + val ref = call.args.getOrNull(0) ?: error("IS STRUCT requires an argument.") + // Replace the result's type + val type = AnyOfType(ref.type.allTypes.filterIsInstance().toSet()) + val replacementVal = ref.copy(type = type) + when (ref.op is Rex.Op.Var.Local) { + true -> RexReplacer.replace(result, ref, replacementVal) + false -> result + } + } + val type = rex.toUnionType().flatten() + return rexOpCaseBranch(condition, result.copy(type)) + } + is Rex.Op.Call.Static -> { + val fn = call.fn if (fn.signature.name.equals("is_struct", ignoreCase = true).not()) { return rexOpCaseBranch(condition, result) } val ref = call.args.getOrNull(0) ?: error("IS STRUCT requires an argument.") + val simplifiedCondition = when { + ref.type.allTypes.all { it is StructType } -> rex(BOOL, rexOpLit(boolValue(true))) + ref.type.allTypes.none { it is StructType } -> rex(BOOL, rexOpLit(boolValue(false))) + else -> condition + } + // Replace the result's type val type = AnyOfType(ref.type.allTypes.filterIsInstance().toSet()) val replacementVal = ref.copy(type = type) - when (ref.op is Rex.Op.Var.Local) { + val rex = when (ref.op is Rex.Op.Var.Local) { true -> RexReplacer.replace(result, ref, replacementVal) false -> result } + return rexOpCaseBranch(simplifiedCondition, rex) } - val type = rex.toUnionType().flatten() - return rexOpCaseBranch(condition, result.copy(type)) + else -> rexOpCaseBranch(condition, result) } - is Rex.Op.Call.Static -> { - val fn = call.fn - if (fn.signature.name.equals("is_struct", ignoreCase = true).not()) { - return rexOpCaseBranch(condition, result) - } - val ref = call.args.getOrNull(0) ?: error("IS STRUCT requires an argument.") - val simplifiedCondition = when { - ref.type.allTypes.all { it is StructType } -> rex(BOOL, rexOpLit(boolValue(true))) - ref.type.allTypes.none { it is StructType } -> rex(BOOL, rexOpLit(boolValue(false))) - else -> condition - } - - // Replace the result's type - val type = AnyOfType(ref.type.allTypes.filterIsInstance().toSet()) - val replacementVal = ref.copy(type = type) - val rex = when (ref.op is Rex.Op.Var.Local) { - true -> RexReplacer.replace(result, ref, replacementVal) - false -> result - } - return rexOpCaseBranch(simplifiedCondition, rex) - } - else -> rexOpCaseBranch(condition, result) } - } - override fun visitRexOpCollection(node: Rex.Op.Collection, ctx: StaticType?): Rex { - if (ctx!! !is CollectionType) { - handleUnexpectedType(ctx, setOf(StaticType.LIST, StaticType.BAG, StaticType.SEXP)) - return rex(StaticType.NULL_OR_MISSING, rexOpErr("Expected collection type")) - } - val values = node.values.map { visitRex(it, it.type) } - val t = when (values.size) { - 0 -> ANY - else -> values.toUnionType() - } - val type = when (ctx as CollectionType) { - is BagType -> BagType(t) - is ListType -> ListType(t) - is SexpType -> SexpType(t) + override fun visitRexOpCollection(node: Rex.Op.Collection, ctx: StaticType?): Rex { + if (ctx!! !is CollectionType) { + return ProblemGenerator.missingRex( + node, + ProblemGenerator.unexpectedType(ctx, setOf(StaticType.LIST, StaticType.BAG, StaticType.SEXP)) + ) + } + val values = node.values.map { visitRex(it, it.type) } + val t = when (values.size) { + 0 -> ANY + else -> values.toUnionType() + } + val type = when (ctx as CollectionType) { + is BagType -> BagType(t) + is ListType -> ListType(t) + is SexpType -> SexpType(t) + } + return rex(type, rexOpCollection(values)) } - return rex(type, rexOpCollection(values)) - } - @OptIn(PartiQLValueExperimental::class) - override fun visitRexOpStruct(node: Rex.Op.Struct, ctx: StaticType?): Rex { - val fields = node.fields.mapNotNull { - val k = visitRex(it.k, it.k.type) - val v = visitRex(it.v, it.v.type) - if (v.type is MissingType) { - null - } else { - rexOpStructField(k, v) + @OptIn(PartiQLValueExperimental::class) + override fun visitRexOpStruct(node: Rex.Op.Struct, ctx: StaticType?): Rex { + val fields = node.fields.mapNotNull { + val k = visitRex(it.k, it.k.type) + val v = visitRex(it.v, it.v.type) + if (v.op is Rex.Op.Missing) { + rexOpStructField(k, v) + } + // ignore literal missing + else if (v.type is MissingType) { + null + } else { + rexOpStructField(k, v) + } } - } - var structIsClosed = true - val structTypeFields = mutableListOf() - val structKeysSeent = mutableSetOf() - for (field in fields) { - when (field.k.op) { - is Rex.Op.Lit -> { - // A field is only included in the StructType if its key is a text literal - val key = field.k.op - if (key.value is TextValue<*>) { - val name = key.value.string!! - val type = field.v.type - structKeysSeent.add(name) - structTypeFields.add(StructType.Field(name, type)) - } + var structIsClosed = true + val structTypeFields = mutableListOf() + val structKeysSeent = mutableSetOf() + for (field in fields) { + // if a field op is an rex op missing, trace the field in the struct for error reporting + // but do not add the missing into the struct typing. + if (field.v.op is Rex.Op.Missing || field.v.op is Rex.Op.Err) { + continue } - else -> { - if (field.k.type.allTypes.any { it.isText() }) { - // If the non-literal could be text, StructType will have open content. - structIsClosed = false - } else { - // A field with a non-literal key name is not included in the StructType. + when (field.k.op) { + is Rex.Op.Lit -> { + // A field is only included in the StructType if its key is a text literal + val key = field.k.op + if (key.value is TextValue<*>) { + val name = key.value.string!! + val type = field.v.type + structKeysSeent.add(name) + structTypeFields.add(StructType.Field(name, type)) + } + } + else -> { + if (field.k.type.allTypes.any { it.isText() }) { + // If the non-literal could be text, StructType will have open content. + structIsClosed = false + } else { + // A field with a non-literal key name is not included in the StructType. + } } } } + val type = StructType( + fields = structTypeFields, + contentClosed = structIsClosed, + constraints = setOf( + TupleConstraint.Open(!structIsClosed), + TupleConstraint.UniqueAttrs( + structKeysSeent.size == fields.filterNot { it.v.op is Rex.Op.Missing || it.v.op is Rex.Op.Err }.size + ) + ), + ) + return rex(type, rexOpStruct(fields)) } - val type = StructType( - fields = structTypeFields, - contentClosed = structIsClosed, - constraints = setOf( - TupleConstraint.Open(!structIsClosed), - TupleConstraint.UniqueAttrs(structKeysSeent.size == fields.size) - ), - ) - return rex(type, rexOpStruct(fields)) - } - override fun visitRexOpPivot(node: Rex.Op.Pivot, ctx: StaticType?): Rex { - val stack = locals.outer + listOf(locals) - val rel = node.rel.type(stack) - val typeEnv = TypeEnv(rel.type.schema, stack) - val typer = RexTyper(typeEnv, Scope.LOCAL) - val key = typer.visitRex(node.key, null) - val value = typer.visitRex(node.value, null) - val type = StructType( - contentClosed = false, constraints = setOf(TupleConstraint.Open(true)) - ) - val op = rexOpPivot(key, value, rel) - return rex(type, op) - } - - override fun visitRexOpSubquery(node: Rex.Op.Subquery, ctx: StaticType?): Rex { - val rel = node.rel.type(locals.outer + listOf(locals)) - val newTypeEnv = TypeEnv(schema = rel.type.schema, outer = locals.outer + listOf(locals)) - val constructor = node.constructor.type(newTypeEnv) - val subquery = rexOpSubquery(constructor, rel, node.coercion) - return when (node.coercion) { - Rex.Op.Subquery.Coercion.SCALAR -> visitRexOpSubqueryScalar(subquery, constructor.type) - Rex.Op.Subquery.Coercion.ROW -> visitRexOpSubqueryRow(subquery, constructor.type) + override fun visitRexOpPivot(node: Rex.Op.Pivot, ctx: StaticType?): Rex { + val stack = locals.outer + listOf(locals) + val rel = node.rel.type(stack) + val typeEnv = TypeEnv(rel.type.schema, stack) + val typer = RexTyper(typeEnv, Scope.LOCAL) + val key = typer.visitRex(node.key, null) + val value = typer.visitRex(node.value, null) + val type = StructType( + contentClosed = false, constraints = setOf(TupleConstraint.Open(true)) + ) + val op = rexOpPivot(key, value, rel) + return rex(type, op) } - } - /** - * Calculate output type of a row-value subquery. - */ - private fun visitRexOpSubqueryRow(subquery: Rex.Op.Subquery, cons: StaticType): Rex { - if (cons !is StructType) { - return rexErr("Subquery with non-SQL SELECT cannot be coerced to a row-value expression. Found constructor type: $cons") + override fun visitRexOpSubquery(node: Rex.Op.Subquery, ctx: StaticType?): Rex { + val rel = node.rel.type(locals.outer + listOf(locals)) + val newTypeEnv = TypeEnv(schema = rel.type.schema, outer = locals.outer + listOf(locals)) + val constructor = node.constructor.type(newTypeEnv) + val subquery = rexOpSubquery(constructor, rel, node.coercion) + return when (node.coercion) { + Rex.Op.Subquery.Coercion.SCALAR -> visitRexOpSubqueryScalar(subquery, constructor.type) + Rex.Op.Subquery.Coercion.ROW -> visitRexOpSubqueryRow(subquery, constructor.type) + } } - // Do a simple cardinality check for the moment. - // TODO we can only check cardinality if we know we are in a a comparison operator. - // val n = coercion.columns.size - // val m = cons.fields.size - // if (n != m) { - // return rexErr("Cannot coercion subquery with $m attributes to a row-value-expression with $n attributes") - // } - // If we made it this far, then we can coerce this subquery to the desired complex value - val type = StaticType.LIST - val op = subquery - return rex(type, op) - } - /** - * Calculate output type of a scalar subquery. - */ - private fun visitRexOpSubqueryScalar(subquery: Rex.Op.Subquery, cons: StaticType): Rex { - if (cons !is StructType) { - return rexErr("Subquery with non-SQL SELECT cannot be coerced to a scalar. Found constructor type: $cons") - } - val n = cons.fields.size - if (n != 1) { - return rexErr("SELECT constructor with $n attributes cannot be coerced to a scalar. Found constructor type: $cons") + /** + * Calculate output type of a row-value subquery. + */ + private fun visitRexOpSubqueryRow(subquery: Rex.Op.Subquery, cons: StaticType): Rex { + if (cons !is StructType) { + return ProblemGenerator.errorRex( + subquery, + ProblemGenerator.compilerError("Subquery with non-SQL SELECT cannot be coerced to a row-value expression. Found constructor type: $cons") + ) + } + // Do a simple cardinality check for the moment. + // TODO we can only check cardinality if we know we are in a a comparison operator. + // val n = coercion.columns.size + // val m = cons.fields.size + // if (n != m) { + // return rexErr("Cannot coercion subquery with $m attributes to a row-value-expression with $n attributes") + // } + // If we made it this far, then we can coerce this subquery to the desired complex value + val type = StaticType.LIST + val op = subquery + return rex(type, op) } - // If we made it this far, then we can coerce this subquery to a scalar - val type = cons.fields.first().value - val op = subquery - return rex(type, op) - } - override fun visitRexOpSelect(node: Rex.Op.Select, ctx: StaticType?): Rex { - val rel = node.rel.type(locals.outer + listOf(locals)) - val newTypeEnv = TypeEnv(schema = rel.type.schema, outer = locals.outer + listOf(locals)) - var constructor = node.constructor.type(newTypeEnv) - var constructorType = constructor.type - // add the ordered property to the constructor - if (constructorType is StructType) { - // TODO: We shouldn't need to copy the ordered constraint. - constructorType = constructorType.copy( - constraints = constructorType.constraints + setOf(TupleConstraint.Ordered) - ) - constructor = rex(constructorType, constructor.op) - } - val type = when (rel.isOrdered()) { - true -> ListType(constructor.type) - else -> BagType(constructor.type) + /** + * Calculate output type of a scalar subquery. + */ + private fun visitRexOpSubqueryScalar(subquery: Rex.Op.Subquery, cons: StaticType): Rex { + if (cons !is StructType) { + return ProblemGenerator.errorRex( + subquery, + ProblemGenerator.compilerError("Subquery with non-SQL SELECT cannot be coerced to a scalar. Found constructor type: $cons") + ) + } + val n = cons.fields.size + if (n != 1) { + return ProblemGenerator.errorRex( + subquery, + ProblemGenerator.compilerError("SELECT constructor with $n attributes cannot be coerced to a scalar. Found constructor type: $cons") + ) + } + // If we made it this far, then we can coerce this subquery to a scalar + val type = cons.fields.first().value + val op = subquery + return rex(type, op) } - return rex(type, rexOpSelect(constructor, rel)) - } - override fun visitRexOpTupleUnion(node: Rex.Op.TupleUnion, ctx: StaticType?): Rex { - val args = node.args.map { visitRex(it, ctx) } - val type = when (args.size) { - 0 -> StructType( - fields = emptyMap(), contentClosed = true, - constraints = setOf( - TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true), TupleConstraint.Ordered + override fun visitRexOpSelect(node: Rex.Op.Select, ctx: StaticType?): Rex { + val rel = node.rel.type(locals.outer + listOf(locals)) + val newTypeEnv = TypeEnv(schema = rel.type.schema, outer = locals.outer + listOf(locals)) + var constructor = node.constructor.type(newTypeEnv) + var constructorType = constructor.type + // add the ordered property to the constructor + if (constructorType is StructType) { + // TODO: We shouldn't need to copy the ordered constraint. + constructorType = constructorType.copy( + constraints = constructorType.constraints + setOf(TupleConstraint.Ordered) ) - ) - else -> { - val argTypes = args.map { it.type } - val potentialTypes = buildArgumentPermutations(argTypes).map { argumentList -> - calculateTupleUnionOutputType(argumentList) - } - unionOf(potentialTypes.toSet()).flatten() + constructor = rex(constructorType, constructor.op) } + val type = when (rel.isOrdered()) { + true -> ListType(constructor.type) + else -> BagType(constructor.type) + } + return rex(type, rexOpSelect(constructor, rel)) } - val op = rexOpTupleUnion(args) - return rex(type, op) - } - override fun visitRexOpErr(node: Rex.Op.Err, ctx: StaticType?): PlanNode { - val type = ctx ?: ANY - return rex(type, node) - } - - // Helpers - - /** - * Given a list of [args], this calculates the output type of `TUPLEUNION(args)`. NOTE: This does NOT handle union - * types intentionally. This function expects that all arguments be flattened, and, if need be, that you invoke - * this function multiple times based on the permutations of arguments. - * - * The signature of TUPLEUNION is: (LIST) -> STRUCT. - * - * If any of the arguments are NULL (or potentially NULL), we return NULL. - * If any of the arguments are non-struct, we return MISSING. - * - * Now, assuming all the other arguments are STRUCT, then we compute the output based on a number of factors: - * - closed content - * - ordering - * - unique attributes - * - * If all arguments are closed content, then the output is closed content. - * If all arguments are ordered, then the output is ordered. - * If all arguments contain unique attributes AND all arguments are closed AND no fields clash, the output has - * unique attributes. - */ - private fun calculateTupleUnionOutputType(args: List): StaticType { - val structFields = mutableListOf() - var structAmount = 0 - var structIsClosed = true - var structIsOrdered = true - var uniqueAttrs = true - val possibleOutputTypes = mutableListOf() - args.forEach { arg -> - when (arg) { - is StructType -> { - structAmount += 1 - structFields.addAll(arg.fields) - structIsClosed = structIsClosed && arg.constraints.contains(TupleConstraint.Open(false)) - structIsOrdered = structIsOrdered && arg.constraints.contains(TupleConstraint.Ordered) - uniqueAttrs = uniqueAttrs && arg.constraints.contains(TupleConstraint.UniqueAttrs(true)) - } - is AnyOfType -> { - onProblem.invoke( - Problem( - UNKNOWN_PROBLEM_LOCATION, - PlanningProblemDetails.CompileError("TupleUnion wasn't normalized to exclude union types.") - ) + override fun visitRexOpTupleUnion(node: Rex.Op.TupleUnion, ctx: StaticType?): Rex { + val args = node.args.map { visitRex(it, ctx) } + val type = when (args.size) { + 0 -> StructType( + fields = emptyMap(), contentClosed = true, + constraints = setOf( + TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true), TupleConstraint.Ordered ) - possibleOutputTypes.add(MISSING) - } - is NullType -> { - return NULL - } + ) else -> { - return MISSING + val argTypes = args.map { it.type } + val potentialTypes = buildArgumentPermutations(argTypes).map { argumentList -> + calculateTupleUnionOutputType(argumentList) + } + unionOf(potentialTypes.toSet()).flatten() } } + val op = rexOpTupleUnion(args) + return rex(type, op) } - uniqueAttrs = when { - structIsClosed.not() && structAmount > 1 -> false - else -> uniqueAttrs - } - uniqueAttrs = uniqueAttrs && (structFields.size == structFields.distinctBy { it.key }.size) - val orderedConstraint = when (structIsOrdered) { - true -> TupleConstraint.Ordered - false -> null - } - val constraints = setOfNotNull( - TupleConstraint.Open(!structIsClosed), TupleConstraint.UniqueAttrs(uniqueAttrs), orderedConstraint - ) - return StructType( - fields = structFields.map { it }, contentClosed = structIsClosed, constraints = constraints - ) - } - /** - * We are essentially making permutations of arguments that maintain the same initial ordering. For example, - * consider the following args: - * ``` - * [ 0 = UNION(INT, STRING), 1 = (DECIMAL, TIMESTAMP) ] - * ``` - * This function will return: - * ``` - * [ - * [ 0 = INT, 1 = DECIMAL ], - * [ 0 = INT, 1 = TIMESTAMP ], - * [ 0 = STRING, 1 = DECIMAL ], - * [ 0 = STRING, 1 = TIMESTAMP ] - * ] - * ``` - * - * Essentially, this becomes useful specifically in the case of TUPLEUNION, since we can make sure that - * the ordering of argument's attributes remains the same. For example: - * ``` - * TUPLEUNION( UNION(STRUCT(a, b), STRUCT(c)), UNION(STRUCT(d, e), STRUCT(f)) ) - * ``` - * - * Then, the output of the tupleunion will have the output types of all of the below: - * ``` - * TUPLEUNION(STRUCT(a,b), STRUCT(d,e)) --> STRUCT(a, b, d, e) - * TUPLEUNION(STRUCT(a,b), STRUCT(f)) --> STRUCT(a, b, f) - * TUPLEUNION(STRUCT(c), STRUCT(d,e)) --> STRUCT(c, d, e) - * TUPLEUNION(STRUCT(c), STRUCT(f)) --> STRUCT(c, f) - * ``` - */ - private fun buildArgumentPermutations(args: List): Sequence> { - val flattenedArgs = args.map { it.flatten().allTypes } - return buildArgumentPermutations(flattenedArgs, accumulator = emptyList()) - } - - private fun buildArgumentPermutations( - args: List>, - accumulator: List, - ): Sequence> { - if (args.isEmpty()) { - return sequenceOf(accumulator) - } - val first = args.first() - val rest = when (args.size) { - 1 -> emptyList() - else -> args.subList(1, args.size) + override fun visitRexOpErr(node: Rex.Op.Err, ctx: StaticType?): PlanNode { + val type = ctx ?: ANY + return rex(type, node) } - return sequence { - first.forEach { argSubType -> - yieldAll(buildArgumentPermutations(rest, accumulator + listOf(argSubType))) - } - } - } - // Helpers + override fun visitRexOpMissing(node: Rex.Op.Missing, ctx: StaticType?): PlanNode { + val type = ctx ?: MISSING + return rex(type, node) + } - /** - * Logic is as follows: - * 1. If [struct] is closed and ordered: - * - If no item is found, return [MissingType] - * - Else, grab first matching item and make sensitive. - * 2. If [struct] is closed - * - AND no item is found, return [MissingType] - * - AND only one item is present -> grab item and make sensitive. - * - AND more than one item is present, keep sensitivity and grab item. - * 3. If [struct] is open, return [AnyType] - * - * @return a [Pair] where the [Pair.first] represents the type of the [step] and the [Pair.second] represents - * the disambiguated [key]. - */ - private fun inferStructLookup(struct: StructType, key: Identifier.Symbol): Pair { - val binding = key.toBindingName() - val isClosed = struct.constraints.contains(TupleConstraint.Open(false)) - val isOrdered = struct.constraints.contains(TupleConstraint.Ordered) - val (name, type) = when { - // 1. Struct is closed and ordered - isClosed && isOrdered -> { - struct.fields.firstOrNull { entry -> binding.matches(entry.key) }?.let { - (sensitive(it.key) to it.value) - } ?: (key to MISSING) - } - // 2. Struct is closed - isClosed -> { - val matches = struct.fields.filter { entry -> binding.matches(entry.key) } - when (matches.size) { - 0 -> (key to MISSING) - 1 -> matches.first().let { (sensitive(it.key) to it.value) } + // Helpers + + /** + * Given a list of [args], this calculates the output type of `TUPLEUNION(args)`. NOTE: This does NOT handle union + * types intentionally. This function expects that all arguments be flattened, and, if need be, that you invoke + * this function multiple times based on the permutations of arguments. + * + * The signature of TUPLEUNION is: (LIST) -> STRUCT. + * + * If any of the arguments are NULL (or potentially NULL), we return NULL. + * If any of the arguments are non-struct, we return MISSING. + * + * Now, assuming all the other arguments are STRUCT, then we compute the output based on a number of factors: + * - closed content + * - ordering + * - unique attributes + * + * If all arguments are closed content, then the output is closed content. + * If all arguments are ordered, then the output is ordered. + * If all arguments contain unique attributes AND all arguments are closed AND no fields clash, the output has + * unique attributes. + */ + private fun calculateTupleUnionOutputType(args: List): StaticType { + val structFields = mutableListOf() + var structAmount = 0 + var structIsClosed = true + var structIsOrdered = true + var uniqueAttrs = true + val possibleOutputTypes = mutableListOf() + args.forEach { arg -> + when (arg) { + is StructType -> { + structAmount += 1 + structFields.addAll(arg.fields) + structIsClosed = structIsClosed && arg.constraints.contains(TupleConstraint.Open(false)) + structIsOrdered = structIsOrdered && arg.constraints.contains(TupleConstraint.Ordered) + uniqueAttrs = uniqueAttrs && arg.constraints.contains(TupleConstraint.UniqueAttrs(true)) + } + is AnyOfType -> { + error("TupleUnion wasn't normalized to exclude union types.") + } + is NullType -> { + return NULL + } else -> { - val firstKey = matches.first().key - val sharedKey = when (matches.all { it.key == firstKey }) { - true -> sensitive(firstKey) - false -> key - } - sharedKey to unionOf(matches.map { it.value }.toSet()).flatten() + return MISSING } } } - // 3. Struct is open - else -> (key to ANY) + uniqueAttrs = when { + structIsClosed.not() && structAmount > 1 -> false + else -> uniqueAttrs + } + uniqueAttrs = uniqueAttrs && (structFields.size == structFields.distinctBy { it.key }.size) + val orderedConstraint = when (structIsOrdered) { + true -> TupleConstraint.Ordered + false -> null + } + val constraints = setOfNotNull( + TupleConstraint.Open(!structIsClosed), TupleConstraint.UniqueAttrs(uniqueAttrs), orderedConstraint + ) + return StructType( + fields = structFields.map { it }, contentClosed = structIsClosed, constraints = constraints + ) } - return type to name - } - private fun sensitive(str: String): Identifier.Symbol = - identifierSymbol(str, Identifier.CaseSensitivity.SENSITIVE) - - @OptIn(FnExperimental::class) - private fun inferFnType(fn: FnSignature, args: List): StaticType { - - // Determine role of NULL and MISSING in the return type - var hadNull = false - var hadNullable = false - var hadMissing = false - var hadMissable = false - for (arg in args) { - val t = arg.type - when { - t is MissingType -> hadMissing = true - t is NullType -> hadNull = true - t.isMissable() -> hadMissable = true - t.isNullable() -> hadNullable = true - } + /** + * We are essentially making permutations of arguments that maintain the same initial ordering. For example, + * consider the following args: + * ``` + * [ 0 = UNION(INT, STRING), 1 = (DECIMAL, TIMESTAMP) ] + * ``` + * This function will return: + * ``` + * [ + * [ 0 = INT, 1 = DECIMAL ], + * [ 0 = INT, 1 = TIMESTAMP ], + * [ 0 = STRING, 1 = DECIMAL ], + * [ 0 = STRING, 1 = TIMESTAMP ] + * ] + * ``` + * + * Essentially, this becomes useful specifically in the case of TUPLEUNION, since we can make sure that + * the ordering of argument's attributes remains the same. For example: + * ``` + * TUPLEUNION( UNION(STRUCT(a, b), STRUCT(c)), UNION(STRUCT(d, e), STRUCT(f)) ) + * ``` + * + * Then, the output of the tupleunion will have the output types of all of the below: + * ``` + * TUPLEUNION(STRUCT(a,b), STRUCT(d,e)) --> STRUCT(a, b, d, e) + * TUPLEUNION(STRUCT(a,b), STRUCT(f)) --> STRUCT(a, b, f) + * TUPLEUNION(STRUCT(c), STRUCT(d,e)) --> STRUCT(c, d, e) + * TUPLEUNION(STRUCT(c), STRUCT(f)) --> STRUCT(c, f) + * ``` + */ + private fun buildArgumentPermutations(args: List): Sequence> { + val flattenedArgs = args.map { it.flatten().allTypes } + return buildArgumentPermutations(flattenedArgs, accumulator = emptyList()) } - // True iff NULL CALL and had a NULL arg; - val isNull = (fn.isNullCall && hadNull) - - // True iff NULL CALL and had a NULLABLE arg; or is a NULLABLE operator - val isNullable = (fn.isNullCall && hadNullable) || fn.isNullable - - // True iff MISSING CALL and had a MISSING arg. - val isMissing = fn.isMissingCall && hadMissing - - // True iff MISSING CALL and had a MISSABLE arg - val isMissable = (fn.isMissingCall && hadMissable) && fn.isMissable - - // Return type with calculated nullability - var type: StaticType = when { - isMissing -> MISSING - // Edge cases for EQ and boolean connective - // If function can not return missing or null, can not propagate missing or null - // AKA, the Function IS MISSING - // return signature return type - !fn.isMissable && !fn.isMissingCall && !fn.isNullable && !fn.isNullCall -> fn.returns.toNonNullStaticType() - isNull || (!fn.isMissable && hadMissing) -> fn.returns.toStaticType() - isNullable -> fn.returns.toStaticType() - else -> fn.returns.toNonNullStaticType() + private fun buildArgumentPermutations( + args: List>, + accumulator: List, + ): Sequence> { + if (args.isEmpty()) { + return sequenceOf(accumulator) + } + val first = args.first() + val rest = when (args.size) { + 1 -> emptyList() + else -> args.subList(1, args.size) + } + return sequence { + first.forEach { argSubType -> + yieldAll(buildArgumentPermutations(rest, accumulator + listOf(argSubType))) + } + } } - // Propagate MISSING unless this operator explicitly doesn't return missing (fn.isMissable = false). - if (isMissable) { - type = unionOf(type, MISSING) + // Helpers + + /** + * Logic is as follows: + * 1. If [struct] is closed and ordered: + * - If no item is found, return [MissingType] + * - Else, grab first matching item and make sensitive. + * 2. If [struct] is closed + * - AND no item is found, return [MissingType] + * - AND only one item is present -> grab item and make sensitive. + * - AND more than one item is present, keep sensitivity and grab item. + * 3. If [struct] is open, return [AnyType] + * + * @return a [Pair] where the [Pair.first] represents the type of the [step] and the [Pair.second] represents + * the disambiguated [key]. + */ + private fun inferStructLookup(struct: StructType, key: Identifier.Symbol): Pair { + val binding = key.toBindingName() + val isClosed = struct.constraints.contains(TupleConstraint.Open(false)) + val isOrdered = struct.constraints.contains(TupleConstraint.Ordered) + val (name, type) = when { + // 1. Struct is closed and ordered + isClosed && isOrdered -> { + struct.fields.firstOrNull { entry -> binding.matches(entry.key) }?.let { + (sensitive(it.key) to it.value) + } ?: (key to MISSING) + } + // 2. Struct is closed + isClosed -> { + val matches = struct.fields.filter { entry -> binding.matches(entry.key) } + when (matches.size) { + 0 -> (key to MISSING) + 1 -> matches.first().let { (sensitive(it.key) to it.value) } + else -> { + val firstKey = matches.first().key + val sharedKey = when (matches.all { it.key == firstKey }) { + true -> sensitive(firstKey) + false -> key + } + sharedKey to unionOf(matches.map { it.value }.toSet()).flatten() + } + } + } + // 3. Struct is open + else -> (key to ANY) + } + return type to name } - return type.flatten() - } + private fun sensitive(str: String): Identifier.Symbol = + identifierSymbol(str, Identifier.CaseSensitivity.SENSITIVE) + + @OptIn(FnExperimental::class) + private fun inferFnType(fn: FnSignature, args: List): StaticType { + + // Determine role of NULL and MISSING in the return type + var hadNull = false + var hadNullable = false + var hadMissing = false + var hadMissable = false + for (arg in args) { + val t = arg.type + when { + t is MissingType -> hadMissing = true + t is NullType -> hadNull = true + t.isMissable() -> hadMissable = true + t.isNullable() -> hadNullable = true + } + } - /** - * Resolution and typing of aggregation function calls. - * - * I've chosen to place this in RexTyper because all arguments will be typed using the same locals. - * There's no need to create new RexTyper instances for each argument. There is no reason to limit aggregations - * to a single argument (covar, corr, pct, etc.) but in practice we typically only have single . - * - * This method is _very_ similar to scalar function resolution, so it is temping to DRY these two out; but the - * separation is cleaner as the typing of NULLS is subtly different. - * - * SQL-99 6.16 General Rules on - * Let TX be the single-column table that is the result of applying the - * to each row of T and eliminating null values <--- all NULL values are eliminated as inputs - */ - @OptIn(FnExperimental::class) - fun resolveAgg(node: Rel.Op.Aggregate.Call.Unresolved): Pair { + // True iff NULL CALL and had a NULL arg; + val isNull = (fn.isNullCall && hadNull) + + // True iff NULL CALL and had a NULLABLE arg; or is a NULLABLE operator + val isNullable = (fn.isNullCall && hadNullable) || fn.isNullable + + // True iff MISSING CALL and had a MISSING arg. + val isMissing = fn.isMissingCall && hadMissing + + // True iff MISSING CALL and had a MISSABLE arg + val isMissable = (fn.isMissingCall && hadMissable) && fn.isMissable + + // Return type with calculated nullability + var type: StaticType = when { + isMissing -> MISSING + // Edge cases for EQ and boolean connective + // If function can not return missing or null, can not propagate missing or null + // AKA, the Function IS MISSING + // return signature return type + !fn.isMissable && !fn.isMissingCall && !fn.isNullable && !fn.isNullCall -> fn.returns.toNonNullStaticType() + isNull || (!fn.isMissable && hadMissing) -> fn.returns.toStaticType() + isNullable -> fn.returns.toStaticType() + else -> fn.returns.toNonNullStaticType() + } - // Type the arguments - var isMissable = false - val args = node.args.map { - val arg = visitRex(it, null) - if (arg.op is Rex.Op.Err) { - // don't attempt to resolve an aggregation with erroneous arguments. - handleUnknownAggregation(node) - return node to ANY - } else if (arg.type is MissingType) { - handleAlwaysMissing() - return relOpAggregateCallUnresolved(node.name, node.setQuantifier, listOf(rexErr("MISSING"))) to MissingType - } else if (arg.type.isMissable()) { - isMissable = true + // Propagate MISSING unless this operator explicitly doesn't return missing (fn.isMissable = false). + if (isMissable) { + type = unionOf(type, MISSING) } - arg - } - // Resolve the function - val call = env.resolveAgg(node.name, node.setQuantifier, args) - if (call == null) { - handleUnknownAggregation(node) - return node to ANY + return type.flatten() } - // Treat MISSING as NULL in aggregations. - val isNullable = call.agg.signature.isNullable || isMissable - val returns = call.agg.signature.returns - val type: StaticType = when { - isNullable -> returns.toStaticType() - else -> returns.toNonNullStaticType() + /** + * Resolution and typing of aggregation function calls. + * + * I've chosen to place this in RexTyper because all arguments will be typed using the same locals. + * There's no need to create new RexTyper instances for each argument. There is no reason to limit aggregations + * to a single argument (covar, corr, pct, etc.) but in practice we typically only have single . + * + * This method is _very_ similar to scalar function resolution, so it is temping to DRY these two out; but the + * separation is cleaner as the typing of NULLS is subtly different. + * + * SQL-99 6.16 General Rules on + * Let TX be the single-column table that is the result of applying the + * to each row of T and eliminating null values <--- all NULL values are eliminated as inputs + */ + @OptIn(FnExperimental::class) + fun resolveAgg(node: Rel.Op.Aggregate.Call.Unresolved): Pair { + // Type the arguments + var isMissable = false + val args = node.args.map { visitRex(it, null) } + val argsResolved = relOpAggregateCallUnresolved(node.name, node.setQuantifier, args) + + // Resolve the function + val call = env.resolveAgg(node.name, node.setQuantifier, args) ?: return argsResolved to ANY + if (args.any { it.type == MISSING }) return argsResolved to MISSING + if (args.any { it.type.isMissable() }) isMissable = true + + // Treat MISSING as NULL in aggregations. + val isNullable = call.agg.signature.isNullable || isMissable + val returns = call.agg.signature.returns + val type: StaticType = when { + isNullable -> returns.toStaticType() + else -> returns.toNonNullStaticType() + } + // + return call to type } - // - return call to type } - } - - // HELPERS - - private fun Rel.type(stack: List, strategy: Scope = Scope.LOCAL): Rel = - RelTyper(stack, strategy).visitRel(this, null) - - /** - * This types the [Rex] given the input record ([input]) and [stack] of [TypeEnv] (representing the outer scopes). - */ - private fun Rex.type(input: List, stack: List, strategy: Scope = Scope.LOCAL) = - RexTyper(TypeEnv(input, stack), strategy).visitRex(this, this.type) - - /** - * This types the [Rex] given a [TypeEnv]. We use the [TypeEnv.schema] as the input schema and the [TypeEnv.outer] - * as the outer scopes/ - */ - private fun Rex.type(typeEnv: TypeEnv, strategy: Scope = Scope.LOCAL) = - RexTyper(typeEnv, strategy).visitRex(this, this.type) - private fun rexErr(message: String) = rex(MISSING, rexOpErr(message)) + // HELPERS - /** - * I found decorating the tree with the binding names (for resolution) was easier than associating introduced - * bindings with a node via an id->list map. ONLY because right now I don't think we have a good way - * of managing ids when trees are rewritten. - * - * We need a good answer for these questions before going for it: - * - If you copy, should the id should come along for the ride? - * - If someone writes their own pass and forgets to copy the id, then resolution could break. - * - * We may be able to eliminate this issue by keeping everything internal and running the typing pass first. - * This is simple enough for now. - */ - private fun Rel.Type.copyWithSchema(types: List): Rel.Type { - assert(types.size == schema.size) { "Illegal copy, types size does not matching bindings list size" } - return this.copy(schema = schema.mapIndexed { i, binding -> binding.copy(type = types[i]) }) - } + private fun Rel.type(stack: List, strategy: Scope = Scope.LOCAL): Rel = + RelTyper(stack, strategy).visitRel(this, null) - private fun Identifier.toBindingPath() = when (this) { - is Identifier.Qualified -> this.toBindingPath() - is Identifier.Symbol -> BindingPath(listOf(this.toBindingName())) - } + /** + * This types the [Rex] given the input record ([input]) and [stack] of [TypeEnv] (representing the outer scopes). + */ + private fun Rex.type(input: List, stack: List, strategy: Scope = Scope.LOCAL) = + RexTyper(TypeEnv(input, stack), strategy).visitRex(this, this.type) - private fun Identifier.Qualified.toBindingPath() = - BindingPath(steps = listOf(this.root.toBindingName()) + steps.map { it.toBindingName() }) + /** + * This types the [Rex] given a [TypeEnv]. We use the [TypeEnv.schema] as the input schema and the [TypeEnv.outer] + * as the outer scopes/ + */ + private fun Rex.type(typeEnv: TypeEnv, strategy: Scope = Scope.LOCAL) = + RexTyper(typeEnv, strategy).visitRex(this, this.type) - private fun Identifier.Symbol.toBindingName() = BindingName( - name = symbol, - case = when (caseSensitivity) { - Identifier.CaseSensitivity.SENSITIVE -> BindingCase.SENSITIVE - Identifier.CaseSensitivity.INSENSITIVE -> BindingCase.INSENSITIVE + /** + * I found decorating the tree with the binding names (for resolution) was easier than associating introduced + * bindings with a node via an id->list map. ONLY because right now I don't think we have a good way + * of managing ids when trees are rewritten. + * + * We need a good answer for these questions before going for it: + * - If you copy, should the id should come along for the ride? + * - If someone writes their own pass and forgets to copy the id, then resolution could break. + * + * We may be able to eliminate this issue by keeping everything internal and running the typing pass first. + * This is simple enough for now. + */ + private fun Rel.Type.copyWithSchema(types: List): Rel.Type { + assert(types.size == schema.size) { "Illegal copy, types size does not matching bindings list size" } + return this.copy(schema = schema.mapIndexed { i, binding -> binding.copy(type = types[i]) }) } - ) - - private fun Rel.isOrdered(): Boolean = type.props.contains(Rel.Prop.ORDERED) - /** - * Produce a union type from all the - */ - private fun List.toUnionType(): StaticType = AnyOfType(map { it.type }.toSet()).flatten() - - private fun getElementTypeForFromSource(fromSourceType: StaticType): StaticType = when (fromSourceType) { - is BagType -> fromSourceType.elementType - is ListType -> fromSourceType.elementType - is AnyType -> ANY - is AnyOfType -> AnyOfType(fromSourceType.types.map { getElementTypeForFromSource(it) }.toSet()) - // All the other types coerce into a bag of themselves (including null/missing/sexp). - else -> fromSourceType - } - - private fun assertAsInt(type: StaticType) { - if (type.flatten().allTypes.any { variant -> variant is IntType }.not()) { - handleUnexpectedType(type, setOf(StaticType.INT)) + private fun Identifier.toBindingPath() = when (this) { + is Identifier.Qualified -> this.toBindingPath() + is Identifier.Symbol -> BindingPath(listOf(this.toBindingName())) } - } - - // ERRORS - - private fun handleUndefinedVariable(id: Identifier) { - val publicId = id.toBindingPath() - onProblem( - Problem( - sourceLocation = UNKNOWN_PROBLEM_LOCATION, details = PlanningProblemDetails.UndefinedVariable(publicId) - ) - ) - } - - private fun handleUnexpectedType(actual: StaticType, expected: Set) { - onProblem( - Problem( - sourceLocation = UNKNOWN_PROBLEM_LOCATION, - details = PlanningProblemDetails.UnexpectedType(actual, expected), - ) - ) - } - private fun handleUnknownCast(node: Rex.Op.Cast.Unresolved) { - onProblem( - Problem( - sourceLocation = UNKNOWN_PROBLEM_LOCATION, - details = PlanningProblemDetails.UnknownFunction( - identifier = "CAST( AS ${node.target})", args = listOf(node.arg.type) - ) - ) - ) - } + private fun Identifier.Qualified.toBindingPath() = + BindingPath(steps = listOf(this.root.toBindingName()) + steps.map { it.toBindingName() }) - private fun handleUnknownAggregation(node: Rel.Op.Aggregate.Call.Unresolved) { - onProblem( - Problem( - sourceLocation = UNKNOWN_PROBLEM_LOCATION, - details = PlanningProblemDetails.UnknownFunction( - identifier = node.name, - args = node.args.map { it.type } - ) - ) - ) - } - - private fun handleUnknownFunction(node: Rex.Op.Call.Unresolved, args: List) { - onProblem( - Problem( - sourceLocation = UNKNOWN_PROBLEM_LOCATION, - details = PlanningProblemDetails.UnknownFunction( - identifier = node.identifier.debug(), - args = args.map { it.type } - ) - ) + private fun Identifier.Symbol.toBindingName() = BindingName( + name = symbol, + case = when (caseSensitivity) { + Identifier.CaseSensitivity.SENSITIVE -> BindingCase.SENSITIVE + Identifier.CaseSensitivity.INSENSITIVE -> BindingCase.INSENSITIVE + } ) - } - private fun handleAlwaysMissing() { - onProblem( - Problem( - sourceLocation = UNKNOWN_PROBLEM_LOCATION, - details = PlanningProblemDetails.ExpressionAlwaysReturnsNullOrMissing - ) - ) - } + private fun Rel.isOrdered(): Boolean = type.props.contains(Rel.Prop.ORDERED) - private fun handleUnresolvedExcludeRoot(root: Identifier) { - onProblem( - Problem( - sourceLocation = UNKNOWN_PROBLEM_LOCATION, - details = PlanningProblemDetails.UnresolvedExcludeExprRoot(root.debug()) - ) - ) - } + /** + * Produce a union type from all the + */ + private fun List.toUnionType(): StaticType = AnyOfType(map { it.type }.toSet()).flatten() + + private fun getElementTypeForFromSource(fromSourceType: StaticType): StaticType = when (fromSourceType) { + is BagType -> fromSourceType.elementType + is ListType -> fromSourceType.elementType + is AnyType -> ANY + is AnyOfType -> AnyOfType(fromSourceType.types.map { getElementTypeForFromSource(it) }.toSet()) + // All the other types coerce into a bag of themselves (including null/missing/sexp). + else -> fromSourceType + } - // HELPERS + private fun assertAsInt(type: StaticType) = + type.flatten().allTypes.any { variant -> variant is IntType } - private fun Identifier.debug(): String = when (this) { - is Identifier.Qualified -> (listOf(root.debug()) + steps.map { it.debug() }).joinToString(".") - is Identifier.Symbol -> when (caseSensitivity) { - Identifier.CaseSensitivity.SENSITIVE -> "\"$symbol\"" - Identifier.CaseSensitivity.INSENSITIVE -> symbol + // HELPERS + /** + * This will make all binding values nullables. If the value is a struct, each field will be nullable. + * + * Note, this does not handle union types or nullable struct types. + */ + private fun List.pad() = map { + val type = when (val t = it.type) { + is StructType -> t.withNullableFields() + else -> t.asNullable() + } + relBinding(it.name, type) } - } - /** - * This will make all binding values nullables. If the value is a struct, each field will be nullable. - * - * Note, this does not handle union types or nullable struct types. - */ - private fun List.pad() = map { - val type = when (val t = it.type) { - is StructType -> t.withNullableFields() - else -> t.asNullable() + private fun StructType.withNullableFields(): StructType { + return copy(fields.map { it.copy(value = it.value.asNullable()) }) } - relBinding(it.name, type) - } - private fun StructType.withNullableFields(): StructType { - return copy(fields.map { it.copy(value = it.value.asNullable()) }) - } - - private fun excludeBindings(input: List, item: Rel.Op.Exclude.Path): List { - var matchedRoot = false - val output = input.map { - when (val root = item.root) { - is Rex.Op.Var.Unresolved -> { - when (val id = root.identifier) { - is Identifier.Symbol -> { - if (id.isEquivalentTo(it.name)) { - matchedRoot = true - // recompute the StaticType of this binding after applying the exclusions - val type = it.type.exclude(item.steps, lastStepOptional = false) - it.copy(type = type) - } else { - it + private fun excludeBindings(input: List, item: Rel.Op.Exclude.Path): List { + var matchedRoot = false + val output = input.map { + when (val root = item.root) { + is Rex.Op.Var.Unresolved -> { + when (val id = root.identifier) { + is Identifier.Symbol -> { + if (id.isEquivalentTo(it.name)) { + matchedRoot = true + // recompute the StaticType of this binding after applying the exclusions + val type = it.type.exclude(item.steps, lastStepOptional = false) + it.copy(type = type) + } else { + it + } } + is Identifier.Qualified -> it } - is Identifier.Qualified -> it } + is Rex.Op.Var.Local, is Rex.Op.Var.Global -> it + else -> it } - is Rex.Op.Var.Local, is Rex.Op.Var.Global -> it } + return output } - if (!matchedRoot && item.root is Rex.Op.Var.Unresolved) handleUnresolvedExcludeRoot(item.root.identifier) - return output - } - private fun Identifier.Symbol.isEquivalentTo(other: String): Boolean = when (caseSensitivity) { - Identifier.CaseSensitivity.SENSITIVE -> symbol.equals(other) - Identifier.CaseSensitivity.INSENSITIVE -> symbol.equals(other, ignoreCase = true) - } + private fun Identifier.Symbol.isEquivalentTo(other: String): Boolean = when (caseSensitivity) { + Identifier.CaseSensitivity.SENSITIVE -> symbol.equals(other) + Identifier.CaseSensitivity.INSENSITIVE -> symbol.equals(other, ignoreCase = true) + } - /** - * Pretty-print a path and its root type. - * - * @return - */ - private fun Rex.Op.Path.debug(): String { - val steps = mutableListOf() - var curr: Rex = rex(ANY, this) - while (true) { - curr = when (val op = curr.op) { - is Rex.Op.Path.Index -> { - steps.add("${op.key}") - op.root - } - is Rex.Op.Path.Key -> { - val k = op.key.op - if (k is Rex.Op.Lit && k.value is TextValue<*>) { - steps.add("${k.value.string}") - } else { + /** + * Pretty-print a path and its root type. + * + * @return + */ + private fun Rex.Op.Path.debug(): String { + val steps = mutableListOf() + var curr: Rex = rex(ANY, this) + while (true) { + curr = when (val op = curr.op) { + is Rex.Op.Path.Index -> { steps.add("${op.key}") + op.root } - op.root - } - is Rex.Op.Path.Symbol -> { - steps.add(op.key) - op.root + is Rex.Op.Path.Key -> { + val k = op.key.op + if (k is Rex.Op.Lit && k.value is TextValue<*>) { + steps.add("${k.value.string}") + } else { + steps.add("${op.key}") + } + op.root + } + is Rex.Op.Path.Symbol -> { + steps.add(op.key) + op.root + } + else -> break } - else -> break } + // curr is root + return "`${steps.joinToString(".")}` on root $curr" } - // curr is root - return "`${steps.joinToString(".")}` on root $curr" } -} + \ No newline at end of file diff --git a/partiql-planner/src/main/resources/partiql_plan_internal.ion b/partiql-planner/src/main/resources/partiql_plan_internal.ion index 4e1c06830a..12125f645a 100644 --- a/partiql-planner/src/main/resources/partiql_plan_internal.ion +++ b/partiql-planner/src/main/resources/partiql_plan_internal.ion @@ -5,6 +5,7 @@ imports::{ static_type::'org.partiql.types.StaticType', fn_signature::'org.partiql.spi.fn.FnSignature', agg_signature::'org.partiql.spi.fn.AggSignature', + problem::'org.partiql.errors.Problem' ], } @@ -212,8 +213,29 @@ rex::{ args: list::[rex], }, + // Internal Error Node: + // Upon encounter an error, i.e., unknown_function(1) + // The an error node will be popoluated in the plan to replace the top node + // i.e., + // |_ Rex.Op.Call.Unresolved["unknown_function"] + // |_ Lit[value=Int32ValueImpl(value = 1))] + // will be come + // |_ Rex.Op.Err[Problem(location = ...., message = "unknown function `unknown function...`)] + // |_ Lit[value=Int32ValueImpl(value = 1))] err::{ - message: string, + problem: problem, + causes: list::['.rex.op'], + }, + + // Internal MISSING Node: + // Upon encounter an operation that always returns missing, + // i.e., t.a where a does not exist in t + // The an mising node will be popoluated in the plan to replace the top node + // i.e., + // + missing::{ + problem: problem, + causes: list::['.rex.op'], }, ], } @@ -334,7 +356,7 @@ rel::{ paths: list::[path], _: [ path::{ - root: '.rex.op.var', + root: '.rex.op', steps: list::[step], }, step::{ diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/PlanTest.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/PlanTest.kt index 6a4e514619..7c8d0bdc8e 100644 --- a/partiql-planner/src/test/kotlin/org/partiql/planner/PlanTest.kt +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/PlanTest.kt @@ -6,6 +6,7 @@ import org.junit.jupiter.api.DynamicNode import org.junit.jupiter.api.DynamicTest import org.junit.jupiter.api.TestFactory import org.partiql.parser.PartiQLParser +import org.partiql.plan.PartiQLPlan import org.partiql.plan.PlanNode import org.partiql.plan.debug.PlanPrinter import org.partiql.planner.test.PartiQLTest @@ -75,22 +76,22 @@ class PlanTest { override fun getUserId(): String = "user-id" } - val session: (PartiQLTest.Key) -> PartiQLPlanner.Session = { key -> - PartiQLPlanner.Session( - queryId = key.toString(), + val pipeline: (PartiQLTest, Boolean) -> PartiQLPlanner.Result = { test, isSignalMode -> + val session = PartiQLPlanner.Session( + queryId = test.key.toString(), userId = "user_id", currentCatalog = "default", - currentDirectory = listOf(), + currentDirectory = listOf("SCHEMA"), catalogs = mapOf("default" to buildMetadata("default")), - instant = Instant.now(), + instant = Instant.now() ) - } - - val pipeline: (PartiQLTest) -> PartiQLPlanner.Result = { test -> val problemCollector = ProblemCollector() val ast = PartiQLParser.default().parse(test.statement).root - val planner = PartiQLPlanner.default() - planner.plan(ast, session(test.key), problemCollector) + val planner = when (isSignalMode) { + true -> PartiQLPlanner.builder().signalMode().build() + else -> PartiQLPlanner.builder().build() + } + planner.plan(ast, session, problemCollector) } fun buildMetadata(catalogName: String): ConnectorMetadata { @@ -98,8 +99,8 @@ class PlanTest { // Insert binding val name = BindingPath( listOf( - BindingName("default", BindingCase.INSENSITIVE), - BindingName("a", BindingCase.INSENSITIVE), + BindingName("SCHEMA", BindingCase.INSENSITIVE), + BindingName("T", BindingCase.INSENSITIVE), ) ) val obj = MemoryObject(type) @@ -132,28 +133,34 @@ class PlanTest { val group = parent.name val tests = parse(group, file) - val children = tests.map { + val children = tests.map { test -> // Prepare - val displayName = it.key.toString() + val displayName = test.key.toString() // Assert DynamicTest.dynamicTest(displayName) { - val input = input[it.key] ?: error("no test cases") - - val inputPlan = pipeline.invoke(input).plan - val outputPlan = pipeline.invoke(it).plan - assert(inputPlan.isEquaivalentTo(outputPlan)) { - buildString { - this.appendLine("expect plan equivalence") - PlanPrinter.append(this, inputPlan) - PlanPrinter.append(this, outputPlan) - } + val input = input[test.key] ?: error("no test cases") + + listOf(true, false).forEach { isSignal -> + val inputPlan = pipeline.invoke(input, isSignal).plan + val outputPlan = pipeline.invoke(test, isSignal).plan + assertPlanEqual(inputPlan, outputPlan) } } } return dynamicContainer(file.nameWithoutExtension, children) } + private fun assertPlanEqual(inputPlan: PartiQLPlan, outputPlan: PartiQLPlan) { + assert(inputPlan.isEquaivalentTo(outputPlan)) { + buildString { + this.appendLine("expect plan equivalence") + PlanPrinter.append(this, inputPlan) + PlanPrinter.append(this, outputPlan) + } + } + } + private fun parse(group: String, file: File): List { val tests = mutableListOf() var name = "" diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/PlannerErrorReportingTests.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/PlannerErrorReportingTests.kt new file mode 100644 index 0000000000..bf5ea6f9c7 --- /dev/null +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/PlannerErrorReportingTests.kt @@ -0,0 +1,458 @@ +package org.partiql.planner + +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.MethodSource +import org.partiql.ast.Statement +import org.partiql.errors.Problem +import org.partiql.errors.ProblemSeverity +import org.partiql.parser.PartiQLParserBuilder +import org.partiql.plan.debug.PlanPrinter +import org.partiql.planner.util.ProblemCollector +import org.partiql.plugins.memory.MemoryCatalog +import org.partiql.plugins.memory.MemoryConnector +import org.partiql.spi.connector.ConnectorSession +import org.partiql.types.BagType +import org.partiql.types.StaticType +import org.partiql.types.StructType +import org.partiql.types.TupleConstraint + +internal class PlannerErrorReportingTests { + val catalogName = "mode_test" + val userId = "test-user" + val queryId = "query" + + val catalog = MemoryCatalog + .PartiQL() + .name(catalogName) + .define("missing_binding", StaticType.MISSING) + .define("atomic", StaticType.INT2) + .define("collection_no_missing_atomic", BagType(StaticType.INT2)) + .define("collection_contain_missing_atomic", BagType(StaticType.unionOf(StaticType.INT2, StaticType.MISSING))) + .define("struct_no_missing", closedStruct(StructType.Field("f1", StaticType.INT2))) + .define( + "struct_with_missing", + closedStruct( + StructType.Field("f1", StaticType.unionOf(StaticType.INT2, StaticType.MISSING)), + StructType.Field("f2", StaticType.MISSING), + ) + ) + .build() + + val metadata = MemoryConnector(catalog).getMetadata( + object : ConnectorSession { + override fun getQueryId(): String = "q" + override fun getUserId(): String = "s" + } + ) + + val session = PartiQLPlanner.Session( + queryId = queryId, + userId = userId, + currentCatalog = catalogName, + catalogs = mapOf(catalogName to metadata), + ) + + val parser = PartiQLParserBuilder().build() + + val statement: ((String) -> Statement) = { query -> + parser.parse(query).root + } + + fun assertProblem( + plan: org.partiql.plan.PlanNode, + problems: List, + vararg block: () -> Boolean + ) { + block.forEachIndexed { index, function -> + assert(function.invoke()) { + buildString { + this.appendLine("assertion #${index + 1} failed") + + this.appendLine("--------Plan---------") + PlanPrinter.append(this, plan) + + this.appendLine("----------problems---------") + problems.forEach { + this.appendLine(it.toString()) + } + } + } + } + } + + data class TestCase( + val query: String, + val isSignal: Boolean, + val assertion: (List) -> List<() -> Boolean>, + val expectedType: StaticType = StaticType.MISSING + ) + + companion object { + fun closedStruct(vararg field: StructType.Field): StructType = + StructType( + field.toList(), + contentClosed = true, + emptyList(), + setOf( + TupleConstraint.Open(false), + TupleConstraint.UniqueAttrs(true), + TupleConstraint.Ordered + ) + ) + + private fun assertOnProblemCount(warningCount: Int, errorCount: Int): (List) -> List<() -> Boolean> = { problems -> + listOf( + { problems.filter { it.details.severity == ProblemSeverity.WARNING }.size == warningCount }, + { problems.filter { it.details.severity == ProblemSeverity.ERROR }.size == errorCount }, + ) + } + + /** + * Those tests focus on MissingOpBehavior. + */ + @JvmStatic + fun testProblems() = listOf( + // Literal MISSING Does not throw warnings or errors in either mode. + TestCase( + "MISSING", + false, + assertOnProblemCount(0, 0) + ), + TestCase( + "MISSING", + true, + assertOnProblemCount(0, 0) + ), + // Unresolved variable + TestCase( + "var_not_exist", + false, + assertOnProblemCount(1, 0) + ), + TestCase( + "var_not_exist", + true, + assertOnProblemCount(0, 1) + ), + + // Function propagates missing in quite mode + TestCase( + "1 + MISSING", + false, + assertOnProblemCount(1, 0) + ), + // This will be a non-resolved function error. + // As plus does not contain a function that match argument type with + // int32 and missing. + // Error in signaling mode. + TestCase( + "1 + MISSING", + true, + assertOnProblemCount(0, 1) + ), + // Attempting to do path navigation(symbol) on missing(which is not tuple) + // returns missing in quite mode, and error out in signal mode + TestCase( + "MISSING.a", + false, + assertOnProblemCount(1, 0) + ), + TestCase( + "MISSING.a", + true, + assertOnProblemCount(0, 1) + ), + // Attempting to do path navigation(index) on missing(which is not list) + // returns missing in quite mode, and error out in signal mode + TestCase( + "MISSING[1]", + false, + assertOnProblemCount(1, 0) + ), + TestCase( + "MISSING[1]", + true, + assertOnProblemCount(0, 1) + ), + // Attempting to do path navigation(key) on missing(which is tuple) + // returns missing in quite mode, and error out in signal mode + TestCase( + "MISSING['a']", + false, + assertOnProblemCount(1, 0) + ), + TestCase( + "MISSING['a']", + true, + assertOnProblemCount(0, 1) + ), + // Chained, demostrate missing trace. + TestCase( + "MISSING['a'].a", + false, + assertOnProblemCount(2, 0) + ), + TestCase( + "MISSING['a'].a", + true, + assertOnProblemCount(0, 2) + ), + TestCase( + """ + -- one branch is missing, no problem + CASE WHEN + 1 = 1 THEN MISSING + ELSE 2 END + """.trimIndent(), + false, + assertOnProblemCount(0, 0), + StaticType.unionOf(StaticType.INT4, StaticType.MISSING) + ), + TestCase( + """ + -- one branch is missing, no problem + CASE WHEN + 1 = 1 THEN MISSING + ELSE 2 END + """.trimIndent(), + true, + assertOnProblemCount(0, 0), + StaticType.unionOf(StaticType.INT4, StaticType.MISSING) + ), + TestCase( + """ + -- both branches are missing, problem + CASE WHEN + 1 = 1 THEN MISSING + ELSE MISSING END + """.trimIndent(), + false, + assertOnProblemCount(1, 0), + ), + TestCase( + """ + -- both branches are missing, problem + CASE WHEN + 1 = 1 THEN MISSING + ELSE MISSING END + """.trimIndent(), + true, + assertOnProblemCount(0, 1), + ), + ) + + /** + * Those tests focus on continuation + */ + @JvmStatic + fun testContinuation() = listOf( + // Continuation with data type mismatch + // the expected type for this case is missing. + // as we know for sure that a + b returns missing. + TestCase( + " 'a' + 'b' ", + false, + assertOnProblemCount(1, 0), + StaticType.MISSING + ), + TestCase( + " 'a' + 'b' ", + true, + assertOnProblemCount(0, 1), + StaticType.MISSING + ), + + // No function with given name is registered. + // always going to return error regardless of mode. + // The expected type for continuation is ANY. + TestCase( + "not_a_function(1)", + false, + assertOnProblemCount(0, 1), + StaticType.ANY + ), + TestCase( + "not_a_function(1)", + true, + assertOnProblemCount(0, 1), + StaticType.ANY + ), + + // 1 + not_a_function(1) + // The continuation will return all numeric type + TestCase( + "1 + not_a_function(1)", + false, + assertOnProblemCount(0, 1), + StaticType.unionOf( + StaticType.INT4, + StaticType.INT8, + StaticType.INT8, + StaticType.INT, + StaticType.FLOAT, + StaticType.DECIMAL, // Parameter? + StaticType.MISSING, +// StaticType.NULL // TODO: There is a bug in function resolution, null type is not there. + ).flatten() + ), + TestCase( + "1 + not_a_function(1)", + false, + assertOnProblemCount(0, 1), + StaticType.unionOf( + StaticType.INT4, + StaticType.INT8, + StaticType.INT8, + StaticType.INT, + StaticType.FLOAT, + StaticType.DECIMAL, // Parameter? + StaticType.MISSING, +// StaticType.NULL // TODO: There is a bug in function resolution, null type is not there. + ).flatten() + ), + + TestCase( + """ + SELECT + t.f1, -- SUCCESS + t.f2 -- no such field + FROM struct_no_missing as t + """.trimIndent(), + false, + assertOnProblemCount(1, 0), + BagType(closedStruct(StructType.Field("f1", StaticType.INT2))) + ), + TestCase( + """ + SELECT + t.f1, -- SUCCESS + t.f2 -- no such field + FROM struct_no_missing as t + """.trimIndent(), + true, + assertOnProblemCount(0, 1), + BagType(closedStruct(StructType.Field("f1", StaticType.INT2))) + ), + TestCase( + """ + SELECT + t.f1, -- OK + t.f2, -- always missing + t.f3 -- no such field + FROM struct_with_missing as t + """.trimIndent(), + false, + assertOnProblemCount(2, 0), + BagType(closedStruct(StructType.Field("f1", StaticType.unionOf(StaticType.INT2, StaticType.MISSING)))) + ), + TestCase( + """ + SELECT + t.f1, -- OK + t.f2, -- always missing + t.f3 -- no such field + FROM struct_with_missing as t + """.trimIndent(), + true, + assertOnProblemCount(0, 2), + BagType(closedStruct(StructType.Field("f1", StaticType.unionOf(StaticType.INT2, StaticType.MISSING)))) + ), + + // TODO: EXCLUDE ERROR reporting is not completed. + // Currently we only handle root resolution. + // i.e., if the root of the exclude path is not resolved, + // we can report the problem. + // but we have not yet handled the situation in which + // the root is resolvable but the path is not. + TestCase( + """ + SELECT * + EXCLUDE t1.f1 -- no such root + FROM struct_no_missing as t + """.trimIndent(), + false, + assertOnProblemCount(1, 0), + BagType(closedStruct(StructType.Field("f1", StaticType.INT2))) + ), + TestCase( + """ + SELECT * + EXCLUDE t1.f1 -- no such root + FROM struct_no_missing as t + """.trimIndent(), + true, + assertOnProblemCount(0, 1), + BagType(closedStruct(StructType.Field("f1", StaticType.INT2))) + ), +// TestCase( +// """ +// SELECT * +// EXCLUDE t.f2 -- no such field +// FROM struct_no_missing as t +// """.trimIndent(), +// false, +// assertOnProblemCount(1, 0), +// BagType(closedStruct(StructType.Field("f1", StaticType.INT2))) +// ), +// TestCase( +// """ +// SELECT * +// EXCLUDE t.f2 -- no such field +// FROM struct_no_missing as t +// """.trimIndent(), +// true, +// assertOnProblemCount(0, 1), +// BagType(closedStruct(StructType.Field("f1", StaticType.INT2))) +// ), + ) + } + + private fun runTestCase(tc: TestCase) { + val planner = when (tc.isSignal) { + true -> PartiQLPlanner.builder().signalMode().build() + else -> PartiQLPlanner.builder().build() + } + val pc = ProblemCollector() + val res = planner.plan(statement(tc.query), session, pc) + val problems = pc.problems + val plan = res.plan + + assertProblem( + plan, problems, + *tc.assertion(problems).toTypedArray() + ) + tc.expectedType.assertStaticTypeEqual((plan.statement as org.partiql.plan.Statement.Query).root.type) + } + + @ParameterizedTest + @MethodSource("testProblems") + fun testProblems(tc: TestCase) = runTestCase(tc) + + @ParameterizedTest + @MethodSource("testContinuation") + fun testContinuation(tc: TestCase) = runTestCase(tc) + + private fun StaticType.assertStaticTypeEqual(other: StaticType) { + val thisAll = this.allTypes.toSet() + val otherAll = other.allTypes.toSet() + val diff = (thisAll - otherAll) + (otherAll - thisAll) + assert(diff.isEmpty()) { + buildString { + this.appendLine("expected: ") + thisAll.forEach { + this.append("$it, ") + } + this.appendLine() + this.appendLine("actual") + otherAll.forEach { + this.append("$it, ") + } + this.appendLine() + this.appendLine("diff") + diff.forEach { + this.append("$it, ") + } + } + } + } +} diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PartiQLTyperTestBase.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PartiQLTyperTestBase.kt index c5eef7767f..262885378a 100644 --- a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PartiQLTyperTestBase.kt +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PartiQLTyperTestBase.kt @@ -7,7 +7,7 @@ import org.partiql.parser.PartiQLParser import org.partiql.plan.Statement import org.partiql.plan.debug.PlanPrinter import org.partiql.planner.PartiQLPlanner -import org.partiql.planner.PlanningProblemDetails +import org.partiql.planner.internal.PlanningProblemDetails import org.partiql.planner.test.PartiQLTest import org.partiql.planner.test.PartiQLTestProvider import org.partiql.planner.util.ProblemCollector @@ -46,7 +46,7 @@ abstract class PartiQLTyperTestBase { currentCatalog = catalog, catalogs = mapOf( catalog to metadata - ) + ), ) } diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTest.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTest.kt index c2ae2a310a..4108f18c42 100644 --- a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTest.kt +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTest.kt @@ -171,7 +171,7 @@ class PlanTyperTest { ) private fun getTyper(): PlanTyperWrapper { - val collector = ProblemCollector() + ProblemCollector() val env = Env( PartiQLPlanner.Session( queryId = Random().nextInt().toString(), @@ -183,13 +183,12 @@ class PlanTyperTest { ), ) ) - return PlanTyperWrapper(PlanTyper(env, collector), collector) + return PlanTyperWrapper(PlanTyper(env)) } } private class PlanTyperWrapper( internal val typer: PlanTyper, - internal val collector: ProblemCollector, ) /** diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTestsPorted.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTestsPorted.kt index a7427f3b92..aa9e289637 100644 --- a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTestsPorted.kt +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTestsPorted.kt @@ -11,13 +11,12 @@ import org.junit.jupiter.params.provider.ArgumentsProvider import org.junit.jupiter.params.provider.ArgumentsSource import org.junit.jupiter.params.provider.MethodSource import org.partiql.errors.Problem -import org.partiql.errors.UNKNOWN_PROBLEM_LOCATION import org.partiql.parser.PartiQLParser import org.partiql.plan.PartiQLPlan import org.partiql.plan.Statement import org.partiql.plan.debug.PlanPrinter import org.partiql.planner.PartiQLPlanner -import org.partiql.planner.PlanningProblemDetails +import org.partiql.planner.internal.ProblemGenerator import org.partiql.planner.internal.typer.PlanTyperTestsPorted.TestCase.ErrorTestCase import org.partiql.planner.internal.typer.PlanTyperTestsPorted.TestCase.SuccessTestCase import org.partiql.planner.internal.typer.PlanTyperTestsPorted.TestCase.ThrowingExceptionTestCase @@ -92,12 +91,14 @@ class PlanTyperTestsPorted { companion object { private val parser = PartiQLParser.default() - private val planner = PartiQLPlanner.default() + private val planner = PartiQLPlanner.builder().signalMode().build() - private fun assertProblemExists(problem: () -> Problem) = ProblemHandler { problems, ignoreSourceLocation -> + private fun assertProblemExists(problem: Problem) = ProblemHandler { problems, ignoreSourceLocation -> when (ignoreSourceLocation) { - true -> assertTrue("Expected to find ${problem.invoke()} in $problems") { problems.any { it.details == problem.invoke().details } } - false -> assertTrue("Expected to find ${problem.invoke()} in $problems") { problems.any { it == problem.invoke() } } + true -> assertTrue("Expected to find $problem in $problems") { + problems.any { it.details == problem.details } + } + false -> assertTrue("Expected to find $problem in $problems") { problems.any { it == problem } } } } @@ -659,18 +660,15 @@ class PlanTyperTestsPorted { name = "Current User (String) PLUS String", query = "CURRENT_USER + 'hello'", expected = StaticType.MISSING, - problemHandler = assertProblemExists { - Problem( - UNKNOWN_PROBLEM_LOCATION, - PlanningProblemDetails.UnknownFunction( - "plus", - listOf( - StaticType.unionOf(StaticType.STRING, StaticType.NULL), - StaticType.STRING, - ), - ) + problemHandler = assertProblemExists( + ProblemGenerator.incompatibleTypesForOp( + listOf( + StaticType.unionOf(StaticType.STRING, StaticType.NULL), + StaticType.STRING, + ), + "PLUS", ) - } + ) ), ) @@ -735,29 +733,23 @@ class PlanTyperTestsPorted { name = "BITWISE_AND_MISSING_OPERAND", query = "1 & MISSING", expected = StaticType.MISSING, - problemHandler = assertProblemExists { - Problem( - sourceLocation = UNKNOWN_PROBLEM_LOCATION, - details = PlanningProblemDetails.UnknownFunction( - "bitwise_and", - listOf(StaticType.INT4, StaticType.MISSING) - ) + problemHandler = assertProblemExists( + ProblemGenerator.incompatibleTypesForOp( + listOf(StaticType.INT4, StaticType.MISSING), + "BITWISE_AND", ) - } + ) ), ErrorTestCase( name = "BITWISE_AND_NON_INT_OPERAND", query = "1 & 'NOT AN INT'", expected = StaticType.MISSING, - problemHandler = assertProblemExists { - Problem( - UNKNOWN_PROBLEM_LOCATION, - PlanningProblemDetails.UnknownFunction( - "bitwise_and", - listOf(StaticType.INT4, StaticType.STRING) - ) + problemHandler = assertProblemExists( + ProblemGenerator.incompatibleTypesForOp( + listOf(StaticType.INT4, StaticType.STRING), + "BITWISE_AND", ) - } + ) ), ) @@ -908,12 +900,9 @@ class PlanTyperTestsPorted { ) ) ), - problemHandler = assertProblemExists { - Problem( - UNKNOWN_PROBLEM_LOCATION, - PlanningProblemDetails.UndefinedVariable(insensitiveId("a")) - ) - } + problemHandler = assertProblemExists( + ProblemGenerator.undefinedVariable(insensitiveId("a")) + ) ), ) @@ -2012,12 +2001,9 @@ class PlanTyperTestsPorted { ) ) ), - problemHandler = assertProblemExists { - Problem( - UNKNOWN_PROBLEM_LOCATION, - PlanningProblemDetails.UnresolvedExcludeExprRoot("nonsense") - ) - } + problemHandler = assertProblemExists( + ProblemGenerator.unresolvedExcludedExprRoot("nonsense") + ) ), // EXCLUDE regression test (behavior subject to change pending RFC); could give error/warning SuccessTestCase( @@ -2126,12 +2112,9 @@ class PlanTyperTestsPorted { catalogPath = listOf("ddb"), query = "SELECT * FROM pets ORDER BY unknown_col", expected = TABLE_AWS_DDB_PETS_LIST, - problemHandler = assertProblemExists { - Problem( - UNKNOWN_PROBLEM_LOCATION, - PlanningProblemDetails.UndefinedVariable(insensitiveId("unknown_col")) - ) - } + problemHandler = assertProblemExists( + ProblemGenerator.undefinedVariable(insensitiveId("unknown_col")) + ) ), ) @@ -2646,12 +2629,11 @@ class PlanTyperTestsPorted { SELECT VALUE 1 FROM "pql"."main"['employer'] AS e; """, expected = BagType(StaticType.INT4), - problemHandler = assertProblemExists { - Problem( - UNKNOWN_PROBLEM_LOCATION, - PlanningProblemDetails.UndefinedVariable(idQualified("pql" to BindingCase.SENSITIVE, "main" to BindingCase.SENSITIVE)) + problemHandler = assertProblemExists( + ProblemGenerator.undefinedVariable( + idQualified("pql" to BindingCase.SENSITIVE, "main" to BindingCase.SENSITIVE) ) - } + ) ), ErrorTestCase( name = "Show that we can't use [] to reference a schema in a catalog. It can only be used on tuples.", @@ -2659,12 +2641,9 @@ class PlanTyperTestsPorted { SELECT VALUE 1 FROM "pql"['main']."employer" AS e; """, expected = BagType(StaticType.INT4), - problemHandler = assertProblemExists { - Problem( - UNKNOWN_PROBLEM_LOCATION, - PlanningProblemDetails.UndefinedVariable(sensitiveId("pql")) - ) - } + problemHandler = assertProblemExists( + ProblemGenerator.undefinedVariable(sensitiveId("pql")) + ) ), SuccessTestCase( name = "Tuple indexing syntax on literal tuple with literal string key", @@ -2680,12 +2659,9 @@ class PlanTyperTestsPorted { { 'aBc': 1, 'AbC': 2.0 }['Ab' || 'C']; """, expected = StaticType.MISSING, - problemHandler = assertProblemExists { - Problem( - sourceLocation = UNKNOWN_PROBLEM_LOCATION, - details = PlanningProblemDetails.ExpressionAlwaysReturnsNullOrMissing - ) - } + problemHandler = assertProblemExists( + ProblemGenerator.expressionAlwaysReturnsMissing("Path Navigation always returns MISSING") + ) ), // The reason this is ANY is because we do not have support for constant-folding. We don't know what // CAST('Ab' || 'C' AS STRING) will evaluate to, and therefore, we don't know what the indexing operation @@ -2964,15 +2940,12 @@ class PlanTyperTestsPorted { >> AS t """.trimIndent(), expected = BagType(StaticType.MISSING), - problemHandler = assertProblemExists { - Problem( - sourceLocation = UNKNOWN_PROBLEM_LOCATION, - details = PlanningProblemDetails.UnknownFunction( - "pos", - listOf(StaticType.STRING) - ) + problemHandler = assertProblemExists( + ProblemGenerator.incompatibleTypesForOp( + listOf(StaticType.STRING), + "POS", ) - } + ) ), ErrorTestCase( name = """ @@ -2987,15 +2960,12 @@ class PlanTyperTestsPorted { >> AS t """.trimIndent(), expected = BagType(StaticType.MISSING), - problemHandler = assertProblemExists { - Problem( - sourceLocation = UNKNOWN_PROBLEM_LOCATION, - details = PlanningProblemDetails.UnknownFunction( - "pos", - listOf(StaticType.unionOf(StaticType.STRING, StaticType.BAG)) - ) + problemHandler = assertProblemExists( + ProblemGenerator.incompatibleTypesForOp( + listOf(StaticType.unionOf(StaticType.STRING, StaticType.BAG)), + "POS", ) - } + ) ), ErrorTestCase( name = """ @@ -3010,12 +2980,12 @@ class PlanTyperTestsPorted { """.trimIndent(), expected = BagType(StaticType.MISSING), // This is because we don't attempt to resolve function when args are error - problemHandler = assertProblemExists { - Problem( - sourceLocation = UNKNOWN_PROBLEM_LOCATION, - details = PlanningProblemDetails.ExpressionAlwaysReturnsNullOrMissing + problemHandler = assertProblemExists( + ProblemGenerator.incompatibleTypesForOp( + listOf(StaticType.MISSING), + "POS", ) - } + ) ), ErrorTestCase( name = """ @@ -3026,15 +2996,12 @@ class PlanTyperTestsPorted { +MISSING """.trimIndent(), expected = StaticType.MISSING, - problemHandler = assertProblemExists { - Problem( - sourceLocation = UNKNOWN_PROBLEM_LOCATION, - details = PlanningProblemDetails.UnknownFunction( - "pos", - listOf(StaticType.MISSING) - ) + problemHandler = assertProblemExists( + ProblemGenerator.incompatibleTypesForOp( + listOf(StaticType.MISSING), + "POS", ) - } + ) ), ) @@ -3250,7 +3217,7 @@ class PlanTyperTestsPorted { USER_ID, tc.catalog, tc.catalogPath, - catalogs = mapOf(*catalogs.toTypedArray()) + catalogs = mapOf(*catalogs.toTypedArray()), ) val hasQuery = tc.query != null @@ -3291,7 +3258,7 @@ class PlanTyperTestsPorted { USER_ID, tc.catalog, tc.catalogPath, - catalogs = mapOf(*catalogs.toTypedArray()) + catalogs = mapOf(*catalogs.toTypedArray()), ) val collector = ProblemCollector() @@ -3337,7 +3304,7 @@ class PlanTyperTestsPorted { USER_ID, tc.catalog, tc.catalogPath, - catalogs = mapOf(*catalogs.toTypedArray()) + catalogs = mapOf(*catalogs.toTypedArray()), ) val collector = ProblemCollector() val exception = assertThrows { @@ -3372,12 +3339,9 @@ class PlanTyperTestsPorted { ) ), ), - problemHandler = assertProblemExists { - Problem( - UNKNOWN_PROBLEM_LOCATION, - PlanningProblemDetails.UndefinedVariable(insensitiveId("pets")) - ) - } + problemHandler = assertProblemExists( + ProblemGenerator.undefinedVariable(insensitiveId("pets")) + ) ), TestCase.ErrorTestCase( name = "Pets should not be accessible #2", @@ -3394,12 +3358,9 @@ class PlanTyperTestsPorted { ) ), ), - problemHandler = assertProblemExists { - Problem( - UNKNOWN_PROBLEM_LOCATION, - PlanningProblemDetails.UndefinedVariable(insensitiveId("pets")) - ) - } + problemHandler = assertProblemExists( + ProblemGenerator.undefinedVariable(insensitiveId("pets")) + ) ), TestCase.SuccessTestCase( name = "Project all explicitly", @@ -3450,19 +3411,16 @@ class PlanTyperTestsPorted { ) ), ), - problemHandler = assertProblemExists { - Problem( - UNKNOWN_PROBLEM_LOCATION, - PlanningProblemDetails.UndefinedVariable( - BindingPath( - steps = listOf( - BindingName("ddb", BindingCase.INSENSITIVE), - BindingName("pets", BindingCase.INSENSITIVE), - ) + problemHandler = assertProblemExists( + ProblemGenerator.undefinedVariable( + BindingPath( + steps = listOf( + BindingName("ddb", BindingCase.INSENSITIVE), + BindingName("pets", BindingCase.INSENSITIVE), ) ) ) - } + ) ), TestCase.SuccessTestCase( name = "Test #10", @@ -3616,15 +3574,12 @@ class PlanTyperTestsPorted { catalogPath = DB_SCHEMA_MARKETS, query = "order_info.customer_id IN 'hello'", expected = StaticType.MISSING, - problemHandler = assertProblemExists { - Problem( - UNKNOWN_PROBLEM_LOCATION, - PlanningProblemDetails.UnknownFunction( - "in_collection", - listOf(StaticType.INT4, StaticType.STRING), - ) + problemHandler = assertProblemExists( + ProblemGenerator.incompatibleTypesForOp( + listOf(StaticType.INT4, StaticType.STRING), + "IN_COLLECTION" ) - } + ) ), SuccessTestCase( name = "BETWEEN", @@ -3639,19 +3594,16 @@ class PlanTyperTestsPorted { catalogPath = DB_SCHEMA_MARKETS, query = "order_info.customer_id BETWEEN 1 AND 'a'", expected = StaticType.MISSING, - problemHandler = assertProblemExists { - Problem( - UNKNOWN_PROBLEM_LOCATION, - PlanningProblemDetails.UnknownFunction( - "between", - listOf( - StaticType.INT4, - StaticType.INT4, - StaticType.STRING - ), - ) + problemHandler = assertProblemExists( + ProblemGenerator.incompatibleTypesForOp( + listOf( + StaticType.INT4, + StaticType.INT4, + StaticType.STRING + ), + "BETWEEN" ) - } + ) ), SuccessTestCase( name = "LIKE", @@ -3666,15 +3618,12 @@ class PlanTyperTestsPorted { catalogPath = DB_SCHEMA_MARKETS, query = "order_info.ship_option LIKE 3", expected = StaticType.MISSING, - problemHandler = assertProblemExists { - Problem( - UNKNOWN_PROBLEM_LOCATION, - PlanningProblemDetails.UnknownFunction( - "like", - listOf(StaticType.STRING, StaticType.INT4), - ) + problemHandler = assertProblemExists( + ProblemGenerator.incompatibleTypesForOp( + listOf(StaticType.STRING, StaticType.INT4), + "LIKE", ) - } + ) ), SuccessTestCase( name = "Case Insensitive success", @@ -3684,12 +3633,13 @@ class PlanTyperTestsPorted { expected = TYPE_BOOL ), // MISSING = 1 + // TODO: Semantic not finalized ErrorTestCase( name = "Case Sensitive failure", catalog = CATALOG_DB, catalogPath = DB_SCHEMA_MARKETS, query = "order_info.\"CUSTOMER_ID\" = 1", - expected = StaticType.MISSING + expected = TYPE_BOOL ), SuccessTestCase( name = "Case Sensitive success", @@ -3724,15 +3674,15 @@ class PlanTyperTestsPorted { catalog = CATALOG_DB, catalogPath = DB_SCHEMA_MARKETS, query = "non_existing_column = 1", + // non_existing_column get typed as missing // Function resolves to EQ__ANY_ANY__BOOL // Which can return BOOL Or NULL - expected = StaticType.MISSING, - problemHandler = assertProblemExists { - Problem( - UNKNOWN_PROBLEM_LOCATION, - PlanningProblemDetails.UndefinedVariable(insensitiveId("non_existing_column")) + expected = TYPE_BOOL, + problemHandler = assertProblemExists( + ProblemGenerator.undefinedVariable( + insensitiveId("non_existing_column") ) - } + ) ), ErrorTestCase( name = "Bad comparison", @@ -3740,15 +3690,12 @@ class PlanTyperTestsPorted { catalogPath = DB_SCHEMA_MARKETS, query = "order_info.customer_id = 1 AND 1", expected = StaticType.MISSING, - problemHandler = assertProblemExists { - Problem( - UNKNOWN_PROBLEM_LOCATION, - PlanningProblemDetails.UnknownFunction( - "and", - listOf(StaticType.BOOL, StaticType.INT4), - ) + problemHandler = assertProblemExists( + ProblemGenerator.incompatibleTypesForOp( + listOf(StaticType.BOOL, StaticType.INT4), + "AND", ) - } + ) ), ErrorTestCase( name = "Bad comparison", @@ -3756,15 +3703,12 @@ class PlanTyperTestsPorted { catalogPath = DB_SCHEMA_MARKETS, query = "1 AND order_info.customer_id = 1", expected = StaticType.MISSING, - problemHandler = assertProblemExists { - Problem( - UNKNOWN_PROBLEM_LOCATION, - PlanningProblemDetails.UnknownFunction( - "and", - listOf(StaticType.INT4, StaticType.BOOL), - ) + problemHandler = assertProblemExists( + ProblemGenerator.incompatibleTypesForOp( + listOf(StaticType.INT4, StaticType.BOOL), + "AND", ) - } + ) ), ErrorTestCase( name = "Unknown column", @@ -3782,12 +3726,9 @@ class PlanTyperTestsPorted { ) ) ), - problemHandler = assertProblemExists { - Problem( - UNKNOWN_PROBLEM_LOCATION, - PlanningProblemDetails.UndefinedVariable(insensitiveId("unknown_col")) - ) - } + problemHandler = assertProblemExists( + ProblemGenerator.undefinedVariable(insensitiveId("unknown_col")) + ) ), SuccessTestCase( name = "LIMIT INT", @@ -3802,12 +3743,11 @@ class PlanTyperTestsPorted { catalogPath = listOf("ddb"), query = "SELECT * FROM pets LIMIT '5'", expected = TABLE_AWS_DDB_PETS, - problemHandler = assertProblemExists { - Problem( - UNKNOWN_PROBLEM_LOCATION, - PlanningProblemDetails.UnexpectedType(StaticType.STRING, setOf(StaticType.INT)) + problemHandler = assertProblemExists( + ProblemGenerator.unexpectedType( + StaticType.STRING, setOf(StaticType.INT) ) - } + ) ), SuccessTestCase( name = "OFFSET INT", @@ -3822,12 +3762,9 @@ class PlanTyperTestsPorted { catalogPath = listOf("ddb"), query = "SELECT * FROM pets LIMIT 1 OFFSET '5'", expected = TABLE_AWS_DDB_PETS, - problemHandler = assertProblemExists { - Problem( - UNKNOWN_PROBLEM_LOCATION, - PlanningProblemDetails.UnexpectedType(StaticType.STRING, setOf(StaticType.INT)) - ) - } + problemHandler = assertProblemExists( + ProblemGenerator.unexpectedType(StaticType.STRING, setOf(StaticType.INT)) + ) ), SuccessTestCase( name = "CAST", @@ -3993,15 +3930,12 @@ class PlanTyperTestsPorted { name = "TRIM_2_error", query = "trim(2 FROM ' Hello, World! ')", expected = StaticType.MISSING, - problemHandler = assertProblemExists { - Problem( - UNKNOWN_PROBLEM_LOCATION, - PlanningProblemDetails.UnknownFunction( - "trim_chars", - args = listOf(StaticType.STRING, StaticType.INT4) - ) + problemHandler = assertProblemExists( + ProblemGenerator.incompatibleTypesForOp( + listOf(StaticType.STRING, StaticType.INT4), + "TRIM_CHARS", ) - } + ) ), ) } diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/util/PlanNodeEquivalentVisitor.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/util/PlanNodeEquivalentVisitor.kt index 9434cf2523..ffe582f45f 100644 --- a/partiql-planner/src/test/kotlin/org/partiql/planner/util/PlanNodeEquivalentVisitor.kt +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/util/PlanNodeEquivalentVisitor.kt @@ -61,8 +61,8 @@ class PlanNodeEquivalentVisitor : PlanBaseVisitor() { override fun visitRexOpErr(node: Rex.Op.Err, ctx: PlanNode): Boolean { if (!super.visitRexOpErr(node, ctx)) return false - ctx as Rex.Op.Err - if (node.message != ctx.message) return false +// ctx as Rex.Op.Err +// if (node.message != ctx.message) return false return true } @@ -96,8 +96,8 @@ class PlanNodeEquivalentVisitor : PlanBaseVisitor() { override fun visitRelOpErr(node: Rel.Op.Err, ctx: PlanNode): Boolean { if (!super.visitRelOpErr(node, ctx)) return false - ctx as Rel.Op.Err - if (node.message != ctx.message) return false +// ctx as Rel.Op.Err +// if (node.message != ctx.message) return false return true } @@ -109,12 +109,16 @@ class PlanNodeEquivalentVisitor : PlanBaseVisitor() { return true } - override fun defaultReturn(node: PlanNode, ctx: PlanNode): Boolean { + override fun defaultVisit(node: PlanNode, ctx: PlanNode): Boolean { if (ctx.javaClass != node.javaClass) return false if (node.children.size != ctx.children.size) return false node.children.forEachIndexed { index, planNode -> - if (planNode.accept(this, ctx.children[index])) return false + if (!planNode.accept(this, ctx.children[index])) return false } return true } + + override fun defaultReturn(node: PlanNode, ctx: PlanNode): Boolean { + return false + } } diff --git a/partiql-planner/src/test/resources/outputs/basics/select.sql b/partiql-planner/src/test/resources/outputs/basics/select.sql index ceead5f36a..97b2079ccf 100644 --- a/partiql-planner/src/test/resources/outputs/basics/select.sql +++ b/partiql-planner/src/test/resources/outputs/basics/select.sql @@ -1,50 +1,50 @@ --#[select-00] -SELECT "T"['a'] AS "a", "T"['b'] AS "b", "T"['c'] AS "c" FROM "default"."T" AS "T"; +SELECT "T"['a'] AS "a", "T"['b'] AS "b", "T"['c'] AS "c" FROM "default"."SCHEMA"."T" AS "T"; --#[select-01] -SELECT "T".* FROM "default"."T" AS "T"; +SELECT "T".* FROM "default"."SCHEMA"."T" AS "T"; --#[select-02] -SELECT "T"['a'] AS "a", "T"['b'] AS "b", "T"['c'] AS "c" FROM "default"."T" AS "T"; +SELECT "T"['a'] AS "a", "T"['b'] AS "b", "T"['c'] AS "c" FROM "default"."SCHEMA"."T" AS "T"; --#[select-03] -SELECT VALUE "T"['a'] FROM "default"."T" AS "T"; +SELECT VALUE "T"['a'] FROM "default"."SCHEMA"."T" AS "T"; --#[select-04] -SELECT "t1".*, "t2".* FROM "default"."T" AS "t1" INNER JOIN "default"."T" AS "t2" ON true; +SELECT "t1".*, "t2".* FROM "default"."SCHEMA"."T" AS "t1" INNER JOIN "default"."SCHEMA"."T" AS "t2" ON true; --#[select-05] -SELECT "T"['d'].* FROM "default"."T" AS "T"; +SELECT "T"['d'].* FROM "default"."SCHEMA"."T" AS "T"; --#[select-06] -SELECT "T" AS "t", "T"['d'].* FROM "default"."T" AS "T"; +SELECT "T" AS "t", "T"['d'].* FROM "default"."SCHEMA"."T" AS "T"; --#[select-07] -SELECT "T"['d'].*, "T"['d'].* FROM "default"."T" AS "T"; +SELECT "T"['d'].*, "T"['d'].* FROM "default"."SCHEMA"."T" AS "T"; --#[select-08] -SELECT "T"['d'].* FROM "default"."T" AS "T"; +SELECT "T"['d'].* FROM "default"."SCHEMA"."T" AS "T"; --#[select-09] -SELECT "T".* FROM "default"."T" AS "T"; +SELECT "T".* FROM "default"."SCHEMA"."T" AS "T"; --#[select-10] -SELECT "T"['c'] || CURRENT_USER AS "_1" FROM "default"."T" AS "T"; +SELECT "T"['c'] || CURRENT_USER AS "_1" FROM "default"."SCHEMA"."T" AS "T"; --#[select-11] -SELECT CURRENT_USER AS "CURRENT_USER" FROM "default"."T" AS "T"; +SELECT CURRENT_USER AS "CURRENT_USER" FROM "default"."SCHEMA"."T" AS "T"; --#[select-12] -SELECT CURRENT_DATE AS "CURRENT_DATE" FROM "default"."T" AS "T"; +SELECT CURRENT_DATE AS "CURRENT_DATE" FROM "default"."SCHEMA"."T" AS "T"; --#[select-13] -SELECT DATE_DIFF(DAY, CURRENT_DATE, CURRENT_DATE) AS "_1" FROM "default"."T" AS "T"; +SELECT DATE_DIFF(DAY, CURRENT_DATE, CURRENT_DATE) AS "_1" FROM "default"."SCHEMA"."T" AS "T"; --#[select-14] -SELECT DATE_ADD(DAY, 5, CURRENT_DATE) AS "_1" FROM "default"."T" AS "T" +SELECT DATE_ADD(DAY, 5, CURRENT_DATE) AS "_1" FROM "default"."SCHEMA"."T" AS "T" --#[select-15] -SELECT DATE_ADD(DAY, -5, CURRENT_DATE) AS "_1" FROM "default"."T" AS "T" +SELECT DATE_ADD(DAY, -5, CURRENT_DATE) AS "_1" FROM "default"."SCHEMA"."T" AS "T" --#[select-16] -SELECT "t"['a'] AS "a" FROM "default"."T" AS "t"; \ No newline at end of file +SELECT "t"['a'] AS "a" FROM "default"."SCHEMA"."T" AS "t"; \ No newline at end of file