Skip to content

Commit 0ff26a7

Browse files
pashu123AmosLewis
andauthored
[Codegen] Add support to emulate unsupported float type (#19943)
This change enables the conversion of types such as f8E4M3FNUZ and f8E5M2FNUZ (emulated via the existing APIs) into f32 operations. The conversion logic is now tightly coupled with the executable target attribute, so that it is applied only for the gfx942 target. This removes the need for manual pass configuration to specify source types and aligns the behaviour with the target’s capabilities. For any new conversion, just populate the conversion target with source and target types. FIX: #19921 (comment) --------- Signed-off-by: Chi Liu<[email protected]> Co-authored-by: AmosLewis <[email protected]>
1 parent 8fab35c commit 0ff26a7

File tree

8 files changed

+159
-0
lines changed

8 files changed

+159
-0
lines changed

compiler/src/iree/compiler/Codegen/Common/BUILD.bazel

+3
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ iree_compiler_cc_library(
9898
"ConvertBf16ArithToF32.cpp",
9999
"ConvertBf16ToUInt16Buffers.cpp",
100100
"ConvertToDestinationPassingStylePass.cpp",
101+
"ConvertUnsupportedFloatArithPass.cpp",
101102
"ConvolutionToIGEMM.cpp",
102103
"DecomposeAffineOpsPass.cpp",
103104
"DecomposeConvolutionToLowerDimOps.cpp",
@@ -203,6 +204,8 @@ iree_compiler_cc_library(
203204
"//compiler/src/iree/compiler/Utils",
204205
"//llvm-external-projects/iree-dialects:IREELinalgTransformDialect",
205206
"@llvm-project//llvm:Support",
207+
"@llvm-project//mlir:AMDGPUDialect",
208+
"@llvm-project//mlir:AMDGPUUtils",
206209
"@llvm-project//mlir:AffineAnalysis",
207210
"@llvm-project//mlir:AffineDialect",
208211
"@llvm-project//mlir:AffineTransforms",

compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt

+3
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ iree_cc_library(
9191
"ConvertBf16ArithToF32.cpp"
9292
"ConvertBf16ToUInt16Buffers.cpp"
9393
"ConvertToDestinationPassingStylePass.cpp"
94+
"ConvertUnsupportedFloatArithPass.cpp"
9495
"ConvolutionToIGEMM.cpp"
9596
"DecomposeAffineOpsPass.cpp"
9697
"DecomposeConvolutionToLowerDimOps.cpp"
@@ -162,6 +163,8 @@ iree_cc_library(
162163
::PassesIncGen
163164
IREELinalgTransformDialect
164165
LLVMSupport
166+
MLIRAMDGPUDialect
167+
MLIRAMDGPUUtils
165168
MLIRAffineAnalysis
166169
MLIRAffineDialect
167170
MLIRAffineTransforms
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
// Copyright 2025 The IREE Authors
2+
//
3+
// Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===--------------- ConvertUnsupportedFloatArithPass.cpp ----------------===//
8+
//
9+
// Emulate arith and vector floating point operations that use float types
10+
// which are unspported on a target by inserting extf/truncf pairs around all
11+
// such operations in order to produce arithmetic that can be performed while
12+
// preserving the original rounding behavior.
13+
//
14+
//===---------------------------------------------------------------------===//
15+
16+
#include "iree/compiler/Codegen/Common/Passes.h"
17+
#include "iree/compiler/Codegen/Common/Transforms.h"
18+
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
19+
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
20+
#include "mlir/Dialect/Arith/IR/Arith.h"
21+
#include "mlir/Dialect/Arith/Transforms/Passes.h"
22+
#include "mlir/IR/BuiltinTypes.h"
23+
#include "mlir/Interfaces/FunctionInterfaces.h"
24+
25+
#define DEBUG_TYPE "iree-convert-unsupported-float-arith"
26+
27+
namespace mlir::iree_compiler {
28+
29+
#define GEN_PASS_DEF_CONVERTUNSUPPORTEDFLOATARITHPASS
30+
#include "iree/compiler/Codegen/Common/Passes.h.inc"
31+
32+
namespace {
33+
34+
struct ConvertUnsupportedFloatArithPass final
35+
: public impl::ConvertUnsupportedFloatArithPassBase<
36+
ConvertUnsupportedFloatArithPass> {
37+
void runOnOperation() override;
38+
using Base::Base;
39+
};
40+
41+
} // namespace
42+
43+
// Populates source and target conversion types based on the target
44+
// architecture.
45+
// TODO(pashu123): Refine the patterns based on the target arch.
46+
static void populateSourceAndTargetType(MLIRContext *ctx, Operation *op,
47+
SmallVectorImpl<Type> &sourceTypes,
48+
Type &targetType) {
49+
auto gpuAttr = getGPUTargetAttr(op);
50+
if (!gpuAttr) {
51+
return;
52+
}
53+
StringRef chipset = gpuAttr.getArch();
54+
FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse(chipset);
55+
if (failed(maybeChipset)) {
56+
LLVM_DEBUG(llvm::dbgs() << "Invalid chip name");
57+
return;
58+
}
59+
// Add source and target conversion types for gfx94{*} series.
60+
if (maybeChipset->majorVersion == 9 && maybeChipset->minorVersion == 4) {
61+
sourceTypes.insert(sourceTypes.end(), {Float8E4M3FNUZType::get(ctx),
62+
Float8E5M2FNUZType::get(ctx)});
63+
targetType = Float32Type::get(ctx);
64+
}
65+
return;
66+
}
67+
68+
void ConvertUnsupportedFloatArithPass::runOnOperation() {
69+
MLIRContext *context = &getContext();
70+
FunctionOpInterface funcOp = getOperation();
71+
SmallVector<Type> sourceTypes;
72+
Type targetType = nullptr;
73+
74+
populateSourceAndTargetType(context, funcOp, sourceTypes, targetType);
75+
76+
if (sourceTypes.empty() || !targetType) {
77+
LLVM_DEBUG(llvm::dbgs() << "no source or target type specified, float "
78+
"emulation will do nothing\n");
79+
return;
80+
}
81+
82+
if (llvm::is_contained(sourceTypes, targetType)) {
83+
funcOp->emitError() << " target type cannot be an unsupported source type";
84+
return signalPassFailure();
85+
}
86+
87+
TypeConverter converter;
88+
arith::populateEmulateUnsupportedFloatsConversions(converter, sourceTypes,
89+
targetType);
90+
RewritePatternSet patterns(context);
91+
arith::populateEmulateUnsupportedFloatsPatterns(patterns, converter);
92+
ConversionTarget target(*context);
93+
arith::populateEmulateUnsupportedFloatsLegality(target, converter);
94+
95+
if (failed(applyPartialConversion(funcOp, target, std::move(patterns)))) {
96+
signalPassFailure();
97+
}
98+
}
99+
100+
} // namespace mlir::iree_compiler

compiler/src/iree/compiler/Codegen/Common/Passes.td

+8
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,14 @@ def ConvertBf16ToUInt16BuffersPass :
8080
let summary = "Convert BF16 buffer ops and conversions to simulated behavior with uint16.";
8181
}
8282

83+
def ConvertUnsupportedFloatArithPass
84+
: InterfacePass<"iree-convert-unsupported-float-arith",
85+
"mlir::FunctionOpInterface"> {
86+
let summary = "Convert arith operations on unsupported(source types) float "
87+
"types to the target type. Populates the source and target "
88+
"based on the target architecture.";
89+
}
90+
8391
def ConvertToDestinationPassingStylePass :
8492
InterfacePass<"iree-codegen-convert-to-destination-passing-style", "mlir::FunctionOpInterface"> {
8593
let summary =

compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel

+1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ iree_lit_test_suite(
2929
"convert_bf16_to_uint16_buffers.mlir",
3030
"convert_bf16_arith_to_f32.mlir",
3131
"convert_to_destination_passing_style.mlir",
32+
"convert_unsupported_float_arith.mlir",
3233
"convolution_to_igemm.mlir",
3334
"convolutions.mlir",
3435
"erase_dead_alloc_and_stores.mlir",

compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ iree_lit_test_suite(
2525
"convert_bf16_arith_to_f32.mlir"
2626
"convert_bf16_to_uint16_buffers.mlir"
2727
"convert_to_destination_passing_style.mlir"
28+
"convert_unsupported_float_arith.mlir"
2829
"convolution_to_igemm.mlir"
2930
"convolutions.mlir"
3031
"decompose_affine_ops.mlir"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-convert-unsupported-float-arith))" %s | FileCheck %s
2+
3+
// CHECK-LABEL: func.func @negf_f8_unsupported
4+
// CHECK-SAME: (%[[ARG0:.*]]: f8E4M3FNUZ) -> f8E4M3FNUZ
5+
// CHECK: %[[EXT:.*]] = arith.extf %[[ARG0]] {{.*}} : f8E4M3FNUZ to f32
6+
// CHECK: %[[NEG:.*]] = arith.negf %[[EXT]] : f32
7+
// CHECK: %[[TRUNC:.*]] = arith.truncf %[[NEG]] {{.*}} : f32 to f8E4M3FNUZ
8+
// CHECK: return %[[TRUNC]] : f8E4M3FNUZ
9+
#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb", {abi = "hip", iree.gpu.target = #iree_gpu.target<arch = "gfx942", features = "", wgp = <compute = fp64|fp32|fp16|int64|int32|int16|int8, storage = b64|b32|b16|b8, subgroup = shuffle|arithmetic, dot = dp4xi8toi32, mma = [], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647], max_load_instruction_bits = 128, simds_per_wgp = 4, vgpr_space_bits = 16384>>, ukernels = "none"}>
10+
func.func @negf_f8_unsupported(%arg0 : f8E4M3FNUZ) -> f8E4M3FNUZ attributes
11+
{ hal.executable.target = #executable_target_rocm_hsaco_fb }{
12+
%0 = arith.negf %arg0 : f8E4M3FNUZ
13+
return %0 : f8E4M3FNUZ
14+
}
15+
16+
// -----
17+
18+
// CHECK-LABEL: func.func @expand_f8(
19+
// CHECK-SAME: %[[ARG0:.*]]: f8E5M2FNUZ
20+
// CHECK: %[[EXT0:.*]] = arith.extf %[[ARG0]] {{.*}} : f8E5M2FNUZ to f32
21+
// CHECK: %[[CST:.*]] = arith.constant 1.000000e+00 : f8E5M2FNUZ
22+
// CHECK: %[[EXT1:.*]] = arith.extf %[[CST]] {{.*}} : f8E5M2FNUZ to f32
23+
// CHECK: %[[SUM:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : f32
24+
// CHECK: %[[TRUNC:.*]] = arith.truncf %[[SUM]] {{.*}} : f32 to f8E5M2FNUZ
25+
// CHECK: return %[[TRUNC]] : f8E5M2FNUZ
26+
#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb", {abi = "hip", iree.gpu.target = #iree_gpu.target<arch = "gfx942", features = "", wgp = <compute = fp64|fp32|fp16|int64|int32|int16|int8, storage = b64|b32|b16|b8, subgroup = shuffle|arithmetic, dot = dp4xi8toi32, mma = [], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647], max_load_instruction_bits = 128, simds_per_wgp = 4, vgpr_space_bits = 16384>>, ukernels = "none"}>
27+
func.func @expand_f8(%x: f8E5M2FNUZ) -> f8E5M2FNUZ attributes
28+
{ hal.executable.target = #executable_target_rocm_hsaco_fb }{
29+
%c = arith.constant 1.0 : f8E5M2FNUZ
30+
%y = arith.addf %x, %c : f8E5M2FNUZ
31+
func.return %y : f8E5M2FNUZ
32+
}
33+
34+
// -----
35+
36+
// CHECK-LABEL: func.func @dont_expand_cpu_target
37+
// CHECK: %[[NEG:.*]] = arith.negf {{.*}} : f8E4M3FNUZ
38+
func.func @dont_expand_cpu_target(%arg0 : f8E4M3FNUZ) -> f8E4M3FNUZ attributes
39+
{ hal.executable.target = #hal.executable.target<"llvm-cpu", "xyz", {target_triple = "x86_64-xyz-xyz"}>}{
40+
%0 = arith.negf %arg0 : f8E4M3FNUZ
41+
return %0 : f8E4M3FNUZ
42+
}

compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -1129,6 +1129,7 @@ static void addLowerToLLVMGPUPasses(OpPassManager &modulePassManager,
11291129

11301130
if (forROCDL) {
11311131
// convert to ROCDL.
1132+
funcPassManager.addPass(createConvertUnsupportedFloatArithPass);
11321133
modulePassManager.addPass(createConvertToROCDLPass());
11331134
modulePassManager.addNestedPass<LLVM::LLVMFuncOp>(
11341135
createROCDLAnnotateKernelForTranslationPass());

0 commit comments

Comments
 (0)