Skip to content

Commit

Permalink
Tpetra: TAFC Converted to use Kokkos (Prototype 2)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ian Halim committed Aug 14, 2024
1 parent 0fc88a3 commit 400e9d3
Show file tree
Hide file tree
Showing 3 changed files with 188 additions and 120 deletions.
63 changes: 30 additions & 33 deletions packages/tpetra/core/src/Tpetra_CrsMatrix_def.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8302,59 +8302,45 @@ CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Node>::
<< std::endl;
std::cerr << os.str ();
}
// Make sure that host has the latest version, since we're
// using the version on host. If host has the latest
// version, syncing to host does nothing.
destMat->numExportPacketsPerLID_.sync_host ();
Teuchos::ArrayView<const size_t> numExportPacketsPerLID =
getArrayViewFromDualView (destMat->numExportPacketsPerLID_);
destMat->numImportPacketsPerLID_.sync_host ();
Teuchos::ArrayView<size_t> numImportPacketsPerLID =
getArrayViewFromDualView (destMat->numImportPacketsPerLID_);

destMat->numExportPacketsPerLID_.sync_device(); //KEEP
auto numExportPacketsPerLID = destMat->numExportPacketsPerLID_.view_device();
auto numImportPacketsPerLID = destMat->numImportPacketsPerLID_.view_device();
if (verbose) {
std::ostringstream os;
os << *verbosePrefix << "Calling 3-arg doReversePostsAndWaits"
<< std::endl;
std::cerr << os.str ();
}
Distor.doReversePostsAndWaits(destMat->numExportPacketsPerLID_.view_host(), 1,
destMat->numImportPacketsPerLID_.view_host());
Distor.doReversePostsAndWaits(numExportPacketsPerLID, 1, numImportPacketsPerLID);
if (verbose) {
std::ostringstream os;
os << *verbosePrefix << "Finished 3-arg doReversePostsAndWaits"
<< std::endl;
std::cerr << os.str ();
}

size_t totalImportPackets = 0;
for (Array_size_type i = 0; i < numImportPacketsPerLID.size (); ++i) {
totalImportPackets += numImportPacketsPerLID[i];
}
size_t totalImportPackets = Kokkos::Experimental::reduce(typename Node::execution_space(), numImportPacketsPerLID);

// Reallocation MUST go before setting the modified flag,
// because it may clear out the flags.
destMat->reallocImportsIfNeeded (totalImportPackets, verbose,
verbosePrefix.get ());
destMat->imports_.modify_host ();
auto hostImports = destMat->imports_.view_host();
// This is a legacy host pack/unpack path, so use the host
// version of exports_.
destMat->exports_.sync_host ();
auto hostExports = destMat->exports_.view_host();
destMat->imports_.modify_host (); //KEEP?
auto deviceImports = destMat->imports_.view_device();
auto deviceExports = destMat->exports_.view_device();
if (verbose) {
std::ostringstream os;
os << *verbosePrefix << "Calling 4-arg doReversePostsAndWaits"
os << *verbosePrefix << "Calling 4-arg doReversePostsAndWaitsKokkos"
<< std::endl;
std::cerr << os.str ();
}
Distor.doReversePostsAndWaits (hostExports,
numExportPacketsPerLID,
hostImports,
numImportPacketsPerLID);
destMat->imports_.sync_device(); //KEEP
destMat->numImportPacketsPerLID_.modify_device(); //KEEP
Distor.doReversePostsAndWaitsKokkos (deviceExports, numExportPacketsPerLID, deviceImports, numImportPacketsPerLID);
destMat->numImportPacketsPerLID_.sync_host(); //KEEP
if (verbose) {
std::ostringstream os;
os << *verbosePrefix << "Finished 4-arg doReversePostsAndWaits"
os << *verbosePrefix << "Finished 4-arg doReversePostsAndWaitsKokkos"
<< std::endl;
std::cerr << os.str ();
}
Expand Down Expand Up @@ -8397,6 +8383,7 @@ CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Node>::
<< std::endl;
std::cerr << os.str ();
}
destMat->numExportPacketsPerLID_.sync_device (); //KEEP
auto numExportPacketsPerLID = destMat->numExportPacketsPerLID_.view_device();
auto numImportPacketsPerLID = destMat->numImportPacketsPerLID_.view_device();
if (verbose) {
Expand All @@ -8405,7 +8392,6 @@ CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Node>::
<< std::endl;
std::cerr << os.str ();
}
destMat->numExportPacketsPerLID_.sync_device (); //remove later when above section is converted to device
Distor.doPostsAndWaits(numExportPacketsPerLID, 1, numImportPacketsPerLID);
if (verbose) {
std::ostringstream os;
Expand All @@ -8416,8 +8402,11 @@ CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Node>::

size_t totalImportPackets = Kokkos::Experimental::reduce(typename Node::execution_space(), numImportPacketsPerLID);

// Reallocation MUST go before setting the modified flag,
// because it may clear out the flags.
destMat->reallocImportsIfNeeded (totalImportPackets, verbose,
verbosePrefix.get ());
destMat->imports_.modify_host (); //KEEP?
auto deviceImports = destMat->imports_.view_device();
auto deviceExports = destMat->exports_.view_device();
if (verbose) {
Expand All @@ -8426,10 +8415,10 @@ CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Node>::
<< std::endl;
std::cerr << os.str ();
}
Distor.doPostsAndWaitsKokkos (deviceExports,
numExportPacketsPerLID,
deviceImports,
numImportPacketsPerLID);
destMat->imports_.sync_device (); //KEEP
destMat->numImportPacketsPerLID_.modify_device (); //KEEP
Distor.doPostsAndWaitsKokkos (deviceExports, numExportPacketsPerLID, deviceImports, numImportPacketsPerLID);
destMat->numImportPacketsPerLID_.sync_host (); //KEEP
if (verbose) {
std::ostringstream os;
os << *verbosePrefix << "Finished 4-arg doPostsAndWaitsKokkos"
Expand Down Expand Up @@ -8679,6 +8668,11 @@ CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Node>::

// Backwards compatibility measure. We'll use this again below.

// TODO JHU Need to track down why numImportPacketsPerLID_ has not been corrently marked as modified on host (which it has been)
// TODO JHU somewhere above, e.g., call to Distor.doPostsAndWaits().
// TODO JHU This only becomes apparent as we begin to convert TAFC to run on device.
destMat->numImportPacketsPerLID_.modify_host(); //FIXME

# ifdef HAVE_TPETRA_MMM_TIMINGS
RCP<TimeMonitor> tmCopySPRdata = rcp(new TimeMonitor(*TimeMonitor::getNewTimer(prefix + std::string("TAFC unpack-count-resize + copy same-perm-remote data"))));
# endif
Expand All @@ -8687,6 +8681,9 @@ CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Node>::
ArrayRCP<LO> CSR_colind_LID;
ArrayRCP<Scalar> CSR_vals;

destMat->imports_.sync_device ();
destMat->numImportPacketsPerLID_.sync_device ();

size_t N = BaseRowMap->getLocalNumElements ();

auto RemoteLIDs_d = RemoteLIDs.view_device();
Expand Down
Loading

0 comments on commit 400e9d3

Please sign in to comment.