Skip to content

Commit

Permalink
Merge pull request #575 from mhoemmen/Fix-574
Browse files Browse the repository at this point in the history
Fix #574 (improve KokkosBlas::dot accuracy for float & complex<float>)
  • Loading branch information
ndellingwood authored Jan 27, 2020
2 parents 13ceecc + 66c2a6c commit 2872ee5
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 9 deletions.
68 changes: 64 additions & 4 deletions src/Kokkos_InnerProductSpaceTraits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,8 @@ class InnerProductSpaceTraits {
/// complex. In that case, see the partial specialization for
/// Kokkos::complex below to see our convention for which input gets
/// conjugated.
static KOKKOS_FORCEINLINE_FUNCTION dot_type
dot (const val_type& x, const val_type& y) {
static KOKKOS_FORCEINLINE_FUNCTION
dot_type dot (const val_type& x, const val_type& y) {
return x * y;
}
};
Expand Down Expand Up @@ -196,8 +196,8 @@ class InnerProductSpaceTraits<Kokkos::complex<T> > {
mag_type norm (const val_type& x) {
return ArithTraits<val_type>::abs (x);
}
static KOKKOS_FORCEINLINE_FUNCTION dot_type
dot (const val_type& x, const val_type& y) {
static KOKKOS_FORCEINLINE_FUNCTION
dot_type dot (const val_type& x, const val_type& y) {
return Kokkos::conj (x) * y;
}
};
Expand Down Expand Up @@ -291,6 +291,66 @@ struct InnerProductSpaceTraits<qd_real>
};
#endif // HAVE_KOKKOS_QD

template<class ResultType, class InputType1, class InputType2>
KOKKOS_INLINE_FUNCTION void
updateDot(ResultType& sum, const InputType1& x, const InputType2& y)
{
// FIXME (mfh 22 Jan 2020) We should actually pick the type with the
// greater precision.
sum += InnerProductSpaceTraits<InputType1>::dot(x, y);
}

KOKKOS_INLINE_FUNCTION void
updateDot(double& sum, const double x, const double y)
{
sum += x * y;
}

KOKKOS_INLINE_FUNCTION void
updateDot(double& sum, const float x, const float y)
{
sum += x * y;
}

// This exists because complex<float> += complex<double> is not defined.
KOKKOS_INLINE_FUNCTION void
updateDot(Kokkos::complex<double>& sum,
const Kokkos::complex<float> x,
const Kokkos::complex<float> y)
{
const auto tmp = Kokkos::conj(x) * y;
sum += Kokkos::complex<double>(tmp.real(), tmp.imag());
}

// This exists in case people call the overload of KokkosBlas::dot
// that takes an output View, and the output View has element type
// Kokkos::complex<float>.
KOKKOS_INLINE_FUNCTION void
updateDot(Kokkos::complex<float>& sum,
const Kokkos::complex<float> x,
const Kokkos::complex<float> y)
{
sum += Kokkos::conj(x) * y;
}

// This exists because Kokkos::complex<double> =
// Kokkos::complex<float> is not defined.
template<class Out, class In>
struct CastPossiblyComplex {
static Out cast(const In& x) {
return x;
}
};

template<class OutReal, class InReal>
struct CastPossiblyComplex<Kokkos::complex<OutReal>, Kokkos::complex<InReal>>
{
static Kokkos::complex<OutReal>
cast (const Kokkos::complex<InReal>& x) {
return {x.real(), x.imag()};
}
};

} // namespace Details
} // namespace Kokkos

Expand Down
28 changes: 24 additions & 4 deletions src/blas/KokkosBlas1_dot.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,19 +90,39 @@ dot (const XVector& x, const YVector& y)
typename YVector::device_type,
Kokkos::MemoryTraits<Kokkos::Unmanaged> > YVector_Internal;

typedef Kokkos::View<typename XVector::non_const_value_type,
using dot_type =
typename Kokkos::Details::InnerProductSpaceTraits<
typename XVector::non_const_value_type>::dot_type;
// Some platforms, such as Mac Clang, seem to get poor accuracy with
// float and complex<float>. Work around some Trilinos test
// failures by using a higher-precision type for intermediate dot
// product sums.
constexpr bool is_complex_float =
std::is_same<dot_type, Kokkos::complex<float>>::value;
constexpr bool is_real_float = std::is_same<dot_type, float>::value;
using result_type = typename std::conditional<is_complex_float,
Kokkos::complex<double>,
typename std::conditional<is_real_float,
double,
dot_type
>::type
>::type;
using RVector_Internal = Kokkos::View<result_type,
Kokkos::LayoutLeft,
Kokkos::HostSpace,
Kokkos::MemoryTraits<Kokkos::Unmanaged> > RVector_Internal;
Kokkos::MemoryTraits<Kokkos::Unmanaged>>;

typename XVector::non_const_value_type result = 0;
result_type result {};
RVector_Internal R = RVector_Internal(&result);
XVector_Internal X = x;
YVector_Internal Y = y;

Impl::Dot<RVector_Internal,XVector_Internal,YVector_Internal>::dot (R,X,Y);
Kokkos::fence();
return result;
// mfh 22 Jan 2020: We need the line below because
// Kokkos::complex<T> lacks a constructor that takes a
// Kokkos::complex<U> with U != T.
return Kokkos::Details::CastPossiblyComplex<dot_type, result_type>::cast(result);
}

/// \brief Compute the column-wise dot products of two multivectors.
Expand Down
2 changes: 1 addition & 1 deletion src/blas/impl/KokkosBlas1_dot_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ struct DotFunctor
KOKKOS_FORCEINLINE_FUNCTION void
operator() (const size_type &i, value_type& sum) const
{
sum += IPT::dot (m_x(i), m_y(i)); // m_x(i) * m_y(i)
Kokkos::Details::updateDot(sum, m_x(i), m_y(i)); // sum += m_x(i) * m_y(i)
}

KOKKOS_INLINE_FUNCTION void
Expand Down

0 comments on commit 2872ee5

Please sign in to comment.