Skip to content

Commit

Permalink
Refactor lapack bindings
Browse files Browse the repository at this point in the history
  • Loading branch information
asadchev committed Nov 16, 2020
1 parent 49fe50e commit 9a1e84f
Show file tree
Hide file tree
Showing 8 changed files with 408 additions and 208 deletions.
3 changes: 1 addition & 2 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ TiledArray/array_impl.cpp
TiledArray/dist_array.cpp
TiledArray/util/backtrace.cpp
TiledArray/util/bug.cpp
TiledArray/algebra/lapack/lapack.cc
)

# the list of libraries on which TiledArray depends on, will be cached later
Expand Down Expand Up @@ -302,5 +303,3 @@ install(
FILES_MATCHING PATTERN "*.h"
PATTERN "CMakeFiles" EXCLUDE
)


2 changes: 1 addition & 1 deletion src/TiledArray/algebra/chol.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ auto cholesky_lsolve(TransposeFlag transpose, const Array& A, const Array& B,
x_trange);
else
#endif
return lapack::cholesky_solve<Array>(transpose, A, B, l_trange, x_trange);
return lapack::cholesky_lsolve<Array>(transpose, A, B, l_trange, x_trange);
}

} // namespace TiledArray
Expand Down
66 changes: 7 additions & 59 deletions src/TiledArray/algebra/lapack/chol.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,35 +24,16 @@
#ifndef TILEDARRAY_ALGEBRA_LAPACK_CHOL_H__INCLUDED
#define TILEDARRAY_ALGEBRA_LAPACK_CHOL_H__INCLUDED

#include <TiledArray/algebra/lapack/util.h>
#include <TiledArray/config.h>
#include <TiledArray/algebra/lapack/lapack.h>
#include <TiledArray/algebra/lapack/util.h>
#include <TiledArray/conversions/eigen.h>

