|
| 1 | +// Copyright 2025 The IREE Authors |
| 2 | +// |
| 3 | +// Licensed under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +// |
| 7 | +//===--------------- ConvertUnsupportedFloatArithPass.cpp ----------------===// |
| 8 | +// |
| 9 | +// Emulate arith and vector floating point operations that use float types |
| 10 | +// which are unspported on a target by inserting extf/truncf pairs around all |
| 11 | +// such operations in order to produce arithmetic that can be performed while |
| 12 | +// preserving the original rounding behavior. |
| 13 | +// |
| 14 | +//===---------------------------------------------------------------------===// |
| 15 | + |
| 16 | +#include "iree/compiler/Codegen/Common/Passes.h" |
| 17 | +#include "iree/compiler/Codegen/Common/Transforms.h" |
| 18 | +#include "iree/compiler/Codegen/Utils/GPUUtils.h" |
| 19 | +#include "mlir/Dialect/AMDGPU/Utils/Chipset.h" |
| 20 | +#include "mlir/Dialect/Arith/IR/Arith.h" |
| 21 | +#include "mlir/Dialect/Arith/Transforms/Passes.h" |
| 22 | +#include "mlir/IR/BuiltinTypes.h" |
| 23 | +#include "mlir/Interfaces/FunctionInterfaces.h" |
| 24 | + |
| 25 | +#define DEBUG_TYPE "iree-convert-unsupported-float-arith" |
| 26 | + |
| 27 | +namespace mlir::iree_compiler { |
| 28 | + |
| 29 | +#define GEN_PASS_DEF_CONVERTUNSUPPORTEDFLOATARITHPASS |
| 30 | +#include "iree/compiler/Codegen/Common/Passes.h.inc" |
| 31 | + |
| 32 | +namespace { |
| 33 | + |
| 34 | +struct ConvertUnsupportedFloatArithPass final |
| 35 | + : public impl::ConvertUnsupportedFloatArithPassBase< |
| 36 | + ConvertUnsupportedFloatArithPass> { |
| 37 | + void runOnOperation() override; |
| 38 | + using Base::Base; |
| 39 | +}; |
| 40 | + |
| 41 | +} // namespace |
| 42 | + |
| 43 | +// Populates source and target conversion types based on the target |
| 44 | +// architecture. |
| 45 | +// TODO(pashu123): Refine the patterns based on the target arch. |
| 46 | +static void populateSourceAndTargetType(MLIRContext *ctx, Operation *op, |
| 47 | + SmallVectorImpl<Type> &sourceTypes, |
| 48 | + Type &targetType) { |
| 49 | + auto gpuAttr = getGPUTargetAttr(op); |
| 50 | + if (!gpuAttr) { |
| 51 | + return; |
| 52 | + } |
| 53 | + StringRef chipset = gpuAttr.getArch(); |
| 54 | + FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse(chipset); |
| 55 | + if (failed(maybeChipset)) { |
| 56 | + LLVM_DEBUG(llvm::dbgs() << "Invalid chip name"); |
| 57 | + return; |
| 58 | + } |
| 59 | + // Add source and target conversion types for gfx94{*} series. |
| 60 | + if (maybeChipset->majorVersion == 9 && maybeChipset->minorVersion == 4) { |
| 61 | + sourceTypes.insert(sourceTypes.end(), {Float8E4M3FNUZType::get(ctx), |
| 62 | + Float8E5M2FNUZType::get(ctx)}); |
| 63 | + targetType = Float32Type::get(ctx); |
| 64 | + } |
| 65 | + return; |
| 66 | +} |
| 67 | + |
| 68 | +void ConvertUnsupportedFloatArithPass::runOnOperation() { |
| 69 | + MLIRContext *context = &getContext(); |
| 70 | + FunctionOpInterface funcOp = getOperation(); |
| 71 | + SmallVector<Type> sourceTypes; |
| 72 | + Type targetType = nullptr; |
| 73 | + |
| 74 | + populateSourceAndTargetType(context, funcOp, sourceTypes, targetType); |
| 75 | + |
| 76 | + if (sourceTypes.empty() || !targetType) { |
| 77 | + LLVM_DEBUG(llvm::dbgs() << "no source or target type specified, float " |
| 78 | + "emulation will do nothing\n"); |
| 79 | + return; |
| 80 | + } |
| 81 | + |
| 82 | + if (llvm::is_contained(sourceTypes, targetType)) { |
| 83 | + funcOp->emitError() << " target type cannot be an unsupported source type"; |
| 84 | + return signalPassFailure(); |
| 85 | + } |
| 86 | + |
| 87 | + TypeConverter converter; |
| 88 | + arith::populateEmulateUnsupportedFloatsConversions(converter, sourceTypes, |
| 89 | + targetType); |
| 90 | + RewritePatternSet patterns(context); |
| 91 | + arith::populateEmulateUnsupportedFloatsPatterns(patterns, converter); |
| 92 | + ConversionTarget target(*context); |
| 93 | + arith::populateEmulateUnsupportedFloatsLegality(target, converter); |
| 94 | + |
| 95 | + if (failed(applyPartialConversion(funcOp, target, std::move(patterns)))) { |
| 96 | + signalPassFailure(); |
| 97 | + } |
| 98 | +} |
| 99 | + |
| 100 | +} // namespace mlir::iree_compiler |
0 commit comments