Skip to content

Commit

Permalink
Belos: Caching of state vectors
Browse files Browse the repository at this point in the history
Signed-off-by: Christian Glusa <[email protected]>
  • Loading branch information
cgcgcg committed Jan 24, 2025
1 parent 6c60489 commit bbad992
Show file tree
Hide file tree
Showing 7 changed files with 249 additions and 311 deletions.
95 changes: 25 additions & 70 deletions packages/belos/src/BelosBlockCGIter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class BlockCGIter : virtual public CGIteration<ScalarType, MV, OP> {
TEUCHOS_TEST_FOR_EXCEPTION(true, std::logic_error, "Stub");
}

void initializeCG (CGIterationState<ScalarType,MV>& /* newstate */) {
void initializeCG (CGIterationState<ScalarType,MV>& /* newstate */, Teuchos::RCP<MV> /* R_0 */) {
TEUCHOS_TEST_FOR_EXCEPTION(true, std::logic_error, "Stub");
}

Expand Down Expand Up @@ -118,11 +118,6 @@ class BlockCGIter : virtual public CGIteration<ScalarType, MV, OP> {
TEUCHOS_TEST_FOR_EXCEPTION(true, std::logic_error, "Stub");
}


private:
void setStateSize() {
TEUCHOS_TEST_FOR_EXCEPTION(true, std::logic_error, "Stub");
}
};

/// \brief Partial specialization for ScalarType types for which
Expand Down Expand Up @@ -192,15 +187,15 @@ class BlockCGIter<ScalarType, MV, OP, true> :
* \note For any pointer in \c newstate which directly points to the multivectors in
* the solver, the data is not copied.
*/
void initializeCG(CGIterationState<ScalarType,MV>& newstate);
void initializeCG(CGIterationState<ScalarType,MV>& newstate, Teuchos::RCP<MV> R_0);

/*! \brief Initialize the solver with the initial vectors from the linear problem
* or random data.
*/
void initialize()
{
CGIterationState<ScalarType,MV> empty;
initializeCG(empty);
initializeCG(empty, Teuchos::null);
}

