From dbfe428a7e86502cbbffad463c500705c5472659 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Tue, 18 Jun 2024 10:43:46 -0500 Subject: [PATCH 1/5] parser: Fix expression escaping --- src/thunk.jl | 12 +++++------- test/thunk.jl | 4 ++++ 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/thunk.jl b/src/thunk.jl index 177239979..4d94d9b96 100644 --- a/src/thunk.jl +++ b/src/thunk.jl @@ -306,7 +306,7 @@ generated thunks. macro par(exs...) opts = exs[1:end-1] ex = exs[end] - _par(ex; lazy=true, opts=opts) + return esc(_par(ex; lazy=true, opts=opts)) end """ @@ -348,7 +348,7 @@ also passes along any options in an `Options` struct. For example, macro spawn(exs...) opts = exs[1:end-1] ex = exs[end] - _par(ex; lazy=false, opts=opts) + return esc(_par(ex; lazy=false, opts=opts)) end struct ExpandedBroadcast{F} end @@ -372,17 +372,16 @@ function _par(ex::Expr; lazy=true, recur=true, opts=()) args = ex.args[2:end] kwargs = Expr(:parameters) end - opts = esc.(opts) args_ex = _par.(args; lazy=lazy, recur=false) kwargs_ex = _par.(kwargs.args; lazy=lazy, recur=false) if lazy - return :(Dagger.delayed($(esc(f)), $Options(;$(opts...)))($(args_ex...); $(kwargs_ex...))) + return :(Dagger.delayed($f, $Options(;$(opts...)))($(args_ex...); $(kwargs_ex...))) else - sync_var = esc(Base.sync_varname) + sync_var = Base.sync_varname @gensym result return quote let args = ($(args_ex...),) - $result = $spawn($(esc(f)), $Options(;$(opts...)), args...; $(kwargs_ex...)) + $result = $spawn($f, $Options(;$(opts...)), args...; $(kwargs_ex...)) if $(Expr(:islocal, sync_var)) put!($sync_var, schedule(Task(()->wait($result)))) end @@ -394,7 +393,6 @@ function _par(ex::Expr; lazy=true, recur=true, opts=()) return Expr(ex.head, _par.(ex.args, lazy=lazy, recur=recur, opts=opts)...) end end -_par(ex::Symbol; kwargs...) = esc(ex) _par(ex; kwargs...) = ex """ diff --git a/test/thunk.jl b/test/thunk.jl index db92ee340..42ed7125d 100644 --- a/test/thunk.jl +++ b/test/thunk.jl @@ -79,6 +79,10 @@ end @test fetch(@spawn A .+ B) ≈ A .+ B @test fetch(@spawn A .* B) ≈ A .* B end + @testset "inner macro" begin + A = rand(4) + @test fetch(@spawn sum(@view A[2:3])) ≈ sum(@view A[2:3]) + end @testset "waiting" begin a = @spawn sleep(1) @test !isready(a) From 8d29bd8444e1fbbeb1a559c9199ad43b0c226aec Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Tue, 18 Jun 2024 10:44:30 -0500 Subject: [PATCH 2/5] tests: Instantiate before loading packages --- test/runtests.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/runtests.jl b/test/runtests.jl index 3f4e1b1ca..98e61967d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -31,6 +31,7 @@ if PROGRAM_FILE != "" && realpath(PROGRAM_FILE) == @__FILE__ pushfirst!(LOAD_PATH, joinpath(@__DIR__, "..")) using Pkg Pkg.activate(@__DIR__) + Pkg.instantiate() using ArgParse s = ArgParseSettings(description = "Dagger Testsuite") From 515e731c5d5d7da27a0d37dd9e68a78f361f9749 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Tue, 18 Jun 2024 12:06:10 -0500 Subject: [PATCH 3/5] parser: Support do-blocks --- src/thunk.jl | 40 +++++++++++++++++++++++++--------------- test/thunk.jl | 34 +++++++++++++++++++++++++++++++++- 2 files changed, 58 insertions(+), 16 deletions(-) diff --git a/src/thunk.jl b/src/thunk.jl index 4d94d9b96..62fa31dec 100644 --- a/src/thunk.jl +++ b/src/thunk.jl @@ -363,25 +363,29 @@ function replace_broadcast(fn::Symbol) end function _par(ex::Expr; lazy=true, recur=true, opts=()) - if ex.head == :call && recur - f = replace_broadcast(ex.args[1]) - if length(ex.args) >= 2 && Meta.isexpr(ex.args[2], :parameters) - args = ex.args[3:end] - kwargs = ex.args[2] - else - args = ex.args[2:end] - kwargs = Expr(:parameters) + body = nothing + if recur && @capture(ex, f_(allargs__)) || @capture(ex, f_(allargs__) do cargs_ body_ end) + f = replace_broadcast(f) + args = filter(arg->!Meta.isexpr(arg, :parameters), allargs) + kwargs = filter(arg->Meta.isexpr(arg, :parameters), allargs) + if !isempty(kwargs) + kwargs = only(kwargs).args + end + if body !== nothing + f = quote + ($(args...); $(kwargs...))->$f($(args...); $(kwargs...)) do $cargs + $body + end + end end - args_ex = _par.(args; lazy=lazy, recur=false) - kwargs_ex = _par.(kwargs.args; lazy=lazy, recur=false) if lazy - return :(Dagger.delayed($f, $Options(;$(opts...)))($(args_ex...); $(kwargs_ex...))) + return :(Dagger.delayed($f, $Options(;$(opts...)))($(args...); $(kwargs...))) else sync_var = Base.sync_varname @gensym result return quote - let args = ($(args_ex...),) - $result = $spawn($f, $Options(;$(opts...)), args...; $(kwargs_ex...)) + let + $result = $spawn($f, $Options(;$(opts...)), $(args...); $(kwargs...)) if $(Expr(:islocal, sync_var)) put!($sync_var, schedule(Task(()->wait($result)))) end @@ -389,11 +393,17 @@ function _par(ex::Expr; lazy=true, recur=true, opts=()) end end end + elseif lazy + # Recurse into the expression + return Expr(ex.head, _par_inner.(ex.args, lazy=lazy, recur=recur, opts=opts)...) else - return Expr(ex.head, _par.(ex.args, lazy=lazy, recur=recur, opts=opts)...) + throw(ArgumentError("Invalid Dagger task expression: $ex")) end end -_par(ex; kwargs...) = ex +_par(ex; kwargs...) = throw(ArgumentError("Invalid Dagger task expression: $ex")) + +_par_inner(ex; kwargs...) = ex +_par_inner(ex::Expr; kwargs...) = _par(ex; kwargs...) """ Dagger.spawn(f, args...; kwargs...) -> DTask diff --git a/test/thunk.jl b/test/thunk.jl index 42ed7125d..66ebc01f8 100644 --- a/test/thunk.jl +++ b/test/thunk.jl @@ -81,7 +81,39 @@ end end @testset "inner macro" begin A = rand(4) - @test fetch(@spawn sum(@view A[2:3])) ≈ sum(@view A[2:3]) + t = @spawn sum(@view A[2:3]) + @test t isa Dagger.DTask + @test fetch(t) ≈ sum(@view A[2:3]) + end + @testset "do block" begin + A = rand(4) + + t = @spawn sum(A) do a + a + 1 + end + @test t isa Dagger.DTask + @test fetch(t) ≈ sum(a->a+1, A) + + t = @spawn sum(A; dims=1) do a + a + 1 + end + @test t isa Dagger.DTask + @test fetch(t) ≈ sum(a->a+1, A; dims=1) + + do_f = f -> f(42) + t = @spawn do_f() do x + x + 1 + end + @test t isa Dagger.DTask + @test fetch(t) == 43 + end + @testset "invalid expression" begin + @test_throws LoadError eval(:(@spawn 1)) + @test_throws LoadError eval(:(@spawn begin 1 end)) + @test_throws LoadError eval(:(@spawn begin + 1+1 + 1+1 + end)) end @testset "waiting" begin a = @spawn sleep(1) From 705266a6fa08d1fdde7c516138c64d9418406f31 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Tue, 18 Jun 2024 12:51:24 -0500 Subject: [PATCH 4/5] parser: Support direct anonymous function calls --- src/thunk.jl | 17 +++++++++++++---- test/thunk.jl | 11 +++++++++++ 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/src/thunk.jl b/src/thunk.jl index 62fa31dec..4f69bf1f8 100644 --- a/src/thunk.jl +++ b/src/thunk.jl @@ -363,8 +363,9 @@ function replace_broadcast(fn::Symbol) end function _par(ex::Expr; lazy=true, recur=true, opts=()) + f = nothing body = nothing - if recur && @capture(ex, f_(allargs__)) || @capture(ex, f_(allargs__) do cargs_ body_ end) + if recur && @capture(ex, f_(allargs__)) || @capture(ex, f_(allargs__) do cargs_ body_ end) || @capture(ex, allargs__->body_) f = replace_broadcast(f) args = filter(arg->!Meta.isexpr(arg, :parameters), allargs) kwargs = filter(arg->Meta.isexpr(arg, :parameters), allargs) @@ -372,9 +373,17 @@ function _par(ex::Expr; lazy=true, recur=true, opts=()) kwargs = only(kwargs).args end if body !== nothing - f = quote - ($(args...); $(kwargs...))->$f($(args...); $(kwargs...)) do $cargs - $body + if f !== nothing + f = quote + ($(args...); $(kwargs...))->$f($(args...); $(kwargs...)) do $cargs + $body + end + end + else + f = quote + ($(args...); $(kwargs...))->begin + $body + end end end end diff --git a/test/thunk.jl b/test/thunk.jl index 66ebc01f8..7d18a292a 100644 --- a/test/thunk.jl +++ b/test/thunk.jl @@ -107,6 +107,17 @@ end @test t isa Dagger.DTask @test fetch(t) == 43 end + @testset "anonymous direct call" begin + A = rand(4) + + t = @spawn A->sum(A) + @test t isa Dagger.DTask + @test fetch(t) == sum(A) + + t = @spawn A->sum(A; dims=1) + @test t isa Dagger.DTask + @test fetch(t) == sum(A; dims=1) + end @testset "invalid expression" begin @test_throws LoadError eval(:(@spawn 1)) @test_throws LoadError eval(:(@spawn begin 1 end)) From 86f2f5a1e4f468b925a01b05940a148628d879c9 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Tue, 18 Jun 2024 15:48:06 -0500 Subject: [PATCH 5/5] parser: Support getindex --- src/thunk.jl | 8 +++++++- test/thunk.jl | 17 +++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/src/thunk.jl b/src/thunk.jl index 4f69bf1f8..017a69a9d 100644 --- a/src/thunk.jl +++ b/src/thunk.jl @@ -365,8 +365,14 @@ end function _par(ex::Expr; lazy=true, recur=true, opts=()) f = nothing body = nothing - if recur && @capture(ex, f_(allargs__)) || @capture(ex, f_(allargs__) do cargs_ body_ end) || @capture(ex, allargs__->body_) + arg1 = nothing + if recur && @capture(ex, f_(allargs__)) || @capture(ex, f_(allargs__) do cargs_ body_ end) || @capture(ex, allargs__->body_) || @capture(ex, arg1_[allargs__]) f = replace_broadcast(f) + if arg1 !== nothing + # Indexing (A[2,3]) + f = Base.getindex + pushfirst!(allargs, arg1) + end args = filter(arg->!Meta.isexpr(arg, :parameters), allargs) kwargs = filter(arg->Meta.isexpr(arg, :parameters), allargs) if !isempty(kwargs) diff --git a/test/thunk.jl b/test/thunk.jl index 7d18a292a..95763e8b9 100644 --- a/test/thunk.jl +++ b/test/thunk.jl @@ -118,6 +118,23 @@ end @test t isa Dagger.DTask @test fetch(t) == sum(A; dims=1) end + @testset "getindex" begin + A = rand(4, 4) + + t = @spawn A[1, 2] + @test t isa Dagger.DTask + @test fetch(t) == A[1, 2] + + B = Dagger.@spawn rand(4, 4) + t = @spawn B[1, 2] + @test t isa Dagger.DTask + @test fetch(t) == fetch(B)[1, 2] + + R = Ref(42) + t = @spawn R[] + @test t isa Dagger.DTask + @test fetch(t) == 42 + end @testset "invalid expression" begin @test_throws LoadError eval(:(@spawn 1)) @test_throws LoadError eval(:(@spawn begin 1 end))