diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir index 303f841e8a828..621baef82319f 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir @@ -110,12 +110,12 @@ func.func @transfer_read_dims_mismatch_non_zero_indices( func.func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices( %arg : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>, - %idx0 : index, - %idx1 : index) -> vector<2x2xf32> { + %idx_1 : index, + %idx_2 : index) -> vector<2x2xf32> { %c0 = arith.constant 0 : index %cst_1 = arith.constant 0.000000e+00 : f32 - %8 = vector.transfer_read %arg[%c0, %idx0, %idx1, %c0], %cst_1 {in_bounds = [true, true]} : + %8 = vector.transfer_read %arg[%c0, %idx_1, %idx_2, %c0], %cst_1 {in_bounds = [true, true]} : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>, vector<2x2xf32> return %8 : vector<2x2xf32> } @@ -123,7 +123,8 @@ func.func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices( // CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2)> // CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices( -// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> into memref<1x3x6xf32, strided<[40, 10, 1], offset: ?>> +// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] +// CHECK-SAME: : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> into memref<1x3x6xf32, strided<[40, 10, 1], offset: ?>> // CHECK: %[[APPLY:.*]] = affine.apply #[[$MAP]]() // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices( @@ -131,10 +132,42 @@ func.func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices( // ----- -// The input memref has a dynamic trailing shape and hence is not flattened. -// TODO: This case could be supported via memref.dim +// The leading dynamic shapes don't affect whether this example is flattenable +// or not. Indeed, those dynamic shapes are not candidates for flattening anyway. -func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes( +func.func @transfer_read_leading_dynamic_dims( + %arg : memref>, + %idx_1 : index, + %idx_2 : index) -> vector<8x4xi8> { + + %c0_i8 = arith.constant 0 : i8 + %c0 = arith.constant 0 : index + %result = vector.transfer_read %arg[%idx_1, %idx_2, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : + memref>, vector<8x4xi8> + return %result : vector<8x4xi8> +} + +// CHECK-LABEL: func @transfer_read_leading_dynamic_dims +// CHECK-SAME: %[[ARG0:.+]]: memref, %[[ARG1:.+]]: index, %[[ARG2:.+]]: index +// CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8 +// CHECK: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG0]] {{\[}}[0], [1], [2, 3]{{\]}} +// CHECK-SAME: : memref into memref +// CHECK: %[[VEC1D:.+]] = vector.transfer_read %[[COLLAPSED]] +// CHECK-SAME: [%[[ARG1]], %[[ARG2]], %[[C0]]], %[[C0_I8]] +// CHECK-SAME: {in_bounds = [true]} +// CHECK-SAME: : memref, vector<32xi8> +// CHECK: %[[VEC2D:.+]] = vector.shape_cast %[[VEC1D]] : vector<32xi8> to vector<8x4xi8> +// CHECK: return %[[VEC2D]] : vector<8x4xi8> + +// CHECK-128B-LABEL: func @transfer_read_leading_dynamic_dims +// CHECK-128B: memref.collapse_shape + +// ----- + +// One of the dims to be flattened is dynamic - not supported ATM. + +func.func @negative_transfer_read_dynamic_dim_to_flatten( %idx_1: index, %idx_2: index, %m_in: memref<1x?x4x6xi32>) -> vector<1x2x6xi32> { @@ -146,11 +179,11 @@ func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes( return %v : vector<1x2x6xi32> } -// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes( +// CHECK-LABEL: func.func @negative_transfer_read_dynamic_dim_to_flatten // CHECK-NOT: memref.collapse_shape // CHECK-NOT: vector.shape_cast -// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes( +// CHECK-128B-LABEL: func @negative_transfer_read_dynamic_dim_to_flatten // CHECK-128B-NOT: memref.collapse_shape // ----- @@ -326,11 +359,11 @@ func.func @transfer_write_dims_mismatch_non_zero_indices( func.func @transfer_write_dims_mismatch_non_contiguous_non_zero_indices( %value : vector<2x2xf32>, %subview : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>, - %idx0 : index, - %idx1 : index) { + %idx_1 : index, + %idx_2 : index) { %c0 = arith.constant 0 : index - vector.transfer_write %value, %subview[%c0, %idx0, %idx1, %c0] {in_bounds = [true, true]} : vector<2x2xf32>, memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> + vector.transfer_write %value, %subview[%c0, %idx_1, %idx_2, %c0] {in_bounds = [true, true]} : vector<2x2xf32>, memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> return } @@ -345,10 +378,40 @@ func.func @transfer_write_dims_mismatch_non_contiguous_non_zero_indices( // ----- -// The input memref has a dynamic trailing shape and hence is not flattened. -// TODO: This case could be supported via memref.dim +// The leading dynamic shapes don't affect whether this example is flattenable +// or not. Indeed, those dynamic shapes are not candidates for flattening anyway. + +func.func @transfer_write_leading_dynamic_dims( + %vec : vector<8x4xi8>, + %arg : memref>, + %idx_1 : index, + %idx_2 : index) { + + %c0 = arith.constant 0 : index + vector.transfer_write %vec, %arg[%idx_1, %idx_2, %c0, %c0] {in_bounds = [true, true]} : + vector<8x4xi8>, memref> + return +} + +// CHECK-LABEL: func @transfer_write_leading_dynamic_dims +// CHECK-SAME: %[[ARG0:.+]]: vector<8x4xi8>, %[[ARG1:.+]]: memref, %[[ARG2:.+]]: index, %[[ARG3:.+]]: index +// CHECK: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG1]] {{\[}}[0], [1], [2, 3]{{\]}} +// CHECK-SAME: : memref into memref +// CHECK: %[[VEC1D:.+]] = vector.shape_cast %[[ARG0]] : vector<8x4xi8> to vector<32xi8> +// CHECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]] +// CHECK-SAME: [%[[ARG2]], %[[ARG3]], %[[C0]]] +// CHECK-SAME: {in_bounds = [true]} +// CHECK-SAME: : vector<32xi8>, memref + +// CHECK-128B-LABEL: func @transfer_write_leading_dynamic_dims +// CHECK-128B: memref.collapse_shape -func.func @transfer_write_dims_mismatch_non_zero_indices_dynamic_shapes( +// ----- + +// One of the dims to be flattened is dynamic - not supported ATM. + +func.func @negative_transfer_write_dynamic_to_flatten( %idx_1: index, %idx_2: index, %vec : vector<1x2x6xi32>, @@ -361,11 +424,11 @@ func.func @transfer_write_dims_mismatch_non_zero_indices_dynamic_shapes( return } -// CHECK-LABEL: func.func @transfer_write_dims_mismatch_non_zero_indices_dynamic_shapes( +// CHECK-LABEL: func.func @negative_transfer_write_dynamic_to_flatten // CHECK-NOT: memref.collapse_shape // CHECK-NOT: vector.shape_cast -// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_non_zero_indices_dynamic_shapes( +// CHECK-128B-LABEL: func @negative_transfer_write_dynamic_to_flatten // CHECK-128B-NOT: memref.collapse_shape // ----- @@ -434,56 +497,10 @@ func.func @transfer_write_non_contiguous_src( // ----- ///---------------------------------------------------------------------------------------- -/// TODO: Categorize + re-format +/// [Pattern: DropUnitDimFromElementwiseOps] +/// TODO: Move to a dedicated file - there's no "flattening" in the following tests ///---------------------------------------------------------------------------------------- -func.func @transfer_read_flattenable_with_dynamic_dims_and_indices(%arg0 : memref>, %arg1 : index, %arg2 : index) -> vector<8x4xi8> { - %c0_i8 = arith.constant 0 : i8 - %c0 = arith.constant 0 : index - %result = vector.transfer_read %arg0[%arg1, %arg2, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref>, vector<8x4xi8> - return %result : vector<8x4xi8> -} - -// CHECK-LABEL: func @transfer_read_flattenable_with_dynamic_dims_and_indices -// CHECK-SAME: %[[ARG0:.+]]: memref, %[[ARG1:.+]]: index, %[[ARG2:.+]]: index -// CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8 -// CHECK: %[[C0:.+]] = arith.constant 0 : index -// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG0]] {{\[}}[0], [1], [2, 3]{{\]}} -// CHECK-SAME: : memref into memref -// CHECK: %[[VEC1D:.+]] = vector.transfer_read %[[COLLAPSED]] -// CHECK-SAME: [%[[ARG1]], %[[ARG2]], %[[C0]]], %[[C0_I8]] -// CHECK-SAME: {in_bounds = [true]} -// CHECK-SAME: : memref, vector<32xi8> -// CHECK: %[[VEC2D:.+]] = vector.shape_cast %[[VEC1D]] : vector<32xi8> to vector<8x4xi8> -// CHECK: return %[[VEC2D]] : vector<8x4xi8> - -// CHECK-128B-LABEL: func @transfer_read_flattenable_with_dynamic_dims_and_indices( -// CHECK-128B: memref.collapse_shape - -// ----- - -func.func @transfer_write_flattenable_with_dynamic_dims_and_indices(%vec : vector<8x4xi8>, %dst : memref>, %arg1 : index, %arg2 : index) { - %c0 = arith.constant 0 : index - vector.transfer_write %vec, %dst[%arg1, %arg2, %c0, %c0] {in_bounds = [true, true]} : vector<8x4xi8>, memref> - return -} - -// CHECK-LABEL: func @transfer_write_flattenable_with_dynamic_dims_and_indices -// CHECK-SAME: %[[ARG0:.+]]: vector<8x4xi8>, %[[ARG1:.+]]: memref, %[[ARG2:.+]]: index, %[[ARG3:.+]]: index -// CHECK: %[[C0:.+]] = arith.constant 0 : index -// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG1]] {{\[}}[0], [1], [2, 3]{{\]}} -// CHECK-SAME: : memref into memref -// CHECK: %[[VEC1D:.+]] = vector.shape_cast %[[ARG0]] : vector<8x4xi8> to vector<32xi8> -// CHECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]] -// CHECK-SAME: [%[[ARG2]], %[[ARG3]], %[[C0]]] -// CHECK-SAME: {in_bounds = [true]} -// CHECK-SAME: : vector<32xi8>, memref - -// CHECK-128B-LABEL: func @transfer_write_flattenable_with_dynamic_dims_and_indices( -// CHECK-128B: memref.collapse_shape - -// ----- - func.func @fold_unit_dim_add_basic(%arg0 : vector<1x8xi32>) -> vector<1x8xi32> { %add = arith.addi %arg0, %arg0 : vector<1x8xi32> return %add : vector<1x8xi32>