@@ -871,6 +871,75 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
871
871
return success ();
872
872
}
873
873
874
+ LogicalResult tosa::ConcatOp::verify () {
875
+ // check that each input has same element type as output
876
+ auto outType = getOutput ().getType ();
877
+ const Operation::operand_range inputList = getInput1 ();
878
+
879
+ // Check there is at least one input
880
+ if (inputList.empty ())
881
+ return emitOpError (" expect at least one input" );
882
+
883
+ if (!llvm::all_of (inputList, [&](auto input) {
884
+ return succeeded (verifySameElementTypes (
885
+ *this , /* inType = */ input.getType (), outType));
886
+ })) {
887
+ return failure ();
888
+ }
889
+
890
+ const int32_t axis = getAxis ();
891
+ ShapeAdaptor firstRankedInputShape = nullptr ;
892
+ for (auto input : inputList) {
893
+ const Type inputType = input.getType ();
894
+ ShapeAdaptor currShape (inputType);
895
+ if (currShape.hasRank ()) {
896
+ firstRankedInputShape = currShape;
897
+ // Check axis is in expected range
898
+ if (axis < 0 || axis >= firstRankedInputShape.getRank ())
899
+ return emitOpError (" expect axis to be within range 0 < axis < "
900
+ " rank(input1[firstRankedTensorIdx]), got " )
901
+ << axis;
902
+ break ;
903
+ }
904
+ }
905
+
906
+ const auto allOperandsHasRank = [](const Value input) {
907
+ return ShapeAdaptor (input.getType ()).hasRank ();
908
+ };
909
+ if (llvm::all_of (inputList, allOperandsHasRank)) {
910
+ const int64_t firstInputRank = firstRankedInputShape.getRank ();
911
+
912
+ for (const auto [index , input] : llvm::enumerate (inputList.drop_front ())) {
913
+ const ShapeAdaptor inputShape (input.getType ());
914
+ const int64_t inputRank = inputShape.getRank ();
915
+ const size_t operandNum = index + 1 ;
916
+
917
+ // Check that each operand has the same rank
918
+ if (inputRank != firstInputRank)
919
+ return emitOpError (
920
+ " expect all operands to have the same rank, but got " )
921
+ << firstInputRank << " vs " << inputRank << " on operands 0 and "
922
+ << operandNum;
923
+
924
+ // Check non-axis dims match
925
+ for (int i = 0 ; i < inputRank; i++) {
926
+ const int64_t inputDim = inputShape.getDimSize (i);
927
+ const int64_t firstInputDim = firstRankedInputShape.getDimSize (i);
928
+ if (i == axis || firstRankedInputShape.isDynamicDim (i) ||
929
+ inputShape.isDynamicDim (i))
930
+ continue ;
931
+ if (inputDim != firstInputDim)
932
+ return emitOpError (" expect all operand shapes to have the same sizes "
933
+ " on non-axis dimensions, but got " )
934
+ << inputDim << " vs " << firstInputDim << " at index " << i
935
+ << " on operands 0 and " << operandNum;
936
+ }
937
+ }
938
+ }
939
+
940
+ return success ();
941
+ }
942
+
874
943
LogicalResult tosa::EqualOp::inferReturnTypeComponents (
875
944
MLIRContext *context, ::std::optional<Location> location,
876
945
ValueShapeRange operands, DictionaryAttr attributes,
@@ -920,6 +989,57 @@ LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
920
989
return success ();
921
990
}
922
991
992
+ LogicalResult MatMulOp::verify () {
993
+ auto aType = llvm::dyn_cast<ShapedType>(getA ().getType ());
994
+ auto bType = llvm::dyn_cast<ShapedType>(getB ().getType ());
995
+
996
+ // Must be shaped tensor types
997
+ if (!aType) {
998
+ emitOpError (" expect a shaped tensor for input a, got " ) << getA ().getType ();
999
+ return failure ();
1000
+ }
1001
+ if (!bType) {
1002
+ emitOpError (" expect a shaped tensor for input b, got " ) << getB ().getType ();
1003
+ return failure ();
1004
+ }
1005
+
1006
+ auto aElementType = aType.getElementType ();
1007
+ auto bElementType = bType.getElementType ();
1008
+
1009
+ auto aQuantizedEType =
1010
+ llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
1011
+ auto bQuantizedEType =
1012
+ llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
1013
+
1014
+ if (aQuantizedEType || bQuantizedEType) {
1015
+ if (!aQuantizedEType || !bQuantizedEType) {
1016
+ emitOpError (
1017
+ " expect operands to be both quantized or both not quantized, got " )
1018
+ << aElementType << " and " << bElementType;
1019
+ return failure ();
1020
+ }
1021
+ // both a and b have quantized element types
1022
+ auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth ();
1023
+ auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth ();
1024
+ if (aQuantWidth != bQuantWidth) {
1025
+ emitOpError (" expect quantized operands to have same widths, got " )
1026
+ << aQuantWidth << " and " << bQuantWidth;
1027
+ return failure ();
1028
+ }
1029
+
1030
+ return success ();
1031
+ }
1032
+
1033
+ // non-quantized element types
1034
+ if (aElementType != bElementType) {
1035
+ emitOpError (" expect same element type for inputs a and b, got " )
1036
+ << aElementType << " and " << bElementType;
1037
+ return failure ();
1038
+ }
1039
+
1040
+ return success ();
1041
+ }
1042
+
923
1043
LogicalResult tosa::PadOp::inferReturnTypeComponents (
924
1044
MLIRContext *context, ::std::optional<Location> location,
925
1045
PadOp::Adaptor adaptor,
@@ -968,6 +1088,20 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
968
1088
}
969
1089
970
1090
LogicalResult tosa::PadOp::verify () {
1091
+ if (verifySameElementTypes (*this , /* inType = */ getInput1 ().getType (),
1092
+ /* outType = */ getOutput ().getType ())
1093
+ .failed ()) {
1094
+ return failure ();
1095
+ }
1096
+
1097
+ if (auto padConst = getPadConst ()) {
1098
+ if (verifySameElementTypes (*this , /* inType = */ padConst.getType (),
1099
+ /* outType = */ getOutput ().getType ())
1100
+ .failed ()) {
1101
+ return failure ();
1102
+ }
1103
+ }
1104
+
971
1105
RankedTensorType inputType = getInput1 ().getType ();
972
1106
RankedTensorType outputType = getOutput ().getType ();
973
1107
auto paddingRank = cast<tosa::shapeType>(getPadding ().getType ()).getRank ();
@@ -1041,21 +1175,23 @@ LogicalResult tosa::SliceOp::inferReturnTypeComponents(
1041
1175
}
1042
1176
1043
1177
LogicalResult tosa::SliceOp::verify () {
1178
+ if (verifySameElementTypes (*this , /* inType = */ getInput1 ().getType (),
1179
+ /* outType = */ getOutput ().getType ())
1180
+ .failed ())
1181
+ return failure ();
1044
1182
auto inputType = llvm::dyn_cast<RankedTensorType>(getInput1 ().getType ());
1045
1183
if (!inputType)
1046
1184
return success ();
1047
1185
1048
1186
auto startShapeRank =
1049
1187
llvm::cast<tosa::shapeType>(getStart ().getType ()).getRank ();
1050
1188
if (inputType.getRank () != startShapeRank)
1051
- return emitOpError (
1052
- " length of start attribute is not equal rank of input shape" );
1189
+ return emitOpError (" length of start is not equal to rank of input shape" );
1053
1190
1054
1191
auto sizeShapeRank =
1055
1192
llvm::cast<tosa::shapeType>(getSize ().getType ()).getRank ();
1056
1193
if (inputType.getRank () != sizeShapeRank)
1057
- return emitOpError (
1058
- " length of size attribute is not equal rank of input shape" );
1194
+ return emitOpError (" length of size is not equal to rank of input shape" );
1059
1195
1060
1196
return success ();
1061
1197
}
@@ -1260,6 +1396,11 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
1260
1396
}
1261
1397
1262
1398
LogicalResult tosa::TileOp::verify () {
1399
+ if (verifySameElementTypes (*this , /* intype = */ getInput1 ().getType (),
1400
+ /* outType = */ getOutput ().getType ())
1401
+ .failed ()) {
1402
+ return failure ();
1403
+ }
1263
1404
ShapedType inputType = llvm::cast<ShapedType>(getInput1 ().getType ());
1264
1405
ShapedType outputType = llvm::cast<ShapedType>(getType ());
1265
1406
@@ -1341,6 +1482,11 @@ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
1341
1482
}
1342
1483
1343
1484
llvm::LogicalResult tosa::ReshapeOp::verify () {
1485
+ if (verifySameElementTypes (*this , /* inType = */ getInput1 ().getType (),
1486
+ /* outType = */ getOutput ().getType ())
1487
+ .failed ()) {
1488
+ return failure ();
1489
+ }
1344
1490
TensorType inputType = getInput1 ().getType ();
1345
1491
RankedTensorType outputType = getType ();
1346
1492
@@ -1528,6 +1674,11 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
1528
1674
}
1529
1675
1530
1676
LogicalResult tosa::TransposeOp::verify () {
1677
+ if (verifySameElementTypes (*this , /* inType = */ getInput1 ().getType (),
1678
+ /* outType = */ getOutput ().getType ())
1679
+ .failed ()) {
1680
+ return failure ();
1681
+ }
1531
1682
TensorType inputType = getInput1 ().getType ();
1532
1683
TensorType outputType = getOutput ().getType ();
1533
1684
const llvm::ArrayRef<int32_t > constantPerms = getPerms ();
@@ -1628,6 +1779,11 @@ LogicalResult tosa::GatherOp::inferReturnTypeComponents(
1628
1779
return success ();
1629
1780
}
1630
1781
1782
+ LogicalResult tosa::GatherOp::verify () {
1783
+ return verifySameElementTypes (*this , /* inType = */ getValues ().getType (),
1784
+ /* outType = */ getOutput ().getType ());
1785
+ }
1786
+
1631
1787
LogicalResult tosa::ResizeOp::inferReturnTypeComponents (
1632
1788
MLIRContext *context, ::std::optional<Location> location,
1633
1789
ResizeOp::Adaptor adaptor,
@@ -1789,6 +1945,18 @@ LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
1789
1945
return success ();
1790
1946
}
1791
1947
1948
+ LogicalResult tosa::ScatterOp::verify () {
1949
+ if (verifySameElementTypes (*this , /* inType = */ getValuesIn ().getType (),
1950
+ /* outType = */ getValuesOut ().getType ())
1951
+ .failed () ||
1952
+ verifySameElementTypes (*this , /* inType = */ getInput ().getType (),
1953
+ /* outType = */ getValuesOut ().getType ())
1954
+ .failed ()) {
1955
+ return failure ();
1956
+ }
1957
+ return success ();
1958
+ }
1959
+
1792
1960
static LogicalResult ReduceInferReturnTypes (
1793
1961
ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
1794
1962
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
@@ -2244,6 +2412,11 @@ LogicalResult MaxPool2dOp::inferReturnTypeComponents(
2244
2412
inferredReturnShapes);
2245
2413
}
2246
2414
2415
+ LogicalResult MaxPool2dOp::verify () {
2416
+ return verifySameElementTypes (*this , /* intype = */ getInput ().getType (),
2417
+ /* outType = */ getOutput ().getType ());
2418
+ }
2419
+
2247
2420
LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents (
2248
2421
MLIRContext *context, ::std::optional<Location> location,
2249
2422
DepthwiseConv2DOp::Adaptor adaptor,
@@ -2544,6 +2717,10 @@ void IfOp::print(OpAsmPrinter &p) {
2544
2717
}
2545
2718
2546
2719
LogicalResult ReverseOp::verify () {
2720
+ if (verifySameElementTypes (*this , /* inType = */ getInput1 ().getType (),
2721
+ /* outType = */ getOutput ().getType ())
2722
+ .failed ())
2723
+ return failure ();
2547
2724
TensorType inputType = getInput1 ().getType ();
2548
2725
TensorType outputType = getOutput ().getType ();
2549
2726
int32_t reverseAxis = getAxis ();
@@ -2572,6 +2749,33 @@ LogicalResult ReverseOp::verify() {
2572
2749
return success ();
2573
2750
}
2574
2751
2752
+ LogicalResult tosa::SelectOp::verify () {
2753
+ // verify input2 and input3 have same element type as output
2754
+ if (verifySameElementTypes (*this , /* inType = */ getInput2 ().getType (),
2755
+ /* outType = */ getOutput ().getType ())
2756
+ .failed () ||
2757
+ verifySameElementTypes (*this , /* inType = */ getInput3 ().getType (),
2758
+ /* outType = */ getOutput ().getType ())
2759
+ .failed ()) {
2760
+ return failure ();
2761
+ }
2762
+ // verify input1 has element type of bool
2763
+ auto predicateType = llvm::dyn_cast<ShapedType>(getInput1 ().getType ());
2764
+ if (!predicateType) {
2765
+ emitOpError (" expect shaped tensor for input1, got " )
2766
+ << getInput1 ().getType ();
2767
+ return failure ();
2768
+ }
2769
+ auto predicateElementType = predicateType.getElementType ();
2770
+ if (!predicateElementType.isInteger (1 )) {
2771
+ emitOpError (" expect element type of bool for input1, got " )
2772
+ << predicateElementType;
2773
+ return failure ();
2774
+ }
2775
+
2776
+ return success ();
2777
+ }
2778
+
2575
2779
// parse and print of WhileOp refer to the implementation of SCF dialect.
2576
2780
ParseResult WhileOp::parse (OpAsmParser &parser, OperationState &result) {
2577
2781
SmallVector<OpAsmParser::Argument, 4 > regionArgs;
0 commit comments