Skip to content

[mlir][tosa] Add more verifiers for the following operators #127923

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ def Tosa_MatMulOp : Tosa_InferShapedTypeOp<"matmul"> {
];

let builders = [Tosa_MatMulOpQuantInfoBuilder];
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -359,6 +360,7 @@ def Tosa_MaxPool2dOp : Tosa_InferShapedTypeOp<"max_pool2d"> {
];

let hasCanonicalizer = 1;
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1491,6 +1493,7 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> {

let hasCanonicalizeMethod = 1;
let hasFolder = 1;
let hasVerifier = 1;

let assemblyFormat = [{
operands attr-dict `:` `(` type($input1) `,` type($input2) `,` type($input3)
Expand Down Expand Up @@ -1866,6 +1869,7 @@ def Tosa_ConcatOp : Tosa_InferTensorTypeOp<"concat"> {

let hasCanonicalizer = 1;
let hasFolder = 1;
let hasVerifier = 1;

let extraClassDeclaration = [{
/// Returns true when two result types are compatible for this op;
Expand Down Expand Up @@ -2122,6 +2126,8 @@ def Tosa_GatherOp : Tosa_InferShapedTypeOp<"gather"> {
Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
Extension<[Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
];

let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2155,6 +2161,8 @@ def Tosa_ScatterOp : Tosa_InferShapedTypeOp<"scatter"> {
Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
Extension<[Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
];

let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
Expand Down
206 changes: 202 additions & 4 deletions mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -978,6 +978,75 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
return success();
}

LogicalResult tosa::ConcatOp::verify() {
// check that each input has same element type as output
auto outType = getOutput().getType();
const Operation::operand_range inputList = getInput1();

// Check there is at least one input
if (inputList.empty())
return emitOpError("expect at least one input");

if (!llvm::all_of(inputList, [&](auto input) {
return succeeded(verifySameElementTypes(
*this, /* inType = */ input.getType(), outType));
})) {
return failure();
}

const int32_t axis = getAxis();
ShapeAdaptor firstRankedInputShape = nullptr;
for (const auto &input : inputList) {
const Type inputType = input.getType();
ShapeAdaptor currShape(inputType);
if (currShape.hasRank()) {
firstRankedInputShape = currShape;
// Check axis is in expected range
if (axis < 0 || axis >= firstRankedInputShape.getRank())
return emitOpError("expect axis to be within range 0 < axis < "
"rank(input1[firstRankedTensorIdx]), got ")
<< axis;
break;
}
}

const auto allOperandsHasRank = [](const Value input) {
return ShapeAdaptor(input.getType()).hasRank();
};
if (llvm::all_of(inputList, allOperandsHasRank)) {
const int64_t firstInputRank = firstRankedInputShape.getRank();

for (const auto &[index, input] : llvm::enumerate(inputList.drop_front())) {
const ShapeAdaptor inputShape(input.getType());
const int64_t inputRank = inputShape.getRank();
const size_t operandNum = index + 1;

// Check that each operand has the same rank
if (inputRank != firstInputRank)
return emitOpError(
"expect all operands to have the same rank, but got ")
<< firstInputRank << " vs " << inputRank << " on operands 0 and "
<< operandNum;

// Check non-axis dims match
for (int i = 0; i < inputRank; i++) {
const int64_t inputDim = inputShape.getDimSize(i);
const int64_t firstInputDim = firstRankedInputShape.getDimSize(i);
if (i == axis || firstRankedInputShape.isDynamicDim(i) ||
inputShape.isDynamicDim(i))
continue;
if (inputDim != firstInputDim)
return emitOpError("expect all operand shapes to have the same sizes "
"on non-axis dimensions, but got ")
<< inputDim << " vs " << firstInputDim << " at index " << i
<< " on operands 0 and " << operandNum;
}
}
}

return success();
}

LogicalResult tosa::EqualOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes,
Expand Down Expand Up @@ -1027,6 +1096,53 @@ LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
return success();
}

LogicalResult MatMulOp::verify() {
auto aType = llvm::dyn_cast<ShapedType>(getA().getType());
auto bType = llvm::dyn_cast<ShapedType>(getB().getType());

// Must be shaped tensor types
if (!aType)
return emitOpError("expect a shaped tensor for input a, got ")
<< getA().getType();

if (!bType)
return emitOpError("expect a shaped tensor for input b, got ")
<< getB().getType();

auto aElementType = aType.getElementType();
auto bElementType = bType.getElementType();

auto aQuantizedEType =
llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
auto bQuantizedEType =
llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);

if (aQuantizedEType || bQuantizedEType) {
if (!aQuantizedEType || !bQuantizedEType) {
return emitOpError("expect operands to be both quantized or both not "
"quantized, got ")
<< aElementType << " and " << bElementType;
}
// both a and b have quantized element types
auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth();
auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth();
if (aQuantWidth != bQuantWidth) {
return emitOpError("expect quantized operands to have same widths, got ")
<< aQuantWidth << " and " << bQuantWidth;
}

return success();
}

// non-quantized element types
if (aElementType != bElementType) {
return emitOpError("expect same element type for inputs a and b, got ")
<< aElementType << " and " << bElementType;
}

return success();
}

LogicalResult tosa::PadOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
PadOp::Adaptor adaptor,
Expand Down Expand Up @@ -1075,6 +1191,20 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents(
}

LogicalResult tosa::PadOp::verify() {
if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
/* outType = */ getOutput().getType())
.failed()) {
return failure();
}

if (auto padConst = getPadConst()) {
if (verifySameElementTypes(*this, /* inType = */ padConst.getType(),
/* outType = */ getOutput().getType())
.failed()) {
return failure();
}
}

RankedTensorType inputType = getInput1().getType();
RankedTensorType outputType = getOutput().getType();
auto paddingRank = cast<tosa::shapeType>(getPadding().getType()).getRank();
Expand Down Expand Up @@ -1148,21 +1278,23 @@ LogicalResult tosa::SliceOp::inferReturnTypeComponents(
}

LogicalResult tosa::SliceOp::verify() {
if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
/* outType = */ getOutput().getType())
.failed())
return failure();
auto inputType = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
if (!inputType)
return success();

auto startShapeRank =
llvm::cast<tosa::shapeType>(getStart().getType()).getRank();
if (inputType.getRank() != startShapeRank)
return emitOpError(
"length of start attribute is not equal rank of input shape");
return emitOpError("length of start is not equal to rank of input shape");

auto sizeShapeRank =
llvm::cast<tosa::shapeType>(getSize().getType()).getRank();
if (inputType.getRank() != sizeShapeRank)
return emitOpError(
"length of size attribute is not equal rank of input shape");
return emitOpError("length of size is not equal to rank of input shape");

return success();
}
Expand Down Expand Up @@ -1367,6 +1499,11 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
}

LogicalResult tosa::TileOp::verify() {
if (verifySameElementTypes(*this, /* intype = */ getInput1().getType(),
/* outType = */ getOutput().getType())
.failed()) {
return failure();
}
ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
ShapedType outputType = llvm::cast<ShapedType>(getType());

Expand Down Expand Up @@ -1448,6 +1585,11 @@ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
}

llvm::LogicalResult tosa::ReshapeOp::verify() {
if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
/* outType = */ getOutput().getType())
.failed()) {
return failure();
}
TensorType inputType = getInput1().getType();
RankedTensorType outputType = getType();

