Skip to content

Commit

Permalink
Merge Pull Request #7477 from etphipp/Trilinos/sacado_fixes
Browse files Browse the repository at this point in the history
Automatically Merged using Trilinos Pull Request AutoTester
PR Title: Sacado:  Several changes for Sacado
PR Author: etphipp
  • Loading branch information
trilinos-autotester authored Jun 7, 2020
2 parents 1691771 + 3cf736d commit 0749fa0
Show file tree
Hide file tree
Showing 33 changed files with 556 additions and 2,585 deletions.
7 changes: 7 additions & 0 deletions packages/sacado/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ TRIBITS_ADD_OPTION_AND_DEFINE(
ON
)

TRIBITS_ADD_OPTION_AND_DEFINE(
${PACKAGE_NAME}_SFAD_INIT_DEFAULT_CONSTRUCTOR
SACADO_SFAD_INIT_DEFAULT_CONSTRUCTOR
"Force SFad (in the new design) to initialize value and derivative components in the default constructor (adds additional runtime cost, but protects against uninitialized use)."
OFF
)

TRIBITS_ADD_OPTION_AND_DEFINE(
${PACKAGE_NAME}_ENABLE_HIERARCHICAL
SACADO_VIEW_CUDA_HIERARCHICAL
Expand Down
3 changes: 3 additions & 0 deletions packages/sacado/cmake/Sacado_config.h.in
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,6 @@

/* Define if want to make the new Fad design the default, replacing the old one */
#cmakedefine SACADO_NEW_FAD_DESIGN_IS_DEFAULT

/* Force SFad (in the new design) to initialize value and derivative components in the default constructor (adds additional runtime cost, but protects against uninitialized use). */
#cmakedefine SACADO_SFAD_INIT_DEFAULT_CONSTRUCTOR
4 changes: 4 additions & 0 deletions packages/sacado/src/Sacado_CacheFad_Ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,10 @@ FAD_UNARYOP_MACRO(sqrt,
SqrtOp,
a = value_type(1)/(value_type(2)*std::sqrt(v)),
std::sqrt(v))
FAD_UNARYOP_MACRO(safe_sqrt,
SafeSqrtOp,
a = (v == value_type(0.0) ? value_type(0.0) : value_type(value_type(1)/(value_type(2)*std::sqrt(v)))),
std::sqrt(v))
FAD_UNARYOP_MACRO(cos,
CosOp,
a = -std::sin(v),
Expand Down
4 changes: 4 additions & 0 deletions packages/sacado/src/Sacado_ConfigDefs.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ Questions? Contact David M. Gay ([email protected]) or Eric T. Phipps
#define KOKKOS_FUNCTION /* */
#endif

#ifndef KOKKOS_DEFAULTED_FUNCTION
#define KOKKOS_DEFAULTED_FUNCTION /* */
#endif

#ifndef KOKKOS_INLINE_FUNCTION
#define KOKKOS_INLINE_FUNCTION inline
#endif
Expand Down
4 changes: 4 additions & 0 deletions packages/sacado/src/Sacado_ELRCacheFad_Ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,10 @@ FAD_UNARYOP_MACRO(sqrt,
SqrtOp,
a = scalar_type(1.0)/(scalar_type(2.0)*std::sqrt(v)),
std::sqrt(v))
FAD_UNARYOP_MACRO(safe_sqrt,
SafeSqrtOp,
a = (v == value_type(0.0) ? value_type(0.0) : value_type(scalar_type(1.0)/(scalar_type(2.0)*std::sqrt(v)))),
std::sqrt(v))
FAD_UNARYOP_MACRO(cos,
CosOp,
a = -std::sin(v),
Expand Down
7 changes: 7 additions & 0 deletions packages/sacado/src/Sacado_ELRFad_Ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,13 @@ FAD_UNARYOP_MACRO(sqrt,
false,
expr.dx(i)/(value_type(2)* std::sqrt(expr.val())),
expr.fastAccessDx(i)/(value_type(2)* std::sqrt(expr.val())))
FAD_UNARYOP_MACRO(safe_sqrt,
SafeSqrtOp,
std::sqrt(expr.val()),
expr.val() == value_type(0.0) ? value_type(0.0) : value_type(value_type(0.5)*bar/std::sqrt(expr.val())),
false,
expr.val() == value_type(0.0) ? value_type(0.0) : value_type(expr.dx(i)/(value_type(2)*std::sqrt(expr.val()))),
expr.val() == value_type(0.0) ? value_type(0.0) : value_type(expr.fastAccessDx(i)/(value_type(2)*std::sqrt(expr.val()))))
FAD_UNARYOP_MACRO(cos,
CosOp,
std::cos(expr.val()),
Expand Down
138 changes: 138 additions & 0 deletions packages/sacado/src/Sacado_Fad_Ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,144 @@ FAD_UNARYOP_MACRO(cbrt,

#undef FAD_UNARYOP_MACRO

// Special handling for safe_sqrt() to provide specializations of SafeSqrtOp for
// "simd" value types that use if_then_else(). The only reason for not using
// if_then_else() always is to avoid evaluating the derivative if the value is
// zero to avoid throwing FPEs.
namespace Sacado {
namespace Fad {

template <typename ExprT, bool is_simd>
class SafeSqrtOp {};

template <typename ExprT>
struct ExprSpec< SafeSqrtOp<ExprT> > {
typedef typename ExprSpec<ExprT>::type type;
};

//
// Implementation for simd type using if_then_else()
//
template <typename ExprT>
class Expr< SafeSqrtOp<ExprT,true>,ExprSpecDefault > {
public:

typedef typename ExprT::value_type value_type;
typedef typename ExprT::scalar_type scalar_type;
typedef typename ExprT::base_expr_type base_expr_type;

KOKKOS_INLINE_FUNCTION
explicit Expr(const ExprT& expr_) : expr(expr_) {}

KOKKOS_INLINE_FUNCTION
int size() const { return expr.size(); }

KOKKOS_INLINE_FUNCTION
bool hasFastAccess() const { return expr.hasFastAccess(); }

KOKKOS_INLINE_FUNCTION
bool isPassive() const { return expr.isPassive();}

KOKKOS_INLINE_FUNCTION
bool updateValue() const { return expr.updateValue(); }

KOKKOS_INLINE_FUNCTION
void cache() const {}

KOKKOS_INLINE_FUNCTION
value_type val() const {
using std::sqrt;
return sqrt(expr.val());
}

KOKKOS_INLINE_FUNCTION
value_type dx(int i) const {
using std::sqrt;
return if_then_else(
expr.val() == value_type(0.0), value_type(0.0),
value_type(expr.dx(i)/(value_type(2)*sqrt(expr.val()))));
}

KOKKOS_INLINE_FUNCTION
value_type fastAccessDx(int i) const {
using std::sqrt;
return if_then_else(
expr.val() == value_type(0.0), value_type(0.0),
value_type(expr.fastAccessDx(i)/(value_type(2)*sqrt(expr.val()))));
}

protected:

const ExprT& expr;
};

//
// Specialization for scalar types using ternary operator
//
template <typename ExprT>
class Expr< SafeSqrtOp<ExprT,false>,ExprSpecDefault > {
public:

typedef typename ExprT::value_type value_type;
typedef typename ExprT::scalar_type scalar_type;
typedef typename ExprT::base_expr_type base_expr_type;

KOKKOS_INLINE_FUNCTION
explicit Expr(const ExprT& expr_) : expr(expr_) {}

KOKKOS_INLINE_FUNCTION
int size() const { return expr.size(); }

KOKKOS_INLINE_FUNCTION
bool hasFastAccess() const { return expr.hasFastAccess(); }

KOKKOS_INLINE_FUNCTION
bool isPassive() const { return expr.isPassive();}

KOKKOS_INLINE_FUNCTION
bool updateValue() const { return expr.updateValue(); }

KOKKOS_INLINE_FUNCTION
void cache() const {}

KOKKOS_INLINE_FUNCTION
value_type val() const {
using std::sqrt;
return sqrt(expr.val());
}

KOKKOS_INLINE_FUNCTION
value_type dx(int i) const {
using std::sqrt;
return expr.val() == value_type(0.0) ? value_type(0.0) :
value_type(expr.dx(i)/(value_type(2)*sqrt(expr.val())));
}

KOKKOS_INLINE_FUNCTION
value_type fastAccessDx(int i) const {
using std::sqrt;
return expr.val() == value_type(0.0) ? value_type(0.0) :
value_type(expr.fastAccessDx(i)/(value_type(2)*sqrt(expr.val())));
}

protected:

const ExprT& expr;
};

template <typename T>
KOKKOS_INLINE_FUNCTION
Expr< SafeSqrtOp< Expr<T> > >
safe_sqrt (const Expr<T>& expr)
{
typedef SafeSqrtOp< Expr<T> > expr_t;

return Expr<expr_t>(expr);
}
}

}

#define FAD_BINARYOP_MACRO(OPNAME,OP,USING,VALUE,DX,FASTACCESSDX,VAL_CONST_DX_1,VAL_CONST_DX_2,CONST_DX_1,CONST_DX_2,CONST_FASTACCESSDX_1,CONST_FASTACCESSDX_2) \
namespace Sacado { \
namespace Fad { \
Expand Down
2 changes: 2 additions & 0 deletions packages/sacado/src/Sacado_Fad_Ops_Fwd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ namespace Sacado {
#ifdef HAVE_SACADO_CXX11
template <typename ExprT> class CbrtOp;
#endif
template <typename ExprT, bool is_simd = IsSimdType<ExprT>::value>
class SafeSqrtOp;

template <typename ExprT1, typename ExprT2> class AdditionOp;
template <typename ExprT1, typename ExprT2> class SubtractionOp;
Expand Down
29 changes: 29 additions & 0 deletions packages/sacado/src/Sacado_MathFunctions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,35 @@ UNARYFUNC_MACRO(cbrt, CbrtOp)

#undef UNARYFUNC_MACRO

namespace Sacado {
namespace Fad {
template <typename T>
KOKKOS_INLINE_FUNCTION
Expr< SafeSqrtOp< Expr<T> > > safe_sqrt (const Expr<T>&);
}

namespace ELRFad {
template <typename T> class SafeSqrtOp;
template <typename T>
KOKKOS_INLINE_FUNCTION
Expr< SafeSqrtOp< Expr<T> > > safe_sqrt (const Expr<T>&);
}

namespace CacheFad {
template <typename T> class SafeSqrtOp;
template <typename T>
KOKKOS_INLINE_FUNCTION
Expr< SafeSqrtOp< Expr<T> > > safe_sqrt (const Expr<T>&);
}

namespace ELRCacheFad {
template <typename T> class SafeSqrtOp;
template <typename T>
KOKKOS_INLINE_FUNCTION
Expr< SafeSqrtOp< Expr<T> > > safe_sqrt (const Expr<T>&);
}
}

#define BINARYFUNC_MACRO(OP,FADOP) \
namespace Sacado { \
\
Expand Down
9 changes: 9 additions & 0 deletions packages/sacado/src/Sacado_cmath.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,15 @@ namespace Sacado {
return cond ? a : b;
}

// Special version of sqrt(x) that avoids the NaN if x==0 in the derivative.
// The default implementation just calls the standard sqrt(x).
template <typename T>
KOKKOS_INLINE_FUNCTION
T safe_sqrt(const T& x) {
using std::sqrt;
return sqrt(x);
}

}

#endif // SACADO_CMATH_HPP
6 changes: 6 additions & 0 deletions packages/sacado/src/new_design/Sacado_Fad_Exp_GeneralFad.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,15 @@ namespace Sacado {
using ExtenderType::ExtenderType;

//! Default constructor
KOKKOS_DEFAULTED_FUNCTION
GeneralFad() = default;

//! Copy constructor
KOKKOS_DEFAULTED_FUNCTION
GeneralFad(const GeneralFad& x) = default;

//! Move constructor
KOKKOS_DEFAULTED_FUNCTION
GeneralFad(GeneralFad&& x) = default;

//! Constructor with value (disabled for ViewFad)
Expand All @@ -118,6 +121,7 @@ namespace Sacado {
}

//! Destructor
KOKKOS_DEFAULTED_FUNCTION
~GeneralFad() = default;

//! Set %GeneralFad object as the \c ith independent variable
Expand Down Expand Up @@ -212,10 +216,12 @@ namespace Sacado {
}

//! Assignment with GeneralFad right-hand-side
KOKKOS_DEFAULTED_FUNCTION
GeneralFad&
operator=(const GeneralFad& x) = default;

//! Move assignment with GeneralFad right-hand-side
KOKKOS_DEFAULTED_FUNCTION
GeneralFad&
operator=(GeneralFad&& x) = default;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ UNARYFUNC_MACRO(exp, ExpOp)
UNARYFUNC_MACRO(log, LogOp)
UNARYFUNC_MACRO(log10, Log10Op)
UNARYFUNC_MACRO(sqrt, SqrtOp)
UNARYFUNC_MACRO(safe_sqrt, SafeSqrtOp)
UNARYFUNC_MACRO(cos, CosOp)
UNARYFUNC_MACRO(sin, SinOp)
UNARYFUNC_MACRO(tan, TanOp)
Expand Down
Loading

0 comments on commit 0749fa0

Please sign in to comment.