Skip to content

Commit db70d76

Browse files
Jerry-GeTai78641lhutton1
authored
[mlir][tosa] Add more verifiers for the following operators (llvm#127923)
For ConcatOp this commit also enhances the verifier by checking 4 another conditions: - The input list is not empty - The axis value is within range of the input shapes - All inputs have the same rank - All non concatenate axis dims have the same value For MatmulOp: - Checked input a, bs tensor type, element types For the following operators, added the verifySameElementTypes check. - PadOp - SliceOp - TileOp - ReshapeOp - TransposeOp - GatherOp - ScatterOp - MaxPool2dOp - ReverseOp - SelectOp Signed-off-by: Jerry Ge <[email protected]> Co-authored-by: Tai Ly <[email protected]> Co-authored-by: Luke Hutton <[email protected]>
1 parent e51331c commit db70d76

File tree

3 files changed

+247
-13
lines changed

3 files changed

+247
-13
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

+8
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,7 @@ def Tosa_MatMulOp : Tosa_InferShapedTypeOp<"matmul"> {
325325
];
326326

327327
let builders = [Tosa_MatMulOpQuantInfoBuilder];
328+
let hasVerifier = 1;
328329
}
329330

330331
//===----------------------------------------------------------------------===//
@@ -359,6 +360,7 @@ def Tosa_MaxPool2dOp : Tosa_InferShapedTypeOp<"max_pool2d"> {
359360
];
360361

361362
let hasCanonicalizer = 1;
363+
let hasVerifier = 1;
362364
}
363365

