Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][tosa] Add more verifiers for the following operators #127923

Merged
merged 1 commit into from
Mar 5, 2025

Conversation

Jerry-Ge
Copy link
Member

@Jerry-Ge Jerry-Ge commented Feb 20, 2025

For ConcatOp this commit also enhances the verifier by
checking 4 another conditions:

  • The input list is not empty
  • The axis value is within range of the input shapes
  • All inputs have the same rank
  • All non concatenate axis dims have the same value

For MatmulOp:

  • Checked input a, bs tensor type, element types

For the following operators, added the verifySameElementTypes check.

  • PadOp
  • SliceOp
  • TileOp
  • ReshapeOp
  • TransposeOp
  • GatherOp
  • ScatterOp
  • MaxPool2dOp
  • ReverseOp
  • SelectOp

Co-authored-by: Tai Ly [email protected]
Co-authored-by: Luke Hutton [email protected]
Signed-off-by: Jerry Ge [email protected]

@llvmbot
Copy link
Member

llvmbot commented Feb 20, 2025

@llvm/pr-subscribers-mlir

Author: Jerry-Ge (Jerry-Ge)

Changes

…owing operators

  • CastOp
  • ConcatOp
  • MatMulOp
  • PadOp
  • SliceOp
  • TileOp
  • ReshapeOp
  • TransposeOp
  • GatherOp
  • ScatterOp
  • MaxPool2dOp
  • ReverseOp
  • SelectOp

Co-authored with:


Patch is 22.97 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/127923.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+9)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+345-4)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+14-9)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 7cdf79f4dc59d..a9f6f56532aeb 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -248,6 +248,7 @@ def Tosa_MatMulOp : Tosa_InferShapedTypeOp<"matmul"> {
   );
 
   let builders = [Tosa_MatMulOpQuantInfoBuilder];
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -277,6 +278,7 @@ def Tosa_MaxPool2dOp : Tosa_InferShapedTypeOp<"max_pool2d"> {
   );
 
   let hasCanonicalizer = 1;
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -1200,6 +1202,7 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> {
   );
   let hasCanonicalizeMethod = 1;
   let hasFolder = 1;
