Skip to content

Commit

Permalink
Merge pull request #13751 from cgcgcg/belosCGcaching
Browse files Browse the repository at this point in the history
Belos CG: Caching of state vectors
  • Loading branch information
cgcgcg authored Feb 5, 2025
2 parents d1b229a + aff1279 commit 363438e
Show file tree
Hide file tree
Showing 8 changed files with 603 additions and 524 deletions.
166 changes: 80 additions & 86 deletions packages/belos/src/BelosBlockCGIter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,41 @@

namespace Belos {

//! @name BlockCGIteration Structures
//@{

/** \brief Structure to contain pointers to BlockCGIteration state variables.
*
* This struct is utilized by BlockCGIteration::initialize() and BlockCGIteration::getState().
*/
template <class ScalarType, class MV>
class BlockCGIterationState : public CGIterationStateBase<ScalarType, MV> {

public:
BlockCGIterationState() = default;

BlockCGIterationState(Teuchos::RCP<const MV> tmp) {
initialize(tmp);
}

virtual ~BlockCGIterationState() = default;

void initialize(Teuchos::RCP<const MV> tmp, int _numVectors) {
using MVT = MultiVecTraits<ScalarType, MV>;
this->R = MVT::Clone( *tmp, _numVectors );
this->Z = MVT::Clone( *tmp, _numVectors );
this->P = MVT::Clone( *tmp, _numVectors );
this->AP = MVT::Clone(*tmp, _numVectors );

CGIterationStateBase<ScalarType, MV>::initialize(tmp, _numVectors);
}

bool matches(Teuchos::RCP<const MV> tmp, int _numVectors=1) const {
return CGIterationStateBase<ScalarType, MV>::matches(tmp, _numVectors);
}

};

/// \class BlockCGIter
/// \brief Implementation of the block preconditioned Conjugate
/// Gradient (CG) iteration.
Expand Down Expand Up @@ -69,15 +104,19 @@ class BlockCGIter : virtual public CGIteration<ScalarType, MV, OP> {
TEUCHOS_TEST_FOR_EXCEPTION(true, std::logic_error, "Stub");
}

void initializeCG (CGIterationState<ScalarType,MV>& /* newstate */) {
void initializeCG (Teuchos::RCP<BlockCGIterationState<ScalarType,MV> > /* newstate */, Teuchos::RCP<MV> /* R_0 */) {
TEUCHOS_TEST_FOR_EXCEPTION(true, std::logic_error, "Stub");
}

void initialize () {
TEUCHOS_TEST_FOR_EXCEPTION(true, std::logic_error, "Stub");
}

CGIterationState<ScalarType,MV> getState () const {
Teuchos::RCP<CGIterationStateBase<ScalarType,MV> > getState () const {
TEUCHOS_TEST_FOR_EXCEPTION(true, std::logic_error, "Stub");
}

void setState(Teuchos::RCP<CGIterationStateBase<ScalarType,MV> > state) {
TEUCHOS_TEST_FOR_EXCEPTION(true, std::logic_error, "Stub");
}

Expand Down Expand Up @@ -118,11 +157,6 @@ class BlockCGIter : virtual public CGIteration<ScalarType, MV, OP> {
TEUCHOS_TEST_FOR_EXCEPTION(true, std::logic_error, "Stub");
}


private:
void setStateSize() {
TEUCHOS_TEST_FOR_EXCEPTION(true, std::logic_error, "Stub");
}
};

/// \brief Partial specialization for ScalarType types for which
Expand All @@ -137,10 +171,10 @@ class BlockCGIter<ScalarType, MV, OP, true> :
//
// Convenience typedefs
//
typedef MultiVecTraits<ScalarType,MV> MVT;
typedef OperatorTraits<ScalarType,MV,OP> OPT;
typedef Teuchos::ScalarTraits<ScalarType> SCT;
typedef typename SCT::magnitudeType MagnitudeType;
using MVT = MultiVecTraits<ScalarType, MV>;
using OPT = OperatorTraits<ScalarType, MV, OP>;
using SCT = Teuchos::ScalarTraits<ScalarType>;
using MagnitudeType = typename SCT::magnitudeType;

//! @name Constructors/Destructor
//@{
Expand All @@ -157,7 +191,7 @@ class BlockCGIter<ScalarType, MV, OP, true> :
Teuchos::ParameterList &params );

//! Destructor.
virtual ~BlockCGIter() {};
virtual ~BlockCGIter() = default;
//@}


Expand Down Expand Up @@ -192,32 +226,39 @@ class BlockCGIter<ScalarType, MV, OP, true> :
* \note For any pointer in \c newstate which directly points to the multivectors in
* the solver, the data is not copied.
*/
void initializeCG(CGIterationState<ScalarType,MV>& newstate);
void initializeCG(Teuchos::RCP<CGIterationStateBase<ScalarType,MV> > newstate, Teuchos::RCP<MV> R_0);

/*! \brief Initialize the solver with the initial vectors from the linear problem
* or random data.
*/
void initialize()
{
CGIterationState<ScalarType,MV> empty;
initializeCG(empty);
initializeCG(Teuchos::null, Teuchos::null);
}

/*! \brief Get the current state of the linear solver.
*
* The data is only valid if isInitialized() == \c true.
*
* \returns A CGIterationState object containing const pointers to the current solver state.
* \returns A BlockCGIterationState object containing const pointers to the current solver state.
*/
CGIterationState<ScalarType,MV> getState() const {
CGIterationState<ScalarType,MV> state;
state.R = R_;
state.P = P_;
state.AP = AP_;
state.Z = Z_;
Teuchos::RCP<CGIterationStateBase<ScalarType,MV> > getState() const {
auto state = Teuchos::rcp(new BlockCGIterationState<ScalarType,MV>());
state->R = R_;
state->P = P_;
state->AP = AP_;
state->Z = Z_;
return state;
}

void setState(Teuchos::RCP<CGIterationStateBase<ScalarType,MV> > state) {
auto s = Teuchos::rcp_dynamic_cast<BlockCGIterationState<ScalarType,MV> >(state, true);
R_ = s->R;
Z_ = s->Z;
P_ = s->P;
AP_ = s->AP;
}

//@}


Expand Down Expand Up @@ -260,7 +301,7 @@ class BlockCGIter<ScalarType, MV, OP, true> :
void setDoCondEst(bool /* val */){/*ignored*/}

//! Gets the diagonal for condition estimation (NOT_IMPLEMENTED)
Teuchos::ArrayView<MagnitudeType> getDiag() {
Teuchos::ArrayView<MagnitudeType> getDiag() {
Teuchos::ArrayView<MagnitudeType> temp;
return temp;
}
Expand All @@ -276,9 +317,6 @@ class BlockCGIter<ScalarType, MV, OP, true> :

//
// Internal methods
//
//! Method for initalizing the state storage needed by block CG.
void setStateSize();

//
// Classes inputed through constructor that define the linear problem to be solved.
Expand Down Expand Up @@ -348,40 +386,6 @@ class BlockCGIter<ScalarType, MV, OP, true> :
setBlockSize( bs );
}

template <class ScalarType, class MV, class OP>
void BlockCGIter<ScalarType,MV,OP,true>::setStateSize ()
{
if (! stateStorageInitialized_) {
// Check if there is any multivector to clone from.
Teuchos::RCP<const MV> lhsMV = lp_->getLHS();
Teuchos::RCP<const MV> rhsMV = lp_->getRHS();
if (lhsMV == Teuchos::null && rhsMV == Teuchos::null) {
stateStorageInitialized_ = false;
return;
}
else {
// Initialize the state storage If the subspace has not be
// initialized before, generate it using the LHS or RHS from
// lp_.
if (R_ == Teuchos::null || MVT::GetNumberVecs(*R_)!=blockSize_) {
// Get the multivector that is not null.
Teuchos::RCP<const MV> tmp = ( (rhsMV!=Teuchos::null)? rhsMV: lhsMV );
TEUCHOS_TEST_FOR_EXCEPTION
(tmp == Teuchos::null,std:: invalid_argument,
"Belos::BlockCGIter::setStateSize: LinearProblem lacks "
"multivectors from which to clone.");
R_ = MVT::Clone (*tmp, blockSize_);
Z_ = MVT::Clone (*tmp, blockSize_);
P_ = MVT::Clone (*tmp, blockSize_);
AP_ = MVT::Clone (*tmp, blockSize_);
}

// State storage has now been initialized.
stateStorageInitialized_ = true;
}
}
}

template <class ScalarType, class MV, class OP>
void BlockCGIter<ScalarType,MV,OP,true>::setBlockSize (int blockSize)
{
Expand All @@ -398,55 +402,50 @@ class BlockCGIter<ScalarType, MV, OP, true> :
}
blockSize_ = blockSize;
initialized_ = false;
// Use the current blockSize_ to initialize the state storage.
setStateSize ();
}

template <class ScalarType, class MV, class OP>
void BlockCGIter<ScalarType,MV,OP,true>::
initializeCG (CGIterationState<ScalarType,MV>& newstate)
initializeCG (Teuchos::RCP<CGIterationStateBase<ScalarType,MV> > newstate, Teuchos::RCP<MV> R_0)
{
const char prefix[] = "Belos::BlockCGIter::initialize: ";

// Initialize the state storage if it isn't already.
if (! stateStorageInitialized_) {
setStateSize();
}

TEUCHOS_TEST_FOR_EXCEPTION
(! stateStorageInitialized_, std::invalid_argument,
prefix << "Cannot initialize state storage!");
Teuchos::RCP<const MV> lhsMV = lp_->getLHS();
Teuchos::RCP<const MV> rhsMV = lp_->getRHS();
Teuchos::RCP<const MV> tmp = ( (rhsMV!=Teuchos::null)? rhsMV: lhsMV );
TEUCHOS_ASSERT(!newstate.is_null());
if (!Teuchos::rcp_dynamic_cast<BlockCGIterationState<ScalarType,MV> >(newstate, true)->matches(tmp, blockSize_))
newstate->initialize(tmp, blockSize_);
setState(newstate);

// NOTE: In BlockCGIter R_, the initial residual, is required!!!
const char errstr[] = "Specified multivectors must have a consistent "
"length and width.";

// Create convenience variables for zero and one.
//const MagnitudeType zero = Teuchos::ScalarTraits<MagnitudeType>::zero(); // unused

if (newstate.R != Teuchos::null) {
{

TEUCHOS_TEST_FOR_EXCEPTION
(MVT::GetGlobalLength(*newstate.R) != MVT::GetGlobalLength(*R_),
(MVT::GetGlobalLength(*R_0) != MVT::GetGlobalLength(*R_),
std::invalid_argument, prefix << errstr );
TEUCHOS_TEST_FOR_EXCEPTION
(MVT::GetNumberVecs(*newstate.R) != blockSize_,
(MVT::GetNumberVecs(*R_0) != blockSize_,
std::invalid_argument, prefix << errstr );

// Copy basis vectors from newstate into V
if (newstate.R != R_) {
if (R_0 != R_) {
// copy over the initial residual (unpreconditioned).
MVT::Assign( *newstate.R, *R_ );
MVT::Assign( *R_0, *R_ );
}
// Compute initial direction vectors
// Initially, they are set to the preconditioned residuals
//
if ( lp_->getLeftPrec() != Teuchos::null ) {
lp_->applyLeftPrec( *R_, *Z_ );
if ( lp_->getRightPrec() != Teuchos::null ) {
Teuchos::RCP<MV> tmp = MVT::Clone( *Z_, blockSize_ );
lp_->applyRightPrec( *Z_, *tmp );
Z_ = tmp;
Teuchos::RCP<MV> tmp2 = MVT::Clone( *Z_, blockSize_ );
lp_->applyRightPrec( *Z_, *tmp2 );
Z_ = tmp2;
}
}
else if ( lp_->getRightPrec() != Teuchos::null ) {
Expand All @@ -457,11 +456,6 @@ class BlockCGIter<ScalarType, MV, OP, true> :
}
MVT::Assign( *Z_, *P_ );
}
else {
TEUCHOS_TEST_FOR_EXCEPTION
(newstate.R == Teuchos::null, std::invalid_argument,
prefix << "BlockCGStateIterState does not have initial residual.");
}

// The solver is initialized
initialized_ = true;
Expand Down
Loading

0 comments on commit 363438e

Please sign in to comment.