Skip to content

Commit bb978ee

Browse files
Jerry-GeTai78641lhutton1
committed
[mlir][tosa] Add more verifiers for the following operators
For ConcatOp this commit also enhances the verifier by checking 4 another conditions: - The input list is not empty - The axis value is within range of the input shapes - All inputs have the same rank - All non concatenate axis dims have the same value For MatmulOp: - Checked input a, bs tensor type, element types For the following operators, added the verifySameElementTypes check. - PadOp - SliceOp - TileOp - ReshapeOp - TransposeOp - GatherOp - ScatterOp - MaxPool2dOp - ReverseOp - SelectOp Change-Id: I1e8a1017f21f617443bc40bae42189915048c750 Co-authored-by: Tai Ly <[email protected]> Co-authored-by: Luke Hutton <[email protected]> Signed-off-by: Jerry Ge <[email protected]>
1 parent 440ea3e commit bb978ee

File tree

3 files changed

+249
-13
lines changed

3 files changed

+249
-13
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,7 @@ def Tosa_MatMulOp : Tosa_InferShapedTypeOp<"matmul"> {
310310
];
311311

312312
let builders = [Tosa_MatMulOpQuantInfoBuilder];
313+
let hasVerifier = 1;
313314
}
314315

315316
//===----------------------------------------------------------------------===//
@@ -344,6 +345,7 @@ def Tosa_MaxPool2dOp : Tosa_InferShapedTypeOp<"max_pool2d"> {
344345
];
345346

346347
let hasCanonicalizer = 1;
348+
let hasVerifier = 1;
347349
}
348350

