Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Belos CG: Caching of state vectors #13751

Merged
merged 3 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 80 additions & 86 deletions packages/belos/src/BelosBlockCGIter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,41 @@

namespace Belos {

//! @name BlockCGIteration Structures
//@{

/** \brief Structure to contain pointers to BlockCGIteration state variables.
*
* This struct is utilized by BlockCGIteration::initialize() and BlockCGIteration::getState().
*/
template <class ScalarType, class MV>
class BlockCGIterationState : public CGIterationStateBase<ScalarType, MV> {

public:
BlockCGIterationState() = default;

BlockCGIterationState(Teuchos::RCP<const MV> tmp) {
initialize(tmp);
}

virtual ~BlockCGIterationState() = default;

void initialize(Teuchos::RCP<const MV> tmp, int _numVectors) {
using MVT = MultiVecTraits<ScalarType, MV>;
this->R = MVT::Clone( *tmp, _numVectors );
this->Z = MVT::Clone( *tmp, _numVectors );
this->P = MVT::Clone( *tmp, _numVectors );
this->AP = MVT::Clone(*tmp, _numVectors );

CGIterationStateBase<ScalarType, MV>::initialize(tmp, _numVectors);
}

bool matches(Teuchos::RCP<const MV> tmp, int _numVectors=1) const {
return CGIterationStateBase<ScalarType, MV>::matches(tmp, _numVectors);
}

};

/// \class BlockCGIter
/// \brief Implementation of the block preconditioned Conjugate
/// Gradient (CG) iteration.
Expand Down Expand Up @@ -69,15 +104,19 @@ class BlockCGIter : virtual public CGIteration<ScalarType, MV, OP> {
TEUCHOS_TEST_FOR_EXCEPTION(true, std::logic_error, "Stub");
}

void initializeCG (CGIterationState<ScalarType,MV>& /* newstate */) {
void initializeCG (Teuchos::RCP<BlockCGIterationState<ScalarType,MV> > /* newstate */, Teuchos::RCP<MV> /* R_0 */) {
TEUCHOS_TEST_FOR_EXCEPTION(true, std::logic_error, "Stub");
}

void initialize () {
TEUCHOS_TEST_FOR_EXCEPTION(true, std::logic_error, "Stub");
}

CGIterationState<ScalarType,MV> getState () const {
Teuchos::RCP<CGIterationStateBase<ScalarType,MV> > getState () const {
TEUCHOS_TEST_FOR_EXCEPTION(true, std::logic_error, "Stub");
}

void setState(Teuchos::RCP<CGIterationStateBase<ScalarType,MV> > state) {
TEUCHOS_TEST_FOR_EXCEPTION(true, std::logic_error, "Stub");
}

Expand Down Expand Up @@ -118,11 +157,6 @@ class BlockCGIter : virtual public CGIteration<ScalarType, MV, OP> {
TEUCHOS_TEST_FOR_EXCEPTION(true, std::logic_error, "Stub");
}


private:
void setStateSize() {
TEUCHOS_TEST_FOR_EXCEPTION(true, std::logic_error, "Stub");
}
};

/// \brief Partial specialization for ScalarType types for which
Expand All @@ -137,10 +171,10 @@ class BlockCGIter<ScalarType, MV, OP, true> :
//
// Convenience typedefs
//
typedef MultiVecTraits<ScalarType,MV> MVT;
typedef OperatorTraits<ScalarType,MV,OP> OPT;
typedef Teuchos::ScalarTraits<ScalarType> SCT;
typedef typename SCT::magnitudeType MagnitudeType;
using MVT = MultiVecTraits<ScalarType, MV>;
using OPT = OperatorTraits<ScalarType, MV, OP>;
using SCT = Teuchos::ScalarTraits<ScalarType>;
using MagnitudeType = typename SCT::magnitudeType;

//! @name Constructors/Destructor
//@{
Expand All @@ -157,7 +191,7 @@ class BlockCGIter<ScalarType, MV, OP, true> :
Teuchos::ParameterList &params );

//! Destructor.
virtual ~BlockCGIter() {};
virtual ~BlockCGIter() = default;
//@}


Expand Down Expand Up @@ -192,32 +226,39 @@ class BlockCGIter<ScalarType, MV, OP, true> :
* \note For any pointer in \c newstate which directly points to the multivectors in
* the solver, the data is not copied.
*/
void initializeCG(CGIterationState<ScalarType,MV>& newstate);
void initializeCG(Teuchos::RCP<CGIterationStateBase<ScalarType,MV> > newstate, Teuchos::RCP<MV> R_0);

/*! \brief Initialize the solver with the initial vectors from the linear problem
* or random data.
*/
void initialize()
{
CGIterationState<ScalarType,MV> empty;
initializeCG(empty);
initializeCG(Teuchos::null, Teuchos::null);
}

/*! \brief Get the current state of the linear solver.
*
* The data is only valid if isInitialized() == \c true.
*
* \returns A CGIterationState object containing const pointers to the current solver state.
* \returns A BlockCGIterationState object containing const pointers to the current solver state.
*/
CGIterationState<ScalarType,MV> getState() const {
CGIterationState<ScalarType,MV> state;
state.R = R_;
state.P = P_;
state.AP = AP_;
state.Z = Z_;
Teuchos::RCP<CGIterationStateBase<ScalarType,MV> > getState() const {
auto state = Teuchos::rcp(new BlockCGIterationState<ScalarType,MV>());
state->R = R_;
state->P = P_;
state->AP = AP_;
state->Z = Z_;
return state;
}

void setState(Teuchos::RCP<CGIterationStateBase<ScalarType,MV> > state) {
auto s = Teuchos::rcp_dynamic_cast<BlockCGIterationState<ScalarType,MV> >(state, true);
R_ = s->R;
Z_ = s->Z;
P_ = s->P;
AP_ = s->AP;
}

//@}


Expand Down Expand Up @@ -260,7 +301,7 @@ class BlockCGIter<ScalarType, MV, OP, true> :
void setDoCondEst(bool /* val */){/*ignored*/}

//! Gets the diagonal for condition estimation (NOT_IMPLEMENTED)
Teuchos::ArrayView<MagnitudeType> getDiag() {
Teuchos::ArrayView<MagnitudeType> getDiag() {
Teuchos::ArrayView<MagnitudeType> temp;
return temp;
}
Expand All @@ -276,9 +317,6 @@ class BlockCGIter<ScalarType, MV, OP, true> :

//
// Internal methods
//
//! Method for initalizing the state storage needed by block CG.
void setStateSize();

//
// Classes inputed through constructor that define the linear problem to be solved.
Expand Down Expand Up @@ -348,40 +386,6 @@ class BlockCGIter<ScalarType, MV, OP, true> :
setBlockSize( bs );
}

template <class ScalarType, class MV, class OP>
void BlockCGIter<ScalarType,MV,OP,true>::setStateSize ()
{
if (! stateStorageInitialized_) {
// Check if there is any multivector to clone from.
Teuchos::RCP<const MV> lhsMV = lp_->getLHS();
Teuchos::RCP<const MV> rhsMV = lp_->getRHS();
if (lhsMV == Teuchos::null && rhsMV == Teuchos::null) {
stateStorageInitialized_ = false;
return;
}
else {
// Initialize the state storage If the subspace has not be
// initialized before, generate it using the LHS or RHS from
// lp_.
if (R_ == Teuchos::null || MVT::GetNumberVecs(*R_)!=blockSize_) {
// Get the multivector that is not null.
Teuchos::RCP<const MV> tmp = ( (rhsMV!=Teuchos::null)? rhsMV: lhsMV );
TEUCHOS_TEST_FOR_EXCEPTION
(tmp == Teuchos::null,std:: invalid_argument,
"Belos::BlockCGIter::setStateSize: LinearProblem lacks "
"multivectors from which to clone.");
R_ = MVT::Clone (*tmp, blockSize_);
Z_ = MVT::Clone (*tmp, blockSize_);
P_ = MVT::Clone (*tmp, blockSize_);
AP_ = MVT::Clone (*tmp, blockSize_);
}

// State storage has now been initialized.
stateStorageInitialized_ = true;
}
}
}

template <class ScalarType, class MV, class OP>
void BlockCGIter<ScalarType,MV,OP,true>::setBlockSize (int blockSize)
{
Expand All @@ -398,55 +402,50 @@ class BlockCGIter<ScalarType, MV, OP, true> :
}
blockSize_ = blockSize;
initialized_ = false;
// Use the current blockSize_ to initialize the state storage.
setStateSize ();
}

template <class ScalarType, class MV, class OP>
void BlockCGIter<ScalarType,MV,OP,true>::
initializeCG (CGIterationState<ScalarType,MV>& newstate)
initializeCG (Teuchos::RCP<CGIterationStateBase<ScalarType,MV> > newstate, Teuchos::RCP<MV> R_0)
{
const char prefix[] = "Belos::BlockCGIter::initialize: ";

// Initialize the state storage if it isn't already.
if (! stateStorageInitialized_) {
setStateSize();
}

TEUCHOS_TEST_FOR_EXCEPTION
(! stateStorageInitialized_, std::invalid_argument,
prefix << "Cannot initialize state storage!");
Teuchos::RCP<const MV> lhsMV = lp_->getLHS();
Teuchos::RCP<const MV> rhsMV = lp_->getRHS();
Teuchos::RCP<const MV> tmp = ( (rhsMV!=Teuchos::null)? rhsMV: lhsMV );
TEUCHOS_ASSERT(!newstate.is_null());
if (!Teuchos::rcp_dynamic_cast<BlockCGIterationState<ScalarType,MV> >(newstate, true)->matches(tmp, blockSize_))
newstate->initialize(tmp, blockSize_);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hkthorn It's getting called using blockSize_. The argument numRHS in BlockCGIterationState::initialize() is a leftover from the refactor. In order to be consistent, I could just rename numRHS to numVectors in all the iteration states.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that would avoid confusion in the future, if you don't mind.

setState(newstate);

// NOTE: In BlockCGIter R_, the initial residual, is required!!!
const char errstr[] = "Specified multivectors must have a consistent "
"length and width.";

// Create convenience variables for zero and one.
//const MagnitudeType zero = Teuchos::ScalarTraits<MagnitudeType>::zero(); // unused

if (newstate.R != Teuchos::null) {
{

TEUCHOS_TEST_FOR_EXCEPTION
(MVT::GetGlobalLength(*newstate.R) != MVT::GetGlobalLength(*R_),
(MVT::GetGlobalLength(*R_0) != MVT::GetGlobalLength(*R_),
std::invalid_argument, prefix << errstr );
TEUCHOS_TEST_FOR_EXCEPTION
(MVT::GetNumberVecs(*newstate.R) != blockSize_,
(MVT::GetNumberVecs(*R_0) != blockSize_,
std::invalid_argument, prefix << errstr );

// Copy basis vectors from newstate into V
if (newstate.R != R_) {
if (R_0 != R_) {
// copy over the initial residual (unpreconditioned).
MVT::Assign( *newstate.R, *R_ );
MVT::Assign( *R_0, *R_ );
}
// Compute initial direction vectors
// Initially, they are set to the preconditioned residuals
//
if ( lp_->getLeftPrec() != Teuchos::null ) {
lp_->applyLeftPrec( *R_, *Z_ );
if ( lp_->getRightPrec() != Teuchos::null ) {
Teuchos::RCP<MV> tmp = MVT::Clone( *Z_, blockSize_ );
lp_->applyRightPrec( *Z_, *tmp );
Z_ = tmp;
Teuchos::RCP<MV> tmp2 = MVT::Clone( *Z_, blockSize_ );
lp_->applyRightPrec( *Z_, *tmp2 );
Z_ = tmp2;
}
}
else if ( lp_->getRightPrec() != Teuchos::null ) {
Expand All @@ -457,11 +456,6 @@ class BlockCGIter<ScalarType, MV, OP, true> :
}
MVT::Assign( *Z_, *P_ );
}
else {
TEUCHOS_TEST_FOR_EXCEPTION
(newstate.R == Teuchos::null, std::invalid_argument,
prefix << "BlockCGStateIterState does not have initial residual.");
}

// The solver is initialized
initialized_ = true;
Expand Down
Loading
Loading