Skip to content

Commit 4595e31

Browse files
neko-kaipshirshov
andauthored
Fix Activation-specific configs being lost in config writer (#2087)
* Bug: Activation-specific configs lost in config writer * fix, but some cleanups might be required. Reboot removed * wip * wip * wip * wip --------- Co-authored-by: Pavel Shirshov <[email protected]>
1 parent 80ed8ba commit 4595e31

File tree

13 files changed

+499
-353
lines changed

13 files changed

+499
-353
lines changed

distage/distage-core/src/main/scala/izumi/distage/bootstrap/BootstrapLocator.scala

+3-3
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ import izumi.distage.model.reflection.{DIKey, MirrorProvider}
1616
import izumi.distage.planning.*
1717
import izumi.distage.planning.sequential.{ForwardingRefResolverDefaultImpl, FwdrefLoopBreaker, SanityCheckerDefaultImpl}
1818
import izumi.distage.planning.solver.SemigraphSolver.SemigraphSolverImpl
19-
import izumi.distage.planning.solver.{GraphPreparations, PlanSolver, SemigraphSolver}
19+
import izumi.distage.planning.solver.{GraphQueries, PlanSolver, SemigraphSolver}
2020
import izumi.distage.provisioning.*
2121
import izumi.distage.provisioning.strategies.*
2222
import izumi.fundamentals.platform.functional.Identity
@@ -71,7 +71,7 @@ object BootstrapLocator {
7171
val sanityChecker = new SanityCheckerDefaultImpl()
7272
val resolver = new PlanSolver.Impl(
7373
new SemigraphSolverImpl[DIKey, Int, InstantiationOp](),
74-
new GraphPreparations(new BindingTranslator.Impl()),
74+
new GraphQueries(new BindingTranslator.Impl()),
7575
)
7676

7777
new PlannerDefaultImpl(
@@ -111,7 +111,7 @@ object BootstrapLocator {
111111
make[MirrorProvider].fromValue(mirrorProvider)
112112

113113
make[PlanSolver].from[PlanSolver.Impl]
114-
make[GraphPreparations]
114+
make[GraphQueries]
115115

116116
make[SemigraphSolver[DIKey, Int, InstantiationOp]].from[SemigraphSolverImpl[DIKey, Int, InstantiationOp]]
117117

distage/distage-core/src/main/scala/izumi/distage/planning/SubcontextHandler.scala

+7
Original file line numberDiff line numberDiff line change
@@ -63,4 +63,11 @@ object SubcontextHandler {
6363

6464
}
6565
}
66+
67+
class TracingHandler() extends SubcontextHandler[Nothing] {
68+
override def handle(binding: Binding, c: ImplDef.ContextImpl): Either[Nothing, SingletonWiring] = {
69+
Right(SingletonWiring.PrepareSubcontext(c.extractingFunction, Plan.empty, c.implType, c.externalKeys, Set.empty))
70+
}
71+
}
72+
6673
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
package izumi.distage.planning.solver
2+
3+
import izumi.distage.model.definition.ModuleBase
4+
import izumi.distage.model.definition.conflicts.{Annotated, Node}
5+
import izumi.distage.model.plan.ExecutableOp.{CreateSet, InstantiationOp}
6+
import izumi.distage.model.plan.Roots
7+
import izumi.distage.model.planning.PlanIssue.*
8+
import izumi.distage.model.planning.{ActivationChoices, AxisPoint, PlanIssue}
9+
import izumi.distage.model.reflection.{DIKey, SafeType}
10+
import izumi.distage.planning.SubcontextHandler
11+
import izumi.distage.planning.solver.GenericSemigraphTraverse.{TraversalFailure, TraversalResult}
12+
import izumi.distage.planning.solver.SemigraphSolver.SemiEdgeSeq
13+
import izumi.distage.provisioning.strategies.ImportStrategyDefaultImpl
14+
import izumi.functional.IzEither.*
15+
import izumi.fundamentals.collections.IzCollections.*
16+
import izumi.fundamentals.collections.nonempty.{NEList, NESet}
17+
import izumi.fundamentals.collections.{ImmutableMultiMap, MutableMultiMap}
18+
import izumi.reflect.TagK
19+
20+
import java.util.concurrent.TimeUnit
21+
import scala.annotation.nowarn
22+
import scala.collection.mutable
23+
import scala.concurrent.duration.FiniteDuration
24+
25+
object GenericSemigraphTraverse {
26+
case class TraversalResult(visitedKeys: Set[DIKey], time: FiniteDuration, maybeIssues: Option[NESet[PlanIssue]])
27+
case class TraversalFailure[Err](time: FiniteDuration, issues: NEList[Err])
28+
}
29+
30+
abstract class GenericSemigraphTraverse[Err](
31+
queries: GraphQueries,
32+
subcontextHandler: SubcontextHandler[Err],
33+
) {
34+
35+
def traverse[F[_]: TagK](
36+
bindings: ModuleBase,
37+
roots: Roots,
38+
providedKeys: DIKey => Boolean,
39+
excludedActivations: Set[NESet[AxisPoint]],
40+
): Either[TraversalFailure[Err], TraversalResult] = {
41+
val before = System.currentTimeMillis()
42+
(for {
43+
ops <- queries.computeOperationsUnsafe(subcontextHandler, bindings).map(_.toSeq)
44+
} yield {
45+
val allAxis: Map[String, Set[String]] = ops.flatMap(_._1.axis).groupBy(_.axis).map {
46+
case (axis, points) =>
47+
(axis, points.map(_.value).toSet)
48+
}
49+
val (mutators, defns) = ops.partition(_._3.isMutator)
50+
val justOps = defns.map { case (k, op, _) => k -> op }
51+
52+
val setOps = queries
53+
.computeSetsUnsafe(justOps)
54+
.map {
55+
case (k, (s, _)) =>
56+
(Annotated(k, None, Set.empty), Node(s.members, s))
57+
58+
}.toMultimapView
59+
.map {
60+
case (k, v) =>
61+
val members = v.flatMap(_.deps).toSet
62+
(k, Node(members, v.head.meta.copy(members = members): InstantiationOp))
63+
}
64+
.toSeq
65+
66+
val opsMatrix: Seq[(Annotated[DIKey], Node[DIKey, InstantiationOp])] = queries.toDeps(justOps)
67+
68+
val matrix: SemiEdgeSeq[Annotated[DIKey], DIKey, InstantiationOp] = SemiEdgeSeq(opsMatrix ++ setOps)
69+
70+
val matrixToTrace = defns.map { case (k, op, _) => (k.key, (op, k.axis)) }.toMultimap
71+
val justMutators = mutators.map { case (k, op, _) => (k.key, (op, k.axis)) }.toMultimap
72+
73+
val rootKeys: Set[DIKey] = queries.getRoots(roots, justOps)
74+
val execOpIndex: MutableMultiMap[DIKey, InstantiationOp] = queries.executableOpIndex(matrix)
75+
76+
val mutVisited = mutable.HashSet.empty[(DIKey, Set[AxisPoint])]
77+
val effectType = SafeType.getK[F]
78+
79+
val issues =
80+
trace(allAxis, mutVisited, matrixToTrace, execOpIndex, justMutators, providedKeys, excludedActivations, rootKeys, effectType, bindings)
81+
82+
val visitedKeys: Set[DIKey] = mutVisited.map(_._1).toSet
83+
val after = System.currentTimeMillis()
84+
val time: FiniteDuration = FiniteDuration(after - before, TimeUnit.MILLISECONDS)
85+
86+
val maybeIssues: Option[NESet[PlanIssue]] = NESet.from(issues)
87+
88+
TraversalResult(visitedKeys, time, maybeIssues)
89+
}).left.map {
90+
errs => TraversalFailure(FiniteDuration(System.currentTimeMillis() - before, TimeUnit.MILLISECONDS), errs)
91+
}
92+
}
93+
94+
@nowarn("msg=Unused import")
95+
protected[this] def trace(
96+
allAxis: Map[String, Set[String]],
97+
allVisited: mutable.HashSet[(DIKey, Set[AxisPoint])],
98+
matrix: ImmutableMultiMap[DIKey, (InstantiationOp, Set[AxisPoint])],
99+
execOpIndex: MutableMultiMap[DIKey, InstantiationOp],
100+
justMutators: ImmutableMultiMap[DIKey, (InstantiationOp, Set[AxisPoint])],
101+
providedKeys: DIKey => Boolean,
102+
excludedActivations: Set[NESet[AxisPoint]],
103+
rootKeys: Set[DIKey],
104+
effectType: SafeType,
105+
bindings: ModuleBase,
106+
): Set[PlanIssue] = {
107+
import scala.collection.compat.*
108+
109+
@inline def go(visited: Set[DIKey], current: Set[(DIKey, DIKey)], currentActivation: Set[AxisPoint]): RecursionResult = RecursionResult(current.iterator.map {
110+
case (key, dependee) =>
111+
if (visited.contains(key) || allVisited.contains((key, currentActivation))) {
112+
Right(Iterator.empty)
113+
} else {
114+
@inline def reportMissing[A](key: DIKey, dependee: DIKey): Left[NEList[MissingImport], Nothing] = {
115+
val origins = queries.allImportingBindings(matrix, currentActivation)(key, dependee)
116+
val similarBindings = ImportStrategyDefaultImpl.findSimilarImports(bindings, key)
117+
Left(NEList(MissingImport(key, dependee, origins, similarBindings.similarSame, similarBindings.similarSub)))
118+
}
119+
120+
@inline def reportMissingIfNotProvided[A](key: DIKey, dependee: DIKey)(orElse: => Either[NEList[PlanIssue], A]): Either[NEList[PlanIssue], A] = {
121+
if (providedKeys(key)) orElse else reportMissing(key, dependee)
122+
}
123+
124+
matrix.get(key) match {
125+
case None =>
126+
reportMissingIfNotProvided(key, dependee)(Right(Iterator.empty))
127+
128+
case Some(allOps) =>
129+
val ops = allOps.filterNot(o => queries.isIgnoredActivation(excludedActivations, o._2))
130+
val ac = ActivationChoices(currentActivation)
131+
132+
val withoutCurrentActivations = {
133+
val withoutImpossibleActivationsIter = ops.iterator.filter(ac `allValid` _._2)
134+
withoutImpossibleActivationsIter.map {
135+
case (op, activations) =>
136+
(op, activations diff currentActivation, activations)
137+
}.toSet
138+
}
139+
140+
for {
141+
// we ignore activations for set definitions
142+
opsWithMergedSets <- {
143+
val (setOps, otherOps) = withoutCurrentActivations.partitionMap {
144+
case (s: CreateSet, _, _) => Left(s)
145+
case a => Right(a)
146+
}
147+
for {
148+
mergedSets <- setOps.groupBy(_.target).values.biTraverse {
149+
ops =>
150+
for {
151+
members <- ops.iterator
152+
.flatMap(_.members)
153+
.biFlatTraverse {
154+
memberKey =>
155+
matrix.get(memberKey) match {
156+
case Some(value) if value.sizeIs == 1 =>
157+
if (ac.allValid(value.head._2)) Right(List(memberKey)) else Right(Nil)
158+
case Some(value) =>
159+
Left(NEList(InconsistentSetMembers(memberKey, NEList.unsafeFrom(value.iterator.map(_._1.origin.value).toList))))
160+
case None =>
161+
reportMissingIfNotProvided(memberKey, key)(Right(List(memberKey)))
162+
}
163+
}.to(Set)
164+
} yield {
165+
(ops.head.copy(members = members), Set.empty[AxisPoint], Set.empty[AxisPoint])
166+
}
167+
}
168+
} yield otherOps ++ mergedSets
169+
}
170+
_ <-
171+
verifyStep(currentActivation, providedKeys, key, dependee, reportMissing, ops, opsWithMergedSets)
172+
next <- checkConflicts(allAxis, opsWithMergedSets, execOpIndex, excludedActivations, effectType)
173+
} yield {
174+
allVisited.add((key, currentActivation))
175+
176+
val mutators =
177+
justMutators.getOrElse(key, Set.empty).iterator.filter(ac `allValid` _._2).flatMap(m => queries.depsOf(execOpIndex, m._1)).toSeq
178+
179+
val goNext = next.iterator.map {
180+
case (nextActivation, nextDeps) =>
181+
() =>
182+
go(
183+
visited = visited + key,
184+
current = (nextDeps ++ mutators).map(_ -> key),
185+
currentActivation = currentActivation ++ nextActivation,
186+
)
187+
}
188+
189+
goNext
190+
}
191+
}
192+
}
193+
})
194+
195+
// for trampoline
196+
sealed trait RecResult {
197+
type RecursionResult <: Iterator[Either[NEList[PlanIssue], Iterator[() => RecursionResult]]]
198+
}
199+
type RecursionResult = RecResult#RecursionResult
200+
@inline def RecursionResult(a: Iterator[Either[NEList[PlanIssue], Iterator[() => RecursionResult]]]): RecursionResult = a.asInstanceOf[RecursionResult]
201+
202+
// trampoline
203+
val errors = Set.newBuilder[PlanIssue]
204+
val remainder = mutable.Stack(() => go(Set.empty, Set.from(rootKeys.map(r => r -> r)), Set.empty))
205+
206+
while (remainder.nonEmpty) {
207+
val i = remainder.pop().apply()
208+
while (i.hasNext) {
209+
i.next() match {
210+
case Right(nextSteps) =>
211+
remainder pushAll nextSteps
212+
case Left(newErrors) =>
213+
errors ++= newErrors
214+
}
215+
}
216+
}
217+
218+
errors.result()
219+
}
220+
221+
protected def verifyStep(
222+
currentActivation: Set[AxisPoint],
223+
providedKeys: DIKey => Boolean,
224+
key: DIKey,
225+
dependee: DIKey,
226+
reportMissing: (DIKey, DIKey) => Left[NEList[MissingImport], Nothing],
227+
ops: Set[(InstantiationOp, Set[AxisPoint])],
228+
opsWithMergedSets: Set[(InstantiationOp, Set[AxisPoint], Set[AxisPoint])],
229+
): Either[NEList[PlanIssue], Unit]
230+
231+
protected def checkConflicts(
232+
allAxis: Map[String, Set[String]],
233+
withoutCurrentActivations: Set[(InstantiationOp, Set[AxisPoint], Set[AxisPoint])],
234+
execOpIndex: MutableMultiMap[DIKey, InstantiationOp],
235+
excludedActivations: Set[NESet[AxisPoint]],
236+
effectType: SafeType,
237+
): Either[NEList[PlanIssue], Seq[(Set[AxisPoint], Set[DIKey])]]
238+
239+
}

distage/distage-core/src/main/scala/izumi/distage/planning/solver/GraphPreparations.scala distage/distage-core/src/main/scala/izumi/distage/planning/solver/GraphQueries.scala

+67-3
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,90 @@ import izumi.distage.model.definition.BindingTag.AxisTag
44
import izumi.distage.model.definition.conflicts.{Annotated, Node}
55
import izumi.distage.model.definition.{Binding, ModuleBase}
66
import izumi.distage.model.plan.ExecutableOp.{CreateSet, InstantiationOp, MonadicOp, WiringOp}
7+
import izumi.distage.model.plan.operations.OperationOrigin
78
import izumi.distage.model.plan.{ExecutableOp, Roots, Wiring}
89
import izumi.distage.model.planning.AxisPoint
910
import izumi.distage.model.reflection.DIKey
11+
import izumi.distage.model.reflection.DIKey.SetElementKey
1012
import izumi.distage.planning.solver.SemigraphSolver.SemiEdgeSeq
1113
import izumi.distage.planning.{BindingTranslator, SubcontextHandler}
1214
import izumi.functional.IzEither.*
13-
import izumi.fundamentals.collections.MutableMultiMap
14-
import izumi.fundamentals.collections.nonempty.NEList
15+
import izumi.fundamentals.collections.{ImmutableMultiMap, MutableMultiMap}
16+
import izumi.fundamentals.collections.nonempty.{NEList, NESet}
1517
import izumi.fundamentals.graphs.WeakEdge
1618
import izumi.fundamentals.graphs.struct.IncidenceMatrix
1719
import izumi.fundamentals.graphs.tools.gc.Tracer
1820

1921
import scala.annotation.nowarn
2022

2123
@nowarn("msg=Unused import")
22-
class GraphPreparations(
24+
class GraphQueries(
2325
bindingTranslator: BindingTranslator
2426
) {
2527

2628
import scala.collection.compat.*
29+
final def isIgnoredActivation(excludedActivations: Set[NESet[AxisPoint]], activation: Set[AxisPoint]): Boolean = {
30+
excludedActivations.exists(_ subsetOf activation)
31+
}
32+
33+
final def allImportingBindings(
34+
matrix: ImmutableMultiMap[DIKey, (InstantiationOp, Set[AxisPoint])],
35+
currentActivation: Set[AxisPoint],
36+
)(importedKey: DIKey,
37+
d: DIKey,
38+
): Set[OperationOrigin] = {
39+
// FIXME: reuse formatting from conflictingAxisTagsHint
40+
matrix
41+
.getOrElse(d, Set.empty)
42+
.collect {
43+
case (op, activations) if activations.subsetOf(currentActivation) && (op match {
44+
case CreateSet(_, members, _) => members
45+
case op: ExecutableOp.WiringOp => op.wiring.requiredKeys
46+
case op: ExecutableOp.MonadicOp => Set(op.effectKey)
47+
}).contains(importedKey) =>
48+
op.origin.value
49+
}
50+
}
51+
52+
def nextDepsToVisit(
53+
execOpIndex: MutableMultiMap[DIKey, InstantiationOp],
54+
withoutCurrentActivations: Set[(InstantiationOp, Set[AxisPoint], Set[AxisPoint])],
55+
): Right[Nothing, Seq[(Set[AxisPoint], Set[DIKey])]] = {
56+
val next = withoutCurrentActivations.iterator.map {
57+
case (op, activations, _) =>
58+
// TODO: I'm not sure if it's "correct" to "activate" all the points together but it simplifies things greatly
59+
val deps = depsOf(execOpIndex, op)
60+
val acts = op match {
61+
case _: CreateSet =>
62+
Set.empty[AxisPoint]
63+
case _ =>
64+
activations
65+
}
66+
(acts, deps)
67+
}.toSeq
68+
Right(next)
69+
}
70+
71+
final def depsOf(
72+
execOpIndex: MutableMultiMap[DIKey, InstantiationOp],
73+
op: InstantiationOp,
74+
): Set[DIKey] = {
75+
op match {
76+
case cs: CreateSet =>
77+
// we completely ignore weak members, they don't make any difference in case they are unreachable through other paths
78+
val members = cs.members.filter {
79+
case m: SetElementKey =>
80+
getSetElementWeakEdges(execOpIndex, m).isEmpty
81+
case _ =>
82+
true
83+
}
84+
members
85+
case op: ExecutableOp.WiringOp =>
86+
toDep(op).deps
87+
case op: ExecutableOp.MonadicOp =>
88+
toDep(op).deps
89+
}
90+
}
2791

2892
def findWeakSetMembers(
2993
setOps: Map[Annotated[DIKey], Node[DIKey, InstantiationOp]],

distage/distage-core/src/main/scala/izumi/distage/planning/solver/PlanSolver.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ object PlanSolver {
3737

3838
@nowarn("msg=Unused import")
3939
class Impl(
40-
resolver: SemigraphSolver[DIKey, Int, InstantiationOp],
41-
preps: GraphPreparations,
40+
resolver: SemigraphSolver[DIKey, Int, InstantiationOp],
41+
preps: GraphQueries,
4242
) extends PlanSolver {
4343

4444
import scala.collection.compat.*

0 commit comments

Comments
 (0)