From c9b310395b86797d93506e645043481c5176a2f0 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Mon, 12 Aug 2024 12:00:10 -0500 Subject: [PATCH] fixup! datadeps: Add at-stencil helper --- src/stencil.jl | 59 +++++++++++++++++++++++++++--------------- src/utils/haloarray.jl | 11 +++++--- 2 files changed, 45 insertions(+), 25 deletions(-) diff --git a/src/stencil.jl b/src/stencil.jl index 808e01459..37e085dd0 100644 --- a/src/stencil.jl +++ b/src/stencil.jl @@ -1,3 +1,8 @@ +# FIXME: Remove me +const Read = In +const Write = Out +const ReadWrite = InOut + function get_neighbor_edge(arr, dim, dir, dist) if dir == -1 start_idx = CartesianIndex(ntuple(i -> i == dim ? (lastindex(arr, i) - dist + 1) : firstindex(arr, i), ndims(arr))) @@ -13,8 +18,7 @@ function get_neighbor_corner(chunk, corner_side, neigh_dist) stop_idx = CartesianIndex(ntuple(i -> corner_side[i] == 0 ? lastindex(chunk, i) : (firstindex(chunk, i) + neigh_dist - 1), ndims(chunk))) return collect(@view chunk[start_idx:stop_idx]) end -function get_neighborhood_chunks(chunks, idx, neigh_dist) - @assert neigh_dist == 1 +function get_neighborhood_chunks(chunks, idx, neigh_dist, boundary) chunk_dist = 1 # Get the center accesses = Any[chunks[idx]] @@ -47,7 +51,7 @@ function get_neighborhood_chunks(chunks, idx, neigh_dist) @assert length(accesses) == 1+2*ndims(chunks)+2^ndims(chunks) "Accesses mismatch: $(length(accesses))" return accesses end -function build_halo(neigh_dist, center::Array{T,N}, all_neighbors...) where {T,N} +function build_halo(neigh_dist, boundary, center::Array{T,N}, all_neighbors...) where {T,N} # FIXME: Don't collect views edges = collect.(all_neighbors[1:(2*N)]) corners = collect.(all_neighbors[((2^N)+1):end]) @@ -62,6 +66,17 @@ function load_neighborhood(arr::HaloArray{T,N}, idx, neigh_dist) where {T,N} return collect(@view arr[start_idx:stop_idx]) end +struct Wrap end +boundary_init(::Wrap, arr, size) = similar(arr, eltype(arr), size) +boundary_has_transition(::Wrap) = true +boundary_transition(::Wrap, idx, size) = mod1(idx, size) + +struct Pad{T} + padval::T +end +boundary_init(::Pad{T}, arr, size) where T = Fill(padval, size) +boundary_has_transition(::Pad) = false + """ @stencil idx in range begin body end @@ -70,13 +85,15 @@ region. The `idx` variable is used to iterate over `range`, which must be a `DArray`. An example usage may look like: ```julia +import Dagger: @stencil, Wrap + A = zeros(Blocks(3, 3), Int, 9, 9) A[5, 5] = 1 B = zeros(Blocks(3, 3), Int, 9, 9) Dagger.@spawn_datadeps() do @stencil idx in A begin # Sum values of all neighbors with self - A[idx] = sum(@neighbors(A[idx], 1, :wrap)) + A[idx] = sum(@neighbors(A[idx], 1, Wrap())) # Decrement all values by 1 A[idx] -= 1 # Copy A to B @@ -93,16 +110,16 @@ size, shape, and chunk layout as `A`. Additionally, the `@neighbors` macro can be used to access a neighborhood of values around `A[idx]`, at a configurable distance (in this case, 1 element -distance) and with or without wrapping (in this case, `:wrap` enables -wrapping). Neighborhoods are computed with respect to neighboring chunks as -well - if a neighborhood would overflow from the current chunk into one or more -neighboring chunks, values from those neighboring chunks will be included in -the neighborhood. +distance) and with various kinds of boundary conditions (in this case, `Wrap()` +specifies wrapping behavior on the boundaries). Neighborhoods are computed with +respect to neighboring chunks as well - if a neighborhood would overflow from +the current chunk into one or more neighboring chunks, values from those +neighboring chunks will be included in the neighborhood. Note that, while `@stencil` may look like a `for` loop, it does not follow the same semantics; in particular, an expression within `@stencil` occurs "all at once" (across all indices) before the next expression occurs. This means that -`A[idx] = sum(@neighbors(A[idx], 1, :wrap))` will write the sum of +`A[idx] = sum(@neighbors(A[idx], 1, Wrap()))` will write the sum of neighbors for all `idx` values into `A[idx]` before `A[idx] -= 1` decrements the values `A` by 1, and that occurs before any of the values are copied to `B` in `B[idx] = A[idx]`. Of course, pipelining and other optimizations may still @@ -123,18 +140,16 @@ macro stencil(index_ex, orig_ex) @assert write_idx == index_var "Can only write to $index_var: $write_ex" accessed_vars = Set{Symbol}() read_vars = Set{Symbol}() - neighborhoods = Dict{Symbol, Tuple{Int, Bool}}() + neighborhoods = Dict{Symbol, Tuple{Any, Any}}() push!(accessed_vars, write_var) - MacroTools.prewalk(read_ex) do read_inner_ex + prewalk(read_ex) do read_inner_ex if @capture(read_inner_ex, read_var_[read_idx_]) && read_idx == index_var push!(accessed_vars, read_var) push!(read_vars, read_var) - elseif @capture(read_inner_ex, @neighbors(read_var_[read_idx_], neigh_dist_, wrapping_)) && read_idx == index_var - @assert neigh_dist == 1 "Neighborhoods greater than 1 not yet supported" - @assert wrapping == :wrap "Unknown wrapping pattern `$(repr(wrapping))`" + elseif @capture(read_inner_ex, @neighbors(read_var_[read_idx_], neigh_dist_, boundary_)) && read_idx == index_var push!(accessed_vars, read_var) push!(read_vars, read_var) - neighborhoods[read_var] = (neigh_dist, true) + neighborhoods[read_var] = (neigh_dist, boundary) end return read_inner_ex end @@ -151,11 +166,11 @@ macro stencil(index_ex, orig_ex) # Generate function with transformed body @gensym inner_index_var - new_inner_ex_body = MacroTools.prewalk(inner_ex) do old_inner_ex + new_inner_ex_body = prewalk(inner_ex) do old_inner_ex if @capture(old_inner_ex, read_var_[read_idx_]) && read_idx == index_var # Direct access return :($read_var[$inner_index_var]) - elseif @capture(old_inner_ex, @neighbors(read_var_[read_idx_], neigh_dist_, wrapping_)) && read_idx == index_var + elseif @capture(old_inner_ex, @neighbors(read_var_[read_idx_], neigh_dist_, boundary_)) && read_idx == index_var # Neighborhood access return :($load_neighborhood($read_var, $inner_index_var, $neigh_dist)) end @@ -179,10 +194,10 @@ macro stencil(index_ex, orig_ex) for read_var in read_vars if read_var in keys(neighborhoods) # Generate a neighborhood copy operation - neigh_dist, dowrap = neighborhoods[read_var] + neigh_dist, boundary = neighborhoods[read_var] deps_inner_ex = Expr(:block) @gensym neighbor_copy_var - push!(neighbor_copy_all_ex.args, :($neighbor_copy_var = Dagger.@spawn $build_halo($neigh_dist, map($Read, $get_neighborhood_chunks($chunks($read_var), $chunk_idx, $neigh_dist))...))) + push!(neighbor_copy_all_ex.args, :($neighbor_copy_var = Dagger.@spawn $build_halo($neigh_dist, $boundary, map($Read, $get_neighborhood_chunks($chunks($read_var), $chunk_idx, $neigh_dist, $boundary))...))) push!(deps_ex, Expr(:kw, read_var, :($Read($neighbor_copy_var)))) else push!(deps_ex, Expr(:kw, read_var, :($Read($chunks($read_var)[$chunk_idx])))) @@ -192,12 +207,14 @@ macro stencil(index_ex, orig_ex) # Generate loop push!(final_ex.args, quote - for $chunk_idx in $index_range + for $chunk_idx in $CartesianIndices($chunks($index_range)) $neighbor_copy_all_ex $spawn_ex end end) end + @show final_ex + return esc(final_ex) end diff --git a/src/utils/haloarray.jl b/src/utils/haloarray.jl index 83b9e0d19..835131d7b 100644 --- a/src/utils/haloarray.jl +++ b/src/utils/haloarray.jl @@ -1,8 +1,8 @@ # Define the HaloArray type with minimized halo storage -struct HaloArray{T,N,E,C} <: AbstractArray{T,N} - center::Array{T,N} - edges::NTuple{E, Array{T,N}} - corners::NTuple{C, Array{T,N}} +struct HaloArray{T,N,E,C,A,EA,CA} <: AbstractArray{T,N} + center::A + edges::NTuple{E, EA} + corners::NTuple{C, CA} halo_width::NTuple{N,Int} end @@ -20,6 +20,9 @@ function HaloArray{T,N}(center_size::NTuple{N,Int}, halo_width::NTuple{N,Int}) w return HaloArray{T,N,2N,2^N}(center, edges, corners, halo_width) end +HaloArray(center::AT, edges::NTuple{E, EA}, corners::NTuple{C, CA}, halo_width::NTuple{N, Int}) where {T,N,AT<:AbstractArray{T,N},C,E,CA,EA} = + HaloArray{T,N,E,C,AT,EA,CA}(center, edges, corners, halo_width) + Base.size(tile::HaloArray) = size(tile.center) .+ 2 .* tile.halo_width function Base.axes(tile::HaloArray{T,N,H}) where {T,N,H} ntuple(N) do i