Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial TRMM & TRTRI eti specializations as serial batched routines #697

Merged
merged 28 commits into from
May 1, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
2176370
blas/trmm: Initial fall-back support.
e10harvey Apr 9, 2020
a8fd429
blas/trmm: Fix serial LLNU implementation.
e10harvey Apr 9, 2020
2c94f39
blas/trmm: Add LLNN implementation.
e10harvey Apr 9, 2020
a0d6fde
blas/trmm: Added all left implementations.
e10harvey Apr 13, 2020
d9d37cd
blas/trmm: Fix conjugate-transpose for left implementation.
e10harvey Apr 14, 2020
94485a0
blas/trmm: Rely on loop unswitch for conjugate-transpose.
e10harvey Apr 14, 2020
4fb463d
blas/trmm: Implemented right lower.
e10harvey Apr 14, 2020
7887736
blas/trmm: Implemented right upper.
e10harvey Apr 14, 2020
2af70ba
blas/trmm: Cleanup.
e10harvey Apr 14, 2020
b3ef010
blas/trtri: Add eti specializations.
e10harvey Apr 15, 2020
9b2c858
blas/trtri: Add initial trtri fall-back skeleton.
e10harvey Apr 15, 2020
2133cb9
blas/trtri: Remove const from template param type for A for deep_copy.
e10harvey Apr 15, 2020
adf50de
blas/trtri: Initial TRTRI fall-back implementation.
e10harvey Apr 17, 2020
233dd36
blas/trtri: Let compiler unswitch loop.
e10harvey Apr 17, 2020
a8a2abd
blas/trmm: Add remaining SerialTrmm permutations.
e10harvey Apr 17, 2020
4c62089
blas/trmm: Add remaining SerialTrtri permutations.
e10harvey Apr 17, 2020
da5d3c8
blas/{trmm,trtri}: Cleanup.
e10harvey Apr 17, 2020
6c67e7f
blas/{trmm,trtri}: Fix spot check compile errors.
e10harvey Apr 20, 2020
8e45391
unit_test/{trmm,trtri}: Reduce test matrix sizes.
e10harvey Apr 20, 2020
1aaeb15
unit_test/{trmm,trtri}: Remove TPL enable if defs.
e10harvey Apr 20, 2020
77d7d68
blas/trtri: Fix fence post error in SerialTrtriInternalLower.
e10harvey Apr 21, 2020
d962955
batched/unit_test: Add Trmm unit test and rename batched files.
e10harvey Apr 27, 2020
a716507
batched/unit_test: Fix trmm serial tests.
e10harvey Apr 28, 2020
f4c03a5
batched/unit_test: Add complex trmm serial tests.
e10harvey Apr 28, 2020
a1ba5c6
batched/unit_test: Add trmm batched tests to cuda and openmp.
e10harvey Apr 28, 2020
1257bd9
batched/serial: add trtri unit test
e10harvey Apr 28, 2020
17db75a
Fix spot check.
e10harvey Apr 28, 2020
ff42e48
blas/tr{mm,tri}: Prefer TPL spec when ETI is available.
e10harvey Apr 29, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,18 @@ KOKKOSKERNELS_GENERATE_ETI(Blas3_trsm trsm
TYPE_LISTS FLOATS LAYOUTS DEVICES
)

KOKKOSKERNELS_GENERATE_ETI(Blas3_trmm trmm
HEADER_LIST HEADERS
SOURCE_LIST SOURCES
TYPE_LISTS FLOATS LAYOUTS DEVICES
)

KOKKOSKERNELS_GENERATE_ETI(Blas_trtri trtri
HEADER_LIST HEADERS
SOURCE_LIST SOURCES
TYPE_LISTS FLOATS LAYOUTS DEVICES
)

