@@ -978,6 +978,75 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
978
978
return success ();
979
979
}
980
980
981
+ LogicalResult tosa::ConcatOp::verify () {
982
+ // check that each input has same element type as output
983
+ auto outType = getOutput ().getType ();
984
+ const Operation::operand_range inputList = getInput1 ();
985
+
986
+ // Check there is at least one input
987
+ if (inputList.empty ())
988
+ return emitOpError (" expect at least one input" );
989
+
990
+ if (!llvm::all_of (inputList, [&](auto input) {
991
+ return succeeded (verifySameElementTypes (
992
+ *this , /* inType = */ input.getType (), outType));
993
+ })) {
994
+ return failure ();
995
+ }
996
+
997
+ const int32_t axis = getAxis ();
998
+ ShapeAdaptor firstRankedInputShape = nullptr ;
999
+ for (const auto &input : inputList) {
1000
+ const Type inputType = input.getType ();
1001
+ ShapeAdaptor currShape (inputType);
1002
+ if (currShape.hasRank ()) {
1003
+ firstRankedInputShape = currShape;
1004
+ // Check axis is in expected range
1005
+ if (axis < 0 || axis >= firstRankedInputShape.getRank ())
1006
+ return emitOpError (" expect axis to be within range 0 < axis < "
1007
+ " rank(input1[firstRankedTensorIdx]), got " )
1008
+ << axis;
1009
+ break ;
1010
+ }
1011
+ }
1012
+
1013
+ const auto allOperandsHasRank = [](const Value input) {
1014
+ return ShapeAdaptor (input.getType ()).hasRank ();
1015
+ };
1016
+ if (llvm::all_of (inputList, allOperandsHasRank)) {
1017
+ const int64_t firstInputRank = firstRankedInputShape.getRank ();
1018
+
1019
+ for (const auto &[index , input] : llvm::enumerate (inputList.drop_front ())) {
1020
+ const ShapeAdaptor inputShape (input.getType ());
1021
+ const int64_t inputRank = inputShape.getRank ();
1022
+ const size_t operandNum = index + 1 ;
1023
+
1024
+ // Check that each operand has the same rank
1025
+ if (inputRank != firstInputRank)
1026
+ return emitOpError (
1027
+ " expect all operands to have the same rank, but got " )
1028
+ << firstInputRank << " vs " << inputRank << " on operands 0 and "
1029
+ << operandNum;
1030
+
1031
+ // Check non-axis dims match
1032
+ for (int i = 0 ; i < inputRank; i++) {
1033
+ const int64_t inputDim = inputShape.getDimSize (i);
1034
+ const int64_t firstInputDim = firstRankedInputShape.getDimSize (i);
1035
+ if (i == axis || firstRankedInputShape.isDynamicDim (i) ||
1036
+ inputShape.isDynamicDim (i))
1037
+ continue ;
1038
+ if (inputDim != firstInputDim)
1039
+ return emitOpError (" expect all operand shapes to have the same sizes "
1040
+ " on non-axis dimensions, but got " )
1041
+ << inputDim << " vs " << firstInputDim << " at index " << i
1042
+ << " on operands 0 and " << operandNum;
1043
+ }
1044
+ }
1045
+ }
1046
+
1047
+ return success ();
1048
+ }
1049
+
981
1050
LogicalResult tosa::EqualOp::inferReturnTypeComponents (
982
1051
MLIRContext *context, ::std::optional<Location> location,
983
1052
ValueShapeRange operands, DictionaryAttr attributes,
@@ -1027,6 +1096,53 @@ LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
1027
1096
return success ();
1028
1097
}
1029
1098
1099
+ LogicalResult MatMulOp::verify () {
1100
+ auto aType = llvm::dyn_cast<ShapedType>(getA ().getType ());
1101
+ auto bType = llvm::dyn_cast<ShapedType>(getB ().getType ());
1102
+
1103
+ // Must be shaped tensor types
1104
+ if (!aType)
1105
+ return emitOpError (" expect a shaped tensor for input a, got " )
1106
+ << getA ().getType ();
1107
+
1108
+ if (!bType)
1109
+ return emitOpError (" expect a shaped tensor for input b, got " )
1110
+ << getB ().getType ();
1111
+
1112
+ auto aElementType = aType.getElementType ();
1113
+ auto bElementType = bType.getElementType ();
1114
+
1115
+ auto aQuantizedEType =
1116
+ llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
1117
+ auto bQuantizedEType =
1118
+ llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
1119
+
1120
+ if (aQuantizedEType || bQuantizedEType) {
1121
+ if (!aQuantizedEType || !bQuantizedEType) {
1122
+ return emitOpError (" expect operands to be both quantized or both not "
1123
+ " quantized, got " )
1124
+ << aElementType << " and " << bElementType;
1125
+ }
1126
+ // both a and b have quantized element types
1127
+ auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth ();
1128
+ auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth ();
1129
+ if (aQuantWidth != bQuantWidth) {
1130
+ return emitOpError (" expect quantized operands to have same widths, got " )
1131
+ << aQuantWidth << " and " << bQuantWidth;
1132
+ }
1133
+
1134
+ return success ();
1135
+ }
1136
+
1137
+ // non-quantized element types
1138
+ if (aElementType != bElementType) {
1139
+ return emitOpError (" expect same element type for inputs a and b, got " )
1140
+ << aElementType << " and " << bElementType;
1141
+ }
1142
+
1143
+ return success ();
1144
+ }
1145
+
1030
1146
LogicalResult tosa::PadOp::inferReturnTypeComponents (
1031
1147
MLIRContext *context, ::std::optional<Location> location,
1032
1148
PadOp::Adaptor adaptor,
@@ -1075,6 +1191,20 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
1075
1191
}
1076
1192
1077
1193
LogicalResult tosa::PadOp::verify () {
1194
+ if (verifySameElementTypes (*this , /* inType = */ getInput1 ().getType (),
1195
+ /* outType = */ getOutput ().getType ())
1196
+ .failed ()) {
1197
+ return failure ();
1198
+ }
1199
+
1200
+ if (auto padConst = getPadConst ()) {
1201
+ if (verifySameElementTypes (*this , /* inType = */ padConst.getType (),
1202
+ /* outType = */ getOutput ().getType ())
1203
+ .failed ()) {
1204
+ return failure ();
1205
+ }
1206
+ }
1207
+
1078
1208
RankedTensorType inputType = getInput1 ().getType ();
1079
1209
RankedTensorType outputType = getOutput ().getType ();
1080
1210
auto paddingRank = cast<tosa::shapeType>(getPadding ().getType ()).getRank ();
@@ -1148,21 +1278,23 @@ LogicalResult tosa::SliceOp::inferReturnTypeComponents(
1148
1278
}
1149
1279
1150
1280
LogicalResult tosa::SliceOp::verify () {
1281
+ if (verifySameElementTypes (*this , /* inType = */ getInput1 ().getType (),
1282
+ /* outType = */ getOutput ().getType ())
1283
+ .failed ())
1284
+ return failure ();
1151
1285
auto inputType = llvm::dyn_cast<RankedTensorType>(getInput1 ().getType ());
1152
1286
if (!inputType)
1153
1287
return success ();
1154
1288
1155
1289
auto startShapeRank =
1156
1290
llvm::cast<tosa::shapeType>(getStart ().getType ()).getRank ();
1157
1291
if (inputType.getRank () != startShapeRank)
1158
- return emitOpError (
1159
- " length of start attribute is not equal rank of input shape" );
1292
+ return emitOpError (" length of start is not equal to rank of input shape" );
1160
1293
1161
1294
auto sizeShapeRank =
1162
1295
llvm::cast<tosa::shapeType>(getSize ().getType ()).getRank ();
1163
1296
if (inputType.getRank () != sizeShapeRank)
1164
- return emitOpError (
1165
- " length of size attribute is not equal rank of input shape" );
1297
+ return emitOpError (" length of size is not equal to rank of input shape" );
1166
1298
1167
1299
return success ();
1168
1300
}
@@ -1367,6 +1499,11 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
1367
1499
}
1368
1500
1369
1501
LogicalResult tosa::TileOp::verify () {
1502
+ if (verifySameElementTypes (*this , /* intype = */ getInput1 ().getType (),
1503
+ /* outType = */ getOutput ().getType ())
1504
+ .failed ()) {
1505
+ return failure ();
1506
+ }
1370
1507
ShapedType inputType = llvm::cast<ShapedType>(getInput1 ().getType ());
1371
1508
ShapedType outputType = llvm::cast<ShapedType>(getType ());
1372
1509
@@ -1448,6 +1585,11 @@ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
1448
1585
}
1449
1586
1450
1587
llvm::LogicalResult tosa::ReshapeOp::verify () {
1588
+ if (verifySameElementTypes (*this , /* inType = */ getInput1 ().getType (),
1589
+ /* outType = */ getOutput ().getType ())
1590
+ .failed ()) {
1591
+ return failure ();
1592
+ }
1451
1593
TensorType inputType = getInput1 ().getType ();
1452
1594
RankedTensorType outputType = getType ();
1453
1595
@@ -1626,6 +1768,11 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
1626
1768
}
1627
1769
1628
1770
LogicalResult tosa::TransposeOp::verify () {
1771
+ if (verifySameElementTypes (*this , /* inType = */ getInput1 ().getType (),
1772
+ /* outType = */ getOutput ().getType ())
1773
+ .failed ()) {
1774
+ return failure ();
1775
+ }
1629
1776
TensorType inputType = getInput1 ().getType ();
1630
1777
TensorType outputType = getOutput ().getType ();
1631
1778
const llvm::ArrayRef<int32_t > constantPerms = getPerms ();
@@ -1726,6 +1873,11 @@ LogicalResult tosa::GatherOp::inferReturnTypeComponents(
1726
1873
return success ();
1727
1874
}
1728
1875
1876
+ LogicalResult tosa::GatherOp::verify () {
1877
+ return verifySameElementTypes (*this , /* inType = */ getValues ().getType (),
1878
+ /* outType = */ getOutput ().getType ());
1879
+ }
1880
+
1729
1881
LogicalResult tosa::ResizeOp::inferReturnTypeComponents (
1730
1882
MLIRContext *context, ::std::optional<Location> location,
1731
1883
ResizeOp::Adaptor adaptor,
@@ -1887,6 +2039,18 @@ LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
1887
2039
return success ();
1888
2040
}
1889
2041
2042
+ LogicalResult tosa::ScatterOp::verify () {
2043
+ if (verifySameElementTypes (*this , /* inType = */ getValuesIn ().getType (),
2044
+ /* outType = */ getValuesOut ().getType ())
2045
+ .failed () ||
2046
+ verifySameElementTypes (*this , /* inType = */ getInput ().getType (),
2047
+ /* outType = */ getValuesOut ().getType ())
2048
+ .failed ()) {
2049
+ return failure ();
2050
+ }
2051
+ return success ();
2052
+ }
2053
+
1890
2054
static LogicalResult ReduceInferReturnTypes (
1891
2055
ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
1892
2056
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
@@ -2342,6 +2506,11 @@ LogicalResult MaxPool2dOp::inferReturnTypeComponents(
2342
2506
inferredReturnShapes);
2343
2507
}
2344
2508
2509
+ LogicalResult MaxPool2dOp::verify () {
2510
+ return verifySameElementTypes (*this , /* intype = */ getInput ().getType (),
2511
+ /* outType = */ getOutput ().getType ());
2512
+ }
2513
+
2345
2514
LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents (
2346
2515
MLIRContext *context, ::std::optional<Location> location,
2347
2516
DepthwiseConv2DOp::Adaptor adaptor,
@@ -2642,6 +2811,10 @@ void IfOp::print(OpAsmPrinter &p) {
2642
2811
}
2643
2812
2644
2813
LogicalResult ReverseOp::verify () {
2814
+ if (verifySameElementTypes (*this , /* inType = */ getInput1 ().getType (),
2815
+ /* outType = */ getOutput ().getType ())
2816
+ .failed ())
2817
+ return failure ();
2645
2818
TensorType inputType = getInput1 ().getType ();
2646
2819
TensorType outputType = getOutput ().getType ();
2647
2820
int32_t reverseAxis = getAxis ();
@@ -2670,6 +2843,31 @@ LogicalResult ReverseOp::verify() {
2670
2843
return success ();
2671
2844
}
2672
2845
2846
+ LogicalResult tosa::SelectOp::verify () {
2847
+ // verify input2 and input3 have same element type as output
2848
+ if (verifySameElementTypes (*this , /* inType = */ getInput2 ().getType (),
2849
+ /* outType = */ getOutput ().getType ())
2850
+ .failed () ||
2851
+ verifySameElementTypes (*this , /* inType = */ getInput3 ().getType (),
2852
+ /* outType = */ getOutput ().getType ())
2853
+ .failed ()) {
2854
+ return failure ();
2855
+ }
2856
+ // verify input1 has element type of bool
2857
+ auto predicateType = llvm::dyn_cast<ShapedType>(getInput1 ().getType ());
2858
+ if (!predicateType) {
2859
+ return emitOpError (" expect shaped tensor for input1, got " )
2860
+ << getInput1 ().getType ();
2861
+ }
2862
+ auto predicateElementType = predicateType.getElementType ();
2863
+ if (!predicateElementType.isInteger (1 )) {
2864
+ return emitOpError (" expect element type of bool for input1, got " )
2865
+ << predicateElementType;
2866
+ }
2867
+
2868
+ return success ();
2869
+ }
2870
+
2673
2871
// parse and print of WhileOp refer to the implementation of SCF dialect.
2674
2872
ParseResult WhileOp::parse (OpAsmParser &parser, OperationState &result) {
2675
2873
SmallVector<OpAsmParser::Argument, 4 > regionArgs;
0 commit comments