diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index db4d09e1..6a565870 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -6,6 +6,7 @@ os: linux arch: x86_64 command: "julia --project -e 'using Pkg; Pkg.develop(;path=\"lib/TimespanLogging\")'" + .bench: &bench if: build.message =~ /\[run benchmarks\]/ agents: @@ -14,6 +15,7 @@ os: linux arch: x86_64 num_cpus: 16 + steps: - label: Julia 1.9 timeout_in_minutes: 90 diff --git a/docs/make.jl b/docs/make.jl index c21c03f2..8f1f97f5 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -22,6 +22,7 @@ makedocs(; "Task Spawning" => "task-spawning.md", "Data Management" => "data-management.md", "Distributed Arrays" => "darray.md", + "Streaming Tasks" => "streaming.md", "Scopes" => "scopes.md", "Processors" => "processors.md", "Task Queues" => "task-queues.md", diff --git a/docs/src/index.md b/docs/src/index.md index 27eb28dd..152b95cc 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -394,3 +394,38 @@ Dagger.@spawn copyto!(C, X) In contrast to the previous example, here, the tasks are executed without argument annotations. As a result, there is a possibility of the `copyto!` task being executed before the `sort!` task, leading to unexpected results in the output array `C`. +## Quickstart: Streaming + +Dagger.jl provides a streaming API that allows you to process data in a streaming fashion, where data is processed as it becomes available, rather than waiting for the entire dataset to be loaded into memory. + +For more details: [Streaming](@ref) + +### Syntax + +The `Dagger.spawn_streaming()` function is used to create a streaming region, +where tasks are executed continuously, processing data as it becomes available: + +```julia +# Open a file to write to on this worker +f = Dagger.@mutable open("output.txt", "w") +t = Dagger.spawn_streaming() do + # Generate random numbers continuously + val = Dagger.@spawn rand() + # Write each random number to a file + Dagger.@spawn (f, val) -> begin + if val < 0.01 + # Finish streaming when the random number is less than 0.01 + Dagger.finish_stream() + end + println(f, val) + end +end +# Wait for all values to be generated and written +wait(t) +``` + +The above example demonstrates a streaming region that generates random numbers +continuously and writes each random number to a file. The streaming region is +terminated when a random number less than 0.01 is generated, which is done by +calling `Dagger.finish_stream()` (this terminates the current task, and will +also terminate all streaming tasks launched by `spawn_streaming`). diff --git a/docs/src/streaming.md b/docs/src/streaming.md new file mode 100644 index 00000000..41c111e8 --- /dev/null +++ b/docs/src/streaming.md @@ -0,0 +1,105 @@ +# Streaming + +Dagger tasks have a limited lifetime - they are created, execute, finish, and +are eventually destroyed when they're no longer needed. Thus, if one wants +to run the same kind of computations over and over, one might re-create a +similar set of tasks for each unit of data that needs processing. + +This might be fine for computations which take a long time to run (thus +dwarfing the cost of task creation, which is quite small), or when working with +a limited set of data, but this approach is not great for doing lots of small +computations on a large (or endless) amount of data. For example, processing +image frames from a webcam, reacting to messages from a message bus, reading +samples from a software radio, etc. All of these tasks are better suited to a +"streaming" model of data processing, where data is simply piped into a +continuously-running task (or DAG of tasks) forever, or until the data runs +out. + +Thankfully, if you have a problem which is best modeled as a streaming system +of tasks, Dagger has you covered! Building on its support for +[Task Queues](@ref), Dagger provides a means to convert an entire DAG of +tasks into a streaming DAG, where data flows into and out of each task +asynchronously, using the `spawn_streaming` function: + +```julia +Dagger.spawn_streaming() do # enters a streaming region + vals = Dagger.@spawn rand() + print_vals = Dagger.@spawn println(vals) +end # exits the streaming region, and starts the DAG running +``` + +In the above example, `vals` is a Dagger task which has been transformed to run +in a streaming manner - instead of just calling `rand()` once and returning its +result, it will re-run `rand()` endlessly, continuously producing new random +values. In typical Dagger style, `print_vals` is a Dagger task which depends on +`vals`, but in streaming form - it will continuously `println` the random +values produced from `vals`. Both tasks will run forever, and will run +efficiently, only doing the work necessary to generate, transfer, and consume +values. + +As the comments point out, `spawn_streaming` creates a streaming region, during +which `vals` and `print_vals` are created and configured. Both tasks are halted +until `spawn_streaming` returns, allowing large DAGs to be built all at once, +without any task losing a single value. If desired, streaming regions can be +connected, although some values might be lost while tasks are being connected: + +```julia +vals = Dagger.spawn_streaming() do + Dagger.@spawn rand() +end + +# Some values might be generated by `vals` but thrown away +# before `print_vals` is fully setup and connected to it + +print_vals = Dagger.spawn_streaming() do + Dagger.@spawn println(vals) +end +``` + +More complicated streaming DAGs can be easily constructed, without doing +anything different. For example, we can generate multiple streams of random +numbers, write them all to their own files, and print the combined results: + +```julia +Dagger.spawn_streaming() do + all_vals = [Dagger.spawn(rand) for i in 1:4] + all_vals_written = map(1:4) do i + Dagger.spawn(all_vals[i]) do val + open("results_$i.txt"; write=true, create=true, append=true) do io + println(io, repr(val)) + end + return val + end + end + Dagger.spawn(all_vals_written...) do all_vals_written... + vals_sum = sum(all_vals_written) + println(vals_sum) + end +end +``` + +If you want to stop the streaming DAG and tear it all down, you can call +`Dagger.cancel!(all_vals[1])` (or with any other task in the streaming DAG) to +terminate all streaming tasks. + +Alternatively, tasks can stop themselves from the inside with +`finish_stream`, optionally returning a value that can be `fetch`'d. Let's +do this when our randomly-drawn number falls within some arbitrary range: + +```julia +vals = Dagger.spawn_streaming() do + Dagger.spawn() do + x = rand() + if x < 0.001 + # That's good enough, let's be done + return Dagger.finish_stream("Finished!") + end + return x + end +end +fetch(vals) +``` + +In this example, the call to `fetch` will hang (while random numbers continue +to be drawn), until a drawn number is less than 0.001; at that point, `fetch` +will return with `"Finished!"`, and the task `vals` will have terminated. diff --git a/src/Dagger.jl b/src/Dagger.jl index 8bc2c24a..fd6395a4 100644 --- a/src/Dagger.jl +++ b/src/Dagger.jl @@ -21,6 +21,7 @@ if !isdefined(Base, :ScopedValues) else import Base.ScopedValues: ScopedValue, with end +import TaskLocalValues: TaskLocalValue if !isdefined(Base, :get_extension) import Requires: @require @@ -55,16 +56,16 @@ include("processor.jl") include("threadproc.jl") include("context.jl") include("utils/processors.jl") +include("dtask.jl") +include("cancellation.jl") include("task-tls.jl") include("scopes.jl") include("utils/scopes.jl") -include("dtask.jl") include("queue.jl") include("thunk.jl") include("submission.jl") include("chunks.jl") include("memory-spaces.jl") -include("cancellation.jl") # Task scheduling include("compute.jl") @@ -76,6 +77,11 @@ include("sch/Sch.jl"); using .Sch # Data dependency task queue include("datadeps.jl") +# Streaming +include("stream.jl") +include("stream-buffers.jl") +include("stream-transfer.jl") + # Array computations include("array/darray.jl") include("array/alloc.jl") @@ -169,6 +175,20 @@ function __init__() ThreadProc(myid(), tid) end end + + # Set up @dagdebug categories, if specified + try + if haskey(ENV, "JULIA_DAGGER_DEBUG") + empty!(DAGDEBUG_CATEGORIES) + for category in split(ENV["JULIA_DAGGER_DEBUG"], ",") + if category != "" + push!(DAGDEBUG_CATEGORIES, Symbol(category)) + end + end + end + catch err + @warn "Error parsing JULIA_DAGGER_DEBUG" exception=err + end end end # module diff --git a/src/array/indexing.jl b/src/array/indexing.jl index 82f44fbf..69725eb7 100644 --- a/src/array/indexing.jl +++ b/src/array/indexing.jl @@ -1,5 +1,3 @@ -import TaskLocalValues: TaskLocalValue - ### getindex struct GetIndex{T,N} <: ArrayOp{T,N} diff --git a/src/cancellation.jl b/src/cancellation.jl index c982fd20..63993a0e 100644 --- a/src/cancellation.jl +++ b/src/cancellation.jl @@ -1,11 +1,61 @@ +# DTask-level cancellation + +mutable struct CancelToken + @atomic cancelled::Bool + @atomic graceful::Bool + event::Base.Event +end +CancelToken() = CancelToken(false, false, Base.Event()) +function cancel!(token::CancelToken; graceful::Bool=true) + if !graceful + @atomic token.graceful = false + end + @atomic token.cancelled = true + notify(token.event) + return +end +function is_cancelled(token::CancelToken; must_force::Bool=false) + if token.cancelled[] + if must_force && token.graceful[] + # If we're only responding to forced cancellation, ignore graceful cancellations + return false + end + return true + end + return false +end +Base.wait(token::CancelToken) = wait(token.event) +# TODO: Enable this for safety +#Serialization.serialize(io::AbstractSerializer, ::CancelToken) = +# throw(ConcurrencyViolationError("Cannot serialize a CancelToken")) + +const DTASK_CANCEL_TOKEN = TaskLocalValue{Union{CancelToken,Nothing}}(()->nothing) + +function clone_cancel_token_remote(orig_token::CancelToken, wid::Integer) + remote_token = remotecall_fetch(wid) do + return poolset(CancelToken()) + end + errormonitor_tracked("remote cancel_token communicator", Threads.@spawn begin + wait(orig_token) + @dagdebug nothing :cancel "Cancelling remote token on worker $wid" + MemPool.access_ref(remote_token) do remote_token + cancel!(remote_token) + end + end) +end + +# Global-level cancellation + """ - cancel!(task::DTask; force::Bool=false, halt_sch::Bool=false) + cancel!(task::DTask; force::Bool=false, graceful::Bool=true, halt_sch::Bool=false) Cancels `task` at any point in its lifecycle, causing the scheduler to abandon -it. If `force` is `true`, the task will be interrupted with an -`InterruptException` (not recommended, this is unsafe). If `halt_sch` is -`true`, the scheduler will be halted after the task is cancelled (it will -restart automatically upon the next `@spawn`/`spawn` call). +it. + +# Keyword arguments +- `force`: If `true`, the task will be interrupted with an `InterruptException` (not recommended, this is unsafe). +- `graceful`: If `true`, the task will be allowed to finish its current execution before being cancelled; otherwise, it will be cancelled as soon as possible. +- `halt_sch`: If `true`, the scheduler will be halted after the task is cancelled (it will restart automatically upon the next `@spawn`/`spawn` call). As an example, the following code will cancel task `t` before it finishes executing: @@ -21,24 +71,24 @@ tasks which are waiting to run. Using `cancel!` is generally a much safer alternative to Ctrl+C, as it cooperates with the scheduler and runtime and avoids unintended side effects. """ -function cancel!(task::DTask; force::Bool=false, halt_sch::Bool=false) +function cancel!(task::DTask; force::Bool=false, graceful::Bool=true, halt_sch::Bool=false) tid = lock(Dagger.Sch.EAGER_ID_MAP) do id_map id_map[task.uid] end - cancel!(tid; force, halt_sch) + cancel!(tid; force, graceful, halt_sch) end function cancel!(tid::Union{Int,Nothing}=nothing; - force::Bool=false, halt_sch::Bool=false) + force::Bool=false, graceful::Bool=true, halt_sch::Bool=false) remotecall_fetch(1, tid, force, halt_sch) do tid, force, halt_sch state = Sch.EAGER_STATE[] # Check that the scheduler isn't stopping or has already stopped if !isnothing(state) && !state.halt.set - @lock state.lock _cancel!(state, tid, force, halt_sch) + @lock state.lock _cancel!(state, tid, force, graceful, halt_sch) end end end -function _cancel!(state, tid, force, halt_sch) +function _cancel!(state, tid, force, graceful, halt_sch) @assert islocked(state.lock) # Get the scheduler uid @@ -48,7 +98,7 @@ function _cancel!(state, tid, force, halt_sch) for task in state.ready tid !== nothing && task.id != tid && continue @dagdebug tid :cancel "Cancelling ready task" - state.cache[task] = InterruptException() + state.cache[task] = DTaskFailedException(task, task, InterruptException()) state.errored[task] = true Sch.set_failed!(state, task) end @@ -58,7 +108,7 @@ function _cancel!(state, tid, force, halt_sch) for task in keys(state.waiting) tid !== nothing && task.id != tid && continue @dagdebug tid :cancel "Cancelling waiting task" - state.cache[task] = InterruptException() + state.cache[task] = DTaskFailedException(task, task, InterruptException()) state.errored[task] = true Sch.set_failed!(state, task) end @@ -80,11 +130,11 @@ function _cancel!(state, tid, force, halt_sch) Tf === typeof(Sch.eager_thunk) && continue istaskdone(task) && continue any_cancelled = true - @dagdebug tid :cancel "Cancelling running task ($Tf)" if force @dagdebug tid :cancel "Interrupting running task ($Tf)" Threads.@spawn Base.throwto(task, InterruptException()) else + @dagdebug tid :cancel "Cancelling running task ($Tf)" # Tell the processor to just drop this task task_occupancy = task_spec[4] time_util = task_spec[2] @@ -93,6 +143,7 @@ function _cancel!(state, tid, force, halt_sch) push!(istate.cancelled, tid) to_proc = istate.proc put!(istate.return_queue, (myid(), to_proc, tid, (InterruptException(), nothing))) + cancel!(istate.cancel_tokens[tid]; graceful) end end end diff --git a/src/compute.jl b/src/compute.jl index f421eacc..093b527f 100644 --- a/src/compute.jl +++ b/src/compute.jl @@ -36,12 +36,6 @@ end Base.@deprecate gather(ctx, x) collect(ctx, x) Base.@deprecate gather(x) collect(x) -cleanup() = cleanup(Context(global_context())) -function cleanup(ctx::Context) - Sch.cleanup(ctx) - nothing -end - function get_type(s::String) local T for t in split(s, ".") diff --git a/src/dtask.jl b/src/dtask.jl index 68f2d3c1..b597db5f 100644 --- a/src/dtask.jl +++ b/src/dtask.jl @@ -39,6 +39,16 @@ end Options(;options...) = Options((;options...)) Options(options...) = Options((;options...)) +""" + DTaskMetadata + +Represents some useful metadata pertaining to a `DTask`: +- `return_type::Type` - The inferred return type of the task +""" +mutable struct DTaskMetadata + return_type::Type +end + """ DTask @@ -50,9 +60,11 @@ more details. mutable struct DTask uid::UInt future::ThunkFuture + metadata::DTaskMetadata finalizer_ref::DRef thunk_ref::DRef - DTask(uid, future, finalizer_ref) = new(uid, future, finalizer_ref) + + DTask(uid, future, metadata, finalizer_ref) = new(uid, future, metadata, finalizer_ref) end const EagerThunk = DTask @@ -73,6 +85,32 @@ function Base.fetch(t::DTask; raw=false) end return fetch(t.future; raw) end +function waitany(tasks::Vector{DTask}) + if isempty(tasks) + return + end + cond = Threads.Condition() + for task in tasks + Sch.errormonitor_tracked("waitany listener", Threads.@spawn begin + wait(task) + @lock cond notify(cond) + end) + end + @lock cond wait(cond) + return +end +function waitall(tasks::Vector{DTask}) + if isempty(tasks) + return + end + @sync for task in tasks + Threads.@spawn begin + wait(task) + @lock cond notify(cond) + end + end + return +end function Base.show(io::IO, t::DTask) status = if istaskstarted(t) isready(t) ? "finished" : "running" diff --git a/src/options.jl b/src/options.jl index 1c1e3ff2..00196dd5 100644 --- a/src/options.jl +++ b/src/options.jl @@ -20,6 +20,12 @@ function with_options(f, options::NamedTuple) end with_options(f; options...) = with_options(f, NamedTuple(options)) +function _without_options(f) + with(options_context => NamedTuple()) do + f() + end +end + """ get_options(key::Symbol, default) -> Any get_options(key::Symbol) -> Any diff --git a/src/sch/Sch.jl b/src/sch/Sch.jl index 5c330841..b894f452 100644 --- a/src/sch/Sch.jl +++ b/src/sch/Sch.jl @@ -259,9 +259,11 @@ end Combine `SchedulerOptions` and `ThunkOptions` into a new `ThunkOptions`. """ function Base.merge(sopts::SchedulerOptions, topts::ThunkOptions) - single = topts.single !== nothing ? topts.single : sopts.single - allow_errors = topts.allow_errors !== nothing ? topts.allow_errors : sopts.allow_errors - proclist = topts.proclist !== nothing ? topts.proclist : sopts.proclist + select_option = (sopt, topt) -> isnothing(topt) ? sopt : topt + + single = select_option(sopts.single, topts.single) + allow_errors = select_option(sopts.allow_errors, topts.allow_errors) + proclist = select_option(sopts.proclist, topts.proclist) ThunkOptions(single, proclist, topts.time_util, @@ -313,9 +315,6 @@ function populate_defaults(opts::ThunkOptions, Tf, Targs) ) end -function cleanup(ctx) -end - # Eager scheduling include("eager.jl") @@ -687,6 +686,9 @@ function schedule!(ctx, state, procs=procs_to_use(ctx)) safepoint(state) @assert length(procs) > 0 + # Remove processors that aren't yet initialized + procs = filter(p -> haskey(state.worker_chans, Dagger.root_worker_id(p)), procs) + populate_processor_cache_list!(state, procs) # Schedule tasks @@ -1186,6 +1188,7 @@ struct ProcessorInternalState proc_occupancy::Base.RefValue{UInt32} time_pressure::Base.RefValue{UInt64} cancelled::Set{Int} + cancel_tokens::Dict{Int,Dagger.CancelToken} done::Base.RefValue{Bool} end struct ProcessorState @@ -1205,7 +1208,7 @@ function proc_states(f::Base.Callable, uid::UInt64) end end proc_states(f::Base.Callable) = - proc_states(f, task_local_storage(:_dagger_sch_uid)::UInt64) + proc_states(f, Dagger.get_tls().sch_uid) task_tid_for_processor(::Processor) = nothing task_tid_for_processor(proc::Dagger.ThreadProc) = proc.tid @@ -1335,7 +1338,14 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re # Execute the task and return its result t = @task begin + # Set up cancellation + cancel_token = Dagger.CancelToken() + Dagger.DTASK_CANCEL_TOKEN[] = cancel_token + lock(istate.queue) do _ + istate.cancel_tokens[thunk_id] = cancel_token + end was_cancelled = false + result = try do_task(to_proc, task) catch err @@ -1352,6 +1362,7 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re # Task was cancelled, so occupancy and pressure are # already reduced pop!(istate.cancelled, thunk_id) + delete!(istate.cancel_tokens, thunk_id) was_cancelled = true end end @@ -1367,8 +1378,11 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re if unwrap_nested_exception(err) isa InvalidStateException || !isopen(return_queue) @dagdebug thunk_id :execute "Return queue is closed, failing to put result" chan=return_queue exception=(err, catch_backtrace()) else - rethrow(err) + rethrow() end + finally + # Ensure that any spawned tasks get cleaned up + Dagger.cancel!(cancel_token) end end lock(istate.queue) do _ @@ -1418,6 +1432,7 @@ function do_tasks(to_proc, return_queue, tasks) Dict{Int,Vector{Any}}(), Ref(UInt32(0)), Ref(UInt64(0)), Set{Int}(), + Dict{Int,Dagger.CancelToken}(), Ref(false)) runner = start_processor_runner!(istate, uid, return_queue) @static if VERSION < v"1.9" @@ -1659,6 +1674,7 @@ function do_task(to_proc, task_desc) sch_handle, processor=to_proc, task_spec=task_desc, + cancel_token=Dagger.DTASK_CANCEL_TOKEN[], )) res = Dagger.with_options(propagated) do diff --git a/src/sch/dynamic.jl b/src/sch/dynamic.jl index e02085ee..5b917fdb 100644 --- a/src/sch/dynamic.jl +++ b/src/sch/dynamic.jl @@ -17,7 +17,7 @@ struct SchedulerHandle end "Gets the scheduler handle for the currently-executing thunk." -sch_handle() = task_local_storage(:_dagger_sch_handle)::SchedulerHandle +sch_handle() = Dagger.get_tls().sch_handle::SchedulerHandle "Thrown when the scheduler halts before finishing processing the DAG." struct SchedulerHaltedException <: Exception end diff --git a/src/sch/eager.jl b/src/sch/eager.jl index 87a10978..aea0abbf 100644 --- a/src/sch/eager.jl +++ b/src/sch/eager.jl @@ -6,7 +6,7 @@ const EAGER_STATE = Ref{Union{ComputeState,Nothing}}(nothing) function eager_context() if EAGER_CONTEXT[] === nothing - EAGER_CONTEXT[] = Context([myid(),workers()...]) + EAGER_CONTEXT[] = Context(procs()) end return EAGER_CONTEXT[] end @@ -124,6 +124,13 @@ function eager_cleanup(state, uid) # N.B. cache and errored expire automatically delete!(state.thunk_dict, tid) end + remotecall_wait(1, uid) do uid + lock(Dagger.EAGER_THUNK_STREAMS) do global_streams + if haskey(global_streams, uid) + delete!(global_streams, uid) + end + end + end end function _find_thunk(e::Dagger.DTask) @@ -134,3 +141,6 @@ function _find_thunk(e::Dagger.DTask) unwrap_weak_checked(EAGER_STATE[].thunk_dict[tid]) end end +Dagger.task_id(t::Dagger.DTask) = lock(EAGER_ID_MAP) do id_map + id_map[t.uid] +end diff --git a/src/sch/util.jl b/src/sch/util.jl index e81703db..2e090b26 100644 --- a/src/sch/util.jl +++ b/src/sch/util.jl @@ -29,6 +29,10 @@ unwrap_nested_exception(err::CapturedException) = unwrap_nested_exception(err.ex) unwrap_nested_exception(err::RemoteException) = unwrap_nested_exception(err.captured) +unwrap_nested_exception(err::DTaskFailedException) = + unwrap_nested_exception(err.ex) +unwrap_nested_exception(err::TaskFailedException) = + unwrap_nested_exception(err.t.exception) unwrap_nested_exception(err) = err "Gets a `NamedTuple` of options propagated by `thunk`." @@ -406,12 +410,19 @@ function has_capacity(state, p, gp, time_util, alloc_util, occupancy, sig) else get(state.signature_alloc_cost, sig, UInt64(0)) end::UInt64 - est_occupancy = if occupancy !== nothing && haskey(occupancy, T) - # Clamp to 0-1, and scale between 0 and `typemax(UInt32)` - Base.unsafe_trunc(UInt32, clamp(occupancy[T], 0, 1) * typemax(UInt32)) - else - typemax(UInt32) - end::UInt32 + est_occupancy::UInt32 = typemax(UInt32) + if occupancy !== nothing + occ = nothing + if haskey(occupancy, T) + occ = occupancy[T] + elseif haskey(occupancy, Any) + occ = occupancy[Any] + end + if occ !== nothing + # Clamp to 0-1, and scale between 0 and `typemax(UInt32)` + est_occupancy = Base.unsafe_trunc(UInt32, clamp(occ, 0, 1) * typemax(UInt32)) + end + end #= FIXME: Estimate if cached data can be swapped to storage storage = storage_resource(p) real_alloc_util = state.worker_storage_pressure[gp][storage] diff --git a/src/stream-buffers.jl b/src/stream-buffers.jl new file mode 100644 index 00000000..9770933f --- /dev/null +++ b/src/stream-buffers.jl @@ -0,0 +1,64 @@ +"A process-local ring buffer." +mutable struct ProcessRingBuffer{T} + read_idx::Int + write_idx::Int + @atomic count::Int + buffer::Vector{T} + @atomic open::Bool + function ProcessRingBuffer{T}(len::Int=1024) where T + buffer = Vector{T}(undef, len) + return new{T}(1, 1, 0, buffer, true) + end +end +Base.isempty(rb::ProcessRingBuffer) = (@atomic rb.count) == 0 +isfull(rb::ProcessRingBuffer) = (@atomic rb.count) == length(rb.buffer) +capacity(rb::ProcessRingBuffer) = length(rb.buffer) +Base.length(rb::ProcessRingBuffer) = @atomic rb.count +Base.isopen(rb::ProcessRingBuffer) = @atomic rb.open +function Base.close(rb::ProcessRingBuffer) + @atomic rb.open = false +end +function Base.put!(rb::ProcessRingBuffer{T}, x) where T + while isfull(rb) + yield() + if !isopen(rb) + throw(InvalidStateException("ProcessRingBuffer is closed", :closed)) + end + task_may_cancel!(; must_force=true) + end + to_write_idx = mod1(rb.write_idx, length(rb.buffer)) + rb.buffer[to_write_idx] = convert(T, x) + rb.write_idx += 1 + @atomic rb.count += 1 +end +function Base.take!(rb::ProcessRingBuffer) + while isempty(rb) + yield() + if !isopen(rb) && isempty(rb) + throw(InvalidStateException("ProcessRingBuffer is closed", :closed)) + end + if task_cancelled() && isempty(rb) + # We respect a graceful cancellation only if the buffer is empty. + # Otherwise, we may have values to continue communicating. + task_may_cancel!() + end + task_may_cancel!(; must_force=true) + end + to_read_idx = rb.read_idx + rb.read_idx += 1 + @atomic rb.count -= 1 + to_read_idx = mod1(to_read_idx, length(rb.buffer)) + return rb.buffer[to_read_idx] +end + +""" +`take!()` all the elements from a buffer and put them in a `Vector`. +""" +function collect!(rb::ProcessRingBuffer{T}) where T + output = Vector{T}(undef, rb.count) + for i in 1:rb.count + output[i] = take!(rb) + end + + return output +end diff --git a/src/stream-transfer.jl b/src/stream-transfer.jl new file mode 100644 index 00000000..96e61fb9 --- /dev/null +++ b/src/stream-transfer.jl @@ -0,0 +1,71 @@ +struct RemoteChannelFetcher + chan::RemoteChannel + RemoteChannelFetcher() = new(RemoteChannel()) +end +const _THEIR_TID = TaskLocalValue{Int}(()->0) +function stream_push_values!(fetcher::RemoteChannelFetcher, T, our_store::StreamStore, their_stream::Stream, buffer) + our_tid = STREAM_THUNK_ID[] + our_uid = our_store.uid + their_uid = their_stream.uid + if _THEIR_TID[] == 0 + _THEIR_TID[] = remotecall_fetch(1) do + lock(Sch.EAGER_ID_MAP) do id_map + id_map[their_uid] + end + end + end + their_tid = _THEIR_TID[] + @dagdebug our_tid :stream_push "taking output value: $our_tid -> $their_tid" + value = try + take!(buffer) + catch + close(fetcher.chan) + rethrow() + end + @lock our_store.lock notify(our_store.lock) + @dagdebug our_tid :stream_push "pushing output value: $our_tid -> $their_tid" + try + put!(fetcher.chan, value) + catch err + if err isa InvalidStateException && !isopen(fetcher.chan) + @dagdebug our_tid :stream_push "channel closed: $our_tid -> $their_tid" + throw(InterruptException()) + end + # N.B. We don't close the buffer to allow for eventual reconnection + rethrow() + end + @dagdebug our_tid :stream_push "finished pushing output value: $our_tid -> $their_tid" +end +function stream_pull_values!(fetcher::RemoteChannelFetcher, T, our_store::StreamStore, their_stream::Stream, buffer) + our_tid = STREAM_THUNK_ID[] + our_uid = our_store.uid + their_uid = their_stream.uid + if _THEIR_TID[] == 0 + _THEIR_TID[] = remotecall_fetch(1) do + lock(Sch.EAGER_ID_MAP) do id_map + id_map[their_uid] + end + end + end + their_tid = _THEIR_TID[] + @dagdebug our_tid :stream_pull "pulling input value: $their_tid -> $our_tid" + value = try + take!(fetcher.chan) + catch err + if err isa InvalidStateException && !isopen(fetcher.chan) + @dagdebug our_tid :stream_pull "channel closed: $their_tid -> $our_tid" + throw(InterruptException()) + end + # N.B. We don't close the buffer to allow for eventual reconnection + rethrow() + end + @dagdebug our_tid :stream_pull "putting input value: $their_tid -> $our_tid" + try + put!(buffer, value) + catch + close(fetcher.chan) + rethrow() + end + @lock our_store.lock notify(our_store.lock) + @dagdebug our_tid :stream_pull "finished putting input value: $their_tid -> $our_tid" +end diff --git a/src/stream.jl b/src/stream.jl new file mode 100644 index 00000000..07a3dae9 --- /dev/null +++ b/src/stream.jl @@ -0,0 +1,707 @@ +mutable struct StreamStore{T,B} + uid::UInt + waiters::Vector{Int} + input_streams::Dict{UInt,Any} # FIXME: Concrete type + output_streams::Dict{UInt,Any} # FIXME: Concrete type + input_buffers::Dict{UInt,B} + output_buffers::Dict{UInt,B} + input_buffer_amount::Int + output_buffer_amount::Int + input_fetchers::Dict{UInt,Any} + output_fetchers::Dict{UInt,Any} + open::Bool + migrating::Bool + lock::Threads.Condition + StreamStore{T,B}(uid::UInt, input_buffer_amount::Integer, output_buffer_amount::Integer) where {T,B} = + new{T,B}(uid, zeros(Int, 0), + Dict{UInt,Any}(), Dict{UInt,Any}(), + Dict{UInt,B}(), Dict{UInt,B}(), + input_buffer_amount, output_buffer_amount, + Dict{UInt,Any}(), Dict{UInt,Any}(), + true, false, Threads.Condition()) +end + +function tid_to_uid(thunk_id) + lock(Sch.EAGER_ID_MAP) do id_map + for (uid, otid) in id_map + if thunk_id == otid + return uid + end + end + end +end + +function Base.put!(store::StreamStore{T,B}, value) where {T,B} + thunk_id = STREAM_THUNK_ID[] + @lock store.lock begin + if !isopen(store) + @dagdebug thunk_id :stream "closed!" + throw(InvalidStateException("Stream is closed", :closed)) + end + @dagdebug thunk_id :stream "adding $value ($(length(store.output_streams)) outputs)" + for output_uid in keys(store.output_streams) + if !haskey(store.output_buffers, output_uid) + initialize_output_stream!(store, output_uid) + end + buffer = store.output_buffers[output_uid] + while isfull(buffer) + if !isopen(store) + @dagdebug thunk_id :stream "closed!" + throw(InvalidStateException("Stream is closed", :closed)) + end + @dagdebug thunk_id :stream "buffer full ($(length(buffer)) values), waiting" + wait(store.lock) + if !isfull(buffer) + @dagdebug thunk_id :stream "buffer has space ($(length(buffer)) values), continuing" + end + task_may_cancel!() + end + put!(buffer, value) + end + notify(store.lock) + end +end + +function Base.take!(store::StreamStore, id::UInt) + thunk_id = STREAM_THUNK_ID[] + @lock store.lock begin + if !haskey(store.output_buffers, id) + @assert haskey(store.output_streams, id) + error("Must first check isempty(store, id) before taking from a stream") + end + buffer = store.output_buffers[id] + while isempty(buffer) && isopen(store, id) + @dagdebug thunk_id :stream "no elements, not taking" + wait(store.lock) + task_may_cancel!() + end + @dagdebug thunk_id :stream "wait finished" + if !isopen(store, id) + @dagdebug thunk_id :stream "closed!" + throw(InvalidStateException("Stream is closed", :closed)) + end + unlock(store.lock) + value = try + take!(buffer) + finally + lock(store.lock) + end + @dagdebug thunk_id :stream "value accepted" + notify(store.lock) + return value + end +end + +""" +Returns whether the store is actively open. Only check this when deciding if +new values can be pushed. +""" +Base.isopen(store::StreamStore) = store.open + +""" +Returns whether the store is actively open, or if closing, still has remaining +messages for `id`. Only check this when deciding if existing values can be +taken. +""" +function Base.isopen(store::StreamStore, id::UInt) + @lock store.lock begin + if !haskey(store.output_buffers, id) + @assert haskey(store.output_streams, id) + return store.open + end + if !isempty(store.output_buffers[id]) + return true + end + return store.open + end +end + +function Base.close(store::StreamStore) + @lock store.lock begin + store.open || return + store.open = false + for buffer in values(store.input_buffers) + close(buffer) + end + for buffer in values(store.output_buffers) + close(buffer) + end + notify(store.lock) + end +end + +# FIXME: Just pass Stream directly, rather than its uid +function add_waiters!(store::StreamStore{T,B}, waiters::Vector{Pair{UInt,Any}}) where {T,B} + our_uid = store.uid + @lock store.lock begin + for (output_uid, output_fetcher) in waiters + store.output_streams[output_uid] = task_to_stream(output_uid) + push!(store.waiters, output_uid) + store.output_fetchers[output_uid] = output_fetcher + end + notify(store.lock) + end +end + +function remove_waiters!(store::StreamStore, waiters::Vector{UInt}) + @lock store.lock begin + for w in waiters + delete!(store.output_buffers, w) + idx = findfirst(wo->wo==w, store.waiters) + deleteat!(store.waiters, idx) + delete!(store.input_streams, w) + end + notify(store.lock) + end +end + +mutable struct Stream{T,B} + uid::UInt + store::Union{StreamStore{T,B},Nothing} + store_ref::Chunk + function Stream{T,B}(uid::UInt, input_buffer_amount::Integer, output_buffer_amount::Integer) where {T,B} + # Creates a new output stream + store = StreamStore{T,B}(uid, input_buffer_amount, output_buffer_amount) + store_ref = tochunk(store) + return new{T,B}(uid, store, store_ref) + end + function Stream(stream::Stream{T,B}) where {T,B} + # References an existing output stream + return new{T,B}(stream.uid, nothing, stream.store_ref) + end +end + +struct StreamingValue{B} + buffer::B +end +Base.take!(sv::StreamingValue) = take!(sv.buffer) + +function initialize_input_stream!(our_store::StreamStore{OT,OB}, input_stream::Stream{IT,IB}) where {IT,OT,IB,OB} + input_uid = input_stream.uid + our_uid = our_store.uid + local buffer, input_fetcher + @lock our_store.lock begin + if haskey(our_store.input_buffers, input_uid) + return StreamingValue(our_store.input_buffers[input_uid]) + end + + buffer = initialize_stream_buffer(OB, IT, our_store.input_buffer_amount) + # FIXME: Also pass a RemoteChannel to track remote closure + our_store.input_buffers[input_uid] = buffer + input_fetcher = our_store.input_fetchers[input_uid] + end + thunk_id = STREAM_THUNK_ID[] + tls = get_tls() + Sch.errormonitor_tracked("streaming input: $input_uid -> $our_uid", Threads.@spawn begin + set_tls!(tls) + STREAM_THUNK_ID[] = thunk_id + try + while isopen(our_store) + stream_pull_values!(input_fetcher, IT, our_store, input_stream, buffer) + end + catch err + unwrapped_err = Sch.unwrap_nested_exception(err) + if unwrapped_err isa InterruptException || (unwrapped_err isa InvalidStateException && !isopen(buffer)) + return + else + rethrow() + end + finally + @dagdebug STREAM_THUNK_ID[] :stream "input stream closed" + end + end) + return StreamingValue(buffer) +end +initialize_input_stream!(our_store::StreamStore, arg) = arg +function initialize_output_stream!(our_store::StreamStore{T,B}, output_uid::UInt) where {T,B} + @assert islocked(our_store.lock) + @dagdebug STREAM_THUNK_ID[] :stream "initializing output stream $output_uid" + buffer = initialize_stream_buffer(B, T, our_store.output_buffer_amount) + our_store.output_buffers[output_uid] = buffer + our_uid = our_store.uid + output_stream = our_store.output_streams[output_uid] + output_fetcher = our_store.output_fetchers[output_uid] + thunk_id = STREAM_THUNK_ID[] + tls = get_tls() + Sch.errormonitor_tracked("streaming output: $our_uid -> $output_uid", Threads.@spawn begin + set_tls!(tls) + STREAM_THUNK_ID[] = thunk_id + try + while true + if !isopen(our_store) && isempty(buffer) + # Only exit if the buffer is empty; otherwise, we need to + # continue draining it + break + end + stream_push_values!(output_fetcher, T, our_store, output_stream, buffer) + end + catch err + unwrapped_err = Sch.unwrap_nested_exception(err) + if unwrapped_err isa InterruptException || (unwrapped_err isa InvalidStateException && !isopen(buffer)) + return + else + rethrow() + end + finally + @dagdebug thunk_id :stream "output stream closed" + end + end) +end + +Base.put!(stream::Stream, @nospecialize(value)) = put!(stream.store, value) + +function Base.isopen(stream::Stream, id::UInt)::Bool + return MemPool.access_ref(stream.store_ref.handle, id) do store, id + return isopen(store::StreamStore, id) + end +end + +function Base.close(stream::Stream) + MemPool.access_ref(stream.store_ref.handle) do store + close(store::StreamStore) + return + end + return +end + +function add_waiters!(stream::Stream, waiters::Vector{Pair{UInt,Any}}) + MemPool.access_ref(stream.store_ref.handle, waiters) do store, waiters + add_waiters!(store::StreamStore, waiters) + return + end + return +end + +function remove_waiters!(stream::Stream, waiters::Vector{UInt}) + MemPool.access_ref(stream.store_ref.handle, waiters) do store, waiters + remove_waiters!(store::StreamStore, waiters) + return + end + return +end + +struct StreamingFunction{F, S} + f::F + stream::S + max_evals::Int + + StreamingFunction(f::F, stream::S, max_evals) where {F, S} = + new{F, S}(f, stream, max_evals) +end + +function migrate_stream!(stream::Stream, w::Integer=myid()) + # Perform migration of the StreamStore + # MemPool will block access to the new ref until the migration completes + # FIXME: Do this ownership check with MemPool.access_ref, + # in case stream was already migrated + if stream.store_ref.handle.owner != w + thunk_id = STREAM_THUNK_ID[] + @dagdebug thunk_id :stream "Beginning migration... ($(length(stream.store.input_streams)) -> $(length(stream.store.output_streams)))" + + # TODO: Wire up listener to ferry cancel_token notifications to remote + # worker once migrations occur during runtime + tls = get_tls() + @assert w == myid() "Only pull-based migration is currently supported" + #remote_cancel_token = clone_cancel_token_remote(get_tls().cancel_token, worker_id) + + new_store_ref = MemPool.migrate!(stream.store_ref.handle, w; + pre_migration=store->begin + # Lock store to prevent any further modifications + # N.B. Serialization automatically unlocks the migrated copy + lock((store::StreamStore).lock) + + # Return the serializeable unsent inputs/outputs. We can't send the + # buffers themselves because they may be mmap'ed or something. + unsent_inputs = Dict(uid => collect!(buffer) for (uid, buffer) in store.input_buffers) + unsent_outputs = Dict(uid => collect!(buffer) for (uid, buffer) in store.output_buffers) + empty!(store.input_buffers) + empty!(store.output_buffers) + return (unsent_inputs, unsent_outputs) + end, + dest_post_migration=(store, unsent)->begin + # Initialize the StreamStore on the destination with the unsent inputs/outputs. + STREAM_THUNK_ID[] = thunk_id + @assert !in_task() + set_tls!(tls) + #get_tls().cancel_token = MemPool.access_ref(identity, remote_cancel_token; local_only=true) + unsent_inputs, unsent_outputs = unsent + for (input_uid, inputs) in unsent_inputs + input_stream = store.input_streams[input_uid] + initialize_input_stream!(store, input_stream) + for item in inputs + put!(store.input_buffers[input_uid], item) + end + end + for (output_uid, outputs) in unsent_outputs + initialize_output_stream!(store, output_uid) + for item in outputs + put!(store.output_buffers[output_uid], item) + end + end + + # Reset the state of this new store + store.open = true + store.migrating = false + end, + post_migration=store->begin + # Indicate that this store has migrated + store.migrating = true + store.open = false + + # Unlock the store + unlock((store::StreamStore).lock) + end) + if w == myid() + stream.store_ref.handle = new_store_ref # FIXME: It's not valid to mutate the Chunk handle, but we want to update this to enable fast location queries + stream.store = MemPool.access_ref(identity, new_store_ref; local_only=true) + end + + @dagdebug thunk_id :stream "Migration complete ($(length(stream.store.input_streams)) -> $(length(stream.store.output_streams)))" + end +end + +struct StreamingTaskQueue <: AbstractTaskQueue + tasks::Vector{Pair{DTaskSpec,DTask}} + self_streams::Dict{UInt,Any} + StreamingTaskQueue() = new(Pair{DTaskSpec,DTask}[], + Dict{UInt,Any}()) +end + +function enqueue!(queue::StreamingTaskQueue, spec::Pair{DTaskSpec,DTask}) + push!(queue.tasks, spec) + initialize_streaming!(queue.self_streams, spec...) +end + +function enqueue!(queue::StreamingTaskQueue, specs::Vector{Pair{DTaskSpec,DTask}}) + append!(queue.tasks, specs) + for (spec, task) in specs + initialize_streaming!(queue.self_streams, spec, task) + end +end + +function initialize_streaming!(self_streams, spec, task) + @assert !isa(spec.f, StreamingFunction) "Task is already in streaming form" + + # Calculate the return type of the called function + T_old = Base.uniontypes(task.metadata.return_type) + T_old = map(t->(t !== Union{} && t <: FinishStream) ? first(t.parameters) : t, T_old) + # N.B. We treat non-dominating error paths as unreachable + T_old = filter(t->t !== Union{}, T_old) + T = task.metadata.return_type = !isempty(T_old) ? Union{T_old...} : Any + + # Get input buffer configuration + input_buffer_amount = get(spec.options, :stream_input_buffer_amount, 1) + if input_buffer_amount <= 0 + throw(ArgumentError("Input buffering is required; please specify a `stream_input_buffer_amount` greater than 0")) + end + + # Get output buffer configuration + output_buffer_amount = get(spec.options, :stream_output_buffer_amount, 1) + if output_buffer_amount <= 0 + throw(ArgumentError("Output buffering is required; please specify a `stream_output_buffer_amount` greater than 0")) + end + + # Create the Stream + buffer_type = get(spec.options, :stream_buffer_type, ProcessRingBuffer) + stream = Stream{T,buffer_type}(task.uid, input_buffer_amount, output_buffer_amount) + self_streams[task.uid] = stream + + # Get max evaluation count + max_evals = get(spec.options, :stream_max_evals, -1) + if max_evals == 0 + throw(ArgumentError("stream_max_evals cannot be 0")) + end + + # Wrap the function in a StreamingFunction + spec.f = StreamingFunction(spec.f, stream, max_evals) + + # Mark the task as non-blocking + spec.options = merge(spec.options, (;occupancy=Dict(Any=>0))) + + # Register Stream globally + remotecall_wait(1, task.uid, stream) do uid, stream + lock(EAGER_THUNK_STREAMS) do global_streams + global_streams[uid] = stream + end + end +end + +""" +Starts a streaming region, within which all tasks run continuously and +concurrently. Any `DTask` argument that is itself a streaming task will be +treated as a streaming input/output. The streaming region will automatically +handle the buffering and synchronization of these tasks' values. + +# Keyword Arguments +- `teardown::Bool=true`: If `true`, the streaming region will automatically + cancel all tasks if any task fails or is cancelled. Otherwise, a failing task + will not cancel the other tasks, which will continue running. +""" +function spawn_streaming(f::Base.Callable; teardown::Bool=true) + queue = StreamingTaskQueue() + result = with_options(f; task_queue=queue) + if length(queue.tasks) > 0 + finalize_streaming!(queue.tasks, queue.self_streams) + enqueue!(queue.tasks) + + if teardown + # Start teardown monitor + dtasks = map(last, queue.tasks)::Vector{DTask} + Sch.errormonitor_tracked("streaming teardown", Threads.@spawn begin + # Wait for any task to finish + waitany(dtasks) + + # Cancel all tasks + for task in dtasks + cancel!(task; graceful=false) + end + end) + end + end + return result +end + +struct FinishStream{T,R} + value::Union{Some{T},Nothing} + result::R +end + +finish_stream(value::T; result::R=nothing) where {T,R} = FinishStream{T,R}(Some{T}(value), result) + +finish_stream(; result::R=nothing) where R = FinishStream{Union{},R}(nothing, result) + +const STREAM_THUNK_ID = TaskLocalValue{Int}(()->0) + +chunktype(sf::StreamingFunction{F}) where F = F + +struct StreamMigrating end + +function (sf::StreamingFunction)(args...; kwargs...) + thunk_id = Sch.sch_handle().thunk_id.id + STREAM_THUNK_ID[] = thunk_id + + # Migrate our output stream store to this worker + if sf.stream isa Stream + remote_cancel_token = migrate_stream!(sf.stream) + end + + @label start + @dagdebug thunk_id :stream "Starting StreamingFunction" + worker_id = sf.stream.store_ref.handle.owner # FIXME: Not valid to access the owner directly + result = if worker_id == myid() + _run_streamingfunction(nothing, nothing, sf, args...; kwargs...) + else + tls = get_tls() + remotecall_fetch(_run_streamingfunction, worker_id, tls, remote_cancel_token, sf, args...; kwargs...) + end + if result === StreamMigrating() + @goto start + end + return result +end + +function _run_streamingfunction(tls, cancel_token, sf, args...; kwargs...) + @nospecialize sf args kwargs + + store = sf.stream.store = MemPool.access_ref(identity, sf.stream.store_ref.handle; local_only=true) + @assert isopen(store) + + if tls !== nothing + # Setup TLS on this new task + tls.cancel_token = MemPool.access_ref(identity, cancel_token; local_only=true) + set_tls!(tls) + end + + thunk_id = Sch.sch_handle().thunk_id.id + STREAM_THUNK_ID[] = thunk_id + + # FIXME: Remove when scheduler is distributed + uid = remotecall_fetch(1, thunk_id) do thunk_id + lock(Sch.EAGER_ID_MAP) do id_map + for (uid, otid) in id_map + if thunk_id == otid + return uid + end + end + end + end + + try + # TODO: This kwarg song-and-dance is required to ensure that we don't + # allocate boxes within `stream!`, when possible + kwarg_names = map(name->Val{name}(), map(first, (kwargs...,))) + kwarg_values = map(last, (kwargs...,)) + args = map(arg->initialize_input_stream!(store, arg), args) + kwarg_values = map(kwarg->initialize_input_stream!(store, kwarg), kwarg_values) + return stream!(sf, uid, (args...,), kwarg_names, kwarg_values) + finally + if !sf.stream.store.migrating + # Remove ourself as a waiter for upstream Streams + streams = Set{Stream}() + for (idx, arg) in enumerate(args) + if arg isa Stream + push!(streams, arg) + end + end + for (idx, (pos, arg)) in enumerate(kwargs) + if arg isa Stream + push!(streams, arg) + end + end + for stream in streams + @dagdebug thunk_id :stream "dropping waiter" + remove_waiters!(stream, uid) + @dagdebug thunk_id :stream "dropped waiter" + end + + # Ensure downstream tasks also terminate + close(sf.stream) + @dagdebug thunk_id :stream "closed stream store" + end + end +end + +# N.B We specialize to minimize/eliminate allocations +function stream!(sf::StreamingFunction, uid, + args::Tuple, kwarg_names::Tuple, kwarg_values::Tuple) + f = move(task_processor(), sf.f) + counter = 0 + + while true + # Yield to other (streaming) tasks + yield() + + # Exit streaming on cancellation + task_may_cancel!() + + # Exit streaming on migration + if sf.stream.store.migrating + error("FIXME: max_evals should be retained") + @dagdebug STREAM_THUNK_ID[] :stream "returning for migration" + return StreamMigrating() + end + + # Get values from Stream args/kwargs + stream_args = _stream_take_values!(args) + stream_kwarg_values = _stream_take_values!(kwarg_values) + stream_kwargs = _stream_namedtuple(kwarg_names, stream_kwarg_values) + + if length(stream_args) > 0 || length(stream_kwarg_values) > 0 + # Notify tasks that input buffers may have space + @lock sf.stream.store.lock notify(sf.stream.store.lock) + end + + # Run a single cycle of f + counter += 1 + @dagdebug STREAM_THUNK_ID[] :stream "executing $f (eval $counter)" + stream_result = f(stream_args...; stream_kwargs...) + + # Exit streaming on graceful request + if stream_result isa FinishStream + if stream_result.value !== nothing + value = something(stream_result.value) + put!(sf.stream, value) + end + @dagdebug STREAM_THUNK_ID[] :stream "voluntarily returning" + return stream_result.result + end + + # Put the result into the output stream + put!(sf.stream, stream_result) + + # Exit streaming on eval limit + if sf.max_evals > 0 && counter >= sf.max_evals + @dagdebug STREAM_THUNK_ID[] :stream "max evals reached (eval $counter)" + return + end + end +end + +function _stream_take_values!(args) + return ntuple(length(args)) do idx + arg = args[idx] + if arg isa StreamingValue + return take!(arg) + else + return arg + end + end +end + +@inline @generated function _stream_namedtuple(kwarg_names::Tuple, + stream_kwarg_values::Tuple) + name_ex = Expr(:tuple, map(name->QuoteNode(name.parameters[1]), kwarg_names.parameters)...) + NT = :(NamedTuple{$name_ex,$stream_kwarg_values}) + return :($NT(stream_kwarg_values)) +end + +# Default for buffers, can be customized +initialize_stream_buffer(B, T, buffer_amount) = B{T}(buffer_amount) + +const EAGER_THUNK_STREAMS = LockedObject(Dict{UInt,Any}()) +function task_to_stream(uid::UInt) + if myid() != 1 + return remotecall_fetch(task_to_stream, 1, uid) + end + lock(EAGER_THUNK_STREAMS) do global_streams + if haskey(global_streams, uid) + return global_streams[uid] + end + return + end +end + +function finalize_streaming!(tasks::Vector{Pair{DTaskSpec,DTask}}, self_streams) + stream_waiter_changes = Dict{UInt,Vector{Pair{UInt,Any}}}() + + for (spec, task) in tasks + @assert haskey(self_streams, task.uid) + our_stream = self_streams[task.uid] + + # Adapt args to accept Stream output of other streaming tasks + for (idx, (pos, arg)) in enumerate(spec.args) + if arg isa DTask + # Check if this is a streaming task + if haskey(self_streams, arg.uid) + other_stream = self_streams[arg.uid] + else + other_stream = task_to_stream(arg.uid) + end + + if other_stream !== nothing + # Generate Stream handle for input + # FIXME: Be configurable + input_fetcher = RemoteChannelFetcher() + other_stream_handle = Stream(other_stream) + spec.args[idx] = pos => other_stream_handle + our_stream.store.input_streams[arg.uid] = other_stream_handle + our_stream.store.input_fetchers[arg.uid] = input_fetcher + + # Add this task as a waiter for the associated output Stream + changes = get!(stream_waiter_changes, arg.uid) do + Pair{UInt,Any}[] + end + push!(changes, task.uid => input_fetcher) + end + end + end + + # Filter out all streaming options + to_filter = (:stream_buffer_type, + :stream_input_buffer_amount, :stream_output_buffer_amount, + :stream_max_evals) + spec.options = NamedTuple(filter(opt -> !(opt[1] in to_filter), + Base.pairs(spec.options))) + if haskey(spec.options, :propagates) + propagates = filter(opt -> !(opt in to_filter), + spec.options.propagates) + spec.options = merge(spec.options, (;propagates)) + end + end + + # Notify Streams of any new waiters + for (uid, waiters) in stream_waiter_changes + stream = task_to_stream(uid) + add_waiters!(stream, waiters) + end +end diff --git a/src/submission.jl b/src/submission.jl index 7312e378..f2353927 100644 --- a/src/submission.jl +++ b/src/submission.jl @@ -218,15 +218,27 @@ function eager_process_options_submission_to_local(id_map, options::NamedTuple) return options end end + +function DTaskMetadata(spec::DTaskSpec) + f = spec.f isa StreamingFunction ? spec.f.f : spec.f + arg_types = ntuple(i->chunktype(spec.args[i][2]), length(spec.args)) + return_type = Base.promote_op(f, arg_types...) + return DTaskMetadata(return_type) +end + function eager_spawn(spec::DTaskSpec) # Generate new DTask uid = eager_next_id() future = ThunkFuture() + metadata = DTaskMetadata(spec) finalizer_ref = poolset(DTaskFinalizer(uid); device=MemPool.CPURAMDevice()) # Create unlaunched DTask - return DTask(uid, future, finalizer_ref) + return DTask(uid, future, metadata, finalizer_ref) end + +chunktype(t::DTask) = t.metadata.return_type + function eager_launch!((spec, task)::Pair{DTaskSpec,DTask}) # Assign a name, if specified eager_assign_name!(spec, task) diff --git a/src/task-tls.jl b/src/task-tls.jl index ea188e00..5c7d0375 100644 --- a/src/task-tls.jl +++ b/src/task-tls.jl @@ -1,41 +1,81 @@ # In-Thunk Helpers +mutable struct DTaskTLS + processor::Processor + sch_uid::UInt + sch_handle::Any # FIXME: SchedulerHandle + task_spec::Vector{Any} # FIXME: TaskSpec + cancel_token::CancelToken +end + +const DTASK_TLS = TaskLocalValue{Union{DTaskTLS,Nothing}}(()->nothing) + +Base.copy(tls::DTaskTLS) = DTaskTLS(tls.processor, tls.sch_uid, tls.sch_handle, tls.task_spec, tls.cancel_token) + """ - task_processor() + get_tls() -> DTaskTLS -Get the current processor executing the current Dagger task. +Gets all Dagger TLS variable as a `DTaskTLS`. """ -task_processor() = task_local_storage(:_dagger_processor)::Processor -@deprecate thunk_processor() task_processor() +get_tls() = DTASK_TLS[]::DTaskTLS """ - in_task() + set_tls!(tls) + +Sets all Dagger TLS variables from `tls`, which may be a `DTaskTLS` or a `NamedTuple`. +""" +function set_tls!(tls) + DTASK_TLS[] = DTaskTLS(tls.processor, tls.sch_uid, tls.sch_handle, tls.task_spec, tls.cancel_token) +end + +""" + in_task() -> Bool Returns `true` if currently executing in a [`DTask`](@ref), else `false`. """ -in_task() = haskey(task_local_storage(), :_dagger_sch_uid) -@deprecate in_thunk() in_task() +in_task() = DTASK_TLS[] !== nothing +@deprecate(in_thunk(), in_task()) """ - get_tls() + task_id() -> Int -Gets all Dagger TLS variable as a `NamedTuple`. +Returns the ID of the current [`DTask`](@ref). """ -get_tls() = ( - sch_uid=task_local_storage(:_dagger_sch_uid), - sch_handle=task_local_storage(:_dagger_sch_handle), - processor=task_processor(), - task_spec=task_local_storage(:_dagger_task_spec), -) +task_id() = get_tls().sch_handle.thunk_id.id """ - set_tls!(tls) + task_processor() -> Processor -Sets all Dagger TLS variables from the `NamedTuple` `tls`. +Get the current processor executing the current [`DTask`](@ref). """ -function set_tls!(tls) - task_local_storage(:_dagger_sch_uid, tls.sch_uid) - task_local_storage(:_dagger_sch_handle, tls.sch_handle) - task_local_storage(:_dagger_processor, tls.processor) - task_local_storage(:_dagger_task_spec, tls.task_spec) +task_processor() = get_tls().processor +@deprecate(thunk_processor(), task_processor()) + +""" + task_cancelled(; must_force::Bool=false) -> Bool + +Returns `true` if the current [`DTask`](@ref) has been cancelled, else `false`. +If `must_force=true`, then only return `true` if the cancellation was forced. +""" +task_cancelled(; must_force::Bool=false) = + is_cancelled(get_tls().cancel_token; must_force) + +""" + task_may_cancel!(; must_force::Bool=false) + +Throws an `InterruptException` if the current [`DTask`](@ref) has been cancelled. +If `must_force=true`, then only throw if the cancellation was forced. +""" +function task_may_cancel!(;must_force::Bool=false) + if task_cancelled(;must_force) + throw(InterruptException()) + end end + +""" + task_cancel!(; graceful::Bool=true) + +Cancels the current [`DTask`](@ref). If `graceful=true`, then the task will be +cancelled gracefully, otherwise it will be forced. +""" +task_cancel!(; graceful::Bool=true) = cancel!(get_tls().cancel_token; graceful) diff --git a/src/threadproc.jl b/src/threadproc.jl index 09099889..b75c90ca 100644 --- a/src/threadproc.jl +++ b/src/threadproc.jl @@ -27,8 +27,9 @@ function execute!(proc::ThreadProc, @nospecialize(f), @nospecialize(args...); @n return result[] catch err if err isa InterruptException + # Direct interrupt hit us, propagate cancellation signal + # FIXME: We should tell the scheduler that the user hit Ctrl-C if !istaskdone(task) - # Propagate cancellation signal Threads.@spawn Base.throwto(task, InterruptException()) end end diff --git a/src/utils/dagdebug.jl b/src/utils/dagdebug.jl index 9a9d2416..6a71e5c5 100644 --- a/src/utils/dagdebug.jl +++ b/src/utils/dagdebug.jl @@ -2,7 +2,8 @@ function istask end function task_id end const DAGDEBUG_CATEGORIES = Symbol[:global, :submit, :schedule, :scope, - :take, :execute, :move, :processor, :cancel] + :take, :execute, :move, :processor, :cancel, + :stream] macro dagdebug(thunk, category, msg, args...) cat_sym = category.value @gensym id @@ -31,6 +32,10 @@ macro dagdebug(thunk, category, msg, args...) $debug_ex_noid end end + + # Always yield to reduce differing behavior for debug vs. non-debug + # TODO: Remove this eventually + yield() end end) end diff --git a/src/utils/tasks.jl b/src/utils/tasks.jl index c2796cf2..ddd8da2e 100644 --- a/src/utils/tasks.jl +++ b/src/utils/tasks.jl @@ -18,3 +18,115 @@ function set_task_tid!(task::Task, tid::Integer) end @assert Threads.threadid(task) == tid "jl_set_task_tid failed!" end + +if isdefined(Base, :waitany) +import Base: waitany, waitall +else +# Vendored from Base +# License is MIT +waitany(tasks; throw=true) = _wait_multiple(tasks, throw) +waitall(tasks; failfast=true, throw=true) = _wait_multiple(tasks, throw, true, failfast) +function _wait_multiple(waiting_tasks, throwexc=false, all=false, failfast=false) + tasks = Task[] + + for t in waiting_tasks + t isa Task || error("Expected an iterator of `Task` object") + push!(tasks, t) + end + + if (all && !failfast) || length(tasks) <= 1 + exception = false + # Force everything to finish synchronously for the case of waitall + # with failfast=false + for t in tasks + _wait(t) + exception |= istaskfailed(t) + end + if exception && throwexc + exceptions = [TaskFailedException(t) for t in tasks if istaskfailed(t)] + throw(CompositeException(exceptions)) + else + return tasks, Task[] + end + end + + exception = false + nremaining::Int = length(tasks) + done_mask = falses(nremaining) + for (i, t) in enumerate(tasks) + if istaskdone(t) + done_mask[i] = true + exception |= istaskfailed(t) + nremaining -= 1 + else + done_mask[i] = false + end + end + + if nremaining == 0 + return tasks, Task[] + elseif any(done_mask) && (!all || (failfast && exception)) + if throwexc && (!all || failfast) && exception + exceptions = [TaskFailedException(t) for t in tasks[done_mask] if istaskfailed(t)] + throw(CompositeException(exceptions)) + else + return tasks[done_mask], tasks[.~done_mask] + end + end + + chan = Channel{Int}(Inf) + sentinel = current_task() + waiter_tasks = fill(sentinel, length(tasks)) + + for (i, done) in enumerate(done_mask) + done && continue + t = tasks[i] + if istaskdone(t) + done_mask[i] = true + exception |= istaskfailed(t) + nremaining -= 1 + exception && failfast && break + else + waiter = @task put!(chan, i) + waiter.sticky = false + _wait2(t, waiter) + waiter_tasks[i] = waiter + end + end + + while nremaining > 0 + i = take!(chan) + t = tasks[i] + waiter_tasks[i] = sentinel + done_mask[i] = true + exception |= istaskfailed(t) + nremaining -= 1 + + # stop early if requested, unless there is something immediately + # ready to consume from the channel (using a race-y check) + if (!all || (failfast && exception)) && !isready(chan) + break + end + end + + close(chan) + + if nremaining == 0 + return tasks, Task[] + else + remaining_mask = .~done_mask + for i in findall(remaining_mask) + waiter = waiter_tasks[i] + donenotify = tasks[i].donenotify::ThreadSynchronizer + @lock donenotify Base.list_deletefirst!(donenotify.waitq, waiter) + end + done_tasks = tasks[done_mask] + if throwexc && exception + exceptions = [TaskFailedException(t) for t in done_tasks if istaskfailed(t)] + throw(CompositeException(exceptions)) + else + return done_tasks, tasks[remaining_mask] + end + end +end +end diff --git a/test/mutation.jl b/test/mutation.jl index b6ac7143..fa2f62bc 100644 --- a/test/mutation.jl +++ b/test/mutation.jl @@ -48,7 +48,7 @@ end x = Dagger.@mutable worker=w Ref{Int}() @test fetch(Dagger.@spawn mutable_update!(x)) == w wo_scope = Dagger.ProcessScope(wo) - @test_throws_unwrap Dagger.DTaskFailedException fetch(Dagger.@spawn scope=wo_scope mutable_update!(x)) + @test_throws_unwrap (Dagger.DTaskFailedException, Dagger.Sch.SchedulingException) fetch(Dagger.@spawn scope=wo_scope mutable_update!(x)) end end # @testset "@mutable" diff --git a/test/processors.jl b/test/processors.jl index e97a1d23..6e56876d 100644 --- a/test/processors.jl +++ b/test/processors.jl @@ -37,9 +37,9 @@ end end @testset "Processor exhaustion" begin opts = ThunkOptions(proclist=[OptOutProc]) - @test_throws_unwrap Dagger.DTaskFailedException ex isa Dagger.Sch.SchedulingException ex.reason="No processors available, try widening scope" collect(delayed(sum; options=opts)([1,2,3])) + @test_throws_unwrap (Dagger.DTaskFailedException, Dagger.Sch.SchedulingException) reason="No processors available, try widening scope" collect(delayed(sum; options=opts)([1,2,3])) opts = ThunkOptions(proclist=(proc)->false) - @test_throws_unwrap Dagger.DTaskFailedException ex isa Dagger.Sch.SchedulingException ex.reason="No processors available, try widening scope" collect(delayed(sum; options=opts)([1,2,3])) + @test_throws_unwrap (Dagger.DTaskFailedException, Dagger.Sch.SchedulingException) reason="No processors available, try widening scope" collect(delayed(sum; options=opts)([1,2,3])) opts = ThunkOptions(proclist=nothing) @test collect(delayed(sum; options=opts)([1,2,3])) == 6 end diff --git a/test/runtests.jl b/test/runtests.jl index cfdab817..79ba890d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,6 +10,7 @@ tests = [ ("Mutation", "mutation.jl"), ("Task Queues", "task-queues.jl"), ("Datadeps", "datadeps.jl"), + ("Streaming", "streaming.jl"), ("Domain Utilities", "domain.jl"), ("Array - Allocation", "array/allocation.jl"), ("Array - Indexing", "array/indexing.jl"), @@ -35,7 +36,10 @@ if PROGRAM_FILE != "" && realpath(PROGRAM_FILE) == @__FILE__ pushfirst!(LOAD_PATH, joinpath(@__DIR__, "..")) using Pkg Pkg.activate(@__DIR__) - Pkg.instantiate() + try + Pkg.instantiate() + catch + end using ArgParse s = ArgParseSettings(description = "Dagger Testsuite") @@ -52,6 +56,12 @@ if PROGRAM_FILE != "" && realpath(PROGRAM_FILE) == @__FILE__ arg_type = Int default = additional_workers help = "How many additional workers to launch" + "-v", "--verbose" + action = :store_true + help = "Run the tests with debug logs from Dagger" + "-O", "--offline" + action = :store_true + help = "Set Pkg into offline mode" end end @@ -81,12 +91,20 @@ if PROGRAM_FILE != "" && realpath(PROGRAM_FILE) == @__FILE__ parsed_args["simulate"] && exit(0) additional_workers = parsed_args["procs"] + + if parsed_args["verbose"] + ENV["JULIA_DEBUG"] = "Dagger" + end + + if parsed_args["offline"] + Pkg.UPDATED_REGISTRY_THIS_SESSION[] = true + Pkg.offline(true) + end else to_test = all_test_names @info "Running all tests" end - using Distributed if additional_workers > 0 # We put this inside a branch because addprocs() takes a minimum of 1s to diff --git a/test/scheduler.jl b/test/scheduler.jl index b9fe0187..b12ad3e1 100644 --- a/test/scheduler.jl +++ b/test/scheduler.jl @@ -182,7 +182,7 @@ end @testset "allow errors" begin opts = ThunkOptions(;allow_errors=true) a = delayed(error; options=opts)("Test") - @test_throws_unwrap Dagger.DTaskFailedException collect(a) + @test_throws_unwrap (Dagger.DTaskFailedException, ErrorException) collect(a) end end @@ -396,7 +396,7 @@ end ([Dagger.tochunk(MyStruct(1)), Dagger.tochunk(1)], sizeof(MyStruct)+sizeof(Int)), ] for arg in args - if arg isa Chunk + if arg isa Dagger.Chunk aff = Dagger.affinity(arg) @test aff[1] == OSProc(1) @test aff[2] == MemPool.approx_size(MemPool.poolget(arg.handle)) @@ -477,7 +477,7 @@ end @test res == 2 @testset "self as input" begin a = delayed(dynamic_add_thunk_self_dominated)(1) - @test_throws_unwrap Dagger.Sch.DynamicThunkException reason="Cannot fetch result of dominated thunk" collect(Context(), a) + @test_throws_unwrap (RemoteException, Dagger.Sch.DynamicThunkException) reason="Cannot fetch result of dominated thunk" collect(Context(), a) end end @testset "Fetch/Wait" begin @@ -487,11 +487,11 @@ end end @testset "self" begin a = delayed(dynamic_fetch_self)(1) - @test_throws_unwrap Dagger.Sch.DynamicThunkException reason="Cannot fetch own result" collect(Context(), a) + @test_throws_unwrap (RemoteException, Dagger.Sch.DynamicThunkException) reason="Cannot fetch own result" collect(Context(), a) end @testset "dominated" begin a = delayed(identity)(delayed(dynamic_fetch_dominated)(1)) - @test_throws_unwrap Dagger.Sch.DynamicThunkException reason="Cannot fetch result of dominated thunk" collect(Context(), a) + @test_throws_unwrap (RemoteException, Dagger.Sch.DynamicThunkException) reason="Cannot fetch result of dominated thunk" collect(Context(), a) end end end @@ -540,7 +540,7 @@ end t = Dagger.@spawn scope=Dagger.scope(worker=1, thread=1) sleep(100) start_time = time_ns() Dagger.cancel!(t) - @test_throws_unwrap Dagger.DTaskFailedException fetch(t) + @test_throws_unwrap (Dagger.DTaskFailedException, InterruptException) fetch(t) t = Dagger.@spawn scope=Dagger.scope(worker=1, thread=1) yield() fetch(t) finish_time = time_ns() diff --git a/test/scopes.jl b/test/scopes.jl index 5f82a71a..065e5158 100644 --- a/test/scopes.jl +++ b/test/scopes.jl @@ -1,3 +1,4 @@ +#@everywhere ENV["JULIA_DEBUG"] = "Dagger" @testset "Chunk Scopes" begin wid1, wid2 = addprocs(2, exeflags=["-t 2"]) @everywhere [wid1,wid2] using Dagger @@ -56,7 +57,7 @@ # Different nodes for (ch1, ch2) in [(ns1_ch, ns2_ch), (ns2_ch, ns1_ch)] - @test_throws_unwrap Dagger.DTaskFailedException ex.reason<"Scopes are not compatible:" fetch(Dagger.@spawn ch1 + ch2) + @test_throws_unwrap (Dagger.DTaskFailedException, Dagger.Sch.SchedulingException) reason<"Scopes are not compatible:" fetch(Dagger.@spawn ch1 + ch2) end end @testset "Process Scope" begin @@ -75,7 +76,7 @@ # Different process for (ch1, ch2) in [(ps1_ch, ps2_ch), (ps2_ch, ps1_ch)] - @test_throws_unwrap Dagger.DTaskFailedException ex.reason<"Scopes are not compatible:" fetch(Dagger.@spawn ch1 + ch2) + @test_throws_unwrap (Dagger.DTaskFailedException, Dagger.Sch.SchedulingException) reason<"Scopes are not compatible:" fetch(Dagger.@spawn ch1 + ch2) end # Same process and node @@ -83,7 +84,7 @@ # Different process and node for (ch1, ch2) in [(ps1_ch, ns2_ch), (ns2_ch, ps1_ch)] - @test_throws_unwrap Dagger.DTaskFailedException ex.reason<"Scopes are not compatible:" fetch(Dagger.@spawn ch1 + ch2) + @test_throws_unwrap (Dagger.DTaskFailedException, Dagger.Sch.SchedulingException) reason<"Scopes are not compatible:" fetch(Dagger.@spawn ch1 + ch2) end end @testset "Exact Scope" begin @@ -104,14 +105,14 @@ # Different process, different processor for (ch1, ch2) in [(es1_ch, es2_ch), (es2_ch, es1_ch)] - @test_throws_unwrap Dagger.DTaskFailedException ex.reason<"Scopes are not compatible:" fetch(Dagger.@spawn ch1 + ch2) + @test_throws_unwrap (Dagger.DTaskFailedException, Dagger.Sch.SchedulingException) reason<"Scopes are not compatible:" fetch(Dagger.@spawn ch1 + ch2) end # Same process, different processor es1_2 = ExactScope(Dagger.ThreadProc(wid1, 2)) es1_2_ch = Dagger.tochunk(nothing, OSProc(), es1_2) for (ch1, ch2) in [(es1_ch, es1_2_ch), (es1_2_ch, es1_ch)] - @test_throws_unwrap Dagger.DTaskFailedException ex.reason<"Scopes are not compatible:" fetch(Dagger.@spawn ch1 + ch2) + @test_throws_unwrap (Dagger.DTaskFailedException, Dagger.Sch.SchedulingException) reason<"Scopes are not compatible:" fetch(Dagger.@spawn ch1 + ch2) end end @testset "Union Scope" begin diff --git a/test/streaming.jl b/test/streaming.jl new file mode 100644 index 00000000..c3bf0e40 --- /dev/null +++ b/test/streaming.jl @@ -0,0 +1,415 @@ +const ACCUMULATOR = Dict{Int,Vector{Real}}() +@everywhere function accumulator(x=0) + tid = Dagger.task_id() + remotecall_wait(1, tid, x) do tid, x + acc = get!(Vector{Real}, ACCUMULATOR, tid) + push!(acc, x) + end + return +end +@everywhere accumulator(xs...) = accumulator(sum(xs)) +@everywhere accumulator(::Nothing) = accumulator(0) + +function catch_interrupt(f) + try + f() + catch err + if err isa Dagger.DTaskFailedException && err.ex isa InterruptException + return + elseif err isa Dagger.Sch.SchedulingException + return + end + rethrow() + end +end + +function merge_testset!(inner::Test.DefaultTestSet) + outer = Test.get_testset() + append!(outer.results, inner.results) + outer.n_passed += inner.n_passed +end + +function test_finishes(f, message::String; timeout=10, ignore_timeout=false, max_evals=10) + t = @eval Threads.@spawn begin + tset = nothing + try + @testset $message begin + try + @testset $message begin + Dagger.with_options(;stream_max_evals=$max_evals) do + catch_interrupt($f) + end + end + finally + tset = Test.get_testset() + end + end + catch + end + return tset + end + + timed_out = timedwait(()->istaskdone(t), timeout) == :timed_out + if timed_out + if !ignore_timeout + @warn "Testing task timed out: $message" + end + Dagger.cancel!(;halt_sch=true, graceful=false) + @everywhere GC.gc() + fetch(Dagger.@spawn 1+1) + end + + tset = fetch(t)::Test.DefaultTestSet + merge_testset!(tset) + return !timed_out +end + +all_scopes = [Dagger.ExactScope(proc) for proc in Dagger.all_processors()] +for idx in 1:5 + if idx == 1 + scopes = [Dagger.scope(worker = 1, thread = 1)] + scope_str = "Worker 1" + elseif idx == 2 && nprocs() > 1 + scopes = [Dagger.scope(worker = 2, thread = 1)] + scope_str = "Worker 2" + else + scopes = all_scopes + scope_str = "All Workers" + end + + @testset "Single Task Control Flow ($scope_str)" begin + @test !test_finishes("Single task running forever"; max_evals=1_000_000, ignore_timeout=true) do + local x + Dagger.spawn_streaming(;teardown=false) do + x = Dagger.@spawn scope=rand(scopes) () -> begin + y = rand() + sleep(1) + return y + end + end + @test_throws_unwrap InterruptException fetch(x) + end + + @test test_finishes("Single task without result") do + local x + Dagger.spawn_streaming(;teardown=false) do + x = Dagger.@spawn scope=rand(scopes) rand() + end + @test fetch(x) === nothing + end + + @test test_finishes("Single task with result"; max_evals=1_000_000) do + local x + Dagger.spawn_streaming(;teardown=false) do + x = Dagger.@spawn scope=rand(scopes) () -> begin + x = rand() + if x < 0.1 + return Dagger.finish_stream(x; result=123) + end + return x + end + end + @test fetch(x) == 123 + end + end + + @testset "Non-Streaming Inputs ($scope_str)" begin + @test test_finishes("() -> A") do + local A + Dagger.spawn_streaming(;teardown=false) do + A = Dagger.@spawn scope=rand(scopes) accumulator() + end + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 10 + @test all(==(0), values[A_tid]) + end + @test test_finishes("42 -> A") do + local A + Dagger.spawn_streaming(;teardown=false) do + A = Dagger.@spawn scope=rand(scopes) accumulator(42) + end + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 10 + @test all(==(42), values[A_tid]) + end + @test test_finishes("(42, 43) -> A") do + local A + Dagger.spawn_streaming(;teardown=false) do + A = Dagger.@spawn scope=rand(scopes) accumulator(42, 43) + end + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 10 + @test all(==(42 + 43), values[A_tid]) + end + end + + @testset "Non-Streaming Outputs ($scope_str)" begin + @test test_finishes("x -> A") do + local x, A + Dagger.spawn_streaming(;teardown=false) do + x = Dagger.@spawn scope=rand(scopes) rand() + end + Dagger._without_options() do + A = Dagger.@spawn accumulator(x) + end + @test fetch(x) === nothing + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 1 + @test all(v -> 0 <= v <= 10, values[A_tid]) + end + + @test test_finishes("x -> (A, B)") do + local x, A, B + Dagger.spawn_streaming(;teardown=false) do + x = Dagger.@spawn scope=rand(scopes) rand() + end + Dagger._without_options() do + A = Dagger.@spawn accumulator(x) + B = Dagger.@spawn accumulator(x) + end + @test fetch(x) === nothing + @test fetch(A) === nothing + @test fetch(B) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 1 + @test all(v -> 0 <= v <= 10, values[A_tid]) + B_tid = Dagger.task_id(B) + @test length(values[B_tid]) == 1 + @test all(v -> 0 <= v <= 10, values[B_tid]) + end + end + + @testset "Teardown" begin + @test test_finishes("teardown=true"; max_evals=1_000_000, ignore_timeout=true) do + local x, y + Dagger.spawn_streaming(;teardown=true) do + x = Dagger.@spawn scope=rand(scopes) () -> begin + sleep(0.1) + return rand() + end + y = Dagger.with_options(;stream_max_evals=10) do + Dagger.@spawn scope=rand(scopes) identity(x) + end + end + @test fetch(y) === nothing + sleep(1) # Wait for teardown + @test istaskdone(x) + fetch(x) + end + @test !test_finishes("teardown=false"; max_evals=1_000_000, ignore_timeout=true) do + local x, y + Dagger.spawn_streaming(;teardown=false) do + x = Dagger.@spawn scope=rand(scopes) () -> begin + sleep(0.1) + return rand() + end + y = Dagger.with_options(;stream_max_evals=10) do + Dagger.@spawn scope=rand(scopes) identity(x) + end + end + @test fetch(y) === nothing + sleep(1) # Wait to ensure `x` task is still running + @test !istaskdone(x) + @test_throws_unwrap InterruptException fetch(x) + end + end + + @testset "Multiple Tasks ($scope_str)" begin + @test test_finishes("x -> A") do + local x, A + Dagger.spawn_streaming(;teardown=false) do + x = Dagger.@spawn scope=rand(scopes) rand() + A = Dagger.@spawn scope=rand(scopes) accumulator(x) + end + @test fetch(x) === nothing + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 10 + @test all(v -> 0 <= v <= 1, values[A_tid]) + end + + @test test_finishes("(x, A)") do + local x, A + Dagger.spawn_streaming(;teardown=false) do + x = Dagger.@spawn scope=rand(scopes) rand() + A = Dagger.@spawn scope=rand(scopes) accumulator(1.0) + end + @test fetch(x) === nothing + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 10 + @test all(v -> v == 1, values[A_tid]) + end + + @test test_finishes("x -> y -> A") do + local x, y, A + Dagger.spawn_streaming(;teardown=false) do + x = Dagger.@spawn scope=rand(scopes) rand() + y = Dagger.@spawn scope=rand(scopes) x+1 + A = Dagger.@spawn scope=rand(scopes) accumulator(y) + end + @test fetch(x) === nothing + @test fetch(y) === nothing + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 10 + @test all(v -> 1 <= v <= 2, values[A_tid]) + end + + @test test_finishes("x -> (y, A)") do + local x, y, A + Dagger.spawn_streaming(;teardown=false) do + x = Dagger.@spawn scope=rand(scopes) rand() + y = Dagger.@spawn scope=rand(scopes) x+1 + A = Dagger.@spawn scope=rand(scopes) accumulator(x) + end + @test fetch(x) === nothing + @test fetch(y) === nothing + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 10 + @test all(v -> 0 <= v <= 1, values[A_tid]) + end + + @test test_finishes("(x, y) -> A") do + local x, y, A + Dagger.spawn_streaming(;teardown=false) do + x = Dagger.@spawn scope=rand(scopes) rand() + y = Dagger.@spawn scope=rand(scopes) rand() + A = Dagger.@spawn scope=rand(scopes) accumulator(x, y) + end + @test fetch(x) === nothing + @test fetch(y) === nothing + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 10 + @test all(v -> 0 <= v <= 2, values[A_tid]) + end + + @test test_finishes("(x, y) -> z -> A") do + local x, y, z, A + Dagger.spawn_streaming(;teardown=false) do + x = Dagger.@spawn scope=rand(scopes) rand() + y = Dagger.@spawn scope=rand(scopes) rand() + z = Dagger.@spawn scope=rand(scopes) x + y + A = Dagger.@spawn scope=rand(scopes) accumulator(z) + end + @test fetch(x) === nothing + @test fetch(y) === nothing + @test fetch(z) === nothing + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 10 + @test all(v -> 0 <= v <= 2, values[A_tid]) + end + + @test test_finishes("x -> (y, z) -> A") do + local x, y, z, A + Dagger.spawn_streaming(;teardown=false) do + x = Dagger.@spawn scope=rand(scopes) rand() + y = Dagger.@spawn scope=rand(scopes) x + 1 + z = Dagger.@spawn scope=rand(scopes) x + 2 + A = Dagger.@spawn scope=rand(scopes) accumulator(y, z) + end + @test fetch(x) === nothing + @test fetch(y) === nothing + @test fetch(z) === nothing + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 10 + @test all(v -> 3 <= v <= 5, values[A_tid]) + end + + @test test_finishes("(x, y) -> z -> (A, B)") do + local x, y, z, A, B + Dagger.spawn_streaming(;teardown=false) do + x = Dagger.@spawn scope=rand(scopes) rand() + y = Dagger.@spawn scope=rand(scopes) rand() + z = Dagger.@spawn scope=rand(scopes) x + y + A = Dagger.@spawn scope=rand(scopes) accumulator(z) + B = Dagger.@spawn scope=rand(scopes) accumulator(z) + end + @test fetch(x) === nothing + @test fetch(y) === nothing + @test fetch(z) === nothing + @test fetch(A) === nothing + @test fetch(B) === nothing + + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 10 + @test all(v -> 0 <= v <= 2, values[A_tid]) + B_tid = Dagger.task_id(B) + @test length(values[B_tid]) == 10 + @test all(v -> 0 <= v <= 2, values[B_tid]) + end + + for T in (Float64, Int32, BigFloat) + @test test_finishes("Stream eltype $T") do + local x, A + Dagger.spawn_streaming(;teardown=false) do + x = Dagger.@spawn scope=rand(scopes) rand(T) + A = Dagger.@spawn scope=rand(scopes) accumulator(x) + end + @test fetch(x) === nothing + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 10 + @test all(v -> v isa T, values[A_tid]) + end + end + end + + @testset "Max Evals ($scope_str)" begin + @test test_finishes("max_evals=0"; max_evals=0) do + @test_throws ArgumentError Dagger.spawn_streaming(;teardown=false) do + A = Dagger.@spawn scope=rand(scopes) accumulator() + end + end + @test test_finishes("max_evals=1"; max_evals=1) do + local A + Dagger.spawn_streaming(;teardown=false) do + A = Dagger.@spawn scope=rand(scopes) accumulator() + end + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 1 + end + @test test_finishes("max_evals=100"; max_evals=100) do + local A + Dagger.spawn_streaming(;teardown=false) do + A = Dagger.@spawn scope=rand(scopes) accumulator() + end + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 100 + end + end + + # FIXME: Varying buffer amounts + + #= TODO: Zero-allocation test + # First execution of a streaming task will almost guaranteed allocate (compiling, setup, etc.) + # BUT, second and later executions could possibly not allocate any further ("steady-state") + # We want to be able to validate that the steady-state execution for certain tasks is non-allocating + =# +end diff --git a/test/thunk.jl b/test/thunk.jl index e6fb7e86..73879545 100644 --- a/test/thunk.jl +++ b/test/thunk.jl @@ -69,7 +69,7 @@ end A = rand(4, 4) @test fetch(@spawn sum(A; dims=1)) ≈ sum(A; dims=1) - @test_throws_unwrap Dagger.DTaskFailedException fetch(@spawn sum(A; fakearg=2)) + @test_throws_unwrap (Dagger.DTaskFailedException, MethodError) fetch(@spawn sum(A; fakearg=2)) @test fetch(@spawn reduce(+, A; dims=1, init=2.0)) ≈ reduce(+, A; dims=1, init=2.0) @@ -194,7 +194,7 @@ end a = @spawn error("Test") wait(a) @test isready(a) - @test_throws_unwrap Dagger.DTaskFailedException fetch(a) + @test_throws_unwrap (Dagger.DTaskFailedException, ErrorException) fetch(a) b = @spawn 1+2 @test fetch(b) == 3 end @@ -207,8 +207,7 @@ end catch err err end - ex = Dagger.Sch.unwrap_nested_exception(ex) - ex_str = sprint(io->Base.showerror(io,ex)) + ex_str = sprint(io->Base.showerror(io, ex)) @test occursin(r"^DTaskFailedException:", ex_str) @test occursin("Test", ex_str) @test !occursin("Root Task", ex_str) @@ -218,36 +217,35 @@ end catch err err end - ex = Dagger.Sch.unwrap_nested_exception(ex) - ex_str = sprint(io->Base.showerror(io,ex)) + ex_str = sprint(io->Base.showerror(io, ex)) @test occursin("Test", ex_str) @test occursin("Root Task", ex_str) end @testset "single dependent" begin a = @spawn error("Test") b = @spawn a+2 - @test_throws_unwrap Dagger.DTaskFailedException fetch(a) + @test_throws_unwrap (Dagger.DTaskFailedException, ErrorException) fetch(a) end @testset "multi dependent" begin a = @spawn error("Test") b = @spawn a+2 c = @spawn a*2 - @test_throws_unwrap Dagger.DTaskFailedException fetch(b) - @test_throws_unwrap Dagger.DTaskFailedException fetch(c) + @test_throws_unwrap (Dagger.DTaskFailedException, ErrorException) fetch(b) + @test_throws_unwrap (Dagger.DTaskFailedException, ErrorException) fetch(c) end @testset "dependent chain" begin a = @spawn error("Test") - @test_throws_unwrap Dagger.DTaskFailedException fetch(a) + @test_throws_unwrap (Dagger.DTaskFailedException, ErrorException) fetch(a) b = @spawn a+1 - @test_throws_unwrap Dagger.DTaskFailedException fetch(b) + @test_throws_unwrap (Dagger.DTaskFailedException, ErrorException) fetch(b) c = @spawn b+2 - @test_throws_unwrap Dagger.DTaskFailedException fetch(c) + @test_throws_unwrap (Dagger.DTaskFailedException, ErrorException) fetch(c) end @testset "single input" begin a = @spawn 1+1 b = @spawn (a->error("Test"))(a) @test fetch(a) == 2 - @test_throws_unwrap Dagger.DTaskFailedException fetch(b) + @test_throws_unwrap (Dagger.DTaskFailedException, ErrorException) fetch(b) end @testset "multi input" begin a = @spawn 1+1 @@ -255,7 +253,7 @@ end c = @spawn ((a,b)->error("Test"))(a,b) @test fetch(a) == 2 @test fetch(b) == 4 - @test_throws_unwrap Dagger.DTaskFailedException fetch(c) + @test_throws_unwrap (Dagger.DTaskFailedException, ErrorException) fetch(c) end @testset "diamond" begin a = @spawn 1+1 @@ -265,7 +263,7 @@ end @test fetch(a) == 2 @test fetch(b) == 3 @test fetch(c) == 4 - @test_throws_unwrap Dagger.DTaskFailedException fetch(d) + @test_throws_unwrap (Dagger.DTaskFailedException, ErrorException) fetch(d) end end @testset "remote spawn" begin @@ -283,7 +281,7 @@ end t1 = Dagger.@spawn 1+"fail" Dagger.@spawn t1+1 end - @test_throws_unwrap Dagger.DTaskFailedException fetch(t2) + @test_throws_unwrap (Dagger.DTaskFailedException, MethodError) fetch(t2) end @testset "undefined function" begin # Issues #254, #255 diff --git a/test/util.jl b/test/util.jl index f01b3d95..1131a9eb 100644 --- a/test/util.jl +++ b/test/util.jl @@ -14,7 +14,7 @@ end replace_obj!(ex::Symbol, obj) = Expr(:(.), obj, QuoteNode(ex)) replace_obj!(ex, obj) = ex function _test_throws_unwrap(terr, ex; to_match=[]) - @gensym rerr + @gensym oerr rerr match_expr = Expr(:block) for m in to_match if m.head == :(=) @@ -35,12 +35,17 @@ function _test_throws_unwrap(terr, ex; to_match=[]) end end quote - $rerr = try - $(esc(ex)) + $oerr, $rerr = try + nothing, $(esc(ex)) catch err - Dagger.Sch.unwrap_nested_exception(err) + (err, Dagger.Sch.unwrap_nested_exception(err)) + end + if $terr isa Tuple + @test $oerr isa $terr[1] + @test $rerr isa $terr[2] + else + @test $rerr isa $terr end - @test $rerr isa $terr $match_expr end end