diff --git a/xls/contrib/mlir/testdata/arith_to_xls.mlir b/xls/contrib/mlir/testdata/arith_to_xls.mlir index 255f88f62c..b906e50146 100644 --- a/xls/contrib/mlir/testdata/arith_to_xls.mlir +++ b/xls/contrib/mlir/testdata/arith_to_xls.mlir @@ -112,9 +112,9 @@ func.func @ext(%arg0: i32) -> (i64, i64) attributes { "xls" = true } { // CHECK-LABEL: @extf( // CHECK: call_dslx // CHECK-SAME: "ext" -func.func @extf(%arg0: bf16) -> f32 attributes { "xls" = true } { - %0 = arith.extf %arg0 : bf16 to f32 - return %0 : f32 +func.func @extf(%arg0: tensor<3x3xbf16>) -> tensor<3x3xf32> attributes { "xls" = true } { + %0 = arith.extf %arg0 : tensor<3x3xbf16> to tensor<3x3xf32> + return %0 : tensor<3x3xf32> } // CHECK-LABEL: @truncf( diff --git a/xls/contrib/mlir/transforms/arith_to_xls_patterns.td b/xls/contrib/mlir/transforms/arith_to_xls_patterns.td index 26a1276d69..ec0bb1be3f 100644 --- a/xls/contrib/mlir/transforms/arith_to_xls_patterns.td +++ b/xls/contrib/mlir/transforms/arith_to_xls_patterns.td @@ -9,6 +9,9 @@ def FloatLib : NativeCodeCall<"getFloatLib($0.getType())">; // Shorthand for a constant string attribute. class CS : ConstantStrAttr; +class ScalarOrTensorOf : + AnyTypeOf<[element, TensorOf<[element]>]>; + class BinaryOpPat : Pat<(a $a, $b), (b $a, $b)>; class BinaryOpOverflowPat : Pat<(a $a, $b, $_), (b $a, $b)>; class BinaryVariadicOpPat : Pat<(a $a, $b), (b (variadic $a, $b))>; @@ -131,19 +134,19 @@ def : Pat<(Arith_TruncFOp:$op $a, /*RoundingMode=*/$_, /*FastMathFlags=*/$_), def : Pat<(Arith_SIToFPOp:$op I32:$a), (Xls_CallDslxOp (FloatLib $op), CS<"from_int32">, (variadic $a), ConstUnitAttr), - [(F32 $op)]>; + [(ScalarOrTensorOf $op)]>; def : Pat<(Arith_SIToFPOp:$op I8:$a), (Xls_CallDslxOp (FloatLib $op), CS<"from_int8">, (variadic $a), ConstUnitAttr), - [(BF16 $op)]>; + [(ScalarOrTensorOf $op)]>; def : Pat<(Arith_FPToSIOp:$op F32:$a), (Xls_CallDslxOp (FloatLib $a), CS<"to_int32">, (variadic $a), ConstUnitAttr), - [(I32 $op)]>; + [(ScalarOrTensorOf $op)]>; def : Pat<(Arith_FPToSIOp:$op BF16:$a), (Xls_CallDslxOp (FloatLib $a), CS<"to_int16">, (variadic $a), ConstUnitAttr), - [(I16 $op)]>; + [(ScalarOrTensorOf $op)]>; // TODO(jmolloy): to_int8 doesn't exist, so truncating the result of to_int16 // seems like a reasonable approximation but I don't know if it's bit accurate. @@ -151,7 +154,7 @@ def : Pat<(Arith_FPToSIOp:$op BF16:$a), (Arith_TruncIOp (Xls_CallDslxOp (FloatLib $a), CS<"to_int16">, (variadic $a), ConstUnitAttr, (returnType "$_builder.getI16Type()"))), - [(I8 $op)]>; + [(ScalarOrTensorOf $op)]>; // The expansion is a little tricky to read due to the one-hot select with the // default case being the first argument.