349351
//===----------------------------------------------------------------------===//
@@ -1471,6 +1473,7 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> {
14711473

14721474
let hasCanonicalizeMethod = 1;
14731475
let hasFolder = 1;
1476+
let hasVerifier = 1;
14741477

14751478
let assemblyFormat = [{
14761479
operands attr-dict `:` `(` type($input1) `,` type($input2) `,` type($input3)
@@ -1846,6 +1849,7 @@ def Tosa_ConcatOp : Tosa_InferTensorTypeOp<"concat"> {
18461849

18471850
let hasCanonicalizer = 1;
18481851
let hasFolder = 1;
1852+
let hasVerifier = 1;
18491853

18501854
let extraClassDeclaration = [{
18511855
/// Returns true when two result types are compatible for this op;
@@ -2102,6 +2106,8 @@ def Tosa_GatherOp : Tosa_InferShapedTypeOp<"gather"> {
21022106
Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
21032107
Extension<[Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
21042108
];
2109+
2110+
let hasVerifier = 1;
21052111
}
21062112

21072113
//===----------------------------------------------------------------------===//
@@ -2135,6 +2141,8 @@ def Tosa_ScatterOp : Tosa_InferShapedTypeOp<"scatter"> {
21352141
Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
21362142
Extension<[Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
21372143
];
2144+
2145+
let hasVerifier = 1;
21382146
}
21392147

21402148
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 204 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -871,6 +871,71 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
871871
return success();
872872
}
873873

874+
LogicalResult tosa::ConcatOp::verify() {
875+
// check that each input has same element type as output
876+
auto outType = getOutput().getType();
877+
const Operation::operand_range inputList = getInput1();
878+
879+
// Check there is at least one input
880+
if (inputList.empty())
881+
return emitOpError("expect at least one input");
882+
883+
if (!llvm::all_of(inputList, [&](auto input) {
884+
return succeeded(verifySameElementTypes(
885+
*this, /* inType = */ input.getType(), outType));
886+
})) {
887+
return failure();
888+
}
889+
890+
const Type firstInputType = inputList.front().getType();
891+
const ShapeAdaptor firstInputShape(firstInputType);
892+
const int32_t axis = getAxis();
893+
894+
if (firstInputShape.hasRank()) {
895+
// Check axis is in expected range
896+
if (axis < 0 || axis >= firstInputShape.getRank())
897+
return emitOpError("expect axis to be within range 0 < axis < "
898+
"rank(input1[0]), got ")
899+
<< axis;
900+
}
901+
902+
const auto allOperandsHasRank = [](const Value input) {
903+
return ShapeAdaptor(input.getType()).hasRank();
904+
};
905+
if (llvm::all_of(inputList, allOperandsHasRank)) {
906+
const int64_t firstInputRank = firstInputShape.getRank();
907+
908+
for (const auto [index, input] : llvm::enumerate(inputList.drop_front())) {
909+
const ShapeAdaptor inputShape(input.getType());
910+
const int64_t inputRank = inputShape.getRank();
911+
const size_t operandNum = index + 1;
912+
913+
// Check that each operand has the same rank
914+
if (inputRank != firstInputRank)
915+
return emitOpError(
916+
"expect all operands to have the same rank, but got ")
917+
<< firstInputRank << " vs " << inputRank << " on operands 0 and "
918+
<< operandNum;
919+
920+
// Check non-axis dims match
921+
for (int i = 0; i < inputRank; i++) {
922+
const int64_t inputDim = inputShape.getDimSize(i);
923+
const int64_t firstInputDim = firstInputShape.getDimSize(i);
924+
if (i == axis || firstInputShape.isDynamicDim(i) ||
925+
inputShape.isDynamicDim(i))
926+
continue;
927+
if (inputDim != firstInputDim)
928+
return emitOpError("expect all operand shapes to have the same sizes "
929+
"on non-axis dimensions, but got ")
930+
<< inputDim << " vs " << firstInputDim << " at index " << i
931+
<< " on operands 0 and " << operandNum;
932+
}
933+
}
934+
}
935+
936+
return success();
937+
}
938+
874939
LogicalResult tosa::EqualOp::inferReturnTypeComponents(
875940
MLIRContext *context, ::std::optional<Location> location,
876941
ValueShapeRange operands, DictionaryAttr attributes,
@@ -920,6 +985,57 @@ LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
920985
return success();
921986
}
922987

988+
LogicalResult MatMulOp::verify() {
989+
auto aType = llvm::dyn_cast<ShapedType>(getA().getType());
990+
auto bType = llvm::dyn_cast<ShapedType>(getB().getType());
991+
992+
// Must be shaped tensor types
993+
if (!aType) {
994+
emitOpError("expect a shaped tensor for input a, got ") << getA().getType();
995+
return failure();
996+
}
997+
if (!bType) {
998+
emitOpError("expect a shaped tensor for input b, got ") << getB().getType();
999+
return failure();
1000+
}
1001+
1002+
auto aElementType = aType.getElementType();
1003+
auto bElementType = bType.getElementType();
1004+
1005+
auto aQuantizedEType =
1006+
llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
1007+
auto bQuantizedEType =
1008+
llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
1009+
1010+
if (aQuantizedEType || bQuantizedEType) {
1011+
if (!aQuantizedEType || !bQuantizedEType) {
1012+
emitOpError(
1013+
"expect operands to be both quantized or both not quantized, got ")
1014+
<< aElementType << " and " << bElementType;
1015+
return failure();
1016+
}
1017+
// both a and b have quantized element types
1018+
auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth();
1019+
auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth();
1020+
if (aQuantWidth != bQuantWidth) {
1021+
emitOpError("expect quantized operands to have same widths, got ")
1022+
<< aQuantWidth << " and " << bQuantWidth;
1023+
return failure();
1024+
}
1025+
1026+
return success();
1027+
}
1028+
1029+
// non-quantized element types
1030+
if (aElementType != bElementType) {
1031+
emitOpError("expect same element type for inputs a and b, got ")
1032+
<< aElementType << " and " << bElementType;
1033+
return failure();
1034+
}
1035+
1036+
return success();
1037+
}
1038+
9231039
LogicalResult tosa::PadOp::inferReturnTypeComponents(
9241040
MLIRContext *context, ::std::optional<Location> location,
9251041
PadOp::Adaptor adaptor,
@@ -968,6 +1084,20 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
9681084
}
9691085

9701086
LogicalResult tosa::PadOp::verify() {
1087+
if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
1088+
/* outType = */ getOutput().getType())
1089+
.failed()) {
1090+
return failure();
1091+
}
1092+
1093+
if (auto padConst = getPadConst()) {
1094+
if (verifySameElementTypes(*this, /* inType = */ padConst.getType(),
1095+
/* outType = */ getOutput().getType())
1096+
.failed()) {
1097+
return failure();
1098+
}
1099+
}
1100+
9711101
RankedTensorType inputType = getInput1().getType();
9721102
RankedTensorType outputType = getOutput().getType();
9731103
auto paddingRank = cast<tosa::shapeType>(getPadding().getType()).getRank();
@@ -1041,21 +1171,23 @@ LogicalResult tosa::SliceOp::inferReturnTypeComponents(
10411171
}
10421172

10431173
LogicalResult tosa::SliceOp::verify() {
1174+
if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
1175+
/* outType = */ getOutput().getType())
1176+
.failed())
1177+
return failure();
10441178
auto inputType = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
10451179
if (!inputType)
10461180
return success();
10471181

10481182
auto startShapeRank =
10491183
llvm::cast<tosa::shapeType>(getStart().getType()).getRank();
10501184
if (inputType.getRank() != startShapeRank)
1051-
return emitOpError(
1052-
"length of start attribute is not equal rank of input shape");
1185+
return emitOpError("length of start is not equal to rank of input shape");
10531186

10541187
auto sizeShapeRank =
10551188
llvm::cast<tosa::shapeType>(getSize().getType()).getRank();
10561189
if (inputType.getRank() != sizeShapeRank)
1057-
return emitOpError(
1058-
"length of size attribute is not equal rank of input shape");
1190+
return emitOpError("length of size is not equal to rank of input shape");
10591191

10601192
return success();
10611193
}
@@ -1260,6 +1392,11 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
12601392
}
12611393

12621394
LogicalResult tosa::TileOp::verify() {
1395+
if (verifySameElementTypes(*this, /* intype = */ getInput1().getType(),
1396+
/* outType = */ getOutput().getType())
1397+
.failed()) {
1398+
return failure();
1399+
}
12631400
ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
12641401
ShapedType outputType = llvm::cast<ShapedType>(getType());
12651402

@@ -1341,6 +1478,11 @@ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
13411478
}
13421479

13431480
llvm::LogicalResult tosa::ReshapeOp::verify() {
1481+
if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
1482+
/* outType = */ getOutput().getType())
1483+
.failed()) {
1484+
return failure();
1485+
}
13441486
TensorType inputType = getInput1().getType();
13451487
RankedTensorType outputType = getType();
13461488