+  let hasVerifier = 1;
 
   let assemblyFormat = [{
     operands attr-dict `:` `(` type($input1) `,` type($input2) `,` type($input3)
@@ -1528,6 +1531,7 @@ def Tosa_ConcatOp : Tosa_InferTensorTypeOp<"concat"> {
 
   let hasCanonicalizer = 1;
   let hasFolder = 1;
+  let hasVerifier = 1;
 
   let extraClassDeclaration = [{
     /// Returns true when two result types are compatible for this op;
@@ -1750,6 +1754,8 @@ def Tosa_GatherOp : Tosa_InferShapedTypeOp<"gather"> {
   let results = (outs
     Tosa_Tensor3D:$output
   );
+  
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -1772,6 +1778,8 @@ def Tosa_ScatterOp : Tosa_InferShapedTypeOp<"scatter"> {
   let results = (outs
     Tosa_Tensor3D:$values_out
   );
+  
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -1860,6 +1868,7 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure,
   let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
 
   let hasFolder = 1;
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index d21e218308df7..154a792552fd2 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -469,6 +469,104 @@ LogicalResult tosa::AvgPool2dOp::verify() {
   return emitOpError("input/output element types are incompatible.");
 }
 
+LogicalResult tosa::CastOp::verify() {
+  mlir::Type inputETy =
+      llvm::cast<ShapedType>(getInput().getType()).getElementType();
+  if (auto inputQuantType =
+          llvm::dyn_cast<mlir::quant::QuantizedType>(inputETy)) {
+    inputETy = inputQuantType.getStorageType();
+  }
+  mlir::Type outputETy =
+      llvm::cast<ShapedType>(getOutput().getType()).getElementType();
+  if (auto outputQuantType =
+          llvm::dyn_cast<mlir::quant::QuantizedType>(outputETy)) {
+    outputETy = outputQuantType.getStorageType();
+  }
+
+  // input element type: bool
+  if (inputETy.isInteger(1)) {
+    if (outputETy.isInteger(8) || outputETy.isInteger(16) ||
+        outputETy.isInteger(32)) {
+      return success();
+    }
+  }
+  // input element type: int8
+  if (inputETy.isInteger(8)) {
+    if (outputETy.isInteger(1) || outputETy.isInteger(16) ||
+        outputETy.isInteger(32) || outputETy.isF16() || outputETy.isBF16() ||
+        outputETy.isF32()) {
+      return success();
+    }
+  }
+  // input element type: int16
+  if (inputETy.isInteger(16)) {
+    if (outputETy.isInteger(1) || outputETy.isInteger(8) ||
+        outputETy.isInteger(32) || outputETy.isF16() || outputETy.isBF16() ||
+        outputETy.isF32()) {
+      return success();
+    }
+  }
+  // input element type: int32
+  if (inputETy.isInteger(32)) {
+    if (outputETy.isInteger(1) || outputETy.isInteger(8) ||
+        outputETy.isInteger(16) || outputETy.isF16() || outputETy.isBF16() ||
+        outputETy.isF32()) {
+      return success();
+    }
+  }
+  // input element type: bf16 or fp16
+  if (inputETy.isBF16() || inputETy.isF16()) {
+    if (outputETy.isInteger(8) || outputETy.isInteger(16) ||
+        outputETy.isInteger(32) || llvm::isa<Float8E5M2Type>(outputETy) ||
+        llvm::isa<Float8E4M3FNType>(outputETy) || outputETy.isF32()) {
+      return success();
+    }
+  }
+  // input element type: f8e4m3 or f8e5m2
+  if (llvm::isa<Float8E4M3FNType>(inputETy) ||
+      llvm::isa<Float8E5M2Type>(inputETy)) {
+    if (outputETy.isF16() || outputETy.isBF16() || outputETy.isF32()) {
+      return success();
+    }
+  }
+  // input element type: fp32
+  if (inputETy.isF32()) {
+    if (outputETy.isInteger(8) || outputETy.isInteger(16) ||
+        outputETy.isInteger(32) || llvm::isa<Float8E5M2Type>(outputETy) ||
+        llvm::isa<Float8E4M3FNType>(outputETy) || outputETy.isF16() ||
+        outputETy.isBF16()) {
+      return success();
+    }
+  }
+
+  // following are outside of TOSA Spec
+
+  // allow casting to same type, for quatization/dequantization
+  if (inputETy == outputETy) {
+    return success();
+  }
+
+  // allow casting float to bool, for tosa_to_linalg testing
+  if (inputETy.isF32() && outputETy.isInteger(1)) {
+    return success();
+  }
+
+  // special case for I64
+  if (inputETy.isInteger(64) || outputETy.isInteger(64)) {
+    // be forgiving of casting to and from F64
+    return success();
+  }
+
+  // special case for fp64
+  if (inputETy.isF64() || outputETy.isF64()) {
+    // be forgiving of casting to and from F64
+    return success();
+  }
+
+  return emitOpError("input/output element types are incompatible: ")
+         << inputETy << " and " << outputETy;
+}
+
 LogicalResult tosa::ClampOp::verify() {
   mlir::Type inputETy =
       llvm::cast<ShapedType>(getInput().getType()).getElementType();
@@ -849,6 +947,65 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
   return success();
 }
 
+LogicalResult tosa::ConcatOp::verify() {
+  // check that each input has same element type as output
+  auto outType = getOutput().getType();
+  const Operation::operand_range inputList = getInput1();
+
+  for (auto input : inputList) {
+    if (verifySameElementTypes(*this, /* inType = */ input.getType(), outType)
+            .failed()) {
+      return failure();
+    }
+  }
+
+  // Check there is at least one input
+  if (inputList.empty())
+    return emitOpError("expect at least one input");
+
+  const Type firstInputType = inputList.front().getType();
+  const ShapeAdaptor firstInputShape(firstInputType);
+  const int32_t axis = getAxis();
+
+  if (firstInputShape.hasRank()) {
+    // Check axis is in expected range
+    if (axis < 0 || axis >= firstInputShape.getRank())
+      return emitOpError("expect axis to be within range 0 < axis < rank(input1[0]), got ")
+        << axis;
+  }
+
+  const auto allOperandsHasRank = [](const Value input) {
+    return ShapeAdaptor(input.getType()).hasRank();
+  };
+  if (llvm::all_of(inputList, allOperandsHasRank)) {
+    const int64_t firstInputRank = firstInputShape.getRank();
+
+    for (const auto [index, input] : llvm::enumerate(inputList.drop_front())) {
+      const ShapeAdaptor inputShape(input.getType());
+      const int64_t inputRank = inputShape.getRank();
+      const size_t operandNum = index + 1;
+
+      // Check that each operand has the same rank
+      if (inputRank != firstInputRank)
+        return emitOpError("expect all operands to have the same rank, but got ")
+          << firstInputRank << " vs " << inputRank << " on operands 0 and " << operandNum;
+
+      // Check non-axis dims match
+      for (int i = 0; i < inputRank; i++) {
+        const int64_t inputDim = inputShape.getDimSize(i);
+        const int64_t firstInputDim = firstInputShape.getDimSize(i);
+        if (i == axis || firstInputShape.isDynamicDim(i) || inputShape.isDynamicDim(i))
+          continue;
+        if (inputDim != firstInputDim)
+          return emitOpError("expect all operand shapes to have the same sizes on non-axis dimensions, but got ")
+            << inputDim << " vs " << firstInputDim << " at index " << i << " on operands 0 and " << operandNum;
+      }
+    }
+  }
+
+  return success();
+}
+
 LogicalResult tosa::EqualOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     ValueShapeRange operands, DictionaryAttr attributes,
@@ -898,6 +1055,107 @@ LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
   return success();
 }
 
+LogicalResult MatMulOp::verify() {
+  auto aType = llvm::dyn_cast<ShapedType>(getA().getType());
+  auto bType = llvm::dyn_cast<ShapedType>(getB().getType());
+  auto resultEType =
+      llvm::cast<ShapedType>(getResult().getType()).getElementType();
+
+  // Must be shaped tensor types
+  if (!aType) {
+    emitOpError("expect a shaped tensor for input a, got ") << getA().getType();
+    return failure();
+  }
+  if (!bType) {
+    emitOpError("expect a shaped tensor for input b, got ") << getB().getType();
+    return failure();
+  }
+
+  auto aElementType = aType.getElementType();
+  auto bElementType = bType.getElementType();
+
+  auto aQuantizedEType =
+      llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
+  auto bQuantizedEType =
+      llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
+
+  if (aQuantizedEType || bQuantizedEType) {
+    if (!aQuantizedEType || !bQuantizedEType) {
+      emitOpError(
+          "expect operands to be both quantized or both not quantized, got ")
+          << aElementType << " and " << bElementType;
+      return failure();
+    }
+    // both a and b have quantized element types
+    auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth();
+    auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth();
+    if (aQuantWidth != bQuantWidth) {
+      emitOpError("expect quantized operands to have same widths, got ")
+          << aQuantWidth << " and " << bQuantWidth;
+      return failure();
+    }
+
+    if (aQuantWidth != 8 && aQuantWidth != 16) {
+      emitOpError("only support quantized types with width of 8 or 16, got ")
+          << aQuantWidth;
+      return failure();
+    }
+
+    // check result types
+    if (aQuantWidth == 8 && !resultEType.isInteger(32)) {
+      emitOpError("expect result element type to be i32, got ") << resultEType;
+      return failure();
+    }
+
+    // check result types
+    if (aQuantWidth == 16 && !resultEType.isInteger(48)) {
+      emitOpError("expect result element type to be i48, got ") << resultEType;
+      return failure();
+    }
+
+    return success();
+  }
+
+  // non-quantized element types
+
+  if (aElementType != bElementType) {
+    emitOpError("expect same element type for inputs a and b, got ")
+        << aElementType << " and " << bElementType;
+    return failure();
+  }
+  if (llvm::isa<Float8E5M2Type>(aElementType) ||
+      llvm::isa<Float8E4M3FNType>(aElementType)) {
+    if (!resultEType.isF16()) {
+      emitOpError("expect result element type to be f16, got ") << resultEType;
+      return failure();
+    }
+  }
+
+  if (aElementType.isInteger(8) && !resultEType.isInteger(32)) {
+    emitOpError("expect result element type to be i32, got ") << resultEType;
+    return failure();
+  }
+  if (aElementType.isInteger(16) && !resultEType.isInteger(48)) {
+    emitOpError("expect result element type to be i48, got ") << resultEType;
+    return failure();
+  }
+  if (aElementType.isF16() && !(resultEType.isF16() || resultEType.isF32())) {
+    emitOpError("expect result element type to be f16 or f32, got ")
+        << resultEType;
+    return failure();
+  }
+  if (aElementType.isBF16() && !resultEType.isF32()) {
+    emitOpError("expect result element type to be f32, got ") << resultEType;
+    return failure();
+  }
+  if (aElementType.isF32() && !resultEType.isF32()) {
+    emitOpError("expect result element type to be f32, got ") << resultEType;
+    return failure();
+  }
+
+  return success();
+}
+
 LogicalResult tosa::PadOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     PadOp::Adaptor adaptor,
@@ -946,6 +1204,18 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
 }
 
 LogicalResult tosa::PadOp::verify() {
+  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
+                             /* outType = */ getOutput().getType())
+          .failed()) {
+    return failure();
+  }
+  if (auto padConst = getPadConst()) {
+    if (verifySameElementTypes(*this, /* inType = */ padConst.getType(),
+                               /* outType = */ getOutput().getType())
+            .failed()) {
+      return failure();
+    }
+  }
   RankedTensorType inputType = getInput1().getType();
   RankedTensorType outputType = getOutput().getType();
   auto paddingRank = cast<tosa::shapeType>(getPadding().getType()).getRank();
@@ -1019,6 +1289,10 @@ LogicalResult tosa::SliceOp::inferReturnTypeComponents(
 }
 
 LogicalResult tosa::SliceOp::verify() {
+  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
+                             /* outType = */ getOutput().getType())
+          .failed())
+    return failure();
   auto inputType = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
   if (!inputType)
     return success();
@@ -1026,14 +1300,13 @@ LogicalResult tosa::SliceOp::verify() {
   auto startShapeRank =
       llvm::cast<tosa::shapeType>(getStart().getType()).getRank();
   if (inputType.getRank() != startShapeRank)
-    return emitOpError(
-        "length of start attribute is not equal rank of input shape");
+    return emitOpError("length of start is not equal to rank of input shape");
+
 
   auto sizeShapeRank =
       llvm::cast<tosa::shapeType>(getSize().getType()).getRank();
   if (inputType.getRank() != sizeShapeRank)
-    return emitOpError(
-        "length of size attribute is not equal rank of input shape");
+    return emitOpError("length of size is not equal to rank of input shape");
 
   return success();
 }
@@ -1238,6 +1511,11 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
 }
 
 LogicalResult tosa::TileOp::verify() {
+  if (verifySameElementTypes(*this, /* intype = */ getInput1().getType(),
+                             /* outType = */ getOutput().getType())
+          .failed()) {
+    return failure();
+  }
   ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
   ShapedType outputType = llvm::cast<ShapedType>(getType());
 
@@ -1319,6 +1597,11 @@ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
 }
 
 llvm::LogicalResult tosa::ReshapeOp::verify() {
+  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
+                             /* outType = */ getOutput().getType())
+          .failed()) {
+    return failure();
+  }
   TensorType inputType = getInput1().getType();
   RankedTensorType outputType = getType();
 
@@ -1463,6 +1746,11 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
 }
 
 LogicalResult tosa::TransposeOp::verify() {
+  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
+                             /* outType = */ getOutput().getType())
+          .failed()) {
+    return failure();
+  }
   TensorType inputType = getInput1().getType();
   TensorType permType = getPerms().getType();
   TensorType outputType = getOutput().getType();
