Skip to content

Commit 5e03b5b

Browse files
Jerry-GeTai78641lhutton1
committed
[mlir][tosa] Add more verifiers for the following operators
For ConcatOp this commit also enhances the verifier by checking 4 another conditions: - The input list is not empty - The axis value is within range of the input shapes - All inputs have the same rank - All non concatenate axis dims have the same value For MatmulOp: - Checked input a, bs tensor type, element types For the following operators, added the verifySameElementTypes check. - PadOp - SliceOp - TileOp - ReshapeOp - TransposeOp - GatherOp - ScatterOp - MaxPool2dOp - ReverseOp - SelectOp Change-Id: I1e8a1017f21f617443bc40bae42189915048c750 Co-authored-by: Tai Ly <[email protected]> Co-authored-by: Luke Hutton <[email protected]> Signed-off-by: Jerry Ge <[email protected]>
1 parent 47c255b commit 5e03b5b

File tree

3 files changed

+253
-13
lines changed

3 files changed

+253
-13
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,7 @@ def Tosa_MatMulOp : Tosa_InferShapedTypeOp<"matmul"> {
310310
];
311311

312312
let builders = [Tosa_MatMulOpQuantInfoBuilder];
313+
let hasVerifier = 1;
313314
}
314315

315316
//===----------------------------------------------------------------------===//
@@ -344,6 +345,7 @@ def Tosa_MaxPool2dOp : Tosa_InferShapedTypeOp<"max_pool2d"> {
344345
];
345346

346347
let hasCanonicalizer = 1;
348+
let hasVerifier = 1;
347349
}
348350

