Skip to content

Commit

Permalink
fixes for complex add and sub and fixes for their unit tests (#339)
Browse files Browse the repository at this point in the history
* fixes for complex add and sub and fixes for their unit tests

* missed host tests
  • Loading branch information
tylera-nvidia authored Dec 1, 2022
1 parent f2c6edd commit a134ff9
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 16 deletions.
20 changes: 8 additions & 12 deletions include/matx/operators/scalar_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -307,24 +307,22 @@ template <typename T1, typename T2> struct AddF {
if constexpr (is_complex_half_v<T1>) {
return (T1){v1.real() + static_cast<typename T1::value_type>(
static_cast<float>(v2)),
v1.imag() + static_cast<typename T1::value_type>(
static_cast<float>(v2))};
v1.imag() };
}
else {
return (T1){v1.real() + static_cast<typename T1::value_type>(v2),
v1.imag() + static_cast<typename T1::value_type>(v2)};
v1.imag() };
}
}
else if constexpr (is_complex_v<T2> && std::is_arithmetic_v<T1>) {
if constexpr (is_complex_half_v<T2>) {
return (T2){v2.real() + static_cast<typename T2::value_type>(
static_cast<float>(v1)),
v2.imag() + static_cast<typename T2::value_type>(
static_cast<float>(v1))};
v2.imag() };
}
else {
return (T2){v2.real() + static_cast<typename T2::value_type>(v1),
v2.imag() + static_cast<typename T2::value_type>(v1)};
v2.imag() };
}
}
else {
Expand All @@ -347,24 +345,22 @@ template <typename T1, typename T2> struct SubF {
if constexpr (is_complex_half_v<T1>) {
return (T1){v1.real() - static_cast<typename T1::value_type>(
static_cast<float>(v2)),
v1.imag() - static_cast<typename T1::value_type>(
static_cast<float>(v2))};
v1.imag() };
}
else {
return (T1){v1.real() - static_cast<typename T1::value_type>(v2),
v1.imag() - static_cast<typename T1::value_type>(v2)};
v1.imag() };
}
}
else if constexpr (is_complex_v<T2> && std::is_arithmetic_v<T1>) {
if constexpr (is_complex_half_v<T2>) {
return (T2){v2.real() - static_cast<typename T2::value_type>(
static_cast<float>(v1)),
v2.imag() - static_cast<typename T2::value_type>(
static_cast<float>(v1))};
v2.imag() };
}
else {
return (T2){v2.real() - static_cast<typename T2::value_type>(v1),
v2.imag() - static_cast<typename T2::value_type>(v1)};
v2.imag() };
}
}
else {
Expand Down
4 changes: 2 additions & 2 deletions test/00_host/OperatorTests.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1466,7 +1466,7 @@ TYPED_TEST(HostOperatorTestsComplex, ComplexTypeCompatibility)
ASSERT_EQ(static_cast<detail::value_promote_t<TypeParam>>(dview(i).real()),
static_cast<detail::value_promote_t<TypeParam>>(i + i));
ASSERT_EQ(static_cast<detail::value_promote_t<TypeParam>>(dview(i).imag()),
static_cast<detail::value_promote_t<TypeParam>>(i + i));
static_cast<detail::value_promote_t<TypeParam>>(i));
}

// Subtract scalar
Expand All @@ -1482,7 +1482,7 @@ TYPED_TEST(HostOperatorTestsComplex, ComplexTypeCompatibility)
ASSERT_EQ(static_cast<detail::value_promote_t<TypeParam>>(dview(i).real()),
static_cast<detail::value_promote_t<TypeParam>>(-1));
ASSERT_EQ(static_cast<detail::value_promote_t<TypeParam>>(dview(i).imag()),
static_cast<detail::value_promote_t<TypeParam>>(-1));
static_cast<detail::value_promote_t<TypeParam>>(i));
}

MATX_EXIT_HANDLER();
Expand Down
5 changes: 3 additions & 2 deletions test/00_operators/OperatorTests.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1620,6 +1620,7 @@ TYPED_TEST(OperatorTestsComplex, ComplexTypeCompatibility)
dview(i) = {static_cast<detail::value_promote_t<TypeParam>>(i),
static_cast<detail::value_promote_t<TypeParam>>(i)};
}


(dview = dview + fview).run();
cudaDeviceSynchronize();
Expand All @@ -1628,7 +1629,7 @@ TYPED_TEST(OperatorTestsComplex, ComplexTypeCompatibility)
ASSERT_EQ(static_cast<detail::value_promote_t<TypeParam>>(dview(i).real()),
static_cast<detail::value_promote_t<TypeParam>>(i + i));
ASSERT_EQ(static_cast<detail::value_promote_t<TypeParam>>(dview(i).imag()),
static_cast<detail::value_promote_t<TypeParam>>(i + i));
static_cast<detail::value_promote_t<TypeParam>>(i));
}

// Subtract scalar
Expand All @@ -1645,7 +1646,7 @@ TYPED_TEST(OperatorTestsComplex, ComplexTypeCompatibility)
ASSERT_EQ(static_cast<detail::value_promote_t<TypeParam>>(dview(i).real()),
static_cast<detail::value_promote_t<TypeParam>>(-1));
ASSERT_EQ(static_cast<detail::value_promote_t<TypeParam>>(dview(i).imag()),
static_cast<detail::value_promote_t<TypeParam>>(-1));
static_cast<detail::value_promote_t<TypeParam>>(i));
}

MATX_EXIT_HANDLER();
Expand Down

0 comments on commit a134ff9

Please sign in to comment.