/*! \brief Get the current state of the linear solver.
Expand All @@ -218,6 +213,13 @@ class BlockCGIter<ScalarType, MV, OP, true> :
return state;
}

void setState(CGIterationState<ScalarType, MV> state) {
R_ = state.R;
Z_ = state.Z;
P_ = state.P;
AP_ = state.AP;
}

//@}


Expand Down Expand Up @@ -276,9 +278,6 @@ class BlockCGIter<ScalarType, MV, OP, true> :

//
// Internal methods
//
//! Method for initalizing the state storage needed by block CG.
void setStateSize();

//
// Classes inputed through constructor that define the linear problem to be solved.
Expand Down Expand Up @@ -348,40 +347,6 @@ class BlockCGIter<ScalarType, MV, OP, true> :
setBlockSize( bs );
}

template <class ScalarType, class MV, class OP>
void BlockCGIter<ScalarType,MV,OP,true>::setStateSize ()
{
if (! stateStorageInitialized_) {
// Check if there is any multivector to clone from.
Teuchos::RCP<const MV> lhsMV = lp_->getLHS();
Teuchos::RCP<const MV> rhsMV = lp_->getRHS();
if (lhsMV == Teuchos::null && rhsMV == Teuchos::null) {
stateStorageInitialized_ = false;
return;
}
else {
// Initialize the state storage If the subspace has not be
// initialized before, generate it using the LHS or RHS from
// lp_.
if (R_ == Teuchos::null || MVT::GetNumberVecs(*R_)!=blockSize_) {
// Get the multivector that is not null.
Teuchos::RCP<const MV> tmp = ( (rhsMV!=Teuchos::null)? rhsMV: lhsMV );
TEUCHOS_TEST_FOR_EXCEPTION
(tmp == Teuchos::null,std:: invalid_argument,
"Belos::BlockCGIter::setStateSize: LinearProblem lacks "
"multivectors from which to clone.");
R_ = MVT::Clone (*tmp, blockSize_);
Z_ = MVT::Clone (*tmp, blockSize_);
P_ = MVT::Clone (*tmp, blockSize_);
AP_ = MVT::Clone (*tmp, blockSize_);
}

// State storage has now been initialized.
stateStorageInitialized_ = true;
}
}
}

template <class ScalarType, class MV, class OP>
void BlockCGIter<ScalarType,MV,OP,true>::setBlockSize (int blockSize)
{
Expand All @@ -398,55 +363,50 @@ class BlockCGIter<ScalarType, MV, OP, true> :
}
blockSize_ = blockSize;
initialized_ = false;
// Use the current blockSize_ to initialize the state storage.
setStateSize ();
}

template <class ScalarType, class MV, class OP>
void BlockCGIter<ScalarType,MV,OP,true>::
initializeCG (CGIterationState<ScalarType,MV>& newstate)
initializeCG (CGIterationState<ScalarType,MV>& newstate, Teuchos::RCP<MV> R_0)
{
const char prefix[] = "Belos::BlockCGIter::initialize: ";

// Initialize the state storage if it isn't already.
if (! stateStorageInitialized_) {
setStateSize();
Teuchos::RCP<const MV> lhsMV = lp_->getLHS();
Teuchos::RCP<const MV> rhsMV = lp_->getRHS();
Teuchos::RCP<const MV> tmp = ( (rhsMV!=Teuchos::null)? rhsMV: lhsMV );
if (!newstate.isInitialized() || !newstate.matches(tmp, blockSize_)) {
newstate.initialize(tmp, BlockCG, blockSize_);
}

TEUCHOS_TEST_FOR_EXCEPTION
(! stateStorageInitialized_, std::invalid_argument,
prefix << "Cannot initialize state storage!");
setState(newstate);

// NOTE: In BlockCGIter R_, the initial residual, is required!!!
const char errstr[] = "Specified multivectors must have a consistent "
"length and width.";

// Create convenience variables for zero and one.
//const MagnitudeType zero = Teuchos::ScalarTraits<MagnitudeType>::zero(); // unused

if (newstate.R != Teuchos::null) {
{

TEUCHOS_TEST_FOR_EXCEPTION
(MVT::GetGlobalLength(*newstate.R) != MVT::GetGlobalLength(*R_),
(MVT::GetGlobalLength(*R_0) != MVT::GetGlobalLength(*R_),
std::invalid_argument, prefix << errstr );
TEUCHOS_TEST_FOR_EXCEPTION
(MVT::GetNumberVecs(*newstate.R) != blockSize_,
(MVT::GetNumberVecs(*R_0) != blockSize_,
std::invalid_argument, prefix << errstr );

// Copy basis vectors from newstate into V
if (newstate.R != R_) {
if (R_0 != R_) {
// copy over the initial residual (unpreconditioned).
MVT::Assign( *newstate.R, *R_ );
MVT::Assign( *R_0, *R_ );
}
// Compute initial direction vectors
// Initially, they are set to the preconditioned residuals
//
if ( lp_->getLeftPrec() != Teuchos::null ) {
lp_->applyLeftPrec( *R_, *Z_ );
if ( lp_->getRightPrec() != Teuchos::null ) {
Teuchos::RCP<MV> tmp = MVT::Clone( *Z_, blockSize_ );
lp_->applyRightPrec( *Z_, *tmp );
Z_ = tmp;
Teuchos::RCP<MV> tmp2 = MVT::Clone( *Z_, blockSize_ );
lp_->applyRightPrec( *Z_, *tmp2 );
Z_ = tmp2;
}
}
else if ( lp_->getRightPrec() != Teuchos::null ) {
Expand All @@ -457,11 +417,6 @@ class BlockCGIter<ScalarType, MV, OP, true> :
}
MVT::Assign( *Z_, *P_ );
}
else {
TEUCHOS_TEST_FOR_EXCEPTION
(newstate.R == Teuchos::null, std::invalid_argument,
prefix << "BlockCGStateIterState does not have initial residual.");
}

// The solver is initialized
initialized_ = true;
Expand Down
12 changes: 6 additions & 6 deletions packages/belos/src/BelosBlockCGSolMgr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,8 @@ namespace Belos {
bool assertPositiveDefiniteness_;
bool foldConvergenceDetectionIntoAllreduce_;

Teuchos::RCP<CGIterationState<ScalarType, MV >> state_;

//! Prefix label for all the timers.
std::string label_;

Expand Down Expand Up @@ -900,9 +902,9 @@ ReturnType BlockCGSolMgr<ScalarType,MV,OP,true>::solve() {
RCP<MV> R_0 = MVT::CloneViewNonConst( *(rcp_const_cast<MV>(problem_->getInitResVec())), currIdx );

// Set the new state and initialize the solver.
CGIterationState<ScalarType,MV> newstate;
newstate.R = R_0;
block_cg_iter->initializeCG(newstate);
if (state_.is_null())
state_ = Teuchos::rcp(new CGIterationState<ScalarType, MV>());
block_cg_iter->initializeCG(*state_, R_0);

while(1) {

Expand Down Expand Up @@ -964,9 +966,7 @@ ReturnType BlockCGSolMgr<ScalarType,MV,OP,true>::solve() {
block_cg_iter->setBlockSize( have );

// Set the new state and initialize the solver.
CGIterationState<ScalarType,MV> defstate;
defstate.R = R_0;
block_cg_iter->initializeCG(defstate);
block_cg_iter->initializeCG(*state_, R_0);
}
//
// None of the linear systems converged. Check whether the
Expand Down
Loading

0 comments on commit bbad992

Please sign in to comment.