From 8ee1ad74a623e035a5a78b967eb2c54eb801bff9 Mon Sep 17 00:00:00 2001 From: Daming Feng Date: Wed, 6 Dec 2023 19:23:12 -0600 Subject: [PATCH] Fix the f8 reference kernel issue that failed CI (#2586) --- src/kernels/hip_f8_impl.hpp | 11 ++++++----- test/gtest/conv_f8_bwd.cpp | 9 ++++++++- test/gtest/conv_f8_fwd.cpp | 7 ++++++- test/gtest/conv_f8_wrw.cpp | 9 ++++++++- 4 files changed, 28 insertions(+), 8 deletions(-) diff --git a/src/kernels/hip_f8_impl.hpp b/src/kernels/hip_f8_impl.hpp index c7a62f9f72..03b7f901bf 100644 --- a/src/kernels/hip_f8_impl.hpp +++ b/src/kernels/hip_f8_impl.hpp @@ -202,12 +202,13 @@ MIOPEN_HIP_HOST_DEVICE uint8_t cast_to_f8(T _x, bool stoch, uint32_t rng) } mantissa += (1 << mfmt); // Add the implicit 1 into mantissa } - const long tmp = (mfmt - wm + exponent_diff); - if(tmp == 33) - printf("Gotcha"); - bool midpoint = (mantissa & ((static_cast(1) << (mfmt - wm + exponent_diff)) - 1)) == - (static_cast(1) << (mfmt - wm + exponent_diff - 1)); + bool midpoint; + if(exponent_diff <= wm) + midpoint = (mantissa & ((1 << (mfmt - wm + exponent_diff)) - 1)) == + (1 << (mfmt - wm + exponent_diff - 1)); + else + midpoint = false; /* This part is a bit tricky. The judgment of whether it is a tie needs to be done before we shift right as shift right could rip off some residual part and make something not midpoint look like midpoint. For example, the fp16 number 0x1002 (0 00100 0000000010), it is larger diff --git a/test/gtest/conv_f8_bwd.cpp b/test/gtest/conv_f8_bwd.cpp index 097d68abb9..156260f875 100644 --- a/test/gtest/conv_f8_bwd.cpp +++ b/test/gtest/conv_f8_bwd.cpp @@ -34,7 +34,14 @@ std::vector ConvTestConfigs() { // g n c d h w k z y x pad_x pad_y pad_z stri_x stri_y stri_z dia_x dia_y dia_z - return {{1, 16, 16, 1, 14, 14, 16, 1, 3, 3, 1, 1, 0, 1, 1, 1, 1, 1, 1, miopenConvolution}}; + return {{1, 16, 16, 1, 14, 14, 16, 1, 3, 3, 1, 1, 0, 1, 1, 1, 1, 1, 1, miopenConvolution}, + {1, 128, 64, 1, 28, 28, 64, 1, 3, 3, 1, 1, 0, 1, 1, 1, 1, 1, 1, miopenConvolution}, + {1, 64, 32, 1, 28, 28, 32, 1, 3, 3, 1, 1, 0, 1, 1, 1, 1, 1, 1, miopenConvolution}, + {32, 128, 32, 1, 28, 28, 32, 1, 3, 3, 1, 1, 0, 1, 1, 1, 1, 1, 1, miopenConvolution}, + {16, 128, 16, 1, 28, 28, 16, 1, 3, 3, 1, 1, 0, 1, 1, 1, 1, 1, 1, miopenConvolution}, + {8, 128, 8, 1, 28, 28, 8, 1, 3, 3, 1, 1, 0, 1, 1, 1, 1, 1, 1, miopenConvolution}, + {4, 128, 4, 1, 28, 28, 4, 1, 3, 3, 1, 1, 0, 1, 1, 1, 1, 1, 1, miopenConvolution}, + {2, 128, 2, 1, 28, 28, 2, 1, 3, 3, 1, 1, 0, 1, 1, 1, 1, 1, 1, miopenConvolution}}; } template diff --git a/test/gtest/conv_f8_fwd.cpp b/test/gtest/conv_f8_fwd.cpp index b94feeeffb..433f9f7fa7 100644 --- a/test/gtest/conv_f8_fwd.cpp +++ b/test/gtest/conv_f8_fwd.cpp @@ -34,7 +34,12 @@ std::vector ConvTestConfigs() { // g n c d h w k z y x pad_x pad_y pad_z stri_x stri_y stri_z dia_x dia_y dia_z - return {{1, 16, 16, 1, 14, 14, 16, 1, 3, 3, 1, 1, 0, 1, 1, 1, 1, 1, 1, miopenConvolution}}; + return {{1, 16, 16, 1, 14, 14, 16, 1, 3, 3, 1, 1, 0, 1, 1, 1, 1, 1, 1, miopenConvolution}, + {1, 64, 64, 1, 14, 14, 64, 1, 3, 3, 1, 1, 0, 1, 1, 1, 1, 1, 1, miopenConvolution}, + {1, 64, 32, 1, 28, 28, 32, 1, 3, 3, 1, 1, 0, 1, 1, 1, 1, 1, 1, miopenConvolution}, + {2, 128, 32, 1, 28, 28, 32, 1, 3, 3, 1, 1, 0, 1, 1, 1, 1, 1, 1, miopenConvolution}, + {32, 128, 32, 1, 28, 28, 32, 1, 3, 3, 1, 1, 0, 1, 1, 1, 1, 1, 1, miopenConvolution}, + {5, 120, 60, 1, 28, 28, 60, 1, 3, 3, 1, 1, 0, 1, 1, 1, 1, 1, 1, miopenConvolution}}; } template diff --git a/test/gtest/conv_f8_wrw.cpp b/test/gtest/conv_f8_wrw.cpp index bae9cc3d99..925590833f 100644 --- a/test/gtest/conv_f8_wrw.cpp +++ b/test/gtest/conv_f8_wrw.cpp @@ -35,7 +35,14 @@ std::vector ConvTestConfigs() { // g n c d h w k z y x pad_x pad_y pad_z stri_x stri_y stri_z dia_x dia_y dia_z - return {{1, 16, 16, 1, 14, 14, 16, 1, 3, 3, 1, 1, 0, 1, 1, 1, 1, 1, 1, miopenConvolution}}; + return {{1, 16, 16, 1, 14, 14, 16, 1, 3, 3, 1, 1, 0, 1, 1, 1, 1, 1, 1, miopenConvolution}, + {1, 64, 128, 1, 28, 3, 128, 1, 3, 3, 1, 1, 0, 1, 1, 1, 1, 1, 1, miopenConvolution}, + {1, 64, 64, 1, 28, 3, 64, 1, 3, 3, 1, 1, 0, 1, 1, 1, 1, 1, 1, miopenConvolution}, + {1, 32, 64, 1, 14, 14, 64, 1, 3, 3, 1, 1, 0, 1, 1, 1, 1, 1, 1, miopenConvolution}, + {1, 32, 32, 1, 14, 14, 32, 1, 3, 3, 1, 1, 0, 1, 1, 1, 1, 1, 1, miopenConvolution}, + {1, 64, 32, 1, 14, 14, 32, 1, 3, 3, 1, 1, 0, 1, 1, 1, 1, 1, 1, miopenConvolution}, + {1, 128, 64, 1, 7, 7, 64, 1, 3, 3, 1, 1, 0, 1, 1, 1, 1, 1, 1, miopenConvolution}, + {1, 128, 32, 1, 7, 7, 32, 1, 3, 3, 1, 1, 0, 1, 1, 1, 1, 1, 1, miopenConvolution}}; } template