Skip to content

Commit

Permalink
Merge Pull Request #9739 from rppawlo/Trilinos/phalanx-const-nonconst
Browse files Browse the repository at this point in the history
Automatically Merged using Trilinos Pull Request AutoTester
PR Title: Phalanx: add conversion utility
PR Author: rppawlo
  • Loading branch information
trilinos-autotester authored Sep 24, 2021
2 parents 57d48c3 + 6cb3075 commit 9e330c6
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 26 deletions.
19 changes: 14 additions & 5 deletions packages/panzer/adapters-stk/example/main_driver/main_driver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
#include "Teuchos_ConfigDefs.hpp"
#include "Teuchos_RCP.hpp"
#include "Teuchos_TimeMonitor.hpp"
#include "Teuchos_StackedTimer.hpp"
#include "Teuchos_DefaultComm.hpp"
#include "Teuchos_CommHelpers.hpp"
#include "Teuchos_GlobalMPISession.hpp"
Expand Down Expand Up @@ -99,11 +100,8 @@ int main(int argc, char *argv[])
#endif

try {

Teuchos::RCP<Teuchos::Time> total_time =
Teuchos::TimeMonitor::getNewTimer("User App: Total Time");

Teuchos::TimeMonitor timer(*total_time);
const auto stackedTimer = Teuchos::rcp(new Teuchos::StackedTimer("Panzer Main Driver"));
Teuchos::TimeMonitor::setStackedTimer(stackedTimer);

Teuchos::RCP<const Teuchos::Comm<int> > comm = Teuchos::DefaultComm<int>::getComm();

Expand Down Expand Up @@ -393,6 +391,17 @@ int main(int argc, char *argv[])
}
}
}