349351
//===----------------------------------------------------------------------===//
@@ -1470,6 +1472,7 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> {
14701472

14711473
let hasCanonicalizeMethod = 1;
14721474
let hasFolder = 1;
1475+
let hasVerifier = 1;
14731476

14741477
let assemblyFormat = [{
14751478
operands attr-dict `:` `(` type($input1) `,` type($input2) `,` type($input3)
@@ -1845,6 +1848,7 @@ def Tosa_ConcatOp : Tosa_InferTensorTypeOp<"concat"> {
18451848

18461849
let hasCanonicalizer = 1;
18471850
let hasFolder = 1;
1851+
let hasVerifier = 1;
18481852

18491853
let extraClassDeclaration = [{
18501854
/// Returns true when two result types are compatible for this op;
@@ -2101,6 +2105,8 @@ def Tosa_GatherOp : Tosa_InferShapedTypeOp<"gather"> {
21012105
Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
21022106
Extension<[Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
21032107
];
2108+
2109+
let hasVerifier = 1;
21042110
}
21052111

21062112
//===----------------------------------------------------------------------===//
@@ -2134,6 +2140,8 @@ def Tosa_ScatterOp : Tosa_InferShapedTypeOp<"scatter"> {
21342140
Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
21352141
Extension<[Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
21362142
];
2143+
2144+
let hasVerifier = 1;
21372145
}
21382146

21392147
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 208 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -871,6 +871,75 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
871871
return success();
872872
}
873873

874+
LogicalResult tosa::ConcatOp::verify() {
875+
// check that each input has same element type as output
876+
auto outType = getOutput().getType();
877+
const Operation::operand_range inputList = getInput1();
878+
879+
// Check there is at least one input
880+
if (inputList.empty())
881+
return emitOpError("expect at least one input");
882+
883+
if (!llvm::all_of(inputList, [&](auto input) {
884+
return succeeded(verifySameElementTypes(
885+
*this, /* inType = */ input.getType(), outType));
886+
})) {
887+
return failure();
888+
}
889+
890+
const int32_t axis = getAxis();
891+
ShapeAdaptor firstRankedInputShape = nullptr;
892+
for (auto input : inputList) {
893+
const Type inputType = input.getType();
894+
ShapeAdaptor currShape(inputType);
895+
if (currShape.hasRank()) {
896+
firstRankedInputShape = currShape;
897+
// Check axis is in expected range
898+
if (axis < 0 || axis >= firstRankedInputShape.getRank())
899+
return emitOpError("expect axis to be within range 0 < axis < "
900+
"rank(input1[firstRankedTensorIdx]), got ")
901+
<< axis;
902+
break;
903+
}
904+
}
905+
906+
const auto allOperandsHasRank = [](const Value input) {
907+
return ShapeAdaptor(input.getType()).hasRank();
908+
};
909+
if (llvm::all_of(inputList, allOperandsHasRank)) {
910+
const int64_t firstInputRank = firstRankedInputShape.getRank();
911+
912+
for (const auto [index, input] : llvm::enumerate(inputList.drop_front())) {
913+
const ShapeAdaptor inputShape(input.getType());
914+
const int64_t inputRank = inputShape.getRank();
915+
const size_t operandNum = index + 1;
916+
917+
// Check that each operand has the same rank
918+
if (inputRank != firstInputRank)
919+
return emitOpError(
920+
"expect all operands to have the same rank, but got ")
921+
<< firstInputRank << " vs " << inputRank << " on operands 0 and "
922+
<< operandNum;
923+
924+
// Check non-axis dims match
925+
for (int i = 0; i < inputRank; i++) {
926+
const int64_t inputDim = inputShape.getDimSize(i);
927+
const int64_t firstInputDim = firstRankedInputShape.getDimSize(i);
928+
if (i == axis || firstRankedInputShape.isDynamicDim(i) ||
929+
inputShape.isDynamicDim(i))
930+
continue;
931+
if (inputDim != firstInputDim)
932+
return emitOpError("expect all operand shapes to have the same sizes "
933+
"on non-axis dimensions, but got ")
934+
<< inputDim << " vs " << firstInputDim << " at index " << i
935+
<< " on operands 0 and " << operandNum;
936+
}
937+
}
938+
}
939+
940+
return success();
941+
}
942+
874943
LogicalResult tosa::EqualOp::inferReturnTypeComponents(
875944
MLIRContext *context, ::std::optional<Location> location,
876945
ValueShapeRange operands, DictionaryAttr attributes,
@@ -920,6 +989,57 @@ LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
920989
return success();
921990
}
922991

992+
LogicalResult MatMulOp::verify() {
993+
auto aType = llvm::dyn_cast<ShapedType>(getA().getType());
994+
auto bType = llvm::dyn_cast<ShapedType>(getB().getType());
995+
996+
// Must be shaped tensor types
997+
if (!aType) {
998+
emitOpError("expect a shaped tensor for input a, got ") << getA().getType();
999+
return failure();
1000+
}
1001+
if (!bType) {
1002+
emitOpError("expect a shaped tensor for input b, got ") << getB().getType();
1003+
return failure();
1004+
}
1005+
1006+
auto aElementType = aType.getElementType();
1007+
auto bElementType = bType.getElementType();
1008+
1009+
auto aQuantizedEType =
1010+
llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
1011+
auto bQuantizedEType =
1012+
llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
1013+
1014+
if (aQuantizedEType || bQuantizedEType) {
1015+
if (!aQuantizedEType || !bQuantizedEType) {
1016+
emitOpError(
1017+
"expect operands to be both quantized or both not quantized, got ")
1018+
<< aElementType << " and " << bElementType;
1019+
return failure();
1020+
}
1021+
// both a and b have quantized element types
1022+
auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth();
1023+
auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth();
1024+
if (aQuantWidth != bQuantWidth) {
1025+
emitOpError("expect quantized operands to have same widths, got ")
1026+
<< aQuantWidth << " and " << bQuantWidth;
1027+
return failure();
1028+
}
1029+
1030+
return success();
1031+
}
1032+
1033+
// non-quantized element types
1034+
if (aElementType != bElementType) {
1035+
emitOpError("expect same element type for inputs a and b, got ")
1036+
<< aElementType << " and " << bElementType;
1037+
return failure();
1038+
}
1039+
1040+
return success();
1041+
}
1042+
9231043
LogicalResult tosa::PadOp::inferReturnTypeComponents(
9241044
MLIRContext *context, ::std::optional<Location> location,
9251045
PadOp::Adaptor adaptor,
@@ -968,6 +1088,20 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
9681088
}
9691089

9701090
LogicalResult tosa::PadOp::verify() {
1091+
if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
1092+
/* outType = */ getOutput().getType())
1093+
.failed()) {
1094+
return failure();
1095+
}
1096+
1097+
if (auto padConst = getPadConst()) {
1098+
if (verifySameElementTypes(*this, /* inType = */ padConst.getType(),
1099+
/* outType = */ getOutput().getType())
1100+
.failed()) {
1101+
return failure();
1102+
}
1103+
}
1104+
9711105
RankedTensorType inputType = getInput1().getType();
9721106
RankedTensorType outputType = getOutput().getType();
9731107
auto paddingRank = cast<tosa::shapeType>(getPadding().getType()).getRank();
@@ -1041,21 +1175,23 @@ LogicalResult tosa::SliceOp::inferReturnTypeComponents(
10411175
}
10421176

10431177
LogicalResult tosa::SliceOp::verify() {
1178+
if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
1179+
/* outType = */ getOutput().getType())
1180+
.failed())
1181+
return failure();
10441182
auto inputType = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
10451183
if (!inputType)
10461184
return success();
10471185

10481186
auto startShapeRank =
10491187
llvm::cast<tosa::shapeType>(getStart().getType()).getRank();
10501188
if (inputType.getRank() != startShapeRank)
1051-
return emitOpError(
1052-
"length of start attribute is not equal rank of input shape");
1189+
return emitOpError("length of start is not equal to rank of input shape");
10531190

10541191
auto sizeShapeRank =
10551192
llvm::cast<tosa::shapeType>(getSize().getType()).getRank();
10561193
if (inputType.getRank() != sizeShapeRank)
1057-
return emitOpError(
1058-
"length of size attribute is not equal rank of input shape");
1194+
return emitOpError("length of size is not equal to rank of input shape");
10591195

10601196
return success();
10611197
}
@@ -1260,6 +1396,11 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
12601396
}
12611397

12621398
LogicalResult tosa::TileOp::verify() {
1399+
if (verifySameElementTypes(*this, /* intype = */ getInput1().getType(),
1400+
/* outType = */ getOutput().getType())
1401+
.failed()) {
1402+
return failure();
1403+
}
12631404
ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
12641405
ShapedType outputType = llvm::cast<ShapedType>(getType());
12651406

@@ -1341,6 +1482,11 @@ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
13411482
}
13421483

13431484
llvm::LogicalResult tosa::ReshapeOp::verify() {
1485+
if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
1486+
/* outType = */ getOutput().getType())
1487+
.failed()) {
1488+
return failure();
1489+
}
13441490
TensorType inputType = getInput1().getType();
13451491
RankedTensorType outputType = getType();
13461492

