diff --git a/simd/src/Kokkos_SIMD.hpp b/simd/src/Kokkos_SIMD.hpp index e5d54b0ff1..9280763407 100644 --- a/simd/src/Kokkos_SIMD.hpp +++ b/simd/src/Kokkos_SIMD.hpp @@ -29,6 +29,10 @@ #include #endif +#ifdef __ARM_NEON +#include +#endif + namespace Kokkos { namespace Experimental { @@ -40,6 +44,8 @@ namespace Impl { using host_native = avx512_fixed_size<8>; #elif defined(KOKKOS_ARCH_AVX2) using host_native = avx2_fixed_size<4>; +#elif defined(__ARM_NEON) +using host_native = neon_fixed_size<2>; #else using host_native = scalar; #endif @@ -134,6 +140,8 @@ class abi_set {}; using host_abi_set = abi_set>; #elif defined(KOKKOS_ARCH_AVX2) using host_abi_set = abi_set>; +#elif defined(__ARM_NEON) +using host_abi_set = abi_set>; #else using host_abi_set = abi_set; #endif diff --git a/simd/src/Kokkos_SIMD_Common.hpp b/simd/src/Kokkos_SIMD_Common.hpp index 9b2c0f81d7..c29d49fb3a 100644 --- a/simd/src/Kokkos_SIMD_Common.hpp +++ b/simd/src/Kokkos_SIMD_Common.hpp @@ -136,6 +136,34 @@ template return simd([&](std::size_t i) { return lhs[i] * rhs[i]; }); } +// fallback simd shift using generator constructor +// At the time of this writing, these fallbacks are only used +// to shift vectors of 64-bit unsigned integers for the NEON backend + +template +[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd operator>>( + simd const& lhs, unsigned int rhs) { + return simd([&](std::size_t i) { return lhs[i] >> rhs; }); +} + +template +[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd operator<<( + simd const& lhs, unsigned int rhs) { + return simd([&](std::size_t i) { return lhs[i] << rhs; }); +} + +template +[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd operator>>( + simd const& lhs, simd const& rhs) { + return simd([&](std::size_t i) { return lhs[i] >> rhs[i]; }); +} + +template +[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd operator<<( + simd const& lhs, simd const& rhs) { + return simd([&](std::size_t i) { return lhs[i] << rhs[i]; }); +} + // The code below provides: // operator@(simd, Arithmetic) // operator@(Arithmetic, simd) diff --git a/simd/src/Kokkos_SIMD_NEON.hpp b/simd/src/Kokkos_SIMD_NEON.hpp new file mode 100644 index 0000000000..2473004098 --- /dev/null +++ b/simd/src/Kokkos_SIMD_NEON.hpp @@ -0,0 +1,995 @@ +//@HEADER +// ************************************************************************ +// +// Kokkos v. 4.0 +// Copyright (2022) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions. +// See https://kokkos.org/LICENSE for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//@HEADER + +#ifndef KOKKOS_SIMD_NEON_HPP +#define KOKKOS_SIMD_NEON_HPP + +#include +#include + +#include + +#include + +namespace Kokkos { + +namespace Experimental { + +namespace simd_abi { + +template +class neon_fixed_size {}; + +} // namespace simd_abi + +namespace Impl { + +template +class neon_mask; + +template +class neon_mask { + uint64x2_t m_value; + + public: + class reference { + uint64x2_t& m_mask; + int m_lane; + + public: + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION reference(uint64x2_t& mask_arg, + int lane_arg) + : m_mask(mask_arg), m_lane(lane_arg) {} + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION reference + operator=(bool value) const { + // this switch statement is needed because the lane argument has to be a + // constant + switch (m_lane) { + case 0: + m_mask = vsetq_lane_u64(value ? 0xFFFFFFFFFFFFFFFFULL : 0, m_mask, 0); + break; + case 1: + m_mask = vsetq_lane_u64(value ? 0xFFFFFFFFFFFFFFFFULL : 0, m_mask, 1); + break; + } + return *this; + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION operator bool() const { + switch (m_lane) { + case 0: return vgetq_lane_u64(m_mask, 0) != 0; + case 1: return vgetq_lane_u64(m_mask, 1) != 0; + } + return false; + } + }; + using value_type = bool; + using abi_type = simd_abi::neon_fixed_size<2>; + using implementation_type = uint64x2_t; + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION neon_mask() = default; + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION explicit neon_mask(value_type value) + : m_value(vmovq_n_u64(value ? 0xFFFFFFFFFFFFFFFFULL : 0)) {} + template + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION neon_mask( + neon_mask const& other) { + operator[](0) = bool(other[0]); + operator[](1) = bool(other[1]); + } + template + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION neon_mask(neon_mask const& other) + : neon_mask(static_cast(other)) {} + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION static constexpr std::size_t size() { + return 2; + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr explicit neon_mask( + uint64x2_t const& value_in) + : m_value(value_in) {} + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr explicit operator uint64x2_t() + const { + return m_value; + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION reference operator[](std::size_t i) { + return reference(m_value, int(i)); + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION value_type + operator[](std::size_t i) const { + return static_cast( + reference(const_cast(m_value), int(i))); + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION Derived + operator||(neon_mask const& other) const { + return Derived(vorrq_u64(m_value, other.m_value)); + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION Derived + operator&&(neon_mask const& other) const { + return Derived(vandq_u64(m_value, other.m_value)); + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION Derived operator!() const { + auto const true_value = static_cast(neon_mask(true)); + return Derived(veorq_u64(m_value, true_value)); + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION bool operator==( + neon_mask const& other) const { + uint64x2_t const elementwise_equality = vceqq_u64(m_value, other.m_value); + uint32x2_t const narrow_elementwise_equality = + vqmovn_u64(elementwise_equality); + uint64x1_t const overall_equality_neon = + vreinterpret_u64_u32(narrow_elementwise_equality); + uint64_t const overall_equality = vget_lane_u64(overall_equality_neon, 0); + return overall_equality == 0xFFFFFFFFFFFFFFFFULL; + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION bool operator!=( + neon_mask const& other) const { + return !operator==(other); + } +}; + +template +class neon_mask { + uint32x2_t m_value; + + public: + class reference { + uint32x2_t& m_mask; + int m_lane; + + public: + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION reference(uint32x2_t& mask_arg, + int lane_arg) + : m_mask(mask_arg), m_lane(lane_arg) {} + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION reference + operator=(bool value) const { + switch (m_lane) { + case 0: + m_mask = vset_lane_u32(value ? 0xFFFFFFFFU : 0, m_mask, 0); + break; + case 1: + m_mask = vset_lane_u32(value ? 0xFFFFFFFFU : 0, m_mask, 1); + break; + } + return *this; + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION operator bool() const { + switch (m_lane) { + case 0: return vget_lane_u32(m_mask, 0) != 0; + case 1: return vget_lane_u32(m_mask, 1) != 0; + } + return false; + } + }; + using value_type = bool; + using abi_type = simd_abi::neon_fixed_size<2>; + using implementation_type = uint32x2_t; + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION neon_mask() = default; + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION explicit neon_mask(value_type value) + : m_value(vmov_n_u32(value ? 0xFFFFFFFFU : 0)) {} + template + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION neon_mask(neon_mask const& other) + : m_value(vqmovn_u64(static_cast(other))) {} + template + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION neon_mask(neon_mask const& other) + : m_value(static_cast(other)) {} + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION static constexpr std::size_t size() { + return 2; + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr explicit neon_mask( + uint32x2_t const& value_in) + : m_value(value_in) {} + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr explicit operator uint32x2_t() + const { + return m_value; + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION reference operator[](std::size_t i) { + return reference(m_value, int(i)); + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION value_type + operator[](std::size_t i) const { + return static_cast( + reference(const_cast(m_value), int(i))); + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION Derived + operator||(neon_mask const& other) const { + return Derived(vorr_u32(m_value, other.m_value)); + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION Derived + operator&&(neon_mask const& other) const { + return Derived(vand_u32(m_value, other.m_value)); + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION Derived operator!() const { + auto const true_value = static_cast(neon_mask(true)); + return Derived(veor_u32(m_value, true_value)); + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION bool operator==( + neon_mask const& other) const { + uint32x2_t const elementwise_equality = vceq_u32(m_value, other.m_value); + uint64x1_t const overall_equality_neon = + vreinterpret_u64_u32(elementwise_equality); + uint64_t const overall_equality = vget_lane_u64(overall_equality_neon, 0); + return overall_equality == 0xFFFFFFFFFFFFFFFFULL; + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION bool operator!=( + neon_mask const& other) const { + return !operator==(other); + } +}; + +} // namespace Impl + +template +class simd_mask> + : public Impl::neon_mask>, + sizeof(T) * 8> { + using base_type = Impl::neon_mask>, + sizeof(T) * 8>; + + public: + using implementation_type = typename base_type::implementation_type; + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd_mask() = default; + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION explicit simd_mask(bool value) + : base_type(value) {} + template + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd_mask( + simd_mask> const& other) + : base_type(other) {} + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr explicit simd_mask( + implementation_type const& value) + : base_type(value) {} +}; + +template <> +class simd> { + float64x2_t m_value; + + public: + using value_type = double; + using abi_type = simd_abi::neon_fixed_size<2>; + using mask_type = simd_mask; + class reference { + float64x2_t& m_value; + int m_lane; + + public: + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION reference(float64x2_t& mask_arg, + int lane_arg) + : m_value(mask_arg), m_lane(lane_arg) {} + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION reference + operator=(double value) const { + switch (m_lane) { + case 0: m_value = vsetq_lane_f64(value, m_value, 0); break; + case 1: m_value = vsetq_lane_f64(value, m_value, 1); break; + } + return *this; + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION operator double() const { + switch (m_lane) { + case 0: return vgetq_lane_f64(m_value, 0); + case 1: return vgetq_lane_f64(m_value, 1); + } + return 0; + } + }; + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd() = default; + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd(simd const&) = default; + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd(simd&&) = default; + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd& operator=(simd const&) = default; + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd& operator=(simd&&) = default; + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION static constexpr std::size_t size() { + return 2; + } + template , + bool> = false> + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd(U&& value) + : m_value(vmovq_n_f64(value_type(value))) {} + template ()); } + std::is_invocable_r_v>, + bool> = false> + KOKKOS_FORCEINLINE_FUNCTION simd(G&& gen) { + m_value = vsetq_lane_f64(gen(std::integral_constant()), + m_value, 0); + m_value = vsetq_lane_f64(gen(std::integral_constant()), + m_value, 1); + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr explicit simd( + float64x2_t const& value_in) + : m_value(value_in) {} + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION reference operator[](std::size_t i) { + return reference(m_value, int(i)); + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION value_type + operator[](std::size_t i) const { + return reference(const_cast(this)->m_value, int(i)); + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION void copy_from(value_type const* ptr, + element_aligned_tag) { + m_value = vld1q_f64(ptr); + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION void copy_to( + value_type* ptr, element_aligned_tag) const { + vst1q_f64(ptr, m_value); + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr explicit + operator float64x2_t() const { + return m_value; + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION mask_type + operator<(simd const& other) const { + return mask_type(vcltq_f64(m_value, other.m_value)); + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION mask_type + operator>(simd const& other) const { + return mask_type(vcgtq_f64(m_value, other.m_value)); + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION mask_type + operator<=(simd const& other) const { + return mask_type(vcleq_f64(m_value, other.m_value)); + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION mask_type + operator>=(simd const& other) const { + return mask_type(vcgeq_f64(m_value, other.m_value)); + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION mask_type + operator==(simd const& other) const { + return mask_type(vceqq_f64(m_value, other.m_value)); + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION mask_type + operator!=(simd const& other) const { + return !(operator==(other)); + } +}; + +[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION + simd> + operator*(simd> const& lhs, + simd> const& rhs) { + return simd>( + vmulq_f64(static_cast(lhs), static_cast(rhs))); +} + +[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION + simd> + operator/(simd> const& lhs, + simd> const& rhs) { + return simd>( + vdivq_f64(static_cast(lhs), static_cast(rhs))); +} + +[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION + simd> + operator+(simd> const& lhs, + simd> const& rhs) { + return simd>( + vaddq_f64(static_cast(lhs), static_cast(rhs))); +} + +[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION + simd> + operator-(simd> const& lhs, + simd> const& rhs) { + return simd>( + vsubq_f64(static_cast(lhs), static_cast(rhs))); +} + +[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION + simd> + operator-(simd> const& a) { + return simd>( + vnegq_f64(static_cast(a))); +} + +KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION +simd> abs( + simd> const& a) { + return simd>( + vabsq_f64(static_cast(a))); +} + +KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION +simd> copysign( + simd> const& a, + simd> const& b) { + uint64x2_t const sign_mask = vreinterpretq_u64_f64(vmovq_n_f64(-0.0)); + return simd>(vreinterpretq_f64_u64( + vorrq_u64(vreinterpretq_u64_f64(static_cast(abs(a))), + vandq_u64(sign_mask, vreinterpretq_u64_f64( + static_cast(b)))))); +} + +KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION +simd> sqrt( + simd> const& a) { + return simd>( + vsqrtq_f64(static_cast(a))); +} + +KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION +simd> fma( + simd> const& a, + simd> const& b, + simd> const& c) { + return simd>( + vfmaq_f64(static_cast(c), static_cast(b), + static_cast(a))); +} + +KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION +simd> max( + simd> const& a, + simd> const& b) { + return simd>( + vmaxq_f64(static_cast(a), static_cast(b))); +} + +KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION +simd> min( + simd> const& a, + simd> const& b) { + return simd>( + vminq_f64(static_cast(a), static_cast(b))); +} + +[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION + simd> + condition(simd_mask> const& a, + simd> const& b, + simd> const& c) { + return simd>( + vbslq_f64(static_cast(a), static_cast(b), + static_cast(c))); +} + +template <> +class simd> { + int32x2_t m_value; + + public: + using value_type = std::int32_t; + using abi_type = simd_abi::neon_fixed_size<2>; + using mask_type = simd_mask; + class reference { + int32x2_t& m_value; + int m_lane; + + public: + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION reference(int32x2_t& value_arg, + int lane_arg) + : m_value(value_arg), m_lane(lane_arg) {} + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION reference + operator=(std::int32_t value) const { + switch (m_lane) { + case 0: m_value = vset_lane_s32(value, m_value, 0); break; + case 1: m_value = vset_lane_s32(value, m_value, 1); break; + } + return *this; + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION operator std::int32_t() const { + switch (m_lane) { + case 0: return vget_lane_s32(m_value, 0); + case 1: return vget_lane_s32(m_value, 1); + } + return 0; + } + }; + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd() = default; + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd(simd const&) = default; + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd(simd&&) = default; + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd& operator=(simd const&) = default; + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd& operator=(simd&&) = default; + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION static constexpr std::size_t size() { + return 2; + } + template , + bool> = false> + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd(U&& value) + : m_value(vmov_n_s32(value_type(value))) {} + template >, + bool> = false> + KOKKOS_FORCEINLINE_FUNCTION simd(G&& gen) { + m_value = vset_lane_s32(gen(std::integral_constant()), + m_value, 0); + m_value = vset_lane_s32(gen(std::integral_constant()), + m_value, 1); + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr explicit simd( + int32x2_t const& value_in) + : m_value(value_in) {} + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION explicit simd( + simd const& other); + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION reference operator[](std::size_t i) { + return reference(m_value, int(i)); + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION value_type + operator[](std::size_t i) const { + return reference(const_cast(this)->m_value, int(i)); + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION void copy_from(value_type const* ptr, + element_aligned_tag) { + m_value = vld1_s32(ptr); + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION void copy_to( + value_type* ptr, element_aligned_tag) const { + vst1_s32(ptr, m_value); + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr explicit operator int32x2_t() + const { + return m_value; + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION mask_type + operator==(simd const& other) const { + return mask_type(vceq_s32(m_value, other.m_value)); + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION mask_type + operator>(simd const& other) const { + return mask_type(vcgt_s32(m_value, other.m_value)); + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION mask_type + operator<(simd const& other) const { + return mask_type(vclt_s32(m_value, other.m_value)); + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION mask_type + operator<=(simd const& other) const { + return mask_type(vcle_s32(m_value, other.m_value)); + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION mask_type + operator>=(simd const& other) const { + return mask_type(vcge_s32(m_value, other.m_value)); + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION mask_type + operator!=(simd const& other) const { + return !((*this) == other); + } +}; + +[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION + simd> + operator-(simd> const& a) { + return simd>( + vneg_s32(static_cast(a))); +} + +[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION + simd> + operator-(simd> const& lhs, + simd> const& rhs) { + return simd>( + vsub_s32(static_cast(lhs), static_cast(rhs))); +} + +[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION + simd> + operator+(simd> const& lhs, + simd> const& rhs) { + return simd>( + vadd_s32(static_cast(lhs), static_cast(rhs))); +} + +[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION + simd> + condition(simd_mask> const& a, + simd> const& b, + simd> const& c) { + return simd>( + vbsl_s32(static_cast(a), static_cast(b), + static_cast(c))); +} + +template <> +class simd> { + int64x2_t m_value; + + public: + using value_type = std::int64_t; + using abi_type = simd_abi::neon_fixed_size<2>; + using mask_type = simd_mask; + class reference { + int64x2_t& m_value; + int m_lane; + + public: + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION reference(int64x2_t& value_arg, + int lane_arg) + : m_value(value_arg), m_lane(lane_arg) {} + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION reference + operator=(std::int64_t value) const { + switch (m_lane) { + case 0: m_value = vsetq_lane_s64(value, m_value, 0); break; + case 1: m_value = vsetq_lane_s64(value, m_value, 1); break; + } + return *this; + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION operator std::int64_t() const { + switch (m_lane) { + case 0: return vgetq_lane_s64(m_value, 0); + case 1: return vgetq_lane_s64(m_value, 1); + } + return 0; + } + }; + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd() = default; + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd(simd const&) = default; + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd(simd&&) = default; + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd& operator=(simd const&) = default; + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd& operator=(simd&&) = default; + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION static constexpr std::size_t size() { + return 2; + } + template , + bool> = false> + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd(U&& value) + : m_value(vmovq_n_s64(value_type(value))) {} + template >, + bool> = false> + KOKKOS_FORCEINLINE_FUNCTION simd(G&& gen) { + m_value = vsetq_lane_s64(gen(std::integral_constant()), + m_value, 0); + m_value = vsetq_lane_s64(gen(std::integral_constant()), + m_value, 1); + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr explicit simd( + int64x2_t const& value_in) + : m_value(value_in) {} + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION explicit simd( + simd const&); + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION reference operator[](std::size_t i) { + return reference(m_value, int(i)); + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION value_type + operator[](std::size_t i) const { + return reference(const_cast(this)->m_value, int(i)); + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION void copy_from(value_type const* ptr, + element_aligned_tag) { + m_value = vld1q_s64(ptr); + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION void copy_to( + value_type* ptr, element_aligned_tag) const { + vst1q_s64(ptr, m_value); + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr explicit operator int64x2_t() + const { + return m_value; + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION mask_type + operator==(simd const& other) const { + return mask_type(vceqq_s64(m_value, other.m_value)); + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION mask_type + operator>(simd const& other) const { + return mask_type(vcgtq_s64(m_value, other.m_value)); + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION mask_type + operator<(simd const& other) const { + return mask_type(vcltq_s64(m_value, other.m_value)); + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION mask_type + operator<=(simd const& other) const { + return mask_type(vcleq_s64(m_value, other.m_value)); + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION mask_type + operator>=(simd const& other) const { + return mask_type(vcgeq_s64(m_value, other.m_value)); + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION mask_type + operator!=(simd const& other) const { + return !((*this) == other); + } +}; + +[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION + simd> + operator-(simd> const& a) { + return simd>( + vnegq_s64(static_cast(a))); +} + +[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION + simd> + operator-(simd> const& lhs, + simd> const& rhs) { + return simd>( + vsubq_s64(static_cast(lhs), static_cast(rhs))); +} + +[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION + simd> + operator+(simd> const& lhs, + simd> const& rhs) { + return simd>( + vaddq_s64(static_cast(lhs), static_cast(rhs))); +} + +[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION + simd> + condition(simd_mask> const& a, + simd> const& b, + simd> const& c) { + return simd>( + vbslq_s64(static_cast(a), static_cast(b), + static_cast(c))); +} + +template <> +class simd> { + uint64x2_t m_value; + + public: + using value_type = std::uint64_t; + using abi_type = simd_abi::neon_fixed_size<2>; + using mask_type = simd_mask; + class reference { + uint64x2_t& m_value; + int m_lane; + + public: + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION reference(uint64x2_t& value_arg, + int lane_arg) + : m_value(value_arg), m_lane(lane_arg) {} + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION reference + operator=(std::uint64_t value) const { + switch (m_lane) { + case 0: m_value = vsetq_lane_u64(value, m_value, 0); break; + case 1: m_value = vsetq_lane_u64(value, m_value, 1); break; + } + return *this; + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION operator std::uint64_t() const { + switch (m_lane) { + case 0: return vgetq_lane_u64(m_value, 0); + case 1: return vgetq_lane_u64(m_value, 1); + } + return 0; + } + }; + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd() = default; + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd(simd const&) = default; + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd(simd&&) = default; + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd& operator=(simd const&) = default; + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd& operator=(simd&&) = default; + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION static constexpr std::size_t size() { + return 2; + } + template , + bool> = false> + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd(U&& value) + : m_value(vmovq_n_u64(value_type(value))) {} + template >, + bool> = false> + KOKKOS_FORCEINLINE_FUNCTION simd(G&& gen) { + m_value = vsetq_lane_u64(gen(std::integral_constant()), + m_value, 0); + m_value = vsetq_lane_u64(gen(std::integral_constant()), + m_value, 1); + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr explicit simd( + uint64x2_t const& value_in) + : m_value(value_in) {} + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION explicit simd( + simd const& other) + : m_value( + vreinterpretq_u64_s64(vmovl_s32(static_cast(other)))) {} + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION reference operator[](std::size_t i) { + return reference(m_value, int(i)); + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION value_type + operator[](std::size_t i) const { + return reference(const_cast(this)->m_value, int(i)); + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd + operator&(simd const& other) const { + return simd(vandq_u64(m_value, other.m_value)); + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd + operator|(simd const& other) const { + return simd(vorrq_u64(m_value, other.m_value)); + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr explicit operator uint64x2_t() + const { + return m_value; + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd + operator<<(unsigned int rhs) const { + return simd(vshlq_u64(m_value, vmovq_n_s64(std::int64_t(rhs)))); + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd + operator>>(unsigned int rhs) const { + return simd(vshlq_u64(m_value, vmovq_n_s64(-std::int64_t(rhs)))); + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION mask_type + operator==(simd const& other) const { + return mask_type(vceqq_u64(m_value, other.m_value)); + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION mask_type + operator!=(simd const& other) const { + return !((*this) == other); + } +}; + +[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION + simd> + operator-(simd> const& lhs, + simd> const& rhs) { + return simd>( + vsubq_u64(static_cast(lhs), static_cast(rhs))); +} + +[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION + simd> + operator+(simd> const& lhs, + simd> const& rhs) { + return simd>( + vaddq_u64(static_cast(lhs), static_cast(rhs))); +} + +KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION +simd>::simd( + simd> const& other) + : m_value( + vmovn_s64(vreinterpretq_s64_u64(static_cast(other)))) {} + +KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION +simd>::simd( + simd> const& other) + : m_value(vreinterpretq_s64_u64(static_cast(other))) {} + +[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION + simd> + condition(simd_mask> const& a, + simd> const& b, + simd> const& c) { + return simd>( + vbslq_u64(static_cast(a), static_cast(b), + static_cast(c))); +} + +template <> +class const_where_expression>, + simd>> { + public: + using abi_type = simd_abi::neon_fixed_size<2>; + using value_type = simd; + using mask_type = simd_mask; + + protected: + value_type& m_value; + mask_type const& m_mask; + + public: + const_where_expression(mask_type const& mask_arg, value_type const& value_arg) + : m_value(const_cast(value_arg)), m_mask(mask_arg) {} + [[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr mask_type const& + mask() const { + return m_mask; + } + [[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr value_type const& + value() const { + return m_value; + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION + void copy_to(double* mem, element_aligned_tag) const { + if (m_mask[0]) mem[0] = m_value[0]; + if (m_mask[1]) mem[1] = m_value[1]; + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION + void scatter_to( + double* mem, + simd> const& index) const { + if (m_mask[0]) mem[index[0]] = m_value[0]; + if (m_mask[1]) mem[index[1]] = m_value[1]; + } +}; + +template <> +class where_expression>, + simd>> + : public const_where_expression< + simd_mask>, + simd>> { + public: + where_expression( + simd_mask> const& mask_arg, + simd>& value_arg) + : const_where_expression(mask_arg, value_arg) {} + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION + void copy_from(double const* mem, element_aligned_tag) { + if (m_mask[0]) m_value[0] = mem[0]; + if (m_mask[1]) m_value[1] = mem[1]; + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION + void gather_from( + double const* mem, + simd> const& index) { + if (m_mask[0]) m_value[0] = mem[index[0]]; + if (m_mask[1]) m_value[1] = mem[index[1]]; + } + template >>, + bool> = false> + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION void operator=(U&& x) { + auto const x_as_value_type = + static_cast>>( + std::forward(x)); + m_value = static_cast>>( + vbslq_f64(static_cast(m_mask), + static_cast(x_as_value_type), + static_cast(m_value))); + } +}; + +template <> +class const_where_expression< + simd_mask>, + simd>> { + public: + using abi_type = simd_abi::neon_fixed_size<2>; + using value_type = simd; + using mask_type = simd_mask; + + protected: + value_type& m_value; + mask_type const& m_mask; + + public: + const_where_expression(mask_type const& mask_arg, value_type const& value_arg) + : m_value(const_cast(value_arg)), m_mask(mask_arg) {} + [[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr mask_type const& + mask() const { + return m_mask; + } + [[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr value_type const& + value() const { + return m_value; + } + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION + void copy_to(std::int32_t* mem, element_aligned_tag) const { + if (m_mask[0]) mem[0] = m_value[0]; + if (m_mask[1]) mem[1] = m_value[1]; + } +}; + +template <> +class where_expression>, + simd>> + : public const_where_expression< + simd_mask>, + simd>> { + public: + where_expression( + simd_mask> const& mask_arg, + simd>& value_arg) + : const_where_expression(mask_arg, value_arg) {} + KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION + void copy_from(std::int32_t const* mem, element_aligned_tag) { + if (m_mask[0]) m_value[0] = mem[0]; + if (m_mask[1]) m_value[1] = mem[1]; + } +}; + +} // namespace Experimental +} // namespace Kokkos + +#endif diff --git a/simd/unit_tests/TestSIMD.cpp b/simd/unit_tests/TestSIMD.cpp index ad6ce9bac3..7a4ecf19ed 100644 --- a/simd/unit_tests/TestSIMD.cpp +++ b/simd/unit_tests/TestSIMD.cpp @@ -170,7 +170,7 @@ void host_check_binary_op_one_loader(BinaryOp binary_op, std::size_t n, simd_type expected_result; for (std::size_t lane = 0; lane < nlanes; ++lane) { expected_result[lane] = - binary_op.on_host(first_arg[lane], second_arg[lane]); + binary_op.on_host(T(first_arg[lane]), T(second_arg[lane])); } simd_type const computed_result = binary_op.on_host(first_arg, second_arg); host_check_equality(expected_result, computed_result, nlanes); @@ -298,6 +298,61 @@ inline void host_check_mask_ops() { EXPECT_FALSE(all_of(mask_type(false))); } +template +inline void host_check_conversions() { + { + auto a = Kokkos::Experimental::simd(1); + auto b = Kokkos::Experimental::simd(a); + EXPECT_TRUE(all_of(b == decltype(b)(1))); + } + { + auto a = Kokkos::Experimental::simd(1); + auto b = Kokkos::Experimental::simd(a); + EXPECT_TRUE(all_of(b == decltype(b)(1))); + } + { + auto a = Kokkos::Experimental::simd(1); + auto b = Kokkos::Experimental::simd(a); + EXPECT_TRUE(all_of(b == decltype(b)(1))); + } + { + auto a = Kokkos::Experimental::simd_mask(true); + auto b = Kokkos::Experimental::simd_mask(a); + EXPECT_TRUE(b == decltype(b)(true)); + } + { + auto a = Kokkos::Experimental::simd_mask(true); + auto b = Kokkos::Experimental::simd_mask(a); + EXPECT_TRUE(b == decltype(b)(true)); + } + { + auto a = Kokkos::Experimental::simd_mask(true); + auto b = Kokkos::Experimental::simd_mask(a); + EXPECT_TRUE(b == decltype(b)(true)); + } + { + auto a = Kokkos::Experimental::simd_mask(true); + auto b = Kokkos::Experimental::simd_mask(a); + EXPECT_TRUE(b == decltype(b)(true)); + } +} + +template +inline void host_check_shifts() { + auto a = Kokkos::Experimental::simd(8); + auto b = a >> 1; + EXPECT_TRUE(all_of(b == decltype(b)(4))); +} + +template +inline void host_check_condition() { + auto a = Kokkos::Experimental::condition( + Kokkos::Experimental::simd(1) > 0, + Kokkos::Experimental::simd(16), + Kokkos::Experimental::simd(20)); + EXPECT_TRUE(all_of(a == decltype(a)(16))); +} + template KOKKOS_INLINE_FUNCTION void device_check_math_ops() { std::size_t constexpr n = 11; @@ -321,16 +376,80 @@ KOKKOS_INLINE_FUNCTION void device_check_mask_ops() { checker.truth(!all_of(mask_type(false))); } +template +KOKKOS_INLINE_FUNCTION void device_check_conversions() { + kokkos_checker checker; + { + auto a = Kokkos::Experimental::simd(1); + auto b = Kokkos::Experimental::simd(a); + checker.truth(all_of(b == decltype(b)(1))); + } + { + auto a = Kokkos::Experimental::simd(1); + auto b = Kokkos::Experimental::simd(a); + checker.truth(all_of(b == decltype(b)(1))); + } + { + auto a = Kokkos::Experimental::simd(1); + auto b = Kokkos::Experimental::simd(a); + checker.truth(all_of(b == decltype(b)(1))); + } + { + auto a = Kokkos::Experimental::simd_mask(true); + auto b = Kokkos::Experimental::simd_mask(a); + checker.truth(b == decltype(b)(true)); + } + { + auto a = Kokkos::Experimental::simd_mask(true); + auto b = Kokkos::Experimental::simd_mask(a); + checker.truth(b == decltype(b)(true)); + } + { + auto a = Kokkos::Experimental::simd_mask(true); + auto b = Kokkos::Experimental::simd_mask(a); + checker.truth(b == decltype(b)(true)); + } + { + auto a = Kokkos::Experimental::simd_mask(true); + auto b = Kokkos::Experimental::simd_mask(a); + checker.truth(b == decltype(b)(true)); + } +} + +template +KOKKOS_INLINE_FUNCTION void device_check_shifts() { + kokkos_checker checker; + auto a = Kokkos::Experimental::simd(8); + auto b = a >> 1; + checker.truth(all_of(b == decltype(b)(4))); +} + +template +KOKKOS_INLINE_FUNCTION void device_check_condition() { + kokkos_checker checker; + auto a = Kokkos::Experimental::condition( + Kokkos::Experimental::simd(1) > 0, + Kokkos::Experimental::simd(16), + Kokkos::Experimental::simd(20)); + checker.truth(all_of(a == decltype(a)(16))); +} + template inline void host_check_abi() { host_check_math_ops(); host_check_mask_ops(); + host_check_conversions(); + host_check_shifts(); + host_check_condition(); } template KOKKOS_INLINE_FUNCTION void device_check_abi() { device_check_math_ops(); device_check_mask_ops(); + device_check_conversions(); + device_check_shifts(); + device_check_condition(); } inline void host_check_abis(Kokkos::Experimental::Impl::abi_set<>) {}