Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minimal support for dependent case classes #21698

Merged
merged 4 commits into from
Feb 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,7 @@ class Definitions {
@tu lazy val Int_== : Symbol = IntClass.requiredMethod(nme.EQ, List(IntType))
@tu lazy val Int_>= : Symbol = IntClass.requiredMethod(nme.GE, List(IntType))
@tu lazy val Int_<= : Symbol = IntClass.requiredMethod(nme.LE, List(IntType))
@tu lazy val Int_> : Symbol = IntClass.requiredMethod(nme.GT, List(IntType))
@tu lazy val LongType: TypeRef = valueTypeRef("scala.Long", java.lang.Long.TYPE, LongEnc, nme.specializedTypeNames.Long)
def LongClass(using Context): ClassSymbol = LongType.symbol.asClass
@tu lazy val Long_+ : Symbol = LongClass.requiredMethod(nme.PLUS, List(LongType))
Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/core/StdNames.scala
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,7 @@ object StdNames {
val array_length : N = "array_length"
val array_update : N = "array_update"
val arraycopy: N = "arraycopy"
val arity: N = "arity"
val as: N = "as"
val asTerm: N = "asTerm"
val asModule: N = "asModule"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class FirstTransform extends MiniPhase with SymTransformer { thisPhase =>
case Select(qual, name) if !name.is(OuterSelectName) && tree.symbol.exists =>
val qualTpe = qual.tpe
assert(
qualTpe.isErasedValueType || qualTpe.derivesFrom(tree.symbol.owner) ||
qualTpe.widenDealias.isErasedValueType || qualTpe.derivesFrom(tree.symbol.owner) ||
tree.symbol.is(JavaStatic) && qualTpe.derivesFrom(tree.symbol.enclosingClass),
i"non member selection of ${tree.symbol.showLocated} from ${qualTpe} in $tree")
case _: TypeTree =>
Expand Down
112 changes: 81 additions & 31 deletions compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -504,53 +504,103 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
/** The class
*
* ```
* case class C[T <: U](x: T, y: String*)
* trait U:
* type Elem
*
* case class C[T <: U](a: T, b: a.Elem, c: String*)
* ```
*
* gets the `fromProduct` method:
*
* ```
* def fromProduct(x$0: Product): MirroredMonoType =
* new C[U](
* x$0.productElement(0).asInstanceOf[U],
* x$0.productElement(1).asInstanceOf[Seq[String]]: _*)
* val a$1 = x$0.productElement(0).asInstanceOf[U]
* val b$1 = x$0.productElement(1).asInstanceOf[a$1.Elem]
* val c$1 = x$0.productElement(2).asInstanceOf[Seq[String]]
* new C[U](a$1, b$1, c$1*)
* ```
* where
* ```
* type MirroredMonoType = C[?]
* ```
*
* However, if the last parameter is annotated `@unroll` then we generate:
*
* def fromProduct(x$0: Product): MirroredMonoType =
* val arity = x$0.productArity
* val a$1 = x$0.productElement(0).asInstanceOf[U]
* val b$1 = x$0.productElement(1).asInstanceOf[a$1.Elem]
* val c$1 = (
* if arity > 2 then
* x$0.productElement(2)
* else
* <default getter for the third parameter of C>
* ).asInstanceOf[Seq[String]]
* new C[U](a$1, b$1, c$1*)
*/
def fromProductBody(caseClass: Symbol, param: Tree, optInfo: Option[MirrorImpl.OfProduct])(using Context): Tree =
def extractParams(tpe: Type): List[Type] =
tpe.asInstanceOf[MethodType].paramInfos

def computeFromCaseClass: (Type, List[Type]) =
val (baseRef, baseInfo) =
val rawRef = caseClass.typeRef
val rawInfo = caseClass.primaryConstructor.info
optInfo match
case Some(info) =>
(rawRef.asSeenFrom(info.pre, caseClass.owner), rawInfo.asSeenFrom(info.pre, caseClass.owner))
case _ =>
(rawRef, rawInfo)
baseInfo match
def fromProductBody(caseClass: Symbol, productParam: Tree, optInfo: Option[MirrorImpl.OfProduct])(using Context): Tree =
val classRef = optInfo match
case Some(info) => TypeRef(info.pre, caseClass)
case _ => caseClass.typeRef
val (newPrefix, constrMeth, constrSyms) =
val constr = TermRef(classRef, caseClass.primaryConstructor)
val symss = caseClass.primaryConstructor.paramSymss
(constr.info: @unchecked) match
case tl: PolyType =>
val tvars = constrained(tl)
val targs = for tvar <- tvars yield
tvar.instantiate(fromBelow = false)
(baseRef.appliedTo(targs), extractParams(tl.instantiate(targs)))
case methTpe =>
(baseRef, extractParams(methTpe))
end computeFromCaseClass

val (classRefApplied, paramInfos) = computeFromCaseClass
val elems =
for ((formal, idx) <- paramInfos.zipWithIndex) yield
val elem =
param.select(defn.Product_productElement).appliedTo(Literal(Constant(idx)))
.ensureConforms(formal.translateFromRepeated(toArray = false))
if (formal.isRepeatedParam) ctx.typer.seqToRepeated(elem) else elem
New(classRefApplied, elems)
(AppliedType(classRef, targs), tl.instantiate(targs).asInstanceOf[MethodType], symss(1))
case mt: MethodType =>
(classRef, mt, symss.head)

// Index of the first parameter marked `@unroll` or -1
val unrolledFrom =
constrSyms.indexWhere(_.hasAnnotation(defn.UnrollAnnot))

// `val arity = x$0.productArity`
val arityDef: Option[ValDef] =
if unrolledFrom != -1 then
Some(SyntheticValDef(nme.arity, productParam.select(defn.Product_productArity).withSpan(ctx.owner.span.focus)))
else None
val arityRefTree = arityDef.map(vd => ref(vd.symbol))

// Create symbols for the vals corresponding to each parameter
// If there are dependent parameters, the infos won't be correct yet.
val bindingSyms = constrMeth.paramRefs.map: pref =>
newSymbol(ctx.owner, pref.paramName.freshened, Synthetic,
pref.underlying.translateFromRepeated(toArray = false), coord = ctx.owner.span.focus)
val bindingRefs = bindingSyms.map(TermRef(NoPrefix, _))
// Fix the infos for dependent parameters
if constrMeth.isParamDependent then
bindingSyms.foreach: bindingSym =>
bindingSym.info = bindingSym.info.substParams(constrMeth, bindingRefs)

def defaultGetterAtIndex(idx: Int): Tree =
val defaultGetterPrefix = caseClass.primaryConstructor.name.toTermName
ref(caseClass.companionModule).select(NameKinds.DefaultGetterName(defaultGetterPrefix, idx))

val bindingDefs = bindingSyms.zipWithIndex.map: (bindingSym, idx) =>
val selection = productParam.select(defn.Product_productElement).appliedTo(Literal(Constant(idx)))
val rhs = (
if unrolledFrom != -1 && idx >= unrolledFrom then
If(arityRefTree.get.select(defn.Int_>).appliedTo(Literal(Constant(idx))),
thenp =
selection,
elsep =
defaultGetterAtIndex(idx))
else
selection
).ensureConforms(bindingSym.info)
ValDef(bindingSym, rhs)

val newArgs = bindingRefs.lazyZip(constrMeth.paramInfos).map: (bindingRef, paramInfo) =>
val refTree = ref(bindingRef)
if paramInfo.isRepeatedParam then ctx.typer.seqToRepeated(refTree) else refTree
Block(
arityDef.toList ::: bindingDefs,
New(newPrefix, newArgs)
)
end fromProductBody

/** For an enum T:
Expand Down
86 changes: 19 additions & 67 deletions compiler/src/dotty/tools/dotc/transform/UnrollDefinitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -228,46 +228,9 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {
forwarderDef
}

private def generateFromProduct(startParamIndices: List[Int], paramCount: Int, defdef: DefDef)(using Context) = {
cpy.DefDef(defdef)(
name = defdef.name,
paramss = defdef.paramss,
tpt = defdef.tpt,
rhs = Match(
ref(defdef.paramss.head.head.asInstanceOf[ValDef].symbol).select(termName("productArity")),
startParamIndices.map { paramIndex =>
val Apply(select, args) = defdef.rhs: @unchecked
CaseDef(
Literal(Constant(paramIndex)),
EmptyTree,
Apply(
select,
args.take(paramIndex) ++
Range(paramIndex, paramCount).map(n =>
ref(defdef.symbol.owner.companionModule)
.select(DefaultGetterName(defdef.symbol.owner.primaryConstructor.name.toTermName, n))
)
)
)
} :+ CaseDef(
Underscore(defn.IntType),
EmptyTree,
defdef.rhs
)
)
).setDefTree
}

private enum Gen:
case Substitute(origin: Symbol, newDef: DefDef)
case Forwarders(origin: Symbol, forwarders: List[DefDef])
case class Forwarders(origin: Symbol, forwarders: List[DefDef])

def origin: Symbol
def extras: List[DefDef] = this match
case Substitute(_, d) => d :: Nil
case Forwarders(_, ds) => ds

private def generateSyntheticDefs(tree: Tree, compute: ComputeIndices)(using Context): Option[Gen] = tree match {
private def generateSyntheticDefs(tree: Tree, compute: ComputeIndices)(using Context): Option[Forwarders] = tree match {
case defdef: DefDef if defdef.paramss.nonEmpty =>
import dotty.tools.dotc.core.NameOps.isConstructorName

Expand All @@ -277,38 +240,29 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {
val isCaseApply =
defdef.name == nme.apply && defdef.symbol.owner.companionClass.is(CaseClass)

val isCaseFromProduct = defdef.name == nme.fromProduct && defdef.symbol.owner.companionClass.is(CaseClass)

val annotated =
if (isCaseCopy) defdef.symbol.owner.primaryConstructor
else if (isCaseApply) defdef.symbol.owner.companionClass.primaryConstructor
else if (isCaseFromProduct) defdef.symbol.owner.companionClass.primaryConstructor
else defdef.symbol

compute(annotated) match {
case Nil => None
case (paramClauseIndex, annotationIndices) :: Nil =>
val paramCount = annotated.paramSymss(paramClauseIndex).size
if isCaseFromProduct then
Some(Gen.Substitute(
origin = defdef.symbol,
newDef = generateFromProduct(annotationIndices, paramCount, defdef)
))
else
val generatedDefs =
val indices = (annotationIndices :+ paramCount).sliding(2).toList.reverse
indices.foldLeft(List.empty[DefDef]):
case (defdefs, paramIndex :: nextParamIndex :: Nil) =>
generateSingleForwarder(
defdef,
paramIndex,
paramCount,
nextParamIndex,
paramClauseIndex,
isCaseApply
) :: defdefs
case _ => unreachable("sliding with at least 2 elements")
Some(Gen.Forwarders(origin = defdef.symbol, forwarders = generatedDefs))
val generatedDefs =
val indices = (annotationIndices :+ paramCount).sliding(2).toList.reverse
indices.foldLeft(List.empty[DefDef]):
case (defdefs, paramIndex :: nextParamIndex :: Nil) =>
generateSingleForwarder(
defdef,
paramIndex,
paramCount,
nextParamIndex,
paramClauseIndex,
isCaseApply
) :: defdefs
case _ => unreachable("sliding with at least 2 elements")
Some(Forwarders(origin = defdef.symbol, forwarders = generatedDefs))

case multiple =>
report.error("Cannot have multiple parameter lists containing `@unroll` annotation", defdef.srcPos)
Expand All @@ -323,14 +277,12 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {
val generatedBody = tmpl.body.flatMap(generateSyntheticDefs(_, compute))
val generatedConstr0 = generateSyntheticDefs(tmpl.constr, compute)
val allGenerated = generatedBody ++ generatedConstr0
val bodySubs = generatedBody.collect({ case s: Gen.Substitute => s.origin }).toSet
val otherDecls = tmpl.body.filterNot(d => d.symbol.exists && bodySubs(d.symbol))

if allGenerated.nonEmpty then
val byName = (tmpl.constr :: otherDecls).groupMap(_.symbol.name.toString)(_.symbol)
val byName = (tmpl.constr :: tmpl.body).groupMap(_.symbol.name.toString)(_.symbol)
for
syntheticDefs <- allGenerated
dcl <- syntheticDefs.extras
dcl <- syntheticDefs.forwarders
do
val replaced = dcl.symbol
byName.get(dcl.name.toString).foreach { syms =>
Expand All @@ -348,7 +300,7 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {
tmpl.parents,
tmpl.derived,
tmpl.self,
otherDecls ++ allGenerated.flatMap(_.extras)
tmpl.body ++ allGenerated.flatMap(_.forwarders)
)
}

Expand Down
14 changes: 1 addition & 13 deletions compiler/src/dotty/tools/dotc/typer/Namer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1957,9 +1957,7 @@ class Namer { typer: Typer =>
if isConstructor then
// set result type tree to unit, but take the current class as result type of the symbol
typedAheadType(ddef.tpt, defn.UnitType)
val mt = wrapMethType(effectiveResultType(sym, paramSymss))
if sym.isPrimaryConstructor then checkCaseClassParamDependencies(mt, sym.owner)
mt
wrapMethType(effectiveResultType(sym, paramSymss))
else
val paramFn = if Feature.enabled(Feature.modularity) && sym.isAllOf(Given | Method) then wrapRefinedMethType else wrapMethType
valOrDefDefSig(ddef, sym, paramSymss, paramFn)
Expand Down Expand Up @@ -2001,16 +1999,6 @@ class Namer { typer: Typer =>
ddef.trailingParamss.foreach(completeParams)
end completeTrailingParamss

/** Checks an implementation restriction on case classes. */
def checkCaseClassParamDependencies(mt: Type, cls: Symbol)(using Context): Unit =
mt.stripPoly match
case mt: MethodType if cls.is(Case) && mt.isParamDependent =>
// See issue #8073 for background
report.error(
em"""Implementation restriction: case classes cannot have dependencies between parameters""",
cls.srcPos)
case _ =>

private def setParamTrackedWithAccessors(psym: Symbol, ownerTpe: Type)(using Context): Unit =
for acc <- ownerTpe.decls.lookupAll(psym.name) if acc.is(ParamAccessor) do
acc.resetFlag(PrivateLocal)
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3196,7 +3196,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
.withType(dummy.termRef)
if (!cls.isOneOf(AbstractOrTrait) && !ctx.isAfterTyper)
checkRealizableBounds(cls, cdef.sourcePos.withSpan(cdef.nameSpan))
if cls.isEnum || firstParentTpe.classSymbol.isEnum then
if cls.isEnum || !cls.isRefinementClass && firstParentTpe.classSymbol.isEnum then
checkEnum(cdef, cls, firstParent)
val cdef1 = assignType(cpy.TypeDef(cdef)(name, impl1), cls)

Expand Down
8 changes: 0 additions & 8 deletions tests/neg/i8069.scala

This file was deleted.

12 changes: 12 additions & 0 deletions tests/pos/enum-refinement.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
enum Enum:
case EC(val x: Int)

val a: Enum.EC { val x: 1 } = Enum.EC(1).asInstanceOf[Enum.EC { val x: 1 }]

import scala.language.experimental.modularity

enum EnumT:
case EC(tracked val x: Int)

val b: EnumT.EC { val x: 1 } = EnumT.EC(1)

Loading
Loading