@@ -1528,6 +1674,11 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
15281674
}
15291675

15301676
LogicalResult tosa::TransposeOp::verify() {
1677+
if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
1678+
/* outType = */ getOutput().getType())
1679+
.failed()) {
1680+
return failure();
1681+
}
15311682
TensorType inputType = getInput1().getType();
15321683
TensorType outputType = getOutput().getType();
15331684
const llvm::ArrayRef<int32_t> constantPerms = getPerms();
@@ -1628,6 +1779,11 @@ LogicalResult tosa::GatherOp::inferReturnTypeComponents(
16281779
return success();
16291780
}
16301781

1782+
LogicalResult tosa::GatherOp::verify() {
1783+
return verifySameElementTypes(*this, /* inType = */ getValues().getType(),
1784+
/* outType = */ getOutput().getType());
1785+
}
1786+
16311787
LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
16321788
MLIRContext *context, ::std::optional<Location> location,
16331789
ResizeOp::Adaptor adaptor,
@@ -1789,6 +1945,18 @@ LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
17891945
return success();
17901946
}
17911947

1948+
LogicalResult tosa::ScatterOp::verify() {
1949+
if (verifySameElementTypes(*this, /* inType = */ getValuesIn().getType(),
1950+
/* outType = */ getValuesOut().getType())
1951+
.failed() ||
1952+
verifySameElementTypes(*this, /* inType = */ getInput().getType(),
1953+
/* outType = */ getValuesOut().getType())
1954+
.failed()) {
1955+
return failure();
1956+
}
1957+
return success();
1958+
}
1959+
17921960
static LogicalResult ReduceInferReturnTypes(
17931961
ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
17941962
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
@@ -2244,6 +2412,11 @@ LogicalResult MaxPool2dOp::inferReturnTypeComponents(
22442412
inferredReturnShapes);
22452413
}
22462414

2415+
LogicalResult MaxPool2dOp::verify() {
2416+
return verifySameElementTypes(*this, /* intype = */ getInput().getType(),
2417+
/* outType = */ getOutput().getType());
2418+
}
2419+
22472420
LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
22482421
MLIRContext *context, ::std::optional<Location> location,
22492422
DepthwiseConv2DOp::Adaptor adaptor,
@@ -2544,6 +2717,10 @@ void IfOp::print(OpAsmPrinter &p) {
25442717
}
25452718

25462719
LogicalResult ReverseOp::verify() {
2720+
if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
2721+
/* outType = */ getOutput().getType())
2722+
.failed())
2723+
return failure();
25472724
TensorType inputType = getInput1().getType();
25482725
TensorType outputType = getOutput().getType();
25492726
int32_t reverseAxis = getAxis();
@@ -2572,6 +2749,33 @@ LogicalResult ReverseOp::verify() {
25722749
return success();
25732750
}
25742751

2752+
LogicalResult tosa::SelectOp::verify() {
2753+
// verify input2 and input3 have same element type as output
2754+
if (verifySameElementTypes(*this, /* inType = */ getInput2().getType(),
2755+
/* outType = */ getOutput().getType())
2756+
.failed() ||
2757+
verifySameElementTypes(*this, /* inType = */ getInput3().getType(),
2758+
/* outType = */ getOutput().getType())
2759+
.failed()) {
2760+
return failure();
2761+
}
2762+
// verify input1 has element type of bool
2763+
auto predicateType = llvm::dyn_cast<ShapedType>(getInput1().getType());
2764+
if (!predicateType) {
2765+
emitOpError("expect shaped tensor for input1, got ")
2766+
<< getInput1().getType();
2767+
return failure();
2768+
}
2769+
auto predicateElementType = predicateType.getElementType();
2770+
if (!predicateElementType.isInteger(1)) {
2771+
emitOpError("expect element type of bool for input1, got ")
2772+
<< predicateElementType;
2773+
return failure();
2774+
}
2775+
2776+
return success();
2777+
}
2778+
25752779
// parse and print of WhileOp refer to the implementation of SCF dialect.
25762780
ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
25772781
SmallVector<OpAsmParser::Argument, 4> regionArgs;

0 commit comments

Comments
 (0)