diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index a6e6d2e1127c8..bf76601c85549 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -1324,7 +1324,6 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/mlir_hlo", - "//xla/mlir_hlo:lhlo", "//xla/service:buffer_assignment", "//xla/service:hlo_parser", "//xla/service/llvm_ir:buffer_assignment_util", @@ -1338,6 +1337,7 @@ cc_library( "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Core", @@ -1366,7 +1366,6 @@ xla_cc_test( "//xla:util", "//xla/hlo/ir:hlo", "//xla/mlir_hlo", - "//xla/mlir_hlo:lhlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", # fixdeps: keep "//xla/translate/hlo_to_mhlo:hlo_utils", diff --git a/xla/service/gpu/fusions/fusions.cc b/xla/service/gpu/fusions/fusions.cc index 90445221113f6..187d750863626 100644 --- a/xla/service/gpu/fusions/fusions.cc +++ b/xla/service/gpu/fusions/fusions.cc @@ -76,45 +76,6 @@ bool IsDynamicUpdateSliceFusion(const HloFusionAnalysis& analysis) { } // namespace -std::optional>> -LmhloFusionInfo::GetCopyFusion() const { - auto params = GetHloOperands(fusion_op_); - auto outputs = GetHloOutputs(fusion_op_); - std::vector srcs; - srcs.reserve(outputs.size()); - - for (auto* root : analysis().fusion_roots()) { - if (root->opcode() != HloOpcode::kCopy || - root->operand(0)->opcode() != HloOpcode::kParameter || - !LayoutUtil::Equal(root->operand(0)->shape().layout(), - root->shape().layout())) { - return std::nullopt; - } - - mlir::Value src = params[root->operand(0)->parameter_number()]; - if (!GetAllocationSlice(src, allocations_).ok()) return std::nullopt; - - srcs.emplace_back(src); - } - - auto dsts = std::vector(outputs.begin(), outputs.end()); - DCHECK(srcs.size() == dsts.size()); - std::vector src_buffers; - std::vector dst_buffers; - for (int i = 0; i < srcs.size(); ++i) { - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice src_buffer, - GetAllocationSlice(srcs[i], allocations_)); - src_buffers.push_back(src_buffer); - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice dst_buffer, - GetAllocationSlice(dsts[i], allocations_)); - dst_buffers.push_back(dst_buffer); - } - - return std::make_unique(std::move(src_buffers), - std::move(dst_buffers), std::move(srcs), - std::move(dsts)); -} - std::optional>> HloFusionInfo::GetCopyFusion() const { std::vector src_buffers; @@ -154,10 +115,6 @@ HloFusionInfo::GetCopyFusion() const { /*dsts=*/std::vector()); } -bool LmhloFusionInfo::CanEmitDynamicUpdateSliceInPlace() const { - return CanEmitFusedDynamicUpdateSliceInPlaceForGpu(fusion_op_, allocations_); -} - bool HloFusionInfo::CanEmitDynamicUpdateSliceInPlace() const { auto ret = CanEmitFusedDynamicUpdateSliceInPlaceForGpu( instr_, buffer_assignment_, analysis().fusion_roots()); diff --git a/xla/service/gpu/fusions/fusions.h b/xla/service/gpu/fusions/fusions.h index 6379329fc8943..f91ac049acb4a 100644 --- a/xla/service/gpu/fusions/fusions.h +++ b/xla/service/gpu/fusions/fusions.h @@ -18,12 +18,10 @@ limitations under the License. #include #include -#include "absl/types/span.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/service/buffer_assignment.h" #include "xla/service/gpu/fusions/fusion_emitter.h" #include "xla/service/gpu/hlo_fusion_analysis.h" -#include "xla/statusor.h" namespace xla { namespace gpu { @@ -51,24 +49,6 @@ class FusionInfo { const HloFusionAnalysis& analysis_; }; -class LmhloFusionInfo : public FusionInfo { - public: - LmhloFusionInfo(const HloFusionAnalysis& analysis, - mlir::lmhlo::FusionOp fusion_op, - absl::Span allocations) - : FusionInfo(analysis), - fusion_op_(fusion_op), - allocations_(allocations) {} - - bool CanEmitDynamicUpdateSliceInPlace() const override; - std::optional>> - GetCopyFusion() const override; - - private: - mlir::lmhlo::FusionOp fusion_op_; - absl::Span allocations_; -}; - class HloFusionInfo : public FusionInfo { public: HloFusionInfo(const HloFusionAnalysis& analysis, diff --git a/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h b/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h index b458fa7a6037c..7caa0ddc1468c 100644 --- a/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h +++ b/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h @@ -20,8 +20,10 @@ limitations under the License. #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "mlir/IR/AffineMap.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "xla/hlo/ir/hlo_instructions.h" #include "xla/service/gpu/fusions/fusion_emitter.h" diff --git a/xla/service/gpu/ir_emission_utils.cc b/xla/service/gpu/ir_emission_utils.cc index 55e49f3905399..c6349da06abc3 100644 --- a/xla/service/gpu/ir_emission_utils.cc +++ b/xla/service/gpu/ir_emission_utils.cc @@ -15,11 +15,8 @@ limitations under the License. #include "xla/service/gpu/ir_emission_utils.h" -#include #include -#include #include -#include #include #include #include @@ -35,8 +32,8 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" -#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/IR/Attributes.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/FPEnv.h" #include "llvm/IR/IRBuilder.h" @@ -47,16 +44,10 @@ limitations under the License. #include "llvm/IR/Verifier.h" #include "llvm/Support/raw_ostream.h" #include "llvm/TargetParser/Triple.h" -#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project -#include "mlir/Dialect/Bufferization/IR/Bufferization.h" // from @llvm-project -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project -#include "mlir/IR/OpDefinition.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project @@ -69,8 +60,6 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/literal.h" -#include "xla/mlir_hlo/lhlo/IR/lhlo_ops.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/primitive_util.h" #include "xla/service/buffer_assignment.h" #include "xla/service/gpu/hlo_traversal.h" @@ -82,15 +71,12 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status.h" -#include "xla/status_macros.h" #include "xla/statusor.h" #include "xla/translate/mhlo_to_hlo/location_exporter.h" #include "xla/translate/mhlo_to_hlo/type_to_shape.h" -#include "xla/types.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" -#include "tsl/platform/ml_dtypes.h" #include "tsl/platform/statusor.h" namespace xla { @@ -314,65 +300,6 @@ llvm::Value* IsBlock0Thread0(llvm::IRBuilder<>* b) { return b->CreateAnd(is_thread0, is_block0); } -// Given an LMHLO op, returns the operand index of the first output operand. -// -// Notice that an operand alised to an output isn't an output, even though in -// that case WritesMlirBuffer() returns true on that operand. -// -// An operand is !WritesMlirBuffer() || equals (aliases) to a later operand. An -// output is the opposite, being both WritesMlirBuffer() and does not equal to -// any later operand. -int PartitionLmhloOperandsAndOutputs(mlir::Operation* op) { - CHECK(op->getDialect() == op->getContext()->getLoadedDialect("lmhlo")); - - int i; - for (i = op->getOperands().size() - 1; i >= 0; i--) { - const bool aliased = - std::find(op->getOperands().begin() + i + 1, op->getOperands().end(), - op->getOperand(i)) != op->getOperands().end(); - if (!WritesMlirBuffer(op, op->getOperand(i)) || aliased) { - break; - } - } - return i + 1; -} - -llvm::SmallVector GetHloOperands(mlir::Operation* op) { - if (auto fusion = mlir::dyn_cast(op)) { - return fusion.getInputBuffers(); - } - if (op->getDialect() == op->getContext()->getLoadedDialect("lmhlo")) { - int output_start = PartitionLmhloOperandsAndOutputs(op); - llvm::SmallVector operands; - for (int i = 0; i < output_start; i++) { - operands.push_back(op->getOperand(i)); - } - return operands; - } - if (op->getDialect() == op->getContext()->getLoadedDialect("mhlo")) { - return op->getOperands(); - } - LOG(FATAL) << "Unexpected op: " << llvm_ir::DumpToString(op); -} - -llvm::SmallVector GetHloOutputs(mlir::Operation* op) { - if (auto fusion = mlir::dyn_cast(op)) { - return fusion.getOutputBuffers(); - } - if (op->getDialect() == op->getContext()->getLoadedDialect("lmhlo")) { - int output_start = PartitionLmhloOperandsAndOutputs(op); - llvm::SmallVector outputs; - for (int i = output_start; i < op->getNumOperands(); i++) { - outputs.push_back(op->getOperand(i)); - } - return outputs; - } - if (op->getDialect() == op->getContext()->getLoadedDialect("mhlo")) { - return op->getResults(); - } - LOG(FATAL) << "Unexpected op: " << llvm_ir::DumpToString(op); -} - bool WritesMlirBuffer(mlir::Operation* op, mlir::Value operand) { llvm::SmallVector effects; mlir::cast(op).getEffectsOnValue(operand, @@ -397,81 +324,6 @@ static int64_t GetMemRefSizeInBytes(mlir::MemRefType type) { } } -static int64_t GetAllocationIndex(mlir::BlockArgument func_arg, - std::string* constant_name) { - auto func_op = - mlir::cast(func_arg.getParentRegion()->getParentOp()); - if (constant_name) { - if (auto constant_name_attr = func_op.getArgAttrOfType( - func_arg.getArgNumber(), "lmhlo.constant_name")) { - *constant_name = constant_name_attr.getValue().str(); - } - } - return func_arg.getArgNumber(); -} - -absl::StatusOr GetAllocationSlice( - mlir::Value v, absl::Span allocations, - std::string* constant_name) { - if (constant_name) { - constant_name->clear(); - } - - int64_t size = GetMemRefSizeInBytes(v.getType().cast()); - - // We match the following patterns here: - // base := ViewOp(arg) | get_global_memref (global_memref) | arg - // root := base | MemRefReinterpretCastOp(base) | CollapseShapeOp(base) - - if (auto cast = mlir::dyn_cast_or_null( - v.getDefiningOp())) { - v = cast.getViewSource(); - } - if (auto collapse_shape = - mlir::dyn_cast_or_null( - v.getDefiningOp())) { - v = collapse_shape.getSrc(); - } - - if (auto view = - mlir::dyn_cast_or_null(v.getDefiningOp())) { - TF_RET_CHECK(view.getSource().isa()); - - const BufferAllocation* allocation = allocations[GetAllocationIndex( - view.getSource().cast(), constant_name)]; - return BufferAllocation::Slice( - allocation, - mlir::cast(view.getByteShift().getDefiningOp()) - .getValue() - .cast() - .getValue() - .getSExtValue(), - size); - } - if (auto get_global = mlir::dyn_cast_or_null( - v.getDefiningOp())) { - auto module = get_global->getParentOfType(); - if (constant_name) { - *constant_name = get_global.getName().str(); - } - auto global = mlir::cast( - module.lookupSymbol(get_global.getName())); - int64_t index = - global->getAttrOfType("lmhlo.alloc").getInt(); - - return BufferAllocation::Slice(allocations[index], 0, - allocations[index]->size()); - } - if (auto arg = v.dyn_cast()) { - return BufferAllocation::Slice( - allocations[GetAllocationIndex(arg, constant_name)], 0, size); - } - - return Unimplemented( - "Operand has to be in the form of ViewOp(arg) or " - "StaticMemRefCastOp(ViewOp(arg)) or arg"); -} - absl::StatusOr GetAllocationSlice( const BufferAssignment& buffer_assignment, const HloInstruction* instr, const ShapeIndex& index) { @@ -480,8 +332,6 @@ absl::StatusOr GetAllocationSlice( std::vector GetOutputDefiningDynamicUpdateSlices( const std::vector& roots) { - // Same as GetOutputDefiningDynamicUpdateSliceOps but on a HLO fusion - // computation instead of a LMHLO FusionOp. std::vector dus_ops; for (const HloInstruction* root : roots) { while (root->opcode() == HloOpcode::kBitcast) { @@ -492,33 +342,6 @@ std::vector GetOutputDefiningDynamicUpdateSlices( dus_ops.push_back(root); } } - - return dus_ops; -} - -std::vector -GetOutputDefiningDynamicUpdateSliceOps(mlir::lmhlo::FusionOp fusion) { - std::vector dus_ops; - - auto fusion_results = fusion.getFusionResults(); - for (const auto& fusion_result : fusion_results) { - // A dynamic slice update is said to be "defining" of a result if that - // result is the output of a dynamic slice update, or if that result is - // the output of a bitcast of a dynamic slice update---since a bitcast may - // be handled here as a no-op. - if (auto dus = mlir::dyn_cast( - fusion_result.getDefiningOp())) { - dus_ops.push_back(dus); - } - - if (auto bitcast = mlir::dyn_cast( - fusion_result.getDefiningOp())) { - if (auto dus = mlir::dyn_cast( - bitcast.getOperand().getDefiningOp())) { - dus_ops.push_back(dus); - } - } - } return dus_ops; } @@ -566,7 +389,6 @@ absl::StatusOr CanEmitFusedDynamicUpdateSliceInPlaceForGpu( Shape update_shape = dus_instrs[0]->operand(1)->shape(); - // TODO(anlunx): Reuse this code in both HLO and LMHLO path. for (int i = 0; i < dus_instrs.size(); ++i) { auto* dus = Cast(dus_instrs[i]); @@ -682,135 +504,6 @@ absl::StatusOr CanEmitFusedDynamicUpdateSliceInPlaceForGpu( return true; } -bool CanEmitFusedDynamicUpdateSliceInPlaceForGpu( - mlir::lmhlo::FusionOp fusion, - absl::Span allocations) { - std::vector dus_ops = - GetOutputDefiningDynamicUpdateSliceOps(fusion); - - // This check could probably be relaxed: if code generation is made to use a - // separate parallel loop for each dynamic slice update, then it shouldn't be - // necessary for every output to be a dynamic slice update, nor to have the - // same shape. - if (dus_ops.size() != fusion.getFusionResults().size()) { - return false; - } - - auto output_buffers = fusion.getOutputBuffers(); - CHECK_GE(output_buffers.size(), 1); - CHECK_EQ(dus_ops.size(), output_buffers.size()); - - auto update_shape = - dus_ops[0].getUpdate().getType().cast().getShape(); - - // We can safely assume here that the slices being updated do not overlap, as - // constructing a fusion with them would not be safe otherwise. - for (auto [dus, output_buffer] : llvm::zip(dus_ops, output_buffers)) { - // Dynamic slice updates should have a single path to the root---this to - // avoid allowing a dynamic slice update to depend on another, as this would - // not be guaranteed to work with the current codegen. - if (!dus->hasOneUse()) { - return false; - } - - // Since the direct consumer of an output dynamic slice update may be a - // bitcast, we also check that this bitcast is used a single time. - // This property is also important because reads and writes on the parameter - // to be updated are done using the shape and layout of the dynamic slice - // update. This is a valid approach only if a subsequent bitcast is not read - // by any other op within the fusion---as this may result in codegen - // accessing elements using the wrong physical layout. - auto dus_user = *dus->user_begin(); - if (auto bitcast = mlir::dyn_cast(dus_user)) { - if (!bitcast->hasOneUse()) { - return false; - } - dus_user = *bitcast->user_begin(); - } - if (!mlir::isa(dus_user)) { - return false; - } - auto operand = dus.getOperand(); - // A bitcast separating a fusion input from a dynamic slice update can be - // treated as a no-op. - if (auto bitcast = - mlir::dyn_cast(operand.getDefiningOp())) { - operand = bitcast.getOperand(); - } - - auto parameter = mlir::dyn_cast( - operand.getDefiningOp()); - - if (!parameter) { - return false; - } - - // We require that the parameter being updated is only read at the same - // index positions by all users, since we otherwise risk a race condition - // when updating the parameter inplace. - std::queue q; - absl::flat_hash_set visited; - q.push(parameter); - visited.insert(parameter); - // We have already checked above that the DUS only has one user: a - // (possibly bitcasted) MaterializeInDestinationOp. So we don't need to - // visit it during the breadth-first search. - visited.insert(dus); - while (!q.empty()) { - auto op = q.front(); - q.pop(); - for (auto user : op->getUsers()) { - if (mlir::isa(user) && - dus->getOperand(0) == user->getOperand(0) && - update_shape == user->getResult(0) - .getType() - .cast() - .getShape()) { - // We can still emit in-place in this case if the same slice is - // accessed by the DUS and the DS. If they don't access the same - // slice, the two slices might partially overlap and read/write the - // same index at different times, and then we cannot guarantee that we - // read before it is overwritten. However if both access only a single - // element, there also can be no race condition. - if (mlir::ShapedType::getNumElements(update_shape) != 1 && - dus.getStartIndices() != - mlir::dyn_cast(user) - .getStartIndices()) { - return false; - } - } else if (user != dus && - !user->hasTrait() && - !mlir::isa( - user)) { - return false; - } - if (visited.insert(user).second) { - q.push(user); - } - } - } - - // This check could probably be relaxed: if code generation is made to use a - // separate parallel loop for each dynamic slice update, then it shouldn't - // be necessary for the shape to be the same for all the dynamic slice - // updates. Note that this equality check purposefully ignores the element - // type. - if (dus.getUpdate().getType().cast().getShape() != - update_shape) { - return false; - } - - auto maybe_lhs = GetAllocationSlice(parameter.getMemref(), allocations); - auto maybe_rhs = GetAllocationSlice(output_buffer, allocations); - - if (!(maybe_lhs.ok() && maybe_rhs.ok() && *maybe_lhs == *maybe_rhs)) { - return false; - } - } - - return true; -} - Shape GetShape(mlir::Value value) { Shape shape; if (value.getType().isa()) { @@ -1113,60 +806,6 @@ llvm::Type* GetIndexTypeForKernel(const HloInstruction* hlo, return b->getInt32Ty(); } -llvm::Type* GetIndexTypeForKernel(mlir::Operation* op, int64_t launch_size, - llvm::IRBuilder<>* b) { - auto shape_in_range = [&](const Shape& s) { - bool in_range = true; - ShapeUtil::ForEachSubshape(s, [&](const Shape& sub_shape, - const ShapeIndex& /*index*/) { - if (sub_shape.IsArray() && !IsInt32(ShapeUtil::ElementsIn(sub_shape))) { - in_range = false; - } - }); - - return in_range; - }; - - llvm::Type* i64_ty = b->getInt64Ty(); - // Check launch dimension - if (!IsInt32(launch_size)) { - return i64_ty; - } - - // Check the size of result tensors - for (auto result : GetHloOutputs(op)) { - if (!shape_in_range(GetShape(result))) { - return i64_ty; - } - } - - auto hlo_shape_in_range = [&](mlir::Value operand) -> bool { - return shape_in_range(GetShape(operand)); - }; - - // Check the size of input tensors - if (!absl::c_all_of(op->getOperands(), hlo_shape_in_range)) { - return i64_ty; - } - - // Check the size of the internal result tensors - if (auto fusion = mlir::dyn_cast(op)) { - auto result = fusion.getRegion().walk([&](mlir::Operation* op) { - for (mlir::Value result : op->getResults()) { - if (!hlo_shape_in_range(result)) { - return mlir::WalkResult::interrupt(); - } - } - return mlir::WalkResult::advance(); - }); - if (result.wasInterrupted()) { - return i64_ty; - } - } - - return b->getInt32Ty(); -} - std::string GetIrNameFromLoc(mlir::Location loc) { return llvm_ir::SanitizeConstantName( mlir::mhlo::GetDebugNameFromLocation(loc)); diff --git a/xla/service/gpu/ir_emission_utils.h b/xla/service/gpu/ir_emission_utils.h index 433546f7a61ec..8b3b3e7381713 100644 --- a/xla/service/gpu/ir_emission_utils.h +++ b/xla/service/gpu/ir_emission_utils.h @@ -23,18 +23,21 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" -#include "xla/literal.h" -#include "xla/mlir_hlo/lhlo/IR/lhlo_ops.h" #include "xla/service/buffer_assignment.h" #include "xla/service/gpu/hlo_traversal.h" -#include "xla/statusor.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/util.h" namespace xla { namespace gpu { @@ -116,7 +119,6 @@ llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset, // block 0 of the kernel. llvm::Value* IsBlock0Thread0(llvm::IRBuilder<>* b); -int PartitionLmhloOperandsAndOutputs(mlir::Operation* op); llvm::SmallVector GetHloOperands(mlir::Operation* op); llvm::SmallVector GetHloOutputs(mlir::Operation* op); @@ -130,10 +132,6 @@ absl::StatusOr GetAllocationSlice( const BufferAssignment& buffer_assignment, const HloInstruction* instr, const ShapeIndex& index); -bool CanEmitFusedDynamicUpdateSliceInPlaceForGpu( - mlir::lmhlo::FusionOp fusion, - absl::Span allocations); - absl::StatusOr CanEmitFusedDynamicUpdateSliceInPlaceForGpu( const HloFusionInstruction* fusion, const BufferAssignment* buffer_assignment, @@ -147,14 +145,6 @@ absl::StatusOr CanEmitFusedDynamicUpdateSliceInPlaceForGpu( std::vector GetOutputDefiningDynamicUpdateSlices( const std::vector& roots); -// Returns the DynamicUpdateSliceOp(s) defining the results of a fusion node. -// A dynamic slice update is said to be "defining" of a result if that result is -// the output of a dynamic slice update, or if that result is the output of a -// bitcast of a dynamic slice update---since such bitcast may be handled as a -// no-op. -std::vector -GetOutputDefiningDynamicUpdateSliceOps(mlir::lmhlo::FusionOp fusion); - Shape GetShape(mlir::Value value); // `is_boundary` returns `true` for edges that are on the boundary of the diff --git a/xla/service/gpu/ir_emission_utils_test.cc b/xla/service/gpu/ir_emission_utils_test.cc index d53c2142ffb8f..346b91b502d04 100644 --- a/xla/service/gpu/ir_emission_utils_test.cc +++ b/xla/service/gpu/ir_emission_utils_test.cc @@ -16,29 +16,14 @@ limitations under the License. #include "xla/service/gpu/ir_emission_utils.h" #include -#include #include #include -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/IR/Builders.h" // from @llvm-project -#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/IR/DialectRegistry.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/Operation.h" // from @llvm-project -#include "mlir/Parser/Parser.h" // from @llvm-project -#include "mlir/Support/LLVM.h" // from @llvm-project -#include "xla/hlo/ir/hlo_opcode.h" #include "xla/literal.h" #include "xla/literal_util.h" -#include "xla/mlir_hlo/lhlo/IR/lhlo_ops.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/tests/hlo_test_base.h" -#include "xla/translate/hlo_to_mhlo/hlo_utils.h" #include "xla/types.h" #include "xla/util.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" @@ -47,63 +32,6 @@ namespace gpu { class IrEmissionUtilsTest : public HloTestBase {}; -TEST_F(IrEmissionUtilsTest, TestOperandPartitionNoAlias) { - mlir::DialectRegistry registry; - registry.insert(); - registry.insert(); - mlir::MLIRContext context(registry); - - auto module = mlir::parseSourceString(R"( - func.func @foo(%arg0 : memref, %arg1 : memref, %arg2 : memref) { - "lmhlo.add" (%arg0, %arg1, %arg2) : (memref, memref, memref) -> () - "lmhlo.terminator" () : () -> () - } - )", - &context); - mlir::func::FuncOp func = - mlir::cast(module->lookupSymbol("foo")); - mlir::Operation* op = &func.getBody().front().front(); - EXPECT_EQ(2, PartitionLmhloOperandsAndOutputs(op)); -} - -TEST_F(IrEmissionUtilsTest, TestOperandPartitionWithAlias0) { - mlir::DialectRegistry registry; - registry.insert(); - registry.insert(); - mlir::MLIRContext context(registry); - - auto module = mlir::parseSourceString(R"( - func.func @foo(%arg0 : memref, %arg1 : memref, %arg2 : memref) { - "lmhlo.add" (%arg0, %arg1, %arg0) : (memref, memref, memref) -> () - "lmhlo.terminator" () : () -> () - } - )", - &context); - mlir::func::FuncOp func = - mlir::cast(module->lookupSymbol("foo")); - mlir::Operation* op = &func.getBody().front().front(); - EXPECT_EQ(2, PartitionLmhloOperandsAndOutputs(op)); -} - -TEST_F(IrEmissionUtilsTest, TestOperandPartitionWithAlias1) { - mlir::DialectRegistry registry; - registry.insert(); - registry.insert(); - mlir::MLIRContext context(registry); - - auto module = mlir::parseSourceString(R"( - func.func @foo(%arg0 : memref, %arg1 : memref, %arg2 : memref) { - "lmhlo.add" (%arg0, %arg1, %arg1) : (memref, memref, memref) -> () - "lmhlo.terminator" () : () -> () - } - )", - &context); - mlir::func::FuncOp func = - mlir::cast(module->lookupSymbol("foo")); - mlir::Operation* op = &func.getBody().front().front(); - EXPECT_EQ(2, PartitionLmhloOperandsAndOutputs(op)); -} - TEST_F(IrEmissionUtilsTest, FindTiledLogicalTranspose) { const char* hlo = R"( HloModule module diff --git a/xla/service/gpu/nccl_p2p_thunk_common.h b/xla/service/gpu/nccl_p2p_thunk_common.h index 97da5f03f678e..7b3c8f31f56b0 100644 --- a/xla/service/gpu/nccl_p2p_thunk_common.h +++ b/xla/service/gpu/nccl_p2p_thunk_common.h @@ -61,75 +61,6 @@ struct NcclP2PConfig { absl::StatusOr>> GetSourceTargetPairs( mlir::DictionaryAttr frontend_attributes); -// Returns the GroupMode for Send and Recv. -template -std::enable_if_t || - std::is_same_v, - CollectiveOpGroupMode> -GetGroupModeForSendRecv(OpT op) { - return GetCollectiveOpGroupMode(op.getChannelHandle().getHandle() > 0, - std::nullopt) - .value(); -} - -// Constructs the NcclP2PConfig for Send and Recv. -template -std::enable_if_t || - std::is_same_v, - NcclP2PConfig> -GetNcclP2PConfigForSendRecv(OpT op, int64_t replica_count, - int64_t partition_count) { - NcclP2PConfig p2p_config; - auto& config = p2p_config.config; - - config.operand_count = 1; - const Shape shape = GetShape(op.getOperand(0)); - config.operand_element_type.push_back(shape.element_type()); - - const int64_t channel_id = op.getChannelHandle().getHandle(); - config.group_mode = GetGroupModeForSendRecv(op); - // Emulate SetCollectiveOpKindAndID. - // Send and Recv ops have a non-optional channel id while collective-permute - // has an optional channel id. We use 0 to encode the send/recv transformed - // from collective-permute without a channel id. - if (channel_id >= 1) { - config.collective_op_kind = RendezvousKey::kCrossModule; - config.op_id = channel_id; - } else { - config.collective_op_kind = RendezvousKey::kCrossReplica; - mlir::ModuleOp parent = op->template getParentOfType(); - mlir::IntegerAttr unique_id = - parent->getAttrOfType("hlo.unique_id"); - config.op_id = static_cast(unique_id.getInt()); - } - - // All execution instances of a send/recv together form a replica group. - const int64_t num_participants = - config.group_mode == CollectiveOpGroupMode::kCrossReplica - ? replica_count - : partition_count; - config.replica_groups.emplace_back(); - ReplicaGroup& replica_group = config.replica_groups.front(); - for (int i = 0; i < num_participants; ++i) { - replica_group.add_replica_ids(i); - } - - auto source_target_pairs = GetSourceTargetPairs(op.getFrontendAttributes()); - TF_CHECK_OK(source_target_pairs.status()); - for (const std::pair& source_target : - *source_target_pairs) { - int64_t source = source_target.first; - int64_t target = source_target.second; - - p2p_config.id_to_source_target.insert({target, {}}).first->second.source = - source; - p2p_config.id_to_source_target.insert({source, {}}).first->second.target = - target; - } - - return p2p_config; -} - // Constructs the NcclP2PConfig for an HLO Send or Recv instruction. NcclP2PConfig GetNcclP2PConfigForSendRecv(const HloSendRecvInstruction* instr, const Shape& shape,