Expand Down Expand Up @@ -1626,6 +1768,11 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
}

LogicalResult tosa::TransposeOp::verify() {
if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
/* outType = */ getOutput().getType())
.failed()) {
return failure();
}
TensorType inputType = getInput1().getType();
TensorType outputType = getOutput().getType();
const llvm::ArrayRef<int32_t> constantPerms = getPerms();
Expand Down Expand Up @@ -1726,6 +1873,11 @@ LogicalResult tosa::GatherOp::inferReturnTypeComponents(
return success();
}

LogicalResult tosa::GatherOp::verify() {
return verifySameElementTypes(*this, /* inType = */ getValues().getType(),
/* outType = */ getOutput().getType());
}

LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ResizeOp::Adaptor adaptor,
Expand Down Expand Up @@ -1887,6 +2039,18 @@ LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
return success();
}

LogicalResult tosa::ScatterOp::verify() {
if (verifySameElementTypes(*this, /* inType = */ getValuesIn().getType(),
/* outType = */ getValuesOut().getType())
.failed() ||
verifySameElementTypes(*this, /* inType = */ getInput().getType(),
/* outType = */ getValuesOut().getType())
.failed()) {
return failure();
}
return success();
}

static LogicalResult ReduceInferReturnTypes(
ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
Expand Down Expand Up @@ -2342,6 +2506,11 @@ LogicalResult MaxPool2dOp::inferReturnTypeComponents(
inferredReturnShapes);
}

LogicalResult MaxPool2dOp::verify() {
return verifySameElementTypes(*this, /* intype = */ getInput().getType(),
/* outType = */ getOutput().getType());
}

LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
DepthwiseConv2DOp::Adaptor adaptor,
Expand Down Expand Up @@ -2642,6 +2811,10 @@ void IfOp::print(OpAsmPrinter &p) {
}

LogicalResult ReverseOp::verify() {
if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
/* outType = */ getOutput().getType())
.failed())
return failure();
TensorType inputType = getInput1().getType();
TensorType outputType = getOutput().getType();
int32_t reverseAxis = getAxis();
Expand Down Expand Up @@ -2670,6 +2843,31 @@ LogicalResult ReverseOp::verify() {
return success();
}

LogicalResult tosa::SelectOp::verify() {
// verify input2 and input3 have same element type as output
if (verifySameElementTypes(*this, /* inType = */ getInput2().getType(),
/* outType = */ getOutput().getType())
.failed() ||
verifySameElementTypes(*this, /* inType = */ getInput3().getType(),
/* outType = */ getOutput().getType())
.failed()) {
return failure();
}
// verify input1 has element type of bool
auto predicateType = llvm::dyn_cast<ShapedType>(getInput1().getType());
if (!predicateType) {
return emitOpError("expect shaped tensor for input1, got ")
<< getInput1().getType();
}
auto predicateElementType = predicateType.getElementType();
if (!predicateElementType.isInteger(1)) {
return emitOpError("expect element type of bool for input1, got ")
<< predicateElementType;
}

return success();
}

// parse and print of WhileOp refer to the implementation of SCF dialect.
ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::Argument, 4> regionArgs;
Expand Down
Loading