Skip to content

[mlir][tosa] Add verifiers for FFT2d and RFFT2d #129273

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,10 @@ def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> {
//===----------------------------------------------------------------------===//
// Operator: fft2d
//===----------------------------------------------------------------------===//
def Tosa_FFT2dOp : Tosa_InferShapedTypeOp<"fft2d"> {
def Tosa_FFT2dOp : Tosa_InferShapedTypeOp<"fft2d", [
SameOperandsAndResultElementType,
SameOperandsAndResultShape,
ResultsAreFloatLike]> {
let summary = "Performs FFT2D operation on the input.";

let description = [{
Expand Down Expand Up @@ -279,6 +282,8 @@ def Tosa_FFT2dOp : Tosa_InferShapedTypeOp<"fft2d"> {
$input_real `,` $input_imag attr-dict `:` `(` type($input_real) `,`
type($input_imag) `)` `->` `(` type($output_real) `,` type($output_imag) `)`
}];

let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -349,7 +354,9 @@ def Tosa_MaxPool2dOp : Tosa_InferShapedTypeOp<"max_pool2d"> {
//===----------------------------------------------------------------------===//
// Operator: rfft2d
//===----------------------------------------------------------------------===//
def Tosa_RFFT2dOp : Tosa_InferShapedTypeOp<"rfft2d"> {
def Tosa_RFFT2dOp : Tosa_InferShapedTypeOp<"rfft2d", [
SameOperandsAndResultElementType,
ResultsAreFloatLike]> {
let summary = "Performs RFFT2D operation on the input.";

let description = [{
Expand Down Expand Up @@ -385,6 +392,8 @@ def Tosa_RFFT2dOp : Tosa_InferShapedTypeOp<"rfft2d"> {
let assemblyFormat = [{
$input attr-dict `:` `(` type($input) `)` `->` `(` type($output_real) `,` type($output_imag) `)`
}];

let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
Expand Down
80 changes: 79 additions & 1 deletion mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -789,7 +789,7 @@ LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
int64_t inWidth = inputShape.getDimSize(2);

// Note that we can support this calculation symbolically
// in the future e.g. [x, y, z] -> [x, y, z / 2 - 1]
// in the future e.g. [x, y, z] -> [x, y, z / 2 + 1]
if (inWidth != ShapedType::kDynamic)
outputShape[2] = inWidth / 2 + 1;

Expand All @@ -799,6 +799,57 @@ LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
return success();
}

static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize,
const llvm::StringRef dimName) {
const bool isPowerOfTwo = (dimSize & (dimSize - 1)) == 0 && dimSize > 0;
if (!isPowerOfTwo)
return op->emitOpError("expected ")
<< dimName << " to be a power of two, got " << dimSize;

return success();
}

LogicalResult tosa::RFFT2dOp::verify() {
const auto outputTypes = getResultTypes();
if (failed(verifyCompatibleShapes(outputTypes)))
return emitOpError("expected output shapes to match, got ") << outputTypes;

const auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().getType());
if (!inputType)
return success();

const int64_t height = inputType.getDimSize(1);
if (!ShapedType::isDynamic(height) &&
failed(verifyDimIsPowerOfTwo(*this, height, "height")))
return failure();

const int64_t width = inputType.getDimSize(2);
if (!ShapedType::isDynamic(width) &&
failed(verifyDimIsPowerOfTwo(*this, width, "width")))
return failure();

const auto outputType = llvm::dyn_cast<RankedTensorType>(outputTypes[0]);
if (!outputType)
return success();

// Batch and height input/output dimensions should match
if (failed(verifyCompatibleShape(inputType.getShape().drop_back(),
outputType.getShape().drop_back())))
return emitOpError("expected batch and height dimensions of input/output "
"to match, got input=")
<< inputType << " output=" << outputType;

// Output width dimension expected to be input_width / 2 + 1
const int64_t outputWidth = outputType.getDimSize(2);
if (!ShapedType::isDynamic(width) && !ShapedType::isDynamic(outputWidth) &&
(outputWidth - 1) * 2 != width)
return emitOpError(
"expected output width to be equal to input_width / 2 + 1, got ")
<< outputWidth;

return success();
}

LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
FFT2dOp::Adaptor adaptor,
Expand All @@ -810,6 +861,33 @@ LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
return success();
}

LogicalResult tosa::FFT2dOp::verify() {
const auto inputRealType =
llvm::dyn_cast<RankedTensorType>(getInputReal().getType());
const auto inputImagType =
llvm::dyn_cast<RankedTensorType>(getInputImag().getType());
if (!inputRealType || !inputImagType)
return success();

const auto trySelectStaticDim = [](const int64_t a, const int64_t b) {
return ShapedType::isDynamic(a) ? a : b;
};

const int64_t height = trySelectStaticDim(inputRealType.getDimSize(1),
inputImagType.getDimSize(1));
if (!ShapedType::isDynamic(height) &&
failed(verifyDimIsPowerOfTwo(*this, height, "height")))
return failure();

const int64_t width = trySelectStaticDim(inputRealType.getDimSize(2),
inputImagType.getDimSize(2));
if (!ShapedType::isDynamic(width) &&
failed(verifyDimIsPowerOfTwo(*this, width, "width")))
return failure();

return success();
}

LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ConcatOp::Adaptor adaptor,
Expand Down
9 changes: 0 additions & 9 deletions mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,6 @@ func.func @unranked_add(%arg0 : tensor<10x10xf32> , %arg1 : tensor<10x10xf32>, %

// -----

// CHECK-LABEL: @rfft2d_with_non_float_type
func.func @rfft2d_with_non_float_type(%arg0 : tensor<1x1x1xi32>) -> (tensor<1x1x1xi32>, tensor<1x1x1xi32>) {
// expected-error@+1 {{failed to legalize operation 'tosa.rfft2d'}}
%real, %imag = tosa.rfft2d %arg0 : (tensor<1x1x1xi32>) -> (tensor<1x1x1xi32>, tensor<1x1x1xi32>)
return %real, %imag : tensor<1x1x1xi32>, tensor<1x1x1xi32>
}

// -----

// CHECK-LABEL: @rescale_unsupported_type
func.func @rescale_unsupported_type(%arg0: tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>> {
// expected-error@+1 {{failed to legalize operation 'tosa.rescale'}}
Expand Down
24 changes: 12 additions & 12 deletions mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1720,28 +1720,28 @@ func.func @table8_dyn_table(%arg0: tensor<6xi8>, %arg1: tensor<?xi8>) -> () {
// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>

// CHECK-LABEL: func.func @test_static_rfft2d(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<5x5x8xf32>) -> (tensor<5x5x5xf32>, tensor<5x5x5xf32>) {
// CHECK-SAME: %[[VAL_0:.*]]: tensor<5x4x8xf32>) -> (tensor<5x4x5xf32>, tensor<5x4x5xf32>) {
// CHECK: %[[VAL_1:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_2:.*]] = arith.constant 2 : index
// CHECK: %[[VAL_3:.*]] = arith.constant 8 : index
// CHECK: %[[VAL_4:.*]] = arith.constant 4 : index
// CHECK: %[[VAL_5:.*]] = arith.constant 5 : index
// CHECK: %[[VAL_6:.*]] = tensor.empty() : tensor<5x5x5xf32>
// CHECK: %[[VAL_6:.*]] = tensor.empty() : tensor<5x4x5xf32>
// CHECK: %[[VAL_7:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[VAL_8:.*]] = linalg.fill ins(%[[VAL_7]] : f32) outs(%[[VAL_6]] : tensor<5x5x5xf32>) -> tensor<5x5x5xf32>
// CHECK: %[[VAL_9:.*]] = tensor.empty() : tensor<5x5x5xf32>
// CHECK: %[[VAL_8:.*]] = linalg.fill ins(%[[VAL_7]] : f32) outs(%[[VAL_6]] : tensor<5x4x5xf32>) -> tensor<5x4x5xf32>
// CHECK: %[[VAL_9:.*]] = tensor.empty() : tensor<5x4x5xf32>
// CHECK: %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[VAL_11:.*]] = linalg.fill ins(%[[VAL_10]] : f32) outs(%[[VAL_9]] : tensor<5x5x5xf32>) -> tensor<5x5x5xf32>
// CHECK: %[[VAL_11:.*]] = linalg.fill ins(%[[VAL_10]] : f32) outs(%[[VAL_9]] : tensor<5x4x5xf32>) -> tensor<5x4x5xf32>
// CHECK: %[[VAL_12:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_13:.*]] = arith.constant 5 : index
// CHECK: %[[VAL_13:.*]] = arith.constant 4 : index
// CHECK: %[[VAL_14:.*]] = arith.constant 2 : index
// CHECK: %[[VAL_15:.*]] = arith.constant 8 : index
// CHECK: %[[VAL_16:.*]] = arith.constant 6.28318548 : f32
// CHECK: %[[VAL_17:.*]] = arith.index_castui %[[VAL_13]] : index to i32
// CHECK: %[[VAL_18:.*]] = arith.uitofp %[[VAL_17]] : i32 to f32
// CHECK: %[[VAL_19:.*]] = arith.index_castui %[[VAL_15]] : index to i32
// CHECK: %[[VAL_20:.*]] = arith.uitofp %[[VAL_19]] : i32 to f32
// CHECK: %[[VAL_21:.*]]:2 = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]]], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%[[VAL_0]] : tensor<5x5x8xf32>) outs(%[[VAL_8]], %[[VAL_11]] : tensor<5x5x5xf32>, tensor<5x5x5xf32>) {
// CHECK: %[[VAL_21:.*]]:2 = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]]], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%[[VAL_0]] : tensor<5x4x8xf32>) outs(%[[VAL_8]], %[[VAL_11]] : tensor<5x4x5xf32>, tensor<5x4x5xf32>) {
// CHECK: ^bb0(%[[VAL_22:.*]]: f32, %[[VAL_23:.*]]: f32, %[[VAL_24:.*]]: f32):
// CHECK: %[[VAL_25:.*]] = linalg.index 1 : index
// CHECK: %[[VAL_26:.*]] = linalg.index 2 : index
Expand All @@ -1766,12 +1766,12 @@ func.func @table8_dyn_table(%arg0: tensor<6xi8>, %arg1: tensor<?xi8>) -> () {
// CHECK: %[[VAL_45:.*]] = arith.addf %[[VAL_23]], %[[VAL_43]] : f32
// CHECK: %[[VAL_46:.*]] = arith.subf %[[VAL_24]], %[[VAL_44]] : f32
// CHECK: linalg.yield %[[VAL_45]], %[[VAL_46]] : f32, f32
// CHECK: } -> (tensor<5x5x5xf32>, tensor<5x5x5xf32>)
// CHECK: return %[[VAL_47:.*]]#0, %[[VAL_47]]#1 : tensor<5x5x5xf32>, tensor<5x5x5xf32>
// CHECK: } -> (tensor<5x4x5xf32>, tensor<5x4x5xf32>)
// CHECK: return %[[VAL_47:.*]]#0, %[[VAL_47]]#1 : tensor<5x4x5xf32>, tensor<5x4x5xf32>
// CHECK: }
func.func @test_static_rfft2d(%arg0: tensor<5x5x8xf32>) -> (tensor<5x5x5xf32>, tensor<5x5x5xf32>) {
%output_real, %output_imag = "tosa.rfft2d"(%arg0) {} : (tensor<5x5x8xf32>) -> (tensor<5x5x5xf32>, tensor<5x5x5xf32>)
return %output_real, %output_imag : tensor<5x5x5xf32>, tensor<5x5x5xf32>
func.func @test_static_rfft2d(%arg0: tensor<5x4x8xf32>) -> (tensor<5x4x5xf32>, tensor<5x4x5xf32>) {
%output_real, %output_imag = "tosa.rfft2d"(%arg0) {} : (tensor<5x4x8xf32>) -> (tensor<5x4x5xf32>, tensor<5x4x5xf32>)
return %output_real, %output_imag : tensor<5x4x5xf32>, tensor<5x4x5xf32>
}

// -----
Expand Down
80 changes: 80 additions & 0 deletions mlir/test/Dialect/Tosa/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1276,3 +1276,83 @@ func.func @test_conv2d_invalid_bias_size(%arg0: tensor<1x4x4x4xf32>, %arg1: tens
: (tensor<1x4x4x4xf32>, tensor<8x1x1x4xf32>, tensor<7xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x4x4x8xf32>
return %0 : tensor<1x4x4x8xf32>
}

// -----

func.func @test_fft2d_same_operands_and_result_element_type(%arg0: tensor<1x4x8xf32>, %arg1: tensor<1x4x8xf32>) -> (tensor<1x4x8xf16>, tensor<1x4x8xf16>) {
// expected-error@+1 {{'tosa.fft2d' op requires the same element type for all operands and results}}
%0, %1 = tosa.fft2d %arg0, %arg1 {inverse = false} : (tensor<1x4x8xf32>, tensor<1x4x8xf32>) -> (tensor<1x4x8xf16>, tensor<1x4x8xf16>)
return %0, %1 : tensor<1x4x8xf16>, tensor<1x4x8xf16>
}

// -----

func.func @test_fft2d_same_operands_and_result_shape(%arg0: tensor<1x4x8xf32>, %arg1: tensor<1x4x7xf32>) -> (tensor<1x4x8xf32>, tensor<1x4x8xf32>) {
// expected-error@+1 {{'tosa.fft2d' op requires the same shape for all operands and results}}
%0, %1 = tosa.fft2d %arg0, %arg1 {inverse = false} : (tensor<1x4x8xf32>, tensor<1x4x7xf32>) -> (tensor<1x4x8xf32>, tensor<1x4x8xf32>)
return %0, %1 : tensor<1x4x8xf32>, tensor<1x4x8xf32>
}

// -----

func.func @test_fft2d_invalid_type(%arg0: tensor<1x4x8xi8>, %arg1: tensor<1x4x8xi8>) -> (tensor<1x4x8xi8>, tensor<1x4x8xi8>) {
// expected-error@+1 {{'tosa.fft2d' op requires a floating point type}}
%0, %1 = tosa.fft2d %arg0, %arg1 {inverse = false} : (tensor<1x4x8xi8>, tensor<1x4x8xi8>) -> (tensor<1x4x8xi8>, tensor<1x4x8xi8>)
return %0, %1 : tensor<1x4x8xi8>, tensor<1x4x8xi8>
}

// -----

func.func @test_fft2d_height_non_power_of_two(%arg0: tensor<1x5x8xf32>, %arg1: tensor<1x5x8xf32>) -> (tensor<1x5x8xf32>, tensor<1x5x8xf32>) {
// expected-error@+1 {{'tosa.fft2d' op expected height to be a power of two, got 5}}
%0, %1 = tosa.fft2d %arg0, %arg1 {inverse = false} : (tensor<1x5x8xf32>, tensor<1x5x8xf32>) -> (tensor<1x5x8xf32>, tensor<1x5x8xf32>)
return %0, %1 : tensor<1x5x8xf32>, tensor<1x5x8xf32>
}

// -----

func.func @test_rfft2d_same_operands_and_result_element_type(%arg0: tensor<1x4x8xf32>) -> (tensor<1x4x5xf16>, tensor<1x4x5xf16>) {
// expected-error@+1 {{'tosa.rfft2d' op requires the same element type for all operands and results}}
%0, %1 = tosa.rfft2d %arg0 {inverse = false} : (tensor<1x4x8xf32>) -> (tensor<1x4x5xf16>, tensor<1x4x5xf16>)
return %0, %1 : tensor<1x4x5xf16>, tensor<1x4x5xf16>
}

// -----

func.func @test_rfft2d_same_results_shape(%arg0: tensor<1x4x8xf32>) -> (tensor<1x4x6xf32>, tensor<1x4x5xf32>) {
// expected-error@+1 {{'tosa.rfft2d' op expected output shapes to match, got 'tensor<1x4x6xf32>', 'tensor<1x4x5xf32>'}}
%0, %1 = tosa.rfft2d %arg0 {inverse = false} : (tensor<1x4x8xf32>) -> (tensor<1x4x6xf32>, tensor<1x4x5xf32>)
return %0, %1 : tensor<1x4x6xf32>, tensor<1x4x5xf32>
}

// -----

func.func @test_rfft2d_invalid_type(%arg0: tensor<1x4x8xi16>) -> (tensor<1x4x5xi16>, tensor<1x4x5xi16>) {
// expected-error@+1 {{'tosa.rfft2d' op requires a floating point type}}
%0, %1 = tosa.rfft2d %arg0 {inverse = false} : (tensor<1x4x8xi16>) -> (tensor<1x4x5xi16>, tensor<1x4x5xi16>)
return %0, %1 : tensor<1x4x5xi16>, tensor<1x4x5xi16>
}

// -----

func.func @test_rfft2d_width_power_of_two(%arg0: tensor<1x4x9xf16>) -> (tensor<1x4x5xf16>, tensor<1x4x5xf16>) {
// expected-error@+1 {{'tosa.rfft2d' op expected width to be a power of two, got 9}}
%0, %1 = tosa.rfft2d %arg0 {inverse = false} : (tensor<1x4x9xf16>) -> (tensor<1x4x5xf16>, tensor<1x4x5xf16>)
return %0, %1 : tensor<1x4x5xf16>, tensor<1x4x5xf16>
}

// -----

func.func @test_rfft2d_batch_input_output_match(%arg0: tensor<1x4x8xf16>) -> (tensor<2x4x5xf16>, tensor<2x4x5xf16>) {
// expected-error@+1 {{'tosa.rfft2d' op expected batch and height dimensions of input/output to match, got input='tensor<1x4x8xf16>' output='tensor<2x4x5xf16>'}}
%0, %1 = tosa.rfft2d %arg0 {inverse = false} : (tensor<1x4x8xf16>) -> (tensor<2x4x5xf16>, tensor<2x4x5xf16>)
return %0, %1 : tensor<2x4x5xf16>, tensor<2x4x5xf16>
}

// -----

func.func @test_rfft2d_width_input_output_match(%arg0: tensor<1x4x8xf16>) -> (tensor<1x4x3xf16>, tensor<1x4x3xf16>) {
// expected-error@+1 {{'tosa.rfft2d' op expected output width to be equal to input_width / 2 + 1, got 3}}
%0, %1 = tosa.rfft2d %arg0 {inverse = false} : (tensor<1x4x8xf16>) -> (tensor<1x4x3xf16>, tensor<1x4x3xf16>)
return %0, %1 : tensor<1x4x3xf16>, tensor<1x4x3xf16>
}
36 changes: 18 additions & 18 deletions mlir/test/Dialect/Tosa/level_check.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -477,38 +477,38 @@ func.func @test_depthwise_conv2d_stride_x(%arg0: tensor<1x32x32x8xf32>, %arg1: t

// -----

func.func @test_fft2d_real_h(%arg0: tensor<32x8193x32xf32>, %arg1: tensor<32x32x32xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>) {
func.func @test_fft2d_real_h(%arg0: tensor<32x16384x32xf32>, %arg1: tensor<32x16384x32xf32>) -> (tensor<32x16384x32xf32>, tensor<32x16384x32xf32>) {
// expected-error@+1 {{'tosa.fft2d' op failed level check: H <= MAX_KERNEL}}
%0, %1 = "tosa.fft2d"(%arg0, %arg1) { inverse = false } :
(tensor<32x8193x32xf32>, tensor<32x32x32xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>)
return %0, %1 : tensor<32x32x32xf32>, tensor<32x32x32xf32>
(tensor<32x16384x32xf32>, tensor<32x16384x32xf32>) -> (tensor<32x16384x32xf32>, tensor<32x16384x32xf32>)
return %0, %1 : tensor<32x16384x32xf32>, tensor<32x16384x32xf32>
}

// -----

func.func @test_fft2d_real_w(%arg0: tensor<32x32x8193xf32>, %arg1: tensor<32x32x32xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>) {
func.func @test_fft2d_real_w(%arg0: tensor<32x32x16384xf32>, %arg1: tensor<32x32x16384xf32>) -> (tensor<32x32x16384xf32>, tensor<32x32x16384xf32>) {
// expected-error@+1 {{'tosa.fft2d' op failed level check: W <= MAX_KERNEL}}
%0, %1 = "tosa.fft2d"(%arg0, %arg1) { inverse = false } :
(tensor<32x32x8193xf32>, tensor<32x32x32xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>)
return %0, %1 : tensor<32x32x32xf32>, tensor<32x32x32xf32>
(tensor<32x32x16384xf32>, tensor<32x32x16384xf32>) -> (tensor<32x32x16384xf32>, tensor<32x32x16384xf32>)
return %0, %1 : tensor<32x32x16384xf32>, tensor<32x32x16384xf32>
}

// -----

func.func @test_fft2d_imag_h(%arg0: tensor<32x32x32xf32>, %arg1: tensor<32x8193x32xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>) {
func.func @test_fft2d_imag_h(%arg0: tensor<32x16384x32xf32>, %arg1: tensor<32x16384x32xf32>) -> (tensor<32x16384x32xf32>, tensor<32x16384x32xf32>) {
// expected-error@+1 {{'tosa.fft2d' op failed level check: H <= MAX_KERNEL}}
%0, %1 = "tosa.fft2d"(%arg0, %arg1) { inverse = false } :
(tensor<32x32x32xf32>, tensor<32x8193x32xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>)
return %0, %1 : tensor<32x32x32xf32>, tensor<32x32x32xf32>
(tensor<32x16384x32xf32>, tensor<32x16384x32xf32>) -> (tensor<32x16384x32xf32>, tensor<32x16384x32xf32>)
return %0, %1 : tensor<32x16384x32xf32>, tensor<32x16384x32xf32>
}

// -----

func.func @test_fft2d_imag_w(%arg0: tensor<32x32x32xf32>, %arg1: tensor<32x32x8193xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>) {
func.func @test_fft2d_imag_w(%arg0: tensor<32x32x16384xf32>, %arg1: tensor<32x32x16384xf32>) -> (tensor<32x32x16384xf32>, tensor<32x32x16384xf32>) {
// expected-error@+1 {{'tosa.fft2d' op failed level check: W <= MAX_KERNEL}}
%0, %1 = "tosa.fft2d"(%arg0, %arg1) { inverse = false } :
(tensor<32x32x32xf32>, tensor<32x32x8193xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>)
return %0, %1 : tensor<32x32x32xf32>, tensor<32x32x32xf32>
(tensor<32x32x16384xf32>, tensor<32x32x16384xf32>) -> (tensor<32x32x16384xf32>, tensor<32x32x16384xf32>)
return %0, %1 : tensor<32x32x16384xf32>, tensor<32x32x16384xf32>
}

// -----
Expand Down Expand Up @@ -577,18 +577,18 @@ func.func @test_maxpool2d_pad_right(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32

// -----

func.func @test_rfft2d_input_h(%arg0: tensor<13x8193x16xf32>) -> (tensor<13x8x9xf32>, tensor<13x8x9xf32>) {
func.func @test_rfft2d_input_h(%arg0: tensor<13x16384x16xf32>) -> (tensor<13x16384x9xf32>, tensor<13x16384x9xf32>) {
// expected-error@+1 {{'tosa.rfft2d' op failed level check: H <= MAX_KERNEL}}
%0, %1 = "tosa.rfft2d"(%arg0) {} : (tensor<13x8193x16xf32>) -> (tensor<13x8x9xf32>, tensor<13x8x9xf32>)
return %0, %1 : tensor<13x8x9xf32>, tensor<13x8x9xf32>
%0, %1 = "tosa.rfft2d"(%arg0) {} : (tensor<13x16384x16xf32>) -> (tensor<13x16384x9xf32>, tensor<13x16384x9xf32>)
return %0, %1 : tensor<13x16384x9xf32>, tensor<13x16384x9xf32>
}

// -----

func.func @test_rfft2d_input_w(%arg0: tensor<13x8x8193xf32>) -> (tensor<13x8x9xf32>, tensor<13x8x9xf32>) {
func.func @test_rfft2d_input_w(%arg0: tensor<13x8x16384xf32>) -> (tensor<13x8x8193xf32>, tensor<13x8x8193xf32>) {
// expected-error@+1 {{'tosa.rfft2d' op failed level check: W <= MAX_KERNEL}}
%0, %1 = "tosa.rfft2d"(%arg0) {} : (tensor<13x8x8193xf32>) -> (tensor<13x8x9xf32>, tensor<13x8x9xf32>)
return %0, %1 : tensor<13x8x9xf32>, tensor<13x8x9xf32>
%0, %1 = "tosa.rfft2d"(%arg0) {} : (tensor<13x8x16384xf32>) -> (tensor<13x8x8193xf32>, tensor<13x8x8193xf32>)
return %0, %1 : tensor<13x8x8193xf32>, tensor<13x8x8193xf32>
}

// -----
Expand Down
Loading