Skip to content

Commit

Permalink
Implement pmap
Browse files Browse the repository at this point in the history
  • Loading branch information
shinaoka committed Dec 26, 2024
1 parent 1f4eb0c commit b2cf5c9
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 2 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@ version = "0.1.0"
[deps]
Coverage = "a2441757-f6aa-5fb2-8edb-039e3f45d037"
CoverageTools = "c36e975a-824b-4404-a568-ef97ca766997"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
NamedGraphs = "678767b0-92e7-4007-89e4-4527a8725b19"

[compat]
Coverage = "1"
CoverageTools = "1"
Distributed = "1.11.0"
Graphs = "1.12.0"
NamedGraphs = "0.6.3"
julia = "1.9"
Expand Down
1 change: 1 addition & 0 deletions src/ComputationalGraphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module ComputationalGraphs

using Graphs
using NamedGraphs
using Distributed

include("computationalgraph.jl")

Expand Down
58 changes: 56 additions & 2 deletions src/computationalgraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,54 @@ function compute_computable_nodes!(obj::ComputationalGraph{K,N})::Int where {K,N
return length(computable_nodes)
end

"""
Prepare minimal data (dependencies' values + computefunc) for a node.
"""
function compute_node_data(
obj::ComputationalGraph{K,N}, node::K
)::Tuple{Dict{K,N},Function} where {K,N}
# Check if node is computable
if !is_computable(obj, node)
error("Node $node is not computable")
end
# Get dependency nodes
dependencies = inneighbors(obj.graph, node)
# Collect values of dependency nodes (Dict{depnode => value})
args = Dict(depnode => obj.nodevalue[depnode] for depnode in dependencies)
# Get function to execute
func = obj.computefunc[node]
# Return tuple of (dependency node values, function)
return (args, func)
end

function _compute_node_remote(args::Dict{K,N}, func::Function) where {K,N}
# Small function to execute on worker nodes
return func(args)
end

"""
Like compute_computable_nodes!, but computes nodes in parallel using distributed workers.
Returns the number of computed nodes.
"""
function compute_computable_nodes_pmap!(obj::ComputationalGraph{K,N})::Int where {K,N}
computable_nodes = get_computable_nodes(obj)

@show length(computable_nodes)
# Create (args, func) and pmap them
# Only (args, func) are transferred to workers
results = pmap(node -> begin
(args, f) = compute_node_data(obj, node) # Generate minimal data locally
_compute_node_remote(args, f) # Execute on worker
end, computable_nodes)

# Update obj.nodevalue[node] with results returned from workers
for (i, node) in enumerate(computable_nodes)
obj.nodevalue[node] = results[i]
end

return length(results)
end

"""
Unassigned all unneseccary intermediate nodes from the computational graph.
Return the number of unassigned nodes.
Expand All @@ -177,10 +225,16 @@ end
Compute all nodes in the computational graph.
Return the number of calls of compute_computable_nodes!.
"""
function compute_all_nodes!(obj::ComputationalGraph{K,N}; callgc=true)::Int where {K,N}
function compute_all_nodes!(
obj::ComputationalGraph{K,N}; callgc=true, distributed=false
)::Int where {K,N}
count = 0
while true
computed_count = compute_computable_nodes!(obj)
computed_count = if distributed
compute_computable_nodes_pmap!(obj)
else
compute_computable_nodes!(obj)
end
unassign_intermediate_nodes!(obj)
if callgc
GC.gc()
Expand Down
8 changes: 8 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
using Distributed

nworkers_ = 2
if nworkers() < nworkers_
addprocs(nworkers_ - nworkers())
end

@everywhere import ComputationalGraphs as CG
using Test

include("test_computationalgraph.jl")
24 changes: 24 additions & 0 deletions test/test_computationalgraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,28 @@ import ComputationalGraphs as CG
@test cg.nodevalue["A"] === nothing
@test cg.nodevalue["B"] === nothing
end

@testset "pmap" begin
cg = CG.ComputationalGraph{Any,Float64}()

N = 24

for n in 1:N
CG.add_node!(cg, "A$n", 1.0)
CG.add_node!(cg, "B$n", 1.0)
CG.add_node!(cg, "C$n")

CG.add_dependency!(
cg,
"C$n";
dependencies=["A$n", "B$n"],
computefunc=x -> (sleep(0.1); sum(values(x))),
)
end

@test CG.compute_all_nodes!(cg; distributed=true) == 1
@test all([
cg.nodevalue["C$n"] == cg.nodevalue["A$n"] + cg.nodevalue["B$n"] for n in 1:N
])
end
end;

0 comments on commit b2cf5c9

Please sign in to comment.