Skip to content

Commit

Permalink
Fix matvec output dims to match A rather than B (#523)
Browse files Browse the repository at this point in the history
For matvecs, the batch dimensions for A and B should match
and the final output dimension should match dim Rank-1 from A.
Also generalize batching support so that the size of out_dims_
is based on the output rank.
  • Loading branch information
tbensonatl authored Dec 4, 2023
1 parent 17fdfc9 commit 1ef3e2e
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 3 deletions.
7 changes: 4 additions & 3 deletions include/matx/operators/matvec.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ namespace matx
OpB b_;
float alpha_;
float beta_;
std::array<index_t, 2> out_dims_;
mutable matx::tensor_t<typename OpA::scalar_type, 2> tmp_out_;
static constexpr int RANK = remove_cvref_t<OpB>::Rank();
std::array<index_t, RANK> out_dims_;
mutable matx::tensor_t<typename OpA::scalar_type, RANK> tmp_out_;

public:
using matxop = bool;
Expand All @@ -65,7 +66,7 @@ namespace matx
a_(A), b_(B), alpha_(alpha), beta_(beta) {

for (int r = 0; r < Rank(); r++) {
out_dims_[r] = b_.Size(r);
out_dims_[r] = a_.Size(r);
}
}

Expand Down
26 changes: 26 additions & 0 deletions test/00_transform/MatMul.cu
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,12 @@ TYPED_TEST(MatMulTestFloatTypes, MediumMatVec)
(cs = matvec(a, bs)).run();
// example-end matvec-test-1

// Test the rank/size of the matvec operator
auto a_times_bs = matvec(a, bs);
ASSERT_EQ(a_times_bs.Rank(), 1);
ASSERT_EQ(a_times_bs.Size(0), m);
ASSERT_EQ(cs.Size(0), m);

MATX_TEST_ASSERT_COMPARE(this->pb, c, "c", this->thresh);

// Test also with rank-1 tensors rather than just slices
Expand All @@ -693,6 +699,26 @@ TYPED_TEST(MatMulTestFloatTypes, MediumMatVec)

MATX_TEST_ASSERT_COMPARE(this->pb, c, "c", this->thresh);

// Test with batching
constexpr index_t batch1 = 5;
constexpr index_t batch2 = 9;
auto a_batch = clone<4>(a, {batch1, batch2, matxKeepDim, matxKeepDim});
auto b_batch = clone<3>(bs, {batch1, batch2, matxKeepDim});
auto batched_matvec = matvec(a_batch, b_batch);
ASSERT_EQ(batched_matvec.Rank(), 3);
ASSERT_EQ(batched_matvec.Size(0), batch1);
ASSERT_EQ(batched_matvec.Size(1), batch2);
ASSERT_EQ(batched_matvec.Size(2), m);
auto result = make_tensor<TypeParam>(batched_matvec.Shape());
(result = batched_matvec).run();
for (index_t i = 0; i < batch1; i++) {
for (index_t j = 0; j < batch2; j++) {
auto rs = slice<1>(result, {i,j,0}, {matxDropDim,matxDropDim,matxEnd});
auto rsc = clone<2>(rs, {matxKeepDim,1});
MATX_TEST_ASSERT_COMPARE(this->pb, rsc, "c", this->thresh);
}
}

MATX_EXIT_HANDLER();
}

Expand Down

0 comments on commit 1ef3e2e

Please sign in to comment.