From 7189f4692d9d2ccc304eef7ab988d424002d2e9e Mon Sep 17 00:00:00 2001 From: Guillaume Martres Date: Thu, 3 Oct 2024 14:25:51 +0200 Subject: [PATCH 1/4] Support refinement type on enum case This used to fail with: trait in value x extends enum EC, but extending enums is prohibited. --- compiler/src/dotty/tools/dotc/typer/Typer.scala | 2 +- tests/pos/enum-refinement.scala | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) create mode 100644 tests/pos/enum-refinement.scala diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 8ba63dfc1e67..1229dc9d5d31 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -3196,7 +3196,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer .withType(dummy.termRef) if (!cls.isOneOf(AbstractOrTrait) && !ctx.isAfterTyper) checkRealizableBounds(cls, cdef.sourcePos.withSpan(cdef.nameSpan)) - if cls.isEnum || firstParentTpe.classSymbol.isEnum then + if cls.isEnum || !cls.isRefinementClass && firstParentTpe.classSymbol.isEnum then checkEnum(cdef, cls, firstParent) val cdef1 = assignType(cpy.TypeDef(cdef)(name, impl1), cls) diff --git a/tests/pos/enum-refinement.scala b/tests/pos/enum-refinement.scala new file mode 100644 index 000000000000..e357125489cd --- /dev/null +++ b/tests/pos/enum-refinement.scala @@ -0,0 +1,12 @@ +enum Enum: + case EC(val x: Int) + +val a: Enum.EC { val x: 1 } = Enum.EC(1).asInstanceOf[Enum.EC { val x: 1 }] + +import scala.language.experimental.modularity + +enum EnumT: + case EC(tracked val x: Int) + +val b: EnumT.EC { val x: 1 } = EnumT.EC(1) + From 4aa59eb5b942261be999e7ea7972511de7d161b1 Mon Sep 17 00:00:00 2001 From: Guillaume Martres Date: Mon, 30 Sep 2024 14:25:03 +0200 Subject: [PATCH 2/4] Minimal support for dependent case classes This lets us write: trait A: type B case class CC(a: A, b: a.B) Pattern matching works but isn't dependent yet: x match case CC(a, b) => val a1: A = a // Dependent pattern matching is not currently supported // val b1: a1.B = b val b1 = b // Type is CC#a.B (for my usecase this isn't a problem, I'm working on a type constraint API which lets me write things like `case class CC(a: Int, b: Int GreaterThan[a.type])`) Because case class pattern matching relies on the product selectors `_N`, making it dependent is a bit tricky, currently we generate: case class CC(a: A, b: a.B): def _1: A = a def _2: a.B = b So the type of `_2` is not obviously related to the type of `_1`, we probably need to change what we generate into: case class CC(a: A, b: a.B): @uncheckedStable def _1: a.type = a def _2: _1.B = b But this can be done in a separate PR. Fixes #8073. --- .../dotc/transform/SyntheticMembers.scala | 73 +++++++++------- .../src/dotty/tools/dotc/typer/Namer.scala | 14 +-- tests/neg/i8069.scala | 8 -- tests/run-macros/tasty-extractors-2.check | 2 +- tests/run/i8073.scala | 86 +++++++++++++++++++ tests/run/i8073b.scala | 86 +++++++++++++++++++ 6 files changed, 216 insertions(+), 53 deletions(-) delete mode 100644 tests/neg/i8069.scala create mode 100644 tests/run/i8073.scala create mode 100644 tests/run/i8073b.scala diff --git a/compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala b/compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala index 376e43b3982d..33fb2de8afe4 100644 --- a/compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala +++ b/compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala @@ -504,53 +504,64 @@ class SyntheticMembers(thisPhase: DenotTransformer) { /** The class * * ``` - * case class C[T <: U](x: T, y: String*) + * trait U: + * type Elem + * + * case class C[T <: U](a: T, b: a.Elem, c: String*) * ``` * * gets the `fromProduct` method: * * ``` * def fromProduct(x$0: Product): MirroredMonoType = - * new C[U]( - * x$0.productElement(0).asInstanceOf[U], - * x$0.productElement(1).asInstanceOf[Seq[String]]: _*) + * val a$1 = x$0.productElement(0).asInstanceOf[U] + * val b$1 = x$0.productElement(1).asInstanceOf[a$1.Elem] + * val c$1 = x$0.productElement(2).asInstanceOf[Seq[String]] + * new C[U](a$1, b$1, c$1*) * ``` * where * ``` * type MirroredMonoType = C[?] * ``` */ - def fromProductBody(caseClass: Symbol, param: Tree, optInfo: Option[MirrorImpl.OfProduct])(using Context): Tree = - def extractParams(tpe: Type): List[Type] = - tpe.asInstanceOf[MethodType].paramInfos - - def computeFromCaseClass: (Type, List[Type]) = - val (baseRef, baseInfo) = - val rawRef = caseClass.typeRef - val rawInfo = caseClass.primaryConstructor.info - optInfo match - case Some(info) => - (rawRef.asSeenFrom(info.pre, caseClass.owner), rawInfo.asSeenFrom(info.pre, caseClass.owner)) - case _ => - (rawRef, rawInfo) - baseInfo match + def fromProductBody(caseClass: Symbol, productParam: Tree, optInfo: Option[MirrorImpl.OfProduct])(using Context): Tree = + val classRef = optInfo match + case Some(info) => TypeRef(info.pre, caseClass) + case _ => caseClass.typeRef + val (newPrefix, constrMeth) = + val constr = TermRef(classRef, caseClass.primaryConstructor) + (constr.info: @unchecked) match case tl: PolyType => val tvars = constrained(tl) val targs = for tvar <- tvars yield tvar.instantiate(fromBelow = false) - (baseRef.appliedTo(targs), extractParams(tl.instantiate(targs))) - case methTpe => - (baseRef, extractParams(methTpe)) - end computeFromCaseClass - - val (classRefApplied, paramInfos) = computeFromCaseClass - val elems = - for ((formal, idx) <- paramInfos.zipWithIndex) yield - val elem = - param.select(defn.Product_productElement).appliedTo(Literal(Constant(idx))) - .ensureConforms(formal.translateFromRepeated(toArray = false)) - if (formal.isRepeatedParam) ctx.typer.seqToRepeated(elem) else elem - New(classRefApplied, elems) + (AppliedType(classRef, targs), tl.instantiate(targs).asInstanceOf[MethodType]) + case mt: MethodType => + (classRef, mt) + + // Create symbols for the vals corresponding to each parameter + // If there are dependent parameters, the infos won't be correct yet. + val bindingSyms = constrMeth.paramRefs.map: pref => + newSymbol(ctx.owner, pref.paramName.freshened, Synthetic, + pref.underlying.translateFromRepeated(toArray = false), coord = ctx.owner.span.focus) + val bindingRefs = bindingSyms.map(TermRef(NoPrefix, _)) + // Fix the infos for dependent parameters + if constrMeth.isParamDependent then + bindingSyms.foreach: bindingSym => + bindingSym.info = bindingSym.info.substParams(constrMeth, bindingRefs) + + val bindingDefs = bindingSyms.zipWithIndex.map: (bindingSym, idx) => + ValDef(bindingSym, + productParam.select(defn.Product_productElement).appliedTo(Literal(Constant(idx))) + .ensureConforms(bindingSym.info)) + + val newArgs = bindingRefs.lazyZip(constrMeth.paramInfos).map: (bindingRef, paramInfo) => + val refTree = ref(bindingRef) + if paramInfo.isRepeatedParam then ctx.typer.seqToRepeated(refTree) else refTree + Block( + bindingDefs, + New(newPrefix, newArgs) + ) end fromProductBody /** For an enum T: diff --git a/compiler/src/dotty/tools/dotc/typer/Namer.scala b/compiler/src/dotty/tools/dotc/typer/Namer.scala index bc4e1a332ff6..197604c3aca3 100644 --- a/compiler/src/dotty/tools/dotc/typer/Namer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Namer.scala @@ -1957,9 +1957,7 @@ class Namer { typer: Typer => if isConstructor then // set result type tree to unit, but take the current class as result type of the symbol typedAheadType(ddef.tpt, defn.UnitType) - val mt = wrapMethType(effectiveResultType(sym, paramSymss)) - if sym.isPrimaryConstructor then checkCaseClassParamDependencies(mt, sym.owner) - mt + wrapMethType(effectiveResultType(sym, paramSymss)) else val paramFn = if Feature.enabled(Feature.modularity) && sym.isAllOf(Given | Method) then wrapRefinedMethType else wrapMethType valOrDefDefSig(ddef, sym, paramSymss, paramFn) @@ -2001,16 +1999,6 @@ class Namer { typer: Typer => ddef.trailingParamss.foreach(completeParams) end completeTrailingParamss - /** Checks an implementation restriction on case classes. */ - def checkCaseClassParamDependencies(mt: Type, cls: Symbol)(using Context): Unit = - mt.stripPoly match - case mt: MethodType if cls.is(Case) && mt.isParamDependent => - // See issue #8073 for background - report.error( - em"""Implementation restriction: case classes cannot have dependencies between parameters""", - cls.srcPos) - case _ => - private def setParamTrackedWithAccessors(psym: Symbol, ownerTpe: Type)(using Context): Unit = for acc <- ownerTpe.decls.lookupAll(psym.name) if acc.is(ParamAccessor) do acc.resetFlag(PrivateLocal) diff --git a/tests/neg/i8069.scala b/tests/neg/i8069.scala deleted file mode 100644 index 50f8b7a3480e..000000000000 --- a/tests/neg/i8069.scala +++ /dev/null @@ -1,8 +0,0 @@ -trait A: - type B - -enum Test: - case Test(a: A, b: a.B) // error: Implementation restriction: case classes cannot have dependencies between parameters - -case class Test2(a: A, b: a.B) // error: Implementation restriction: case classes cannot have dependencies between parameters - diff --git a/tests/run-macros/tasty-extractors-2.check b/tests/run-macros/tasty-extractors-2.check index 5dd6af8d8b04..15d844670b7a 100644 --- a/tests/run-macros/tasty-extractors-2.check +++ b/tests/run-macros/tasty-extractors-2.check @@ -49,7 +49,7 @@ TypeRef(ThisType(TypeRef(NoPrefix(), "scala")), "Unit") Inlined(None, Nil, Block(List(ClassDef("Foo", DefDef("", List(TermParamClause(Nil)), Inferred(), None), List(Apply(Select(New(Inferred()), ""), Nil)), None, List(DefDef("a", Nil, Inferred(), Some(Literal(IntConstant(0))))))), Literal(UnitConstant()))) TypeRef(ThisType(TypeRef(NoPrefix(), "scala")), "Unit") -Inlined(None, Nil, Block(List(ClassDef("Foo", DefDef("", List(TermParamClause(Nil)), Inferred(), None), List(Apply(Select(New(Inferred()), ""), Nil), TypeSelect(Select(Ident("_root_"), "scala"), "Product"), TypeSelect(Select(Ident("_root_"), "scala"), "Serializable")), None, List(DefDef("hashCode", List(TermParamClause(Nil)), Inferred(), Some(Apply(Ident("_hashCode"), List(This(Some("Foo")))))), DefDef("equals", List(TermParamClause(List(ValDef("x$0", Inferred(), None)))), Inferred(), Some(Apply(Select(Apply(Select(This(Some("Foo")), "eq"), List(TypeApply(Select(Ident("x$0"), "$asInstanceOf$"), List(Inferred())))), "||"), List(Match(Ident("x$0"), List(CaseDef(Bind("x$0", Typed(Wildcard(), Inferred())), None, Apply(Select(Literal(BooleanConstant(true)), "&&"), List(Apply(Select(Ident("x$0"), "canEqual"), List(This(Some("Foo"))))))), CaseDef(Wildcard(), None, Literal(BooleanConstant(false))))))))), DefDef("toString", List(TermParamClause(Nil)), Inferred(), Some(Apply(Ident("_toString"), List(This(Some("Foo")))))), DefDef("canEqual", List(TermParamClause(List(ValDef("that", Inferred(), None)))), Inferred(), Some(TypeApply(Select(Ident("that"), "isInstanceOf"), List(Inferred())))), DefDef("productArity", Nil, Inferred(), Some(Literal(IntConstant(0)))), DefDef("productPrefix", Nil, Inferred(), Some(Literal(StringConstant("Foo")))), DefDef("productElement", List(TermParamClause(List(ValDef("n", Inferred(), None)))), Inferred(), Some(Match(Ident("n"), List(CaseDef(Wildcard(), None, Apply(Ident("throw"), List(Apply(Select(New(Inferred()), ""), List(Apply(Select(Ident("n"), "toString"), Nil)))))))))), DefDef("productElementName", List(TermParamClause(List(ValDef("n", Inferred(), None)))), Inferred(), Some(Match(Ident("n"), List(CaseDef(Wildcard(), None, Apply(Ident("throw"), List(Apply(Select(New(Inferred()), ""), List(Apply(Select(Ident("n"), "toString"), Nil)))))))))), DefDef("copy", List(TermParamClause(Nil)), Inferred(), Some(Apply(Select(New(Inferred()), ""), Nil))))), ValDef("Foo", TypeIdent("Foo$"), Some(Apply(Select(New(TypeIdent("Foo$")), ""), Nil))), ClassDef("Foo$", DefDef("", List(TermParamClause(Nil)), Inferred(), None), List(Apply(Select(New(Inferred()), ""), Nil), Inferred()), Some(ValDef("_", Singleton(Ident("Foo")), None)), List(DefDef("apply", List(TermParamClause(Nil)), Inferred(), Some(Apply(Select(New(Inferred()), ""), Nil))), DefDef("unapply", List(TermParamClause(List(ValDef("x$1", Inferred(), None)))), Singleton(Literal(BooleanConstant(true))), Some(Literal(BooleanConstant(true)))), DefDef("toString", Nil, Inferred(), Some(Literal(StringConstant("Foo")))), TypeDef("MirroredMonoType", TypeBoundsTree(Inferred(), Inferred())), DefDef("fromProduct", List(TermParamClause(List(ValDef("x$0", Inferred(), None)))), Inferred(), Some(Apply(Select(New(Inferred()), ""), Nil)))))), Literal(UnitConstant()))) +Inlined(None, Nil, Block(List(ClassDef("Foo", DefDef("", List(TermParamClause(Nil)), Inferred(), None), List(Apply(Select(New(Inferred()), ""), Nil), TypeSelect(Select(Ident("_root_"), "scala"), "Product"), TypeSelect(Select(Ident("_root_"), "scala"), "Serializable")), None, List(DefDef("hashCode", List(TermParamClause(Nil)), Inferred(), Some(Apply(Ident("_hashCode"), List(This(Some("Foo")))))), DefDef("equals", List(TermParamClause(List(ValDef("x$0", Inferred(), None)))), Inferred(), Some(Apply(Select(Apply(Select(This(Some("Foo")), "eq"), List(TypeApply(Select(Ident("x$0"), "$asInstanceOf$"), List(Inferred())))), "||"), List(Match(Ident("x$0"), List(CaseDef(Bind("x$0", Typed(Wildcard(), Inferred())), None, Apply(Select(Literal(BooleanConstant(true)), "&&"), List(Apply(Select(Ident("x$0"), "canEqual"), List(This(Some("Foo"))))))), CaseDef(Wildcard(), None, Literal(BooleanConstant(false))))))))), DefDef("toString", List(TermParamClause(Nil)), Inferred(), Some(Apply(Ident("_toString"), List(This(Some("Foo")))))), DefDef("canEqual", List(TermParamClause(List(ValDef("that", Inferred(), None)))), Inferred(), Some(TypeApply(Select(Ident("that"), "isInstanceOf"), List(Inferred())))), DefDef("productArity", Nil, Inferred(), Some(Literal(IntConstant(0)))), DefDef("productPrefix", Nil, Inferred(), Some(Literal(StringConstant("Foo")))), DefDef("productElement", List(TermParamClause(List(ValDef("n", Inferred(), None)))), Inferred(), Some(Match(Ident("n"), List(CaseDef(Wildcard(), None, Apply(Ident("throw"), List(Apply(Select(New(Inferred()), ""), List(Apply(Select(Ident("n"), "toString"), Nil)))))))))), DefDef("productElementName", List(TermParamClause(List(ValDef("n", Inferred(), None)))), Inferred(), Some(Match(Ident("n"), List(CaseDef(Wildcard(), None, Apply(Ident("throw"), List(Apply(Select(New(Inferred()), ""), List(Apply(Select(Ident("n"), "toString"), Nil)))))))))), DefDef("copy", List(TermParamClause(Nil)), Inferred(), Some(Apply(Select(New(Inferred()), ""), Nil))))), ValDef("Foo", TypeIdent("Foo$"), Some(Apply(Select(New(TypeIdent("Foo$")), ""), Nil))), ClassDef("Foo$", DefDef("", List(TermParamClause(Nil)), Inferred(), None), List(Apply(Select(New(Inferred()), ""), Nil), Inferred()), Some(ValDef("_", Singleton(Ident("Foo")), None)), List(DefDef("apply", List(TermParamClause(Nil)), Inferred(), Some(Apply(Select(New(Inferred()), ""), Nil))), DefDef("unapply", List(TermParamClause(List(ValDef("x$1", Inferred(), None)))), Singleton(Literal(BooleanConstant(true))), Some(Literal(BooleanConstant(true)))), DefDef("toString", Nil, Inferred(), Some(Literal(StringConstant("Foo")))), TypeDef("MirroredMonoType", TypeBoundsTree(Inferred(), Inferred())), DefDef("fromProduct", List(TermParamClause(List(ValDef("x$0", Inferred(), None)))), Inferred(), Some(Block(Nil, Apply(Select(New(Inferred()), ""), Nil))))))), Literal(UnitConstant()))) TypeRef(ThisType(TypeRef(NoPrefix(), "scala")), "Unit") Inlined(None, Nil, Block(List(ClassDef("Foo1", DefDef("", List(TermParamClause(List(ValDef("a", TypeIdent("Int"), None)))), Inferred(), None), List(Apply(Select(New(Inferred()), ""), Nil)), None, List(ValDef("a", Inferred(), None)))), Literal(UnitConstant()))) diff --git a/tests/run/i8073.scala b/tests/run/i8073.scala new file mode 100644 index 000000000000..6b5bfc3b9832 --- /dev/null +++ b/tests/run/i8073.scala @@ -0,0 +1,86 @@ +import scala.deriving.Mirror + +trait A: + type B + +object Test: + case class CC(a: A, b: a.B) + + def test1(): Unit = + val generic = summon[Mirror.Of[CC]] + // No language syntax for type projection of a singleton type + // summon[generic.MirroredElemTypes =:= (A, CC#a.B)] + + val aa: A { type B = Int } = new A { type B = Int } + val x: CC { val a: aa.type } = CC(aa, 1).asInstanceOf[CC { val a: aa.type }] // manual `tracked` + + val dependent = summon[Mirror.Of[x.type]] + summon[dependent.MirroredElemTypes =:= (A, x.a.B)] + + assert(CC(aa, 1) == generic.fromProduct((aa, 1))) + assert(CC(aa, 1) == dependent.fromProduct((aa, 1))) + + x match + case CC(a, b) => + val a1: A = a + // Dependent pattern matching is not currently supported + // val b1: a1.B = b + val b1 = b // Type is CC#a.B + + end test1 + + case class CCPoly[T <: A](a: T, b: a.B) + + def test2(): Unit = + val generic = summon[Mirror.Of[CCPoly[A]]] + // No language syntax for type projection of a singleton type + // summon[generic.MirroredElemTypes =:= (A, CCPoly[A]#a.B)] + + val aa: A { type B = Int } = new A { type B = Int } + val x: CCPoly[aa.type] = CCPoly(aa, 1) + + val dependent = summon[Mirror.Of[x.type]] + summon[dependent.MirroredElemTypes =:= (aa.type, x.a.B)] + + assert(CCPoly[A](aa, 1) == generic.fromProduct((aa, 1))) + assert(CCPoly[A](aa, 1) == dependent.fromProduct((aa, 1))) + + x match + case CCPoly(a, b) => + val a1: A = a + // Dependent pattern matching is not currently supported + // val b1: a1.B = b + val b1 = b // Type is CC#a.B + + end test2 + + enum Enum: + case EC(a: A, b: a.B) + + def test3(): Unit = + val generic = summon[Mirror.Of[Enum.EC]] + // No language syntax for type projection of a singleton type + // summon[generic.MirroredElemTypes =:= (A, Enum.EC#a.B)] + + val aa: A { type B = Int } = new A { type B = Int } + val x: Enum.EC { val a: aa.type } = Enum.EC(aa, 1).asInstanceOf[Enum.EC { val a: aa.type }] // manual `tracked` + + val dependent = summon[Mirror.Of[x.type]] + summon[dependent.MirroredElemTypes =:= (A, x.a.B)] + + assert(Enum.EC(aa, 1) == generic.fromProduct((aa, 1))) + assert(Enum.EC(aa, 1) == dependent.fromProduct((aa, 1))) + + x match + case Enum.EC(a, b) => + val a1: A = a + // Dependent pattern matching is not currently supported + // val b1: a1.B = b + val b1 = b // Type is Enum.EC#a.B + + end test3 + + def main(args: Array[String]): Unit = + test1() + test2() + test3() diff --git a/tests/run/i8073b.scala b/tests/run/i8073b.scala new file mode 100644 index 000000000000..cc85731d01df --- /dev/null +++ b/tests/run/i8073b.scala @@ -0,0 +1,86 @@ +import scala.deriving.Mirror + +trait A: + type B + +// Test local mirrors +@main def Test = + case class CC(a: A, b: a.B) + + def test1(): Unit = + val generic = summon[Mirror.Of[CC]] + // No language syntax for type projection of a singleton type + // summon[generic.MirroredElemTypes =:= (A, CC#a.B)] + + val aa: A { type B = Int } = new A { type B = Int } + val x: CC { val a: aa.type } = CC(aa, 1).asInstanceOf[CC { val a: aa.type }] // manual `tracked` + + val dependent = summon[Mirror.Of[x.type]] + summon[dependent.MirroredElemTypes =:= (A, x.a.B)] + + assert(CC(aa, 1) == generic.fromProduct((aa, 1))) + assert(CC(aa, 1) == dependent.fromProduct((aa, 1))) + + x match + case CC(a, b) => + val a1: A = a + // Dependent pattern matching is not currently supported + // val b1: a1.B = b + val b1 = b // Type is CC#a.B + + end test1 + + case class CCPoly[T <: A](a: T, b: a.B) + + def test2(): Unit = + val generic = summon[Mirror.Of[CCPoly[A]]] + // No language syntax for type projection of a singleton type + // summon[generic.MirroredElemTypes =:= (A, CCPoly[A]#a.B)] + + val aa: A { type B = Int } = new A { type B = Int } + val x: CCPoly[aa.type] = CCPoly(aa, 1) + + val dependent = summon[Mirror.Of[x.type]] + summon[dependent.MirroredElemTypes =:= (aa.type, x.a.B)] + + assert(CCPoly[A](aa, 1) == generic.fromProduct((aa, 1))) + assert(CCPoly[A](aa, 1) == dependent.fromProduct((aa, 1))) + + x match + case CCPoly(a, b) => + val a1: A = a + // Dependent pattern matching is not currently supported + // val b1: a1.B = b + val b1 = b // Type is CC#a.B + + end test2 + + enum Enum: + case EC(a: A, b: a.B) + + def test3(): Unit = + val generic = summon[Mirror.Of[Enum.EC]] + // No language syntax for type projection of a singleton type + // summon[generic.MirroredElemTypes =:= (A, Enum.EC#a.B)] + + val aa: A { type B = Int } = new A { type B = Int } + val x: Enum.EC { val a: aa.type } = Enum.EC(aa, 1).asInstanceOf[Enum.EC { val a: aa.type }] // manual `tracked` + + val dependent = summon[Mirror.Of[x.type]] + summon[dependent.MirroredElemTypes =:= (A, x.a.B)] + + assert(Enum.EC(aa, 1) == generic.fromProduct((aa, 1))) + assert(Enum.EC(aa, 1) == dependent.fromProduct((aa, 1))) + + x match + case Enum.EC(a, b) => + val a1: A = a + // Dependent pattern matching is not currently supported + // val b1: a1.B = b + val b1 = b // Type is Enum.EC#a.B + + end test3 + + test1() + test2() + test3() From 18bd314dfb16f846875f9a5ad52671e08ac41ca9 Mon Sep 17 00:00:00 2001 From: Guillaume Martres Date: Tue, 18 Feb 2025 20:48:31 +0100 Subject: [PATCH 3/4] Fix fragile transformation of fromProduct when using `@unroll` UnrollDefinitions assumed that the body of `fromProduct` had a specific shape which is no longer the case with the dependent case class support introduced in the previous commit. This caused compiler crashes for tests/run/unroll-multiple.scala and tests/run/unroll-caseclass-integration This commit fixes this by directly generating the correct fromProduct in SyntheticMembers. This should also prevent crashes in situations where code is injected into existing trees like the code coverage support or external compiler plugins. --- .../dotty/tools/dotc/core/Definitions.scala | 1 + .../src/dotty/tools/dotc/core/StdNames.scala | 1 + .../dotc/transform/SyntheticMembers.scala | 53 ++++++++++-- .../dotc/transform/UnrollDefinitions.scala | 86 ++++--------------- 4 files changed, 67 insertions(+), 74 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index 6e2e924edf65..49b8fbdd3f15 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -618,6 +618,7 @@ class Definitions { @tu lazy val Int_== : Symbol = IntClass.requiredMethod(nme.EQ, List(IntType)) @tu lazy val Int_>= : Symbol = IntClass.requiredMethod(nme.GE, List(IntType)) @tu lazy val Int_<= : Symbol = IntClass.requiredMethod(nme.LE, List(IntType)) + @tu lazy val Int_> : Symbol = IntClass.requiredMethod(nme.GT, List(IntType)) @tu lazy val LongType: TypeRef = valueTypeRef("scala.Long", java.lang.Long.TYPE, LongEnc, nme.specializedTypeNames.Long) def LongClass(using Context): ClassSymbol = LongType.symbol.asClass @tu lazy val Long_+ : Symbol = LongClass.requiredMethod(nme.PLUS, List(LongType)) diff --git a/compiler/src/dotty/tools/dotc/core/StdNames.scala b/compiler/src/dotty/tools/dotc/core/StdNames.scala index 56d71c7fb57e..90e5544f19af 100644 --- a/compiler/src/dotty/tools/dotc/core/StdNames.scala +++ b/compiler/src/dotty/tools/dotc/core/StdNames.scala @@ -425,6 +425,7 @@ object StdNames { val array_length : N = "array_length" val array_update : N = "array_update" val arraycopy: N = "arraycopy" + val arity: N = "arity" val as: N = "as" val asTerm: N = "asTerm" val asModule: N = "asModule" diff --git a/compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala b/compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala index 33fb2de8afe4..926a19224e79 100644 --- a/compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala +++ b/compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala @@ -523,21 +523,47 @@ class SyntheticMembers(thisPhase: DenotTransformer) { * ``` * type MirroredMonoType = C[?] * ``` + * + * However, if the last parameter is annotated `@unroll` then we generate: + * + * def fromProduct(x$0: Product): MirroredMonoType = + * val arity = x$0.productArity + * val a$1 = x$0.productElement(0).asInstanceOf[U] + * val b$1 = x$0.productElement(1).asInstanceOf[a$1.Elem] + * val c$1 = ( + * if arity > 2 then + * x$0.productElement(2) + * else + * + * ).asInstanceOf[Seq[String]] + * new C[U](a$1, b$1, c$1*) */ def fromProductBody(caseClass: Symbol, productParam: Tree, optInfo: Option[MirrorImpl.OfProduct])(using Context): Tree = val classRef = optInfo match case Some(info) => TypeRef(info.pre, caseClass) case _ => caseClass.typeRef - val (newPrefix, constrMeth) = + val (newPrefix, constrMeth, constrSyms) = val constr = TermRef(classRef, caseClass.primaryConstructor) + val symss = caseClass.primaryConstructor.paramSymss (constr.info: @unchecked) match case tl: PolyType => val tvars = constrained(tl) val targs = for tvar <- tvars yield tvar.instantiate(fromBelow = false) - (AppliedType(classRef, targs), tl.instantiate(targs).asInstanceOf[MethodType]) + (AppliedType(classRef, targs), tl.instantiate(targs).asInstanceOf[MethodType], symss(1)) case mt: MethodType => - (classRef, mt) + (classRef, mt, symss.head) + + // Index of the first parameter marked `@unroll` or -1 + val unrolledFrom = + constrSyms.indexWhere(_.hasAnnotation(defn.UnrollAnnot)) + + // `val arity = x$0.productArity` + val arityDef: Option[ValDef] = + if unrolledFrom != -1 then + Some(SyntheticValDef(nme.arity, productParam.select(defn.Product_productArity).withSpan(ctx.owner.span.focus))) + else None + val arityRefTree = arityDef.map(vd => ref(vd.symbol)) // Create symbols for the vals corresponding to each parameter // If there are dependent parameters, the infos won't be correct yet. @@ -550,16 +576,29 @@ class SyntheticMembers(thisPhase: DenotTransformer) { bindingSyms.foreach: bindingSym => bindingSym.info = bindingSym.info.substParams(constrMeth, bindingRefs) + def defaultGetterAtIndex(idx: Int): Tree = + val defaultGetterPrefix = caseClass.primaryConstructor.name.toTermName + ref(caseClass.companionModule).select(NameKinds.DefaultGetterName(defaultGetterPrefix, idx)) + val bindingDefs = bindingSyms.zipWithIndex.map: (bindingSym, idx) => - ValDef(bindingSym, - productParam.select(defn.Product_productElement).appliedTo(Literal(Constant(idx))) - .ensureConforms(bindingSym.info)) + val selection = productParam.select(defn.Product_productElement).appliedTo(Literal(Constant(idx))) + val rhs = ( + if unrolledFrom != -1 && idx >= unrolledFrom then + If(arityRefTree.get.select(defn.Int_>).appliedTo(Literal(Constant(idx))), + thenp = + selection, + elsep = + defaultGetterAtIndex(idx)) + else + selection + ).ensureConforms(bindingSym.info) + ValDef(bindingSym, rhs) val newArgs = bindingRefs.lazyZip(constrMeth.paramInfos).map: (bindingRef, paramInfo) => val refTree = ref(bindingRef) if paramInfo.isRepeatedParam then ctx.typer.seqToRepeated(refTree) else refTree Block( - bindingDefs, + arityDef.toList ::: bindingDefs, New(newPrefix, newArgs) ) end fromProductBody diff --git a/compiler/src/dotty/tools/dotc/transform/UnrollDefinitions.scala b/compiler/src/dotty/tools/dotc/transform/UnrollDefinitions.scala index b431a81afeac..44379b88bf16 100644 --- a/compiler/src/dotty/tools/dotc/transform/UnrollDefinitions.scala +++ b/compiler/src/dotty/tools/dotc/transform/UnrollDefinitions.scala @@ -228,46 +228,9 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer { forwarderDef } - private def generateFromProduct(startParamIndices: List[Int], paramCount: Int, defdef: DefDef)(using Context) = { - cpy.DefDef(defdef)( - name = defdef.name, - paramss = defdef.paramss, - tpt = defdef.tpt, - rhs = Match( - ref(defdef.paramss.head.head.asInstanceOf[ValDef].symbol).select(termName("productArity")), - startParamIndices.map { paramIndex => - val Apply(select, args) = defdef.rhs: @unchecked - CaseDef( - Literal(Constant(paramIndex)), - EmptyTree, - Apply( - select, - args.take(paramIndex) ++ - Range(paramIndex, paramCount).map(n => - ref(defdef.symbol.owner.companionModule) - .select(DefaultGetterName(defdef.symbol.owner.primaryConstructor.name.toTermName, n)) - ) - ) - ) - } :+ CaseDef( - Underscore(defn.IntType), - EmptyTree, - defdef.rhs - ) - ) - ).setDefTree - } - - private enum Gen: - case Substitute(origin: Symbol, newDef: DefDef) - case Forwarders(origin: Symbol, forwarders: List[DefDef]) + case class Forwarders(origin: Symbol, forwarders: List[DefDef]) - def origin: Symbol - def extras: List[DefDef] = this match - case Substitute(_, d) => d :: Nil - case Forwarders(_, ds) => ds - - private def generateSyntheticDefs(tree: Tree, compute: ComputeIndices)(using Context): Option[Gen] = tree match { + private def generateSyntheticDefs(tree: Tree, compute: ComputeIndices)(using Context): Option[Forwarders] = tree match { case defdef: DefDef if defdef.paramss.nonEmpty => import dotty.tools.dotc.core.NameOps.isConstructorName @@ -277,38 +240,29 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer { val isCaseApply = defdef.name == nme.apply && defdef.symbol.owner.companionClass.is(CaseClass) - val isCaseFromProduct = defdef.name == nme.fromProduct && defdef.symbol.owner.companionClass.is(CaseClass) - val annotated = if (isCaseCopy) defdef.symbol.owner.primaryConstructor else if (isCaseApply) defdef.symbol.owner.companionClass.primaryConstructor - else if (isCaseFromProduct) defdef.symbol.owner.companionClass.primaryConstructor else defdef.symbol compute(annotated) match { case Nil => None case (paramClauseIndex, annotationIndices) :: Nil => val paramCount = annotated.paramSymss(paramClauseIndex).size - if isCaseFromProduct then - Some(Gen.Substitute( - origin = defdef.symbol, - newDef = generateFromProduct(annotationIndices, paramCount, defdef) - )) - else - val generatedDefs = - val indices = (annotationIndices :+ paramCount).sliding(2).toList.reverse - indices.foldLeft(List.empty[DefDef]): - case (defdefs, paramIndex :: nextParamIndex :: Nil) => - generateSingleForwarder( - defdef, - paramIndex, - paramCount, - nextParamIndex, - paramClauseIndex, - isCaseApply - ) :: defdefs - case _ => unreachable("sliding with at least 2 elements") - Some(Gen.Forwarders(origin = defdef.symbol, forwarders = generatedDefs)) + val generatedDefs = + val indices = (annotationIndices :+ paramCount).sliding(2).toList.reverse + indices.foldLeft(List.empty[DefDef]): + case (defdefs, paramIndex :: nextParamIndex :: Nil) => + generateSingleForwarder( + defdef, + paramIndex, + paramCount, + nextParamIndex, + paramClauseIndex, + isCaseApply + ) :: defdefs + case _ => unreachable("sliding with at least 2 elements") + Some(Forwarders(origin = defdef.symbol, forwarders = generatedDefs)) case multiple => report.error("Cannot have multiple parameter lists containing `@unroll` annotation", defdef.srcPos) @@ -323,14 +277,12 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer { val generatedBody = tmpl.body.flatMap(generateSyntheticDefs(_, compute)) val generatedConstr0 = generateSyntheticDefs(tmpl.constr, compute) val allGenerated = generatedBody ++ generatedConstr0 - val bodySubs = generatedBody.collect({ case s: Gen.Substitute => s.origin }).toSet - val otherDecls = tmpl.body.filterNot(d => d.symbol.exists && bodySubs(d.symbol)) if allGenerated.nonEmpty then - val byName = (tmpl.constr :: otherDecls).groupMap(_.symbol.name.toString)(_.symbol) + val byName = (tmpl.constr :: tmpl.body).groupMap(_.symbol.name.toString)(_.symbol) for syntheticDefs <- allGenerated - dcl <- syntheticDefs.extras + dcl <- syntheticDefs.forwarders do val replaced = dcl.symbol byName.get(dcl.name.toString).foreach { syms => @@ -348,7 +300,7 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer { tmpl.parents, tmpl.derived, tmpl.self, - otherDecls ++ allGenerated.flatMap(_.extras) + tmpl.body ++ allGenerated.flatMap(_.forwarders) ) } From 246793a354642d0c9ef87b1ca41caa51dd0d8314 Mon Sep 17 00:00:00 2001 From: Guillaume Martres Date: Wed, 19 Feb 2025 20:19:12 +0100 Subject: [PATCH 4/4] Fix Ycheck false-positive in the compiler after previous commit This accounts for singletons wrapping an ErasedValueType. --- compiler/src/dotty/tools/dotc/transform/FirstTransform.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compiler/src/dotty/tools/dotc/transform/FirstTransform.scala b/compiler/src/dotty/tools/dotc/transform/FirstTransform.scala index c66e6b9471cb..8d01d2415340 100644 --- a/compiler/src/dotty/tools/dotc/transform/FirstTransform.scala +++ b/compiler/src/dotty/tools/dotc/transform/FirstTransform.scala @@ -62,7 +62,7 @@ class FirstTransform extends MiniPhase with SymTransformer { thisPhase => case Select(qual, name) if !name.is(OuterSelectName) && tree.symbol.exists => val qualTpe = qual.tpe assert( - qualTpe.isErasedValueType || qualTpe.derivesFrom(tree.symbol.owner) || + qualTpe.widenDealias.isErasedValueType || qualTpe.derivesFrom(tree.symbol.owner) || tree.symbol.is(JavaStatic) && qualTpe.derivesFrom(tree.symbol.enclosingClass), i"non member selection of ${tree.symbol.showLocated} from ${qualTpe} in $tree") case _: TypeTree =>