Skip to content

Commit

Permalink
use adapt_structure macro where-ever possible
Browse files Browse the repository at this point in the history
  • Loading branch information
vchuravy committed Jan 19, 2025
1 parent eed7fb2 commit 3fff1fa
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 71 deletions.
35 changes: 2 additions & 33 deletions src/solvers/dgsem/basis_lobatto_legendre.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,30 +128,7 @@ In particular, not the nodes themselves are returned.

@inline get_nodes(basis::LobattoLegendreBasis) = basis.nodes

function Adapt.adapt_structure(to, basis::LobattoLegendreBasis)
# Do not adapt SVector fields, i.e. nodes, weights and inverse_weights
(; nodes, weights, inverse_weights) = basis
inverse_vandermonde_legendre = Adapt.adapt_structure(to,
basis.inverse_vandermonde_legendre)
boundary_interpolation = basis.boundary_interpolation
derivative_matrix = Adapt.adapt_structure(to, basis.derivative_matrix)
derivative_split = Adapt.adapt_structure(to, basis.derivative_split)
derivative_split_transpose = Adapt.adapt_structure(to,
basis.derivative_split_transpose)
derivative_dhat = Adapt.adapt_structure(to, basis.derivative_dhat)
return LobattoLegendreBasis{real(basis), nnodes(basis), typeof(basis.nodes),
typeof(inverse_vandermonde_legendre),
typeof(boundary_interpolation),
typeof(derivative_matrix)}(nodes,
weights,
inverse_weights,
inverse_vandermonde_legendre,
boundary_interpolation,
derivative_matrix,
derivative_split,
derivative_split_transpose,
derivative_dhat)
end
Adapt.@adapt_structure(LobattoLegendreBasis)

