Skip to content

Commit

Permalink
Merge pull request #580 from JuliaParallel/jps/datadeps-no-haswritedep
Browse files Browse the repository at this point in the history
datadeps: Don't skip copy on no writedep
  • Loading branch information
jpsamaroo authored Jan 22, 2025
2 parents 2e155c9 + 96b4f89 commit 496f68b
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 13 deletions.
2 changes: 2 additions & 0 deletions src/array/darray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,8 @@ function Base.collect(d::DArray; tree=false)
end
end

Base.wait(A::DArray) = foreach(wait, A.chunks)

### show

#= FIXME
Expand Down
9 changes: 8 additions & 1 deletion src/array/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,17 @@ Base.last(A::DArray) = A[end]

# In-place operations

function imap!(f, A)
for idx in eachindex(A)
A[idx] = f(A[idx])
end
return A
end

function Base.map!(f, a::DArray{T}) where T
Dagger.spawn_datadeps() do
for ca in chunks(a)
Dagger.@spawn map!(f, InOut(ca), ca)
Dagger.@spawn imap!(f, InOut(ca))
end
end
return a
Expand Down
4 changes: 2 additions & 2 deletions src/array/random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ function Random.rand!(rng::AbstractRNG, A::DArray{T}) where T
Dagger.spawn_datadeps() do
for Ac in chunks(A)
rng = randfork(rng, part_sz)
Dagger.@spawn map!(_->rand(rng, T), InOut(Ac), Ac)
Dagger.@spawn imap!(InOut(_->rand(rng, T)), InOut(Ac))
end
end
return A
Expand All @@ -19,7 +19,7 @@ function Random.randn!(rng::AbstractRNG, A::DArray{T}) where T
Dagger.spawn_datadeps() do
for Ac in chunks(A)
rng = randfork(rng, part_sz)
Dagger.@spawn map!(_->randn(rng, T), InOut(Ac), Ac)
Dagger.@spawn imap!(InOut(_->randn(rng, T)), InOut(Ac))
end
end
return A
Expand Down
70 changes: 63 additions & 7 deletions src/datadeps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,18 +147,22 @@ struct DataDepsState{State<:Union{DataDepsAliasingState,DataDepsNonAliasingState
# The mapping of memory space to remote argument copies
remote_args::Dict{MemorySpace,IdDict{Any,Any}}

# Cache of whether arguments supports in-place move
supports_inplace_cache::IdDict{Any,Bool}

# The aliasing analysis state
alias_state::State

function DataDepsState(aliasing::Bool)
dependencies = Pair{DTask,Vector{Tuple{Bool,Bool,<:AbstractAliasing,<:Any,<:Any}}}[]
remote_args = Dict{MemorySpace,IdDict{Any,Any}}()
supports_inplace_cache = IdDict{Any,Bool}()
if aliasing
state = DataDepsAliasingState()
else
state = DataDepsNonAliasingState()
end
return new{typeof(state)}(aliasing, dependencies, remote_args, state)
return new{typeof(state)}(aliasing, dependencies, remote_args, supports_inplace_cache, state)
end
end

Expand All @@ -168,6 +172,12 @@ function aliasing(astate::DataDepsAliasingState, arg, dep_mod)
end
end

function supports_inplace_move(state::DataDepsState, arg)
return get!(state.supports_inplace_cache, arg) do
return supports_inplace_move(arg)
end
end

# Determine which arguments could be written to, and thus need tracking

"Whether `arg` has any writedep in this datadeps region."
Expand Down Expand Up @@ -323,6 +333,30 @@ function populate_return_info!(state::DataDepsState{DataDepsNonAliasingState}, t
astate.data_origin[task] = space
end

"""
supports_inplace_move(x) -> Bool
Returns `false` if `x` doesn't support being copied into from another object
like `x`, via `move!`. This is used in `spawn_datadeps` to prevent attempting
to copy between values which don't support mutation or otherwise don't have an
implemented `move!` and want to skip in-place copies. When this returns
`false`, datadeps will instead perform out-of-place copies for each non-local
use of `x`, and the data in `x` will not be updated when the `spawn_datadeps`
region returns.
"""
supports_inplace_move(x) = true
supports_inplace_move(t::DTask) = supports_inplace_move(fetch(t; raw=true))
function supports_inplace_move(c::Chunk)
# FIXME: Use MemPool.access_ref
pid = root_worker_id(c.processor)
if pid == myid()
return supports_inplace_move(poolget(c.handle))
else
return remotecall_fetch(supports_inplace_move, pid, c)
end
end
supports_inplace_move(::Function) = false

# Read/write dependency management
function get_write_deps!(state::DataDepsState, ainfo_or_arg, task, write_num, syncdeps)
_get_write_deps!(state, ainfo_or_arg, task, write_num, syncdeps)
Expand Down Expand Up @@ -677,8 +711,15 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
# Is the data written previously or now?
arg, deps = unwrap_inout(arg)
arg = arg isa DTask ? fetch(arg; raw=true) : arg
if !type_may_alias(typeof(arg)) || !has_writedep(state, arg, deps, task)
@dagdebug nothing :spawn_datadeps "($(repr(spec.f)))[$idx] Skipped copy-to (unwritten)"
if !type_may_alias(typeof(arg))
@dagdebug nothing :spawn_datadeps "($(repr(spec.f)))[$idx] Skipped copy-to (immutable)"
spec.args[idx] = pos => arg
continue
end

# Is the data writeable?
if !supports_inplace_move(state, arg)
@dagdebug nothing :spawn_datadeps "($(repr(spec.f)))[$idx] Skipped copy-to (non-writeable)"
spec.args[idx] = pos => arg
continue
end
Expand Down Expand Up @@ -738,7 +779,10 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
# Validate that we're not accidentally performing a copy
for (idx, (_, arg)) in enumerate(spec.args)
_, deps = unwrap_inout(task_args[idx][2])
if is_writedep(arg, deps, task)
# N.B. We only do this check when the argument supports in-place
# moves, because for the moment, we are not guaranteeing updates or
# write-back of results
if is_writedep(arg, deps, task) && supports_inplace_move(state, arg)
arg_space = memory_space(arg)
@assert arg_space == our_space "($(repr(spec.f)))[$idx] Tried to pass $(typeof(arg)) from $arg_space to $our_space"
end
Expand All @@ -750,6 +794,7 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
arg, deps = unwrap_inout(arg)
arg = arg isa DTask ? fetch(arg; raw=true) : arg
type_may_alias(typeof(arg)) || continue
supports_inplace_move(state, arg) || continue
if queue.aliasing
for (dep_mod, _, writedep) in deps
ainfo = aliasing(astate, arg, dep_mod)
Expand Down Expand Up @@ -830,6 +875,12 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
continue
end

# Skip non-writeable arguments
if !supports_inplace_move(state, arg)
@dagdebug nothing :spawn_datadeps "Skipped copy-from (non-writeable)"
continue
end

# Get the set of writers
ainfo_writes = get!(Vector{Tuple{AbstractAliasing,<:Any,MemorySpace}}, arg_writes, arg)

Expand Down Expand Up @@ -877,8 +928,13 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
for arg in keys(astate.data_origin)
# Is the data previously written?
arg, deps = unwrap_inout(arg)
if !type_may_alias(typeof(arg)) || !has_writedep(state, arg, deps)
@dagdebug nothing :spawn_datadeps "Skipped copy-from (unwritten)"
if !type_may_alias(typeof(arg))
@dagdebug nothing :spawn_datadeps "Skipped copy-from (immutable)"
end

# Can the data be written back to?
if !supports_inplace_move(state, arg)
@dagdebug nothing :spawn_datadeps "Skipped copy-from (non-writeable)"
end

# Is the source of truth elsewhere?
Expand Down Expand Up @@ -912,7 +968,7 @@ Dagger tasks launched within `f` may wrap their arguments with `In`, `Out`, or
argument, respectively. These argument dependencies will be used to specify
which tasks depend on each other based on the following rules:
- Dependencies across different arguments are independent; only dependencies on the same argument synchronize with each other ("same-ness" is determined based on `isequal`)
- Dependencies across unrelated arguments are independent; only dependencies on arguments which overlap in memory synchronize with each other
- `InOut` is the same as `In` and `Out` applied simultaneously, and synchronizes with the union of the `In` and `Out` effects
- Any two or more `In` dependencies do not synchronize with each other, and may execute in parallel
- An `Out` dependency synchronizes with any previous `In` and `Out` dependencies
Expand Down
9 changes: 8 additions & 1 deletion src/thunk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,14 @@ function show_thunk(io::IO, t)
end
print(io, ")")
end
Base.show(io::IO, t::Thunk) = show_thunk(io, t)
function Base.show(io::IO, t::Thunk)
lazy_level = parse(Int, get(ENV, "JULIA_DAGGER_SHOW_THUNK_VERBOSITY", "0"))
if lazy_level == 0
show_thunk(io, t)
else
show_thunk(IOContext(io, :lazy_level => lazy_level), t)
end
end
Base.summary(t::Thunk) = repr(t)

inputs(x::Thunk) = x.inputs
Expand Down
4 changes: 2 additions & 2 deletions test/datadeps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ function taskdeps_for_task(logs::Dict{Int,<:Dict}, tid::Int)
end
error("Task $tid not found in logs")
end
function test_task_dominators(logs::Dict, tid::Int, doms::Vector; all_tids::Vector=[], nondom_check::Bool=true)
function test_task_dominators(logs::Dict, tid::Int, doms::Vector; all_tids::Vector=[], nondom_check::Bool=false)
g = SimpleDiGraph()
tid_to_v = Dict{Int,Int}()
seen = Set{Int}()
Expand Down Expand Up @@ -165,7 +165,7 @@ function test_datadeps(;args_chunks::Bool,
end
tid_1, tid_2 = task_id.(ts)
test_task_dominators(logs, tid_1, []; all_tids=[tid_1, tid_2])
test_task_dominators(logs, tid_2, []; all_tids=[tid_1, tid_2])
test_task_dominators(logs, tid_2, []; all_tids=[tid_1, tid_2], nondom_check=false)

# R->W Aliasing
ts = []
Expand Down

0 comments on commit 496f68b

Please sign in to comment.