@@ -1578,6 +1866,11 @@ LogicalResult tosa::GatherOp::inferReturnTypeComponents(
   return success();
 }
 
+LogicalResult tosa::GatherOp::verify() {
+  return verifySameElementTypes(*this, /* inType = */ getValues().getType(),
+                                /* outType = */ getOutput().getType());
+}
+
 LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     ResizeOp::Adaptor adaptor,
@@ -1746,6 +2039,18 @@ LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
   return success();
 }
 
+LogicalResult tosa::ScatterOp::verify() {
+  if (verifySameElementTypes(*this, /* inType = */ getValuesIn().getType(),
+                             /* outType = */ getValuesOut().getType())
+          .failed() ||
+      verifySameElementTypes(*this, /* inType = */ getInput().getType(),
+                             /* outType = */ getValuesOut().getType())
+          .failed()) {
+    return failure();
+  }
+  return success();
+}
+
 static LogicalResult ReduceInferReturnTypes(
     ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
@@ -2110,6 +2415,11 @@ LogicalResult MaxPool2dOp::inferReturnTypeComponents(
                                  inferredReturnShapes);
 }
 
+LogicalResult MaxPool2dOp::verify() {
+  return verifySameElementTypes(*this, /* intype = */ getInput().getType(),
+                                /* outType = */ getOutput().getType());
+}
+
 LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     DepthwiseConv2DOp::Adaptor adaptor,
