Skip to content

Commit

Permalink
Merge pull request #416 from seoklab/python/tools/tm
Browse files Browse the repository at this point in the history
feat(python/tools/tm): provide TM-align interface
  • Loading branch information
jnooree authored Nov 29, 2024
2 parents e1d3a2e + 25d3b44 commit 4fd6e86
Show file tree
Hide file tree
Showing 10 changed files with 760 additions and 92 deletions.
24 changes: 20 additions & 4 deletions include/nuri/tools/tm.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,32 @@

namespace nuri {
namespace internal {
class AlignedXY;

template <class Pred>
void remap_helper(AlignedXY &xy, const Pred &pred) noexcept;

class AlignedXY {
public:
AlignedXY(ConstRef<Matrix3Xd> x, ConstRef<Matrix3Xd> y, const int l_min)
: x_(x), y_(y), xtm_(3, l_min), ytm_(3, l_min), y2x_(y.cols()),
l_ali_(0) { }

void remap(ArrayXi &y2x) noexcept;
void remap(ConstRef<ArrayXi> y2x) noexcept;

void remap(ArrayXi &&y2x) noexcept;

void remap_final(ConstRef<Matrix3Xd> x_aln, double score_d8sq) noexcept;

void swap_remap(ArrayXi &y2x) noexcept;

void swap_align_with(ArrayXi &y2x) noexcept;

void reset() noexcept {
y2x_.setConstant(-1);
l_ali_ = 0;
}

ConstRef<Matrix3Xd> x() const { return x_; }

auto xtm() { return xtm_.leftCols(l_ali_); }
Expand All @@ -49,6 +63,9 @@ namespace internal {
const ArrayXi &y2x() const { return y2x_; }

private:
template <class Pred>
friend void remap_helper(AlignedXY &xy, const Pred &pred) noexcept;

Eigen::Ref<const Matrix3Xd> x_;
Eigen::Ref<const Matrix3Xd> y_;

Expand Down Expand Up @@ -229,14 +246,13 @@ class TMAlign {
* @param y2x A map of the template structure to the query structure. Negative
* values indicate that the corresponding residue in the template
* structure is not aligned to any residue in the query structure.
* Will be invalidated after this call.
* @return Whether the initialization was successful.
* @note If size of y2x is not equal to the length of the template structure
* or any value of y2x is larger than or equal to the length of the
* query structure, the behavior is undefined.
*/
ABSL_MUST_USE_RESULT
bool initialize(ArrayXi &y2x);
bool initialize(ConstRef<ArrayXi> y2x);

bool initialized() const { return xy_.l_ali() > 0; }

Expand Down Expand Up @@ -394,7 +410,7 @@ tm_align(ConstRef<Matrix3Xd> query, ConstRef<Matrix3Xd> templ,
* structure, the behavior is undefined.
*/
extern TMAlignResult tm_align(ConstRef<Matrix3Xd> query,
ConstRef<Matrix3Xd> templ, ArrayXi &y2x,
ConstRef<Matrix3Xd> templ, ConstRef<ArrayXi> y2x,
int l_norm = -1, double d0 = -1);

// test utils
Expand Down
1 change: 1 addition & 0 deletions python/docs/nuri.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ Submodules
nuri.core
nuri.fmt
nuri.algo
nuri.tools

-------------------
Top-level Functions
Expand Down
37 changes: 37 additions & 0 deletions python/docs/nuri.tools.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
.. Project NuriKit - Copyright 2024 SNU Compbio Lab.
SPDX-License-Identifier: Apache-2.0
==========
nuri.tools
==========

.. currentmodule:: nuri.tools

.. automodule:: nuri.tools

--------
TM-tools
--------

.. currentmodule:: nuri.tools.tm

.. code-block:: python
from nuri.tools import tm as tmtools
.. automodule:: nuri.tools.tm
:exclude-members: TMAlign

This module provides ground-up reimplementation of TM-align algorithm based
on the original TM-align code (version 20220412) by Yang Zhang. This
implementation aims to reproduce the results of the original code while
providing improved user interface and maintainability. Refer to the
following paper for details of the algorithm. :footcite:`tm-align`

.. footbibliography::

.. autoclass:: TMAlign
:exclude-members: from_alignment

.. automethod:: __init__
.. automethod:: from_alignment
118 changes: 69 additions & 49 deletions python/include/nuri/python/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>

#include <absl/algorithm/container.h>
#include <absl/log/absl_check.h>
#include <absl/log/absl_log.h>
#include <absl/strings/str_cat.h>
Expand Down Expand Up @@ -287,6 +288,46 @@ template <Eigen::Index Rows = Eigen::Dynamic,
Eigen::Index Cols = Eigen::Dynamic, class DT = double>
class NpArrayWrapper;

template <Eigen::Index Rows, Eigen::Index Cols, class DT>
void numpy_to_eigen_check_compat(const py::array_t<DT> &arr) {
constexpr bool is_vector = Rows == 1 || Cols == 1;
constexpr Eigen::Index size = Rows == Eigen::Dynamic || Cols == Eigen::Dynamic
? Eigen::Dynamic
: Rows * Cols;

if constexpr (is_vector) {
if (arr.ndim() != 1)
throw py::value_error(
absl::StrCat("expected 1D array, got ", arr.ndim(), "D"));

if constexpr (size != Eigen::Dynamic) {
if (arr.size() != size)
throw py::value_error(
absl::StrCat("expected ", size, " elements, got ", arr.size()));
}
} else {
if (arr.ndim() != 2)
throw py::value_error(
absl::StrCat("expected 2D array, got ", arr.ndim(), "D"));

const auto py_rows = arr.shape()[0], py_cols = arr.shape()[1];

if constexpr (Cols != Eigen::Dynamic) {
if (Cols != py_rows) {
throw py::value_error(
absl::StrCat("expected ", Cols, " rows, got ", py_rows));
}
}

if constexpr (Rows != Eigen::Dynamic) {
if (Rows != py_cols) {
throw py::value_error(
absl::StrCat("expected ", Rows, " columns, got ", py_cols));
}
}
}
}

template <class ML>
using NpArrayLike = NpArrayWrapper<ML::RowsAtCompileTime, ML::ColsAtCompileTime,
typename ML::Scalar>;
Expand Down Expand Up @@ -324,35 +365,50 @@ class NpArrayWrapper: private py::array_t<DT> {
Parent numpy() && { return std::move(*this); }

private:
explicit NpArrayWrapper(const Parent &arr): Parent(arr) {
explicit NpArrayWrapper(std::vector<py::ssize_t> &&shape)
: Parent(std::move(shape)) {
check_invariants();
}

explicit NpArrayWrapper(Parent &&arr): Parent(std::move(arr)) {
explicit NpArrayWrapper(const Parent &arr): Parent(arr) {
check_invariants();
}

explicit NpArrayWrapper(std::vector<py::ssize_t> &&shape)
: Parent(std::move(shape)) {
explicit NpArrayWrapper(Parent &&arr): Parent(std::move(arr)) {
check_invariants();
}

void check_invariants() const {
#ifdef NURI_DEBUG
int req_ndim = kIsVector ? 1 : 2;
constexpr int req_ndim = kIsVector ? 1 : 2;

numpy_to_eigen_check_compat<Rows, Cols, DT>(*this);

ABSL_DCHECK_EQ(this->ndim(), req_ndim);
ABSL_DCHECK_EQ(eigen_stride(*this, req_ndim - 1), 1);
#endif
const auto inner_stride = eigen_stride(*this, req_ndim - 1);
if (inner_stride != 1) {
throw std::runtime_error(
absl::StrCat("Unexpected inner stride (", inner_stride, " != 1)"));
}
}

template <Eigen::Index R, Eigen::Index C, class DU>
friend NpArrayWrapper<R, C, DU> py_array_cast(py::handle h);
friend NpArrayWrapper<R, C, DU>
empty_numpy(std::vector<py::ssize_t> &&eigen_shape);

template <class ML>
friend NpArrayLike<ML> empty_like(const ML &mat);

template <Eigen::Index R, Eigen::Index C, class DU>
friend NpArrayWrapper<R, C, DU> py_array_cast(py::handle h);
};

template <Eigen::Index Rows = Eigen::Dynamic,
Eigen::Index Cols = Eigen::Dynamic, class DT = double>
NpArrayWrapper<Rows, Cols, DT>
empty_numpy(std::vector<py::ssize_t> &&eigen_shape) {
absl::c_reverse(eigen_shape);
return NpArrayWrapper<Rows, Cols, DT>(std::move(eigen_shape));
}

template <class ML>
NpArrayLike<ML> empty_like(const ML &mat) {
std::vector<py::ssize_t> shape;
Expand All @@ -376,6 +432,7 @@ NpArrayWrapper<Rows, Cols, DT> py_array_cast(py::handle h) {
}

py::array_t<DT> arr = py::reinterpret_steal<py::array_t<DT>>(result);
numpy_to_eigen_check_compat<Rows, Cols, DT>(arr);

auto maybe_copy = [&arr](Eigen::Index rows, Eigen::Index cols, auto strides) {
if (strides.inner() == 1)
Expand All @@ -391,23 +448,7 @@ NpArrayWrapper<Rows, Cols, DT> py_array_cast(py::handle h) {
return wrapper;
};

constexpr bool is_vector = Rows == 1 || Cols == 1;
constexpr Eigen::Index size = Rows == Eigen::Dynamic || Cols == Eigen::Dynamic
? Eigen::Dynamic
: Rows * Cols;

if constexpr (is_vector) {
if (arr.ndim() != 1)
throw py::value_error(
absl::StrCat("expected 1D array, got ", arr.ndim(), "D"));

if constexpr (size != Eigen::Dynamic) {
if (arr.size() != size) {
throw py::value_error(
absl::StrCat("expected ", size, " elements, got ", arr.size()));
}
}

if constexpr (Rows == 1 || Cols == 1) {
Eigen::Index rows, cols;
if constexpr (Cols == 1) {
rows = arr.size();
Expand All @@ -416,31 +457,10 @@ NpArrayWrapper<Rows, Cols, DT> py_array_cast(py::handle h) {
rows = 1;
cols = arr.size();
}

return maybe_copy(rows, cols,
Eigen::InnerStride<> { eigen_stride(arr, 0) });
} else {
if (arr.ndim() != 2)
throw py::value_error(
absl::StrCat("expected 2D array, got ", arr.ndim(), "D"));

auto py_rows = arr.shape()[0], py_cols = arr.shape()[1];

if constexpr (Cols != Eigen::Dynamic) {
if (Cols != py_rows) {
throw py::value_error(
absl::StrCat("expected ", Cols, " rows, got ", py_rows));
}
}

if constexpr (Rows != Eigen::Dynamic) {
if (Rows != py_cols) {
throw py::value_error(
absl::StrCat("expected ", Rows, " columns, got ", py_cols));
}
}

return maybe_copy(py_cols, py_rows,
return maybe_copy(arr.shape()[1], arr.shape()[0],
py::EigenDStride { eigen_stride(arr, 0),
eigen_stride(arr, 1) });
}
Expand Down
4 changes: 4 additions & 0 deletions python/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,7 @@ nuri_python_add_module(fmt
nuri/fmt/fmt_module.cpp
)
target_link_libraries(nuri_python_fmt PRIVATE absl::span)

nuri_python_add_module(tm OUTPUT_DIRECTORY "tools"
nuri/tools/tm_module.cpp
)
4 changes: 4 additions & 0 deletions python/src/nuri/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#
# Project NuriKit - Copyright 2024 SNU Compbio Lab.
# SPDX-License-Identifier: Apache-2.0
#
Loading

0 comments on commit 4fd6e86

Please sign in to comment.