Skip to content

Commit

Permalink
[TOSA] Add Tosa_Shape type and ConstShapeOp (llvm#122547)
Browse files Browse the repository at this point in the history
Adds:
1. tosa shape type to Tosa dialect e.g., !tosa.shape<4> is a type for
rank-4 shape values (size-4 array of index values)
 2. const_shape operator
3. trait TosaShapeOperator, added to tosa shape operators, and a
verifier that all operands and results of operator are tosa shapes
4. trait TosaResolvableShapeOperands, added to all tosa operators, and a
verifier that every tosa shape operand is produced by a tosa shape
operator (indicated by trait TosaShapeOperator)
5. trait TosaShapeOperatorWithSameRanks, added to
Tosa_ElementwiseShapeOp and a verifier that all operands and result
shapes have same ranks
5. changed TileOp's multiples from attribute to input, of !tosa.shape
type.
 6. add folder for tosa ConstShape operator

This patch was originally authored by Tai Ly <[email protected]>

Signed-off-by: Jerry Ge <[email protected]>
Signed-off-by: Tai Ly <[email protected]>
  • Loading branch information
Jerry-Ge authored Jan 14, 2025
1 parent 31249e2 commit f09db6a
Show file tree
Hide file tree
Showing 18 changed files with 425 additions and 33 deletions.
3 changes: 2 additions & 1 deletion mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ add_mlir_doc(TosaOps TosaOps Dialects/ -gen-op-doc)
add_mlir_interface(TosaInterfaces)

set(LLVM_TARGET_DEFINITIONS TosaOps.td)
mlir_tablegen(TosaOpsTypesBase.h.inc -gen-typedef-decls -typedefs-dialect=tosa)
mlir_tablegen(TosaOpsTypesBase.cpp.inc -gen-typedef-defs -typedefs-dialect=tosa)
mlir_tablegen(TosaAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=tosa)
mlir_tablegen(TosaAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=tosa)
add_public_tablegen_target(MLIRTosaAttributesIncGen)

set(LLVM_TARGET_DEFINITIONS TosaDialectBytecode.td)
mlir_tablegen(TosaDialectBytecode.cpp.inc -gen-bytecode -bytecode-dialect="Tosa")
add_public_tablegen_target(MLIRTosaDialectBytecodeIncGen)

12 changes: 11 additions & 1 deletion mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def Tosa_Dialect : Dialect {
let cppNamespace = "mlir::tosa";
let hasConstantMaterializer = 1;
let useDefaultAttributePrinterParser = 1;
let useDefaultTypePrinterParser = 1;
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -217,12 +218,21 @@ def TosaElementwiseOperator : NativeOpTrait<"TosaElementwiseOperator"> {
let cppNamespace = "mlir::OpTrait::tosa";
}

//===----------------------------------------------------------------------===//
// TOSA Operator Trait.
//===----------------------------------------------------------------------===//
// Op operands with TOSA shape types must be compile time resolvable
def TosaResolvableShapeOperands : NativeOpTrait<"TosaResolvableShapeOperands"> {
let cppNamespace = "mlir::OpTrait::tosa";
}

//===----------------------------------------------------------------------===//
// TOSA Operator Class.
//===----------------------------------------------------------------------===//

class Tosa_Op<string mnemonic, list<Trait> traits = []> :
Op<Tosa_Dialect, mnemonic, !listconcat(traits, [TosaOpInterface])> {
Op<Tosa_Dialect, mnemonic, !listconcat(traits, [TosaOpInterface,
TosaResolvableShapeOperands])> {
}

class Tosa_ElementwiseOp<string mnemonic, list<Trait> traits = []> :
Expand Down
41 changes: 41 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,55 @@ template <typename ConcreteType>
class TosaElementwiseOperator
: public TraitBase<ConcreteType, TosaElementwiseOperator> {};

LogicalResult verifyTosaResolvableShapeOperands(Operation *op);
/// This class verifies that tosa shape operands are compile time resolvable
template <typename ConcreteType>
class TosaResolvableShapeOperands
: public TraitBase<ConcreteType, TosaResolvableShapeOperands> {
public:
static LogicalResult verifyTrait(Operation *op) {
return verifyTosaResolvableShapeOperands(op);
}
};

LogicalResult verifyTosaShapeOperator(Operation *op);
/// This class indicates that op operates on tosa shape types
template <typename ConcreteType>
class TosaShapeOperator : public TraitBase<ConcreteType, TosaShapeOperator> {
public:
static LogicalResult verifyTrait(Operation *op) {
return verifyTosaShapeOperator(op);
}
};

LogicalResult verifyTosaShapeOperatorWithSameRanks(Operation *op);
/// This class indicates that op operates on tosa shape types
template <typename ConcreteType>
class TosaShapeOperatorWithSameRanks
: public TraitBase<ConcreteType, TosaShapeOperatorWithSameRanks> {
public:
static LogicalResult verifyTrait(Operation *op) {
return verifyTosaShapeOperatorWithSameRanks(op);
}
};

} // namespace tosa
} // namespace OpTrait

namespace tosa {

bool isa_tosa_shape_type(mlir::Type t);

} // namespace tosa

} // namespace mlir