@@ -2412,6 +2722,10 @@ void IfOp::print(OpAsmPrinter &p) {
 }
 
 LogicalResult ReverseOp::verify() {
+  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
+                             /* outType = */ getOutput().getType())
+          .failed())
+    return failure();
   TensorType inputType = getInput1().getType();
   TensorType outputType = getOutput().getType();
   int32_t reverseAxis = getAxis();
@@ -2440,6 +2754,33 @@ LogicalResult ReverseOp::verify() {
   return success();
 }
 
+LogicalResult tosa::SelectOp::verify() {
+  // verify input2 and input3 have same element type as output
+  if (verifySameElementTypes(*this, /* inType = */ getInput2().getType(),
+                             /* outType = */ getOutput().getType())
+          .failed() ||
+      verifySameElementTypes(*this, /* inType = */ getInput3().getType(),
+                             /* outType = */ getOutput().getType())
+          .failed()) {
+    return failure();
+  }
+  // verify input1 has element type of bool
+  auto predicateType = llvm::dyn_cast<ShapedType>(getInput1().getType());
+  if (!predicateType) {
+    emitOpError("expect shaped tensor for input1, got ")
+        << getInput1().getType();
+    return failure();
+  }
+  auto predicateElementType = predicateType.getElementType();
+  if (!predicateElementType.isInteger(1)) {
+    emitOpError("expect element type of bool for input1, got ")
+        << predicateElementType;
+    return failure();
+  }
+
+  return success();
+}
+
 // parse and print of WhileOp refer to the implementation of SCF dialect.
 ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
   SmallVector<OpAsmParser::Argument, 4> regionArgs;
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 1307da88d1e64..4b3525aa005bc 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -165,8 +165,7 @@ func.func @test_conv2d_quant_any_result(%arg0: tensor<1x4x4x4x!quant.any<i8<-8:7
 // -----
 
 func.func @test_concat(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> tensor<?x?xf32> {
-  // expected-error@+2 {{failed to infer returned types}}
-  // expected-error@+1 {{Cannot concat tensors with different sizes on the non-axis dimension 1}}
+  // expected-error@+1 {{'tosa.concat' op expect all operand shapes to have the same sizes on non-axis dimensions, but got 2 vs 1 at index 1 on operands 0 and 1}}
   %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<2x1xf32>, tensor<2x2xf32>) -> tensor<?x?xf32>
   return %0 : tensor<?x?xf32>
 }
@@ -174,8 +173,7 @@ func.func @test_concat(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> tens
 // -----
 
 func.func @test_concat_element_type_mismatch(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x2xf32>) -> tensor<?x?xi8> {
-  // expected-error@+2 {{failed to infer returned types}}
-  // expected-error@+1 {{'tosa.concat' op inferred type(s) 'tensor<3x2xf32>' are incompatible with return type(s) of operation 'tensor<?x?xi8>}}
+  // expected-error@+1 {{'tosa.concat' op expect input and output to have same element type, got 'f32' and 'i8'}}
   %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<?x?xi8>
   return %0 : tensor<?x?xi8>
 }
@@ -208,6 +206,14 @@ func.func @test_pad_io_rank_mismatch(%arg0: tensor<13x21xf32>) {
 
 // -----
 
+func.func @test_concat_input_rank_mismatch(%arg0: tensor<1x...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Feb 20, 2025

@llvm/pr-subscribers-mlir-tosa

Author: Jerry-Ge (Jerry-Ge)

Changes

…owing operators

  • CastOp
  • ConcatOp
  • MatMulOp
  • PadOp
  • SliceOp
  • TileOp
  • ReshapeOp
  • TransposeOp
  • GatherOp
  • ScatterOp
  • MaxPool2dOp
  • ReverseOp
  • SelectOp

Co-authored with:


Patch is 22.97 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/127923.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+9)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+345-4)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+14-9)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 7cdf79f4dc59d..a9f6f56532aeb 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -248,6 +248,7 @@ def Tosa_MatMulOp : Tosa_InferShapedTypeOp<"matmul"> {
   );
 
   let builders = [Tosa_MatMulOpQuantInfoBuilder];
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -277,6 +278,7 @@ def Tosa_MaxPool2dOp : Tosa_InferShapedTypeOp<"max_pool2d"> {
   );
 
   let hasCanonicalizer = 1;
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -1200,6 +1202,7 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> {
   );
   let hasCanonicalizeMethod = 1;
   let hasFolder = 1;
+  let hasVerifier = 1;
 
   let assemblyFormat = [{
     operands attr-dict `:` `(` type($input1) `,` type($input2) `,` type($input3)
@@ -1528,6 +1531,7 @@ def Tosa_ConcatOp : Tosa_InferTensorTypeOp<"concat"> {
 
   let hasCanonicalizer = 1;
   let hasFolder = 1;
+  let hasVerifier = 1;
 
   let extraClassDeclaration = [{
     /// Returns true when two result types are compatible for this op;
@@ -1750,6 +1754,8 @@ def Tosa_GatherOp : Tosa_InferShapedTypeOp<"gather"> {
   let results = (outs
     Tosa_Tensor3D:$output
   );
+  
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -1772,6 +1778,8 @@ def Tosa_ScatterOp : Tosa_InferShapedTypeOp<"scatter"> {
   let results = (outs
     Tosa_Tensor3D:$values_out
   );
+  
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -1860,6 +1868,7 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure,
   let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
 
   let hasFolder = 1;
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index d21e218308df7..154a792552fd2 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -469,6 +469,104 @@ LogicalResult tosa::AvgPool2dOp::verify() {
   return emitOpError("input/output element types are incompatible.");
 }
 
+LogicalResult tosa::CastOp::verify() {
+  mlir::Type inputETy =
+      llvm::cast<ShapedType>(getInput().getType()).getElementType();
+  if (auto inputQuantType =
+          llvm::dyn_cast<mlir::quant::QuantizedType>(inputETy)) {
+    inputETy = inputQuantType.getStorageType();
+  }
+  mlir::Type outputETy =
+      llvm::cast<ShapedType>(getOutput().getType()).getElementType();
+  if (auto outputQuantType =
+          llvm::dyn_cast<mlir::quant::QuantizedType>(outputETy)) {
+    outputETy = outputQuantType.getStorageType();
+  }
+
+  // input element type: bool
+  if (inputETy.isInteger(1)) {
+    if (outputETy.isInteger(8) || outputETy.isInteger(16) ||
+        outputETy.isInteger(32)) {
+      return success();
+    }
+  }
+  // input element type: int8
+  if (inputETy.isInteger(8)) {
+    if (outputETy.isInteger(1) || outputETy.isInteger(16) ||
+        outputETy.isInteger(32) || outputETy.isF16() || outputETy.isBF16() ||
+        outputETy.isF32()) {
+      return success();
+    }
+  }
+  // input element type: int16
+  if (inputETy.isInteger(16)) {
+    if (outputETy.isInteger(1) || outputETy.isInteger(8) ||
+        outputETy.isInteger(32) || outputETy.isF16() || outputETy.isBF16() ||
+        outputETy.isF32()) {
+      return success();
+    }
+  }
+  // input element type: int32
+  if (inputETy.isInteger(32)) {
+    if (outputETy.isInteger(1) || outputETy.isInteger(8) ||
+        outputETy.isInteger(16) || outputETy.isF16() || outputETy.isBF16() ||
+        outputETy.isF32()) {
+      return success();
+    }
+  }
+  // input element type: bf16 or fp16
+  if (inputETy.isBF16() || inputETy.isF16()) {
+    if (outputETy.isInteger(8) || outputETy.isInteger(16) ||
+        outputETy.isInteger(32) || llvm::isa<Float8E5M2Type>(outputETy) ||
+        llvm::isa<Float8E4M3FNType>(outputETy) || outputETy.isF32()) {
+      return success();
+    }
+  }
+  // input element type: f8e4m3 or f8e5m2
+  if (llvm::isa<Float8E4M3FNType>(inputETy) ||
+      llvm::isa<Float8E5M2Type>(inputETy)) {
+    if (outputETy.isF16() || outputETy.isBF16() || outputETy.isF32()) {
+      return success();
+    }
+  }
+  // input element type: fp32
+  if (inputETy.isF32()) {
+    if (outputETy.isInteger(8) || outputETy.isInteger(16) ||
+        outputETy.isInteger(32) || llvm::isa<Float8E5M2Type>(outputETy) ||
+        llvm::isa<Float8E4M3FNType>(outputETy) || outputETy.isF16() ||
+        outputETy.isBF16()) {
+      return success();
+    }
+  }
+
+  // following are outside of TOSA Spec
+
+  // allow casting to same type, for quatization/dequantization
+  if (inputETy == outputETy) {
+    return success();
+  }
+
+  // allow casting float to bool, for tosa_to_linalg testing
+  if (inputETy.isF32() && outputETy.isInteger(1)) {
+    return success();
+  }
+
+  // special case for I64
+  if (inputETy.isInteger(64) || outputETy.isInteger(64)) {
+    // be forgiving of casting to and from F64
+    return success();
+  }
+
+  // special case for fp64
+  if (inputETy.isF64() || outputETy.isF64()) {
+    // be forgiving of casting to and from F64
+    return success();
+  }
+
+  return emitOpError("input/output element types are incompatible: ")
+         << inputETy << " and " << outputETy;
+}
+
 LogicalResult tosa::ClampOp::verify() {
   mlir::Type inputETy =
       llvm::cast<ShapedType>(getInput().getType()).getElementType();
@@ -849,6 +947,65 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
   return success();
 }
 
+LogicalResult tosa::ConcatOp::verify() {
+  // check that each input has same element type as output
+  auto outType = getOutput().getType();
+  const Operation::operand_range inputList = getInput1();
+
+  for (auto input : inputList) {
+    if (verifySameElementTypes(*this, /* inType = */ input.getType(), outType)
+            .failed()) {
+      return failure();
+    }
+  }
+
+  // Check there is at least one input
+  if (inputList.empty())
+    return emitOpError("expect at least one input");
+
+  const Type firstInputType = inputList.front().getType();
+  const ShapeAdaptor firstInputShape(firstInputType);
+  const int32_t axis = getAxis();
+
+  if (firstInputShape.hasRank()) {
+    // Check axis is in expected range
+    if (axis < 0 || axis >= firstInputShape.getRank())
+      return emitOpError("expect axis to be within range 0 < axis < rank(input1[0]), got ")
+        << axis;
+  }
+
+  const auto allOperandsHasRank = [](const Value input) {
+    return ShapeAdaptor(input.getType()).hasRank();
+  };
+  if (llvm::all_of(inputList, allOperandsHasRank)) {
+    const int64_t firstInputRank = firstInputShape.getRank();
+
+    for (const auto [index, input] : llvm::enumerate(inputList.drop_front())) {
+      const ShapeAdaptor inputShape(input.getType());
+      const int64_t inputRank = inputShape.getRank();
+      const size_t operandNum = index + 1;
+
+      // Check that each operand has the same rank
+      if (inputRank != firstInputRank)
+        return emitOpError("expect all operands to have the same rank, but got ")
+          << firstInputRank << " vs " << inputRank << " on operands 0 and " << operandNum;
+
+      // Check non-axis dims match
+      for (int i = 0; i < inputRank; i++) {
+        const int64_t inputDim = inputShape.getDimSize(i);
+        const int64_t firstInputDim = firstInputShape.getDimSize(i);
+        if (i == axis || firstInputShape.isDynamicDim(i) || inputShape.isDynamicDim(i))
+          continue;
+        if (inputDim != firstInputDim)
+          return emitOpError("expect all operand shapes to have the same sizes on non-axis dimensions, but got ")
+            << inputDim << " vs " << firstInputDim << " at index " << i << " on operands 0 and " << operandNum;
+      }
+    }
+  }
+
+  return success();
+}
+
 LogicalResult tosa::EqualOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     ValueShapeRange operands, DictionaryAttr attributes,
@@ -898,6 +1055,107 @@ LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
   return success();
 }
 
+LogicalResult MatMulOp::verify() {
+  auto aType = llvm::dyn_cast<ShapedType>(getA().getType());
+  auto bType = llvm::dyn_cast<ShapedType>(getB().getType());
+  auto resultEType =
+      llvm::cast<ShapedType>(getResult().getType()).getElementType();
+
+  // Must be shaped tensor types
+  if (!aType) {
+    emitOpError("expect a shaped tensor for input a, got ") << getA().getType();
+    return failure();
+  }
+  if (!bType) {
+    emitOpError("expect a shaped tensor for input b, got ") << getB().getType();
+    return failure();
+  }
+
+  auto aElementType = aType.getElementType();
+  auto bElementType = bType.getElementType();
+
+  auto aQuantizedEType =
+      llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
+  auto bQuantizedEType =
+      llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
+
+  if (aQuantizedEType || bQuantizedEType) {
+    if (!aQuantizedEType || !bQuantizedEType) {
+      emitOpError(
+          "expect operands to be both quantized or both not quantized, got ")
+          << aElementType << " and " << bElementType;
+      return failure();
+    }
+    // both a and b have quantized element types
+    auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth();
+    auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth();
+    if (aQuantWidth != bQuantWidth) {
+      emitOpError("expect quantized operands to have same widths, got ")
+          << aQuantWidth << " and " << bQuantWidth;
+      return failure();
+    }
+
+    if (aQuantWidth != 8 && aQuantWidth != 16) {
+      emitOpError("only support quantized types with width of 8 or 16, got ")
+          << aQuantWidth;
+      return failure();
+    }
+
+    // check result types
+    if (aQuantWidth == 8 && !resultEType.isInteger(32)) {
+      emitOpError("expect result element type to be i32, got ") << resultEType;
+      return failure();
+    }
+
+    // check result types
+    if (aQuantWidth == 16 && !resultEType.isInteger(48)) {
+      emitOpError("expect result element type to be i48, got ") << resultEType;
+      return failure();
+    }
+
+    return success();
+  }
+
+  // non-quantized element types
+
+  if (aElementType != bElementType) {
+    emitOpError("expect same element type for inputs a and b, got ")
+        << aElementType << " and " << bElementType;
+    return failure();
+  }
+  if (llvm::isa<Float8E5M2Type>(aElementType) ||
+      llvm::isa<Float8E4M3FNType>(aElementType)) {
+    if (!resultEType.isF16()) {
+      emitOpError("expect result element type to be f16, got ") << resultEType;
+      return failure();
+    }
+  }
+
+  if (aElementType.isInteger(8) && !resultEType.isInteger(32)) {
+    emitOpError("expect result element type to be i32, got ") << resultEType;
+    return failure();
+  }
+  if (aElementType.isInteger(16) && !resultEType.isInteger(48)) {
+    emitOpError("expect result element type to be i48, got ") << resultEType;
+    return failure();
+  }
+  if (aElementType.isF16() && !(resultEType.isF16() || resultEType.isF32())) {
+    emitOpError("expect result element type to be f16 or f32, got ")
+        << resultEType;
+    return failure();
+  }
+  if (aElementType.isBF16() && !resultEType.isF32()) {
+    emitOpError("expect result element type to be f32, got ") << resultEType;
+    return failure();
+  }
+  if (aElementType.isF32() && !resultEType.isF32()) {
+    emitOpError("expect result element type to be f32, got ") << resultEType;
+    return failure();
+  }
+
+  return success();
+}
+
 LogicalResult tosa::PadOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     PadOp::Adaptor adaptor,
@@ -946,6 +1204,18 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
 }
 
 LogicalResult tosa::PadOp::verify() {
+  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
+                             /* outType = */ getOutput().getType())
+          .failed()) {
+    return failure();
+  }
+  if (auto padConst = getPadConst()) {
+    if (verifySameElementTypes(*this, /* inType = */ padConst.getType(),
+                               /* outType = */ getOutput().getType())
+            .failed()) {
+      return failure();
+    }
+  }
   RankedTensorType inputType = getInput1().getType();
   RankedTensorType outputType = getOutput().getType();
   auto paddingRank = cast<tosa::shapeType>(getPadding().getType()).getRank();
@@ -1019,6 +1289,10 @@ LogicalResult tosa::SliceOp::inferReturnTypeComponents(
 }
 
 LogicalResult tosa::SliceOp::verify() {
+  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
+                             /* outType = */ getOutput().getType())
+          .failed())
+    return failure();
   auto inputType = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
   if (!inputType)
     return success();
@@ -1026,14 +1300,13 @@ LogicalResult tosa::SliceOp::verify() {
   auto startShapeRank =
       llvm::cast<tosa::shapeType>(getStart().getType()).getRank();
   if (inputType.getRank() != startShapeRank)
-    return emitOpError(
-        "length of start attribute is not equal rank of input shape");
+    return emitOpError("length of start is not equal to rank of input shape");
+
 
   auto sizeShapeRank =
       llvm::cast<tosa::shapeType>(getSize().getType()).getRank();
   if (inputType.getRank() != sizeShapeRank)
-    return emitOpError(
-        "length of size attribute is not equal rank of input shape");
+    return emitOpError("length of size is not equal to rank of input shape");
 
   return success();
 }
@@ -1238,6 +1511,11 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
 }
 
 LogicalResult tosa::TileOp::verify() {
+  if (verifySameElementTypes(*this, /* intype = */ getInput1().getType(),
+                             /* outType = */ getOutput().getType())
+          .failed()) {
+    return failure();
+  }
   ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
   ShapedType outputType = llvm::cast<ShapedType>(getType());
 
@@ -1319,6 +1597,11 @@ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
 }
 
 llvm::LogicalResult tosa::ReshapeOp::verify() {
+  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
+                             /* outType = */ getOutput().getType())
+          .failed()) {
+    return failure();
+  }
   TensorType inputType = getInput1().getType();
   RankedTensorType outputType = getType();
 
@@ -1463,6 +1746,11 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
 }
 
 LogicalResult tosa::TransposeOp::verify() {
+  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
+                             /* outType = */ getOutput().getType())
+          .failed()) {
+    return failure();
+  }
   TensorType inputType = getInput1().getType();
   TensorType permType = getPerms().getType();
   TensorType outputType = getOutput().getType();
