Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOSA] Add Tosa_Shape type and ConstShapeOp #122547

Merged
merged 1 commit into from
Jan 14, 2025
Merged

[TOSA] Add Tosa_Shape type and ConstShapeOp #122547

merged 1 commit into from
Jan 14, 2025

Conversation

Jerry-Ge
Copy link
Member

Adds:

  1. tosa shape type to Tosa dialect e.g., !tosa.shape<4> is a type for rank-4 shape values (size-4 array of index values)
  2. const_shape operator
  3. trait TosaShapeOperator, added to tosa shape operators, and a verifier that all operands and results of operator are tosa shapes
  4. trait TosaResolvableShapeOperands, added to all tosa operators, and a verifier that every tosa shape operand is produced by a tosa shape operator (indicated by trait TosaShapeOperator)
  5. trait TosaShapeOperatorWithSameRanks, added to Tosa_ElementwiseShapeOp and a verifier that all operands and result shapes have same ranks
  6. changed TileOp's multiples from attribute to input, of !tosa.shape type.
  7. add folder for tosa ConstShape operator

This patch was originally authored by Tai Ly [email protected]

Signed-off-by: Jerry Ge [email protected]
Signed-off-by: Tai Ly [email protected]

Change-Id: I0213f99f5816b648f732b01fe8bd196956f1dfc8

@llvmbot llvmbot added mlir:linalg mlir mlir:tosa bazel "Peripheral" support tier build system: utils/bazel labels Jan 10, 2025
@Jerry-Ge
Copy link
Member Author

cc @Tai78641 @sjarus @eric-k256

@llvmbot
Copy link
Member

llvmbot commented Jan 10, 2025

@llvm/pr-subscribers-mlir-linalg
@llvm/pr-subscribers-mlir-tosa

@llvm/pr-subscribers-mlir

Author: Jerry-Ge (Jerry-Ge)

Changes

Adds:

  1. tosa shape type to Tosa dialect e.g., !tosa.shape<4> is a type for rank-4 shape values (size-4 array of index values)
  2. const_shape operator
  3. trait TosaShapeOperator, added to tosa shape operators, and a verifier that all operands and results of operator are tosa shapes
  4. trait TosaResolvableShapeOperands, added to all tosa operators, and a verifier that every tosa shape operand is produced by a tosa shape operator (indicated by trait TosaShapeOperator)
  5. trait TosaShapeOperatorWithSameRanks, added to Tosa_ElementwiseShapeOp and a verifier that all operands and result shapes have same ranks
  6. changed TileOp's multiples from attribute to input, of !tosa.shape type.
  7. add folder for tosa ConstShape operator

This patch was originally authored by Tai Ly <[email protected]>

Signed-off-by: Jerry Ge <[email protected]>
Signed-off-by: Tai Ly <[email protected]>

Change-Id: I0213f99f5816b648f732b01fe8bd196956f1dfc8


Patch is 33.77 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/122547.diff

18 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt (+2-1)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td (+1)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h (+43)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+8-1)
  • (added) mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td (+79)
  • (added) mlir/include/mlir/Dialect/Tosa/IR/TosaTypes.td (+87)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+3-1)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp (+1)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+16-3)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+134-9)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+2)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+10-5)
  • (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+6-3)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+29-4)
  • (modified) mlir/test/Dialect/Tosa/level_check.mlir (+2-2)
  • (modified) mlir/test/Dialect/Tosa/ops.mlir (+9-1)
  • (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+4-2)
  • (modified) utils/bazel/llvm-project-overlay/mlir/BUILD.bazel (+8)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt
index 1ee105f0ceb98b..81c0f2ef159e82 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt
@@ -3,6 +3,8 @@ add_mlir_doc(TosaOps TosaOps Dialects/ -gen-op-doc)
 add_mlir_interface(TosaInterfaces)
 
 set(LLVM_TARGET_DEFINITIONS TosaOps.td)
+mlir_tablegen(TosaOpsTypes.h.inc -gen-typedef-decls -typedefs-dialect=tosa)
+mlir_tablegen(TosaOpsTypes.cpp.inc -gen-typedef-defs -typedefs-dialect=tosa)
 mlir_tablegen(TosaAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=tosa)
 mlir_tablegen(TosaAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=tosa)
 add_public_tablegen_target(MLIRTosaAttributesIncGen)
