From 124f35dd9f0414974669e6b95a7a8893ffb126b7 Mon Sep 17 00:00:00 2001 From: Charlie Kawczynski Date: Fri, 14 Feb 2025 11:04:06 -0800 Subject: [PATCH] Fix adapt for vert topo and Topology2D --- src/Topologies/interval.jl | 2 ++ src/Topologies/topology2d.jl | 39 ++++++++++++++++++++++++++++++++++++ test/Fields/unit_field.jl | 3 +++ 3 files changed, 44 insertions(+) diff --git a/src/Topologies/interval.jl b/src/Topologies/interval.jl index f1084e0d31..72c584d70a 100644 --- a/src/Topologies/interval.jl +++ b/src/Topologies/interval.jl @@ -16,6 +16,8 @@ struct IntervalTopology{ boundaries::B end +Adapt.@adapt_structure IntervalTopology + ## gpu struct DeviceIntervalTopology{B} <: AbstractIntervalTopology boundaries::B diff --git a/src/Topologies/topology2d.jl b/src/Topologies/topology2d.jl index 607a4e40d1..3b5b4bb369 100644 --- a/src/Topologies/topology2d.jl +++ b/src/Topologies/topology2d.jl @@ -124,6 +124,45 @@ mutable struct Topology2D{ ghost_face_neighbor_loc::Vector{Int} end +function Adapt.adapt_structure(to, topo::Topology2D) + return Topology2D( + Adapt.adapt(to, topo.context), + Adapt.adapt(to, topo.mesh), + Adapt.adapt(to, topo.elemorder), + Adapt.adapt(to, topo.orderindex), + topo.elempid, + topo.local_elem_gidx, + topo.neighbor_pids, + topo.send_elem_lidx, + topo.send_elem_lengths, + topo.recv_elem_gidx, + topo.recv_elem_lengths, + Adapt.adapt(to, topo.interior_faces), + Adapt.adapt(to, topo.ghost_faces), + Adapt.adapt(to, topo.local_vertices), + Adapt.adapt(to, topo.local_vertex_offset), + Adapt.adapt(to, topo.ghost_vertices), + Adapt.adapt(to, topo.ghost_vertex_offset), + Adapt.adapt(to, topo.local_neighbor_elem), + Adapt.adapt(to, topo.local_neighbor_elem_offset), + topo.ghost_neighbor_elem, + topo.ghost_neighbor_elem_offset, + Adapt.adapt(to, topo.boundaries), + topo.internal_elems, + topo.perimeter_elems, + topo.nglobalvertices, + topo.nglobalfaces, + topo.ghost_vertex_gcidx, + topo.ghost_face_gcidx, + topo.comm_vertex_lengths, + topo.comm_face_lengths, + topo.ghost_vertex_neighbor_loc, + topo.ghost_vertex_comm_idx_offset, + Adapt.adapt(to, topo.repr_ghost_vertex), + topo.ghost_face_neighbor_loc, + ) +end + ClimaComms.device(topology::Topology2D) = ClimaComms.device(topology.context) ClimaComms.array_type(topology::Topology2D) = ClimaComms.array_type(topology.context.device) diff --git a/test/Fields/unit_field.jl b/test/Fields/unit_field.jl index f05afb048f..1abec71de6 100644 --- a/test/Fields/unit_field.jl +++ b/test/Fields/unit_field.jl @@ -730,6 +730,9 @@ function test_adapt(cpu_space_in) # cpu -> gpu gpu_f_out = ClimaCore.to_device(ClimaComms.CUDADevice(), cpu_f_in) @test parent(Fields.field_values(gpu_f_out)) isa CUDA.CuArray + @test ClimaComms.device(gpu_f_out) isa ClimaComms.CUDADevice + @test ClimaComms.array_type(gpu_f_out) == CUDA.CuArray + # gpu -> gpu cpu_f_out = ClimaCore.to_device(ClimaComms.CPUSingleThreaded(), gpu_f_out)