Skip to content

Commit

Permalink
optimize factorExpression
Browse files Browse the repository at this point in the history
  • Loading branch information
gusiri committed Feb 6, 2025
1 parent 184771b commit eced66b
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 30 deletions.
13 changes: 11 additions & 2 deletions prover/protocol/compiler/globalcs/factoring.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
"io"
"reflect"
"sync"

"github.com/consensys/linea-monorepo/prover/protocol/accessors"
"github.com/consensys/linea-monorepo/prover/protocol/serialization"
Expand All @@ -13,12 +14,20 @@ import (
"github.com/consensys/linea-monorepo/prover/utils"
)

// factorExpressionList applies [factorExpression] over a list of expression
// factorExpressionList applies [factorExpression] over a list of expressions
func factorExpressionList(comp *wizard.CompiledIOP, exprList []*symbolic.Expression) []*symbolic.Expression {
res := make([]*symbolic.Expression, len(exprList))
var wg sync.WaitGroup

for i, expr := range exprList {
res[i] = factorExpression(comp, expr)
wg.Add(1)
go func(i int, expr *symbolic.Expression) {
defer wg.Done()
res[i] = factorExpression(comp, expr)
}(i, expr)
}

wg.Wait()
return res
}

Expand Down
11 changes: 10 additions & 1 deletion prover/symbolic/expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package symbolic
import (
"fmt"
"reflect"
"sync"

"github.com/consensys/gnark/frontend"
"github.com/consensys/linea-monorepo/prover/maths/common/mempool"
Expand Down Expand Up @@ -244,9 +245,17 @@ func (e *Expression) ReconstructBottomUp(
// LinComb or Product or PolyEval. This is an intermediate expression.
case LinComb, Product, PolyEval:
children := make([]*Expression, len(e.Children))
var wg sync.WaitGroup
wg.Add(len(e.Children))

for i, c := range e.Children {
children[i] = c.ReconstructBottomUp(constructor)
go func(i int, c *Expression) {
defer wg.Done()
children[i] = c.ReconstructBottomUp(constructor)
}(i, c)
}

wg.Wait()
return constructor(e, children)
}

Expand Down
16 changes: 14 additions & 2 deletions prover/symbolic/simplify/cost_stat.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package simplify

import (
"math/bits"
"sync"

sym "github.com/consensys/linea-monorepo/prover/symbolic"
)
Expand All @@ -21,10 +22,21 @@ func (s *costStats) add(cost costStats) {
// Returns the cost stats of a boarded expression
func evaluateCostStat(expr *sym.Expression) (s costStats) {
board := expr.Board()
var wg sync.WaitGroup
var mu sync.Mutex

for i := 1; i < len(board.Nodes); i++ {
s_ := evaluateNodeCosts(board.Nodes[i]...)
s.add(s_)
wg.Add(1)
go func(nodes []sym.Node) {
defer wg.Done()
s_ := evaluateNodeCosts(nodes...)
mu.Lock()
s.add(s_)
mu.Unlock()
}(board.Nodes[i])
}
wg.Wait()

return s
}

Expand Down
31 changes: 15 additions & 16 deletions prover/symbolic/simplify/factor.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"math"
"sort"
"sync"

"github.com/consensys/linea-monorepo/prover/maths/field"
sym "github.com/consensys/linea-monorepo/prover/symbolic"
Expand All @@ -17,7 +18,8 @@ import (
func factorizeExpression(expr *sym.Expression, iteration int) *sym.Expression {
res := expr
initEsh := expr.ESHash
alreadyWalked := map[field.Element]*sym.Expression{}
alreadyWalked := sync.Map{}
factorMemo := sync.Map{}

logrus.Infof("factoring expression : init stats %v", evaluateCostStat(expr))

Expand All @@ -26,10 +28,9 @@ func factorizeExpression(expr *sym.Expression, iteration int) *sym.Expression {
scoreInit := evaluateCostStat(res)

res = res.ReconstructBottomUp(func(lincomb *sym.Expression, newChildren []*sym.Expression) *sym.Expression {

// Time save, we reuse the results we got for that particular node.
if ret, ok := alreadyWalked[lincomb.ESHash]; ok {
return ret
if ret, ok := alreadyWalked.Load(lincomb.ESHash); ok {
return ret.(*sym.Expression)
}

// Incorporate the new children inside of the expression to account
Expand All @@ -55,31 +56,29 @@ func factorizeExpression(expr *sym.Expression, iteration int) *sym.Expression {

group := findGdChildrenGroup(new)

logrus.Tracef("found children group: %v\n", group)

if len(group) < 1 {
if k > 0 {
logrus.Tracef("finished factoring : %v opportunities", k)
}
return new
}

logrus.Tracef(
"factoring an expression with a set of %v siblings",
len(group),
)
// Memoize the factorLinCompFromGroup result
cacheKey := fmt.Sprintf("%v-%v", new.ESHash, group)

new = factorLinCompFromGroup(new, group)
if cachedResult, ok := factorMemo.Load(cacheKey); ok {
new = cachedResult.(*sym.Expression)

} else {
new = factorLinCompFromGroup(new, group)
factorMemo.Store(cacheKey, new)
}

if len(new.Children) >= prevSize {
logrus.Tracef("factorization did not help. stopping")
return new
}

prevSize = len(new.Children)
}

logrus.Tracef("finished factoring slow node")
alreadyWalked.Store(new.ESHash, new)
return new
})

Expand Down
25 changes: 16 additions & 9 deletions prover/symbolic/simplify/rmpolyeval.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,27 @@ func removePolyEval(e *sym.Expression) *sym.Expression {
x := newChildren[0]
cs := newChildren[1:]

if len(cs) == 0 {
return oldExpr // Handle edge case where there are no coefficients
}

acc := cs[0]
xPowi := x

// Precompute powers of x
powersOfX := make([]*sym.Expression, len(cs))
powersOfX[0] = x
for i := 1; i < len(cs); i++ {
// We don't use the default constructor because it will collapse the
// intermediate terms into a single term. The intermediates are useful because
// they tell the evaluator to reuse the intermediate terms instead of
// computing x^i for every term.
powersOfX[i] = sym.NewProduct([]*sym.Expression{powersOfX[i-1], x}, []int{1, 1})
}

for i := 1; i < len(cs); i++ {
// Here we want to use the default constructor to ensure that we
// will have a merged sum at the end.
acc = sym.Add(acc, sym.Mul(xPowi, cs[i]))
if i+1 < len(cs) {
// We don't use the default construct because it will collapse the
// xPowi into a single term. The intermediate are useful because
// it tells the evaluator to reuse the intermediate terms instead of
// computing x^i for every term.
xPowi = sym.NewProduct([]*sym.Expression{xPowi, x}, []int{1, 1})
}
acc = sym.Add(acc, sym.Mul(powersOfX[i-1], cs[i]))
}

if oldExpr.ESHash != acc.ESHash {
Expand Down

0 comments on commit eced66b

Please sign in to comment.