Skip to content

Commit

Permalink
MueLu NotayAggregationFactory: Allow LWGraph_kokkos
Browse files Browse the repository at this point in the history
Signed-off-by: Christian Glusa <[email protected]>
  • Loading branch information
cgcgcg committed Jul 26, 2024
1 parent 088f3f3 commit c9cb561
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <Xpetra_Matrix_fwd.hpp>

#include "MueLu_LWGraph_fwd.hpp"
#include "MueLu_LWGraph_kokkos_fwd.hpp"
#include "MueLu_Exceptions.hpp"
#include "MueLu_SingleLevelFactoryBase.hpp"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

#include "MueLu_Aggregates.hpp"
#include "MueLu_LWGraph.hpp"
#include "MueLu_LWGraph_kokkos.hpp"
#include "MueLu_Level.hpp"
#include "MueLu_MasterList.hpp"
#include "MueLu_Monitor.hpp"
Expand Down Expand Up @@ -60,10 +61,10 @@ RCP<const ParameterList> NotayAggregationFactory<Scalar, LocalOrdinal, GlobalOrd
#undef SET_VALID_ENTRY

// general variables needed in AggregationFactory
validParamList->set<RCP<const FactoryBase> >("A", null, "Generating factory of the matrix");
validParamList->set<RCP<const FactoryBase> >("Graph", null, "Generating factory of the graph");
validParamList->set<RCP<const FactoryBase> >("DofsPerNode", null, "Generating factory for variable \'DofsPerNode\', usually the same as for \'Graph\'");
validParamList->set<RCP<const FactoryBase> >("AggregateQualities", null, "Generating factory for variable \'AggregateQualities\'");
validParamList->set<RCP<const FactoryBase>>("A", null, "Generating factory of the matrix");
validParamList->set<RCP<const FactoryBase>>("Graph", null, "Generating factory of the graph");
validParamList->set<RCP<const FactoryBase>>("DofsPerNode", null, "Generating factory for variable \'DofsPerNode\', usually the same as for \'Graph\'");
validParamList->set<RCP<const FactoryBase>>("AggregateQualities", null, "Generating factory for variable \'AggregateQualities\'");

return validParamList;
}
Expand Down Expand Up @@ -113,8 +114,14 @@ void NotayAggregationFactory<Scalar, LocalOrdinal, GlobalOrdinal, Node>::Build(L
"NotayAggregationFactory::Build(): \"aggregation: pairwise: size\""
" must be a strictly positive integer");

RCP<const LWGraph> graph = Get<RCP<LWGraph> >(currentLevel, "Graph");
RCP<const Matrix> A = Get<RCP<Matrix> >(currentLevel, "A");
RCP<const LWGraph> graph;
if (IsType<RCP<LWGraph>>(currentLevel, "Graph"))
graph = Get<RCP<LWGraph>>(currentLevel, "Graph");
else {
auto graph_k = Get<RCP<LWGraph_kokkos>>(currentLevel, "Graph");
graph = graph_k->copyToHost();
}
RCP<const Matrix> A = Get<RCP<Matrix>>(currentLevel, "A");

// Setup aggregates & aggStat objects
RCP<Aggregates> aggregates = rcp(new Aggregates(*graph));
Expand Down Expand Up @@ -161,8 +168,8 @@ void NotayAggregationFactory<Scalar, LocalOrdinal, GlobalOrdinal, Node>::Build(L
if (ordering == O_RANDOM)
MueLu::NotayUtils::RandomReorder(orderingVector);
else if (ordering == O_CUTHILL_MCKEE) {
RCP<Xpetra::Vector<LO, LO, GO, NO> > rcmVector = MueLu::Utilities<SC, LO, GO, NO>::CuthillMcKee(*A);
auto localVector = rcmVector->getData(0);
RCP<Xpetra::Vector<LO, LO, GO, NO>> rcmVector = MueLu::Utilities<SC, LO, GO, NO>::CuthillMcKee(*A);
auto localVector = rcmVector->getData(0);
for (LO i = 0; i < numRows; i++)
orderingVector[i] = localVector[i];
}
Expand Down Expand Up @@ -198,7 +205,7 @@ void NotayAggregationFactory<Scalar, LocalOrdinal, GlobalOrdinal, Node>::Build(L
// Directly compute rowsum from A, rather than coarseA
row_sum_type rowSum("rowSum", numLocalAggregates);
{
std::vector<std::vector<LO> > agg2vertex(numLocalAggregates);
std::vector<std::vector<LO>> agg2vertex(numLocalAggregates);
auto vertex2AggId = aggregates->GetVertex2AggId()->getData(0);
for (LO i = 0; i < (LO)numRows; i++) {
if (aggStat[i] != AGGREGATED)
Expand Down Expand Up @@ -248,8 +255,8 @@ void NotayAggregationFactory<Scalar, LocalOrdinal, GlobalOrdinal, Node>::Build(L
if (ordering == O_RANDOM)
MueLu::NotayUtils::RandomReorder(localOrderingVector);
else if (ordering == O_CUTHILL_MCKEE) {
RCP<Xpetra::Vector<LO, LO, GO, NO> > rcmVector = MueLu::Utilities<SC, LO, GO, NO>::CuthillMcKee(*A);
auto localVector = rcmVector->getData(0);
RCP<Xpetra::Vector<LO, LO, GO, NO>> rcmVector = MueLu::Utilities<SC, LO, GO, NO>::CuthillMcKee(*A);
auto localVector = rcmVector->getData(0);
for (LO i = 0; i < numRows; i++)
localOrderingVector[i] = localVector[i];
}
Expand Down

0 comments on commit c9cb561

Please sign in to comment.