@@ -1578,6 +1866,11 @@ LogicalResult tosa::GatherOp::inferReturnTypeComponents(
   return success();
 }
 
+LogicalResult tosa::GatherOp::verify() {
+  return verifySameElementTypes(*this, /* inType = */ getValues().getType(),
+                                /* outType = */ getOutput().getType());
+}
+
 LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     ResizeOp::Adaptor adaptor,
@@ -1746,6 +2039,18 @@ LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
   return success();
 }
 
+LogicalResult tosa::ScatterOp::verify() {
+  if (verifySameElementTypes(*this, /* inType = */ getValuesIn().getType(),
+                             /* outType = */ getValuesOut().getType())
+          .failed() ||
+      verifySameElementTypes(*this, /* inType = */ getInput().getType(),
+                             /* outType = */ getValuesOut().getType())
+          .failed()) {
+    return failure();
+  }
+  return success();
+}
+
 static LogicalResult ReduceInferReturnTypes(
     ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
@@ -2110,6 +2415,11 @@ LogicalResult MaxPool2dOp::inferReturnTypeComponents(
                                  inferredReturnShapes);
 }
 
+LogicalResult MaxPool2dOp::verify() {
+  return verifySameElementTypes(*this, /* intype = */ getInput().getType(),
+                                /* outType = */ getOutput().getType());
+}
+
 LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     DepthwiseConv2DOp::Adaptor adaptor,
