Skip to content

Commit

Permalink
applyPhysBC: GPU
Browse files Browse the repository at this point in the history
  • Loading branch information
WeiqunZhang committed Jan 17, 2024
1 parent 71cb2b1 commit 4c02d45
Showing 1 changed file with 34 additions and 4 deletions.
38 changes: 34 additions & 4 deletions Src/LinearSolvers/MLMG/AMReX_MLCurlCurl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ void MLCurlCurl::averageDownSolutionRHS (int camrlev, MF& crse_sol, MF& crse_rhs
}

void
MLCurlCurl::apply (int amrlev, int mglev, MF& out, MF& in, BCMode bc_mode,
StateMode s_mode, const MLMGBndryT<MF>* /*bndry*/) const
MLCurlCurl::apply (int amrlev, int mglev, MF& out, MF& in, BCMode /*bc_mode*/,
StateMode /*s_mode*/, const MLMGBndryT<MF>* /*bndry*/) const
{
applyBC(amrlev, mglev, in);

Expand Down Expand Up @@ -450,11 +450,33 @@ void MLCurlCurl::applyBC (int amrlev, int mglev, MultiFab& mf) const
applyPhysBC(amrlev, mglev, mf);
}

#ifdef AMREX_USE_GPU
struct MLCurlCurlBCTag {
Array4<Real> fab;
Box bx;
Orientation face;

[[nodiscard]] AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
Box const& box() const noexcept { return bx; }
};
#endif

void MLCurlCurl::applyPhysBC (int amrlev, int mglev, MultiFab& mf) const
{
auto const idxtype = mf.ixType();
Box const domain = amrex::convert(this->m_geom[amrlev][mglev].Domain(), idxtype);
for (MFIter mfi(mf); mfi.isValid(); ++mfi) {

MFItInfo mfi_info{};

#ifdef AMREX_USE_GPU
Vector<MLCurlCurlBCTag> tags;
mfi_info.DisableDeviceSync();
#endif

#ifdef AMREX_USE_OMP
#pragma omp parallel if (Gpu::notInLaunchRegion());
#endif
for (MFIter mfi(mf,mfi_info); mfi.isValid(); ++mfi) {
auto const& vbx = mfi.validbox();
auto const& a = mf.array(mfi);
for (OrientationIter oit; oit; ++oit) {
Expand All @@ -468,7 +490,7 @@ void MLCurlCurl::applyPhysBC (int amrlev, int mglev, MultiFab& mf) const
int shift = face.isLow() ? -1 : 1;
b.setRange(idim, domain[face] + shift, 1);
#ifdef AMREX_USE_GPU
static_assert(false, "MLCurlCurl: todo");
tags.emplace_back(MLCurlCurlBCTag{a,b,face});
#else
amrex::LoopOnCpu(b, [&] (int i, int j, int k)
{
Expand All @@ -478,6 +500,14 @@ void MLCurlCurl::applyPhysBC (int amrlev, int mglev, MultiFab& mf) const
}
}
}

#ifdef AMREX_USE_GPU
ParallelFor(tags,
[=] AMREX_GPU_DEVICE (int i, int j, int k, MLCurlCurlBCTag const& tag) noexcept
{
mlcurlcurl_bc_symmetry(i, j, k, tag.face, idxtype, tag.fab);
});
#endif
}

iMultiFab const& MLCurlCurl::getDotMask (int amrlev, int mglev, int idim) const
Expand Down

0 comments on commit 4c02d45

Please sign in to comment.