Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing DeclarationSorter #777

Merged
merged 10 commits into from
Apr 28, 2021
1 change: 1 addition & 0 deletions UNRELEASED.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
* Parser: supporting annotations in multiline comments, see #718
* Parser: supporting TLA+ identifiers in annotations, see #768
* Parser: better parser for annotations, see #757
* Parser: fixed two bugs in the declaration sorter, see #645 and #758
* The command `config --enable-stats=true` creates `$HOME/.tlaplus` if needed, see #762

### Changes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class EtcTypeChecker(varPool: TypeVarPool, inferPolytypes: Boolean = false) exte
computeRec(ctx, solver, mkConst(ex.sourceRef, knownType))
}
} else {
onTypeError(ex.sourceRef, s"Undefined name $name. Introduce a type annotation.")
onTypeError(ex.sourceRef, s"Found no annotation for $name. Did you write one?")
konnov marked this conversation as resolved.
Show resolved Hide resolved
throw new UnwindException
}

Expand Down Expand Up @@ -170,7 +170,7 @@ class EtcTypeChecker(varPool: TypeVarPool, inferPolytypes: Boolean = false) exte
onTypeFound(name.sourceRef, nameType)
computeRec(ctx, solver, mkApp(ex.sourceRef, Seq(nameType), args: _*))
} else {
onTypeError(ex.sourceRef, s"Undefined operator name $name. Introduce a type annotation.")
onTypeError(ex.sourceRef, s"It looks like the operator $name is used before it is defined. Is it true?")
konnov marked this conversation as resolved.
Show resolved Hide resolved
throw new UnwindException
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package at.forsyte.apalache.tla.lir.transformations.standard

import at.forsyte.apalache.tla.lir.transformations.TlaModuleTransformation
import at.forsyte.apalache.tla.lir._
import at.forsyte.apalache.tla.lir.oper.TlaOper
import at.forsyte.apalache.tla.lir.transformations.TlaModuleTransformation
import at.forsyte.apalache.tla.lir.transformations.impl.StableTopologicalSort
import com.typesafe.scalalogging.LazyLogging

import java.io.{FileWriter, PrintWriter}
import scala.collection.immutable.HashMap

/**
Expand All @@ -19,12 +21,13 @@ import scala.collection.immutable.HashMap
*
* @author Igor Konnov
*/
class DeclarationSorter extends TlaModuleTransformation {
class DeclarationSorter extends TlaModuleTransformation with LazyLogging {
type Edges = Map[UID, Set[UID]]
type MapIdToDecl = Map[UID, TlaDecl]

override def apply(input: TlaModule): TlaModule = {
// save the original order of the declarations
val idToDecl: Map[UID, TlaDecl] = input.declarations.foldLeft(Map.empty[UID, TlaDecl]) { case (map, d) =>
val idToDecl: MapIdToDecl = input.declarations.foldLeft(Map.empty[UID, TlaDecl]) { case (map, d) =>
map + (d.ID -> d)
}

Expand All @@ -38,13 +41,46 @@ class DeclarationSorter extends TlaModuleTransformation {
sorted.map(idToDecl)

case Right(witnesses) =>
val filename = explainCyclicDependency(idToDecl, depsGraph, witnesses)
logger.error(s"Check the dependency graph in $filename. You can view it with graphviz.")
val opers = witnesses.map(idToDecl).map(_.name).mkString(", ")
throw new CyclicDependencyError("Found a cyclic dependency among operators: " + opers)
}

new TlaModule(input.name, sortedDeclarations)
}

// Dump the dependency graph to a dot file. Otherwise, it is very hard to see what is going on.
private def explainCyclicDependency(idToDecl: MapIdToDecl, depsGraph: Edges, witnesses: Set[UID]): String = {
val filename = "dependencies.dot"

def getName(uid: UID): String = {
idToDecl.get(uid).map(_.name).getOrElse("undefined")
}

def printToDot(writer: PrintWriter): Unit = {
writer.println("digraph dependencies {")

for (fromId <- witnesses) {
writer.println(""" n%s[label="%s"];""".format(fromId, getName(fromId)))
for (toId <- depsGraph.getOrElse(fromId, Set.empty)) {
if (witnesses.contains(toId)) {
writer.println(""" n%s -> n%s;""".format(fromId, toId))
}
}
}
writer.println("}")
}

val writer = new PrintWriter(new FileWriter(filename, false))
try {
printToDot(writer)
filename
} finally {
writer.close()
}
}

// For every declaration ID id1, compute the list of distinct ID of the declarations that must be defined before id1
private def computeDependenciesGraph(declarations: Seq[TlaDecl]): Edges = {
// create a map from declaration names to their IDs
Expand All @@ -56,7 +92,6 @@ class DeclarationSorter extends TlaModuleTransformation {
map + (defId -> (map(defId) ++ uses))
}

val findUses = findExprUses(nameToId)
// create a map that contains the list of used-by IDs for every declaration, excluding the declaration itself
val initMap = Map(declarations.map { d => d.ID -> Set.empty[UID] }: _*)
declarations.foldLeft(initMap) {
Expand All @@ -67,11 +102,13 @@ class DeclarationSorter extends TlaModuleTransformation {
map + (d.ID -> Set.empty[UID])

case (map, d @ TlaAssumeDecl(body)) =>
val uses = (findUses(body) - d.ID)
val uses = findExprUses(nameToId)(body) - d.ID
updateDependencies(map, d.ID, uses)

case (map, d @ TlaOperDecl(_, _, body)) =>
val uses = (findUses(body) - d.ID)
case (map, d @ TlaOperDecl(_, params, body)) =>
// the operator parameters may shadow the name of top-level operator, so we have to exclude parameter names
val nameToIdWithoutParams = nameToId -- params.map(_.name)
val uses = findExprUses(nameToIdWithoutParams)(body) - d.ID
updateDependencies(map, d.ID, uses)

case (map, _) => map
Expand All @@ -86,10 +123,14 @@ class DeclarationSorter extends TlaModuleTransformation {
// A singleton with the id, if the name is registered; otherwise, empty set.
nameToId.get(name).fold(Set.empty[UID])(Set(_))

case OperEx(TlaOper.apply, NameEx(name), _*) =>
case OperEx(TlaOper.apply, NameEx(name), args @ _*) =>
// This may be an application of a user-defined operator.
// A singleton with the id, if the name is registered; otherwise, empty set.
nameToId.get(name).fold(Set.empty[UID])(Set(_))
// A singleton with the id, if the name is registered
// (that is, for a top-level definition); otherwise, return the empty set (that is, for a LET-IN).
val base = nameToId.get(name).map(Set(_)).getOrElse(Set.empty)
args.foldLeft(base) { (u, arg) =>
u ++ usesRec(arg)
}

case OperEx(_, args @ _*) =>
// join the uses of the arguments
Expand All @@ -98,9 +139,12 @@ class DeclarationSorter extends TlaModuleTransformation {
}

case LetInEx(body, decls @ _*) =>
// join the uses of the body and the declarations
// Join the uses of the body and the declarations.
// We do not track dependencies between the LET-IN operators, because they are scoped and ordered.
decls.foldLeft(usesRec(body)) { (u, d) =>
u ++ usesRec(d.body)
// the operator parameters may shadow the name of top-level operator, so we have to exclude parameter names
val nameToIdWithoutParams = nameToId -- d.formalParams.map(_.name)
u ++ findExprUses(nameToIdWithoutParams)(d.body)
}

case _ =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,28 @@ class TestDeclarationSorter extends FunSuite with BeforeAndAfterEach {
assertThrows[CyclicDependencyError](DeclarationSorter.instance(input))
}

test("regression: a cycle hidden via a call") {
// regression for #758
val foo = tla.declOp("Foo", tla.int(1))
val bar = tla.declOp("Bar", tla.appOp(tla.name("Foo"), tla.appOp(tla.name("Baz"))))
val baz = tla.declOp("Baz", tla.appOp(tla.name("Foo"), tla.appOp(tla.name("Bar"))))
val input = new TlaModule("test", List(foo, bar, baz))
assertThrows[CyclicDependencyError](DeclarationSorter.instance(input))
}

test("regression: a false cycle") {
// regression for #645

// The following two declarations do not form a cycle, as Foo uses it's parameter 'pid', not calling the operator 'pid'.
// Foo(pid) == 1
val foo = tla.declOp("Foo", tla.name("pid"), OperParam("pid"))
// pid == Foo(2)
val pid = tla.declOp("pid", tla.appOp(tla.name("Foo"), tla.int(2)))
val input = new TlaModule("test", List(foo, pid))
val expected = new TlaModule("test", List(foo, pid))
assert(expected == DeclarationSorter.instance(input))
}

test("Foo uses VARIABLE x out of order") {
val foo = tla.declOp("Foo", tla.name("x"))
val x = TlaVarDecl("x")
Expand Down