Skip to content

Commit

Permalink
Adding begin/endAssembly to Tpetra::FEMultiVector
Browse files Browse the repository at this point in the history
  • Loading branch information
tjfulle committed Apr 13, 2021
1 parent f07cabe commit 9920fce
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 10 deletions.
17 changes: 17 additions & 0 deletions packages/tpetra/core/src/Tpetra_FEMultiVector_decl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,15 @@ namespace Tpetra {
//! Declare the end of a phase of owned+shared modifications.
void endFill ();

//! Declare the beginning of a phase of owned+shared modifications.
void beginAssembly ();

//! Declare the end of a phase of owned+shared modifications.
void endAssembly ();

void beginModify ();

This comment has been minimized.

Copy link
@skennon10

skennon10 Apr 13, 2021

Contributor

Suggest add comments like for begin/endAssembly.

void endModify ();

/// \brief Declare the end of a phase of owned+shared
/// modifications; same as endFill().
void globalAssemble ();
Expand Down Expand Up @@ -199,6 +208,14 @@ namespace Tpetra {
FE_ACTIVE_OWNED
};

enum class FillState
{
open, // matrix is "open". Values can freely summed in to and replaced
modify, // matrix is open for modification. *local* values can be replaced
closed
};
Teuchos::RCP<FillState> fillState_;

//! Whichever MultiVector is <i>not</i> currently active.
Teuchos::RCP<base_type> inactiveMultiVector_;

Expand Down
47 changes: 47 additions & 0 deletions packages/tpetra/core/src/Tpetra_FEMultiVector_def.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ FEMultiVector (const Teuchos::RCP<const map_type>& map,
inactiveMultiVector_ =
Teuchos::rcp (new base_type (importer_->getSourceMap (), dv));
}
fillState_ = Teuchos::rcp(new FillState(FillState::closed));
}

template <class Scalar, class LocalOrdinal, class GlobalOrdinal, class Node>
Expand Down Expand Up @@ -128,6 +129,52 @@ endFill ()
}
}

template<class Scalar, class LocalOrdinal, class GlobalOrdinal, class Node>
void FEMultiVector<Scalar, LocalOrdinal, GlobalOrdinal, Node>::beginAssembly() {
const char tfecfFuncName[] = "FEMultiVector::beginAssembly: ";
TEUCHOS_TEST_FOR_EXCEPTION_CLASS_FUNC(
*fillState_ != FillState::closed,
std::runtime_error,
"Cannot beginAssembly, matrix is not in a closed state"
);
*fillState_ = FillState::open;
this->beginFill();
}

template<class Scalar, class LocalOrdinal, class GlobalOrdinal, class Node>
void FEMultiVector<Scalar, LocalOrdinal, GlobalOrdinal, Node>::endAssembly() {
const char tfecfFuncName[] = "FEMultiVector::endAssembly: ";
TEUCHOS_TEST_FOR_EXCEPTION_CLASS_FUNC(
*fillState_ != FillState::open,
std::runtime_error,
"Cannot endAssembly, matrix is not open to fill."

This comment has been minimized.

Copy link
@skennon10

skennon10 Apr 13, 2021

Contributor

nit: replace "to fill" with "for assembly"?

);
*fillState_ = FillState::closed;
this->endFill();
}

template<class Scalar, class LocalOrdinal, class GlobalOrdinal, class Node>
void FEMultiVector<Scalar, LocalOrdinal, GlobalOrdinal, Node>::beginModify() {
const char tfecfFuncName[] = "FEMultiVector::beginModify: ";
TEUCHOS_TEST_FOR_EXCEPTION_CLASS_FUNC(
*fillState_ != FillState::closed,
std::runtime_error,
"Cannot beginModify, matrix is not in a closed state"
);
*fillState_ = FillState::modify;
}

template<class Scalar, class LocalOrdinal, class GlobalOrdinal, class Node>
void FEMultiVector<Scalar, LocalOrdinal, GlobalOrdinal, Node>::endModify() {
const char tfecfFuncName[] = "FEMultiVector::endModify: ";
TEUCHOS_TEST_FOR_EXCEPTION_CLASS_FUNC(
*fillState_ != FillState::modify,
std::runtime_error,
"Cannot endModify, matrix is not open to modify."
);
*fillState_ = FillState::closed;
}

template <class Scalar, class LocalOrdinal, class GlobalOrdinal, class Node>
void
FEMultiVector<Scalar, LocalOrdinal, GlobalOrdinal, Node>::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,11 @@ namespace {
Vdomain.doExport(Vcolumn,*importer,Tpetra::ADD);


Vfe.beginFill();
Vfe.beginAssembly();
Vfe.putScalar(ZERO);
for(size_t i=0; i<Ndomain; i++)
Vfe.getDataNonConst(0)[i] = domainMap->getGlobalElement(i);
Vfe.endFill();
Vfe.endAssembly();
vector_check(Ndomain,Vfe,Vdomain);

// 2) Test column -> domain (with off-proc addition)
Expand All @@ -172,9 +172,9 @@ namespace {
Vdomain.doExport(Vcolumn,*importer,Tpetra::ADD);

Vfe.putScalar(ZERO);
Vfe.beginFill();
Vfe.beginAssembly();
Vfe.putScalar(ONE);
Vfe.endFill();
Vfe.endAssembly();
vector_check(Ncolumn,Vfe,Vdomain);
} catch (std::exception& e) {
err << "Proc " << myRank << ": " << e.what () << std::endl;
Expand Down Expand Up @@ -219,10 +219,10 @@ namespace {
Tpetra::FEMultiVector<Scalar,LO,GO,Node> v2(map,importer,1);
Tpetra::FEMultiVector<Scalar,LO,GO,Node> v3(map,importer,1);

// Just check to make sure beginFill() / endFill() compile
Tpetra::beginFill(v1,v2,v3);
// Just check to make sure beginAssembly() / endAssembly() compile
Tpetra::beginAssembly(v1,v2,v3);

Tpetra::endFill(v1,v2,v3);
Tpetra::endAssembly(v1,v2,v3);
}

#define UNIT_TEST_GROUP( SC, LO, GO, NO ) \
Expand Down
6 changes: 3 additions & 3 deletions packages/tpetra/core/test/FEMultiVector/Fix3101.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,13 @@ int FEMultiVectorTest::intTest()

// Add contributions to owned vertices and copies of off-processor vertices
try {
femv->beginFill();
femv->beginAssembly();
for (lno_t i = 0; i < nLocalOwned + nLocalCopy; i++) {
gno_t gid = mapWithCopies->getGlobalElement(i);
femv->replaceGlobalValue(gid, 0, gid);
femv->replaceGlobalValue(gid, 1, me);
}
femv->endFill();
femv->endAssembly();
}
catch (std::exception &e) {
std::cout << "FAIL: Exception thrown in Fill: " << e.what() << std::endl;
Expand All @@ -140,7 +140,7 @@ int FEMultiVectorTest::intTest()

printFEMV("After doOwnedToOwnedPlusShared ");

// Check results: after ADD in endFill,
// Check results: after ADD in endAssembly,
// - overlapping entries of vec 0 should be 2 * gid
// nonoverlapping entries of vec 0 should be gid
// - overlapping entries of vec 1 should be me + (np + me-1) % np;
Expand Down

0 comments on commit 9920fce

Please sign in to comment.