Skip to content

Commit

Permalink
Merge pull request #688 from JuliaSymbolics/b/cse-dag
Browse files Browse the repository at this point in the history
Optimize CSE: Transition to DAG Representation with Hash Consing for Faster Equality Checks
  • Loading branch information
ChrisRackauckas authored Jan 23, 2025
2 parents 4b0d24e + 9c0c27c commit d4d2d9a
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 45 deletions.
98 changes: 56 additions & 42 deletions src/code.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module Code

using StaticArrays, SparseArrays, LinearAlgebra, NaNMath, SpecialFunctions
using StaticArrays, SparseArrays, LinearAlgebra, NaNMath, SpecialFunctions,
DocStringExtensions

export toexpr, Assignment, (), Let, Func, DestructuredArgs, LiteralExpr,
SetArray, MakeArray, MakeSparseArray, MakeTuple, AtIndex,
Expand Down Expand Up @@ -696,6 +697,52 @@ end

@inline newsym(::Type{T}) where T = Sym{T}(gensym("cse"))

"""
$(SIGNATURES)
Perform a topological sort on a symbolic expression represented as a Directed Acyclic
Graph (DAG).
This function takes a symbolic expression `graph` (potentially containing shared common
sub-expressions) and returns an array of `Assignment` objects. Each `Assignment`
represents a node in the sorted order, assigning a fresh symbol to its corresponding
expression. The order ensures that all dependencies of a node appear before the node itself
in the array.
Hash consing is assumed, meaning that structurally identical expressions are represented by
the same object in memory. This allows for efficient equality checks using `IdDict`.
"""
function topological_sort(graph)
sorted_nodes = Assignment[]
visited = IdDict()

function dfs(node)
if haskey(visited, node)
return visited[node]
end
if iscall(node)
args = map(dfs, arguments(node))
new_node = maketerm(typeof(node), operation(node), args, metadata(node))
sym = newsym(symtype(new_node))
push!(sorted_nodes, sym new_node)
visited[node] = sym
return sym
elseif _is_array_of_symbolics(node)
new_node = map(dfs, node)
sym = newsym(typeof(new_node))
push!(sorted_nodes, sym new_node)
visited[node] = sym
return sym
else
visited[node] = node
return node
end
end

dfs(graph)
return sorted_nodes
end

function _cse!(mem, expr)
iscall(expr) || return expr
op = _cse!(mem, operation(expr))
Expand All @@ -714,12 +761,16 @@ function _cse!(mem, expr)
end

function cse(expr)
state = Dict{Any, Int}()
cse_state!(state, expr)
cse_block(state, expr)
sorted_nodes = topological_sort(expr)
if isempty(sorted_nodes)
return Let(Assignment[], expr)
else
last_assignment = pop!(sorted_nodes)
body = rhs(last_assignment)
return Let(sorted_nodes, body)
end
end


function _cse(exprs::AbstractArray)
letblock = cse(Term{Any}(tuple, vec(exprs)))
letblock.pairs, reshape(arguments(letblock.body), size(exprs))
Expand All @@ -746,41 +797,4 @@ function cse(x::MakeSparseArray)
end
end


function cse_state!(state, t)
!iscall(t) && return t
state[t] = Base.get(state, t, 0) + 1
foreach(x->cse_state!(state, x), arguments(t))
end

function cse_block!(assignments, counter, names, name, state, x)
if get(state, x, 0) > 1
if haskey(names, x)
return names[x]
else
sym = Sym{symtype(x)}(Symbol(name, counter[]))
names[x] = sym
push!(assignments, sym x)
counter[] += 1
return sym
end
elseif iscall(x)
args = map(a->cse_block!(assignments, counter, names, name, state,a), arguments(x))
if isterm(x)
return term(operation(x), args...)
else
return maketerm(typeof(x), operation(x), args, metadata(x))
end
else
return x
end
end

function cse_block(state, t, name=Symbol("var-", hash(t)))
assignments = Assignment[]
counter = Ref{Int}(1)
names = Dict{Any, BasicSymbolic}()
Let(assignments, cse_block!(assignments, counter, names, name, state, t))
end

end
52 changes: 49 additions & 3 deletions test/cse.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,56 @@
using SymbolicUtils, SymbolicUtils.Code, Test
using SymbolicUtils.Code: topological_sort

@testset "CSE" begin
@syms x
t = cse(hypot(hypot(cos(x), sin(x)), atan(cos(x), sin(x))))

@test t isa Let
@test length(t.pairs) == 2
@test occursin(t.pairs[1].lhs, t.body)
@test occursin(t.pairs[2].lhs, t.body)
@test length(t.pairs) == 4
@test occursin(t.pairs[3].lhs, t.body)
@test occursin(t.pairs[4].lhs, t.body)
end

@testset "DAG CSE" begin
@syms a b
expr = sin(a + b) * (a + b)
sorted_nodes = topological_sort(expr)
@test length(sorted_nodes) == 3
@test isequal(sorted_nodes[1].rhs, a + b)
@test isequal(sin(sorted_nodes[1].lhs), sorted_nodes[2].rhs)

expr = (a + b)^(a + b)
sorted_nodes = topological_sort(expr)
@test length(sorted_nodes) == 2
@test isequal(sorted_nodes[1].rhs, a + b)
ab_node = sorted_nodes[1].lhs
@test isequal(ab_node^ab_node, sorted_nodes[2].rhs)
let_expr = cse(expr)
@test length(let_expr.pairs) == 1
@test isequal(let_expr.pairs[1].rhs, a + b)
corresponding_sym = let_expr.pairs[1].lhs
@test isequal(let_expr.body, corresponding_sym^corresponding_sym)

expr = a + b
sorted_nodes = topological_sort(expr)
@test length(sorted_nodes) == 1
@test isequal(sorted_nodes[1].rhs, a + b)
let_expr = cse(expr)
@test isempty(let_expr.pairs)
@test isequal(let_expr.body, a + b)

expr = a
sorted_nodes = topological_sort(expr)
@test isempty(sorted_nodes)
let_expr = cse(expr)
@test isempty(let_expr.pairs)
@test isequal(let_expr.body, a)

# array symbolics
# https://github.com/JuliaSymbolics/SymbolicUtils.jl/pull/688#pullrequestreview-2554931739
@syms c
function foo end
ex = term(foo, [a^2 + b^2, b^2 + c], c; type = Real)
sorted_nodes = topological_sort(ex)
@test length(sorted_nodes) == 6
end

0 comments on commit d4d2d9a

Please sign in to comment.