@@ -850,6 +850,71 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
850
850
return success ();
851
851
}
852
852
853
+ LogicalResult tosa::ConcatOp::verify () {
854
+ // check that each input has same element type as output
855
+ auto outType = getOutput ().getType ();
856
+ const Operation::operand_range inputList = getInput1 ();
857
+
858
+ if (!llvm::all_of (inputList, [&](auto input) {
859
+ return succeeded (verifySameElementTypes (
860
+ *this , /* inType = */ input.getType (), outType));
861
+ })) {
862
+ return failure ();
863
+ }
864
+
865
+ // Check there is at least one input
866
+ if (inputList.empty ())
867
+ return emitOpError (" expect at least one input" );
868
+
869
+ const Type firstInputType = inputList.front ().getType ();
870
+ const ShapeAdaptor firstInputShape (firstInputType);
871
+ const int32_t axis = getAxis ();
872
+
873
+ if (firstInputShape.hasRank ()) {
874
+ // Check axis is in expected range
875
+ if (axis < 0 || axis >= firstInputShape.getRank ())
876
+ return emitOpError (" expect axis to be within range 0 < axis < "
877
+ " rank(input1[0]), got " )
878
+ << axis;
879
+ }
880
+
881
+ const auto allOperandsHasRank = [](const Value input) {
882
+ return ShapeAdaptor (input.getType ()).hasRank ();
883
+ };
884
+ if (llvm::all_of (inputList, allOperandsHasRank)) {
885
+ const int64_t firstInputRank = firstInputShape.getRank ();
886
+
887
+ for (const auto [index , input] : llvm::enumerate (inputList.drop_front ())) {
888
+ const ShapeAdaptor inputShape (input.getType ());
889
+ const int64_t inputRank = inputShape.getRank ();
890
+ const size_t operandNum = index + 1 ;
891
+
892
+ // Check that each operand has the same rank
893
+ if (inputRank != firstInputRank)
894
+ return emitOpError (
895
+ " expect all operands to have the same rank, but got " )
896
+ << firstInputRank << " vs " << inputRank << " on operands 0 and "
897
+ << operandNum;
898
+
899
+ // Check non-axis dims match
900
+ for (int i = 0 ; i < inputRank; i++) {
901
+ const int64_t inputDim = inputShape.getDimSize (i);
902
+ const int64_t firstInputDim = firstInputShape.getDimSize (i);
903
+ if (i == axis || firstInputShape.isDynamicDim (i) ||
904
+ inputShape.isDynamicDim (i))
905
+ continue ;
906
+ if (inputDim != firstInputDim)
907
+ return emitOpError (" expect all operand shapes to have the same sizes "
908
+ " on non-axis dimensions, but got " )
909
+ << inputDim << " vs " << firstInputDim << " at index " << i
910
+ << " on operands 0 and " << operandNum;
911
+ }
912
+ }
913
+ }
914
+
915
+ return success ();
916
+ }
917
+
853
918
LogicalResult tosa::EqualOp::inferReturnTypeComponents (
854
919
MLIRContext *context, ::std::optional<Location> location,
855
920
ValueShapeRange operands, DictionaryAttr attributes,
@@ -899,6 +964,57 @@ LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
899
964
return success ();
900
965
}
901
966
967
+ LogicalResult MatMulOp::verify () {
968
+ auto aType = llvm::dyn_cast<ShapedType>(getA ().getType ());
969
+ auto bType = llvm::dyn_cast<ShapedType>(getB ().getType ());
970
+
971
+ // Must be shaped tensor types
972
+ if (!aType) {
973
+ emitOpError (" expect a shaped tensor for input a, got " ) << getA ().getType ();
974
+ return failure ();
975
+ }
976
+ if (!bType) {
977
+ emitOpError (" expect a shaped tensor for input b, got " ) << getB ().getType ();
978
+ return failure ();
979
+ }
980
+
981
+ auto aElementType = aType.getElementType ();
982
+ auto bElementType = bType.getElementType ();
983
+
984
+ auto aQuantizedEType =
985
+ llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
986
+ auto bQuantizedEType =
987
+ llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
988
+
989
+ if (aQuantizedEType || bQuantizedEType) {
990
+ if (!aQuantizedEType || !bQuantizedEType) {
991
+ emitOpError (
992
+ " expect operands to be both quantized or both not quantized, got " )
993
+ << aElementType << " and " << bElementType;
994
+ return failure ();
995
+ }
996
+ // both a and b have quantized element types
997
+ auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth ();
998
+ auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth ();
999
+ if (aQuantWidth != bQuantWidth) {
1000
+ emitOpError (" expect quantized operands to have same widths, got " )
1001
+ << aQuantWidth << " and " << bQuantWidth;
1002
+ return failure ();
1003
+ }
1004
+
1005
+ return success ();
1006
+ }
1007
+
1008
+ // non-quantized element types
1009
+ if (aElementType != bElementType) {
1010
+ emitOpError (" expect same element type for inputs a and b, got " )
1011
+ << aElementType << " and " << bElementType;
1012
+ return failure ();
1013
+ }
1014
+
1015
+ return success ();
1016
+ }
1017
+
902
1018
LogicalResult tosa::PadOp::inferReturnTypeComponents (
903
1019
MLIRContext *context, ::std::optional<Location> location,
904
1020
PadOp::Adaptor adaptor,
@@ -947,6 +1063,18 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
947
1063
}
948
1064
949
1065
LogicalResult tosa::PadOp::verify () {
1066
+ if (verifySameElementTypes (*this , /* inType = */ getInput1 ().getType (),
1067
+ /* outType = */ getOutput ().getType ())
1068
+ .failed ()) {
1069
+ return failure ();
1070
+ }
1071
+ if (auto padConst = getPadConst ()) {
1072
+ if (verifySameElementTypes (*this , /* inType = */ padConst.getType (),
1073
+ /* outType = */ getOutput ().getType ())
1074
+ .failed ()) {
1075
+ return failure ();
1076
+ }
1077
+ }
950
1078
RankedTensorType inputType = getInput1 ().getType ();
951
1079
RankedTensorType outputType = getOutput ().getType ();
952
1080
auto paddingRank = cast<tosa::shapeType>(getPadding ().getType ()).getRank ();
@@ -1020,21 +1148,23 @@ LogicalResult tosa::SliceOp::inferReturnTypeComponents(
1020
1148
}
1021
1149
1022
1150
LogicalResult tosa::SliceOp::verify () {
1151
+ if (verifySameElementTypes (*this , /* inType = */ getInput1 ().getType (),
1152
+ /* outType = */ getOutput ().getType ())
1153
+ .failed ())
1154
+ return failure ();
1023
1155
auto inputType = llvm::dyn_cast<RankedTensorType>(getInput1 ().getType ());
1024
1156
if (!inputType)
1025
1157
return success ();
1026
1158
1027
1159
auto startShapeRank =
1028
1160
llvm::cast<tosa::shapeType>(getStart ().getType ()).getRank ();
1029
1161
if (inputType.getRank () != startShapeRank)
1030
- return emitOpError (
1031
- " length of start attribute is not equal rank of input shape" );
1162
+ return emitOpError (" length of start is not equal to rank of input shape" );
1032
1163
1033
1164
auto sizeShapeRank =
1034
1165
llvm::cast<tosa::shapeType>(getSize ().getType ()).getRank ();
1035
1166
if (inputType.getRank () != sizeShapeRank)
1036
- return emitOpError (
1037
- " length of size attribute is not equal rank of input shape" );
1167
+ return emitOpError (" length of size is not equal to rank of input shape" );
1038
1168
1039
1169
return success ();
1040
1170
}
@@ -1239,6 +1369,11 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
1239
1369
}
1240
1370
1241
1371
LogicalResult tosa::TileOp::verify () {
1372
+ if (verifySameElementTypes (*this , /* intype = */ getInput1 ().getType (),
1373
+ /* outType = */ getOutput ().getType ())
1374
+ .failed ()) {
1375
+ return failure ();
1376
+ }
1242
1377
ShapedType inputType = llvm::cast<ShapedType>(getInput1 ().getType ());
1243
1378
ShapedType outputType = llvm::cast<ShapedType>(getType ());
1244
1379
@@ -1320,6 +1455,11 @@ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
1320
1455
}
1321
1456
1322
1457
llvm::LogicalResult tosa::ReshapeOp::verify () {
1458
+ if (verifySameElementTypes (*this , /* inType = */ getInput1 ().getType (),
1459
+ /* outType = */ getOutput ().getType ())
1460
+ .failed ()) {
1461
+ return failure ();
1462
+ }
1323
1463
TensorType inputType = getInput1 ().getType ();
1324
1464
RankedTensorType outputType = getType ();
1325
1465
@@ -1434,6 +1574,11 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
1434
1574
}
1435
1575
1436
1576
LogicalResult tosa::TransposeOp::verify () {
1577
+ if (verifySameElementTypes (*this , /* inType = */ getInput1 ().getType (),
1578
+ /* outType = */ getOutput ().getType ())
1579
+ .failed ()) {
1580
+ return failure ();
1581
+ }
1437
1582
TensorType inputType = getInput1 ().getType ();
1438
1583
TensorType outputType = getOutput ().getType ();
1439
1584
const llvm::ArrayRef<int32_t > constantPerms = getPerms ();
@@ -1534,6 +1679,11 @@ LogicalResult tosa::GatherOp::inferReturnTypeComponents(
1534
1679
return success ();
1535
1680
}
1536
1681
1682
+ LogicalResult tosa::GatherOp::verify () {
1683
+ return verifySameElementTypes (*this , /* inType = */ getValues ().getType (),
1684
+ /* outType = */ getOutput ().getType ());
1685
+ }
1686
+
1537
1687
LogicalResult tosa::ResizeOp::inferReturnTypeComponents (
1538
1688
MLIRContext *context, ::std::optional<Location> location,
1539
1689
ResizeOp::Adaptor adaptor,
@@ -1702,6 +1852,18 @@ LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
1702
1852
return success ();
1703
1853
}
1704
1854
1855
+ LogicalResult tosa::ScatterOp::verify () {
1856
+ if (verifySameElementTypes (*this , /* inType = */ getValuesIn ().getType (),
1857
+ /* outType = */ getValuesOut ().getType ())
1858
+ .failed () ||
1859
+ verifySameElementTypes (*this , /* inType = */ getInput ().getType (),
1860
+ /* outType = */ getValuesOut ().getType ())
1861
+ .failed ()) {
1862
+ return failure ();
1863
+ }
1864
+ return success ();
1865
+ }
1866
+
1705
1867
static LogicalResult ReduceInferReturnTypes (
1706
1868
ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
1707
1869
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
@@ -2066,6 +2228,11 @@ LogicalResult MaxPool2dOp::inferReturnTypeComponents(
2066
2228
inferredReturnShapes);
2067
2229
}
2068
2230
2231
+ LogicalResult MaxPool2dOp::verify () {
2232
+ return verifySameElementTypes (*this , /* intype = */ getInput ().getType (),
2233
+ /* outType = */ getOutput ().getType ());
2234
+ }
2235
+
2069
2236
LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents (
2070
2237
MLIRContext *context, ::std::optional<Location> location,
2071
2238
DepthwiseConv2DOp::Adaptor adaptor,
@@ -2368,6 +2535,10 @@ void IfOp::print(OpAsmPrinter &p) {
2368
2535
}
2369
2536
2370
2537
LogicalResult ReverseOp::verify () {
2538
+ if (verifySameElementTypes (*this , /* inType = */ getInput1 ().getType (),
2539
+ /* outType = */ getOutput ().getType ())
2540
+ .failed ())
2541
+ return failure ();
2371
2542
TensorType inputType = getInput1 ().getType ();
2372
2543
TensorType outputType = getOutput ().getType ();
2373
2544
int32_t reverseAxis = getAxis ();
@@ -2396,6 +2567,33 @@ LogicalResult ReverseOp::verify() {
2396
2567
return success ();
2397
2568
}
2398
2569
2570
+ LogicalResult tosa::SelectOp::verify () {
2571
+ // verify input2 and input3 have same element type as output
2572
+ if (verifySameElementTypes (*this , /* inType = */ getInput2 ().getType (),
2573
+ /* outType = */ getOutput ().getType ())
2574
+ .failed () ||
2575
+ verifySameElementTypes (*this , /* inType = */ getInput3 ().getType (),
2576
+ /* outType = */ getOutput ().getType ())
2577
+ .failed ()) {
2578
+ return failure ();
2579
+ }
2580
+ // verify input1 has element type of bool
2581
+ auto predicateType = llvm::dyn_cast<ShapedType>(getInput1 ().getType ());
2582
+ if (!predicateType) {
2583
+ emitOpError (" expect shaped tensor for input1, got " )
2584
+ << getInput1 ().getType ();
2585
+ return failure ();
2586
+ }
2587
+ auto predicateElementType = predicateType.getElementType ();
2588
+ if (!predicateElementType.isInteger (1 )) {
2589
+ emitOpError (" expect element type of bool for input1, got " )
2590
+ << predicateElementType;
2591
+ return failure ();
2592
+ }
2593
+
2594
+ return success ();
2595
+ }
2596
+
2399
2597
// parse and print of WhileOp refer to the implementation of SCF dialect.
2400
2598
ParseResult WhileOp::parse (OpAsmParser &parser, OperationState &result) {
2401
2599
SmallVector<OpAsmParser::Argument, 4 > regionArgs;
0 commit comments