Skip to content

Commit

Permalink
[MACROS] Made macro-level IR immutable.
Browse files Browse the repository at this point in the history
Resolves #153.

Adapted the ComprehensionModel nodes so all fields are immutable.

The only place where mutability is left is in the Generator type, as
otherwise we would have to also rewrite the FoldGroupFusion optimization,
which won't be trivial and at the moment seems unnecessary.
The Generator type should be fixed as we make further progress on #147.
  • Loading branch information
aalexandrov committed Jan 22, 2016
1 parent 0927bcd commit 0422fac
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 104 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -447,16 +447,25 @@ private[emma] trait ComprehensionAnalysis
def normalizePredicates(tree: Tree)
(implicit cfGraph: CFGraph, compView: ComprehensionView) = {

for {
ComprehendedTerm(_, _, ExpressionRoot(expr), _) <- compView.terms
comprehension @ Comprehension(_, qualifiers) <- expr
} comprehension.qualifiers = qualifiers flatMap {
case Filter(ScalaExpr(x)) =>
// Normalize the tree
(x ->> deMorgan ->> distributeOrOverAnd ->> cleanConjuncts)
.collect { case Some(nf) => Filter(ScalaExpr(nf)) }

case q => q :: Nil
for (ComprehendedTerm(_, _, root @ ExpressionRoot(expr), _) <- compView.terms) {
root.expr = new ExpressionTransformer {
override def transform(expr: Expression): Expression = expr match {
case comprehension: Comprehension =>
val hd = comprehension.hd
val qualifiers = comprehension.qualifiers flatMap {
case Filter(ScalaExpr(x)) =>
// Normalize the tree
(x ->> deMorgan ->> distributeOrOverAnd ->> cleanConjuncts)
.collect { case Some(nf) => Filter(ScalaExpr(nf)) }

case q =>
q :: Nil
}
Comprehension(hd, qualifiers)
case _ =>
super.transform(expr)
}
}.transform(expr)
}

tree
Expand Down Expand Up @@ -533,7 +542,7 @@ private[emma] trait ComprehensionAnalysis
s"Unexpected structure in predicate disjunction: ${showCode(expr)}")
}

return into
into
}

def collectConjuncts(from: Tree,
Expand All @@ -554,7 +563,7 @@ private[emma] trait ComprehensionAnalysis
//throw new RuntimeException("Unexpected structure in predicate conjunction")
}

return into.toList
into.toList
}

collectConjuncts(tree) map { _.getTree }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,16 @@ private[emma] trait ComprehensionModel extends BlackBoxUtil { model =>
val DATA_BAG = typeOf[eu.stratosphere.emma.api.DataBag[Nothing]].typeConstructor
val GROUP = typeOf[eu.stratosphere.emma.api.Group[Nothing, Nothing]].typeConstructor

// --------------------------------------------------------------------------
// Comprehension Syntax
// --------------------------------------------------------------------------

// TODO: unify with ReflectUtil.syntax.let (with two distinct in methods)

case class letexpr(bindings: (Symbol, Tree)*) {
def in(expr: Expression): Expression = expr.substitute { bindings.toMap }
}

// --------------------------------------------------------------------------
// Comprehension Model
// --------------------------------------------------------------------------
Expand Down Expand Up @@ -111,16 +121,17 @@ private[emma] trait ComprehensionModel extends BlackBoxUtil { model =>
}.flatten
}

/** Substitute `find` with `replacement` in all enclosing trees. */
/** Bind a dictionary of [[Symbol]]-value pairs in all enclosing trees and assign the result to `expr`. */
def substitute(dict: Map[Symbol, Tree]): Unit =
expr = expr.substitute(dict)

/** Substitute `find` with `replacement` in all enclosing trees and assign the result to `expr`. */
def replace(find: Tree, replacement: Tree): Unit =
expr = new ExpressionTransformer {
override def xform(tree: Tree) = model.replace(tree)(find, replacement)
}.transform(expr)
expr = expr.replace(find, replacement)

/** Rename `key` as `alias` in all enclosing trees and assign the result to `expr`. */
def rename(key: Symbol, alias: TermSymbol): Unit =
expr = new ExpressionTransformer {
override def xform(tree: Tree) = model.rename(tree, key, alias)
}.transform(expr)
expr = expr.rename(key, alias)

override def toString =
prettyPrint(expr)
Expand Down Expand Up @@ -192,23 +203,41 @@ private[emma] trait ComprehensionModel extends BlackBoxUtil { model =>
cet.traverse(context)
cet.env
}

/** Bind a dictionary of [[Symbol]]-value pairs in all enclosing trees. */
def substitute(dict: Map[Symbol, Tree]): Expression =
new ExpressionTransformer {
override def xform(tree: Tree) = model.substitute(tree)(dict)
}.transform(this)

