Skip to content

Commit

Permalink
Implement FD shmem
Browse files Browse the repository at this point in the history
Add shmem to FD operators

Fixes

Passing simple example

Fixes

Use shared memory for DivergenceF2C operator stencils
  • Loading branch information
charleskawczynski authored and Charlie Kawczynski committed Feb 13, 2025
1 parent c19162a commit 82ff728
Show file tree
Hide file tree
Showing 7 changed files with 366 additions and 15 deletions.
2 changes: 2 additions & 0 deletions ext/ClimaCoreCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ include(joinpath("cuda", "operators_integral.jl"))
include(joinpath("cuda", "remapping_interpolate_array.jl"))
include(joinpath("cuda", "limiters.jl"))
include(joinpath("cuda", "operators_sem_shmem.jl"))
include(joinpath("cuda", "operators_fd_shmem_common.jl"))
include(joinpath("cuda", "operators_fd_shmem.jl"))
include(joinpath("cuda", "operators_thomas_algorithm.jl"))
include(joinpath("cuda", "matrix_fields_single_field_solve.jl"))
include(joinpath("cuda", "matrix_fields_multiple_field_solve.jl"))
Expand Down
39 changes: 39 additions & 0 deletions ext/cuda/data_layouts_threadblock.jl
Original file line number Diff line number Diff line change
Expand Up @@ -289,3 +289,42 @@ end
ij,
slabidx,
) = Operators.is_valid_index(space, ij, slabidx)

##### shmem fd kernel partition
@inline function fd_stencil_partition(
us::DataLayouts.UniversalSize,
n_max_threads::Integer = 256;
)
(Nq, _, _, Nv, Nh) = DataLayouts.universal_size(us)
Nvthreads = min(fld(n_max_threads, Nq) + 2, maximum_allowable_threads()[1])
Nvblocks = cld(Nv, Nvthreads) # +1 may be needed to guarantee that shared memory is populated at the last cell face
@assert Nvthreads >= Nv + 1 "Nvthreads must exceed number of vertical points by 1, for shared memory."
# @assert prod((Nq, Nvthreads)) ≤ n_max_threads "threads,n_max_threads=($(prod((Nq, Nvthreads))),$n_max_threads)"
@assert Nq * Nq n_max_threads
# This is basically a rotation permutation of the spectral element config
return (;
threads = (Nvthreads,),
blocks = (Nvblocks, Nh, Nq * Nq),
Nvthreads,
)
end
@inline function fd_stencil_universal_index(space::Spaces.AbstractSpace, us)
(tv,) = CUDA.threadIdx()
(bv, h, ij) = CUDA.blockIdx()
vid = tv + (bv - 1) * CUDA.blockDim().x
(Nq, _, _, _, _) = DataLayouts.universal_size(us)
(i, j) = CartesianIndices((Nq, Nq))[ij].I
v =
if space isa Spaces.FaceExtrudedFiniteDifferenceSpace ||
space isa Spaces.FaceFiniteDifferenceSpace
v = vid - half
elseif space isa Spaces.CenterExtrudedFiniteDifferenceSpace ||
space isa Spaces.CenterFiniteDifferenceSpace
v = vid
else
error("Invalid space")
end
return CartesianIndex((i, j, 1, v, h))
end
@inline fd_stencil_is_valid_index(I::CI5, us::UniversalSize) =
1 I[5] DataLayouts.get_Nh(us)
47 changes: 47 additions & 0 deletions ext/cuda/operators_fd_shmem.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import ClimaCore: DataLayouts, Spaces, Geometry, RecursiveApply, DataLayouts
import CUDA
import ClimaCore.Operators: DivergenceF2C
import ClimaCore.Operators: return_eltype, get_local_geometry

Base.@propagate_inbounds function fd_operator_shmem(
space,
::Val{Nvt},
op::DivergenceF2C,
arg,
) where {Nvt}
# allocate temp output
RT = return_eltype(op, arg)
Ju³ = CUDA.CuStaticSharedArray(RT, (Nvt,))
return Ju³
end

Base.@propagate_inbounds function fd_operator_fill_shmem!(
op::DivergenceF2C,
Ju³,
loc,
space,
idx,
hidx,
arg,
)
vt = threadIdx().x
lg = Geometry.LocalGeometry(space, idx - half, hidx)
= Operators.getidx(space, arg, loc, idx - half, hidx)
Ju³[vt] = Geometry.Jcontravariant3(u³, lg)
end

