-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][tosa] Add verifiers for FFT2d and RFFT2d #129273
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Adds checks for element types and input/output shapes. Signed-off-by: Luke Hutton <[email protected]> Change-Id: Ib40928027f5b9d75306aa662c4627e3263db7de7
@llvm/pr-subscribers-mlir-tosa @llvm/pr-subscribers-mlir-linalg Author: Luke Hutton (lhutton1) ChangesAdds checks for element types and input/output shapes. Patch is 21.37 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/129273.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index ddfec2c9bfcd3..9e14daf0d014c 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -240,7 +240,10 @@ def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> {
//===----------------------------------------------------------------------===//
// Operator: fft2d
//===----------------------------------------------------------------------===//
-def Tosa_FFT2dOp : Tosa_InferShapedTypeOp<"fft2d"> {
+def Tosa_FFT2dOp : Tosa_InferShapedTypeOp<"fft2d", [
+ SameOperandsAndResultElementType,
+ SameOperandsAndResultShape,
+ ResultsAreFloatLike]> {
let summary = "Performs FFT2D operation on the input.";
let description = [{
@@ -279,6 +282,8 @@ def Tosa_FFT2dOp : Tosa_InferShapedTypeOp<"fft2d"> {
$input_real `,` $input_imag attr-dict `:` `(` type($input_real) `,`
type($input_imag) `)` `->` `(` type($output_real) `,` type($output_imag) `)`
}];
+
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -349,7 +354,9 @@ def Tosa_MaxPool2dOp : Tosa_InferShapedTypeOp<"max_pool2d"> {
//===----------------------------------------------------------------------===//
// Operator: rfft2d
//===----------------------------------------------------------------------===//
-def Tosa_RFFT2dOp : Tosa_InferShapedTypeOp<"rfft2d"> {
+def Tosa_RFFT2dOp : Tosa_InferShapedTypeOp<"rfft2d", [
+ SameOperandsAndResultElementType,
+ ResultsAreFloatLike]> {
let summary = "Performs RFFT2D operation on the input.";
let description = [{
@@ -385,6 +392,8 @@ def Tosa_RFFT2dOp : Tosa_InferShapedTypeOp<"rfft2d"> {
let assemblyFormat = [{
$input attr-dict `:` `(` type($input) `)` `->` `(` type($output_real) `,` type($output_imag) `)`
}];
+
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 7b50eceb081dd..63afaed22d7ab 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -789,7 +789,7 @@ LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
int64_t inWidth = inputShape.getDimSize(2);
// Note that we can support this calculation symbolically
- // in the future e.g. [x, y, z] -> [x, y, z / 2 - 1]
+ // in the future e.g. [x, y, z] -> [x, y, z / 2 + 1]
if (inWidth != ShapedType::kDynamic)
outputShape[2] = inWidth / 2 + 1;
@@ -799,6 +799,57 @@ LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
return success();
}
+static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize,
+ const llvm::StringRef dimName) {
+ const bool isPowerOfTwo = (dimSize & (dimSize - 1)) == 0 && dimSize > 0;
+ if (!isPowerOfTwo)
+ return op->emitOpError("expected ")
+ << dimName << " to be a power of two, got " << dimSize;
+
+ return success();
+}
+
+LogicalResult tosa::RFFT2dOp::verify() {
+ const auto outputTypes = getResultTypes();
+ if (failed(verifyCompatibleShapes(outputTypes)))
+ return emitOpError("expected output shapes to match, got ") << outputTypes;
+
+ const auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().getType());
+ if (!inputType)
+ return success();
+
+ const int64_t height = inputType.getDimSize(1);
+ if (!ShapedType::isDynamic(height) &&
+ failed(verifyDimIsPowerOfTwo(*this, height, "height")))
+ return failure();
+
+ const int64_t width = inputType.getDimSize(2);
+ if (!ShapedType::isDynamic(width) &&
+ failed(verifyDimIsPowerOfTwo(*this, width, "width")))
+ return failure();
+
+ const auto outputType = llvm::dyn_cast<RankedTensorType>(outputTypes[0]);
+ if (!outputType)
+ return success();
+
+ // Batch and height input/output dimensions should match
+ if (failed(verifyCompatibleShape(inputType.getShape().drop_back(),
+ outputType.getShape().drop_back())))
+ return emitOpError("expected batch and height dimensions of input/output "
+ "to match, got input=")
+ << inputType << " output=" << outputType;
+
+ // Output width dimension expected to be input_width / 2 + 1
+ const int64_t outputWidth = outputType.getDimSize(2);
+ if (!ShapedType::isDynamic(width) && !ShapedType::isDynamic(outputWidth) &&
+ (outputWidth - 1) * 2 != width)
+ return emitOpError(
+ "expected output width to be equal to input_width / 2 + 1, got ")
+ << outputWidth;
+
+ return success();
+}
+
LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
FFT2dOp::Adaptor adaptor,
@@ -810,6 +861,33 @@ LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
return success();
}
+LogicalResult tosa::FFT2dOp::verify() {
+ const auto inputRealType =
+ llvm::dyn_cast<RankedTensorType>(getInputReal().getType());
+ const auto inputImagType =
+ llvm::dyn_cast<RankedTensorType>(getInputImag().getType());
+ if (!inputRealType || !inputImagType)
+ return success();
+
+ const auto trySelectStaticDim = [](const int64_t a, const int64_t b) {
+ return ShapedType::isDynamic(a) ? a : b;
+ };
+
+ const int64_t height = trySelectStaticDim(inputRealType.getDimSize(1),
+ inputImagType.getDimSize(1));
+ if (!ShapedType::isDynamic(height) &&
+ failed(verifyDimIsPowerOfTwo(*this, height, "height")))
+ return failure();
+
+ const int64_t width = trySelectStaticDim(inputRealType.getDimSize(2),
+ inputImagType.getDimSize(2));
+ if (!ShapedType::isDynamic(width) &&
+ failed(verifyDimIsPowerOfTwo(*this, width, "width")))
+ return failure();
+
+ return success();
+}
+
LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ConcatOp::Adaptor adaptor,
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
index 5db3f56cf459e..71fcd4129a618 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
@@ -31,15 +31,6 @@ func.func @unranked_add(%arg0 : tensor<10x10xf32> , %arg1 : tensor<10x10xf32>, %
// -----
-// CHECK-LABEL: @rfft2d_with_non_float_type
-func.func @rfft2d_with_non_float_type(%arg0 : tensor<1x1x1xi32>) -> (tensor<1x1x1xi32>, tensor<1x1x1xi32>) {
- // expected-error@+1 {{failed to legalize operation 'tosa.rfft2d'}}
- %real, %imag = tosa.rfft2d %arg0 : (tensor<1x1x1xi32>) -> (tensor<1x1x1xi32>, tensor<1x1x1xi32>)
- return %real, %imag : tensor<1x1x1xi32>, tensor<1x1x1xi32>
-}
-
-// -----
-
// CHECK-LABEL: @rescale_unsupported_type
func.func @rescale_unsupported_type(%arg0: tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>> {
// expected-error@+1 {{failed to legalize operation 'tosa.rescale'}}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 78f2e173d7cb1..e68783f779063 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1720,20 +1720,20 @@ func.func @table8_dyn_table(%arg0: tensor<6xi8>, %arg1: tensor<?xi8>) -> () {
// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
// CHECK-LABEL: func.func @test_static_rfft2d(
-// CHECK-SAME: %[[VAL_0:.*]]: tensor<5x5x8xf32>) -> (tensor<5x5x5xf32>, tensor<5x5x5xf32>) {
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<5x4x8xf32>) -> (tensor<5x4x5xf32>, tensor<5x4x5xf32>) {
// CHECK: %[[VAL_1:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_2:.*]] = arith.constant 2 : index
// CHECK: %[[VAL_3:.*]] = arith.constant 8 : index
// CHECK: %[[VAL_4:.*]] = arith.constant 4 : index
// CHECK: %[[VAL_5:.*]] = arith.constant 5 : index
-// CHECK: %[[VAL_6:.*]] = tensor.empty() : tensor<5x5x5xf32>
+// CHECK: %[[VAL_6:.*]] = tensor.empty() : tensor<5x4x5xf32>
// CHECK: %[[VAL_7:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[VAL_8:.*]] = linalg.fill ins(%[[VAL_7]] : f32) outs(%[[VAL_6]] : tensor<5x5x5xf32>) -> tensor<5x5x5xf32>
-// CHECK: %[[VAL_9:.*]] = tensor.empty() : tensor<5x5x5xf32>
+// CHECK: %[[VAL_8:.*]] = linalg.fill ins(%[[VAL_7]] : f32) outs(%[[VAL_6]] : tensor<5x4x5xf32>) -> tensor<5x4x5xf32>
+// CHECK: %[[VAL_9:.*]] = tensor.empty() : tensor<5x4x5xf32>
// CHECK: %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[VAL_11:.*]] = linalg.fill ins(%[[VAL_10]] : f32) outs(%[[VAL_9]] : tensor<5x5x5xf32>) -> tensor<5x5x5xf32>
+// CHECK: %[[VAL_11:.*]] = linalg.fill ins(%[[VAL_10]] : f32) outs(%[[VAL_9]] : tensor<5x4x5xf32>) -> tensor<5x4x5xf32>
// CHECK: %[[VAL_12:.*]] = arith.constant 1 : index
-// CHECK: %[[VAL_13:.*]] = arith.constant 5 : index
+// CHECK: %[[VAL_13:.*]] = arith.constant 4 : index
// CHECK: %[[VAL_14:.*]] = arith.constant 2 : index
// CHECK: %[[VAL_15:.*]] = arith.constant 8 : index
// CHECK: %[[VAL_16:.*]] = arith.constant 6.28318548 : f32
@@ -1741,7 +1741,7 @@ func.func @table8_dyn_table(%arg0: tensor<6xi8>, %arg1: tensor<?xi8>) -> () {
// CHECK: %[[VAL_18:.*]] = arith.uitofp %[[VAL_17]] : i32 to f32
// CHECK: %[[VAL_19:.*]] = arith.index_castui %[[VAL_15]] : index to i32
// CHECK: %[[VAL_20:.*]] = arith.uitofp %[[VAL_19]] : i32 to f32
-// CHECK: %[[VAL_21:.*]]:2 = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]]], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%[[VAL_0]] : tensor<5x5x8xf32>) outs(%[[VAL_8]], %[[VAL_11]] : tensor<5x5x5xf32>, tensor<5x5x5xf32>) {
+// CHECK: %[[VAL_21:.*]]:2 = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]]], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%[[VAL_0]] : tensor<5x4x8xf32>) outs(%[[VAL_8]], %[[VAL_11]] : tensor<5x4x5xf32>, tensor<5x4x5xf32>) {
// CHECK: ^bb0(%[[VAL_22:.*]]: f32, %[[VAL_23:.*]]: f32, %[[VAL_24:.*]]: f32):
// CHECK: %[[VAL_25:.*]] = linalg.index 1 : index
// CHECK: %[[VAL_26:.*]] = linalg.index 2 : index
@@ -1766,12 +1766,12 @@ func.func @table8_dyn_table(%arg0: tensor<6xi8>, %arg1: tensor<?xi8>) -> () {
// CHECK: %[[VAL_45:.*]] = arith.addf %[[VAL_23]], %[[VAL_43]] : f32
// CHECK: %[[VAL_46:.*]] = arith.subf %[[VAL_24]], %[[VAL_44]] : f32
// CHECK: linalg.yield %[[VAL_45]], %[[VAL_46]] : f32, f32
-// CHECK: } -> (tensor<5x5x5xf32>, tensor<5x5x5xf32>)
-// CHECK: return %[[VAL_47:.*]]#0, %[[VAL_47]]#1 : tensor<5x5x5xf32>, tensor<5x5x5xf32>
+// CHECK: } -> (tensor<5x4x5xf32>, tensor<5x4x5xf32>)
+// CHECK: return %[[VAL_47:.*]]#0, %[[VAL_47]]#1 : tensor<5x4x5xf32>, tensor<5x4x5xf32>
// CHECK: }
-func.func @test_static_rfft2d(%arg0: tensor<5x5x8xf32>) -> (tensor<5x5x5xf32>, tensor<5x5x5xf32>) {
- %output_real, %output_imag = "tosa.rfft2d"(%arg0) {} : (tensor<5x5x8xf32>) -> (tensor<5x5x5xf32>, tensor<5x5x5xf32>)
- return %output_real, %output_imag : tensor<5x5x5xf32>, tensor<5x5x5xf32>
+func.func @test_static_rfft2d(%arg0: tensor<5x4x8xf32>) -> (tensor<5x4x5xf32>, tensor<5x4x5xf32>) {
+ %output_real, %output_imag = "tosa.rfft2d"(%arg0) {} : (tensor<5x4x8xf32>) -> (tensor<5x4x5xf32>, tensor<5x4x5xf32>)
+ return %output_real, %output_imag : tensor<5x4x5xf32>, tensor<5x4x5xf32>
}
// -----
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 123c65e1b4fcd..fb7a9222947f6 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -1276,3 +1276,83 @@ func.func @test_conv2d_invalid_bias_size(%arg0: tensor<1x4x4x4xf32>, %arg1: tens
: (tensor<1x4x4x4xf32>, tensor<8x1x1x4xf32>, tensor<7xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x4x4x8xf32>
return %0 : tensor<1x4x4x8xf32>
}
+
+// -----
+
+func.func @test_fft2d_same_operands_and_result_element_type(%arg0: tensor<1x4x8xf32>, %arg1: tensor<1x4x8xf32>) -> (tensor<1x4x8xf16>, tensor<1x4x8xf16>) {
+ // expected-error@+1 {{'tosa.fft2d' op requires the same element type for all operands and results}}
+ %0, %1 = tosa.fft2d %arg0, %arg1 {inverse = false} : (tensor<1x4x8xf32>, tensor<1x4x8xf32>) -> (tensor<1x4x8xf16>, tensor<1x4x8xf16>)
+ return %0, %1 : tensor<1x4x8xf16>, tensor<1x4x8xf16>
+}
+
+// -----
+
+func.func @test_fft2d_same_operands_and_result_shape(%arg0: tensor<1x4x8xf32>, %arg1: tensor<1x4x7xf32>) -> (tensor<1x4x8xf32>, tensor<1x4x8xf32>) {
+ // expected-error@+1 {{'tosa.fft2d' op requires the same shape for all operands and results}}
+ %0, %1 = tosa.fft2d %arg0, %arg1 {inverse = false} : (tensor<1x4x8xf32>, tensor<1x4x7xf32>) -> (tensor<1x4x8xf32>, tensor<1x4x8xf32>)
+ return %0, %1 : tensor<1x4x8xf32>, tensor<1x4x8xf32>
+}
+
+// -----
+
+func.func @test_fft2d_invalid_type(%arg0: tensor<1x4x8xi8>, %arg1: tensor<1x4x8xi8>) -> (tensor<1x4x8xi8>, tensor<1x4x8xi8>) {
+ // expected-error@+1 {{'tosa.fft2d' op requires a floating point type}}
+ %0, %1 = tosa.fft2d %arg0, %arg1 {inverse = false} : (tensor<1x4x8xi8>, tensor<1x4x8xi8>) -> (tensor<1x4x8xi8>, tensor<1x4x8xi8>)
+ return %0, %1 : tensor<1x4x8xi8>, tensor<1x4x8xi8>
+}
+
+// -----
+
+func.func @test_fft2d_height_non_power_of_two(%arg0: tensor<1x5x8xf32>, %arg1: tensor<1x5x8xf32>) -> (tensor<1x5x8xf32>, tensor<1x5x8xf32>) {
+ // expected-error@+1 {{'tosa.fft2d' op expected height to be a power of two, got 5}}
+ %0, %1 = tosa.fft2d %arg0, %arg1 {inverse = false} : (tensor<1x5x8xf32>, tensor<1x5x8xf32>) -> (tensor<1x5x8xf32>, tensor<1x5x8xf32>)
+ return %0, %1 : tensor<1x5x8xf32>, tensor<1x5x8xf32>
+}
+
+// -----
+
+func.func @test_rfft2d_same_operands_and_result_element_type(%arg0: tensor<1x4x8xf32>) -> (tensor<1x4x5xf16>, tensor<1x4x5xf16>) {
+ // expected-error@+1 {{'tosa.rfft2d' op requires the same element type for all operands and results}}
+ %0, %1 = tosa.rfft2d %arg0 {inverse = false} : (tensor<1x4x8xf32>) -> (tensor<1x4x5xf16>, tensor<1x4x5xf16>)
+ return %0, %1 : tensor<1x4x5xf16>, tensor<1x4x5xf16>
+}
+
+// -----
+
+func.func @test_rfft2d_same_results_shape(%arg0: tensor<1x4x8xf32>) -> (tensor<1x4x6xf32>, tensor<1x4x5xf32>) {
+ // expected-error@+1 {{'tosa.rfft2d' op expected output shapes to match, got 'tensor<1x4x6xf32>', 'tensor<1x4x5xf32>'}}
+ %0, %1 = tosa.rfft2d %arg0 {inverse = false} : (tensor<1x4x8xf32>) -> (tensor<1x4x6xf32>, tensor<1x4x5xf32>)
+ return %0, %1 : tensor<1x4x6xf32>, tensor<1x4x5xf32>
+}
+
+// -----
+
+func.func @test_rfft2d_invalid_type(%arg0: tensor<1x4x8xi16>) -> (tensor<1x4x5xi16>, tensor<1x4x5xi16>) {
+ // expected-error@+1 {{'tosa.rfft2d' op requires a floating point type}}
+ %0, %1 = tosa.rfft2d %arg0 {inverse = false} : (tensor<1x4x8xi16>) -> (tensor<1x4x5xi16>, tensor<1x4x5xi16>)
+ return %0, %1 : tensor<1x4x5xi16>, tensor<1x4x5xi16>
+}
+
+// -----
+
+func.func @test_rfft2d_width_power_of_two(%arg0: tensor<1x4x9xf16>) -> (tensor<1x4x5xf16>, tensor<1x4x5xf16>) {
+ // expected-error@+1 {{'tosa.rfft2d' op expected width to be a power of two, got 9}}
+ %0, %1 = tosa.rfft2d %arg0 {inverse = false} : (tensor<1x4x9xf16>) -> (tensor<1x4x5xf16>, tensor<1x4x5xf16>)
+ return %0, %1 : tensor<1x4x5xf16>, tensor<1x4x5xf16>
+}
+
+// -----
+
+func.func @test_rfft2d_batch_input_output_match(%arg0: tensor<1x4x8xf16>) -> (tensor<2x4x5xf16>, tensor<2x4x5xf16>) {
+ // expected-error@+1 {{'tosa.rfft2d' op expected batch and height dimensions of input/output to match, got input='tensor<1x4x8xf16>' output='tensor<2x4x5xf16>'}}
+ %0, %1 = tosa.rfft2d %arg0 {inverse = false} : (tensor<1x4x8xf16>) -> (tensor<2x4x5xf16>, tensor<2x4x5xf16>)
+ return %0, %1 : tensor<2x4x5xf16>, tensor<2x4x5xf16>
+}
+
+// -----
+
+func.func @test_rfft2d_width_input_output_match(%arg0: tensor<1x4x8xf16>) -> (tensor<1x4x3xf16>, tensor<1x4x3xf16>) {
+ // expected-error@+1 {{'tosa.rfft2d' op expected output width to be equal to input_width / 2 + 1, got 3}}
+ %0, %1 = tosa.rfft2d %arg0 {inverse = false} : (tensor<1x4x8xf16>) -> (tensor<1x4x3xf16>, tensor<1x4x3xf16>)
+ return %0, %1 : tensor<1x4x3xf16>, tensor<1x4x3xf16>
+}
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index d2958efe1bb24..19584103d40c7 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -477,38 +477,38 @@ func.func @test_depthwise_conv2d_stride_x(%arg0: tensor<1x32x32x8xf32>, %arg1: t
// -----
-func.func @test_fft2d_real_h(%arg0: tensor<32x8193x32xf32>, %arg1: tensor<32x32x32xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>) {
+func.func @test_fft2d_real_h(%arg0: tensor<32x16384x32xf32>, %arg1: tensor<32x16384x32xf32>) -> (tensor<32x16384x32xf32>, tensor<32x16384x32xf32>) {
// expected-error@+1 {{'tosa.fft2d' op failed level check: H <= MAX_KERNEL}}
%0, %1 = "tosa.fft2d"(%arg0, %arg1) { inverse = false } :
- (tensor<32x8193x32xf32>, tensor<32x32x32xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>)
- return %0, %1 : tensor<32x32x32xf32>, tensor<32x32x32xf32>
+ (tensor<32x16384x32xf32>, tensor<32x16384x32xf32>) -> (tensor<32x16384x32xf32>, tensor<32x16384x32xf32>)
+ return %0, %1 : tensor<32x16384x32xf32>, tensor<32x16384x32xf32>
}
// -----
-func.func @test_fft2d_real_w(%arg0: tensor<32x32x8193xf32>, %arg1: tensor<32x32x32xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>) {
+func.func @test_fft2d_real_w(%arg0: tensor<32x32x16384xf32>, %arg1: tensor<32x32x16384xf32>) -> (tensor<32x32x16384xf32>, tensor<32x32x16384xf32>) {
// expected-error@+1 {{'tosa.fft2d' op failed level check: W <= MAX_KERNEL}}
%0, %1 = "tosa.fft2d"(%arg0, %arg1) { inverse = false } :
- (tensor<32x32x8193xf32>, tensor<32x32x32xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>)
- return %0, %1 : tensor<32x32x32xf32>, tensor<32x32x32xf32>
+ (tensor<32x32x16384xf32>, tensor<32x32x16384xf32>) -> (tensor<32x32x16384xf32>, tensor<32x32x16384xf32>)
+ return %0, %1 : tensor<32x32x16384xf32>, tensor<32x32x16384xf32>
}
// -----
-func.func @test_fft2d_imag_h(%arg0: tensor<32x32x32xf32>, %arg1: tensor<32x8193x32xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>) {
+func.func @test_fft2d_imag_h(%arg0: tensor<32x16384x32xf32>, %arg1: tensor<32x16384x32xf32>) -> (tensor<32x16384x32xf32>, tensor<32x16384x32xf32>) {
// expected-error@+1 {{'tosa.fft2d' op failed level check: H <= MAX_KERNEL}}
%0, %1 = "tosa.fft2d"(%arg0, %arg1) { inverse = false } :
- (tensor<32x32x32xf32>, tensor<32x8193x32xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>)
- return %0, %1 : tensor<32x32x32xf32>, tensor<32x32x32xf32>
+ (tensor<32x16384x32xf32>, tensor<32x16384x32xf32>) -> (tensor<32x16384x32xf32>, tensor<32x16384x32xf32>)
+ return %0, %1 : tensor<32x16384x32xf32>, tensor<32x16384x32xf32>
}
// -----
-func.func @test_fft2d_imag_w(%arg0: tensor<32x32x32xf32>, %arg1: tensor<32x32x8193xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>) {
+func.func @test_fft2d_imag_w(%arg0: tensor<32x32x16384xf32>, %arg1: tensor<32x32x16384xf32>) -> (tensor<32x32x16384xf32>, tensor<32x32x16384xf32>) {
// expected-error@+1 {{'tosa.fft2d' op failed level check: W <= MAX_KERNEL}}
%0, %1 = "tosa.fft2d"(%arg0, %arg1) { inverse = false } :
- (tensor<32x32x32xf32>, tensor<32x32x8193xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>)
- return %0, %1 : tensor<32x32x32xf32>, tensor<32x32x32xf32>
+ (tensor<32x32x16384xf32>, tensor<32x32x16384xf32>) -> (tensor<32x32x16384xf32>, tensor<32x32x16384xf32>)
+ return %0, %1 : tensor<32x32x16384xf32>, tensor<32x32x16384xf32>
}
// -----
@@ -577,18 +577,18 @@ func.func @test_maxpool2d_p...
[truncated]
|
@llvm/pr-subscribers-mlir Author: Luke Hutton (lhutton1) ChangesAdds checks for element types and input/output shapes. Patch is 21.37 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/129273.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index ddfec2c9bfcd3..9e14daf0d014c 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -240,7 +240,10 @@ def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> {
//===----------------------------------------------------------------------===//
// Operator: fft2d
//===----------------------------------------------------------------------===//
-def Tosa_FFT2dOp : Tosa_InferShapedTypeOp<"fft2d"> {
+def Tosa_FFT2dOp : Tosa_InferShapedTypeOp<"fft2d", [
+ SameOperandsAndResultElementType,
+ SameOperandsAndResultShape,
+ ResultsAreFloatLike]> {
let summary = "Performs FFT2D operation on the input.";
let description = [{
@@ -279,6 +282,8 @@ def Tosa_FFT2dOp : Tosa_InferShapedTypeOp<"fft2d"> {
$input_real `,` $input_imag attr-dict `:` `(` type($input_real) `,`
type($input_imag) `)` `->` `(` type($output_real) `,` type($output_imag) `)`
}];
+
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -349,7 +354,9 @@ def Tosa_MaxPool2dOp : Tosa_InferShapedTypeOp<"max_pool2d"> {
//===----------------------------------------------------------------------===//
// Operator: rfft2d
//===----------------------------------------------------------------------===//
-def Tosa_RFFT2dOp : Tosa_InferShapedTypeOp<"rfft2d"> {
+def Tosa_RFFT2dOp : Tosa_InferShapedTypeOp<"rfft2d", [
+ SameOperandsAndResultElementType,
+ ResultsAreFloatLike]> {
let summary = "Performs RFFT2D operation on the input.";
let description = [{
@@ -385,6 +392,8 @@ def Tosa_RFFT2dOp : Tosa_InferShapedTypeOp<"rfft2d"> {
let assemblyFormat = [{
$input attr-dict `:` `(` type($input) `)` `->` `(` type($output_real) `,` type($output_imag) `)`
}];
+
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 7b50eceb081dd..63afaed22d7ab 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -789,7 +789,7 @@ LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
int64_t inWidth = inputShape.getDimSize(2);
// Note that we can support this calculation symbolically
- // in the future e.g. [x, y, z] -> [x, y, z / 2 - 1]
+ // in the future e.g. [x, y, z] -> [x, y, z / 2 + 1]
if (inWidth != ShapedType::kDynamic)
outputShape[2] = inWidth / 2 + 1;
@@ -799,6 +799,57 @@ LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
return success();
}
+static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize,
+ const llvm::StringRef dimName) {
+ const bool isPowerOfTwo = (dimSize & (dimSize - 1)) == 0 && dimSize > 0;
+ if (!isPowerOfTwo)
+ return op->emitOpError("expected ")
+ << dimName << " to be a power of two, got " << dimSize;
+
+ return success();
+}
+
+LogicalResult tosa::RFFT2dOp::verify() {
+ const auto outputTypes = getResultTypes();
+ if (failed(verifyCompatibleShapes(outputTypes)))
+ return emitOpError("expected output shapes to match, got ") << outputTypes;
+
+ const auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().getType());
+ if (!inputType)
+ return success();
+
+ const int64_t height = inputType.getDimSize(1);
+ if (!ShapedType::isDynamic(height) &&
+ failed(verifyDimIsPowerOfTwo(*this, height, "height")))
+ return failure();
+
+ const int64_t width = inputType.getDimSize(2);
+ if (!ShapedType::isDynamic(width) &&
+ failed(verifyDimIsPowerOfTwo(*this, width, "width")))
+ return failure();
+
+ const auto outputType = llvm::dyn_cast<RankedTensorType>(outputTypes[0]);
+ if (!outputType)
+ return success();
+
+ // Batch and height input/output dimensions should match
+ if (failed(verifyCompatibleShape(inputType.getShape().drop_back(),
+ outputType.getShape().drop_back())))
+ return emitOpError("expected batch and height dimensions of input/output "
+ "to match, got input=")
+ << inputType << " output=" << outputType;
+
+ // Output width dimension expected to be input_width / 2 + 1
+ const int64_t outputWidth = outputType.getDimSize(2);
+ if (!ShapedType::isDynamic(width) && !ShapedType::isDynamic(outputWidth) &&
+ (outputWidth - 1) * 2 != width)
+ return emitOpError(
+ "expected output width to be equal to input_width / 2 + 1, got ")
+ << outputWidth;
+
+ return success();
+}
+
LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
FFT2dOp::Adaptor adaptor,
@@ -810,6 +861,33 @@ LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
return success();
}
+LogicalResult tosa::FFT2dOp::verify() {
+ const auto inputRealType =
+ llvm::dyn_cast<RankedTensorType>(getInputReal().getType());
+ const auto inputImagType =
+ llvm::dyn_cast<RankedTensorType>(getInputImag().getType());
+ if (!inputRealType || !inputImagType)
+ return success();
+
+ const auto trySelectStaticDim = [](const int64_t a, const int64_t b) {
+ return ShapedType::isDynamic(a) ? a : b;
+ };
+
+ const int64_t height = trySelectStaticDim(inputRealType.getDimSize(1),
+ inputImagType.getDimSize(1));
+ if (!ShapedType::isDynamic(height) &&
+ failed(verifyDimIsPowerOfTwo(*this, height, "height")))
+ return failure();
+
+ const int64_t width = trySelectStaticDim(inputRealType.getDimSize(2),
+ inputImagType.getDimSize(2));
+ if (!ShapedType::isDynamic(width) &&
+ failed(verifyDimIsPowerOfTwo(*this, width, "width")))
+ return failure();
+
+ return success();
+}
+
LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ConcatOp::Adaptor adaptor,
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
index 5db3f56cf459e..71fcd4129a618 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
@@ -31,15 +31,6 @@ func.func @unranked_add(%arg0 : tensor<10x10xf32> , %arg1 : tensor<10x10xf32>, %
// -----
-// CHECK-LABEL: @rfft2d_with_non_float_type
-func.func @rfft2d_with_non_float_type(%arg0 : tensor<1x1x1xi32>) -> (tensor<1x1x1xi32>, tensor<1x1x1xi32>) {
- // expected-error@+1 {{failed to legalize operation 'tosa.rfft2d'}}
- %real, %imag = tosa.rfft2d %arg0 : (tensor<1x1x1xi32>) -> (tensor<1x1x1xi32>, tensor<1x1x1xi32>)
- return %real, %imag : tensor<1x1x1xi32>, tensor<1x1x1xi32>
-}
-
-// -----
-
// CHECK-LABEL: @rescale_unsupported_type
func.func @rescale_unsupported_type(%arg0: tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>> {
// expected-error@+1 {{failed to legalize operation 'tosa.rescale'}}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 78f2e173d7cb1..e68783f779063 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1720,20 +1720,20 @@ func.func @table8_dyn_table(%arg0: tensor<6xi8>, %arg1: tensor<?xi8>) -> () {
// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
// CHECK-LABEL: func.func @test_static_rfft2d(
-// CHECK-SAME: %[[VAL_0:.*]]: tensor<5x5x8xf32>) -> (tensor<5x5x5xf32>, tensor<5x5x5xf32>) {
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<5x4x8xf32>) -> (tensor<5x4x5xf32>, tensor<5x4x5xf32>) {
// CHECK: %[[VAL_1:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_2:.*]] = arith.constant 2 : index
// CHECK: %[[VAL_3:.*]] = arith.constant 8 : index
// CHECK: %[[VAL_4:.*]] = arith.constant 4 : index
// CHECK: %[[VAL_5:.*]] = arith.constant 5 : index
-// CHECK: %[[VAL_6:.*]] = tensor.empty() : tensor<5x5x5xf32>
+// CHECK: %[[VAL_6:.*]] = tensor.empty() : tensor<5x4x5xf32>
// CHECK: %[[VAL_7:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[VAL_8:.*]] = linalg.fill ins(%[[VAL_7]] : f32) outs(%[[VAL_6]] : tensor<5x5x5xf32>) -> tensor<5x5x5xf32>
-// CHECK: %[[VAL_9:.*]] = tensor.empty() : tensor<5x5x5xf32>
+// CHECK: %[[VAL_8:.*]] = linalg.fill ins(%[[VAL_7]] : f32) outs(%[[VAL_6]] : tensor<5x4x5xf32>) -> tensor<5x4x5xf32>
+// CHECK: %[[VAL_9:.*]] = tensor.empty() : tensor<5x4x5xf32>
// CHECK: %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[VAL_11:.*]] = linalg.fill ins(%[[VAL_10]] : f32) outs(%[[VAL_9]] : tensor<5x5x5xf32>) -> tensor<5x5x5xf32>
+// CHECK: %[[VAL_11:.*]] = linalg.fill ins(%[[VAL_10]] : f32) outs(%[[VAL_9]] : tensor<5x4x5xf32>) -> tensor<5x4x5xf32>
// CHECK: %[[VAL_12:.*]] = arith.constant 1 : index
-// CHECK: %[[VAL_13:.*]] = arith.constant 5 : index
+// CHECK: %[[VAL_13:.*]] = arith.constant 4 : index
// CHECK: %[[VAL_14:.*]] = arith.constant 2 : index
// CHECK: %[[VAL_15:.*]] = arith.constant 8 : index
// CHECK: %[[VAL_16:.*]] = arith.constant 6.28318548 : f32
@@ -1741,7 +1741,7 @@ func.func @table8_dyn_table(%arg0: tensor<6xi8>, %arg1: tensor<?xi8>) -> () {
// CHECK: %[[VAL_18:.*]] = arith.uitofp %[[VAL_17]] : i32 to f32
// CHECK: %[[VAL_19:.*]] = arith.index_castui %[[VAL_15]] : index to i32
// CHECK: %[[VAL_20:.*]] = arith.uitofp %[[VAL_19]] : i32 to f32
-// CHECK: %[[VAL_21:.*]]:2 = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]]], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%[[VAL_0]] : tensor<5x5x8xf32>) outs(%[[VAL_8]], %[[VAL_11]] : tensor<5x5x5xf32>, tensor<5x5x5xf32>) {
+// CHECK: %[[VAL_21:.*]]:2 = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]]], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%[[VAL_0]] : tensor<5x4x8xf32>) outs(%[[VAL_8]], %[[VAL_11]] : tensor<5x4x5xf32>, tensor<5x4x5xf32>) {
// CHECK: ^bb0(%[[VAL_22:.*]]: f32, %[[VAL_23:.*]]: f32, %[[VAL_24:.*]]: f32):
// CHECK: %[[VAL_25:.*]] = linalg.index 1 : index
// CHECK: %[[VAL_26:.*]] = linalg.index 2 : index
@@ -1766,12 +1766,12 @@ func.func @table8_dyn_table(%arg0: tensor<6xi8>, %arg1: tensor<?xi8>) -> () {
// CHECK: %[[VAL_45:.*]] = arith.addf %[[VAL_23]], %[[VAL_43]] : f32
// CHECK: %[[VAL_46:.*]] = arith.subf %[[VAL_24]], %[[VAL_44]] : f32
// CHECK: linalg.yield %[[VAL_45]], %[[VAL_46]] : f32, f32
-// CHECK: } -> (tensor<5x5x5xf32>, tensor<5x5x5xf32>)
-// CHECK: return %[[VAL_47:.*]]#0, %[[VAL_47]]#1 : tensor<5x5x5xf32>, tensor<5x5x5xf32>
+// CHECK: } -> (tensor<5x4x5xf32>, tensor<5x4x5xf32>)
+// CHECK: return %[[VAL_47:.*]]#0, %[[VAL_47]]#1 : tensor<5x4x5xf32>, tensor<5x4x5xf32>
// CHECK: }
-func.func @test_static_rfft2d(%arg0: tensor<5x5x8xf32>) -> (tensor<5x5x5xf32>, tensor<5x5x5xf32>) {
- %output_real, %output_imag = "tosa.rfft2d"(%arg0) {} : (tensor<5x5x8xf32>) -> (tensor<5x5x5xf32>, tensor<5x5x5xf32>)
- return %output_real, %output_imag : tensor<5x5x5xf32>, tensor<5x5x5xf32>
+func.func @test_static_rfft2d(%arg0: tensor<5x4x8xf32>) -> (tensor<5x4x5xf32>, tensor<5x4x5xf32>) {
+ %output_real, %output_imag = "tosa.rfft2d"(%arg0) {} : (tensor<5x4x8xf32>) -> (tensor<5x4x5xf32>, tensor<5x4x5xf32>)
+ return %output_real, %output_imag : tensor<5x4x5xf32>, tensor<5x4x5xf32>
}
// -----
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 123c65e1b4fcd..fb7a9222947f6 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -1276,3 +1276,83 @@ func.func @test_conv2d_invalid_bias_size(%arg0: tensor<1x4x4x4xf32>, %arg1: tens
: (tensor<1x4x4x4xf32>, tensor<8x1x1x4xf32>, tensor<7xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x4x4x8xf32>
return %0 : tensor<1x4x4x8xf32>
}
+
+// -----
+
+func.func @test_fft2d_same_operands_and_result_element_type(%arg0: tensor<1x4x8xf32>, %arg1: tensor<1x4x8xf32>) -> (tensor<1x4x8xf16>, tensor<1x4x8xf16>) {
+ // expected-error@+1 {{'tosa.fft2d' op requires the same element type for all operands and results}}
+ %0, %1 = tosa.fft2d %arg0, %arg1 {inverse = false} : (tensor<1x4x8xf32>, tensor<1x4x8xf32>) -> (tensor<1x4x8xf16>, tensor<1x4x8xf16>)
+ return %0, %1 : tensor<1x4x8xf16>, tensor<1x4x8xf16>
+}
+
+// -----
+
+func.func @test_fft2d_same_operands_and_result_shape(%arg0: tensor<1x4x8xf32>, %arg1: tensor<1x4x7xf32>) -> (tensor<1x4x8xf32>, tensor<1x4x8xf32>) {
+ // expected-error@+1 {{'tosa.fft2d' op requires the same shape for all operands and results}}
+ %0, %1 = tosa.fft2d %arg0, %arg1 {inverse = false} : (tensor<1x4x8xf32>, tensor<1x4x7xf32>) -> (tensor<1x4x8xf32>, tensor<1x4x8xf32>)
+ return %0, %1 : tensor<1x4x8xf32>, tensor<1x4x8xf32>
+}
+
+// -----
+
+func.func @test_fft2d_invalid_type(%arg0: tensor<1x4x8xi8>, %arg1: tensor<1x4x8xi8>) -> (tensor<1x4x8xi8>, tensor<1x4x8xi8>) {
+ // expected-error@+1 {{'tosa.fft2d' op requires a floating point type}}
+ %0, %1 = tosa.fft2d %arg0, %arg1 {inverse = false} : (tensor<1x4x8xi8>, tensor<1x4x8xi8>) -> (tensor<1x4x8xi8>, tensor<1x4x8xi8>)
+ return %0, %1 : tensor<1x4x8xi8>, tensor<1x4x8xi8>
+}
+
+// -----
+
+func.func @test_fft2d_height_non_power_of_two(%arg0: tensor<1x5x8xf32>, %arg1: tensor<1x5x8xf32>) -> (tensor<1x5x8xf32>, tensor<1x5x8xf32>) {
+ // expected-error@+1 {{'tosa.fft2d' op expected height to be a power of two, got 5}}
+ %0, %1 = tosa.fft2d %arg0, %arg1 {inverse = false} : (tensor<1x5x8xf32>, tensor<1x5x8xf32>) -> (tensor<1x5x8xf32>, tensor<1x5x8xf32>)
+ return %0, %1 : tensor<1x5x8xf32>, tensor<1x5x8xf32>
+}
+
+// -----
+
+func.func @test_rfft2d_same_operands_and_result_element_type(%arg0: tensor<1x4x8xf32>) -> (tensor<1x4x5xf16>, tensor<1x4x5xf16>) {
+ // expected-error@+1 {{'tosa.rfft2d' op requires the same element type for all operands and results}}
+ %0, %1 = tosa.rfft2d %arg0 {inverse = false} : (tensor<1x4x8xf32>) -> (tensor<1x4x5xf16>, tensor<1x4x5xf16>)
+ return %0, %1 : tensor<1x4x5xf16>, tensor<1x4x5xf16>
+}
+
+// -----
+
+func.func @test_rfft2d_same_results_shape(%arg0: tensor<1x4x8xf32>) -> (tensor<1x4x6xf32>, tensor<1x4x5xf32>) {
+ // expected-error@+1 {{'tosa.rfft2d' op expected output shapes to match, got 'tensor<1x4x6xf32>', 'tensor<1x4x5xf32>'}}
+ %0, %1 = tosa.rfft2d %arg0 {inverse = false} : (tensor<1x4x8xf32>) -> (tensor<1x4x6xf32>, tensor<1x4x5xf32>)
+ return %0, %1 : tensor<1x4x6xf32>, tensor<1x4x5xf32>
+}
+
+// -----
+
+func.func @test_rfft2d_invalid_type(%arg0: tensor<1x4x8xi16>) -> (tensor<1x4x5xi16>, tensor<1x4x5xi16>) {
+ // expected-error@+1 {{'tosa.rfft2d' op requires a floating point type}}
+ %0, %1 = tosa.rfft2d %arg0 {inverse = false} : (tensor<1x4x8xi16>) -> (tensor<1x4x5xi16>, tensor<1x4x5xi16>)
+ return %0, %1 : tensor<1x4x5xi16>, tensor<1x4x5xi16>
+}
+
+// -----
+
+func.func @test_rfft2d_width_power_of_two(%arg0: tensor<1x4x9xf16>) -> (tensor<1x4x5xf16>, tensor<1x4x5xf16>) {
+ // expected-error@+1 {{'tosa.rfft2d' op expected width to be a power of two, got 9}}
+ %0, %1 = tosa.rfft2d %arg0 {inverse = false} : (tensor<1x4x9xf16>) -> (tensor<1x4x5xf16>, tensor<1x4x5xf16>)
+ return %0, %1 : tensor<1x4x5xf16>, tensor<1x4x5xf16>
+}
+
+// -----
+
+func.func @test_rfft2d_batch_input_output_match(%arg0: tensor<1x4x8xf16>) -> (tensor<2x4x5xf16>, tensor<2x4x5xf16>) {
+ // expected-error@+1 {{'tosa.rfft2d' op expected batch and height dimensions of input/output to match, got input='tensor<1x4x8xf16>' output='tensor<2x4x5xf16>'}}
+ %0, %1 = tosa.rfft2d %arg0 {inverse = false} : (tensor<1x4x8xf16>) -> (tensor<2x4x5xf16>, tensor<2x4x5xf16>)
+ return %0, %1 : tensor<2x4x5xf16>, tensor<2x4x5xf16>
+}
+
+// -----
+
+func.func @test_rfft2d_width_input_output_match(%arg0: tensor<1x4x8xf16>) -> (tensor<1x4x3xf16>, tensor<1x4x3xf16>) {
+ // expected-error@+1 {{'tosa.rfft2d' op expected output width to be equal to input_width / 2 + 1, got 3}}
+ %0, %1 = tosa.rfft2d %arg0 {inverse = false} : (tensor<1x4x8xf16>) -> (tensor<1x4x3xf16>, tensor<1x4x3xf16>)
+ return %0, %1 : tensor<1x4x3xf16>, tensor<1x4x3xf16>
+}
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index d2958efe1bb24..19584103d40c7 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -477,38 +477,38 @@ func.func @test_depthwise_conv2d_stride_x(%arg0: tensor<1x32x32x8xf32>, %arg1: t
// -----
-func.func @test_fft2d_real_h(%arg0: tensor<32x8193x32xf32>, %arg1: tensor<32x32x32xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>) {
+func.func @test_fft2d_real_h(%arg0: tensor<32x16384x32xf32>, %arg1: tensor<32x16384x32xf32>) -> (tensor<32x16384x32xf32>, tensor<32x16384x32xf32>) {
// expected-error@+1 {{'tosa.fft2d' op failed level check: H <= MAX_KERNEL}}
%0, %1 = "tosa.fft2d"(%arg0, %arg1) { inverse = false } :
- (tensor<32x8193x32xf32>, tensor<32x32x32xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>)
- return %0, %1 : tensor<32x32x32xf32>, tensor<32x32x32xf32>
+ (tensor<32x16384x32xf32>, tensor<32x16384x32xf32>) -> (tensor<32x16384x32xf32>, tensor<32x16384x32xf32>)
+ return %0, %1 : tensor<32x16384x32xf32>, tensor<32x16384x32xf32>
}
// -----
-func.func @test_fft2d_real_w(%arg0: tensor<32x32x8193xf32>, %arg1: tensor<32x32x32xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>) {
+func.func @test_fft2d_real_w(%arg0: tensor<32x32x16384xf32>, %arg1: tensor<32x32x16384xf32>) -> (tensor<32x32x16384xf32>, tensor<32x32x16384xf32>) {
// expected-error@+1 {{'tosa.fft2d' op failed level check: W <= MAX_KERNEL}}
%0, %1 = "tosa.fft2d"(%arg0, %arg1) { inverse = false } :
- (tensor<32x32x8193xf32>, tensor<32x32x32xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>)
- return %0, %1 : tensor<32x32x32xf32>, tensor<32x32x32xf32>
+ (tensor<32x32x16384xf32>, tensor<32x32x16384xf32>) -> (tensor<32x32x16384xf32>, tensor<32x32x16384xf32>)
+ return %0, %1 : tensor<32x32x16384xf32>, tensor<32x32x16384xf32>
}
// -----
-func.func @test_fft2d_imag_h(%arg0: tensor<32x32x32xf32>, %arg1: tensor<32x8193x32xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>) {
+func.func @test_fft2d_imag_h(%arg0: tensor<32x16384x32xf32>, %arg1: tensor<32x16384x32xf32>) -> (tensor<32x16384x32xf32>, tensor<32x16384x32xf32>) {
// expected-error@+1 {{'tosa.fft2d' op failed level check: H <= MAX_KERNEL}}
%0, %1 = "tosa.fft2d"(%arg0, %arg1) { inverse = false } :
- (tensor<32x32x32xf32>, tensor<32x8193x32xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>)
- return %0, %1 : tensor<32x32x32xf32>, tensor<32x32x32xf32>
+ (tensor<32x16384x32xf32>, tensor<32x16384x32xf32>) -> (tensor<32x16384x32xf32>, tensor<32x16384x32xf32>)
+ return %0, %1 : tensor<32x16384x32xf32>, tensor<32x16384x32xf32>
}
// -----
-func.func @test_fft2d_imag_w(%arg0: tensor<32x32x32xf32>, %arg1: tensor<32x32x8193xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>) {
+func.func @test_fft2d_imag_w(%arg0: tensor<32x32x16384xf32>, %arg1: tensor<32x32x16384xf32>) -> (tensor<32x32x16384xf32>, tensor<32x32x16384xf32>) {
// expected-error@+1 {{'tosa.fft2d' op failed level check: W <= MAX_KERNEL}}
%0, %1 = "tosa.fft2d"(%arg0, %arg1) { inverse = false } :
- (tensor<32x32x32xf32>, tensor<32x32x8193xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>)
- return %0, %1 : tensor<32x32x32xf32>, tensor<32x32x32xf32>
+ (tensor<32x32x16384xf32>, tensor<32x32x16384xf32>) -> (tensor<32x32x16384xf32>, tensor<32x32x16384xf32>)
+ return %0, %1 : tensor<32x32x16384xf32>, tensor<32x32x16384xf32>
}
// -----
@@ -577,18 +577,18 @@ func.func @test_maxpool2d_p...
[truncated]
|
Adds checks for element types and input/output shapes. Signed-off-by: Luke Hutton <[email protected]>
Adds checks for element types and input/output shapes.