Skip to content

Commit

Permalink
fixup! datadeps: Add at-stencil helper
Browse files Browse the repository at this point in the history
  • Loading branch information
jpsamaroo committed Aug 12, 2024
1 parent e96ed37 commit c9b3103
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 25 deletions.
59 changes: 38 additions & 21 deletions src/stencil.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# FIXME: Remove me
const Read = In
const Write = Out
const ReadWrite = InOut

function get_neighbor_edge(arr, dim, dir, dist)
if dir == -1
start_idx = CartesianIndex(ntuple(i -> i == dim ? (lastindex(arr, i) - dist + 1) : firstindex(arr, i), ndims(arr)))
Expand All @@ -13,8 +18,7 @@ function get_neighbor_corner(chunk, corner_side, neigh_dist)
stop_idx = CartesianIndex(ntuple(i -> corner_side[i] == 0 ? lastindex(chunk, i) : (firstindex(chunk, i) + neigh_dist - 1), ndims(chunk)))
return collect(@view chunk[start_idx:stop_idx])
end
function get_neighborhood_chunks(chunks, idx, neigh_dist)
@assert neigh_dist == 1
function get_neighborhood_chunks(chunks, idx, neigh_dist, boundary)
chunk_dist = 1
# Get the center
accesses = Any[chunks[idx]]
Expand Down Expand Up @@ -47,7 +51,7 @@ function get_neighborhood_chunks(chunks, idx, neigh_dist)
@assert length(accesses) == 1+2*ndims(chunks)+2^ndims(chunks) "Accesses mismatch: $(length(accesses))"
return accesses
end
function build_halo(neigh_dist, center::Array{T,N}, all_neighbors...) where {T,N}
function build_halo(neigh_dist, boundary, center::Array{T,N}, all_neighbors...) where {T,N}
# FIXME: Don't collect views
edges = collect.(all_neighbors[1:(2*N)])
corners = collect.(all_neighbors[((2^N)+1):end])
Expand All @@ -62,6 +66,17 @@ function load_neighborhood(arr::HaloArray{T,N}, idx, neigh_dist) where {T,N}
return collect(@view arr[start_idx:stop_idx])
end

struct Wrap end
boundary_init(::Wrap, arr, size) = similar(arr, eltype(arr), size)
boundary_has_transition(::Wrap) = true
boundary_transition(::Wrap, idx, size) = mod1(idx, size)

struct Pad{T}
padval::T
end
boundary_init(::Pad{T}, arr, size) where T = Fill(padval, size)
boundary_has_transition(::Pad) = false