Base.@propagate_inbounds function fd_operator_evaluate(
op::DivergenceF2C,
Ju³,
loc,
space,
idx,
hidx,
args...,
)
vt = threadIdx().x
local_geometry = Geometry.LocalGeometry(space, idx, hidx)
Ju³₋ = Ju³[vt] # corresponds to idx - half
Ju³₊ = Ju³[vt + 1] # corresponds to idx + half
return (Ju³₊ Ju³₋) local_geometry.invJ
end
134 changes: 134 additions & 0 deletions ext/cuda/operators_fd_shmem_common.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import ClimaCore: DataLayouts, Spaces, Geometry, RecursiveApply, DataLayouts
import CUDA
import ClimaCore.Operators: DivergenceF2C
import ClimaCore.Operators: return_eltype, get_local_geometry
import ClimaCore.Operators: getidx

# We don't support all operators yet, so let's use a custom getidx
Base.@propagate_inbounds function shmem_getidx(
# Base.@propagate_inbounds function getidx(
space,
bc::StencilBroadcasted{CUDAColumnStencilStyle},
loc,
idx,
hidx,
)
fd_operator_evaluate(bc.op, bc.work, loc, space, idx, hidx, bc.args...)
end

"""
fd_allocate_shmem(Val(Nvt), b)
Create a new broadcasted object with necessary share memory allocated,
using `Nvt` nodal points per block.
"""
@inline function fd_allocate_shmem(::Val{Nvt}, obj) where {Nvt}
obj
end
@inline function fd_allocate_shmem(
::Val{Nvt},
bc::Broadcasted{Style},
) where {Nvt, Style}
Broadcasted{Style}(bc.f, _fd_allocate_shmem(Val(Nvt), bc.args...), bc.axes)
end
@inline function fd_allocate_shmem(
::Val{Nvt},
sbc::StencilBroadcasted{Style},
) where {Nvt, Style}
args = _fd_allocate_shmem(Val(Nvt), sbc.args...)
work = fd_operator_shmem(sbc.axes, Val(Nvt), sbc.op, args...)
StencilBroadcasted{Style}(sbc.op, args, sbc.axes, work)
end

@inline _fd_allocate_shmem(::Val{Nvt}) where {Nvt} = ()
@inline _fd_allocate_shmem(::Val{Nvt}, arg, xargs...) where {Nvt} = (
fd_allocate_shmem(Val(Nvt), arg),
_fd_allocate_shmem(Val(Nvt), xargs...)...,
)

"""
fd_resolve_shmem!(
sbc::StencilBroadcasted,
idx,
hidx,
bds
)
Recursively stores the arguments to all operators into shared memory, at the
given indices (if they are valid).
As this calls `sync_threads()`, it should be called collectively on all threads
at the same time.
"""
Base.@propagate_inbounds function fd_resolve_shmem!(
sbc::StencilBroadcasted,
idx,
hidx,
bds,
)
(li, lw, rw, ri) = bds
space = axes(sbc)
# Do we need an _extra_ + 1 since shmem is loaded at idx - half? need to generalize to boundary window?
isactive = li <= idx <= ri + 1

_fd_resolve_shmem!(idx, hidx, bds, sbc.args...)

if isactive
if li <= idx <= (lw - 1)
lwindow =
Operators.LeftBoundaryWindow{Spaces.left_boundary_name(space)}()
fd_operator_fill_shmem!(
sbc.op,
sbc.work,
lwindow,
space,
idx,
hidx,
sbc.args...,
)
elseif (rw + 1 + 1) <= idx <= ri
rwindow = RightBoundaryWindow{Spaces.right_boundary_name(space)}()
fd_operator_fill_shmem!(
sbc.op,
sbc.work,
rwindow,
space,
idx,
hidx,
sbc.args...,
)
else
iwindow = Interior()
fd_operator_fill_shmem!(
sbc.op,
sbc.work,
iwindow,
space,
idx,
hidx,
sbc.args...,
)
end
end
return nothing
end

@inline _fd_resolve_shmem!(idx, hidx, bds) = nothing
@inline function _fd_resolve_shmem!(idx, hidx, bds, arg, xargs...)
fd_resolve_shmem!(arg, idx, hidx, bds)
_fd_resolve_shmem!(idx, hidx, bds, xargs...)
end