@@ -2412,6 +2722,10 @@ void IfOp::print(OpAsmPrinter &p) {
 }
 
 LogicalResult ReverseOp::verify() {
+  if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
+                             /* outType = */ getOutput().getType())
+          .failed())
+    return failure();
   TensorType inputType = getInput1().getType();
   TensorType outputType = getOutput().getType();
   int32_t reverseAxis = getAxis();
@@ -2440,6 +2754,33 @@ LogicalResult ReverseOp::verify() {
   return success();
 }
 
+LogicalResult tosa::SelectOp::verify() {
+  // verify input2 and input3 have same element type as output
+  if (verifySameElementTypes(*this, /* inType = */ getInput2().getType(),
+                             /* outType = */ getOutput().getType())
+          .failed() ||
+      verifySameElementTypes(*this, /* inType = */ getInput3().getType(),
+                             /* outType = */ getOutput().getType())
+          .failed()) {
+    return failure();
+  }
+  // verify input1 has element type of bool
+  auto predicateType = llvm::dyn_cast<ShapedType>(getInput1().getType());
+  if (!predicateType) {
+    emitOpError("expect shaped tensor for input1, got ")
+        << getInput1().getType();
+    return failure();
+  }
+  auto predicateElementType = predicateType.getElementType();
+  if (!predicateElementType.isInteger(1)) {
+    emitOpError("expect element type of bool for input1, got ")
+        << predicateElementType;
+    return failure();
+  }
+
+  return success();
+}
+
 // parse and print of WhileOp refer to the implementation of SCF dialect.
 ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
   SmallVector<OpAsmParser::Argument, 4> regionArgs;
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 1307da88d1e64..4b3525aa005bc 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -165,8 +165,7 @@ func.func @test_conv2d_quant_any_result(%arg0: tensor<1x4x4x4x!quant.any<i8<-8:7
 // -----
 
 func.func @test_concat(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> tensor<?x?xf32> {
-  // expected-error@+2 {{failed to infer returned types}}
-  // expected-error@+1 {{Cannot concat tensors with different sizes on the non-axis dimension 1}}
+  // expected-error@+1 {{'tosa.concat' op expect all operand shapes to have the same sizes on non-axis dimensions, but got 2 vs 1 at index 1 on operands 0 and 1}}
   %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<2x1xf32>, tensor<2x2xf32>) -> tensor<?x?xf32>
   return %0 : tensor<?x?xf32>
 }
@@ -174,8 +173,7 @@ func.func @test_concat(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> tens
 // -----
 
 func.func @test_concat_element_type_mismatch(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x2xf32>) -> tensor<?x?xi8> {
-  // expected-error@+2 {{failed to infer returned types}}
-  // expected-error@+1 {{'tosa.concat' op inferred type(s) 'tensor<3x2xf32>' are incompatible with return type(s) of operation 'tensor<?x?xi8>}}
+  // expected-error@+1 {{'tosa.concat' op expect input and output to have same element type, got 'f32' and 'i8'}}
   %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<?x?xi8>
   return %0 : tensor<?x?xi8>
 }
@@ -208,6 +206,14 @@ func.func @test_pad_io_rank_mismatch(%arg0: tensor<13x21xf32>) {
 
 // -----
 
+func.func @test_concat_input_rank_mismatch(%arg0: tensor<1x...
[truncated]

Copy link

github-actions bot commented Feb 20, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@Jerry-Ge Jerry-Ge force-pushed the verifiers branch 3 times, most recently from 5c2f632 to d8b401e Compare February 20, 2025 22:24
Copy link
Contributor

@lhutton1 lhutton1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a user of the TOSA dialect that stays within the specification, I'm okay with these data type checks. However we might want to leave this PR open for a bit to ensure other users of TOSA can make sure these checks aren't too restrictive for their workflows

@Jerry-Ge Jerry-Ge force-pushed the verifiers branch 2 times, most recently from 2fe9dd3 to 97ba7a5 Compare February 25, 2025 22:13
@Jerry-Ge Jerry-Ge changed the title [mlir][tosa] Add additional input output dtype verifiers for the foll… [mlir][tosa] Add more verifiers for the following operators Feb 25, 2025
@Jerry-Ge
Copy link
Member Author

Jerry-Ge commented Feb 25, 2025

removed verifier checks for Cast input/output types, will need to move those checks into the validation pass. cc @lhutton1

@Jerry-Ge Jerry-Ge force-pushed the verifiers branch 2 times, most recently from fc2791d to bb978ee Compare February 27, 2025 22:12
@GeorgeARM
Copy link
Contributor

Btw, please for future reference try to split to separate patches where possible.

Copy link
Contributor

@lhutton1 lhutton1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM modulo @FranklandJack's comments. I won't explicitly approve since I authored a portion of this patch.

For ConcatOp this commit also enhances the verifier by
checking 4 another conditions:
- The input list is not empty
- The axis value is within range of the input shapes
- All inputs have the same rank
- All non concatenate axis dims have the same value

For MatmulOp:
- Checked input a, bs tensor type, element types

For the following operators, added the verifySameElementTypes check.
- PadOp
- SliceOp
- TileOp
- ReshapeOp
- TransposeOp
- GatherOp
- ScatterOp
- MaxPool2dOp
- ReverseOp
- SelectOp

Change-Id: I1e8a1017f21f617443bc40bae42189915048c750
Co-authored-by: Tai Ly <[email protected]>
Co-authored-by: Luke Hutton <[email protected]>
Signed-off-by: Jerry Ge <[email protected]>
@Jerry-Ge Jerry-Ge merged commit db70d76 into llvm:main Mar 5, 2025
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants