Skip to content

Commit

Permalink
Clean up Discover macro and codegen (#4461)
Browse files Browse the repository at this point in the history
* Make `Discover` return a `class` (that can be evolved by adding
fields) rather than a `Tuple` (which cannot)
* Simplify handling of `millDiscover` flags, in particular we do not
need them to be defined for subfolder base modules
* Remove unused `ObjectDataInstrument`, `Snippet`, `ObjectData`
* Remove `MILL_SPLICED_CODE_START_MARKER`

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
  • Loading branch information
lihaoyi and autofix-ci[bot] authored Feb 3, 2025
1 parent 737dec9 commit 351fc0f
Show file tree
Hide file tree
Showing 18 changed files with 135 additions and 517 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,9 @@ object RootSubfolderModuleCollisionTests extends UtestIntegrationTestSuite {
import tester._
val res = eval(("resolve", "_"))
assert(res.isSuccess == false)
assert(res.err.contains("cannot override final member"))
assert(res.err.contains(
" final lazy val sub: _root_.build_.sub.package_.type = _root_.build_.sub.package_ // subfolder module referenc"
))
assert(res.err.contains("Reference to sub is ambiguous."))
assert(res.err.contains("It is both defined in class package_"))
assert(res.err.contains("and inherited subsequently in class package_"))
}
}
}

This file was deleted.

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package mill.integration
import mill.testkit.UtestIntegrationTestSuite
import utest._

