Skip to content

Commit 582dacd

Browse files
committed
[Codegen] Add f8 to f32 pass for arith.negf
Signed-off-by: Chi Liu<[email protected]>
1 parent 0781072 commit 582dacd

File tree

3 files changed

+89
-0
lines changed

3 files changed

+89
-0
lines changed

compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ iree_cc_library(
8989
"ConvertAccGEMMToGEMMPass.cpp"
9090
"ConvertBf16ArithToF32.cpp"
9191
"ConvertBf16ToUInt16Buffers.cpp"
92+
"BubbleUpOrdinalOps.cpp"
9293
"ConvertToDestinationPassingStylePass.cpp"
9394
"ConvolutionToIGEMM.cpp"
9495
"DecomposeAffineOpsPass.cpp"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
// Copyright 2023 The IREE Authors
2+
//
3+
// Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//===- BubbleUpOrdinalOpPass.cpp -----------------------------------------===//
7+
//
8+
// The workgroup count computation when using slices needs the ordinal
9+
// annotation ops to be bubbled up as much as possible. This pass implements
10+
// patterns to bubble these operations up.
11+
//
12+
//===---------------------------------------------------------------------===//
13+
14+
#include "iree/compiler/Codegen/Common/Passes.h"
15+
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
16+
#include "mlir/Dialect/Arith/IR/Arith.h"
17+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
18+
19+
namespace mlir::iree_compiler {
20+
21+
#define GEN_PASS_DEF_CONVERTF8ARITHTOF32PASS
22+
#include "iree/compiler/Codegen/Common/Passes.h.inc"
23+
24+
namespace {
25+
26+
/// Replace the following sequence
27+
///
28+
/// ```mlir
29+
/// %1 = arith.negf %input : vector<1x2x1x1x1x1xf8E4M3FNUZ>
30+
/// ```
31+
///
32+
/// with
33+
///
34+
/// ```mlir
35+
/// %0 = arith.extf %input : f8E4M3FNUZ to f32
36+
/// %1 = arith.negf %0 : vector<1x2x1x1x1x1xf32>
37+
/// %2 = arith.truncf %1 : vector<1x2x1x1x1x1xf8E4M3FNUZ>
38+
/// ```
39+
///
40+
/// to make all the uses flow through `flow.dispatch.workload.ordinal` ops.
41+
template <typename OpTy>
42+
struct F8ArithToF32CastOp : public OpRewritePattern<OpTy> {
43+
using OpRewritePattern::OpRewritePattern;
44+
LogicalResult matchAndRewrite(OpTy op,
45+
PatternRewriter &rewriter) const override {
46+
47+
auto inputType = op.getOperand().getType().cast<VectorType>();
48+
if (inputType.getElementType().isF8E4M3FNUZ()) {
49+
// Extend the input to f32
50+
auto extendedType = inputType.clone(rewriter.getF32Type());
51+
auto extended = rewriter.create<arith::ExtFOp>(op.getLoc(), extendedType,
52+
op.getOperand());
53+
54+
// Negate the extended value
55+
auto negated = rewriter.create<OpTy>(op.getLoc(), extended);
56+
57+
// Truncate back to f8E4M3FNUZ
58+
auto truncated =
59+
rewriter.create<arith::TruncFOp>(op.getLoc(), inputType, negated);
60+
61+
// Replace the original operation
62+
rewriter.replaceOp(op, truncated.getResult());
63+
return success();
64+
}
65+
return failure();
66+
}
67+
};
68+
69+
struct ConvertF8ArithToF32Pass final
70+
: impl::ConvertF8ArithToF32PassBase<ConvertF8ArithToF32Pass> {
71+
void runOnOperation() override;
72+
};
73+
} // namespace
74+
75+
void ConvertF8ArithToF32Pass::runOnOperation() {
76+
MLIRContext *context = &getContext();
77+
RewritePatternSet patterns(context);
78+
patterns.insert<F8ArithToF32CastOp<arith::NegFOp>>(context);
79+
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
80+
return signalPassFailure();
81+
}
82+
}
83+
} // namespace mlir::iree_compiler

compiler/src/iree/compiler/Codegen/Common/Passes.td

+5
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,11 @@ def ConvertBf16ToUInt16BuffersPass :
8080
let summary = "Convert BF16 buffer ops and conversions to simulated behavior with uint16.";
8181
}
8282

83+
def ConvertF8ArithToF32Pass :
84+
Pass<"iree-codegen-convert-f8-to-f32-buffers", ""> {
85+
let summary = "Convert f8 buffer ops and conversions to simulated behavior with f32.";
86+
}
87+
8388
def ConvertToDestinationPassingStylePass :
8489
InterfacePass<"iree-codegen-convert-to-destination-passing-style", "mlir::FunctionOpInterface"> {
8590
let summary =

0 commit comments

Comments
 (0)