Skip to content

Commit

Permalink
fixup! Add support for worker state callbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
JamesWrigley committed Dec 26, 2024
1 parent 90f44f6 commit 8b04241
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 40 deletions.
6 changes: 4 additions & 2 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,10 @@ DistributedNext.cluster_cookie(::Any)
## Callbacks

```@docs
DistributedNext.add_worker_added_callback
DistributedNext.remove_worker_added_callback
DistributedNext.add_worker_starting_callback
DistributedNext.remove_worker_starting_callback
DistributedNext.add_worker_started_callback
DistributedNext.remove_worker_started_callback
DistributedNext.add_worker_exiting_callback
DistributedNext.remove_worker_exiting_callback
DistributedNext.add_worker_exited_callback
Expand Down
4 changes: 2 additions & 2 deletions ext/ReviseExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ Revise.is_master_worker(worker::DistributedNextWorker) = worker.id == 1

function __init__()
Revise.register_workers_function(get_workers)
DistributedNext.add_worker_added_callback(pid -> Revise.init_worker(DistributedNextWorker(pid));
key="DistributedNext-integration")
DistributedNext.add_worker_started_callback(pid -> Revise.init_worker(DistributedNextWorker(pid));
key="DistributedNext-integration")
end

end
94 changes: 69 additions & 25 deletions src/cluster.jl
Original file line number Diff line number Diff line change
Expand Up @@ -463,23 +463,17 @@ function addprocs(manager::ClusterManager; kwargs...)

cluster_mgmt_from_master_check()

# Call worker-starting callbacks
warning_interval = params[:callback_warning_interval]
_run_callbacks_concurrently("worker-starting", worker_starting_callbacks,
warning_interval, [(manager, kwargs)])

# Add new workers
new_workers = @lock worker_lock addprocs_locked(manager::ClusterManager, params)

callback_tasks = Dict{Any, Task}()
for worker in new_workers
for (name, callback) in worker_added_callbacks
callback_tasks[name] = Threads.@spawn callback(worker)
end
end

running_callbacks = () -> ["'$(key)'" for (key, task) in callback_tasks if !istaskdone(task)]
while timedwait(() -> isempty(running_callbacks()), params[:callback_warning_interval]) === :timed_out
callbacks_str = join(running_callbacks(), ", ")
@warn "Waiting for these worker-added callbacks to finish: $(callbacks_str)"
end

# Wait on the tasks so that exceptions bubble up
wait.(values(callback_tasks))
# Call worker-started callbacks
_run_callbacks_concurrently("worker-started", worker_started_callbacks,
warning_interval, new_workers)

return new_workers
end
Expand Down Expand Up @@ -870,7 +864,8 @@ const HDR_COOKIE_LEN=16
const map_pid_wrkr = Dict{Int, Union{Worker, LocalProcess}}()
const map_sock_wrkr = IdDict()
const map_del_wrkr = Set{Int}()
const worker_added_callbacks = Dict{Any, Base.Callable}()
const worker_starting_callbacks = Dict{Any, Base.Callable}()
const worker_started_callbacks = Dict{Any, Base.Callable}()
const worker_exiting_callbacks = Dict{Any, Base.Callable}()
const worker_exited_callbacks = Dict{Any, Base.Callable}()

Expand All @@ -882,9 +877,29 @@ end

# Callbacks

function _add_callback(f, key, dict)
if !hasmethod(f, Tuple{Int})
throw(ArgumentError("Callback function is invalid, it must be able to accept a single Int argument"))
function _run_callbacks_concurrently(callbacks_name, callbacks_dict, warning_interval, arglist)
callback_tasks = Dict{Any, Task}()
for args in arglist
for (name, callback) in callbacks_dict
callback_tasks[name] = Threads.@spawn callback(args...)
end
end

running_callbacks = () -> ["'$(key)'" for (key, task) in callback_tasks if !istaskdone(task)]
while timedwait(() -> isempty(running_callbacks()), warning_interval) === :timed_out
callbacks_str = join(running_callbacks(), ", ")
@warn "Waiting for these $(callbacks_name) callbacks to finish: $(callbacks_str)"
end

# Wait on the tasks so that exceptions bubble up
wait.(values(callback_tasks))
end

function _add_callback(f, key, dict; arg_types=Tuple{Int})
desired_signature = "f(" * join(["::$(t)" for t in arg_types.types], ", ") * ")"

if !hasmethod(f, arg_types)
throw(ArgumentError("Callback function is invalid, it must be able to be called with these argument types: $(desired_signature)"))
elseif haskey(dict, key)
throw(ArgumentError("A callback function with key '$(key)' already exists"))
end
Expand All @@ -900,29 +915,58 @@ end
_remove_callback(key, dict) = delete!(dict, key)

"""
add_worker_added_callback(f::Base.Callable; key=nothing)
add_worker_starting_callback(f::Base.Callable; key=nothing)
Register a callback to be called on the master process immediately before new
workers are started. The callback `f` will be called with the `ClusterManager`
instance that is being used and a dictionary of parameters related to adding
workers, i.e. `f(manager, params)`. The `params` dictionary is specific to the
`manager` type. Note that the `LocalManager` and `SSHManager` cluster managers
in DistributedNext are not fully documented yet, see the
[managers.jl](https://github.com/JuliaParallel/DistributedNext.jl/blob/master/src/managers.jl)
file for their definitions.
!!! warning
Adding workers can fail so it is not guaranteed that the workers requested
will exist.
The worker-starting callbacks will be executed concurrently. If one throws an
exception it will not be caught and will bubble up through [`addprocs`](@ref).
Keep in mind that the callbacks will add to the time taken to launch workers; so
try to either keep the callbacks fast to execute, or do the actual work
asynchronously by spawning a task in the callback (beware of race conditions if
you do this).
"""
add_worker_starting_callback(f::Base.Callable; key=nothing) = _add_callback(f, key, worker_starting_callbacks;
arg_types=Tuple{ClusterManager, Dict})

remove_worker_starting_callback(key) = _remove_callback(key, worker_starting_callbacks)

"""
add_worker_started_callback(f::Base.Callable; key=nothing)
Register a callback to be called on the master process whenever a worker is
added. The callback will be called with the added worker ID,
e.g. `f(w::Int)`. Chooses and returns a unique key for the callback if `key` is
not specified.
The worker-added callbacks will be executed concurrently. If one throws an
The worker-started callbacks will be executed concurrently. If one throws an
exception it will not be caught and will bubble up through [`addprocs()`](@ref).
Keep in mind that the callbacks will add to the time taken to launch workers; so
try to either keep the callbacks fast to execute, or do the actual
initialization asynchronously by spawning a task in the callback (beware of race
conditions if you do this).
"""
add_worker_added_callback(f::Base.Callable; key=nothing) = _add_callback(f, key, worker_added_callbacks)
add_worker_started_callback(f::Base.Callable; key=nothing) = _add_callback(f, key, worker_started_callbacks)

"""
remove_worker_added_callback(key)
remove_worker_started_callback(key)
Remove the callback for `key` that was added with [`add_worker_added_callback()`](@ref).
Remove the callback for `key` that was added with [`add_worker_started_callback()`](@ref).
"""
remove_worker_added_callback(key) = _remove_callback(key, worker_added_callbacks)
remove_worker_started_callback(key) = _remove_callback(key, worker_started_callbacks)

"""
add_worker_exiting_callback(f::Base.Callable; key=nothing)
Expand Down
30 changes: 19 additions & 11 deletions test/distributed_exec.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1939,40 +1939,47 @@ end
@testset "Worker state callbacks" begin
rmprocs(other_workers())

# Adding a callback with an invalid signature should fail
@test_throws ArgumentError DistributedNext.add_worker_started_callback(() -> nothing)

# Smoke test to ensure that all the callbacks are executed
added_workers = Int[]
starting_managers = []
started_workers = Int[]
exiting_workers = Int[]
exited_workers = Int[]
added_key = DistributedNext.add_worker_added_callback(pid -> (push!(added_workers, pid); error("foo")))
starting_key = DistributedNext.add_worker_starting_callback((manager, kwargs) -> push!(starting_managers, manager))
started_key = DistributedNext.add_worker_started_callback(pid -> (push!(started_workers, pid); error("foo")))
exiting_key = DistributedNext.add_worker_exiting_callback(pid -> push!(exiting_workers, pid))
exited_key = DistributedNext.add_worker_exited_callback(pid -> push!(exited_workers, pid))

# Test that the worker-added exception bubbles up
# Test that the worker-started exception bubbles up
@test_throws TaskFailedException addprocs(1)

pid = only(workers())
@test added_workers == [pid]
@test only(starting_managers) isa DistributedNext.LocalManager
@test started_workers == [pid]
rmprocs(workers())
@test exiting_workers == [pid]
@test exited_workers == [pid]

# Trying to reset an existing callback should fail
@test_throws ArgumentError DistributedNext.add_worker_added_callback(Returns(nothing); key=added_key)
@test_throws ArgumentError DistributedNext.add_worker_started_callback(Returns(nothing); key=started_key)

# Remove the callbacks
DistributedNext.remove_worker_added_callback(added_key)
DistributedNext.remove_worker_starting_callback(starting_key)
DistributedNext.remove_worker_started_callback(started_key)
DistributedNext.remove_worker_exiting_callback(exiting_key)
DistributedNext.remove_worker_exited_callback(exited_key)

# Test that the worker-exiting `callback_timeout` option works and that we
# get warnings about slow worker-added callbacks.
# get warnings about slow worker-started callbacks.
event = Base.Event()
callback_task = nothing
added_key = DistributedNext.add_worker_added_callback(_ -> sleep(0.5))
started_key = DistributedNext.add_worker_started_callback(_ -> sleep(0.5))
exiting_key = DistributedNext.add_worker_exiting_callback(_ -> (callback_task = current_task(); wait(event)))

@test_logs (:warn, r"Waiting for these worker-added callbacks.+") match_mode=:any addprocs(1; callback_warning_interval=0.05)
DistributedNext.remove_worker_added_callback(added_key)
@test_logs (:warn, r"Waiting for these worker-started callbacks.+") match_mode=:any addprocs(1; callback_warning_interval=0.05)
DistributedNext.remove_worker_started_callback(started_key)

@test_logs (:warn, r"Some worker-exiting callbacks have not yet finished.+") rmprocs(workers(); callback_timeout=0.5)
DistributedNext.remove_worker_exiting_callback(exiting_key)
Expand All @@ -1981,7 +1988,8 @@ end
wait(callback_task)

# Test that the initial callbacks were indeed removed
@test length(added_workers) == 1
@test length(starting_managers) == 1
@test length(started_workers) == 1
@test length(exiting_workers) == 1
@test length(exited_workers) == 1
end
Expand Down

0 comments on commit 8b04241

Please sign in to comment.