@@ -871,6 +871,71 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
871
871
return success ();
872
872
}
873
873
874
+ LogicalResult tosa::ConcatOp::verify () {
875
+ // check that each input has same element type as output
876
+ auto outType = getOutput ().getType ();
877
+ const Operation::operand_range inputList = getInput1 ();
878
+
879
+ // Check there is at least one input
880
+ if (inputList.empty ())
881
+ return emitOpError (" expect at least one input" );
882
+
883
+ if (!llvm::all_of (inputList, [&](auto input) {
884
+ return succeeded (verifySameElementTypes (
885
+ *this , /* inType = */ input.getType (), outType));
886
+ })) {
887
+ return failure ();
888
+ }
889
+
890
+ const Type firstInputType = inputList.front ().getType ();
891
+ const ShapeAdaptor firstInputShape (firstInputType);
892
+ const int32_t axis = getAxis ();
893
+
894
+ if (firstInputShape.hasRank ()) {
895
+ // Check axis is in expected range
896
+ if (axis < 0 || axis >= firstInputShape.getRank ())
897
+ return emitOpError (" expect axis to be within range 0 < axis < "
898
+ " rank(input1[0]), got " )
899
+ << axis;
900
+ }
901
+
902
+ const auto allOperandsHasRank = [](const Value input) {
903
+ return ShapeAdaptor (input.getType ()).hasRank ();
904
+ };
905
+ if (llvm::all_of (inputList, allOperandsHasRank)) {
906
+ const int64_t firstInputRank = firstInputShape.getRank ();
907
+
908
+ for (const auto [index , input] : llvm::enumerate (inputList.drop_front ())) {
909
+ const ShapeAdaptor inputShape (input.getType ());
910
+ const int64_t inputRank = inputShape.getRank ();
911
+ const size_t operandNum = index + 1 ;
912
+
913
+ // Check that each operand has the same rank
914
+ if (inputRank != firstInputRank)
915
+ return emitOpError (
916
+ " expect all operands to have the same rank, but got " )
917
+ << firstInputRank << " vs " << inputRank << " on operands 0 and "
918
+ << operandNum;
919
+
920
+ // Check non-axis dims match
921
+ for (int i = 0 ; i < inputRank; i++) {
922
+ const int64_t inputDim = inputShape.getDimSize (i);
923
+ const int64_t firstInputDim = firstInputShape.getDimSize (i);
924
+ if (i == axis || firstInputShape.isDynamicDim (i) ||
925
+ inputShape.isDynamicDim (i))
926
+ continue ;
927
+ if (inputDim != firstInputDim)
928
+ return emitOpError (" expect all operand shapes to have the same sizes "
929
+ " on non-axis dimensions, but got " )
930
+ << inputDim << " vs " << firstInputDim << " at index " << i
931
+ << " on operands 0 and " << operandNum;
932
+ }
933
+ }
934
+ }
935
+
936
+ return success ();
937
+ }
938
+
874
939
LogicalResult tosa::EqualOp::inferReturnTypeComponents (
875
940
MLIRContext *context, ::std::optional<Location> location,
876
941
ValueShapeRange operands, DictionaryAttr attributes,
@@ -920,6 +985,57 @@ LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
920
985
return success ();
921
986
}
922
987
988
+ LogicalResult MatMulOp::verify () {
989
+ auto aType = llvm::dyn_cast<ShapedType>(getA ().getType ());
990
+ auto bType = llvm::dyn_cast<ShapedType>(getB ().getType ());
991
+
992
+ // Must be shaped tensor types
993
+ if (!aType) {
994
+ emitOpError (" expect a shaped tensor for input a, got " ) << getA ().getType ();
995
+ return failure ();
996
+ }
997
+ if (!bType) {
998
+ emitOpError (" expect a shaped tensor for input b, got " ) << getB ().getType ();
999
+ return failure ();
1000
+ }
1001
+
1002
+ auto aElementType = aType.getElementType ();
1003
+ auto bElementType = bType.getElementType ();
1004
+
1005
+ auto aQuantizedEType =
1006
+ llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
1007
+ auto bQuantizedEType =
1008
+ llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
1009
+
1010
+ if (aQuantizedEType || bQuantizedEType) {
1011
+ if (!aQuantizedEType || !bQuantizedEType) {
1012
+ emitOpError (
1013
+ " expect operands to be both quantized or both not quantized, got " )
1014
+ << aElementType << " and " << bElementType;
1015
+ return failure ();
1016
+ }
1017
+ // both a and b have quantized element types
1018
+ auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth ();
1019
+ auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth ();
1020
+ if (aQuantWidth != bQuantWidth) {
1021
+ emitOpError (" expect quantized operands to have same widths, got " )
1022
+ << aQuantWidth << " and " << bQuantWidth;
1023
+ return failure ();
1024
+ }
1025
+
1026
+ return success ();
1027
+ }
1028
+
1029
+ // non-quantized element types
1030
+ if (aElementType != bElementType) {
1031
+ emitOpError (" expect same element type for inputs a and b, got " )
1032
+ << aElementType << " and " << bElementType;
1033
+ return failure ();
1034
+ }
1035
+
1036
+ return success ();
1037
+ }
1038
+
923
1039
LogicalResult tosa::PadOp::inferReturnTypeComponents (
924
1040
MLIRContext *context, ::std::optional<Location> location,
925
1041
PadOp::Adaptor adaptor,
@@ -968,6 +1084,20 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
968
1084
}
969
1085
970
1086
LogicalResult tosa::PadOp::verify () {
1087
+ if (verifySameElementTypes (*this , /* inType = */ getInput1 ().getType (),
1088
+ /* outType = */ getOutput ().getType ())
1089
+ .failed ()) {
1090
+ return failure ();
1091
+ }
1092
+
1093
+ if (auto padConst = getPadConst ()) {
1094
+ if (verifySameElementTypes (*this , /* inType = */ padConst.getType (),
1095
+ /* outType = */ getOutput ().getType ())
1096
+ .failed ()) {
1097
+ return failure ();
1098
+ }
1099
+ }
1100
+
971
1101
RankedTensorType inputType = getInput1 ().getType ();
972
1102
RankedTensorType outputType = getOutput ().getType ();
973
1103
auto paddingRank = cast<tosa::shapeType>(getPadding ().getType ()).getRank ();
@@ -1041,21 +1171,23 @@ LogicalResult tosa::SliceOp::inferReturnTypeComponents(
1041
1171
}
1042
1172
1043
1173
LogicalResult tosa::SliceOp::verify () {
1174
+ if (verifySameElementTypes (*this , /* inType = */ getInput1 ().getType (),
1175
+ /* outType = */ getOutput ().getType ())
1176
+ .failed ())
1177
+ return failure ();
1044
1178
auto inputType = llvm::dyn_cast<RankedTensorType>(getInput1 ().getType ());
1045
1179
if (!inputType)
1046
1180
return success ();
1047
1181
1048
1182
auto startShapeRank =
1049
1183
llvm::cast<tosa::shapeType>(getStart ().getType ()).getRank ();
1050
1184
if (inputType.getRank () != startShapeRank)
1051
- return emitOpError (
1052
- " length of start attribute is not equal rank of input shape" );
1185
+ return emitOpError (" length of start is not equal to rank of input shape" );
1053
1186
1054
1187
auto sizeShapeRank =
1055
1188
llvm::cast<tosa::shapeType>(getSize ().getType ()).getRank ();
1056
1189
if (inputType.getRank () != sizeShapeRank)
1057
- return emitOpError (
1058
- " length of size attribute is not equal rank of input shape" );
1190
+ return emitOpError (" length of size is not equal to rank of input shape" );
1059
1191
1060
1192
return success ();
1061
1193
}
@@ -1260,6 +1392,11 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
1260
1392
}
1261
1393
1262
1394
LogicalResult tosa::TileOp::verify () {
1395
+ if (verifySameElementTypes (*this , /* intype = */ getInput1 ().getType (),
1396
+ /* outType = */ getOutput ().getType ())
1397
+ .failed ()) {
1398
+ return failure ();
1399
+ }
1263
1400
ShapedType inputType = llvm::cast<ShapedType>(getInput1 ().getType ());
1264
1401
ShapedType outputType = llvm::cast<ShapedType>(getType ());
1265
1402
@@ -1341,6 +1478,11 @@ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
1341
1478
}
1342
1479
1343
1480
llvm::LogicalResult tosa::ReshapeOp::verify () {
1481
+ if (verifySameElementTypes (*this , /* inType = */ getInput1 ().getType (),
1482
+ /* outType = */ getOutput ().getType ())
1483
+ .failed ()) {
1484
+ return failure ();
1485
+ }
1344
1486
TensorType inputType = getInput1 ().getType ();
1345
1487
RankedTensorType outputType = getType ();
1346
1488
@@ -1528,6 +1670,11 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
1528
1670
}
1529
1671
1530
1672
LogicalResult tosa::TransposeOp::verify () {
1673
+ if (verifySameElementTypes (*this , /* inType = */ getInput1 ().getType (),
1674
+ /* outType = */ getOutput ().getType ())
1675
+ .failed ()) {
1676
+ return failure ();
1677
+ }
1531
1678
TensorType inputType = getInput1 ().getType ();
1532
1679
TensorType outputType = getOutput ().getType ();
1533
1680
const llvm::ArrayRef<int32_t > constantPerms = getPerms ();
@@ -1628,6 +1775,11 @@ LogicalResult tosa::GatherOp::inferReturnTypeComponents(
1628
1775
return success ();
1629
1776
}
1630
1777
1778
+ LogicalResult tosa::GatherOp::verify () {
1779
+ return verifySameElementTypes (*this , /* inType = */ getValues ().getType (),
1780
+ /* outType = */ getOutput ().getType ());
1781
+ }
1782
+
1631
1783
LogicalResult tosa::ResizeOp::inferReturnTypeComponents (
1632
1784
MLIRContext *context, ::std::optional<Location> location,
1633
1785
ResizeOp::Adaptor adaptor,
@@ -1789,6 +1941,18 @@ LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
1789
1941
return success ();
1790
1942
}
1791
1943
1944
+ LogicalResult tosa::ScatterOp::verify () {
1945
+ if (verifySameElementTypes (*this , /* inType = */ getValuesIn ().getType (),
1946
+ /* outType = */ getValuesOut ().getType ())
1947
+ .failed () ||
1948
+ verifySameElementTypes (*this , /* inType = */ getInput ().getType (),
1949
+ /* outType = */ getValuesOut ().getType ())
1950
+ .failed ()) {
1951
+ return failure ();
1952
+ }
1953
+ return success ();
1954
+ }
1955
+
1792
1956
static LogicalResult ReduceInferReturnTypes (
1793
1957
ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
1794
1958
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
@@ -2244,6 +2408,11 @@ LogicalResult MaxPool2dOp::inferReturnTypeComponents(
2244
2408
inferredReturnShapes);
2245
2409
}
2246
2410
2411
+ LogicalResult MaxPool2dOp::verify () {
2412
+ return verifySameElementTypes (*this , /* intype = */ getInput ().getType (),
2413
+ /* outType = */ getOutput ().getType ());
2414
+ }
2415
+
2247
2416
LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents (
2248
2417
MLIRContext *context, ::std::optional<Location> location,
2249
2418
DepthwiseConv2DOp::Adaptor adaptor,
@@ -2546,6 +2715,10 @@ void IfOp::print(OpAsmPrinter &p) {
2546
2715
}
2547
2716
2548
2717
LogicalResult ReverseOp::verify () {
2718
+ if (verifySameElementTypes (*this , /* inType = */ getInput1 ().getType (),
2719
+ /* outType = */ getOutput ().getType ())
2720
+ .failed ())
2721
+ return failure ();
2549
2722
TensorType inputType = getInput1 ().getType ();
2550
2723
TensorType outputType = getOutput ().getType ();
2551
2724
int32_t reverseAxis = getAxis ();
@@ -2574,6 +2747,33 @@ LogicalResult ReverseOp::verify() {
2574
2747
return success ();
2575
2748
}
2576
2749
2750
+ LogicalResult tosa::SelectOp::verify () {
2751
+ // verify input2 and input3 have same element type as output
2752
+ if (verifySameElementTypes (*this , /* inType = */ getInput2 ().getType (),
2753
+ /* outType = */ getOutput ().getType ())
2754
+ .failed () ||
2755
+ verifySameElementTypes (*this , /* inType = */ getInput3 ().getType (),
2756
+ /* outType = */ getOutput ().getType ())
2757
+ .failed ()) {
2758
+ return failure ();
2759
+ }
2760
+ // verify input1 has element type of bool
2761
+ auto predicateType = llvm::dyn_cast<ShapedType>(getInput1 ().getType ());
2762
+ if (!predicateType) {
2763
+ emitOpError (" expect shaped tensor for input1, got " )
2764
+ << getInput1 ().getType ();
2765
+ return failure ();
2766
+ }
2767
+ auto predicateElementType = predicateType.getElementType ();
2768
+ if (!predicateElementType.isInteger (1 )) {
2769
+ emitOpError (" expect element type of bool for input1, got " )
2770
+ << predicateElementType;
2771
+ return failure ();
2772
+ }
2773
+
2774
+ return success ();
2775
+ }
2776
+
2577
2777
// parse and print of WhileOp refer to the implementation of SCF dialect.
2578
2778
ParseResult WhileOp::parse (OpAsmParser &parser, OperationState &result) {
2579
2779
SmallVector<OpAsmParser::Argument, 4 > regionArgs;
0 commit comments