Skip to content

Commit

Permalink
Promote Desugarer to be the first pass (#536)
Browse files Browse the repository at this point in the history
closes #531
  • Loading branch information
konnov authored Feb 3, 2021
1 parent b7d8974 commit f9cd507
Show file tree
Hide file tree
Showing 10 changed files with 427 additions and 296 deletions.
2 changes: 2 additions & 0 deletions UNRELEASED.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@
* new command-line option for `typecheck`:
- enable inference of polymorphic types: `--infer-poly`
* updates to ADR002 and the manual
* support for parallel assignments `<<x', y'>> = <<1, 2>>`, see #531

### Bugfixes

* Boolean values are now supported in TLC config files, see #512
* Promoting Desugarer to run as the first preprocessing pass, see #531
* Proper error on invalid type annotations, the parser is strengthened with Scalacheck, see #332
* Typechecking quantifiers over tuples, see #482
* Fixed a parsing bug for strings that contain '-', see #539
12 changes: 12 additions & 0 deletions test/tla/Fix531.tla
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
-------------------------------- MODULE Fix531 --------------------------------
EXTENDS Integers

VARIABLE f

Init ==
f = [<<a, b>> \in { <<1, 3>>, <<2, 4>> } |-> a + b]

Next ==
UNCHANGED f

===============================================================================
8 changes: 8 additions & 0 deletions test/tla/cli-integration-tests.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,14 @@ EXITCODE: OK

## running the check command

### check Fix531.tla reports no error: regression for issue 531

```sh
$ apalache-mc check --length=1 Fix531.tla | sed 's/I@.*//'
...
The outcome is: NoError
...
```

### check UnchangedExpr471.tla reports no error: regression for issue 471

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,17 @@ class CheckerModule extends AbstractModule {
bind(classOf[Pass])
.annotatedWith(Names.named("AfterParser"))
.to(classOf[ConfigurationPass])
// the next pass is DesugarerPass
bind(classOf[DesugarerPass])
.to(classOf[DesugarerPassImpl])
bind(classOf[Pass])
.annotatedWith(Names.named("AfterConfiguration"))
.to(classOf[DesugarerPass])
// the next pass is UnrollPass
bind(classOf[UnrollPass])
.to(classOf[UnrollPassImpl])
bind(classOf[Pass])
.annotatedWith(Names.named("AfterConfiguration"))
.annotatedWith(Names.named("AfterDesugarer"))
.to(classOf[UnrollPass])
// the next pass is PrimingPass
bind(classOf[PrimingPass])
Expand Down
238 changes: 129 additions & 109 deletions tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/Desugarer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,109 +2,121 @@ package at.forsyte.apalache.tla.pp

import at.forsyte.apalache.tla.lir._
import at.forsyte.apalache.tla.lir.convenience._
import at.forsyte.apalache.tla.lir.oper.{TlaActionOper, TlaBoolOper, TlaFunOper, TlaOper, TlaSetOper}
import at.forsyte.apalache.tla.lir.transformations.standard.FlatLanguagePred
import at.forsyte.apalache.tla.lir.transformations.{LanguageWatchdog, TlaExTransformation, TransformationTracker}
import at.forsyte.apalache.tla.lir.oper._
import at.forsyte.apalache.tla.lir.transformations.{TlaExTransformation, TransformationTracker}

import javax.inject.Singleton

/**
* <p>Remove all annoying syntactic sugar. In the future we should move most of the pre-processing code to this class,
* unless it really changes the expressive power.</p>
*
* <p>This transformation assumes that all operator definitions and internal
* let-definitions have been inlined.</p>
*
* <p>TODO: can we make transformation tracking more precise?</p>
*
* @author Igor Konnov
*/
* <p>Remove all annoying syntactic sugar.</p>
*
* @author Igor Konnov
*/
@Singleton
class Desugarer(tracker: TransformationTracker) extends TlaExTransformation {

override def apply(expr: TlaEx): TlaEx = {
LanguageWatchdog(FlatLanguagePred()).check(expr)
transform(expr)
}

def transform: TlaExTransformation = tracker.trackEx {
case ex @ NameEx(_) => ex
case ex @ ValEx(_) => ex
case ex @ NullEx => ex

case OperEx(TlaFunOper.except, fun, args @ _*) =>
val trArgs = args map transform
val (accessors, newValues) = TlaOper.deinterleave(trArgs)
val nonSingletons = accessors.collect { case OperEx(TlaFunOper.tuple, lst @ _*) => lst.size > 1 }
if (nonSingletons.isEmpty) {
// only singleton tuples, construct the same EXCEPT, but with transformed fun and args
OperEx(TlaFunOper.except, transform(fun) +: trArgs :_*)
} else {
// multiple accesses, e.g., ![i][j] = ...
expandExcept(transform(fun), accessors, newValues)
}
case ex @ NameEx(_) => ex
case ex @ ValEx(_) => ex
case ex @ NullEx => ex

case OperEx(TlaFunOper.except, fun, args @ _*) =>
val trArgs = args map transform
val (accessors, newValues) = TlaOper.deinterleave(trArgs)
val nonSingletons = accessors.collect { case OperEx(TlaFunOper.tuple, lst @ _*) => lst.size > 1 }
if (nonSingletons.isEmpty) {
// only singleton tuples, construct the same EXCEPT, but with transformed fun and args
OperEx(TlaFunOper.except, transform(fun) +: trArgs: _*)
} else {
// multiple accesses, e.g., ![i][j] = ...
expandExcept(transform(fun), accessors, newValues)
}

case OperEx(TlaActionOper.unchanged, args @ _*) =>
// flatten all tuples, e.g., convert <<x, <<y, z>> >> to [x, y, z]
val flatArgs = flattenTuplesInUnchanged(tla.tuple(args.map(transform) :_*))
// map every x to x' = x
val eqs = flatArgs map { x => tla.eql(tla.prime(x), x) }
// x' = x /\ y' = y /\ z' = z
eqs match {
case Seq() =>
// results from UNCHANGED <<>>, UNCHANGED << << >> >>, etc.
tla.bool(true)

case Seq(one) =>
one

case _ =>
tla.and(eqs: _*)
}
case OperEx(TlaActionOper.unchanged, args @ _*) =>
// flatten all tuples, e.g., convert <<x, <<y, z>> >> to [x, y, z]
val flatArgs = flattenTuplesInUnchanged(tla.tuple(args.map(transform): _*))
// map every x to x' = x
val eqs = flatArgs map { x => tla.eql(tla.prime(x), x) }
// x' = x /\ y' = y /\ z' = z
eqs match {
case Seq() =>
// results from UNCHANGED <<>>, UNCHANGED << << >> >>, etc.
tla.bool(true)

case Seq(one) =>
one

case _ =>
tla.and(eqs: _*)
}

case OperEx(TlaSetOper.filter, boundEx, setEx, predEx) =>
// rewrite { <<x, <<y, z>> >> \in XYZ: P(x, y, z) }
// to { x_y_z \in XYZ: P(x_y_z[1], x_y_z[1][1], x_y_z[1][2]) }
OperEx(TlaSetOper.filter, collapseTuplesInFilter(transform(boundEx), transform(setEx), transform(predEx)) :_*)

case OperEx(TlaBoolOper.exists, boundEx, setEx, predEx) =>
// rewrite \E <<x, <<y, z>> >> \in XYZ: P(x, y, z)
// to \E x_y_z \in XYZ: P(x_y_z[1], x_y_z[1][1], x_y_z[1][2])
OperEx(TlaBoolOper.exists, collapseTuplesInFilter(transform(boundEx), transform(setEx), transform(predEx)) :_*)

case OperEx(TlaBoolOper.forall, boundEx, setEx, predEx) =>
// rewrite \A <<x, <<y, z>> >> \in XYZ: P(x, y, z)
// to \A x_y_z \in XYZ: P(x_y_z[1], x_y_z[1][1], x_y_z[1][2])
OperEx(TlaBoolOper.forall, collapseTuplesInFilter(transform(boundEx), transform(setEx), transform(predEx)) :_*)

case OperEx(TlaSetOper.map, args @ _*) =>
// rewrite { <<x, <<y, z>> >> \in XYZ |-> e(x, y, z) }
// to { x_y_z \in XYZ |-> e(x_y_z[1], x_y_z[1][1], x_y_z[1][2])
val trArgs = args map transform
OperEx(TlaSetOper.map, collapseTuplesInMap(trArgs.head, trArgs.tail) :_*)

case OperEx(funDefOp, args @ _*) if (funDefOp == TlaFunOper.funDef || funDefOp == TlaFunOper.recFunDef) =>
val trArgs = args map transform
val fun = trArgs.head
val (vars, sets) = TlaOper.deinterleave(trArgs.tail)
val (onlyVar, onlySet) =
if (vars.length > 1) {
val pair = (tla.tuple(vars :_*), tla.times(sets :_*))
// track the modification to point to the first variable and set
tracker.hold(vars.head, pair._1)
tracker.hold(sets.head, pair._2)
pair
} else {
(vars.head, sets.head)
}
// transform the function into a single-argument function and collapse tuples
OperEx(funDefOp, collapseTuplesInMap(fun, Seq(onlyVar, onlySet)) :_*)
case OperEx(TlaOper.eq, OperEx(TlaFunOper.tuple, largs @ _*), OperEx(TlaFunOper.tuple, rargs @ _*)) =>
// <<e_1, ..., e_k>> = <<f_1, ..., f_n>>
// produce pairwise comparison
if (largs.length != rargs.length) {
tla.bool(false)
} else {
val eqs = largs.zip(rargs) map { case (l, r) => tla.eql(this(l), this(r)) }
tla.and(eqs: _*)
}

case OperEx(op, args @ _*) =>
OperEx(op, args map transform :_*)
case OperEx(TlaOper.ne, OperEx(TlaFunOper.tuple, largs @ _*), OperEx(TlaFunOper.tuple, rargs @ _*)) =>
// <<e_1, ..., e_k>> /= <<f_1, ..., f_n>>
// produce pairwise comparison
if (largs.length != rargs.length) {
tla.bool(true)
} else {
val neqs = largs.zip(rargs) map { case (l, r) => tla.neql(this(l), this(r)) }
tla.or(neqs: _*)
}

case OperEx(TlaSetOper.filter, boundEx, setEx, predEx) =>
// rewrite { <<x, <<y, z>> >> \in XYZ: P(x, y, z) }
// to { x_y_z \in XYZ: P(x_y_z[1], x_y_z[1][1], x_y_z[1][2]) }
OperEx(TlaSetOper.filter, collapseTuplesInFilter(transform(boundEx), transform(setEx), transform(predEx)): _*)

case OperEx(TlaBoolOper.exists, boundEx, setEx, predEx) =>
// rewrite \E <<x, <<y, z>> >> \in XYZ: P(x, y, z)
// to \E x_y_z \in XYZ: P(x_y_z[1], x_y_z[1][1], x_y_z[1][2])
OperEx(TlaBoolOper.exists, collapseTuplesInFilter(transform(boundEx), transform(setEx), transform(predEx)): _*)

case OperEx(TlaBoolOper.forall, boundEx, setEx, predEx) =>
// rewrite \A <<x, <<y, z>> >> \in XYZ: P(x, y, z)
// to \A x_y_z \in XYZ: P(x_y_z[1], x_y_z[1][1], x_y_z[1][2])
OperEx(TlaBoolOper.forall, collapseTuplesInFilter(transform(boundEx), transform(setEx), transform(predEx)): _*)

case OperEx(TlaSetOper.map, args @ _*) =>
// rewrite { <<x, <<y, z>> >> \in XYZ |-> e(x, y, z) }
// to { x_y_z \in XYZ |-> e(x_y_z[1], x_y_z[1][1], x_y_z[1][2])
val trArgs = args map transform
OperEx(TlaSetOper.map, collapseTuplesInMap(trArgs.head, trArgs.tail): _*)

case OperEx(funDefOp, args @ _*) if (funDefOp == TlaFunOper.funDef || funDefOp == TlaFunOper.recFunDef) =>
val trArgs = args map transform
val fun = trArgs.head
val (vars, sets) = TlaOper.deinterleave(trArgs.tail)
val (onlyVar, onlySet) =
if (vars.length > 1) {
val pair = (tla.tuple(vars: _*), tla.times(sets: _*))
// track the modification to point to the first variable and set
tracker.hold(vars.head, pair._1)
tracker.hold(sets.head, pair._2)
pair
} else {
(vars.head, sets.head)
}
// transform the function into a single-argument function and collapse tuples
OperEx(funDefOp, collapseTuplesInMap(fun, Seq(onlyVar, onlySet)): _*)

case OperEx(op, args @ _*) =>
OperEx(op, args map transform: _*)

case LetInEx( body, defs@_* ) =>
LetInEx( transform( body ), defs map { d => d.copy( body = transform( d.body ) ) } : _* )
case LetInEx(body, defs @ _*) =>
LetInEx(transform(body), defs map { d => d.copy(body = transform(d.body)) }: _*)
}

private def flattenTuplesInUnchanged(ex: TlaEx): Seq[TlaEx] = ex match {
Expand All @@ -117,7 +129,7 @@ class Desugarer(tracker: TransformationTracker) extends TlaExTransformation {
}

case ValEx(_) =>
Seq() // no point in priming literals
Seq() // no point in priming literals

case _ =>
// in general, UNCHANGED e becomes e' = e
Expand All @@ -129,10 +141,10 @@ class Desugarer(tracker: TransformationTracker) extends TlaExTransformation {
def unfoldKey(indicesInPrefix: Seq[TlaEx], indicesInSuffix: Seq[TlaEx], newValue: TlaEx): TlaEx = {
// produce [f[i_1]...[i_m] EXCEPT ![i_m+1] = unfoldKey(...) ]
indicesInSuffix match {
case Nil => newValue // nothing to unfold, just return g
case Nil => newValue // nothing to unfold, just return g
case oneMoreIndex +: otherIndices =>
// f[i_1]...[i_m]
val funApp = indicesInPrefix.foldLeft(topFun) ((f, i) => tla.appFun(f, i))
val funApp = indicesInPrefix.foldLeft(topFun)((f, i) => tla.appFun(f, i))
// the recursive call defines another chain of EXCEPTS
val rhs = unfoldKey(indicesInPrefix :+ oneMoreIndex, otherIndices, newValue)
OperEx(TlaFunOper.except, funApp, tla.tuple(oneMoreIndex), rhs)
Expand All @@ -149,16 +161,17 @@ class Desugarer(tracker: TransformationTracker) extends TlaExTransformation {
}
val expandedPairs = accessors.zip(newValues).map((eachPair _).tupled)
val expandedArgs = (TlaOper.interleave _).tupled(expandedPairs.unzip)
OperEx(TlaFunOper.except, topFun +: expandedArgs :_*)
OperEx(TlaFunOper.except, topFun +: expandedArgs: _*)
}

/**
* Transform filter expressions like {<< x, y >> \in S: x = 1} to { x_y \in S: x_y[1] = 1 }
* @param boundEx a bound expression, e.g., x or << x, y >>
* @param setEx a set expression, e.g., S
* @param predEx a predicate expression, e.g., x == 1
* @return transformed arguments
*/
* Transform filter expressions like {<< x, y >> \in S: x = 1} to { x_y \in S: x_y[1] = 1 }
*
* @param boundEx a bound expression, e.g., x or << x, y >>
* @param setEx a set expression, e.g., S
* @param predEx a predicate expression, e.g., x == 1
* @return transformed arguments
*/
def collapseTuplesInFilter(boundEx: TlaEx, setEx: TlaEx, predEx: TlaEx): Seq[TlaEx] = {
val boundName = mkTupleName(boundEx) // rename a tuple into a name, if needed
// variable substitutions for the variables inside the tuples
Expand All @@ -168,11 +181,12 @@ class Desugarer(tracker: TransformationTracker) extends TlaExTransformation {
}

/**
* Transform filter expressions like {x : << x, y >> \in S} to { x_y[1] : x_y \in S }
* @param mapEx the mapping, e.g., x
* @param args bindings and sets
* @return transformed arguments
*/
* Transform filter expressions like {x : << x, y >> \in S} to { x_y[1] : x_y \in S }
*
* @param mapEx the mapping, e.g., x
* @param args bindings and sets
* @return transformed arguments
*/
def collapseTuplesInMap(mapEx: TlaEx, args: Seq[TlaEx]): Seq[TlaEx] = {
val (boundEs, setEs) = TlaOper.deinterleave(args)
val boundNames = boundEs map mkTupleName // rename tuples into a names, if needed
Expand All @@ -190,12 +204,17 @@ class Desugarer(tracker: TransformationTracker) extends TlaExTransformation {
case OperEx(TlaFunOper.tuple, _*) =>
val tupleName = mkTupleName(ex) // introduce a name, e.g., x_y_z for <<x, <<y, z>> >>
val indices = assignIndicesInTuple(Map(), ex, Seq())

def indexToTlaEx(index: Seq[Int]): TlaEx = {
index.foldLeft(tla.name(tupleName): TlaEx) { (e, i) => tla.appFun(e, tla.int(i)) }
index.foldLeft(tla.name(tupleName): TlaEx) { (e, i) =>
tla.appFun(e, tla.int(i))
}
}

// map every variable inside the tuple to a tuple access, e.g., x -> x_y_z[1] and z -> x_y_z[1][2]
indices.foldLeft(subs) { (m, p) => m + (p._1 -> indexToTlaEx(p._2))}
indices.foldLeft(subs) { (m, p) =>
m + (p._1 -> indexToTlaEx(p._2))
}

case _ =>
throw new IllegalArgumentException("Unexpected %s among set filter parameters".format(ex))
Expand All @@ -211,7 +230,8 @@ class Desugarer(tracker: TransformationTracker) extends TlaExTransformation {
(args map mkTupleName) mkString "_"

case _ =>
throw new IllegalArgumentException("Unexpected %s among set filter parameters".format(ex)) }
throw new IllegalArgumentException("Unexpected %s among set filter parameters".format(ex))
}
}

private def assignIndicesInTuple(map: Map[String, Seq[Int]], ex: TlaEx, myIndex: Seq[Int]): Map[String, Seq[Int]] = {
Expand All @@ -234,12 +254,12 @@ class Desugarer(tracker: TransformationTracker) extends TlaExTransformation {
def rename(e: TlaEx): TlaEx = e match {
case NameEx(name) => if (!subs.contains(name)) e else subs(name)

case LetInEx( body, defs@_* ) =>
val newDefs = defs.map( d => d.copy( body = rename( d.body ) ) )
LetInEx( rename( body ), newDefs : _* )
case LetInEx(body, defs @ _*) =>
val newDefs = defs.map(d => d.copy(body = rename(d.body)))
LetInEx(rename(body), newDefs: _*)

case OperEx(op, args @ _*) =>
OperEx(op, args map rename :_*)
OperEx(op, args map rename: _*)

case _ => e
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package at.forsyte.apalache.tla.pp.passes

import at.forsyte.apalache.infra.passes.{Pass, TlaModuleMixin}

/**
* A pass that does TLA+ desugaring.
*
* @author Igor Konnov
*/
trait DesugarerPass extends Pass with TlaModuleMixin
Loading

0 comments on commit f9cd507

Please sign in to comment.