KOKKOSKERNELS_GENERATE_ETI(Sparse_sptrsv_solve sptrsv_solve
HEADER_LIST ETI_HEADERS
SOURCE_LIST SOURCES
Expand Down
69 changes: 69 additions & 0 deletions src/batched/KokkosBatched_Trmm_Decl.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
//@HEADER
// ************************************************************************
//
// Kokkos v. 3.0
// Copyright (2020) National Technology & Engineering
// Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// 1. Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
//
// 3. Neither the name of the Corporation nor the names of the
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY NTESS "AS IS" AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NTESS OR THE
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//
// Questions? Contact Siva Rajamanickam ([email protected])
//
// ************************************************************************
//@HEADER
*/

#ifndef __KOKKOSBATCHED_TRMM_DECL_HPP__
#define __KOKKOSBATCHED_TRMM_DECL_HPP__

#include "KokkosBatched_Util.hpp"
#include "KokkosBatched_Vector.hpp"

namespace KokkosBatched {

template<typename ArgSide,
typename ArgUplo,
typename ArgTrans,
typename ArgDiag,
typename ArgAlgo>
struct SerialTrmm {
template<typename ScalarType,
typename AViewType,
typename BViewType>
KOKKOS_INLINE_FUNCTION
static int
invoke(const ScalarType alpha,
const AViewType &A,
const BViewType &B);
};
} // namespace KokkosBatched
#endif // __KOKKOSBATCHED_TRMM_DECL_HPP__
288 changes: 288 additions & 0 deletions src/batched/KokkosBatched_Trmm_Serial_Impl.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,288 @@
/*
//@HEADER
// ************************************************************************
//
// Kokkos v. 3.0
// Copyright (2020) National Technology & Engineering
// Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// 1. Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
//
// 3. Neither the name of the Corporation nor the names of the
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY NTESS "AS IS" AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NTESS OR THE
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//
// Questions? Contact Siva Rajamanickam ([email protected])
//
// ************************************************************************
//@HEADER
*/

#ifndef __KOKKOSBATCHED_TRMM_SERIAL_IMPL_HPP__
#define __KOKKOSBATCHED_TRMM_SERIAL_IMPL_HPP__

#include "KokkosBatched_Util.hpp"
#include "KokkosBatched_Trmm_Serial_Internal.hpp"

namespace KokkosBatched {
//// Lower non-transpose ////
template<typename ArgDiag>
struct SerialTrmm<Side::Left,Uplo::Lower,Trans::NoTranspose,ArgDiag,Algo::Trmm::Unblocked> {
template<typename ScalarType,
typename AViewType,
typename BViewType>
KOKKOS_INLINE_FUNCTION
static int
invoke(const ScalarType alpha,
const AViewType &A,
const BViewType &B) {
return SerialTrmmInternalLeftLower<Algo::Trmm::Unblocked>::invoke(ArgDiag::use_unit_diag,
false,
A.extent(0), A.extent(1),
B.extent(0), B.extent(1),
alpha,
A.data(), A.stride_0(), A.stride_1(),
B.data(), B.stride_0(), B.stride_1());
}
};
template<typename ArgDiag>
struct SerialTrmm<Side::Right,Uplo::Lower,Trans::NoTranspose,ArgDiag,Algo::Trmm::Unblocked> {
template<typename ScalarType,
typename AViewType,
typename BViewType>
KOKKOS_INLINE_FUNCTION
static int
invoke(const ScalarType alpha,
const AViewType &A,
const BViewType &B) {
return SerialTrmmInternalRightLower<Algo::Trmm::Unblocked>::invoke(ArgDiag::use_unit_diag,
false,
A.extent(0), A.extent(1),
B.extent(0), B.extent(1),
alpha,
A.data(), A.stride_0(), A.stride_1(),
B.data(), B.stride_0(), B.stride_1());
}
};
//// Lower transpose /////
template<typename ArgDiag>
struct SerialTrmm<Side::Left,Uplo::Lower,Trans::Transpose,ArgDiag,Algo::Trmm::Unblocked> {
template<typename ScalarType,
typename AViewType,
typename BViewType>
KOKKOS_INLINE_FUNCTION
static int
invoke(const ScalarType alpha,
const AViewType &A,
const BViewType &B) {
return SerialTrmmInternalLeftUpper<Algo::Trmm::Unblocked>::invoke(ArgDiag::use_unit_diag,
false,
A.extent(1), A.extent(0),
B.extent(0), B.extent(1),
alpha,
A.data(), A.stride_1(), A.stride_0(),
B.data(), B.stride_0(), B.stride_1());
}
};
template<typename ArgDiag>
struct SerialTrmm<Side::Right,Uplo::Lower,Trans::Transpose,ArgDiag,Algo::Trmm::Unblocked> {
template<typename ScalarType,
typename AViewType,
typename BViewType>
KOKKOS_INLINE_FUNCTION
static int
invoke(const ScalarType alpha,
const AViewType &A,
const BViewType &B) {
return SerialTrmmInternalRightUpper<Algo::Trmm::Unblocked>::invoke(ArgDiag::use_unit_diag,
false,
A.extent(1), A.extent(0),
B.extent(0), B.extent(1),
alpha,
A.data(), A.stride_1(), A.stride_0(),
B.data(), B.stride_0(), B.stride_1());
}
};
//// Lower conjugate-transpose ////
template<typename ArgDiag>
struct SerialTrmm<Side::Left,Uplo::Lower,Trans::ConjTranspose,ArgDiag,Algo::Trmm::Unblocked> {
template<typename ScalarType,
typename AViewType,
typename BViewType>
KOKKOS_INLINE_FUNCTION
static int
invoke(const ScalarType alpha,
const AViewType &A,
const BViewType &B) {
return SerialTrmmInternalLeftUpper<Algo::Trmm::Unblocked>::invoke(ArgDiag::use_unit_diag,
true,
A.extent(1), A.extent(0),
B.extent(0), B.extent(1),
alpha,
A.data(), A.stride_1(), A.stride_0(),
B.data(), B.stride_0(), B.stride_1());
}
};
template<typename ArgDiag>
struct SerialTrmm<Side::Right,Uplo::Lower,Trans::ConjTranspose,ArgDiag,Algo::Trmm::Unblocked> {
template<typename ScalarType,
typename AViewType,
typename BViewType>
KOKKOS_INLINE_FUNCTION
static int
invoke(const ScalarType alpha,
const AViewType &A,
const BViewType &B) {
return SerialTrmmInternalRightUpper<Algo::Trmm::Unblocked>::invoke(ArgDiag::use_unit_diag,
true,
A.extent(1), A.extent(0),
B.extent(0), B.extent(1),
alpha,
A.data(), A.stride_1(), A.stride_0(),
B.data(), B.stride_0(), B.stride_1());
}
};
//// Upper non-transpose ////
template<typename ArgDiag>
struct SerialTrmm<Side::Left,Uplo::Upper,Trans::NoTranspose,ArgDiag,Algo::Trmm::Unblocked> {
template<typename ScalarType,
typename AViewType,
typename BViewType>
KOKKOS_INLINE_FUNCTION
static int
invoke(const ScalarType alpha,
const AViewType &A,
const BViewType &B) {
return SerialTrmmInternalLeftUpper<Algo::Trmm::Unblocked>::invoke(ArgDiag::use_unit_diag,
false,
A.extent(0), A.extent(1),
B.extent(0), B.extent(1),
alpha,
A.data(), A.stride_0(), A.stride_1(),
B.data(), B.stride_0(), B.stride_1());
}
};
template<typename ArgDiag>
struct SerialTrmm<Side::Right,Uplo::Upper,Trans::NoTranspose,ArgDiag,Algo::Trmm::Unblocked> {
template<typename ScalarType,
typename AViewType,
typename BViewType>
KOKKOS_INLINE_FUNCTION
static int
invoke(const ScalarType alpha,
const AViewType &A,
const BViewType &B) {
return SerialTrmmInternalRightUpper<Algo::Trmm::Unblocked>::invoke(ArgDiag::use_unit_diag,
false,
A.extent(0), A.extent(1),
B.extent(0), B.extent(1),
alpha,
A.data(), A.stride_0(), A.stride_1(),
B.data(), B.stride_0(), B.stride_1());
}
};
//// Upper transpose /////
template<typename ArgDiag>
struct SerialTrmm<Side::Left,Uplo::Upper,Trans::Transpose,ArgDiag,Algo::Trmm::Unblocked> {
template<typename ScalarType,
typename AViewType,
typename BViewType>
KOKKOS_INLINE_FUNCTION
static int
invoke(const ScalarType alpha,
const AViewType &A,
const BViewType &B) {
return SerialTrmmInternalLeftLower<Algo::Trmm::Unblocked>::invoke(ArgDiag::use_unit_diag,
false,
A.extent(1), A.extent(0),
B.extent(0), B.extent(1),
alpha,
A.data(), A.stride_1(), A.stride_0(),
B.data(), B.stride_0(), B.stride_1());
}
};
template<typename ArgDiag>
struct SerialTrmm<Side::Right,Uplo::Upper,Trans::Transpose,ArgDiag,Algo::Trmm::Unblocked> {
template<typename ScalarType,
typename AViewType,
typename BViewType>
KOKKOS_INLINE_FUNCTION
static int
invoke(const ScalarType alpha,
const AViewType &A,
const BViewType &B) {
return SerialTrmmInternalRightLower<Algo::Trmm::Unblocked>::invoke(ArgDiag::use_unit_diag,
false,
A.extent(1), A.extent(0),
B.extent(0), B.extent(1),
alpha,
A.data(), A.stride_1(), A.stride_0(),
B.data(), B.stride_0(), B.stride_1());
}
};
//// Upper conjugate-transpose ////
template<typename ArgDiag>
struct SerialTrmm<Side::Left,Uplo::Upper,Trans::ConjTranspose,ArgDiag,Algo::Trmm::Unblocked> {
template<typename ScalarType,
typename AViewType,
typename BViewType>
KOKKOS_INLINE_FUNCTION
static int
invoke(const ScalarType alpha,
const AViewType &A,
const BViewType &B) {
return SerialTrmmInternalLeftLower<Algo::Trmm::Unblocked>::invoke(ArgDiag::use_unit_diag,
true,
A.extent(1), A.extent(0),
B.extent(0), B.extent(1),
alpha,
A.data(), A.stride_1(), A.stride_0(),
B.data(), B.stride_0(), B.stride_1());
}
};
template<typename ArgDiag>
struct SerialTrmm<Side::Right,Uplo::Upper,Trans::ConjTranspose,ArgDiag,Algo::Trmm::Unblocked> {
template<typename ScalarType,
typename AViewType,
typename BViewType>
KOKKOS_INLINE_FUNCTION
static int
invoke(const ScalarType alpha,
const AViewType &A,
const BViewType &B) {
return SerialTrmmInternalRightLower<Algo::Trmm::Unblocked>::invoke(ArgDiag::use_unit_diag,
true,
A.extent(1), A.extent(0),
B.extent(0), B.extent(1),
alpha,
A.data(), A.stride_1(), A.stride_0(),
B.data(), B.stride_0(), B.stride_1());
}
};
} // namespace KokkosBatched

#endif // __KOKKOSBATCHED_TRMM_SERIAL_IMPL_HPP__
Loading