Skip to content

Commit

Permalink
Merge pull request #460 from ValeevGroup/gaudel/fix/more_tot_corner_case
Browse files Browse the repository at this point in the history
Tests and fixes one more corner case of ToT x ToT evaluation.
  • Loading branch information
bimalgaudel authored Jul 8, 2024
2 parents e15fcdf + 6994163 commit 253d59c
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 3 deletions.
19 changes: 16 additions & 3 deletions src/TiledArray/einsum/tiledarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -525,9 +525,22 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,
// contracted indices
auto i = (a & b) - h;

// no Hadamard indices => standard contraction (or even outer product)
// same a, b, and c => pure Hadamard
if (!h || (h && !(i || e))) {
//
// *) Pure Hadamard indices: (h && !(i || e)) is true implies
// the evaluation can be delegated to the expression layer
// for distarrays of both nested and non-nested tensor tiles.
// *) If no Hadamard indices are present (!h) the evaluation
// can be delegated to the expression _only_ for distarrays with
// non-nested tensor tiles.
// This is because even if Hadamard indices are not present, a contracted
// index might be present pertinent to the outer tensor in case of a
// nested-tile distarray, which is especially handled within this
// function because expression layer cannot handle that yet.
//
if ((h && !(i || e)) // pure Hadamard
|| (IsArrayToT<ArrayC> && !(i || h)) // ToT result from outer-product
|| (IsArrayT<ArrayC> && !h)) // T from general product without Hadamard
{
ArrayC C;
C(std::string(c) + inner.c) = A * B;
return C;
Expand Down
13 changes: 13 additions & 0 deletions tests/einsum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,19 @@ BOOST_AUTO_TEST_CASE(corner_cases) {
{{0, 4, 8}, {0, 4}}, //
{{0, 4, 8}, {0, 4}}, //
{8})));

BOOST_REQUIRE(check_manual_eval<ArrayToT>("il;bae,il;e->li;ab", //
{{0, 2}, {0, 4}}, //
{{0, 2}, {0, 4}}, //
{4, 2, 3}, //
{3}));

BOOST_REQUIRE(
check_manual_eval<ArrayToT>("ijkl;abecdf,k;e->ijl;bafdc", //
{{0, 2}, {0, 3}, {0, 4}, {0, 5}}, //
{{0, 4}}, //
{2, 3, 6, 4, 5, 7}, //
{6}));
}

BOOST_AUTO_TEST_SUITE_END()
Expand Down

0 comments on commit 253d59c

Please sign in to comment.