@@ -1528,6 +1670,11 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
15281670
}
15291671

15301672
LogicalResult tosa::TransposeOp::verify() {
1673+
if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
1674+
/* outType = */ getOutput().getType())
1675+
.failed()) {
1676+
return failure();
1677+
}
15311678
TensorType inputType = getInput1().getType();
15321679
TensorType outputType = getOutput().getType();
15331680
const llvm::ArrayRef<int32_t> constantPerms = getPerms();
@@ -1628,6 +1775,11 @@ LogicalResult tosa::GatherOp::inferReturnTypeComponents(
16281775
return success();
16291776
}
16301777

1778+
LogicalResult tosa::GatherOp::verify() {
1779+
return verifySameElementTypes(*this, /* inType = */ getValues().getType(),
1780+
/* outType = */ getOutput().getType());
1781+
}
1782+
16311783
LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
16321784
MLIRContext *context, ::std::optional<Location> location,
16331785
ResizeOp::Adaptor adaptor,
@@ -1789,6 +1941,18 @@ LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
17891941
return success();
17901942
}
17911943

1944+
LogicalResult tosa::ScatterOp::verify() {
1945+
if (verifySameElementTypes(*this, /* inType = */ getValuesIn().getType(),
1946+
/* outType = */ getValuesOut().getType())
1947+
.failed() ||
1948+
verifySameElementTypes(*this, /* inType = */ getInput().getType(),
1949+
/* outType = */ getValuesOut().getType())
1950+
.failed()) {
1951+
return failure();
1952+
}
1953+
return success();
1954+
}
1955+
17921956
static LogicalResult ReduceInferReturnTypes(
17931957
ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
17941958
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
@@ -2244,6 +2408,11 @@ LogicalResult MaxPool2dOp::inferReturnTypeComponents(
22442408
inferredReturnShapes);
22452409
}
22462410

2411+
LogicalResult MaxPool2dOp::verify() {
2412+
return verifySameElementTypes(*this, /* intype = */ getInput().getType(),
2413+
/* outType = */ getOutput().getType());
2414+
}
2415+
22472416
LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
22482417
MLIRContext *context, ::std::optional<Location> location,
22492418
DepthwiseConv2DOp::Adaptor adaptor,
@@ -2546,6 +2715,10 @@ void IfOp::print(OpAsmPrinter &p) {
25462715
}
25472716

25482717
LogicalResult ReverseOp::verify() {
2718+
if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
2719+
/* outType = */ getOutput().getType())
2720+
.failed())
2721+
return failure();
25492722
TensorType inputType = getInput1().getType();
25502723
TensorType outputType = getOutput().getType();
25512724
int32_t reverseAxis = getAxis();
@@ -2574,6 +2747,33 @@ LogicalResult ReverseOp::verify() {
25742747
return success();
25752748
}
25762749

2750+
LogicalResult tosa::SelectOp::verify() {
2751+
// verify input2 and input3 have same element type as output
2752+
if (verifySameElementTypes(*this, /* inType = */ getInput2().getType(),
2753+
/* outType = */ getOutput().getType())
2754+
.failed() ||
2755+
verifySameElementTypes(*this, /* inType = */ getInput3().getType(),
2756+
/* outType = */ getOutput().getType())
2757+
.failed()) {
2758+
return failure();
2759+
}
2760+
// verify input1 has element type of bool
2761+
auto predicateType = llvm::dyn_cast<ShapedType>(getInput1().getType());
2762+
if (!predicateType) {
2763+
emitOpError("expect shaped tensor for input1, got ")
2764+
<< getInput1().getType();
2765+
return failure();
2766+
}
2767+
auto predicateElementType = predicateType.getElementType();
2768+
if (!predicateElementType.isInteger(1)) {
2769+
emitOpError("expect element type of bool for input1, got ")
2770+
<< predicateElementType;
2771+
return failure();
2772+
}
2773+
2774+
return success();
2775+
}
2776+
25772777
// parse and print of WhileOp refer to the implementation of SCF dialect.
25782778
ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
25792779
SmallVector<OpAsmParser::Argument, 4> regionArgs;

0 commit comments

Comments
 (0)