#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/Tosa/IR/TosaAttributes.h.inc"

#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.h.inc"

#define GET_OP_CLASSES
#include "mlir/Dialect/Tosa/IR/TosaOps.h.inc"

Expand Down
8 changes: 7 additions & 1 deletion mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1689,12 +1689,16 @@ def Tosa_TileOp : Tosa_InferShapedTypeOp<"tile"> {

let arguments = (ins
Tosa_Tensor:$input1,
DenseI64ArrayAttr:$multiples);
Tosa_Shape:$multiples);

let results = (outs
Tosa_Tensor:$output
);

let extraClassDeclaration = [{
LogicalResult getConstantMultiples(llvm::SmallVector<int64_t> &multiples);
}];

let hasFolder = 1;
let hasVerifier = 1;
}
Expand Down Expand Up @@ -2106,4 +2110,6 @@ def Tosa_WhileOp : Tosa_Op<"while_loop", [

include "mlir/Dialect/Tosa/IR/TosaUtilOps.td"

include "mlir/Dialect/Tosa/IR/TosaShapeOps.td"

#endif // TOSA_OPS
77 changes: 77 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaShapeOps.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
//===-- TosaShapeOps.td - TOSA dialect utility operations --*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines shape operators for the TOSA dialect.
//
//===----------------------------------------------------------------------===//

#ifndef TOSA_SHAPE_OPS
#define TOSA_SHAPE_OPS

include "mlir/IR/OpBase.td"

include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/Dialect/Tosa/IR/TosaInterfaces.td"

include "mlir/Dialect/Tosa/IR/TosaTypesBase.td"
include "mlir/Dialect/Tosa/IR/TosaOpBase.td"

// Op trait: operator has operands and results with TOSA shape type
def TosaShapeOperator : NativeOpTrait<"TosaShapeOperator"> {
let cppNamespace = "mlir::OpTrait::tosa";
}

class Tosa_ShapeOp<string mnemonic, list<Trait> traits = []>
: Tosa_Op<mnemonic, !listconcat(traits, [TosaShapeOperator, Pure])> {

let assemblyFormat =
"operands attr-dict `:` functional-type(operands, results)";

let hasFolder = 1;
}

// op trait: shape operator has same ranks for operands and results
def TosaShapeOperatorWithSameRanks
: NativeOpTrait<"TosaShapeOperatorWithSameRanks"> {
let cppNamespace = "mlir::OpTrait::tosa";
}

class Tosa_ElementwiseShapeOp<string mnemonic, list<Trait> traits = []>
: Tosa_ShapeOp<mnemonic,
!listconcat(traits, [TosaShapeOperatorWithSameRanks])> {
}


//===----------------------------------------------------------------------===//
// Operator: ConstShape
//===----------------------------------------------------------------------===//
def Tosa_ConstShapeOp : Tosa_ShapeOp<"const_shape", [ConstantLike, Pure]> {
let summary = "Constant Shape op.";

let description = [{
A node containing constant data for use as the input to an shape operation. May
hold data only in index data type.

Example:

```mlir
// Generic form
%out = "tosa.const_shape"() {value = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
```
}];

let arguments = (ins IndexElementsAttr : $value);

let results = (outs Tosa_Shape : $output);

let hasVerifier = 1;
}

#endif // TOSA_SHAPE_OPS
65 changes: 65 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@
#ifndef TOSA_TYPES_BASE
#define TOSA_TYPES_BASE

include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/OpBase.td"

include "mlir/Dialect/Tosa/IR/TosaOpBase.td"

//===----------------------------------------------------------------------===//
// Tosa Type Definitions.
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -215,4 +218,66 @@ def Tosa_Buffer : MemRefOf<[Tosa_AnyNumber]>;
def Tosa_TupleBuffer : NestedTupleOf<[Tosa_Buffer]>;
def Tosa_BufOrTuple : AnyTypeOf<[Tosa_Buffer, Tosa_TupleBuffer]>;

//===----------------------------------------------------------------------===//
// Tosa Type Definitions.
//===----------------------------------------------------------------------===//

// The base class for Tosa dialect types.
class Tosa_Type<string name, string typeMnemonic, list<Trait> traits = []>
: TypeDef<Tosa_Dialect, name, traits> {
let mnemonic = typeMnemonic;
}

//===----------------------------------------------------------------------===//
// ShapeType
//===----------------------------------------------------------------------===//
def Tosa_Shape : Tosa_Type<"shape", "shape"> {
let summary = "Shape with static rank and Index element type";
let description = [{
Syntax:

``` shape - type :: = `shape` `<` rank `>`
``` Values with shape type represents a shape with a fixed rank and a list
of dimensions
.Rank must be zero or a positive integer
.Each dimension is represented by the builtin
Index type.

Examples:

```mlir
// Shape with rank of four, for example, [1, 1, 8, 16]:
!tosa
.shape<4>

// Shape with rank of one, for example, [16]:
!tosa
.shape<1>

// Shape with rank zero, for example, [] (i.e., shape of scalar values):
!tosa.shape<0>
```
}];
let parameters = (ins "int" : $rank);
let builders = [TypeBuilder<(ins "int" : $rank)>];
let assemblyFormat = "`<` $rank `>`";

let genVerifyDecl = 1;
}

def IsTosaShapeType : CPred<"mlir::tosa::isa_tosa_shape_type($_self)">;

// Whether a Tosa Shape type has a rank equal to the specified rank.
class IsTosaShapeOfRankPred<int rank> : And<[
IsTosaShapeType,
CPred<[{::llvm::cast<::mlir::tosa::shapeType>($_self).getRank() == }] # rank>
]>;

class TosaShapeOfRank<int rank>
: Type<IsTosaShapeOfRankPred<rank>, "Tosa shape type of rank " #rank>;

def Rank1TosaShape : TosaShapeOfRank<1>;
def Rank2TosaShape : TosaShapeOfRank<2>;
def Rank4TosaShape : TosaShapeOfRank<4>;

#endif // TOSA_TYPES_BASE
4 changes: 3 additions & 1 deletion mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1886,7 +1886,9 @@ struct TileConverter : public OpConversionPattern<tosa::TileOp> {
auto elementTy = inputTy.getElementType();
int64_t rank = inputTy.getRank();

ArrayRef<int64_t> multiples = op.getMultiples();
SmallVector<int64_t> multiples;
if (failed(op.getConstantMultiples(multiples)))
return failure();

// Broadcast the newly added dimensions to their appropriate multiple.
SmallVector<int64_t, 2> genericShape;
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ struct TosaToLinalg : public impl::TosaToLinalgBase<TosaToLinalg> {
target.addLegalOp<tosa::ApplyScaleOp>();
target.addLegalOp<tosa::IfOp>();
target.addLegalOp<tosa::ConstOp>();
target.addLegalOp<tosa::ConstShapeOp>();
target.addLegalOp<tosa::WhileOp>();
target.addLegalOp<tosa::ConcatOp>();
target.addLegalOp<tosa::SliceOp>();
Expand Down
19 changes: 16 additions & 3 deletions mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -808,6 +808,8 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) {

OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }

OpFoldResult ConstShapeOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }

#define REDUCE_FOLDER(OP) \
OpFoldResult OP::fold(FoldAdaptor adaptor) { \
ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType()); \
Expand Down Expand Up @@ -985,9 +987,20 @@ OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
}

OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
bool allOnes = llvm::all_of(getMultiples(), [](int64_t v) { return v == 1; });
if (allOnes && getInput1().getType() == getType())
return getInput1();
if (getInput1().getType() == getType()) {
if (auto multiples = llvm::dyn_cast_if_present<DenseElementsAttr>(
adaptor.getMultiples())) {
if (multiples.isSplat() &&
multiples.getSplatValue<APInt>().getSExtValue() == 1)
return getInput1();
if (auto int_array_attr =
llvm::dyn_cast<DenseIntElementsAttr>(multiples)) {
if (llvm::all_of(int_array_attr.getValues<APInt>(),
[](APInt v) { return v.getSExtValue() == 1; }))
return getInput1();
}
}
}
return {};
}

Expand Down
Loading

0 comments on commit f09db6a

Please sign in to comment.