diff --git a/src/thunk.jl b/src/thunk.jl index 177239979..017a69a9d 100644 --- a/src/thunk.jl +++ b/src/thunk.jl @@ -306,7 +306,7 @@ generated thunks. macro par(exs...) opts = exs[1:end-1] ex = exs[end] - _par(ex; lazy=true, opts=opts) + return esc(_par(ex; lazy=true, opts=opts)) end """ @@ -348,7 +348,7 @@ also passes along any options in an `Options` struct. For example, macro spawn(exs...) opts = exs[1:end-1] ex = exs[end] - _par(ex; lazy=false, opts=opts) + return esc(_par(ex; lazy=false, opts=opts)) end struct ExpandedBroadcast{F} end @@ -363,26 +363,44 @@ function replace_broadcast(fn::Symbol) end function _par(ex::Expr; lazy=true, recur=true, opts=()) - if ex.head == :call && recur - f = replace_broadcast(ex.args[1]) - if length(ex.args) >= 2 && Meta.isexpr(ex.args[2], :parameters) - args = ex.args[3:end] - kwargs = ex.args[2] - else - args = ex.args[2:end] - kwargs = Expr(:parameters) + f = nothing + body = nothing + arg1 = nothing + if recur && @capture(ex, f_(allargs__)) || @capture(ex, f_(allargs__) do cargs_ body_ end) || @capture(ex, allargs__->body_) || @capture(ex, arg1_[allargs__]) + f = replace_broadcast(f) + if arg1 !== nothing + # Indexing (A[2,3]) + f = Base.getindex + pushfirst!(allargs, arg1) + end + args = filter(arg->!Meta.isexpr(arg, :parameters), allargs) + kwargs = filter(arg->Meta.isexpr(arg, :parameters), allargs) + if !isempty(kwargs) + kwargs = only(kwargs).args + end + if body !== nothing + if f !== nothing + f = quote + ($(args...); $(kwargs...))->$f($(args...); $(kwargs...)) do $cargs + $body + end + end + else + f = quote + ($(args...); $(kwargs...))->begin + $body + end + end + end end - opts = esc.(opts) - args_ex = _par.(args; lazy=lazy, recur=false) - kwargs_ex = _par.(kwargs.args; lazy=lazy, recur=false) if lazy - return :(Dagger.delayed($(esc(f)), $Options(;$(opts...)))($(args_ex...); $(kwargs_ex...))) + return :(Dagger.delayed($f, $Options(;$(opts...)))($(args...); $(kwargs...))) else - sync_var = esc(Base.sync_varname) + sync_var = Base.sync_varname @gensym result return quote - let args = ($(args_ex...),) - $result = $spawn($(esc(f)), $Options(;$(opts...)), args...; $(kwargs_ex...)) + let + $result = $spawn($f, $Options(;$(opts...)), $(args...); $(kwargs...)) if $(Expr(:islocal, sync_var)) put!($sync_var, schedule(Task(()->wait($result)))) end @@ -390,12 +408,17 @@ function _par(ex::Expr; lazy=true, recur=true, opts=()) end end end + elseif lazy + # Recurse into the expression + return Expr(ex.head, _par_inner.(ex.args, lazy=lazy, recur=recur, opts=opts)...) else - return Expr(ex.head, _par.(ex.args, lazy=lazy, recur=recur, opts=opts)...) + throw(ArgumentError("Invalid Dagger task expression: $ex")) end end -_par(ex::Symbol; kwargs...) = esc(ex) -_par(ex; kwargs...) = ex +_par(ex; kwargs...) = throw(ArgumentError("Invalid Dagger task expression: $ex")) + +_par_inner(ex; kwargs...) = ex +_par_inner(ex::Expr; kwargs...) = _par(ex; kwargs...) """ Dagger.spawn(f, args...; kwargs...) -> DTask diff --git a/test/runtests.jl b/test/runtests.jl index 3f4e1b1ca..98e61967d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -31,6 +31,7 @@ if PROGRAM_FILE != "" && realpath(PROGRAM_FILE) == @__FILE__ pushfirst!(LOAD_PATH, joinpath(@__DIR__, "..")) using Pkg Pkg.activate(@__DIR__) + Pkg.instantiate() using ArgParse s = ArgParseSettings(description = "Dagger Testsuite") diff --git a/test/thunk.jl b/test/thunk.jl index db92ee340..95763e8b9 100644 --- a/test/thunk.jl +++ b/test/thunk.jl @@ -79,6 +79,70 @@ end @test fetch(@spawn A .+ B) ≈ A .+ B @test fetch(@spawn A .* B) ≈ A .* B end + @testset "inner macro" begin + A = rand(4) + t = @spawn sum(@view A[2:3]) + @test t isa Dagger.DTask + @test fetch(t) ≈ sum(@view A[2:3]) + end + @testset "do block" begin + A = rand(4) + + t = @spawn sum(A) do a + a + 1 + end + @test t isa Dagger.DTask + @test fetch(t) ≈ sum(a->a+1, A) + + t = @spawn sum(A; dims=1) do a + a + 1 + end + @test t isa Dagger.DTask + @test fetch(t) ≈ sum(a->a+1, A; dims=1) + + do_f = f -> f(42) + t = @spawn do_f() do x + x + 1 + end + @test t isa Dagger.DTask + @test fetch(t) == 43 + end + @testset "anonymous direct call" begin + A = rand(4) + + t = @spawn A->sum(A) + @test t isa Dagger.DTask + @test fetch(t) == sum(A) + + t = @spawn A->sum(A; dims=1) + @test t isa Dagger.DTask + @test fetch(t) == sum(A; dims=1) + end + @testset "getindex" begin + A = rand(4, 4) + + t = @spawn A[1, 2] + @test t isa Dagger.DTask + @test fetch(t) == A[1, 2] + + B = Dagger.@spawn rand(4, 4) + t = @spawn B[1, 2] + @test t isa Dagger.DTask + @test fetch(t) == fetch(B)[1, 2] + + R = Ref(42) + t = @spawn R[] + @test t isa Dagger.DTask + @test fetch(t) == 42 + end + @testset "invalid expression" begin + @test_throws LoadError eval(:(@spawn 1)) + @test_throws LoadError eval(:(@spawn begin 1 end)) + @test_throws LoadError eval(:(@spawn begin + 1+1 + 1+1 + end)) + end @testset "waiting" begin a = @spawn sleep(1) @test !isready(a)