-
Notifications
You must be signed in to change notification settings - Fork 13k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][tosa] Add more verifiers for the following operators #127923
Conversation
@llvm/pr-subscribers-mlir Author: Jerry-Ge (Jerry-Ge) Changes…owing operators
Co-authored with:
Patch is 22.97 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/127923.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 7cdf79f4dc59d..a9f6f56532aeb 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -248,6 +248,7 @@ def Tosa_MatMulOp : Tosa_InferShapedTypeOp<"matmul"> {
);
let builders = [Tosa_MatMulOpQuantInfoBuilder];
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -277,6 +278,7 @@ def Tosa_MaxPool2dOp : Tosa_InferShapedTypeOp<"max_pool2d"> {
);
let hasCanonicalizer = 1;
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -1200,6 +1202,7 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> {
);
let hasCanonicalizeMethod = 1;
let hasFolder = 1;
+ let hasVerifier = 1;
let assemblyFormat = [{
operands attr-dict `:` `(` type($input1) `,` type($input2) `,` type($input3)
@@ -1528,6 +1531,7 @@ def Tosa_ConcatOp : Tosa_InferTensorTypeOp<"concat"> {
let hasCanonicalizer = 1;
let hasFolder = 1;
+ let hasVerifier = 1;
let extraClassDeclaration = [{
/// Returns true when two result types are compatible for this op;
@@ -1750,6 +1754,8 @@ def Tosa_GatherOp : Tosa_InferShapedTypeOp<"gather"> {
let results = (outs
Tosa_Tensor3D:$output
);
+
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -1772,6 +1778,8 @@ def Tosa_ScatterOp : Tosa_InferShapedTypeOp<"scatter"> {
let results = (outs
Tosa_Tensor3D:$values_out
);
+
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -1860,6 +1868,7 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure,
let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
let hasFolder = 1;
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index d21e218308df7..154a792552fd2 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -469,6 +469,104 @@ LogicalResult tosa::AvgPool2dOp::verify() {
return emitOpError("input/output element types are incompatible.");
}
+LogicalResult tosa::CastOp::verify() {
+ mlir::Type inputETy =
+ llvm::cast<ShapedType>(getInput().getType()).getElementType();
+ if (auto inputQuantType =
+ llvm::dyn_cast<mlir::quant::QuantizedType>(inputETy)) {
+ inputETy = inputQuantType.getStorageType();
+ }
+ mlir::Type outputETy =
+ llvm::cast<ShapedType>(getOutput().getType()).getElementType();
+ if (auto outputQuantType =
+ llvm::dyn_cast<mlir::quant::QuantizedType>(outputETy)) {
+ outputETy = outputQuantType.getStorageType();
+ }
+
+ // input element type: bool
+ if (inputETy.isInteger(1)) {
+ if (outputETy.isInteger(8) || outputETy.isInteger(16) ||
+ outputETy.isInteger(32)) {
+ return success();
+ }
+ }
+ // input element type: int8
+ if (inputETy.isInteger(8)) {
+ if (outputETy.isInteger(1) || outputETy.isInteger(16) ||
+ outputETy.isInteger(32) || outputETy.isF16() || outputETy.isBF16() ||
+ outputETy.isF32()) {
+ return success();
+ }
+ }
+ // input element type: int16
+ if (inputETy.isInteger(16)) {
+ if (outputETy.isInteger(1) || outputETy.isInteger(8) ||
+ outputETy.isInteger(32) || outputETy.isF16() || outputETy.isBF16() ||
+ outputETy.isF32()) {
+ return success();
+ }
+ }
+ // input element type: int32
+ if (inputETy.isInteger(32)) {
+ if (outputETy.isInteger(1) || outputETy.isInteger(8) ||
+ outputETy.isInteger(16) || outputETy.isF16() || outputETy.isBF16() ||
+ outputETy.isF32()) {
+ return success();
+ }
+ }
+ // input element type: bf16 or fp16
+ if (inputETy.isBF16() || inputETy.isF16()) {
+ if (outputETy.isInteger(8) || outputETy.isInteger(16) ||
+ outputETy.isInteger(32) || llvm::isa<Float8E5M2Type>(outputETy) ||
+ llvm::isa<Float8E4M3FNType>(outputETy) || outputETy.isF32()) {
+ return success();
+ }
+ }
+ // input element type: f8e4m3 or f8e5m2
+ if (llvm::isa<Float8E4M3FNType>(inputETy) ||
+ llvm::isa<Float8E5M2Type>(inputETy)) {
+ if (outputETy.isF16() || outputETy.isBF16() || outputETy.isF32()) {
+ return success();
+ }
+ }
+ // input element type: fp32
+ if (inputETy.isF32()) {
+ if (outputETy.isInteger(8) || outputETy.isInteger(16) ||
+ outputETy.isInteger(32) || llvm::isa<Float8E5M2Type>(outputETy) ||
+ llvm::isa<Float8E4M3FNType>(outputETy) || outputETy.isF16() ||
+ outputETy.isBF16()) {
+ return success();
+ }
+ }
+
+ // following are outside of TOSA Spec
+
+ // allow casting to same type, for quatization/dequantization
+ if (inputETy == outputETy) {
+ return success();
+ }
+
+ // allow casting float to bool, for tosa_to_linalg testing
+ if (inputETy.isF32() && outputETy.isInteger(1)) {
+ return success();
+ }
+
+ // special case for I64
+ if (inputETy.isInteger(64) || outputETy.isInteger(64)) {
+ // be forgiving of casting to and from F64
+ return success();
+ }
+
+ // special case for fp64
+ if (inputETy.isF64() || outputETy.isF64()) {
+ // be forgiving of casting to and from F64
+ return success();
+ }
+
+ return emitOpError("input/output element types are incompatible: ")
+ << inputETy << " and " << outputETy;
+}
+
LogicalResult tosa::ClampOp::verify() {
mlir::Type inputETy =
llvm::cast<ShapedType>(getInput().getType()).getElementType();
@@ -849,6 +947,65 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
return success();
}
+LogicalResult tosa::ConcatOp::verify() {
+ // check that each input has same element type as output
+ auto outType = getOutput().getType();
+ const Operation::operand_range inputList = getInput1();
+
+ for (auto input : inputList) {
+ if (verifySameElementTypes(*this, /* inType = */ input.getType(), outType)
+ .failed()) {
+ return failure();
+ }
+ }
+
+ // Check there is at least one input
+ if (inputList.empty())
+ return emitOpError("expect at least one input");
+
+ const Type firstInputType = inputList.front().getType();
+ const ShapeAdaptor firstInputShape(firstInputType);
+ const int32_t axis = getAxis();
+
+ if (firstInputShape.hasRank()) {
+ // Check axis is in expected range
+ if (axis < 0 || axis >= firstInputShape.getRank())
+ return emitOpError("expect axis to be within range 0 < axis < rank(input1[0]), got ")
+ << axis;
+ }
+
+ const auto allOperandsHasRank = [](const Value input) {
+ return ShapeAdaptor(input.getType()).hasRank();
+ };
+ if (llvm::all_of(inputList, allOperandsHasRank)) {
+ const int64_t firstInputRank = firstInputShape.getRank();
+
+ for (const auto [index, input] : llvm::enumerate(inputList.drop_front())) {
+ const ShapeAdaptor inputShape(input.getType());
+ const int64_t inputRank = inputShape.getRank();
+ const size_t operandNum = index + 1;
+
+ // Check that each operand has the same rank
+ if (inputRank != firstInputRank)
+ return emitOpError("expect all operands to have the same rank, but got ")
+ << firstInputRank << " vs " << inputRank << " on operands 0 and " << operandNum;
+
+ // Check non-axis dims match
+ for (int i = 0; i < inputRank; i++) {
+ const int64_t inputDim = inputShape.getDimSize(i);
+ const int64_t firstInputDim = firstInputShape.getDimSize(i);
+ if (i == axis || firstInputShape.isDynamicDim(i) || inputShape.isDynamicDim(i))
+ continue;
+ if (inputDim != firstInputDim)
+ return emitOpError("expect all operand shapes to have the same sizes on non-axis dimensions, but got ")
+ << inputDim << " vs " << firstInputDim << " at index " << i << " on operands 0 and " << operandNum;
+ }
+ }
+ }
+
+ return success();
+}
+
LogicalResult tosa::EqualOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes,
@@ -898,6 +1055,107 @@ LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
return success();
}
+LogicalResult MatMulOp::verify() {
+ auto aType = llvm::dyn_cast<ShapedType>(getA().getType());
+ auto bType = llvm::dyn_cast<ShapedType>(getB().getType());
+ auto resultEType =
+ llvm::cast<ShapedType>(getResult().getType()).getElementType();
+
+ // Must be shaped tensor types
+ if (!aType) {
+ emitOpError("expect a shaped tensor for input a, got ") << getA().getType();
+ return failure();
+ }
+ if (!bType) {
+ emitOpError("expect a shaped tensor for input b, got ") << getB().getType();
+ return failure();
+ }
+
+ auto aElementType = aType.getElementType();
+ auto bElementType = bType.getElementType();
+
+ auto aQuantizedEType =
+ llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
+ auto bQuantizedEType =
+ llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
+
+ if (aQuantizedEType || bQuantizedEType) {
+ if (!aQuantizedEType || !bQuantizedEType) {
+ emitOpError(
+ "expect operands to be both quantized or both not quantized, got ")
+ << aElementType << " and " << bElementType;
+ return failure();
+ }
+ // both a and b have quantized element types
+ auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth();
+ auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth();
+ if (aQuantWidth != bQuantWidth) {
+ emitOpError("expect quantized operands to have same widths, got ")
+ << aQuantWidth << " and " << bQuantWidth;
+ return failure();
+ }
+
+ if (aQuantWidth != 8 && aQuantWidth != 16) {
+ emitOpError("only support quantized types with width of 8 or 16, got ")
+ << aQuantWidth;
+ return failure();
+ }
+
+ // check result types
+ if (aQuantWidth == 8 && !resultEType.isInteger(32)) {
+ emitOpError("expect result element type to be i32, got ") << resultEType;
+ return failure();
+ }
+
+ // check result types
+ if (aQuantWidth == 16 && !resultEType.isInteger(48)) {
+ emitOpError("expect result element type to be i48, got ") << resultEType;
+ return failure();
+ }
+
+ return success();
+ }
+
+ // non-quantized element types
+
+ if (aElementType != bElementType) {
+ emitOpError("expect same element type for inputs a and b, got ")
+ << aElementType << " and " << bElementType;
+ return failure();
+ }
+ if (llvm::isa<Float8E5M2Type>(aElementType) ||
+ llvm::isa<Float8E4M3FNType>(aElementType)) {
+ if (!resultEType.isF16()) {
+ emitOpError("expect result element type to be f16, got ") << resultEType;
+ return failure();
+ }
+ }
+
+ if (aElementType.isInteger(8) && !resultEType.isInteger(32)) {
+ emitOpError("expect result element type to be i32, got ") << resultEType;
+ return failure();
+ }
+ if (aElementType.isInteger(16) && !resultEType.isInteger(48)) {
+ emitOpError("expect result element type to be i48, got ") << resultEType;
+ return failure();
+ }
+ if (aElementType.isF16() && !(resultEType.isF16() || resultEType.isF32())) {
+ emitOpError("expect result element type to be f16 or f32, got ")
+ << resultEType;
+ return failure();
+ }
+ if (aElementType.isBF16() && !resultEType.isF32()) {
+ emitOpError("expect result element type to be f32, got ") << resultEType;
+ return failure();
+ }
+ if (aElementType.isF32() && !resultEType.isF32()) {
+ emitOpError("expect result element type to be f32, got ") << resultEType;
+ return failure();
+ }
+
+ return success();
+}
+
LogicalResult tosa::PadOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
PadOp::Adaptor adaptor,
@@ -946,6 +1204,18 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
}
LogicalResult tosa::PadOp::verify() {
+ if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
+ /* outType = */ getOutput().getType())
+ .failed()) {
+ return failure();
+ }
+ if (auto padConst = getPadConst()) {
+ if (verifySameElementTypes(*this, /* inType = */ padConst.getType(),
+ /* outType = */ getOutput().getType())
+ .failed()) {
+ return failure();
+ }
+ }
RankedTensorType inputType = getInput1().getType();
RankedTensorType outputType = getOutput().getType();
auto paddingRank = cast<tosa::shapeType>(getPadding().getType()).getRank();
@@ -1019,6 +1289,10 @@ LogicalResult tosa::SliceOp::inferReturnTypeComponents(
}
LogicalResult tosa::SliceOp::verify() {
+ if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
+ /* outType = */ getOutput().getType())
+ .failed())
+ return failure();
auto inputType = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
if (!inputType)
return success();
@@ -1026,14 +1300,13 @@ LogicalResult tosa::SliceOp::verify() {
auto startShapeRank =
llvm::cast<tosa::shapeType>(getStart().getType()).getRank();
if (inputType.getRank() != startShapeRank)
- return emitOpError(
- "length of start attribute is not equal rank of input shape");
+ return emitOpError("length of start is not equal to rank of input shape");
+
auto sizeShapeRank =
llvm::cast<tosa::shapeType>(getSize().getType()).getRank();
if (inputType.getRank() != sizeShapeRank)
- return emitOpError(
- "length of size attribute is not equal rank of input shape");
+ return emitOpError("length of size is not equal to rank of input shape");
return success();
}
@@ -1238,6 +1511,11 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
}
LogicalResult tosa::TileOp::verify() {
+ if (verifySameElementTypes(*this, /* intype = */ getInput1().getType(),
+ /* outType = */ getOutput().getType())
+ .failed()) {
+ return failure();
+ }
ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
ShapedType outputType = llvm::cast<ShapedType>(getType());
@@ -1319,6 +1597,11 @@ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
}
llvm::LogicalResult tosa::ReshapeOp::verify() {
+ if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
+ /* outType = */ getOutput().getType())
+ .failed()) {
+ return failure();
+ }
TensorType inputType = getInput1().getType();
RankedTensorType outputType = getType();
@@ -1463,6 +1746,11 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
}
LogicalResult tosa::TransposeOp::verify() {
+ if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
+ /* outType = */ getOutput().getType())
+ .failed()) {
+ return failure();
+ }
TensorType inputType = getInput1().getType();
TensorType permType = getPerms().getType();
TensorType outputType = getOutput().getType();
@@ -1578,6 +1866,11 @@ LogicalResult tosa::GatherOp::inferReturnTypeComponents(
return success();
}
+LogicalResult tosa::GatherOp::verify() {
+ return verifySameElementTypes(*this, /* inType = */ getValues().getType(),
+ /* outType = */ getOutput().getType());
+}
+
LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ResizeOp::Adaptor adaptor,
@@ -1746,6 +2039,18 @@ LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
return success();
}
+LogicalResult tosa::ScatterOp::verify() {
+ if (verifySameElementTypes(*this, /* inType = */ getValuesIn().getType(),
+ /* outType = */ getValuesOut().getType())
+ .failed() ||
+ verifySameElementTypes(*this, /* inType = */ getInput().getType(),
+ /* outType = */ getValuesOut().getType())
+ .failed()) {
+ return failure();
+ }
+ return success();
+}
+
static LogicalResult ReduceInferReturnTypes(
ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
@@ -2110,6 +2415,11 @@ LogicalResult MaxPool2dOp::inferReturnTypeComponents(
inferredReturnShapes);
}
+LogicalResult MaxPool2dOp::verify() {
+ return verifySameElementTypes(*this, /* intype = */ getInput().getType(),
+ /* outType = */ getOutput().getType());
+}
+
LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
DepthwiseConv2DOp::Adaptor adaptor,
@@ -2412,6 +2722,10 @@ void IfOp::print(OpAsmPrinter &p) {
}
LogicalResult ReverseOp::verify() {
+ if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
+ /* outType = */ getOutput().getType())
+ .failed())
+ return failure();
TensorType inputType = getInput1().getType();
TensorType outputType = getOutput().getType();
int32_t reverseAxis = getAxis();
@@ -2440,6 +2754,33 @@ LogicalResult ReverseOp::verify() {
return success();
}
+LogicalResult tosa::SelectOp::verify() {
+ // verify input2 and input3 have same element type as output
+ if (verifySameElementTypes(*this, /* inType = */ getInput2().getType(),
+ /* outType = */ getOutput().getType())
+ .failed() ||
+ verifySameElementTypes(*this, /* inType = */ getInput3().getType(),
+ /* outType = */ getOutput().getType())
+ .failed()) {
+ return failure();
+ }
+ // verify input1 has element type of bool
+ auto predicateType = llvm::dyn_cast<ShapedType>(getInput1().getType());
+ if (!predicateType) {
+ emitOpError("expect shaped tensor for input1, got ")
+ << getInput1().getType();
+ return failure();
+ }
+ auto predicateElementType = predicateType.getElementType();
+ if (!predicateElementType.isInteger(1)) {
+ emitOpError("expect element type of bool for input1, got ")
+ << predicateElementType;
+ return failure();
+ }
+
+ return success();
+}
+
// parse and print of WhileOp refer to the implementation of SCF dialect.
ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::Argument, 4> regionArgs;
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 1307da88d1e64..4b3525aa005bc 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -165,8 +165,7 @@ func.func @test_conv2d_quant_any_result(%arg0: tensor<1x4x4x4x!quant.any<i8<-8:7
// -----
func.func @test_concat(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> tensor<?x?xf32> {
- // expected-error@+2 {{failed to infer returned types}}
- // expected-error@+1 {{Cannot concat tensors with different sizes on the non-axis dimension 1}}
+ // expected-error@+1 {{'tosa.concat' op expect all operand shapes to have the same sizes on non-axis dimensions, but got 2 vs 1 at index 1 on operands 0 and 1}}
%0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<2x1xf32>, tensor<2x2xf32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
@@ -174,8 +173,7 @@ func.func @test_concat(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> tens
// -----
func.func @test_concat_element_type_mismatch(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x2xf32>) -> tensor<?x?xi8> {
- // expected-error@+2 {{failed to infer returned types}}
- // expected-error@+1 {{'tosa.concat' op inferred type(s) 'tensor<3x2xf32>' are incompatible with return type(s) of operation 'tensor<?x?xi8>}}
+ // expected-error@+1 {{'tosa.concat' op expect input and output to have same element type, got 'f32' and 'i8'}}
%0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<?x?xi8>
return %0 : tensor<?x?xi8>
}
@@ -208,6 +206,14 @@ func.func @test_pad_io_rank_mismatch(%arg0: tensor<13x21xf32>) {
// -----
+func.func @test_concat_input_rank_mismatch(%arg0: tensor<1x...
[truncated]
|
@llvm/pr-subscribers-mlir-tosa Author: Jerry-Ge (Jerry-Ge) Changes…owing operators
Co-authored with:
Patch is 22.97 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/127923.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 7cdf79f4dc59d..a9f6f56532aeb 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -248,6 +248,7 @@ def Tosa_MatMulOp : Tosa_InferShapedTypeOp<"matmul"> {
);
let builders = [Tosa_MatMulOpQuantInfoBuilder];
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -277,6 +278,7 @@ def Tosa_MaxPool2dOp : Tosa_InferShapedTypeOp<"max_pool2d"> {
);
let hasCanonicalizer = 1;
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -1200,6 +1202,7 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> {
);
let hasCanonicalizeMethod = 1;
let hasFolder = 1;
+ let hasVerifier = 1;
let assemblyFormat = [{
operands attr-dict `:` `(` type($input1) `,` type($input2) `,` type($input3)
@@ -1528,6 +1531,7 @@ def Tosa_ConcatOp : Tosa_InferTensorTypeOp<"concat"> {
let hasCanonicalizer = 1;
let hasFolder = 1;
+ let hasVerifier = 1;
let extraClassDeclaration = [{
/// Returns true when two result types are compatible for this op;
@@ -1750,6 +1754,8 @@ def Tosa_GatherOp : Tosa_InferShapedTypeOp<"gather"> {
let results = (outs
Tosa_Tensor3D:$output
);
+
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -1772,6 +1778,8 @@ def Tosa_ScatterOp : Tosa_InferShapedTypeOp<"scatter"> {
let results = (outs
Tosa_Tensor3D:$values_out
);
+
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -1860,6 +1868,7 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure,
let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
let hasFolder = 1;
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index d21e218308df7..154a792552fd2 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -469,6 +469,104 @@ LogicalResult tosa::AvgPool2dOp::verify() {
return emitOpError("input/output element types are incompatible.");
}
+LogicalResult tosa::CastOp::verify() {
+ mlir::Type inputETy =
+ llvm::cast<ShapedType>(getInput().getType()).getElementType();
+ if (auto inputQuantType =
+ llvm::dyn_cast<mlir::quant::QuantizedType>(inputETy)) {
+ inputETy = inputQuantType.getStorageType();
+ }
+ mlir::Type outputETy =
+ llvm::cast<ShapedType>(getOutput().getType()).getElementType();
+ if (auto outputQuantType =
+ llvm::dyn_cast<mlir::quant::QuantizedType>(outputETy)) {
+ outputETy = outputQuantType.getStorageType();
+ }
+
+ // input element type: bool
+ if (inputETy.isInteger(1)) {
+ if (outputETy.isInteger(8) || outputETy.isInteger(16) ||
+ outputETy.isInteger(32)) {
+ return success();
+ }
+ }
+ // input element type: int8
+ if (inputETy.isInteger(8)) {
+ if (outputETy.isInteger(1) || outputETy.isInteger(16) ||
+ outputETy.isInteger(32) || outputETy.isF16() || outputETy.isBF16() ||
+ outputETy.isF32()) {
+ return success();
+ }
+ }
+ // input element type: int16
+ if (inputETy.isInteger(16)) {
+ if (outputETy.isInteger(1) || outputETy.isInteger(8) ||
+ outputETy.isInteger(32) || outputETy.isF16() || outputETy.isBF16() ||
+ outputETy.isF32()) {
+ return success();
+ }
+ }
+ // input element type: int32
+ if (inputETy.isInteger(32)) {
+ if (outputETy.isInteger(1) || outputETy.isInteger(8) ||
+ outputETy.isInteger(16) || outputETy.isF16() || outputETy.isBF16() ||
+ outputETy.isF32()) {
+ return success();
+ }
+ }
+ // input element type: bf16 or fp16
+ if (inputETy.isBF16() || inputETy.isF16()) {
+ if (outputETy.isInteger(8) || outputETy.isInteger(16) ||
+ outputETy.isInteger(32) || llvm::isa<Float8E5M2Type>(outputETy) ||
+ llvm::isa<Float8E4M3FNType>(outputETy) || outputETy.isF32()) {
+ return success();
+ }
+ }
+ // input element type: f8e4m3 or f8e5m2
+ if (llvm::isa<Float8E4M3FNType>(inputETy) ||
+ llvm::isa<Float8E5M2Type>(inputETy)) {
+ if (outputETy.isF16() || outputETy.isBF16() || outputETy.isF32()) {
+ return success();
+ }
+ }
+ // input element type: fp32
+ if (inputETy.isF32()) {
+ if (outputETy.isInteger(8) || outputETy.isInteger(16) ||
+ outputETy.isInteger(32) || llvm::isa<Float8E5M2Type>(outputETy) ||
+ llvm::isa<Float8E4M3FNType>(outputETy) || outputETy.isF16() ||
+ outputETy.isBF16()) {
+ return success();
+ }
+ }
+
+ // following are outside of TOSA Spec
+
+ // allow casting to same type, for quatization/dequantization
+ if (inputETy == outputETy) {
+ return success();
+ }
+
+ // allow casting float to bool, for tosa_to_linalg testing
+ if (inputETy.isF32() && outputETy.isInteger(1)) {
+ return success();
+ }
+
+ // special case for I64
+ if (inputETy.isInteger(64) || outputETy.isInteger(64)) {
+ // be forgiving of casting to and from F64
+ return success();
+ }
+
+ // special case for fp64
+ if (inputETy.isF64() || outputETy.isF64()) {
+ // be forgiving of casting to and from F64
+ return success();
+ }
+
+ return emitOpError("input/output element types are incompatible: ")
+ << inputETy << " and " << outputETy;
+}
+
LogicalResult tosa::ClampOp::verify() {
mlir::Type inputETy =
llvm::cast<ShapedType>(getInput().getType()).getElementType();
@@ -849,6 +947,65 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
return success();
}
+LogicalResult tosa::ConcatOp::verify() {
+ // check that each input has same element type as output
+ auto outType = getOutput().getType();
+ const Operation::operand_range inputList = getInput1();
+
+ for (auto input : inputList) {
+ if (verifySameElementTypes(*this, /* inType = */ input.getType(), outType)
+ .failed()) {
+ return failure();
+ }
+ }
+
+ // Check there is at least one input
+ if (inputList.empty())
+ return emitOpError("expect at least one input");
+
+ const Type firstInputType = inputList.front().getType();
+ const ShapeAdaptor firstInputShape(firstInputType);
+ const int32_t axis = getAxis();
+
+ if (firstInputShape.hasRank()) {
+ // Check axis is in expected range
+ if (axis < 0 || axis >= firstInputShape.getRank())
+ return emitOpError("expect axis to be within range 0 < axis < rank(input1[0]), got ")
+ << axis;
+ }
+
+ const auto allOperandsHasRank = [](const Value input) {
+ return ShapeAdaptor(input.getType()).hasRank();
+ };
+ if (llvm::all_of(inputList, allOperandsHasRank)) {
+ const int64_t firstInputRank = firstInputShape.getRank();
+
+ for (const auto [index, input] : llvm::enumerate(inputList.drop_front())) {
+ const ShapeAdaptor inputShape(input.getType());
+ const int64_t inputRank = inputShape.getRank();
+ const size_t operandNum = index + 1;
+
+ // Check that each operand has the same rank
+ if (inputRank != firstInputRank)
+ return emitOpError("expect all operands to have the same rank, but got ")
+ << firstInputRank << " vs " << inputRank << " on operands 0 and " << operandNum;
+
+ // Check non-axis dims match
+ for (int i = 0; i < inputRank; i++) {
+ const int64_t inputDim = inputShape.getDimSize(i);
+ const int64_t firstInputDim = firstInputShape.getDimSize(i);
+ if (i == axis || firstInputShape.isDynamicDim(i) || inputShape.isDynamicDim(i))
+ continue;
+ if (inputDim != firstInputDim)
+ return emitOpError("expect all operand shapes to have the same sizes on non-axis dimensions, but got ")
+ << inputDim << " vs " << firstInputDim << " at index " << i << " on operands 0 and " << operandNum;
+ }
+ }
+ }
+
+ return success();
+}
+
LogicalResult tosa::EqualOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes,
@@ -898,6 +1055,107 @@ LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
return success();
}
+LogicalResult MatMulOp::verify() {
+ auto aType = llvm::dyn_cast<ShapedType>(getA().getType());
+ auto bType = llvm::dyn_cast<ShapedType>(getB().getType());
+ auto resultEType =
+ llvm::cast<ShapedType>(getResult().getType()).getElementType();
+
+ // Must be shaped tensor types
+ if (!aType) {
+ emitOpError("expect a shaped tensor for input a, got ") << getA().getType();
+ return failure();
+ }
+ if (!bType) {
+ emitOpError("expect a shaped tensor for input b, got ") << getB().getType();
+ return failure();
+ }
+
+ auto aElementType = aType.getElementType();
+ auto bElementType = bType.getElementType();
+
+ auto aQuantizedEType =
+ llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
+ auto bQuantizedEType =
+ llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
+
+ if (aQuantizedEType || bQuantizedEType) {
+ if (!aQuantizedEType || !bQuantizedEType) {
+ emitOpError(
+ "expect operands to be both quantized or both not quantized, got ")
+ << aElementType << " and " << bElementType;
+ return failure();
+ }
+ // both a and b have quantized element types
+ auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth();
+ auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth();
+ if (aQuantWidth != bQuantWidth) {
+ emitOpError("expect quantized operands to have same widths, got ")
+ << aQuantWidth << " and " << bQuantWidth;
+ return failure();
+ }
+
+ if (aQuantWidth != 8 && aQuantWidth != 16) {
+ emitOpError("only support quantized types with width of 8 or 16, got ")
+ << aQuantWidth;
+ return failure();
+ }
+
+ // check result types
+ if (aQuantWidth == 8 && !resultEType.isInteger(32)) {
+ emitOpError("expect result element type to be i32, got ") << resultEType;
+ return failure();
+ }
+
+ // check result types
+ if (aQuantWidth == 16 && !resultEType.isInteger(48)) {
+ emitOpError("expect result element type to be i48, got ") << resultEType;
+ return failure();
+ }
+
+ return success();
+ }
+
+ // non-quantized element types
+
+ if (aElementType != bElementType) {
+ emitOpError("expect same element type for inputs a and b, got ")
+ << aElementType << " and " << bElementType;
+ return failure();
+ }
+ if (llvm::isa<Float8E5M2Type>(aElementType) ||
+ llvm::isa<Float8E4M3FNType>(aElementType)) {
+ if (!resultEType.isF16()) {
+ emitOpError("expect result element type to be f16, got ") << resultEType;
+ return failure();
+ }
+ }
+
+ if (aElementType.isInteger(8) && !resultEType.isInteger(32)) {
+ emitOpError("expect result element type to be i32, got ") << resultEType;
+ return failure();
+ }
+ if (aElementType.isInteger(16) && !resultEType.isInteger(48)) {
+ emitOpError("expect result element type to be i48, got ") << resultEType;
+ return failure();
+ }
+ if (aElementType.isF16() && !(resultEType.isF16() || resultEType.isF32())) {
+ emitOpError("expect result element type to be f16 or f32, got ")
+ << resultEType;
+ return failure();
+ }
+ if (aElementType.isBF16() && !resultEType.isF32()) {
+ emitOpError("expect result element type to be f32, got ") << resultEType;
+ return failure();
+ }
+ if (aElementType.isF32() && !resultEType.isF32()) {
+ emitOpError("expect result element type to be f32, got ") << resultEType;
+ return failure();
+ }
+
+ return success();
+}
+
LogicalResult tosa::PadOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
PadOp::Adaptor adaptor,
@@ -946,6 +1204,18 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
}
LogicalResult tosa::PadOp::verify() {
+ if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
+ /* outType = */ getOutput().getType())
+ .failed()) {
+ return failure();
+ }
+ if (auto padConst = getPadConst()) {
+ if (verifySameElementTypes(*this, /* inType = */ padConst.getType(),
+ /* outType = */ getOutput().getType())
+ .failed()) {
+ return failure();
+ }
+ }
RankedTensorType inputType = getInput1().getType();
RankedTensorType outputType = getOutput().getType();
auto paddingRank = cast<tosa::shapeType>(getPadding().getType()).getRank();
@@ -1019,6 +1289,10 @@ LogicalResult tosa::SliceOp::inferReturnTypeComponents(
}
LogicalResult tosa::SliceOp::verify() {
+ if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
+ /* outType = */ getOutput().getType())
+ .failed())
+ return failure();
auto inputType = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
if (!inputType)
return success();
@@ -1026,14 +1300,13 @@ LogicalResult tosa::SliceOp::verify() {
auto startShapeRank =
llvm::cast<tosa::shapeType>(getStart().getType()).getRank();
if (inputType.getRank() != startShapeRank)
- return emitOpError(
- "length of start attribute is not equal rank of input shape");
+ return emitOpError("length of start is not equal to rank of input shape");
+
auto sizeShapeRank =
llvm::cast<tosa::shapeType>(getSize().getType()).getRank();
if (inputType.getRank() != sizeShapeRank)
- return emitOpError(
- "length of size attribute is not equal rank of input shape");
+ return emitOpError("length of size is not equal to rank of input shape");
return success();
}
@@ -1238,6 +1511,11 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
}
LogicalResult tosa::TileOp::verify() {
+ if (verifySameElementTypes(*this, /* intype = */ getInput1().getType(),
+ /* outType = */ getOutput().getType())
+ .failed()) {
+ return failure();
+ }
ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
ShapedType outputType = llvm::cast<ShapedType>(getType());
@@ -1319,6 +1597,11 @@ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
}
llvm::LogicalResult tosa::ReshapeOp::verify() {
+ if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
+ /* outType = */ getOutput().getType())
+ .failed()) {
+ return failure();
+ }
TensorType inputType = getInput1().getType();
RankedTensorType outputType = getType();
@@ -1463,6 +1746,11 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
}
LogicalResult tosa::TransposeOp::verify() {
+ if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
+ /* outType = */ getOutput().getType())
+ .failed()) {
+ return failure();
+ }
TensorType inputType = getInput1().getType();
TensorType permType = getPerms().getType();
TensorType outputType = getOutput().getType();
@@ -1578,6 +1866,11 @@ LogicalResult tosa::GatherOp::inferReturnTypeComponents(
return success();
}
+LogicalResult tosa::GatherOp::verify() {
+ return verifySameElementTypes(*this, /* inType = */ getValues().getType(),
+ /* outType = */ getOutput().getType());
+}
+
LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ResizeOp::Adaptor adaptor,
@@ -1746,6 +2039,18 @@ LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
return success();
}
+LogicalResult tosa::ScatterOp::verify() {
+ if (verifySameElementTypes(*this, /* inType = */ getValuesIn().getType(),
+ /* outType = */ getValuesOut().getType())
+ .failed() ||
+ verifySameElementTypes(*this, /* inType = */ getInput().getType(),
+ /* outType = */ getValuesOut().getType())
+ .failed()) {
+ return failure();
+ }
+ return success();
+}
+
static LogicalResult ReduceInferReturnTypes(
ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
@@ -2110,6 +2415,11 @@ LogicalResult MaxPool2dOp::inferReturnTypeComponents(
inferredReturnShapes);
}
+LogicalResult MaxPool2dOp::verify() {
+ return verifySameElementTypes(*this, /* intype = */ getInput().getType(),
+ /* outType = */ getOutput().getType());
+}
+
LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
DepthwiseConv2DOp::Adaptor adaptor,
@@ -2412,6 +2722,10 @@ void IfOp::print(OpAsmPrinter &p) {
}
LogicalResult ReverseOp::verify() {
+ if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
+ /* outType = */ getOutput().getType())
+ .failed())
+ return failure();
TensorType inputType = getInput1().getType();
TensorType outputType = getOutput().getType();
int32_t reverseAxis = getAxis();
@@ -2440,6 +2754,33 @@ LogicalResult ReverseOp::verify() {
return success();
}
+LogicalResult tosa::SelectOp::verify() {
+ // verify input2 and input3 have same element type as output
+ if (verifySameElementTypes(*this, /* inType = */ getInput2().getType(),
+ /* outType = */ getOutput().getType())
+ .failed() ||
+ verifySameElementTypes(*this, /* inType = */ getInput3().getType(),
+ /* outType = */ getOutput().getType())
+ .failed()) {
+ return failure();
+ }
+ // verify input1 has element type of bool
+ auto predicateType = llvm::dyn_cast<ShapedType>(getInput1().getType());
+ if (!predicateType) {
+ emitOpError("expect shaped tensor for input1, got ")
+ << getInput1().getType();
+ return failure();
+ }
+ auto predicateElementType = predicateType.getElementType();
+ if (!predicateElementType.isInteger(1)) {
+ emitOpError("expect element type of bool for input1, got ")
+ << predicateElementType;
+ return failure();
+ }
+
+ return success();
+}
+
// parse and print of WhileOp refer to the implementation of SCF dialect.
ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::Argument, 4> regionArgs;
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 1307da88d1e64..4b3525aa005bc 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -165,8 +165,7 @@ func.func @test_conv2d_quant_any_result(%arg0: tensor<1x4x4x4x!quant.any<i8<-8:7
// -----
func.func @test_concat(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> tensor<?x?xf32> {
- // expected-error@+2 {{failed to infer returned types}}
- // expected-error@+1 {{Cannot concat tensors with different sizes on the non-axis dimension 1}}
+ // expected-error@+1 {{'tosa.concat' op expect all operand shapes to have the same sizes on non-axis dimensions, but got 2 vs 1 at index 1 on operands 0 and 1}}
%0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<2x1xf32>, tensor<2x2xf32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
@@ -174,8 +173,7 @@ func.func @test_concat(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> tens
// -----
func.func @test_concat_element_type_mismatch(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x2xf32>) -> tensor<?x?xi8> {
- // expected-error@+2 {{failed to infer returned types}}
- // expected-error@+1 {{'tosa.concat' op inferred type(s) 'tensor<3x2xf32>' are incompatible with return type(s) of operation 'tensor<?x?xi8>}}
+ // expected-error@+1 {{'tosa.concat' op expect input and output to have same element type, got 'f32' and 'i8'}}
%0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<?x?xi8>
return %0 : tensor<?x?xi8>
}
@@ -208,6 +206,14 @@ func.func @test_pad_io_rank_mismatch(%arg0: tensor<13x21xf32>) {
// -----
+func.func @test_concat_input_rank_mismatch(%arg0: tensor<1x...
[truncated]
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
5c2f632
to
d8b401e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As a user of the TOSA dialect that stays within the specification, I'm okay with these data type checks. However we might want to leave this PR open for a bit to ensure other users of TOSA can make sure these checks aren't too restrictive for their workflows
2fe9dd3
to
97ba7a5
Compare
removed verifier checks for Cast input/output types, will need to move those checks into the validation pass. cc @lhutton1 |
fc2791d
to
bb978ee
Compare
Btw, please for future reference try to split to separate patches where possible. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM modulo @FranklandJack's comments. I won't explicitly approve since I authored a portion of this patch.
For ConcatOp this commit also enhances the verifier by checking 4 another conditions: - The input list is not empty - The axis value is within range of the input shapes - All inputs have the same rank - All non concatenate axis dims have the same value For MatmulOp: - Checked input a, bs tensor type, element types For the following operators, added the verifySameElementTypes check. - PadOp - SliceOp - TileOp - ReshapeOp - TransposeOp - GatherOp - ScatterOp - MaxPool2dOp - ReverseOp - SelectOp Change-Id: I1e8a1017f21f617443bc40bae42189915048c750 Co-authored-by: Tai Ly <[email protected]> Co-authored-by: Luke Hutton <[email protected]> Signed-off-by: Jerry Ge <[email protected]>
For ConcatOp this commit also enhances the verifier by
checking 4 another conditions:
For MatmulOp:
For the following operators, added the verifySameElementTypes check.
Co-authored-by: Tai Ly [email protected]
Co-authored-by: Luke Hutton [email protected]
Signed-off-by: Jerry Ge [email protected]