/** Substitute `find` with `replacement` in all enclosing trees. */
def replace(find: Tree, replacement: Tree): Expression =
new ExpressionTransformer {
override def xform(tree: Tree) = model.replace(tree)(find, replacement)
}.transform(this)

/** Rename `key` as `alias` in all enclosing trees. */
def rename(key: Symbol, alias: TermSymbol): Expression =
new ExpressionTransformer {
override def xform(tree: Tree) = model.rename(tree, key, alias)
}.transform(this)
}

// Monads

sealed trait MonadExpression extends Expression

case class MonadJoin(var expr: MonadExpression) extends MonadExpression {
case class MonadJoin(expr: MonadExpression) extends MonadExpression {
def tpe = DATA_BAG(expr.tpe)
def descend[U](f: Expression => U) = expr foreach f
}

case class MonadUnit(var expr: MonadExpression) extends MonadExpression {
case class MonadUnit(expr: MonadExpression) extends MonadExpression {
def tpe = DATA_BAG(expr.tpe)
def descend[U](f: Expression => U) = expr foreach f
}

case class Comprehension(var hd: Expression, var qualifiers: List[Qualifier])
case class Comprehension(hd: Expression, qualifiers: List[Qualifier])
extends MonadExpression {

def tpe = DATA_BAG(hd.tpe)
Expand All @@ -223,7 +252,7 @@ private[emma] trait ComprehensionModel extends BlackBoxUtil { model =>

sealed trait Qualifier extends Expression

case class Filter(var expr: Expression) extends Qualifier {
case class Filter(expr: Expression) extends Qualifier {
def tpe = typeOf[Boolean]
def descend[U](f: Expression => U) = expr foreach f
}
Expand All @@ -235,7 +264,7 @@ private[emma] trait ComprehensionModel extends BlackBoxUtil { model =>

// Environment & Host Language Connectors

case class ScalaExpr(var tree: Tree) extends Expression {
case class ScalaExpr(tree: Tree) extends Expression {
def tpe = tree.preciseType
def descend[U](f: Expression => U) = ()
}
Expand All @@ -261,7 +290,7 @@ private[emma] trait ComprehensionModel extends BlackBoxUtil { model =>
def descend[U](f: Expression => U) = ()
}

case class TempSink(id: TermName, var xs: Expression) extends Combinator {
case class TempSink(id: TermName, xs: Expression) extends Combinator {
val tpe = xs.tpe
def descend[U](f: Expression => U) = xs foreach f
}
Expand All @@ -277,7 +306,7 @@ private[emma] trait ComprehensionModel extends BlackBoxUtil { model =>
def descend[U](f: Expression => U) = xs foreach f
}

case class Filter(var p: Tree, var xs: Expression) extends Combinator {
case class Filter(p: Tree, xs: Expression) extends Combinator {
def tpe = xs.tpe
def descend[U](f: Expression => U) = xs foreach f
}
Expand All @@ -303,7 +332,7 @@ private[emma] trait ComprehensionModel extends BlackBoxUtil { model =>
}
}

case class Group(var key: Tree, var xs: Expression) extends Combinator {
case class Group(key: Tree, xs: Expression) extends Combinator {

def tpe = {
val K = key.preciseType.typeArgs(1)
Expand All @@ -315,14 +344,12 @@ private[emma] trait ComprehensionModel extends BlackBoxUtil { model =>
xs foreach f
}

case class Fold(var empty: Tree, var sng: Tree, var union: Tree, var xs: Expression,
var origin: Tree) extends Combinator {
case class Fold(empty: Tree, sng: Tree, union: Tree, xs: Expression, origin: Tree) extends Combinator {
def tpe = sng.preciseType.typeArgs(1) // Function[A, B]#B
def descend[U](f: Expression => U) = xs foreach f
}

case class FoldGroup(var key: Tree, var empty: Tree, var sng: Tree, var union: Tree,
var xs: Expression) extends Combinator {
case class FoldGroup(key: Tree, empty: Tree, sng: Tree, union: Tree, xs: Expression) extends Combinator {

def tpe = {
val K = key.preciseType.typeArgs(1)
Expand All @@ -334,12 +361,12 @@ private[emma] trait ComprehensionModel extends BlackBoxUtil { model =>
xs foreach f
}

case class Distinct(var xs: Expression) extends Combinator {
case class Distinct(xs: Expression) extends Combinator {
def tpe = xs.tpe
def descend[U](f: Expression => U) = xs foreach f
}

case class Union(var xs: Expression, var ys: Expression) extends Combinator {
case class Union(xs: Expression, ys: Expression) extends Combinator {

def tpe = ys.tpe

Expand All @@ -349,7 +376,7 @@ private[emma] trait ComprehensionModel extends BlackBoxUtil { model =>
}
}

case class Diff(var xs: Expression, var ys: Expression) extends Combinator {
case class Diff(xs: Expression, ys: Expression) extends Combinator {

def tpe = ys.tpe

Expand All @@ -359,7 +386,7 @@ private[emma] trait ComprehensionModel extends BlackBoxUtil { model =>
}
}

case class StatefulCreate(var xs: Expression, stateType: Type, keyType: Type)
case class StatefulCreate(xs: Expression, stateType: Type, keyType: Type)
extends Combinator {
def tpe = typeOf[Unit]
def descend[U](f: Expression => U) = xs foreach f
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,20 @@ trait ComprehensionCombination extends ComprehensionRewriteEngine {
def fire(rm: RuleMatch) = {
val RuleMatch(_, parent, gen, filter) = rm
val p = lambda(gen.lhs) { unComprehend(filter.expr) }
gen.rhs = combinator.Filter(p, gen.rhs)
parent.qualifiers = parent.qualifiers diff List(filter)
parent // return new parent
// construct new parent components
val hd = parent.hd
val qualifiers = parent.qualifiers.foldLeft(List.empty[Qualifier])((qs, q) => q match {
case _ if q == filter =>
// do not include original guard in the result
qs
case _ if q == gen =>
// wrap the rhs of the generator in a filter combinator
qs :+ gen.copy(rhs = combinator.Filter(p, gen.rhs))
case _ =>
qs :+ q
})
// return new parent
Comprehension(hd, qualifiers) // return new parent
}
}

Expand Down Expand Up @@ -149,10 +160,12 @@ trait ComprehensionCombination extends ComprehensionRewriteEngine {
def fire(rm: RuleMatch) = {
val RuleMatch(_, parent, g1, g2 @ Generator(_, rhs)) = rm
val f = lambda(g1.lhs) { unComprehend(rhs) }
parent.qualifiers = for (q <- parent.qualifiers if q != g1)
// construct new parent components
val hd = parent.hd
val qualifiers = for (q <- parent.qualifiers if q != g1)
yield if (q == g2) Generator(g2.lhs, combinator.FlatMap(f, g1.rhs)) else q

parent
// return new parent
Comprehension(hd, qualifiers)
}
}

Expand Down Expand Up @@ -207,23 +220,13 @@ trait ComprehensionCombination extends ComprehensionRewriteEngine {
val term = mk.freeTerm($"join".toString, tpe)
val qs = suffix drop 2 filter { _ != filter }

for { // substitute [v._1/x] in affected expressions
expr @ ScalaExpr(tree) <- parent
if expr.usedVars(root) contains xs.lhs
body = q"$term._1" :: xType
} expr.tree = let (xs.lhs -> body) in tree

for { // substitute [v._2/y] in affected expressions
expr @ ScalaExpr(tree) <- parent
if expr.usedVars(root) contains ys.lhs
body = q"$term._2" :: yType
} expr.tree = let (ys.lhs -> body) in tree

// modify parent qualifier list
parent.qualifiers = prefix ::: Generator(term, join) :: qs

// return the modified parent
parent
// construct new parent components
val hd = parent.hd
val qualifiers = prefix ::: Generator(term, join) :: qs
// return new parent
letexpr (
xs.lhs -> (q"$term._1" :: xType),
ys.lhs -> (q"$term._2" :: yType)) in Comprehension(hd, qualifiers)
}

private def parseJoinPredicate(root: Expression, xs: Generator, ys: Generator, p: ScalaExpr):
Expand Down Expand Up @@ -291,23 +294,13 @@ trait ComprehensionCombination extends ComprehensionRewriteEngine {
val tpe = PAIR(xs.tpe, ys.tpe)
val term = mk.freeTerm($"cross".toString, tpe)

for { // substitute [v._1/x] in affected expressions
expr @ ScalaExpr(tree) <- parent
if expr.usedVars(root) contains xs.lhs
body = q"$term._1" :: xs.tpe
} expr.tree = let (xs.lhs -> body) in tree

for { // substitute [v._2/y] in affected expressions
expr @ ScalaExpr(tree) <- parent
if expr.usedVars(root) contains ys.lhs
body = q"$term._2" :: ys.tpe
} expr.tree = let (ys.lhs -> body) in tree

// modify parent qualifier list
parent.qualifiers = prefix ::: Generator(term, cross) :: suffix.drop(2)

// return the modified parent
parent
// construct new parent components
val hd = parent.hd
val qualifiers = prefix ::: Generator(term, cross) :: suffix.drop(2)
// return new parent
letexpr (
xs.lhs -> (q"$term._1" :: xs.tpe),
ys.lhs -> (q"$term._2" :: ys.tpe)) in Comprehension(hd, qualifiers)
}
}

Expand Down
Loading

0 comments on commit 0422fac

Please sign in to comment.