namespace TiledArray {
namespace lapack {

namespace detail {

template <typename Scalar, int RowsAtCompileTime, int ColsAtCompileTime,
int Options, int MaxRowsAtCompileTime, int MaxColsAtCompileTime>
void chol_eig(
Eigen::Matrix<Scalar, RowsAtCompileTime, ColsAtCompileTime, Options,
MaxRowsAtCompileTime, MaxColsAtCompileTime>& A) {
using numeric_type = Scalar;
char uplo = 'L';
integer n = A.rows();
numeric_type* a = A.data();
integer lda = n;
integer info = 0;
#if defined(MADNESS_LINALG_USE_LAPACKE)
MADNESS_DISPATCH_LAPACK_FN(potrf, &uplo, &n, a, &lda, &info);
#else
MADNESS_DISPATCH_LAPACK_FN(potrf, &uplo, &n, a, &lda, &info, sizeof(char));
#endif

if (info != 0) TA_EXCEPTION("LAPACK::potrf failed");
}

template <typename Tile, typename Policy>
auto make_L_eig(const DistArray<Tile, Policy>& A) {
using Array = DistArray<Tile, Policy>;
Expand Down Expand Up @@ -121,7 +102,7 @@ template <typename ContiguousTensor,
TiledArray::detail::is_contiguous_tensor_v<ContiguousTensor>>>
auto cholesky(const ContiguousTensor& A) {
auto A_eig = detail::to_eigen(A);
detail::chol_eig(A_eig);
lapack::cholesky(A_eig);
detail::zero_out_upper_triangle(A_eig);
return detail::from_eigen<ContiguousTensor>(A_eig, A.range());
}
Expand Down Expand Up @@ -156,21 +137,11 @@ auto cholesky_linv(const Array& A, TiledRange l_trange = TiledRange()) {

// if need to return L use its copy to compute inverse
decltype(L_eig) L_inv_eig;
if (RetL && world.rank() == 0) L_inv_eig = L_eig;

if (world.rank() == 0) {
if (RetL) L_inv_eig = L_eig;
auto& L_inv_eig_ref = RetL ? L_inv_eig : L_eig;

char uplo = 'L';
char diag = 'N';
integer n = L_eig.rows();
using numeric_type = typename Array::numeric_type;
numeric_type* l = L_inv_eig_ref.data();
integer lda = n;
integer info = 0;
MADNESS_DISPATCH_LAPACK_FN(trtri, &uplo, &diag, &n, l, &lda, &info);
if (info != 0) TA_EXCEPTION("LAPACK::trtri failed");

cholesky_linv(L_inv_eig_ref);
detail::zero_out_upper_triangle(L_inv_eig_ref);
}
world.gop.broadcast_serializable(RetL ? L_inv_eig : L_eig, 0);
Expand All @@ -196,16 +167,7 @@ auto cholesky_solve(const Array& A, const Array& B,
auto X_eig = detail::to_eigen(B);
World& world = A.world();
if (world.rank() == 0) {
char uplo = 'L';
integer n = A_eig.rows();
integer nrhs = X_eig.cols();
numeric_type* a = A_eig.data();
numeric_type* b = X_eig.data();
integer lda = n;
integer ldb = n;
integer info = 0;
MADNESS_DISPATCH_LAPACK_FN(posv, &uplo, &n, &nrhs, a, &lda, b, &ldb, &info);
if (info != 0) TA_EXCEPTION("LAPACK::posv failed");
cholesky_solve(A_eig, X_eig);
}
world.gop.broadcast_serializable(X_eig, 0);
if (x_trange.rank() == 0) x_trange = B.trange();
Expand All @@ -228,21 +190,7 @@ auto cholesky_lsolve(TransposeFlag transpose, const Array& A, const Array& B,

auto X_eig = detail::to_eigen(B);
if (world.rank() == 0) {
char uplo = 'L';
char trans = transpose == TransposeFlag::Transpose
? 'T'
: (transpose == TransposeFlag::NoTranspose ? 'N' : 'C');
char diag = 'N';
integer n = L_eig.rows();
integer nrhs = X_eig.cols();
numeric_type* a = L_eig.data();
numeric_type* b = X_eig.data();
integer lda = n;
integer ldb = n;
integer info = 0;
MADNESS_DISPATCH_LAPACK_FN(trtrs, &uplo, &trans, &diag, &n, &nrhs, a, &lda,
b, &ldb, &info);
if (info != 0) TA_EXCEPTION("LAPACK::trtrs failed");
cholesky_lsolve(transpose, L_eig, X_eig);
}
world.gop.broadcast_serializable(X_eig, 0);
if (l_trange.rank() == 0) l_trange = A.trange();
Expand Down
107 changes: 54 additions & 53 deletions src/TiledArray/algebra/lapack/heig.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,60 +61,61 @@ auto heig(const Array& A, TiledRange evec_trange = TiledRange()) {
World& world = A.world();
auto A_eig = detail::to_eigen(A);
std::vector<scalar_type> evals;
if (world.rank() == 0) {
char jobz = 'V';
char uplo = 'L';
integer n = A_eig.rows();
numeric_type* a = A_eig.data();
integer lda = n;
integer info = 0;
evals.resize(n);
integer lwork = -1;
std::vector<numeric_type> work(1);
// run once to query, then to compute
while (lwork != static_cast<integer>(work.size())) {
if (lwork > 0) {
work.resize(lwork);
}
if constexpr (is_real) {
#if defined(MADNESS_LINALG_USE_LAPACKE)
MADNESS_DISPATCH_LAPACK_FN(syev, &jobz, &uplo, &n, a, &lda,
evals.data(), work.data(), &lwork, &info);
#else
MADNESS_DISPATCH_LAPACK_FN(syev, &jobz, &uplo, &n, a, &lda,
evals.data(), work.data(), &lwork, &info,
sizeof(char), sizeof(char));
#endif
} else {
std::vector<scalar_type> rwork;
if (lwork == static_cast<integer>(work.size())) rwork.resize(3 * n - 2);
#if defined(MADNESS_LINALG_USE_LAPACKE)
MADNESS_DISPATCH_LAPACK_FN(heev, &jobz, &uplo, &n, a, &lda,
evals.data(), work.data(), &lwork,
&rwork.data(), &info);
#else
MADNESS_DISPATCH_LAPACK_FN(
heev, &jobz, &uplo, &n, a, &lda, evals.data(), work.data(), &lwork,
&rwork.data(), &info, sizeof(char), sizeof(char));
#endif
}
if (lwork == -1) {
if constexpr (is_real) {
lwork = static_cast<integer>(work[0]);
} else {
lwork = static_cast<integer>(work[0].real());
}
TA_ASSERT(lwork > 1);
}
};
// if (world.rank() == 0) {
// char jobz = 'V';
// char uplo = 'L';
// integer n = A_eig.rows();
// numeric_type* a = A_eig.data();
// integer lda = n;
// integer info = 0;
// evals.resize(n);
// integer lwork = -1;
// std::vector<numeric_type> work(1);
// // run once to query, then to compute
// while (lwork != static_cast<integer>(work.size())) {
// if (lwork > 0) {
// work.resize(lwork);
// }
// if constexpr (is_real) {
// #if defined(MADNESS_LINALG_USE_LAPACKE)
// MADNESS_DISPATCH_LAPACK_FN(syev, &jobz, &uplo, &n, a, &lda,
// evals.data(), work.data(), &lwork, &info);
// #else
// MADNESS_DISPATCH_LAPACK_FN(syev, &jobz, &uplo, &n, a, &lda,
// evals.data(), work.data(), &lwork, &info,
// sizeof(char), sizeof(char));
// #endif
// } else {
// std::vector<scalar_type> rwork;
// if (lwork == static_cast<integer>(work.size())) rwork.resize(3 * n - 2);
// #if defined(MADNESS_LINALG_USE_LAPACKE)
// MADNESS_DISPATCH_LAPACK_FN(heev, &jobz, &uplo, &n, a, &lda,
// evals.data(), work.data(), &lwork,
// &rwork.data(), &info);
// #else
// MADNESS_DISPATCH_LAPACK_FN(
// heev, &jobz, &uplo, &n, a, &lda, evals.data(), work.data(), &lwork,
// &rwork.data(), &info, sizeof(char), sizeof(char));
// #endif
// }
// if (lwork == -1) {
// if constexpr (is_real) {
// lwork = static_cast<integer>(work[0]);
// } else {
// lwork = static_cast<integer>(work[0].real());
// }
// TA_ASSERT(lwork > 1);
// }
// };

// if (info != 0) {
// if (is_real)
// TA_EXCEPTION("LAPACK::syev failed");
// else
// TA_EXCEPTION("LAPACK::heev failed");
// }
// }

if (info != 0) {
if (is_real)
TA_EXCEPTION("LAPACK::syev failed");
else
TA_EXCEPTION("LAPACK::heev failed");
}
}
world.gop.broadcast_serializable(A_eig, 0);
world.gop.broadcast_serializable(evals, 0);
if (evec_trange.rank() == 0) evec_trange = A.trange();
Expand Down
Loading

0 comments on commit 9a1e84f

Please sign in to comment.