Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement getHostBasis() in newer Basis subclasses #8811

Merged
merged 10 commits into from
Mar 1, 2021
15 changes: 10 additions & 5 deletions packages/intrepid2/src/Discretization/Basis/Intrepid2_Basis.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ class Basis;
template <typename DeviceType = void, typename OutputType = double, typename PointType = double>
using BasisPtr = Teuchos::RCP<Basis<DeviceType,OutputType,PointType> >;

/** \brief Pointer to a Basis whose device type is on the host (Kokkos::HostSpace::device_type), allowing host access to input and output views, and ensuring host execution of basis evaluation.
*/
template <typename OutputType = double, typename PointType = double>
using HostBasisPtr = BasisPtr<typename Kokkos::HostSpace::device_type, OutputType, PointType>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you do defaulting for the output type and point type ? This should match to this output type and point type. It does not need to default with double.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kyungjoo-kim, I followed the pattern of the BasisPtr typedef that is immediately above this typedef. It's true it doesn't need to default to double, but I think it's good to be consistent between the two typedefs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is internal typedef of the basis. Defaulting the template arguments here and in the above does not make sense to me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kyungjoo-kim, ah, I see the confusion. Actually, these typedefs are not internal to the Basis class.


/** \class Intrepid2::Basis
\brief An abstract base class that defines interface for concrete basis implementations for
Finite Element (FEM) and Finite Volume/Finite Difference (FVD) discrete spaces.
Expand Down Expand Up @@ -898,11 +903,11 @@ using BasisPtr = Teuchos::RCP<Basis<DeviceType,OutputType,PointType> >;
">>> ERROR (Basis::getSubCellRefBasis): this method is not supported or should be overridden accordingly by derived classes.");
}

/** \brief creates and returns a basis object allocated on host.
\return pointer to a basis allocated on host.
*/
virtual BasisPtr<typename Kokkos::HostSpace::device_type, OutputValueType, PointValueType>
/** \brief Creates and returns a Basis object whose DeviceType template argument is Kokkos::HostSpace::device_type, but is otherwise identical to this.

\return Pointer to the new Basis object.
*/
virtual HostBasisPtr<OutputValueType, PointValueType>
getHostBasis() const {
INTREPID2_TEST_FOR_EXCEPTION( true, std::logic_error,
">>> ERROR (Basis::getHostBasis): this method is not supported or should be overridden accordingly by derived classes.");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ namespace Intrepid2
using OutputValueType = typename LineBasisHGRAD::OutputValueType;
using PointValueType = typename LineBasisHGRAD::PointValueType;

using Basis = ::Intrepid2::Basis<ExecutionSpace,OutputValueType,PointValueType>;
using Basis = typename LineBasisHGRAD::BasisBase;
using BasisPtr = Teuchos::RCP<Basis>;

// line bases
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,7 @@ namespace Intrepid2
{
template<class HGRAD_LINE, class HVOL_LINE>
class Basis_Derived_HCURL_Family1_HEX
: public Basis_TensorBasis3<HVOL_LINE,
HGRAD_LINE,
HGRAD_LINE>
: public Basis_TensorBasis3<typename HGRAD_LINE::BasisBase>
{
public:
using OutputViewType = typename HGRAD_LINE::OutputViewType;
Expand All @@ -79,7 +77,7 @@ namespace Intrepid2
using LineGradBasis = HGRAD_LINE;
using LineVolBasis = HVOL_LINE;

using TensorBasis3 = Basis_TensorBasis3<LineVolBasis, LineGradBasis, LineGradBasis>;
using TensorBasis3 = Basis_TensorBasis3<typename HGRAD_LINE::BasisBase>;
public:
/** \brief Constructor.
\param [in] polyOrder_x - the polynomial order in the x dimension.
Expand Down Expand Up @@ -224,9 +222,7 @@ namespace Intrepid2

template<class HGRAD_LINE, class HVOL_LINE>
class Basis_Derived_HCURL_Family2_HEX
: public Basis_TensorBasis3<HGRAD_LINE,
HVOL_LINE,
HGRAD_LINE>
: public Basis_TensorBasis3<typename HGRAD_LINE::BasisBase>
{
public:
using OutputViewType = typename HGRAD_LINE::OutputViewType;
Expand All @@ -236,7 +232,7 @@ namespace Intrepid2
using LineGradBasis = HGRAD_LINE;
using LineVolBasis = HVOL_LINE;

using TensorBasis3 = Basis_TensorBasis3<LineGradBasis, LineVolBasis, LineGradBasis>;
using TensorBasis3 = Basis_TensorBasis3<typename HGRAD_LINE::BasisBase>;
public:
/** \brief Constructor.
\param [in] polyOrder_x - the polynomial order in the x dimension.
Expand Down Expand Up @@ -388,9 +384,7 @@ namespace Intrepid2

template<class HGRAD_LINE, class HVOL_LINE>
class Basis_Derived_HCURL_Family3_HEX
: public Basis_TensorBasis3<HGRAD_LINE,
HGRAD_LINE,
HVOL_LINE>
: public Basis_TensorBasis3<typename HGRAD_LINE::BasisBase>
{
using OutputViewType = typename HGRAD_LINE::OutputViewType;
using PointViewType = typename HGRAD_LINE::PointViewType ;
Expand All @@ -399,7 +393,7 @@ namespace Intrepid2
using LineGradBasis = HGRAD_LINE;
using LineVolBasis = HVOL_LINE;

using TensorBasis3 = Basis_TensorBasis3<LineGradBasis, LineGradBasis, LineVolBasis>;
using TensorBasis3 = Basis_TensorBasis3<typename HGRAD_LINE::BasisBase>;
public:
/** \brief Constructor.
\param [in] polyOrder_x - the polynomial order in the x dimension.
Expand Down Expand Up @@ -540,11 +534,11 @@ namespace Intrepid2

template<class HGRAD_LINE, class HVOL_LINE>
class Basis_Derived_HCURL_Family1_Family2_HEX
: public Basis_DirectSumBasis <typename HGRAD_LINE::ExecutionSpace, typename HGRAD_LINE::OutputValueType, typename HGRAD_LINE::PointValueType>
: public Basis_DirectSumBasis <typename HGRAD_LINE::BasisBase>
{
using Family1 = Basis_Derived_HCURL_Family1_HEX<HGRAD_LINE, HVOL_LINE>;
using Family2 = Basis_Derived_HCURL_Family2_HEX<HGRAD_LINE, HVOL_LINE>;
using DirectSumBasis = Basis_DirectSumBasis <typename HGRAD_LINE::ExecutionSpace, typename HGRAD_LINE::OutputValueType, typename HGRAD_LINE::PointValueType>;
using DirectSumBasis = Basis_DirectSumBasis <typename HGRAD_LINE::BasisBase>;
public:
/** \brief Constructor.
\param [in] polyOrder_x - the polynomial order in the x dimension.
Expand All @@ -560,11 +554,11 @@ namespace Intrepid2

template<class HGRAD_LINE, class HVOL_LINE>
class Basis_Derived_HCURL_HEX
: public Basis_DirectSumBasis <typename HGRAD_LINE::ExecutionSpace, typename HGRAD_LINE::OutputValueType, typename HGRAD_LINE::PointValueType>
: public Basis_DirectSumBasis <typename HGRAD_LINE::BasisBase>
{
using Family12 = Basis_Derived_HCURL_Family1_Family2_HEX<HGRAD_LINE, HVOL_LINE>;
using Family3 = Basis_Derived_HCURL_Family3_HEX <HGRAD_LINE, HVOL_LINE>;
using DirectSumBasis = Basis_DirectSumBasis <typename HGRAD_LINE::ExecutionSpace, typename HGRAD_LINE::OutputValueType, typename HGRAD_LINE::PointValueType>;
using DirectSumBasis = Basis_DirectSumBasis <typename HGRAD_LINE::BasisBase>;

std::string name_;
ordinal_type order_x_;
Expand All @@ -577,6 +571,8 @@ namespace Intrepid2
using ExecutionSpace = typename HGRAD_LINE::ExecutionSpace;
using OutputValueType = typename HGRAD_LINE::OutputValueType;
using PointValueType = typename HGRAD_LINE::PointValueType;

using BasisBase = typename HGRAD_LINE::BasisBase;

/** \brief Constructor.
\param [in] polyOrder_x - the polynomial order in the x dimension.
Expand Down Expand Up @@ -631,7 +627,7 @@ namespace Intrepid2
\param [in] subCellOrd - position of the subCell among of the subCells having the same dimension
\return pointer to the subCell basis of dimension subCellDim and position subCellOrd
*/
BasisPtr<ExecutionSpace, OutputValueType, PointValueType>
Teuchos::RCP<BasisBase>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change meaningful ? This means that the Teuchos::RCP<BasisBase> should be the same as BasisPtr. That just means that the BasisBase should be defined and inherited from the Basis class. I am not sure how this improves things in your implementation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is meaningful. BasisBase may have as its first template argument a DeviceType that is distinct from the ExecutionSpace. If so, then BasisPtr<ExecutionSpace, OutputValueType, PointValueType> does not match the return type in the base class.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work in the virtual override context if you change ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not entirely sure what you are asking. If you are asking whether things compile if I don't make this change, the answer is no, they do not compile.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean that both returns teuchos rcp and the compiler may consider this class is overriding function. If they consider this function is different, the compiler may redirect the function to its super class interface. Previously, it probably does not compile because it uses execution space template argument instead of device argument. As you mentioned, it compiles but does it override the function correctly when the object is casted to its super class (in other words, does this make sure that it has the same function signature to override) ? Unit test will tell this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kyungjoo-kim, the method is marked as override. If the compiler considered this to be a different function from that in the base class, it would refuse to compile it.

getSubCellRefBasis(const ordinal_type subCellDim, const ordinal_type subCellOrd) const override{

using LineBasis = HVOL_LINE;
Expand Down Expand Up @@ -674,7 +670,18 @@ namespace Intrepid2

INTREPID2_TEST_FOR_EXCEPTION(true,std::invalid_argument,"Input parameters out of bounds");
}

/** \brief Creates and returns a Basis object whose DeviceType template argument is Kokkos::HostSpace::device_type, but is otherwise identical to this.
\return Pointer to the new Basis object.
*/
virtual HostBasisPtr<OutputValueType, PointValueType>
getHostBasis() const override {
using HostBasis = Basis_Derived_HCURL_HEX<typename HGRAD_LINE::HostBasis, typename HVOL_LINE::HostBasis>;

auto hostBasis = Teuchos::rcp(new HostBasis(order_x_, order_y_, order_z_, pointType_));

return hostBasis;
}
};
} // end namespace Intrepid2

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ namespace Intrepid2
{
template<class HGRAD_LINE, class HVOL_LINE>
class Basis_Derived_HCURL_Family1_QUAD
: public Basis_TensorBasis<typename HGRAD_LINE::ExecutionSpace, typename HGRAD_LINE::OutputValueType, typename HGRAD_LINE::PointValueType>
: public Basis_TensorBasis<typename HGRAD_LINE::BasisBase>
{
public:
using ExecutionSpace = typename HGRAD_LINE::ExecutionSpace;
Expand All @@ -77,10 +77,12 @@ namespace Intrepid2
using PointViewType = typename HGRAD_LINE::PointViewType ;
using ScalarViewType = typename HGRAD_LINE::ScalarViewType;

using BasisBase = typename HGRAD_LINE::BasisBase;

using LineGradBasis = HGRAD_LINE;
using LineHVolBasis = HVOL_LINE;

using TensorBasis = Basis_TensorBasis<ExecutionSpace, OutputValueType, PointValueType>;
using TensorBasis = Basis_TensorBasis<BasisBase>;
public:
/** \brief Constructor.
\param [in] polyOrder_x - the polynomial order in the x dimension.
Expand Down Expand Up @@ -191,7 +193,7 @@ namespace Intrepid2

template<class HGRAD_LINE, class HVOL_LINE>
class Basis_Derived_HCURL_Family2_QUAD
: public Basis_TensorBasis<typename HGRAD_LINE::ExecutionSpace, typename HGRAD_LINE::OutputValueType, typename HGRAD_LINE::PointValueType>
: public Basis_TensorBasis<typename HGRAD_LINE::BasisBase>
{

public:
Expand All @@ -206,7 +208,9 @@ namespace Intrepid2
using LineGradBasis = HGRAD_LINE;
using LineHVolBasis = HVOL_LINE;

using TensorBasis = Basis_TensorBasis<ExecutionSpace, OutputValueType, PointValueType>;
using BasisBase = typename HGRAD_LINE::BasisBase;

using TensorBasis = Basis_TensorBasis<BasisBase>;

/** \brief Constructor.
\param [in] polyOrder_x - the polynomial order in the x dimension.
Expand Down Expand Up @@ -317,11 +321,13 @@ namespace Intrepid2

template<class HGRAD_LINE, class HVOL_LINE>
class Basis_Derived_HCURL_QUAD
: public Basis_DirectSumBasis <typename HGRAD_LINE::ExecutionSpace, typename HGRAD_LINE::OutputValueType, typename HGRAD_LINE::PointValueType>
: public Basis_DirectSumBasis <typename HGRAD_LINE::BasisBase>
{
using Family1 = Basis_Derived_HCURL_Family1_QUAD<HGRAD_LINE, HVOL_LINE>;
using Family2 = Basis_Derived_HCURL_Family2_QUAD<HGRAD_LINE, HVOL_LINE>;
using DirectSumBasis = Basis_DirectSumBasis <typename HGRAD_LINE::ExecutionSpace, typename HGRAD_LINE::OutputValueType, typename HGRAD_LINE::PointValueType>;
using DirectSumBasis = Basis_DirectSumBasis <typename HGRAD_LINE::BasisBase>;

using BasisBase = typename HGRAD_LINE::BasisBase;

protected:
std::string name_;
Expand Down Expand Up @@ -388,7 +394,7 @@ namespace Intrepid2
\param [in] subCellOrd - position of the subCell among of the subCells having the same dimension
\return pointer to the subCell basis of dimension subCellDim and position subCellOrd
*/
BasisPtr<ExecutionSpace, OutputValueType, PointValueType>
Teuchos::RCP<BasisBase>
getSubCellRefBasis(const ordinal_type subCellDim, const ordinal_type subCellOrd) const override{
if(subCellDim == 1) {
switch(subCellOrd) {
Expand All @@ -404,6 +410,18 @@ namespace Intrepid2
INTREPID2_TEST_FOR_EXCEPTION(true,std::invalid_argument,"Input parameters out of bounds");
}

/** \brief Creates and returns a Basis object whose DeviceType template argument is Kokkos::HostSpace::device_type, but is otherwise identical to this.
\return Pointer to the new Basis object.
*/
virtual HostBasisPtr<OutputValueType, PointValueType>
getHostBasis() const override {
using HostBasis = Basis_Derived_HCURL_QUAD<typename HGRAD_LINE::HostBasis, typename HVOL_LINE::HostBasis>;

auto hostBasis = Teuchos::rcp(new HostBasis(order_x_, order_y_, pointType_));

return hostBasis;
}
};
} // end namespace Intrepid2

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ namespace Intrepid2
template<class HGRAD_LINE, class HVOL_LINE>
class Basis_Derived_HDIV_Family1_HEX
:
public Basis_TensorBasis3<HGRAD_LINE, HVOL_LINE, HVOL_LINE>
public Basis_TensorBasis3<typename HGRAD_LINE::BasisBase>
{
public:
using OutputViewType = typename HGRAD_LINE::OutputViewType;
Expand All @@ -81,7 +81,7 @@ namespace Intrepid2
using LineGradBasis = HGRAD_LINE;
using LineHVolBasis = HVOL_LINE;

using TensorBasis3 = Basis_TensorBasis3<LineGradBasis, LineHVolBasis, LineHVolBasis>;
using TensorBasis3 = Basis_TensorBasis3<typename HGRAD_LINE::BasisBase>;
public:
/** \brief Constructor.
\param [in] polyOrder_x - the polynomial order in the x dimension.
Expand Down Expand Up @@ -205,7 +205,7 @@ namespace Intrepid2
template<class HGRAD_LINE, class HVOL_LINE>
class Basis_Derived_HDIV_Family2_HEX
:
public Basis_TensorBasis3<HVOL_LINE, HGRAD_LINE, HVOL_LINE>
public Basis_TensorBasis3<typename HGRAD_LINE::BasisBase>
{
public:
using OutputViewType = typename HGRAD_LINE::OutputViewType;
Expand All @@ -215,7 +215,7 @@ namespace Intrepid2
using LineGradBasis = HGRAD_LINE;
using LineHVolBasis = HVOL_LINE;

using TensorBasis3 = Basis_TensorBasis3<LineHVolBasis, LineGradBasis, LineHVolBasis>;
using TensorBasis3 = Basis_TensorBasis3<typename HGRAD_LINE::BasisBase>;
public:
/** \brief Constructor.
\param [in] polyOrder_x - the polynomial order in the x dimension.
Expand Down Expand Up @@ -344,7 +344,7 @@ namespace Intrepid2

template<class HGRAD_LINE, class HVOL_LINE>
class Basis_Derived_HDIV_Family3_HEX
: public Basis_TensorBasis3<HVOL_LINE, HVOL_LINE, HGRAD_LINE>
: public Basis_TensorBasis3<typename HGRAD_LINE::BasisBase>
{
public:
using OutputViewType = typename HGRAD_LINE::OutputViewType;
Expand All @@ -354,7 +354,7 @@ namespace Intrepid2
using LineGradBasis = HGRAD_LINE;
using LineHVolBasis = HVOL_LINE;

using TensorBasis3 = Basis_TensorBasis3<LineHVolBasis, LineHVolBasis, LineGradBasis>;
using TensorBasis3 = Basis_TensorBasis3<typename HGRAD_LINE::BasisBase>;
public:
/** \brief Constructor.
\param [in] polyOrder_x - the polynomial order in the x dimension.
Expand Down Expand Up @@ -483,11 +483,11 @@ namespace Intrepid2
// which is to say that we go 3,1,2.
template<class HGRAD_LINE, class HVOL_LINE>
class Basis_Derived_HDIV_Family3_Family1_HEX
: public Basis_DirectSumBasis <typename HGRAD_LINE::ExecutionSpace, typename HGRAD_LINE::OutputValueType, typename HGRAD_LINE::PointValueType>
: public Basis_DirectSumBasis <typename HGRAD_LINE::BasisBase>
{
using Family3 = Basis_Derived_HDIV_Family3_HEX<HGRAD_LINE, HVOL_LINE>;
using Family1 = Basis_Derived_HDIV_Family1_HEX<HGRAD_LINE, HVOL_LINE>;
using DirectSumBasis = Basis_DirectSumBasis<typename HGRAD_LINE::ExecutionSpace, typename HGRAD_LINE::OutputValueType, typename HGRAD_LINE::PointValueType>;
using DirectSumBasis = Basis_DirectSumBasis<typename HGRAD_LINE::BasisBase>;
public:
/** \brief Constructor.
\param [in] polyOrder_x - the polynomial order in the x dimension.
Expand All @@ -505,11 +505,11 @@ namespace Intrepid2

template<class HGRAD_LINE, class HVOL_LINE>
class Basis_Derived_HDIV_HEX
: public Basis_DirectSumBasis <typename HGRAD_LINE::ExecutionSpace, typename HGRAD_LINE::OutputValueType, typename HGRAD_LINE::PointValueType>
: public Basis_DirectSumBasis <typename HGRAD_LINE::BasisBase>
{
using Family31 = Basis_Derived_HDIV_Family3_Family1_HEX<HGRAD_LINE, HVOL_LINE>;
using Family2 = Basis_Derived_HDIV_Family2_HEX <HGRAD_LINE, HVOL_LINE>;
using DirectSumBasis = Basis_DirectSumBasis<typename HGRAD_LINE::ExecutionSpace, typename HGRAD_LINE::OutputValueType, typename HGRAD_LINE::PointValueType>;
using DirectSumBasis = Basis_DirectSumBasis<typename HGRAD_LINE::BasisBase>;

std::string name_;
ordinal_type order_x_;
Expand All @@ -521,6 +521,8 @@ namespace Intrepid2
using ExecutionSpace = typename HGRAD_LINE::ExecutionSpace;
using OutputValueType = typename HGRAD_LINE::OutputValueType;
using PointValueType = typename HGRAD_LINE::PointValueType;

using BasisBase = typename HGRAD_LINE::BasisBase;

/** \brief Constructor.
\param [in] polyOrder_x - the polynomial order in the x dimension.
Expand Down Expand Up @@ -575,8 +577,8 @@ namespace Intrepid2
\param [in] subCellOrd - position of the subCell among of the subCells having the same dimension
\return pointer to the subCell basis of dimension subCellDim and position subCellOrd
*/
BasisPtr<ExecutionSpace, OutputValueType, PointValueType>
getSubCellRefBasis(const ordinal_type subCellDim, const ordinal_type subCellOrd) const override{
Teuchos::RCP<BasisBase>
getSubCellRefBasis(const ordinal_type subCellDim, const ordinal_type subCellOrd) const override {

using QuadBasis = Basis_Derived_HVOL_QUAD<HVOL_LINE>;

Expand All @@ -599,7 +601,19 @@ namespace Intrepid2

INTREPID2_TEST_FOR_EXCEPTION(true,std::invalid_argument,"Input parameters out of bounds");
}


/** \brief Creates and returns a Basis object whose DeviceType template argument is Kokkos::HostSpace::device_type, but is otherwise identical to this.

\return Pointer to the new Basis object.
*/
virtual HostBasisPtr<OutputValueType, PointValueType>
getHostBasis() const override {
using HostBasis = Basis_Derived_HDIV_HEX<typename HGRAD_LINE::HostBasis, typename HVOL_LINE::HostBasis>;

auto hostBasis = Teuchos::rcp(new HostBasis(order_x_, order_y_, order_z_, pointType_));

return hostBasis;
}
};
} // end namespace Intrepid2

Expand Down
Loading