Base.@propagate_inbounds function fd_resolve_shmem!(
bc::Broadcasted,
idx,
hidx,
bds,
)
_fd_resolve_shmem!(idx, hidx, bds, bc.args...)
return nothing
end
Base.@propagate_inbounds function fd_resolve_shmem!(obj, idx, hidx, bds)
nothing
end
99 changes: 87 additions & 12 deletions ext/cuda/operators_finite_difference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import ClimaCore: Spaces, Quadratures, Topologies
import Base.Broadcast: Broadcasted
import ClimaComms
using CUDA: @cuda
import ClimaCore.Utilities: half
import ClimaCore.Operators: AbstractStencilStyle, strip_space
import ClimaCore.Operators: setidx!, getidx
import ClimaCore.Operators: StencilBroadcasted
Expand All @@ -10,6 +11,15 @@ import ClimaCore.Operators: LeftBoundaryWindow, RightBoundaryWindow, Interior
struct CUDAColumnStencilStyle <: AbstractStencilStyle end
AbstractStencilStyle(::ClimaComms.CUDADevice) = CUDAColumnStencilStyle

Base.@propagate_inbounds function getidx(
space,
sbc::StencilBroadcasted{CUDAColumnStencilStyle},
ij,
slabidx,
)
operator_evaluate(sbc.op, sbc.work, sbc.axes, ij, slabidx)
end

function Base.copyto!(
out::Field,
bc::Union{
Expand All @@ -21,25 +31,38 @@ function Base.copyto!(
bounds = Operators.window_bounds(space, bc)
out_fv = Fields.field_values(out)
us = DataLayouts.UniversalSize(out_fv)
args =
(strip_space(out, space), strip_space(bc, space), axes(out), bounds, us)

threads = threads_via_occupancy(copyto_stencil_kernel!, args)
n_max_threads = min(threads, get_N(us))
p = partition(out_fv, n_max_threads)

auto_launch!(
copyto_stencil_kernel!,
args;
threads_s = p.threads,
blocks_s = p.blocks,
p = fd_stencil_partition(us)
args = (
strip_space(out, space),
strip_space(bc, space),
axes(out),
bounds,
us,
Val(p.Nvthreads),
)

if bc.op isa DivergenceF2C{@NamedTuple{}}
auto_launch!(
copyto_stencil_kernel_shmem!,
args;
threads_s = p.threads,
blocks_s = p.blocks,
)
else
auto_launch!(
copyto_stencil_kernel!,
args;
threads_s = p.threads,
blocks_s = p.blocks,
)
end
call_post_op_callback() && post_op_callback(out, out, bc)
return out
end
import ClimaCore.DataLayouts: get_N, get_Nv, get_Nij, get_Nij, get_Nh

function copyto_stencil_kernel!(out, bc, space, bds, us)
function copyto_stencil_kernel!(out, bc, space, bds, us, ::Val{Nvt}) where {Nvt}
@inbounds begin
out_fv = Fields.field_values(out)
I = universal_index(out_fv)
Expand All @@ -64,3 +87,55 @@ function copyto_stencil_kernel!(out, bc, space, bds, us)
end
return nothing
end


function copyto_stencil_kernel_shmem!(
out,
bc′,
space,
bds,
us,
::Val{Nvt},
) where {Nvt}
@inbounds begin
out_fv = Fields.field_values(out)
us = DataLayouts.UniversalSize(out_fv)
I = fd_stencil_universal_index(space, us)
if fd_stencil_is_valid_index(I, us) # check that hidx is in bounds
(li, lw, rw, ri) = bds
(i, j, _, v, h) = I.I
hidx = (i, j, h)
idx = v - 1 + li
bc = Operators.reconstruct_placeholder_broadcasted(space, bc′)
bc_shmem = fd_allocate_shmem(Val(Nvt), bc) # allocates shmem

fd_resolve_shmem!(bc_shmem, idx, hidx, bds) # recursively fills shmem
CUDA.sync_threads()

nv = Spaces.nlevels(space)
isactive = if Operators.is_face_space(space) # check that idx is in bounds
idx + half <= nv
else
idx <= nv
end
if isactive
# Call getidx overloaded in operators_fd_shmem_common.jl
if li <= idx <= (lw - 1)
lwindow =
LeftBoundaryWindow{Spaces.left_boundary_name(space)}()
val = shmem_getidx(space, bc_shmem, lwindow, idx, hidx)
elseif (rw + 1) <= idx <= ri
rwindow =
RightBoundaryWindow{Spaces.right_boundary_name(space)}()
val = shmem_getidx(space, bc_shmem, rwindow, idx, hidx)
else
# @assert lw <= idx <= rw
iwindow = Interior()
val = shmem_getidx(space, bc_shmem, iwindow, idx, hidx)
end
setidx!(space, out, idx, hidx, val)
end
end
end
return nothing
end
Loading

0 comments on commit 82ff728

Please sign in to comment.