Skip to content

Commit 97ba7a5

Browse files
Jerry-GeTai78641lhutton1
committed
[mlir][tosa] Add more verifiers for the following operators
For ConcatOp this commit also enhances the verifier by checking 4 another conditions: - The input list is not empty - The axis value is within range of the input shapes - All inputs have the same rank - All non concatenate axis dims have the same value For MatmulOp: - Checked input a, bs tensor type, element types For the following operators, added the verifySameElementTypes check. - PadOp - SliceOp - TileOp - ReshapeOp - TransposeOp - GatherOp - ScatterOp - MaxPool2dOp - ReverseOp - SelectOp Change-Id: I1e8a1017f21f617443bc40bae42189915048c750 Co-authored-by: Tai Ly <[email protected]> Co-authored-by: Luke Hutton <[email protected]> Signed-off-by: Jerry Ge <[email protected]>
1 parent 5e4938a commit 97ba7a5

File tree

3 files changed

+224
-13
lines changed

3 files changed

+224
-13
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,7 @@ def Tosa_MatMulOp : Tosa_InferShapedTypeOp<"matmul"> {
286286
];
287287

288288
let builders = [Tosa_MatMulOpQuantInfoBuilder];
289+
let hasVerifier = 1;
289290
}
290291

291292
//===----------------------------------------------------------------------===//
@@ -320,6 +321,7 @@ def Tosa_MaxPool2dOp : Tosa_InferShapedTypeOp<"max_pool2d"> {
320321
];
321322

322323
let hasCanonicalizer = 1;
324+
let hasVerifier = 1;
323325
}
324326

