Skip to content

Commit 079557c

Browse files
authored
[mlir][tosa] Add verifiers for FFT2d and RFFT2d (#129273)
Adds checks for element types and input/output shapes. Signed-off-by: Luke Hutton <[email protected]>
1 parent ab30df4 commit 079557c

File tree

6 files changed

+200
-42
lines changed

6 files changed

+200
-42
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,10 @@ def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> {
240240
//===----------------------------------------------------------------------===//
241241
// Operator: fft2d
242242
//===----------------------------------------------------------------------===//
243-
def Tosa_FFT2dOp : Tosa_InferShapedTypeOp<"fft2d"> {
243+
def Tosa_FFT2dOp : Tosa_InferShapedTypeOp<"fft2d", [
244+
SameOperandsAndResultElementType,
245+
SameOperandsAndResultShape,
246+
ResultsAreFloatLike]> {
244247
let summary = "Performs FFT2D operation on the input.";
245248

246249
let description = [{
@@ -279,6 +282,8 @@ def Tosa_FFT2dOp : Tosa_InferShapedTypeOp<"fft2d"> {
279282
$input_real `,` $input_imag attr-dict `:` `(` type($input_real) `,`
280283
type($input_imag) `)` `->` `(` type($output_real) `,` type($output_imag) `)`
281284
}];
285+
286+
let hasVerifier = 1;
282287
}
283288

284289
//===----------------------------------------------------------------------===//
@@ -349,7 +354,9 @@ def Tosa_MaxPool2dOp : Tosa_InferShapedTypeOp<"max_pool2d"> {
349354
//===----------------------------------------------------------------------===//
350355
// Operator: rfft2d
351356
//===----------------------------------------------------------------------===//
352-
def Tosa_RFFT2dOp : Tosa_InferShapedTypeOp<"rfft2d"> {
357+
def Tosa_RFFT2dOp : Tosa_InferShapedTypeOp<"rfft2d", [
358+
SameOperandsAndResultElementType,
359+
ResultsAreFloatLike]> {
353360
let summary = "Performs RFFT2D operation on the input.";
354361

355362
let description = [{
@@ -385,6 +392,8 @@ def Tosa_RFFT2dOp : Tosa_InferShapedTypeOp<"rfft2d"> {
385392
let assemblyFormat = [{
386393
$input attr-dict `:` `(` type($input) `)` `->` `(` type($output_real) `,` type($output_imag) `)`
387394
}];
395+
396+
let hasVerifier = 1;
388397
}
389398

390399
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -789,7 +789,7 @@ LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
789789
int64_t inWidth = inputShape.getDimSize(2);
790790

791791
// Note that we can support this calculation symbolically
792-
// in the future e.g. [x, y, z] -> [x, y, z / 2 - 1]
792+
// in the future e.g. [x, y, z] -> [x, y, z / 2 + 1]
793793
if (inWidth != ShapedType::kDynamic)
794794
outputShape[2] = inWidth / 2 + 1;
795795

@@ -799,6 +799,57 @@ LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
799799
return success();
800800
}
801801

802+
static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize,
803+
const llvm::StringRef dimName) {
804+
const bool isPowerOfTwo = (dimSize & (dimSize - 1)) == 0 && dimSize > 0;
805+
if (!isPowerOfTwo)
806+
return op->emitOpError("expected ")
807+
<< dimName << " to be a power of two, got " << dimSize;
808+
809+
return success();
810+
}
811+
812+
LogicalResult tosa::RFFT2dOp::verify() {
813+
const auto outputTypes = getResultTypes();
814+
if (failed(verifyCompatibleShapes(outputTypes)))
815+
return emitOpError("expected output shapes to match, got ") << outputTypes;
816+
817+
const auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().getType());
818+
if (!inputType)
819+
return success();
820+
821+
const int64_t height = inputType.getDimSize(1);
822+
if (!ShapedType::isDynamic(height) &&
823+
failed(verifyDimIsPowerOfTwo(*this, height, "height")))
824+
return failure();
825+
826+
const int64_t width = inputType.getDimSize(2);
827+
if (!ShapedType::isDynamic(width) &&
828+
failed(verifyDimIsPowerOfTwo(*this, width, "width")))
829+
return failure();
830+
831+
const auto outputType = llvm::dyn_cast<RankedTensorType>(outputTypes[0]);
832+
if (!outputType)
833+
return success();
834+
835+
// Batch and height input/output dimensions should match
836+
if (failed(verifyCompatibleShape(inputType.getShape().drop_back(),
837+
outputType.getShape().drop_back())))
838+
return emitOpError("expected batch and height dimensions of input/output "
839+
"to match, got input=")
840+
<< inputType << " output=" << outputType;
841+
842+
// Output width dimension expected to be input_width / 2 + 1
843+
const int64_t outputWidth = outputType.getDimSize(2);
844+
if (!ShapedType::isDynamic(width) && !ShapedType::isDynamic(outputWidth) &&
845+
(outputWidth - 1) * 2 != width)
846+
return emitOpError(
847+
"expected output width to be equal to input_width / 2 + 1, got ")
848+
<< outputWidth;
849+
850+
return success();
851+
}
852+
802853
LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
803854
MLIRContext *context, ::std::optional<Location> location,
804855
FFT2dOp::Adaptor adaptor,
@@ -810,6 +861,33 @@ LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
810861
return success();
811862
}
812863

864+
LogicalResult tosa::FFT2dOp::verify() {
865+
const auto inputRealType =
866+
llvm::dyn_cast<RankedTensorType>(getInputReal().getType());
867+
const auto inputImagType =
868+
llvm::dyn_cast<RankedTensorType>(getInputImag().getType());
869+
if (!inputRealType || !inputImagType)
870+
return success();
871+
872+
const auto trySelectStaticDim = [](const int64_t a, const int64_t b) {
873+
return ShapedType::isDynamic(a) ? a : b;
874+
};
875+
876+
const int64_t height = trySelectStaticDim(inputRealType.getDimSize(1),
877+
inputImagType.getDimSize(1));
878+
if (!ShapedType::isDynamic(height) &&
879+
failed(verifyDimIsPowerOfTwo(*this, height, "height")))
880+
return failure();
881+
882+
const int64_t width = trySelectStaticDim(inputRealType.getDimSize(2),
883+
inputImagType.getDimSize(2));
884+
if (!ShapedType::isDynamic(width) &&
885+
failed(verifyDimIsPowerOfTwo(*this, width, "width")))
886+
return failure();
887+
888+
return success();
889+
}
890+
813891
LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
814892
MLIRContext *context, ::std::optional<Location> location,
815893
ConcatOp::Adaptor adaptor,

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,6 @@ func.func @unranked_add(%arg0 : tensor<10x10xf32> , %arg1 : tensor<10x10xf32>, %
3131

3232
// -----
3333

34-
// CHECK-LABEL: @rfft2d_with_non_float_type
35-
func.func @rfft2d_with_non_float_type(%arg0 : tensor<1x1x1xi32>) -> (tensor<1x1x1xi32>, tensor<1x1x1xi32>) {
36-
// expected-error@+1 {{failed to legalize operation 'tosa.rfft2d'}}
37-
%real, %imag = tosa.rfft2d %arg0 : (tensor<1x1x1xi32>) -> (tensor<1x1x1xi32>, tensor<1x1x1xi32>)
38-
return %real, %imag : tensor<1x1x1xi32>, tensor<1x1x1xi32>
39-
}
40-
41-
// -----
42-
4334
// CHECK-LABEL: @rescale_unsupported_type
4435
func.func @rescale_unsupported_type(%arg0: tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>> {
4536
// expected-error@+1 {{failed to legalize operation 'tosa.rescale'}}

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1720,28 +1720,28 @@ func.func @table8_dyn_table(%arg0: tensor<6xi8>, %arg1: tensor<?xi8>) -> () {
17201720
// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
17211721

17221722
// CHECK-LABEL: func.func @test_static_rfft2d(
1723-
// CHECK-SAME: %[[VAL_0:.*]]: tensor<5x5x8xf32>) -> (tensor<5x5x5xf32>, tensor<5x5x5xf32>) {
1723+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<5x4x8xf32>) -> (tensor<5x4x5xf32>, tensor<5x4x5xf32>) {
17241724
// CHECK: %[[VAL_1:.*]] = arith.constant 1 : index
17251725
// CHECK: %[[VAL_2:.*]] = arith.constant 2 : index
17261726
// CHECK: %[[VAL_3:.*]] = arith.constant 8 : index
17271727
// CHECK: %[[VAL_4:.*]] = arith.constant 4 : index
17281728
// CHECK: %[[VAL_5:.*]] = arith.constant 5 : index
1729-
// CHECK: %[[VAL_6:.*]] = tensor.empty() : tensor<5x5x5xf32>
1729+
// CHECK: %[[VAL_6:.*]] = tensor.empty() : tensor<5x4x5xf32>
17301730
// CHECK: %[[VAL_7:.*]] = arith.constant 0.000000e+00 : f32
1731-
// CHECK: %[[VAL_8:.*]] = linalg.fill ins(%[[VAL_7]] : f32) outs(%[[VAL_6]] : tensor<5x5x5xf32>) -> tensor<5x5x5xf32>
1732-
// CHECK: %[[VAL_9:.*]] = tensor.empty() : tensor<5x5x5xf32>
1731+
// CHECK: %[[VAL_8:.*]] = linalg.fill ins(%[[VAL_7]] : f32) outs(%[[VAL_6]] : tensor<5x4x5xf32>) -> tensor<5x4x5xf32>
1732+
// CHECK: %[[VAL_9:.*]] = tensor.empty() : tensor<5x4x5xf32>
17331733
// CHECK: %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f32
1734-
// CHECK: %[[VAL_11:.*]] = linalg.fill ins(%[[VAL_10]] : f32) outs(%[[VAL_9]] : tensor<5x5x5xf32>) -> tensor<5x5x5xf32>
1734+
// CHECK: %[[VAL_11:.*]] = linalg.fill ins(%[[VAL_10]] : f32) outs(%[[VAL_9]] : tensor<5x4x5xf32>) -> tensor<5x4x5xf32>
17351735
// CHECK: %[[VAL_12:.*]] = arith.constant 1 : index
1736-
// CHECK: %[[VAL_13:.*]] = arith.constant 5 : index
1736+
// CHECK: %[[VAL_13:.*]] = arith.constant 4 : index
17371737
// CHECK: %[[VAL_14:.*]] = arith.constant 2 : index
17381738
// CHECK: %[[VAL_15:.*]] = arith.constant 8 : index
17391739
// CHECK: %[[VAL_16:.*]] = arith.constant 6.28318548 : f32
17401740
// CHECK: %[[VAL_17:.*]] = arith.index_castui %[[VAL_13]] : index to i32
17411741
// CHECK: %[[VAL_18:.*]] = arith.uitofp %[[VAL_17]] : i32 to f32
17421742
// CHECK: %[[VAL_19:.*]] = arith.index_castui %[[VAL_15]] : index to i32
17431743
// CHECK: %[[VAL_20:.*]] = arith.uitofp %[[VAL_19]] : i32 to f32
1744-
// CHECK: %[[VAL_21:.*]]:2 = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]]], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%[[VAL_0]] : tensor<5x5x8xf32>) outs(%[[VAL_8]], %[[VAL_11]] : tensor<5x5x5xf32>, tensor<5x5x5xf32>) {
1744+
// CHECK: %[[VAL_21:.*]]:2 = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]]], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%[[VAL_0]] : tensor<5x4x8xf32>) outs(%[[VAL_8]], %[[VAL_11]] : tensor<5x4x5xf32>, tensor<5x4x5xf32>) {
17451745
// CHECK: ^bb0(%[[VAL_22:.*]]: f32, %[[VAL_23:.*]]: f32, %[[VAL_24:.*]]: f32):
17461746
// CHECK: %[[VAL_25:.*]] = linalg.index 1 : index
17471747
// CHECK: %[[VAL_26:.*]] = linalg.index 2 : index
@@ -1766,12 +1766,12 @@ func.func @table8_dyn_table(%arg0: tensor<6xi8>, %arg1: tensor<?xi8>) -> () {
17661766
// CHECK: %[[VAL_45:.*]] = arith.addf %[[VAL_23]], %[[VAL_43]] : f32
17671767
// CHECK: %[[VAL_46:.*]] = arith.subf %[[VAL_24]], %[[VAL_44]] : f32
17681768
// CHECK: linalg.yield %[[VAL_45]], %[[VAL_46]] : f32, f32
1769-
// CHECK: } -> (tensor<5x5x5xf32>, tensor<5x5x5xf32>)
1770-
// CHECK: return %[[VAL_47:.*]]#0, %[[VAL_47]]#1 : tensor<5x5x5xf32>, tensor<5x5x5xf32>
1769+
// CHECK: } -> (tensor<5x4x5xf32>, tensor<5x4x5xf32>)
1770+
// CHECK: return %[[VAL_47:.*]]#0, %[[VAL_47]]#1 : tensor<5x4x5xf32>, tensor<5x4x5xf32>
17711771
// CHECK: }
1772-
func.func @test_static_rfft2d(%arg0: tensor<5x5x8xf32>) -> (tensor<5x5x5xf32>, tensor<5x5x5xf32>) {
1773-
%output_real, %output_imag = "tosa.rfft2d"(%arg0) {} : (tensor<5x5x8xf32>) -> (tensor<5x5x5xf32>, tensor<5x5x5xf32>)
1774-
return %output_real, %output_imag : tensor<5x5x5xf32>, tensor<5x5x5xf32>
1772+
func.func @test_static_rfft2d(%arg0: tensor<5x4x8xf32>) -> (tensor<5x4x5xf32>, tensor<5x4x5xf32>) {
1773+
%output_real, %output_imag = "tosa.rfft2d"(%arg0) {} : (tensor<5x4x8xf32>) -> (tensor<5x4x5xf32>, tensor<5x4x5xf32>)
1774+
return %output_real, %output_imag : tensor<5x4x5xf32>, tensor<5x4x5xf32>
17751775
}
17761776

17771777
// -----

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1268,3 +1268,83 @@ func.func @test_conv2d_invalid_bias_size(%arg0: tensor<1x4x4x4xf32>, %arg1: tens
12681268
: (tensor<1x4x4x4xf32>, tensor<8x1x1x4xf32>, tensor<7xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x4x4x8xf32>
12691269
return %0 : tensor<1x4x4x8xf32>
12701270
}
1271+
1272+
// -----
1273+
1274+
func.func @test_fft2d_same_operands_and_result_element_type(%arg0: tensor<1x4x8xf32>, %arg1: tensor<1x4x8xf32>) -> (tensor<1x4x8xf16>, tensor<1x4x8xf16>) {
1275+
// expected-error@+1 {{'tosa.fft2d' op requires the same element type for all operands and results}}
1276+
%0, %1 = tosa.fft2d %arg0, %arg1 {inverse = false} : (tensor<1x4x8xf32>, tensor<1x4x8xf32>) -> (tensor<1x4x8xf16>, tensor<1x4x8xf16>)
1277+
return %0, %1 : tensor<1x4x8xf16>, tensor<1x4x8xf16>
1278+
}
1279+
1280+
// -----
1281+
1282+
func.func @test_fft2d_same_operands_and_result_shape(%arg0: tensor<1x4x8xf32>, %arg1: tensor<1x4x7xf32>) -> (tensor<1x4x8xf32>, tensor<1x4x8xf32>) {
1283+
// expected-error@+1 {{'tosa.fft2d' op requires the same shape for all operands and results}}
1284+
%0, %1 = tosa.fft2d %arg0, %arg1 {inverse = false} : (tensor<1x4x8xf32>, tensor<1x4x7xf32>) -> (tensor<1x4x8xf32>, tensor<1x4x8xf32>)
1285+
return %0, %1 : tensor<1x4x8xf32>, tensor<1x4x8xf32>
1286+
}
1287+
1288+
// -----
1289+
1290+
func.func @test_fft2d_invalid_type(%arg0: tensor<1x4x8xi8>, %arg1: tensor<1x4x8xi8>) -> (tensor<1x4x8xi8>, tensor<1x4x8xi8>) {
1291+
// expected-error@+1 {{'tosa.fft2d' op requires a floating point type}}
1292+
%0, %1 = tosa.fft2d %arg0, %arg1 {inverse = false} : (tensor<1x4x8xi8>, tensor<1x4x8xi8>) -> (tensor<1x4x8xi8>, tensor<1x4x8xi8>)
1293+
return %0, %1 : tensor<1x4x8xi8>, tensor<1x4x8xi8>
1294+
}
1295+
1296+
// -----
1297+
1298+
func.func @test_fft2d_height_non_power_of_two(%arg0: tensor<1x5x8xf32>, %arg1: tensor<1x5x8xf32>) -> (tensor<1x5x8xf32>, tensor<1x5x8xf32>) {
1299+
// expected-error@+1 {{'tosa.fft2d' op expected height to be a power of two, got 5}}
1300+
%0, %1 = tosa.fft2d %arg0, %arg1 {inverse = false} : (tensor<1x5x8xf32>, tensor<1x5x8xf32>) -> (tensor<1x5x8xf32>, tensor<1x5x8xf32>)
1301+
return %0, %1 : tensor<1x5x8xf32>, tensor<1x5x8xf32>
1302+
}
1303+
1304+
// -----
1305+
1306+
func.func @test_rfft2d_same_operands_and_result_element_type(%arg0: tensor<1x4x8xf32>) -> (tensor<1x4x5xf16>, tensor<1x4x5xf16>) {
1307+
// expected-error@+1 {{'tosa.rfft2d' op requires the same element type for all operands and results}}
1308+
%0, %1 = tosa.rfft2d %arg0 {inverse = false} : (tensor<1x4x8xf32>) -> (tensor<1x4x5xf16>, tensor<1x4x5xf16>)
1309+
return %0, %1 : tensor<1x4x5xf16>, tensor<1x4x5xf16>
1310+
}
1311+
1312+
// -----
1313+
1314+
func.func @test_rfft2d_same_results_shape(%arg0: tensor<1x4x8xf32>) -> (tensor<1x4x6xf32>, tensor<1x4x5xf32>) {
1315+
// expected-error@+1 {{'tosa.rfft2d' op expected output shapes to match, got 'tensor<1x4x6xf32>', 'tensor<1x4x5xf32>'}}
1316+
%0, %1 = tosa.rfft2d %arg0 {inverse = false} : (tensor<1x4x8xf32>) -> (tensor<1x4x6xf32>, tensor<1x4x5xf32>)
1317+
return %0, %1 : tensor<1x4x6xf32>, tensor<1x4x5xf32>
1318+
}
1319+
1320+
// -----
1321+
1322+
func.func @test_rfft2d_invalid_type(%arg0: tensor<1x4x8xi16>) -> (tensor<1x4x5xi16>, tensor<1x4x5xi16>) {
1323+
// expected-error@+1 {{'tosa.rfft2d' op requires a floating point type}}
1324+
%0, %1 = tosa.rfft2d %arg0 {inverse = false} : (tensor<1x4x8xi16>) -> (tensor<1x4x5xi16>, tensor<1x4x5xi16>)
1325+
return %0, %1 : tensor<1x4x5xi16>, tensor<1x4x5xi16>
1326+
}
1327+
1328+
// -----
1329+
1330+
func.func @test_rfft2d_width_power_of_two(%arg0: tensor<1x4x9xf16>) -> (tensor<1x4x5xf16>, tensor<1x4x5xf16>) {
1331+
// expected-error@+1 {{'tosa.rfft2d' op expected width to be a power of two, got 9}}
1332+
%0, %1 = tosa.rfft2d %arg0 {inverse = false} : (tensor<1x4x9xf16>) -> (tensor<1x4x5xf16>, tensor<1x4x5xf16>)
1333+
return %0, %1 : tensor<1x4x5xf16>, tensor<1x4x5xf16>
1334+
}
1335+
1336+
// -----
1337+
1338+
func.func @test_rfft2d_batch_input_output_match(%arg0: tensor<1x4x8xf16>) -> (tensor<2x4x5xf16>, tensor<2x4x5xf16>) {
1339+
// expected-error@+1 {{'tosa.rfft2d' op expected batch and height dimensions of input/output to match, got input='tensor<1x4x8xf16>' output='tensor<2x4x5xf16>'}}
1340+
%0, %1 = tosa.rfft2d %arg0 {inverse = false} : (tensor<1x4x8xf16>) -> (tensor<2x4x5xf16>, tensor<2x4x5xf16>)
1341+
return %0, %1 : tensor<2x4x5xf16>, tensor<2x4x5xf16>
1342+
}
1343+
1344+
// -----
1345+
1346+
func.func @test_rfft2d_width_input_output_match(%arg0: tensor<1x4x8xf16>) -> (tensor<1x4x3xf16>, tensor<1x4x3xf16>) {
1347+
// expected-error@+1 {{'tosa.rfft2d' op expected output width to be equal to input_width / 2 + 1, got 3}}
1348+
%0, %1 = tosa.rfft2d %arg0 {inverse = false} : (tensor<1x4x8xf16>) -> (tensor<1x4x3xf16>, tensor<1x4x3xf16>)
1349+
return %0, %1 : tensor<1x4x3xf16>, tensor<1x4x3xf16>
1350+
}

mlir/test/Dialect/Tosa/level_check.mlir

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -477,38 +477,38 @@ func.func @test_depthwise_conv2d_stride_x(%arg0: tensor<1x32x32x8xf32>, %arg1: t
477477

478478
// -----
479479

480-
func.func @test_fft2d_real_h(%arg0: tensor<32x8193x32xf32>, %arg1: tensor<32x32x32xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>) {
480+
func.func @test_fft2d_real_h(%arg0: tensor<32x16384x32xf32>, %arg1: tensor<32x16384x32xf32>) -> (tensor<32x16384x32xf32>, tensor<32x16384x32xf32>) {
481481
// expected-error@+1 {{'tosa.fft2d' op failed level check: H <= MAX_KERNEL}}
482482
%0, %1 = "tosa.fft2d"(%arg0, %arg1) { inverse = false } :
483-
(tensor<32x8193x32xf32>, tensor<32x32x32xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>)
484-
return %0, %1 : tensor<32x32x32xf32>, tensor<32x32x32xf32>
483+
(tensor<32x16384x32xf32>, tensor<32x16384x32xf32>) -> (tensor<32x16384x32xf32>, tensor<32x16384x32xf32>)
484+
return %0, %1 : tensor<32x16384x32xf32>, tensor<32x16384x32xf32>
485485
}
486486

487487
// -----
488488

489-
func.func @test_fft2d_real_w(%arg0: tensor<32x32x8193xf32>, %arg1: tensor<32x32x32xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>) {
489+
func.func @test_fft2d_real_w(%arg0: tensor<32x32x16384xf32>, %arg1: tensor<32x32x16384xf32>) -> (tensor<32x32x16384xf32>, tensor<32x32x16384xf32>) {
490490
// expected-error@+1 {{'tosa.fft2d' op failed level check: W <= MAX_KERNEL}}
491491
%0, %1 = "tosa.fft2d"(%arg0, %arg1) { inverse = false } :
492-
(tensor<32x32x8193xf32>, tensor<32x32x32xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>)
493-
return %0, %1 : tensor<32x32x32xf32>, tensor<32x32x32xf32>
492+
(tensor<32x32x16384xf32>, tensor<32x32x16384xf32>) -> (tensor<32x32x16384xf32>, tensor<32x32x16384xf32>)
493+
return %0, %1 : tensor<32x32x16384xf32>, tensor<32x32x16384xf32>
494494
}
495495

496496
// -----
497497

498-
func.func @test_fft2d_imag_h(%arg0: tensor<32x32x32xf32>, %arg1: tensor<32x8193x32xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>) {
498+
func.func @test_fft2d_imag_h(%arg0: tensor<32x16384x32xf32>, %arg1: tensor<32x16384x32xf32>) -> (tensor<32x16384x32xf32>, tensor<32x16384x32xf32>) {
499499
// expected-error@+1 {{'tosa.fft2d' op failed level check: H <= MAX_KERNEL}}
500500
%0, %1 = "tosa.fft2d"(%arg0, %arg1) { inverse = false } :
501-
(tensor<32x32x32xf32>, tensor<32x8193x32xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>)
502-
return %0, %1 : tensor<32x32x32xf32>, tensor<32x32x32xf32>
501+
(tensor<32x16384x32xf32>, tensor<32x16384x32xf32>) -> (tensor<32x16384x32xf32>, tensor<32x16384x32xf32>)
502+
return %0, %1 : tensor<32x16384x32xf32>, tensor<32x16384x32xf32>
503503
}
504504

505505
// -----
506506

507-
func.func @test_fft2d_imag_w(%arg0: tensor<32x32x32xf32>, %arg1: tensor<32x32x8193xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>) {
507+
func.func @test_fft2d_imag_w(%arg0: tensor<32x32x16384xf32>, %arg1: tensor<32x32x16384xf32>) -> (tensor<32x32x16384xf32>, tensor<32x32x16384xf32>) {
508508
// expected-error@+1 {{'tosa.fft2d' op failed level check: W <= MAX_KERNEL}}
509509
%0, %1 = "tosa.fft2d"(%arg0, %arg1) { inverse = false } :
510-
(tensor<32x32x32xf32>, tensor<32x32x8193xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>)
511-
return %0, %1 : tensor<32x32x32xf32>, tensor<32x32x32xf32>
510+
(tensor<32x32x16384xf32>, tensor<32x32x16384xf32>) -> (tensor<32x32x16384xf32>, tensor<32x32x16384xf32>)
511+
return %0, %1 : tensor<32x32x16384xf32>, tensor<32x32x16384xf32>
512512
}
513513

514514
// -----
@@ -577,18 +577,18 @@ func.func @test_maxpool2d_pad_right(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32
577577

578578
// -----
579579

580-
func.func @test_rfft2d_input_h(%arg0: tensor<13x8193x16xf32>) -> (tensor<13x8x9xf32>, tensor<13x8x9xf32>) {
580+
func.func @test_rfft2d_input_h(%arg0: tensor<13x16384x16xf32>) -> (tensor<13x16384x9xf32>, tensor<13x16384x9xf32>) {
581581
// expected-error@+1 {{'tosa.rfft2d' op failed level check: H <= MAX_KERNEL}}
582-
%0, %1 = "tosa.rfft2d"(%arg0) {} : (tensor<13x8193x16xf32>) -> (tensor<13x8x9xf32>, tensor<13x8x9xf32>)
583-
return %0, %1 : tensor<13x8x9xf32>, tensor<13x8x9xf32>
582+
%0, %1 = "tosa.rfft2d"(%arg0) {} : (tensor<13x16384x16xf32>) -> (tensor<13x16384x9xf32>, tensor<13x16384x9xf32>)
583+
return %0, %1 : tensor<13x16384x9xf32>, tensor<13x16384x9xf32>
584584
}
585585

586586
// -----
587587

588-
func.func @test_rfft2d_input_w(%arg0: tensor<13x8x8193xf32>) -> (tensor<13x8x9xf32>, tensor<13x8x9xf32>) {
588+
func.func @test_rfft2d_input_w(%arg0: tensor<13x8x16384xf32>) -> (tensor<13x8x8193xf32>, tensor<13x8x8193xf32>) {
589589
// expected-error@+1 {{'tosa.rfft2d' op failed level check: W <= MAX_KERNEL}}
590-
%0, %1 = "tosa.rfft2d"(%arg0) {} : (tensor<13x8x8193xf32>) -> (tensor<13x8x9xf32>, tensor<13x8x9xf32>)
591-
return %0, %1 : tensor<13x8x9xf32>, tensor<13x8x9xf32>
590+
%0, %1 = "tosa.rfft2d"(%arg0) {} : (tensor<13x8x16384xf32>) -> (tensor<13x8x8193xf32>, tensor<13x8x8193xf32>)
591+
return %0, %1 : tensor<13x8x8193xf32>, tensor<13x8x8193xf32>
592592
}
593593

594594
// -----

0 commit comments

Comments
 (0)