"""
integrate(f, u, basis::LobattoLegendreBasis)
Expand Down Expand Up @@ -241,15 +218,7 @@ end

@inline polydeg(mortar::LobattoLegendreMortarL2) = nnodes(mortar) - 1

function Adapt.adapt_structure(to, mortar::LobattoLegendreMortarL2)
forward_upper = Adapt.adapt_structure(to, mortar.forward_upper)
forward_lower = Adapt.adapt_structure(to, mortar.forward_lower)
reverse_upper = Adapt.adapt_structure(to, mortar.reverse_upper)
reverse_lower = Adapt.adapt_structure(to, mortar.reverse_lower)
return LobattoLegendreMortarL2{real(mortar), nnodes(mortar), typeof(forward_upper),
typeof(reverse_upper)}(forward_upper, forward_lower,
reverse_upper, reverse_lower)
end
Adapt.@adapt_structure(LobattoLegendreMortarL2)

# TODO: We can create EC mortars along the lines of the following implementation.
# abstract type AbstractMortarEC{RealT} <: AbstractMortar{RealT} end
Expand Down
14 changes: 9 additions & 5 deletions src/solvers/dgsem_p4est/containers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ end
function KernelAbstractions.get_backend(elements::P4estElementContainer)
return KernelAbstractions.get_backend(elements.node_coordinates)
end
# Adapt.@adapt_structure(P4estElementContainer)
function Adapt.adapt_structure(to,
elements::P4estElementContainer{NDIMS, RealT, uEltype}) where {
NDIMS,
Expand All @@ -158,7 +159,7 @@ function Adapt.adapt_structure(to,
_contravariant_vectors = Adapt.adapt_structure(to, elements._contravariant_vectors)
_inverse_jacobian = Adapt.adapt_structure(to, elements._inverse_jacobian)
_surface_flux_values = Adapt.adapt_structure(to, elements._surface_flux_values)

# Wrap arrays again
node_coordinates = unsafe_wrap_or_alloc(to, _node_coordinates,
size(elements.node_coordinates))
Expand Down Expand Up @@ -298,15 +299,18 @@ end
function KernelAbstractions.get_backend(interfaces::P4estInterfaceContainer)
return KernelAbstractions.get_backend(interfaces.u)
end
# Adapt.@adapt_structure(P4estInterfaceContainer)
function Adapt.adapt_structure(to, interfaces::P4estInterfaceContainer)
# Adapt underlying storage
_u = Adapt.adapt_structure(to, interfaces._u)
_neighbor_ids = Adapt.adapt_structure(to, interfaces._neighbor_ids)
_node_indices = Adapt.adapt_structure(to, interfaces._node_indices)
# Wrap arrays again
u = unsafe_wrap_or_alloc(to, _u, size(interfaces.u))
neighbor_ids = unsafe_wrap_or_alloc(to, _neighbor_ids, size(interfaces.neighbor_ids))
node_indices = unsafe_wrap_or_alloc(to, _node_indices, size(interfaces.node_indices))
neighbor_ids = unsafe_wrap_or_alloc(to, _neighbor_ids,
size(interfaces.neighbor_ids))
node_indices = unsafe_wrap_or_alloc(to, _node_indices,
size(interfaces.node_indices))

NDIMS = ndims(interfaces)
new_type_params = (NDIMS,
Expand Down Expand Up @@ -449,7 +453,7 @@ function Adapt.adapt_structure(to, boundaries::P4estBoundaryContainer)
neighbor_ids = Adapt.adapt_structure(to, boundaries.neighbor_ids)
node_indices = Adapt.adapt_structure(to, boundaries.node_indices)
name = boundaries.name

NDIMS = ndims(boundaries)
return P4estBoundaryContainer{NDIMS, eltype(boundaries), NDIMS + 1, typeof(u),
typeof(neighbor_ids), typeof(node_indices),
Expand Down Expand Up @@ -583,6 +587,7 @@ end
function KernelAbstractions.get_backend(mortars::P4estMortarContainer)
return KernelAbstractions.get_backend(mortars.u)
end
# Adapt.@adapt_structure P4estMortarContainer
function Adapt.adapt_structure(to, mortars::P4estMortarContainer)
# Adapt underlying storage
_u = Adapt.adapt_structure(to, mortars._u)
Expand All @@ -594,7 +599,6 @@ function Adapt.adapt_structure(to, mortars::P4estMortarContainer)
neighbor_ids = unsafe_wrap_or_alloc(to, _neighbor_ids, size(mortars.neighbor_ids))
node_indices = unsafe_wrap_or_alloc(to, _node_indices, size(mortars.node_indices))


NDIMS = ndims(mortars)
new_type_params = (NDIMS,
eltype(mortars),
Expand Down
3 changes: 2 additions & 1 deletion src/solvers/dgsem_p4est/containers_parallel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ end
function KernelAbstractions.get_backend(mpi_interfaces::P4estMPIInterfaceContainer)
return KernelAbstractions.get_backend(mpi_interfaces.u)
end
# Adapt.@adapt_structure(P4estMPIInterfaceContainer)
function Adapt.adapt_structure(to, mpi_interfaces::P4estMPIInterfaceContainer)
# Adapt Vectors and underlying storage
_u = Adapt.adapt_structure(to, mpi_interfaces._u)
Expand Down Expand Up @@ -201,7 +202,7 @@ function init_mpi_mortars(mesh::Union{ParallelP4estMesh, ParallelT8codeMesh}, eq

mpi_mortars = P4estMPIMortarContainer{NDIMS, uEltype, RealT, NDIMS + 1, NDIMS + 2,
NDIMS + 3, typeof(u),
typeof(_u),
typeof(_u),
Array, false}(u, local_neighbor_ids,
local_neighbor_positions,
node_indices, normal_directions,
Expand Down
25 changes: 3 additions & 22 deletions src/solvers/dgsem_p4est/dg_parallel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
@muladd begin
#! format: noindent

mutable struct P4estMPICache{BufferType <: DenseVector, VecInt <: DenseVector{<:Integer}}
mutable struct P4estMPICache{BufferType <: DenseVector,
VecInt <: DenseVector{<:Integer}}
mpi_neighbor_ranks::Vector{Int}
mpi_neighbor_interfaces::VecOfArrays{VecInt}
mpi_neighbor_mortars::VecOfArrays{VecInt}
Expand Down Expand Up @@ -46,27 +47,7 @@ end

@inline Base.eltype(::P4estMPICache{BufferType}) where {BufferType} = eltype(BufferType)

function Adapt.adapt_structure(to, mpi_cache::P4estMPICache)
mpi_neighbor_ranks = mpi_cache.mpi_neighbor_ranks
mpi_neighbor_interfaces = Adapt.adapt_structure(to, mpi_cache.mpi_neighbor_interfaces)
mpi_neighbor_mortars = Adapt.adapt_structure(to, mpi_cache.mpi_neighbor_mortars)
mpi_send_buffers = Adapt.adapt_structure(to, mpi_cache.mpi_send_buffers)
mpi_recv_buffers = Adapt.adapt_structure(to, mpi_cache.mpi_recv_buffers)
mpi_send_requests = mpi_cache.mpi_send_requests
mpi_recv_requests = mpi_cache.mpi_recv_requests
n_elements_by_rank = mpi_cache.n_elements_by_rank
n_elements_global = mpi_cache.n_elements_global
first_element_global_id = mpi_cache.first_element_global_id

@assert eltype(mpi_send_buffers) == eltype(mpi_recv_buffers)
BufferType = eltype(mpi_send_buffers)
VecInt = eltype(mpi_neighbor_interfaces)
return P4estMPICache{BufferType, VecInt}(mpi_neighbor_ranks, mpi_neighbor_interfaces,
mpi_neighbor_mortars, mpi_send_buffers,
mpi_recv_buffers, mpi_send_requests,
mpi_recv_requests, n_elements_by_rank,
n_elements_global, first_element_global_id)
end
Adapt.@adapt_structure(P4estMPICache)

function start_mpi_send!(mpi_cache::P4estMPICache, mesh, equations, dg, cache)
data_size = nvariables(equations) * nnodes(dg)^(ndims(mesh) - 1)
Expand Down
11 changes: 1 addition & 10 deletions src/solvers/dgsem_unstructured/sort_boundary_conditions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,5 @@ function initialize!(boundary_types_container::UnstructuredSortedBoundaryTypes{N
return boundary_types_container
end

function Adapt.adapt_structure(to, bcs::UnstructuredSortedBoundaryTypes)
boundary_indices = Adapt.adapt_structure(to, bcs.boundary_indices)
n_boundary_types = length(bcs.boundary_condition_types)
return UnstructuredSortedBoundaryTypes{n_boundary_types,
typeof(bcs.boundary_condition_types),
eltype(boundary_indices)}(bcs.boundary_condition_types,
boundary_indices,
bcs.boundary_dictionary,
bcs.boundary_symbol_indices)
end
Adapt.@adapt_structure(UnstructuredSortedBoundaryTypes)
end # @muladd

0 comments on commit 3fff1fa

Please sign in to comment.