stackedTimer->stopBaseTimer();
{
Teuchos::StackedTimer::OutputOptions options;
options.output_fraction = true;
options.output_minmax = true;
options.output_histogram = true;
options.num_histogram = 5;
stackedTimer->report(std::cout, Teuchos::DefaultComm<int>::getComm(), options);
}

}
catch (std::exception& e) {
*out << "*********** Caught Exception: Begin Error Report ***********" << std::endl;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#ifndef PHALANX_GET_NON_CONST_DYN_RANK_VIEW_FROM_CONST_MDFIELD_HPP
#define PHALANX_GET_NON_CONST_DYN_RANK_VIEW_FROM_CONST_MDFIELD_HPP

#include "Phalanx_MDField.hpp"
#include "Kokkos_DynRankView.hpp"
#include "Teuchos_Assert.hpp"

namespace PHX {

template<typename Scalar,typename...Props>
Kokkos::DynRankView<Scalar,typename PHX::DevLayout<Scalar>::type,Kokkos::MemoryUnmanaged>
getNonConstDynRankViewFromConstMDField(const PHX::MDField<const Scalar,Props...>& f) {

using drv_type = Kokkos::DynRankView<Scalar,typename PHX::DevLayout<Scalar>::type,Kokkos::MemoryUnmanaged>;
using nonconst_data_type = typename Sacado::ScalarType< typename drv_type::value_type >::type*;
const int rank = f.rank();
Kokkos::DynRankView<Scalar,typename PHX::DevLayout<Scalar>::type,Kokkos::MemoryUnmanaged> tmp;

#ifdef PHX_DEBUG
TEUCHOS_ASSERT( (rank > 0) && (rank < 6) );
#endif

if (Sacado::IsFad<Scalar>::value) {
const int num_derivatives = Kokkos::dimension_scalar(f.get_static_view());
if (rank==1)
tmp = Kokkos::DynRankView<Scalar,typename PHX::DevLayout<Scalar>::type,Kokkos::MemoryUnmanaged>(const_cast<nonconst_data_type>(f.get_static_view().data()),f.extent(0),num_derivatives);
else if (rank==2)
tmp = Kokkos::DynRankView<Scalar,typename PHX::DevLayout<Scalar>::type,Kokkos::MemoryUnmanaged>(const_cast<nonconst_data_type>(f.get_static_view().data()),f.extent(0),f.extent(1),num_derivatives);
else if (rank==3)
tmp = Kokkos::DynRankView<Scalar,typename PHX::DevLayout<Scalar>::type,Kokkos::MemoryUnmanaged>(const_cast<nonconst_data_type>(f.get_static_view().data()),f.extent(0),f.extent(1),f.extent(2),num_derivatives);
else if (rank==4)
tmp = Kokkos::DynRankView<Scalar,typename PHX::DevLayout<Scalar>::type,Kokkos::MemoryUnmanaged>(const_cast<nonconst_data_type>(f.get_static_view().data()),f.extent(0),f.extent(1),f.extent(2),f.extent(3),num_derivatives);
else if (rank==5)
tmp = Kokkos::DynRankView<Scalar,typename PHX::DevLayout<Scalar>::type,Kokkos::MemoryUnmanaged>(const_cast<nonconst_data_type>(f.get_static_view().data()),f.extent(0),f.extent(1),f.extent(2),f.extent(3),f.extent(4),num_derivatives);
}
else {
if (rank==1)
tmp = Kokkos::DynRankView<Scalar,typename PHX::DevLayout<Scalar>::type,Kokkos::MemoryUnmanaged>(const_cast<nonconst_data_type>(f.get_static_view().data()),f.extent(0));
else if (rank==2)
tmp = Kokkos::DynRankView<Scalar,typename PHX::DevLayout<Scalar>::type,Kokkos::MemoryUnmanaged>(const_cast<nonconst_data_type>(f.get_static_view().data()),f.extent(0),f.extent(1));
else if (rank==3)
tmp = Kokkos::DynRankView<Scalar,typename PHX::DevLayout<Scalar>::type,Kokkos::MemoryUnmanaged>(const_cast<nonconst_data_type>(f.get_static_view().data()),f.extent(0),f.extent(1),f.extent(2));
else if (rank==4)
tmp = Kokkos::DynRankView<Scalar,typename PHX::DevLayout<Scalar>::type,Kokkos::MemoryUnmanaged>(const_cast<nonconst_data_type>(f.get_static_view().data()),f.extent(0),f.extent(1),f.extent(2),f.extent(3));
else if (rank==5)
tmp = Kokkos::DynRankView<Scalar,typename PHX::DevLayout<Scalar>::type,Kokkos::MemoryUnmanaged>(const_cast<nonconst_data_type>(f.get_static_view().data()),f.extent(0),f.extent(1),f.extent(2),f.extent(3),f.extent(4));
}

return tmp;
}

} // namespace PHX

#endif
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
#include "Phalanx_ExtentTraits.hpp"
#include "Phalanx_MDField.hpp"
#include "Teuchos_UnitTestHarness.hpp"
#include "Teuchos_TypeNameTraits.hpp"
#include "Kokkos_DynRankView.hpp"
#include "Phalanx_GetNonConstDynRankViewFromConstMDField.hpp"

// This test demonstrates how to get a nonconst DynRankView from a
// const MDField for double and FAD scalar types.
Expand All @@ -24,9 +23,6 @@ const int num_pts = 8;
const int num_equations = 32;
const int num_derivatives = 4;

template<typename T> struct remove_low_level_const {using type = T;};
template<typename T> struct remove_low_level_const<T const *> {using type = T*;};

namespace {
// function input for a,b,c are all CONST
template<typename Scalar>
Expand All @@ -36,24 +32,11 @@ namespace {
const_mdfield<Scalar>& b,
const_mdfield<Scalar>& c)
{
// Demonstrate getting a nonconst view from a const view
using data_type = decltype(c.get_static_view().data());
using nonconst_data_type = typename remove_low_level_const<data_type>::type;
std::cout << "\ndata_type = " << Teuchos::demangleName(typeid(data_type).name()) << std::endl;
std::cout << "nonconst_data_type = " << Teuchos::demangleName(typeid(nonconst_data_type).name()) << std::endl;

// NOTE: the FAD types need the DevLayout for contiguous mapping
// on cuda. This is embedded in the MDField.
Kokkos::DynRankView<Scalar,typename PHX::DevLayout<Scalar>::type,Kokkos::MemoryUnmanaged> tmp_c;
if (Sacado::IsFad<Scalar>::value)
tmp_c = Kokkos::DynRankView<Scalar,typename PHX::DevLayout<Scalar>::type,Kokkos::MemoryUnmanaged>(const_cast<nonconst_data_type>(c.get_static_view().data()),c.extent(0),c.extent(1),c.extent(2),num_derivatives);
else
tmp_c = Kokkos::DynRankView<Scalar,typename PHX::DevLayout<Scalar>::type,Kokkos::MemoryUnmanaged>(const_cast<nonconst_data_type>(c.get_static_view().data()),c.extent(0),c.extent(1),c.extent(2));

std::cout << "ext_0=" << c.extent(0) << ", ext_1=" << c.extent(1) << ", ext_2=" << c.extent(2) << ", ext_3=" << c.extent(3) << std::endl;

auto tmp_a = a.get_static_view();
auto tmp_b = b.get_static_view();
// Demonstrate getting a nonconst view from a const view
auto tmp_c = PHX::getNonConstDynRankViewFromConstMDField<Scalar>(c);

Kokkos::MDRangePolicy<exec_t,Kokkos::Rank<3>> policy({0,0,0},{num_cells,num_pts,num_equations});
Kokkos::parallel_for("use non-const DynRankView from const View",policy,KOKKOS_LAMBDA (const int cell,const int pt,const int eq) {
tmp_c(cell,pt,eq) = tmp_a(cell,pt,eq) + tmp_b(cell,pt,eq);
Expand All @@ -62,6 +45,7 @@ namespace {
}
}

// Check double values to make sure layout hasn't been switched.
TEUCHOS_UNIT_TEST(NonConstDynRankViewFromView,double) {
using ScalarType = double;
non_const_mdfield<ScalarType> a("a","layout",num_cells,num_pts,num_equations);
Expand All @@ -80,6 +64,7 @@ TEUCHOS_UNIT_TEST(NonConstDynRankViewFromView,double) {
}
}

// Check DFAD values to make sure layout hasn't been switched.
TEUCHOS_UNIT_TEST(NonConstDynRankViewFromView,FAD) {
using RealType = double;
using ScalarType = Sacado::Fad::DFad<RealType>;
Expand Down

0 comments on commit 9e330c6

Please sign in to comment.