diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir index 40a8b7e5e0737..3a5041fca53fc 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir @@ -131,25 +131,6 @@ func.func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices( // ----- -func.func @transfer_read_dims_mismatch_non_contiguous( - %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<2x1x2x2xi8> { - - %c0 = arith.constant 0 : index - %cst = arith.constant 0 : i8 - %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst : - memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<2x1x2x2xi8> - return %v : vector<2x1x2x2xi8> -} - -// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous -// CHECK-NOT: memref.collapse_shape -// CHECK-NOT: vector.shape_cast - -// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous( -// CHECK-128B-NOT: memref.collapse_shape - -// ----- - // The input memref has a dynamic trailing shape and hence is not flattened. // TODO: This case could be supported via memref.dim @@ -214,6 +195,28 @@ func.func @transfer_read_0d( // ----- +// Strides make the input memref non-contiguous, hence non-flattenable. + +func.func @transfer_read_non_contiguous_src( + %arg : memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> { + + %c0 = arith.constant 0 : index + %cst = arith.constant 0 : i8 + %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst : + memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>, vector<5x4x3x2xi8> + return %v : vector<5x4x3x2xi8> +} + +// CHECK-LABEL: func.func @transfer_read_non_contiguous_src +// CHECK-NOT: memref.collapse_shape +// CHECK-NOT: vector.shape_cast + +// CHECK-128B-LABEL: func @transfer_read_non_contiguous_src +// CHECK-128B-NOT: memref.collapse_shape +// CHECK-128B-NOT: vector.shape_cast + +// ----- + ///---------------------------------------------------------------------------------------- /// vector.transfer_write /// [Pattern: FlattenContiguousRowMajorTransferWritePattern] @@ -342,25 +345,6 @@ func.func @transfer_write_dims_mismatch_non_contiguous_non_zero_indices( // ----- -func.func @transfer_write_dims_mismatch_non_contiguous( - %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, - %vec : vector<2x1x2x2xi8>) { - - %c0 = arith.constant 0 : index - vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] : - vector<2x1x2x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> - return -} - -// CHECK-LABEL: func.func @transfer_write_dims_mismatch_non_contiguous -// CHECK-NOT: memref.collapse_shape -// CHECK-NOT: vector.shape_cast - -// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_non_contiguous( -// CHECK-128B-NOT: memref.collapse_shape - -// ----- - // The input memref has a dynamic trailing shape and hence is not flattened. // TODO: This case could be supported via memref.dim @@ -427,6 +411,28 @@ func.func @transfer_write_0d( // ----- +// The strides make the input memref non-contiguous, hence non-flattenable. + +func.func @transfer_write_non_contiguous_src( + %arg : memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>, + %vec : vector<5x4x3x2xi8>) { + + %c0 = arith.constant 0 : index + vector.transfer_write %vec, %arg[%c0, %c0, %c0, %c0] : + vector<5x4x3x2xi8>, memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>> + return +} + +// CHECK-LABEL: func.func @transfer_write_non_contiguous_src +// CHECK-NOT: memref.collapse_shape +// CHECK-NOT: vector.shape_cast + +// CHECK-128B-LABEL: func @transfer_write_non_contiguous_src +// CHECK-128B-NOT: memref.collapse_shape +// CHECK-128B-NOT: vector.shape_cast + +// ----- + ///---------------------------------------------------------------------------------------- /// TODO: Categorize + re-format ///---------------------------------------------------------------------------------------- @@ -478,40 +484,6 @@ func.func @transfer_write_flattenable_with_dynamic_dims_and_indices(%vec : vecto // ----- -func.func @transfer_read_flattenable_negative( - %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<2x2x2x2xi8> { - %c0 = arith.constant 0 : index - %cst = arith.constant 0 : i8 - %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst : - memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<2x2x2x2xi8> - return %v : vector<2x2x2x2xi8> -} - -// CHECK-LABEL: func @transfer_read_flattenable_negative -// CHECK: vector.transfer_read {{.*}} vector<2x2x2x2xi8> - -// CHECK-128B-LABEL: func @transfer_read_flattenable_negative( -// CHECK-128B-NOT: memref.collapse_shape - -// ----- - -func.func @transfer_read_flattenable_negative2( - %arg : memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> { - %c0 = arith.constant 0 : index - %cst = arith.constant 0 : i8 - %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst : - memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>, vector<5x4x3x2xi8> - return %v : vector<5x4x3x2xi8> -} - -// CHECK-LABEL: func @transfer_read_flattenable_negative2 -// CHECK: vector.transfer_read {{.*}} vector<5x4x3x2xi8> - -// CHECK-128B-LABEL: func @transfer_read_flattenable_negative2( -// CHECK-128B-NOT: memref.collapse_shape - -// ----- - func.func @fold_unit_dim_add_basic(%arg0 : vector<1x8xi32>) -> vector<1x8xi32> { %add = arith.addi %arg0, %arg0 : vector<1x8xi32> return %add : vector<1x8xi32>