364366
//===----------------------------------------------------------------------===//
@@ -1491,6 +1493,7 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> {
14911493

14921494
let hasCanonicalizeMethod = 1;
14931495
let hasFolder = 1;
1496+
let hasVerifier = 1;
14941497

14951498
let assemblyFormat = [{
14961499
operands attr-dict `:` `(` type($input1) `,` type($input2) `,` type($input3)
@@ -1866,6 +1869,7 @@ def Tosa_ConcatOp : Tosa_InferTensorTypeOp<"concat"> {
18661869

18671870
let hasCanonicalizer = 1;
18681871
let hasFolder = 1;
1872+
let hasVerifier = 1;
18691873

18701874
let extraClassDeclaration = [{
18711875
/// Returns true when two result types are compatible for this op;
@@ -2119,6 +2123,8 @@ def Tosa_GatherOp : Tosa_InferShapedTypeOp<"gather"> {
21192123
Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
21202124
Extension<[Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
21212125
];
2126+
2127+
let hasVerifier = 1;
21222128
}
21232129

21242130
//===----------------------------------------------------------------------===//
@@ -2152,6 +2158,8 @@ def Tosa_ScatterOp : Tosa_InferShapedTypeOp<"scatter"> {
21522158
Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
21532159
Extension<[Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
21542160
];
2161+
2162+
let hasVerifier = 1;
21552163
}
21562164

21572165
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

+202-4
Original file line numberDiff line numberDiff line change
@@ -978,6 +978,75 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
978978
return success();
979979
}
980980

981+
LogicalResult tosa::ConcatOp::verify() {
982+
// check that each input has same element type as output
983+
auto outType = getOutput().getType();
984+
const Operation::operand_range inputList = getInput1();
985+
986+
// Check there is at least one input
987+
if (inputList.empty())
988+
return emitOpError("expect at least one input");
989+
990+
if (!llvm::all_of(inputList, [&](auto input) {
991+
return succeeded(verifySameElementTypes(
992+
*this, /* inType = */ input.getType(), outType));
993+
})) {
994+
return failure();
995+
}
996+
997+
const int32_t axis = getAxis();
998+
ShapeAdaptor firstRankedInputShape = nullptr;
999+
for (const auto &input : inputList) {
1000+
const Type inputType = input.getType();
1001+
ShapeAdaptor currShape(inputType);
1002+
if (currShape.hasRank()) {
1003+
firstRankedInputShape = currShape;
1004+
// Check axis is in expected range
1005+
if (axis < 0 || axis >= firstRankedInputShape.getRank())
1006+
return emitOpError("expect axis to be within range 0 < axis < "
1007+
"rank(input1[firstRankedTensorIdx]), got ")
1008+
<< axis;
1009+
break;
1010+
}
1011+
}
1012+
1013+
const auto allOperandsHasRank = [](const Value input) {
1014+
return ShapeAdaptor(input.getType()).hasRank();
1015+
};
1016+
if (llvm::all_of(inputList, allOperandsHasRank)) {
1017+
const int64_t firstInputRank = firstRankedInputShape.getRank();
1018+
1019+
for (const auto &[index, input] : llvm::enumerate(inputList.drop_front())) {
1020+
const ShapeAdaptor inputShape(input.getType());
1021+
const int64_t inputRank = inputShape.getRank();
1022+
const size_t operandNum = index + 1;
1023+
1024+
// Check that each operand has the same rank
1025+
if (inputRank != firstInputRank)
1026+
return emitOpError(
1027+
"expect all operands to have the same rank, but got ")
1028+
<< firstInputRank << " vs " << inputRank << " on operands 0 and "
1029+
<< operandNum;
1030+
1031+
// Check non-axis dims match
1032+
for (int i = 0; i < inputRank; i++) {
1033+
const int64_t inputDim = inputShape.getDimSize(i);
1034+
const int64_t firstInputDim = firstRankedInputShape.getDimSize(i);
1035+
if (i == axis || firstRankedInputShape.isDynamicDim(i) ||
1036+
inputShape.isDynamicDim(i))
1037+
continue;
1038+
if (inputDim != firstInputDim)
1039+
return emitOpError("expect all operand shapes to have the same sizes "
1040+
"on non-axis dimensions, but got ")
1041+
<< inputDim << " vs " << firstInputDim << " at index " << i
1042+
<< " on operands 0 and " << operandNum;
1043+
}
1044+
}
1045+
}
1046+
1047+
return success();
1048+
}
1049+
9811050
LogicalResult tosa::EqualOp::inferReturnTypeComponents(
9821051
MLIRContext *context, ::std::optional<Location> location,
9831052
ValueShapeRange operands, DictionaryAttr attributes,
@@ -1027,6 +1096,53 @@ LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
10271096
return success();
10281097
}
10291098

1099+
LogicalResult MatMulOp::verify() {
1100+
auto aType = llvm::dyn_cast<ShapedType>(getA().getType());
1101+
auto bType = llvm::dyn_cast<ShapedType>(getB().getType());
1102+
1103+
// Must be shaped tensor types
1104+
if (!aType)
1105+
return emitOpError("expect a shaped tensor for input a, got ")
1106+
<< getA().getType();
1107+
1108+
if (!bType)
1109+
return emitOpError("expect a shaped tensor for input b, got ")
1110+
<< getB().getType();
1111+
1112+
auto aElementType = aType.getElementType();
1113+
auto bElementType = bType.getElementType();
1114+
1115+
auto aQuantizedEType =
1116+
llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
1117+
auto bQuantizedEType =
1118+
llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
1119+
1120+
if (aQuantizedEType || bQuantizedEType) {
1121+
if (!aQuantizedEType || !bQuantizedEType) {
1122+
return emitOpError("expect operands to be both quantized or both not "
1123+
"quantized, got ")
1124+
<< aElementType << " and " << bElementType;
1125+
}
1126+
// both a and b have quantized element types
1127+
auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth();
1128+
auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth();
1129+
if (aQuantWidth != bQuantWidth) {
1130+
return emitOpError("expect quantized operands to have same widths, got ")
1131+
<< aQuantWidth << " and " << bQuantWidth;
1132+
}
1133+
1134+
return success();
1135+
}
1136+
1137+
// non-quantized element types
1138+
if (aElementType != bElementType) {
1139+
return emitOpError("expect same element type for inputs a and b, got ")
1140+
<< aElementType << " and " << bElementType;
1141+
}
1142+
1143+
return success();
1144+
}
1145+
10301146
LogicalResult tosa::PadOp::inferReturnTypeComponents(
10311147
MLIRContext *context, ::std::optional<Location> location,
10321148
PadOp::Adaptor adaptor,
@@ -1075,6 +1191,20 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
10751191
}
10761192

10771193
LogicalResult tosa::PadOp::verify() {
1194+
if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
1195+
/* outType = */ getOutput().getType())
1196+
.failed()) {
1197+
return failure();
1198+
}
1199+
1200+
if (auto padConst = getPadConst()) {
1201+
if (verifySameElementTypes(*this, /* inType = */ padConst.getType(),
1202+
/* outType = */ getOutput().getType())
1203+
.failed()) {
1204+
return failure();
1205+
}
1206+
}
1207+
10781208
RankedTensorType inputType = getInput1().getType();
10791209
RankedTensorType outputType = getOutput().getType();
10801210
auto paddingRank = cast<tosa::shapeType>(getPadding().getType()).getRank();
@@ -1148,21 +1278,23 @@ LogicalResult tosa::SliceOp::inferReturnTypeComponents(
11481278
}
11491279

11501280
LogicalResult tosa::SliceOp::verify() {
1281+
if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
1282+
/* outType = */ getOutput().getType())
1283+
.failed())
1284+
return failure();
11511285
auto inputType = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
11521286
if (!inputType)
11531287
return success();
11541288

11551289
auto startShapeRank =
11561290
llvm::cast<tosa::shapeType>(getStart().getType()).getRank();
11571291
if (inputType.getRank() != startShapeRank)
1158-
return emitOpError(
1159-
"length of start attribute is not equal rank of input shape");
1292+
return emitOpError("length of start is not equal to rank of input shape");
11601293

11611294
auto sizeShapeRank =
11621295
llvm::cast<tosa::shapeType>(getSize().getType()).getRank();
11631296
if (inputType.getRank() != sizeShapeRank)
1164-
return emitOpError(
1165-
"length of size attribute is not equal rank of input shape");
1297+
return emitOpError("length of size is not equal to rank of input shape");
11661298

11671299
return success();
11681300
}
@@ -1367,6 +1499,11 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
13671499
}
13681500

13691501
LogicalResult tosa::TileOp::verify() {
1502+
if (verifySameElementTypes(*this, /* intype = */ getInput1().getType(),
1503+
/* outType = */ getOutput().getType())
1504+
.failed()) {
1505+
return failure();
1506+
}
13701507
ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
13711508
ShapedType outputType = llvm::cast<ShapedType>(getType());
13721509

@@ -1448,6 +1585,11 @@ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
14481585
}
14491586

14501587
llvm::LogicalResult tosa::ReshapeOp::verify() {
1588+
if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
1589+
/* outType = */ getOutput().getType())
1590+
.failed()) {
1591+
return failure();
1592+
}
14511593
TensorType inputType = getInput1().getType();
14521594
RankedTensorType outputType = getType();
14531595