object DocAnnotationsTests extends UtestIntegrationTestSuite {
object InspectTests extends UtestIntegrationTestSuite {
def globMatches(glob: String, input: String): Boolean = {
StringContext
.glob(
Expand Down
27 changes: 15 additions & 12 deletions integration/feature/scala-3-syntax/resources/build.mill
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,27 @@ import $packages._
import $file.foo.Box
import $file.foo.{given Box[Int]}


given Cross.ToSegments[DayValue](d => List(d.toString))

given mainargs.TokensReader.Simple[DayValue] with
def shortName = "day"

def read(strs: Seq[String]) =
try
Right(DayValue.valueOf(strs.head))
catch
case _: Exception => Left("not a day")

enum DayValue:
case Monday, Tuesday, Wednesday, Thursday, Friday, Saturday, Sunday

object `package` extends RootModule:

def someTopLevelCommand(): Command[Unit] = Task.Command:
println(s"Hello, world! ${summon[Box[Int]]} ${build.sub.subTask()}")
end someTopLevelCommand

given Cross.ToSegments[DayValue](d => List(d.toString))

given mainargs.TokensReader.Simple[DayValue] with
def shortName = "day"
def read(strs: Seq[String]) =
try
Right(DayValue.valueOf(strs.head))
catch
case _: Exception => Left("not a day")

enum DayValue:
case Monday, Tuesday, Wednesday, Thursday, Friday, Saturday, Sunday

object day extends Cross[DayModule](DayValue.values.toSeq)

Expand Down
139 changes: 61 additions & 78 deletions main/define/src/mill/define/Discover.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,30 +14,13 @@ import scala.collection.mutable
* the `Task.Command` methods we find. This mapping from `Class[_]` to `MainData`
* can then be used later to look up the `MainData` for any module.
*/
case class Discover private (
value: Map[
Class[_],
(Seq[String], Seq[mainargs.MainData[_, _]], Seq[String])
],
dummy: Int = 0 /* avoid conflict with Discover.apply(value: Map) below*/
) {
@deprecated("Binary compatibility shim", "Mill 0.11.4")
private[define] def this(value: Map[Class[_], Seq[mainargs.MainData[_, _]]]) =
this(value.view.mapValues((Nil, _, Nil)).toMap)
@deprecated("Binary compatibility shim", "Mill 0.11.4")
private[define] def copy(value: Map[Class[_], Seq[mainargs.MainData[_, _]]]): Discover = {
new Discover(value.view.mapValues((Nil, _, Nil)).toMap, dummy)
}
}
class Discover(val classInfo: Map[Class[_], Discover.Node], val allNames: Seq[String])

object Discover {
def apply2[T](value: Map[Class[_], (Seq[String], Seq[mainargs.MainData[_, _]], Seq[String])])
: Discover =
new Discover(value)

@deprecated("Binary compatibility shim", "Mill 0.11.4")
def apply[T](value: Map[Class[_], Seq[mainargs.MainData[_, _]]]): Discover =
new Discover(value.view.mapValues((Nil, _, Nil)).toMap)
class Node(
val entryPoints: Seq[mainargs.MainData[_, _]],
val declaredNames: Seq[String]
)

inline def apply[T]: Discover = ${ Router.applyImpl[T] }

Expand All @@ -46,7 +29,7 @@ object Discover {
import mainargs.Macros.*
import scala.util.control.NonFatal

def applyImpl[T: Type](using Quotes): Expr[Discover] = {
def applyImpl[T: Type](using quotes: Quotes): Expr[Discover] = {
import quotes.reflect.*
val seen = mutable.Set.empty[TypeRepr]
val moduleSym = Symbol.requiredClass("mill.define.Module")
Expand All @@ -62,10 +45,8 @@ object Discover {
} {
rec(memberTpe)
memberTpe.asType match {
case '[mill.define.Cross[m]] =>
rec(TypeRepr.of[m])
case _ =>
() // no cross argument to extract
case '[mill.define.Cross[m]] => rec(TypeRepr.of[m])
case _ => () // no cross argument to extract
}
}
}
Expand Down Expand Up @@ -107,89 +88,91 @@ object Discover {
)
)

def sortedMethods(curCls: TypeRepr, sub: TypeRepr, methods: Seq[Symbol]): Seq[Symbol] =
for {
m <- methods.toList.sortBy(_.fullName)
mType = curCls.memberType(m)
returnType = methodReturn(mType)
if returnType <:< sub
} yield m

// Make sure we sort the types and methods to keep the output deterministic;
// otherwise the compiler likes to give us stuff in random orders, which
// causes the code to be generated in random order resulting in code hashes
// changing unnecessarily
val mapping = for {
val mapping: Seq[(Expr[(Class[_], Node)], Seq[String])] = for {
discoveredModuleType <- seen.toSeq.sortBy(_.typeSymbol.fullName)
curCls = discoveredModuleType
methods = filterDefs(curCls.typeSymbol.methodMembers)
declMethods = filterDefs(curCls.typeSymbol.declaredMethods)
overridesRoutes = {
_ = {
assertParamListCounts(
curCls,
methods,
(TypeRepr.of[mill.define.Command[?]], 1, "`Task.Command`"),
(TypeRepr.of[mill.define.Target[?]], 0, "Target")
)
}

def sortedMethods(sub: TypeRepr, methods: Seq[Symbol] = methods): Seq[Symbol] =
for {
m <- methods.toList.sortBy(_.fullName)
mType = curCls.memberType(m)
returnType = methodReturn(mType)
if returnType <:< sub
} yield m

Tuple3(
for {
m <- sortedMethods(sub = TypeRepr.of[mill.define.NamedTask[?]])
} yield m.name, // .decoded // we don't need to decode the name in Scala 3
for {
m <- sortedMethods(sub = TypeRepr.of[mill.define.Command[?]])
} yield curCls.asType match {
case '[t] =>
val expr =
try
createMainData[Any, t](
m,
m.annotations.find(_.tpe =:= TypeRepr.of[mainargs.main]).getOrElse('{
new mainargs.main()
}.asTerm),
m.paramSymss
).asExprOf[mainargs.MainData[?, ?]]
catch {
case NonFatal(e) =>
val (before, Array(after, _*)) = e.getStackTrace().span(e =>
!(e.getClassName() == "mill.define.Discover$Router$" && e.getMethodName() == "applyImpl")
): @unchecked
val trace =
(before :+ after).map(_.toString).mkString("trace:\n", "\n", "\n...")
report.errorAndAbort(
s"Error generating maindata for ${m.fullName}: ${e}\n$trace",
m.pos.getOrElse(Position.ofMacroExpansion)
)
}
// report.warning(s"generated maindata for ${m.fullName}:\n${expr.asTerm.show}", m.pos.getOrElse(Position.ofMacroExpansion))
expr
},
for
m <- sortedMethods(sub = TypeRepr.of[mill.define.Task[?]], methods = declMethods)
yield m.name.toString
)
names =
sortedMethods(curCls, sub = TypeRepr.of[mill.define.NamedTask[?]], methods).map(_.name)
entryPoints = for {
m <- sortedMethods(curCls, sub = TypeRepr.of[mill.define.Command[?]], methods)
} yield curCls.asType match {
case '[t] =>
val expr =
try
createMainData[Any, t](
m,
m.annotations.find(_.tpe =:= TypeRepr.of[mainargs.main]).getOrElse('{
new mainargs.main()
}.asTerm),
m.paramSymss
).asExprOf[mainargs.MainData[?, ?]]
catch {
case NonFatal(e) =>
val (before, Array(after, _*)) = e.getStackTrace().span(e =>
!(e.getClassName() == "mill.define.Discover$Router$" && e.getMethodName() == "applyImpl")
): @unchecked
val trace =
(before :+ after).map(_.toString).mkString("trace:\n", "\n", "\n...")
report.errorAndAbort(
s"Error generating maindata for ${m.fullName}: ${e}\n$trace",
m.pos.getOrElse(Position.ofMacroExpansion)
)
}
expr
}
if overridesRoutes._1.nonEmpty || overridesRoutes._2.nonEmpty || overridesRoutes._3.nonEmpty
declaredNames =
sortedMethods(
curCls,
sub = TypeRepr.of[mill.define.NamedTask[?]],
declMethods
).map(_.name)
if names.nonEmpty || entryPoints.nonEmpty
} yield {
val (names, mainDataExprs, taskNames) = overridesRoutes
// by wrapping the `overridesRoutes` in a lambda function we kind of work around
// the problem of generating a *huge* macro method body that finally exceeds the
// JVM's maximum allowed method size
val overridesLambda = '{
def triple() = (${ Expr(names) }, ${ Expr.ofList(mainDataExprs) }, ${ Expr(taskNames) })
def triple() =
new Node(${ Expr.ofList(entryPoints) }, ${ Expr(declaredNames) })
triple()
}
val lhs =
Ref(defn.Predef_classOf).appliedToType(discoveredModuleType.widen).asExprOf[Class[?]]
'{ $lhs -> $overridesLambda }
('{ $lhs -> $overridesLambda }, names)
}

val expr: Expr[Discover] =
'{
// TODO: we can not import this here, so we have to import at the use site now, or redesign?
// import mill.main.TokenReaders.*
// import mill.api.JsonFormatters.*
Discover.apply2(Map(${ Varargs(mapping) }*))
new Discover(
Map[Class[_], Node](${ Varargs(mapping.map(_._1)) }*),
${ Expr(mapping.iterator.flatMap(_._2).distinct.toList.sorted) }
)
}
// TODO: if needed for debugging, we can re-enable this
// report.warning(s"generated discovery for ${TypeRepr.of[T].show}:\n${expr.asTerm.show}", TypeRepr.of[T].typeSymbol.pos.getOrElse(Position.ofMacroExpansion))
Expand Down
6 changes: 3 additions & 3 deletions main/resolve/src/mill/resolve/Resolve.scala
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,9 @@ object Resolve {
nullCommandDefaults: Boolean,
allowPositionalCommandArgs: Boolean
): Iterable[Either[String, Command[_]]] = for {
(cls, (names, entryPoints, _)) <- discover.value
(cls, node) <- discover.classInfo
if cls.isAssignableFrom(target.getClass)
ep <- entryPoints
ep <- node.entryPoints
if ep.name == name
} yield {
def withNullDefault(a: mainargs.ArgSig): mainargs.ArgSig = {
Expand Down Expand Up @@ -303,7 +303,7 @@ trait Resolve[T] {
) match {
case ResolveCore.Success(value) => Right(value)
case ResolveCore.NotFound(segments, found, next, possibleNexts) =>
val allPossibleNames = rootModule.millDiscover.value.values.flatMap(_._1).toSet
val allPossibleNames = rootModule.millDiscover.allNames.toSet
Left(ResolveNotFoundHandler(
selector = sel,
segments = segments,
Expand Down
13 changes: 7 additions & 6 deletions main/src/mill/main/MainModule.scala
Original file line number Diff line number Diff line change
Expand Up @@ -265,9 +265,9 @@ trait MainModule extends BaseModule0 {
val mainDataOpt = evaluator
.rootModule
.millDiscover
.value
.classInfo
.get(t.ctx.enclosingCls)
.flatMap(_._2.find(_.name == t.ctx.segments.last.value))
.flatMap(_.entryPoints.find(_.name == t.ctx.segments.last.value))
.headOption

mainDataOpt match {
Expand Down Expand Up @@ -352,10 +352,11 @@ trait MainModule extends BaseModule0 {
case _ => None
}

val methodMap = evaluator.rootModule.millDiscover.value
val tasks = methodMap.get(cls).map {
case (_, _, tasks) => tasks.map(task => s"${t.module}.$task")
}.toSeq.flatten
val methodMap = evaluator.rootModule.millDiscover.classInfo
val tasks = methodMap
.get(cls)
.map { node => node.declaredNames.map(task => s"${t.module}.$task") }
.toSeq.flatten
pprint.Tree.Lazy { ctx =>
Iterator(
// module name(module/file:line)
Expand Down
13 changes: 2 additions & 11 deletions main/src/mill/main/Subfolder.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package mill.main;
import mill._
import mill.define.{Caller, Ctx, Segments, Discover}
import mill.define.{Caller, Ctx, Segments}

object SubfolderModule {
class Info(val millSourcePath0: os.Path, val segments: Seq[String]) {
Expand All @@ -23,13 +23,4 @@ abstract class SubfolderModule()(implicit
fileName = millFile0,
enclosing = Caller(null)
)
) with Module {
// SCALA 3: REINTRODUCED millDiscover because we need to splice the millDiscover from
// child modules into the parent module - this isnt wasteful because the parent module
// doesnt scan the children - hence why it is being spliced in in the Scala 3 version.

// Dummy `millDiscover` defined but never actually used and overriden by codegen.
// Provided for IDEs to think that one is available and not show errors in
// build.mill/package.mill even though they can't see the codegen
def millDiscover: Discover = sys.error("RootModule#millDiscover must be overriden")
}
) with Module {}
Loading

0 comments on commit 351fc0f

Please sign in to comment.