Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support parametrized TypeRep in Scala AST #10612

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,16 @@ class TypeOrderingSpec extends AnyWordSpec with Matchers {
TypeOrdering.compare
) shouldBe primTypesInProtoOrder
}
"order parametrized TypeReps" in {
TypeOrdering.compare(
Ast.TTypeRepGeneric(Ast.KStar),
Ast.TTypeRepGeneric(Ast.KStar),
) shouldBe 0
TypeOrdering.compare(
Ast.TTypeRepGeneric(Ast.KStar),
Ast.TTypeRepGeneric(Ast.KArrow(Ast.KStar, Ast.KStar)),
) shouldBe -1
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ private[engine] final class Preprocessor(compiledPackages: MutableCompiledPackag
}
case Ast.TTyCon(_) | Ast.TNat(_) | Ast.TBuiltin(_) | Ast.TVar(_) =>
go(typesToProcess, tmplToProcess0, tyConAlreadySeen0, tmplsAlreadySeen0)
case Ast.TSynApp(_, _) | Ast.TForall(_, _) | Ast.TStruct(_) =>
case Ast.TSynApp(_, _) | Ast.TForall(_, _) | Ast.TStruct(_) | Ast.TTypeRepGeneric(_) =>
// We assume that getDependencies is always given serializable types
ResultError(
Error.Preprocessing
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,8 @@ object InterfaceReader {
toIfaceType(ctx, arg, FrontStack.empty) flatMap (tArg =>
toIfaceType(ctx, tyfun, tArg +: args)
)
case Ast.TForall(_, _) | Ast.TStruct(_) | Ast.TNat(_) | Ast.TSynApp(_, _) =>
case Ast.TForall(_, _) | Ast.TStruct(_) | Ast.TNat(_) | Ast.TSynApp(_, _) |
Ast.TTypeRepGeneric(_) =>
unserializableDataType(ctx, s"unserializable data type: ${a.pretty}")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ class TypeSpec extends AnyWordSpec with Matchers {
case _: Pkg.TStruct => sys.error("cannot use structs in interface type")
case _: Pkg.TForall => sys.error("cannot use forall in interface type")
case _: Pkg.TSynApp => sys.error("cannot use type synonym in interface type")
case _: Pkg.TTypeRepGeneric => sys.error("cannot use TypeRepGeneric in interface type")
}

go(pkgTyp00, BackStack.empty)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,10 @@ private[lf] final class Compiler(
SBFromAny(ty)(compile(e))
case ETypeRep(typ) =>
SEValue(STypeRep(typ))
case ETypeRepGeneric(_, typ) =>
SEValue(STypeRep(typ))
case ETypeRepGenericApp(_, _) =>
SEBuiltin(SBTypeRepApp)
case EToAnyException(ty, e) =>
SBToAny(ty)(compile(e))
case EFromAnyException(ty, e) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,12 @@ private[speedy] sealed abstract class SBuiltin(val arity: Int) {
case otherwise => unexpectedType(i, "Exception", otherwise)
}

final protected def getSTypeRep(args: util.ArrayList[SValue], i: Int): Ast.Type =
args.get(i) match {
case STypeRep(x) => x
case otherwise => unexpectedType(i, "STNat", otherwise)
}

final protected def checkToken(args: util.ArrayList[SValue], i: Int): Unit =
args.get(i) match {
case SToken => ()
Expand Down Expand Up @@ -356,6 +362,14 @@ private[lf] object SBuiltin {
}
}

final case object SBTypeRepApp extends SBuiltinPure(2) {
override private[speedy] def executePure(args: util.ArrayList[SValue]): STypeRep = {
val t1 = getSTypeRep(args, 0)
val t2 = getSTypeRep(args, 1)
STypeRep(Ast.TApp(t1, t2))
}
}

final case object SBShiftNumeric extends SBuiltinPure(3) {
override private[speedy] def executePure(args: util.ArrayList[SValue]): SNumeric = {
val inputScale = getSTNat(args, 0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import java.util
import com.daml.lf.data._
import com.daml.lf.interpretation.{Error => IE}
import com.daml.lf.language.Ast._
import com.daml.lf.language.Util._
import com.daml.lf.speedy.SError.{SError, SErrorCrash}
import com.daml.lf.speedy.SExpr._
import com.daml.lf.speedy.SResult.{SResultError, SResultFinalValue, SResultNeedPackage}
Expand Down Expand Up @@ -1565,6 +1566,22 @@ class SBuiltinTest extends AnyFreeSpec with Matchers with TableDrivenPropertyChe
}
}

"TypeRepGeneric/TypeRepGenericApp" - {
"should produce typerep for Unit" in {
eval(e"type_rep_generic @* @Unit") shouldBe Right(SValue.STypeRep(TUnit))
}
"should produce typerep for Option" in {
eval(e"type_rep_generic @(* -> *) @Option") shouldBe Right(
SValue.STypeRep(TBuiltin(BTOptional))
)
}
"should produce typerep for Option Unit" in {
eval(
e"type_rep_generic_app @* @* (type_rep_generic @(* -> *) @Option) (type_rep_generic @* @Unit)"
) shouldBe Right(SValue.STypeRep(TOptional(TUnit)))
}
}

}

object SBuiltinTest {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,9 @@ object Ast {
/** Unique textual representation of template Id * */
final case class ETypeRep(typ: Type) extends Expr

final case class ETypeRepGeneric(kind: Kind, typ: Type) extends Expr
final case class ETypeRepGenericApp(argKind: Kind, resKind: Kind) extends Expr

/** Throw an exception */
final case class EThrow(returnType: Type, exceptionType: Type, exception: Expr) extends Expr

Expand Down Expand Up @@ -240,6 +243,8 @@ object Ast {
.map { case (n, t) => n + ": " + prettyType(t, precTForall) }
.toSeq
.mkString(", ") + ")"
case TTypeRepGeneric(kind) =>
"TypeRep<" + kind.pretty + ">"
}

def prettyForAll(t: Type): String = t match {
Expand Down Expand Up @@ -283,6 +288,8 @@ object Ast {
/** Structs */
final case class TStruct(fields: Struct[Type]) extends Type

final case class TTypeRepGeneric(kind: Kind) extends Type

sealed abstract class BuiltinType extends Product with Serializable

case object BTInt64 extends BuiltinType
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package com.daml.lf.language

import Ast._

object KindOrdering extends Ordering[Ast.Kind] {
@inline
def compare(x: Kind, y: Kind): Int = {
var diff = 0
var stackX = List(Iterator.single(x))
var stackY = List(Iterator.single(y))

@inline
def push(xs: Iterator[Kind], ys: Iterator[Kind]): Unit = {
stackX = xs :: stackX
stackY = ys :: stackY
}

@inline
def pop(): Unit = {
stackX = stackX.tail
stackY = stackY.tail
}

@inline
def step(tuple: (Kind, Kind)): Unit =
tuple match {
case (KStar, KStar) => diff = 0
case (KNat, KNat) => diff = 0
case (KArrow(x1, x2), KArrow(y1, y2)) =>
push(Iterator(x1, x2), Iterator(y1, y2))
case (k1, k2) =>
diff = kindRank(k1) compareTo kindRank(k2)
}

while (diff == 0 && stackX.nonEmpty) {
diff = stackX.head.hasNext compare stackY.head.hasNext
if (diff == 0)
if (stackX.head.hasNext)
step((stackX.head.next(), stackY.head.next()))
else
pop()
}

diff
}

private[this] def kindRank(kind: Ast.Kind): Int = kind match {
case KStar => 0
case KNat => 1
case KArrow(_, _) => 2
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ object LanguageVersion {
val choiceObservers = v1_11
val bigNumeric = v1_13
val exceptions = v1_14
val typeRepGeneric = v1_dev

/** Unstable, experimental features. This should stay in 1.dev forever.
* Features implemented with this flag should be moved to a separate
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ object TypeOrdering extends Ordering[Type] {
push(xs.values, ys.values)
case (Ast.TApp(x1, x2), Ast.TApp(y1, y2)) =>
push(Iterator(x1, x2), Iterator(y1, y2))
case (Ast.TTypeRepGeneric(k1), Ast.TTypeRepGeneric(k2)) =>
diff = KindOrdering.compare(k1, k2)
case (t1, t2) =>
diff = typeRank(t1) compareTo typeRank(t2)
}
Expand Down Expand Up @@ -100,6 +102,8 @@ object TypeOrdering extends Ordering[Type] {
case Ast.TNat(_) => 2
case Ast.TStruct(_) => 3
case Ast.TApp(_, _) => 4
// TODO fixme
case Ast.TTypeRepGeneric(_) => 5
case Ast.TVar(_) | Ast.TForall(_, _) | Ast.TSynApp(_, _) =>
InternalError.illegalArgumentException(
NameOf.qualifiedNameOfCurrentFunc,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ object Util {
TApp(TApp(TBuiltin(BTArrow), targ), tres)
}

class ParametricType1(bType: BuiltinType) {
val cons = TBuiltin(bType)
class ParametricType1(val cons: Type) {
def this(bType: BuiltinType) = this(TBuiltin(bType))
def apply(typ: Type): Type =
TApp(cons, typ)
def unapply(typ: TApp): Option[Type] = typ match {
Expand Down Expand Up @@ -77,6 +77,7 @@ object Util {
val TParty = TBuiltin(BTParty)
val TAny = TBuiltin(BTAny)
val TTypeRep = TBuiltin(BTTypeRep)
def TTypeRepGen(kind: Kind) = new ParametricType1(TTypeRepGeneric(kind))
val TBigNumeric = TBuiltin(BTBigNumeric)
val TRoundingMode = TBuiltin(BTRoundingMode)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ private[daml] class AstRewriter(
TForall(binder, apply(body))
case TStruct(fields) =>
TStruct(fields.mapValues(apply))
case TTypeRepGeneric(kind) =>
TTypeRepGeneric(kind)
}

def apply(nameWithType: (Name, Type)): (Name, Type) = nameWithType match {
Expand All @@ -60,9 +62,11 @@ private[daml] class AstRewriter(
exprRule(x)
else
x match {
case EVar(_) | EBuiltin(_) | EPrimCon(_) | EPrimLit(_) | ETypeRep(_) |
EExperimental(_, _) =>
case EVar(_) | EBuiltin(_) | EPrimCon(_) | EPrimLit(_) | EExperimental(_, _) |
ETypeRepGenericApp(_, _) =>
x
case ETypeRep(ty) => ETypeRep(apply(ty))
case ETypeRepGeneric(kind, ty) => ETypeRepGeneric(kind, apply(ty))
case EVal(ref) =>
EVal(apply(ref))
case ELocation(loc, expr) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ private[parser] class ExprParser[P](parserParameters: ParserParameters[P]) {
eFromAny |
eToAnyException |
eFromAnyException |
eToTextTypeConName |
eTypeRep |
eTypeRepGeneric |
eTypeRepGenericApp |
eThrow |
(id ^? builtinFunctions) ^^ EBuiltin |
caseOf |
Expand Down Expand Up @@ -217,9 +219,19 @@ private[parser] class ExprParser[P](parserParameters: ParserParameters[P]) {
EThrow(retType, excepType, exception)
}

private lazy val eToTextTypeConName: Parser[Expr] =
private lazy val eTypeRep: Parser[Expr] =
`type_rep` ~>! argTyp ^^ ETypeRep

private lazy val eTypeRepGeneric: Parser[Expr] =
`type_rep_generic` ~>! argKind ~ argTyp ^^ { case kind ~ typ =>
ETypeRepGeneric(kind, typ)
}

private lazy val eTypeRepGenericApp: Parser[Expr] =
`type_rep_generic_app` ~>! argKind ~ argKind ^^ { case k1 ~ k2 =>
ETypeRepGenericApp(k1, k2)
}

private lazy val pattern: Parser[CasePat] =
primCon ^^ CPPrimCon |
(`nil` ^^^ CPNil) |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ private[parser] object Lexer extends RegexParsers {
"to_any" -> `to_any`,
"from_any" -> `from_any`,
"type_rep" -> `type_rep`,
"type_rep_generic" -> `type_rep_generic`,
"type_rep_generic_app" -> `type_rep_generic_app`,
"loc" -> `loc`,
"to_any_exception" -> `to_any_exception`,
"from_any_exception" -> `from_any_exception`,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ private[parser] object Token {
case object `to_any` extends Token
case object `from_any` extends Token
case object `type_rep` extends Token
case object `type_rep_generic` extends Token
case object `type_rep_generic_app` extends Token
case object `loc` extends Token
case object `to_any_exception` extends Token
case object `from_any_exception` extends Token
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,18 @@ private[parser] class TypeParser[P](parameters: ParserParameters[P]) {
private lazy val tTypeSynApp: Parser[Type] =
`|` ~> fullIdentifier ~ rep(typ0) <~ `|` ^^ { case id ~ tys => TSynApp(id, ImmArray(tys)) }

private lazy val tTypeRepGeneric: Parser[Type] =
Id("TypeRepGeneric") ~>! (`[` ~> KindParser.kind <~ `]`) ^^ { k =>
TTypeRepGeneric(k)
}

lazy val typ0: Parser[Type] =
`(` ~> typ <~ `)` |
tNat |
tForall |
tStruct |
tTypeSynApp |
tTypeRepGeneric |
(id ^? builtinTypes) ^^ TBuiltin |
fullIdentifier ^^ TTyCon.apply |
id ^^ TVar.apply
Expand All @@ -80,5 +86,6 @@ private[parser] class TypeParser[P](parameters: ParserParameters[P]) {
lazy val typ: Parser[Type] = rep1sep(typ1, `->`) ^^ (_.reduceRight(TFun))

private[parser] lazy val argTyp: Parser[Type] = `@` ~> typ0
private[parser] lazy val argKind: Parser[Kind] = `@` ~> KindParser.kind0

}
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ class ParsersSpec extends AnyWordSpec with ScalaCheckPropertyChecks with Matcher
"forall (a: *). Mod:T a" -> TForall((α.name, KStar), TApp(T, α)),
"<f1: a, f2: Bool, f3:Mod:T>" ->
TStruct(Struct.assertFromSeq(List(n"f1" -> α, n"f2" -> TBuiltin(BTBool), n"f3" -> T))),
"TypeRepGeneric[* -> *]" -> TTypeRepGeneric(KArrow(KStar, KStar)),
)

forEvery(testCases)((stringToParse, expectedType) =>
Expand Down Expand Up @@ -331,6 +332,12 @@ class ParsersSpec extends AnyWordSpec with ScalaCheckPropertyChecks with Matcher
EFromAnyException(E, e"anyException"),
"throw @Unit @Mod:E exception" ->
EThrow(TUnit, E, e"exception"),
"type_rep_generic @(* -> *) @Option" -> ETypeRepGeneric(
KArrow(KStar, KStar),
TBuiltin(BTOptional),
),
"type_rep_generic @* @Unit" -> ETypeRepGeneric(KStar, TUnit),
"type_rep_generic_app @* @(* -> *)" -> ETypeRepGenericApp(KStar, KArrow(KStar, KStar)),
)

forEvery(testCases)((stringToParse, expectedExp) =>
Expand Down Expand Up @@ -729,6 +736,8 @@ class ParsersSpec extends AnyWordSpec with ScalaCheckPropertyChecks with Matcher
"to_any",
"from_any",
"type_rep",
"type_rep_generic",
"type_rep_generic_app",
"loc",
"to_any_exception",
"from_any_exception",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@ object Repl {
.map { case (n, t) => n + ": " + prettyType(t, precTForall) }
.toSeq
.mkString(", ") + ")"
case TTypeRepGeneric(_) => "typerepgeneric"
}

def prettyForAll(t: Type): String = t match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ private[validation] object AlphaEquiv {
case (TStruct(fs1), TStruct(fs2)) =>
(fs1.names sameElements fs2.names) &&
(fs1.values zip fs2.values).forall((alphaEquiv _).tupled)
case (TTypeRepGeneric(k1), TTypeRepGeneric(k2)) =>
k1 == k2
case _ => false
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ private[validation] object Serializability {
case BTBigNumeric =>
unserializable(URBigNumeric)
}
case TTypeRepGeneric(_) => unserializable(URTypeRepGeneric)
case TForall(_, _) =>
unserializable(URForall)
case TStruct(_) =>
Expand Down
Loading