diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index e80805a302e3b..db25c8396b5d1 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -325,6 +325,7 @@ def Tosa_MatMulOp : Tosa_InferShapedTypeOp<"matmul"> { ]; let builders = [Tosa_MatMulOpQuantInfoBuilder]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -359,6 +360,7 @@ def Tosa_MaxPool2dOp : Tosa_InferShapedTypeOp<"max_pool2d"> { ]; let hasCanonicalizer = 1; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -1491,6 +1493,7 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> { let hasCanonicalizeMethod = 1; let hasFolder = 1; + let hasVerifier = 1; let assemblyFormat = [{ operands attr-dict `:` `(` type($input1) `,` type($input2) `,` type($input3) @@ -1866,6 +1869,7 @@ def Tosa_ConcatOp : Tosa_InferTensorTypeOp<"concat"> { let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; let extraClassDeclaration = [{ /// Returns true when two result types are compatible for this op; @@ -2122,6 +2126,8 @@ def Tosa_GatherOp : Tosa_InferShapedTypeOp<"gather"> { Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>, Extension<[Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>, ]; + + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -2155,6 +2161,8 @@ def Tosa_ScatterOp : Tosa_InferShapedTypeOp<"scatter"> { Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>, Extension<[Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>, ]; + + let hasVerifier = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 800968e6f4766..322e8837d08d9 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -978,6 +978,75 @@ LogicalResult tosa::ConcatOp::inferReturnTypeComponents( return success(); } +LogicalResult tosa::ConcatOp::verify() { + // check that each input has same element type as output + auto outType = getOutput().getType(); + const Operation::operand_range inputList = getInput1(); + + // Check there is at least one input + if (inputList.empty()) + return emitOpError("expect at least one input"); + + if (!llvm::all_of(inputList, [&](auto input) { + return succeeded(verifySameElementTypes( + *this, /* inType = */ input.getType(), outType)); + })) { + return failure(); + } + + const int32_t axis = getAxis(); + ShapeAdaptor firstRankedInputShape = nullptr; + for (const auto &input : inputList) { + const Type inputType = input.getType(); + ShapeAdaptor currShape(inputType); + if (currShape.hasRank()) { + firstRankedInputShape = currShape; + // Check axis is in expected range + if (axis < 0 || axis >= firstRankedInputShape.getRank()) + return emitOpError("expect axis to be within range 0 < axis < " + "rank(input1[firstRankedTensorIdx]), got ") + << axis; + break; + } + } + + const auto allOperandsHasRank = [](const Value input) { + return ShapeAdaptor(input.getType()).hasRank(); + }; + if (llvm::all_of(inputList, allOperandsHasRank)) { + const int64_t firstInputRank = firstRankedInputShape.getRank(); + + for (const auto &[index, input] : llvm::enumerate(inputList.drop_front())) { + const ShapeAdaptor inputShape(input.getType()); + const int64_t inputRank = inputShape.getRank(); + const size_t operandNum = index + 1; + + // Check that each operand has the same rank + if (inputRank != firstInputRank) + return emitOpError( + "expect all operands to have the same rank, but got ") + << firstInputRank << " vs " << inputRank << " on operands 0 and " + << operandNum; + + // Check non-axis dims match + for (int i = 0; i < inputRank; i++) { + const int64_t inputDim = inputShape.getDimSize(i); + const int64_t firstInputDim = firstRankedInputShape.getDimSize(i); + if (i == axis || firstRankedInputShape.isDynamicDim(i) || + inputShape.isDynamicDim(i)) + continue; + if (inputDim != firstInputDim) + return emitOpError("expect all operand shapes to have the same sizes " + "on non-axis dimensions, but got ") + << inputDim << " vs " << firstInputDim << " at index " << i + << " on operands 0 and " << operandNum; + } + } + } + + return success(); +} + LogicalResult tosa::EqualOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, ValueShapeRange operands, DictionaryAttr attributes, @@ -1027,6 +1096,53 @@ LogicalResult tosa::MatMulOp::inferReturnTypeComponents( return success(); } +LogicalResult MatMulOp::verify() { + auto aType = llvm::dyn_cast(getA().getType()); + auto bType = llvm::dyn_cast(getB().getType()); + + // Must be shaped tensor types + if (!aType) + return emitOpError("expect a shaped tensor for input a, got ") + << getA().getType(); + + if (!bType) + return emitOpError("expect a shaped tensor for input b, got ") + << getB().getType(); + + auto aElementType = aType.getElementType(); + auto bElementType = bType.getElementType(); + + auto aQuantizedEType = + llvm::dyn_cast(aElementType); + auto bQuantizedEType = + llvm::dyn_cast(bElementType); + + if (aQuantizedEType || bQuantizedEType) { + if (!aQuantizedEType || !bQuantizedEType) { + return emitOpError("expect operands to be both quantized or both not " + "quantized, got ") + << aElementType << " and " << bElementType; + } + // both a and b have quantized element types + auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth(); + auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth(); + if (aQuantWidth != bQuantWidth) { + return emitOpError("expect quantized operands to have same widths, got ") + << aQuantWidth << " and " << bQuantWidth; + } + + return success(); + } + + // non-quantized element types + if (aElementType != bElementType) { + return emitOpError("expect same element type for inputs a and b, got ") + << aElementType << " and " << bElementType; + } + + return success(); +} + LogicalResult tosa::PadOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, PadOp::Adaptor adaptor, @@ -1075,6 +1191,20 @@ LogicalResult tosa::PadOp::inferReturnTypeComponents( } LogicalResult tosa::PadOp::verify() { + if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(), + /* outType = */ getOutput().getType()) + .failed()) { + return failure(); + } + + if (auto padConst = getPadConst()) { + if (verifySameElementTypes(*this, /* inType = */ padConst.getType(), + /* outType = */ getOutput().getType()) + .failed()) { + return failure(); + } + } + RankedTensorType inputType = getInput1().getType(); RankedTensorType outputType = getOutput().getType(); auto paddingRank = cast(getPadding().getType()).getRank(); @@ -1148,6 +1278,10 @@ LogicalResult tosa::SliceOp::inferReturnTypeComponents( } LogicalResult tosa::SliceOp::verify() { + if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(), + /* outType = */ getOutput().getType()) + .failed()) + return failure(); auto inputType = llvm::dyn_cast(getInput1().getType()); if (!inputType) return success(); @@ -1155,14 +1289,12 @@ LogicalResult tosa::SliceOp::verify() { auto startShapeRank = llvm::cast(getStart().getType()).getRank(); if (inputType.getRank() != startShapeRank) - return emitOpError( - "length of start attribute is not equal rank of input shape"); + return emitOpError("length of start is not equal to rank of input shape"); auto sizeShapeRank = llvm::cast(getSize().getType()).getRank(); if (inputType.getRank() != sizeShapeRank) - return emitOpError( - "length of size attribute is not equal rank of input shape"); + return emitOpError("length of size is not equal to rank of input shape"); return success(); } @@ -1367,6 +1499,11 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents( } LogicalResult tosa::TileOp::verify() { + if (verifySameElementTypes(*this, /* intype = */ getInput1().getType(), + /* outType = */ getOutput().getType()) + .failed()) { + return failure(); + } ShapedType inputType = llvm::cast(getInput1().getType()); ShapedType outputType = llvm::cast(getType()); @@ -1448,6 +1585,11 @@ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents( } llvm::LogicalResult tosa::ReshapeOp::verify() { + if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(), + /* outType = */ getOutput().getType()) + .failed()) { + return failure(); + } TensorType inputType = getInput1().getType(); RankedTensorType outputType = getType(); @@ -1626,6 +1768,11 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents( } LogicalResult tosa::TransposeOp::verify() { + if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(), + /* outType = */ getOutput().getType()) + .failed()) { + return failure(); + } TensorType inputType = getInput1().getType(); TensorType outputType = getOutput().getType(); const llvm::ArrayRef constantPerms = getPerms(); @@ -1726,6 +1873,11 @@ LogicalResult tosa::GatherOp::inferReturnTypeComponents( return success(); } +LogicalResult tosa::GatherOp::verify() { + return verifySameElementTypes(*this, /* inType = */ getValues().getType(), + /* outType = */ getOutput().getType()); +} + LogicalResult tosa::ResizeOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, ResizeOp::Adaptor adaptor, @@ -1887,6 +2039,18 @@ LogicalResult tosa::ScatterOp::inferReturnTypeComponents( return success(); } +LogicalResult tosa::ScatterOp::verify() { + if (verifySameElementTypes(*this, /* inType = */ getValuesIn().getType(), + /* outType = */ getValuesOut().getType()) + .failed() || + verifySameElementTypes(*this, /* inType = */ getInput().getType(), + /* outType = */ getValuesOut().getType()) + .failed()) { + return failure(); + } + return success(); +} + static LogicalResult ReduceInferReturnTypes( ShapeAdaptor operandShape, Type inputType, IntegerAttr axis, SmallVectorImpl &inferredReturnShapes) { @@ -2342,6 +2506,11 @@ LogicalResult MaxPool2dOp::inferReturnTypeComponents( inferredReturnShapes); } +LogicalResult MaxPool2dOp::verify() { + return verifySameElementTypes(*this, /* intype = */ getInput().getType(), + /* outType = */ getOutput().getType()); +} + LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, DepthwiseConv2DOp::Adaptor adaptor, @@ -2642,6 +2811,10 @@ void IfOp::print(OpAsmPrinter &p) { } LogicalResult ReverseOp::verify() { + if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(), + /* outType = */ getOutput().getType()) + .failed()) + return failure(); TensorType inputType = getInput1().getType(); TensorType outputType = getOutput().getType(); int32_t reverseAxis = getAxis(); @@ -2670,6 +2843,31 @@ LogicalResult ReverseOp::verify() { return success(); } +LogicalResult tosa::SelectOp::verify() { + // verify input2 and input3 have same element type as output + if (verifySameElementTypes(*this, /* inType = */ getInput2().getType(), + /* outType = */ getOutput().getType()) + .failed() || + verifySameElementTypes(*this, /* inType = */ getInput3().getType(), + /* outType = */ getOutput().getType()) + .failed()) { + return failure(); + } + // verify input1 has element type of bool + auto predicateType = llvm::dyn_cast(getInput1().getType()); + if (!predicateType) { + return emitOpError("expect shaped tensor for input1, got ") + << getInput1().getType(); + } + auto predicateElementType = predicateType.getElementType(); + if (!predicateElementType.isInteger(1)) { + return emitOpError("expect element type of bool for input1, got ") + << predicateElementType; + } + + return success(); +} + // parse and print of WhileOp refer to the implementation of SCF dialect. ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) { SmallVector regionArgs; diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index e665510ff0143..85dc80e651269 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -193,8 +193,7 @@ func.func @test_conv2d_quant_any_result(%arg0: tensor<1x4x4x4x!quant.any, %arg1 : tensor<2x2xf32>) -> tensor { - // expected-error@+2 {{failed to infer returned types}} - // expected-error@+1 {{Cannot concat tensors with different sizes on the non-axis dimension 1}} + // expected-error@+1 {{'tosa.concat' op expect all operand shapes to have the same sizes on non-axis dimensions, but got 2 vs 1 at index 1 on operands 0 and 1}} %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<2x1xf32>, tensor<2x2xf32>) -> tensor return %0 : tensor } @@ -202,14 +201,36 @@ func.func @test_concat(%arg0 : tensor<2x1xf32>, %arg1 : tensor<2x2xf32>) -> tens // ----- func.func @test_concat_element_type_mismatch(%arg0 : tensor<1x2xf32>, %arg1 : tensor<2x2xf32>) -> tensor { - // expected-error@+2 {{failed to infer returned types}} - // expected-error@+1 {{'tosa.concat' op inferred type(s) 'tensor<3x2xf32>' are incompatible with return type(s) of operation 'tensor}} + // expected-error@+1 {{'tosa.concat' op expect input and output to have same element type, got 'f32' and 'i8'}} %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor return %0 : tensor } // ----- +func.func @test_concat_zero_inputs() { + // expected-error@+1 {{'tosa.concat' op expect at least one input}} + %0 = tosa.concat {axis = 0 : i32} : () -> tensor<*xf32> +} + +// ----- + +func.func @test_concat_axis_negative(%arg0: tensor<1x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { + // expected-error@+1 {{'tosa.concat' op expect axis to be within range 0 < axis < rank(input1[firstRankedTensorIdx]), got -1}} + %0 = tosa.concat %arg0, %arg1 {axis = -1 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} + +// ----- + +func.func @test_concat_axis_out_of_range(%arg0: tensor<1x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { + // expected-error@+1 {{'tosa.concat' op expect axis to be within range 0 < axis < rank(input1[firstRankedTensorIdx]), got 3}} + %0 = tosa.concat %arg0, %arg1 {axis = 3 : i32} : (tensor<1x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} + +// ----- + func.func @test_pad_non_const(%arg0: tensor<13x21x3xf32>, %arg1: !tosa.shape<6>) -> tensor<13x21x3xf32> { // expected-error@+1 {{'tosa.pad' op shape operand is not compile time resolvable}} %0 = tosa.pad %arg0, %arg1 : (tensor<13x21x3xf32>, !tosa.shape<6>) -> tensor<13x21x3xf32> @@ -236,6 +257,14 @@ func.func @test_pad_io_rank_mismatch(%arg0: tensor<13x21xf32>) { // ----- +func.func @test_concat_input_rank_mismatch(%arg0: tensor<1x2x3xf32>, %arg1: tensor<1x2xf32>) -> tensor<2x2x3xf32> { + // expected-error@+1 {{'tosa.concat' op expect all operands to have the same rank, but got 3 vs 2 on operands 0 and 1}} + %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2x3xf32>, tensor<1x2xf32>) -> tensor<2x2x3xf32> + return %0 : tensor<2x2x3xf32> +} + +// ----- + func.func @test_pad_invalid_padding_rank(%arg0: tensor<13x21xf32>) { %0 = tosa.const_shape {value = dense<1> : tensor<6xindex>} : () -> !tosa.shape<6> // expected-error@+1 {{'tosa.pad' op expected padding tensor dim 0 to have size 4 (2*rank(shape1)) but got size 6}} @@ -430,8 +459,7 @@ func.func @test_reduce_min_invalid_output_rank(%arg0 : tensor) -> () { func.func @test_reshape_type_mismatch(%arg0 : tensor<13x21x3xf32>) -> () { %1 = tosa.const_shape {value = dense<[13, 21, 3, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> - // expected-error@+2 {{failed to infer returned types}} - // expected-error@+1 {{'tosa.reshape' op inferred type(s) 'tensor<13x21x3x1xf32>' are incompatible with return type(s) of operation 'tensor<13x21x3x1xi32>'}} + // expected-error@+1 {{'tosa.reshape' op expect input and output to have same element type, got 'f32' and 'i32'}} %0 = tosa.reshape %arg0, %1 : (tensor<13x21x3xf32>, !tosa.shape<4>) -> tensor<13x21x3x1xi32> return } @@ -521,7 +549,7 @@ func.func @test_reshape_invalid_tensor_dim(%arg0 : tensor<4x?xf32>) -> () { func.func @test_reverse_axis_out_of_range(%arg0 : tensor<13x21x3xf32>) -> () { // expected-error@+1 {{'tosa.reverse' op expect input tensor rank (3) to be larger than reverse axis (5)}} - %0 = tosa.reverse %arg0 {axis = 5 : i32} : (tensor<13x21x3xf32>) -> tensor + %0 = tosa.reverse %arg0 {axis = 5 : i32} : (tensor<13x21x3xf32>) -> tensor return } @@ -634,7 +662,7 @@ func.func @test_slice_invalid_start() { %0 = tensor.empty() : tensor<4x31x31xf32> %start = tosa.const_shape {value = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> %size = tosa.const_shape {value = dense<[1, 1, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> - // expected-error@+1 {{'tosa.slice' op length of start attribute is not equal rank of input shape}} + // expected-error@+1 {{'tosa.slice' op length of start is not equal to rank of input shape}} %3 = tosa.slice %0, %start, %size : (tensor<4x31x31xf32>, !tosa.shape<2>, !tosa.shape<3>) -> tensor<*xf32> return } @@ -645,7 +673,7 @@ func.func @test_slice_invalid_size() { %0 = tensor.empty() : tensor<4x31x31xf32> %start = tosa.const_shape {value = dense<[1, 1, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> %size = tosa.const_shape {value = dense<[1]> : tensor<1xindex>} : () -> !tosa.shape<1> - // expected-error@+1 {{'tosa.slice' op length of size attribute is not equal rank of input shape}} + // expected-error@+1 {{'tosa.slice' op length of size is not equal to rank of input shape}} %3 = tosa.slice %0, %start, %size : (tensor<4x31x31xf32>, !tosa.shape<3>, !tosa.shape<1>) -> tensor<*xf32> return }