Skip to content

Commit

Permalink
fixup! Add additional test in verify_onnx for different datatypes usi…
Browse files Browse the repository at this point in the history
…ng parsed in mod protobuf
  • Loading branch information
TedThemistokleous committed Jul 15, 2022
1 parent ca630d4 commit 5419584
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 17 deletions.
2 changes: 1 addition & 1 deletion src/include/migraphx/op/mod.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ struct mod : binary<mod>
std::string point_function() const { return "mod"; }
auto apply() const
{
return [](auto x, auto y) { return std::fmod((std::abs(std::fmod(x, y)) + y), y); };
return [](auto x, auto y) { return std::fmod((std::remainder(x, y)) + y, y); };
}
};

Expand Down
34 changes: 18 additions & 16 deletions test/onnx/verify_onnx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -638,12 +638,13 @@ TEST_CASE(mod_test)

migraphx::shape s{migraphx::shape::float_type, {3, 3, 3}};

std::vector<float> a = {1.0, -2.0, 3.0, 4.0, -5.0, 6.0, 7.0, -8.0, 9.0,
10.0, 11.0, 12.0, 13.0, -14.0, 15.0, 16.0, 17.0, 18.0,
19.0, 20.0, 21.0, -22.0, 23.0, -24.0, 25.0, 26.0, 27.0};
std::vector<float> a = {-4.0, 7.0, 5.0, 4.0, -7.0, 8.0, -4.0, 7.0, 5.0,
4.0, -7.0, 8.0, -4.0, 7.0, 5.0, 4.0, -7.0, 8.0,
-4.0, 7.0, 5.0, 4.0, -7.0, 8.0, -4.0, 7.0, 5.0};

std::vector<float> b = {30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17,
16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4};
std::vector<float> b = {2.0, -3.0, 8.0, -2.0, 3.0, 5.0, 2.0, -3.0, 8.0,
-2.0, 3.0, 5.0, 2.0, -3.0, 8.0, -2.0, 3.0, 5.0,
2.0, -3.0, 8.0, -2.0, 3.0, 5.0, 2.0, -3.0, 8.0};

migraphx::parameter_map p_map;
p_map["0"] = migraphx::argument(s, a.data());
Expand All @@ -653,9 +654,10 @@ TEST_CASE(mod_test)
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });

std::vector<float> gold{1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0,
10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 1.0, 3.0, 5.0,
7.0, 9.0, 1.0, 4.0, 7.0, 3.0, 1.0, 1.0, 3.0};
std::vector<float> gold = {0.0, -2.0, 5.0, 0.0, 2.0, 3.0, 0.0, -2.0, 5.0,
0.0, 2.0, 3.0, 0.0, -2.0, 5.0, 0.0, 2.0, 3.0,
0.0, -2.0, 5.0, 0.0, 2.0, 3.0, 0.0, -2.0, 5.0};

EXPECT(migraphx::verify_range(result_vector, gold));
}

Expand All @@ -667,12 +669,12 @@ TEST_CASE(mod_test_different_types)
migraphx::shape s_float{migraphx::shape::float_type, {3, 3, 3}};
migraphx::shape s_int{migraphx::shape::int32_type, {3, 3, 3}};

std::vector<float> a = {1.0, -2.0, 3.0, 4.0, -5.0, 6.0, 7.0, -8.0, 9.0,
10.0, 11.0, 12.0, 13.0, -14.0, 15.0, 16.0, 17.0, 18.0,
19.0, 20.0, 21.0, -22.0, 23.0, -24.0, 25.0, 26.0, 27.0};
std::vector<float> a = {-4.0, 7.0, 5.0, 4.0, -7.0, 8.0, -4.0, 7.0, 5.0,
4.0, -7.0, 8.0, -4.0, 7.0, 5.0, 4.0, -7.0, 8.0,
-4.0, 7.0, 5.0, 4.0, -7.0, 8.0, -4.0, 7.0, 5.0};

std::vector<int32_t> b = {30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17,
16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4};
std::vector<int32_t> b = {2, -3, 8, -2, 3, 5, 2, -3, 8, -2, 3, 5, 2, -3,
8, -2, 3, 5, 2, -3, 8, -2, 3, 5, 2, -3, 8};

migraphx::parameter_map p_map;
p_map["0"] = migraphx::argument(s_float, a.data());
Expand All @@ -682,9 +684,9 @@ TEST_CASE(mod_test_different_types)
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });

std::vector<float> gold{1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0,
10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 1.0, 3.0, 5.0,
7.0, 9.0, 1.0, 4.0, 7.0, 3.0, 1.0, 1.0, 3.0};
std::vector<float> gold = {0.0, -2.0, 5.0, 0.0, 2.0, 3.0, 0.0, -2.0, 5.0,
0.0, 2.0, 3.0, 0.0, -2.0, 5.0, 0.0, 2.0, 3.0,
0.0, -2.0, 5.0, 0.0, 2.0, 3.0, 0.0, -2.0, 5.0};

EXPECT(migraphx::verify_range(result_vector, gold));
}
Expand Down

0 comments on commit 5419584

Please sign in to comment.