@@ -1626,6 +1768,11 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
16261768
}
16271769

16281770
LogicalResult tosa::TransposeOp::verify() {
1771+
if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
1772+
/* outType = */ getOutput().getType())
1773+
.failed()) {
1774+
return failure();
1775+
}
16291776
TensorType inputType = getInput1().getType();
16301777
TensorType outputType = getOutput().getType();
16311778
const llvm::ArrayRef<int32_t> constantPerms = getPerms();
@@ -1726,6 +1873,11 @@ LogicalResult tosa::GatherOp::inferReturnTypeComponents(
17261873
return success();
17271874
}
17281875

1876+
LogicalResult tosa::GatherOp::verify() {
1877+
return verifySameElementTypes(*this, /* inType = */ getValues().getType(),
1878+
/* outType = */ getOutput().getType());
1879+
}
1880+
17291881
LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
17301882
MLIRContext *context, ::std::optional<Location> location,
17311883
ResizeOp::Adaptor adaptor,
@@ -1887,6 +2039,18 @@ LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
18872039
return success();
18882040
}
18892041

2042+
LogicalResult tosa::ScatterOp::verify() {
2043+
if (verifySameElementTypes(*this, /* inType = */ getValuesIn().getType(),
2044+
/* outType = */ getValuesOut().getType())
2045+
.failed() ||
2046+
verifySameElementTypes(*this, /* inType = */ getInput().getType(),
2047+
/* outType = */ getValuesOut().getType())
2048+
.failed()) {
2049+
return failure();
2050+
}
2051+
return success();
2052+
}
2053+
18902054
static LogicalResult ReduceInferReturnTypes(
18912055
ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
18922056
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
@@ -2342,6 +2506,11 @@ LogicalResult MaxPool2dOp::inferReturnTypeComponents(
23422506
inferredReturnShapes);
23432507
}
23442508

2509+
LogicalResult MaxPool2dOp::verify() {
2510+
return verifySameElementTypes(*this, /* intype = */ getInput().getType(),
2511+
/* outType = */ getOutput().getType());
2512+
}
2513+
23452514
LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
23462515
MLIRContext *context, ::std::optional<Location> location,
23472516
DepthwiseConv2DOp::Adaptor adaptor,
@@ -2642,6 +2811,10 @@ void IfOp::print(OpAsmPrinter &p) {
26422811
}
26432812

26442813
LogicalResult ReverseOp::verify() {
2814+
if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
2815+
/* outType = */ getOutput().getType())
2816+
.failed())
2817+
return failure();
26452818
TensorType inputType = getInput1().getType();
26462819
TensorType outputType = getOutput().getType();
26472820
int32_t reverseAxis = getAxis();
@@ -2670,6 +2843,31 @@ LogicalResult ReverseOp::verify() {
26702843
return success();
26712844
}
26722845

2846+
LogicalResult tosa::SelectOp::verify() {
2847+
// verify input2 and input3 have same element type as output
2848+
if (verifySameElementTypes(*this, /* inType = */ getInput2().getType(),
2849+
/* outType = */ getOutput().getType())
2850+
.failed() ||
2851+
verifySameElementTypes(*this, /* inType = */ getInput3().getType(),
2852+
/* outType = */ getOutput().getType())
2853+
.failed()) {
2854+
return failure();
2855+
}
2856+
// verify input1 has element type of bool
2857+
auto predicateType = llvm::dyn_cast<ShapedType>(getInput1().getType());
2858+
if (!predicateType) {
2859+
return emitOpError("expect shaped tensor for input1, got ")
2860+
<< getInput1().getType();
2861+
}
2862+
auto predicateElementType = predicateType.getElementType();
2863+
if (!predicateElementType.isInteger(1)) {
2864+
return emitOpError("expect element type of bool for input1, got ")
2865+
<< predicateElementType;
2866+
}
2867+
2868+
return success();
2869+
}
2870+
26732871
// parse and print of WhileOp refer to the implementation of SCF dialect.
26742872
ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
26752873
SmallVector<OpAsmParser::Argument, 4> regionArgs;

0 commit comments

Comments
 (0)