|
| 1 | +// Copyright 2023 The IREE Authors |
| 2 | +// |
| 3 | +// Licensed under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +//===- BubbleUpOrdinalOpPass.cpp -----------------------------------------===// |
| 7 | +// |
| 8 | +// The workgroup count computation when using slices needs the ordinal |
| 9 | +// annotation ops to be bubbled up as much as possible. This pass implements |
| 10 | +// patterns to bubble these operations up. |
| 11 | +// |
| 12 | +//===---------------------------------------------------------------------===// |
| 13 | + |
| 14 | +#include "iree/compiler/Codegen/Common/Passes.h" |
| 15 | +#include "iree/compiler/Dialect/Flow/IR/FlowOps.h" |
| 16 | +#include "mlir/Dialect/Arith/IR/Arith.h" |
| 17 | +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 18 | + |
| 19 | +namespace mlir::iree_compiler { |
| 20 | + |
| 21 | +#define GEN_PASS_DEF_CONVERTF8ARITHTOF32PASS |
| 22 | +#include "iree/compiler/Codegen/Common/Passes.h.inc" |
| 23 | + |
| 24 | +namespace { |
| 25 | + |
| 26 | +/// Replace the following sequence |
| 27 | +/// |
| 28 | +/// ```mlir |
| 29 | +/// %1 = arith.negf %input : vector<1x2x1x1x1x1xf8E4M3FNUZ> |
| 30 | +/// ``` |
| 31 | +/// |
| 32 | +/// with |
| 33 | +/// |
| 34 | +/// ```mlir |
| 35 | +/// %0 = arith.extf %input : f8E4M3FNUZ to f32 |
| 36 | +/// %1 = arith.negf %0 : vector<1x2x1x1x1x1xf32> |
| 37 | +/// %2 = arith.truncf %1 : vector<1x2x1x1x1x1xf8E4M3FNUZ> |
| 38 | +/// ``` |
| 39 | +/// |
| 40 | +/// to make all the uses flow through `flow.dispatch.workload.ordinal` ops. |
| 41 | +template <typename OpTy> |
| 42 | +struct F8ArithToF32CastOp : public OpRewritePattern<OpTy> { |
| 43 | + using OpRewritePattern::OpRewritePattern; |
| 44 | + LogicalResult matchAndRewrite(OpTy op, |
| 45 | + PatternRewriter &rewriter) const override { |
| 46 | + |
| 47 | + auto inputType = op.getOperand().getType().cast<VectorType>(); |
| 48 | + if (inputType.getElementType().isF8E4M3FNUZ()) { |
| 49 | + // Extend the input to f32 |
| 50 | + auto extendedType = inputType.clone(rewriter.getF32Type()); |
| 51 | + auto extended = rewriter.create<arith::ExtFOp>(op.getLoc(), extendedType, |
| 52 | + op.getOperand()); |
| 53 | + |
| 54 | + // Negate the extended value |
| 55 | + auto negated = rewriter.create<OpTy>(op.getLoc(), extended); |
| 56 | + |
| 57 | + // Truncate back to f8E4M3FNUZ |
| 58 | + auto truncated = |
| 59 | + rewriter.create<arith::TruncFOp>(op.getLoc(), inputType, negated); |
| 60 | + |
| 61 | + // Replace the original operation |
| 62 | + rewriter.replaceOp(op, truncated.getResult()); |
| 63 | + return success(); |
| 64 | + } |
| 65 | + return failure(); |
| 66 | + } |
| 67 | +}; |
| 68 | + |
| 69 | +struct ConvertF8ArithToF32Pass final |
| 70 | + : impl::ConvertF8ArithToF32PassBase<ConvertF8ArithToF32Pass> { |
| 71 | + void runOnOperation() override; |
| 72 | +}; |
| 73 | +} // namespace |
| 74 | + |
| 75 | +void ConvertF8ArithToF32Pass::runOnOperation() { |
| 76 | + MLIRContext *context = &getContext(); |
| 77 | + RewritePatternSet patterns(context); |
| 78 | + patterns.insert<F8ArithToF32CastOp<arith::NegFOp>>(context); |
| 79 | + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { |
| 80 | + return signalPassFailure(); |
| 81 | + } |
| 82 | +} |
| 83 | +} // namespace mlir::iree_compiler |
0 commit comments