diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index bc4b340cf6..11c4bd65ad 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -235,6 +235,18 @@ KOKKOSKERNELS_GENERATE_ETI(Blas3_trsm trsm TYPE_LISTS FLOATS LAYOUTS DEVICES ) +KOKKOSKERNELS_GENERATE_ETI(Blas3_trmm trmm + HEADER_LIST HEADERS + SOURCE_LIST SOURCES + TYPE_LISTS FLOATS LAYOUTS DEVICES +) + +KOKKOSKERNELS_GENERATE_ETI(Blas_trtri trtri + HEADER_LIST HEADERS + SOURCE_LIST SOURCES + TYPE_LISTS FLOATS LAYOUTS DEVICES +) + KOKKOSKERNELS_GENERATE_ETI(Sparse_sptrsv_solve sptrsv_solve HEADER_LIST ETI_HEADERS SOURCE_LIST SOURCES diff --git a/src/batched/KokkosBatched_Trmm_Decl.hpp b/src/batched/KokkosBatched_Trmm_Decl.hpp new file mode 100644 index 0000000000..9844de8431 --- /dev/null +++ b/src/batched/KokkosBatched_Trmm_Decl.hpp @@ -0,0 +1,69 @@ +/* +//@HEADER +// ************************************************************************ +// +// Kokkos v. 3.0 +// Copyright (2020) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// 1. Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the Corporation nor the names of the +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY NTESS "AS IS" AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NTESS OR THE +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Questions? Contact Siva Rajamanickam (srajama@sandia.gov) +// +// ************************************************************************ +//@HEADER +*/ + +#ifndef __KOKKOSBATCHED_TRMM_DECL_HPP__ +#define __KOKKOSBATCHED_TRMM_DECL_HPP__ + +#include "KokkosBatched_Util.hpp" +#include "KokkosBatched_Vector.hpp" + +namespace KokkosBatched { + + template + struct SerialTrmm { + template + KOKKOS_INLINE_FUNCTION + static int + invoke(const ScalarType alpha, + const AViewType &A, + const BViewType &B); + }; +} // namespace KokkosBatched +#endif // __KOKKOSBATCHED_TRMM_DECL_HPP__ diff --git a/src/batched/KokkosBatched_Trmm_Serial_Impl.hpp b/src/batched/KokkosBatched_Trmm_Serial_Impl.hpp new file mode 100644 index 0000000000..4d52388d3e --- /dev/null +++ b/src/batched/KokkosBatched_Trmm_Serial_Impl.hpp @@ -0,0 +1,288 @@ +/* +//@HEADER +// ************************************************************************ +// +// Kokkos v. 3.0 +// Copyright (2020) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// 1. Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the Corporation nor the names of the +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY NTESS "AS IS" AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NTESS OR THE +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Questions? Contact Siva Rajamanickam (srajama@sandia.gov) +// +// ************************************************************************ +//@HEADER +*/ + +#ifndef __KOKKOSBATCHED_TRMM_SERIAL_IMPL_HPP__ +#define __KOKKOSBATCHED_TRMM_SERIAL_IMPL_HPP__ + +#include "KokkosBatched_Util.hpp" +#include "KokkosBatched_Trmm_Serial_Internal.hpp" + +namespace KokkosBatched { + //// Lower non-transpose //// + template + struct SerialTrmm { + template + KOKKOS_INLINE_FUNCTION + static int + invoke(const ScalarType alpha, + const AViewType &A, + const BViewType &B) { + return SerialTrmmInternalLeftLower::invoke(ArgDiag::use_unit_diag, + false, + A.extent(0), A.extent(1), + B.extent(0), B.extent(1), + alpha, + A.data(), A.stride_0(), A.stride_1(), + B.data(), B.stride_0(), B.stride_1()); + } + }; + template + struct SerialTrmm { + template + KOKKOS_INLINE_FUNCTION + static int + invoke(const ScalarType alpha, + const AViewType &A, + const BViewType &B) { + return SerialTrmmInternalRightLower::invoke(ArgDiag::use_unit_diag, + false, + A.extent(0), A.extent(1), + B.extent(0), B.extent(1), + alpha, + A.data(), A.stride_0(), A.stride_1(), + B.data(), B.stride_0(), B.stride_1()); + } + }; + //// Lower transpose ///// + template + struct SerialTrmm { + template + KOKKOS_INLINE_FUNCTION + static int + invoke(const ScalarType alpha, + const AViewType &A, + const BViewType &B) { + return SerialTrmmInternalLeftUpper::invoke(ArgDiag::use_unit_diag, + false, + A.extent(1), A.extent(0), + B.extent(0), B.extent(1), + alpha, + A.data(), A.stride_1(), A.stride_0(), + B.data(), B.stride_0(), B.stride_1()); + } + }; + template + struct SerialTrmm { + template + KOKKOS_INLINE_FUNCTION + static int + invoke(const ScalarType alpha, + const AViewType &A, + const BViewType &B) { + return SerialTrmmInternalRightUpper::invoke(ArgDiag::use_unit_diag, + false, + A.extent(1), A.extent(0), + B.extent(0), B.extent(1), + alpha, + A.data(), A.stride_1(), A.stride_0(), + B.data(), B.stride_0(), B.stride_1()); + } + }; + //// Lower conjugate-transpose //// + template + struct SerialTrmm { + template + KOKKOS_INLINE_FUNCTION + static int + invoke(const ScalarType alpha, + const AViewType &A, + const BViewType &B) { + return SerialTrmmInternalLeftUpper::invoke(ArgDiag::use_unit_diag, + true, + A.extent(1), A.extent(0), + B.extent(0), B.extent(1), + alpha, + A.data(), A.stride_1(), A.stride_0(), + B.data(), B.stride_0(), B.stride_1()); + } + }; + template + struct SerialTrmm { + template + KOKKOS_INLINE_FUNCTION + static int + invoke(const ScalarType alpha, + const AViewType &A, + const BViewType &B) { + return SerialTrmmInternalRightUpper::invoke(ArgDiag::use_unit_diag, + true, + A.extent(1), A.extent(0), + B.extent(0), B.extent(1), + alpha, + A.data(), A.stride_1(), A.stride_0(), + B.data(), B.stride_0(), B.stride_1()); + } + }; + //// Upper non-transpose //// + template + struct SerialTrmm { + template + KOKKOS_INLINE_FUNCTION + static int + invoke(const ScalarType alpha, + const AViewType &A, + const BViewType &B) { + return SerialTrmmInternalLeftUpper::invoke(ArgDiag::use_unit_diag, + false, + A.extent(0), A.extent(1), + B.extent(0), B.extent(1), + alpha, + A.data(), A.stride_0(), A.stride_1(), + B.data(), B.stride_0(), B.stride_1()); + } + }; + template + struct SerialTrmm { + template + KOKKOS_INLINE_FUNCTION + static int + invoke(const ScalarType alpha, + const AViewType &A, + const BViewType &B) { + return SerialTrmmInternalRightUpper::invoke(ArgDiag::use_unit_diag, + false, + A.extent(0), A.extent(1), + B.extent(0), B.extent(1), + alpha, + A.data(), A.stride_0(), A.stride_1(), + B.data(), B.stride_0(), B.stride_1()); + } + }; + //// Upper transpose ///// + template + struct SerialTrmm { + template + KOKKOS_INLINE_FUNCTION + static int + invoke(const ScalarType alpha, + const AViewType &A, + const BViewType &B) { + return SerialTrmmInternalLeftLower::invoke(ArgDiag::use_unit_diag, + false, + A.extent(1), A.extent(0), + B.extent(0), B.extent(1), + alpha, + A.data(), A.stride_1(), A.stride_0(), + B.data(), B.stride_0(), B.stride_1()); + } + }; + template + struct SerialTrmm { + template + KOKKOS_INLINE_FUNCTION + static int + invoke(const ScalarType alpha, + const AViewType &A, + const BViewType &B) { + return SerialTrmmInternalRightLower::invoke(ArgDiag::use_unit_diag, + false, + A.extent(1), A.extent(0), + B.extent(0), B.extent(1), + alpha, + A.data(), A.stride_1(), A.stride_0(), + B.data(), B.stride_0(), B.stride_1()); + } + }; + //// Upper conjugate-transpose //// + template + struct SerialTrmm { + template + KOKKOS_INLINE_FUNCTION + static int + invoke(const ScalarType alpha, + const AViewType &A, + const BViewType &B) { + return SerialTrmmInternalLeftLower::invoke(ArgDiag::use_unit_diag, + true, + A.extent(1), A.extent(0), + B.extent(0), B.extent(1), + alpha, + A.data(), A.stride_1(), A.stride_0(), + B.data(), B.stride_0(), B.stride_1()); + } + }; + template + struct SerialTrmm { + template + KOKKOS_INLINE_FUNCTION + static int + invoke(const ScalarType alpha, + const AViewType &A, + const BViewType &B) { + return SerialTrmmInternalRightLower::invoke(ArgDiag::use_unit_diag, + true, + A.extent(1), A.extent(0), + B.extent(0), B.extent(1), + alpha, + A.data(), A.stride_1(), A.stride_0(), + B.data(), B.stride_0(), B.stride_1()); + } + }; +} // namespace KokkosBatched + +#endif // __KOKKOSBATCHED_TRMM_SERIAL_IMPL_HPP__ diff --git a/src/batched/KokkosBatched_Trmm_Serial_Internal.hpp b/src/batched/KokkosBatched_Trmm_Serial_Internal.hpp new file mode 100644 index 0000000000..35d8206cbb --- /dev/null +++ b/src/batched/KokkosBatched_Trmm_Serial_Internal.hpp @@ -0,0 +1,437 @@ +/* +//@HEADER +// ************************************************************************ +// +// Kokkos v. 3.0 +// Copyright (2020) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// 1. Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the Corporation nor the names of the +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY NTESS "AS IS" AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NTESS OR THE +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Questions? Contact Siva Rajamanickam (srajama@sandia.gov) +// +// ************************************************************************ +//@HEADER +*/ + +#ifndef __KOKKOSBATCHED_TRMM_SERIAL_INTERNAL_HPP__ +#define __KOKKOSBATCHED_TRMM_SERIAL_INTERNAL_HPP__ + +#include "KokkosBatched_Util.hpp" + +#include "KokkosBatched_Set_Internal.hpp" +#include "KokkosBatched_Scale_Internal.hpp" + +namespace KokkosBatched { + + template + struct SerialTrmmInternalLeftLower { + template + KOKKOS_INLINE_FUNCTION + static int + invoke(const bool use_unit_diag, + const bool do_conj, + const int am, const int an, + const int bm, const int bn, + const ScalarType alpha, + const ValueType *__restrict__ A, const int as0, const int as1, + /**/ ValueType *__restrict__ B, const int bs0, const int bs1); + }; + + template + struct SerialTrmmInternalLeftUpper { + template + KOKKOS_INLINE_FUNCTION + static int + invoke(const bool use_unit_diag, + const bool do_conj, + const int am, const int an, + const int bm, const int bn, + const ScalarType alpha, + const ValueType *__restrict__ A, const int as0, const int as1, + /**/ ValueType *__restrict__ B, const int bs0, const int bs1); + }; + + template + struct SerialTrmmInternalRightLower { + template + KOKKOS_INLINE_FUNCTION + static int + invoke(const bool use_unit_diag, + const bool do_conj, + const int am, const int an, + const int bm, const int bn, + const ScalarType alpha, + const ValueType *__restrict__ A, const int as0, const int as1, + /**/ ValueType *__restrict__ B, const int bs0, const int bs1); + }; + + template + struct SerialTrmmInternalRightUpper { + template + KOKKOS_INLINE_FUNCTION + static int + invoke(const bool use_unit_diag, + const bool do_conj, + const int am, const int an, + const int bm, const int bn, + const ScalarType alpha, + const ValueType *__restrict__ A, const int as0, const int as1, + /**/ ValueType *__restrict__ B, const int bs0, const int bs1); + }; + + // ech-note: use_unit_diag intentionally ignored for now. Compiler can optimize + // it out. Assuming that branching logic (especially on GPU) to handle use_unit_diag + // will use more cycles than simply doing 1.0*B[idx] for the copy if use_unit_diag. + template<> + template + KOKKOS_INLINE_FUNCTION + int + SerialTrmmInternalLeftLower:: + invoke(const bool use_unit_diag, + const bool do_conj, + const int am, const int an, + const int bm, const int bn, + const ScalarType alpha, + const ValueType *__restrict__ A, const int as0, const int as1, + /**/ ValueType *__restrict__ B, const int bs0, const int bs1) { + + const ScalarType one(1.0), zero(0.0); + typedef Kokkos::Details::ArithTraits AT; + int left_m = am; + int right_n = bn; + //echo-TODO: See about coniditionally setting conjOp at compile time. + //auto conjOp = noop; + //if (do_conj) { + // conjOp = AT::conj; + //} + //printf("SerialTrmmInternalLeftLower\n"); + + auto dotLowerLeftConj = [&](const ValueType *__restrict__ __A, const int __as0, const int __as1, const int __left_row, ValueType *__restrict__ __B, const int __bs0, const int __bs1, const int __right_col) { + auto B_elems = __left_row; + ScalarType sum = 0; +#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL) +#pragma unroll +#endif + for (int i = 0; i <= B_elems; i++) { + // sum += A[left_row, i] * B[i, right_col] + sum += AT::conj(__A[__left_row*__as0 + i*__as1]) * __B[i*__bs0 + __bs1*__right_col]; + } + return sum; + }; + + auto dotLowerLeft = [&](const ValueType *__restrict__ __A, const int __as0, const int __as1, const int __left_row, ValueType *__restrict__ __B, const int __bs0, const int __bs1, const int __right_col) { + auto B_elems = __left_row; + ScalarType sum = 0; +#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL) +#pragma unroll +#endif + for (int i = 0; i <= B_elems; i++) { + // sum += A[left_row, i] * B[i, right_col] + sum += __A[__left_row*__as0 + i*__as1] * __B[i*__bs0 + __bs1*__right_col]; + } + return sum; + }; + + if (bm <= 0 || bn <= 0 || am <= 0 || an <= 0) + return 0; + + if (alpha == zero) + SerialSetInternal::invoke(bm, bn, zero, B, bs0, bs1); + else { + if (alpha != one) + SerialScaleInternal::invoke(bm, bn, alpha, B, bs0, bs1); + +#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL) +#pragma unroll +#endif + for (int m = left_m-1; m >= 0; m--) { +#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL) +#pragma unroll +#endif + for (int n = 0; n < right_n; n++) { + if (do_conj) { + B[m*bs0 + n*bs1] = dotLowerLeftConj(A, as0, as1, m, B, bs0, bs1, n); + } else { + B[m*bs0 + n*bs1] = dotLowerLeft(A, as0, as1, m, B, bs0, bs1, n); + } + } + } + } + return 0; + } + + // ech-note: use_unit_diag intentionally ignored for now. Compiler can optimize + // it out. Assuming that branching logic (especially on GPU) to handle use_unit_diag + // will use more cycles than simply doing 1.0*B[idx] for the copy if use_unit_diag. + template<> + template + KOKKOS_INLINE_FUNCTION + int + SerialTrmmInternalRightLower:: + invoke(const bool use_unit_diag, + const bool do_conj, + const int am, const int an, + const int bm, const int bn, + const ScalarType alpha, + const ValueType *__restrict__ A, const int as0, const int as1, + /**/ ValueType *__restrict__ B, const int bs0, const int bs1) { + + const ScalarType one(1.0), zero(0.0); + typedef Kokkos::Details::ArithTraits AT; + int left_m = bm; + int right_n = an; + //echo-TODO: See about coniditionally setting conjOp at compile time. + //auto conjOp = noop; + //if (do_conj) { + // conjOp = AT::conj; + //} + + // Lower triangular matrix is on RHS with the base facing down. + // Everytime we compute a new output row of B, we must shift over to the + // right by one in A's column to ensure we skip the 0's. + auto dotLowerRightConj = [&](const ValueType *__restrict__ __A, const int __as0, const int __as1, const int __am, const int __left_row, ValueType *__restrict__ __B, const int __bs0, const int __bs1, const int __right_col) { + auto B_elems = __am - 1; + ScalarType sum = 0; +#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL) +#pragma unroll +#endif + for (int i = __right_col; i <= B_elems; i++) { + // sum += B[left_row, i] * A[i, right_col] + sum += __B[__bs0*__left_row + i*__bs1] * AT::conj(__A[i*__as0 + __right_col*__as1]); + } + return sum; + }; + + auto dotLowerRight = [&](const ValueType *__restrict__ __A, const int __as0, const int __as1, const int __am, const int __left_row, ValueType *__restrict__ __B, const int __bs0, const int __bs1, const int __right_col) { + auto B_elems = __am - 1; + ScalarType sum = 0; +#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL) +#pragma unroll +#endif + for (int i = __right_col; i <= B_elems; i++) { + // sum += B[left_row, i] * A[i, right_col] + sum += __B[__bs0*__left_row + i*__bs1] * __A[i*__as0 + __right_col*__as1]; + } + return sum; + }; + + if (bm <= 0 || bn <= 0 || am <= 0 || an <= 0) + return 0; + + if (alpha == zero) + SerialSetInternal::invoke(bm, bn, zero, B, bs0, bs1); + else { + if (alpha != one) + SerialScaleInternal::invoke(bm, bn, alpha, B, bs0, bs1); + +#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL) +#pragma unroll +#endif + for (int m = 0; m < left_m; m++) { +#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL) +#pragma unroll +#endif + for (int n = 0; n < right_n; n++) { + if (do_conj) { + B[m*bs0 + n*bs1] = dotLowerRightConj(A, as0, as1, am, m, B, bs0, bs1, n); + } else { + B[m*bs0 + n*bs1] = dotLowerRight(A, as0, as1, am, m, B, bs0, bs1, n); + } + } + } + } + return 0; + } + + template<> + template + KOKKOS_INLINE_FUNCTION + int + SerialTrmmInternalLeftUpper:: + invoke(const bool use_unit_diag, + const bool do_conj, + const int am, const int an, + const int bm, const int bn, + const ScalarType alpha, + const ValueType *__restrict__ A, const int as0, const int as1, + /**/ ValueType *__restrict__ B, const int bs0, const int bs1) { + + const ScalarType one(1.0), zero(0.0); + typedef Kokkos::Details::ArithTraits AT; + int left_m = am; + int right_n = bn; + //echo-TODO: See about coniditionally setting conjOp at compile time. + //auto conjOp = noop; + //if (do_conj) { + // conjOp = AT::conj; + //} + + auto dotUpperLeftConj = [&](const ValueType *__restrict__ __A, const int __as0, const int __as1, const int __an, const int __left_row, ValueType *__restrict__ __B, const int __bs0, const int __bs1, const int __right_col) { + auto B_elems = __an - __left_row - 1; + ScalarType sum = 0; +#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL) +#pragma unroll +#endif + for (int i = 0; i <= B_elems; i++) { + // sum += A[left_row, i+left_row] * B[i+left_row, right_col] + sum += AT::conj(__A[__left_row*__as0 + (i+__left_row)*__as1]) * __B[(i+__left_row)*__bs0 + __bs1*__right_col]; + } + return sum; + }; + + auto dotUpperLeft = [&](const ValueType *__restrict__ __A, const int __as0, const int __as1, const int __an, const int __left_row, ValueType *__restrict__ __B, const int __bs0, const int __bs1, const int __right_col) { + auto B_elems = __an - __left_row - 1; + ScalarType sum = 0; +#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL) +#pragma unroll +#endif + for (int i = 0; i <= B_elems; i++) { + // sum += A[left_row, i+left_row] * B[i+left_row, right_col] + sum += __A[__left_row*__as0 + (i+__left_row)*__as1] * __B[(i+__left_row)*__bs0 + __bs1*__right_col]; + } + return sum; + }; + + if (bm <= 0 || bn <= 0 || am <= 0 || an <= 0) + return 0; + + if (alpha == zero) + SerialSetInternal::invoke(bm, bn, zero, B, bs0, bs1); + else { + if (alpha != one) + SerialScaleInternal::invoke(bm, bn, alpha, B, bs0, bs1); + +#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL) +#pragma unroll +#endif + for (int m = 0; m < left_m; ++m) { +#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL) +#pragma unroll +#endif + for (int n = 0; n < right_n; ++n) { + if (do_conj) { + B[m*bs0 + n*bs1] = dotUpperLeftConj(A, as0, as1, an, m, B, bs0, bs1, n); + } else { + B[m*bs0 + n*bs1] = dotUpperLeft(A, as0, as1, an, m, B, bs0, bs1, n); + } + } + } + } + return 0; + } + + template<> + template + KOKKOS_INLINE_FUNCTION + int + SerialTrmmInternalRightUpper:: + invoke(const bool use_unit_diag, + const bool do_conj, + const int am, const int an, + const int bm, const int bn, + const ScalarType alpha, + const ValueType *__restrict__ A, const int as0, const int as1, + /**/ ValueType *__restrict__ B, const int bs0, const int bs1) { + + const ScalarType one(1.0), zero(0.0); + typedef Kokkos::Details::ArithTraits AT; + int left_m = bm; + int right_n = an; + //echo-TODO: See about coniditionally setting conjOp at compile time. + //auto conjOp = noop; + //if (do_conj) { + // conjOp = AT::conj; + //} + + auto dotUpperRightConj = [&](const ValueType *__restrict__ __A, const int __as0, const int __as1, const int __left_row, ValueType *__restrict__ __B, const int __bs0, const int __bs1, const int __right_col) { + auto B_elems = __right_col; + ScalarType sum = 0; +#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL) +#pragma unroll +#endif + for (int i = 0; i <= B_elems; i++) { + // sum += B[left_row, i] * A[i, right_col] + sum += __B[__left_row*__bs0 + i*__bs1] * AT::conj(__A[i*__as0 + __right_col*__as1]); + } + return sum; + }; + + auto dotUpperRight = [&](const ValueType *__restrict__ __A, const int __as0, const int __as1, const int __left_row, ValueType *__restrict__ __B, const int __bs0, const int __bs1, const int __right_col) { + auto B_elems = __right_col; + ScalarType sum = 0; +#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL) +#pragma unroll +#endif + for (int i = 0; i <= B_elems; i++) { + // sum += B[left_row, i] * A[i, right_col] + sum += __B[__left_row*__bs0 + i*__bs1] * __A[i*__as0 + __right_col*__as1]; + } + return sum; + }; + + if (bm <= 0 || bn <= 0 || am <= 0 || an <= 0) + return 0; + + if (alpha == zero) + SerialSetInternal::invoke(bm, bn, zero, B, bs0, bs1); + else { + if (alpha != one) + SerialScaleInternal::invoke(bm, bn, alpha, B, bs0, bs1); + +#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL) +#pragma unroll +#endif + for (int m = 0; m < left_m; ++m) { +#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL) +#pragma unroll +#endif + for (int n = right_n - 1; n >= 0; --n) { + if (do_conj) { + B[m*bs0 + n*bs1] = dotUpperRightConj(A, as0, as1, m, B, bs0, bs1, n); + } else { + B[m*bs0 + n*bs1] = dotUpperRight(A, as0, as1, m, B, bs0, bs1, n); + } + } + } + } + return 0; + } +} // namespace KokkosBatched +#endif // __KOKKOSBATCHED_TRMM_SERIAL_INTERNAL_HPP__ diff --git a/src/batched/KokkosBatched_Trtri_Decl.hpp b/src/batched/KokkosBatched_Trtri_Decl.hpp new file mode 100644 index 0000000000..1d5cf0632d --- /dev/null +++ b/src/batched/KokkosBatched_Trtri_Decl.hpp @@ -0,0 +1,64 @@ +/* +//@HEADER +// ************************************************************************ +// +// Kokkos v. 3.0 +// Copyright (2020) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// 1. Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the Corporation nor the names of the +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY NTESS "AS IS" AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NTESS OR THE +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Questions? Contact Siva Rajamanickam (srajama@sandia.gov) +// +// ************************************************************************ +//@HEADER +*/ + +#ifndef __KOKKOSBATCHED_TRTRI_DECL_HPP__ +#define __KOKKOSBATCHED_TRTRI_DECL_HPP__ + +#include "KokkosBatched_Util.hpp" +#include "KokkosBatched_Vector.hpp" + +namespace KokkosBatched { + + template + struct SerialTrtri { + template + KOKKOS_INLINE_FUNCTION + static int + invoke(const AViewType &A); + }; +} // namespace KokkosBatched +#endif // __KOKKOSBATCHED_TRTRI_DECL_HPP__ diff --git a/src/batched/KokkosBatched_Trtri_Serial_Impl.hpp b/src/batched/KokkosBatched_Trtri_Serial_Impl.hpp new file mode 100644 index 0000000000..bc2da0e066 --- /dev/null +++ b/src/batched/KokkosBatched_Trtri_Serial_Impl.hpp @@ -0,0 +1,76 @@ +/* +//@HEADER +// ************************************************************************ +// +// Kokkos v. 3.0 +// Copyright (2020) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// 1. Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the Corporation nor the names of the +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY NTESS "AS IS" AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NTESS OR THE +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Questions? Contact Siva Rajamanickam (srajama@sandia.gov) +// +// ************************************************************************ +//@HEADER +*/ + +#ifndef __KOKKOSBATCHED_TRTRI_SERIAL_IMPL_HPP__ +#define __KOKKOSBATCHED_TRTRI_SERIAL_IMPL_HPP__ + +#include "KokkosBatched_Util.hpp" +#include "KokkosBatched_Trtri_Serial_Internal.hpp" + +namespace KokkosBatched { + template + struct SerialTrtri { + template + KOKKOS_INLINE_FUNCTION + static int + invoke(const AViewType &A) { + return SerialTrtriInternalLower::invoke(ArgDiag::use_unit_diag, + A.extent(0), A.extent(1), + A.data(), A.stride_0(), A.stride_1()); + } + }; + template + struct SerialTrtri { + template + KOKKOS_INLINE_FUNCTION + static int + invoke(const AViewType &A) { + return SerialTrtriInternalUpper::invoke(ArgDiag::use_unit_diag, + A.extent(0), A.extent(1), + A.data(), A.stride(0), A.stride(1)); + } + }; +} // namespace KokkosBatched + +#endif // __KOKKOSBATCHED_TRTRI_SERIAL_IMPL_HPP__ diff --git a/src/batched/KokkosBatched_Trtri_Serial_Internal.hpp b/src/batched/KokkosBatched_Trtri_Serial_Internal.hpp new file mode 100644 index 0000000000..f3c7f3c960 --- /dev/null +++ b/src/batched/KokkosBatched_Trtri_Serial_Internal.hpp @@ -0,0 +1,186 @@ +/* +//@HEADER +// ************************************************************************ +// +// Kokkos v. 3.0 +// Copyright (2020) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// 1. Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the Corporation nor the names of the +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY NTESS "AS IS" AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NTESS OR THE +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Questions? Contact Siva Rajamanickam (srajama@sandia.gov) +// +// ************************************************************************ +//@HEADER +*/ + +#ifndef __KOKKOSBATCHED_TRTRI_SERIAL_INTERNAL_HPP__ +#define __KOKKOSBATCHED_TRTRI_SERIAL_INTERNAL_HPP__ + +#include "KokkosBatched_Util.hpp" +#include "KokkosBatched_Trmm_Serial_Internal.hpp" + +namespace KokkosBatched { + + template + struct SerialTrtriInternalLower { + template + KOKKOS_INLINE_FUNCTION + static int + invoke(const bool use_unit_diag, + const int am, const int an, + ValueType *__restrict__ A, const int as0, const int as1); + }; + + template + struct SerialTrtriInternalUpper { + template + KOKKOS_INLINE_FUNCTION + static int + invoke(const bool use_unit_diag, + const int am, const int an, + ValueType *__restrict__ A, const int as0, const int as1); + }; + + template<> + template + KOKKOS_INLINE_FUNCTION + int + SerialTrtriInternalLower:: + invoke(const bool use_unit_diag, + const int am, const int an, + ValueType *__restrict__ A, const int as0, const int as1) { + ValueType one(1.0), zero(0.0), A_ii; + if (!use_unit_diag) { +#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL) +#pragma unroll +#endif + // Check for singularity + for (int i = 0; i < am; ++i) + if (A[i*as0 + i*as1] == zero) + return i+1; + } + +#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL) +#pragma unroll +#endif + for (int i = am - 1; i >= 0; --i) { + A[i*as0 + i*as1] = one / A[i*as0 + i*as1]; + + if (i < am - 1) { + if (use_unit_diag) + A_ii = -one; + else + A_ii = -A[i*as0 + i*as1]; + + ValueType *__restrict__ A_subblock = &A[(i+1)*as0 + (i+1)*as1]; + int A_subblock_m = am - i - 1, + A_subblock_n = am - i - 1; + ValueType *__restrict__ A_col_vec = &A[(i+1)*as0 + i*as1]; + int A_col_vec_m = am - i - 1, + A_col_vec_n = 1; + // TRMV/TRMM −− x=Ax + // A((j+1):n,j) = A((j+1):n,(j+1):n) ∗ A((j+1):n,j) ; + SerialTrmmInternalLeftLower::invoke(use_unit_diag, + false, + A_subblock_m, A_subblock_n, + A_col_vec_m, A_col_vec_n, + one, + A_subblock, as0, as1, + A_col_vec, as0, as1); + + // SCAL -- x=ax + // A((j+1):n,j) = A_ii * A((j+1):n,j) + SerialScaleInternal::invoke(A_col_vec_m, A_col_vec_n, A_ii, A_col_vec, as0, as1); + } + } + return 0; + } + + template<> + template + KOKKOS_INLINE_FUNCTION + int + SerialTrtriInternalUpper:: + invoke(const bool use_unit_diag, + const int am, const int an, + ValueType *__restrict__ A, const int as0, const int as1) { + ValueType one(1.0), zero(0.0), A_ii; + + + if (!use_unit_diag) { +#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL) +#pragma unroll +#endif + // Check for singularity + for (int i = 0; i < am; ++i) + if (A[i*as0 + i*as1] == zero) + return i+1; + } + +#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL) +#pragma unroll +#endif + for (int i = 0; i < am; ++i) { + A[i*as0 + i*as1] = one / A[i*as0 + i*as1]; + + if (i > 0) { + if (use_unit_diag) + A_ii = -one; + else + A_ii = -A[i*as0 + i*as1]; + + ValueType *__restrict__ A_subblock = &A[0*as0 + 0*as1]; + int A_subblock_m = i, + A_subblock_n = i; + ValueType *__restrict__ A_col_vec = &A[0*as0 + i*as1]; + int A_col_vec_m = i, + A_col_vec_n = 1; + // TRMV/TRMM −− x=Ax + // A(1:(j-1),j) = A(1:(j-1),1:(j-1)) ∗ A(1:(j-1),j) ; + //SerialTrmm + SerialTrmmInternalLeftUpper::invoke(use_unit_diag, + false, + A_subblock_m, A_subblock_n, + A_col_vec_m, A_col_vec_n, + one, + A_subblock, as0, as1, + A_col_vec, as0, as1); + + // SCAL -- x=ax + // A((j+1):n,j) = A_ii * A((j+1):n,j) + SerialScaleInternal::invoke(A_col_vec_m, A_col_vec_n, A_ii, A_col_vec, as0, as1); + } + } + return 0; + } +} // namespace KokkosBatched +#endif // __KOKKOSBATCHED_TRTRI_SERIAL_INTERNAL_HPP__ diff --git a/src/batched/KokkosBatched_Util.hpp b/src/batched/KokkosBatched_Util.hpp index fc1f91d5f4..bcb69f4812 100644 --- a/src/batched/KokkosBatched_Util.hpp +++ b/src/batched/KokkosBatched_Util.hpp @@ -290,6 +290,8 @@ namespace KokkosBatched { using Gemm = Level3; using Trsm = Level3; + using Trmm = Level3; + using Trtri = Level3; // TODO: Need new level for Trtri? using LU = Level3; using InverseLU = Level3; using SolveLU = Level3; diff --git a/src/blas/KokkosBlas_trtri.hpp b/src/blas/KokkosBlas_trtri.hpp index 2e0230fb6b..39c191f4d7 100644 --- a/src/blas/KokkosBlas_trtri.hpp +++ b/src/blas/KokkosBlas_trtri.hpp @@ -120,7 +120,7 @@ trtri (const char uplo[], } // Create A matrix view type alias - using AViewInternalType = Kokkos::View >; diff --git a/src/blas/impl/KokkosBlas3_trmm_impl.hpp b/src/blas/impl/KokkosBlas3_trmm_impl.hpp new file mode 100644 index 0000000000..73b9da8873 --- /dev/null +++ b/src/blas/impl/KokkosBlas3_trmm_impl.hpp @@ -0,0 +1,195 @@ +/* +//@HEADER +// ************************************************************************ +// +// Kokkos v. 3.0 +// Copyright (2020) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// 1. Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the Corporation nor the names of the +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY NTESS "AS IS" AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NTESS OR THE +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Questions? Contact Siva Rajamanickam (srajama@sandia.gov) +// +// ************************************************************************ +//@HEADER +*/ + +#ifndef KOKKOSBLAS3_TRMM_IMPL_HPP_ +#define KOKKOSBLAS3_TRMM_IMPL_HPP_ + +/** + * \file KokkosBlas3_trmm_impl.hpp + * \brief Implementation of triangular matrix multiply + */ + + +#include "KokkosKernels_config.h" +#include "Kokkos_Core.hpp" +#include "Kokkos_ArithTraits.hpp" +#include "KokkosBatched_Set_Internal.hpp" +#include "KokkosBatched_Scale_Internal.hpp" +#include "KokkosBatched_Trmm_Decl.hpp" +#include "KokkosBatched_Trmm_Serial_Impl.hpp" + +using namespace KokkosBatched; + +namespace KokkosBlas { + namespace Impl { + + template + void SerialTrmm_Invoke (const char side[], + const char uplo[], + const char trans[], + const char diag[], + typename BViewType::const_value_type& alpha, + const AViewType& A, + const BViewType& B) + { + char __side = tolower(side[0]), + __uplo = tolower(uplo[0]), + __trans = tolower(trans[0]); + //__diag = tolower(diag[0]); + bool do_conj = true; + + // Ignoring diag, see "ech-note" in KokkosBatched_Trmm_Serial_Internal.hpp + + //// Lower non-transpose //// + if (__side == 'l' && __uplo == 'l' && __trans == 'n') + SerialTrmmInternalLeftLower::invoke(Diag::Unit::use_unit_diag, + !do_conj, + A.extent(0), A.extent(1), + B.extent(0), B.extent(1), + alpha, + A.data(), A.stride(0), A.stride(1), + B.data(), B.stride(0), B.stride(1)); + if (__side == 'r' && __uplo == 'l' && __trans == 'n') + SerialTrmmInternalRightLower::invoke(Diag::Unit::use_unit_diag, + !do_conj, + A.extent(0), A.extent(1), + B.extent(0), B.extent(1), + alpha, + A.data(), A.stride(0), A.stride(1), + B.data(), B.stride(0), B.stride(1)); + //// Lower transpose ///// + // Transpose A by simply swapping the dimensions (extent) and stride parameters + if (__side == 'l' && __uplo == 'l' && __trans == 't') + SerialTrmmInternalLeftUpper::invoke(Diag::Unit::use_unit_diag, + !do_conj, + A.extent(1), A.extent(0), + B.extent(0), B.extent(1), + alpha, + A.data(), A.stride(1), A.stride(0), + B.data(), B.stride(0), B.stride(1)); + if (__side == 'r' && __uplo == 'l' && __trans == 't') + SerialTrmmInternalRightUpper::invoke(Diag::Unit::use_unit_diag, + !do_conj, + A.extent(1), A.extent(0), + B.extent(0), B.extent(1), + alpha, + A.data(), A.stride(1), A.stride(0), + B.data(), B.stride(0), B.stride(1)); + + //// Lower conjugate-transpose //// + // Conjugate-Transpose A by simply swapping the dimensions (extent) and stride parameters + if (__side == 'l' && __uplo == 'l' && __trans == 'c') + SerialTrmmInternalLeftUpper::invoke(Diag::Unit::use_unit_diag, + do_conj, + A.extent(1), A.extent(0), + B.extent(0), B.extent(1), + alpha, + A.data(), A.stride(1), A.stride(0), + B.data(), B.stride(0), B.stride(1)); + if (__side == 'r' && __uplo == 'l' && __trans == 'c') + SerialTrmmInternalRightUpper::invoke(Diag::Unit::use_unit_diag, + do_conj, + A.extent(1), A.extent(0), + B.extent(0), B.extent(1), + alpha, + A.data(), A.stride(1), A.stride(0), + B.data(), B.stride(0), B.stride(1)); + //// Upper non-transpose //// + if (__side == 'l' && __uplo == 'u' && __trans == 'n') + SerialTrmmInternalLeftUpper::invoke(Diag::Unit::use_unit_diag, + !do_conj, + A.extent(0), A.extent(1), + B.extent(0), B.extent(1), + alpha, + A.data(), A.stride(0), A.stride(1), + B.data(), B.stride(0), B.stride(1)); + if (__side == 'r' && __uplo == 'u' && __trans == 'n') + SerialTrmmInternalRightUpper::invoke(Diag::Unit::use_unit_diag, + !do_conj, + A.extent(0), A.extent(1), + B.extent(0), B.extent(1), + alpha, + A.data(), A.stride(0), A.stride(1), + B.data(), B.stride(0), B.stride(1)); + //// Upper transpose + // Transpose A by simply swapping the dimensions (extent) and stride parameters + if (__side == 'l' && __uplo == 'u' && __trans == 't') + SerialTrmmInternalLeftLower::invoke(Diag::Unit::use_unit_diag, + !do_conj, + A.extent(1), A.extent(0), + B.extent(0), B.extent(1), + alpha, + A.data(), A.stride(1), A.stride(0), + B.data(), B.stride(0), B.stride(1)); + if (__side == 'r' && __uplo == 'u' && __trans == 't') + SerialTrmmInternalRightLower::invoke(Diag::Unit::use_unit_diag, + !do_conj, + A.extent(1), A.extent(0), + B.extent(0), B.extent(1), + alpha, + A.data(), A.stride(1), A.stride(0), + B.data(), B.stride(0), B.stride(1)); + + //// Upper conjugate-transpose //// + // Conjugate-Transpose A by simply swapping the dimensions (extent) and stride parameters + if (__side == 'l' && __uplo == 'u' && __trans == 'c') + SerialTrmmInternalLeftLower::invoke(Diag::Unit::use_unit_diag, + do_conj, + A.extent(1), A.extent(0), + B.extent(0), B.extent(1), + alpha, + A.data(), A.stride(1), A.stride(0), + B.data(), B.stride(0), B.stride(1)); + if (__side == 'r' && __uplo == 'u' && __trans == 'c') + SerialTrmmInternalRightLower::invoke(Diag::Unit::use_unit_diag, + do_conj, + A.extent(1), A.extent(0), + B.extent(0), B.extent(1), + alpha, + A.data(), A.stride(1), A.stride(0), + B.data(), B.stride(0), B.stride(1)); + } + } // namespace Impl +} // namespace KokkosBlas +#endif // KOKKOSBLAS3_TRMM_IMPL_HPP_ diff --git a/src/blas/impl/KokkosBlas3_trmm_spec.hpp b/src/blas/impl/KokkosBlas3_trmm_spec.hpp index 0143596477..13c87a299e 100644 --- a/src/blas/impl/KokkosBlas3_trmm_spec.hpp +++ b/src/blas/impl/KokkosBlas3_trmm_spec.hpp @@ -48,7 +48,7 @@ #include "Kokkos_Core.hpp" #if !defined(KOKKOSKERNELS_ETI_ONLY) || KOKKOSKERNELS_IMPL_COMPILE_LIBRARY -//#include +#include #endif namespace KokkosBlas { @@ -71,7 +71,7 @@ struct trmm_eti_spec_avail { Kokkos::MemoryTraits >, \ Kokkos::View, \ Kokkos::MemoryTraits > \ - > { enum : bool { value = false }; }; + > { enum : bool { value = true }; }; // // This Macros provides the ETI specialization of trmm, currently not available. @@ -81,7 +81,7 @@ struct trmm_eti_spec_avail { // Include the actual specialization declarations #include -//TODO: #include +#include namespace KokkosBlas { namespace Impl { @@ -107,8 +107,7 @@ struct TRMM{ const BVIT& B); }; -// TODO: Fall-back ETI implementation of KokkosBlas::trmm. -#if 0 && (!defined(KOKKOSKERNELS_ETI_ONLY) || KOKKOSKERNELS_IMPL_COMPILE_LIBRARY) +#if !defined(KOKKOSKERNELS_ETI_ONLY) || KOKKOSKERNELS_IMPL_COMPILE_LIBRARY template struct TRMM { @@ -187,6 +186,6 @@ template struct TRMM< \ KOKKOSBLAS3_TRMM_ETI_SPEC_INST_LAYOUTS(SCALAR, LAYOUT, LAYOUT, EXEC_SPACE, MEM_SPACE) #include -//#include +#include #endif // KOKKOSBLAS3_TRMM_SPEC_HPP_ diff --git a/src/blas/impl/KokkosBlas_trtri_impl.hpp b/src/blas/impl/KokkosBlas_trtri_impl.hpp new file mode 100644 index 0000000000..c8cc4c7efa --- /dev/null +++ b/src/blas/impl/KokkosBlas_trtri_impl.hpp @@ -0,0 +1,100 @@ +/* +//@HEADER +// ************************************************************************ +// +// Kokkos v. 3.0 +// Copyright (2020) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// 1. Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the Corporation nor the names of the +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY NTESS "AS IS" AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NTESS OR THE +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Questions? Contact Siva Rajamanickam (srajama@sandia.gov) +// +// ************************************************************************ +//@HEADER +*/ + +#ifndef KOKKOSBLAS_TRTRI_IMPL_HPP_ +#define KOKKOSBLAS_TRTRI_IMPL_HPP_ + +/** + * \file KokkosBlas_trtri_impl.hpp + * \brief Implementation of triangular matrix inverse + */ + + +#include "KokkosKernels_config.h" +#include "Kokkos_Core.hpp" +#include "KokkosBatched_Trtri_Decl.hpp" +#include "KokkosBatched_Trtri_Serial_Impl.hpp" + +using namespace KokkosBatched; + +namespace KokkosBlas { + namespace Impl { + + template + void SerialTrtri_Invoke (const RViewType &R, + const char uplo[], + const char diag[], + const AViewType &A) + { + char __uplo = tolower(uplo[0]), + __diag = tolower(diag[0]); + + //// Lower //// + if (__uplo == 'l') { + if (__diag == 'u') { + R() = SerialTrtriInternalLower::invoke(Diag::Unit::use_unit_diag, + A.extent(0), A.extent(1), + A.data(), A.stride(0), A.stride(1)); + } else { + R() = SerialTrtriInternalLower::invoke(Diag::NonUnit::use_unit_diag, + A.extent(0), A.extent(1), + A.data(), A.stride(0), A.stride(1)); + } + } else { + //// Upper //// + if (__diag == 'u') { + R() = SerialTrtriInternalUpper::invoke(Diag::Unit::use_unit_diag, + A.extent(0), A.extent(1), + A.data(), A.stride(0), A.stride(1)); + } else { + R() = SerialTrtriInternalUpper::invoke(Diag::NonUnit::use_unit_diag, + A.extent(0), A.extent(1), + A.data(), A.stride(0), A.stride(1)); + } + } + } + } // namespace Impl +} // namespace KokkosBlas +#endif // KOKKOSBLAS_TRTRI_IMPL_HPP_ diff --git a/src/blas/impl/KokkosBlas_trtri_spec.hpp b/src/blas/impl/KokkosBlas_trtri_spec.hpp index e0fc72943b..01f53d04e1 100644 --- a/src/blas/impl/KokkosBlas_trtri_spec.hpp +++ b/src/blas/impl/KokkosBlas_trtri_spec.hpp @@ -48,7 +48,7 @@ #include "Kokkos_Core.hpp" #if !defined(KOKKOSKERNELS_ETI_ONLY) || KOKKOSKERNELS_IMPL_COMPILE_LIBRARY -//#include +#include #endif namespace KokkosBlas { @@ -69,13 +69,13 @@ struct trtri_eti_spec_avail { struct trtri_eti_spec_avail< \ Kokkos::View >, \ - Kokkos::View, \ - Kokkos::MemoryTraits >, \ - > { enum : bool { value = false }; }; + Kokkos::View, \ + Kokkos::MemoryTraits > \ + > { enum : bool { value = true }; }; // Include the actual specialization declarations #include -//TODO: #include +#include namespace KokkosBlas { namespace Impl { @@ -98,8 +98,7 @@ struct TRTRI{ const AVIT& A); }; -// TODO: Fall-back ETI implementation of KokkosBlas::trtri. -#if 0 && (!defined(KOKKOSKERNELS_ETI_ONLY) || KOKKOSKERNELS_IMPL_COMPILE_LIBRARY) +#if !defined(KOKKOSKERNELS_ETI_ONLY) || KOKKOSKERNELS_IMPL_COMPILE_LIBRARY template struct TRTRI { static void @@ -116,12 +115,11 @@ struct TRTRI { Kokkos::Profiling::pushRegion(KOKKOSKERNELS_IMPL_COMPILE_LIBRARY?"KokkosBlas::trtri[ETI]":"KokkosBlas::trtri[noETI]"); typename AVIT::HostMirror host_A = Kokkos::create_mirror_view(A); + typename RVIT::HostMirror host_R = Kokkos::create_mirror_view(R); Kokkos::deep_copy(host_A, A); - // TODO: Why does this always execute in host space? kokkos parallel operations - // can execute in device space. - SerialTrtri_Invoke (uplo, diag, host_A); + SerialTrtri_Invoke (R, uplo, diag, host_A); Kokkos::deep_copy(A, host_A); @@ -145,7 +143,7 @@ struct TRTRI { extern template struct TRTRI< \ Kokkos::View >, \ - Kokkos::View, \ + Kokkos::View, \ Kokkos::MemoryTraits >, \ false, true>; @@ -153,11 +151,11 @@ extern template struct TRTRI< \ template struct TRTRI< \ Kokkos::View >, \ - Kokkos::View, \ + Kokkos::View, \ Kokkos::MemoryTraits >, \ false, true>; #include -// TODO: #include +#include #endif // KOKKOSBLAS_TRTRI_SPEC_HPP_ diff --git a/src/impl/generated_specializations_cpp/trmm/KokkosBlas3_trmm_eti_spec_inst.cpp.in b/src/impl/generated_specializations_cpp/trmm/KokkosBlas3_trmm_eti_spec_inst.cpp.in new file mode 100644 index 0000000000..b6976565c1 --- /dev/null +++ b/src/impl/generated_specializations_cpp/trmm/KokkosBlas3_trmm_eti_spec_inst.cpp.in @@ -0,0 +1,54 @@ +/* +//@HEADER +// ************************************************************************ +// +// Kokkos v. 3.0 +// Copyright (2020) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// 1. Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the Corporation nor the names of the +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY NTESS "AS IS" AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NTESS OR THE +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Questions? Contact Siva Rajamanickam (srajama@sandia.gov) +// +// ************************************************************************ +//@HEADER +*/ + + +#define KOKKOSKERNELS_IMPL_COMPILE_LIBRARY true +#include "KokkosKernels_config.h" +#include "KokkosBlas3_trmm_spec.hpp" + +namespace KokkosBlas { +namespace Impl { +@BLAS3_TRMM_ETI_INST_BLOCK@ + } //IMPL +} //Kokkos diff --git a/src/impl/generated_specializations_cpp/trtri/KokkosBlas_trtri_eti_spec_inst.cpp.in b/src/impl/generated_specializations_cpp/trtri/KokkosBlas_trtri_eti_spec_inst.cpp.in new file mode 100644 index 0000000000..f5a20fd85d --- /dev/null +++ b/src/impl/generated_specializations_cpp/trtri/KokkosBlas_trtri_eti_spec_inst.cpp.in @@ -0,0 +1,54 @@ +/* +//@HEADER +// ************************************************************************ +// +// Kokkos v. 3.0 +// Copyright (2020) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// 1. Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the Corporation nor the names of the +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY NTESS "AS IS" AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NTESS OR THE +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Questions? Contact Siva Rajamanickam (srajama@sandia.gov) +// +// ************************************************************************ +//@HEADER +*/ + + +#define KOKKOSKERNELS_IMPL_COMPILE_LIBRARY true +#include "KokkosKernels_config.h" +#include "KokkosBlas_trtri_spec.hpp" + +namespace KokkosBlas { +namespace Impl { +@BLAS_TRTRI_ETI_INST_BLOCK@ + } //IMPL +} //Kokkos diff --git a/src/impl/generated_specializations_hpp/KokkosBlas3_trmm_eti_spec_avail.hpp.in b/src/impl/generated_specializations_hpp/KokkosBlas3_trmm_eti_spec_avail.hpp.in new file mode 100644 index 0000000000..1b4a7f720c --- /dev/null +++ b/src/impl/generated_specializations_hpp/KokkosBlas3_trmm_eti_spec_avail.hpp.in @@ -0,0 +1,54 @@ +#ifndef KOKKOSBLAS3_TRMM_ETI_SPEC_AVAIL_HPP_ +#define KOKKOSBLAS3_TRMM_ETI_SPEC_AVAIL_HPP_ +/* +//@HEADER +// ************************************************************************ +// +// Kokkos v. 3.0 +// Copyright (2020) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// 1. Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the Corporation nor the names of the +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY NTESS "AS IS" AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NTESS OR THE +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Questions? Contact Siva Rajamanickam (srajama@sandia.gov) +// +// ************************************************************************ +//@HEADER +*/ + +namespace KokkosBlas { +namespace Impl { + +@BLAS3_TRMM_ETI_AVAIL_BLOCK@ + +} // Impl +} // KokkosBlas +#endif // KOKKOSBLAS3_TRMM_ETI_SPEC_AVAIL_HPP_ diff --git a/src/impl/generated_specializations_hpp/KokkosBlas3_trmm_eti_spec_decl.hpp.in b/src/impl/generated_specializations_hpp/KokkosBlas3_trmm_eti_spec_decl.hpp.in new file mode 100644 index 0000000000..cb76daa82b --- /dev/null +++ b/src/impl/generated_specializations_hpp/KokkosBlas3_trmm_eti_spec_decl.hpp.in @@ -0,0 +1,54 @@ +#ifndef KOKKOSBLAS3_TRMM_ETI_SPEC_DECL_HPP_ +#define KOKKOSBLAS3_TRMM_ETI_SPEC_DECL_HPP_ +/* +//@HEADER +// ************************************************************************ +// +// Kokkos v. 3.0 +// Copyright (2020) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// 1. Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the Corporation nor the names of the +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY NTESS "AS IS" AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NTESS OR THE +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Questions? Contact Siva Rajamanickam (srajama@sandia.gov) +// +// ************************************************************************ +//@HEADER +*/ + +namespace KokkosBlas { +namespace Impl { + +@BLAS3_TRMM_ETI_DECL_BLOCK@ + +} // Impl +} // KokkosBlas +#endif // KOKKOSBLAS3_TRMM_ETI_SPEC_DECL_HPP_ diff --git a/src/impl/generated_specializations_hpp/KokkosBlas_trtri_eti_spec_avail.hpp.in b/src/impl/generated_specializations_hpp/KokkosBlas_trtri_eti_spec_avail.hpp.in new file mode 100644 index 0000000000..988f69389b --- /dev/null +++ b/src/impl/generated_specializations_hpp/KokkosBlas_trtri_eti_spec_avail.hpp.in @@ -0,0 +1,54 @@ +#ifndef KOKKOSBLAS_TRTRI_ETI_SPEC_AVAIL_HPP_ +#define KOKKOSBLAS_TRTRI_ETI_SPEC_AVAIL_HPP_ +/* +//@HEADER +// ************************************************************************ +// +// Kokkos v. 3.0 +// Copyright (2020) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// 1. Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the Corporation nor the names of the +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY NTESS "AS IS" AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NTESS OR THE +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Questions? Contact Siva Rajamanickam (srajama@sandia.gov) +// +// ************************************************************************ +//@HEADER +*/ + +namespace KokkosBlas { +namespace Impl { + +@BLAS_TRTRI_ETI_AVAIL_BLOCK@ + +} // Impl +} // KokkosBlas +#endif // KOKKOSBLAS_TRTRI_ETI_SPEC_AVAIL_HPP_ diff --git a/src/impl/generated_specializations_hpp/KokkosBlas_trtri_eti_spec_decl.hpp.in b/src/impl/generated_specializations_hpp/KokkosBlas_trtri_eti_spec_decl.hpp.in new file mode 100644 index 0000000000..6469c42257 --- /dev/null +++ b/src/impl/generated_specializations_hpp/KokkosBlas_trtri_eti_spec_decl.hpp.in @@ -0,0 +1,54 @@ +#ifndef KOKKOSBLAS_TRTRI_ETI_SPEC_DECL_HPP_ +#define KOKKOSBLAS_TRTRI_ETI_SPEC_DECL_HPP_ +/* +//@HEADER +// ************************************************************************ +// +// Kokkos v. 3.0 +// Copyright (2020) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// 1. Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the Corporation nor the names of the +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY NTESS "AS IS" AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NTESS OR THE +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Questions? Contact Siva Rajamanickam (srajama@sandia.gov) +// +// ************************************************************************ +//@HEADER +*/ + +namespace KokkosBlas { +namespace Impl { + +@BLAS_TRTRI_ETI_DECL_BLOCK@ + +} // Impl +} // KokkosBlas +#endif // KOKKOSBLAS_TRTRI_ETI_SPEC_DECL_HPP_ diff --git a/src/impl/tpls/KokkosBlas3_trmm_tpl_spec_decl.hpp b/src/impl/tpls/KokkosBlas3_trmm_tpl_spec_decl.hpp index c08abe2b70..c323db7e4a 100644 --- a/src/impl/tpls/KokkosBlas3_trmm_tpl_spec_decl.hpp +++ b/src/impl/tpls/KokkosBlas3_trmm_tpl_spec_decl.hpp @@ -131,24 +131,24 @@ KOKKOSBLAS3_TRMM_BLAS(Kokkos::complex, std::complex, LAYOUTA, LAY // Explicitly define the TRMM class for all permutations listed below -//KOKKOSBLAS3_DTRMM_BLAS( Kokkos::LayoutLeft, Kokkos::LayoutLeft, Kokkos::HostSpace, true) +KOKKOSBLAS3_DTRMM_BLAS( Kokkos::LayoutLeft, Kokkos::LayoutLeft, Kokkos::HostSpace, true) KOKKOSBLAS3_DTRMM_BLAS( Kokkos::LayoutLeft, Kokkos::LayoutLeft, Kokkos::HostSpace, false) -//KOKKOSBLAS3_DTRMM_BLAS( Kokkos::LayoutRight, Kokkos::LayoutRight, Kokkos::HostSpace, true) +KOKKOSBLAS3_DTRMM_BLAS( Kokkos::LayoutRight, Kokkos::LayoutRight, Kokkos::HostSpace, true) KOKKOSBLAS3_DTRMM_BLAS( Kokkos::LayoutRight, Kokkos::LayoutRight, Kokkos::HostSpace, false) -//KOKKOSBLAS3_STRMM_BLAS( Kokkos::LayoutLeft, Kokkos::LayoutLeft, Kokkos::HostSpace, true) +KOKKOSBLAS3_STRMM_BLAS( Kokkos::LayoutLeft, Kokkos::LayoutLeft, Kokkos::HostSpace, true) KOKKOSBLAS3_STRMM_BLAS( Kokkos::LayoutLeft, Kokkos::LayoutLeft, Kokkos::HostSpace, false) -//KOKKOSBLAS3_STRMM_BLAS( Kokkos::LayoutRight, Kokkos::LayoutRight, Kokkos::HostSpace, true) +KOKKOSBLAS3_STRMM_BLAS( Kokkos::LayoutRight, Kokkos::LayoutRight, Kokkos::HostSpace, true) KOKKOSBLAS3_STRMM_BLAS( Kokkos::LayoutRight, Kokkos::LayoutRight, Kokkos::HostSpace, false) -//KOKKOSBLAS3_ZTRMM_BLAS( Kokkos::LayoutLeft, Kokkos::LayoutLeft, Kokkos::HostSpace, true) +KOKKOSBLAS3_ZTRMM_BLAS( Kokkos::LayoutLeft, Kokkos::LayoutLeft, Kokkos::HostSpace, true) KOKKOSBLAS3_ZTRMM_BLAS( Kokkos::LayoutLeft, Kokkos::LayoutLeft, Kokkos::HostSpace, false) -//KOKKOSBLAS3_ZTRMM_BLAS( Kokkos::LayoutRight, Kokkos::LayoutRight, Kokkos::HostSpace, true) +KOKKOSBLAS3_ZTRMM_BLAS( Kokkos::LayoutRight, Kokkos::LayoutRight, Kokkos::HostSpace, true) KOKKOSBLAS3_ZTRMM_BLAS( Kokkos::LayoutRight, Kokkos::LayoutRight, Kokkos::HostSpace, false) -//KOKKOSBLAS3_CTRMM_BLAS( Kokkos::LayoutLeft, Kokkos::LayoutLeft, Kokkos::HostSpace, true) +KOKKOSBLAS3_CTRMM_BLAS( Kokkos::LayoutLeft, Kokkos::LayoutLeft, Kokkos::HostSpace, true) KOKKOSBLAS3_CTRMM_BLAS( Kokkos::LayoutLeft, Kokkos::LayoutLeft, Kokkos::HostSpace, false) -//KOKKOSBLAS3_CTRMM_BLAS( Kokkos::LayoutRight, Kokkos::LayoutRight, Kokkos::HostSpace, true) +KOKKOSBLAS3_CTRMM_BLAS( Kokkos::LayoutRight, Kokkos::LayoutRight, Kokkos::HostSpace, true) KOKKOSBLAS3_CTRMM_BLAS( Kokkos::LayoutRight, Kokkos::LayoutRight, Kokkos::HostSpace, false) } @@ -255,44 +255,44 @@ KOKKOSBLAS3_TRMM_CUBLAS(Kokkos::complex, cuComplex, cublasCtrmm, LAYOUTA, // Explicitly define the TRMM class for all permutations listed below -//KOKKOSBLAS3_DTRMM_CUBLAS( Kokkos::LayoutLeft, Kokkos::LayoutLeft, Kokkos::CudaSpace, true) +KOKKOSBLAS3_DTRMM_CUBLAS( Kokkos::LayoutLeft, Kokkos::LayoutLeft, Kokkos::CudaSpace, true) KOKKOSBLAS3_DTRMM_CUBLAS( Kokkos::LayoutLeft, Kokkos::LayoutLeft, Kokkos::CudaSpace, false) -//KOKKOSBLAS3_DTRMM_CUBLAS( Kokkos::LayoutRight, Kokkos::LayoutRight, Kokkos::CudaSpace, true) +KOKKOSBLAS3_DTRMM_CUBLAS( Kokkos::LayoutRight, Kokkos::LayoutRight, Kokkos::CudaSpace, true) KOKKOSBLAS3_DTRMM_CUBLAS( Kokkos::LayoutRight, Kokkos::LayoutRight, Kokkos::CudaSpace, false) -//KOKKOSBLAS3_DTRMM_CUBLAS( Kokkos::LayoutLeft, Kokkos::LayoutLeft, Kokkos::CudaUVMSpace, true) +KOKKOSBLAS3_DTRMM_CUBLAS( Kokkos::LayoutLeft, Kokkos::LayoutLeft, Kokkos::CudaUVMSpace, true) KOKKOSBLAS3_DTRMM_CUBLAS( Kokkos::LayoutLeft, Kokkos::LayoutLeft, Kokkos::CudaUVMSpace, false) -//KOKKOSBLAS3_DTRMM_CUBLAS( Kokkos::LayoutRight, Kokkos::LayoutRight, Kokkos::CudaUVMSpace, true) +KOKKOSBLAS3_DTRMM_CUBLAS( Kokkos::LayoutRight, Kokkos::LayoutRight, Kokkos::CudaUVMSpace, true) KOKKOSBLAS3_DTRMM_CUBLAS( Kokkos::LayoutRight, Kokkos::LayoutRight, Kokkos::CudaUVMSpace, false) -//KOKKOSBLAS3_STRMM_CUBLAS( Kokkos::LayoutLeft, Kokkos::LayoutLeft, Kokkos::CudaSpace, true) +KOKKOSBLAS3_STRMM_CUBLAS( Kokkos::LayoutLeft, Kokkos::LayoutLeft, Kokkos::CudaSpace, true) KOKKOSBLAS3_STRMM_CUBLAS( Kokkos::LayoutLeft, Kokkos::LayoutLeft, Kokkos::CudaSpace, false) -//KOKKOSBLAS3_STRMM_CUBLAS( Kokkos::LayoutRight, Kokkos::LayoutRight, Kokkos::CudaSpace, true) +KOKKOSBLAS3_STRMM_CUBLAS( Kokkos::LayoutRight, Kokkos::LayoutRight, Kokkos::CudaSpace, true) KOKKOSBLAS3_STRMM_CUBLAS( Kokkos::LayoutRight, Kokkos::LayoutRight, Kokkos::CudaSpace, false) -//KOKKOSBLAS3_STRMM_CUBLAS( Kokkos::LayoutLeft, Kokkos::LayoutLeft, Kokkos::CudaUVMSpace, true) +KOKKOSBLAS3_STRMM_CUBLAS( Kokkos::LayoutLeft, Kokkos::LayoutLeft, Kokkos::CudaUVMSpace, true) KOKKOSBLAS3_STRMM_CUBLAS( Kokkos::LayoutLeft, Kokkos::LayoutLeft, Kokkos::CudaUVMSpace, false) -//KOKKOSBLAS3_STRMM_CUBLAS( Kokkos::LayoutRight, Kokkos::LayoutRight, Kokkos::CudaUVMSpace, true) +KOKKOSBLAS3_STRMM_CUBLAS( Kokkos::LayoutRight, Kokkos::LayoutRight, Kokkos::CudaUVMSpace, true) KOKKOSBLAS3_STRMM_CUBLAS( Kokkos::LayoutRight, Kokkos::LayoutRight, Kokkos::CudaUVMSpace, false) -//KOKKOSBLAS3_ZTRMM_CUBLAS( Kokkos::LayoutLeft, Kokkos::LayoutLeft, Kokkos::CudaSpace, true) +KOKKOSBLAS3_ZTRMM_CUBLAS( Kokkos::LayoutLeft, Kokkos::LayoutLeft, Kokkos::CudaSpace, true) KOKKOSBLAS3_ZTRMM_CUBLAS( Kokkos::LayoutLeft, Kokkos::LayoutLeft, Kokkos::CudaSpace, false) -//KOKKOSBLAS3_ZTRMM_CUBLAS( Kokkos::LayoutRight, Kokkos::LayoutRight, Kokkos::CudaSpace, true) +KOKKOSBLAS3_ZTRMM_CUBLAS( Kokkos::LayoutRight, Kokkos::LayoutRight, Kokkos::CudaSpace, true) KOKKOSBLAS3_ZTRMM_CUBLAS( Kokkos::LayoutRight, Kokkos::LayoutRight, Kokkos::CudaSpace, false) -//KOKKOSBLAS3_ZTRMM_CUBLAS( Kokkos::LayoutLeft, Kokkos::LayoutLeft, Kokkos::CudaUVMSpace, true) +KOKKOSBLAS3_ZTRMM_CUBLAS( Kokkos::LayoutLeft, Kokkos::LayoutLeft, Kokkos::CudaUVMSpace, true) KOKKOSBLAS3_ZTRMM_CUBLAS( Kokkos::LayoutLeft, Kokkos::LayoutLeft, Kokkos::CudaUVMSpace, false) -//KOKKOSBLAS3_ZTRMM_CUBLAS( Kokkos::LayoutRight, Kokkos::LayoutRight, Kokkos::CudaUVMSpace, true) +KOKKOSBLAS3_ZTRMM_CUBLAS( Kokkos::LayoutRight, Kokkos::LayoutRight, Kokkos::CudaUVMSpace, true) KOKKOSBLAS3_ZTRMM_CUBLAS( Kokkos::LayoutRight, Kokkos::LayoutRight, Kokkos::CudaUVMSpace, false) -//KOKKOSBLAS3_CTRMM_CUBLAS( Kokkos::LayoutLeft, Kokkos::LayoutLeft, Kokkos::CudaSpace, true) +KOKKOSBLAS3_CTRMM_CUBLAS( Kokkos::LayoutLeft, Kokkos::LayoutLeft, Kokkos::CudaSpace, true) KOKKOSBLAS3_CTRMM_CUBLAS( Kokkos::LayoutLeft, Kokkos::LayoutLeft, Kokkos::CudaSpace, false) -//KOKKOSBLAS3_CTRMM_CUBLAS( Kokkos::LayoutRight, Kokkos::LayoutRight, Kokkos::CudaSpace, true) +KOKKOSBLAS3_CTRMM_CUBLAS( Kokkos::LayoutRight, Kokkos::LayoutRight, Kokkos::CudaSpace, true) KOKKOSBLAS3_CTRMM_CUBLAS( Kokkos::LayoutRight, Kokkos::LayoutRight, Kokkos::CudaSpace, false) -//KOKKOSBLAS3_CTRMM_CUBLAS( Kokkos::LayoutLeft, Kokkos::LayoutLeft, Kokkos::CudaUVMSpace, true) +KOKKOSBLAS3_CTRMM_CUBLAS( Kokkos::LayoutLeft, Kokkos::LayoutLeft, Kokkos::CudaUVMSpace, true) KOKKOSBLAS3_CTRMM_CUBLAS( Kokkos::LayoutLeft, Kokkos::LayoutLeft, Kokkos::CudaUVMSpace, false) -//KOKKOSBLAS3_CTRMM_CUBLAS( Kokkos::LayoutRight, Kokkos::LayoutRight, Kokkos::CudaUVMSpace, true) +KOKKOSBLAS3_CTRMM_CUBLAS( Kokkos::LayoutRight, Kokkos::LayoutRight, Kokkos::CudaUVMSpace, true) KOKKOSBLAS3_CTRMM_CUBLAS( Kokkos::LayoutRight, Kokkos::LayoutRight, Kokkos::CudaUVMSpace, false) } // namespace Impl diff --git a/src/impl/tpls/KokkosBlas_trtri_tpl_spec_avail.hpp b/src/impl/tpls/KokkosBlas_trtri_tpl_spec_avail.hpp index b6da2c40ee..fa651f531f 100644 --- a/src/impl/tpls/KokkosBlas_trtri_tpl_spec_avail.hpp +++ b/src/impl/tpls/KokkosBlas_trtri_tpl_spec_avail.hpp @@ -60,7 +60,7 @@ template \ struct trtri_tpl_spec_avail< \ Kokkos::View >, \ - Kokkos::View, \ + Kokkos::View, \ Kokkos::MemoryTraits > \ > { enum : bool { value = true }; }; diff --git a/src/impl/tpls/KokkosBlas_trtri_tpl_spec_decl.hpp b/src/impl/tpls/KokkosBlas_trtri_tpl_spec_decl.hpp index ea5d6ac4c2..79d52a7602 100644 --- a/src/impl/tpls/KokkosBlas_trtri_tpl_spec_decl.hpp +++ b/src/impl/tpls/KokkosBlas_trtri_tpl_spec_decl.hpp @@ -57,7 +57,7 @@ template \ struct TRTRI< \ Kokkos::View >, \ - Kokkos::View, \ + Kokkos::View, \ Kokkos::MemoryTraits >, \ true, ETI_SPEC_AVAIL> { \ typedef SCALAR_TYPE SCALAR; \ @@ -100,7 +100,7 @@ template \ struct TRTRI< \ Kokkos::View >, \ - Kokkos::View, \ + Kokkos::View, \ Kokkos::MemoryTraits >, \ true, ETI_SPEC_AVAIL> { \ typedef SCALAR_TYPE SCALAR; \ @@ -169,24 +169,24 @@ KOKKOSBLAS_TRTRI_BLAS_MAGMA(Kokkos::complex, magmaFloatComplex_ptr, magma KOKKOSBLAS_TRTRI_BLAS_MAGMA(Kokkos::complex, magmaFloatComplex_ptr, magma_ctrtri_gpu, LAYOUTA, Kokkos::CudaUVMSpace, ETI_SPEC_AVAIL) \ // Handle layout permutations -//KOKKOSBLAS_DTRTRI_BLAS(Kokkos::LayoutLeft, Kokkos::LayoutLeft, Kokkos::HostSpace, true) +KOKKOSBLAS_DTRTRI_BLAS(Kokkos::LayoutLeft, true) KOKKOSBLAS_DTRTRI_BLAS(Kokkos::LayoutLeft, false) -//KOKKOSBLAS_DTRTRI_BLAS(Kokkos::LayoutRight, Kokkos::LayoutRight, Kokkos::HostSpace, true) +KOKKOSBLAS_DTRTRI_BLAS(Kokkos::LayoutRight, true) KOKKOSBLAS_DTRTRI_BLAS(Kokkos::LayoutRight, false) -//KOKKOSBLAS_STRTRI_BLAS(Kokkos::LayoutLeft, Kokkos::LayoutLeft, true) +KOKKOSBLAS_STRTRI_BLAS(Kokkos::LayoutLeft, true) KOKKOSBLAS_STRTRI_BLAS(Kokkos::LayoutLeft, false) -//KOKKOSBLAS_STRTRI_BLAS(Kokkos::LayoutRight, Kokkos::LayoutRight, true) +KOKKOSBLAS_STRTRI_BLAS(Kokkos::LayoutRight, true) KOKKOSBLAS_STRTRI_BLAS(Kokkos::LayoutRight, false) -//KOKKOSBLAS_ZTRTRI_BLAS(Kokkos::LayoutLeft, Kokkos::LayoutLeft, true) +KOKKOSBLAS_ZTRTRI_BLAS(Kokkos::LayoutLeft, true) KOKKOSBLAS_ZTRTRI_BLAS(Kokkos::LayoutLeft, false) -//KOKKOSBLAS_ZTRTRI_BLAS(Kokkos::LayoutRight, Kokkos::LayoutRight, true) +KOKKOSBLAS_ZTRTRI_BLAS(Kokkos::LayoutRight, true) KOKKOSBLAS_ZTRTRI_BLAS(Kokkos::LayoutRight, false) -//KOKKOSBLAS_CTRTRI_BLAS(Kokkos::LayoutLeft, Kokkos::LayoutLeft, true) +KOKKOSBLAS_CTRTRI_BLAS(Kokkos::LayoutLeft, true) KOKKOSBLAS_CTRTRI_BLAS(Kokkos::LayoutLeft, false) -//KOKKOSBLAS_CTRTRI_BLAS(Kokkos::LayoutRight, Kokkos::LayoutRight, true) +KOKKOSBLAS_CTRTRI_BLAS(Kokkos::LayoutRight, true) KOKKOSBLAS_CTRTRI_BLAS(Kokkos::LayoutRight, false) } // namespace Impl diff --git a/unit_test/batched/Test_Batched_SerialTrmm.hpp b/unit_test/batched/Test_Batched_SerialTrmm.hpp new file mode 100644 index 0000000000..b444de788b --- /dev/null +++ b/unit_test/batched/Test_Batched_SerialTrmm.hpp @@ -0,0 +1,307 @@ +#include "gtest/gtest.h" +#include "Kokkos_Core.hpp" +#include "Kokkos_Random.hpp" + +#include "KokkosBatched_Trmm_Decl.hpp" +#include "KokkosBatched_Trmm_Serial_Impl.hpp" + +#include "KokkosKernels_TestUtils.hpp" + +using namespace KokkosBatched; + +namespace Test { + + template + struct UnitDiagTRMM { + ViewTypeA A_; + using ScalarA = typename ViewTypeA::value_type; + + UnitDiagTRMM (const ViewTypeA& A) : A_(A) {} + + KOKKOS_INLINE_FUNCTION + void operator() (const int& i) const { + A_(i,i) = ScalarA(1); + } + }; + template + struct NonUnitDiagTRMM { + ViewTypeA A_; + using ScalarA = typename ViewTypeA::value_type; + + NonUnitDiagTRMM (const ViewTypeA& A) : A_(A) {} + + KOKKOS_INLINE_FUNCTION + void operator() (const int& i) const { + A_(i,i) = A_(i,i)+10; + } + }; + template + struct VanillaGEMM { + bool A_t, B_t, A_c, B_c; + int N,K; + ViewTypeA A; + ViewTypeB B; + ViewTypeC C; + + typedef typename ViewTypeA::value_type ScalarA; + typedef typename ViewTypeB::value_type ScalarB; + typedef typename ViewTypeC::value_type ScalarC; + typedef Kokkos::Details::ArithTraits APT; + typedef typename APT::mag_type mag_type; + ScalarA alpha; + ScalarC beta; + + KOKKOS_INLINE_FUNCTION + void operator() (const typename Kokkos::TeamPolicy::member_type& team) const { +// GNU COMPILER BUG WORKAROUND +#if defined(KOKKOS_COMPILER_GNU) && !defined(__CUDA_ARCH__) + int i = team.league_rank(); +#else + const int i = team.league_rank(); +#endif + Kokkos::parallel_for(Kokkos::TeamThreadRange(team,N), [&] (const int& j) { + ScalarC C_ij = 0.0; + + // GNU 5.3, 5.4 and 6.1 (and maybe more) crash with another nested lambda here + +#if defined(KOKKOS_COMPILER_GNU) && !defined(KOKKOS_COMPILER_NVCC) + for(int k=0; k + struct ParamTag { + typedef S side; + typedef U uplo; + typedef T trans; + typedef D diag; + }; + + template + struct Functor_TestBatchedSerialTrmm { + ViewType _a, _b; + + ScalarType _alpha; + + KOKKOS_INLINE_FUNCTION + Functor_TestBatchedSerialTrmm(const ScalarType alpha, + const ViewType &a, + const ViewType &b) + : _a(a), _b(b), _alpha(alpha) {} + + KOKKOS_INLINE_FUNCTION + void operator()(const ParamTagType &, const int k) const { + auto aa = Kokkos::subview(_a, k, Kokkos::ALL(), Kokkos::ALL()); + auto bb = Kokkos::subview(_b, k, Kokkos::ALL(), Kokkos::ALL()); + + SerialTrmm:: + invoke(_alpha, aa, bb); + } + + inline + void run() { + typedef typename ViewType::value_type value_type; + std::string name_region("KokkosBatched::Test::SerialTrmm"); + std::string name_value_type = ( std::is_same::value ? "::Float" : + std::is_same::value ? "::Double" : + std::is_same >::value ? "::ComplexFloat" : + std::is_same >::value ? "::ComplexDouble" : "::UnknownValueType" ); + std::string name = name_region + name_value_type; + Kokkos::Profiling::pushRegion( name.c_str() ); + Kokkos::RangePolicy policy(0, _a.extent(0)); + Kokkos::parallel_for(name.c_str(), policy, *this); + Kokkos::Profiling::popRegion(); + } + }; + + template + void impl_test_batched_trmm(const int N, const int nRows, const int nCols, const char *trans) { + typedef typename ViewType::value_type value_type; + typedef typename DeviceType::execution_space execution_space; + typedef Kokkos::Details::ArithTraits ats; + + ScalarType alpha(1.0); + ScalarType beta(0.0); + + const bool is_side_right = std::is_same::value; + const bool is_A_lower = std::is_same::value; + const int K = is_side_right ? nCols : nRows; + ViewType + A("A", N, K, K), + B_actual("B_actual", N, nRows, nCols), + B_expected("B_expected", N, nRows, nCols); + typename ViewType::HostMirror A_host = Kokkos::create_mirror_view(A); + typename ViewType::HostMirror B_actual_host = Kokkos::create_mirror_view(B_actual); + typename ViewType::HostMirror B_expected_host = Kokkos::create_mirror_view(B_expected); + uint64_t seed = Kokkos::Impl::clock_tic(); + + using ViewTypeSubA = decltype(Kokkos::subview(A, 0, Kokkos::ALL(), Kokkos::ALL())); + using ViewTypeSubB = decltype(Kokkos::subview(B_actual, 0, Kokkos::ALL(), Kokkos::ALL())); + + Kokkos::Random_XorShift64_Pool rand_pool(seed); + + if(std::is_same::value) { + // Initialize A with deterministic random numbers + Kokkos::fill_random(A, rand_pool, Kokkos::rand, ScalarType>::max()); + using functor_type = UnitDiagTRMM; + for (int k = 0; k < N; ++k) { + functor_type udtrmm(Kokkos::subview(A, k, Kokkos::ALL(), Kokkos::ALL())); + // Initialize As diag with 1s + Kokkos::parallel_for("KokkosBlas::Test::UnitDiagTRMM", Kokkos::RangePolicy(0,K), udtrmm); + } + } else {//(diag[0]=='N')||(diag[0]=='n') + // Initialize A with random numbers + Kokkos::fill_random(A, rand_pool, Kokkos::rand, ScalarType>::max()); + using functor_type = NonUnitDiagTRMM; + for (int k = 0; k < N; ++k) { + functor_type nudtrmm(Kokkos::subview(A, k, Kokkos::ALL(), Kokkos::ALL())); + // Initialize As diag with A(i,i)+10 + Kokkos::parallel_for("KokkosBlas::Test::NonUnitDiagTRMM", Kokkos::RangePolicy(0,K), nudtrmm); + } + } + Kokkos::fill_random(B_actual, rand_pool, Kokkos::rand, ScalarType>::max()); + Kokkos::fence(); + + Kokkos::deep_copy(B_expected, B_actual); + Kokkos::fence(); + + Kokkos::deep_copy(A_host, A); + // Make A_host a lower triangle + for (int k = 0; k < N; k++) { + if (is_A_lower) { + for (int i = 0; i < K-1; i++) + for (int j = i+1; j < K; j++) + A_host(k,i,j) = ScalarType(0); + } + else { + // Make A_host a upper triangle + for (int i = 1; i < K; i++) + for (int j = 0; j < i; j++) + A_host(k,i,j) = ScalarType(0); + } + } + Kokkos::deep_copy(A, A_host); + + if (!is_side_right){ + // B_expected = alpha * op(A) * B + beta * C = 1 * op(A) * B + 0 * C + struct VanillaGEMM vgemm; + vgemm.A_t = (trans[0]!='N') && (trans[0]!='n'); vgemm.B_t = false; + vgemm.A_c = (trans[0]=='C') || (trans[0]=='c'); vgemm.B_c = false; + vgemm.N = nCols; vgemm.K = K; + vgemm.alpha = alpha; + vgemm.beta = beta; + for (int i = 0; i < N; i++) { + vgemm.A = Kokkos::subview(A, i, Kokkos::ALL(), Kokkos::ALL()); + vgemm.B = Kokkos::subview(B_actual, i, Kokkos::ALL(), Kokkos::ALL());; + vgemm.C = Kokkos::subview(B_expected, i, Kokkos::ALL(), Kokkos::ALL());; + Kokkos::parallel_for("KokkosBlas::Test::VanillaGEMM", Kokkos::TeamPolicy(nRows,Kokkos::AUTO,16), vgemm); + } + } + else { + // B_expected = alpha * B * op(A) + beta * C = 1 * B * op(A) + 0 * C + struct VanillaGEMM vgemm; + vgemm.A_t = false; vgemm.B_t = (trans[0]!='N') && (trans[0]!='n'); + vgemm.A_c = false; vgemm.B_c = (trans[0]=='C') || (trans[0]=='c'); + vgemm.N = nCols; vgemm.K = K; + vgemm.alpha = alpha; + vgemm.beta = beta; + for (int i = 0; i < N; i++) { + vgemm.A = Kokkos::subview(B_actual, i, Kokkos::ALL(), Kokkos::ALL()); + vgemm.B = Kokkos::subview(A, i, Kokkos::ALL(), Kokkos::ALL());; + vgemm.C = Kokkos::subview(B_expected, i, Kokkos::ALL(), Kokkos::ALL());; + Kokkos::parallel_for("KokkosBlas::Test::VanillaGEMM", Kokkos::TeamPolicy(nRows,Kokkos::AUTO,16), vgemm); + } + } + + Functor_TestBatchedSerialTrmm(alpha, A, B_actual).run(); + + Kokkos::fence(); + + Kokkos::deep_copy(B_actual_host, B_actual); + Kokkos::deep_copy(B_expected_host, B_expected); + + Kokkos::fence(); + + // eps is ~ 10^-13 for double + typedef typename ats::mag_type mag_type; + const mag_type eps = 1.0e8 * ats::epsilon(); + bool fail_flag = false; + + for (int k=0;k eps) { + //printf(" Error: eps ( %g ), abs_result( %.15lf ) != abs_solution( %.15lf ) (abs result-solution %g) at (k %d, i %d, j %d)\n", eps, ats::abs(B_actual_host(k,i,j)), ats::abs(B_expected_host(k,i,j)), ats::abs(B_actual_host(k,i,j) - B_expected_host(k,i,j)), k, i, j); + fail_flag = true; + } + } + } + } + + ASSERT_EQ( fail_flag, false ); + } +} + + +template +int test_batched_trmm() { + char trans = std::is_same::value ? 'N' : + std::is_same::value ? 'T' : + std::is_same::value ? 'C' : 'E'; +#if defined(KOKKOSKERNELS_INST_LAYOUTLEFT) + { + typedef Kokkos::View ViewType; + Test::impl_test_batched_trmm( 0, 10, 4, &trans); + for (int i=0;i<10;++i) { + //printf("Testing: LayoutLeft, Blksize %d\n", i); + Test::impl_test_batched_trmm(1024, i, 4, &trans); + Test::impl_test_batched_trmm(1024, i, 1, &trans); + } + } +#endif +#if defined(KOKKOSKERNELS_INST_LAYOUTRIGHT) + { + typedef Kokkos::View ViewType; + Test::impl_test_batched_trmm( 0, 10, 4, &trans); + for (int i=0;i<10;++i) { + //printf("Testing: LayoutRight, Blksize %d\n", i); + Test::impl_test_batched_trmm(1024, i, 4, &trans); + Test::impl_test_batched_trmm(1024, i, 1, &trans); + } + } +#endif + + return 0; +} + diff --git a/unit_test/batched/Test_Batched_SerialTrmm_Complex.hpp b/unit_test/batched/Test_Batched_SerialTrmm_Complex.hpp new file mode 100644 index 0000000000..0ba532df24 --- /dev/null +++ b/unit_test/batched/Test_Batched_SerialTrmm_Complex.hpp @@ -0,0 +1,230 @@ + +#if defined(KOKKOSKERNELS_INST_COMPLEX_FLOAT) +// NO TRANSPOSE +TEST_F( TestCategory, batched_scalar_serial_trmm_l_l_nt_u_scomplex_scomplex ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm,Kokkos::complex,param_tag_type,algo_tag_type>(); +} +TEST_F( TestCategory, batched_scalar_serial_trmm_l_l_nt_n_scomplex_scomplex ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm,Kokkos::complex,param_tag_type,algo_tag_type>(); +} +TEST_F( TestCategory, batched_scalar_serial_trmm_l_u_nt_u_scomplex_scomplex ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm,Kokkos::complex,param_tag_type,algo_tag_type>(); +} +TEST_F( TestCategory, batched_scalar_serial_trmm_l_u_nt_n_scomplex_scomplex ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm,Kokkos::complex,param_tag_type,algo_tag_type>(); +} +TEST_F( TestCategory, batched_scalar_serial_trmm_r_u_nt_u_scomplex_scomplex ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm,Kokkos::complex,param_tag_type,algo_tag_type>(); +} +TEST_F( TestCategory, batched_scalar_serial_trmm_r_u_nt_n_scomplex_scomplex ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm,Kokkos::complex,param_tag_type,algo_tag_type>(); +} +// TRANSPOSE +TEST_F( TestCategory, batched_scalar_serial_trmm_l_l_t_u_scomplex_scomplex ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm,Kokkos::complex,param_tag_type,algo_tag_type>(); +} +TEST_F( TestCategory, batched_scalar_serial_trmm_l_l_t_n_scomplex_scomplex ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm,Kokkos::complex,param_tag_type,algo_tag_type>(); +} +TEST_F( TestCategory, batched_scalar_serial_trmm_l_u_t_u_scomplex_scomplex ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm,Kokkos::complex,param_tag_type,algo_tag_type>(); +} +TEST_F( TestCategory, batched_scalar_serial_trmm_l_u_t_n_scomplex_scomplex ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm,Kokkos::complex,param_tag_type,algo_tag_type>(); +} +TEST_F( TestCategory, batched_scalar_serial_trmm_r_u_t_u_scomplex_scomplex ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm,Kokkos::complex,param_tag_type,algo_tag_type>(); +} +TEST_F( TestCategory, batched_scalar_serial_trmm_r_u_t_n_scomplex_scomplex ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm,Kokkos::complex,param_tag_type,algo_tag_type>(); +} +// CONJUGATE TRANSPOSE +TEST_F( TestCategory, batched_scalar_serial_trmm_l_l_ct_u_scomplex_scomplex ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm,Kokkos::complex,param_tag_type,algo_tag_type>(); +} +TEST_F( TestCategory, batched_scalar_serial_trmm_l_l_ct_n_scomplex_scomplex ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm,Kokkos::complex,param_tag_type,algo_tag_type>(); +} +TEST_F( TestCategory, batched_scalar_serial_trmm_l_u_ct_u_scomplex_scomplex ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm,Kokkos::complex,param_tag_type,algo_tag_type>(); +} +TEST_F( TestCategory, batched_scalar_serial_trmm_l_u_ct_n_scomplex_scomplex ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm,Kokkos::complex,param_tag_type,algo_tag_type>(); +} +TEST_F( TestCategory, batched_scalar_serial_trmm_r_u_ct_u_scomplex_scomplex ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm,Kokkos::complex,param_tag_type,algo_tag_type>(); +} +TEST_F( TestCategory, batched_scalar_serial_trmm_r_u_ct_n_scomplex_scomplex ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm,Kokkos::complex,param_tag_type,algo_tag_type>(); +} +#endif + + +#if defined(KOKKOSKERNELS_INST_COMPLEX_DOUBLE) +// NO TRANSPOSE +TEST_F( TestCategory, batched_scalar_serial_trmm_l_l_nt_u_dcomplex_dcomplex ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm,Kokkos::complex,param_tag_type,algo_tag_type>(); +} +TEST_F( TestCategory, batched_scalar_serial_trmm_l_l_nt_n_dcomplex_dcomplex ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm,Kokkos::complex,param_tag_type,algo_tag_type>(); +} +TEST_F( TestCategory, batched_scalar_serial_trmm_l_u_nt_u_dcomplex_dcomplex ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm,Kokkos::complex,param_tag_type,algo_tag_type>(); +} +TEST_F( TestCategory, batched_scalar_serial_trmm_l_u_nt_n_dcomplex_dcomplex ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm,Kokkos::complex,param_tag_type,algo_tag_type>(); +} +TEST_F( TestCategory, batched_scalar_serial_trmm_r_u_nt_u_dcomplex_dcomplex ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm,Kokkos::complex,param_tag_type,algo_tag_type>(); +} +TEST_F( TestCategory, batched_scalar_serial_trmm_r_u_nt_n_dcomplex_dcomplex ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm,Kokkos::complex,param_tag_type,algo_tag_type>(); +} +// TRANSPOSE +TEST_F( TestCategory, batched_scalar_serial_trmm_l_l_t_u_dcomplex_dcomplex ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm,Kokkos::complex,param_tag_type,algo_tag_type>(); +} +TEST_F( TestCategory, batched_scalar_serial_trmm_l_l_t_n_dcomplex_dcomplex ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm,Kokkos::complex,param_tag_type,algo_tag_type>(); +} +TEST_F( TestCategory, batched_scalar_serial_trmm_l_u_t_u_dcomplex_dcomplex ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm,Kokkos::complex,param_tag_type,algo_tag_type>(); +} +TEST_F( TestCategory, batched_scalar_serial_trmm_l_u_t_n_dcomplex_dcomplex ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm,Kokkos::complex,param_tag_type,algo_tag_type>(); +} +TEST_F( TestCategory, batched_scalar_serial_trmm_r_u_t_u_dcomplex_dcomplex ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm,Kokkos::complex,param_tag_type,algo_tag_type>(); +} +TEST_F( TestCategory, batched_scalar_serial_trmm_r_u_t_n_dcomplex_dcomplex ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm,Kokkos::complex,param_tag_type,algo_tag_type>(); +} +// CONJUGATE TRANSPOSE +TEST_F( TestCategory, batched_scalar_serial_trmm_l_l_ct_u_dcomplex_dcomplex ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm,Kokkos::complex,param_tag_type,algo_tag_type>(); +} +TEST_F( TestCategory, batched_scalar_serial_trmm_l_l_ct_n_dcomplex_dcomplex ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm,Kokkos::complex,param_tag_type,algo_tag_type>(); +} +TEST_F( TestCategory, batched_scalar_serial_trmm_l_u_ct_u_dcomplex_dcomplex ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm,Kokkos::complex,param_tag_type,algo_tag_type>(); +} +TEST_F( TestCategory, batched_scalar_serial_trmm_l_u_ct_n_dcomplex_dcomplex ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm,Kokkos::complex,param_tag_type,algo_tag_type>(); +} +TEST_F( TestCategory, batched_scalar_serial_trmm_r_u_ct_u_dcomplex_dcomplex ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm,Kokkos::complex,param_tag_type,algo_tag_type>(); +} +TEST_F( TestCategory, batched_scalar_serial_trmm_r_u_ct_n_dcomplex_dcomplex ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm,Kokkos::complex,param_tag_type,algo_tag_type>(); +} +#endif + diff --git a/unit_test/batched/Test_Batched_SerialTrmm_Real.hpp b/unit_test/batched/Test_Batched_SerialTrmm_Real.hpp new file mode 100644 index 0000000000..1603af3971 --- /dev/null +++ b/unit_test/batched/Test_Batched_SerialTrmm_Real.hpp @@ -0,0 +1,230 @@ + +#if defined(KOKKOSKERNELS_INST_FLOAT) +// NO TRANSPOSE +TEST_F( TestCategory, batched_scalar_serial_trmm_l_l_nt_u_float_float ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm(); +} +TEST_F( TestCategory, batched_scalar_serial_trmm_l_l_nt_n_float_float ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm(); +} +TEST_F( TestCategory, batched_scalar_serial_trmm_l_u_nt_u_float_float ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm(); +} +TEST_F( TestCategory, batched_scalar_serial_trmm_l_u_nt_n_float_float ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm(); +} +TEST_F( TestCategory, batched_scalar_serial_trmm_r_u_nt_u_float_float ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm(); +} +TEST_F( TestCategory, batched_scalar_serial_trmm_r_u_nt_n_float_float ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm(); +} +// TRANSPOSE +TEST_F( TestCategory, batched_scalar_serial_trmm_l_l_t_u_float_float ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm(); +} +TEST_F( TestCategory, batched_scalar_serial_trmm_l_l_t_n_float_float ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm(); +} +TEST_F( TestCategory, batched_scalar_serial_trmm_l_u_t_u_float_float ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm(); +} +TEST_F( TestCategory, batched_scalar_serial_trmm_l_u_t_n_float_float ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm(); +} +TEST_F( TestCategory, batched_scalar_serial_trmm_r_u_t_u_float_float ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm(); +} +TEST_F( TestCategory, batched_scalar_serial_trmm_r_u_t_n_float_float ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm(); +} +// CONJUGATE TRANSPOSE +TEST_F( TestCategory, batched_scalar_serial_trmm_l_l_ct_u_float_float ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm(); +} +TEST_F( TestCategory, batched_scalar_serial_trmm_l_l_ct_n_float_float ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm(); +} +TEST_F( TestCategory, batched_scalar_serial_trmm_l_u_ct_u_float_float ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm(); +} +TEST_F( TestCategory, batched_scalar_serial_trmm_l_u_ct_n_float_float ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm(); +} +TEST_F( TestCategory, batched_scalar_serial_trmm_r_u_ct_u_float_float ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm(); +} +TEST_F( TestCategory, batched_scalar_serial_trmm_r_u_ct_n_float_float ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm(); +} +#endif + + +#if defined(KOKKOSKERNELS_INST_DOUBLE) +// NO TRANSPOSE +TEST_F( TestCategory, batched_scalar_serial_trmm_l_l_nt_u_double_double ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm(); +} +TEST_F( TestCategory, batched_scalar_serial_trmm_l_l_nt_n_double_double ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm(); +} +TEST_F( TestCategory, batched_scalar_serial_trmm_l_u_nt_u_double_double ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm(); +} +TEST_F( TestCategory, batched_scalar_serial_trmm_l_u_nt_n_double_double ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm(); +} +TEST_F( TestCategory, batched_scalar_serial_trmm_r_u_nt_u_double_double ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm(); +} +TEST_F( TestCategory, batched_scalar_serial_trmm_r_u_nt_n_double_double ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm(); +} +// TRANSPOSE +TEST_F( TestCategory, batched_scalar_serial_trmm_l_l_t_u_double_double ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm(); +} +TEST_F( TestCategory, batched_scalar_serial_trmm_l_l_t_n_double_double ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm(); +} +TEST_F( TestCategory, batched_scalar_serial_trmm_l_u_t_u_double_double ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm(); +} +TEST_F( TestCategory, batched_scalar_serial_trmm_l_u_t_n_double_double ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm(); +} +TEST_F( TestCategory, batched_scalar_serial_trmm_r_u_t_u_double_double ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm(); +} +TEST_F( TestCategory, batched_scalar_serial_trmm_r_u_t_n_double_double ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm(); +} +// CONJUGATE TRANSPOSE +TEST_F( TestCategory, batched_scalar_serial_trmm_l_l_ct_u_double_double ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm(); +} +TEST_F( TestCategory, batched_scalar_serial_trmm_l_l_ct_n_double_double ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm(); +} +TEST_F( TestCategory, batched_scalar_serial_trmm_l_u_ct_u_double_double ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm(); +} +TEST_F( TestCategory, batched_scalar_serial_trmm_l_u_ct_n_double_double ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm(); +} +TEST_F( TestCategory, batched_scalar_serial_trmm_r_u_ct_u_double_double ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm(); +} +TEST_F( TestCategory, batched_scalar_serial_trmm_r_u_ct_n_double_double ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trmm::Unblocked algo_tag_type; + + test_batched_trmm(); +} +#endif + diff --git a/unit_test/batched/Test_Batched_SerialTrtri.hpp b/unit_test/batched/Test_Batched_SerialTrtri.hpp new file mode 100644 index 0000000000..089b10e732 --- /dev/null +++ b/unit_test/batched/Test_Batched_SerialTrtri.hpp @@ -0,0 +1,334 @@ +#include "gtest/gtest.h" +#include "Kokkos_Core.hpp" +#include "Kokkos_Random.hpp" + +#include "KokkosBatched_Trtri_Decl.hpp" +#include "KokkosBatched_Trtri_Serial_Impl.hpp" + +#include "KokkosKernels_TestUtils.hpp" + +#define PRINT_MAT 0 + +using namespace KokkosBatched; + +namespace Test { + + template + struct UnitDiagTRTRI { + ViewTypeA A_; + using ScalarA = typename ViewTypeA::value_type; + + UnitDiagTRTRI (const ViewTypeA& A) : A_(A) {} + + KOKKOS_INLINE_FUNCTION + void operator() (const int& i) const { + A_(i,i) = ScalarA(1); + } + }; + template + struct NonUnitDiagTRTRI { + ViewTypeA A_; + using ScalarA = typename ViewTypeA::value_type; + + NonUnitDiagTRTRI (const ViewTypeA& A) : A_(A) {} + + KOKKOS_INLINE_FUNCTION + void operator() (const int& i) const { + A_(i,i) = A_(i,i)+10; + } + }; + template + struct VanillaGEMM { + bool A_t, B_t, A_c, B_c; + int N,K; + ViewTypeA A; + ViewTypeB B; + ViewTypeC C; + + typedef typename ViewTypeA::value_type ScalarA; + typedef typename ViewTypeB::value_type ScalarB; + typedef typename ViewTypeC::value_type ScalarC; + typedef Kokkos::Details::ArithTraits APT; + typedef typename APT::mag_type mag_type; + ScalarA alpha; + ScalarC beta; + + KOKKOS_INLINE_FUNCTION + void operator() (const typename Kokkos::TeamPolicy::member_type& team) const { +// GNU COMPILER BUG WORKAROUND +#if defined(KOKKOS_COMPILER_GNU) && !defined(__CUDA_ARCH__) + int i = team.league_rank(); +#else + const int i = team.league_rank(); +#endif + Kokkos::parallel_for(Kokkos::TeamThreadRange(team,N), [&] (const int& j) { + ScalarC C_ij = 0.0; + + // GNU 5.3, 5.4 and 6.1 (and maybe more) crash with another nested lambda here + +#if defined(KOKKOS_COMPILER_GNU) && !defined(KOKKOS_COMPILER_NVCC) + for(int k=0; k + struct ParamTag { + typedef U uplo; + typedef D diag; + }; + + template + struct Functor_TestBatchedSerialTrtri { + ViewType _a; + + KOKKOS_INLINE_FUNCTION + Functor_TestBatchedSerialTrtri(const ViewType &a) + : _a(a) {} + + KOKKOS_INLINE_FUNCTION + void operator()(const ParamTagType &, const int k) const { + auto aa = Kokkos::subview(_a, k, Kokkos::ALL(), Kokkos::ALL()); + + SerialTrtri::invoke(aa); + } + + inline + void run() { + typedef typename ViewType::value_type value_type; + std::string name_region("KokkosBatched::Test::SerialTrtri"); + std::string name_value_type = ( std::is_same::value ? "::Float" : + std::is_same::value ? "::Double" : + std::is_same >::value ? "::ComplexFloat" : + std::is_same >::value ? "::ComplexDouble" : "::UnknownValueType" ); + std::string name = name_region + name_value_type; + Kokkos::Profiling::pushRegion( name.c_str() ); + Kokkos::RangePolicy policy(0,_a.extent(0)); + Kokkos::parallel_for("Functor_TestBatchedSerialTrtri", policy, *this); + Kokkos::Profiling::popRegion(); + } + }; + + template + void impl_test_batched_trtri(const int N, const int K) { + typedef typename ViewType::value_type value_type; + typedef typename DeviceType::execution_space execution_space; + typedef Kokkos::Details::ArithTraits ats; + + ScalarType alpha(1.0); + ScalarType beta(0.0); + + // eps is ~ 10^-13 for double + typedef typename ats::mag_type mag_type; + const mag_type eps = 1.0e8 * ats::epsilon(); + bool fail_flag = false; + ScalarType cur_check_val; // Either 1 or 0, to check A_I + + const bool is_A_lower = std::is_same::value; + ViewType A("A", N, K, K); + ViewType A_original("A_original", N, K, K); + ViewType A_I("A_I", N, K, K); + + typename ViewType::HostMirror I_host = Kokkos::create_mirror_view(A_I); + typename ViewType::HostMirror A_host = Kokkos::create_mirror_view(A); + + uint64_t seed = Kokkos::Impl::clock_tic(); + + using ViewTypeSubA = decltype(Kokkos::subview(A, 0, Kokkos::ALL(), Kokkos::ALL())); + + Kokkos::Random_XorShift64_Pool rand_pool(seed); + + if(std::is_same::value) { + // Initialize A with deterministic random numbers + Kokkos::fill_random(A, rand_pool, Kokkos::rand, ScalarType>::max()); + using functor_type = UnitDiagTRTRI; + for (int k = 0; k < N; ++k) { + functor_type udtrtri(Kokkos::subview(A, k, Kokkos::ALL(), Kokkos::ALL())); + // Initialize As diag with 1s + Kokkos::parallel_for("KokkosBlas::Test::UnitDiagTRTRI", Kokkos::RangePolicy(0,K), udtrtri); + } + } else {//(diag[0]=='N')||(diag[0]=='n') + // Initialize A with random numbers + Kokkos::fill_random(A, rand_pool, Kokkos::rand, ScalarType>::max()); + using functor_type = NonUnitDiagTRTRI; + for (int k = 0; k < N; ++k) { + functor_type nudtrtri(Kokkos::subview(A, k, Kokkos::ALL(), Kokkos::ALL())); + // Initialize As diag with A(i,i)+10 + Kokkos::parallel_for("KokkosBlas::Test::NonUnitDiagTRTRI", Kokkos::RangePolicy(0,K), nudtrtri); + } + } + Kokkos::fence(); + + Kokkos::deep_copy(A_host, A); + // Make A_host a lower triangle + for (int k = 0; k < N; k++) { + if (is_A_lower) { + for (int i = 0; i < K-1; i++) + for (int j = i+1; j < K; j++) + A_host(k,i,j) = ScalarType(0); + } + else { + // Make A_host a upper triangle + for (int i = 1; i < K; i++) + for (int j = 0; j < i; j++) + A_host(k,i,j) = ScalarType(0); + } + } + Kokkos::deep_copy(A, A_host); + Kokkos::deep_copy(A_original, A); + Kokkos::fence(); + + #if PRINT_MAT + printf("A_original:\n"); + for (int k = 0; k < N; ++k) { + for (int i = 0; i < K; i++) { + for (int j = 0; j < K; j++) { + printf("%*.13lf ", 20, A_original(k,i,j)); + } + printf("\n"); + } + } + #endif + + #if PRINT_MAT + printf("A:\n"); + for (int k = 0; k < N; ++k) { + for (int i = 0; i < K; i++) { + for (int j = 0; j < K; j++) { + printf("%*.13lf ", 20, A(k,i,j)); + } + printf("\n"); + } + } + #endif + + Functor_TestBatchedSerialTrtri(A).run(); + + #if PRINT_MAT + printf("A_original:\n"); + for (int k = 0; k < N; ++k) { + for (int i = 0; i < K; i++) { + for (int j = 0; j < K; j++) { + printf("%*.13lf ", 20, A_original(k,i,j)); + } + printf("\n"); + } + } + #endif + + #if PRINT_MAT + printf("A:\n"); + for (int k = 0; k < N; ++k) { + for (int i = 0; i < K; i++) { + for (int j = 0; j < K; j++) { + printf("%*.13lf ", 20, A(k,i,j)); + } + printf("\n"); + } + } + #endif + + Kokkos::fence(); + + struct VanillaGEMM vgemm; + vgemm.A_t = false; vgemm.B_t = false; + vgemm.A_c = false; vgemm.B_c = false; + vgemm.N = K; vgemm.K = K; + vgemm.alpha = alpha; + vgemm.beta = beta; + for (int i = 0; i < N; i++) { + vgemm.A = Kokkos::subview(A, i, Kokkos::ALL(), Kokkos::ALL()); + vgemm.B = Kokkos::subview(A_original, i, Kokkos::ALL(), Kokkos::ALL());; + vgemm.C = Kokkos::subview(A_I, i, Kokkos::ALL(), Kokkos::ALL());; + Kokkos::parallel_for("KokkosBlas::Test::VanillaGEMM", Kokkos::TeamPolicy(K,Kokkos::AUTO,16), vgemm); + } + + Kokkos::fence(); + Kokkos::deep_copy(I_host, A_I); + Kokkos::fence(); + + #if PRINT_MAT + printf("I_host:\n"); + for (int k = 0; k < N; ++k) { + for (int i = 0; i < K; i++) { + for (int j = 0; j < K; j++) { + printf("%*.13lf ", 20, I_host(k,i,j)); + } + printf("\n"); + } + } + #endif + + for (int k=0;k eps) { + fail_flag = true; + //printf(" Error: eps ( %g ), I_host ( %.15f ) != cur_check_val (%.15f) (abs result-cur_check_val %g) at (k %d, i %d, j %d)\n", + //eps, I_host(k,i,j), cur_check_val, ats::abs(I_host(k,i,j) - cur_check_val), k, i, j); + } + } + } + } + + ASSERT_EQ( fail_flag, false ); + } +} + + +template +int test_batched_trtri() { +#if defined(KOKKOSKERNELS_INST_LAYOUTLEFT) + { + typedef Kokkos::View ViewType; + Test::impl_test_batched_trtri( 0, 10); + //Test::impl_test_batched_trtri( 1, 2); + for (int i=0;i<10;++i) { + //printf("Testing: LayoutLeft, Blksize %d\n", i); + Test::impl_test_batched_trtri(1024, i); + Test::impl_test_batched_trtri(1024, i); + } + } +#endif +#if defined(KOKKOSKERNELS_INST_LAYOUTRIGHT) + { + typedef Kokkos::View ViewType; + Test::impl_test_batched_trtri( 0, 10); + for (int i=0;i<10;++i) { + //printf("Testing: LayoutRight, Blksize %d\n", i); + Test::impl_test_batched_trtri(1024, i); + Test::impl_test_batched_trtri(1024, i); + } + } +#endif + + return 0; +} + diff --git a/unit_test/batched/Test_Batched_SerialTrtri_Complex.hpp b/unit_test/batched/Test_Batched_SerialTrtri_Complex.hpp new file mode 100644 index 0000000000..fc8b51bdaf --- /dev/null +++ b/unit_test/batched/Test_Batched_SerialTrtri_Complex.hpp @@ -0,0 +1,58 @@ + +#if defined(KOKKOSKERNELS_INST_COMPLEX_FLOAT) +// NO TRANSPOSE +TEST_F( TestCategory, batched_scalar_serial_trtri_u_n_scomplex_scomplex ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trtri::Unblocked algo_tag_type; + + test_batched_trtri,Kokkos::complex,param_tag_type,algo_tag_type>(); +} +TEST_F( TestCategory, batched_scalar_serial_trtri_u_u_scomplex_scomplex ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trtri::Unblocked algo_tag_type; + + test_batched_trtri,Kokkos::complex,param_tag_type,algo_tag_type>(); +} +TEST_F( TestCategory, batched_scalar_serial_trtri_l_n_scomplex_scomplex ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trtri::Unblocked algo_tag_type; + + test_batched_trtri,Kokkos::complex,param_tag_type,algo_tag_type>(); +} +TEST_F( TestCategory, batched_scalar_serial_trtri_l_u_scomplex_scomplex ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trtri::Unblocked algo_tag_type; + + test_batched_trtri,Kokkos::complex,param_tag_type,algo_tag_type>(); +} +#endif + + +#if defined(KOKKOSKERNELS_INST_COMPLEX_DOUBLE) +// NO TRANSPOSE +TEST_F( TestCategory, batched_scalar_serial_trtri_u_n_dcomplex_dcomplex ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trtri::Unblocked algo_tag_type; + + test_batched_trtri,Kokkos::complex,param_tag_type,algo_tag_type>(); +} +TEST_F( TestCategory, batched_scalar_serial_trtri_u_u_dcomplex_dcomplex ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trtri::Unblocked algo_tag_type; + + test_batched_trtri,Kokkos::complex,param_tag_type,algo_tag_type>(); +} +TEST_F( TestCategory, batched_scalar_serial_trtri_l_n_dcomplex_dcomplex ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trtri::Unblocked algo_tag_type; + + test_batched_trtri,Kokkos::complex,param_tag_type,algo_tag_type>(); +} +TEST_F( TestCategory, batched_scalar_serial_trtri_l_u_dcomplex_dcomplex ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trtri::Unblocked algo_tag_type; + + test_batched_trtri,Kokkos::complex,param_tag_type,algo_tag_type>(); +} +#endif + diff --git a/unit_test/batched/Test_Batched_SerialTrtri_Real.hpp b/unit_test/batched/Test_Batched_SerialTrtri_Real.hpp new file mode 100644 index 0000000000..6ba687b94f --- /dev/null +++ b/unit_test/batched/Test_Batched_SerialTrtri_Real.hpp @@ -0,0 +1,58 @@ + +#if defined(KOKKOSKERNELS_INST_FLOAT) +// NO TRANSPOSE +TEST_F( TestCategory, batched_scalar_serial_trtri_u_n_float_float ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trtri::Unblocked algo_tag_type; + + test_batched_trtri(); +} +TEST_F( TestCategory, batched_scalar_serial_trtri_u_u_float_float ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trtri::Unblocked algo_tag_type; + + test_batched_trtri(); +} +TEST_F( TestCategory, batched_scalar_serial_trtri_l_n_float_float ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trtri::Unblocked algo_tag_type; + + test_batched_trtri(); +} +TEST_F( TestCategory, batched_scalar_serial_trtri_l_u_float_float ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trtri::Unblocked algo_tag_type; + + test_batched_trtri(); +} +#endif + + +#if defined(KOKKOSKERNELS_INST_DOUBLE) +// NO TRANSPOSE +TEST_F( TestCategory, batched_scalar_serial_trtri_u_n_double_double ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trtri::Unblocked algo_tag_type; + + test_batched_trtri(); +} +TEST_F( TestCategory, batched_scalar_serial_trtri_u_u_double_double ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trtri::Unblocked algo_tag_type; + + test_batched_trtri(); +} +TEST_F( TestCategory, batched_scalar_serial_trtri_l_n_double_double ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trtri::Unblocked algo_tag_type; + + test_batched_trtri(); +} +TEST_F( TestCategory, batched_scalar_serial_trtri_l_u_double_double ) { + typedef ::Test::ParamTag param_tag_type; + typedef Algo::Trtri::Unblocked algo_tag_type; + + test_batched_trtri(); +} +#endif + diff --git a/unit_test/blas/Test_Blas3_trmm.hpp b/unit_test/blas/Test_Blas3_trmm.hpp index 4c380553ea..74fd49b988 100644 --- a/unit_test/blas/Test_Blas3_trmm.hpp +++ b/unit_test/blas/Test_Blas3_trmm.hpp @@ -192,7 +192,7 @@ int test_trmm(const char* mode, ScalarA alpha) { Test::impl_test_trmm(&mode[0],&mode[1],&mode[2],&mode[3],0,0,alpha); Test::impl_test_trmm(&mode[0],&mode[1],&mode[2],&mode[3],101,19,alpha); Test::impl_test_trmm(&mode[0],&mode[1],&mode[2],&mode[3],19,101,alpha); - Test::impl_test_trmm(&mode[0],&mode[1],&mode[2],&mode[3],1031,731,alpha); + Test::impl_test_trmm(&mode[0],&mode[1],&mode[2],&mode[3],12,731,alpha); #endif #if defined(KOKKOSKERNELS_INST_LAYOUTRIGHT) || (!defined(KOKKOSKERNELS_ETI_ONLY) && !defined(KOKKOSKERNELS_IMPL_CHECK_ETI_CALLS)) @@ -201,7 +201,7 @@ int test_trmm(const char* mode, ScalarA alpha) { Test::impl_test_trmm(&mode[0],&mode[1],&mode[2],&mode[3],0,0,alpha); Test::impl_test_trmm(&mode[0],&mode[1],&mode[2],&mode[3],101,19,alpha); Test::impl_test_trmm(&mode[0],&mode[1],&mode[2],&mode[3],19,101,alpha); - Test::impl_test_trmm(&mode[0],&mode[1],&mode[2],&mode[3],1031,731,alpha); + Test::impl_test_trmm(&mode[0],&mode[1],&mode[2],&mode[3],12,731,alpha); #endif return 1; diff --git a/unit_test/blas/Test_Blas_trtri.hpp b/unit_test/blas/Test_Blas_trtri.hpp index aa31ed2ef6..f939b87b31 100644 --- a/unit_test/blas/Test_Blas_trtri.hpp +++ b/unit_test/blas/Test_Blas_trtri.hpp @@ -258,10 +258,7 @@ int test_trtri(const char* mode) { EXPECT_EQ(ret, 0); // Rounding errors with randomly generated matrices begin here where M>100, so we pass in A=I - ret = Test::impl_test_trtri(bad_diag_idx, &mode[0], &mode[1], 473, 473); - EXPECT_EQ(ret, 0); - - ret = Test::impl_test_trtri(bad_diag_idx, &mode[0], &mode[1], 1002, 1002); + ret = Test::impl_test_trtri(bad_diag_idx, &mode[0], &mode[1], 273, 273); EXPECT_EQ(ret, 0); // Only non-unit matrices could be singular. @@ -293,10 +290,7 @@ int test_trtri(const char* mode) { EXPECT_EQ(ret, 0); // Rounding errors with randomly generated matrices begin here where M>100, so we pass in A=I - ret = Test::impl_test_trtri(bad_diag_idx, &mode[0], &mode[1], 473, 473); - EXPECT_EQ(ret, 0); - - ret = Test::impl_test_trtri(bad_diag_idx, &mode[0], &mode[1], 1002, 1002); + ret = Test::impl_test_trtri(bad_diag_idx, &mode[0], &mode[1], 273, 273); EXPECT_EQ(ret, 0); // Only non-unit matrices could be singular. diff --git a/unit_test/cuda/Test_Cuda_Batched_SerialTrmm_Complex.cpp b/unit_test/cuda/Test_Cuda_Batched_SerialTrmm_Complex.cpp new file mode 100644 index 0000000000..aa35495ca3 --- /dev/null +++ b/unit_test/cuda/Test_Cuda_Batched_SerialTrmm_Complex.cpp @@ -0,0 +1,3 @@ +#include "Test_Cuda.hpp" +#include "Test_Batched_SerialTrmm.hpp" +#include "Test_Batched_SerialTrmm_Complex.hpp" diff --git a/unit_test/cuda/Test_Cuda_Batched_SerialTrmm_Real.cpp b/unit_test/cuda/Test_Cuda_Batched_SerialTrmm_Real.cpp new file mode 100644 index 0000000000..ea0fe7d36c --- /dev/null +++ b/unit_test/cuda/Test_Cuda_Batched_SerialTrmm_Real.cpp @@ -0,0 +1,3 @@ +#include "Test_Cuda.hpp" +#include "Test_Batched_SerialTrmm.hpp" +#include "Test_Batched_SerialTrmm_Real.hpp" diff --git a/unit_test/cuda/Test_Cuda_Batched_SerialTrtri_Complex.cpp b/unit_test/cuda/Test_Cuda_Batched_SerialTrtri_Complex.cpp new file mode 100644 index 0000000000..df492e6231 --- /dev/null +++ b/unit_test/cuda/Test_Cuda_Batched_SerialTrtri_Complex.cpp @@ -0,0 +1,3 @@ +#include "Test_Cuda.hpp" +#include "Test_Batched_SerialTrtri.hpp" +#include "Test_Batched_SerialTrtri_Complex.hpp" diff --git a/unit_test/cuda/Test_Cuda_Batched_SerialTrtri_Real.cpp b/unit_test/cuda/Test_Cuda_Batched_SerialTrtri_Real.cpp new file mode 100644 index 0000000000..e6917691f6 --- /dev/null +++ b/unit_test/cuda/Test_Cuda_Batched_SerialTrtri_Real.cpp @@ -0,0 +1,3 @@ +#include "Test_Cuda.hpp" +#include "Test_Batched_SerialTrtri.hpp" +#include "Test_Batched_SerialTrtri_Real.hpp" diff --git a/unit_test/cuda/Test_Cuda_Blas3_trmm.cpp b/unit_test/cuda/Test_Cuda_Blas3_trmm.cpp index 55f80d2365..eb627d4751 100644 --- a/unit_test/cuda/Test_Cuda_Blas3_trmm.cpp +++ b/unit_test/cuda/Test_Cuda_Blas3_trmm.cpp @@ -1,5 +1,2 @@ #include -// Remove this ifdef once we have a fall back implementation. -#ifdef KOKKOSKERNELS_ENABLE_TPL_CUBLAS #include -#endif // KOKKOSKERNELS_ENABLE_TPL_CUBLAS diff --git a/unit_test/cuda/Test_Cuda_Blas_trtri.cpp b/unit_test/cuda/Test_Cuda_Blas_trtri.cpp index bcafa97c5f..1ffd98693e 100644 --- a/unit_test/cuda/Test_Cuda_Blas_trtri.cpp +++ b/unit_test/cuda/Test_Cuda_Blas_trtri.cpp @@ -1,4 +1,2 @@ #include -#ifdef KOKKOSKERNELS_ENABLE_TPL_MAGMA #include -#endif // KOKKOSKERNELS_ENABLE_TPL_MAGMA diff --git a/unit_test/openmp/Test_OpenMP_Batched_SerialTrmm_Complex.cpp b/unit_test/openmp/Test_OpenMP_Batched_SerialTrmm_Complex.cpp new file mode 100644 index 0000000000..1976225105 --- /dev/null +++ b/unit_test/openmp/Test_OpenMP_Batched_SerialTrmm_Complex.cpp @@ -0,0 +1,3 @@ +#include "Test_OpenMP.hpp" +#include "Test_Batched_SerialTrmm.hpp" +#include "Test_Batched_SerialTrmm_Complex.hpp" diff --git a/unit_test/openmp/Test_OpenMP_Batched_SerialTrmm_Real.cpp b/unit_test/openmp/Test_OpenMP_Batched_SerialTrmm_Real.cpp new file mode 100644 index 0000000000..1976225105 --- /dev/null +++ b/unit_test/openmp/Test_OpenMP_Batched_SerialTrmm_Real.cpp @@ -0,0 +1,3 @@ +#include "Test_OpenMP.hpp" +#include "Test_Batched_SerialTrmm.hpp" +#include "Test_Batched_SerialTrmm_Complex.hpp" diff --git a/unit_test/openmp/Test_OpenMP_Batched_SerialTrtri_Complex.cpp b/unit_test/openmp/Test_OpenMP_Batched_SerialTrtri_Complex.cpp new file mode 100644 index 0000000000..93caef3e8b --- /dev/null +++ b/unit_test/openmp/Test_OpenMP_Batched_SerialTrtri_Complex.cpp @@ -0,0 +1,3 @@ +#include "Test_OpenMP.hpp" +#include "Test_Batched_SerialTrtri.hpp" +#include "Test_Batched_SerialTrtri_Complex.hpp" diff --git a/unit_test/openmp/Test_OpenMP_Batched_SerialTrtri_Real.cpp b/unit_test/openmp/Test_OpenMP_Batched_SerialTrtri_Real.cpp new file mode 100644 index 0000000000..25a74f7867 --- /dev/null +++ b/unit_test/openmp/Test_OpenMP_Batched_SerialTrtri_Real.cpp @@ -0,0 +1,3 @@ +#include "Test_OpenMP.hpp" +#include "Test_Batched_SerialTrtri.hpp" +#include "Test_Batched_SerialTrtri_Real.hpp" diff --git a/unit_test/openmp/Test_OpenMP_Blas3_trmm.cpp b/unit_test/openmp/Test_OpenMP_Blas3_trmm.cpp index 106e4404c0..ec8e46d4f0 100644 --- a/unit_test/openmp/Test_OpenMP_Blas3_trmm.cpp +++ b/unit_test/openmp/Test_OpenMP_Blas3_trmm.cpp @@ -1,5 +1,2 @@ #include -// Remove this ifdef once we have a fall back implementation. -#ifdef KOKKOSKERNELS_ENABLE_TPL_BLAS #include -#endif // KOKKOSKERNELS_ENABLE_TPL_BLAS diff --git a/unit_test/openmp/Test_OpenMP_Blas_trtri.cpp b/unit_test/openmp/Test_OpenMP_Blas_trtri.cpp index 821cc89bdf..d2d05c58ec 100644 --- a/unit_test/openmp/Test_OpenMP_Blas_trtri.cpp +++ b/unit_test/openmp/Test_OpenMP_Blas_trtri.cpp @@ -1,4 +1,2 @@ #include -#ifdef KOKKOSKERNELS_ENABLE_TPL_BLAS #include -#endif // KOKKOSKERNELS_ENABLE_TPL_BLAS diff --git a/unit_test/serial/Test_Serial_Batched_SerialTrmm_Complex.cpp b/unit_test/serial/Test_Serial_Batched_SerialTrmm_Complex.cpp new file mode 100644 index 0000000000..94b0e16d92 --- /dev/null +++ b/unit_test/serial/Test_Serial_Batched_SerialTrmm_Complex.cpp @@ -0,0 +1,3 @@ +#include "Test_Serial.hpp" +#include "Test_Batched_SerialTrmm.hpp" +#include "Test_Batched_SerialTrmm_Complex.hpp" diff --git a/unit_test/serial/Test_Serial_Batched_SerialTrmm_Real.cpp b/unit_test/serial/Test_Serial_Batched_SerialTrmm_Real.cpp new file mode 100644 index 0000000000..df7fd8cec9 --- /dev/null +++ b/unit_test/serial/Test_Serial_Batched_SerialTrmm_Real.cpp @@ -0,0 +1,3 @@ +#include "Test_Serial.hpp" +#include "Test_Batched_SerialTrmm.hpp" +#include "Test_Batched_SerialTrmm_Real.hpp" diff --git a/unit_test/serial/Test_Serial_Batched_SerialTrtri_Complex.cpp b/unit_test/serial/Test_Serial_Batched_SerialTrtri_Complex.cpp new file mode 100644 index 0000000000..4a92854e8a --- /dev/null +++ b/unit_test/serial/Test_Serial_Batched_SerialTrtri_Complex.cpp @@ -0,0 +1,3 @@ +#include "Test_Serial.hpp" +#include "Test_Batched_SerialTrtri.hpp" +#include "Test_Batched_SerialTrtri_Complex.hpp" diff --git a/unit_test/serial/Test_Serial_Batched_SerialTrtri_Real.cpp b/unit_test/serial/Test_Serial_Batched_SerialTrtri_Real.cpp new file mode 100644 index 0000000000..a44bb8cb68 --- /dev/null +++ b/unit_test/serial/Test_Serial_Batched_SerialTrtri_Real.cpp @@ -0,0 +1,3 @@ +#include "Test_Serial.hpp" +#include "Test_Batched_SerialTrtri.hpp" +#include "Test_Batched_SerialTrtri_Real.hpp" diff --git a/unit_test/serial/Test_Serial_Blas3_trmm.cpp b/unit_test/serial/Test_Serial_Blas3_trmm.cpp new file mode 100644 index 0000000000..5ced17ecbb --- /dev/null +++ b/unit_test/serial/Test_Serial_Blas3_trmm.cpp @@ -0,0 +1,2 @@ +#include +#include diff --git a/unit_test/serial/Test_Serial_Blas_trtri.cpp b/unit_test/serial/Test_Serial_Blas_trtri.cpp new file mode 100644 index 0000000000..bb08b3a139 --- /dev/null +++ b/unit_test/serial/Test_Serial_Blas_trtri.cpp @@ -0,0 +1,2 @@ +#include +#include