Skip to content

Commit

Permalink
[skip ci] einsum unit test for
Browse files Browse the repository at this point in the history
ij;mn * kj;mn -> ijk;mn
  • Loading branch information
bimalgaudel committed Dec 15, 2023
1 parent 7b7dbb8 commit 02a7db7
Showing 1 changed file with 41 additions and 0 deletions.
41 changes: 41 additions & 0 deletions tests/einsum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,40 @@ BOOST_AUTO_TEST_CASE(ij_mn_eq_ij_mn_times_ji_mn) {
BOOST_CHECK(are_equal);
}

BOOST_AUTO_TEST_CASE(ijk_mn_eq_ij_mn_times_kj_mn) {
using dist_array_t = DistArray<Tensor<Tensor<double>>, DensePolicy>;
using matrix_il = TiledArray::detail::matrix_il<Tensor<double>>;
auto& world = TiledArray::get_default_world();

auto random_tot = [](TA::Range const& rng) {
TA::Range inner_rng{7,14};
TA::Tensor<double> t{inner_rng};
TA::Tensor<TA::Tensor<double>> result{rng};
for (auto& e: result) e = t;
return result;
};

auto random_tot_darr = [&random_tot](World& world,
TiledRange const& tr) {
dist_array_t result(world, tr);
for (auto it = result.begin(); it != result.end(); ++it) {
auto tile =
TA::get_default_world().taskq.add(random_tot, it.make_range());
*it = tile;
}
return result;
};

TiledRange lhs_trange{{0, 2, 4}, {0, 5}};
auto lhs = random_tot_darr(world, lhs_trange);

TiledRange rhs_trange{{0, 2, 4, 6}, {0, 5}};
auto rhs = random_tot_darr(world, rhs_trange);
dist_array_t result;
BOOST_REQUIRE_NO_THROW(
result = einsum(lhs("i,j;m,n"), rhs("k,j;m,n"), "i,j,k;m,n"));
}

BOOST_AUTO_TEST_CASE(xxx) {
using dist_array_t = DistArray<Tensor<Tensor<double>>, DensePolicy>;
using matrix_il = TiledArray::detail::matrix_il<Tensor<double>>;
Expand Down Expand Up @@ -1328,6 +1362,13 @@ BOOST_AUTO_TEST_CASE(einsum_tiledarray_hji_jih_hj) {
"hji,jih->hj");
}

BOOST_AUTO_TEST_CASE(einsum_tiledarray_ik_jk_ijk) {
einsum_tiledarray_check<2, 2, 3>(random<SparsePolicy>(7, 5),
random<SparsePolicy>(14, 5), "ik,jk->ijk");
einsum_tiledarray_check<2, 2, 3>(sparse_zero(7, 5), sparse_zero(14, 5),
"ik,jk->ijk");
}

BOOST_AUTO_TEST_CASE(einsum_tiledarray_replicated) {
einsum_tiledarray_check<3, 3, 3>(replicated(random<DensePolicy>(7, 14, 3)),
random<DensePolicy>(7, 15, 3),
Expand Down

0 comments on commit 02a7db7

Please sign in to comment.