@@ -10,4 +12,3 @@ add_public_tablegen_target(MLIRTosaAttributesIncGen)
 set(LLVM_TARGET_DEFINITIONS TosaDialectBytecode.td)
 mlir_tablegen(TosaDialectBytecode.cpp.inc -gen-bytecode -bytecode-dialect="Tosa")
 add_public_tablegen_target(MLIRTosaDialectBytecodeIncGen)
-
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index d3f12c34421b06..a66c8b1975e5ba 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -45,6 +45,7 @@ def Tosa_Dialect : Dialect {
   let cppNamespace = "mlir::tosa";
   let hasConstantMaterializer = 1;
   let useDefaultAttributePrinterParser = 1;
+  let useDefaultTypePrinterParser = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
index 66512cbe350ec8..f00cb4c282db88 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
@@ -90,14 +90,57 @@ template <typename ConcreteType>
 class TosaElementwiseOperator
     : public TraitBase<ConcreteType, TosaElementwiseOperator> {};
 
+LogicalResult verifyTosaResolvableShapeOperands(Operation *op);
+/// This class verifies that tosa shape operands are compile time resolvable
+template <typename ConcreteType>
+class TosaResolvableShapeOperands
+    : public TraitBase<ConcreteType, TosaResolvableShapeOperands> {
+public:
+  static LogicalResult verifyTrait(Operation *op) {
+    return verifyTosaResolvableShapeOperands(op);
+  }
+};
+
+
+LogicalResult verifyTosaShapeOperator(Operation *op);
+/// This class indicates that op operates on tosa shape types
+template <typename ConcreteType>
+class TosaShapeOperator : public TraitBase<ConcreteType, TosaShapeOperator> {
+public:
+  static LogicalResult verifyTrait(Operation *op) {
+    return verifyTosaShapeOperator(op);
+  }
+};
+
+
+LogicalResult verifyTosaShapeOperatorWithSameRanks(Operation *op);
+/// This class indicates that op operates on tosa shape types
+template <typename ConcreteType>
+class TosaShapeOperatorWithSameRanks
+    : public TraitBase<ConcreteType, TosaShapeOperatorWithSameRanks> {
+public:
+  static LogicalResult verifyTrait(Operation *op) {
+    return verifyTosaShapeOperatorWithSameRanks(op);
+  }
+};
+
 } // namespace tosa
 } // namespace OpTrait
 
+namespace tosa {
+
+bool isa_tosa_shape_type(mlir::Type t);
+
+} // namespace tosa
+
 } // namespace mlir
 
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/Tosa/IR/TosaAttributes.h.inc"
 
+#define GET_TYPEDEF_CLASSES
+#include "mlir/Dialect/Tosa/IR/TosaOpsTypes.h.inc"
+
 #define GET_OP_CLASSES
 #include "mlir/Dialect/Tosa/IR/TosaOps.h.inc"
 
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 6b43c9a259b108..718ca361c05469 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -23,6 +23,7 @@ include "mlir/Dialect/Tosa/IR/TosaInterfaces.td"
 
 include "mlir/Dialect/Tosa/IR/TosaTypesBase.td"
 include "mlir/Dialect/Tosa/IR/TosaOpBase.td"
+include "mlir/Dialect/Tosa/IR/TosaTypes.td"
 
 //===----------------------------------------------------------------------===//
 // TOSA Spec Section 2.2
@@ -1689,12 +1690,16 @@ def Tosa_TileOp : Tosa_InferShapedTypeOp<"tile"> {
 
   let arguments = (ins
     Tosa_Tensor:$input1,
-    DenseI64ArrayAttr:$multiples);
+    Tosa_Shape:$multiples);
 
   let results = (outs
     Tosa_Tensor:$output
   );
 
+  let extraClassDeclaration = [{
+    LogicalResult getConstantMultiples(llvm::SmallVector<int64_t> &multiples);
+  }];
+
   let hasFolder = 1;
   let hasVerifier = 1;
 }
