Skip to content

[mlir][tosa] Add verifiers for FFT2d and RFFT2d #129273

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 3, 2025

Conversation

lhutton1
Copy link
Contributor

Adds checks for element types and input/output shapes.

Adds checks for element types and input/output shapes.

Signed-off-by: Luke Hutton <[email protected]>
Change-Id: Ib40928027f5b9d75306aa662c4627e3263db7de7
@llvmbot
Copy link
Member

llvmbot commented Feb 28, 2025

@llvm/pr-subscribers-mlir-tosa

@llvm/pr-subscribers-mlir-linalg

Author: Luke Hutton (lhutton1)

Changes

Adds checks for element types and input/output shapes.


Patch is 21.37 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/129273.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+11-2)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+79-1)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir (-9)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+12-12)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+80)
  • (modified) mlir/test/Dialect/Tosa/level_check.mlir (+18-18)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index ddfec2c9bfcd3..9e14daf0d014c 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -240,7 +240,10 @@ def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> {
 //===----------------------------------------------------------------------===//
 // Operator: fft2d
 //===----------------------------------------------------------------------===//
-def Tosa_FFT2dOp : Tosa_InferShapedTypeOp<"fft2d"> {
+def Tosa_FFT2dOp : Tosa_InferShapedTypeOp<"fft2d", [
+    SameOperandsAndResultElementType,
+    SameOperandsAndResultShape,
+    ResultsAreFloatLike]> {
   let summary = "Performs FFT2D operation on the input.";
 
   let description = [{
@@ -279,6 +282,8 @@ def Tosa_FFT2dOp : Tosa_InferShapedTypeOp<"fft2d"> {
     $input_real `,` $input_imag attr-dict `:` `(` type($input_real) `,`
     type($input_imag) `)` `->` `(` type($output_real) `,` type($output_imag) `)`
   }];
+
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -349,7 +354,9 @@ def Tosa_MaxPool2dOp : Tosa_InferShapedTypeOp<"max_pool2d"> {
 //===----------------------------------------------------------------------===//
 // Operator: rfft2d
 //===----------------------------------------------------------------------===//
-def Tosa_RFFT2dOp : Tosa_InferShapedTypeOp<"rfft2d"> {
+def Tosa_RFFT2dOp : Tosa_InferShapedTypeOp<"rfft2d", [
+    SameOperandsAndResultElementType,
+    ResultsAreFloatLike]> {
   let summary = "Performs RFFT2D operation on the input.";
 
   let description = [{
@@ -385,6 +392,8 @@ def Tosa_RFFT2dOp : Tosa_InferShapedTypeOp<"rfft2d"> {
   let assemblyFormat = [{
     $input attr-dict `:` `(` type($input) `)` `->` `(` type($output_real) `,` type($output_imag) `)`
   }];
+
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 7b50eceb081dd..63afaed22d7ab 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -789,7 +789,7 @@ LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
   int64_t inWidth = inputShape.getDimSize(2);
 
   // Note that we can support this calculation symbolically
-  // in the future e.g. [x, y, z] -> [x, y, z / 2 - 1]
+  // in the future e.g. [x, y, z] -> [x, y, z / 2 + 1]
   if (inWidth != ShapedType::kDynamic)
     outputShape[2] = inWidth / 2 + 1;
 
@@ -799,6 +799,57 @@ LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
   return success();
 }
 
+static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize,
+                                           const llvm::StringRef dimName) {
+  const bool isPowerOfTwo = (dimSize & (dimSize - 1)) == 0 && dimSize > 0;
+  if (!isPowerOfTwo)
+    return op->emitOpError("expected ")
+           << dimName << " to be a power of two, got " << dimSize;
+
+  return success();
+}
+
+LogicalResult tosa::RFFT2dOp::verify() {
+  const auto outputTypes = getResultTypes();
+  if (failed(verifyCompatibleShapes(outputTypes)))
+    return emitOpError("expected output shapes to match, got ") << outputTypes;
+
+  const auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().getType());
+  if (!inputType)
+    return success();
+
+  const int64_t height = inputType.getDimSize(1);
+  if (!ShapedType::isDynamic(height) &&
+      failed(verifyDimIsPowerOfTwo(*this, height, "height")))
+    return failure();
+
+  const int64_t width = inputType.getDimSize(2);
+  if (!ShapedType::isDynamic(width) &&
+      failed(verifyDimIsPowerOfTwo(*this, width, "width")))
+    return failure();
+
+  const auto outputType = llvm::dyn_cast<RankedTensorType>(outputTypes[0]);
+  if (!outputType)
+    return success();
+
+  // Batch and height input/output dimensions should match
+  if (failed(verifyCompatibleShape(inputType.getShape().drop_back(),
+                                   outputType.getShape().drop_back())))
+    return emitOpError("expected batch and height dimensions of input/output "
+                       "to match, got input=")
+           << inputType << " output=" << outputType;
+
+  // Output width dimension expected to be input_width / 2 + 1
+  const int64_t outputWidth = outputType.getDimSize(2);
+  if (!ShapedType::isDynamic(width) && !ShapedType::isDynamic(outputWidth) &&
+      (outputWidth - 1) * 2 != width)
+    return emitOpError(
+               "expected output width to be equal to input_width / 2 + 1, got ")
+           << outputWidth;
+
+  return success();
+}
+
 LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     FFT2dOp::Adaptor adaptor,
@@ -810,6 +861,33 @@ LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
   return success();
 }
 
+LogicalResult tosa::FFT2dOp::verify() {
+  const auto inputRealType =
+      llvm::dyn_cast<RankedTensorType>(getInputReal().getType());
+  const auto inputImagType =
+      llvm::dyn_cast<RankedTensorType>(getInputImag().getType());
+  if (!inputRealType || !inputImagType)
+    return success();
+
+  const auto trySelectStaticDim = [](const int64_t a, const int64_t b) {
+    return ShapedType::isDynamic(a) ? a : b;
+  };
+
+  const int64_t height = trySelectStaticDim(inputRealType.getDimSize(1),
+                                            inputImagType.getDimSize(1));
+  if (!ShapedType::isDynamic(height) &&
+      failed(verifyDimIsPowerOfTwo(*this, height, "height")))
+    return failure();
+
+  const int64_t width = trySelectStaticDim(inputRealType.getDimSize(2),
+                                           inputImagType.getDimSize(2));
+  if (!ShapedType::isDynamic(width) &&
+      failed(verifyDimIsPowerOfTwo(*this, width, "width")))
+    return failure();
+
+  return success();
+}
+
 LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     ConcatOp::Adaptor adaptor,
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
index 5db3f56cf459e..71fcd4129a618 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
@@ -31,15 +31,6 @@ func.func @unranked_add(%arg0 : tensor<10x10xf32> , %arg1 : tensor<10x10xf32>, %
 
 // -----
 
-// CHECK-LABEL: @rfft2d_with_non_float_type
-func.func @rfft2d_with_non_float_type(%arg0 : tensor<1x1x1xi32>) -> (tensor<1x1x1xi32>, tensor<1x1x1xi32>) {
-  // expected-error@+1 {{failed to legalize operation 'tosa.rfft2d'}}
-  %real, %imag = tosa.rfft2d %arg0 : (tensor<1x1x1xi32>) -> (tensor<1x1x1xi32>, tensor<1x1x1xi32>)
-  return %real, %imag : tensor<1x1x1xi32>, tensor<1x1x1xi32>
-}
-
-// -----
-
 // CHECK-LABEL: @rescale_unsupported_type
 func.func @rescale_unsupported_type(%arg0: tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>> {
   // expected-error@+1 {{failed to legalize operation 'tosa.rescale'}}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 78f2e173d7cb1..e68783f779063 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1720,20 +1720,20 @@ func.func @table8_dyn_table(%arg0: tensor<6xi8>, %arg1: tensor<?xi8>) -> () {
 // CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
 
 // CHECK-LABEL:   func.func @test_static_rfft2d(
-// CHECK-SAME:                                  %[[VAL_0:.*]]: tensor<5x5x8xf32>) -> (tensor<5x5x5xf32>, tensor<5x5x5xf32>) {
+// CHECK-SAME:                                  %[[VAL_0:.*]]: tensor<5x4x8xf32>) -> (tensor<5x4x5xf32>, tensor<5x4x5xf32>) {
 // CHECK:           %[[VAL_1:.*]] = arith.constant 1 : index
 // CHECK:           %[[VAL_2:.*]] = arith.constant 2 : index
 // CHECK:           %[[VAL_3:.*]] = arith.constant 8 : index
 // CHECK:           %[[VAL_4:.*]] = arith.constant 4 : index
 // CHECK:           %[[VAL_5:.*]] = arith.constant 5 : index
-// CHECK:           %[[VAL_6:.*]] = tensor.empty() : tensor<5x5x5xf32>
+// CHECK:           %[[VAL_6:.*]] = tensor.empty() : tensor<5x4x5xf32>
 // CHECK:           %[[VAL_7:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK:           %[[VAL_8:.*]] = linalg.fill ins(%[[VAL_7]] : f32) outs(%[[VAL_6]] : tensor<5x5x5xf32>) -> tensor<5x5x5xf32>
-// CHECK:           %[[VAL_9:.*]] = tensor.empty() : tensor<5x5x5xf32>
+// CHECK:           %[[VAL_8:.*]] = linalg.fill ins(%[[VAL_7]] : f32) outs(%[[VAL_6]] : tensor<5x4x5xf32>) -> tensor<5x4x5xf32>
+// CHECK:           %[[VAL_9:.*]] = tensor.empty() : tensor<5x4x5xf32>
 // CHECK:           %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK:           %[[VAL_11:.*]] = linalg.fill ins(%[[VAL_10]] : f32) outs(%[[VAL_9]] : tensor<5x5x5xf32>) -> tensor<5x5x5xf32>
+// CHECK:           %[[VAL_11:.*]] = linalg.fill ins(%[[VAL_10]] : f32) outs(%[[VAL_9]] : tensor<5x4x5xf32>) -> tensor<5x4x5xf32>
 // CHECK:           %[[VAL_12:.*]] = arith.constant 1 : index
-// CHECK:           %[[VAL_13:.*]] = arith.constant 5 : index
+// CHECK:           %[[VAL_13:.*]] = arith.constant 4 : index
 // CHECK:           %[[VAL_14:.*]] = arith.constant 2 : index
 // CHECK:           %[[VAL_15:.*]] = arith.constant 8 : index
 // CHECK:           %[[VAL_16:.*]] = arith.constant 6.28318548 : f32
@@ -1741,7 +1741,7 @@ func.func @table8_dyn_table(%arg0: tensor<6xi8>, %arg1: tensor<?xi8>) -> () {
 // CHECK:           %[[VAL_18:.*]] = arith.uitofp %[[VAL_17]] : i32 to f32
 // CHECK:           %[[VAL_19:.*]] = arith.index_castui %[[VAL_15]] : index to i32
 // CHECK:           %[[VAL_20:.*]] = arith.uitofp %[[VAL_19]] : i32 to f32
-// CHECK:           %[[VAL_21:.*]]:2 = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]]], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%[[VAL_0]] : tensor<5x5x8xf32>) outs(%[[VAL_8]], %[[VAL_11]] : tensor<5x5x5xf32>, tensor<5x5x5xf32>) {
+// CHECK:           %[[VAL_21:.*]]:2 = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]]], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%[[VAL_0]] : tensor<5x4x8xf32>) outs(%[[VAL_8]], %[[VAL_11]] : tensor<5x4x5xf32>, tensor<5x4x5xf32>) {
 // CHECK:           ^bb0(%[[VAL_22:.*]]: f32, %[[VAL_23:.*]]: f32, %[[VAL_24:.*]]: f32):
 // CHECK:             %[[VAL_25:.*]] = linalg.index 1 : index
 // CHECK:             %[[VAL_26:.*]] = linalg.index 2 : index
@@ -1766,12 +1766,12 @@ func.func @table8_dyn_table(%arg0: tensor<6xi8>, %arg1: tensor<?xi8>) -> () {
 // CHECK:             %[[VAL_45:.*]] = arith.addf %[[VAL_23]], %[[VAL_43]] : f32
 // CHECK:             %[[VAL_46:.*]] = arith.subf %[[VAL_24]], %[[VAL_44]] : f32
 // CHECK:             linalg.yield %[[VAL_45]], %[[VAL_46]] : f32, f32
-// CHECK:           } -> (tensor<5x5x5xf32>, tensor<5x5x5xf32>)
-// CHECK:           return %[[VAL_47:.*]]#0, %[[VAL_47]]#1 : tensor<5x5x5xf32>, tensor<5x5x5xf32>
+// CHECK:           } -> (tensor<5x4x5xf32>, tensor<5x4x5xf32>)
+// CHECK:           return %[[VAL_47:.*]]#0, %[[VAL_47]]#1 : tensor<5x4x5xf32>, tensor<5x4x5xf32>
 // CHECK:         }
-func.func @test_static_rfft2d(%arg0: tensor<5x5x8xf32>) -> (tensor<5x5x5xf32>, tensor<5x5x5xf32>) {
-  %output_real, %output_imag = "tosa.rfft2d"(%arg0) {} : (tensor<5x5x8xf32>) -> (tensor<5x5x5xf32>, tensor<5x5x5xf32>)
-  return %output_real, %output_imag : tensor<5x5x5xf32>, tensor<5x5x5xf32>
+func.func @test_static_rfft2d(%arg0: tensor<5x4x8xf32>) -> (tensor<5x4x5xf32>, tensor<5x4x5xf32>) {
+  %output_real, %output_imag = "tosa.rfft2d"(%arg0) {} : (tensor<5x4x8xf32>) -> (tensor<5x4x5xf32>, tensor<5x4x5xf32>)
+  return %output_real, %output_imag : tensor<5x4x5xf32>, tensor<5x4x5xf32>
 }
 
 // -----
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 123c65e1b4fcd..fb7a9222947f6 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -1276,3 +1276,83 @@ func.func @test_conv2d_invalid_bias_size(%arg0: tensor<1x4x4x4xf32>, %arg1: tens
     : (tensor<1x4x4x4xf32>, tensor<8x1x1x4xf32>, tensor<7xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x4x4x8xf32>
   return %0 : tensor<1x4x4x8xf32>
 }
+
+// -----
+
+func.func @test_fft2d_same_operands_and_result_element_type(%arg0: tensor<1x4x8xf32>, %arg1: tensor<1x4x8xf32>) -> (tensor<1x4x8xf16>, tensor<1x4x8xf16>) {
+  // expected-error@+1 {{'tosa.fft2d' op requires the same element type for all operands and results}}
+  %0, %1 = tosa.fft2d %arg0, %arg1 {inverse = false} : (tensor<1x4x8xf32>, tensor<1x4x8xf32>) -> (tensor<1x4x8xf16>, tensor<1x4x8xf16>)
+  return %0, %1 : tensor<1x4x8xf16>, tensor<1x4x8xf16>
+}
+
+// -----
+
+func.func @test_fft2d_same_operands_and_result_shape(%arg0: tensor<1x4x8xf32>, %arg1: tensor<1x4x7xf32>) -> (tensor<1x4x8xf32>, tensor<1x4x8xf32>) {
+  // expected-error@+1 {{'tosa.fft2d' op requires the same shape for all operands and results}}
+  %0, %1 = tosa.fft2d %arg0, %arg1 {inverse = false} : (tensor<1x4x8xf32>, tensor<1x4x7xf32>) -> (tensor<1x4x8xf32>, tensor<1x4x8xf32>)
+  return %0, %1 : tensor<1x4x8xf32>, tensor<1x4x8xf32>
+}
+
+// -----
+
+func.func @test_fft2d_invalid_type(%arg0: tensor<1x4x8xi8>, %arg1: tensor<1x4x8xi8>) -> (tensor<1x4x8xi8>, tensor<1x4x8xi8>) {
+  // expected-error@+1 {{'tosa.fft2d' op requires a floating point type}}
+  %0, %1 = tosa.fft2d %arg0, %arg1 {inverse = false} : (tensor<1x4x8xi8>, tensor<1x4x8xi8>) -> (tensor<1x4x8xi8>, tensor<1x4x8xi8>)
+  return %0, %1 : tensor<1x4x8xi8>, tensor<1x4x8xi8>
+}
+
+// -----
+
+func.func @test_fft2d_height_non_power_of_two(%arg0: tensor<1x5x8xf32>, %arg1: tensor<1x5x8xf32>) -> (tensor<1x5x8xf32>, tensor<1x5x8xf32>) {
+  // expected-error@+1 {{'tosa.fft2d' op expected height to be a power of two, got 5}}
+  %0, %1 = tosa.fft2d %arg0, %arg1 {inverse = false} : (tensor<1x5x8xf32>, tensor<1x5x8xf32>) -> (tensor<1x5x8xf32>, tensor<1x5x8xf32>)
+  return %0, %1 : tensor<1x5x8xf32>, tensor<1x5x8xf32>
+}
+
+// -----
+
+func.func @test_rfft2d_same_operands_and_result_element_type(%arg0: tensor<1x4x8xf32>) -> (tensor<1x4x5xf16>, tensor<1x4x5xf16>) {
+  // expected-error@+1 {{'tosa.rfft2d' op requires the same element type for all operands and results}}
+  %0, %1 = tosa.rfft2d %arg0 {inverse = false} : (tensor<1x4x8xf32>) -> (tensor<1x4x5xf16>, tensor<1x4x5xf16>)
+  return %0, %1 : tensor<1x4x5xf16>, tensor<1x4x5xf16>
+}
+
+// -----
+
+func.func @test_rfft2d_same_results_shape(%arg0: tensor<1x4x8xf32>) -> (tensor<1x4x6xf32>, tensor<1x4x5xf32>) {
+  // expected-error@+1 {{'tosa.rfft2d' op expected output shapes to match, got 'tensor<1x4x6xf32>', 'tensor<1x4x5xf32>'}}
+  %0, %1 = tosa.rfft2d %arg0 {inverse = false} : (tensor<1x4x8xf32>) -> (tensor<1x4x6xf32>, tensor<1x4x5xf32>)
+  return %0, %1 : tensor<1x4x6xf32>, tensor<1x4x5xf32>
+}
+
+// -----
+
+func.func @test_rfft2d_invalid_type(%arg0: tensor<1x4x8xi16>) -> (tensor<1x4x5xi16>, tensor<1x4x5xi16>) {
+  // expected-error@+1 {{'tosa.rfft2d' op requires a floating point type}}
+  %0, %1 = tosa.rfft2d %arg0 {inverse = false} : (tensor<1x4x8xi16>) -> (tensor<1x4x5xi16>, tensor<1x4x5xi16>)
+  return %0, %1 : tensor<1x4x5xi16>, tensor<1x4x5xi16>
+}
+
+// -----
+
+func.func @test_rfft2d_width_power_of_two(%arg0: tensor<1x4x9xf16>) -> (tensor<1x4x5xf16>, tensor<1x4x5xf16>) {
+  // expected-error@+1 {{'tosa.rfft2d' op expected width to be a power of two, got 9}}
+  %0, %1 = tosa.rfft2d %arg0 {inverse = false} : (tensor<1x4x9xf16>) -> (tensor<1x4x5xf16>, tensor<1x4x5xf16>)
+  return %0, %1 : tensor<1x4x5xf16>, tensor<1x4x5xf16>
+}
+
+// -----
+
+func.func @test_rfft2d_batch_input_output_match(%arg0: tensor<1x4x8xf16>) -> (tensor<2x4x5xf16>, tensor<2x4x5xf16>) {
+  // expected-error@+1 {{'tosa.rfft2d' op expected batch and height dimensions of input/output to match, got input='tensor<1x4x8xf16>' output='tensor<2x4x5xf16>'}}
+  %0, %1 = tosa.rfft2d %arg0 {inverse = false} : (tensor<1x4x8xf16>) -> (tensor<2x4x5xf16>, tensor<2x4x5xf16>)
+  return %0, %1 : tensor<2x4x5xf16>, tensor<2x4x5xf16>
+}
+
+// -----
+
+func.func @test_rfft2d_width_input_output_match(%arg0: tensor<1x4x8xf16>) -> (tensor<1x4x3xf16>, tensor<1x4x3xf16>) {
+  // expected-error@+1 {{'tosa.rfft2d' op expected output width to be equal to input_width / 2 + 1, got 3}}
+  %0, %1 = tosa.rfft2d %arg0 {inverse = false} : (tensor<1x4x8xf16>) -> (tensor<1x4x3xf16>, tensor<1x4x3xf16>)
+  return %0, %1 : tensor<1x4x3xf16>, tensor<1x4x3xf16>
+}
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index d2958efe1bb24..19584103d40c7 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -477,38 +477,38 @@ func.func @test_depthwise_conv2d_stride_x(%arg0: tensor<1x32x32x8xf32>, %arg1: t
 
 // -----
 
-func.func @test_fft2d_real_h(%arg0: tensor<32x8193x32xf32>, %arg1: tensor<32x32x32xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>) {
+func.func @test_fft2d_real_h(%arg0: tensor<32x16384x32xf32>, %arg1: tensor<32x16384x32xf32>) -> (tensor<32x16384x32xf32>, tensor<32x16384x32xf32>) {
   // expected-error@+1 {{'tosa.fft2d' op failed level check: H <= MAX_KERNEL}}
   %0, %1 = "tosa.fft2d"(%arg0, %arg1) { inverse = false } :
-            (tensor<32x8193x32xf32>, tensor<32x32x32xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>)
-  return %0, %1 : tensor<32x32x32xf32>, tensor<32x32x32xf32>
+            (tensor<32x16384x32xf32>, tensor<32x16384x32xf32>) -> (tensor<32x16384x32xf32>, tensor<32x16384x32xf32>)
+  return %0, %1 : tensor<32x16384x32xf32>, tensor<32x16384x32xf32>
 }
 
 // -----
 
-func.func @test_fft2d_real_w(%arg0: tensor<32x32x8193xf32>, %arg1: tensor<32x32x32xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>) {
+func.func @test_fft2d_real_w(%arg0: tensor<32x32x16384xf32>, %arg1: tensor<32x32x16384xf32>) -> (tensor<32x32x16384xf32>, tensor<32x32x16384xf32>) {
   // expected-error@+1 {{'tosa.fft2d' op failed level check: W <= MAX_KERNEL}}
   %0, %1 = "tosa.fft2d"(%arg0, %arg1) { inverse = false } :
-            (tensor<32x32x8193xf32>, tensor<32x32x32xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>)
-  return %0, %1 : tensor<32x32x32xf32>, tensor<32x32x32xf32>
+            (tensor<32x32x16384xf32>, tensor<32x32x16384xf32>) -> (tensor<32x32x16384xf32>, tensor<32x32x16384xf32>)
+  return %0, %1 : tensor<32x32x16384xf32>, tensor<32x32x16384xf32>
 }
 
 // -----
 
-func.func @test_fft2d_imag_h(%arg0: tensor<32x32x32xf32>, %arg1: tensor<32x8193x32xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>) {
+func.func @test_fft2d_imag_h(%arg0: tensor<32x16384x32xf32>, %arg1: tensor<32x16384x32xf32>) -> (tensor<32x16384x32xf32>, tensor<32x16384x32xf32>) {
   // expected-error@+1 {{'tosa.fft2d' op failed level check: H <= MAX_KERNEL}}
   %0, %1 = "tosa.fft2d"(%arg0, %arg1) { inverse = false } :
-            (tensor<32x32x32xf32>, tensor<32x8193x32xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>)
-  return %0, %1 : tensor<32x32x32xf32>, tensor<32x32x32xf32>
+            (tensor<32x16384x32xf32>, tensor<32x16384x32xf32>) -> (tensor<32x16384x32xf32>, tensor<32x16384x32xf32>)
+  return %0, %1 : tensor<32x16384x32xf32>, tensor<32x16384x32xf32>
 }
 
 // -----
 
-func.func @test_fft2d_imag_w(%arg0: tensor<32x32x32xf32>, %arg1: tensor<32x32x8193xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>) {
+func.func @test_fft2d_imag_w(%arg0: tensor<32x32x16384xf32>, %arg1: tensor<32x32x16384xf32>) -> (tensor<32x32x16384xf32>, tensor<32x32x16384xf32>) {
   // expected-error@+1 {{'tosa.fft2d' op failed level check: W <= MAX_KERNEL}}
   %0, %1 = "tosa.fft2d"(%arg0, %arg1) { inverse = false } :
-            (tensor<32x32x32xf32>, tensor<32x32x8193xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>)
-  return %0, %1 : tensor<32x32x32xf32>, tensor<32x32x32xf32>
+            (tensor<32x32x16384xf32>, tensor<32x32x16384xf32>) -> (tensor<32x32x16384xf32>, tensor<32x32x16384xf32>)
+  return %0, %1 : tensor<32x32x16384xf32>, tensor<32x32x16384xf32>
 }
 
 // -----
@@ -577,18 +577,18 @@ func.func @test_maxpool2d_p...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Feb 28, 2025

@llvm/pr-subscribers-mlir

Author: Luke Hutton (lhutton1)

Changes

Adds checks for element types and input/output shapes.


Patch is 21.37 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/129273.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+11-2)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+79-1)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir (-9)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+12-12)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+80)
  • (modified) mlir/test/Dialect/Tosa/level_check.mlir (+18-18)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index ddfec2c9bfcd3..9e14daf0d014c 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -240,7 +240,10 @@ def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> {
 //===----------------------------------------------------------------------===//
 // Operator: fft2d
 //===----------------------------------------------------------------------===//
-def Tosa_FFT2dOp : Tosa_InferShapedTypeOp<"fft2d"> {
+def Tosa_FFT2dOp : Tosa_InferShapedTypeOp<"fft2d", [
+    SameOperandsAndResultElementType,
+    SameOperandsAndResultShape,
+    ResultsAreFloatLike]> {
   let summary = "Performs FFT2D operation on the input.";
 
   let description = [{
@@ -279,6 +282,8 @@ def Tosa_FFT2dOp : Tosa_InferShapedTypeOp<"fft2d"> {
     $input_real `,` $input_imag attr-dict `:` `(` type($input_real) `,`
     type($input_imag) `)` `->` `(` type($output_real) `,` type($output_imag) `)`
   }];
+
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -349,7 +354,9 @@ def Tosa_MaxPool2dOp : Tosa_InferShapedTypeOp<"max_pool2d"> {
 //===----------------------------------------------------------------------===//
 // Operator: rfft2d
 //===----------------------------------------------------------------------===//
-def Tosa_RFFT2dOp : Tosa_InferShapedTypeOp<"rfft2d"> {
+def Tosa_RFFT2dOp : Tosa_InferShapedTypeOp<"rfft2d", [
+    SameOperandsAndResultElementType,
+    ResultsAreFloatLike]> {
   let summary = "Performs RFFT2D operation on the input.";
 
   let description = [{
@@ -385,6 +392,8 @@ def Tosa_RFFT2dOp : Tosa_InferShapedTypeOp<"rfft2d"> {
   let assemblyFormat = [{
     $input attr-dict `:` `(` type($input) `)` `->` `(` type($output_real) `,` type($output_imag) `)`
   }];
+
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 7b50eceb081dd..63afaed22d7ab 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -789,7 +789,7 @@ LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
   int64_t inWidth = inputShape.getDimSize(2);
 
   // Note that we can support this calculation symbolically
-  // in the future e.g. [x, y, z] -> [x, y, z / 2 - 1]
+  // in the future e.g. [x, y, z] -> [x, y, z / 2 + 1]
   if (inWidth != ShapedType::kDynamic)
     outputShape[2] = inWidth / 2 + 1;
 
@@ -799,6 +799,57 @@ LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
   return success();
 }
 
+static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize,
+                                           const llvm::StringRef dimName) {
+  const bool isPowerOfTwo = (dimSize & (dimSize - 1)) == 0 && dimSize > 0;
+  if (!isPowerOfTwo)
+    return op->emitOpError("expected ")
+           << dimName << " to be a power of two, got " << dimSize;
+
+  return success();
+}
+
+LogicalResult tosa::RFFT2dOp::verify() {
+  const auto outputTypes = getResultTypes();
+  if (failed(verifyCompatibleShapes(outputTypes)))
+    return emitOpError("expected output shapes to match, got ") << outputTypes;
+
+  const auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().getType());
+  if (!inputType)
+    return success();
+
+  const int64_t height = inputType.getDimSize(1);
+  if (!ShapedType::isDynamic(height) &&
+      failed(verifyDimIsPowerOfTwo(*this, height, "height")))
+    return failure();
+
+  const int64_t width = inputType.getDimSize(2);
+  if (!ShapedType::isDynamic(width) &&
+      failed(verifyDimIsPowerOfTwo(*this, width, "width")))
+    return failure();
+
+  const auto outputType = llvm::dyn_cast<RankedTensorType>(outputTypes[0]);
+  if (!outputType)
+    return success();
+
+  // Batch and height input/output dimensions should match
+  if (failed(verifyCompatibleShape(inputType.getShape().drop_back(),
+                                   outputType.getShape().drop_back())))
+    return emitOpError("expected batch and height dimensions of input/output "
+                       "to match, got input=")
+           << inputType << " output=" << outputType;
+
+  // Output width dimension expected to be input_width / 2 + 1
+  const int64_t outputWidth = outputType.getDimSize(2);
+  if (!ShapedType::isDynamic(width) && !ShapedType::isDynamic(outputWidth) &&
+      (outputWidth - 1) * 2 != width)
+    return emitOpError(
+               "expected output width to be equal to input_width / 2 + 1, got ")
+           << outputWidth;
+
+  return success();
+}
+
 LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     FFT2dOp::Adaptor adaptor,
@@ -810,6 +861,33 @@ LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
   return success();
 }
 
+LogicalResult tosa::FFT2dOp::verify() {
+  const auto inputRealType =
+      llvm::dyn_cast<RankedTensorType>(getInputReal().getType());
+  const auto inputImagType =
+      llvm::dyn_cast<RankedTensorType>(getInputImag().getType());
+  if (!inputRealType || !inputImagType)
+    return success();
+
+  const auto trySelectStaticDim = [](const int64_t a, const int64_t b) {
+    return ShapedType::isDynamic(a) ? a : b;
+  };
+
+  const int64_t height = trySelectStaticDim(inputRealType.getDimSize(1),
+                                            inputImagType.getDimSize(1));
+  if (!ShapedType::isDynamic(height) &&
+      failed(verifyDimIsPowerOfTwo(*this, height, "height")))
+    return failure();
+
+  const int64_t width = trySelectStaticDim(inputRealType.getDimSize(2),
+                                           inputImagType.getDimSize(2));
+  if (!ShapedType::isDynamic(width) &&
+      failed(verifyDimIsPowerOfTwo(*this, width, "width")))
+    return failure();
+
+  return success();
+}
+
 LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     ConcatOp::Adaptor adaptor,
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
index 5db3f56cf459e..71fcd4129a618 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
@@ -31,15 +31,6 @@ func.func @unranked_add(%arg0 : tensor<10x10xf32> , %arg1 : tensor<10x10xf32>, %
 
 // -----
 
-// CHECK-LABEL: @rfft2d_with_non_float_type
-func.func @rfft2d_with_non_float_type(%arg0 : tensor<1x1x1xi32>) -> (tensor<1x1x1xi32>, tensor<1x1x1xi32>) {
-  // expected-error@+1 {{failed to legalize operation 'tosa.rfft2d'}}
-  %real, %imag = tosa.rfft2d %arg0 : (tensor<1x1x1xi32>) -> (tensor<1x1x1xi32>, tensor<1x1x1xi32>)
-  return %real, %imag : tensor<1x1x1xi32>, tensor<1x1x1xi32>
-}
-
-// -----
-
 // CHECK-LABEL: @rescale_unsupported_type
 func.func @rescale_unsupported_type(%arg0: tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>> {
   // expected-error@+1 {{failed to legalize operation 'tosa.rescale'}}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 78f2e173d7cb1..e68783f779063 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1720,20 +1720,20 @@ func.func @table8_dyn_table(%arg0: tensor<6xi8>, %arg1: tensor<?xi8>) -> () {
 // CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
 
 // CHECK-LABEL:   func.func @test_static_rfft2d(
-// CHECK-SAME:                                  %[[VAL_0:.*]]: tensor<5x5x8xf32>) -> (tensor<5x5x5xf32>, tensor<5x5x5xf32>) {
+// CHECK-SAME:                                  %[[VAL_0:.*]]: tensor<5x4x8xf32>) -> (tensor<5x4x5xf32>, tensor<5x4x5xf32>) {
 // CHECK:           %[[VAL_1:.*]] = arith.constant 1 : index
 // CHECK:           %[[VAL_2:.*]] = arith.constant 2 : index
 // CHECK:           %[[VAL_3:.*]] = arith.constant 8 : index
 // CHECK:           %[[VAL_4:.*]] = arith.constant 4 : index
 // CHECK:           %[[VAL_5:.*]] = arith.constant 5 : index
-// CHECK:           %[[VAL_6:.*]] = tensor.empty() : tensor<5x5x5xf32>
+// CHECK:           %[[VAL_6:.*]] = tensor.empty() : tensor<5x4x5xf32>
 // CHECK:           %[[VAL_7:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK:           %[[VAL_8:.*]] = linalg.fill ins(%[[VAL_7]] : f32) outs(%[[VAL_6]] : tensor<5x5x5xf32>) -> tensor<5x5x5xf32>
-// CHECK:           %[[VAL_9:.*]] = tensor.empty() : tensor<5x5x5xf32>
+// CHECK:           %[[VAL_8:.*]] = linalg.fill ins(%[[VAL_7]] : f32) outs(%[[VAL_6]] : tensor<5x4x5xf32>) -> tensor<5x4x5xf32>
+// CHECK:           %[[VAL_9:.*]] = tensor.empty() : tensor<5x4x5xf32>
 // CHECK:           %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK:           %[[VAL_11:.*]] = linalg.fill ins(%[[VAL_10]] : f32) outs(%[[VAL_9]] : tensor<5x5x5xf32>) -> tensor<5x5x5xf32>
+// CHECK:           %[[VAL_11:.*]] = linalg.fill ins(%[[VAL_10]] : f32) outs(%[[VAL_9]] : tensor<5x4x5xf32>) -> tensor<5x4x5xf32>
 // CHECK:           %[[VAL_12:.*]] = arith.constant 1 : index
-// CHECK:           %[[VAL_13:.*]] = arith.constant 5 : index
+// CHECK:           %[[VAL_13:.*]] = arith.constant 4 : index
 // CHECK:           %[[VAL_14:.*]] = arith.constant 2 : index
 // CHECK:           %[[VAL_15:.*]] = arith.constant 8 : index
 // CHECK:           %[[VAL_16:.*]] = arith.constant 6.28318548 : f32
@@ -1741,7 +1741,7 @@ func.func @table8_dyn_table(%arg0: tensor<6xi8>, %arg1: tensor<?xi8>) -> () {
 // CHECK:           %[[VAL_18:.*]] = arith.uitofp %[[VAL_17]] : i32 to f32
 // CHECK:           %[[VAL_19:.*]] = arith.index_castui %[[VAL_15]] : index to i32
 // CHECK:           %[[VAL_20:.*]] = arith.uitofp %[[VAL_19]] : i32 to f32
-// CHECK:           %[[VAL_21:.*]]:2 = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]]], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%[[VAL_0]] : tensor<5x5x8xf32>) outs(%[[VAL_8]], %[[VAL_11]] : tensor<5x5x5xf32>, tensor<5x5x5xf32>) {
+// CHECK:           %[[VAL_21:.*]]:2 = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]]], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%[[VAL_0]] : tensor<5x4x8xf32>) outs(%[[VAL_8]], %[[VAL_11]] : tensor<5x4x5xf32>, tensor<5x4x5xf32>) {
 // CHECK:           ^bb0(%[[VAL_22:.*]]: f32, %[[VAL_23:.*]]: f32, %[[VAL_24:.*]]: f32):
 // CHECK:             %[[VAL_25:.*]] = linalg.index 1 : index
 // CHECK:             %[[VAL_26:.*]] = linalg.index 2 : index
@@ -1766,12 +1766,12 @@ func.func @table8_dyn_table(%arg0: tensor<6xi8>, %arg1: tensor<?xi8>) -> () {
 // CHECK:             %[[VAL_45:.*]] = arith.addf %[[VAL_23]], %[[VAL_43]] : f32
 // CHECK:             %[[VAL_46:.*]] = arith.subf %[[VAL_24]], %[[VAL_44]] : f32
 // CHECK:             linalg.yield %[[VAL_45]], %[[VAL_46]] : f32, f32
-// CHECK:           } -> (tensor<5x5x5xf32>, tensor<5x5x5xf32>)
-// CHECK:           return %[[VAL_47:.*]]#0, %[[VAL_47]]#1 : tensor<5x5x5xf32>, tensor<5x5x5xf32>
+// CHECK:           } -> (tensor<5x4x5xf32>, tensor<5x4x5xf32>)
+// CHECK:           return %[[VAL_47:.*]]#0, %[[VAL_47]]#1 : tensor<5x4x5xf32>, tensor<5x4x5xf32>
 // CHECK:         }
-func.func @test_static_rfft2d(%arg0: tensor<5x5x8xf32>) -> (tensor<5x5x5xf32>, tensor<5x5x5xf32>) {
-  %output_real, %output_imag = "tosa.rfft2d"(%arg0) {} : (tensor<5x5x8xf32>) -> (tensor<5x5x5xf32>, tensor<5x5x5xf32>)
-  return %output_real, %output_imag : tensor<5x5x5xf32>, tensor<5x5x5xf32>
+func.func @test_static_rfft2d(%arg0: tensor<5x4x8xf32>) -> (tensor<5x4x5xf32>, tensor<5x4x5xf32>) {
+  %output_real, %output_imag = "tosa.rfft2d"(%arg0) {} : (tensor<5x4x8xf32>) -> (tensor<5x4x5xf32>, tensor<5x4x5xf32>)
+  return %output_real, %output_imag : tensor<5x4x5xf32>, tensor<5x4x5xf32>
 }
 
 // -----
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 123c65e1b4fcd..fb7a9222947f6 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -1276,3 +1276,83 @@ func.func @test_conv2d_invalid_bias_size(%arg0: tensor<1x4x4x4xf32>, %arg1: tens
     : (tensor<1x4x4x4xf32>, tensor<8x1x1x4xf32>, tensor<7xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x4x4x8xf32>
   return %0 : tensor<1x4x4x8xf32>
 }
+
+// -----
+
+func.func @test_fft2d_same_operands_and_result_element_type(%arg0: tensor<1x4x8xf32>, %arg1: tensor<1x4x8xf32>) -> (tensor<1x4x8xf16>, tensor<1x4x8xf16>) {
+  // expected-error@+1 {{'tosa.fft2d' op requires the same element type for all operands and results}}
+  %0, %1 = tosa.fft2d %arg0, %arg1 {inverse = false} : (tensor<1x4x8xf32>, tensor<1x4x8xf32>) -> (tensor<1x4x8xf16>, tensor<1x4x8xf16>)
+  return %0, %1 : tensor<1x4x8xf16>, tensor<1x4x8xf16>
+}
+
+// -----
+
+func.func @test_fft2d_same_operands_and_result_shape(%arg0: tensor<1x4x8xf32>, %arg1: tensor<1x4x7xf32>) -> (tensor<1x4x8xf32>, tensor<1x4x8xf32>) {
+  // expected-error@+1 {{'tosa.fft2d' op requires the same shape for all operands and results}}
+  %0, %1 = tosa.fft2d %arg0, %arg1 {inverse = false} : (tensor<1x4x8xf32>, tensor<1x4x7xf32>) -> (tensor<1x4x8xf32>, tensor<1x4x8xf32>)
+  return %0, %1 : tensor<1x4x8xf32>, tensor<1x4x8xf32>
+}
+
+// -----
+
+func.func @test_fft2d_invalid_type(%arg0: tensor<1x4x8xi8>, %arg1: tensor<1x4x8xi8>) -> (tensor<1x4x8xi8>, tensor<1x4x8xi8>) {
+  // expected-error@+1 {{'tosa.fft2d' op requires a floating point type}}
+  %0, %1 = tosa.fft2d %arg0, %arg1 {inverse = false} : (tensor<1x4x8xi8>, tensor<1x4x8xi8>) -> (tensor<1x4x8xi8>, tensor<1x4x8xi8>)
+  return %0, %1 : tensor<1x4x8xi8>, tensor<1x4x8xi8>
+}
+
+// -----
+
+func.func @test_fft2d_height_non_power_of_two(%arg0: tensor<1x5x8xf32>, %arg1: tensor<1x5x8xf32>) -> (tensor<1x5x8xf32>, tensor<1x5x8xf32>) {
+  // expected-error@+1 {{'tosa.fft2d' op expected height to be a power of two, got 5}}
+  %0, %1 = tosa.fft2d %arg0, %arg1 {inverse = false} : (tensor<1x5x8xf32>, tensor<1x5x8xf32>) -> (tensor<1x5x8xf32>, tensor<1x5x8xf32>)
+  return %0, %1 : tensor<1x5x8xf32>, tensor<1x5x8xf32>
+}
+
+// -----
+
+func.func @test_rfft2d_same_operands_and_result_element_type(%arg0: tensor<1x4x8xf32>) -> (tensor<1x4x5xf16>, tensor<1x4x5xf16>) {
+  // expected-error@+1 {{'tosa.rfft2d' op requires the same element type for all operands and results}}
+  %0, %1 = tosa.rfft2d %arg0 {inverse = false} : (tensor<1x4x8xf32>) -> (tensor<1x4x5xf16>, tensor<1x4x5xf16>)
+  return %0, %1 : tensor<1x4x5xf16>, tensor<1x4x5xf16>
+}
+
+// -----
+
+func.func @test_rfft2d_same_results_shape(%arg0: tensor<1x4x8xf32>) -> (tensor<1x4x6xf32>, tensor<1x4x5xf32>) {
+  // expected-error@+1 {{'tosa.rfft2d' op expected output shapes to match, got 'tensor<1x4x6xf32>', 'tensor<1x4x5xf32>'}}
+  %0, %1 = tosa.rfft2d %arg0 {inverse = false} : (tensor<1x4x8xf32>) -> (tensor<1x4x6xf32>, tensor<1x4x5xf32>)
+  return %0, %1 : tensor<1x4x6xf32>, tensor<1x4x5xf32>
+}
+
+// -----
+
+func.func @test_rfft2d_invalid_type(%arg0: tensor<1x4x8xi16>) -> (tensor<1x4x5xi16>, tensor<1x4x5xi16>) {
+  // expected-error@+1 {{'tosa.rfft2d' op requires a floating point type}}
+  %0, %1 = tosa.rfft2d %arg0 {inverse = false} : (tensor<1x4x8xi16>) -> (tensor<1x4x5xi16>, tensor<1x4x5xi16>)
+  return %0, %1 : tensor<1x4x5xi16>, tensor<1x4x5xi16>
+}
+
+// -----
+
+func.func @test_rfft2d_width_power_of_two(%arg0: tensor<1x4x9xf16>) -> (tensor<1x4x5xf16>, tensor<1x4x5xf16>) {
+  // expected-error@+1 {{'tosa.rfft2d' op expected width to be a power of two, got 9}}
+  %0, %1 = tosa.rfft2d %arg0 {inverse = false} : (tensor<1x4x9xf16>) -> (tensor<1x4x5xf16>, tensor<1x4x5xf16>)
+  return %0, %1 : tensor<1x4x5xf16>, tensor<1x4x5xf16>
+}
+
+// -----
+
+func.func @test_rfft2d_batch_input_output_match(%arg0: tensor<1x4x8xf16>) -> (tensor<2x4x5xf16>, tensor<2x4x5xf16>) {
+  // expected-error@+1 {{'tosa.rfft2d' op expected batch and height dimensions of input/output to match, got input='tensor<1x4x8xf16>' output='tensor<2x4x5xf16>'}}
+  %0, %1 = tosa.rfft2d %arg0 {inverse = false} : (tensor<1x4x8xf16>) -> (tensor<2x4x5xf16>, tensor<2x4x5xf16>)
+  return %0, %1 : tensor<2x4x5xf16>, tensor<2x4x5xf16>
+}
+
+// -----
+
+func.func @test_rfft2d_width_input_output_match(%arg0: tensor<1x4x8xf16>) -> (tensor<1x4x3xf16>, tensor<1x4x3xf16>) {
+  // expected-error@+1 {{'tosa.rfft2d' op expected output width to be equal to input_width / 2 + 1, got 3}}
+  %0, %1 = tosa.rfft2d %arg0 {inverse = false} : (tensor<1x4x8xf16>) -> (tensor<1x4x3xf16>, tensor<1x4x3xf16>)
+  return %0, %1 : tensor<1x4x3xf16>, tensor<1x4x3xf16>
+}
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index d2958efe1bb24..19584103d40c7 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -477,38 +477,38 @@ func.func @test_depthwise_conv2d_stride_x(%arg0: tensor<1x32x32x8xf32>, %arg1: t
 
 // -----
 
-func.func @test_fft2d_real_h(%arg0: tensor<32x8193x32xf32>, %arg1: tensor<32x32x32xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>) {
+func.func @test_fft2d_real_h(%arg0: tensor<32x16384x32xf32>, %arg1: tensor<32x16384x32xf32>) -> (tensor<32x16384x32xf32>, tensor<32x16384x32xf32>) {
   // expected-error@+1 {{'tosa.fft2d' op failed level check: H <= MAX_KERNEL}}
   %0, %1 = "tosa.fft2d"(%arg0, %arg1) { inverse = false } :
-            (tensor<32x8193x32xf32>, tensor<32x32x32xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>)
-  return %0, %1 : tensor<32x32x32xf32>, tensor<32x32x32xf32>
+            (tensor<32x16384x32xf32>, tensor<32x16384x32xf32>) -> (tensor<32x16384x32xf32>, tensor<32x16384x32xf32>)
+  return %0, %1 : tensor<32x16384x32xf32>, tensor<32x16384x32xf32>
 }
 
 // -----
 
-func.func @test_fft2d_real_w(%arg0: tensor<32x32x8193xf32>, %arg1: tensor<32x32x32xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>) {
+func.func @test_fft2d_real_w(%arg0: tensor<32x32x16384xf32>, %arg1: tensor<32x32x16384xf32>) -> (tensor<32x32x16384xf32>, tensor<32x32x16384xf32>) {
   // expected-error@+1 {{'tosa.fft2d' op failed level check: W <= MAX_KERNEL}}
   %0, %1 = "tosa.fft2d"(%arg0, %arg1) { inverse = false } :
-            (tensor<32x32x8193xf32>, tensor<32x32x32xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>)
-  return %0, %1 : tensor<32x32x32xf32>, tensor<32x32x32xf32>
+            (tensor<32x32x16384xf32>, tensor<32x32x16384xf32>) -> (tensor<32x32x16384xf32>, tensor<32x32x16384xf32>)
+  return %0, %1 : tensor<32x32x16384xf32>, tensor<32x32x16384xf32>
 }
 
 // -----
 
-func.func @test_fft2d_imag_h(%arg0: tensor<32x32x32xf32>, %arg1: tensor<32x8193x32xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>) {
+func.func @test_fft2d_imag_h(%arg0: tensor<32x16384x32xf32>, %arg1: tensor<32x16384x32xf32>) -> (tensor<32x16384x32xf32>, tensor<32x16384x32xf32>) {
   // expected-error@+1 {{'tosa.fft2d' op failed level check: H <= MAX_KERNEL}}
   %0, %1 = "tosa.fft2d"(%arg0, %arg1) { inverse = false } :
-            (tensor<32x32x32xf32>, tensor<32x8193x32xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>)
-  return %0, %1 : tensor<32x32x32xf32>, tensor<32x32x32xf32>
+            (tensor<32x16384x32xf32>, tensor<32x16384x32xf32>) -> (tensor<32x16384x32xf32>, tensor<32x16384x32xf32>)
+  return %0, %1 : tensor<32x16384x32xf32>, tensor<32x16384x32xf32>
 }
 
 // -----
 
-func.func @test_fft2d_imag_w(%arg0: tensor<32x32x32xf32>, %arg1: tensor<32x32x8193xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>) {
+func.func @test_fft2d_imag_w(%arg0: tensor<32x32x16384xf32>, %arg1: tensor<32x32x16384xf32>) -> (tensor<32x32x16384xf32>, tensor<32x32x16384xf32>) {
   // expected-error@+1 {{'tosa.fft2d' op failed level check: W <= MAX_KERNEL}}
   %0, %1 = "tosa.fft2d"(%arg0, %arg1) { inverse = false } :
-            (tensor<32x32x32xf32>, tensor<32x32x8193xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>)
-  return %0, %1 : tensor<32x32x32xf32>, tensor<32x32x32xf32>
+            (tensor<32x32x16384xf32>, tensor<32x32x16384xf32>) -> (tensor<32x32x16384xf32>, tensor<32x32x16384xf32>)
+  return %0, %1 : tensor<32x32x16384xf32>, tensor<32x32x16384xf32>
 }
 
 // -----
@@ -577,18 +577,18 @@ func.func @test_maxpool2d_p...
[truncated]

@Jerry-Ge Jerry-Ge merged commit 079557c into llvm:main Mar 3, 2025
15 checks passed
@lhutton1 lhutton1 deleted the fft2d-rfft2d-verify branch March 5, 2025 12:06
jph-13 pushed a commit to jph-13/llvm-project that referenced this pull request Mar 21, 2025
Adds checks for element types and input/output shapes.

Signed-off-by: Luke Hutton <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants