From 656f9525e29c318e8b0283922a85c86862e23b23 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Mon, 30 Sep 2024 16:21:58 -0500 Subject: [PATCH] parser: Re-support broadcast --- src/thunk.jl | 21 +++++++++++++-------- test/thunk.jl | 7 +++++++ 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/src/thunk.jl b/src/thunk.jl index 02353fd17..dc961f303 100644 --- a/src/thunk.jl +++ b/src/thunk.jl @@ -307,7 +307,7 @@ generated thunks. macro par(exs...) opts = exs[1:end-1] ex = exs[end] - return esc(_par(ex; lazy=true, opts=opts)) + return esc(_par(__module__, ex; lazy=true, opts=opts)) end """ @@ -349,7 +349,7 @@ also passes along any options in an `Options` struct. For example, macro spawn(exs...) opts = exs[1:end-1] ex = exs[end] - return esc(_par(ex; lazy=false, opts=opts)) + return esc(_par(__module__, ex; lazy=false, opts=opts)) end struct ExpandedBroadcast{F} end @@ -365,8 +365,9 @@ end to_namedtuple(;kwargs...) = (;kwargs...) -function _par(ex::Expr; lazy=true, recur=true, opts=()) +function _par(mod, ex::Expr; lazy=true, recur=true, opts=()) f = nothing + bf = nothing body = nothing arg1 = nothing arg2 = nothing @@ -375,7 +376,11 @@ function _par(ex::Expr; lazy=true, recur=true, opts=()) @capture(ex, allargs__->body_) || @capture(ex, arg1_[allargs__]) || @capture(ex, arg1_.arg2_) || - @capture(ex, (;allargs__)) + @capture(ex, (;allargs__)) || + @capture(ex, bf_.(allargs__)) + if bf !== nothing + f = ExpandedBroadcast{mod.eval(bf)}() + end f = replace_broadcast(f) if arg1 !== nothing if arg2 !== nothing @@ -429,15 +434,15 @@ function _par(ex::Expr; lazy=true, recur=true, opts=()) end elseif lazy # Recurse into the expression - return Expr(ex.head, _par_inner.(ex.args, lazy=lazy, recur=recur, opts=opts)...) + return Expr(ex.head, _par_inner.(Ref(mod), ex.args, lazy=lazy, recur=recur, opts=opts)...) else throw(ArgumentError("Invalid Dagger task expression: $ex")) end end -_par(ex; kwargs...) = throw(ArgumentError("Invalid Dagger task expression: $ex")) +_par(mod, ex; kwargs...) = throw(ArgumentError("Invalid Dagger task expression: $ex")) -_par_inner(ex; kwargs...) = ex -_par_inner(ex::Expr; kwargs...) = _par(ex; kwargs...) +_par_inner(mod, ex; kwargs...) = ex +_par_inner(mod, ex::Expr; kwargs...) = _par(mod, ex; kwargs...) """ Dagger.spawn(f, args...; kwargs...) -> DTask diff --git a/test/thunk.jl b/test/thunk.jl index 22a755dba..e6fb7e86b 100644 --- a/test/thunk.jl +++ b/test/thunk.jl @@ -160,6 +160,13 @@ end @test t isa Dagger.DTask @test fetch(t) == fetch(nt2).b end + @testset "broadcast" begin + x = randn(100) + + t = @spawn abs.(x) + @test t isa Dagger.DTask + @test fetch(t) == abs.(x) + end @testset "invalid expression" begin @test_throws LoadError eval(:(@spawn 1)) @test_throws LoadError eval(:(@spawn begin 1 end))