325327
//===----------------------------------------------------------------------===//
@@ -1439,6 +1441,7 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> {
14391441

14401442
let hasCanonicalizeMethod = 1;
14411443
let hasFolder = 1;
1444+
let hasVerifier = 1;
14421445

14431446
let assemblyFormat = [{
14441447
operands attr-dict `:` `(` type($input1) `,` type($input2) `,` type($input3)
@@ -1814,6 +1817,7 @@ def Tosa_ConcatOp : Tosa_InferTensorTypeOp<"concat"> {
18141817

18151818
let hasCanonicalizer = 1;
18161819
let hasFolder = 1;
1820+
let hasVerifier = 1;
18171821

18181822
let extraClassDeclaration = [{
18191823
/// Returns true when two result types are compatible for this op;
@@ -2070,6 +2074,8 @@ def Tosa_GatherOp : Tosa_InferShapedTypeOp<"gather"> {
20702074
Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
20712075
Extension<[Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
20722076
];
2077+
2078+
let hasVerifier = 1;
20732079
}
20742080

20752081
//===----------------------------------------------------------------------===//
@@ -2103,6 +2109,8 @@ def Tosa_ScatterOp : Tosa_InferShapedTypeOp<"scatter"> {
21032109
Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
21042110
Extension<[Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
21052111
];
2112+
2113+
let hasVerifier = 1;
21062114
}
21072115

21082116
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 202 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -850,6 +850,71 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
850850
return success();
851851
}
852852

853+
LogicalResult tosa::ConcatOp::verify() {
854+
// check that each input has same element type as output
855+
auto outType = getOutput().getType();
856+
const Operation::operand_range inputList = getInput1();
857+
858+
if (!llvm::all_of(inputList, [&](auto input) {
859+
return succeeded(verifySameElementTypes(
860+
*this, /* inType = */ input.getType(), outType));
861+
})) {
862+
return failure();
863+
}
864+
865+
// Check there is at least one input
866+
if (inputList.empty())
867+
return emitOpError("expect at least one input");
868+
869+
const Type firstInputType = inputList.front().getType();
870+
const ShapeAdaptor firstInputShape(firstInputType);
871+
const int32_t axis = getAxis();
872+
873+
if (firstInputShape.hasRank()) {
874+
// Check axis is in expected range
875+
if (axis < 0 || axis >= firstInputShape.getRank())
876+
return emitOpError("expect axis to be within range 0 < axis < "
877+
"rank(input1[0]), got ")
878+
<< axis;
879+
}
880+
881+
const auto allOperandsHasRank = [](const Value input) {
882+
return ShapeAdaptor(input.getType()).hasRank();
883+
};
884+
if (llvm::all_of(inputList, allOperandsHasRank)) {
885+
const int64_t firstInputRank = firstInputShape.getRank();
886+
887+
for (const auto [index, input] : llvm::enumerate(inputList.drop_front())) {
888+
const ShapeAdaptor inputShape(input.getType());
889+
const int64_t inputRank = inputShape.getRank();
890+
const size_t operandNum = index + 1;
891+
892+
// Check that each operand has the same rank
893+
if (inputRank != firstInputRank)
894+
return emitOpError(
895+
"expect all operands to have the same rank, but got ")
896+
<< firstInputRank << " vs " << inputRank << " on operands 0 and "
897+
<< operandNum;
898+
899+
// Check non-axis dims match
900+
for (int i = 0; i < inputRank; i++) {
901+
const int64_t inputDim = inputShape.getDimSize(i);
902+
const int64_t firstInputDim = firstInputShape.getDimSize(i);
903+
if (i == axis || firstInputShape.isDynamicDim(i) ||
904+
inputShape.isDynamicDim(i))
905+
continue;
906+
if (inputDim != firstInputDim)
907+
return emitOpError("expect all operand shapes to have the same sizes "
908+
"on non-axis dimensions, but got ")
909+
<< inputDim << " vs " << firstInputDim << " at index " << i
910+
<< " on operands 0 and " << operandNum;
911+
}
912+
}
913+
}
914+
915+
return success();
916+
}
917+
853918
LogicalResult tosa::EqualOp::inferReturnTypeComponents(
854919
MLIRContext *context, ::std::optional<Location> location,
855920
ValueShapeRange operands, DictionaryAttr attributes,
@@ -899,6 +964,57 @@ LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
899964
return success();
900965
}
901966

967+
LogicalResult MatMulOp::verify() {
968+
auto aType = llvm::dyn_cast<ShapedType>(getA().getType());
969+
auto bType = llvm::dyn_cast<ShapedType>(getB().getType());
970+
971+
// Must be shaped tensor types
972+
if (!aType) {
973+
emitOpError("expect a shaped tensor for input a, got ") << getA().getType();
974+
return failure();
975+
}
976+
if (!bType) {
977+
emitOpError("expect a shaped tensor for input b, got ") << getB().getType();
978+
return failure();
979+
}
980+
981+
auto aElementType = aType.getElementType();
982+
auto bElementType = bType.getElementType();
983+
984+
auto aQuantizedEType =
985+
llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
986+
auto bQuantizedEType =
987+
llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
988+
989+
if (aQuantizedEType || bQuantizedEType) {
990+
if (!aQuantizedEType || !bQuantizedEType) {
991+
emitOpError(
992+
"expect operands to be both quantized or both not quantized, got ")
993+
<< aElementType << " and " << bElementType;
994+
return failure();
995+
}
996+
// both a and b have quantized element types
997+
auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth();
998+
auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth();
999+
if (aQuantWidth != bQuantWidth) {
1000+
emitOpError("expect quantized operands to have same widths, got ")
1001+
<< aQuantWidth << " and " << bQuantWidth;
1002+
return failure();
1003+
}
1004+
1005+
return success();
1006+
}
1007+
1008+
// non-quantized element types
1009+
if (aElementType != bElementType) {
1010+
emitOpError("expect same element type for inputs a and b, got ")
1011+
<< aElementType << " and " << bElementType;
1012+
return failure();
1013+
}
1014+
1015+
return success();
1016+
}
1017+
9021018
LogicalResult tosa::PadOp::inferReturnTypeComponents(
9031019
MLIRContext *context, ::std::optional<Location> location,
9041020
PadOp::Adaptor adaptor,
@@ -947,6 +1063,18 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
9471063
}
9481064

9491065
LogicalResult tosa::PadOp::verify() {
1066+
if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
1067+
/* outType = */ getOutput().getType())
1068+
.failed()) {
1069+
return failure();
1070+
}
1071+
if (auto padConst = getPadConst()) {
1072+
if (verifySameElementTypes(*this, /* inType = */ padConst.getType(),
1073+
/* outType = */ getOutput().getType())
1074+
.failed()) {
1075+
return failure();
1076+
}
1077+
}
9501078
RankedTensorType inputType = getInput1().getType();
9511079
RankedTensorType outputType = getOutput().getType();
9521080
auto paddingRank = cast<tosa::shapeType>(getPadding().getType()).getRank();
@@ -1020,21 +1148,23 @@ LogicalResult tosa::SliceOp::inferReturnTypeComponents(
10201148
}
10211149

10221150
LogicalResult tosa::SliceOp::verify() {
1151+
if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
1152+
/* outType = */ getOutput().getType())
1153+
.failed())
1154+
return failure();
10231155
auto inputType = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
10241156
if (!inputType)
10251157
return success();
10261158

10271159
auto startShapeRank =
10281160
llvm::cast<tosa::shapeType>(getStart().getType()).getRank();
10291161
if (inputType.getRank() != startShapeRank)
1030-
return emitOpError(
1031-
"length of start attribute is not equal rank of input shape");
1162+
return emitOpError("length of start is not equal to rank of input shape");
10321163

10331164
auto sizeShapeRank =
10341165
llvm::cast<tosa::shapeType>(getSize().getType()).getRank();
10351166
if (inputType.getRank() != sizeShapeRank)
1036-
return emitOpError(
1037-
"length of size attribute is not equal rank of input shape");
1167+
return emitOpError("length of size is not equal to rank of input shape");
10381168

10391169
return success();
10401170
}
@@ -1239,6 +1369,11 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
12391369
}
12401370

12411371
LogicalResult tosa::TileOp::verify() {
1372+
if (verifySameElementTypes(*this, /* intype = */ getInput1().getType(),
1373+
/* outType = */ getOutput().getType())
1374+
.failed()) {
1375+
return failure();
1376+
}
12421377
ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
12431378
ShapedType outputType = llvm::cast<ShapedType>(getType());
12441379

@@ -1320,6 +1455,11 @@ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
13201455
}
13211456

13221457
llvm::LogicalResult tosa::ReshapeOp::verify() {
1458+
if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
1459+
/* outType = */ getOutput().getType())
1460+
.failed()) {
1461+
return failure();
1462+
}
13231463
TensorType inputType = getInput1().getType();
13241464
RankedTensorType outputType = getType();
13251465

