From 7946aacada17302d360f26cfd25a50b3674b97f8 Mon Sep 17 00:00:00 2001 From: pramenku <7664080+pramenku@users.noreply.github.com> Date: Mon, 12 Aug 2024 22:07:07 +0530 Subject: [PATCH] [dev-upstream] Update llvm path + Add wrap interface (#2577) * [dev-upstream] Update llvm path after compiler change Added the logic to add new path from ROCm >=6.3.0 To fix, https://ontrack-internal.amd.com/browse/SWDEV-470815 Pushing https://github.com/ROCm/tensorflow-upstream/pull/2575 to all branches as need. * Update rocm_configure.bzl * Update rocblas_wrapper.h Merging back https://github.com/ROCm/tensorflow-upstream/pull/2572/files * Update rocm_blas.cc --- third_party/gpus/rocm_configure.bzl | 4 +++ .../stream_executor/rocm/rocblas_wrapper.h | 28 +++++++++++-------- .../xla/xla/stream_executor/rocm/rocm_blas.cc | 6 ++-- 3 files changed, 23 insertions(+), 15 deletions(-) diff --git a/third_party/gpus/rocm_configure.bzl b/third_party/gpus/rocm_configure.bzl index 234b38e41e0dbc..f8d99affb45fe4 100644 --- a/third_party/gpus/rocm_configure.bzl +++ b/third_party/gpus/rocm_configure.bzl @@ -210,6 +210,10 @@ def _rocm_include_path(repository_ctx, rocm_config, bash_bin): inc_dirs.append(rocm_config.llvm_path + "/lib/clang/17/include") inc_dirs.append(rocm_config.llvm_path + "/lib/clang/18/include") inc_dirs.append(rocm_config.llvm_path + "/lib/clang/19/include") + if int(rocm_config.rocm_version_number) >= 60200: + inc_dirs.append(rocm_toolkit_path + "/lib/llvm/lib/clang/17/include") + inc_dirs.append(rocm_toolkit_path + "/lib/llvm/lib/clang/18/include") + inc_dirs.append(rocm_toolkit_path + "/lib/llvm/lib/clang/19/include") # Support hcc based off clang 10.0.0 (for ROCm 3.3) rocm_toolkit_path = realpath(repository_ctx, rocm_config.rocm_toolkit_path, bash_bin) diff --git a/third_party/xla/xla/stream_executor/rocm/rocblas_wrapper.h b/third_party/xla/xla/stream_executor/rocm/rocblas_wrapper.h index a077824fafabb7..3e399914837042 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocblas_wrapper.h +++ b/third_party/xla/xla/stream_executor/rocm/rocblas_wrapper.h @@ -262,18 +262,22 @@ using stream_executor::internal::CachedDsoLoader::GetRocblasDsoHandle; __macro(rocblas_gemm_batched_ex_get_solutions) \ __macro(rocblas_gemm_batched_ex_get_solutions_by_type) \ __macro(rocblas_gemm_strided_batched_ex_get_solutions) \ - __macro(rocblas_is_managing_device_memory) \ - __macro(rocblas_is_user_managing_device_memory) \ - __macro(rocblas_set_workspace) \ - __macro(rocblas_strsm_batched) \ - __macro(rocblas_dtrsm_batched) \ - __macro(rocblas_ctrsm_batched) \ - __macro(rocblas_ztrsm_batched) \ - __macro(rocblas_create_handle) \ - __macro(rocblas_destroy_handle) \ - __macro(rocblas_get_stream) \ - __macro(rocblas_set_stream) \ - __macro(rocblas_set_atomics_mode) + __macro(rocblas_strsm_batched) \ + __macro(rocblas_dtrsm_batched) \ + __macro(rocblas_ctrsm_batched) \ + __macro(rocblas_ztrsm_batched) \ + __macro(rocblas_create_handle) \ + __macro(rocblas_destroy_handle) \ + __macro(rocblas_get_stream) \ + __macro(rocblas_set_stream) \ + __macro(rocblas_set_atomics_mode) \ + __macro(rocblas_get_version_string) \ + __macro(rocblas_get_version_string_size) \ + __macro(rocblas_is_managing_device_memory) \ + __macro(rocblas_is_user_managing_device_memory) \ + __macro(rocblas_set_workspace) \ + __macro(rocblas_create_handle) + // clang-format on diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_blas.cc b/third_party/xla/xla/stream_executor/rocm/rocm_blas.cc index 7b16d6ac586625..a97d35650f2447 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_blas.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_blas.cc @@ -1259,16 +1259,16 @@ IMPL_DoBlasGemmBatched(float, wrap::rocblas_sgemm_strided_batched) } absl::Status ROCMBlas::GetVersion(std::string *version) { -#if TF_ROCM_VERSION >= 60300 // Not yet available in ROCM-6.1 +#if TF_ROCM_VERSION > 60100 // Not available in ROCM-6.1 absl::MutexLock lock{&mu_}; size_t len = 0; - if (auto res = rocblas_get_version_string_size(&len); + if (auto res = wrap::rocblas_get_version_string_size(&len); res != rocblas_status_success) { return absl::InternalError( absl::StrCat("GetVersion failed with: ", ToString(res))); } std::vector buf(len + 1); - if (auto res = rocblas_get_version_string(buf.data(), len); + if (auto res = wrap::rocblas_get_version_string(buf.data(), len); res != rocblas_status_success) { return absl::InternalError( absl::StrCat("GetVersion failed with: ", ToString(res)));