From 081e4aeadf9f371727315b8cda15f24a3b2a92dc Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Thu, 9 Jan 2025 16:35:25 +0800 Subject: [PATCH 01/11] Topological sort directed acyclic graph --- src/code.jl | 25 +++++++++++++++++++++++++ test/cse.jl | 10 ++++++++++ 2 files changed, 35 insertions(+) diff --git a/src/code.jl b/src/code.jl index 8b1953b20..f47564231 100644 --- a/src/code.jl +++ b/src/code.jl @@ -696,6 +696,31 @@ end @inline newsym(::Type{T}) where T = Sym{T}(gensym("cse")) +function topological_sort(graph) + sorted_nodes = Assignment[] + visited = IdDict() + + function dfs(node) + if haskey(visited, node) + return visited[node] + end + if iscall(node) + args = map(dfs, arguments(node)) + new_node = maketerm(typeof(node), operation(node), args, metadata(node)) + sym = newsym(symtype(new_node)) + push!(sorted_nodes, sym ← new_node) + visited[node] = sym + return sym + else + visited[node] = node + return node + end + end + + dfs(graph) + return sorted_nodes +end + function _cse!(mem, expr) iscall(expr) || return expr op = _cse!(mem, operation(expr)) diff --git a/test/cse.jl b/test/cse.jl index fcc36d47d..95795f181 100644 --- a/test/cse.jl +++ b/test/cse.jl @@ -8,3 +8,13 @@ using SymbolicUtils, SymbolicUtils.Code, Test @test occursin(t.pairs[1].lhs, t.body) @test occursin(t.pairs[2].lhs, t.body) end + +@testset "DAG CSE" begin + @syms a b + expr = sin(a + b) * (a + b) + sorted_nodes = topological_sort(expr) + @test length(sorted_nodes) == 3 + expr = (a + b)^(a + b) + sorted_nodes = topological_sort(expr) + @test length(sorted_nodes) == 2 +end From 3530de66487dc61b7729e0787a2dfa0827c40401 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Thu, 9 Jan 2025 16:41:46 +0800 Subject: [PATCH 02/11] Refactor CSE by leveraging DAG structure --- src/code.jl | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/code.jl b/src/code.jl index f47564231..2e94dfa31 100644 --- a/src/code.jl +++ b/src/code.jl @@ -739,12 +739,16 @@ function _cse!(mem, expr) end function cse(expr) - state = Dict{Any, Int}() - cse_state!(state, expr) - cse_block(state, expr) + sorted_nodes = topological_sort(expr) + if isempty(sorted_nodes) + return Let(Assignment[], expr) + else + last_assignment = pop!(sorted_nodes) + body = rhs(last_assignment) + return Let(sorted_nodes, body) + end end - function _cse(exprs::AbstractArray) letblock = cse(Term{Any}(tuple, vec(exprs))) letblock.pairs, reshape(arguments(letblock.body), size(exprs)) From 5f1f8ffffa20cc94dd4a91cb60d334f4a96a78d6 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Thu, 9 Jan 2025 16:50:49 +0800 Subject: [PATCH 03/11] Import `topological_sort` in CSE test --- test/cse.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/cse.jl b/test/cse.jl index 95795f181..b2ef8132f 100644 --- a/test/cse.jl +++ b/test/cse.jl @@ -1,4 +1,6 @@ using SymbolicUtils, SymbolicUtils.Code, Test +using SymbolicUtils.Code: topological_sort + @testset "CSE" begin @syms x t = cse(hypot(hypot(cos(x), sin(x)), atan(cos(x), sin(x)))) From b38a2a613b40c0460bc86773b7120df03d462223 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Thu, 9 Jan 2025 17:13:48 +0800 Subject: [PATCH 04/11] Fix CSE tests due to new DAG implementation --- test/cse.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/cse.jl b/test/cse.jl index b2ef8132f..103f5ac6f 100644 --- a/test/cse.jl +++ b/test/cse.jl @@ -6,9 +6,9 @@ using SymbolicUtils.Code: topological_sort t = cse(hypot(hypot(cos(x), sin(x)), atan(cos(x), sin(x)))) @test t isa Let - @test length(t.pairs) == 2 - @test occursin(t.pairs[1].lhs, t.body) - @test occursin(t.pairs[2].lhs, t.body) + @test length(t.pairs) == 4 + @test occursin(t.pairs[3].lhs, t.body) + @test occursin(t.pairs[4].lhs, t.body) end @testset "DAG CSE" begin From eaf4446cf26c351b902ea1aeb4a320df164dc16d Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Thu, 9 Jan 2025 21:20:47 +0800 Subject: [PATCH 05/11] Add tests for CSE DAG --- test/cse.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/test/cse.jl b/test/cse.jl index 103f5ac6f..abaf8e83e 100644 --- a/test/cse.jl +++ b/test/cse.jl @@ -16,7 +16,12 @@ end expr = sin(a + b) * (a + b) sorted_nodes = topological_sort(expr) @test length(sorted_nodes) == 3 + @test isequal(sorted_nodes[1].rhs, a + b) + @test isequal(sin(sorted_nodes[1].lhs), sorted_nodes[2].rhs) expr = (a + b)^(a + b) sorted_nodes = topological_sort(expr) @test length(sorted_nodes) == 2 + @test isequal(sorted_nodes[1].rhs, a + b) + ab_node = sorted_nodes[1].lhs + @test isequal(ab_node^ab_node, sorted_nodes[2].rhs) end From f2673bffa573a43b8951c1eae54636472397af48 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Thu, 9 Jan 2025 21:36:56 +0800 Subject: [PATCH 06/11] Add docstring for `topological_sort` --- src/code.jl | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/code.jl b/src/code.jl index 2e94dfa31..361f145aa 100644 --- a/src/code.jl +++ b/src/code.jl @@ -696,6 +696,21 @@ end @inline newsym(::Type{T}) where T = Sym{T}(gensym("cse")) +""" +$(SIGNATURES) + +Perform a topological sort on a symbolic expression represented as a Directed Acyclic +Graph (DAG). + +This function takes a symbolic expression `graph` (potentially containing shared common +sub-expressions) and returns an array of `Assignment` objects. Each `Assignment` +represents a node in the sorted order, assigning a fresh symbol to its corresponding +expression. The order ensures that all dependencies of a node appear before the node itself +in the array. + +Hash consing is assumed, meaning that structurally identical expressions are represented by +the same object in memory. This allows for efficient equality checks using `IdDict`. +""" function topological_sort(graph) sorted_nodes = Assignment[] visited = IdDict() From 3010908d7deb5e60da86c7a340a1be248e550262 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Thu, 9 Jan 2025 21:39:18 +0800 Subject: [PATCH 07/11] Remove outdated CSE functions --- src/code.jl | 37 ------------------------------------- 1 file changed, 37 deletions(-) diff --git a/src/code.jl b/src/code.jl index 361f145aa..3a47b0ae0 100644 --- a/src/code.jl +++ b/src/code.jl @@ -790,41 +790,4 @@ function cse(x::MakeSparseArray) end end - -function cse_state!(state, t) - !iscall(t) && return t - state[t] = Base.get(state, t, 0) + 1 - foreach(x->cse_state!(state, x), arguments(t)) -end - -function cse_block!(assignments, counter, names, name, state, x) - if get(state, x, 0) > 1 - if haskey(names, x) - return names[x] - else - sym = Sym{symtype(x)}(Symbol(name, counter[])) - names[x] = sym - push!(assignments, sym ← x) - counter[] += 1 - return sym - end - elseif iscall(x) - args = map(a->cse_block!(assignments, counter, names, name, state,a), arguments(x)) - if isterm(x) - return term(operation(x), args...) - else - return maketerm(typeof(x), operation(x), args, metadata(x)) - end - else - return x - end -end - -function cse_block(state, t, name=Symbol("var-", hash(t))) - assignments = Assignment[] - counter = Ref{Int}(1) - names = Dict{Any, BasicSymbolic}() - Let(assignments, cse_block!(assignments, counter, names, name, state, t)) -end - end From 48cc5a62b4fbc9a7daf1a80e8ef0dff938d4fd05 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Thu, 9 Jan 2025 21:45:30 +0800 Subject: [PATCH 08/11] Import DocStringExtensions in SymbolicUtils.Code --- src/code.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/code.jl b/src/code.jl index 3a47b0ae0..9c532faba 100644 --- a/src/code.jl +++ b/src/code.jl @@ -1,6 +1,7 @@ module Code -using StaticArrays, SparseArrays, LinearAlgebra, NaNMath, SpecialFunctions +using StaticArrays, SparseArrays, LinearAlgebra, NaNMath, SpecialFunctions, + DocStringExtensions export toexpr, Assignment, (←), Let, Func, DestructuredArgs, LiteralExpr, SetArray, MakeArray, MakeSparseArray, MakeTuple, AtIndex, From 700a8c6dcdb94ca0fb524e8654c3267f524fca64 Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Thu, 9 Jan 2025 22:13:31 +0800 Subject: [PATCH 09/11] Add more tests for DAG CSE --- test/cse.jl | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/test/cse.jl b/test/cse.jl index abaf8e83e..f3fc7b575 100644 --- a/test/cse.jl +++ b/test/cse.jl @@ -18,10 +18,31 @@ end @test length(sorted_nodes) == 3 @test isequal(sorted_nodes[1].rhs, a + b) @test isequal(sin(sorted_nodes[1].lhs), sorted_nodes[2].rhs) + expr = (a + b)^(a + b) sorted_nodes = topological_sort(expr) @test length(sorted_nodes) == 2 @test isequal(sorted_nodes[1].rhs, a + b) ab_node = sorted_nodes[1].lhs @test isequal(ab_node^ab_node, sorted_nodes[2].rhs) + let_expr = cse(expr) + @test length(let_expr.pairs) == 1 + @test isequal(let_expr.pairs[1].rhs, a + b) + corresponding_sym = let_expr.pairs[1].lhs + @test isequal(let_expr.body, corresponding_sym^corresponding_sym) + + expr = a + b + sorted_nodes = topological_sort(expr) + @test length(sorted_nodes) == 1 + @test isequal(sorted_nodes[1].rhs, a + b) + let_expr = cse(expr) + @test isempty(let_expr.pairs) + @test isequal(let_expr.body, a + b) + + expr = a + sorted_nodes = topological_sort(expr) + @test isempty(sorted_nodes) + let_expr = cse(expr) + @test isempty(let_expr.pairs) + @test isequal(let_expr.body, a) end From 4426e88ec00622aa8c7ad777196110b726d5f45b Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Thu, 16 Jan 2025 22:22:09 -0800 Subject: [PATCH 10/11] Handle array symbolics in CSE `topological_sort` Co-authored-by: Aayush Sabharwal --- src/code.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/code.jl b/src/code.jl index 9c532faba..0f8638182 100644 --- a/src/code.jl +++ b/src/code.jl @@ -727,6 +727,12 @@ function topological_sort(graph) push!(sorted_nodes, sym ← new_node) visited[node] = sym return sym + elseif _is_array_of_symbolics(node) + new_node = map(dfs, node) + sym = newsym(typeof(new_node)) + push!(sorted_nodes, sym ← new_node) + visited[node] = sym + return sym else visited[node] = node return node From 9c0c27ca8748bc2f588aed6164fbe94027c6b9ea Mon Sep 17 00:00:00 2001 From: "Bowen S. Zhu" Date: Thu, 16 Jan 2025 22:35:41 -0800 Subject: [PATCH 11/11] Test CSE `topological_sort` on array of symbolics --- test/cse.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test/cse.jl b/test/cse.jl index f3fc7b575..7e2ef68ea 100644 --- a/test/cse.jl +++ b/test/cse.jl @@ -45,4 +45,12 @@ end let_expr = cse(expr) @test isempty(let_expr.pairs) @test isequal(let_expr.body, a) + + # array symbolics + # https://github.com/JuliaSymbolics/SymbolicUtils.jl/pull/688#pullrequestreview-2554931739 + @syms c + function foo end + ex = term(foo, [a^2 + b^2, b^2 + c], c; type = Real) + sorted_nodes = topological_sort(ex) + @test length(sorted_nodes) == 6 end