@@ -1434,6 +1574,11 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
14341574
}
14351575

14361576
LogicalResult tosa::TransposeOp::verify() {
1577+
if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
1578+
/* outType = */ getOutput().getType())
1579+
.failed()) {
1580+
return failure();
1581+
}
14371582
TensorType inputType = getInput1().getType();
14381583
TensorType outputType = getOutput().getType();
14391584
const llvm::ArrayRef<int32_t> constantPerms = getPerms();
@@ -1534,6 +1679,11 @@ LogicalResult tosa::GatherOp::inferReturnTypeComponents(
15341679
return success();
15351680
}
15361681

1682+
LogicalResult tosa::GatherOp::verify() {
1683+
return verifySameElementTypes(*this, /* inType = */ getValues().getType(),
1684+
/* outType = */ getOutput().getType());
1685+
}
1686+
15371687
LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
15381688
MLIRContext *context, ::std::optional<Location> location,
15391689
ResizeOp::Adaptor adaptor,
@@ -1702,6 +1852,18 @@ LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
17021852
return success();
17031853
}
17041854

1855+
LogicalResult tosa::ScatterOp::verify() {
1856+
if (verifySameElementTypes(*this, /* inType = */ getValuesIn().getType(),
1857+
/* outType = */ getValuesOut().getType())
1858+
.failed() ||
1859+
verifySameElementTypes(*this, /* inType = */ getInput().getType(),
1860+
/* outType = */ getValuesOut().getType())
1861+
.failed()) {
1862+
return failure();
1863+
}
1864+
return success();
1865+
}
1866+
17051867
static LogicalResult ReduceInferReturnTypes(
17061868
ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
17071869
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
@@ -2066,6 +2228,11 @@ LogicalResult MaxPool2dOp::inferReturnTypeComponents(
20662228
inferredReturnShapes);
20672229
}
20682230

2231+
LogicalResult MaxPool2dOp::verify() {
2232+
return verifySameElementTypes(*this, /* intype = */ getInput().getType(),
2233+
/* outType = */ getOutput().getType());
2234+
}
2235+
20692236
LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
20702237
MLIRContext *context, ::std::optional<Location> location,
20712238
DepthwiseConv2DOp::Adaptor adaptor,
@@ -2368,6 +2535,10 @@ void IfOp::print(OpAsmPrinter &p) {
23682535
}
23692536

23702537
LogicalResult ReverseOp::verify() {
2538+
if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
2539+
/* outType = */ getOutput().getType())
2540+
.failed())
2541+
return failure();
23712542
TensorType inputType = getInput1().getType();
23722543
TensorType outputType = getOutput().getType();
23732544
int32_t reverseAxis = getAxis();
@@ -2396,6 +2567,33 @@ LogicalResult ReverseOp::verify() {
23962567
return success();
23972568
}
23982569

2570+
LogicalResult tosa::SelectOp::verify() {
2571+
// verify input2 and input3 have same element type as output
2572+
if (verifySameElementTypes(*this, /* inType = */ getInput2().getType(),
2573+
/* outType = */ getOutput().getType())
2574+
.failed() ||
2575+
verifySameElementTypes(*this, /* inType = */ getInput3().getType(),
2576+
/* outType = */ getOutput().getType())
2577+
.failed()) {
2578+
return failure();
2579+
}
2580+
// verify input1 has element type of bool
2581+
auto predicateType = llvm::dyn_cast<ShapedType>(getInput1().getType());
2582+
if (!predicateType) {
2583+
emitOpError("expect shaped tensor for input1, got ")
2584+
<< getInput1().getType();
2585+
return failure();
2586+
}
2587+
auto predicateElementType = predicateType.getElementType();
2588+
if (!predicateElementType.isInteger(1)) {
2589+
emitOpError("expect element type of bool for input1, got ")
2590+
<< predicateElementType;
2591+
return failure();
2592+
}
2593+
2594+
return success();
2595+
}
2596+
23992597
// parse and print of WhileOp refer to the implementation of SCF dialect.
24002598
ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
24012599
SmallVector<OpAsmParser::Argument, 4> regionArgs;

0 commit comments

Comments
 (0)