"""
@stencil idx in range begin body end
Expand All @@ -70,13 +85,15 @@ region. The `idx` variable is used to iterate over `range`, which must be a
`DArray`. An example usage may look like:
```julia
import Dagger: @stencil, Wrap
A = zeros(Blocks(3, 3), Int, 9, 9)
A[5, 5] = 1
B = zeros(Blocks(3, 3), Int, 9, 9)
Dagger.@spawn_datadeps() do
@stencil idx in A begin
# Sum values of all neighbors with self
A[idx] = sum(@neighbors(A[idx], 1, :wrap))
A[idx] = sum(@neighbors(A[idx], 1, Wrap()))
# Decrement all values by 1
A[idx] -= 1
# Copy A to B
Expand All @@ -93,16 +110,16 @@ size, shape, and chunk layout as `A`.
Additionally, the `@neighbors` macro can be used to access a neighborhood of
values around `A[idx]`, at a configurable distance (in this case, 1 element
distance) and with or without wrapping (in this case, `:wrap` enables
wrapping). Neighborhoods are computed with respect to neighboring chunks as
well - if a neighborhood would overflow from the current chunk into one or more
neighboring chunks, values from those neighboring chunks will be included in
the neighborhood.
distance) and with various kinds of boundary conditions (in this case, `Wrap()`
specifies wrapping behavior on the boundaries). Neighborhoods are computed with
respect to neighboring chunks as well - if a neighborhood would overflow from
the current chunk into one or more neighboring chunks, values from those
neighboring chunks will be included in the neighborhood.
Note that, while `@stencil` may look like a `for` loop, it does not follow the
same semantics; in particular, an expression within `@stencil` occurs "all at
once" (across all indices) before the next expression occurs. This means that
`A[idx] = sum(@neighbors(A[idx], 1, :wrap))` will write the sum of
`A[idx] = sum(@neighbors(A[idx], 1, Wrap()))` will write the sum of
neighbors for all `idx` values into `A[idx]` before `A[idx] -= 1` decrements
the values `A` by 1, and that occurs before any of the values are copied to `B`
in `B[idx] = A[idx]`. Of course, pipelining and other optimizations may still
Expand All @@ -123,18 +140,16 @@ macro stencil(index_ex, orig_ex)
@assert write_idx == index_var "Can only write to $index_var: $write_ex"
accessed_vars = Set{Symbol}()
read_vars = Set{Symbol}()
neighborhoods = Dict{Symbol, Tuple{Int, Bool}}()
neighborhoods = Dict{Symbol, Tuple{Any, Any}}()
push!(accessed_vars, write_var)
MacroTools.prewalk(read_ex) do read_inner_ex
prewalk(read_ex) do read_inner_ex
if @capture(read_inner_ex, read_var_[read_idx_]) && read_idx == index_var
push!(accessed_vars, read_var)
push!(read_vars, read_var)
elseif @capture(read_inner_ex, @neighbors(read_var_[read_idx_], neigh_dist_, wrapping_)) && read_idx == index_var
@assert neigh_dist == 1 "Neighborhoods greater than 1 not yet supported"
@assert wrapping == :wrap "Unknown wrapping pattern `$(repr(wrapping))`"
elseif @capture(read_inner_ex, @neighbors(read_var_[read_idx_], neigh_dist_, boundary_)) && read_idx == index_var
push!(accessed_vars, read_var)
push!(read_vars, read_var)
neighborhoods[read_var] = (neigh_dist, true)
neighborhoods[read_var] = (neigh_dist, boundary)
end
return read_inner_ex
end
Expand All @@ -151,11 +166,11 @@ macro stencil(index_ex, orig_ex)

# Generate function with transformed body
@gensym inner_index_var
new_inner_ex_body = MacroTools.prewalk(inner_ex) do old_inner_ex
new_inner_ex_body = prewalk(inner_ex) do old_inner_ex
if @capture(old_inner_ex, read_var_[read_idx_]) && read_idx == index_var
# Direct access
return :($read_var[$inner_index_var])
elseif @capture(old_inner_ex, @neighbors(read_var_[read_idx_], neigh_dist_, wrapping_)) && read_idx == index_var
elseif @capture(old_inner_ex, @neighbors(read_var_[read_idx_], neigh_dist_, boundary_)) && read_idx == index_var
# Neighborhood access
return :($load_neighborhood($read_var, $inner_index_var, $neigh_dist))
end
Expand All @@ -179,10 +194,10 @@ macro stencil(index_ex, orig_ex)
for read_var in read_vars
if read_var in keys(neighborhoods)
# Generate a neighborhood copy operation
neigh_dist, dowrap = neighborhoods[read_var]
neigh_dist, boundary = neighborhoods[read_var]
deps_inner_ex = Expr(:block)
@gensym neighbor_copy_var
push!(neighbor_copy_all_ex.args, :($neighbor_copy_var = Dagger.@spawn $build_halo($neigh_dist, map($Read, $get_neighborhood_chunks($chunks($read_var), $chunk_idx, $neigh_dist))...)))
push!(neighbor_copy_all_ex.args, :($neighbor_copy_var = Dagger.@spawn $build_halo($neigh_dist, $boundary, map($Read, $get_neighborhood_chunks($chunks($read_var), $chunk_idx, $neigh_dist, $boundary))...)))
push!(deps_ex, Expr(:kw, read_var, :($Read($neighbor_copy_var))))
else
push!(deps_ex, Expr(:kw, read_var, :($Read($chunks($read_var)[$chunk_idx]))))
Expand All @@ -192,12 +207,14 @@ macro stencil(index_ex, orig_ex)

# Generate loop
push!(final_ex.args, quote
for $chunk_idx in $index_range
for $chunk_idx in $CartesianIndices($chunks($index_range))
$neighbor_copy_all_ex
$spawn_ex
end
end)
end

@show final_ex

return esc(final_ex)
end
11 changes: 7 additions & 4 deletions src/utils/haloarray.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Define the HaloArray type with minimized halo storage
struct HaloArray{T,N,E,C} <: AbstractArray{T,N}
center::Array{T,N}
edges::NTuple{E, Array{T,N}}
corners::NTuple{C, Array{T,N}}
struct HaloArray{T,N,E,C,A,EA,CA} <: AbstractArray{T,N}
center::A
edges::NTuple{E, EA}
corners::NTuple{C, CA}
halo_width::NTuple{N,Int}
end

Expand All @@ -20,6 +20,9 @@ function HaloArray{T,N}(center_size::NTuple{N,Int}, halo_width::NTuple{N,Int}) w
return HaloArray{T,N,2N,2^N}(center, edges, corners, halo_width)
end

HaloArray(center::AT, edges::NTuple{E, EA}, corners::NTuple{C, CA}, halo_width::NTuple{N, Int}) where {T,N,AT<:AbstractArray{T,N},C,E,CA,EA} =
HaloArray{T,N,E,C,AT,EA,CA}(center, edges, corners, halo_width)

Base.size(tile::HaloArray) = size(tile.center) .+ 2 .* tile.halo_width
function Base.axes(tile::HaloArray{T,N,H}) where {T,N,H}
ntuple(N) do i
Expand Down

0 comments on commit c9b3103

Please sign in to comment.