Skip to content

Commit

Permalink
Test outer pure-Hadamard with inner tensors contraction.
Browse files Browse the repository at this point in the history
  • Loading branch information
bimalgaudel committed Jan 18, 2024
1 parent 3990ede commit df0b808
Showing 1 changed file with 10 additions and 53 deletions.
63 changes: 10 additions & 53 deletions tests/einsum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -792,63 +792,20 @@ BOOST_AUTO_TEST_CASE(xxx) {
}

BOOST_AUTO_TEST_CASE(ij_mn_eq_ij_mo_times_ji_on) {
using Array = TA::DistArray<TA::Tensor<TA::Tensor<int>>, TA::DensePolicy>;
auto& world = TA::get_default_world();

TA::Range const inner_rng{2, 7};
TA::Range const inner_rng_perm{7, 2};
TA::TiledRange lhs_trng{{0, 2, 4}, {0, 2}};
TA::TiledRange rhs_trng{{0, 2}, {0, 2, 4}};
auto lhs = random_array<Array>(lhs_trng, inner_rng);
auto rhs = random_array<Array>(rhs_trng, inner_rng_perm);

//
// manual evaluation: 'ij;mn = ij;mo * ji;on'
//
Array ref{world, lhs_trng};
{
lhs.make_replicated();
rhs.make_replicated();
world.gop.fence();

auto make_tile = [lhs, rhs](TA::Range const& rng) {
typename Array::value_type result_tile{rng};
for (auto&& res_ix : result_tile.range()) {
auto i = res_ix[0];
auto j = res_ix[1];

auto lhs_tile_ix = lhs.trange().element_to_tile({i, j});
auto lhs_tile = lhs.find_local(lhs_tile_ix).get(/* dowork = */ false);

auto rhs_tile_ix = rhs.trange().element_to_tile({j, i});
auto rhs_tile = rhs.find_local(rhs_tile_ix).get(/* dowork = */ false);

auto& res_el = result_tile({i, j});
auto const& lhs_el = lhs_tile({i, j});
auto const& rhs_el = rhs_tile({j, i});
using namespace std::string_literals;
res_el =
TA::detail::tensor_contract(lhs_el, "mo"s, rhs_el, "on"s, "mn"s);
}
return result_tile;
};
using std::begin;
using std::end;
using Array = TA::DistArray<TA::Tensor<TA::Tensor<int>>, TA::DensePolicy>;
using Perm = TA::Permutation;

for (auto it = begin(ref); it != end(ref); ++it)
if (ref.is_local(it.index())) {
auto tile = world.taskq.add(make_tile, it.make_range());
*it = tile;
}
}
TA::TiledRange lhs_trng{{0, 2, 3}, {0, 2, 4}};
TA::TiledRange rhs_trng{{0, 2, 4}, {0, 2, 3}};
TA::Range lhs_inner_rng{1, 1};
TA::Range rhs_inner_rng{1, 1};

auto out = einsum(lhs("i,j;m,o"), rhs("j,i;o,n"), "i,j;m,n");
std::cerr << "TODO: ij;mo * ji;on -> ij;mn using expression layer does not "
"produce the same result compared to manual evaluation."
<< '\n';
// bool are_equal = ToTArrayFixture::are_equal<ShapeComp::False>(ref, out);
// std::cout << out << '\n' << ref << '\n';
// BOOST_CHECK(are_equal);
auto lhs = random_array<Array>(lhs_trng, lhs_inner_rng);
auto rhs = random_array<Array>(rhs_trng, rhs_inner_rng);
Array out;
BOOST_REQUIRE_NO_THROW(out("i,j;m,n") = lhs("i,j;m,o") * rhs("j,i;o,n"));
}

BOOST_AUTO_TEST_CASE(ij_mn_eq_ijk_mo_times_ijk_no) {
Expand Down

0 comments on commit df0b808

Please sign in to comment.