From f09db6a3af971ab7d9bbc7ba574a8dc0c10b2940 Mon Sep 17 00:00:00 2001 From: Jerry-Ge Date: Tue, 14 Jan 2025 13:41:08 -0800 Subject: [PATCH] [TOSA] Add Tosa_Shape type and ConstShapeOp (#122547) Adds: 1. tosa shape type to Tosa dialect e.g., !tosa.shape<4> is a type for rank-4 shape values (size-4 array of index values) 2. const_shape operator 3. trait TosaShapeOperator, added to tosa shape operators, and a verifier that all operands and results of operator are tosa shapes 4. trait TosaResolvableShapeOperands, added to all tosa operators, and a verifier that every tosa shape operand is produced by a tosa shape operator (indicated by trait TosaShapeOperator) 5. trait TosaShapeOperatorWithSameRanks, added to Tosa_ElementwiseShapeOp and a verifier that all operands and result shapes have same ranks 5. changed TileOp's multiples from attribute to input, of !tosa.shape type. 6. add folder for tosa ConstShape operator This patch was originally authored by Tai Ly Signed-off-by: Jerry Ge Signed-off-by: Tai Ly --- .../mlir/Dialect/Tosa/IR/CMakeLists.txt | 3 +- .../mlir/Dialect/Tosa/IR/TosaOpBase.td | 12 +- mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h | 41 +++++ mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 8 +- .../mlir/Dialect/Tosa/IR/TosaShapeOps.td | 77 ++++++++++ .../mlir/Dialect/Tosa/IR/TosaTypesBase.td | 65 ++++++++ .../Conversion/TosaToLinalg/TosaToLinalg.cpp | 4 +- .../TosaToLinalg/TosaToLinalgPass.cpp | 1 + .../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 19 ++- mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 141 ++++++++++++++++-- .../Tosa/Transforms/TosaValidation.cpp | 2 + .../TosaToLinalg/tosa-to-linalg.mlir | 15 +- mlir/test/Dialect/Tosa/canonicalize.mlir | 9 +- mlir/test/Dialect/Tosa/invalid.mlir | 33 +++- mlir/test/Dialect/Tosa/level_check.mlir | 4 +- mlir/test/Dialect/Tosa/ops.mlir | 10 +- mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir | 6 +- .../llvm-project-overlay/mlir/BUILD.bazel | 8 + 18 files changed, 425 insertions(+), 33 deletions(-) create mode 100644 mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td diff --git a/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt index 1ee105f0ceb98..cc8d5ed9b0044 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt @@ -3,6 +3,8 @@ add_mlir_doc(TosaOps TosaOps Dialects/ -gen-op-doc) add_mlir_interface(TosaInterfaces) set(LLVM_TARGET_DEFINITIONS TosaOps.td) +mlir_tablegen(TosaOpsTypesBase.h.inc -gen-typedef-decls -typedefs-dialect=tosa) +mlir_tablegen(TosaOpsTypesBase.cpp.inc -gen-typedef-defs -typedefs-dialect=tosa) mlir_tablegen(TosaAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=tosa) mlir_tablegen(TosaAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=tosa) add_public_tablegen_target(MLIRTosaAttributesIncGen) @@ -10,4 +12,3 @@ add_public_tablegen_target(MLIRTosaAttributesIncGen) set(LLVM_TARGET_DEFINITIONS TosaDialectBytecode.td) mlir_tablegen(TosaDialectBytecode.cpp.inc -gen-bytecode -bytecode-dialect="Tosa") add_public_tablegen_target(MLIRTosaDialectBytecodeIncGen) - diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td index d3f12c34421b0..47cda3c9f481e 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td @@ -45,6 +45,7 @@ def Tosa_Dialect : Dialect { let cppNamespace = "mlir::tosa"; let hasConstantMaterializer = 1; let useDefaultAttributePrinterParser = 1; + let useDefaultTypePrinterParser = 1; } //===----------------------------------------------------------------------===// @@ -217,12 +218,21 @@ def TosaElementwiseOperator : NativeOpTrait<"TosaElementwiseOperator"> { let cppNamespace = "mlir::OpTrait::tosa"; } +//===----------------------------------------------------------------------===// +// TOSA Operator Trait. +//===----------------------------------------------------------------------===// +// Op operands with TOSA shape types must be compile time resolvable +def TosaResolvableShapeOperands : NativeOpTrait<"TosaResolvableShapeOperands"> { + let cppNamespace = "mlir::OpTrait::tosa"; +} + //===----------------------------------------------------------------------===// // TOSA Operator Class. //===----------------------------------------------------------------------===// class Tosa_Op traits = []> : - Op { + Op { } class Tosa_ElementwiseOp traits = []> : diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h index 66512cbe350ec..e4f5d09064cd7 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h @@ -90,14 +90,55 @@ template class TosaElementwiseOperator : public TraitBase {}; +LogicalResult verifyTosaResolvableShapeOperands(Operation *op); +/// This class verifies that tosa shape operands are compile time resolvable +template +class TosaResolvableShapeOperands + : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return verifyTosaResolvableShapeOperands(op); + } +}; + +LogicalResult verifyTosaShapeOperator(Operation *op); +/// This class indicates that op operates on tosa shape types +template +class TosaShapeOperator : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return verifyTosaShapeOperator(op); + } +}; + +LogicalResult verifyTosaShapeOperatorWithSameRanks(Operation *op); +/// This class indicates that op operates on tosa shape types +template +class TosaShapeOperatorWithSameRanks + : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return verifyTosaShapeOperatorWithSameRanks(op); + } +}; + } // namespace tosa } // namespace OpTrait +namespace tosa { + +bool isa_tosa_shape_type(mlir::Type t); + +} // namespace tosa + } // namespace mlir #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/Tosa/IR/TosaAttributes.h.inc" +#define GET_TYPEDEF_CLASSES +#include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.h.inc" + #define GET_OP_CLASSES #include "mlir/Dialect/Tosa/IR/TosaOps.h.inc" diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index 6b43c9a259b10..e1efa7a3001b9 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1689,12 +1689,16 @@ def Tosa_TileOp : Tosa_InferShapedTypeOp<"tile"> { let arguments = (ins Tosa_Tensor:$input1, - DenseI64ArrayAttr:$multiples); + Tosa_Shape:$multiples); let results = (outs Tosa_Tensor:$output ); + let extraClassDeclaration = [{ + LogicalResult getConstantMultiples(llvm::SmallVector &multiples); + }]; + let hasFolder = 1; let hasVerifier = 1; } @@ -2106,4 +2110,6 @@ def Tosa_WhileOp : Tosa_Op<"while_loop", [ include "mlir/Dialect/Tosa/IR/TosaUtilOps.td" +include "mlir/Dialect/Tosa/IR/TosaShapeOps.td" + #endif // TOSA_OPS diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td new file mode 100644 index 0000000000000..597dc32e84402 --- /dev/null +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td @@ -0,0 +1,77 @@ +//===-- TosaShapeOps.td - TOSA dialect utility operations --*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines shape operators for the TOSA dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef TOSA_SHAPE_OPS +#define TOSA_SHAPE_OPS + +include "mlir/IR/OpBase.td" + +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/LoopLikeInterface.td" +include "mlir/Dialect/Tosa/IR/TosaInterfaces.td" + +include "mlir/Dialect/Tosa/IR/TosaTypesBase.td" +include "mlir/Dialect/Tosa/IR/TosaOpBase.td" + +// Op trait: operator has operands and results with TOSA shape type +def TosaShapeOperator : NativeOpTrait<"TosaShapeOperator"> { + let cppNamespace = "mlir::OpTrait::tosa"; +} + +class Tosa_ShapeOp traits = []> + : Tosa_Op { + + let assemblyFormat = + "operands attr-dict `:` functional-type(operands, results)"; + + let hasFolder = 1; +} + +// op trait: shape operator has same ranks for operands and results +def TosaShapeOperatorWithSameRanks + : NativeOpTrait<"TosaShapeOperatorWithSameRanks"> { + let cppNamespace = "mlir::OpTrait::tosa"; +} + +class Tosa_ElementwiseShapeOp traits = []> + : Tosa_ShapeOp { +} + + +//===----------------------------------------------------------------------===// +// Operator: ConstShape +//===----------------------------------------------------------------------===// +def Tosa_ConstShapeOp : Tosa_ShapeOp<"const_shape", [ConstantLike, Pure]> { + let summary = "Constant Shape op."; + + let description = [{ + A node containing constant data for use as the input to an shape operation. May + hold data only in index data type. + + Example: + + ```mlir + // Generic form + %out = "tosa.const_shape"() {value = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> + ``` + }]; + + let arguments = (ins IndexElementsAttr : $value); + + let results = (outs Tosa_Shape : $output); + + let hasVerifier = 1; +} + +#endif // TOSA_SHAPE_OPS diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td index d3cc6e92bac22..13325fb0ab9a2 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td @@ -13,8 +13,11 @@ #ifndef TOSA_TYPES_BASE #define TOSA_TYPES_BASE +include "mlir/IR/AttrTypeBase.td" include "mlir/IR/OpBase.td" +include "mlir/Dialect/Tosa/IR/TosaOpBase.td" + //===----------------------------------------------------------------------===// // Tosa Type Definitions. //===----------------------------------------------------------------------===// @@ -215,4 +218,66 @@ def Tosa_Buffer : MemRefOf<[Tosa_AnyNumber]>; def Tosa_TupleBuffer : NestedTupleOf<[Tosa_Buffer]>; def Tosa_BufOrTuple : AnyTypeOf<[Tosa_Buffer, Tosa_TupleBuffer]>; +//===----------------------------------------------------------------------===// +// Tosa Type Definitions. +//===----------------------------------------------------------------------===// + +// The base class for Tosa dialect types. +class Tosa_Type traits = []> + : TypeDef { + let mnemonic = typeMnemonic; +} + +//===----------------------------------------------------------------------===// +// ShapeType +//===----------------------------------------------------------------------===// +def Tosa_Shape : Tosa_Type<"shape", "shape"> { + let summary = "Shape with static rank and Index element type"; + let description = [{ + Syntax: + + ``` shape - type :: = `shape` `<` rank `>` + ``` Values with shape type represents a shape with a fixed rank and a list + of dimensions + .Rank must be zero or a positive integer + .Each dimension is represented by the builtin + Index type. + + Examples: + + ```mlir + // Shape with rank of four, for example, [1, 1, 8, 16]: + !tosa + .shape<4> + + // Shape with rank of one, for example, [16]: + !tosa + .shape<1> + + // Shape with rank zero, for example, [] (i.e., shape of scalar values): + !tosa.shape<0> + ``` + }]; + let parameters = (ins "int" : $rank); + let builders = [TypeBuilder<(ins "int" : $rank)>]; + let assemblyFormat = "`<` $rank `>`"; + + let genVerifyDecl = 1; +} + +def IsTosaShapeType : CPred<"mlir::tosa::isa_tosa_shape_type($_self)">; + +// Whether a Tosa Shape type has a rank equal to the specified rank. +class IsTosaShapeOfRankPred : And<[ + IsTosaShapeType, + CPred<[{::llvm::cast<::mlir::tosa::shapeType>($_self).getRank() == }] # rank> +]>; + +class TosaShapeOfRank + : Type, "Tosa shape type of rank " #rank>; + +def Rank1TosaShape : TosaShapeOfRank<1>; +def Rank2TosaShape : TosaShapeOfRank<2>; +def Rank4TosaShape : TosaShapeOfRank<4>; + #endif // TOSA_TYPES_BASE diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index 1d7ead16e8b63..9295afd36e3ab 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -1886,7 +1886,9 @@ struct TileConverter : public OpConversionPattern { auto elementTy = inputTy.getElementType(); int64_t rank = inputTy.getRank(); - ArrayRef multiples = op.getMultiples(); + SmallVector multiples; + if (failed(op.getConstantMultiples(multiples))) + return failure(); // Broadcast the newly added dimensions to their appropriate multiple. SmallVector genericShape; diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp index 06a7262c46742..8dfa55bef74fc 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp @@ -55,6 +55,7 @@ struct TosaToLinalg : public impl::TosaToLinalgBase { target.addLegalOp(); target.addLegalOp(); target.addLegalOp(); + target.addLegalOp(); target.addLegalOp(); target.addLegalOp(); target.addLegalOp(); diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index f51c3dbce6eef..f7a596f1ccb19 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -808,6 +808,8 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) { OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValueAttr(); } +OpFoldResult ConstShapeOp::fold(FoldAdaptor adaptor) { return getValueAttr(); } + #define REDUCE_FOLDER(OP) \ OpFoldResult OP::fold(FoldAdaptor adaptor) { \ ShapedType inputTy = llvm::cast(getInput().getType()); \ @@ -985,9 +987,20 @@ OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) { } OpFoldResult TileOp::fold(FoldAdaptor adaptor) { - bool allOnes = llvm::all_of(getMultiples(), [](int64_t v) { return v == 1; }); - if (allOnes && getInput1().getType() == getType()) - return getInput1(); + if (getInput1().getType() == getType()) { + if (auto multiples = llvm::dyn_cast_if_present( + adaptor.getMultiples())) { + if (multiples.isSplat() && + multiples.getSplatValue().getSExtValue() == 1) + return getInput1(); + if (auto int_array_attr = + llvm::dyn_cast(multiples)) { + if (llvm::all_of(int_array_attr.getValues(), + [](APInt v) { return v.getSExtValue() == 1; })) + return getInput1(); + } + } + } return {}; } diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 764a5db48e078..83cf4a9415fe6 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -130,6 +130,10 @@ SmallVector tosa::WhileOp::getLoopRegions() { return {&getBody()}; } //===----------------------------------------------------------------------===// void TosaDialect::initialize() { + addTypes< +#define GET_TYPEDEF_LIST +#include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc" + >(); addOperations< #define GET_OP_LIST #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc" @@ -153,6 +157,10 @@ Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { // Tosa dialect constants only support ElementsAttr unlike standard dialect // constant which supports all attributes. + if (llvm::isa(type) && llvm::isa(value)) { + return builder.create( + loc, type, llvm::cast(value)); + } if (llvm::isa(value)) return builder.create(loc, type, llvm::cast(value)); @@ -962,11 +970,30 @@ LogicalResult tosa::TableOp::verify() { return success(); } +LogicalResult +tosa::TileOp::getConstantMultiples(SmallVector &multiples) { + // Multiples must be constants. + DenseIntElementsAttr multiplesAttr; + if (!matchPattern(getMultiples(), m_Constant(&multiplesAttr))) + return failure(); + multiples = llvm::to_vector( + llvm::map_range(multiplesAttr.getValues(), + [](const APInt &val) { return val.getSExtValue(); })); + return success(); +} + LogicalResult tosa::TileOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, TileOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { - ArrayRef multiples = adaptor.getMultiples(); + DenseIntElementsAttr multiplesAttr; + if (!matchPattern(adaptor.getMultiples(), m_Constant(&multiplesAttr))) + return failure(); + + SmallVector multiples = llvm::to_vector( + llvm::map_range(multiplesAttr.getValues(), + [](const APInt &val) { return val.getSExtValue(); })); + ShapeAdaptor inputShape(adaptor.getInput1().getType()); SmallVector outputShape; if (!inputShape.hasRank()) { @@ -992,20 +1019,25 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents( LogicalResult tosa::TileOp::verify() { ShapedType inputType = llvm::cast(getInput1().getType()); ShapedType outputType = llvm::cast(getType()); - auto multiples = getMultiples(); + + shapeType multiplesType = + llvm::cast(getMultiples().getType()); + + auto multiplesRank = multiplesType.getRank(); if (inputType.hasRank()) { - if (static_cast(inputType.getRank()) != multiples.size()) - return emitOpError("expect 'multiples' array to have length ") - << inputType.getRank() << " but got " << multiples.size() << "."; + if (inputType.getRank() != multiplesRank) + return emitOpError("expect 'multiples' to have rank ") + << inputType.getRank() << " but got " << multiplesRank << "."; if (outputType.hasRank() && inputType.getRank() != outputType.getRank()) return emitOpError("expect same input and output tensor rank."); - } else if (outputType.hasRank() && - static_cast(outputType.getRank()) != multiples.size()) + } else if (outputType.hasRank() && outputType.getRank() != multiplesRank) return emitOpError("expect 'multiples' array to have length ") - << outputType.getRank() << " but got " << multiples.size() << "."; + << outputType.getRank() << " but got " << multiplesRank << "."; - if (llvm::any_of(multiples, [](int64_t v) { return v <= 0 && v != -1; })) + SmallVector multiples; + if (getConstantMultiples(multiples).succeeded() && + llvm::any_of(multiples, [](int64_t v) { return v <= 0 && v != -1; })) return emitOpError( "expect element of 'multiples' to be positive integer or -1."); @@ -2146,6 +2178,91 @@ void WhileOp::print(OpAsmPrinter &parser) { parser.printOptionalAttrDictWithKeyword((*this)->getAttrs()); } +//===----------------------------------------------------------------------===// +// TOSA Shape and Shape Operators Helper functions. +//===----------------------------------------------------------------------===// + +bool mlir::tosa::isa_tosa_shape_type(mlir::Type t) { + return mlir::isa(t); +} + +LogicalResult +mlir::tosa::shapeType::verify(function_ref emitError, + int rank) { + if (rank < 0) + return emitError() << "invalid rank (must be >= 0): " << rank; + return success(); +} + +LogicalResult OpTrait::tosa::verifyTosaResolvableShapeOperands(Operation *op) { + for (auto v : op->getOperands()) { + if (mlir::isa<::mlir::tosa::shapeType>(v.getType())) { + Operation *definingOp = v.getDefiningOp(); + if (!definingOp || !definingOp->hasTrait()) { + return op->emitOpError("shape operand is not compile time resolvable"); + } + } + } + return success(); +} + +LogicalResult OpTrait::tosa::verifyTosaShapeOperator(Operation *op) { + for (auto type : op->getOperandTypes()) { + if (!mlir::isa(type)) { + return op->emitOpError("must have operands with tosa shape type"); + } + } + for (auto type : op->getResultTypes()) { + if (!mlir::isa(type)) { + return op->emitOpError("must have result with tosa shape type"); + } + } + return success(); +} + +LogicalResult +OpTrait::tosa::verifyTosaShapeOperatorWithSameRanks(Operation *op) { + if (failed(OpTrait::impl::verifyAtLeastNOperands(op, 1)) || + failed(verifyTosaShapeOperator(op))) + return failure(); + + // delegate function that returns rank of shape type + auto getRank = [](const Type type) { + return mlir::cast(type).getRank(); + }; + auto operandTypes = op->getOperandTypes(); + auto resultTypes = op->getResultTypes(); + + auto rank = getRank(*op->getOperandTypes().begin()); + for (auto type : operandTypes) { + if (getRank(type) != rank) { + return op->emitOpError("operands don't have matching ranks"); + } + } + for (auto type : resultTypes) { + if (getRank(type) != rank) { + return op->emitOpError("result shape has different rank than operands"); + } + } + return success(); +} + +//===----------------------------------------------------------------------===// +// TOSA Shape Operators verify functions. +//===----------------------------------------------------------------------===// + +LogicalResult tosa::ConstShapeOp::verify() { + // check that number of elements in value attr equal to rank of result shape + auto count = getValue().getNumElements(); + auto rank = (cast(getResult().getType())).getRank(); + if (!(count == rank || (count == 1 && rank == 0))) { + return emitOpError("expect number of elements in attribute value (") + << count << ") to be equal to the rank (" << rank + << ") for the result shape type"; + } + return success(); +} + //===----------------------------------------------------------------------===// // TOSA Attribute Definitions. //===----------------------------------------------------------------------===// @@ -2153,6 +2270,12 @@ void WhileOp::print(OpAsmPrinter &parser) { #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc" +//===----------------------------------------------------------------------===// +// TOSA Type Definitions. +//===----------------------------------------------------------------------===// +#define GET_TYPEDEF_CLASSES +#include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc" + //===----------------------------------------------------------------------===// // TOSA Operator Definitions. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp index 8588c878bfe4f..a49870687fdc6 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -536,6 +536,8 @@ bool TosaValidation::isValidElementType(Type type) { return true; } } + } else if (mlir::isa(type)) { + return true; } return false; } diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir index c840fb8648d7b..1d235092b71d5 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -1378,21 +1378,24 @@ func.func @tile(%arg0 : tensor<2x3xi8>) -> () { // CHECK: ^bb0(%[[ARG1:[0-9a-zA-Z_]+]]: i8 // CHECK: linalg.yield %[[ARG1]] : i8 // CHECK: tosa.reshape [[GENERIC]] {new_shape = array} - %0 = tosa.tile %arg0 {multiples = array} : (tensor<2x3xi8>) -> tensor<4x3xi8> + %cst21 = tosa.const_shape { value = dense<[2, 1]> : tensor<2xindex> } : () -> !tosa.shape<2> + %0 = tosa.tile %arg0, %cst21: (tensor<2x3xi8>, !tosa.shape<2>) -> tensor<4x3xi8> // CHECK: [[INIT:%.+]] = tensor.empty() // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor<2x3xi8>) outs([[INIT]] : tensor<1x2x2x3xi8>) // CHECK: ^bb0(%[[ARG1:[0-9a-zA-Z_]+]]: i8 // CHECK: linalg.yield %[[ARG1]] : i8 // CHECK: tosa.reshape [[GENERIC]] {new_shape = array} - %1 = tosa.tile %arg0 {multiples = array} : (tensor<2x3xi8>) -> tensor<2x6xi8> + %cst12 = tosa.const_shape { value = dense<[1, 2]> : tensor<2xindex> } : () -> !tosa.shape<2> + %1 = tosa.tile %arg0, %cst12: (tensor<2x3xi8>, !tosa.shape<2>) -> tensor<2x6xi8> // CHECK: [[INIT:%.+]] = tensor.empty() // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor<2x3xi8>) outs([[INIT]] : tensor<5x2x7x3xi8>) // CHECK: ^bb0(%[[ARG1:[0-9a-zA-Z_]+]]: i8 // CHECK: linalg.yield %[[ARG1]] : i8 // CHECK: tosa.reshape [[GENERIC]] {new_shape = array} - %2 = tosa.tile %arg0 {multiples = array} : (tensor<2x3xi8>) -> tensor<10x21xi8> + %cst57 = tosa.const_shape { value = dense<[5, 7]> : tensor<2xindex> } : () -> !tosa.shape<2> + %2 = tosa.tile %arg0, %cst57: (tensor<2x3xi8>, !tosa.shape<2>) -> tensor<10x21xi8> return } @@ -1412,7 +1415,8 @@ func.func @tile_dyn_input(%arg0 : tensor) -> () { // CHECK: ^bb0(%[[ARG1:.+]]: i8, // CHECK: linalg.yield %[[ARG1]] : i8 // CHECK: tosa.reshape %[[GENERIC]] {new_shape = array} - %0 = tosa.tile %arg0 {multiples = array} : (tensor) -> tensor + %cst21 = tosa.const_shape { value = dense<[2, 1]> : tensor<2xindex> } : () -> !tosa.shape<2> + %0 = tosa.tile %arg0, %cst21: (tensor, !tosa.shape<2>) -> tensor return } @@ -1432,7 +1436,8 @@ func.func @tile_dyn_multiples(%arg0 : tensor<2x3xi8>) -> () { // CHECK: ^bb0(%[[ARG1:.+]]: i8, // CHECK: linalg.yield %[[ARG1]] : i8 // CHECK: tosa.reshape %[[GENERIC]] {new_shape = array} - %0 = tosa.tile %arg0 {multiples = array} : (tensor<2x3xi8>) -> tensor<2x?xi8> + %cst = tosa.const_shape { value = dense<[2, -1]> : tensor<2xindex> } : () -> !tosa.shape<2> + %0 = tosa.tile %arg0, %cst: (tensor<2x3xi8>, !tosa.shape<2>) -> tensor<2x?xi8> return } diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index 60121bb0ea2f1..889e2eda9e5b8 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -588,7 +588,8 @@ func.func @slice_nofold(%arg0: tensor) -> tensor { // CHECK-LABEL: @tile_fold func.func @tile_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> { // CHECK: return %arg0 - %0 = tosa.tile %arg0 { multiples = array }: (tensor<3x4xf32>) -> tensor<3x4xf32> + %cst = tosa.const_shape { value = dense<1> : tensor<2xindex> } : () -> !tosa.shape<2> + %0 = tosa.tile %arg0, %cst: (tensor<3x4xf32>, !tosa.shape<2>) -> tensor<3x4xf32> return %0 : tensor<3x4xf32> } @@ -597,7 +598,8 @@ func.func @tile_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> { // CHECK-LABEL: @tile_nofold func.func @tile_nofold(%arg0: tensor<3x4xf32>) -> tensor<3x8xf32> { // CHECK: tosa.tile - %0 = tosa.tile %arg0 { multiples = array }: (tensor<3x4xf32>) -> tensor<3x8xf32> + %cst = tosa.const_shape { value = dense<[1, 2]> : tensor<2xindex> } : () -> !tosa.shape<2> + %0 = tosa.tile %arg0, %cst: (tensor<3x4xf32>, !tosa.shape<2>) -> tensor<3x8xf32> return %0 : tensor<3x8xf32> } @@ -763,7 +765,8 @@ func.func @fold_reduce_rank_zero() { func.func nested @fold_tile_rank_zero() -> tensor { // CHECK-NOT: tosa.tile %0 = tensor.empty() : tensor - %1 = tosa.tile %0 {multiples = array} : (tensor) -> tensor + %cst = tosa.const_shape { value = dense<> : tensor<0xindex> } : () -> !tosa.shape<0> + %1 = tosa.tile %0, %cst : (tensor, !tosa.shape<0>) -> tensor return %1 : tensor } diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index a6d57f8a2f61f..cc7fd009f01fa 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -621,8 +621,9 @@ func.func @test_slice_invalid_size() { func.func @test_tile_invalid_multiples() { %0 = tensor.empty() : tensor<4x31x31xf32> - // expected-error@+1 {{'tosa.tile' op expect 'multiples' array to have length 3 but got 0.}} - %1 = tosa.tile %0 {multiples = array} : (tensor<4x31x31xf32>) -> tensor<4x31x31xf32> + %cst = tosa.const_shape { value = dense<1> : tensor<1xindex> } : () -> !tosa.shape<1> + // expected-error@+1 {{'tosa.tile' op expect 'multiples' to have rank 3 but got 1.}} + %1 = tosa.tile %0, %cst: (tensor<4x31x31xf32>, !tosa.shape<1>) -> tensor<4x31x31xf32> return } @@ -630,8 +631,9 @@ func.func @test_tile_invalid_multiples() { func.func @test_tile_invalid_multiples_value() { %0 = tensor.empty() : tensor<4x31xf32> + %multiples = tosa.const_shape { value = dense<[2, -2]> : tensor<2xindex> } : () -> !tosa.shape<2> // expected-error@+1 {{'tosa.tile' op expect element of 'multiples' to be positive integer or -1.}} - %1 = tosa.tile %0 {multiples = array} : (tensor<4x31xf32>) -> tensor<4x31xf32> + %1 = tosa.tile %0, %multiples : (tensor<4x31xf32>, !tosa.shape<2>) -> tensor<4x31xf32> return } @@ -639,8 +641,9 @@ func.func @test_tile_invalid_multiples_value() { func.func @test_tile_io_rank_mismatch() { %0 = tensor.empty() : tensor<4x31xf32> + %multiples = tosa.const_shape { value = dense<[2, 2]> : tensor<2xindex> } : () -> !tosa.shape<2> // expected-error@+1 {{'tosa.tile' op expect same input and output tensor rank.}} - %1 = tosa.tile %0 {multiples = array} : (tensor<4x31xf32>) -> tensor<4x31x31xf32> + %1 = tosa.tile %0, %multiples : (tensor<4x31xf32>, !tosa.shape<2>) -> tensor<4x31x31xf32> return } @@ -993,3 +996,25 @@ func.func @test_non_tosa_ops() { %2 = tensor.empty(%0) : tensor return } + +// ----- + +// expected-error@+1 {{invalid rank (must be >= 0): -1}} +func.func @test_shape_type(%arg0: !tosa.shape<-1>) -> !tosa.shape<-1> { + return %arg0 : !tosa.shape<-1> +} + +// ----- +func.func @test_const_shape() -> !tosa.shape<4> { + // expected-error@+1 {{'tosa.const_shape' op attribute 'value' failed to satisfy constraint: index elements attribute}} + %cst = tosa.const_shape {value = dense<[1, 2, 3, 4]> : tensor<4xi32>} : () -> !tosa.shape<4> + return %cst : !tosa.shape<4> +} + +// ----- + +func.func @test_const_shape_value() -> !tosa.shape<5> { + // expected-error@+1 {{'tosa.const_shape' op expect number of elements in attribute value (4) to be equal to the rank (5) for the result shape type}} + %cst = tosa.const_shape {value = dense<[1, 2, 3, 4]> : tensor<4xindex>} : () -> !tosa.shape<5> + return %cst : !tosa.shape<5> +} diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir index ba8ed8a1e5f50..0fe35d88f0e73 100644 --- a/mlir/test/Dialect/Tosa/level_check.mlir +++ b/mlir/test/Dialect/Tosa/level_check.mlir @@ -95,8 +95,9 @@ func.func @test_slice(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x4x11 // ----- // CHECK-LABEL: tile func.func @test_tile(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x39x21x6xf32> { + %cst = tosa.const_shape { value = dense<[1, 1, 1, 1, 3, 1, 2]> : tensor<7xindex> } : () -> !tosa.shape<7> // expected-error@+1 {{'tosa.tile' op failed level check: operand rank(shape) <= MAX_RANK}} - %0 = "tosa.tile"(%arg0) {multiples = array} : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x39x21x6xf32> + %0 = tosa.tile %arg0, %cst : (tensor<1x1x1x1x13x21x3xf32>, !tosa.shape<7>) -> tensor<1x1x1x1x39x21x6xf32> return %0 : tensor<1x1x1x1x39x21x6xf32> } @@ -740,4 +741,3 @@ func.func @test_unranked_tensor(%arg0: tensor<*xf32>) { (tensor<*xf32>) -> tensor<*xf32> return } - diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir index f2e1cff72ab28..690e208af1e5f 100644 --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -562,7 +562,8 @@ func.func @test_slice(%arg0: tensor<13x21x3xf32>) -> tensor<4x11x1xf32> { // ----- // CHECK-LABEL: tile func.func @test_tile(%arg0: tensor<13x21x3xf32>) -> tensor<39x21x6xf32> { - %0 = tosa.tile %arg0 {multiples = array} : (tensor<13x21x3xf32>) -> tensor<39x21x6xf32> + %cst = tosa.const_shape { value = dense<[3, 1, 2]> : tensor<3xindex> } : () -> !tosa.shape<3> + %0 = tosa.tile %arg0, %cst: (tensor<13x21x3xf32>, !tosa.shape<3>) -> tensor<39x21x6xf32> return %0 : tensor<39x21x6xf32> } @@ -692,3 +693,10 @@ func.func @test_custom(%arg0: tensor<10xi32>) -> tensor<10xi32> { %0 = tosa.custom %arg0 {operator_name="custom_test", domain_name="tosa.mlir_test", implementation_attrs="" } : (tensor<10xi32>) -> (tensor<10xi32>) return %0 : tensor<10xi32> } + +// ----- +// CHECK-LABEL: const_shape +func.func @test_const_shape() -> !tosa.shape<4> { + %cst = tosa.const_shape {value = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4> + return %cst : !tosa.shape<4> +} diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir index 82f3e22a38722..f4da66ef561b2 100644 --- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir +++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir @@ -543,8 +543,10 @@ func.func @test_slice_dynamic(%arg0 : tensor<10x?x2xf32>) -> () { // CHECK-LABEL: @test_tile func.func @test_tile(%arg0 : tensor<2x3x?xi32>) -> () { - // CHECK: tosa.tile %arg0 {multiples = array} : (tensor<2x3x?xi32>) -> tensor<4x3x?xi32> - %0 = tosa.tile %arg0 {multiples = array} : (tensor<2x3x?xi32>) -> tensor + // CHECK: %[[CST:.*]] = tosa.const_shape {value = dense<[2, 1, 5]> : tensor<3xindex>} : () -> !tosa.shape<3> + // CHECK: tosa.tile %arg0, %[[CST]] : (tensor<2x3x?xi32>, !tosa.shape<3>) -> tensor<4x3x?xi32> + %cst = tosa.const_shape {value = dense<[2, 1, 5]> : tensor<3xindex>} : () -> !tosa.shape<3> + %0 = tosa.tile %arg0, %cst : (tensor<2x3x?xi32>, !tosa.shape<3>) -> tensor return } diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index 5c2a77ca67fd4..d3f3697903d72 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -12115,6 +12115,14 @@ gentbl_cc_library( ["-gen-dialect-defs"], "include/mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc", ), + ( + ["-gen-typedef-decls"], + "include/mlir/Dialect/Tosa/IR/TosaOpsTypesBase.h.inc", + ), + ( + ["-gen-typedef-defs"], + "include/mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc", + ), ( ["-gen-attrdef-decls"], "include/mlir/Dialect/Tosa/IR/TosaAttributes.h.inc",