Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support cuda 12.x #10367

Merged
merged 6 commits into from
Dec 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion oneflow/core/device/cuda_pseudo_bfloat16.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ limitations under the License.
#include <cuda_bf16.h>
#endif

#if CUDA_VERSION >= 11000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
#if CUDA_VERSION >= 11000 && CUDA_VERSION <= 12010 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800

#define DEFINE_CUDA_PSEUDO_BFLOAT16_ARITHMETIC_BINARY_OPERATOR(op) \
__device__ __forceinline__ __nv_bfloat16 operator op(const __nv_bfloat16& lh, \
Expand Down
18 changes: 17 additions & 1 deletion oneflow/core/ep/test/primitive/fill_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <type_traits>
#include <gtest/gtest.h>
#include "oneflow/core/ep/test/primitive/primitive_test.h"
#include "oneflow/core/ep/include/primitive/memset.h"
Expand Down Expand Up @@ -66,7 +67,22 @@ void TestFill(DeviceManagerRegistry* registry, const std::set<DeviceType>& devic
fill->Launch(stream.stream(), device_mem.ptr(), Scalar(0), n);
d2h->Launch(stream.stream(), host_mem.ptr(), device_mem.ptr(), vector_size);
CHECK_JUST(stream.stream()->Sync());
for (size_t i = 0; i < n; ++i) { ASSERT_EQ(*reinterpret_cast<T*>(host_mem.ptr<T>() + i), 0); }

for (size_t i = 0; i < n; ++i) {
#ifdef WITH_CUDA
if constexpr (std::is_same_v<T, half>) {
ASSERT_EQ(*reinterpret_cast<T*>(host_mem.ptr<T>() + i), __float2half(0.0));
#if CUDA_VERSION >= 11000
} else if constexpr (std::is_same_v<T, nv_bfloat16>) {
ASSERT_EQ(*reinterpret_cast<T*>(host_mem.ptr<T>() + i), __float2bfloat16(0.0));
#endif // CUDA_VERSION >= 11000
} else {
ASSERT_EQ(*reinterpret_cast<T*>(host_mem.ptr<T>() + i), static_cast<T>(0));
}
#else
ASSERT_EQ(*reinterpret_cast<T*>(host_mem.ptr<T>() + i), static_cast<T>(0));
#endif // WITH_CUDA
}
}
}

Expand Down
11 changes: 5 additions & 6 deletions oneflow/core/ndarray/ndarray_reduce_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,16 @@ struct NanSum {
}
};

template<>
OF_DEVICE_FUNC cuComplex cub::Sum::operator()(const cuComplex& a, const cuComplex& b) const {
} // namespace cub

__host__ __device__ __forceinline__ cuComplex operator+(const cuComplex& a, const cuComplex& b) {
return cuComplex{a.x + b.x, a.y + b.y};
}

template<>
OF_DEVICE_FUNC cuDoubleComplex cub::Sum::operator()(const cuDoubleComplex& a,
const cuDoubleComplex& b) const {
__host__ __device__ __forceinline__ cuDoubleComplex operator+(const cuDoubleComplex& a,
const cuDoubleComplex& b) {
return cuDoubleComplex{a.x + b.x, a.y + b.y};
}
} // namespace cub

namespace oneflow {

Expand Down
6 changes: 3 additions & 3 deletions oneflow/user/kernels/arange_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class ArangeKernel final : public OpKernel, public CudaGraphSupport {
start = static_cast<T>(static_cast<double>(ctx->Attr<int64_t>("integer_start")));
delta = static_cast<T>(static_cast<double>(ctx->Attr<int64_t>("integer_delta")));
limit = static_cast<T>(static_cast<double>(ctx->Attr<int64_t>("integer_limit")));
arange_elem_cnt = std::ceil(static_cast<double>(limit - start) / delta);
arange_elem_cnt = std::ceil(static_cast<double>(limit - start) / static_cast<double>(delta));
} else {
// If we use static_cast<T>(start, delta, limit) and std::ceil to calculate arange_elem_cnt,
// it will cause rounding error.
Expand All @@ -102,8 +102,8 @@ class ArangeKernel final : public OpKernel, public CudaGraphSupport {
} else {
const auto* arange_cache = dynamic_cast<const ArangeOpKernelCache*>(cache);
auto arange_len = arange_cache->upper() - arange_cache->lower();
ArangeFunctor<device_type, T>()(ctx->stream(),
static_cast<T>(start + delta * arange_cache->lower()), delta,
auto lower = static_cast<T>(static_cast<float>(arange_cache->lower()));
ArangeFunctor<device_type, T>()(ctx->stream(), static_cast<T>(start + delta * lower), delta,
arange_len, output);
}
}
Expand Down
Loading