Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Detect when resolve needs ModuleRef when likely cyclic references #3878

Merged
merged 18 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion main/resolve/src/mill/resolve/Resolve.scala
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,8 @@ trait Resolve[T] {
rootModule = rootModule,
remainingQuery = sel.value.toList,
current = rootResolved,
querySoFar = Segments()
querySoFar = Segments(),
seenModules = Set.empty
) match {
case ResolveCore.Success(value) => Right(value)
case ResolveCore.NotFound(segments, found, next, possibleNexts) =>
Expand Down
109 changes: 76 additions & 33 deletions main/resolve/src/mill/resolve/ResolveCore.scala
Original file line number Diff line number Diff line change
Expand Up @@ -62,20 +62,43 @@ private object ResolveCore {
def makeResultException(e: Throwable, base: Exception): Left[String, Nothing] =
mill.api.Result.makeResultException(e, base)

def cyclicModuleErrorMsg(segments: Segments): String = {
s"Cyclic module reference detected at ${segments.render}, " +
s"it's required to wrap it in ModuleRef."
}
def resolve(
rootModule: BaseModule,
remainingQuery: List[Segment],
current: Resolved,
querySoFar: Segments
querySoFar: Segments,
seenModules: Set[Class[_]] = Set.empty
): Result = {
def moduleClasses(resolved: Iterable[Resolved]): Set[Class[_]] = {
resolved.collect { case Resolved.Module(_, cls) => cls }.toSet
}

remainingQuery match {
case Nil => Success(Set(current))
case head :: tail =>
def recurse(searchModules: Set[Resolved]): Result = {
val (failures, successesLists) = searchModules
.map(r => resolve(rootModule, tail, r, querySoFar ++ Seq(head)))
val results = searchModules
.map { r =>
val rClasses = moduleClasses(Set(r))
if (seenModules.intersect(rClasses).nonEmpty) {
Error(cyclicModuleErrorMsg(r.segments))
} else {
resolve(
rootModule,
tail,
r,
querySoFar ++ Seq(head),
seenModules ++ moduleClasses(Set(current))
)
}
}
.partitionMap { case s: Success => Right(s.value); case f: Failed => Left(f) }

val (failures, successesLists) = results
val (errors, notFounds) = failures.partitionMap {
case s: NotFound => Right(s)
case s: Error => Left(s.msg)
Expand All @@ -87,11 +110,12 @@ private object ResolveCore {
case 1 => notFounds.head
case _ => notFoundResult(rootModule, querySoFar, current, head)
}

}

(head, current) match {
case (Segment.Label(singleLabel), m: Resolved.Module) =>
val resOrErr = singleLabel match {
val resOrErr: Either[String, Iterable[Resolved]] = singleLabel match {
case "__" =>
val self = Seq(Resolved.Module(m.segments, m.cls))
val transitiveOrErr =
Expand All @@ -100,7 +124,8 @@ private object ResolveCore {
m.cls,
None,
current.segments,
Nil
Nil,
seenModules
)

transitiveOrErr.map(transitive => self ++ transitive)
Expand All @@ -122,7 +147,8 @@ private object ResolveCore {
m.cls,
None,
current.segments,
typePattern
typePattern,
seenModules
)

transitiveOrErr.map(transitive => self ++ transitive)
Expand Down Expand Up @@ -244,25 +270,41 @@ private object ResolveCore {
cls: Class[_],
nameOpt: Option[String],
segments: Segments,
typePattern: Seq[String]
typePattern: Seq[String],
seenModules: Set[Class[_]]
): Either[String, Set[Resolved]] = {
val direct =
resolveDirectChildren(rootModule, cls, nameOpt, segments, typePattern)
direct.flatMap { direct =>
for {
directTraverse <-
resolveDirectChildren(rootModule, cls, nameOpt, segments, Nil)
indirect0 = directTraverse
.collect { case m: Resolved.Module =>
resolveTransitiveChildren(
if (seenModules.contains(cls)) Left(cyclicModuleErrorMsg(segments))
else {
val errOrDirect = resolveDirectChildren(rootModule, cls, nameOpt, segments, typePattern)
val directTraverse = resolveDirectChildren(rootModule, cls, nameOpt, segments, Nil)

val errOrModules = directTraverse.map { modules =>
modules.flatMap {
case m: Resolved.Module => Some(m)
case _ => None
}
}

val errOrIndirect0 = errOrModules match {
case Right(modules) =>
modules.flatMap { m =>
Some(resolveTransitiveChildren(
rootModule,
m.cls,
nameOpt,
m.segments,
typePattern
)
typePattern,
seenModules + cls
))
}
indirect <- EitherOps.sequence(indirect0).map(_.flatten)
case Left(err) => Seq(Left(err))
}

val errOrIndirect = EitherOps.sequence(errOrIndirect0).map(_.flatten)

for {
direct <- errOrDirect
indirect <- errOrIndirect
} yield direct ++ indirect
}
}
Expand Down Expand Up @@ -306,7 +348,6 @@ private object ResolveCore {
segments: Segments,
typePattern: Seq[String] = Nil
): Either[String, Set[Resolved]] = {

val crossesOrErr = if (classOf[Cross[_]].isAssignableFrom(cls) && nameOpt.isEmpty) {
instantiateModule(rootModule, segments).map {
case cross: Cross[_] =>
Expand All @@ -318,21 +359,23 @@ private object ResolveCore {
}
} else Right(Nil)

crossesOrErr.flatMap { crosses =>
val filteredCrosses = crosses.filter { c =>
classMatchesTypePred(typePattern)(c.cls)
def expandSegments(direct: Seq[(Resolved, Option[Module => Either[String, Module]])]) = {
direct.map {
case (Resolved.Module(s, cls), _) => Resolved.Module(segments ++ s, cls)
case (Resolved.NamedTask(s), _) => Resolved.NamedTask(segments ++ s)
case (Resolved.Command(s), _) => Resolved.Command(segments ++ s)
}
}

resolveDirectChildren0(rootModule, segments, cls, nameOpt, typePattern)
.map(
_.map {
case (Resolved.Module(s, cls), _) => Resolved.Module(segments ++ s, cls)
case (Resolved.NamedTask(s), _) => Resolved.NamedTask(segments ++ s)
case (Resolved.Command(s), _) => Resolved.Command(segments ++ s)
}
.toSet
.++(filteredCrosses)
)
for {
crosses <- crossesOrErr
filteredCrosses = crosses.filter { c =>
classMatchesTypePred(typePattern)(c.cls)
}
direct0 <- resolveDirectChildren0(rootModule, segments, cls, nameOpt, typePattern)
direct <- Right(expandSegments(direct0))
} yield {
direct.toSet ++ filteredCrosses
}
}

Expand Down
85 changes: 85 additions & 0 deletions main/resolve/test/src/mill/main/ResolveTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1110,5 +1110,90 @@ object ResolveTests extends TestSuite {
Right(Set(_.concrete.tests.inner.foo, _.concrete.tests.inner.innerer.bar))
)
}
test("cyclicModuleRefInitError") {
val check = new Checker(TestGraphs.CyclicModuleRefInitError)
test - check.checkSeq0(
Seq("__"),
isShortError(_, "Cyclic module reference detected at myA.a,")
)
test - check(
"_",
Right(Set(_.foo))
)
test - check.checkSeq0(
Seq("myA.__"),
isShortError(_, "Cyclic module reference detected at myA.a,")
)
test - check.checkSeq0(
Seq("myA.a.__"),
isShortError(_, "Cyclic module reference detected at myA.a,")
)
test - check.checkSeq0(
Seq("myA.a._"),
isShortError(_, "Cyclic module reference detected at myA.a.a,")
)
test - check.checkSeq0(
Seq("myA.a._.a"),
isShortError(_, "Cyclic module reference detected at myA.a.a,")
)
test - check.checkSeq0(
Seq("myA.a.b.a"),
isShortError(_, "Cyclic module reference detected at myA.a.b.a,")
)
}
test("cyclicModuleRefInitError2") {
val check = new Checker(TestGraphs.CyclicModuleRefInitError2)
test - check.checkSeq0(
Seq("__"),
isShortError(_, "Cyclic module reference detected at A.myA.a,")
)
}
test("cyclicModuleRefInitError3") {
val check = new Checker(TestGraphs.CyclicModuleRefInitError3)
test - check.checkSeq0(
Seq("__"),
isShortError(_, "Cyclic module reference detected at A.b.a,")
)
test - check.checkSeq0(
Seq("A.__"),
isShortError(_, "Cyclic module reference detected at A.b.a,")
)
test - check.checkSeq0(
Seq("A.b.__.a.b"),
isShortError(_, "Cyclic module reference detected at A.b.a,")
)
}
test("crossedCyclicModuleRefInitError") {
val check = new Checker(TestGraphs.CrossedCyclicModuleRefInitError)
test - check.checkSeq0(
Seq("__"),
isShortError(_, "Cyclic module reference detected at cross[210].c2[210].c1,")
)
}
test("nonCyclicModules") {
val check = new Checker(TestGraphs.NonCyclicModules)
test - check(
"__",
Right(Set(_.foo))
)
}
test("moduleRefWithNonModuleRefChild") {
val check = new Checker(TestGraphs.ModuleRefWithNonModuleRefChild)
test - check(
"__",
Right(Set(_.foo))
)
}
test("moduleRefCycle") {
val check = new Checker(TestGraphs.ModuleRefCycle)
test - check(
"__",
Right(Set(_.foo))
)
test - check(
"__._",
Right(Set(_.foo))
)
}
}
}
94 changes: 94 additions & 0 deletions main/test/src/mill/util/TestGraphs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -667,4 +667,98 @@ object TestGraphs {
}
}

object CyclicModuleRefInitError extends TestBaseModule {
import mill.Agg
def foo = Task { "foo" }

// See issue: https://github.com/com-lihaoyi/mill/issues/3715
trait CommonModule extends TestBaseModule {
def foo = Task { "foo" }
def moduleDeps: Seq[CommonModule] = Seq.empty
def a = myA
def b = myB
}

object myA extends A
trait A extends CommonModule
object myB extends B
trait B extends CommonModule {
override def moduleDeps = super.moduleDeps ++ Agg(a)
}
}

object CyclicModuleRefInitError2 extends TestBaseModule {
// The cycle is in the child
def A = CyclicModuleRefInitError
}

object CyclicModuleRefInitError3 extends TestBaseModule {
// The cycle is in directly here
object A extends Module {
def b = B
}
object B extends Module {
def a = A
}
}

object CrossedCyclicModuleRefInitError extends TestBaseModule {
object cross extends mill.Cross[Cross]("210", "211", "212")
trait Cross extends Cross.Module[String] {
def suffix = Task { crossValue }
def c2 = cross2
}

object cross2 extends mill.Cross[Cross2]("210", "211", "212")
trait Cross2 extends Cross.Module[String] {
override def millSourcePath = super.millSourcePath / crossValue
def suffix = Task { crossValue }
def c1 = cross
}
}

// The module names repeat, but it's not actually cyclic and is meant to confuse the cycle detection.
object NonCyclicModules extends TestBaseModule {
def foo = Task { "foo" }

object A extends Module {
def b = B
}
object B extends Module {
object A extends Module {
def b = B
}
def a = A

object B extends Module {
object B extends Module {}
object A extends Module {
def b = B
}
def a = A
}
}
}

// This edge case shouldn't be an error
object ModuleRefWithNonModuleRefChild extends TestBaseModule {
def foo = Task { "foo" }

def aRef = A
def a = ModuleRef(A)

object A extends TestBaseModule {}
}

object ModuleRefCycle extends TestBaseModule {
def foo = Task { "foo" }

// The cycle is in directly here
object A extends Module {
def b = ModuleRef(B)
}
object B extends Module {
def a = ModuleRef(A)
}
}
}
Loading