Skip to content

Commit

Permalink
tpetra: adding back and deprecating getLocalRowView and
Browse files Browse the repository at this point in the history
getGlobalRowView in Tpetra_BlockMultiVector and Tpetra_BlockVector
These functions were inadvertantly removed without deprecation during
UVM removal.
The functions are, however, deprecated.  Users should use
getLocalBlockHost with access flags instead of these functions
  • Loading branch information
kddevin committed May 3, 2021
1 parent c00ff3b commit a2127a9
Show file tree
Hide file tree
Showing 8 changed files with 179 additions and 84 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ localApplyBlockNoTrans (Tpetra::BlockCrsMatrix<Scalar, LO, GO, Node>& A,

for (LO j = 0; j < numVecs; ++j) {
for (LO lclRow = 0; lclRow < numLocalMeshRows; ++lclRow) {
auto Y_cur = Y.getLocalBlock (lclRow, j, Tpetra::Access::ReadWrite);
auto Y_cur = Y.getLocalBlockHost (lclRow, j, Tpetra::Access::ReadWrite);
if (beta == zero) {
FILL (Y_lcl, zero);
} else if (beta == one) {
Expand All @@ -132,7 +132,7 @@ localApplyBlockNoTrans (Tpetra::BlockCrsMatrix<Scalar, LO, GO, Node>& A,

auto A_cur_1d = Kokkos::subview (val, absBlkOff * blockSize * blockSize);
little_blk_type A_cur (A_cur_1d.data (), blockSize, blockSize);
auto X_cur = X.getLocalBlock (meshCol, j, Tpetra::Access::ReadOnly);
auto X_cur = X.getLocalBlockHost (meshCol, j, Tpetra::Access::ReadOnly);

GEMV (alpha, A_cur, X_cur, Y_lcl); // Y_lcl += alpha*A_cur*X_cur
} // for each entry in the current local row of the matrix
Expand Down
52 changes: 38 additions & 14 deletions packages/tpetra/core/src/Tpetra_BlockMultiVector_decl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -577,39 +577,63 @@ class BlockMultiVector :
/// is invalid on the calling process.
bool sumIntoGlobalValues (const GO globalRowIndex, const LO colIndex, const Scalar vals[]);

#ifdef TPETRA_ENABLE_DEPRECATED_CODE
/// \brief Get a writeable view of the entries at the given mesh
/// point, using a local index.
///
/// \param localRowIndex [in] Local index of the mesh point.
/// \param colIndex [in] Column (vector) to view.
/// \param vals [out] View of the entries at the given mesh point.
///
/// \return true if successful, else false. This method will
/// <i>not</i> succeed if the given local index of the mesh point
/// is invalid on the calling process.
// TPETRA_DEPRECATED
bool getLocalRowView (const LO localRowIndex, const LO colIndex, Scalar*& vals);

/// \brief Get a writeable view of the entries at the given mesh
/// point, using a global index.
///
/// \param globalRowIndex [in] Global index of the mesh point.
/// \param colIndex [in] Column (vector) to view.
/// \param vals [out] View of the entries at the given mesh point.
///
/// \return true if successful, else false. This method will
/// <i>not</i> succeed if the given global index of the mesh point
/// is invalid on the calling process.
// TPETRA_DEPRECATED
bool getGlobalRowView (const GO globalRowIndex, const LO colIndex, Scalar*& vals);

/// \brief Get a host view of the degrees of freedom at the given
/// mesh point.
///
/// \warning This method's interface may change or disappear at any
/// time. Please do not rely on it in your code yet.
///
/// Prefer using \c auto to let the compiler compute the return
/// type. This gives us the freedom to change this type in the
/// future. If you insist not to use \c auto, then please use the
/// \c little_vec_type typedef to deduce the correct return type;
/// don't try to hard-code the return type yourself.
#ifdef TPETRA_ENABLE_DEPRECATED_CODE
//TPETRA_DEPRECATED
little_host_vec_type getLocalBlock (const LO localRowIndex, const LO colIndex) const;
#endif //TPETRA_DEPRECATED
little_host_vec_type getLocalBlock (const LO localRowIndex, const LO colIndex);

#endif // TPETRA_ENABLE_DEPRECATED_CODE

const_little_host_vec_type getLocalBlock(
const_little_host_vec_type getLocalBlockHost(
const LO localRowIndex,
const LO colIndex,
Access::ReadOnlyStruct) const;
const Access::ReadOnlyStruct) const;

little_host_vec_type getLocalBlock(
little_host_vec_type getLocalBlockHost(
const LO localRowIndex,
const LO colIndex,
Access::ReadWriteStruct);
const Access::ReadWriteStruct);

/// \brief Get a local block on host, with the intent to overwrite all blocks in the BlockMultiVector
/// before accessing the data on device. If you only intend to modify some blocks on host, use ReadWrite
/// instead (otherwise, previous changes on device may be lost)
little_host_vec_type getLocalBlock(
/// before accessing the data on device. If you intend to modify only some blocks on host, use
/// Access::ReadWrite instead (otherwise, previous changes on device may be lost)
little_host_vec_type getLocalBlockHost(
const LO localRowIndex,
const LO colIndex,
Access::OverwriteAllStruct);
const Access::OverwriteAllStruct);
//@}

protected:
Expand Down
60 changes: 46 additions & 14 deletions packages/tpetra/core/src/Tpetra_BlockMultiVector_def.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ replaceLocalValuesImpl (const LO localRowIndex,
const LO colIndex,
const Scalar vals[])
{
auto X_dst = getLocalBlock (localRowIndex, colIndex, Access::ReadWrite);
auto X_dst = getLocalBlockHost (localRowIndex, colIndex, Access::ReadWrite);
typename const_little_vec_type::HostMirror::const_type X_src (reinterpret_cast<const impl_scalar_type*> (vals),
getBlockSize ());
Kokkos::deep_copy (X_dst, X_src);
Expand Down Expand Up @@ -361,7 +361,7 @@ sumIntoLocalValuesImpl (const LO localRowIndex,
const LO colIndex,
const Scalar vals[])
{
auto X_dst = getLocalBlock (localRowIndex, colIndex, Access::ReadWrite);
auto X_dst = getLocalBlockHost (localRowIndex, colIndex, Access::ReadWrite);
typename const_little_vec_type::HostMirror::const_type X_src (reinterpret_cast<const impl_scalar_type*> (vals),
getBlockSize ());
AXPY (static_cast<impl_scalar_type> (STS::one ()), X_src, X_dst);
Expand Down Expand Up @@ -400,12 +400,43 @@ sumIntoGlobalValues (const GO globalRowIndex,

#ifdef TPETRA_ENABLE_DEPRECATED_CODE

template<class Scalar, class LO, class GO, class Node>
bool
// TPETRA_DEPRECATED
BlockMultiVector<Scalar, LO, GO, Node>::
getLocalRowView (const LO localRowIndex, const LO colIndex, Scalar*& vals)
{
if (! meshMap_.isNodeLocalElement (localRowIndex)) {
return false;
} else {
auto X_ij = getLocalBlockHost (localRowIndex, colIndex, Access::ReadWrite);
vals = reinterpret_cast<Scalar*> (X_ij.data ());
return true;
}
}

template<class Scalar, class LO, class GO, class Node>
bool
// TPETRA_DEPRECATED
BlockMultiVector<Scalar, LO, GO, Node>::
getGlobalRowView (const GO globalRowIndex, const LO colIndex, Scalar*& vals)
{
const LO localRowIndex = meshMap_.getLocalElement (globalRowIndex);
if (localRowIndex == Teuchos::OrdinalTraits<LO>::invalid ()) {
return false;
} else {
auto X_ij = getLocalBlockHost (localRowIndex, colIndex, Access::ReadWrite);
vals = reinterpret_cast<Scalar*> (X_ij.data ());
return true;
}
}

template<class Scalar, class LO, class GO, class Node>
typename BlockMultiVector<Scalar, LO, GO, Node>::little_host_vec_type
TPETRA_DEPRECATED
// TPETRA_DEPRECATED
BlockMultiVector<Scalar, LO, GO, Node>::
getLocalBlock (const LO localRowIndex,
const LO colIndex) const
const LO colIndex)
{
if (! isValidLocalMeshIndex (localRowIndex)) {
return little_host_vec_type ();
Expand All @@ -417,14 +448,15 @@ getLocalBlock (const LO localRowIndex,
return little_host_vec_type (blockRaw, blockSize);
}
}
#endif

#endif // TPETRA_ENABLE_DEPRECATED_CODE

template<class Scalar, class LO, class GO, class Node>
typename BlockMultiVector<Scalar, LO, GO, Node>::const_little_host_vec_type
BlockMultiVector<Scalar, LO, GO, Node>::
getLocalBlock (const LO localRowIndex,
const LO colIndex,
Access::ReadOnlyStruct) const
getLocalBlockHost (const LO localRowIndex,
const LO colIndex,
const Access::ReadOnlyStruct) const
{
if (!isValidLocalMeshIndex(localRowIndex)) {
return const_little_host_vec_type();
Expand All @@ -441,9 +473,9 @@ getLocalBlock (const LO localRowIndex,
template<class Scalar, class LO, class GO, class Node>
typename BlockMultiVector<Scalar, LO, GO, Node>::little_host_vec_type
BlockMultiVector<Scalar, LO, GO, Node>::
getLocalBlock (const LO localRowIndex,
const LO colIndex,
Access::OverwriteAllStruct)
getLocalBlockHost (const LO localRowIndex,
const LO colIndex,
const Access::OverwriteAllStruct)
{
if (!isValidLocalMeshIndex(localRowIndex)) {
return little_host_vec_type();
Expand All @@ -460,9 +492,9 @@ getLocalBlock (const LO localRowIndex,
template<class Scalar, class LO, class GO, class Node>
typename BlockMultiVector<Scalar, LO, GO, Node>::little_host_vec_type
BlockMultiVector<Scalar, LO, GO, Node>::
getLocalBlock (const LO localRowIndex,
const LO colIndex,
Access::ReadWriteStruct)
getLocalBlockHost (const LO localRowIndex,
const LO colIndex,
const Access::ReadWriteStruct)
{
if (!isValidLocalMeshIndex(localRowIndex)) {
return little_host_vec_type();
Expand Down
44 changes: 34 additions & 10 deletions packages/tpetra/core/src/Tpetra_BlockVector_decl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -315,12 +315,36 @@ class BlockVector : public BlockMultiVector<Scalar, LO, GO, Node> {
/// is invalid on the calling process.
bool sumIntoGlobalValues (const GO globalRowIndex, const Scalar vals[]);

#ifdef TPETRA_ENABLE_DEPRECATED_CODE
/// \brief Get a writeable view of the entries at the given mesh
/// point, using a local index.
///
/// \param localRowIndex [in] Local index of the mesh point.
/// \param vals [in] Input values with which to replace whatever
/// existing values are at the mesh point.
///
/// \return true if successful, else false. This method will
/// <i>not</i> succeed if the given local index of the mesh point
/// is invalid on the calling process.
bool getLocalRowView (const LO localRowIndex, Scalar*& vals);

/// \brief Get a writeable view of the entries at the given mesh
/// point, using a global index.
///
/// \param globalRowIndex [in] Global index of the mesh point.
/// \param vals [in] Input values with which to replace whatever
/// existing values are at the mesh point.
///
/// \return true if successful, else false. This method will
/// <i>not</i> succeed if the given global index of the mesh point
/// is invalid on the calling process.
bool getGlobalRowView (const GO globalRowIndex, Scalar*& vals);

#endif //TPETRA_ENABLE_DEPRECATED_CODE

/// \brief Get a view of the degrees of freedom at the given mesh point,
/// using a local index.
///
/// \warning This method's interface may change or disappear at any
/// time. Please do not rely on it in your code yet.
///
/// The preferred way to refer to little_vec_type is to get it from
/// BlockVector's typedef. This is because different
/// specializations of BlockVector reserve the right to use
Expand All @@ -329,14 +353,14 @@ class BlockVector : public BlockMultiVector<Scalar, LO, GO, Node> {
/// refactor version.
#ifdef TPETRA_ENABLE_DEPRECATED_CODE
//TPETRA_DEPRECATED
little_host_vec_type getLocalBlock (const LO localRowIndex) const;
little_host_vec_type getLocalBlock (const LO localRowIndex);
#endif
const_little_host_vec_type getLocalBlock (const LO localRowIndex,
Access::ReadOnlyStruct) const;
little_host_vec_type getLocalBlock (const LO localRowIndex,
Access::OverwriteAllStruct);
little_host_vec_type getLocalBlock (const LO localRowIndex,
Access::ReadWriteStruct);
const_little_host_vec_type getLocalBlockHost (const LO localRowIndex,
Access::ReadOnlyStruct) const;
little_host_vec_type getLocalBlockHost (const LO localRowIndex,
Access::OverwriteAllStruct);
little_host_vec_type getLocalBlockHost (const LO localRowIndex,
Access::ReadWriteStruct);
//@}
};

Expand Down
39 changes: 27 additions & 12 deletions packages/tpetra/core/src/Tpetra_BlockVector_def.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,27 @@ namespace Tpetra {
}

#ifdef TPETRA_ENABLE_DEPRECATED_CODE
template<class Scalar, class LO, class GO, class Node>
bool
// TPETRA_DEPRECATED
BlockVector<Scalar, LO, GO, Node>::
getLocalRowView (const LO localRowIndex, Scalar*& vals) {
return ((base_type*) this)->getLocalRowView (localRowIndex, 0, vals);
}

template<class Scalar, class LO, class GO, class Node>
bool
// TPETRA_DEPRECATED
BlockVector<Scalar, LO, GO, Node>::
getGlobalRowView (const GO globalRowIndex, Scalar*& vals) {
return ((base_type*) this)->getGlobalRowView (globalRowIndex, 0, vals);
}

template<class Scalar, class LO, class GO, class Node>
typename BlockVector<Scalar, LO, GO, Node>::little_host_vec_type
TPETRA_DEPRECATED
// TPETRA_DEPRECATED
BlockVector<Scalar, LO, GO, Node>::
getLocalBlock (const LO localRowIndex) const
getLocalBlock (const LO localRowIndex)
{
if (! this->isValidLocalMeshIndex (localRowIndex)) {
return little_host_vec_type ();
Expand All @@ -165,31 +181,30 @@ namespace Tpetra {
template<class Scalar, class LO, class GO, class Node>
typename BlockVector<Scalar, LO, GO, Node>::const_little_host_vec_type
BlockVector<Scalar, LO, GO, Node>::
getLocalBlock (const LO localRowIndex, Access::ReadOnlyStruct) const
getLocalBlockHost (const LO localRowIndex, Access::ReadOnlyStruct) const
{
return ((const base_type*) this)->getLocalBlock(localRowIndex, 0,
Access::ReadOnly);
return ((const base_type*) this)->getLocalBlockHost(localRowIndex, 0,
Access::ReadOnly);
}

template<class Scalar, class LO, class GO, class Node>
typename BlockVector<Scalar, LO, GO, Node>::little_host_vec_type
BlockVector<Scalar, LO, GO, Node>::
getLocalBlock (const LO localRowIndex, Access::ReadWriteStruct)
getLocalBlockHost (const LO localRowIndex, Access::ReadWriteStruct)
{
return ((base_type*) this)->getLocalBlock(localRowIndex, 0,
Access::ReadWrite);
return ((base_type*) this)->getLocalBlockHost(localRowIndex, 0,
Access::ReadWrite);
}

template<class Scalar, class LO, class GO, class Node>
typename BlockVector<Scalar, LO, GO, Node>::little_host_vec_type
BlockVector<Scalar, LO, GO, Node>::
getLocalBlock (const LO localRowIndex, Access::OverwriteAllStruct)
getLocalBlockHost (const LO localRowIndex, Access::OverwriteAllStruct)
{
return ((base_type*) this)->getLocalBlock(localRowIndex, 0,
Access::OverwriteAll);
return ((base_type*) this)->getLocalBlockHost(localRowIndex, 0,
Access::OverwriteAll);
}


} // namespace Tpetra

//
Expand Down
Loading

0 comments on commit a2127a9

Please sign in to comment.