diff --git a/packages/muelu/src/Graph/PairwiseAggregation/MueLu_NotayAggregationFactory_decl.hpp b/packages/muelu/src/Graph/PairwiseAggregation/MueLu_NotayAggregationFactory_decl.hpp index 59c5056a3be1..c695ce7fdf77 100644 --- a/packages/muelu/src/Graph/PairwiseAggregation/MueLu_NotayAggregationFactory_decl.hpp +++ b/packages/muelu/src/Graph/PairwiseAggregation/MueLu_NotayAggregationFactory_decl.hpp @@ -18,6 +18,7 @@ #include #include "MueLu_LWGraph_fwd.hpp" +#include "MueLu_LWGraph_kokkos_fwd.hpp" #include "MueLu_Exceptions.hpp" #include "MueLu_SingleLevelFactoryBase.hpp" diff --git a/packages/muelu/src/Graph/PairwiseAggregation/MueLu_NotayAggregationFactory_def.hpp b/packages/muelu/src/Graph/PairwiseAggregation/MueLu_NotayAggregationFactory_def.hpp index 33b8118288d8..b432ffb1d868 100644 --- a/packages/muelu/src/Graph/PairwiseAggregation/MueLu_NotayAggregationFactory_def.hpp +++ b/packages/muelu/src/Graph/PairwiseAggregation/MueLu_NotayAggregationFactory_def.hpp @@ -24,6 +24,7 @@ #include "MueLu_Aggregates.hpp" #include "MueLu_LWGraph.hpp" +#include "MueLu_LWGraph_kokkos.hpp" #include "MueLu_Level.hpp" #include "MueLu_MasterList.hpp" #include "MueLu_Monitor.hpp" @@ -60,10 +61,10 @@ RCP NotayAggregationFactoryset >("A", null, "Generating factory of the matrix"); - validParamList->set >("Graph", null, "Generating factory of the graph"); - validParamList->set >("DofsPerNode", null, "Generating factory for variable \'DofsPerNode\', usually the same as for \'Graph\'"); - validParamList->set >("AggregateQualities", null, "Generating factory for variable \'AggregateQualities\'"); + validParamList->set>("A", null, "Generating factory of the matrix"); + validParamList->set>("Graph", null, "Generating factory of the graph"); + validParamList->set>("DofsPerNode", null, "Generating factory for variable \'DofsPerNode\', usually the same as for \'Graph\'"); + validParamList->set>("AggregateQualities", null, "Generating factory for variable \'AggregateQualities\'"); return validParamList; } @@ -113,8 +114,14 @@ void NotayAggregationFactory::Build(L "NotayAggregationFactory::Build(): \"aggregation: pairwise: size\"" " must be a strictly positive integer"); - RCP graph = Get >(currentLevel, "Graph"); - RCP A = Get >(currentLevel, "A"); + RCP graph; + if (IsType>(currentLevel, "Graph")) + graph = Get>(currentLevel, "Graph"); + else { + auto graph_k = Get>(currentLevel, "Graph"); + graph = graph_k->copyToHost(); + } + RCP A = Get>(currentLevel, "A"); // Setup aggregates & aggStat objects RCP aggregates = rcp(new Aggregates(*graph)); @@ -161,8 +168,8 @@ void NotayAggregationFactory::Build(L if (ordering == O_RANDOM) MueLu::NotayUtils::RandomReorder(orderingVector); else if (ordering == O_CUTHILL_MCKEE) { - RCP > rcmVector = MueLu::Utilities::CuthillMcKee(*A); - auto localVector = rcmVector->getData(0); + RCP> rcmVector = MueLu::Utilities::CuthillMcKee(*A); + auto localVector = rcmVector->getData(0); for (LO i = 0; i < numRows; i++) orderingVector[i] = localVector[i]; } @@ -198,7 +205,7 @@ void NotayAggregationFactory::Build(L // Directly compute rowsum from A, rather than coarseA row_sum_type rowSum("rowSum", numLocalAggregates); { - std::vector > agg2vertex(numLocalAggregates); + std::vector> agg2vertex(numLocalAggregates); auto vertex2AggId = aggregates->GetVertex2AggId()->getData(0); for (LO i = 0; i < (LO)numRows; i++) { if (aggStat[i] != AGGREGATED) @@ -248,8 +255,8 @@ void NotayAggregationFactory::Build(L if (ordering == O_RANDOM) MueLu::NotayUtils::RandomReorder(localOrderingVector); else if (ordering == O_CUTHILL_MCKEE) { - RCP > rcmVector = MueLu::Utilities::CuthillMcKee(*A); - auto localVector = rcmVector->getData(0); + RCP> rcmVector = MueLu::Utilities::CuthillMcKee(*A); + auto localVector = rcmVector->getData(0); for (LO i = 0; i < numRows; i++) localOrderingVector[i] = localVector[i]; }