Skip to content

Commit

Permalink
Support cuda 12.x (#10367)
Browse files Browse the repository at this point in the history
  • Loading branch information
mosout authored Dec 22, 2023
1 parent 559f5ec commit a6feab0
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 11 deletions.
2 changes: 1 addition & 1 deletion oneflow/core/device/cuda_pseudo_bfloat16.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ limitations under the License.
#include <cuda_bf16.h>
#endif

#if CUDA_VERSION >= 11000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
#if CUDA_VERSION >= 11000 && CUDA_VERSION <= 12010 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800

#define DEFINE_CUDA_PSEUDO_BFLOAT16_ARITHMETIC_BINARY_OPERATOR(op) \
__device__ __forceinline__ __nv_bfloat16 operator op(const __nv_bfloat16& lh, \
Expand Down
18 changes: 17 additions & 1 deletion oneflow/core/ep/test/primitive/fill_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <type_traits>
#include <gtest/gtest.h>
#include "oneflow/core/ep/test/primitive/primitive_test.h"
#include "oneflow/core/ep/include/primitive/memset.h"
Expand Down Expand Up @@ -66,7 +67,22 @@ void TestFill(DeviceManagerRegistry* registry, const std::set<DeviceType>& devic
fill->Launch(stream.stream(), device_mem.ptr(), Scalar(0), n);
d2h->Launch(stream.stream(), host_mem.ptr(), device_mem.ptr(), vector_size);
CHECK_JUST(stream.stream()->Sync());
for (size_t i = 0; i < n; ++i) { ASSERT_EQ(*reinterpret_cast<T*>(host_mem.ptr<T>() + i), 0); }

for (size_t i = 0; i < n; ++i) {
#ifdef WITH_CUDA
if constexpr (std::is_same_v<T, half>) {
ASSERT_EQ(*reinterpret_cast<T*>(host_mem.ptr<T>() + i), __float2half(0.0));
#if CUDA_VERSION >= 11000
} else if constexpr (std::is_same_v<T, nv_bfloat16>) {
ASSERT_EQ(*reinterpret_cast<T*>(host_mem.ptr<T>() + i), __float2bfloat16(0.0));
#endif // CUDA_VERSION >= 11000
} else {
ASSERT_EQ(*reinterpret_cast<T*>(host_mem.ptr<T>() + i), static_cast<T>(0));
}
#else
ASSERT_EQ(*reinterpret_cast<T*>(host_mem.ptr<T>() + i), static_cast<T>(0));
#endif // WITH_CUDA
}
}
}

Expand Down
11 changes: 5 additions & 6 deletions oneflow/core/ndarray/ndarray_reduce_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,16 @@ struct NanSum {
}
};

template<>
OF_DEVICE_FUNC cuComplex cub::Sum::operator()(const cuComplex& a, const cuComplex& b) const {
} // namespace cub

__host__ __device__ __forceinline__ cuComplex operator+(const cuComplex& a, const cuComplex& b) {
return cuComplex{a.x + b.x, a.y + b.y};
}

template<>
OF_DEVICE_FUNC cuDoubleComplex cub::Sum::operator()(const cuDoubleComplex& a,
const cuDoubleComplex& b) const {
__host__ __device__ __forceinline__ cuDoubleComplex operator+(const cuDoubleComplex& a,
const cuDoubleComplex& b) {
return cuDoubleComplex{a.x + b.x, a.y + b.y};
}
} // namespace cub

namespace oneflow {

Expand Down
6 changes: 3 additions & 3 deletions oneflow/user/kernels/arange_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class ArangeKernel final : public OpKernel, public CudaGraphSupport {
start = static_cast<T>(static_cast<double>(ctx->Attr<int64_t>("integer_start")));
delta = static_cast<T>(static_cast<double>(ctx->Attr<int64_t>("integer_delta")));
limit = static_cast<T>(static_cast<double>(ctx->Attr<int64_t>("integer_limit")));
arange_elem_cnt = std::ceil(static_cast<double>(limit - start) / delta);
arange_elem_cnt = std::ceil(static_cast<double>(limit - start) / static_cast<double>(delta));
} else {
// If we use static_cast<T>(start, delta, limit) and std::ceil to calculate arange_elem_cnt,
// it will cause rounding error.
Expand All @@ -102,8 +102,8 @@ class ArangeKernel final : public OpKernel, public CudaGraphSupport {
} else {
const auto* arange_cache = dynamic_cast<const ArangeOpKernelCache*>(cache);
auto arange_len = arange_cache->upper() - arange_cache->lower();
ArangeFunctor<device_type, T>()(ctx->stream(),
static_cast<T>(start + delta * arange_cache->lower()), delta,
auto lower = static_cast<T>(static_cast<float>(arange_cache->lower()));
ArangeFunctor<device_type, T>()(ctx->stream(), static_cast<T>(start + delta * lower), delta,
arange_len, output);
}
}
Expand Down

0 comments on commit a6feab0

Please sign in to comment.