@@ -2106,4 +2111,6 @@ def Tosa_WhileOp : Tosa_Op<"while_loop", [
 
 include "mlir/Dialect/Tosa/IR/TosaUtilOps.td"
 
+include "mlir/Dialect/Tosa/IR/TosaShapeOps.td"
+
 #endif // TOSA_OPS
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
new file mode 100644
index 00000000000000..aacb04f77ce0e9
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
@@ -0,0 +1,79 @@
+//===-- TosaShapeOps.td - TOSA dialect utility operations --*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines shape operators for the TOSA dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef TOSA_SHAPE_OPS
+#define TOSA_SHAPE_OPS
+
+include "mlir/IR/OpBase.td"
+
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/InferTypeOpInterface.td"
+include "mlir/Interfaces/LoopLikeInterface.td"
+include "mlir/Dialect/Tosa/IR/TosaInterfaces.td"
+
+include "mlir/Dialect/Tosa/IR/TosaTypesBase.td"
+include "mlir/Dialect/Tosa/IR/TosaOpBase.td"
+include "mlir/Dialect/Tosa/IR/TosaTypes.td"
+
+// Op trait: operator has operands and results with TOSA shape type
+def TosaShapeOperator : NativeOpTrait<"TosaShapeOperator"> {
+  let cppNamespace = "mlir::OpTrait::tosa";
+}
+
+class Tosa_ShapeOp<string mnemonic, list<Trait> traits = []>
+    : Tosa_Op<mnemonic, !listconcat(traits, [TosaShapeOperator, Pure])> {
+
+  let assemblyFormat =
+      "operands attr-dict `:` functional-type(operands, results)";
+
+  let hasFolder = 1;
+}
+
+// op trait: shape operator has same ranks for operands and results
+def TosaShapeOperatorWithSameRanks : NativeOpTrait<"TosaShapeOperatorWithSameRanks"> {
+  let cppNamespace = "mlir::OpTrait::tosa";
+}
+
+class Tosa_ElementwiseShapeOp<string mnemonic, list<Trait> traits = []>
+    : Tosa_ShapeOp<mnemonic, !listconcat(traits, [TosaShapeOperatorWithSameRanks])> {
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: ConstShape
+//===----------------------------------------------------------------------===//
+def Tosa_ConstShapeOp : Tosa_ShapeOp<"const_shape", [ConstantLike, Pure]> {
+  let summary = "Constant Shape op.";
+
+  let description = [{
+    A node containing constant data for use as the input to an shape operation. May
+    hold data only in index data type.
+
+    Example:
+
+    ```mlir
+    // Generic form
+    %out = "tosa.const_shape"() {value = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
+    ```
+  }];
+
+  let arguments = (ins
+    IndexElementsAttr:$value
+  );
+
+  let results = (outs
+    Tosa_Shape:$output
+  );
+
+  let hasVerifier = 1;
+}
+
+#endif // TOSA_SHAPE_OPS
\ No newline at end of file
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypes.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypes.td
new file mode 100644
index 00000000000000..480248a3216af7
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypes.td
@@ -0,0 +1,87 @@
+//===-- TosaTypes.td - TOSA type definitions ---------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the type definitions for the TOSA dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef TOSA_TYPES
+#define TOSA_TYPES
+
+include "mlir/IR/AttrTypeBase.td"
+include "mlir/IR/OpBase.td"
+
+include "mlir/Dialect/Tosa/IR/TosaOpBase.td"
+
+//===----------------------------------------------------------------------===//
+// Tosa Type Definitions.
+//===----------------------------------------------------------------------===//
+
+// The base class for Tosa dialect types.
+class Tosa_Type<string name, string typeMnemonic, list<Trait> traits = []>
+    : TypeDef<Tosa_Dialect, name, traits> {
+  let mnemonic = typeMnemonic;
+}
+
+//===----------------------------------------------------------------------===//
+// ShapeType
+//===----------------------------------------------------------------------===//
+def Tosa_Shape : Tosa_Type<"shape", "shape"> {
+  let summary = "Shape with static rank and Index element type";
+  let description = [{
+    Syntax:
+
+    ```
+    shape-type ::= `shape` `<` rank `>`
+    ```
+    Values with shape type represents a shape with a fixed rank and a list of dimensions.
+    Rank must be zero or a positive integer.
+    Each dimension is represented by the builtin Index type.
+
+    Examples:
+
+    ```mlir
+    // Shape with rank of four, for example, [1, 1, 8, 16]:
+    !tosa.shape<4>
+
+    // Shape with rank of one, for example, [16]:
+    !tosa.shape<1>
+
+    // Shape with rank zero, for example, [] (i.e., shape of scalar values):
+    !tosa.shape<0>
+    ```
+  }];
+  let parameters = (ins
+    "int":$rank
+  );
+  let builders = [
+   TypeBuilder<(ins "int":$rank)>
+  ];
+  let assemblyFormat = "`<` $rank `>`";
+
+  let genVerifyDecl = 1;
+}
+
+def IsTosaShapeType : CPred<"mlir::tosa::isa_tosa_shape_type($_self)">;
+
+// Whether a Tosa Shape type has a rank equal to the specified rank.
+class IsTosaShapeOfRankPred<int rank> : And<[
+    IsTosaShapeType,
+    CPred<[{::llvm::cast<::mlir::tosa::shapeType>($_self).getRank() == }] # rank>
+]>;
+
+class TosaShapeOfRank<int rank> :
+    Type<IsTosaShapeOfRankPred<rank>,
+         "Tosa shape type of rank " # rank
+         >;
+
+def Rank1TosaShape : TosaShapeOfRank<1>;
+def Rank2TosaShape : TosaShapeOfRank<2>;
+def Rank4TosaShape : TosaShapeOfRank<4>;
+
+#endif // TOSA_TYPES
\ No newline at end of file
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 1d7ead16e8b631..9295afd36e3ab1 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1886,7 +1886,9 @@ struct TileConverter : public OpConversionPattern<tosa::TileOp> {
     auto elementTy = inputTy.getElementType();
     int64_t rank = inputTy.getRank();
 
-    ArrayRef<int64_t> multiples = op.getMultiples();
+    SmallVector<int64_t> multiples;
+    if (failed(op.getConstantMultiples(multiples)))
+      return failure();
 
     // Broadcast the newly added dimensions to their appropriate multiple.
     SmallVector<int64_t, 2> genericShape;
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
index 06a7262c467421..8dfa55bef74fc4 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
@@ -55,6 +55,7 @@ struct TosaToLinalg : public impl::TosaToLinalgBase<TosaToLinalg> {
     target.addLegalOp<tosa::ApplyScaleOp>();
     target.addLegalOp<tosa::IfOp>();
     target.addLegalOp<tosa::ConstOp>();
+    target.addLegalOp<tosa::ConstShapeOp>();
     target.addLegalOp<tosa::WhileOp>();
     target.addLegalOp<tosa::ConcatOp>();
     target.addLegalOp<tosa::SliceOp>();
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index f51c3dbce6eefe..f7a596f1ccb192 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -808,6 +808,8 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
 
 OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
 
+OpFoldResult ConstShapeOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
+
 #define REDUCE_FOLDER(OP)                                                      \
   OpFoldResult OP::fold(FoldAdaptor adaptor) {                                 \
     ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType());         \
@@ -985,9 +987,20 @@ OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
 }
 
 OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
-  bool allOnes = llvm::all_of(getMultiples(), [](int64_t v) { return v == 1; });
-  if (allOnes && getInput1().getType() == getType())
-    return getInput1();
+  if (getInput1().getType() == getType()) {
+    if (auto multiples = llvm::dyn_cast_if_present<DenseElementsAttr>(
+            adaptor.getMultiples())) {
+      if (multiples.isSplat() &&
+          multiples.getSplatValue<APInt>().getSExtValue() == 1)
+        return getInput1();
+      if (auto int_array_attr =
+              llvm::dyn_cast<DenseIntElementsAttr>(multiples)) {
+        if (llvm::all_of(int_array_attr.getValues<APInt>(),
+                         [](APInt v) { return v.getSExtValue() == 1; }))
+          return getInput1();
+      }
+    }
+  }
   return {};
 }
 
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 764a5db48e0787..3354f6908c55cd 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -130,6 +130,10 @@ SmallVector<Region *> tosa::WhileOp::getLoopRegions() { return {&getBody()}; }
 //===----------------------------------------------------------------------===//
 
 void TosaDialect::initialize() {
+  addTypes<
+#define GET_TYPEDEF_LIST
+#include "mlir/Dialect/Tosa/IR/TosaOpsTypes.cpp.inc"
+      >();
   addOperations<
 #define GET_OP_LIST
 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
@@ -153,6 +157,10 @@ Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value,
                                             Type type, Location loc) {
   // Tosa dialect constants only support ElementsAttr unlike standard dialect
   // constant which supports all attributes.
+  if (llvm::isa<shapeType>(type) && llvm::isa<DenseIntElementsAttr>(value)) {
+    return builder.create<tosa::ConstShapeOp>(
+        loc, type, llvm::cast<DenseIntElementsAttr>(value));
+  }
   if (llvm::isa<ElementsAttr>(value))
     return builder.create<tosa::ConstOp>(loc, type,
                                          llvm::cast<ElementsAttr>(value));
@@ -962,11 +970,32 @@ LogicalResult tosa::TableOp::verify() {
   return success();
 }
 
+LogicalResult
+tosa::TileOp::getConstantMultiples(SmallVector<int64_t> &multiples) {
+  // Multiples must be constants.
+  DenseIntElementsAttr multiplesAttr;
+  if (!matchPattern(getMultiples(), m_Constant(&multiplesAttr)))
+    return failure();
+  multiples = llvm::to_vector(
+      llvm::map_range(multiplesAttr.getValues<APInt>(),
+                      [](const APInt &val) { return val.getSExtValue(); }));
+  return success();
+}
+
 LogicalResult tosa::TileOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     TileOp::Adaptor adaptor,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
-  ArrayRef<int64_t> multiples = adaptor.getMultiples();
+  DenseIntElementsAttr multiplesAttr;
+  if (!matchPattern(adaptor.getMultiples(), m_Constant(&multiplesAttr)))
+    return failure();
+
+
+  // ArrayRef<int64_t> multiples = adaptor.getMultiples();
+  SmallVector<int64_t> multiples = llvm::to_vector(
+      llvm::map_range(multiplesAttr.getValues<APInt>(),
+                      [](const APInt &val) { return val.getSExtValue(); }));
+
   ShapeAdaptor inputShape(adaptor.getInput1().getType());
   SmallVector<int64_t> outputShape;
   if (!inputShape.hasRank()) {
@@ -992,20 +1021,25 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
 LogicalResult tosa::TileOp::verify() {
   ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
   ShapedType outputType = llvm::cast<ShapedType>(getType());
-  auto multiples = getMultiples();
+
+  shapeType multiplesType =
+      llvm::cast<tosa::shapeType>(getMultiples().getType());
+
+  auto multiplesRank = multiplesType.getRank();
 
   if (inputType.hasRank()) {
-    if (static_cast<size_t>(inputType.getRank()) != multiples.size())
-      return emitOpError("expect 'multiples' array to have length ")
-             << inputType.getRank() << " but got " << multiples.size() << ".";
+    if (inputType.getRank() != multiplesRank)
+      return emitOpError("expect 'multiples' to have rank ")
+             << inputType.getRank() << " but got " << multiplesRank << ".";
     if (outputType.hasRank() && inputType.getRank() != outputType.getRank())
       return emitOpError("expect same input and output tensor rank.");
-  } else if (outputType.hasRank() &&
-             static_cast<size_t>(outputType.getRank()) != multiples.size())
+  } else if (outputType.hasRank() && outputType.getRank() != multiplesRank)
     return emitOpError("expect 'multiples' array to have length ")
-           << outputType.getRank() << " but got " << multiples.size() << ".";
+           << outputType.getRank() << " but got " << multiplesRank << ".";
 
-  if (llvm::any_of(multiples, [](int64_t v) { return v <= 0 && v != -1; }))
+  SmallVector<int64_t> multiples;
+  if (getConstantMultiples(multiples).succeeded() &&
+      llvm::any_of(multiples, [](int64_t v) { return v <= 0 && v != -1; }))
     return emitOpError(
         "expect element of 'multiples' to be positive integer or -1.");
 
@@ -2146,6 +2180,91 @@ void WhileOp::print(OpAsmPrinter &parser) {
   parser.printOptionalAttrDictWithKeyword((*this)->getAttrs());
 }
 
+//===----------------------------------------------------------------------===//
+// TOSA Shape and Shape Operators Helper functions.
+//===----------------------------------------------------------------------===//
+
+bool mlir::tosa::isa_tosa_shape_type(mlir::Type t) {
+  return mlir::isa<tosa::shapeType>(t);
+}
+
+LogicalResult
+mlir::tosa::shapeType::verify(function_ref<InFlightDiagnostic()> emitError,
+                              int rank) {
+  if (rank < 0)
+    return emitError() << "invalid rank (must be >= 0): " << rank;
+  return success();
+}
+
+LogicalResult OpTrait::tosa::verifyTosaResolvableShapeOperands(Operation *op) {
+  for (auto v : op->getOperands()) {
+    if (mlir::isa<::mlir::tosa::shapeType>(v.getType())) {
+      Operation *definingOp = v.getDefiningOp();
+      if (!definingOp || !definingOp->hasTrait<TosaShapeOperator>()) {
+        return op->emitOpError("shape operand is not compile time resolvable");
+      }
+    }
+  }
+  return success();
+}
+
+LogicalResult OpTrait::tosa::verifyTosaShapeOperator(Operation *op) {
+  for (auto type : op->getOperandTypes()) {
+    if (!mlir::isa<mlir::tosa::shapeType>(type)) {
+      return op->emitOpError("must have operands with tosa shape type");
+    }
+  }
+  for (auto type : op->getResultTypes()) {
+    if (!mlir::isa<mlir::tosa::shapeType>(type)) {
+      return op->emitOpError("must have result with tosa shape type");
+    }
+  }
+  return success();
+}
+
+LogicalResult
+OpTrait::tosa::verifyTosaShapeOperatorWithSameRanks(Operation *op) {
+  if (failed(OpTrait::impl::verifyAtLeastNOperands(op, 1)) ||
+      failed(verifyTosaShapeOperator(op)))
+    return failure();
+
+  // delegate function that returns rank of shape type
+  auto getRank = [](const Type type) {
+    return mlir::cast<mlir::tosa::shapeType>(type).getRank();
+  };
+  auto operandTypes = op->getOperandTypes();
+  auto resultTypes = op->getResultTypes();
+
+  auto rank = getRank(*op->getOperandTypes().begin());
+  for (auto type : operandTypes) {
+    if (getRank(type) != rank) {
+      return op->emitOpError("operands don't have matching ranks");
+    }
+  }
+  for (auto type : resultTypes) {
+    if (getRank(type) != rank) {
+      return op->emitOpError("result shape has different rank than operands");
+    }
+  }
+  ret...
[truncated]

@Jerry-Ge Jerry-Ge requested review from eric-k256 and sjarus January 10, 2025 23:09
Copy link

github-actions bot commented Jan 10, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Contributor

@eric-k256 eric-k256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Jerry-Ge
Copy link
Member Author

Pushed a new patch to refactor TosaTypes.td content into TosaTypesBase.td

Copy link
Contributor

@lhutton1 lhutton1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR @Jerry-Ge! I had a look and left a few comments below, largely they're just questions along with some requests for better test coverage

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td Show resolved Hide resolved
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp Show resolved Hide resolved
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h Show resolved Hide resolved
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp Show resolved Hide resolved
@Jerry-Ge
Copy link
Member Author

Thanks for the PR @Jerry-Ge! I had a look and left a few comments below, largely they're just questions along with some requests for better test coverage

Thanks Luke! This patch is primarily for aligning the internal and upstream branches. To keep things clean, I will achieve this in multiple patches.

Adds:
 1. tosa shape type to Tosa dialect e.g., !tosa.shape<4> is a type for
    rank-4 shape values (size-4 array of index values)
 2. const_shape operator
 3. trait TosaShapeOperator, added to tosa shape operators, and a
    verifier that all operands and results of operator are tosa shapes
 4. trait TosaResolvableShapeOperands, added to all tosa operators, and
    a verifier that every tosa shape operand is produced by a tosa shape
    operator (indicated by trait TosaShapeOperator)
 5. trait TosaShapeOperatorWithSameRanks, added to
    Tosa_ElementwiseShapeOp and a verifier that all operands and result
    shapes have same ranks
 5. changed TileOp's multiples from attribute to input, of !tosa.shape
    type.
 6. add folder for tosa ConstShape operator

Signed-off-by: Jerry Ge <[email protected]>
Signed-off-by: Tai Ly <[email protected]>

Change-Id: I0213f99f5816b648f732b01fe8bd196956f1dfc8
@Jerry-Ge Jerry-Ge merged commit f09db6a into llvm:main Jan 14, 2025
8 checks passed
paulhuggett pushed a commit to paulhuggett/llvm-project that referenced this pull request Jan 16, 2025
Adds:
1. tosa shape type to Tosa dialect e.g., !tosa.shape<4> is a type for
rank-4 shape values (size-4 array of index values)
 2. const_shape operator
3. trait TosaShapeOperator, added to tosa shape operators, and a
verifier that all operands and results of operator are tosa shapes
4. trait TosaResolvableShapeOperands, added to all tosa operators, and a
verifier that every tosa shape operand is produced by a tosa shape
operator (indicated by trait TosaShapeOperator)
5. trait TosaShapeOperatorWithSameRanks, added to
Tosa_ElementwiseShapeOp and a verifier that all operands and result
shapes have same ranks
5. changed TileOp's multiples from attribute to input, of !tosa.shape
type.
 6. add folder for tosa ConstShape operator

This patch was originally authored by Tai Ly <[email protected]>

Signed-off-by: Jerry Ge <[email protected]>
Signed-off-by: Tai Ly <[email protected]>
@GeorgeARM
Copy link
Contributor

Thanks for the PR @Jerry-Ge! I had a look and left a few comments below, largely they're just questions along with some requests for better test coverage

Thanks Luke! This patch is primarily for aligning the internal and upstream branches. To keep things clean, I will achieve this in multiple patches.

Great work @Jerry-Ge,
It would be great though to keep patches nicely self-contained and layered. So please let's try and do this. If something can be fixed in this patch we should try and do it; not let it spill over to next ones.

@Jerry-Ge
Copy link
Member Author

Thanks for the PR @Jerry-Ge! I had a look and left a few comments below, largely they're just questions along with some requests for better test coverage

Thanks Luke! This patch is primarily for aligning the internal and upstream branches. To keep things clean, I will achieve this in multiple patches.

Great work @Jerry-Ge, It would be great though to keep patches nicely self-contained and layered. So please let's try and do this. If something can be fixed in this patch we should try and do it; not let it spill over to next ones.

Thanks for the comments. Yes, I will try my best to achieve that.

DKLoehr pushed a commit to DKLoehr/llvm-project that referenced this pull request Jan 17, 2025
Adds:
1. tosa shape type to Tosa dialect e.g., !tosa.shape<4> is a type for
rank-4 shape values (size-4 array of index values)
 2. const_shape operator
3. trait TosaShapeOperator, added to tosa shape operators, and a
verifier that all operands and results of operator are tosa shapes
4. trait TosaResolvableShapeOperands, added to all tosa operators, and a
verifier that every tosa shape operand is produced by a tosa shape
operator (indicated by trait TosaShapeOperator)
5. trait TosaShapeOperatorWithSameRanks, added to
Tosa_ElementwiseShapeOp and a verifier that all operands and result
shapes have same ranks
5. changed TileOp's multiples from attribute to input, of !tosa.shape
type.
 6. add folder for tosa ConstShape operator

This patch was originally authored by Tai Ly <[email protected]>

Signed-off-by: Jerry Ge <[email protected]>
Signed-off-by: Tai Ly <[email protected]>
nirvedhmeshram added a commit to iree-org/llvm-project that referenced this pull request Jan 20, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bazel "Peripheral" support tier build system: utils/bazel mlir:linalg mlir:tosa mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants