diff --git a/oneflow/core/device/cuda_pseudo_bfloat16.h b/oneflow/core/device/cuda_pseudo_bfloat16.h index 8443f856610..f06d592a17e 100644 --- a/oneflow/core/device/cuda_pseudo_bfloat16.h +++ b/oneflow/core/device/cuda_pseudo_bfloat16.h @@ -24,7 +24,7 @@ limitations under the License. #include #endif -#if CUDA_VERSION >= 11000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 +#if CUDA_VERSION >= 11000 && CUDA_VERSION <= 12010 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #define DEFINE_CUDA_PSEUDO_BFLOAT16_ARITHMETIC_BINARY_OPERATOR(op) \ __device__ __forceinline__ __nv_bfloat16 operator op(const __nv_bfloat16& lh, \ diff --git a/oneflow/core/ep/test/primitive/fill_test.cpp b/oneflow/core/ep/test/primitive/fill_test.cpp index a9863154fcb..f30302ee4fa 100644 --- a/oneflow/core/ep/test/primitive/fill_test.cpp +++ b/oneflow/core/ep/test/primitive/fill_test.cpp @@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include #include "oneflow/core/ep/test/primitive/primitive_test.h" #include "oneflow/core/ep/include/primitive/memset.h" @@ -66,7 +67,22 @@ void TestFill(DeviceManagerRegistry* registry, const std::set& devic fill->Launch(stream.stream(), device_mem.ptr(), Scalar(0), n); d2h->Launch(stream.stream(), host_mem.ptr(), device_mem.ptr(), vector_size); CHECK_JUST(stream.stream()->Sync()); - for (size_t i = 0; i < n; ++i) { ASSERT_EQ(*reinterpret_cast(host_mem.ptr() + i), 0); } + + for (size_t i = 0; i < n; ++i) { +#ifdef WITH_CUDA + if constexpr (std::is_same_v) { + ASSERT_EQ(*reinterpret_cast(host_mem.ptr() + i), __float2half(0.0)); +#if CUDA_VERSION >= 11000 + } else if constexpr (std::is_same_v) { + ASSERT_EQ(*reinterpret_cast(host_mem.ptr() + i), __float2bfloat16(0.0)); +#endif // CUDA_VERSION >= 11000 + } else { + ASSERT_EQ(*reinterpret_cast(host_mem.ptr() + i), static_cast(0)); + } +#else + ASSERT_EQ(*reinterpret_cast(host_mem.ptr() + i), static_cast(0)); +#endif // WITH_CUDA + } } } diff --git a/oneflow/core/ndarray/ndarray_reduce_impl.cu b/oneflow/core/ndarray/ndarray_reduce_impl.cu index e436f0c32f2..d2ad50b0524 100644 --- a/oneflow/core/ndarray/ndarray_reduce_impl.cu +++ b/oneflow/core/ndarray/ndarray_reduce_impl.cu @@ -51,17 +51,16 @@ struct NanSum { } }; -template<> -OF_DEVICE_FUNC cuComplex cub::Sum::operator()(const cuComplex& a, const cuComplex& b) const { +} // namespace cub + +__host__ __device__ __forceinline__ cuComplex operator+(const cuComplex& a, const cuComplex& b) { return cuComplex{a.x + b.x, a.y + b.y}; } -template<> -OF_DEVICE_FUNC cuDoubleComplex cub::Sum::operator()(const cuDoubleComplex& a, - const cuDoubleComplex& b) const { +__host__ __device__ __forceinline__ cuDoubleComplex operator+(const cuDoubleComplex& a, + const cuDoubleComplex& b) { return cuDoubleComplex{a.x + b.x, a.y + b.y}; } -} // namespace cub namespace oneflow { diff --git a/oneflow/user/kernels/arange_kernel.cpp b/oneflow/user/kernels/arange_kernel.cpp index 2cfa9c3e4f3..2771bec508e 100644 --- a/oneflow/user/kernels/arange_kernel.cpp +++ b/oneflow/user/kernels/arange_kernel.cpp @@ -84,7 +84,7 @@ class ArangeKernel final : public OpKernel, public CudaGraphSupport { start = static_cast(static_cast(ctx->Attr("integer_start"))); delta = static_cast(static_cast(ctx->Attr("integer_delta"))); limit = static_cast(static_cast(ctx->Attr("integer_limit"))); - arange_elem_cnt = std::ceil(static_cast(limit - start) / delta); + arange_elem_cnt = std::ceil(static_cast(limit - start) / static_cast(delta)); } else { // If we use static_cast(start, delta, limit) and std::ceil to calculate arange_elem_cnt, // it will cause rounding error. @@ -102,8 +102,8 @@ class ArangeKernel final : public OpKernel, public CudaGraphSupport { } else { const auto* arange_cache = dynamic_cast(cache); auto arange_len = arange_cache->upper() - arange_cache->lower(); - ArangeFunctor()(ctx->stream(), - static_cast(start + delta * arange_cache->lower()), delta, + auto lower = static_cast(static_cast(arange_cache->lower())); + ArangeFunctor()(ctx->stream(), static_cast(start + delta * lower), delta, arange_len, output); } }