Skip to content

Commit

Permalink
Minimal support for dependent case classes (#21698)
Browse files Browse the repository at this point in the history
This lets us write:

```scala
    trait A:
      type B

    case class CC(a: A, b: a.B)
```

Pattern matching works but isn't dependent yet:

```scala
    x match
      case CC(a, b) =>
        val a1: A = a
        // Dependent pattern matching is not currently supported
        // val b1: a1.B = b
        val b1 = b // Type is CC#a.B
```

(for my usecase this isn't a problem, I'm working on a type constraint
API which lets me write things like `case class CC(a: Int, b: Int
GreaterThan[a.type])`)

Because case class pattern matching relies on the product selectors
`_N`, making it dependent is a bit tricky, currently we generate:

```scala
    case class CC(a: A, b: a.B):
      def _1: A = a
      def _2: a.B = b
```

So the type of `_2` is not obviously related to the type of `_1`, we
probably need to change what we generate into:

```scala
    case class CC(a: A, b: a.B):
      @uncheckedStable def _1: a.type = a
      def _2: _1.B = b
```

But this can be done in a separate PR.

Fixes #8073.
  • Loading branch information
smarter authored Feb 19, 2025
2 parents 5929c87 + 246793a commit 345b2da
Show file tree
Hide file tree
Showing 12 changed files with 290 additions and 122 deletions.
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,7 @@ class Definitions {
@tu lazy val Int_== : Symbol = IntClass.requiredMethod(nme.EQ, List(IntType))
@tu lazy val Int_>= : Symbol = IntClass.requiredMethod(nme.GE, List(IntType))
@tu lazy val Int_<= : Symbol = IntClass.requiredMethod(nme.LE, List(IntType))
@tu lazy val Int_> : Symbol = IntClass.requiredMethod(nme.GT, List(IntType))
@tu lazy val LongType: TypeRef = valueTypeRef("scala.Long", java.lang.Long.TYPE, LongEnc, nme.specializedTypeNames.Long)
def LongClass(using Context): ClassSymbol = LongType.symbol.asClass
@tu lazy val Long_+ : Symbol = LongClass.requiredMethod(nme.PLUS, List(LongType))
Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/core/StdNames.scala
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,7 @@ object StdNames {
val array_length : N = "array_length"
val array_update : N = "array_update"
val arraycopy: N = "arraycopy"
val arity: N = "arity"
val as: N = "as"
val asTerm: N = "asTerm"
val asModule: N = "asModule"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class FirstTransform extends MiniPhase with SymTransformer { thisPhase =>
case Select(qual, name) if !name.is(OuterSelectName) && tree.symbol.exists =>
val qualTpe = qual.tpe
assert(
qualTpe.isErasedValueType || qualTpe.derivesFrom(tree.symbol.owner) ||
qualTpe.widenDealias.isErasedValueType || qualTpe.derivesFrom(tree.symbol.owner) ||
tree.symbol.is(JavaStatic) && qualTpe.derivesFrom(tree.symbol.enclosingClass),
i"non member selection of ${tree.symbol.showLocated} from ${qualTpe} in $tree")
case _: TypeTree =>
Expand Down
112 changes: 81 additions & 31 deletions compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -504,53 +504,103 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
/** The class
*
* ```
* case class C[T <: U](x: T, y: String*)
* trait U:
* type Elem
*
* case class C[T <: U](a: T, b: a.Elem, c: String*)
* ```
*
* gets the `fromProduct` method:
*
* ```
* def fromProduct(x$0: Product): MirroredMonoType =
* new C[U](
* x$0.productElement(0).asInstanceOf[U],
* x$0.productElement(1).asInstanceOf[Seq[String]]: _*)
* val a$1 = x$0.productElement(0).asInstanceOf[U]
* val b$1 = x$0.productElement(1).asInstanceOf[a$1.Elem]
* val c$1 = x$0.productElement(2).asInstanceOf[Seq[String]]
* new C[U](a$1, b$1, c$1*)
* ```
* where
* ```
* type MirroredMonoType = C[?]
* ```
*
* However, if the last parameter is annotated `@unroll` then we generate:
*
* def fromProduct(x$0: Product): MirroredMonoType =
* val arity = x$0.productArity
* val a$1 = x$0.productElement(0).asInstanceOf[U]
* val b$1 = x$0.productElement(1).asInstanceOf[a$1.Elem]
* val c$1 = (
* if arity > 2 then
* x$0.productElement(2)
* else
* <default getter for the third parameter of C>
* ).asInstanceOf[Seq[String]]
* new C[U](a$1, b$1, c$1*)
*/
def fromProductBody(caseClass: Symbol, param: Tree, optInfo: Option[MirrorImpl.OfProduct])(using Context): Tree =
def extractParams(tpe: Type): List[Type] =
tpe.asInstanceOf[MethodType].paramInfos

def computeFromCaseClass: (Type, List[Type]) =
val (baseRef, baseInfo) =
val rawRef = caseClass.typeRef
val rawInfo = caseClass.primaryConstructor.info
optInfo match
case Some(info) =>
(rawRef.asSeenFrom(info.pre, caseClass.owner), rawInfo.asSeenFrom(info.pre, caseClass.owner))
case _ =>
(rawRef, rawInfo)
baseInfo match
def fromProductBody(caseClass: Symbol, productParam: Tree, optInfo: Option[MirrorImpl.OfProduct])(using Context): Tree =
val classRef = optInfo match
case Some(info) => TypeRef(info.pre, caseClass)
case _ => caseClass.typeRef
val (newPrefix, constrMeth, constrSyms) =
val constr = TermRef(classRef, caseClass.primaryConstructor)
val symss = caseClass.primaryConstructor.paramSymss
(constr.info: @unchecked) match
case tl: PolyType =>
val tvars = constrained(tl)
val targs = for tvar <- tvars yield
tvar.instantiate(fromBelow = false)
(baseRef.appliedTo(targs), extractParams(tl.instantiate(targs)))
case methTpe =>
(baseRef, extractParams(methTpe))
end computeFromCaseClass

val (classRefApplied, paramInfos) = computeFromCaseClass
val elems =
for ((formal, idx) <- paramInfos.zipWithIndex) yield
val elem =
param.select(defn.Product_productElement).appliedTo(Literal(Constant(idx)))
.ensureConforms(formal.translateFromRepeated(toArray = false))
if (formal.isRepeatedParam) ctx.typer.seqToRepeated(elem) else elem
New(classRefApplied, elems)
(AppliedType(classRef, targs), tl.instantiate(targs).asInstanceOf[MethodType], symss(1))
case mt: MethodType =>
(classRef, mt, symss.head)

// Index of the first parameter marked `@unroll` or -1
val unrolledFrom =
constrSyms.indexWhere(_.hasAnnotation(defn.UnrollAnnot))

// `val arity = x$0.productArity`
val arityDef: Option[ValDef] =
if unrolledFrom != -1 then
Some(SyntheticValDef(nme.arity, productParam.select(defn.Product_productArity).withSpan(ctx.owner.span.focus)))
else None
val arityRefTree = arityDef.map(vd => ref(vd.symbol))

// Create symbols for the vals corresponding to each parameter
// If there are dependent parameters, the infos won't be correct yet.
val bindingSyms = constrMeth.paramRefs.map: pref =>
newSymbol(ctx.owner, pref.paramName.freshened, Synthetic,
pref.underlying.translateFromRepeated(toArray = false), coord = ctx.owner.span.focus)
val bindingRefs = bindingSyms.map(TermRef(NoPrefix, _))
// Fix the infos for dependent parameters
if constrMeth.isParamDependent then
bindingSyms.foreach: bindingSym =>
bindingSym.info = bindingSym.info.substParams(constrMeth, bindingRefs)

def defaultGetterAtIndex(idx: Int): Tree =
val defaultGetterPrefix = caseClass.primaryConstructor.name.toTermName
ref(caseClass.companionModule).select(NameKinds.DefaultGetterName(defaultGetterPrefix, idx))

val bindingDefs = bindingSyms.zipWithIndex.map: (bindingSym, idx) =>
val selection = productParam.select(defn.Product_productElement).appliedTo(Literal(Constant(idx)))
val rhs = (
if unrolledFrom != -1 && idx >= unrolledFrom then
If(arityRefTree.get.select(defn.Int_>).appliedTo(Literal(Constant(idx))),
thenp =
selection,
elsep =
defaultGetterAtIndex(idx))
else
selection
).ensureConforms(bindingSym.info)
ValDef(bindingSym, rhs)

val newArgs = bindingRefs.lazyZip(constrMeth.paramInfos).map: (bindingRef, paramInfo) =>
val refTree = ref(bindingRef)
if paramInfo.isRepeatedParam then ctx.typer.seqToRepeated(refTree) else refTree
Block(
arityDef.toList ::: bindingDefs,
New(newPrefix, newArgs)
)
end fromProductBody

/** For an enum T:
Expand Down
86 changes: 19 additions & 67 deletions compiler/src/dotty/tools/dotc/transform/UnrollDefinitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -228,46 +228,9 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {
forwarderDef
}

private def generateFromProduct(startParamIndices: List[Int], paramCount: Int, defdef: DefDef)(using Context) = {
cpy.DefDef(defdef)(
name = defdef.name,
paramss = defdef.paramss,
tpt = defdef.tpt,
rhs = Match(
ref(defdef.paramss.head.head.asInstanceOf[ValDef].symbol).select(termName("productArity")),
startParamIndices.map { paramIndex =>
val Apply(select, args) = defdef.rhs: @unchecked
CaseDef(
Literal(Constant(paramIndex)),
EmptyTree,
Apply(
select,
args.take(paramIndex) ++
Range(paramIndex, paramCount).map(n =>
ref(defdef.symbol.owner.companionModule)
.select(DefaultGetterName(defdef.symbol.owner.primaryConstructor.name.toTermName, n))
)
)
)
} :+ CaseDef(
Underscore(defn.IntType),
EmptyTree,
defdef.rhs
)
)
).setDefTree
}

private enum Gen:
case Substitute(origin: Symbol, newDef: DefDef)
case Forwarders(origin: Symbol, forwarders: List[DefDef])
case class Forwarders(origin: Symbol, forwarders: List[DefDef])

def origin: Symbol
def extras: List[DefDef] = this match
case Substitute(_, d) => d :: Nil
case Forwarders(_, ds) => ds

private def generateSyntheticDefs(tree: Tree, compute: ComputeIndices)(using Context): Option[Gen] = tree match {
private def generateSyntheticDefs(tree: Tree, compute: ComputeIndices)(using Context): Option[Forwarders] = tree match {
case defdef: DefDef if defdef.paramss.nonEmpty =>
import dotty.tools.dotc.core.NameOps.isConstructorName

Expand All @@ -277,38 +240,29 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {
val isCaseApply =
defdef.name == nme.apply && defdef.symbol.owner.companionClass.is(CaseClass)

val isCaseFromProduct = defdef.name == nme.fromProduct && defdef.symbol.owner.companionClass.is(CaseClass)

val annotated =
if (isCaseCopy) defdef.symbol.owner.primaryConstructor
else if (isCaseApply) defdef.symbol.owner.companionClass.primaryConstructor
else if (isCaseFromProduct) defdef.symbol.owner.companionClass.primaryConstructor
else defdef.symbol

compute(annotated) match {
case Nil => None
case (paramClauseIndex, annotationIndices) :: Nil =>
val paramCount = annotated.paramSymss(paramClauseIndex).size
if isCaseFromProduct then
Some(Gen.Substitute(
origin = defdef.symbol,
newDef = generateFromProduct(annotationIndices, paramCount, defdef)
))
else
val generatedDefs =
val indices = (annotationIndices :+ paramCount).sliding(2).toList.reverse
indices.foldLeft(List.empty[DefDef]):
case (defdefs, paramIndex :: nextParamIndex :: Nil) =>
generateSingleForwarder(
defdef,
paramIndex,
paramCount,
nextParamIndex,
paramClauseIndex,
isCaseApply
) :: defdefs
case _ => unreachable("sliding with at least 2 elements")
Some(Gen.Forwarders(origin = defdef.symbol, forwarders = generatedDefs))
val generatedDefs =
val indices = (annotationIndices :+ paramCount).sliding(2).toList.reverse
indices.foldLeft(List.empty[DefDef]):
case (defdefs, paramIndex :: nextParamIndex :: Nil) =>
generateSingleForwarder(
defdef,
paramIndex,
paramCount,
nextParamIndex,
paramClauseIndex,
isCaseApply
) :: defdefs
case _ => unreachable("sliding with at least 2 elements")
Some(Forwarders(origin = defdef.symbol, forwarders = generatedDefs))

case multiple =>
report.error("Cannot have multiple parameter lists containing `@unroll` annotation", defdef.srcPos)
Expand All @@ -323,14 +277,12 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {
val generatedBody = tmpl.body.flatMap(generateSyntheticDefs(_, compute))
val generatedConstr0 = generateSyntheticDefs(tmpl.constr, compute)
val allGenerated = generatedBody ++ generatedConstr0
val bodySubs = generatedBody.collect({ case s: Gen.Substitute => s.origin }).toSet
val otherDecls = tmpl.body.filterNot(d => d.symbol.exists && bodySubs(d.symbol))

if allGenerated.nonEmpty then
val byName = (tmpl.constr :: otherDecls).groupMap(_.symbol.name.toString)(_.symbol)
val byName = (tmpl.constr :: tmpl.body).groupMap(_.symbol.name.toString)(_.symbol)
for
syntheticDefs <- allGenerated
dcl <- syntheticDefs.extras
dcl <- syntheticDefs.forwarders
do
val replaced = dcl.symbol
byName.get(dcl.name.toString).foreach { syms =>
Expand All @@ -348,7 +300,7 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {
tmpl.parents,
tmpl.derived,
tmpl.self,
otherDecls ++ allGenerated.flatMap(_.extras)
tmpl.body ++ allGenerated.flatMap(_.forwarders)
)
}

Expand Down
14 changes: 1 addition & 13 deletions compiler/src/dotty/tools/dotc/typer/Namer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1957,9 +1957,7 @@ class Namer { typer: Typer =>
if isConstructor then
// set result type tree to unit, but take the current class as result type of the symbol
typedAheadType(ddef.tpt, defn.UnitType)
val mt = wrapMethType(effectiveResultType(sym, paramSymss))
if sym.isPrimaryConstructor then checkCaseClassParamDependencies(mt, sym.owner)
mt
wrapMethType(effectiveResultType(sym, paramSymss))
else
val paramFn = if Feature.enabled(Feature.modularity) && sym.isAllOf(Given | Method) then wrapRefinedMethType else wrapMethType
valOrDefDefSig(ddef, sym, paramSymss, paramFn)
Expand Down Expand Up @@ -2001,16 +1999,6 @@ class Namer { typer: Typer =>
ddef.trailingParamss.foreach(completeParams)
end completeTrailingParamss

/** Checks an implementation restriction on case classes. */
def checkCaseClassParamDependencies(mt: Type, cls: Symbol)(using Context): Unit =
mt.stripPoly match
case mt: MethodType if cls.is(Case) && mt.isParamDependent =>
// See issue #8073 for background
report.error(
em"""Implementation restriction: case classes cannot have dependencies between parameters""",
cls.srcPos)
case _ =>

private def setParamTrackedWithAccessors(psym: Symbol, ownerTpe: Type)(using Context): Unit =
for acc <- ownerTpe.decls.lookupAll(psym.name) if acc.is(ParamAccessor) do
acc.resetFlag(PrivateLocal)
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3198,7 +3198,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
.withType(dummy.termRef)
if (!cls.isOneOf(AbstractOrTrait) && !ctx.isAfterTyper)
checkRealizableBounds(cls, cdef.sourcePos.withSpan(cdef.nameSpan))
if cls.isEnum || firstParentTpe.classSymbol.isEnum then
if cls.isEnum || !cls.isRefinementClass && firstParentTpe.classSymbol.isEnum then
checkEnum(cdef, cls, firstParent)
val cdef1 = assignType(cpy.TypeDef(cdef)(name, impl1), cls)

Expand Down
8 changes: 0 additions & 8 deletions tests/neg/i8069.scala

This file was deleted.

12 changes: 12 additions & 0 deletions tests/pos/enum-refinement.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
enum Enum:
case EC(val x: Int)

val a: Enum.EC { val x: 1 } = Enum.EC(1).asInstanceOf[Enum.EC { val x: 1 }]

import scala.language.experimental.modularity

enum EnumT:
case EC(tracked val x: Int)

val b: EnumT.EC { val x: 1 } = EnumT.EC(1)

Loading

0 comments on commit 345b2da

Please sign in to comment.