Skip to content

Commit

Permalink
update FECrsMatrix to add begin/endAssembly and begin/endModify
Browse files Browse the repository at this point in the history
  • Loading branch information
tjfulle committed Mar 24, 2021
1 parent 4a18d9b commit 7df3306
Show file tree
Hide file tree
Showing 4 changed files with 377 additions and 69 deletions.
24 changes: 14 additions & 10 deletions packages/tpetra/core/src/Tpetra_CrsMatrix_decl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -842,13 +842,14 @@ namespace Tpetra {
MultiVector<S2,LO2,GO2,N2> & R);

// This friend declaration allows for batching of apply calls
template <class MatrixArray, class MultiVectorArray>
friend void batchedApply(const MatrixArray &Matrices,
template <class MatrixArray, class MultiVectorArray>
friend void batchedApply(const MatrixArray &Matrices,
const typename std::remove_pointer<typename MultiVectorArray::value_type>::type & X,
MultiVectorArray &Y,
typename std::remove_pointer<typename MatrixArray::value_type>::type::scalar_type alpha,
typename std::remove_pointer<typename MatrixArray::value_type>::type::scalar_type beta,
Teuchos::RCP<Teuchos::ParameterList> params);

public:
//@}
//! @name Methods for inserting, modifying, or removing entries
Expand Down Expand Up @@ -1013,7 +1014,7 @@ namespace Tpetra {
const Scalar vals[],
const LocalOrdinal cols[]);

private:
protected:
/// \brief Implementation detail of replaceGlobalValues.
///
/// \param rowVals [in/out] On input: Values of the row of the
Expand All @@ -1024,7 +1025,7 @@ namespace Tpetra {
/// \param inds [in] Global column indices of that row to modify.
/// \param newVals [in] For each k, replace the value in rowVals
/// corresponding to local column index inds[k] with newVals[k].
LocalOrdinal
virtual LocalOrdinal
replaceGlobalValuesImpl (impl_scalar_type rowVals[],
const crs_graph_type& graph,
const RowInfo& rowInfo,
Expand Down Expand Up @@ -1102,7 +1103,7 @@ namespace Tpetra {
const Scalar vals[],
const GlobalOrdinal cols[]) const;

private:
protected:
/// \brief Implementation detail of replaceLocalValues.
///
/// \param rowVals [in/out] On input: Values of the row of the
Expand All @@ -1113,7 +1114,7 @@ namespace Tpetra {
/// \param inds [in] Local column indices of that row to modify.
/// \param newVals [in] For each k, replace the value in rowVals
/// corresponding to local column index inds[k] with newVals[k].
LocalOrdinal
virtual LocalOrdinal
replaceLocalValuesImpl (impl_scalar_type rowVals[],
const crs_graph_type& graph,
const RowInfo& rowInfo,
Expand Down Expand Up @@ -1229,7 +1230,8 @@ namespace Tpetra {
/// error other than one or more invalid column indices, this
/// method returns
/// Teuchos::OrdinalTraits<LocalOrdinal>::invalid().
LocalOrdinal
protected:
virtual LocalOrdinal
sumIntoGlobalValuesImpl (impl_scalar_type rowVals[],
const crs_graph_type& graph,
const RowInfo& rowInfo,
Expand Down Expand Up @@ -1310,7 +1312,7 @@ namespace Tpetra {
const GlobalOrdinal cols[],
const bool atomic = useAtomicUpdatesByDefault);

private:
protected:
/// \brief Implementation detail of sumIntoLocalValues.
///
/// \param rowVals [in/out] On input: Values of the row of the
Expand All @@ -1323,7 +1325,7 @@ namespace Tpetra {
/// corresponding to local column index inds[k] by newVals[k].
/// \param atomic [in] Whether to use atomic updates (+=) when
/// incrementing values.
LocalOrdinal
virtual LocalOrdinal
sumIntoLocalValuesImpl (impl_scalar_type rowVals[],
const crs_graph_type& graph,
const RowInfo& rowInfo,
Expand Down Expand Up @@ -3469,13 +3471,15 @@ namespace Tpetra {
/// are in the column Map on the calling process. That is, the
/// entries of gblColInds (and their corresponding vals entries)
/// are "prefiltered," if we needed to filter them.
void
protected:
virtual void
insertGlobalValuesImpl (crs_graph_type& graph,
RowInfo& rowInfo,
const GlobalOrdinal gblColInds[],
const impl_scalar_type vals[],
const size_t numInputEnt);

private:
/// \brief Like insertGlobalValues(), but with column filtering.
///
/// "Column filtering" means that if the matrix has a column Map,
Expand Down
75 changes: 75 additions & 0 deletions packages/tpetra/core/src/Tpetra_FECrsMatrix_decl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,73 @@ class FECrsMatrix :
//! Activates the owned+shared mode for assembly
void beginFill();

//! Migrates data to the owned mode
void endAssembly();

//! Activates the owned+shared mode for assembly
void beginAssembly();

//! Closes modification phase
void endModify();

//! Activates the owned mode for modifying local values
void beginModify();

private:

/// \brief Whether sumIntoLocalValues and sumIntoGlobalValues
/// should use atomic updates by default.
///
/// \warning This is an implementation detail.
static const bool useAtomicUpdatesByDefault =
#ifdef KOKKOS_ENABLE_SERIAL
! std::is_same<execution_space, Kokkos::Serial>::value;
#else
true;
#endif // KOKKOS_ENABLE_SERIAL

//! Overloads of modification methods
LocalOrdinal
replaceGlobalValuesImpl (impl_scalar_type rowVals[],
const crs_graph_type& graph,
const RowInfo& rowInfo,
const GlobalOrdinal inds[],
const impl_scalar_type newVals[],
const LocalOrdinal numElts) const;

LocalOrdinal
replaceLocalValuesImpl (impl_scalar_type rowVals[],
const crs_graph_type& graph,
const RowInfo& rowInfo,
const LocalOrdinal inds[],
const impl_scalar_type newVals[],
const LocalOrdinal numElts) const;

LocalOrdinal
sumIntoGlobalValuesImpl (impl_scalar_type rowVals[],
const crs_graph_type& graph,
const RowInfo& rowInfo,
const GlobalOrdinal inds[],
const impl_scalar_type newVals[],
const LocalOrdinal numElts,
const bool atomic = useAtomicUpdatesByDefault) const;

LocalOrdinal
sumIntoLocalValuesImpl (impl_scalar_type rowVals[],
const crs_graph_type& graph,
const RowInfo& rowInfo,
const LocalOrdinal inds[],
const impl_scalar_type newVals[],
const LocalOrdinal numElts,
const bool atomic = useAtomicUpdatesByDefault) const;

void
insertGlobalValuesImpl (crs_graph_type& graph,
RowInfo& rowInfo,
const GlobalOrdinal gblColInds[],
const impl_scalar_type vals[],
const size_t numInputEnt);

protected:
/// \brief Migrate data from the owned+shared to the owned matrix
/// Since this is non-unique -> unique, we need a combine mode.
Expand Down Expand Up @@ -254,6 +321,14 @@ class FECrsMatrix :
// This is in RCP to make shallow copies of the FECrsMatrix work correctly
Teuchos::RCP<FEWhichActive> activeCrsMatrix_;

enum class FillState
{
open, // matrix is "open". Values can freely summed in to and replaced
modify, // matrix is open for modification. *local* values can be replaced
closed
};
Teuchos::RCP<FillState> fillState_;

}; // end class FECrsMatrix


Expand Down
159 changes: 156 additions & 3 deletions packages/tpetra/core/src/Tpetra_FECrsMatrix_def.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ namespace Tpetra {
template<class Scalar, class LocalOrdinal, class GlobalOrdinal, class Node>
FECrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Node>::
FECrsMatrix(const Teuchos::RCP<const fe_crs_graph_type>& graph,
const Teuchos::RCP<Teuchos::ParameterList>& params) :
const Teuchos::RCP<Teuchos::ParameterList>& params) :
// We want the OWNED_PLUS_SHARED graph here
// NOTE: The casts below are terrible, but necesssary
crs_matrix_type( graph->inactiveCrsGraph_.is_null() ? Teuchos::rcp_const_cast<crs_graph_type>(Teuchos::rcp_dynamic_cast<const crs_graph_type>(graph)) : graph->inactiveCrsGraph_,params),
Expand All @@ -76,13 +76,15 @@ FECrsMatrix(const Teuchos::RCP<const fe_crs_graph_type>& graph,

// Make an "inactive" matrix, if we need to
if(!graph->inactiveCrsGraph_.is_null() ) {
// We are *requiring* memory aliasing here, so we'll grab the first chunk of the Owned+Shared matrix's values array to make the
// We are *requiring* memory aliasing here, so we'll grab the first chunk of the Owned+Shared matrix's values array to make the
// guy for the Owned matrix.
values_type myvals = this->getLocalMatrix().values;

size_t numOwnedVals = graph->getLocalGraph().entries.extent(0); // OwnedVals
inactiveCrsMatrix_ = Teuchos::rcp(new crs_matrix_type(graph,Kokkos::subview(myvals,Kokkos::pair<size_t,size_t>(0,numOwnedVals))));
}

fillState_ = Teuchos::rcp(new FillState(FillState::closed));
}


Expand Down Expand Up @@ -111,7 +113,7 @@ void FECrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Node>::switchActiveCrsMatr
*activeCrsMatrix_ = FE_ACTIVE_OWNED_PLUS_SHARED;

if(inactiveCrsMatrix_.is_null()) return;

this->swap(*inactiveCrsMatrix_);

}//end switchActiveCrsMatrix
Expand All @@ -138,8 +140,159 @@ void FECrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Node>::beginFill() {
this->resumeFill();
}

template<class Scalar, class LocalOrdinal, class GlobalOrdinal, class Node>
void FECrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Node>::beginAssembly() {
const char tfecfFuncName[] = "FECrsMatrix::beginAssembly: ";
TEUCHOS_TEST_FOR_EXCEPTION_CLASS_FUNC(
*fillState_ != FillState::closed,
std::runtime_error,
"Cannot beginAssembly, matrix is not in a closed state"
);
*fillState_ = FillState::open;
this->beginFill();
}

template<class Scalar, class LocalOrdinal, class GlobalOrdinal, class Node>
void FECrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Node>::endAssembly() {
const char tfecfFuncName[] = "FECrsMatrix::endAssembly: ";
TEUCHOS_TEST_FOR_EXCEPTION_CLASS_FUNC(
*fillState_ != FillState::open,
std::runtime_error,
"Cannot endAssembly, matrix is not open to fill."
);
*fillState_ = FillState::closed;
this->endFill();
}

template<class Scalar, class LocalOrdinal, class GlobalOrdinal, class Node>
void FECrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Node>::beginModify() {
const char tfecfFuncName[] = "FECrsMatrix::beginModify: ";
TEUCHOS_TEST_FOR_EXCEPTION_CLASS_FUNC(
*fillState_ != FillState::closed,
std::runtime_error,
"Cannot beginModify, matrix is not in a closed state"
);
*fillState_ = FillState::modify;
this->resumeFill();
}

template<class Scalar, class LocalOrdinal, class GlobalOrdinal, class Node>
void FECrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Node>::endModify() {
const char tfecfFuncName[] = "FECrsMatrix::endModify: ";
TEUCHOS_TEST_FOR_EXCEPTION_CLASS_FUNC(
*fillState_ != FillState::modify,
std::runtime_error,
"Cannot endModify, matrix is not open to modify."
);
*fillState_ = FillState::closed;
this->fillComplete();
}

template<class Scalar, class LocalOrdinal, class GlobalOrdinal, class Node>
LocalOrdinal
FECrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Node>::replaceGlobalValuesImpl(
impl_scalar_type rowVals[],
const crs_graph_type& graph,
const RowInfo& rowInfo,
const GlobalOrdinal inds[],
const impl_scalar_type newVals[],
const LocalOrdinal numElts) const
{
const char tfecfFuncName[] = "FECrsMatrix::replaceGlobalValues: ";
TEUCHOS_TEST_FOR_EXCEPTION_CLASS_FUNC(
*fillState_ != FillState::open,
std::runtime_error,
"Cannot replace global values, matrix is not open to fill."
);
return CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Node>::replaceGlobalValuesImpl(
rowVals, graph, rowInfo, inds, newVals, numElts
);
}

template<class Scalar, class LocalOrdinal, class GlobalOrdinal, class Node>
LocalOrdinal
FECrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Node>::replaceLocalValuesImpl(
impl_scalar_type rowVals[],
const crs_graph_type& graph,
const RowInfo& rowInfo,
const LocalOrdinal inds[],
const impl_scalar_type newVals[],
const LocalOrdinal numElts) const
{
const char tfecfFuncName[] = "FECrsMatrix::replaceLocalValues: ";
TEUCHOS_TEST_FOR_EXCEPTION_CLASS_FUNC(
*fillState_ != FillState::open && *fillState_ != FillState::modify,
std::runtime_error,
"Cannot replace local values, matrix is not open to fill/modify."
);
return CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Node>::replaceLocalValuesImpl(
rowVals, graph, rowInfo, inds, newVals, numElts
);
}

template<class Scalar, class LocalOrdinal, class GlobalOrdinal, class Node>
LocalOrdinal
FECrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Node>::sumIntoGlobalValuesImpl(
impl_scalar_type rowVals[],
const crs_graph_type& graph,
const RowInfo& rowInfo,
const GlobalOrdinal inds[],
const impl_scalar_type newVals[],
const LocalOrdinal numElts,
const bool atomic) const
{
const char tfecfFuncName[] = "FECrsMatrix::sumIntoGlobalValues: ";
TEUCHOS_TEST_FOR_EXCEPTION_CLASS_FUNC(
*fillState_ != FillState::open,
std::runtime_error,
"Cannot sum in to global values, matrix is not open to fill."
);
return CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Node>::sumIntoGlobalValuesImpl(
rowVals, graph, rowInfo, inds, newVals, numElts, atomic
);
}

template<class Scalar, class LocalOrdinal, class GlobalOrdinal, class Node>
LocalOrdinal
FECrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Node>::sumIntoLocalValuesImpl(
impl_scalar_type rowVals[],
const crs_graph_type& graph,
const RowInfo& rowInfo,
const LocalOrdinal inds[],
const impl_scalar_type newVals[],
const LocalOrdinal numElts,
const bool atomic) const
{
const char tfecfFuncName[] = "FECrsMatrix::sumIntoLocalValues: ";
TEUCHOS_TEST_FOR_EXCEPTION_CLASS_FUNC(
*fillState_ != FillState::open,
std::runtime_error,
"Cannot sum in to local values, matrix is not open to fill."
);
return CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Node>::sumIntoLocalValuesImpl(
rowVals, graph, rowInfo, inds, newVals, numElts, atomic
);
}

template<class Scalar, class LocalOrdinal, class GlobalOrdinal, class Node>
void
FECrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Node>::insertGlobalValuesImpl(
crs_graph_type& graph,
RowInfo& rowInfo,
const GlobalOrdinal gblColInds[],
const impl_scalar_type vals[],
const size_t numInputEnt)
{
const char tfecfFuncName[] = "FECrsMatrix::insertGlobalValues: ";
TEUCHOS_TEST_FOR_EXCEPTION_CLASS_FUNC(
*fillState_ != FillState::open,
std::runtime_error,
"Cannot sum in to local values, matrix is not open to fill."
);
return CrsMatrix<Scalar, LocalOrdinal, GlobalOrdinal, Node>::insertGlobalValuesImpl(
graph, rowInfo, gblColInds, vals, numInputEnt
);
}

} // end namespace Tpetra

Expand Down
Loading

0 comments on commit 7df3306

Please sign in to comment.