Skip to content

Commit

Permalink
[r2.17-rocm-enhanced] Cherry pick fixes for 6.2 (#2621)
Browse files Browse the repository at this point in the history
* [r2.16] Update llvm path after compiler change (#2575)

Added the logic to add new path from ROCm >=6.3.0

* [r2.16] Update llvm path after compiler change

For urgent fix for SWDEV-472803
Added the logic to add new path from ROCm >=6.2.0

* skip sub h1w1 conv2d subtest due to miopen issue

* Update rocblas_wrapper.h

to fix , http://rocm-ci.amd.com/job/wip-tensorflow-ci-builder/5/console error

* Update rocm_blas.cc

---------

Co-authored-by: pramenku <[email protected]>
Co-authored-by: Chao Chen <[email protected]>
  • Loading branch information
3 people authored Aug 12, 2024
1 parent b86ece4 commit e4a0c4e
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 6 deletions.
6 changes: 5 additions & 1 deletion tensorflow/compiler/tests/conv2d_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.platform import googletest
from tensorflow.python.platform import googletest, test

DATA_FORMATS = (
("_data_format_NHWC", "NHWC"),
Expand Down Expand Up @@ -543,6 +543,8 @@ def testConv2D2x2FilterStride2Same(self, data_format):
@parameterized.named_parameters(*DATA_FORMATS)
def testConv2D2x2Depth3ValidBackpropInputStride1x1Dilation2x1(
self, data_format):
if test.is_built_with_rocm():
self.skipTest('only ROCm 6.2 will skip this subtest')
self._VerifyValues(
input_sizes=[1, 3, 6, 1],
filter_sizes=[2, 2, 1, 1],
Expand Down Expand Up @@ -601,6 +603,8 @@ def testConv2D2x2Depth3ValidBackpropInputDilation2x1(self, data_format):
@parameterized.named_parameters(*DATA_FORMATS)
def testConv2DKernelSizeMatchesInputSizeBackpropInputDilation2x2(
self, data_format):
if test.is_built_with_rocm():
self.skipTest('only ROCm 6.2 will skip this subtest')
self._VerifyValues(
input_sizes=[1, 3, 3, 1],
filter_sizes=[2, 2, 1, 2],
Expand Down
6 changes: 5 additions & 1 deletion third_party/gpus/rocm_configure.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,11 @@ def _rocm_include_path(repository_ctx, rocm_config, bash_bin):
inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/17/include")
inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/18/include")
inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/19/include")

if int(rocm_config.rocm_version_number) >= 60200:
inc_dirs.append(rocm_toolkit_path + "/lib/llvm/lib/clang/17/include")
inc_dirs.append(rocm_toolkit_path + "/lib/llvm/lib/clang/18/include")
inc_dirs.append(rocm_toolkit_path + "/lib/llvm/lib/clang/19/include")

# Support hcc based off clang 10.0.0 (for ROCm 3.3)
inc_dirs.append(rocm_toolkit_path + "/hcc/compiler/lib/clang/10.0.0/include/")
inc_dirs.append(rocm_toolkit_path + "/hcc/lib/clang/10.0.0/include")
Expand Down
4 changes: 3 additions & 1 deletion third_party/xla/xla/stream_executor/rocm/rocblas_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,9 @@ using stream_executor::internal::CachedDsoLoader::GetRocblasDsoHandle;
__macro(rocblas_destroy_handle) \
__macro(rocblas_get_stream) \
__macro(rocblas_set_stream) \
__macro(rocblas_set_atomics_mode)
__macro(rocblas_set_atomics_mode) \
__macro(rocblas_get_version_string) \
__macro(rocblas_get_version_string_size)

// clang-format on

Expand Down
6 changes: 3 additions & 3 deletions third_party/xla/xla/stream_executor/rocm/rocm_blas.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1251,16 +1251,16 @@ IMPL_DoBlasGemmBatched(float, wrap::rocblas_sgemm_strided_batched)
}

absl::Status ROCMBlas::GetVersion(std::string *version) {
#if TF_ROCM_VERSION >= 60300 // Not yet available in ROCM-6.1
#if TF_ROCM_VERSION > 60100 // Not available in ROCM-6.1
absl::MutexLock lock{&mu_};
size_t len = 0;
if (auto res = rocblas_get_version_string_size(&len);
if (auto res = wrap::rocblas_get_version_string_size(&len);
res != rocblas_status_success) {
return absl::InternalError(
absl::StrCat("GetVersion failed with: ", ToString(res)));
}
std::vector<char> buf(len + 1);
if (auto res = rocblas_get_version_string(buf.data(), len);
if (auto res = wrap::rocblas_get_version_string(buf.data(), len);
res != rocblas_status_success) {
return absl::InternalError(
absl::StrCat("GetVersion failed with: ", ToString(res)));
Expand Down

0 comments on commit e4a0c4e

Please sign in to comment.