diff --git a/compiler/src/main/scala/com/rawlabs/compiler/CompilerService.scala b/compiler/src/main/scala/com/rawlabs/compiler/CompilerService.scala index 84b65e52e..876cc7b51 100644 --- a/compiler/src/main/scala/com/rawlabs/compiler/CompilerService.scala +++ b/compiler/src/main/scala/com/rawlabs/compiler/CompilerService.scala @@ -12,28 +12,12 @@ package com.rawlabs.compiler +import com.rawlabs.protocol.raw.{Type, Value} import org.graalvm.polyglot.Engine import java.io.OutputStream import scala.collection.mutable -import com.rawlabs.utils.core.{RawException, RawService, RawSettings} - -// Exception that wraps the underlying error so that it includes the extra debug info. -final class CompilerServiceException( - message: String, - val debugInfo: List[(String, String)] = List.empty, - cause: Throwable = null -) extends RawException(message, cause) { - - def this(t: Throwable, debugInfo: List[(String, String)]) = this(t.getMessage, debugInfo, t) - - def this(t: Throwable, environment: ProgramEnvironment) = { - this(t, CompilerService.getDebugInfo(environment)) - } - - def this(t: Throwable) = this(t.getMessage, cause = t) - -} +import com.rawlabs.utils.core.{RawService, RawSettings} object CompilerService { @@ -90,19 +74,6 @@ object CompilerService { } } - def getDebugInfo(environment: ProgramEnvironment): List[(String, String)] = { - List( - "Trace ID" -> environment.maybeTraceId.getOrElse(""), - "Arguments" -> environment.maybeArguments - .map(args => args.map { case (k, v) => s"$k -> $v" }.mkString("\n")) - .getOrElse(""), - "Uid" -> environment.uid.toString, - "Scopes" -> environment.scopes.mkString(","), - "Options" -> environment.options.map { case (k, v) => s"$k -> $v" }.mkString("\n") - //"Settings" -> runtimeContext.settings.toString - ) - } - } trait CompilerService extends RawService { @@ -112,24 +83,27 @@ trait CompilerService extends RawService { def language: Set[String] // Get the description of a source program. - @throws[CompilerServiceException] def getProgramDescription( source: String, environment: ProgramEnvironment - ): GetProgramDescriptionResponse + ): Either[List[ErrorMessage], ProgramDescription] // Execute a source program and write the results to the output stream. - @throws[CompilerServiceException] def execute( source: String, environment: ProgramEnvironment, maybeDecl: Option[String], outputStream: OutputStream, maxRows: Option[Long] = None - ): ExecutionResponse + ): Either[ExecutionError, ExecutionSuccess] + + def eval( + source: String, + environment: ProgramEnvironment, + maybeDecl: Option[String] + ): Either[ExecutionError, EvalSuccess] // Format a source program. - @throws[CompilerServiceException] def formatCode( source: String, environment: ProgramEnvironment, @@ -138,7 +112,6 @@ trait CompilerService extends RawService { ): FormatCodeResponse // Auto-complete a source program. - @throws[CompilerServiceException] def dotAutoComplete( source: String, environment: ProgramEnvironment, @@ -146,7 +119,6 @@ trait CompilerService extends RawService { ): AutoCompleteResponse // Auto-complete a word in a source program. - @throws[CompilerServiceException] def wordAutoComplete( source: String, environment: ProgramEnvironment, @@ -155,38 +127,38 @@ trait CompilerService extends RawService { ): AutoCompleteResponse // Get the hover information for a source program. - @throws[CompilerServiceException] def hover(source: String, environment: ProgramEnvironment, position: Pos): HoverResponse // Rename an identifier in a source program. - @throws[CompilerServiceException] def rename(source: String, environment: ProgramEnvironment, position: Pos): RenameResponse // Go to definition of an identifier in a source program. - @throws[CompilerServiceException] def goToDefinition(source: String, environment: ProgramEnvironment, position: Pos): GoToDefinitionResponse // Validate a source program. - @throws[CompilerServiceException] def validate(source: String, environment: ProgramEnvironment): ValidateResponse // Validate a source program for the AI service. - @throws[CompilerServiceException] def aiValidate(source: String, environment: ProgramEnvironment): ValidateResponse } final case class Pos(line: Int, column: Int) -sealed trait GetProgramDescriptionResponse -final case class GetProgramDescriptionFailure(errors: List[ErrorMessage]) extends GetProgramDescriptionResponse -final case class GetProgramDescriptionSuccess(programDescription: ProgramDescription) - extends GetProgramDescriptionResponse +final case class ExecutionSuccess(complete: Boolean) -sealed trait ExecutionResponse -final case class ExecutionSuccess(complete: Boolean) extends ExecutionResponse -final case class ExecutionValidationFailure(errors: List[ErrorMessage]) extends ExecutionResponse -final case class ExecutionRuntimeFailure(error: String) extends ExecutionResponse +sealed trait ExecutionError +object ExecutionError { + final case class ValidationError(errors: List[ErrorMessage]) extends ExecutionError + + final case class RuntimeError(error: String) extends ExecutionError +} + +sealed trait EvalSuccess +object EvalSuccess { + final case class IteratorValue(innerType: Type, valueIterator: Iterator[Value] with AutoCloseable) extends EvalSuccess + final case class ResultValue(valueType: Type, value: Value) extends EvalSuccess +} final case class FormatCodeResponse(code: Option[String]) final case class HoverResponse(completion: Option[Completion]) diff --git a/compiler/src/main/scala/com/rawlabs/compiler/RawValues.scala b/compiler/src/main/scala/com/rawlabs/compiler/RawValues.scala index 4798d7d80..7ab5d3939 100644 --- a/compiler/src/main/scala/com/rawlabs/compiler/RawValues.scala +++ b/compiler/src/main/scala/com/rawlabs/compiler/RawValues.scala @@ -64,5 +64,6 @@ final case class RawInterval( seconds: Int, millis: Int ) extends RawValue -final case class RawRecord(fields: Map[String, RawValue]) extends RawValue +final case class RawRecord(atts: Seq[RawRecordAttr]) extends RawValue +final case class RawRecordAttr(idn: String, value: RawValue) final case class RawList(values: List[RawValue]) extends RawValue diff --git a/compiler/src/main/scala/com/rawlabs/compiler/TypeConverter.scala b/compiler/src/main/scala/com/rawlabs/compiler/TypeConverter.scala new file mode 100644 index 000000000..7a251e825 --- /dev/null +++ b/compiler/src/main/scala/com/rawlabs/compiler/TypeConverter.scala @@ -0,0 +1,153 @@ +/* + * Copyright 2024 RAW Labs S.A. + * + * Use of this software is governed by the Business Source License + * included in the file licenses/BSL.txt. + * + * As of the Change Date specified in that file, in accordance with + * the Business Source License, use of this software will be governed + * by the Apache License, Version 2.0, included in the file + * licenses/APL.txt. + */ + +package com.rawlabs.compiler + +import com.rawlabs.protocol.raw.{ + AnyType, + AttrType, + BinaryType, + BoolType, + ByteType, + DateType, + DecimalType, + DoubleType, + FloatType, + IntType, + IntervalType, + IterableType, + ListType, + LongType, + OrType, + RecordType, + ShortType, + StringType, + TimeType, + TimestampType, + Type, + UndefinedType +} + +import scala.collection.JavaConverters._ + +object TypeConverter { + + /** + * Converts a raw compiler API type to a gRPC protocol type. + * + * @param t the raw compiler API type + * @return the gRPC protocol type + */ + def toProtocolType(t: com.rawlabs.compiler.RawType): Type = { + t match { + case com.rawlabs.compiler.RawAnyType() => Type.newBuilder().setAny(AnyType.newBuilder().build()).build() + case com.rawlabs.compiler.RawUndefinedType(nullable, triable) => Type + .newBuilder() + .setUndefined(UndefinedType.newBuilder().setNullable(nullable).setTriable(triable).build()) + .build() + case com.rawlabs.compiler.RawByteType(nullable, triable) => Type + .newBuilder() + .setByte(ByteType.newBuilder().setNullable(nullable).setTriable(triable).build()) + .build() + case com.rawlabs.compiler.RawShortType(nullable, triable) => Type + .newBuilder() + .setShort(ShortType.newBuilder().setNullable(nullable).setTriable(triable).build()) + .build() + case com.rawlabs.compiler.RawIntType(nullable, triable) => Type + .newBuilder() + .setInt(IntType.newBuilder().setNullable(nullable).setTriable(triable).build()) + .build() + case com.rawlabs.compiler.RawLongType(nullable, triable) => Type + .newBuilder() + .setLong(LongType.newBuilder().setNullable(nullable).setTriable(triable).build()) + .build() + case com.rawlabs.compiler.RawFloatType(nullable, triable) => Type + .newBuilder() + .setFloat(FloatType.newBuilder().setNullable(nullable).setTriable(triable).build()) + .build() + case com.rawlabs.compiler.RawDoubleType(nullable, triable) => Type + .newBuilder() + .setDouble(DoubleType.newBuilder().setNullable(nullable).setTriable(triable).build()) + .build() + case com.rawlabs.compiler.RawDecimalType(nullable, triable) => Type + .newBuilder() + .setDecimal(DecimalType.newBuilder().setNullable(nullable).setTriable(triable).build()) + .build() + case com.rawlabs.compiler.RawBoolType(nullable, triable) => Type + .newBuilder() + .setBool(BoolType.newBuilder().setNullable(nullable).setTriable(triable).build()) + .build() + case com.rawlabs.compiler.RawStringType(nullable, triable) => Type + .newBuilder() + .setString(StringType.newBuilder().setNullable(nullable).setTriable(triable).build()) + .build() + case com.rawlabs.compiler.RawBinaryType(nullable, triable) => Type + .newBuilder() + .setBinary(BinaryType.newBuilder().setNullable(nullable).setTriable(triable).build()) + .build() + case com.rawlabs.compiler.RawDateType(nullable, triable) => Type + .newBuilder() + .setDate(DateType.newBuilder().setNullable(nullable).setTriable(triable).build()) + .build() + case com.rawlabs.compiler.RawTimeType(nullable, triable) => Type + .newBuilder() + .setTime(TimeType.newBuilder().setNullable(nullable).setTriable(triable).build()) + .build() + case com.rawlabs.compiler.RawTimestampType(nullable, triable) => Type + .newBuilder() + .setTimestamp(TimestampType.newBuilder().setNullable(nullable).setTriable(triable).build()) + .build() + case com.rawlabs.compiler.RawIntervalType(nullable, triable) => Type + .newBuilder() + .setInterval(IntervalType.newBuilder().setNullable(nullable).setTriable(triable).build()) + .build() + case com.rawlabs.compiler.RawRecordType(atts, nullable, triable) => Type + .newBuilder() + .setRecord( + RecordType + .newBuilder() + .addAllAtts( + atts.map(f => AttrType.newBuilder().setIdn(f.idn).setTipe(toProtocolType(f.tipe)).build()).asJava + ) + .setNullable(nullable) + .setTriable(triable) + .build() + ) + .build() + case com.rawlabs.compiler.RawListType(t, nullable, triable) => Type + .newBuilder() + .setList( + ListType.newBuilder().setInnerType(toProtocolType(t)).setNullable(nullable).setTriable(triable).build() + ) + .build() + case com.rawlabs.compiler.RawIterableType(t, nullable, triable) => Type + .newBuilder() + .setIterable( + IterableType.newBuilder().setInnerType(toProtocolType(t)).setNullable(nullable).setTriable(triable).build() + ) + .build() + case com.rawlabs.compiler.RawOrType(ors, nullable, triable) => Type + .newBuilder() + .setOr( + OrType + .newBuilder() + .addAllOrs(ors.map(toProtocolType).asJava) + .setNullable(nullable) + .setTriable(triable) + .build() + ) + .build() + case t => throw new AssertionError(s"support for $t not implemented yet") + } + } + +} diff --git a/python-compiler/src/main/scala/com/rawlabs/python/compiler/PythonCompilerService.scala b/python-compiler/src/main/scala/com/rawlabs/python/compiler/PythonCompilerService.scala index 3aed791cf..5d09f39bc 100644 --- a/python-compiler/src/main/scala/com/rawlabs/python/compiler/PythonCompilerService.scala +++ b/python-compiler/src/main/scala/com/rawlabs/python/compiler/PythonCompilerService.scala @@ -12,7 +12,37 @@ package com.rawlabs.python.compiler -import com.rawlabs.compiler.{AutoCompleteResponse, CompilerService, CompilerServiceException, ExecutionResponse, ExecutionRuntimeFailure, ExecutionSuccess, FormatCodeResponse, GetProgramDescriptionResponse, GoToDefinitionResponse, HoverResponse, Pos, ProgramEnvironment, RawBool, RawByte, RawDate, RawDecimal, RawDouble, RawFloat, RawInt, RawInterval, RawLong, RawNull, RawShort, RawString, RawTime, RawTimestamp, RawValue, RenameResponse, ValidateResponse} +import com.rawlabs.compiler.{ + AutoCompleteResponse, + CompilerService, + ErrorMessage, + EvalSuccess, + ExecutionError, + ExecutionSuccess, + FormatCodeResponse, + GoToDefinitionResponse, + HoverResponse, + Pos, + ProgramDescription, + ProgramEnvironment, + RawBool, + RawByte, + RawDate, + RawDecimal, + RawDouble, + RawFloat, + RawInt, + RawInterval, + RawLong, + RawNull, + RawShort, + RawString, + RawTime, + RawTimestamp, + RawValue, + RenameResponse, + ValidateResponse +} import com.rawlabs.compiler.writers.{PolyglotBinaryWriter, PolyglotCsvWriter, PolyglotJsonWriter, PolyglotTextWriter} import com.rawlabs.utils.core.{RawSettings, RawUtils} import org.graalvm.polyglot.{Context, Engine, PolyglotAccess, PolyglotException, Source, Value} @@ -37,7 +67,10 @@ class PythonCompilerService(engineDefinition: (Engine, Boolean))(implicit protec override def language: Set[String] = Set("python") - override def getProgramDescription(source: String, environment: ProgramEnvironment): GetProgramDescriptionResponse = { + override def getProgramDescription( + source: String, + environment: ProgramEnvironment + ): Either[List[ErrorMessage], ProgramDescription] = { ??? // val source = """ //import ast @@ -106,7 +139,7 @@ class PythonCompilerService(engineDefinition: (Engine, Boolean))(implicit protec maybeDecl: Option[String], outputStream: OutputStream, maxRows: Option[Long] - ): ExecutionResponse = { + ): Either[ExecutionError, ExecutionSuccess] = { val ctx = buildTruffleContext(environment, maybeOutputStream = Some(outputStream)) ctx.initialize("python") ctx.enter() @@ -121,7 +154,18 @@ class PythonCompilerService(engineDefinition: (Engine, Boolean))(implicit protec val f = bindings.getMember(decl) environment.maybeArguments match { case Some(args) => - val polyglotArguments = args.map(arg => rawValueToPolyglotValue(arg._2, ctx)) + val maybePolyglotArguments = args.map(arg => rawValueToPolyglotValue(arg._2, ctx)) + val unsupportedMandatoryPolyglotArguments = maybePolyglotArguments.zipWithIndex.collect { + case (None, idx) => ErrorMessage( + s"unsupported mandatory argument at position ${idx + 1}", + List.empty, + "" + ) + } + if (unsupportedMandatoryPolyglotArguments.nonEmpty) { + return Left(ExecutionError.ValidationError(unsupportedMandatoryPolyglotArguments.to)) + } + val polyglotArguments = maybePolyglotArguments.flatten f.execute(polyglotArguments: _*) case None => f.execute() } @@ -140,9 +184,9 @@ class PythonCompilerService(engineDefinition: (Engine, Boolean))(implicit protec try { w.write(v) w.flush() - ExecutionSuccess(complete = true) + Right(ExecutionSuccess(complete = true)) } catch { - case ex: IOException => ExecutionRuntimeFailure(ex.getMessage) + case ex: IOException => Left(ExecutionError.RuntimeError(ex.getMessage)) } finally { RawUtils.withSuppressNonFatalException(w.close()) } @@ -151,9 +195,9 @@ class PythonCompilerService(engineDefinition: (Engine, Boolean))(implicit protec try { w.write(v) w.flush() - ExecutionSuccess(complete = true) + Right(ExecutionSuccess(complete = true)) } catch { - case ex: IOException => ExecutionRuntimeFailure(ex.getMessage) + case ex: IOException => Left(ExecutionError.RuntimeError(ex.getMessage)) } finally { RawUtils.withSuppressNonFatalException(w.close()) } @@ -161,26 +205,26 @@ class PythonCompilerService(engineDefinition: (Engine, Boolean))(implicit protec val w = new PolyglotTextWriter(outputStream) try { w.writeAndFlush(v) - ExecutionSuccess(complete = true) + Right(ExecutionSuccess(complete = true)) } catch { - case ex: IOException => ExecutionRuntimeFailure(ex.getMessage) + case ex: IOException => Left(ExecutionError.RuntimeError(ex.getMessage)) } case Some("binary") => val w = new PolyglotBinaryWriter(outputStream) try { w.writeAndFlush(v) - ExecutionSuccess(complete = true) + Right(ExecutionSuccess(complete = true)) } catch { - case ex: IOException => ExecutionRuntimeFailure(ex.getMessage) + case ex: IOException => Left(ExecutionError.RuntimeError(ex.getMessage)) } - case _ => ExecutionRuntimeFailure("unknown output format") + case _ => Left(ExecutionError.RuntimeError("unknown output format")) } } catch { case ex: PolyglotException => if (ex.isInterrupted) { throw new InterruptedException() } else { - ExecutionRuntimeFailure(ex.getMessage) + Left(ExecutionError.RuntimeError(ex.getMessage)) } } finally { ctx.leave() @@ -188,7 +232,7 @@ class PythonCompilerService(engineDefinition: (Engine, Boolean))(implicit protec } } - private def rawValueToPolyglotValue(rawValue: RawValue, ctx: Context): Value = { + private def rawValueToPolyglotValue(rawValue: RawValue, ctx: Context): Option[Value] = { val code: String = rawValue match { case RawNull() => "None" case RawByte(v) => ??? @@ -204,12 +248,18 @@ class PythonCompilerService(engineDefinition: (Engine, Boolean))(implicit protec case RawTime(v) => ??? case RawTimestamp(v) => ??? case RawInterval(years, months, weeks, days, hours, minutes, seconds, millis) => ??? - case _ => throw new CompilerServiceException("type not supported") + case _ => return None } val value = ctx.eval("python", code) - ctx.asValue(value) + Some(ctx.asValue(value)) } + override def eval( + source: String, + environment: ProgramEnvironment, + maybeDecl: Option[String] + ): Either[ExecutionError, EvalSuccess] = ??? + override def formatCode( source: String, environment: ProgramEnvironment, diff --git a/python-compiler/src/test/scala/com/rawlabs/python/compiler/TestPythonCompilerService.scala b/python-compiler/src/test/scala/com/rawlabs/python/compiler/TestPythonCompilerService.scala index a53b8b765..e33cb4a80 100644 --- a/python-compiler/src/test/scala/com/rawlabs/python/compiler/TestPythonCompilerService.scala +++ b/python-compiler/src/test/scala/com/rawlabs/python/compiler/TestPythonCompilerService.scala @@ -47,7 +47,7 @@ class TestPythonCompilerService extends RawTestSuite with SettingsTestContext wi Map("output-format" -> "json") ) val baos = new ByteArrayOutputStream() - assert(compilerService.execute("1+1", environment, None, baos) == ExecutionSuccess(true)) + assert(compilerService.execute("1+1", environment, None, baos) == Right(ExecutionSuccess(true))) assert(baos.toString() == "2") } @@ -61,7 +61,9 @@ class TestPythonCompilerService extends RawTestSuite with SettingsTestContext wi Map("output-format" -> "json") ) val baos = new ByteArrayOutputStream() - assert(compilerService.execute("def f(): return 1+1", environment, Some("f"), baos) == ExecutionSuccess(true)) + assert( + compilerService.execute("def f(): return 1+1", environment, Some("f"), baos) == Right(ExecutionSuccess(true)) + ) assert(baos.toString() == "2") } @@ -75,7 +77,9 @@ class TestPythonCompilerService extends RawTestSuite with SettingsTestContext wi Map("output-format" -> "json") ) val baos = new ByteArrayOutputStream() - assert(compilerService.execute("def f(v): return v*2", environment, Some("f"), baos) == ExecutionSuccess(true)) + assert( + compilerService.execute("def f(v): return v*2", environment, Some("f"), baos) == Right(ExecutionSuccess(true)) + ) assert(baos.toString() == "4") } diff --git a/snapi-compiler/src/main/scala/com/rawlabs/snapi/compiler/SnapiCompilerService.scala b/snapi-compiler/src/main/scala/com/rawlabs/snapi/compiler/SnapiCompilerService.scala index ac368504b..97658bef8 100644 --- a/snapi-compiler/src/main/scala/com/rawlabs/snapi/compiler/SnapiCompilerService.scala +++ b/snapi-compiler/src/main/scala/com/rawlabs/snapi/compiler/SnapiCompilerService.scala @@ -12,22 +12,19 @@ package com.rawlabs.snapi.compiler +import com.google.protobuf.ByteString +import com.rawlabs.compiler.utils.RecordFieldsNaming import com.rawlabs.compiler.{ AutoCompleteResponse, CompilerService, - CompilerServiceException, DeclDescription, ErrorMessage, ErrorPosition, ErrorRange, - ExecutionResponse, - ExecutionRuntimeFailure, + EvalSuccess, + ExecutionError, ExecutionSuccess, - ExecutionValidationFailure, FormatCodeResponse, - GetProgramDescriptionFailure, - GetProgramDescriptionResponse, - GetProgramDescriptionSuccess, GoToDefinitionResponse, HoverResponse, Message, @@ -51,9 +48,23 @@ import com.rawlabs.compiler.{ RawTimestamp, RawValue, RenameResponse, + TypeConverter, ValidateResponse } import com.rawlabs.compiler.writers.{PolyglotBinaryWriter, PolyglotTextWriter} +import com.rawlabs.protocol.raw +import com.rawlabs.protocol.raw.{ + ValueBinary, + ValueBool, + ValueByte, + ValueError, + ValueInterval, + ValueList, + ValueNull, + ValueRecord, + ValueRecordField, + ValueShort +} import com.rawlabs.snapi.compiler.SnapiCompilerService.getTruffleClassLoader import com.rawlabs.snapi.compiler.writers.{SnapiCsvWriter, SnapiJsonWriter} import com.rawlabs.utils.core.{RawSettings, RawUid, RawUtils} @@ -75,6 +86,7 @@ import com.rawlabs.snapi.frontend.snapi.extensions.builtin.{BinaryPackage, CsvPa import java.io.{IOException, OutputStream} import scala.collection.mutable import scala.util.control.NonFatal +import scala.collection.JavaConverters._ object SnapiCompilerService extends CustomClassAndModuleLoader { val LANGUAGE: Set[String] = Set("snapi") @@ -150,18 +162,13 @@ class SnapiCompilerService(engineDefinition: (Engine, Boolean))(implicit protect } def parse(source: String, environment: ProgramEnvironment): ParseResponse = { - val programContext = getProgramContext(environment.uid, environment) - try { - val positions = new Positions() - val parser = new Antlr4SyntaxAnalyzer(positions, true) - val parseResult = parser.parse(source) - if (parseResult.isSuccess) { - ParseSuccess(parseResult.tree) - } else { - ParseFailure(parseResult.errors) - } - } catch { - case NonFatal(t) => throw new CompilerServiceException(t, programContext.dumpDebugInfo) + val positions = new Positions() + val parser = new Antlr4SyntaxAnalyzer(positions, true) + val parseResult = parser.parse(source) + if (parseResult.isSuccess) { + ParseSuccess(parseResult.tree) + } else { + ParseFailure(parseResult.errors) } } @@ -173,15 +180,11 @@ class SnapiCompilerService(engineDefinition: (Engine, Boolean))(implicit protect environment, _ => { val programContext = getProgramContext(environment.uid, environment) - try { - val tree = new TreeWithPositions(source, ensureTree = false, frontend = true)(programContext) - if (tree.valid) { - GetTypeSuccess(tree.rootType) - } else { - GetTypeFailure(tree.errors) - } - } catch { - case NonFatal(t) => throw new CompilerServiceException(t, programContext.dumpDebugInfo) + val tree = new TreeWithPositions(source, ensureTree = false, frontend = true)(programContext) + if (tree.valid) { + GetTypeSuccess(tree.rootType) + } else { + GetTypeFailure(tree.errors) } } ) @@ -190,40 +193,36 @@ class SnapiCompilerService(engineDefinition: (Engine, Boolean))(implicit protect override def getProgramDescription( source: String, environment: ProgramEnvironment - ): GetProgramDescriptionResponse = { + ): Either[List[ErrorMessage], ProgramDescription] = { withTruffleContext( environment, _ => { val programContext = getProgramContext(environment.uid, environment) - try { - val tree = new TreeWithPositions(source, ensureTree = false, frontend = true)(programContext) - if (tree.valid) { - val TreeDescription(decls, maybeType, comment) = tree.description - val formattedDecls = decls.map { - case (idn, programDecls) => - val formattedDecls = programDecls.map { - case TreeDeclDescription(None, outType, comment) => - DeclDescription(None, snapiTypeToRawType(outType), comment) - case TreeDeclDescription(Some(params), outType, comment) => - val formattedParams = params.map { - case TreeParamDescription(idn, tipe, required) => - ParamDescription(idn, snapiTypeToRawType(tipe), defaultValue = None, comment = None, required) - } - DeclDescription(Some(formattedParams), snapiTypeToRawType(outType), comment) - } - (idn, formattedDecls) - } - val programDescription = ProgramDescription( - formattedDecls, - maybeType.map(t => DeclDescription(None, snapiTypeToRawType(t), None)), - comment - ) - GetProgramDescriptionSuccess(programDescription) - } else { - GetProgramDescriptionFailure(tree.errors.collect { case e: ErrorMessage => e }) + val tree = new TreeWithPositions(source, ensureTree = false, frontend = true)(programContext) + if (tree.valid) { + val TreeDescription(decls, maybeType, comment) = tree.description + val formattedDecls = decls.map { + case (idn, programDecls) => + val formattedDecls = programDecls.map { + case TreeDeclDescription(None, outType, comment) => + DeclDescription(None, snapiTypeToRawType(outType), comment) + case TreeDeclDescription(Some(params), outType, comment) => + val formattedParams = params.map { + case TreeParamDescription(idn, tipe, required) => + ParamDescription(idn, snapiTypeToRawType(tipe), defaultValue = None, comment = None, required) + } + DeclDescription(Some(formattedParams), snapiTypeToRawType(outType), comment) + } + (idn, formattedDecls) } - } catch { - case NonFatal(t) => throw new CompilerServiceException(t, programContext.dumpDebugInfo) + val programDescription = ProgramDescription( + formattedDecls, + maybeType.map(t => DeclDescription(None, snapiTypeToRawType(t), None)), + comment + ) + Right(programDescription) + } else { + Left(tree.errors.collect { case e: ErrorMessage => e }) } } ) @@ -235,12 +234,140 @@ class SnapiCompilerService(engineDefinition: (Engine, Boolean))(implicit protect maybeDecl: Option[String], outputStream: OutputStream, maxRows: Option[Long] - ): ExecutionResponse = { + ): Either[ExecutionError, ExecutionSuccess] = { val ctx = buildTruffleContext(environment, maybeOutputStream = Some(outputStream)) ctx.initialize("snapi") ctx.enter() try { - val (v, tipe) = maybeDecl match { + getValueAndType(source, environment, maybeDecl, ctx) match { + case Left(err) => Left(err) + case Right((v, tipe)) => environment.options + .get("output-format") + .map(_.toLowerCase) match { + case Some("csv") => + if (!CsvPackage.outputWriteSupport(tipe)) { + return Left(ExecutionError.RuntimeError("unsupported type")) + } + val windowsLineEnding = environment.options.get("windows-line-ending") match { + case Some("true") => true + case _ => false + } + val lineSeparator = if (windowsLineEnding) "\r\n" else "\n" + val w = new SnapiCsvWriter(outputStream, lineSeparator, maxRows) + try { + w.write(v, tipe.asInstanceOf[SnapiTypeWithProperties]) + w.flush() + Right(ExecutionSuccess(w.complete)) + } catch { + case ex: IOException => Left(ExecutionError.RuntimeError(ex.getMessage)) + } finally { + RawUtils.withSuppressNonFatalException(w.close()) + } + case Some("json") => + if (!JsonPackage.outputWriteSupport(tipe)) { + return Left(ExecutionError.RuntimeError("unsupported type")) + } + val w = new SnapiJsonWriter(outputStream, maxRows) + try { + w.write(v, tipe.asInstanceOf[SnapiTypeWithProperties]) + w.flush() + Right(ExecutionSuccess(w.complete)) + } catch { + case ex: IOException => Left(ExecutionError.RuntimeError(ex.getMessage)) + } finally { + RawUtils.withSuppressNonFatalException(w.close()) + } + case Some("text") => + if (!StringPackage.outputWriteSupport(tipe)) { + return Left(ExecutionError.RuntimeError("unsupported type")) + } + val w = new PolyglotTextWriter(outputStream) + try { + w.writeAndFlush(v) + Right(ExecutionSuccess(complete = true)) + } catch { + case ex: IOException => Left(ExecutionError.RuntimeError(ex.getMessage)) + } + case Some("binary") => + if (!BinaryPackage.outputWriteSupport(tipe)) { + return Left(ExecutionError.RuntimeError("unsupported type")) + } + val w = new PolyglotBinaryWriter(outputStream) + try { + w.writeAndFlush(v) + Right(ExecutionSuccess(complete = true)) + } catch { + case ex: IOException => Left(ExecutionError.RuntimeError(ex.getMessage)) + } + case _ => Left(ExecutionError.RuntimeError("unknown output format")) + } + } + } catch { + case ex: PolyglotException => + // (msb): The following are various "hacks" to ensure the inner language InterruptException propagates "out". + // Unfortunately, I do not find a more reliable alternative; the branch that does seem to work is the one + // that does startsWith. That said, I believe with Truffle, the expectation is that one is supposed to + // "cancel the context", but in our case this doesn't quite match the current architecture, where we have + // other non-Truffle languages and also, we have parts of the pipeline that are running outside of Truffle + // and which must handle interruption as well. + if (ex.isInterrupted) { + throw new InterruptedException() + } else if (ex.getCause.isInstanceOf[InterruptedException]) { + throw ex.getCause + } else if (ex.getMessage.startsWith("java.lang.InterruptedException")) { + throw new InterruptedException() + } else if (ex.isGuestException) { + if (ex.isInternalError) { + // An internal error. It means a regular Exception thrown from the language (e.g. a Java Exception, + // or a RawTruffleInternalErrorException, which isn't an AbstractTruffleException) + throw ex + } else { + val err = ex.getGuestObject + if (err != null && err.hasMembers && err.hasMember("errors")) { + // A validation exception, semantic or syntax error (both come as the same kind of error) + // that has a list of errors and their positions. + val errorsValue = err.getMember("errors") + val errors = (0L until errorsValue.getArraySize).map { i => + val errorValue = errorsValue.getArrayElement(i) + val message = errorValue.asString + val positions = (0L until errorValue.getArraySize).map { j => + val posValue = errorValue.getArrayElement(j) + val beginValue = posValue.getMember("begin") + val endValue = posValue.getMember("end") + val begin = ErrorPosition(beginValue.getMember("line").asInt, beginValue.getMember("column").asInt) + val end = ErrorPosition(endValue.getMember("line").asInt, endValue.getMember("column").asInt) + ErrorRange(begin, end) + } + ErrorMessage(message, positions.to, ParserErrors.ParserErrorCode) + } + Left(ExecutionError.ValidationError(errors.to)) + } else { + // A runtime failure during execution. The query could be a failed tryable, or a runtime error (e.g. a + // file not found) hit when processing a reader that evaluates as a _collection_ (processed outside the + // evaluation of the query). + Left(ExecutionError.RuntimeError(ex.getMessage)) + } + } + } else { + // Unexpected error. For now we throw the PolyglotException. + throw ex + } + } finally { + ctx.leave() + ctx.close() + } + } + + private def getValueAndType( + source: String, + environment: ProgramEnvironment, + maybeDecl: Option[String], + ctx: Context + ): Either[ExecutionError, (Value, Type)] = { + import ExecutionError._ + + try { + maybeDecl match { case Some(decl) => // Eval the code and extract the function referred to by 'decl' val truffleSource = Source @@ -273,24 +400,58 @@ class SnapiCompilerService(engineDefinition: (Engine, Boolean))(implicit protect // Mandatory args have to be all provided if (mandatoryArgs.length != funType.ms.size) { - return ExecutionRuntimeFailure("missing mandatory arguments") + return Left( + ValidationError( + List(ErrorMessage("missing mandatory arguments", List.empty, ParserErrors.ParserErrorCode)) + ) + ) + } + + // Mandatory arguments are converted to Polyglot values + val maybeMandatoryPolyglotArguments = mandatoryArgs.map(arg => rawValueToPolyglotValue(arg, ctx)) + val unsupportedMandatoryPolyglotArguments = maybeMandatoryPolyglotArguments.zipWithIndex.collect { + case (None, idx) => ErrorMessage( + s"unsupported mandatory argument at position ${idx + 1}", + List.empty, + ParserErrors.ParserErrorCode + ) + } + if (unsupportedMandatoryPolyglotArguments.nonEmpty) { + return Left(ValidationError(unsupportedMandatoryPolyglotArguments.to)) } - val mandatoryPolyglotArguments = mandatoryArgs.map(arg => rawValueToPolyglotValue(arg, ctx)) + val mandatoryPolyglotArguments: Array[Value] = maybeMandatoryPolyglotArguments.flatten + + // Optional arguments are converted to Polyglot values + val maybeOptionalPolyglotArguments: Map[String, Option[Value]] = funType.os.collect { + case arg if optionalArgs.contains(arg.i) => + val paramValue = optionalArgs(arg.i) + arg.i -> rawValueToPolyglotValue(paramValue, ctx) + }.toMap + val unsupportedOptionalPolyglotArguments = maybeOptionalPolyglotArguments.collect { + case (i, None) => + ErrorMessage(s"unsupported optional argument $i", List.empty, ParserErrors.ParserErrorCode) + } + if (unsupportedOptionalPolyglotArguments.nonEmpty) { + return Left(ValidationError(unsupportedOptionalPolyglotArguments.to)) + } + // Optional arguments can be missing from the provided arguments. // We replace the missing ones by their default value. - val optionalPolyglotArguments = funType.os.map { arg => - optionalArgs.get(arg.i) match { + val optionalPolyglotArguments: Seq[Value] = funType.os.map { arg => + maybeOptionalPolyglotArguments.get(arg.i) match { // if the argument is provided, use it - case Some(paramValue) => rawValueToPolyglotValue(paramValue, ctx) + case Some(paramValue) => paramValue.get // else, the argument has a default value that can be obtained from `f`. case None => f.invokeMember("default_" + arg.i) } } + // All arguments are there. Call .execute. val result = f.execute(mandatoryPolyglotArguments ++ optionalPolyglotArguments: _*) val tipe = funType.r // Return the result and its type. - (result, tipe) + Right((result, tipe)) + case None => val truffleSource = Source .newBuilder("snapi", source, "unnamed") @@ -300,123 +461,363 @@ class SnapiCompilerService(engineDefinition: (Engine, Boolean))(implicit protect // The value type is found in polyglot bindings after calling eval(). val rawType = ctx.getPolyglotBindings.getMember("@type").asString() val ParseTypeSuccess(tipe) = parseType(rawType, environment.uid, internal = true) - (result, tipe) + Right((result, tipe)) } + } catch { + case TruffleValidationError(errors) => Left(ValidationError(errors)) + case TruffleRuntimeError(message) => Left(RuntimeError(message)) + } + } - environment.options - .get("output-format") - .map(_.toLowerCase) match { - case Some("csv") => - if (!CsvPackage.outputWriteSupport(tipe)) { - return ExecutionRuntimeFailure("unsupported type") - } - val windowsLineEnding = environment.options.get("windows-line-ending") match { - case Some("true") => true - case _ => false - } - val lineSeparator = if (windowsLineEnding) "\r\n" else "\n" - val w = new SnapiCsvWriter(outputStream, lineSeparator, maxRows) + override def eval( + source: String, + environment: ProgramEnvironment, + maybeDecl: Option[String] + ): Either[ExecutionError, EvalSuccess] = { + + // 1) Build the Truffle context + val ctx = buildTruffleContext(environment, maybeOutputStream = None) + ctx.initialize("snapi") + ctx.enter() + + try { + // 2) Get the value+type. This is where we handle immediate validation errors, which are caught by the + // SnapiLanguage during the parse call. + getValueAndType(source, environment, maybeDecl, ctx) match { + case Left(error) => + // 3) We got an immediate error. + // We must close the context now, or we leak it. + ctx.leave() + ctx.close(true) // Cancel the context and any executing threads. + Left(error) + + case Right((v, t)) => + // 4) We have a value so let's produce the final evaluation result. + buildEvalResult(ctx, v, t.asInstanceOf[SnapiTypeWithProperties]) + + } + } catch { + case t: Throwable => + // 6) We caught some other exception. + // We must close the context now, or we leak it. + ctx.leave() + ctx.close(true) // Cancel the context and any executing threads. + throw t + } + } + + /** + * buildTruffleIterator determines if the top-level type is an iterable (SnapiIterableType / SnapiListType). + * - If yes, returns Right(Iterator[RawValue]). + * - Else returns Left(RawValue). + * Also returns the corresponding RawType to describe the final shape. + */ + private def buildEvalResult( + ctx: Context, + v: Value, + t: SnapiTypeWithProperties + ): Either[ExecutionError.RuntimeError, EvalSuccess] = { + + import EvalSuccess._ + import ExecutionError._ + + // If top-level is triable AND holds an exception, return it as a runtime error. + if (maybeIsException(v, t)) { + try { + v.throwException() + } catch { + case NonFatal(ex) => + // Close context here directly, since we're done + ctx.leave() + ctx.close(true) + return Left(ExecutionError.RuntimeError(ex.getMessage)) + } + } + + t match { + case SnapiIterableType(innerType, _) if !v.isNull => + assert(!maybeIsException(v, t)) + + // 1) If top-level is an iterable (and not top-level triable error), we produce an iterator + // Note that the iterator does not close itself on error, as the consumer is expected to call close() itself + + // First obtain the polyglot iterator + val polyIt = try { - w.write(v, tipe.asInstanceOf[SnapiTypeWithProperties]) - w.flush() - ExecutionSuccess(w.complete) + v.getIterator } catch { - case ex: IOException => ExecutionRuntimeFailure(ex.getMessage) - } finally { - RawUtils.withSuppressNonFatalException(w.close()) + case TruffleRuntimeError(message) => return Left(RuntimeError(message)) } - case Some("json") => - if (!JsonPackage.outputWriteSupport(tipe)) { - return ExecutionRuntimeFailure("unsupported type") + + val valueIterator = new Iterator[raw.Value] with AutoCloseable { + // Lazy to avoid calling getIterator if we don't need it + // Note we do not need to check if v.isException because we know 'v' is not a triable. + + override def hasNext: Boolean = { + // We do not expect this call to fail; only the creation of the polyIt can fail... + polyIt.hasIteratorNextElement } - val w = new SnapiJsonWriter(outputStream, maxRows) - try { - w.write(v, tipe.asInstanceOf[SnapiTypeWithProperties]) - w.flush() - ExecutionSuccess(w.complete) - } catch { - case ex: IOException => ExecutionRuntimeFailure(ex.getMessage) - } finally { - RawUtils.withSuppressNonFatalException(w.close()) + + override def next(): com.rawlabs.protocol.raw.Value = { + // We do not expect this call to fail; only the creation of the polyIt can fail... + val truffleVal = polyIt.getIteratorNextElement + + // Now convert the Truffle value to protocolValue. + // If the value is an exception, as per our semantic, it will be converted now to a ValueError. + val protocolVal = fromTruffleValue(truffleVal, innerType.asInstanceOf[SnapiTypeWithProperties]) + + protocolVal } - case Some("text") => - if (!StringPackage.outputWriteSupport(tipe)) { - return ExecutionRuntimeFailure("unsupported type") + + override def close(): Unit = { + ctx.leave() + ctx.close(true) } - val w = new PolyglotTextWriter(outputStream) - try { - w.writeAndFlush(v) - ExecutionSuccess(complete = true) - } catch { - case ex: IOException => ExecutionRuntimeFailure(ex.getMessage) + } + + // The context remains open since we are returning an iterator. + // It is up to the caller to close that iterator, which then triggers the Truffle context closure. + + val protocolType = TypeConverter.toProtocolType(snapiTypeToRawType(innerType).get) + Right(IteratorValue(protocolType, valueIterator)) + + case _ => + // 2) Otherwise, produce a single value iterator + val protocolValue = fromTruffleValue(v, t) + val protocolType = TypeConverter.toProtocolType(snapiTypeToRawType(t).get) + val evalResult = ResultValue(protocolType, protocolValue) + + // Close context here directly + ctx.leave() + ctx.close(true) + + Right(evalResult) + } + } + + /** + * Check if top-level is triable and an exception, in which case fromTruffleValue + * would yield a RawError. But for iterables, we skip that check at top-level + * because we want an Iterator that possibly yields items or errors in iteration. + */ + private def maybeIsException(v: Value, st: SnapiTypeWithProperties): Boolean = { + val tryableProp = SnapiIsTryableTypeProperty() + st.props.contains(tryableProp) && v.isException + } + + /** + * Convert a single Truffle Value to a RawValue (no iteration at top level). + * It recurses into records, or-lambda, intervals, etc. If the type is triable + * and the value is an exception, we produce a RawError; if the type is nullable + * and value is null, we produce RawNull, etc. + * + * This mirrors your SnapiJsonWriter logic, but instead of writing JSON, + * we build up the appropriate RawValue. + */ + private def fromTruffleValue( + v: Value, + t: SnapiTypeWithProperties + ): com.rawlabs.protocol.raw.Value = { + val tryable = SnapiIsTryableTypeProperty() + val nullable = SnapiIsNullableTypeProperty() + + // 1) If triable + if (t.props.contains(tryable)) { + // if the value is an exception => produce RawError + if (v.isException) { + try { + // attempt to "throw" to get the actual Java exception + v.throwException() + // If we get here, it didn't actually throw? We'll produce RawError anyway. + throw new AssertionError("triable error did not throw exception!") + } catch { + case NonFatal(ex) => com.rawlabs.protocol.raw.Value + .newBuilder() + .setError(ValueError.newBuilder().setMessage(ex.getMessage).build()) + .build() + } + } else { + // remove the property and keep going + fromTruffleValue(v, t.cloneAndRemoveProp(tryable).asInstanceOf[SnapiTypeWithProperties]) + } + } + // 2) If nullable + else if (t.props.contains(nullable)) { + if (v.isNull) com.rawlabs.protocol.raw.Value.newBuilder().setNull(ValueNull.newBuilder()).build() + else fromTruffleValue(v, t.cloneAndRemoveProp(nullable).asInstanceOf[SnapiTypeWithProperties]) + } + // 3) Otherwise match on the underlying type + else { + t match { + case _: SnapiBinaryType => + // read all bytes from the buffer + val bytes = Array.ofDim[Byte](v.getBufferSize.toInt) + var idx = 0 + while (idx < bytes.length) { + bytes(idx) = v.readBufferByte(idx.toLong) + idx += 1 } - case Some("binary") => - if (!BinaryPackage.outputWriteSupport(tipe)) { - return ExecutionRuntimeFailure("unsupported type") + com.rawlabs.protocol.raw.Value + .newBuilder() + .setBinary(ValueBinary.newBuilder().setV(ByteString.copyFrom(bytes))) + .build() + + case _: SnapiBoolType => + com.rawlabs.protocol.raw.Value.newBuilder().setBool(ValueBool.newBuilder().setV(v.asBoolean())).build() + case _: SnapiByteType => + com.rawlabs.protocol.raw.Value.newBuilder().setByte(ValueByte.newBuilder().setV(v.asByte())).build() + case _: SnapiShortType => + com.rawlabs.protocol.raw.Value.newBuilder().setShort(ValueShort.newBuilder().setV(v.asShort())).build() + case _: SnapiIntType => com.rawlabs.protocol.raw.Value + .newBuilder() + .setInt(com.rawlabs.protocol.raw.ValueInt.newBuilder().setV(v.asInt())) + .build() + case _: SnapiLongType => com.rawlabs.protocol.raw.Value + .newBuilder() + .setLong(com.rawlabs.protocol.raw.ValueLong.newBuilder().setV(v.asLong())) + .build() + case _: SnapiFloatType => com.rawlabs.protocol.raw.Value + .newBuilder() + .setFloat(com.rawlabs.protocol.raw.ValueFloat.newBuilder().setV(v.asFloat())) + .build() + case _: SnapiDoubleType => com.rawlabs.protocol.raw.Value + .newBuilder() + .setDouble(com.rawlabs.protocol.raw.ValueDouble.newBuilder().setV(v.asDouble())) + .build() + case _: SnapiDecimalType => + // If asString() returns decimal textual form, parse into BigDecimal + val txt = v.asString() + com.rawlabs.protocol.raw.Value + .newBuilder() + .setDecimal(com.rawlabs.protocol.raw.ValueDecimal.newBuilder().setV(txt)) + .build() + + case _: SnapiStringType => com.rawlabs.protocol.raw.Value + .newBuilder() + .setString(com.rawlabs.protocol.raw.ValueString.newBuilder().setV(v.asString())) + .build() + + case _: SnapiDateType => + // v.asDate() => LocalDate + com.rawlabs.protocol.raw.Value + .newBuilder() + .setDate( + com.rawlabs.protocol.raw.ValueDate + .newBuilder() + .setYear(v.asDate().getYear) + .setMonth(v.asDate().getMonthValue) + .setDay(v.asDate().getDayOfMonth) + ) + .build() + + case _: SnapiTimeType => + // v.asTime() => LocalTime + com.rawlabs.protocol.raw.Value + .newBuilder() + .setTime( + com.rawlabs.protocol.raw.ValueTime + .newBuilder() + .setHour(v.asTime().getHour) + .setMinute(v.asTime().getMinute) + .setSecond(v.asTime().getSecond) + .setNano(v.asTime().getNano) + ) + .build() + + case _: SnapiTimestampType => + // Typically we treat v.asDate() as LocalDate, v.asTime() as LocalTime, then combine + com.rawlabs.protocol.raw.Value + .newBuilder() + .setTimestamp( + com.rawlabs.protocol.raw.ValueTimestamp + .newBuilder() + .setYear(v.asDate().getYear) + .setMonth(v.asDate().getMonthValue) + .setDay(v.asDate().getDayOfMonth) + .setHour(v.asTime().getHour) + .setMinute(v.asTime().getMinute) + .setSecond(v.asTime().getSecond) + .setNano(v.asTime().getNano) + ) + .build() + + case _: SnapiIntervalType => + val duration = v.asDuration() + val days = duration.toDays + val hours = duration.toHoursPart + val minutes = duration.toMinutesPart + val seconds = duration.toSecondsPart + val millis = duration.toMillisPart + + com.rawlabs.protocol.raw.Value + .newBuilder() + .setInterval( + ValueInterval + .newBuilder() + .setYears(0) + .setMonths(0) + .setWeeks(0) + .setDays(days.toInt) + .setHours(hours) + .setMinutes(minutes) + .setSeconds(seconds) + .setMillis(millis) + ) + .build() + + case SnapiRecordType(attributes, _) => + // Snapi language produces record fields that can be renamed (while the type is not!) + // So this compensates for that, so that when we do TRuffle value.getMember(...) we can use the + // distinct name + val names = new java.util.Vector[String]() + attributes.foreach(a => names.add(a.idn)) + val distincted = RecordFieldsNaming.makeDistinct(names).asScala + + // Build a RawRecord + val recordAttrs = attributes.zip(distincted).map { + case (att, distinctFieldName) => + val fieldName = att.idn + val memberVal = v.getMember(distinctFieldName) + val fieldValue = fromTruffleValue( + memberVal, + att.tipe.asInstanceOf[SnapiTypeWithProperties] + ) + ValueRecordField.newBuilder().setName(fieldName).setValue(fieldValue).build() } - val w = new PolyglotBinaryWriter(outputStream) - try { - w.writeAndFlush(v) - ExecutionSuccess(complete = true) - } catch { - case ex: IOException => ExecutionRuntimeFailure(ex.getMessage) + com.rawlabs.protocol.raw.Value + .newBuilder() + .setRecord(ValueRecord.newBuilder().addAllFields(recordAttrs.asJava)) + .build() + + case SnapiIterableType(innerType, _) => + val items = mutable.ArrayBuffer[com.rawlabs.protocol.raw.Value]() + val iterator = v.getIterator + while (iterator.hasIteratorNextElement) { + val elem = iterator.getIteratorNextElement + items += fromTruffleValue(elem, innerType.asInstanceOf[SnapiTypeWithProperties]) } - case _ => ExecutionRuntimeFailure("unknown output format") - } - } catch { - case ex: PolyglotException => - // (msb): The following are various "hacks" to ensure the inner language InterruptException propagates "out". - // Unfortunately, I do not find a more reliable alternative; the branch that does seem to work is the one - // that does startsWith. That said, I believe with Truffle, the expectation is that one is supposed to - // "cancel the context", but in our case this doesn't quite match the current architecture, where we have - // other non-Truffle languages and also, we have parts of the pipeline that are running outside of Truffle - // and which must handle interruption as well. - if (ex.isInterrupted) { - throw new InterruptedException() - } else if (ex.getCause.isInstanceOf[InterruptedException]) { - throw ex.getCause - } else if (ex.getMessage.startsWith("java.lang.InterruptedException")) { - throw new InterruptedException() - } else if (ex.isGuestException) { - if (ex.isInternalError) { - // An internal error. It means a regular Exception thrown from the language (e.g. a Java Exception, - // or a RawTruffleInternalErrorException, which isn't an AbstractTruffleException) - val programContext = getProgramContext(environment.uid, environment) - throw new CompilerServiceException(ex, programContext.dumpDebugInfo) - } else { - val err = ex.getGuestObject - if (err != null && err.hasMembers && err.hasMember("errors")) { - // A validation exception, semantic or syntax error (both come as the same kind of error) - // that has a list of errors and their positions. - val errorsValue = err.getMember("errors") - val errors = (0L until errorsValue.getArraySize).map { i => - val errorValue = errorsValue.getArrayElement(i) - val message = errorValue.asString - val positions = (0L until errorValue.getArraySize).map { j => - val posValue = errorValue.getArrayElement(j) - val beginValue = posValue.getMember("begin") - val endValue = posValue.getMember("end") - val begin = ErrorPosition(beginValue.getMember("line").asInt, beginValue.getMember("column").asInt) - val end = ErrorPosition(endValue.getMember("line").asInt, endValue.getMember("column").asInt) - ErrorRange(begin, end) - } - ErrorMessage(message, positions.to, ParserErrors.ParserErrorCode) - } - ExecutionValidationFailure(errors.to) - } else { - // A runtime failure during execution. The query could be a failed tryable, or a runtime error (e.g. a - // file not found) hit when processing a reader that evaluates as a _collection_ (processed outside the - // evaluation of the query). - ExecutionRuntimeFailure(ex.getMessage) - } + com.rawlabs.protocol.raw.Value.newBuilder().setList(ValueList.newBuilder().addAllValues(items.asJava)).build() + + case SnapiListType(innerType, _) => + val size = v.getArraySize + val items = (0L until size).map { i => + val elem = v.getArrayElement(i) + fromTruffleValue(elem, innerType.asInstanceOf[SnapiTypeWithProperties]) } - } else { - // Unexpected error. For now we throw the PolyglotException. - throw ex - } - } finally { - ctx.leave() - ctx.close() + com.rawlabs.protocol.raw.Value.newBuilder().setList(ValueList.newBuilder().addAllValues(items.asJava)).build() + + case SnapiOrType(tipes, _) if tipes.exists(SnapiTypeUtils.getProps(_).nonEmpty) => + // A trick to make sur inner types do not have properties + val inners = tipes.map { case inner: SnapiTypeWithProperties => SnapiTypeUtils.resetProps(inner, Set.empty) } + val orProps = tipes.flatMap { case inner: SnapiTypeWithProperties => inner.props }.toSet + fromTruffleValue(v, SnapiOrType(inners, orProps)) + + case SnapiOrType(inners, _) => + // We can check which index the union picked. Typically you do: + val index = v.invokeMember("getIndex").asInt() + val value = v.invokeMember("getValue") + fromTruffleValue(value, inners(index).asInstanceOf[SnapiTypeWithProperties]) + } } } @@ -426,15 +827,10 @@ class SnapiCompilerService(engineDefinition: (Engine, Boolean))(implicit protect maybeIndent: Option[Int], maybeWidth: Option[Int] ): FormatCodeResponse = { - val programContext = getProgramContext(environment.uid, environment) - try { - val pretty = new SourceCommentsPrettyPrinter(maybeIndent, maybeWidth) - pretty.prettyCode(source) match { - case Right(code) => FormatCodeResponse(Some(code)) - case Left(_) => FormatCodeResponse(None) - } - } catch { - case NonFatal(t) => throw new CompilerServiceException(t, programContext.dumpDebugInfo) + val pretty = new SourceCommentsPrettyPrinter(maybeIndent, maybeWidth) + pretty.prettyCode(source) match { + case Right(code) => FormatCodeResponse(Some(code)) + case Left(_) => FormatCodeResponse(None) } } @@ -447,15 +843,11 @@ class SnapiCompilerService(engineDefinition: (Engine, Boolean))(implicit protect environment, _ => { val programContext = getProgramContext(environment.uid, environment) - try { - withLspTree(source, lspService => lspService.dotAutoComplete(source, environment, position))( - programContext - ) match { - case Right(value) => value - case Left(_) => AutoCompleteResponse(Array.empty) - } - } catch { - case NonFatal(t) => throw new CompilerServiceException(t, programContext.dumpDebugInfo) + withLspTree(source, lspService => lspService.dotAutoComplete(source, environment, position))( + programContext + ) match { + case Right(value) => value + case Left(_) => AutoCompleteResponse(Array.empty) } } ) @@ -471,15 +863,11 @@ class SnapiCompilerService(engineDefinition: (Engine, Boolean))(implicit protect environment, _ => { val programContext = getProgramContext(environment.uid, environment) - try { - withLspTree(source, lspService => lspService.wordAutoComplete(source, environment, prefix, position))( - programContext - ) match { - case Right(value) => value - case Left(_) => AutoCompleteResponse(Array.empty) - } - } catch { - case NonFatal(t) => throw new CompilerServiceException(t, programContext.dumpDebugInfo) + withLspTree(source, lspService => lspService.wordAutoComplete(source, environment, prefix, position))( + programContext + ) match { + case Right(value) => value + case Left(_) => AutoCompleteResponse(Array.empty) } } ) @@ -490,13 +878,9 @@ class SnapiCompilerService(engineDefinition: (Engine, Boolean))(implicit protect environment, _ => { val programContext = getProgramContext(environment.uid, environment) - try { - withLspTree(source, lspService => lspService.hover(source, environment, position))(programContext) match { - case Right(value) => value - case Left(_) => HoverResponse(None) - } - } catch { - case NonFatal(t) => throw new CompilerServiceException(t, programContext.dumpDebugInfo) + withLspTree(source, lspService => lspService.hover(source, environment, position))(programContext) match { + case Right(value) => value + case Left(_) => HoverResponse(None) } } ) @@ -507,13 +891,9 @@ class SnapiCompilerService(engineDefinition: (Engine, Boolean))(implicit protect environment, _ => { val programContext = getProgramContext(environment.uid, environment) - try { - withLspTree(source, lspService => lspService.rename(source, environment, position))(programContext) match { - case Right(value) => value - case Left(_) => RenameResponse(Array.empty) - } - } catch { - case NonFatal(t) => throw new CompilerServiceException(t, programContext.dumpDebugInfo) + withLspTree(source, lspService => lspService.rename(source, environment, position))(programContext) match { + case Right(value) => value + case Left(_) => RenameResponse(Array.empty) } } ) @@ -528,15 +908,11 @@ class SnapiCompilerService(engineDefinition: (Engine, Boolean))(implicit protect environment, _ => { val programContext = getProgramContext(environment.uid, environment) - try { - withLspTree(source, lspService => lspService.definition(source, environment, position))( - programContext - ) match { - case Right(value) => value - case Left(_) => GoToDefinitionResponse(None) - } - } catch { - case NonFatal(t) => throw new CompilerServiceException(t, programContext.dumpDebugInfo) + withLspTree(source, lspService => lspService.definition(source, environment, position))( + programContext + ) match { + case Right(value) => value + case Left(_) => GoToDefinitionResponse(None) } } ) @@ -547,16 +923,12 @@ class SnapiCompilerService(engineDefinition: (Engine, Boolean))(implicit protect environment, _ => { val programContext = getProgramContext(environment.uid, environment) - try { - withLspTree( - source, - lspService => lspService.validate - )(programContext) match { - case Right(value) => value - case Left((err, pos)) => ValidateResponse(parseError(err, pos)) - } - } catch { - case NonFatal(t) => throw new CompilerServiceException(t, programContext.dumpDebugInfo) + withLspTree( + source, + lspService => lspService.validate + )(programContext) match { + case Right(value) => value + case Left((err, pos)) => ValidateResponse(parseError(err, pos)) } } ) @@ -650,7 +1022,7 @@ class SnapiCompilerService(engineDefinition: (Engine, Boolean))(implicit protect } } - private def rawValueToPolyglotValue(rawValue: RawValue, ctx: Context): Value = { + private def rawValueToPolyglotValue(rawValue: RawValue, ctx: Context): Option[Value] = { val code: String = rawValue match { case RawNull() => "let x: undefined = null in x" case RawByte(v) => s"let x: byte = ${v}b in x" @@ -669,10 +1041,10 @@ class SnapiCompilerService(engineDefinition: (Engine, Boolean))(implicit protect s"""let x: timestamp = Timestamp.Build(${v.getYear}, ${v.getMonthValue}, ${v.getDayOfMonth}, ${v.getHour}, ${v.getMinute}, millis=${v.getNano / 1000000}) in x""" case RawInterval(years, months, weeks, days, hours, minutes, seconds, millis) => s"""let x: interval = Interval.Build(years=$years, months=$months, weeks=$weeks, days=$days, hours=$hours, minutes=$minutes, seconds=$seconds, millis=$millis) in x""" - case _ => throw new CompilerServiceException("type not supported") + case _ => return None } val value = ctx.eval("snapi", code) - ctx.asValue(value) + Some(ctx.asValue(value)) } private def buildTruffleContext( diff --git a/snapi-compiler/src/main/scala/com/rawlabs/snapi/compiler/TruffleExceptionExtractors.scala b/snapi-compiler/src/main/scala/com/rawlabs/snapi/compiler/TruffleExceptionExtractors.scala new file mode 100644 index 000000000..c124ca2ad --- /dev/null +++ b/snapi-compiler/src/main/scala/com/rawlabs/snapi/compiler/TruffleExceptionExtractors.scala @@ -0,0 +1,49 @@ +/* + * Copyright 2024 RAW Labs S.A. + * + * Use of this software is governed by the Business Source License + * included in the file licenses/BSL.txt. + * + * As of the Change Date specified in that file, in accordance with + * the Business Source License, use of this software will be governed + * by the Apache License, Version 2.0, included in the file + * licenses/APL.txt. + */ + +package com.rawlabs.snapi.compiler + +import com.rawlabs.compiler.{ErrorMessage, ErrorPosition, ErrorRange} +import com.rawlabs.snapi.frontend.snapi.antlr4.ParserErrors +import org.graalvm.polyglot.PolyglotException + +object TruffleValidationError { + def unapply(t: Throwable): Option[List[ErrorMessage]] = t match { + case ex: PolyglotException + if ex.isGuestException && ex.getGuestObject != null && + ex.getGuestObject.hasMembers && ex.getGuestObject.hasMember("errors") => + val err = ex.getGuestObject + // A validation error occurred in SnapiLanguage during the parse method + val errorsValue = err.getMember("errors") + val errors = (0L until errorsValue.getArraySize).map { i => + val errorValue = errorsValue.getArrayElement(i) + val message = errorValue.asString + val positions = (0L until errorValue.getArraySize).map { j => + val posValue = errorValue.getArrayElement(j) + val beginValue = posValue.getMember("begin") + val endValue = posValue.getMember("end") + val begin = ErrorPosition(beginValue.getMember("line").asInt, beginValue.getMember("column").asInt) + val end = ErrorPosition(endValue.getMember("line").asInt, endValue.getMember("column").asInt) + ErrorRange(begin, end) + } + ErrorMessage(message, positions.to, ParserErrors.ParserErrorCode) + } + Some(errors.to) + case _ => None + } +} + +object TruffleRuntimeError { + def unapply(t: Throwable): Option[String] = t match { + case ex: PolyglotException if ex.isGuestException => Some(ex.getMessage) + } +} diff --git a/snapi-compiler/src/test/scala/com/rawlabs/snapi/compiler/tests/SnapiTestContext.scala b/snapi-compiler/src/test/scala/com/rawlabs/snapi/compiler/tests/SnapiTestContext.scala index 32dc4b4d6..998df04f5 100644 --- a/snapi-compiler/src/test/scala/com/rawlabs/snapi/compiler/tests/SnapiTestContext.scala +++ b/snapi-compiler/src/test/scala/com/rawlabs/snapi/compiler/tests/SnapiTestContext.scala @@ -15,9 +15,9 @@ package com.rawlabs.snapi.compiler.tests import com.rawlabs.compiler.{ AutoCompleteResponse, CompilerServiceTestContext, - ExecutionRuntimeFailure, + EvalSuccess, + ExecutionError, ExecutionSuccess, - ExecutionValidationFailure, FormatCodeResponse, GoToDefinitionResponse, HoverResponse, @@ -55,6 +55,7 @@ import com.rawlabs.protocol.compiler.{ SQLServerConfig, SnowflakeConfig } +import com.rawlabs.protocol.raw.Value import com.rawlabs.utils.core._ import com.rawlabs.snapi.compiler.SnapiOutputTestContext @@ -607,6 +608,93 @@ trait SnapiTestContext def typeErrorAs(expectedErrors: String*) = new TypeErrorAs(expectedErrors) + ///////////////////////////////////////////////////////////////////////// + // evalSingle + ///////////////////////////////////////////////////////////////////////// + + class EvalSingle(expected: Value) extends Matcher[TestData] { + def apply(data: TestData) = { + compilerService.eval(data.q, getQueryEnvironment(), maybeDecl = None) match { + case Right(EvalSuccess.ResultValue(_, actual)) => MatchResult( + actual == expected, + s"""results didn't match! + |expected: $expected + |actual: $actual""".stripMargin, + s"""results matched: + |$actual""".stripMargin + ) + case _ => MatchResult(false, "didn't evaluate to a value", "???") + } + } + } + def evalSingle(expected: Value) = new EvalSingle(expected) + + ///////////////////////////////////////////////////////////////////////// + // evalIterator + ///////////////////////////////////////////////////////////////////////// + + class EvalIterator(expected: Seq[Value]) extends Matcher[TestData] { + def apply(data: TestData) = { + compilerService.eval(data.q, getQueryEnvironment(), maybeDecl = None) match { + case Right(EvalSuccess.IteratorValue(_, actual)) => + try { + val actualList = actual.toList + val expectedList = expected.toList + MatchResult( + actualList == expectedList, + s"""results didn't match! + |expected: $expectedList + |actual: $actualList""".stripMargin, + s"""results matched: + |$actualList""".stripMargin + ) + } finally actual.close() + case _ => MatchResult(false, "didn't evaluate to an iterator value", "???") + } + } + } + def evalIterator(expected: Value*) = new EvalIterator(expected) + + ///////////////////////////////////////////////////////////////////////// + // evalTypeErrorAs + ///////////////////////////////////////////////////////////////////////// + + class EvalTypeErrorAs(expected: Seq[String]) extends Matcher[TestData] { + override def apply(data: TestData): MatchResult = { + compilerService.eval(data.q, getQueryEnvironment(), maybeDecl = None) match { + case Left(ExecutionError.ValidationError(actual)) => + val leftOvers = expected.filter(e => !actual.exists(_.message.contains(e))) + MatchResult( + leftOvers.isEmpty, + s"didn't include error '${leftOvers.mkString(",")}' in '$actual'", + """failed as expected""" + ) + case _ => MatchResult(false, "didn't evaluate to a type error", "???") + } + } + } + + def evalTypeErrorAs(errors: String*) = new EvalTypeErrorAs(errors) + + ///////////////////////////////////////////////////////////////////////// + // evalRunErrorAs + ///////////////////////////////////////////////////////////////////////// + + class EvalRunErrorAs(expected: String) extends Matcher[TestData] { + override def apply(data: TestData): MatchResult = { + compilerService.eval(data.q, getQueryEnvironment(), maybeDecl = None) match { + case Left(ExecutionError.RuntimeError(actual)) => MatchResult( + expected.contains(actual), + s"""dind't include error '$expected' in '$actual'""", + """failed as expected""" + ) + case _ => MatchResult(false, "didn't evaluate to a runtime error", "???") + } + } + } + + def evalRunErrorAs(msg: String) = new EvalRunErrorAs(msg) + ///////////////////////////////////////////////////////////////////////// // evaluateTo ///////////////////////////////////////////////////////////////////////// @@ -861,9 +949,9 @@ trait SnapiTestContext maybeDecl, outputStream ) match { - case ExecutionValidationFailure(errs) => Left(errs.map(err => err.toString).mkString(",")) - case ExecutionRuntimeFailure(err) => Left(err) - case ExecutionSuccess(_) => Right(Path.of(outputStream.toString)) + case Left(ExecutionError.ValidationError(errs)) => Left(errs.map(err => err.toString).mkString(",")) + case Left(ExecutionError.RuntimeError(err)) => Left(err) + case Right(ExecutionSuccess(_)) => Right(Path.of(outputStream.toString)) } } finally { outputStream.close() @@ -910,9 +998,9 @@ trait SnapiTestContext logger.debug(s"Test infrastructure now writing output result to temporary location: $path") try { compilerService.execute(query, getQueryEnvironment(maybeArgs, scopes, options), maybeDecl, outputStream) match { - case ExecutionValidationFailure(errs) => Left(errs.map(err => err.toString).mkString(",")) - case ExecutionRuntimeFailure(err) => Left(err) - case ExecutionSuccess(_) => Right(path) + case Left(ExecutionError.ValidationError(errs)) => Left(errs.map(err => err.toString).mkString(",")) + case Left(ExecutionError.RuntimeError(err)) => Left(err) + case Right(ExecutionSuccess(_)) => Right(path) } } finally { outputStream.close() diff --git a/snapi-compiler/src/test/scala/com/rawlabs/snapi/compiler/tests/output/EvalTest.scala b/snapi-compiler/src/test/scala/com/rawlabs/snapi/compiler/tests/output/EvalTest.scala new file mode 100644 index 000000000..b7f859820 --- /dev/null +++ b/snapi-compiler/src/test/scala/com/rawlabs/snapi/compiler/tests/output/EvalTest.scala @@ -0,0 +1,270 @@ +/* + * Copyright 2024 RAW Labs S.A. + * + * Use of this software is governed by the Business Source License + * included in the file licenses/BSL.txt. + * + * As of the Change Date specified in that file, in accordance with + * the Business Source License, use of this software will be governed + * by the Apache License, Version 2.0, included in the file + * licenses/APL.txt. + */ + +package com.rawlabs.snapi.compiler.tests.output + +import com.rawlabs.protocol.raw.{ + Value, + ValueBinary, + ValueBool, + ValueByte, + ValueDate, + ValueDecimal, + ValueDouble, + ValueError, + ValueFloat, + ValueInt, + ValueList, + ValueLong, + ValueNull, + ValueRecord, + ValueRecordField, + ValueShort, + ValueString, + ValueTime, + ValueTimestamp +} +import com.rawlabs.snapi.compiler.tests.SnapiTestContext +import com.rawlabs.snapi.frontend.snapi.SnapiInterpolator +import com.rawlabs.utils.sources.filesystem.local.LocalLocationsTestContext + +import java.io.File + +// FIXME (msb): This test should remain in SnapiCompiler while all other tests should become Truffle "value" tests. +class EvalTest extends SnapiTestContext with LocalLocationsTestContext { + + ////////////////////////// + // Basic tests + ////////////////////////// + + // Test single value + test(snapi"""42""")(it => it should evalSingle(Value.newBuilder().setInt(ValueInt.newBuilder().setV(42)).build())) + + // Test iterator + test(snapi"""Collection.Build(1,2,3)""") { it => + it should evalIterator( + Value.newBuilder().setInt(ValueInt.newBuilder().setV(1)).build(), + Value.newBuilder().setInt(ValueInt.newBuilder().setV(2)).build(), + Value.newBuilder().setInt(ValueInt.newBuilder().setV(3)).build() + ) + } + + // Test list (vs iterator) + test(snapi"""[1,2,3]""") { it => + it should evalSingle( + Value + .newBuilder() + .setList( + ValueList + .newBuilder() + .addValues(Value.newBuilder().setInt(ValueInt.newBuilder().setV(1)).build()) + .addValues(Value.newBuilder().setInt(ValueInt.newBuilder().setV(2)).build()) + .addValues(Value.newBuilder().setInt(ValueInt.newBuilder().setV(3)).build()) + .build() + ) + .build() + ) + } + + // Test validation failure + test(snapi"""x + 1""")(it => it should evalTypeErrorAs("x is not declared")) + + // Test runtime failure + test(snapi"""Regex.Groups("aaa", "(\\d+)") """)(it => + it should evalRunErrorAs("string 'aaa' does not match pattern '(\\d+)'") + ) + + // Test failure during iteration + test("""Collection.Build( + | 1, + | 2, + | Error.Build("foo") + |)""".stripMargin) { it => + it should evalIterator( + Value.newBuilder().setInt(ValueInt.newBuilder().setV(1)).build(), + Value.newBuilder().setInt(ValueInt.newBuilder().setV(2)).build(), + Value.newBuilder().setError(ValueError.newBuilder().setMessage("foo")).build() + ) + } + + // Record (w/ duplicated fields remain) + test("""{a: 1, b: 2, a: 3, c: 4, a: 5}""") { it => + it should evalSingle( + Value + .newBuilder() + .setRecord( + ValueRecord + .newBuilder() + .addFields( + ValueRecordField + .newBuilder() + .setName("a") + .setValue(Value.newBuilder().setInt(ValueInt.newBuilder().setV(1)).build()) + ) + .addFields( + ValueRecordField + .newBuilder() + .setName("b") + .setValue(Value.newBuilder().setInt(ValueInt.newBuilder().setV(2)).build()) + ) + .addFields( + ValueRecordField + .newBuilder() + .setName("a") + .setValue(Value.newBuilder().setInt(ValueInt.newBuilder().setV(3)).build()) + ) + .addFields( + ValueRecordField + .newBuilder() + .setName("c") + .setValue(Value.newBuilder().setInt(ValueInt.newBuilder().setV(4)).build()) + ) + .addFields( + ValueRecordField + .newBuilder() + .setName("a") + .setValue(Value.newBuilder().setInt(ValueInt.newBuilder().setV(5)).build()) + ) + ) + .build() + ) + } + + ////////////////////////// + // Binary data + ////////////////////////// + + test("""Binary.FromString("Hello World")""") { it => + val expected = "Hello World".getBytes("UTF-8") + it should evalSingle( + Value + .newBuilder() + .setBinary(ValueBinary.newBuilder().setV(com.google.protobuf.ByteString.copyFrom(expected))) + .build() + ) + } + + test(s"""Binary.Read("$peopleExcel")""") { it => + val expected = new File(peopleExcel.drop("file:".length)) + it should evalSingle( + Value + .newBuilder() + .setBinary( + ValueBinary.newBuilder().setV(com.google.protobuf.ByteString.readFrom(new java.io.FileInputStream(expected))) + ) + .build() + ) + } + + // A nullable result is handled, and null results to an empty file. + test(s"""Binary.Read(if (1 == 0) then "$peopleExcel" else null)""") { it => + it should evalSingle(Value.newBuilder().setNull(ValueNull.newBuilder().build()).build()) + } + + // An error fails the execution. + test(s"""Binary.Read("file:/not/found")""")(it => + it should evalRunErrorAs("file system error: path not found: /not/found") + ) + + ////////////////////////// + // Text data + ////////////////////////// + + test(""" "Hello World" """) { it => + val expected = "Hello World" + it should evalSingle(Value.newBuilder().setString(ValueString.newBuilder().setV(expected)).build()) + } + + test(""" if true then "Hello World" else null""") { it => + val expected = "Hello World" + it should evalSingle(Value.newBuilder().setString(ValueString.newBuilder().setV(expected)).build()) + } + + test(""" if false then "Hello World" else null""") { it => + it should evalSingle(Value.newBuilder().setNull(ValueNull.newBuilder()).build()) + } + + ////////////////////////// + // Remaining types + ////////////////////////// + + // format: off + test("""Collection.Build( + | { + | byteCol: Byte.From("1"), + | shortCol: Short.From("10"), + | intCol: Int.From("100"), + | longCol: Long.From("1000"), + | floatCol: Float.From("3.14"), + | doubleCol: Double.From("6.28"), + | decimalCol: Decimal.From("9.42"), + | boolCol: true, + | dateCol: Date.Parse("12/25/2023", "M/d/yyyy"), + | timeCol: Time.Parse("01:02:03", "H:m:s"), + | timestampCol: Timestamp.Parse("12/25/2023 01:02:03", "M/d/yyyy H:m:s"), + | binaryCol: Binary.FromString("Hello World!"), + | stringCol: "Hello,World!" + | }, + | { + | byteCol: Byte.From("120"), + | shortCol: Short.From("2500"), + | intCol: Int.From("25000"), + | longCol: Long.From("9223372036854775807"), + | floatCol: Float.From("30.14"), + | doubleCol: Double.From("60.28"), + | decimalCol: Decimal.From("90.42"), + | boolCol: false, + | dateCol: Date.Parse("2/5/2023", "M/d/yyyy"), + | timeCol: Time.Parse("11:12:13", "H:m:s"), + | timestampCol: Timestamp.Parse("2/5/2023 11:12:13", "M/d/yyyy H:m:s"), + | binaryCol: Binary.FromString("Olala!"), + | stringCol: "Ciao World!" + | } + |)""".stripMargin) { it => + it should evalIterator( + Value.newBuilder() + .setRecord(ValueRecord.newBuilder() + .addFields(ValueRecordField.newBuilder().setName("byteCol").setValue(Value.newBuilder().setByte(ValueByte.newBuilder().setV(1)).build())) + .addFields(ValueRecordField.newBuilder().setName("shortCol").setValue(Value.newBuilder().setShort(ValueShort.newBuilder().setV(10)).build())) + .addFields(ValueRecordField.newBuilder().setName("intCol").setValue(Value.newBuilder().setInt(ValueInt.newBuilder().setV(100)).build())) + .addFields(ValueRecordField.newBuilder().setName("longCol").setValue(Value.newBuilder().setLong(ValueLong.newBuilder().setV(1000)).build())) + .addFields(ValueRecordField.newBuilder().setName("floatCol").setValue(Value.newBuilder().setFloat(ValueFloat.newBuilder().setV(3.14f)).build())) + .addFields(ValueRecordField.newBuilder().setName("doubleCol").setValue(Value.newBuilder().setDouble(ValueDouble.newBuilder().setV(6.28)).build())) + .addFields(ValueRecordField.newBuilder().setName("decimalCol").setValue(Value.newBuilder().setDecimal(ValueDecimal.newBuilder().setV("9.42")).build())) + .addFields(ValueRecordField.newBuilder().setName("boolCol").setValue(Value.newBuilder().setBool(ValueBool.newBuilder().setV(true)).build())) + .addFields(ValueRecordField.newBuilder().setName("dateCol").setValue(Value.newBuilder().setDate(ValueDate.newBuilder().setYear(2023).setMonth(12).setDay(25)).build())) + .addFields(ValueRecordField.newBuilder().setName("timeCol").setValue(Value.newBuilder().setTime(ValueTime.newBuilder().setHour(1).setMinute(2).setSecond(3)).build())) + .addFields(ValueRecordField.newBuilder().setName("timestampCol").setValue(Value.newBuilder().setTimestamp(ValueTimestamp.newBuilder().setYear(2023).setMonth(12).setDay(25).setHour(1).setMinute(2).setSecond(3)).build())) + .addFields(ValueRecordField.newBuilder().setName("binaryCol").setValue(Value.newBuilder().setBinary(ValueBinary.newBuilder().setV(com.google.protobuf.ByteString.copyFrom("Hello World!".getBytes("UTF-8")))).build())) + .addFields(ValueRecordField.newBuilder().setName("stringCol").setValue(Value.newBuilder().setString(ValueString.newBuilder().setV("Hello,World!")).build())) + ).build(), + Value.newBuilder() + .setRecord(ValueRecord.newBuilder() + .addFields(ValueRecordField.newBuilder().setName("byteCol").setValue(Value.newBuilder().setByte(ValueByte.newBuilder().setV(120)).build())) + .addFields(ValueRecordField.newBuilder().setName("shortCol").setValue(Value.newBuilder().setShort(ValueShort.newBuilder().setV(2500)).build())) + .addFields(ValueRecordField.newBuilder().setName("intCol").setValue(Value.newBuilder().setInt(ValueInt.newBuilder().setV(25000)).build())) + .addFields(ValueRecordField.newBuilder().setName("longCol").setValue(Value.newBuilder().setLong(ValueLong.newBuilder().setV(9223372036854775807L)).build())) + .addFields(ValueRecordField.newBuilder().setName("floatCol").setValue(Value.newBuilder().setFloat(ValueFloat.newBuilder().setV(30.14f)).build())) + .addFields(ValueRecordField.newBuilder().setName("doubleCol").setValue(Value.newBuilder().setDouble(ValueDouble.newBuilder().setV(60.28)).build())) + .addFields(ValueRecordField.newBuilder().setName("decimalCol").setValue(Value.newBuilder().setDecimal(ValueDecimal.newBuilder().setV("90.42")).build())) + .addFields(ValueRecordField.newBuilder().setName("boolCol").setValue(Value.newBuilder().setBool(ValueBool.newBuilder().setV(false)).build())) + .addFields(ValueRecordField.newBuilder().setName("dateCol").setValue(Value.newBuilder().setDate(ValueDate.newBuilder().setYear(2023).setMonth(2).setDay(5)).build())) + .addFields(ValueRecordField.newBuilder().setName("timeCol").setValue(Value.newBuilder().setTime(ValueTime.newBuilder().setHour(11).setMinute(12).setSecond(13)).build())) + .addFields(ValueRecordField.newBuilder().setName("timestampCol").setValue(Value.newBuilder().setTimestamp(ValueTimestamp.newBuilder().setYear(2023).setMonth(2).setDay(5).setHour(11).setMinute(12).setSecond(13)).build())) + .addFields(ValueRecordField.newBuilder().setName("binaryCol").setValue(Value.newBuilder().setBinary(ValueBinary.newBuilder().setV(com.google.protobuf.ByteString.copyFrom("Olala!".getBytes("UTF-8")))).build())) + .addFields(ValueRecordField.newBuilder().setName("stringCol").setValue(Value.newBuilder().setString(ValueString.newBuilder().setV("Ciao World!")).build())) + ).build() + ) + } + // format: on + +} diff --git a/snapi-compiler/src/test/scala/com/rawlabs/snapi/compiler/tests/regressions/RD10767Test.scala b/snapi-compiler/src/test/scala/com/rawlabs/snapi/compiler/tests/regressions/RD10767Test.scala index e86424abf..8f65324f4 100644 --- a/snapi-compiler/src/test/scala/com/rawlabs/snapi/compiler/tests/regressions/RD10767Test.scala +++ b/snapi-compiler/src/test/scala/com/rawlabs/snapi/compiler/tests/regressions/RD10767Test.scala @@ -12,7 +12,7 @@ package com.rawlabs.snapi.compiler.tests.regressions -import com.rawlabs.compiler.{GetProgramDescriptionSuccess, ProgramEnvironment} +import com.rawlabs.compiler.ProgramEnvironment import com.rawlabs.snapi.frontend.snapi.SnapiInterpolator import com.rawlabs.snapi.compiler.tests.SnapiTestContext @@ -39,7 +39,7 @@ class RD10767Test extends SnapiTestContext { None ) compilerService.getProgramDescription(it.q, programEnvironment) match { - case GetProgramDescriptionSuccess(desc) => + case Right(desc) => assert(desc.maybeRunnable.isDefined, "Expected a runnable program") val decls = desc.decls("data_type") assert(decls.head.outType.isEmpty) @@ -62,7 +62,7 @@ class RD10767Test extends SnapiTestContext { None ) compilerService.getProgramDescription(it.q, programEnvironment) match { - case GetProgramDescriptionSuccess(desc) => + case Right(desc) => assert(desc.maybeRunnable.isDefined, "Expected a runnable program") val decls = desc.decls("func") assert(decls.head.outType.isEmpty) diff --git a/snapi-frontend/src/main/scala/com/rawlabs/snapi/frontend/base/ProgramContext.scala b/snapi-frontend/src/main/scala/com/rawlabs/snapi/frontend/base/ProgramContext.scala index 03a69715e..81507562b 100644 --- a/snapi-frontend/src/main/scala/com/rawlabs/snapi/frontend/base/ProgramContext.scala +++ b/snapi-frontend/src/main/scala/com/rawlabs/snapi/frontend/base/ProgramContext.scala @@ -12,7 +12,7 @@ package com.rawlabs.snapi.frontend.base -import com.rawlabs.compiler.{CompilerService, ProgramEnvironment} +import com.rawlabs.compiler.ProgramEnvironment import com.rawlabs.utils.core.RawSettings /** @@ -27,7 +27,16 @@ trait ProgramContext { def settings: RawSettings = compilerContext.settings def dumpDebugInfo: List[(String, String)] = { - CompilerService.getDebugInfo(programEnvironment) + List( + "Trace ID" -> programEnvironment.maybeTraceId.getOrElse(""), + "Arguments" -> programEnvironment.maybeArguments + .map(args => args.map { case (k, v) => s"$k -> $v" }.mkString("\n")) + .getOrElse(""), + "Uid" -> programEnvironment.uid.toString, + "Scopes" -> programEnvironment.scopes.mkString(","), + "Options" -> programEnvironment.options.map { case (k, v) => s"$k -> $v" }.mkString("\n") + //"Settings" -> runtimeContext.settings.toString + ) } } diff --git a/sql-compiler/src/main/scala/com/rawlabs/sql/compiler/NamedParametersPreparedStatement.scala b/sql-compiler/src/main/scala/com/rawlabs/sql/compiler/NamedParametersPreparedStatement.scala index 4610d03ec..fc5874fe5 100644 --- a/sql-compiler/src/main/scala/com/rawlabs/sql/compiler/NamedParametersPreparedStatement.scala +++ b/sql-compiler/src/main/scala/com/rawlabs/sql/compiler/NamedParametersPreparedStatement.scala @@ -16,6 +16,7 @@ import com.rawlabs.compiler.{ ErrorMessage, ErrorPosition, ErrorRange, + ExecutionError, RawBinary, RawBool, RawByte, @@ -463,9 +464,13 @@ class NamedParametersPreparedStatement( def errorPosition(p: Position): ErrorPosition = ErrorPosition(p.line, p.column) ErrorRange(errorPosition(position), errorPosition(position1)) } + def executeWith( parameters: Seq[(String, RawValue)] - ): Either[String, NamedParametersPreparedStatementExecutionResult] = { + ): Either[ExecutionError, NamedParametersPreparedStatementExecutionResult] = { + + import ExecutionError._ + val mandatoryParameters = { for ( (name, diagnostic) <- declaredTypeInfo @@ -478,8 +483,19 @@ class NamedParametersPreparedStatement( setParam(p, v) mandatoryParameters.remove(p) } - if (mandatoryParameters.nonEmpty) Left(s"no value was specified for ${mandatoryParameters.mkString(", ")}") - else + if (mandatoryParameters.nonEmpty) { + Left( + ValidationError( + List( + ErrorMessage( + s"no value was specified for ${mandatoryParameters.mkString(", ")}", + Nil, + ErrorCode.SqlErrorCode + ) + ) + ) + ) + } else try { val isResultSet = stmt.execute() if (isResultSet) Right(NamedParametersPreparedStatementResultSet(stmt.getResultSet)) @@ -495,8 +511,9 @@ class NamedParametersPreparedStatement( // that has ... changed (e.g. the database doesn't have the table anymore, a remote service // account has expired (RD-10895)). We report these errors to the user. case ex: PSQLException => + // These are still considered validation errors. val error = ex.getMessage // it has the code, the message, hint, etc. - Left(error) + Left(RuntimeError(error)) } } diff --git a/sql-compiler/src/main/scala/com/rawlabs/sql/compiler/SqlCompilerService.scala b/sql-compiler/src/main/scala/com/rawlabs/sql/compiler/SqlCompilerService.scala index bed44dbdb..085da8c1e 100644 --- a/sql-compiler/src/main/scala/com/rawlabs/sql/compiler/SqlCompilerService.scala +++ b/sql-compiler/src/main/scala/com/rawlabs/sql/compiler/SqlCompilerService.scala @@ -14,14 +14,10 @@ package com.rawlabs.sql.compiler import com.google.common.cache.{CacheBuilder, CacheLoader} import com.rawlabs.compiler._ +import com.rawlabs.protocol.raw.{Value, ValueInt, ValueRecord, ValueRecordField} import com.rawlabs.sql.compiler.antlr4.{ParseProgramResult, SqlIdnNode, SqlParamUseNode, SqlSyntaxAnalyzer} import com.rawlabs.sql.compiler.metadata.UserMetadataCache -import com.rawlabs.sql.compiler.writers.{ - StatusCsvWriter, - StatusJsonWriter, - TypedResultSetCsvWriter, - TypedResultSetJsonWriter -} +import com.rawlabs.sql.compiler.writers._ import com.rawlabs.utils.core.{RawSettings, RawUtils} import org.bitbucket.inkytonik.kiama.util.Positions @@ -87,76 +83,68 @@ class SqlCompilerService()(implicit protected val settings: RawSettings) extends override def getProgramDescription( source: String, environment: ProgramEnvironment - ): GetProgramDescriptionResponse = { - try { - logger.debug(s"Getting program description: $source") - safeParse(source) match { - case Left(errors) => GetProgramDescriptionFailure(errors) - case Right(parsedTree) => + ): Either[List[ErrorMessage], ProgramDescription] = { + safeParse(source) match { + case Left(errors) => Left(errors) + case Right(parsedTree) => + try { + val conn = connectionPool.getConnection(environment.jdbcUrl.get) try { - val conn = connectionPool.getConnection(environment.jdbcUrl.get) - try { - val stmt = new NamedParametersPreparedStatement(conn, parsedTree) - val description = stmt.queryMetadata match { - case Right(info) => - val queryParamInfo = info.parameters - val outputType = pgRowTypeToIterableType(info.outputType) - val parameterInfo = queryParamInfo - .map { - case (name, paramInfo) => SqlTypesUtils.rawTypeFromPgType(paramInfo.pgType).map { rawType => - // we ignore tipe.nullable and mark all parameters as nullable - val paramType = rawType match { - case RawAnyType() => rawType; - case other => other.cloneNullable - } - ParamDescription( - name, - Some(paramType), - paramInfo.default, - comment = paramInfo.comment, - required = paramInfo.default.isEmpty - ) + val stmt = new NamedParametersPreparedStatement(conn, parsedTree) + val description = stmt.queryMetadata match { + case Right(info) => + val queryParamInfo = info.parameters + val outputType = pgRowTypeToIterableType(info.outputType) + val parameterInfo = queryParamInfo + .map { + case (name, paramInfo) => SqlTypesUtils.rawTypeFromPgType(paramInfo.pgType).map { rawType => + // we ignore tipe.nullable and mark all parameters as nullable + val paramType = rawType match { + case RawAnyType() => rawType; + case other => other.cloneNullable } - } - .foldLeft(Right(Seq.empty): Either[Seq[String], Seq[ParamDescription]]) { - case (Left(errors), Left(error)) => Left(errors :+ error) - case (_, Left(error)) => Left(Seq(error)) - case (Right(params), Right(param)) => Right(params :+ param) - case (errors @ Left(_), _) => errors - case (_, Right(param)) => Right(Seq(param)) - } - (outputType, parameterInfo) match { - case (Right(iterableType), Right(ps)) => - // Regardless if there are parameters, we declare a main function with the output type. - // This permits the publish endpoints from the UI (https://raw-labs.atlassian.net/browse/RD-10359) - val ok = ProgramDescription( - Map.empty, - Some(DeclDescription(Some(ps.toVector), Some(iterableType), None)), - None - ) - GetProgramDescriptionSuccess(ok) - case _ => - val errorMessages = - outputType.left.getOrElse(Seq.empty) ++ parameterInfo.left.getOrElse(Seq.empty) - GetProgramDescriptionFailure(treeErrors(parsedTree, errorMessages).toList) + ParamDescription( + name, + Some(paramType), + paramInfo.default, + comment = paramInfo.comment, + required = paramInfo.default.isEmpty + ) + } } - case Left(errors) => GetProgramDescriptionFailure(errors) - } - RawUtils.withSuppressNonFatalException(stmt.close()) - description - } catch { - case e: NamedParametersPreparedStatementException => GetProgramDescriptionFailure(e.errors) - } finally { - RawUtils.withSuppressNonFatalException(conn.close()) + .foldLeft(Right(Seq.empty): Either[Seq[String], Seq[ParamDescription]]) { + case (Left(errors), Left(error)) => Left(errors :+ error) + case (_, Left(error)) => Left(Seq(error)) + case (Right(params), Right(param)) => Right(params :+ param) + case (errors @ Left(_), _) => errors + case (_, Right(param)) => Right(Seq(param)) + } + (outputType, parameterInfo) match { + case (Right(iterableType), Right(ps)) => + // Regardless if there are parameters, we declare a main function with the output type. + // This permits the publish endpoints from the UI (https://raw-labs.atlassian.net/browse/RD-10359) + val ok = ProgramDescription( + Map.empty, + Some(DeclDescription(Some(ps.toVector), Some(iterableType), None)), + None + ) + Right(ok) + case _ => + val errorMessages = outputType.left.getOrElse(Seq.empty) ++ parameterInfo.left.getOrElse(Seq.empty) + Left(treeErrors(parsedTree, errorMessages).toList) + } + case Left(errors) => Left(errors) } + RawUtils.withSuppressNonFatalException(stmt.close()) + description } catch { - case ex: SQLException if isConnectionFailure(ex) => - logger.warn("SqlConnectionPool connection failure", ex) - GetProgramDescriptionFailure(List(treeError(parsedTree, ex.getMessage))) + case e: NamedParametersPreparedStatementException => Left(e.errors) + } finally { + RawUtils.withSuppressNonFatalException(conn.close()) } - } - } catch { - case NonFatal(t) => throw new CompilerServiceException(t, environment) + } catch { + case ex: SQLException if isConnectionFailure(ex) => Left(List(treeError(parsedTree, ex.getMessage))) + } } } @@ -175,11 +163,12 @@ class SqlCompilerService()(implicit protected val settings: RawSettings) extends maybeDecl: Option[String], outputStream: OutputStream, maxRows: Option[Long] - ): ExecutionResponse = { + ): Either[ExecutionError, ExecutionSuccess] = { + import ExecutionError._ + try { - logger.debug(s"Executing: $source") safeParse(source) match { - case Left(errors) => ExecutionValidationFailure(errors) + case Left(errors) => Left(ValidationError(errors)) case Right(parsedTree) => val conn = connectionPool.getConnection(environment.jdbcUrl.get) try { @@ -197,26 +186,23 @@ class SqlCompilerService()(implicit protected val settings: RawSettings) extends // No ResultSet, it was an update. Return a status in the expected format. updateResultRendering(environment, outputStream, count, maxRows) } - case Left(error) => ExecutionRuntimeFailure(error) + case Left(error) => Left(error) } - case Left(errors) => ExecutionRuntimeFailure(errors.mkString(", ")) + case Left(errors) => Left(RuntimeError(errors.mkString(", "))) } - case Left(errors) => ExecutionValidationFailure(errors) + case Left(errors) => Left(ValidationError(errors)) } } finally { RawUtils.withSuppressNonFatalException(pstmt.close()) } } catch { - case e: NamedParametersPreparedStatementException => ExecutionValidationFailure(e.errors) + case e: NamedParametersPreparedStatementException => Left(ValidationError(e.errors)) } finally { RawUtils.withSuppressNonFatalException(conn.close()) } } } catch { - case ex: SQLException if isConnectionFailure(ex) => - logger.warn("SqlConnectionPool connection failure", ex) - ExecutionRuntimeFailure(ex.getMessage) - case NonFatal(t) => throw new CompilerServiceException(t, environment) + case ex: SQLException if isConnectionFailure(ex) => Left(RuntimeError(ex.getMessage)) } } @@ -226,13 +212,15 @@ class SqlCompilerService()(implicit protected val settings: RawSettings) extends v: ResultSet, outputStream: OutputStream, maxRows: Option[Long] - ): ExecutionResponse = { + ): Either[ExecutionError, ExecutionSuccess] = { + import ExecutionError._ + environment.options .get("output-format") .map(_.toLowerCase) match { case Some("csv") => if (!TypedResultSetCsvWriter.outputWriteSupport(tipe)) { - ExecutionRuntimeFailure("unsupported type") + RuntimeError("unsupported type") } val windowsLineEnding = environment.options.get("windows-line-ending") match { case Some("true") => true @@ -242,28 +230,27 @@ class SqlCompilerService()(implicit protected val settings: RawSettings) extends val w = new TypedResultSetCsvWriter(outputStream, lineSeparator, maxRows) try { w.write(v, tipe) - ExecutionSuccess(w.complete) + Right(ExecutionSuccess(w.complete)) } catch { - case ex: IOException => ExecutionRuntimeFailure(ex.getMessage) + case ex: IOException => Left(RuntimeError(ex.getMessage)) } finally { RawUtils.withSuppressNonFatalException(w.close()) } case Some("json") => if (!TypedResultSetJsonWriter.outputWriteSupport(tipe)) { - ExecutionRuntimeFailure("unsupported type") + RuntimeError("unsupported type") } val w = new TypedResultSetJsonWriter(outputStream, maxRows) try { w.write(v, tipe) - ExecutionSuccess(w.complete) + Right(ExecutionSuccess(w.complete)) } catch { - case ex: IOException => ExecutionRuntimeFailure(ex.getMessage) + case ex: IOException => Left(RuntimeError(ex.getMessage)) } finally { RawUtils.withSuppressNonFatalException(w.close()) } - case _ => ExecutionRuntimeFailure("unknown output format") + case _ => Left(RuntimeError("unknown output format")) } - } private def updateResultRendering( @@ -271,7 +258,9 @@ class SqlCompilerService()(implicit protected val settings: RawSettings) extends stream: OutputStream, count: Int, maybeLong: Option[Long] - ) = { + ): Either[ExecutionError, ExecutionSuccess] = { + import ExecutionError._ + environment.options .get("output-format") .map(_.toLowerCase) match { @@ -284,8 +273,9 @@ class SqlCompilerService()(implicit protected val settings: RawSettings) extends val writer = new StatusCsvWriter(stream, lineSeparator) try { writer.write(count) + Right(ExecutionSuccess(true)) } catch { - case ex: IOException => ExecutionRuntimeFailure(ex.getMessage) + case ex: IOException => Left(RuntimeError(ex.getMessage)) } finally { RawUtils.withSuppressNonFatalException(writer.close()) } @@ -293,54 +283,166 @@ class SqlCompilerService()(implicit protected val settings: RawSettings) extends val w = new StatusJsonWriter(stream) try { w.write(count) - ExecutionSuccess(true) + Right(ExecutionSuccess(true)) } catch { - case ex: IOException => ExecutionRuntimeFailure(ex.getMessage) + case ex: IOException => Left(RuntimeError(ex.getMessage)) } finally { RawUtils.withSuppressNonFatalException(w.close()) } - case _ => ExecutionRuntimeFailure("unknown output format") + case _ => Left(RuntimeError("unknown output format")) + } + } + + override def eval( + source: String, + environment: ProgramEnvironment, + maybeDecl: Option[String] + ): Either[ExecutionError, EvalSuccess] = { + + import EvalSuccess._ + import ExecutionError._ + + // 1) Parse + safeParse(source) match { + case Left(parseErrors) => Left(ValidationError(parseErrors)) + + case Right(parsedTree) => + // 2) Attempt to get a connection from the pool + val conn = + try connectionPool.getConnection(environment.jdbcUrl.get) + catch { + case ex: SQLException if isConnectionFailure(ex) => + return Left(ValidationError(List(ErrorMessage(ex.getMessage, Nil, ErrorCode.SqlErrorCode)))) + } + + // If parse and connection succeeded, proceed: + try { + // 3) Build statement + val pstmt = new NamedParametersPreparedStatement(conn, parsedTree, environment.scopes) + // We do NOT close `pstmt` right away. We only close if we fail or once iteration is done. + // So no try/finally here that automatically kills it. + val pstmtAutoClose = asAutoCloseable(pstmt)(_.close()) + + // 4) Query metadata + pstmt.queryMetadata match { + case Left(errors) => + // If we can't get metadata, close everything now. + closeQuietly(pstmtAutoClose, conn) + Left(ValidationError(errors)) + + case Right(info) => pgRowTypeToIterableType(info.outputType) match { + case Left(errs) => + // Another error => close. + closeQuietly(pstmtAutoClose, conn) + Left(ValidationError(List(ErrorMessage(errs.mkString(", "), Nil, ErrorCode.SqlErrorCode)))) + + case Right(iterableType) => + // Actually run the statement + val arguments = environment.maybeArguments.getOrElse(Array.empty) + pstmt.executeWith(arguments) match { + case Left(error) => + // Execution failed => close + closeQuietly(pstmtAutoClose, conn) + Left(error) + + case Right(result) => + // We have a success; produce a streaming or single-value iterator + val protocolType = TypeConverter.toProtocolType(iterableType.innerType) + + result match { + case NamedParametersPreparedStatementResultSet(rs) => + // Return a streaming iterator that closes RS, stmt, conn in close() + val valueIterator = new TypedResultSetRawValueIterator(rs, iterableType) with AutoCloseable { + override def close(): Unit = { + // Freed when the caller is done iterating + closeQuietly(rs, pstmtAutoClose, conn) + } + } + + Right(IteratorValue(protocolType, valueIterator)) + + case NamedParametersPreparedStatementUpdate(countV) => + // A single-value scenario (the "UPDATE count" integer). + // It wraps the count in a record with "update_count" + val resultValue = Value + .newBuilder() + .setRecord( + ValueRecord + .newBuilder() + .addFields( + ValueRecordField + .newBuilder() + .setName("update_count") + .setValue(Value.newBuilder().setInt(ValueInt.newBuilder().setV(countV)).build()) + ) + .build() + ) + .build() + + // Close everything now + closeQuietly(pstmtAutoClose) + closeQuietly(conn) + + Right(ResultValue(protocolType, resultValue)) + } + } + } + } + + } catch { + // If building/executing statement fails badly: + case e: NamedParametersPreparedStatementException => + closeQuietly(conn) // stmt might not even have been created + Left(ValidationError(e.errors)) + + case t: Throwable => + closeQuietly(conn) + throw t + } } - ExecutionSuccess(true) } + /** Utility to close multiple resources ignoring non-fatal exceptions. */ + private def closeQuietly(resources: AutoCloseable*): Unit = { + resources.foreach { r => + try r.close() + catch { case NonFatal(_) => () } + } + } + + /** + * Helper to wrap anything that has a `.close()` method into an `AutoCloseable`. + * That way we can pass it to `closeQuietly(...)`. + */ + private def asAutoCloseable[T](obj: T)(closeFn: T => Unit): AutoCloseable = + new AutoCloseable { def close(): Unit = closeFn(obj) } + override def formatCode( source: String, environment: ProgramEnvironment, maybeIndent: Option[Int], maybeWidth: Option[Int] ): FormatCodeResponse = { - try { - FormatCodeResponse(Some(source)) - } catch { - case NonFatal(t) => throw new CompilerServiceException(t, environment) - } + FormatCodeResponse(Some(source)) } override def dotAutoComplete(source: String, environment: ProgramEnvironment, position: Pos): AutoCompleteResponse = { - try { - logger.debug(s"dotAutoComplete at position: $position") - val analyzer = new SqlCodeUtils(parse(source)) - // The editor removes the dot in the completion event - // So we call the identifier with +1 column - analyzer.identifierUnder(Pos(position.line, position.column + 1)) match { - case Some(idn: SqlIdnNode) => - val metadataBrowser = metadataBrowsers.get(environment.jdbcUrl.get) - val matches = metadataBrowser.getDotCompletionMatches(idn) - val collectedValues = matches.collect { - case (idns, tipe) => - // If the last identifier is quoted, we need to quote the completion - val name = if (idns.last.quoted) '"' + idns.last.value + '"' else idns.last.value - LetBindCompletion(name, tipe) - } - logger.debug(s"dotAutoComplete returned ${collectedValues.size} matches") - AutoCompleteResponse(collectedValues.toArray) - case Some(_: SqlParamUseNode) => - AutoCompleteResponse(Array.empty) // dot completion makes no sense on parameters - case _ => AutoCompleteResponse(Array.empty) - } - } catch { - case NonFatal(t) => throw new CompilerServiceException(t, environment) + val analyzer = new SqlCodeUtils(parse(source)) + // The editor removes the dot in the completion event + // So we call the identifier with +1 column + analyzer.identifierUnder(Pos(position.line, position.column + 1)) match { + case Some(idn: SqlIdnNode) => + val metadataBrowser = metadataBrowsers.get(environment.jdbcUrl.get) + val matches = metadataBrowser.getDotCompletionMatches(idn) + val collectedValues = matches.collect { + case (idns, tipe) => + // If the last identifier is quoted, we need to quote the completion + val name = if (idns.last.quoted) '"' + idns.last.value + '"' else idns.last.value + LetBindCompletion(name, tipe) + } + AutoCompleteResponse(collectedValues.toArray) + case Some(_: SqlParamUseNode) => AutoCompleteResponse(Array.empty) // dot completion makes no sense on parameters + case _ => AutoCompleteResponse(Array.empty) } } @@ -350,72 +452,59 @@ class SqlCompilerService()(implicit protected val settings: RawSettings) extends prefix: String, position: Pos ): AutoCompleteResponse = { - try { - logger.debug(s"wordAutoComplete at position: $position") - val tree = parse(source) - val analyzer = new SqlCodeUtils(tree) - val item = analyzer.identifierUnder(position) - logger.debug(s"idn $item") - val matches: Seq[Completion] = item match { - case Some(idn: SqlIdnNode) => - val metadataBrowser = metadataBrowsers.get(environment.jdbcUrl.get) - val matches = metadataBrowser.getWordCompletionMatches(idn) - matches.collect { case (idns, value) => LetBindCompletion(idns.last.value, value) } - case Some(use: SqlParamUseNode) => tree.params.collect { - case (p, paramDescription) if p.startsWith(use.name) => - FunParamCompletion(p, paramDescription.tipe.getOrElse("")) - }.toSeq - case _ => Array.empty[Completion] - } - AutoCompleteResponse(matches.toArray) - } catch { - case NonFatal(t) => throw new CompilerServiceException(t, environment) + val tree = parse(source) + val analyzer = new SqlCodeUtils(tree) + val item = analyzer.identifierUnder(position) + val matches: Seq[Completion] = item match { + case Some(idn: SqlIdnNode) => + val metadataBrowser = metadataBrowsers.get(environment.jdbcUrl.get) + val matches = metadataBrowser.getWordCompletionMatches(idn) + matches.collect { case (idns, value) => LetBindCompletion(idns.last.value, value) } + case Some(use: SqlParamUseNode) => tree.params.collect { + case (p, paramDescription) if p.startsWith(use.name) => + FunParamCompletion(p, paramDescription.tipe.getOrElse("")) + }.toSeq + case _ => Array.empty[Completion] } + AutoCompleteResponse(matches.toArray) } override def hover(source: String, environment: ProgramEnvironment, position: Pos): HoverResponse = { - try { - logger.debug(s"Hovering at position: $position") - val tree = parse(source) - val analyzer = new SqlCodeUtils(tree) - analyzer - .identifierUnder(position) - .map { - case identifier: SqlIdnNode => - val metadataBrowser = metadataBrowsers.get(environment.jdbcUrl.get) - val matches = metadataBrowser.getWordCompletionMatches(identifier) - matches.headOption - .map { case (names, tipe) => HoverResponse(Some(TypeCompletion(formatIdns(names), tipe))) } - .getOrElse(HoverResponse(None)) - case use: SqlParamUseNode => + val tree = parse(source) + val analyzer = new SqlCodeUtils(tree) + analyzer + .identifierUnder(position) + .map { + case identifier: SqlIdnNode => + val metadataBrowser = metadataBrowsers.get(environment.jdbcUrl.get) + val matches = metadataBrowser.getWordCompletionMatches(identifier) + matches.headOption + .map { case (names, tipe) => HoverResponse(Some(TypeCompletion(formatIdns(names), tipe))) } + .getOrElse(HoverResponse(None)) + case use: SqlParamUseNode => + try { + val conn = connectionPool.getConnection(environment.jdbcUrl.get) try { - val conn = connectionPool.getConnection(environment.jdbcUrl.get) + val pstmt = new NamedParametersPreparedStatement(conn, tree) try { - val pstmt = new NamedParametersPreparedStatement(conn, tree) - try { - pstmt.parameterInfo(use.name) match { - case Right(typeInfo) => HoverResponse(Some(TypeCompletion(use.name, typeInfo.pgType.typeName))) - case Left(_) => HoverResponse(None) - } - } finally { - RawUtils.withSuppressNonFatalException(pstmt.close()) + pstmt.parameterInfo(use.name) match { + case Right(typeInfo) => HoverResponse(Some(TypeCompletion(use.name, typeInfo.pgType.typeName))) + case Left(_) => HoverResponse(None) } - } catch { - case _: NamedParametersPreparedStatementException => HoverResponse(None) } finally { - RawUtils.withSuppressNonFatalException(conn.close()) + RawUtils.withSuppressNonFatalException(pstmt.close()) } } catch { - case ex: SQLException if isConnectionFailure(ex) => - logger.warn("SqlConnectionPool connection failure", ex) - HoverResponse(None) + case _: NamedParametersPreparedStatementException => HoverResponse(None) + } finally { + RawUtils.withSuppressNonFatalException(conn.close()) } - case other => throw new AssertionError(s"Unexpected node type: $other") - } - .getOrElse(HoverResponse(None)) - } catch { - case NonFatal(t) => throw new CompilerServiceException(t, environment) - } + } catch { + case ex: SQLException if isConnectionFailure(ex) => HoverResponse(None) + } + case other => throw new AssertionError(s"Unexpected node type: $other") + } + .getOrElse(HoverResponse(None)) } private def formatIdns(idns: Seq[SqlIdentifier]): String = { @@ -423,11 +512,7 @@ class SqlCompilerService()(implicit protected val settings: RawSettings) extends } override def rename(source: String, environment: ProgramEnvironment, position: Pos): RenameResponse = { - try { - RenameResponse(Array.empty) - } catch { - case NonFatal(t) => throw new CompilerServiceException(t, environment) - } + RenameResponse(Array.empty) } override def goToDefinition( @@ -435,55 +520,39 @@ class SqlCompilerService()(implicit protected val settings: RawSettings) extends environment: ProgramEnvironment, position: Pos ): GoToDefinitionResponse = { - try { - GoToDefinitionResponse(None) - } catch { - case NonFatal(t) => throw new CompilerServiceException(t, environment) - } + GoToDefinitionResponse(None) } override def validate(source: String, environment: ProgramEnvironment): ValidateResponse = { - try { - logger.debug(s"Validating: $source") - safeParse(source) match { - case Left(errors) => ValidateResponse(errors) - case Right(parsedTree) => + safeParse(source) match { + case Left(errors) => ValidateResponse(errors) + case Right(parsedTree) => + try { + val conn = connectionPool.getConnection(environment.jdbcUrl.get) try { - val conn = connectionPool.getConnection(environment.jdbcUrl.get) + val stmt = new NamedParametersPreparedStatement(conn, parsedTree) try { - val stmt = new NamedParametersPreparedStatement(conn, parsedTree) - try { - stmt.queryMetadata match { - case Right(_) => ValidateResponse(List.empty) - case Left(errors) => ValidateResponse(errors) - } - } finally { - RawUtils.withSuppressNonFatalException(stmt.close()) + stmt.queryMetadata match { + case Right(_) => ValidateResponse(List.empty) + case Left(errors) => ValidateResponse(errors) } - } catch { - case e: NamedParametersPreparedStatementException => ValidateResponse(e.errors) } finally { - RawUtils.withSuppressNonFatalException(conn.close()) + RawUtils.withSuppressNonFatalException(stmt.close()) } } catch { - case ex: SQLException if isConnectionFailure(ex) => - logger.warn("SqlConnectionPool connection failure", ex) - ValidateResponse(List(treeError(parsedTree, ex.getMessage))) + case e: NamedParametersPreparedStatementException => ValidateResponse(e.errors) + } finally { + RawUtils.withSuppressNonFatalException(conn.close()) } - } - } catch { - case NonFatal(t) => - logger.debug(t.getMessage) - throw new CompilerServiceException(t, environment) + } catch { + case ex: SQLException if isConnectionFailure(ex) => + ValidateResponse(List(treeError(parsedTree, ex.getMessage))) + } } } override def aiValidate(source: String, environment: ProgramEnvironment): ValidateResponse = { - try { - ValidateResponse(List.empty) - } catch { - case NonFatal(t) => throw new CompilerServiceException(t, environment) - } + ValidateResponse(List.empty) } override def doStop(): Unit = { diff --git a/sql-compiler/src/main/scala/com/rawlabs/sql/compiler/writers/TypedResultSetRawValueIterator.scala b/sql-compiler/src/main/scala/com/rawlabs/sql/compiler/writers/TypedResultSetRawValueIterator.scala new file mode 100644 index 000000000..09e7012cb --- /dev/null +++ b/sql-compiler/src/main/scala/com/rawlabs/sql/compiler/writers/TypedResultSetRawValueIterator.scala @@ -0,0 +1,590 @@ +/* + * Copyright 2025 RAW Labs S.A. + * + * Use of this software is governed by the Business Source License + * included in the file licenses/BSL.txt. + * + * As of the Change Date specified in that file, in accordance with + * the Business Source License, use of this software will be governed + * by the Apache License, Version 2.0, included in the file + * licenses/APL.txt. + */ + +package com.rawlabs.sql.compiler.writers + +import com.fasterxml.jackson.databind.{JsonNode, ObjectMapper} +import com.google.protobuf.ByteString +import com.rawlabs.compiler._ +import com.rawlabs.protocol.raw._ +import com.rawlabs.sql.compiler.SqlIntervals.stringToInterval +import com.typesafe.scalalogging.StrictLogging +import org.postgresql.util.{PGInterval, PGobject} + +import java.sql.{ResultSet, Timestamp} +import java.time.temporal.ChronoField +import scala.annotation.tailrec +import scala.collection.JavaConverters._ + +/** + * Reads a JDBC ResultSet described by t = RawIterableType(RawRecordType(...)) + * and yields an Iterator[Value], each being a ValueRecord of that row. + */ +class TypedResultSetRawValueIterator( + resultSet: ResultSet, + t: RawType +) extends Iterator[Value] + with StrictLogging { + + private val mapper = new ObjectMapper() + + // We expect t to be RawIterableType(RawRecordType(...)) + private val attributes = t match { + case RawIterableType(RawRecordType(atts, _, _), _, _) => atts + case _ => throw new IllegalArgumentException( + s"TypedResultSetRawValueIterator can only handle Iterable of Record. Got: $t" + ) + } + + private var fetched = false + private var hasMore = false + private var rowsRead: Long = 0 + + override def hasNext: Boolean = { + if (!fetched) { + if (resultSet.next()) { + hasMore = true + rowsRead += 1 + } else { + hasMore = false + } + fetched = true + } + hasMore + } + + override def next(): Value = { + if (!hasNext) { + throw new NoSuchElementException("No more rows in the ResultSet.") + } + fetched = false + + // Build a ValueRecord for this row + val rowAttrs = attributes.indices.map { i => + val fieldName = attributes(i).idn + val fieldType = attributes(i).tipe + val colValue = readValue(resultSet, i + 1, fieldType) + ValueRecordField.newBuilder().setName(fieldName).setValue(colValue).build() + } + + Value + .newBuilder() + .setRecord(ValueRecord.newBuilder().addAllFields(rowAttrs.asJava)) + .build() + } + + /** + * Recursively read a single column from the result set and produce RawValue. + */ + @tailrec + private def readValue(rs: ResultSet, colIndex: Int, tipe: RawType): Value = { + if (tipe.nullable) { + // Check null first + rs.getObject(colIndex) + if (rs.wasNull()) { + buildNullValue() + } else { + readValue(rs, colIndex, tipe.cloneNotNullable) + } + } else tipe match { + case _: RawBoolType => + val b = rs.getBoolean(colIndex) + if (rs.wasNull()) buildNullValue() else boolValue(b) + + case _: RawByteType => + val b = rs.getByte(colIndex) + if (rs.wasNull()) buildNullValue() else byteValue(b) + + case _: RawShortType => + val s = rs.getShort(colIndex) + if (rs.wasNull()) buildNullValue() else shortValue(s) + + case _: RawIntType => + val i = rs.getInt(colIndex) + if (rs.wasNull()) buildNullValue() else intValue(i) + + case _: RawLongType => + val l = rs.getLong(colIndex) + if (rs.wasNull()) buildNullValue() else longValue(l) + + case _: RawFloatType => + val f = rs.getFloat(colIndex) + if (rs.wasNull()) buildNullValue() else floatValue(f) + + case _: RawDoubleType => + val d = rs.getDouble(colIndex) + if (rs.wasNull()) buildNullValue() else doubleValue(d) + + case _: RawDecimalType => + val dec = rs.getBigDecimal(colIndex) + if (rs.wasNull() || dec == null) buildNullValue() else decimalValue(dec.toString) + + case _: RawStringType => + val s = rs.getString(colIndex) + if (rs.wasNull() || s == null) buildNullValue() else stringValue(s) + + case RawListType(inner, _, _) => + val arrayObj = rs.getArray(colIndex) + if (rs.wasNull() || arrayObj == null) { + buildNullValue() + } else { + // Convert to array + val arrayVals = arrayObj.getArray.asInstanceOf[Array[AnyRef]] + val converted = arrayVals.map { v => + convertArrayElementToRawValue(v, inner, rs.getMetaData.getColumnTypeName(colIndex).toLowerCase) + }.toList + listValue(converted) + } + + case _: RawDateType => + val date = rs.getDate(colIndex) + if (rs.wasNull() || date == null) { + buildNullValue() + } else { + val localDate = date.toLocalDate + dateValue(localDate.getYear, localDate.getMonthValue, localDate.getDayOfMonth) + } + + case _: RawTimeType => + val sqlTime = rs.getTime(colIndex) + if (rs.wasNull() || sqlTime == null) { + buildNullValue() + } else { + // Attempt to preserve milliseconds + val localTime = sqlTime.toLocalTime + // The raw .toLocalTime might discard fractional seconds. + // So we re-derive them from the underlying time in millis: + val asMillis = sqlTime.getTime % 1000 + val fixedTime = localTime.`with`(ChronoField.MILLI_OF_SECOND, asMillis) + timeValue( + fixedTime.getHour, + fixedTime.getMinute, + fixedTime.getSecond, + fixedTime.getNano + ) + } + + case _: RawTimestampType => + val ts = rs.getTimestamp(colIndex) + if (rs.wasNull() || ts == null) { + buildNullValue() + } else { + val ldt = ts.toLocalDateTime + // Optionally fix fractional part if needed + timestampValue( + ldt.getYear, + ldt.getMonthValue, + ldt.getDayOfMonth, + ldt.getHour, + ldt.getMinute, + ldt.getSecond, + ldt.getNano + ) + } + + case _: RawIntervalType => + val rawStr = rs.getString(colIndex) + if (rs.wasNull() || rawStr == null) { + buildNullValue() + } else { + val interval = stringToInterval(rawStr) + intervalValue(interval) + } + + case _: RawBinaryType => + val bytes = rs.getBytes(colIndex) + if (rs.wasNull() || bytes == null) { + buildNullValue() + } else { + binaryValue(bytes) + } + + case _: RawAnyType => + // Single column typed as ANY + val colTypeName = rs.getMetaData.getColumnTypeName(colIndex).toLowerCase + // e.g. "json", "jsonb", "hstore", or something else + handleAnySingleValue(rs, colIndex, colTypeName) + + case _ => throw new IllegalArgumentException(s"Unsupported type: $tipe") + } + } + + /** + * Convert array elements to RawValue. + * If the array is typed ANY (e.g. `_json`, `_hstore`), we handle that similarly. + */ + @tailrec + private def convertArrayElementToRawValue( + element: AnyRef, + tipe: RawType, + pgType: String + ): Value = { + // If the element is null, just return RawNull: + if (element == null) return buildNullValue() + + if (tipe.nullable) { + // If the subtype is nullable, we treat a null element or the object as if it might be null + if (element == null) buildNullValue() + else convertArrayElementToRawValue(element, tipe.cloneNotNullable, pgType) + } else tipe match { + case _: RawBoolType => boolValue(element.asInstanceOf[Boolean]) + case _: RawByteType => byteValue(element.asInstanceOf[Byte]) + case _: RawShortType => shortValue(element.asInstanceOf[Short]) + case _: RawIntType => intValue(element.asInstanceOf[Int]) + case _: RawLongType => longValue(element.asInstanceOf[Long]) + case _: RawFloatType => floatValue(element.asInstanceOf[Float]) + case _: RawDoubleType => doubleValue(element.asInstanceOf[Double]) + case _: RawDecimalType => decimalValue(element.asInstanceOf[java.math.BigDecimal].toString) + case _: RawStringType => stringValue(element.asInstanceOf[String]) + + case _: RawIntervalType => + val pgint = element.asInstanceOf[PGInterval] + intervalValue( + RawInterval( + pgint.getYears, + pgint.getMonths, + 0, + pgint.getDays, + pgint.getHours, + pgint.getMinutes, + pgint.getWholeSeconds, + pgint.getMicroSeconds + ) + ) + + case _: RawDateType => + val d = element.asInstanceOf[java.sql.Date] + val ld = d.toLocalDate + dateValue(ld.getYear, ld.getMonthValue, ld.getDayOfMonth) + + case _: RawTimeType => + val t = element.asInstanceOf[java.sql.Time] + val localTime = t.toLocalTime + val asMillis = t.getTime % 1000 + val fixedTime = localTime.`with`(ChronoField.MILLI_OF_SECOND, asMillis) + timeValue(fixedTime.getHour, fixedTime.getMinute, fixedTime.getSecond, fixedTime.getNano) + + case _: RawTimestampType => + val ts = element.asInstanceOf[Timestamp] + val ldt = ts.toLocalDateTime + timestampValue( + ldt.getYear, + ldt.getMonthValue, + ldt.getDayOfMonth, + ldt.getHour, + ldt.getMinute, + ldt.getSecond, + ldt.getNano + ) + + case _: RawAnyType => pgType match { + case "_jsonb" | "_json" => + val data = element.asInstanceOf[String] + val json = mapper.readTree(data) + jsonNodeToRawValue(json) + case "_hstore" => + val item = element.asInstanceOf[PGobject] + hstoreToRawRecord(item.getValue) + } + + case _ => throw new IllegalArgumentException(s"Unsupported type: $tipe") + + } + } + + /** + * Handle a single column typed ANY (non-array). + * We look at the column type name to decide how to parse the underlying value. + */ + private def handleAnySingleValue(rs: ResultSet, colIndex: Int, colTypeName: String): Value = { + colTypeName match { + case "json" | "jsonb" => + val data = rs.getString(colIndex) + if (rs.wasNull() || data == null) buildNullValue() + else { + val node = mapper.readTree(data) + jsonNodeToRawValue(node) + } + + case "hstore" => + val obj = rs.getObject(colIndex) + if (rs.wasNull() || obj == null) { + buildNullValue() + } else obj match { + case pg: PGobject if pg.getValue != null => hstoreToRawRecord(pg.getValue) + case m: java.util.Map[_, _] => mapToRawRecord(m) + case s: String => hstoreToRawRecord(s) + case _ => + // fallback, treat as string + stringValue(obj.toString) + } + + case "_json" | "_jsonb" | "_hstore" => + // The driver might let us read an array from getObject or getArray. + // If we’re here, it suggests the declared column is an array of ANY. + // Let's do a fallback approach: + val arrObj = rs.getArray(colIndex) + if (rs.wasNull() || arrObj == null) { + buildNullValue() + } else { + val arrayVals = arrObj.getArray.asInstanceOf[Array[AnyRef]] + // For ANY array, we convert each element using convertAnyElement + val converted = arrayVals.map(convertAnyElement).toList + listValue(converted) + } + + case _ => + // fallback – treat it as a string + val data = rs.getString(colIndex) + if (rs.wasNull() || data == null) buildNullValue() else stringValue(data) + } + } + + /** + * Convert a single array-element typed ANY at runtime by inspecting the object. + */ + private def convertAnyElement(element: AnyRef): Value = { + if (element == null) return buildNullValue() + + element match { + case pg: PGobject => + val pgType = pg.getType.toLowerCase + pgType match { + case "json" | "jsonb" => + val data = pg.getValue + if (data == null) buildNullValue() + else jsonNodeToRawValue(mapper.readTree(data)) + + case "hstore" => + val data = pg.getValue + if (data == null) buildNullValue() else hstoreToRawRecord(data) + + // fallback for other PGobject types + case _ => + if (pg.getValue == null) buildNullValue() + else stringValue(pg.getValue) + } + + case m: java.util.Map[_, _] => + // Likely hstore returned as a map + mapToRawRecord(m) + + case s: String => + // Could be JSON, could be hstore, or just a string + // We'll try JSON parse first, fallback to hstore parse, else raw string + try { + val node = mapper.readTree(s) + jsonNodeToRawValue(node) + } catch { + case _: Throwable => + // maybe hstore + hstoreToRawRecord(s) + } + + case arr: Array[_] => + // Possibly a nested array scenario + val subVals = arr.map(x => convertAnyElement(x.asInstanceOf[AnyRef])).toList + listValue(subVals) + + // fallback for other possible data + case other => stringValue(other.toString) + } + } + + /** + * Convert a Jackson JsonNode → RawValue (recursive). + */ + private def jsonNodeToRawValue(node: JsonNode): Value = { + if (node.isNull) { + buildNullValue() + } else if (node.isObject) { + val fields = node + .fields() + .asScala + .map { entry => + val key = entry.getKey + val valueNode = entry.getValue + ValueRecordField + .newBuilder() + .setName(key) + .setValue(jsonNodeToRawValue(valueNode)) + .build() + } + .toSeq + Value.newBuilder().setRecord(ValueRecord.newBuilder().addAllFields(fields.asJava)).build() + } else if (node.isArray) { + val elems = node.elements().asScala.map(jsonNodeToRawValue).toList + listValue(elems) + } else if (node.isTextual) { + stringValue(node.asText()) + } else if (node.isIntegralNumber) { + longValue(node.asLong()) + } else if (node.isFloatingPointNumber) { + doubleValue(node.asDouble()) + } else if (node.isBoolean) { + boolValue(node.asBoolean()) + } else { + // fallback or error + stringValue(node.toString) + } + } + + /** + * Convert an hstore string into a RawRecord. + * For example: "key1"=>"val1", "key2"=>"val2" + */ + private def hstoreToRawRecord(hstoreStr: String): Value = { + if (hstoreStr.trim.isEmpty) { + // handle empty string + Value.newBuilder().setRecord(ValueRecord.newBuilder()).build() + } else { + // naive parse by splitting on commas that aren't inside quotes + val pairs = hstoreStr.split(",(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)").toList + val fields = pairs.map { pair => + val kv = pair.split("=>", 2).map(_.trim) + if (kv.length != 2) { + // malformed chunk + throw new IllegalArgumentException(s"Malformed hstore chunk: '$pair'") + } + val key = kv(0).replaceAll("^\"|\"$", "") // remove leading/trailing quotes + val value = kv(1).replaceAll("^\"|\"$", "") + val v = if (value == "NULL") buildNullValue() else stringValue(value) + ValueRecordField.newBuilder().setName(key).setValue(v).build() + } + Value.newBuilder().setRecord(ValueRecord.newBuilder().addAllFields(fields.asJava)).build() + } + } + + /** + * Convert a Map[String, String] (like some drivers do for hstore) to a RawRecord. + */ + private def mapToRawRecord(m: java.util.Map[_, _]): Value = { + val fields = m.asScala.collect { + case (k: String, v: String) => + val valField = + if (v == null) buildNullValue() + else stringValue(v) + ValueRecordField.newBuilder().setName(k).setValue(valField).build() + }.toSeq + Value.newBuilder().setRecord(ValueRecord.newBuilder().addAllFields(fields.asJava)).build() + } + + // ---- Helper methods to build various RawValue types ---- + private def buildNullValue(): Value = Value.newBuilder().setNull(ValueNull.newBuilder()).build() + + private def boolValue(b: Boolean): Value = Value.newBuilder().setBool(ValueBool.newBuilder().setV(b)).build() + private def byteValue(b: Byte): Value = Value.newBuilder().setByte(ValueByte.newBuilder().setV(b)).build() + private def shortValue(s: Short): Value = Value.newBuilder().setShort(ValueShort.newBuilder().setV(s)).build() + private def intValue(i: Int): Value = Value.newBuilder().setInt(ValueInt.newBuilder().setV(i)).build() + private def longValue(l: Long): Value = Value.newBuilder().setLong(ValueLong.newBuilder().setV(l)).build() + private def floatValue(f: Float): Value = Value.newBuilder().setFloat(ValueFloat.newBuilder().setV(f)).build() + private def doubleValue(d: Double): Value = Value.newBuilder().setDouble(ValueDouble.newBuilder().setV(d)).build() + private def decimalValue(str: String): Value = + Value.newBuilder().setDecimal(ValueDecimal.newBuilder().setV(str)).build() + private def stringValue(s: String): Value = Value.newBuilder().setString(ValueString.newBuilder().setV(s)).build() + + private def listValue(elems: Seq[Value]): Value = { + Value.newBuilder().setList(ValueList.newBuilder().addAllValues(elems.asJava)).build() + } + + private def dateValue(year: Int, month: Int, day: Int): Value = { + Value + .newBuilder() + .setDate( + ValueDate.newBuilder().setYear(year).setMonth(month).setDay(day) + ) + .build() + } + + private def timeValue(hour: Int, minute: Int, second: Int, nano: Int): Value = { + Value + .newBuilder() + .setTime( + ValueTime + .newBuilder() + .setHour(hour) + .setMinute(minute) + .setSecond(second) + .setNano(nano) + ) + .build() + } + + private def timestampValue( + year: Int, + month: Int, + day: Int, + hour: Int, + minute: Int, + second: Int, + nano: Int + ): Value = { + Value + .newBuilder() + .setTimestamp( + ValueTimestamp + .newBuilder() + .setYear(year) + .setMonth(month) + .setDay(day) + .setHour(hour) + .setMinute(minute) + .setSecond(second) + .setNano(nano) + ) + .build() + } + + private def intervalValue(i: RawInterval): Value = { + Value + .newBuilder() + .setInterval( + ValueInterval + .newBuilder() + .setYears(i.years) + .setMonths(i.months) + .setWeeks(i.weeks) + .setDays(i.days) + .setHours(i.hours) + .setMinutes(i.minutes) + .setSeconds(i.seconds) + .setMillis(i.millis.toInt) + ) + .build() + } + + private def binaryValue(bytes: Array[Byte]): Value = { + Value + .newBuilder() + .setBinary( + ValueBinary + .newBuilder() + .setV( + ByteString.copyFrom(bytes) + ) + ) + .build() + } + + /** + * Close the underlying ResultSet. Typically you'd also close the Statement that created it. + */ + def close(): Unit = { + if (resultSet != null) { + try { + resultSet.close() + } catch { + case _: Throwable => // ignore + } + } + } +} diff --git a/sql-compiler/src/test/scala/com/rawlabs/sql/compiler/TestSqlCompilerServiceAirports.scala b/sql-compiler/src/test/scala/com/rawlabs/sql/compiler/TestSqlCompilerServiceAirports.scala index f46a6e7aa..afecc695f 100644 --- a/sql-compiler/src/test/scala/com/rawlabs/sql/compiler/TestSqlCompilerServiceAirports.scala +++ b/sql-compiler/src/test/scala/com/rawlabs/sql/compiler/TestSqlCompilerServiceAirports.scala @@ -13,45 +13,22 @@ package com.rawlabs.sql.compiler import com.dimafeng.testcontainers.{ForAllTestContainer, PostgreSQLContainer} -import com.rawlabs.compiler.{ - CompilerService, - ErrorMessage, - ErrorPosition, - ExecutionRuntimeFailure, - ExecutionSuccess, - ExecutionValidationFailure, - GetProgramDescriptionFailure, - GetProgramDescriptionSuccess, - LetBindCompletion, - ParamDescription, - Pos, - ProgramEnvironment, - RawAttrType, - RawDate, - RawDateType, - RawDecimalType, - RawInt, - RawIntType, - RawIterableType, - RawLongType, - RawNull, - RawRecordType, - RawString, - RawStringType, - RawValue, - TypeCompletion, - ValidateResponse -} -import org.testcontainers.utility.DockerImageName +import com.rawlabs.compiler._ +import com.rawlabs.protocol.raw.{Type, Value} import com.rawlabs.utils.core._ +import org.scalatest.EitherValues +import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper +import org.testcontainers.utility.DockerImageName import java.io.ByteArrayOutputStream import java.sql.DriverManager import java.time.LocalDate +import scala.collection.JavaConverters._ import scala.io.Source class TestSqlCompilerServiceAirports extends RawTestSuite + with EitherValues with ForAllTestContainer with SettingsTestContext with TrainingWheelsContext { @@ -102,25 +79,14 @@ class TestSqlCompilerServiceAirports super.afterAll() } - private def asJson(params: Map[String, RawValue] = Map.empty, scopes: Set[String] = Set.empty): ProgramEnvironment = { + private def mkEnv(params: Map[String, RawValue] = Map.empty, scopes: Set[String] = Set.empty): ProgramEnvironment = { ProgramEnvironment( user, if (params.isEmpty) None else Some(params.toArray), scopes, Map.empty, Map.empty, - Map("output-format" -> "json"), - jdbcUrl = Some(jdbcUrl) - ) - } - private def asCsv(params: Map[String, RawValue] = Map.empty, scopes: Set[String] = Set.empty): ProgramEnvironment = { - ProgramEnvironment( - user, - if (params.isEmpty) None else Some(params.toArray), - scopes, - Map.empty, Map.empty, - Map("output-format" -> "csv"), jdbcUrl = Some(jdbcUrl) ) } @@ -142,131 +108,189 @@ class TestSqlCompilerServiceAirports | ARRAY['{"a": 2}'::jsonb, '{"b": 3}'::jsonb, '{"c": 4}'::jsonb, '{"d": 5}'::jsonb] AS jsonb_array, | ARRAY['"a" => "2", "b" => "3"'::hstore, '"c" => "4", "d" => "5"'::hstore] AS hstore_array, | ARRAY['apple', 'banana', 'cherry'] AS text_array;""".stripMargin) { t => - val v = compilerService.validate(t.q, asJson()) + val v = compilerService.validate(t.q, mkEnv()) assert(v.messages.isEmpty) - val GetProgramDescriptionSuccess(description) = compilerService.getProgramDescription(t.q, asJson()) + val Right(description) = compilerService.getProgramDescription(t.q, mkEnv()) val Some(main) = description.maybeRunnable assert(main.params.contains(Vector.empty)) - val baos = new ByteArrayOutputStream() - assert( - compilerService.execute( - t.q, - asJson(), - None, - baos - ) == ExecutionSuccess(true) - ) - assert( - baos.toString() == - """[ - | { - | "interval": "P1D", - | "short_array": [ - | 1, - | 2, - | 3, - | 4 - | ], - | "integer_array": [ - | 1, - | 2, - | 3, - | 4 - | ], - | "float_array": [ - | 1.1, - | 2.2, - | 3.3, - | 4.4 - | ], - | "double_array": [ - | 1.1, - | 2.2, - | 3.3, - | 4.4 - | ], - | "decimal_array": [ - | 1.1, - | 2.2, - | 3.3, - | 4.4 - | ], - | "boolean_array": [ - | true, - | false, - | true, - | false - | ], - | "date_array": [ - | "2021-01-01", - | "2021-01-02", - | "2021-01-03", - | "2021-01-04" - | ], - | "time_array": [ - | "12:00:00.000", - | "13:00:00.000", - | "14:00:00.000", - | "15:00:00.000" - | ], - | "timestamp_array": [ - | "2021-01-01T12:00:00.000", - | "2021-01-02T13:00:00.000", - | "2021-01-03T14:00:00.000", - | "2021-01-04T15:00:00.000" - | ], - | "interval_array": [ - | "P1D", - | "P2D", - | "P3D", - | "P4D" - | ], - | "json_array": [ - | { - | "a": 2.0 - | }, - | { - | "b": 3.0 - | }, - | { - | "c": 4.0 - | }, - | { - | "d": 5.0 - | } - | ], - | "jsonb_array": [ - | { - | "a": 2.0 - | }, - | { - | "b": 3.0 - | }, - | { - | "c": 4.0 - | }, - | { - | "d": 5.0 - | } - | ], - | "hstore_array": [ - | { - | "a": "2", - | "b": "3" - | }, - | { - | "c": "4", - | "d": "5" - | } - | ], - | "text_array": [ - | "apple", - | "banana", - | "cherry" - | ] - | } - |]""".stripMargin.replaceAll("\\s+", "") - ) + val EvalSuccess.IteratorValue(tipe, it) = compilerService.eval(t.q, mkEnv(), None).value + assert(tipe.hasRecord) + val colTypes = tipe.getRecord.getAttsList + val colValues = it.next().getRecord.getFieldsList + it.hasNext shouldBe false + colTypes.size shouldBe 15 + colValues.size shouldBe colTypes.size + // Col 0 + colTypes.get(0).getIdn shouldBe "interval" + colTypes.get(0).getTipe.hasInterval shouldBe true + val intervalCol = colValues.get(0).getValue.getInterval + intervalCol.getYears shouldBe 0 + intervalCol.getMonths shouldBe 0 + intervalCol.getDays shouldBe 1 + intervalCol.getHours shouldBe 0 + intervalCol.getMinutes shouldBe 0 + intervalCol.getSeconds shouldBe 0 + intervalCol.getMillis shouldBe 0 + // Col 1 + colTypes.get(1).getIdn shouldBe "short_array" + colTypes.get(1).getTipe.hasList shouldBe true + colTypes.get(1).getTipe.getList.getInnerType.hasShort shouldBe true + val shortArrayCol = colValues.get(1).getValue.getList.getValuesList.asScala.map(v => v.getShort.getV) + shortArrayCol shouldBe Seq(1, 2, 3, 4) + // Col 2 + colTypes.get(2).getIdn shouldBe "integer_array" + colTypes.get(2).getTipe.hasList shouldBe true + colTypes.get(2).getTipe.getList.getInnerType.hasInt shouldBe true + val integerArrayCol = colValues.get(2).getValue.getList.getValuesList.asScala.map(v => v.getInt.getV) + integerArrayCol shouldBe Seq(1, 2, 3, 4) + // Col 3 + colTypes.get(3).getIdn shouldBe "float_array" + colTypes.get(3).getTipe.hasList shouldBe true + colTypes.get(3).getTipe.getList.getInnerType.hasFloat shouldBe true + val floatArrayCol = colValues.get(3).getValue.getList.getValuesList.asScala.map(v => v.getFloat.getV) + floatArrayCol shouldBe Seq(1.1f, 2.2f, 3.3f, 4.4f) + // Col 4 + colTypes.get(4).getIdn shouldBe "double_array" + colTypes.get(4).getTipe.hasList shouldBe true + colTypes.get(4).getTipe.getList.getInnerType.hasDouble shouldBe true + val doubleArrayCol = colValues.get(4).getValue.getList.getValuesList.asScala.map(v => v.getDouble.getV) + doubleArrayCol shouldBe Seq(1.1, 2.2, 3.3, 4.4) + // Col 5 + colTypes.get(5).getIdn shouldBe "decimal_array" + colTypes.get(5).getTipe.hasList shouldBe true + colTypes.get(5).getTipe.getList.getInnerType.hasDecimal shouldBe true + val decimalArrayCol = colValues.get(5).getValue.getList.getValuesList.asScala.map(v => v.getDecimal.getV) + decimalArrayCol shouldBe Seq("1.1", "2.2", "3.3", "4.4") + // Col 6 + colTypes.get(6).getIdn shouldBe "boolean_array" + colTypes.get(6).getTipe.hasList shouldBe true + colTypes.get(6).getTipe.getList.getInnerType.hasBool shouldBe true + val booleanArrayCol = colValues.get(6).getValue.getList.getValuesList.asScala.map(v => v.getBool.getV) + booleanArrayCol shouldBe Seq(true, false, true, false) + // Col 7 + colTypes.get(7).getIdn shouldBe "date_array" + colTypes.get(7).getTipe.hasList shouldBe true + colTypes.get(7).getTipe.getList.getInnerType.hasDate shouldBe true + val dateArrayCol = colValues + .get(7) + .getValue + .getList + .getValuesList + .asScala + .map(v => v.getDate) + .map(d => (d.getYear, d.getMonth, d.getDay)) + dateArrayCol shouldBe Seq( + (2021, 1, 1), + (2021, 1, 2), + (2021, 1, 3), + (2021, 1, 4) + ) + // Col 8 + colTypes.get(8).getIdn shouldBe "time_array" + colTypes.get(8).getTipe.hasList shouldBe true + colTypes.get(8).getTipe.getList.getInnerType.hasTime shouldBe true + val timeArrayCol = colValues + .get(8) + .getValue + .getList + .getValuesList + .asScala + .map(v => v.getTime) + .map(v => (v.getHour, v.getMinute, v.getSecond, v.getNano)) + timeArrayCol shouldBe Seq( + (12, 0, 0, 0), + (13, 0, 0, 0), + (14, 0, 0, 0), + (15, 0, 0, 0) + ) + // Col 9 + colTypes.get(9).getIdn shouldBe "timestamp_array" + colTypes.get(9).getTipe.hasList shouldBe true + colTypes.get(9).getTipe.getList.getInnerType.hasTimestamp shouldBe true + val timestampArrayCol = colValues + .get(9) + .getValue + .getList + .getValuesList + .asScala + .map(v => v.getTimestamp) + .map(v => (v.getYear, v.getMonth, v.getDay, v.getHour, v.getMinute, v.getSecond, v.getNano)) + timestampArrayCol shouldBe Seq( + (2021, 1, 1, 12, 0, 0, 0), + (2021, 1, 2, 13, 0, 0, 0), + (2021, 1, 3, 14, 0, 0, 0), + (2021, 1, 4, 15, 0, 0, 0) + ) + // Col 10 + colTypes.get(10).getIdn shouldBe "interval_array" + colTypes.get(10).getTipe.hasList shouldBe true + colTypes.get(10).getTipe.getList.getInnerType.hasInterval shouldBe true + val intervalArrayCol = colValues.get(10).getValue.getList.getValuesList.asScala.map(v => v.getInterval) + intervalArrayCol.map(i => + (i.getYears, i.getMonths, i.getDays, i.getHours, i.getMinutes, i.getSeconds, i.getMillis) + ) shouldBe Seq( + (0, 0, 1, 0, 0, 0, 0), + (0, 0, 2, 0, 0, 0, 0), + (0, 0, 3, 0, 0, 0, 0), + (0, 0, 4, 0, 0, 0, 0) + ) + // Col 11 + colTypes.get(11).getIdn shouldBe "json_array" + colTypes.get(11).getTipe.hasList shouldBe true + colTypes.get(11).getTipe.getList.getInnerType.hasAny shouldBe true + val jsonArrayCol = colValues + .get(11) + .getValue + .getList + .getValuesList + .asScala + .map(v => v.getRecord) + .map(r => r.getFieldsList.asScala.map(f => f.getName -> f.getValue.getLong.getV).toMap) + jsonArrayCol shouldBe Seq( + Map("a" -> 2L), + Map("b" -> 3L), + Map("c" -> 4L), + Map("d" -> 5L) + ) + // Col 12 + colTypes.get(12).getIdn shouldBe "jsonb_array" + colTypes.get(12).getTipe.hasList shouldBe true + colTypes.get(12).getTipe.getList.getInnerType.hasAny shouldBe true + val jsonbArrayCol = colValues + .get(12) + .getValue + .getList + .getValuesList + .asScala + .map(v => v.getRecord) + .map(r => r.getFieldsList.asScala.map(f => f.getName -> f.getValue.getLong.getV).toMap) + jsonbArrayCol shouldBe Seq( + Map("a" -> 2L), + Map("b" -> 3L), + Map("c" -> 4L), + Map("d" -> 5L) + ) + // Col 13 + colTypes.get(13).getIdn shouldBe "hstore_array" + colTypes.get(13).getTipe.hasList shouldBe true + colTypes.get(13).getTipe.getList.getInnerType.hasAny shouldBe true + val hstoreArrayCol = colValues + .get(13) + .getValue + .getList + .getValuesList + .asScala + .map(v => v.getRecord.getFieldsList.asScala.map(f => f.getName -> f.getValue.getString.getV).toMap) + hstoreArrayCol shouldBe Seq( + Map("a" -> "2", "b" -> "3"), + Map("c" -> "4", "d" -> "5") + ) + // Col 14 + colTypes.get(14).getIdn shouldBe "text_array" + colTypes.get(14).getTipe.hasList shouldBe true + colTypes.get(14).getTipe.getList.getInnerType.hasString shouldBe true + val textArrayCol = colValues.get(14).getValue.getList.getValuesList.asScala.map(v => v.getString.getV) + textArrayCol shouldBe Seq("apple", "banana", "cherry") } // To be sure our offset checks aren't fooled by internal postgres parameters called $1, $2, ..., $10 (with several digits) @@ -287,7 +311,7 @@ class TestSqlCompilerServiceAirports |AND :n > 0 |AND :n > 0 |AND :n > 0""".stripMargin) { t => - val GetProgramDescriptionFailure(errors) = compilerService.getProgramDescription(t.q, asJson()) + val Left(errors) = compilerService.getProgramDescription(t.q, mkEnv()) assert(errors.size == 1) val error = errors.head assert(error.message.contains("syntax error at or near \",\"")) @@ -301,7 +325,7 @@ class TestSqlCompilerServiceAirports |WHERE :city::json IS NULL | OR :city::integer != 3 | OR :city::xml IS NULL""".stripMargin) { t => - val v = compilerService.validate(t.q, asJson()) + val v = compilerService.validate(t.q, mkEnv()) assert(v.messages.size == 2) assert(v.messages(0).positions(0).begin.line == 2) // first error is about json (one position, the :city::json) assert(v.messages(1).positions(0).begin.line == 4) // second error is about xml (one position, the :city::xml) @@ -310,7 +334,7 @@ class TestSqlCompilerServiceAirports test("""-- @type v double precisionw |SELECT :v FROM example.airports where city = :city""".stripMargin) { t => - val hover = compilerService.hover(t.q, asJson(), Pos(2, 48)) + val hover = compilerService.hover(t.q, mkEnv(), Pos(2, 48)) // the typo in type declaration doesn't block hover info about a correct one assert(hover.completion.contains(TypeCompletion("city", "varchar"))) @@ -319,7 +343,7 @@ class TestSqlCompilerServiceAirports // Quoted value test("""select * from exam""".stripMargin) { t => - val completion = compilerService.wordAutoComplete(t.q, asJson(), "c", Pos(1, 19)) + val completion = compilerService.wordAutoComplete(t.q, mkEnv(), "c", Pos(1, 19)) assert( completion.completions.toSet === Set(LetBindCompletion("example", "schema")) ) @@ -327,7 +351,7 @@ class TestSqlCompilerServiceAirports // Quoted value ignore("""do something to see if a schema has the same name as a column and it still works""") { t => - val completion = compilerService.wordAutoComplete(t.q, asJson(), "c", Pos(1, 19)) // right after 'm' + val completion = compilerService.wordAutoComplete(t.q, mkEnv(), "c", Pos(1, 19)) // right after 'm' assert( completion.completions.toSet === Set(LetBindCompletion("example", "schema")) ) @@ -335,7 +359,7 @@ class TestSqlCompilerServiceAirports // Quoted value test("""select * from "example"."airp""".stripMargin) { t => - val completion = compilerService.wordAutoComplete(t.q, asJson(), "c", Pos(1, 30)) + val completion = compilerService.wordAutoComplete(t.q, mkEnv(), "c", Pos(1, 30)) assert( completion.completions.toSet === Set(LetBindCompletion("airports", "table")) ) @@ -344,7 +368,7 @@ class TestSqlCompilerServiceAirports test("""SELECT * FROM example.airports |WHERE airports.c |""".stripMargin) { t => - val completion = compilerService.wordAutoComplete(t.q, asJson(), "c", Pos(2, 17)) + val completion = compilerService.wordAutoComplete(t.q, mkEnv(), "c", Pos(2, 17)) assert( completion.completions.toSet === Set( LetBindCompletion("city", "character varying"), @@ -357,7 +381,7 @@ class TestSqlCompilerServiceAirports test("""SELECT * FROM example.airports |WHERE airports."c |""".stripMargin) { t => - val completion = compilerService.wordAutoComplete(t.q, asJson(), "c", Pos(2, 18)) + val completion = compilerService.wordAutoComplete(t.q, mkEnv(), "c", Pos(2, 18)) assert( completion.completions.toSet === Set( LetBindCompletion("city", "character varying"), @@ -371,9 +395,9 @@ class TestSqlCompilerServiceAirports |WHERE ai. |AND airports. |""".stripMargin) { t => - val hover = compilerService.hover(t.q, asJson(), Pos(1, 16)) + val hover = compilerService.hover(t.q, mkEnv(), Pos(1, 16)) assert(hover.completion.contains(TypeCompletion("example", "schema"))) - val completion = compilerService.wordAutoComplete(t.q, asJson(), "", Pos(2, 9)) + val completion = compilerService.wordAutoComplete(t.q, mkEnv(), "", Pos(2, 9)) assert( completion.completions.toSet === Set( LetBindCompletion("airport_id", "integer"), @@ -382,7 +406,7 @@ class TestSqlCompilerServiceAirports ) // The calls to the dotAutoComplete have to point to the place before the dot - val dotCompletion = compilerService.dotAutoComplete(t.q, asJson(), Pos(3, 15)) + val dotCompletion = compilerService.dotAutoComplete(t.q, mkEnv(), Pos(3, 15)) assert( dotCompletion.completions.toSet === airportColumns ) @@ -391,9 +415,9 @@ class TestSqlCompilerServiceAirports test("""SELECT * FROM example.airports |WHERE ai. |""".stripMargin) { t => - val hover = compilerService.hover(t.q, asJson(), Pos(1, 16)) + val hover = compilerService.hover(t.q, mkEnv(), Pos(1, 16)) assert(hover.completion.contains(TypeCompletion("example", "schema"))) - val completion = compilerService.wordAutoComplete(t.q, asJson(), "", Pos(2, 9)) + val completion = compilerService.wordAutoComplete(t.q, mkEnv(), "", Pos(2, 9)) assert( completion.completions.toSet === Set( LetBindCompletion("airport_id", "integer"), @@ -406,10 +430,10 @@ class TestSqlCompilerServiceAirports test("""SELECT * FROM example.airports |WHERE airports. |""".stripMargin) { t => - val hover = compilerService.hover(t.q, asJson(), Pos(1, 16)) + val hover = compilerService.hover(t.q, mkEnv(), Pos(1, 16)) assert(hover.completion.contains(TypeCompletion("example", "schema"))) // The calls to the dotAutoComplete have to point to the place before the dot - val dotCompletion = compilerService.dotAutoComplete(t.q, asJson(), Pos(2, 15)) + val dotCompletion = compilerService.dotAutoComplete(t.q, mkEnv(), Pos(2, 15)) assert( dotCompletion.completions.toSet === airportColumns ) @@ -419,9 +443,9 @@ class TestSqlCompilerServiceAirports test("""SELECT * FROM example.airports |WHERE example.airports. |""".stripMargin) { t => - val hover = compilerService.hover(t.q, asJson(), Pos(1, 16)) + val hover = compilerService.hover(t.q, mkEnv(), Pos(1, 16)) assert(hover.completion.contains(TypeCompletion("example", "schema"))) - val completion = compilerService.wordAutoComplete(t.q, asJson(), "", Pos(2, 17)) + val completion = compilerService.wordAutoComplete(t.q, mkEnv(), "", Pos(2, 17)) assert( completion.completions.toSet === Set( LetBindCompletion("airports", "table") @@ -429,7 +453,7 @@ class TestSqlCompilerServiceAirports ) // The calls to the dotAutoComplete have to point to the place before the dot - val dotCompletion = compilerService.dotAutoComplete(t.q, asJson(), Pos(2, 23)) + val dotCompletion = compilerService.dotAutoComplete(t.q, mkEnv(), Pos(2, 23)) assert( dotCompletion.completions.toSet === airportColumns ) @@ -439,9 +463,9 @@ class TestSqlCompilerServiceAirports test("""SELECT * FROM "example"."airports" |WHERE "ai |""".stripMargin) { t => - val hover = compilerService.hover(t.q, asJson(), Pos(1, 17)) + val hover = compilerService.hover(t.q, mkEnv(), Pos(1, 17)) assert(hover.completion.contains(TypeCompletion("example", "schema"))) - val completion = compilerService.wordAutoComplete(t.q, asJson(), "", Pos(2, 10)) + val completion = compilerService.wordAutoComplete(t.q, mkEnv(), "", Pos(2, 10)) assert( completion.completions.toSet === Set( LetBindCompletion("airport_id", "integer"), @@ -454,10 +478,10 @@ class TestSqlCompilerServiceAirports test("""SELECT * FROM "example"."airports" |WHERE "airports". |""".stripMargin) { t => - val hover = compilerService.hover(t.q, asJson(), Pos(1, 17)) + val hover = compilerService.hover(t.q, mkEnv(), Pos(1, 17)) assert(hover.completion.contains(TypeCompletion("example", "schema"))) // The calls to the dotAutoComplete have to point to the place before the dot - val dotCompletion = compilerService.dotAutoComplete(t.q, asJson(), Pos(2, 17)) + val dotCompletion = compilerService.dotAutoComplete(t.q, mkEnv(), Pos(2, 17)) assert( dotCompletion.completions.toSet === Set( LetBindCompletion("icao", "character varying"), @@ -480,10 +504,10 @@ class TestSqlCompilerServiceAirports test("""SELECT * FROM EXAMPLE.AIRPORTS |WHERE AIRPORTS. |""".stripMargin) { t => - val hover = compilerService.hover(t.q, asJson(), Pos(1, 16)) + val hover = compilerService.hover(t.q, mkEnv(), Pos(1, 16)) assert(hover.completion.contains(TypeCompletion("example", "schema"))) // The calls to the dotAutoComplete have to point to the place before the dot - val dotCompletion = compilerService.dotAutoComplete(t.q, asJson(), Pos(2, 15)) + val dotCompletion = compilerService.dotAutoComplete(t.q, mkEnv(), Pos(2, 15)) assert( dotCompletion.completions.toSet === Set( LetBindCompletion("icao", "character varying"), @@ -506,9 +530,9 @@ class TestSqlCompilerServiceAirports test("""SELECT * FROM EXAMPLE.AIRPORTS |WHERE AI. |""".stripMargin) { t => - val hover = compilerService.hover(t.q, asJson(), Pos(1, 16)) + val hover = compilerService.hover(t.q, mkEnv(), Pos(1, 16)) assert(hover.completion.contains(TypeCompletion("example", "schema"))) - val completion = compilerService.wordAutoComplete(t.q, asJson(), "", Pos(2, 9)) + val completion = compilerService.wordAutoComplete(t.q, mkEnv(), "", Pos(2, 9)) assert( completion.completions.toSet === Set( LetBindCompletion("airport_id", "integer"), @@ -522,13 +546,13 @@ class TestSqlCompilerServiceAirports test("""SELECT * FROM "EXAMPLE"."AIRPORTS" |WHERE "AIRPORTS". |""".stripMargin) { t => - val hover = compilerService.hover(t.q, asJson(), Pos(1, 16)) + val hover = compilerService.hover(t.q, mkEnv(), Pos(1, 16)) assert(hover.completion.isEmpty) - val completion = compilerService.wordAutoComplete(t.q, asJson(), "", Pos(2, 9)) + val completion = compilerService.wordAutoComplete(t.q, mkEnv(), "", Pos(2, 9)) assert(completion.completions.isEmpty) // The calls to the dotAutoComplete have to point to the place before the dot - val dotCompletion = compilerService.dotAutoComplete(t.q, asJson(), Pos(2, 17)) + val dotCompletion = compilerService.dotAutoComplete(t.q, mkEnv(), Pos(2, 17)) assert(dotCompletion.completions.isEmpty) } @@ -536,12 +560,12 @@ class TestSqlCompilerServiceAirports |WHERE airports.city = 'Porto' |AND airports.country = 'Portugal' |""".stripMargin) { t => - val hover = compilerService.hover(t.q, asJson(), Pos(1, 16)) + val hover = compilerService.hover(t.q, mkEnv(), Pos(1, 16)) assert(hover.completion.contains(TypeCompletion("example", "schema"))) - val completion = compilerService.wordAutoComplete(t.q, asJson(), "exa", Pos(1, 18)) + val completion = compilerService.wordAutoComplete(t.q, mkEnv(), "exa", Pos(1, 18)) assert(completion.completions sameElements Array(LetBindCompletion("example", "schema"))) // The calls to the dotAutoComplete have to point to the place before the dot - val dotCompletion = compilerService.dotAutoComplete(t.q, asJson(), Pos(1, 22)) + val dotCompletion = compilerService.dotAutoComplete(t.q, mkEnv(), Pos(1, 22)) assert( dotCompletion.completions.toSet === Set( LetBindCompletion("airports", "table"), @@ -552,49 +576,63 @@ class TestSqlCompilerServiceAirports } test("SELECT * FROM example.airports WHERE city = 'Braganca'") { t => - val v = compilerService.validate(t.q, asJson()) + val v = compilerService.validate(t.q, mkEnv()) assert(v.messages.isEmpty) - val GetProgramDescriptionSuccess(description) = compilerService.getProgramDescription(t.q, asJson()) + val Right(description) = compilerService.getProgramDescription(t.q, mkEnv()) val Some(main) = description.maybeRunnable assert( main.outType.get == airportType ) assert(main.params.contains(Vector.empty)) - val baos = new ByteArrayOutputStream() - assert( - compilerService.execute( - t.q, - asJson(), - None, - baos - ) == ExecutionSuccess(true) - ) - assert( - baos.toString() == - """[ - | { - | "airport_id": 1618, - | "name": "Braganca", - | "city": "Braganca", - | "country": "Portugal", - | "iata_faa": "BGC", - | "icao": "LPBG", - | "latitude": 41.857800, - | "longitude": -6.707125, - | "altitude": 2241.000, - | "timezone": 0, - | "dst": "E", - | "tz": "Europe/Lisbon" - | } - |]""".stripMargin.replaceAll("\\s+", "") - ) + val EvalSuccess.IteratorValue(tipe, it) = compilerService.eval(t.q, mkEnv(), None).value + val row = it.next() + it.hasNext shouldBe false + it.close() + val columns = row.getRecord.getFieldsList + val colTypes = tipe.getRecord.getAttsList.asScala.map(_.getTipe) + colTypes(0).hasInt shouldBe true + columns.get(0).getName shouldBe "airport_id" + columns.get(0).getValue.getInt.getV shouldBe 1618 + colTypes(1).hasString shouldBe true + columns.get(1).getName shouldBe "name" + columns.get(1).getValue.getString.getV shouldBe "Braganca" + colTypes(2).hasString shouldBe true + columns.get(2).getName shouldBe "city" + columns.get(2).getValue.getString.getV shouldBe "Braganca" + colTypes(3).hasString shouldBe true + columns.get(3).getName shouldBe "country" + columns.get(3).getValue.getString.getV shouldBe "Portugal" + colTypes(4).hasString shouldBe true + columns.get(4).getName shouldBe "iata_faa" + columns.get(4).getValue.getString.getV shouldBe "BGC" + colTypes(5).hasString shouldBe true + columns.get(5).getName shouldBe "icao" + columns.get(5).getValue.getString.getV shouldBe "LPBG" + colTypes(6).hasDecimal shouldBe true + columns.get(6).getName shouldBe "latitude" + columns.get(6).getValue.getDecimal.getV shouldBe "41.857800" + colTypes(7).hasDecimal shouldBe true + columns.get(7).getName shouldBe "longitude" + columns.get(7).getValue.getDecimal.getV shouldBe "-6.707125" + colTypes(8).hasDecimal shouldBe true + columns.get(8).getName shouldBe "altitude" + columns.get(8).getValue.getDecimal.getV shouldBe "2241.000" + colTypes(9).hasInt shouldBe true + columns.get(9).getName shouldBe "timezone" + columns.get(9).getValue.getDouble.getV shouldBe 0 + colTypes(10).hasString shouldBe true + columns.get(10).getName shouldBe "dst" + columns.get(10).getValue.getString.getV shouldBe "E" + colTypes(11).hasString shouldBe true + columns.get(11).getName shouldBe "tz" + columns.get(11).getValue.getString.getV shouldBe "Europe/Lisbon" } test("SELECT * FROM example.airports WHERE city = :city") { t => - val environment = asCsv(Map("city" -> RawString("Braganca"))) + val environment = mkEnv(Map("city" -> RawString("Braganca"))) val v = compilerService.validate(t.q, environment) assert(v.messages.isEmpty) - val GetProgramDescriptionSuccess(description) = compilerService.getProgramDescription(t.q, environment) + val Right(description) = compilerService.getProgramDescription(t.q, environment) assert(description.decls.isEmpty) val Some(main) = description.maybeRunnable assert( @@ -605,29 +643,56 @@ class TestSqlCompilerServiceAirports assert(param.idn == "city") assert(param.tipe.get == RawStringType(true, false)) assert(param.defaultValue.isEmpty) - val baos = new ByteArrayOutputStream() - assert( - compilerService.execute( - t.q, - environment, - None, - baos - ) == ExecutionSuccess(true) - ) - assert( - baos.toString() == - """airport_id,name,city,country,iata_faa,icao,latitude,longitude,altitude,timezone,dst,tz - |1618,Braganca,Braganca,Portugal,BGC,LPBG,41.857800,-6.707125,2241.000,0,E,Europe/Lisbon - |""".stripMargin - ) + val EvalSuccess.IteratorValue(tipe, it) = compilerService.eval(t.q, environment, None).value + val row = it.next() + it.hasNext shouldBe false + it.close() + val columns = row.getRecord.getFieldsList + val colTypes = tipe.getRecord.getAttsList.asScala.map(_.getTipe) + colTypes(0).hasInt shouldBe true + columns.get(0).getName shouldBe "airport_id" + columns.get(0).getValue.getInt.getV shouldBe 1618 + colTypes(1).hasString shouldBe true + columns.get(1).getName shouldBe "name" + columns.get(1).getValue.getString.getV shouldBe "Braganca" + colTypes(2).hasString shouldBe true + columns.get(2).getName shouldBe "city" + columns.get(2).getValue.getString.getV shouldBe "Braganca" + colTypes(3).hasString shouldBe true + columns.get(3).getName shouldBe "country" + columns.get(3).getValue.getString.getV shouldBe "Portugal" + colTypes(4).hasString shouldBe true + columns.get(4).getName shouldBe "iata_faa" + columns.get(4).getValue.getString.getV shouldBe "BGC" + colTypes(5).hasString shouldBe true + columns.get(5).getName shouldBe "icao" + columns.get(5).getValue.getString.getV shouldBe "LPBG" + colTypes(6).hasDecimal shouldBe true + columns.get(6).getName shouldBe "latitude" + columns.get(6).getValue.getDecimal.getV shouldBe "41.857800" + colTypes(7).hasDecimal shouldBe true + columns.get(7).getName shouldBe "longitude" + columns.get(7).getValue.getDecimal.getV shouldBe "-6.707125" + colTypes(8).hasDecimal shouldBe true + columns.get(8).getName shouldBe "altitude" + columns.get(8).getValue.getDecimal.getV shouldBe "2241.000" + colTypes(9).hasInt shouldBe true + columns.get(9).getName shouldBe "timezone" + columns.get(9).getValue.getDouble.getV shouldBe 0 + colTypes(10).hasString shouldBe true + columns.get(10).getName shouldBe "dst" + columns.get(10).getValue.getString.getV shouldBe "E" + colTypes(11).hasString shouldBe true + columns.get(11).getName shouldBe "tz" + columns.get(11).getValue.getString.getV shouldBe "Europe/Lisbon" } test("""-- a query with no default, called without the parameter |SELECT * FROM example.airports WHERE city = :city""".stripMargin) { t => - val environment = asJson(Map("city" -> RawNull())) + val environment = mkEnv(Map("city" -> RawNull())) val v = compilerService.validate(t.q, environment) assert(v.messages.isEmpty) - val GetProgramDescriptionSuccess(description) = compilerService.getProgramDescription(t.q, environment) + val Right(description) = compilerService.getProgramDescription(t.q, environment) val Some(main) = description.maybeRunnable assert( main.outType.get == airportType @@ -637,25 +702,31 @@ class TestSqlCompilerServiceAirports assert(param.idn == "city") assert(param.tipe.get == RawStringType(true, false)) assert(param.defaultValue.isEmpty) - val baos = new ByteArrayOutputStream() - assert( - compilerService.execute( - t.q, - environment, - None, - baos - ) == ExecutionSuccess(true) - ) - assert(baos.toString() == "[]") + val EvalSuccess.IteratorValue(tipe, it) = compilerService.eval(t.q, environment, None).value + it.hasNext shouldBe false + it.close() + val colTypes = tipe.getRecord.getAttsList.asScala.map(_.getTipe) + colTypes(0).hasInt shouldBe true + colTypes(1).hasString shouldBe true + colTypes(2).hasString shouldBe true + colTypes(3).hasString shouldBe true + colTypes(4).hasString shouldBe true + colTypes(5).hasString shouldBe true + colTypes(6).hasDecimal shouldBe true + colTypes(7).hasDecimal shouldBe true + colTypes(8).hasDecimal shouldBe true + colTypes(9).hasInt shouldBe true + colTypes(10).hasString shouldBe true + colTypes(11).hasString shouldBe true } test("""-- a query with a default, called without the parameter |-- @default city 'Athens' |SELECT COUNT(*) AS n FROM example.airports WHERE city = :city""".stripMargin) { t => - val environment = asJson() + val environment = mkEnv() val v = compilerService.validate(t.q, environment) assert(v.messages.isEmpty) - val GetProgramDescriptionSuccess(description) = compilerService.getProgramDescription(t.q, environment) + val Right(description) = compilerService.getProgramDescription(t.q, environment) assert(description.decls.isEmpty) val Some(main) = description.maybeRunnable assert( @@ -672,23 +743,24 @@ class TestSqlCompilerServiceAirports assert(param.idn == "city") assert(param.tipe.contains(RawStringType(true, false))) assert(param.defaultValue.contains(RawString("Athens"))) - val baos = new ByteArrayOutputStream() - assert( - compilerService.execute( - t.q, - environment, - None, - baos - ) == ExecutionSuccess(true) - ) - assert(baos.toString() == """[{"n":6}]""") + val EvalSuccess.IteratorValue(tipe, it) = compilerService.eval(t.q, environment, None).value + val row = it.next() + it.hasNext shouldBe false + it.close() + val columns = row.getRecord.getFieldsList + columns.size shouldBe 1 + val colTypes = tipe.getRecord.getAttsList.asScala.map(_.getTipe) + colTypes.size shouldBe 1 + colTypes(0).hasLong shouldBe true + columns.get(0).getName shouldBe "n" + columns.get(0).getValue.getLong.getV shouldBe 6 } test("""-- @type age intger |-- @default age 'tralala |-- @param whatever an unknown parameter |SELECT COUNT(*) FROM example.airports WHERE :city = city""".stripMargin) { t => - val environment = asJson(Map.empty) + val environment = mkEnv(Map.empty) val v = compilerService.validate(t.q, environment) val expectedErrors = List( 1 -> "unsupported type intger", @@ -709,7 +781,7 @@ class TestSqlCompilerServiceAirports |-- @default age 'tralala' |-- @param whatever an unknown parameter |SELECT COUNT(*) FROM example.airports WHERE :age = city""".stripMargin) { t => - val environment = asJson(Map.empty) + val environment = mkEnv(Map.empty) val v = compilerService.validate(t.q, environment) val expectedErrors = List( @@ -728,10 +800,10 @@ class TestSqlCompilerServiceAirports } test("SELECT * FROM example.airports WHERE city = :param and airport_id = :param") { t => - val environment = asCsv(Map("param" -> RawString("Braganca"))) + val environment = mkEnv(Map("param" -> RawString("Braganca"))) val v = compilerService.validate(t.q, environment) assert(v.messages.nonEmpty) - val GetProgramDescriptionFailure(errors) = compilerService.getProgramDescription(t.q, environment) + val Left(errors) = compilerService.getProgramDescription(t.q, environment) assert(errors.size === 1) assert(errors.head.message === "a parameter cannot be both varchar and integer") assert(errors.head.positions(0).begin === ErrorPosition(1, 45)) @@ -743,74 +815,77 @@ class TestSqlCompilerServiceAirports test("""SELECT * |FROM""".stripMargin) { t => val expectedError = "syntax error at end of input" - val validation = compilerService.validate(t.q, asJson()) + val validation = compilerService.validate(t.q, mkEnv()) assert(validation.messages.exists(_.message.contains(expectedError))) - val GetProgramDescriptionFailure(descriptionErrors) = compilerService.getProgramDescription(t.q, asJson()) + val Left(descriptionErrors) = compilerService.getProgramDescription(t.q, mkEnv()) assert(descriptionErrors.exists(_.message.contains(expectedError))) - val baos = new ByteArrayOutputStream() - val ExecutionValidationFailure(executionErrors) = compilerService.execute(t.q, asJson(), None, baos) - assert(executionErrors.exists(_.message.contains(expectedError))) + val Left(ExecutionError.ValidationError(errorMessages)) = compilerService.eval(t.q, mkEnv(), None) + assert(errorMessages.exists(_.message.contains(expectedError))) } test("SELECT * FROM inexistent-table") { t => val expectedError = "syntax error at or near \"-\"" - val v = compilerService.validate(t.q, asJson()) + val v = compilerService.validate(t.q, mkEnv()) assert(v.messages.exists(_.message.contains(expectedError))) - val GetProgramDescriptionFailure(descriptionErrors) = compilerService.getProgramDescription(t.q, asJson()) + val Left(descriptionErrors) = compilerService.getProgramDescription(t.q, mkEnv()) assert(descriptionErrors.exists(_.message.contains(expectedError))) - val baos = new ByteArrayOutputStream() - val ExecutionValidationFailure(executionErrors) = compilerService.execute(t.q, asJson(), None, baos) - assert(executionErrors.exists(_.message.contains(expectedError))) + val Left(ExecutionError.ValidationError(errorMessages)) = compilerService.eval(t.q, mkEnv(), None) + assert(errorMessages.exists(_.message.contains(expectedError))) } test("SELECT * FROM inexistent_table") { t => val expectedErrors = Set("relation \"inexistent_table\" does not exist", "Did you forget to add credentials?") - val v = compilerService.validate(t.q, asJson()) + val v = compilerService.validate(t.q, mkEnv()) val failures = v.messages.collect { case errorMessage: ErrorMessage => errorMessage } assert(failures.exists(failure => expectedErrors.forall(failure.message.contains))) - val GetProgramDescriptionFailure(descriptionErrors) = compilerService.getProgramDescription(t.q, asJson()) + val Left(descriptionErrors) = compilerService.getProgramDescription(t.q, mkEnv()) assert(descriptionErrors.exists(error => expectedErrors.forall(error.message.contains))) - val baos = new ByteArrayOutputStream() - val ExecutionValidationFailure(executionErrors) = compilerService.execute(t.q, asJson(), None, baos) - assert(executionErrors.exists(error => expectedErrors.forall(error.message.contains))) + val Left(ExecutionError.ValidationError(errorMessages)) = compilerService.eval(t.q, mkEnv(), None) + assert(errorMessages.exists(error => expectedErrors.forall(error.message.contains))) } test("""-- @default a 1234 |-- @type a date |SELECT LENGTH(:a)""".stripMargin) { t => val expectedErrors = Set("cannot cast type integer to date", "function length(date) does not exist") - val v = compilerService.validate(t.q, asJson()) + val v = compilerService.validate(t.q, mkEnv()) val failures = v.messages.collect { case errorMessage: ErrorMessage => errorMessage } assert(failures.forall(err => expectedErrors.exists(err.message.contains))) - val GetProgramDescriptionFailure(descriptionErrors) = compilerService.getProgramDescription(t.q, asJson()) + val Left(descriptionErrors) = compilerService.getProgramDescription(t.q, mkEnv()) assert(descriptionErrors.forall(err => expectedErrors.exists(err.message.contains))) - val baos = new ByteArrayOutputStream() - val ExecutionValidationFailure(executionErrors) = compilerService.execute(t.q, asJson(), None, baos) - assert(executionErrors.forall(err => expectedErrors.exists(err.message.contains))) + val Left(ExecutionError.ValidationError(errorMessages)) = compilerService.eval(t.q, mkEnv(), None) + assert(errorMessages.forall(err => expectedErrors.exists(err.message.contains))) } test("""/* @default a 1 + 1 */ |SELECT :a + :a AS v""".stripMargin) { t => - val v = compilerService.validate(t.q, asJson()) + val v = compilerService.validate(t.q, mkEnv()) assert(v.messages.isEmpty, v.messages.mkString(",")) - val GetProgramDescriptionSuccess(description) = compilerService.getProgramDescription(t.q, asJson()) + val Right(description) = compilerService.getProgramDescription(t.q, mkEnv()) assert(!description.maybeRunnable.get.params.get.head.required) assert(description.maybeRunnable.get.params.get.head.defaultValue.contains(RawInt(2))) - val baos = new ByteArrayOutputStream() - assert(compilerService.execute(t.q, asJson(), None, baos) == ExecutionSuccess(true)) - assert(baos.toString() == """[{"v":4}]""") + val EvalSuccess.IteratorValue(tipe, it) = compilerService.eval(t.q, mkEnv(), None).value + val row = it.next() + it.hasNext shouldBe false + it.close() + val columns = row.getRecord.getFieldsList + columns.size shouldBe 1 + val colTypes = tipe.getRecord.getAttsList.asScala.map(_.getTipe) + colTypes.size shouldBe 1 + colTypes(0).hasInt shouldBe true + columns.get(0).getName shouldBe "v" + columns.get(0).getValue.getInt.getV shouldBe 4 } test("SELECT * FROM wrong.relation") { t => val expectedErrors = Set("relation \"wrong.relation\" does not exist", "Did you forget to add credentials?") - val v = compilerService.validate(t.q, asJson()) + val v = compilerService.validate(t.q, mkEnv()) val failures = v.messages.collect { case errorMessage: ErrorMessage => errorMessage } assert(failures.exists(failure => expectedErrors.forall(failure.message.contains))) - val GetProgramDescriptionFailure(descriptionErrors) = compilerService.getProgramDescription(t.q, asJson()) + val Left(descriptionErrors) = compilerService.getProgramDescription(t.q, mkEnv()) assert(descriptionErrors.exists(error => expectedErrors.forall(error.message.contains))) - val baos = new ByteArrayOutputStream() - val ExecutionValidationFailure(executionErrors) = compilerService.execute(t.q, asJson(), None, baos) - assert(executionErrors.exists(error => expectedErrors.forall(error.message.contains))) + val Left(ExecutionError.ValidationError(errorMessages)) = compilerService.eval(t.q, mkEnv(), None) + assert(errorMessages.forall(err => expectedErrors.exists(err.message.contains))) } private val airportType = RawIterableType( @@ -836,15 +911,43 @@ class TestSqlCompilerServiceAirports false ) + private def fetchOneRowResult(q: String, environment: ProgramEnvironment): Seq[(Type, Value)] = { + val EvalSuccess.IteratorValue(tipe, it) = compilerService.eval(q, environment, None).value + val row = it.next() + it.hasNext shouldBe false + it.close() + val columns = row.getRecord.getFieldsList.asScala + val atts = tipe.getRecord.getAttsList.asScala + columns.size shouldBe atts.size + atts.zip(columns).foreach { case (a, v) => v.getName shouldBe a.getIdn } + atts.zip(columns).map { case (t, v) => (t.getTipe, v.getValue) } + } + + private def fetchStatementResult(q: String, environment: ProgramEnvironment): Int = { + val EvalSuccess.ResultValue(tipe, v) = compilerService.eval(q, environment, None).value + val atts = tipe.getRecord.getAttsList.asScala + atts.size shouldBe 1 + tipe.getRecord.getAtts(0).getIdn shouldBe "update_count" + v.getRecord.getFields(0).getValue.getInt.getV + } + + private def fetchCountQueryResult(q: String, environment: ProgramEnvironment) = { + val res = fetchOneRowResult(q, environment) + res.size shouldBe 1 + res.head._1.hasLong shouldBe true + res.head._2.hasLong shouldBe true + res.head._2.getLong.getV + } + test(""" |SELECT COUNT(*) FROM example.airports |WHERE city = :name OR country = :name""".stripMargin) { t => - val withCity = asJson(Map("name" -> RawString("Braganca"))) - val withCountry = asJson(Map("name" -> RawString("Portugal"))) - val withNull = asJson(Map("name" -> RawNull())) + val withCity = mkEnv(Map("name" -> RawString("Braganca"))) + val withCountry = mkEnv(Map("name" -> RawString("Portugal"))) + val withNull = mkEnv(Map("name" -> RawNull())) val v = compilerService.validate(t.q, withCity) assert(v.messages.isEmpty) - val GetProgramDescriptionSuccess(description) = compilerService.getProgramDescription(t.q, withCity) + val Right(description) = compilerService.getProgramDescription(t.q, withCity) assert(description.decls.isEmpty) val Some(main) = description.maybeRunnable assert( @@ -859,26 +962,21 @@ class TestSqlCompilerServiceAirports assert(param.idn == "name") assert(param.tipe.get == RawStringType(true, false)) assert(param.defaultValue.isEmpty) - val baos = new ByteArrayOutputStream() - baos.reset() - assert(compilerService.execute(t.q, withCity, None, baos) == ExecutionSuccess(true)) - assert(baos.toString() == """[{"count":1}]""") - baos.reset() - assert(compilerService.execute(t.q, withNull, None, baos) == ExecutionSuccess(true)) - assert(baos.toString() == """[{"count":0}]""") - baos.reset() - assert(compilerService.execute(t.q, withCountry, None, baos) == ExecutionSuccess(true)) - assert(baos.toString() == """[{"count":39}]""") + + fetchCountQueryResult(t.q, withCity) shouldBe 1 + fetchCountQueryResult(t.q, withNull) shouldBe 0 + fetchCountQueryResult(t.q, withCountry) shouldBe 39 + } test(""" |SELECT COUNT(*) FROM example.airports |WHERE city = COALESCE(:name, 'Lyon')""".stripMargin) { t => - val withCity = asJson(Map("name" -> RawString("Braganca"))) - val withNull = asJson(Map("name" -> RawNull())) + val withCity = mkEnv(Map("name" -> RawString("Braganca"))) + val withNull = mkEnv(Map("name" -> RawNull())) val v = compilerService.validate(t.q, withCity) assert(v.messages.isEmpty) - val GetProgramDescriptionSuccess(description) = compilerService.getProgramDescription(t.q, withCity) + val Right(description) = compilerService.getProgramDescription(t.q, withCity) assert(description.decls.isEmpty) val Some(main) = description.maybeRunnable assert( @@ -893,20 +991,15 @@ class TestSqlCompilerServiceAirports assert(param.idn == "name") assert(param.tipe.get == RawStringType(true, false)) assert(param.defaultValue.isEmpty) - val baos = new ByteArrayOutputStream() - baos.reset() - assert(compilerService.execute(t.q, withCity, None, baos) == ExecutionSuccess(true)) - assert(baos.toString() == """[{"count":1}]""") - baos.reset() - assert(compilerService.execute(t.q, withNull, None, baos) == ExecutionSuccess(true)) - assert(baos.toString() == """[{"count":3}]""") + fetchCountQueryResult(t.q, withCity) shouldBe 1 + fetchCountQueryResult(t.q, withNull) shouldBe 3 } test("""-- @param s just an int |SELECT DATE '2002-01-01' - :s::int AS x -- RD-10538""".stripMargin) { t => - val v = compilerService.validate(t.q, asJson()) + val v = compilerService.validate(t.q, mkEnv()) assert(v.messages.isEmpty) - val GetProgramDescriptionSuccess(description) = compilerService.getProgramDescription(t.q, asJson()) + val Right(description) = compilerService.getProgramDescription(t.q, mkEnv()) assert(description.decls.isEmpty) val Some(main) = description.maybeRunnable assert( @@ -927,22 +1020,16 @@ class TestSqlCompilerServiceAirports Vector(ParamDescription("s", Some(RawIntType(true, false)), None, Some("just an int"), required = true)) ) ) - val baos = new ByteArrayOutputStream() - assert( - compilerService.execute( - t.q, - asJson(), - None, - baos - ) == ExecutionRuntimeFailure("no value was specified for s") - ) + val Left(ExecutionError.ValidationError(errorMessages)) = compilerService.eval(t.q, mkEnv(), None) + errorMessages.size shouldBe 1 + errorMessages.head shouldBe ErrorMessage("no value was specified for s", List(), "sqlError", List()) } test("""-- @default s CAST(null AS INTEGER) |SELECT DATE '2002-01-01' - :s AS x -- RD-10538""".stripMargin) { t => - val v = compilerService.validate(t.q, asJson()) + val v = compilerService.validate(t.q, mkEnv()) assert(v.messages.isEmpty) - val GetProgramDescriptionSuccess(description) = compilerService.getProgramDescription(t.q, asJson()) + val Right(description) = compilerService.getProgramDescription(t.q, mkEnv()) val Some(main) = description.maybeRunnable assert( main.outType.contains( @@ -964,23 +1051,17 @@ class TestSqlCompilerServiceAirports Vector(ParamDescription("s", Some(RawIntType(true, false)), Some(RawNull()), comment = None, required = false)) ) ) - val baos = new ByteArrayOutputStream() - assert( - compilerService.execute( - t.q, - asJson(), - None, - baos - ) == ExecutionSuccess(true) - ) - assert( - baos.toString() == - """[ - | { - | "x": null - | } - |]""".stripMargin.replaceAll("\\s+", "") - ) + val EvalSuccess.IteratorValue(tipe, it) = compilerService.eval(t.q, mkEnv(), None).value + val row = it.next() + it.hasNext shouldBe false + it.close() + val columns = row.getRecord.getFieldsList + columns.size shouldBe 1 + val colTypes = tipe.getRecord.getAttsList.asScala.map(_.getTipe) + colTypes.size shouldBe 1 + colTypes(0).hasDate shouldBe true + columns.get(0).getName shouldBe "x" + columns.get(0).getValue.hasNull shouldBe true } private val airportColumns = Set( @@ -1001,9 +1082,8 @@ class TestSqlCompilerServiceAirports test("""SELECT COUNT(*) FROM example.airports;""") { t => val baos = new ByteArrayOutputStream() baos.reset() - val noParam = asJson() - assert(compilerService.execute(t.q, noParam, None, baos) == ExecutionSuccess(true)) - assert(baos.toString() == """[{"count":8107}]""") + val noParam = mkEnv() + fetchCountQueryResult(t.q, noParam) shouldBe 8107 } test( // RD-10505 @@ -1013,9 +1093,8 @@ class TestSqlCompilerServiceAirports ) { t => val baos = new ByteArrayOutputStream() baos.reset() - val noParam = asJson() - assert(compilerService.execute(t.q, noParam, None, baos) == ExecutionSuccess(true)) - assert(baos.toString() == """[{"count":8107}]""") + val noParam = mkEnv() + fetchCountQueryResult(t.q, noParam) shouldBe 8107 } test( // RD-10505 @@ -1027,9 +1106,8 @@ class TestSqlCompilerServiceAirports ) { t => val baos = new ByteArrayOutputStream() baos.reset() - val noParam = asJson() - assert(compilerService.execute(t.q, noParam, None, baos) == ExecutionSuccess(true)) - assert(baos.toString() == """[{"count":8107}]""") + val noParam = mkEnv() + fetchCountQueryResult(t.q, noParam) shouldBe 8107 } // #RD-10612: hovering on a parameter name doesn't return the parameter type + fails internally @@ -1037,28 +1115,28 @@ class TestSqlCompilerServiceAirports // However, because of the function SqlCodeUtils.identifiers right now it returns an identifier with an empty string. // The state machine in that function bails out because it finds the ':' at the start of the string. test("SELECT :v > 12 AS column") { t => - val hover = compilerService.hover(t.q, asJson(), Pos(1, 10)) + val hover = compilerService.hover(t.q, mkEnv(), Pos(1, 10)) assert(hover.completion.contains(TypeCompletion("v", "integer"))) // postgres type } test("SELECT :city > 12, city FROM example.airports WHERE airport_id = :city LIMIT 2 ") { t => - val hover1 = compilerService.hover(t.q, asJson(), Pos(1, 11)) + val hover1 = compilerService.hover(t.q, mkEnv(), Pos(1, 11)) assert(hover1.completion.contains(TypeCompletion("city", "integer"))) // has to be the Postgres type - val hover2 = compilerService.hover(t.q, asJson(), Pos(1, 20)) + val hover2 = compilerService.hover(t.q, mkEnv(), Pos(1, 20)) assert(hover2.completion.contains(TypeCompletion("city", "character varying"))) } // RD-10865 (mistakenly passing snapi code) test("""[{a: 12}, null]""".stripMargin) { t => - val v = compilerService.validate(t.q, asJson()) + val v = compilerService.validate(t.q, mkEnv()) assert(v.messages.exists(_.message contains "the input does not form a valid statement or expression")) } test("""RD-10948+10961""") { _ => val q = """: |""".stripMargin - val ValidateResponse(errors) = compilerService.validate(q, asJson()) + val ValidateResponse(errors) = compilerService.validate(q, mkEnv()) assert(errors.nonEmpty) } @@ -1070,7 +1148,7 @@ class TestSqlCompilerServiceAirports |) as i(id, first_name, email, password) |WHERE email = :email AND password: |""".stripMargin - val ValidateResponse(errors) = compilerService.validate(q, asJson()) + val ValidateResponse(errors) = compilerService.validate(q, mkEnv()) assert(errors.nonEmpty) } @@ -1084,74 +1162,72 @@ class TestSqlCompilerServiceAirports |) as i(id, first_name, last_name, birthday) |WHERE id = :id && id = : |""".stripMargin - val ValidateResponse(errors) = compilerService.validate(q, asJson()) + val ValidateResponse(errors) = compilerService.validate(q, mkEnv()) assert(errors.nonEmpty) } test("""scopes work""") { _ => - val baos = new ByteArrayOutputStream() - def runWith(q: String, scopes: Set[String]): String = { - val env = asJson(scopes = scopes) + def runWith(q: String, scopes: Set[String]): Seq[(String, String)] = { + val env = mkEnv(scopes = scopes) assert(compilerService.validate(q, env).messages.isEmpty) - val GetProgramDescriptionSuccess(_) = compilerService.getProgramDescription(q, env) - baos.reset() - assert(compilerService.execute(q, env, None, baos) == ExecutionSuccess(true)) - baos.toString + val Right(_) = compilerService.getProgramDescription(q, env) + val EvalSuccess.IteratorValue(_, it) = compilerService + .eval(q, env, None) + .value + val r = it.map { v => + val fields = v.getRecord.getFieldsList + (fields.get(0).getName, fields.get(0).getValue.getString.getV) + }.toList + it.close() + r } -// assert(runWith("SELECT e.airport_id FROM example.airports e", Set.empty) == """[]""") - assert(runWith("SELECT token\nFROM scopes", Set.empty) == """[]""") - assert(runWith("SELECT * FROM scopes value ORDER by value", Set.empty) == """[]""") - assert(runWith("SELECT * FROM scopes AS value ORDER by value", Set("ADMIN")) == """[{"token":"ADMIN"}]""") - assert( - runWith( - "SELECT token FROM scopes value ORDER by value", - Set("ADMIN", "SALES", "DEV") - ) == """[{"token":"ADMIN"},{"token":"DEV"},{"token":"SALES"}]""" - ) - assert( - runWith( - """SELECT 'DEV' IN (SELECT * FROM scopes) AS isDev, - | 'ADMIN' IN (SELECT token FROM scopes) AS isAdmin""".stripMargin, - Set("ADMIN") - ) == """[{"isdev":false,"isadmin":true}]""" - ) + runWith("SELECT token\nFROM scopes", Set.empty) shouldBe Seq.empty + runWith("SELECT * FROM scopes value ORDER by value", Set.empty) shouldBe Seq.empty + runWith("SELECT * FROM scopes AS value ORDER by value", Set("ADMIN")) shouldBe Seq(("token", "ADMIN")) + runWith( + "SELECT token FROM scopes value ORDER by value", + Set("ADMIN", "SALES", "DEV") + ) shouldBe Seq(("token", "ADMIN"), ("token", "DEV"), ("token", "SALES")) + // more complex query + val r1 = fetchOneRowResult( + """SELECT 'DEV' IN (SELECT * FROM scopes) AS isDev, + | 'ADMIN' IN (SELECT token FROM scopes) AS isAdmin""".stripMargin, + mkEnv(scopes = Set("ADMIN")) + ) + r1.size shouldBe 2 + r1(0)._1.hasBool shouldBe true + r1(0)._2.getBool.getV shouldBe false + r1(1)._1.hasBool shouldBe true + r1(1)._2.getBool.getV shouldBe true // demo CASE WHEN to hide a certain field - val q = """SELECT - | CASE WHEN 'DEV' IN (SELECT * FROM scopes) THEN trip_id END AS trip_id, -- "AS trip_id" to name it normally - | departure_date, - | arrival_date - |FROM example.trips - |WHERE reason = 'Holidays' AND departure_date = DATE '2016-02-27'""".stripMargin - assert( - runWith(q, Set("ADMIN")) - == """[{"trip_id":null,"departure_date":"2016-02-27","arrival_date":"2016-03-06"}]""" - ) - assert( - runWith(q, Set("DEV")) - == """[{"trip_id":0,"departure_date":"2016-02-27","arrival_date":"2016-03-06"}]""" - ) + val r2 = fetchOneRowResult( + """SELECT + | CASE WHEN 'DEV' IN (SELECT * FROM scopes) THEN trip_id END AS trip_id, -- "AS trip_id" to name it normally + | departure_date, + | arrival_date + |FROM example.trips + |WHERE reason = 'Holidays' AND departure_date = DATE '2016-02-27'""".stripMargin, + mkEnv(scopes = Set("ADMIN")) + ) + r2.size shouldBe 3 // Three columns + r2(0)._1.hasInt shouldBe true + r2(0)._2.hasNull shouldBe true + r2(1)._1.hasDate shouldBe true + r2(1)._2.getDate.getYear shouldBe 2016 + r2(1)._2.getDate.getMonth shouldBe 2 + r2(1)._2.getDate.getDay shouldBe 27 + r2(2)._2.getDate.getYear shouldBe 2016 + r2(2)._2.getDate.getMonth shouldBe 3 + r2(2)._2.getDate.getDay shouldBe 6 } test("""-- @param p |-- @type p integer |-- SELECT :p + 10; |""".stripMargin) { t => - val v = compilerService.validate(t.q, asJson()) + val v = compilerService.validate(t.q, mkEnv()) assert(v.messages.isEmpty) - val baos = new ByteArrayOutputStream() - assert( - compilerService.execute( - t.q, - asJson(Map("p" -> RawInt(5))), - None, - baos - ) == ExecutionSuccess(true) - ) - // The code does nothing, but we don't get an error when running it in Postgres. - assert( - baos.toString() === - """[{"update_count":0}]""".stripMargin - ) + fetchStatementResult(t.q, mkEnv(Map("p" -> RawInt(5)))) shouldBe 0 } test("""select @@ -1162,139 +1238,117 @@ class TestSqlCompilerServiceAirports | example.airports |limit 10; |""".stripMargin) { t => - val v = compilerService.validate(t.q, asJson()) + val v = compilerService.validate(t.q, mkEnv()) assert(v.messages.size == 1) assert(v.messages(0).message == "schema \"country\" does not exist") } test("""SELECT pg_typeof(NOW())""".stripMargin) { t => - val ValidateResponse(errors) = compilerService.validate(t.q, asJson()) + val ValidateResponse(errors) = compilerService.validate(t.q, mkEnv()) assert(errors.isEmpty) - val GetProgramDescriptionFailure(errors2) = compilerService.getProgramDescription(t.q, asJson()) + val Left(errors2) = compilerService.getProgramDescription(t.q, mkEnv()) errors2.map(_.message).contains("unsupported type: regtype") } test("""SELECT CAST(pg_typeof(NOW()) AS VARCHAR)""".stripMargin) { t => - val ValidateResponse(errors) = compilerService.validate(t.q, asJson()) + val ValidateResponse(errors) = compilerService.validate(t.q, mkEnv()) assert(errors.isEmpty) - val GetProgramDescriptionSuccess(_) = compilerService.getProgramDescription(t.q, asJson()) - val baos = new ByteArrayOutputStream() - baos.reset() - assert(compilerService.execute(t.q, asJson(), None, baos) == ExecutionSuccess(true)) - assert(baos.toString() == """[{"pg_typeof":"timestamp with time zone"}]""") - + val Right(_) = compilerService.getProgramDescription(t.q, mkEnv()) + val row = fetchOneRowResult(t.q, mkEnv()) + row.size shouldBe 1 + row.head._1.hasString shouldBe true + row.head._2.getString.getV shouldBe "timestamp with time zone" } test("""SELECT NOW()""".stripMargin) { t => // NOW() is a timestamp with timezone. The one of the SQL connection. This test is to make sure // it works (we cannot assert on the result). - val ValidateResponse(errors) = compilerService.validate(t.q, asJson()) + val ValidateResponse(errors) = compilerService.validate(t.q, mkEnv()) assert(errors.isEmpty) - val GetProgramDescriptionSuccess(description) = compilerService.getProgramDescription(t.q, asJson()) - val baos = new ByteArrayOutputStream() - baos.reset() - assert(compilerService.execute(t.q, asJson(), None, baos) == ExecutionSuccess(true)) + val Right(description) = compilerService.getProgramDescription(t.q, mkEnv()) + val row = fetchOneRowResult(t.q, mkEnv()) + row.size shouldBe 1 } test("""SELECT TIMESTAMP '2001-07-01 12:13:14.567' AS t""".stripMargin) { t => - val ValidateResponse(errors) = compilerService.validate(t.q, asJson()) + val ValidateResponse(errors) = compilerService.validate(t.q, mkEnv()) assert(errors.isEmpty) - val GetProgramDescriptionSuccess(_) = compilerService.getProgramDescription(t.q, asJson()) - val baos = new ByteArrayOutputStream() - for (env <- Seq(asJson(), asCsv())) { - baos.reset() - assert(compilerService.execute(t.q, env, None, baos) == ExecutionSuccess(true)) - assert(baos.toString().contains("12:13:14.567")) - } - } - - test("""SELECT TIMESTAMP '2001-07-01 12:13:14' AS t""".stripMargin) { t => - val ValidateResponse(errors) = compilerService.validate(t.q, asJson()) - assert(errors.isEmpty) - val GetProgramDescriptionSuccess(_) = compilerService.getProgramDescription(t.q, asJson()) - val baos = new ByteArrayOutputStream() - baos.reset() - assert(compilerService.execute(t.q, asCsv(), None, baos) == ExecutionSuccess(true)) - assert(baos.toString().contains("12:13:14")) - assert(!baos.toString().contains("12:13:14.000")) + val Right(_) = compilerService.getProgramDescription(t.q, mkEnv()) + val row = fetchOneRowResult(t.q, mkEnv()) + row.size shouldBe 1 + val ts = row.head._2.getTimestamp + ts.getYear shouldBe 2001 + ts.getMonth shouldBe 7 + ts.getDay shouldBe 1 + ts.getHour shouldBe 12 + ts.getMinute shouldBe 13 + ts.getSecond shouldBe 14 + ts.getNano shouldBe 567000000 } test("""SELECT TIME '12:13:14.567' AS t""".stripMargin) { t => - val ValidateResponse(errors) = compilerService.validate(t.q, asJson()) - assert(errors.isEmpty) - val GetProgramDescriptionSuccess(_) = compilerService.getProgramDescription(t.q, asJson()) - val baos = new ByteArrayOutputStream() - for (env <- Seq(asJson(), asCsv())) { - baos.reset() - assert(compilerService.execute(t.q, env, None, baos) == ExecutionSuccess(true)) - assert(baos.toString().contains("12:13:14.567")) - } - } - - test("""SELECT TIME '12:13:14' AS t""".stripMargin) { t => - val ValidateResponse(errors) = compilerService.validate(t.q, asJson()) + val ValidateResponse(errors) = compilerService.validate(t.q, mkEnv()) assert(errors.isEmpty) - val GetProgramDescriptionSuccess(_) = compilerService.getProgramDescription(t.q, asJson()) - val baos = new ByteArrayOutputStream() - baos.reset() - assert(compilerService.execute(t.q, asCsv(), None, baos) == ExecutionSuccess(true)) - assert(baos.toString().contains("12:13:14")) - assert(!baos.toString().contains("12:13:14.000")) + val Right(_) = compilerService.getProgramDescription(t.q, mkEnv()) + val row = fetchOneRowResult(t.q, mkEnv()) + row.size shouldBe 1 + val ts = row.head._2.getTime + ts.getHour shouldBe 12 + ts.getMinute shouldBe 13 + ts.getSecond shouldBe 14 + ts.getNano shouldBe 567000000 } test("""-- @default t TIME '12:13:14.567' |SELECT :t AS t""".stripMargin) { t => - val ValidateResponse(errors) = compilerService.validate(t.q, asJson()) + val ValidateResponse(errors) = compilerService.validate(t.q, mkEnv()) assert(errors.isEmpty) - val GetProgramDescriptionSuccess(_) = compilerService.getProgramDescription(t.q, asJson()) + val Right(_) = compilerService.getProgramDescription(t.q, mkEnv()) val baos = new ByteArrayOutputStream() baos.reset() - for (env <- Seq(asJson(), asCsv())) { - baos.reset() - assert(compilerService.execute(t.q, env, None, baos) == ExecutionSuccess(true)) - assert(baos.toString().contains("12:13:14.567")) - } + val row = fetchOneRowResult(t.q, mkEnv()) + row.size shouldBe 1 + val ts = row.head._2.getTime + ts.getHour shouldBe 12 + ts.getMinute shouldBe 13 + ts.getSecond shouldBe 14 + ts.getNano shouldBe 567000000 } test("""-- @type x integer |-- @default x null |SELECT :x AS x""".stripMargin) { t => - val baos = new ByteArrayOutputStream() - baos.reset() - assert(compilerService.execute(t.q, asJson(), None, baos) == ExecutionSuccess(true)) - assert(baos.toString() === """[{"x":null}]""") - baos.reset() - assert(compilerService.execute(t.q, asJson(Map("x" -> RawInt(12))), None, baos) == ExecutionSuccess(true)) - assert(baos.toString() === """[{"x":12}]""") + val rowWithDefault = fetchOneRowResult(t.q, mkEnv()) + rowWithDefault.size shouldBe 1 + rowWithDefault.head._2.hasNull shouldBe true + val rowWith12 = fetchOneRowResult(t.q, mkEnv(Map("x" -> RawInt(12)))) + rowWith12.size shouldBe 1 + rowWith12.head._2.getInt.getV shouldBe 12 } test("""-- @type x varchar |-- @default x null |SELECT :x AS x""".stripMargin) { t => - val baos = new ByteArrayOutputStream() - assert(compilerService.execute(t.q, asJson(), None, baos) == ExecutionSuccess(true)) - assert(baos.toString() === """[{"x":null}]""") - baos.reset() - assert(compilerService.execute(t.q, asJson(Map("x" -> RawString("tralala"))), None, baos) == ExecutionSuccess(true)) - assert(baos.toString() === """[{"x":"tralala"}]""") + val rowWithDefault = fetchOneRowResult(t.q, mkEnv()) + rowWithDefault.size shouldBe 1 + rowWithDefault.head._2.hasNull shouldBe true + val rowWithTralala = fetchOneRowResult(t.q, mkEnv(Map("x" -> RawString("tralala")))) + rowWithTralala.size shouldBe 1 + rowWithTralala.head._2.getString.getV shouldBe "tralala" } test("""-- @type x date |-- @default x null |SELECT :x AS x""".stripMargin) { t => - val baos = new ByteArrayOutputStream() - assert(compilerService.execute(t.q, asJson(), None, baos) == ExecutionSuccess(true)) - assert(baos.toString() === """[{"x":null}]""") - baos.reset() - assert( - compilerService.execute( - t.q, - asJson(Map("x" -> RawDate(LocalDate.of(2008, 9, 29)))), - None, - baos - ) == ExecutionSuccess(true) - ) - assert(baos.toString() === """[{"x":"2008-09-29"}]""") + val rowWithDefault = fetchOneRowResult(t.q, mkEnv()) + rowWithDefault.size shouldBe 1 + rowWithDefault.head._2.hasNull shouldBe true + val rowWithDate = fetchOneRowResult(t.q, mkEnv(Map("x" -> RawDate(LocalDate.of(2008, 9, 29))))) + rowWithDate.size shouldBe 1 + val d = rowWithDate.head._2.getDate + d.getYear shouldBe 2008 + d.getMonth shouldBe 9 + d.getDay shouldBe 29 } test("""RD-14898""".stripMargin) { _ => @@ -1305,10 +1359,10 @@ class TestSqlCompilerServiceAirports | AND airport_id = :iata |""".stripMargin) // hover ':location' returns the type of location (varchar) - val v1 = compilerService.hover(t.q, asJson(), Pos(2, 57)) + val v1 = compilerService.hover(t.q, mkEnv(), Pos(2, 57)) assert(v1.completion.contains(TypeCompletion("location", "varchar"))) // hover ':iata' returns the type of iata (integer) - val v2 = compilerService.hover(t.q, asJson(), Pos(3, 57)) + val v2 = compilerService.hover(t.q, mkEnv(), Pos(3, 57)) assert(v2.completion.contains(TypeCompletion("iata", "integer"))) } { @@ -1318,86 +1372,26 @@ class TestSqlCompilerServiceAirports | AND airport_id = :iata |""".stripMargin) // hover ':location' still returns the type of location (varchar) - val v1 = compilerService.hover(t.q, asJson(), Pos(2, 57)) + val v1 = compilerService.hover(t.q, mkEnv(), Pos(2, 57)) assert(v1.completion.contains(TypeCompletion("location", "varchar"))) // hover ':iata' doesn't return anything since it's ignored. - val v2 = compilerService.hover(t.q, asJson(), Pos(3, 57)) + val v2 = compilerService.hover(t.q, mkEnv(), Pos(3, 57)) assert(v2.completion.isEmpty) } } - test("SELECT 'a=>1,b=>tralala'::hstore AS r -- JSON") { t => - val baos = new ByteArrayOutputStream() - assert( - compilerService.execute( - t.q, - asJson(), - None, - baos - ) == ExecutionSuccess(true) - ) - assert(baos.toString() === """[{"r":{"a":"1","b":"tralala"}}]""") - } - - test("SELECT NULL::hstore AS r -- JSON") { t => - val baos = new ByteArrayOutputStream() - assert( - compilerService.execute( - t.q, - asJson(), - None, - baos - ) == ExecutionSuccess(true) - ) - assert(baos.toString() === """[{"r":null}]""") - } - - // TODO What do we do in CSV? - ignore("SELECT 'a=>1,b=>tralala'::hstore AS r -- CSV") { t => - val baos = new ByteArrayOutputStream() - assert( - compilerService.execute( - t.q, - asCsv(), - None, - baos - ) == ExecutionSuccess(true) - ) - assert( - baos.toString() === - """r - |{"a":"1","b":"tralala"} - |""".stripMargin - ) - } - - ignore("SELECT NULL::hstore AS r -- CSV") { t => - val baos = new ByteArrayOutputStream() - assert( - compilerService.execute( - t.q, - asCsv(), - None, - baos - ) == ExecutionSuccess(true) - ) - assert(baos.toString() === """[{"r":{"a":"1","b":"tralala"}}]""") + // Checking nested records with hstore + test("SELECT 'a=>1,b=>tralala'::hstore AS r") { t => + val Seq((_, v)) = fetchOneRowResult(t.q, mkEnv()) + v.getRecord.getFields(0).getName shouldBe "a" + v.getRecord.getFields(0).getValue.getString.getV shouldBe "1" + v.getRecord.getFields(1).getName shouldBe "b" + v.getRecord.getFields(1).getValue.getString.getV shouldBe "tralala" } - test("SELECT a.* FROM example.airports a ORDER BY airport_id LIMIT 1") { t => - val baos = new ByteArrayOutputStream() - assert( - compilerService.execute( - t.q, - asCsv(), - None, - baos - ) == ExecutionSuccess(true) - ) - assert(baos.toString() === """airport_id,name,city,country,iata_faa,icao,latitude,longitude,altitude,timezone,dst,tz - |1,Goroka,Goroka,Papua New Guinea,GKA,AYGA,-6.081689,145.391881,5282.000,10,U,Pacific/Port_Moresby - |""".stripMargin) - + test("SELECT NULL::hstore AS r") { t => + val Seq((_, v)) = fetchOneRowResult(t.q, mkEnv()) + v.hasNull shouldBe true } test("INSERT") { _ => @@ -1406,100 +1400,32 @@ class TestSqlCompilerServiceAirports """INSERT INTO example.airports (airport_id, name, city, country, iata_faa, icao, latitude, longitude, altitude, timezone, dst, tz) |VALUES (8108, :airport, :city, :country, 'FC', 'FC', 0.0, 0.0, 0.0, 0, 'U', 'UTC') |""".stripMargin - val baos = new ByteArrayOutputStream() - assert( - compilerService.execute( - q, - asCsv(params = - Map("airport" -> RawString("FAKE"), "city" -> RawString("Fake City"), "country" -> RawString("Fake Country")) - ), - None, - baos - ) == ExecutionSuccess(true) - ) - assert( - baos.toString() === - """update_count - |1 - |""".stripMargin - ) - baos.reset() - assert( - compilerService.execute( - "SELECT city, country FROM example.airports WHERE name = :a", - asCsv(params = Map("a" -> RawString("FAKE"))), - None, - baos - ) == ExecutionSuccess(true) - ) - assert( - baos.toString() === - """city,country - |Fake City,Fake Country - |""".stripMargin + fetchStatementResult( + q, + mkEnv( + Map("airport" -> RawString("FAKE"), "city" -> RawString("Fake City"), "country" -> RawString("Fake Country")) + ) + ) shouldBe 1 + val columns = fetchOneRowResult( + "SELECT city, country FROM example.airports WHERE name = :a", + mkEnv(params = Map("a" -> RawString("FAKE"))) ) - + columns.size shouldBe 2 + columns(0)._2.getString.getV shouldBe "Fake City" + columns(1)._2.getString.getV shouldBe "Fake Country" } - test("UPDATE (CSV output)") { _ => - val baos = new ByteArrayOutputStream() - assert( - compilerService.execute( - "UPDATE example.airports SET city = :newName WHERE country = :c", - asCsv(params = Map("newName" -> RawString("La Roche sur Foron"), "c" -> RawString("Portugal"))), - None, - baos - ) == ExecutionSuccess(true) - ) - assert( - baos.toString() === - """update_count - |39 - |""".stripMargin - ) - baos.reset() - assert( - compilerService.execute( - "SELECT DISTINCT city FROM example.airports WHERE country = :c", - asCsv(params = Map("c" -> RawString("Portugal"))), - None, - baos - ) == ExecutionSuccess(true) - ) - assert( - baos.toString() === - """city - |La Roche sur Foron - |""".stripMargin + test("UPDATE") { _ => + fetchStatementResult( + "UPDATE example.airports SET city = :newName WHERE country = :c", + mkEnv(params = Map("newName" -> RawString("La Roche sur Foron"), "c" -> RawString("Portugal"))) + ) shouldBe 39 + val columns = fetchOneRowResult( + "SELECT DISTINCT city FROM example.airports WHERE country = :c", + mkEnv(params = Map("c" -> RawString("Portugal"))) ) + columns.size shouldBe 1 + columns.head._2.getString.getV shouldBe "La Roche sur Foron" } - test("UPDATE (Json output)") { _ => - val baos = new ByteArrayOutputStream() - assert( - compilerService.execute( - "UPDATE example.airports SET city = :newName WHERE country = :c", - asJson(params = Map("newName" -> RawString("Lausanne"), "c" -> RawString("Portugal"))), - None, - baos - ) == ExecutionSuccess(true) - ) - assert( - baos.toString() === - """[{"update_count":39}]""".stripMargin - ) - baos.reset() - assert( - compilerService.execute( - "SELECT DISTINCT city FROM example.airports WHERE country = :c", - asJson(params = Map("c" -> RawString("Portugal"))), - None, - baos - ) == ExecutionSuccess(true) - ) - assert( - baos.toString() === - """[{"city":"Lausanne"}]""" - ) - } } diff --git a/sql-compiler/src/test/scala/com/rawlabs/sql/compiler/TestSqlConnectionFailures.scala b/sql-compiler/src/test/scala/com/rawlabs/sql/compiler/TestSqlConnectionFailures.scala index 7b82bd54e..426223cf6 100644 --- a/sql-compiler/src/test/scala/com/rawlabs/sql/compiler/TestSqlConnectionFailures.scala +++ b/sql-compiler/src/test/scala/com/rawlabs/sql/compiler/TestSqlConnectionFailures.scala @@ -13,29 +13,12 @@ package com.rawlabs.sql.compiler import com.dimafeng.testcontainers.{ForAllTestContainer, PostgreSQLContainer} -import com.rawlabs.compiler.{ - AutoCompleteResponse, - CompilerService, - ExecutionResponse, - ExecutionRuntimeFailure, - ExecutionSuccess, - GetProgramDescriptionFailure, - GetProgramDescriptionResponse, - GetProgramDescriptionSuccess, - HoverResponse, - LetBindCompletion, - Pos, - ProgramEnvironment, - RawInt, - TypeCompletion, - ValidateResponse -} +import com.rawlabs.compiler._ +import com.rawlabs.utils.core._ import org.scalatest.matchers.must.Matchers.be import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper import org.testcontainers.utility.DockerImageName -import com.rawlabs.utils.core._ -import java.io.ByteArrayOutputStream import java.sql.DriverManager import java.util.concurrent.{Executors, TimeUnit} import scala.io.Source @@ -125,10 +108,10 @@ class TestSqlConnectionFailures val pool = Executors.newFixedThreadPool(others.size) try { // All other users run a long query which picks a connection for them - val futures = others.map(user => pool.submit(() => runExecute(compilerService, user, longRunningQuery, 5))) + val futures = others.map(user => pool.submit(() => runEval(compilerService, user, longRunningQuery, 5))) val results = futures.map(_.get(60, TimeUnit.SECONDS)) results.foreach { - case ExecutionSuccess(complete) => complete shouldBe true + case Right(n) => logger.debug(s"Got $n rows") case r => fail(s"unexpected result $r") } // The user is able to get a connection to run all LSP calls. @@ -161,10 +144,10 @@ class TestSqlConnectionFailures val compilerService = new SqlCompilerService() val pool = Executors.newFixedThreadPool(others.size) try { - val futures = others.map(user => pool.submit(() => runExecute(compilerService, user, longRunningQuery, 5))) + val futures = others.map(user => pool.submit(() => runEval(compilerService, user, longRunningQuery, 5))) val results = futures.map(_.get(60, TimeUnit.SECONDS)) results.foreach { - case ExecutionSuccess(complete) => complete shouldBe true + case Right(n) => logger.debug(s"Got $n rows") case r => fail(s"unexpected result $r") } // hover returns nothing @@ -194,7 +177,7 @@ class TestSqlConnectionFailures val compilerService = new SqlCompilerService() val pool = Executors.newFixedThreadPool(users.size) try { - val futures = users.map(user => pool.submit(() => runExecute(compilerService, user, longRunningQuery, 5))) + val futures = users.map(user => pool.submit(() => runEval(compilerService, user, longRunningQuery, 5))) Thread.sleep(2000) // give some time to make sure they're all running val hoverResponse = runHover(compilerService, joe, "SELECT * FROM example.airports", Pos(1, 17)) assert(hoverResponse.completion.contains(TypeCompletion("example", "schema"))) @@ -209,7 +192,7 @@ class TestSqlConnectionFailures }.toSet === Set("airports", "trips", "machines")) val results = futures.map(_.get(60, TimeUnit.SECONDS)) results.foreach { - case ExecutionSuccess(complete) => complete shouldBe true + case Right(n) => logger.debug(s"Got $n rows") case r => fail(s"unexpected result $r") } } finally { @@ -228,7 +211,7 @@ class TestSqlConnectionFailures val pool = Executors.newFixedThreadPool(users.size) try { // All users run a long query - val futures = users.map(user => pool.submit(() => runExecute(compilerService, user, longRunningQuery, 5))) + val futures = users.map(user => pool.submit(() => runEval(compilerService, user, longRunningQuery, 5))) Thread.sleep(2000) // give some time to make sure they're all running // hover is None val hoverResponse = runHover(compilerService, joe, "SELECT * FROM example.airports", Pos(1, 17)) @@ -243,7 +226,7 @@ class TestSqlConnectionFailures assert(dotCompletionResponse.completions.isEmpty) val results = futures.map(_.get(60, TimeUnit.SECONDS)) results.foreach { - case ExecutionSuccess(complete) => complete shouldBe true + case Right(n) => logger.debug(s"Got $n rows") case r => fail(s"unexpected result $r") } } finally { @@ -252,7 +235,7 @@ class TestSqlConnectionFailures } } - test("[execute] enough connections in total") { _ => + test("[eval] enough connections in total") { _ => // Each user runs three times the same long query, one call at a time. The same connection is reused per user. // This is confirmed by setting max-connections-per-db to 1 although several calls are performed per DB. // In total, there's one connection per user. Setting max-connections to nUsers is working. @@ -263,10 +246,10 @@ class TestSqlConnectionFailures val iterations = 1 to nCalls try { val results = users - .map(user => user -> iterations.map(_ => runExecute(compilerService, user, longRunningQuery, 0))) + .map(user => user -> iterations.map(_ => runEval(compilerService, user, longRunningQuery, 0))) .toMap for (userResults <- results.values; r <- userResults) r match { - case ExecutionSuccess(complete) => complete shouldBe true + case Right(n) => logger.debug(s"Got $n rows") case _ => fail(s"unexpected result $r") } } finally { @@ -274,7 +257,7 @@ class TestSqlConnectionFailures } } - test("[execute] enough connections per user") { _ => + test("[eval] enough connections per user") { _ => // We run `execute` _in parallel_ using a long query. Each user runs it `nCalls` times. So we have // a total number of queries of nUsers x nCalls. We set max-connections to that value to be sure and // set max-connections-per-db to nCalls so that all concurrent queries can run. @@ -287,12 +270,12 @@ class TestSqlConnectionFailures try { val futures = users .map(user => - user -> iterations.map(_ => pool.submit(() => runExecute(compilerService, user, longRunningQuery, 5))) + user -> iterations.map(_ => pool.submit(() => runEval(compilerService, user, longRunningQuery, 5))) ) .toMap val results = futures.mapValues(_.map(_.get(60, TimeUnit.SECONDS))) for (userResults <- results.values; r <- userResults) r match { - case ExecutionSuccess(complete) => complete shouldBe true + case Right(n) => logger.debug(s"Got $n rows") case _ => fail(s"unexpected result $r") } } finally { @@ -301,7 +284,7 @@ class TestSqlConnectionFailures } } - test("[execute] not enough connections") { _ => + test("[eval] not enough connections") { _ => // Each user runs twice execute, one call at a time. The same connection can be reused per user. // In total, there's one connection per user. Setting max-connections to nUsers - 1 triggers the // expected failure. The number of errors hit should be positive (checked in the end) @@ -312,14 +295,16 @@ class TestSqlConnectionFailures val iterations = 1 to nCalls try { val results = users - .map(user => user -> iterations.map(_ => runExecute(compilerService, user, longRunningQuery, 0))) + .map(user => user -> iterations.map(_ => runEval(compilerService, user, longRunningQuery, 0))) .toMap + var errorCount = 0 for (userResults <- results.values; r <- userResults) r match { - case ExecutionSuccess(complete) => complete shouldBe true - case ExecutionRuntimeFailure(error) => error shouldBe "no connections available" + case Right(n) => logger.debug(s"Got $n rows") + case Left(ExecutionError.ValidationError(Seq(error))) => + error.message shouldBe "no connections available" + errorCount += 1 case _ => fail(s"unexpected result $r") } - val errorCount = results.values.map(_.count(_.isInstanceOf[ExecutionRuntimeFailure])).sum errorCount should be > 0 } finally { compilerService.stop() @@ -339,13 +324,14 @@ class TestSqlConnectionFailures val results = users .map(user => user -> iterations.map(_ => runGetProgramDescription(compilerService, user, longValidateQuery))) .toMap + var errorCount = 0 for (userResults <- results.values; r <- userResults) r match { - case GetProgramDescriptionSuccess(_) => - case GetProgramDescriptionFailure(errors) => + case Right(_) => + case Left(errors) => errors.size shouldBe 1 errors.head.message shouldBe "no connections available" + errorCount += 1 } - val errorCount = results.values.map(_.count(_.isInstanceOf[GetProgramDescriptionFailure])).sum errorCount should be > 0 } finally { compilerService.stop() @@ -365,20 +351,21 @@ class TestSqlConnectionFailures val results = users .map(user => user -> iterations.map(_ => runValidate(compilerService, user, longValidateQuery))) .toMap + var errorCount = 0 for (userResults <- results.values; r <- userResults) r match { case ValidateResponse(errors) if errors.isEmpty => case ValidateResponse(errors) => errors.size shouldBe 1 errors.head.message shouldBe "no connections available" + errorCount += 1 } - val errorCount = results.values.map(_.count(_.messages.nonEmpty)).sum errorCount should be > 0 } finally { compilerService.stop() } } - test("[execute] not enough connections per user") { _ => + test("[eval] not enough connections per user") { _ => // We run `execute` in parallel using a long query. Each user runs it `nCalls` times. So we have // a total number of queries of nUsers x nCalls. We set max-connections to that value to be sure but // set max-connections-per-db to two so that all concurrent queries cannot all get a connection although @@ -392,16 +379,18 @@ class TestSqlConnectionFailures try { val futures = users .map(user => - user -> iterations.map(_ => pool.submit(() => runExecute(compilerService, user, longRunningQuery, 5))) + user -> iterations.map(_ => pool.submit(() => runEval(compilerService, user, longRunningQuery, 5))) ) .toMap val results = futures.mapValues(_.map(_.get(60, TimeUnit.SECONDS))) + var errorCount = 0 for (userResults <- results.values; r <- userResults) r match { - case ExecutionSuccess(complete) => complete shouldBe true - case ExecutionRuntimeFailure(error) => error shouldBe "too many connections active" + case Right(n) => logger.debug(s"Got $n rows") + case Left(ExecutionError.ValidationError(Seq(error))) => + error.message shouldBe "too many connections active" + errorCount += 1 case _ => fail(s"unexpected result $r") } - val errorCount = results.values.map(_.count(_.isInstanceOf[ExecutionRuntimeFailure])).sum errorCount should be > 0 } finally { pool.close() @@ -429,13 +418,14 @@ class TestSqlConnectionFailures ) .toMap val results = futures.mapValues(_.map(_.get(60, TimeUnit.SECONDS))) + var errorCount = 0 for (userResults <- results.values; r <- userResults) r match { - case GetProgramDescriptionSuccess(_) => - case GetProgramDescriptionFailure(errors) => + case Right(_) => + case Left(errors) => errors.size shouldBe 1 errors.head.message shouldBe "too many connections active" + errorCount += 1 } - val errorCount = results.values.map(_.count(_.isInstanceOf[GetProgramDescriptionFailure])).sum errorCount should be > 0 } finally { pool.close() @@ -461,13 +451,14 @@ class TestSqlConnectionFailures ) .toMap val results = futures.mapValues(_.map(_.get(60, TimeUnit.SECONDS))) + var errorCount = 0 for (userResults <- results.values; r <- userResults) r match { case ValidateResponse(errors) if errors.isEmpty => case ValidateResponse(errors) => errors.size shouldBe 1 errors.head.message shouldBe "too many connections active" + errorCount += 1 } - val errorCount = results.values.map(_.count(_.messages.nonEmpty)).sum errorCount should be > 0 } finally { pool.close() @@ -475,26 +466,27 @@ class TestSqlConnectionFailures } } - private def runExecute( + private def runEval( compilerService: CompilerService, user: RawUid, code: String, arg: Int - ): ExecutionResponse = { + ): Either[ExecutionError, Long] = { val env = ProgramEnvironment( user, Some(Array("arg" -> RawInt(arg))), Set.empty, Map.empty, Map.empty, - Map("output-format" -> "json"), + Map.empty, jdbcUrl = Some(jdbcUrl(user)) ) - val baos = new ByteArrayOutputStream() - try { - compilerService.execute(code, env, None, baos) - } finally { - baos.close() + compilerService.eval(code, env, None).right.map { + case EvalSuccess.IteratorValue(_, it) => + val n = it.size + it.close() + n + case _ => 1 } } @@ -510,7 +502,7 @@ class TestSqlConnectionFailures Set.empty, Map.empty, Map.empty, - Map("output-format" -> "json"), + Map.empty, jdbcUrl = Some(jdbcUrl(user)) ) compilerService.hover(code, env, pos) @@ -529,7 +521,7 @@ class TestSqlConnectionFailures Set.empty, Map.empty, Map.empty, - Map("output-format" -> "json"), + Map.empty, jdbcUrl = Some(jdbcUrl(user)) ) compilerService.wordAutoComplete(code, env, prefix, pos) @@ -547,7 +539,7 @@ class TestSqlConnectionFailures Set.empty, Map.empty, Map.empty, - Map("output-format" -> "json"), + Map.empty, jdbcUrl = Some(jdbcUrl(user)) ) compilerService.dotAutoComplete(code, env, pos) @@ -557,14 +549,14 @@ class TestSqlConnectionFailures compilerService: CompilerService, user: RawUid, code: String - ): GetProgramDescriptionResponse = { + ): Either[List[ErrorMessage], ProgramDescription] = { val env = ProgramEnvironment( user, None, Set.empty, Map.empty, Map.empty, - Map("output-format" -> "json"), + Map.empty, jdbcUrl = Some(jdbcUrl(user)) ) compilerService.getProgramDescription(code, env) @@ -581,7 +573,7 @@ class TestSqlConnectionFailures Set.empty, Map.empty, Map.empty, - Map("output-format" -> "json"), + Map.empty, jdbcUrl = Some(jdbcUrl(user)) ) compilerService.validate(code, env)