Skip to content

Commit

Permalink
Fix adapt for vert topo and Topology2D
Browse files Browse the repository at this point in the history
  • Loading branch information
Charlie Kawczynski committed Feb 14, 2025
1 parent c19162a commit 124f35d
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/Topologies/interval.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ struct IntervalTopology{
boundaries::B
end

Adapt.@adapt_structure IntervalTopology

## gpu
struct DeviceIntervalTopology{B} <: AbstractIntervalTopology
boundaries::B
Expand Down
39 changes: 39 additions & 0 deletions src/Topologies/topology2d.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,45 @@ mutable struct Topology2D{
ghost_face_neighbor_loc::Vector{Int}
end

function Adapt.adapt_structure(to, topo::Topology2D)
return Topology2D(
Adapt.adapt(to, topo.context),
Adapt.adapt(to, topo.mesh),
Adapt.adapt(to, topo.elemorder),
Adapt.adapt(to, topo.orderindex),
topo.elempid,
topo.local_elem_gidx,
topo.neighbor_pids,
topo.send_elem_lidx,
topo.send_elem_lengths,
topo.recv_elem_gidx,
topo.recv_elem_lengths,
Adapt.adapt(to, topo.interior_faces),
Adapt.adapt(to, topo.ghost_faces),
Adapt.adapt(to, topo.local_vertices),
Adapt.adapt(to, topo.local_vertex_offset),
Adapt.adapt(to, topo.ghost_vertices),
Adapt.adapt(to, topo.ghost_vertex_offset),
Adapt.adapt(to, topo.local_neighbor_elem),
Adapt.adapt(to, topo.local_neighbor_elem_offset),
topo.ghost_neighbor_elem,
topo.ghost_neighbor_elem_offset,
Adapt.adapt(to, topo.boundaries),
topo.internal_elems,
topo.perimeter_elems,
topo.nglobalvertices,
topo.nglobalfaces,
topo.ghost_vertex_gcidx,
topo.ghost_face_gcidx,
topo.comm_vertex_lengths,
topo.comm_face_lengths,
topo.ghost_vertex_neighbor_loc,
topo.ghost_vertex_comm_idx_offset,
Adapt.adapt(to, topo.repr_ghost_vertex),
topo.ghost_face_neighbor_loc,
)
end

ClimaComms.device(topology::Topology2D) = ClimaComms.device(topology.context)
ClimaComms.array_type(topology::Topology2D) =
ClimaComms.array_type(topology.context.device)
Expand Down
3 changes: 3 additions & 0 deletions test/Fields/unit_field.jl
Original file line number Diff line number Diff line change
Expand Up @@ -730,6 +730,9 @@ function test_adapt(cpu_space_in)
# cpu -> gpu
gpu_f_out = ClimaCore.to_device(ClimaComms.CUDADevice(), cpu_f_in)
@test parent(Fields.field_values(gpu_f_out)) isa CUDA.CuArray
@test ClimaComms.device(gpu_f_out) isa ClimaComms.CUDADevice
@test ClimaComms.array_type(gpu_f_out) == CUDA.CuArray

# gpu -> gpu
cpu_f_out =
ClimaCore.to_device(ClimaComms.CPUSingleThreaded(), gpu_f_out)
Expand Down

0 comments on commit 124f35d

Please sign in to comment.