Skip to content

Commit

Permalink
Merge branch 'main' into dlpack_xpu_jax
Browse files Browse the repository at this point in the history
  • Loading branch information
wozna authored Feb 23, 2024
2 parents f04c3a4 + 0d636c1 commit 373cdc8
Show file tree
Hide file tree
Showing 277 changed files with 4,406 additions and 1,984 deletions.
1 change: 1 addition & 0 deletions build_tools/configure/configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,7 @@ def main():
DiscoverablePathsAndVersions(
clang_path=args.clang_path,
gcc_path=args.gcc_path,
lld_path=args.lld_path,
ld_library_path=args.ld_library_path,
cublas_version=args.cublas_version,
cuda_compute_capabilities=args.cuda_compute_capabilities,
Expand Down
45 changes: 45 additions & 0 deletions third_party/llvm/generated.patch
Original file line number Diff line number Diff line change
@@ -1 +1,46 @@
Auto generated patch. Do not edit or delete it, even if empty.
diff -ruN --strip-trailing-cr a/clang/test/CodeGen/aarch64-sme-inline-streaming-attrs.c b/clang/test/CodeGen/aarch64-sme-inline-streaming-attrs.c
--- a/clang/test/CodeGen/aarch64-sme-inline-streaming-attrs.c
+++ b/clang/test/CodeGen/aarch64-sme-inline-streaming-attrs.c
@@ -1,7 +1,7 @@
-// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -S -target-feature +sme -verify -DTEST_NONE %s
-// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -S -target-feature +sme -verify -DTEST_COMPATIBLE %s
-// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -S -target-feature +sme -verify -DTEST_STREAMING %s
-// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -S -target-feature +sme -verify -DTEST_LOCALLY %s
+// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -S -o /dev/null -target-feature +sme -verify -DTEST_NONE %s
+// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -S -o /dev/null -target-feature +sme -verify -DTEST_COMPATIBLE %s
+// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -S -o /dev/null -target-feature +sme -verify -DTEST_STREAMING %s
+// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -S -o /dev/null -target-feature +sme -verify -DTEST_LOCALLY %s

#define __ai __attribute__((always_inline))
__ai void inlined_fn(void) {}
diff -ruN --strip-trailing-cr a/llvm/test/Transforms/LoopVectorize/X86/pr72969.ll b/llvm/test/Transforms/LoopVectorize/X86/pr72969.ll
--- a/llvm/test/Transforms/LoopVectorize/X86/pr72969.ll
+++ b/llvm/test/Transforms/LoopVectorize/X86/pr72969.ll
@@ -1,5 +1,6 @@
; RUN: not --crash opt -mtriple=x86_64 -mattr=-avx,-avx2,-avx512f,+sse,-sse2,-sse3,-sse4.2 -passes=loop-vectorize -S < %s
; RUN: not --crash opt -mtriple=x86_64 -mattr=-avx,-avx2,-avx512f,+sse,-sse2,-sse3,-sse4.2 -passes=loop-vectorize -force-vector-width=4 -S < %s
+; REQUIRES: asserts

@h = global i64 0

diff -ruN --strip-trailing-cr a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel
--- a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel
@@ -613,14 +613,15 @@
libc_support_library(
name = "__support_fixed_point",
hdrs = [
- "src/__support/fixed_point/fx_rep.h",
"src/__support/fixed_point/fx_bits.h",
+ "src/__support/fixed_point/fx_rep.h",
],
deps = [
":__support_cpp_bit",
":__support_cpp_type_traits",
":__support_macros_attributes",
":__support_macros_optimization",
+ ":__support_math_extras",
":llvm_libc_macros_stdfix_macros",
],
)
4 changes: 2 additions & 2 deletions third_party/llvm/workspace.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive")

def repo(name):
"""Imports LLVM."""
LLVM_COMMIT = "2e29c91b96832504b9008be5e095f7dd640cdea0"
LLVM_SHA256 = "35edce994621f4a8e4413d1d2b833805ab8f203d38ca153dd45844e998bf6b2d"
LLVM_COMMIT = "e630a451b457e4d8d071a2b4f102b342bbea2d02"
LLVM_SHA256 = "184e7622a47609d960295e5e363466e9e60e6d9dbc20d554b3e1118ffd9f1bfb"

tf_http_archive(
name = name,
Expand Down
51 changes: 0 additions & 51 deletions third_party/stablehlo/temporary.patch
Original file line number Diff line number Diff line change
Expand Up @@ -163,27 +163,6 @@ diff --ruN a/stablehlo/CMakeLists.txt b/stablehlo/CMakeLists.txt

#-------------------------------------------------------------------------------
# Directory setup
diff --ruN a/stablehlo/MODULE.bazel b/stablehlo/MODULE.bazel
--- stablehlo/MODULE.bazel
+++ stablehlo/MODULE.bazel
@@ -1,3 +1,17 @@
+# Copyright 2024 The StableHLO Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
###############################################################################
# Bazel now uses Bzlmod by default to manage external dependencies.
# Please consider migrating your external dependencies from WORKSPACE to MODULE.bazel.
diff --ruN a/stablehlo/MODULE.bazel.lock b/stablehlo/MODULE.bazel.lock
--- stablehlo/MODULE.bazel.lock
+++ stablehlo/MODULE.bazel.lock
Expand Down Expand Up @@ -3541,39 +3520,9 @@ diff --ruN a/stablehlo/stablehlo/tests/verify_reduce.mlir b/stablehlo/stablehlo/
%0 = stablehlo.reduce(%arg0 init: %arg1) applies stablehlo.reshape across dimensions = [1] : (tensor<?x?xf32>, tensor<f32>) -> tensor<?xf32> loc("foo")
func.return %0 : tensor<?xf32>
}
diff --ruN a/stablehlo/stablehlo/transforms/ChloDecompositionPatterns.td b/stablehlo/stablehlo/transforms/ChloDecompositionPatterns.td
--- stablehlo/stablehlo/transforms/ChloDecompositionPatterns.td
+++ stablehlo/stablehlo/transforms/ChloDecompositionPatterns.td
@@ -1,6 +1,6 @@
// Copyright 2020 The IREE Authors
//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// Licensed under the Apache License, Version 2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

diff --ruN a/stablehlo/stablehlo/transforms/ChloLegalizeToStablehlo.cpp b/stablehlo/stablehlo/transforms/ChloLegalizeToStablehlo.cpp
--- stablehlo/stablehlo/transforms/ChloLegalizeToStablehlo.cpp
+++ stablehlo/stablehlo/transforms/ChloLegalizeToStablehlo.cpp
@@ -1,6 +1,6 @@
// Copyright 2020 The IREE Authors
//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// Licensed under the Apache License, Version 2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

diff --ruN a/stablehlo/stablehlo/transforms/StablehloAggressiveSimplification.cpp b/stablehlo/stablehlo/transforms/StablehloAggressiveSimplification.cpp
--- stablehlo/stablehlo/transforms/StablehloAggressiveSimplification.cpp
+++ stablehlo/stablehlo/transforms/StablehloAggressiveSimplification.cpp
@@ -1,6 +1,6 @@
// Copyright 2023 The IREE Authors
//
-// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// Licensed under the Apache License, Version 2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

@@ -126,9 +126,8 @@

// The canonical form has the constant operand as the RHS.
Expand Down
4 changes: 2 additions & 2 deletions third_party/stablehlo/workspace.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")

def repo():
# LINT.IfChange
STABLEHLO_COMMIT = "f3b20b3a0558187ee27a9428bc1d3c2a3ba459cf"
STABLEHLO_SHA256 = "8023a23de405fcffc70cca457389c4e637ee9ecab7fe2aaf292a69fc118edb5a"
STABLEHLO_COMMIT = "e708c82502982697540886738a307f72f9e9a7ff"
STABLEHLO_SHA256 = "3fecbe7779bee0801af746d974738748f7b461df54a4f610b32bb75647b32125"
# LINT.ThenChange(Google-internal path)

tf_http_archive(
Expand Down
45 changes: 45 additions & 0 deletions third_party/tsl/third_party/llvm/generated.patch
Original file line number Diff line number Diff line change
@@ -1 +1,46 @@
Auto generated patch. Do not edit or delete it, even if empty.
diff -ruN --strip-trailing-cr a/clang/test/CodeGen/aarch64-sme-inline-streaming-attrs.c b/clang/test/CodeGen/aarch64-sme-inline-streaming-attrs.c
--- a/clang/test/CodeGen/aarch64-sme-inline-streaming-attrs.c
+++ b/clang/test/CodeGen/aarch64-sme-inline-streaming-attrs.c
@@ -1,7 +1,7 @@
-// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -S -target-feature +sme -verify -DTEST_NONE %s
-// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -S -target-feature +sme -verify -DTEST_COMPATIBLE %s
-// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -S -target-feature +sme -verify -DTEST_STREAMING %s
-// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -S -target-feature +sme -verify -DTEST_LOCALLY %s
+// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -S -o /dev/null -target-feature +sme -verify -DTEST_NONE %s
+// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -S -o /dev/null -target-feature +sme -verify -DTEST_COMPATIBLE %s
+// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -S -o /dev/null -target-feature +sme -verify -DTEST_STREAMING %s
+// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -S -o /dev/null -target-feature +sme -verify -DTEST_LOCALLY %s

#define __ai __attribute__((always_inline))
__ai void inlined_fn(void) {}
diff -ruN --strip-trailing-cr a/llvm/test/Transforms/LoopVectorize/X86/pr72969.ll b/llvm/test/Transforms/LoopVectorize/X86/pr72969.ll
--- a/llvm/test/Transforms/LoopVectorize/X86/pr72969.ll
+++ b/llvm/test/Transforms/LoopVectorize/X86/pr72969.ll
@@ -1,5 +1,6 @@
; RUN: not --crash opt -mtriple=x86_64 -mattr=-avx,-avx2,-avx512f,+sse,-sse2,-sse3,-sse4.2 -passes=loop-vectorize -S < %s
; RUN: not --crash opt -mtriple=x86_64 -mattr=-avx,-avx2,-avx512f,+sse,-sse2,-sse3,-sse4.2 -passes=loop-vectorize -force-vector-width=4 -S < %s
+; REQUIRES: asserts

@h = global i64 0

diff -ruN --strip-trailing-cr a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel
--- a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel
@@ -613,14 +613,15 @@
libc_support_library(
name = "__support_fixed_point",
hdrs = [
- "src/__support/fixed_point/fx_rep.h",
"src/__support/fixed_point/fx_bits.h",
+ "src/__support/fixed_point/fx_rep.h",
],
deps = [
":__support_cpp_bit",
":__support_cpp_type_traits",
":__support_macros_attributes",
":__support_macros_optimization",
+ ":__support_math_extras",
":llvm_libc_macros_stdfix_macros",
],
)
4 changes: 2 additions & 2 deletions third_party/tsl/third_party/llvm/workspace.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive")

def repo(name):
"""Imports LLVM."""
LLVM_COMMIT = "2e29c91b96832504b9008be5e095f7dd640cdea0"
LLVM_SHA256 = "35edce994621f4a8e4413d1d2b833805ab8f203d38ca153dd45844e998bf6b2d"
LLVM_COMMIT = "e630a451b457e4d8d071a2b4f102b342bbea2d02"
LLVM_SHA256 = "184e7622a47609d960295e5e363466e9e60e6d9dbc20d554b3e1118ffd9f1bfb"

tf_http_archive(
name = name,
Expand Down
4 changes: 2 additions & 2 deletions third_party/tsl/third_party/tf_runtime/workspace.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ def repo():
"""Imports TFRT."""

# Attention: tools parse and update these lines.
TFRT_COMMIT = "614d0da37224866a99ca4feb76214f8d10ecc33d"
TFRT_SHA256 = "8d70e5a746aed992ffb24e4ba2d0f1584251b30fd8362d1e5b8f6880d5d3f186"
TFRT_COMMIT = "0aeefb1660d7e37964b2bb71b1f518096bda9a25"
TFRT_SHA256 = "a3df827d7896774cb1d80bf4e1c79ab05c268f29bd4d3db1fb5a4b9c2079d8e3"

tf_http_archive(
name = "tf_runtime",
Expand Down
1 change: 1 addition & 0 deletions third_party/tsl/tsl/profiler/lib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ cc_library(
cc_library(
name = "nvtx_utils",
hdrs = ["nvtx_utils.h"],
local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]),
visibility = ["//visibility:public"],
deps = [
"//tsl/platform:logging",
Expand Down
11 changes: 10 additions & 1 deletion third_party/tsl/tsl/profiler/lib/nvtx_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License.

#if GOOGLE_CUDA
#include "nvtx3/nvToolsExt.h"
#include "nvtx3/nvToolsExtPayload.h"
#else
// Some typedef to help build without NVTX.
typedef void* nvtxDomainHandle_t;
Expand Down Expand Up @@ -49,6 +50,7 @@ inline bool RangesEnabled() {
#endif
}

// Older/simpler version; NVTX implementation copies a C-style string each time
inline void RangePush(nvtxDomainHandle_t domain, const char* ascii) {
#if GOOGLE_CUDA
nvtxEventAttributes_t attrs{};
Expand All @@ -63,13 +65,20 @@ inline void RangePush(nvtxDomainHandle_t domain, const std::string& str) {
RangePush(domain, str.c_str());
}

inline void RangePush(nvtxDomainHandle_t domain, nvtxStringHandle_t handle) {
// More powerful version: pass a registered string instead of a C-style string,
// and attach a generic payload. The Annotation type must implement a method
// called NvtxSchemaId() that allows the NVTX backend to interpret the payload.
template <typename Annotation>
void RangePush(nvtxDomainHandle_t domain, nvtxStringHandle_t handle,
const Annotation& annotation) {
#if GOOGLE_CUDA
nvtxEventAttributes_t attrs{};
attrs.version = NVTX_VERSION;
attrs.size = NVTX_EVENT_ATTRIB_STRUCT_SIZE;
attrs.messageType = NVTX_MESSAGE_TYPE_REGISTERED;
attrs.message.registered = handle;
NVTX_PAYLOAD_EVTATTR_SET(attrs, annotation.NvtxSchemaId(), &annotation,
sizeof(Annotation));
::nvtxDomainRangePushEx(domain, &attrs);
#endif
}
Expand Down
1 change: 1 addition & 0 deletions xla/backends/interpreter/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ cc_library(
"//xla/stream_executor:stream_executor_internal",
"//xla/stream_executor/host:host_stream",
"@com_google_absl//absl/functional:any_invocable",
"@com_google_absl//absl/log",
"@com_google_absl//absl/types:span",
],
)
11 changes: 6 additions & 5 deletions xla/backends/interpreter/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ namespace {

// Handles custom_call ops during evaluation by routing them through the global
// CPU registry used by other CPU-based backends.
StatusOr<Literal> HandleEvaluatorCustomCall(
absl::StatusOr<Literal> HandleEvaluatorCustomCall(
const HloInstruction* custom_call, absl::Span<const Literal*> operands) {
// Find the target C function in the global registry.
auto* registry = CustomCallTargetRegistry::Global();
Expand Down Expand Up @@ -110,15 +110,15 @@ Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) {
return pipeline.Run(hlo_module).status();
}

StatusOr<std::unique_ptr<HloModule>> InterpreterCompiler::RunHloPasses(
absl::StatusOr<std::unique_ptr<HloModule>> InterpreterCompiler::RunHloPasses(
std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* /*stream_exec*/,
const CompileOptions& /*options*/) {
VLOG(1) << "Run hlo passes on graph " << hlo_module->name();
TF_RETURN_IF_ERROR(RunHloOptimization(hlo_module.get()));
return std::move(hlo_module);
}

StatusOr<std::unique_ptr<Executable>> InterpreterCompiler::RunBackend(
absl::StatusOr<std::unique_ptr<Executable>> InterpreterCompiler::RunBackend(
std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* stream_exec,
const CompileOptions& /*options*/) {
TF_RET_CHECK(stream_exec != nullptr);
Expand Down Expand Up @@ -147,7 +147,8 @@ StatusOr<std::unique_ptr<Executable>> InterpreterCompiler::RunBackend(
return std::move(executable);
}

StatusOr<std::vector<std::unique_ptr<Executable>>> InterpreterCompiler::Compile(
absl::StatusOr<std::vector<std::unique_ptr<Executable>>>
InterpreterCompiler::Compile(
std::unique_ptr<HloModuleGroup> module_group,
std::vector<std::vector<se::StreamExecutor*>> stream_exec,
const CompileOptions& options) {
Expand All @@ -171,7 +172,7 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> InterpreterCompiler::Compile(
return std::move(ret);
}

StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
absl::StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
InterpreterCompiler::CompileAheadOfTime(
std::unique_ptr<HloModuleGroup> module_group,
const AotCompilationOptions& aot_options) {
Expand Down
8 changes: 4 additions & 4 deletions xla/backends/interpreter/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,18 @@ class InterpreterCompiler : public Compiler {
InterpreterCompiler() {}
~InterpreterCompiler() override {}

StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
absl::StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* stream_exec,
const CompileOptions& options) override;
StatusOr<std::unique_ptr<Executable>> RunBackend(
absl::StatusOr<std::unique_ptr<Executable>> RunBackend(
std::unique_ptr<HloModule> hlo_module, se::StreamExecutor* stream_exec,
const CompileOptions& options) override;
StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
absl::StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
std::unique_ptr<HloModuleGroup> module_group,
std::vector<std::vector<se::StreamExecutor*>> stream_exec,
const CompileOptions& options) override;

StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
absl::StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
const AotCompilationOptions& aot_options) override;

Expand Down
2 changes: 1 addition & 1 deletion xla/backends/interpreter/executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ InterpreterExecutable::InterpreterExecutable(
}
}

StatusOr<Literal> InterpreterExecutable::Evaluate(
absl::StatusOr<Literal> InterpreterExecutable::Evaluate(
const ServiceExecutableRunOptions* run_options,
const HloComputation& computation, absl::Span<const Literal> arg_literals) {
// Execute the graph using the HloEvaluator.
Expand Down
7 changes: 4 additions & 3 deletions xla/backends/interpreter/executable.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,10 @@ class InterpreterExecutable : public InterpreterExecutableBase {
static int64_t ShapeSizeBytes(const Shape& shape);

protected:
StatusOr<Literal> Evaluate(const ServiceExecutableRunOptions* run_options,
const HloComputation& computation,
absl::Span<const Literal> arg_literals) override
absl::StatusOr<Literal> Evaluate(
const ServiceExecutableRunOptions* run_options,
const HloComputation& computation,
absl::Span<const Literal> arg_literals) override
ABSL_LOCKS_EXCLUDED(evaluator_lock_);

// The interpreter interprets executables with an HloEvaluator.
Expand Down
4 changes: 2 additions & 2 deletions xla/backends/interpreter/executable_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ InterpreterExecutableBase::InterpreterExecutableBase(
: Executable(std::move(hlo_module), /*hlo_profile_printer_data=*/nullptr,
/*hlo_profile_index_map=*/nullptr) {}

StatusOr<ExecutionOutput> InterpreterExecutableBase::ExecuteAsyncOnStream(
absl::StatusOr<ExecutionOutput> InterpreterExecutableBase::ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options,
std::vector<ExecutionInput> arguments,
HloExecutionProfile* hlo_execution_profile) {
Expand Down Expand Up @@ -150,7 +150,7 @@ StatusOr<ExecutionOutput> InterpreterExecutableBase::ExecuteAsyncOnStream(
return std::move(result);
}

StatusOr<ExecutionOutput>
absl::StatusOr<ExecutionOutput>
InterpreterExecutableBase::AllocateOutputMemoryWithInputReuse(
const Shape& shape, const HloInputOutputAliasConfig& alias_config,
se::DeviceMemoryAllocator* allocator,
Expand Down
Loading

0 comments on commit 373cdc8

Please sign in to comment.