diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir index 31c4775696b31..9d16aa46a9f2a 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir @@ -9,18 +9,18 @@ ///---------------------------------------------------------------------------------------- func.func @transfer_read_dims_match_contiguous( - %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> { + %mem : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> { %c0 = arith.constant 0 : index %cst = arith.constant 0 : i8 - %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst : + %res = vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<5x4x3x2xi8> - return %v : vector<5x4x3x2xi8> + return %res : vector<5x4x3x2xi8> } // CHECK-LABEL: func @transfer_read_dims_match_contiguous -// CHECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8 -// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3] +// CHECK-SAME: %[[MEM:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8 +// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] {{.}}[0, 1, 2, 3] // CHECK: %[[READ1D:.+]] = vector.transfer_read %[[COLLAPSED]] // CHECK: %[[VEC2D:.+]] = vector.shape_cast %[[READ1D]] : vector<120xi8> to vector<5x4x3x2xi8> // CHECK: return %[[VEC2D]] @@ -31,18 +31,18 @@ func.func @transfer_read_dims_match_contiguous( // ----- func.func @transfer_read_dims_match_contiguous_empty_stride( - %arg : memref<5x4x3x2xi8>) -> vector<5x4x3x2xi8> { + %mem : memref<5x4x3x2xi8>) -> vector<5x4x3x2xi8> { %c0 = arith.constant 0 : index %cst = arith.constant 0 : i8 - %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst : + %res = vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst : memref<5x4x3x2xi8>, vector<5x4x3x2xi8> - return %v : vector<5x4x3x2xi8> + return %res : vector<5x4x3x2xi8> } // CHECK-LABEL: func @transfer_read_dims_match_contiguous_empty_stride( -// CHECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8 -// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3] +// CHECK-SAME: %[[MEM:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8 +// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] {{.}}[0, 1, 2, 3] // CHECK: %[[READ1D:.+]] = vector.transfer_read %[[COLLAPSED]] // CHECK: %[[VEC2D:.+]] = vector.shape_cast %[[READ1D]] : vector<120xi8> to vector<5x4x3x2xi8> // CHECK: return %[[VEC2D]] @@ -56,20 +56,20 @@ func.func @transfer_read_dims_match_contiguous_empty_stride( // contiguous subset of the memref, so "flattenable". func.func @transfer_read_dims_mismatch_contiguous( - %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> { + %mem : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> { %c0 = arith.constant 0 : index %cst = arith.constant 0 : i8 - %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst : + %res = vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<1x1x2x2xi8> - return %v : vector<1x1x2x2xi8> + return %res : vector<1x1x2x2xi8> } // CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous( -// CHECK-SAME: %[[VAL_0:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> { +// CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> { // CHECK: %[[VAL_1:.*]] = arith.constant 0 : i8 // CHECK: %[[VAL_2:.*]] = arith.constant 0 : index -// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, strided<[1], offset: ?>> +// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[MEM]] {{\[\[}}0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, strided<[1], offset: ?>> // CHECK: %[[VAL_4:.*]] = vector.transfer_read %[[VAL_3]]{{\[}}%[[VAL_2]]], %[[VAL_1]] {in_bounds = [true]} : memref<120xi8, strided<[1], offset: ?>>, vector<4xi8> // CHECK: %[[VAL_5:.*]] = vector.shape_cast %[[VAL_4]] : vector<4xi8> to vector<1x1x2x2xi8> // CHECK: return %[[VAL_5]] : vector<1x1x2x2xi8> @@ -82,23 +82,24 @@ func.func @transfer_read_dims_mismatch_contiguous( func.func @transfer_read_dims_mismatch_non_zero_indices( %idx_1: index, %idx_2: index, - %arg: memref<1x43x4x6xi32>) -> vector<1x2x6xi32>{ + %mem: memref<1x43x4x6xi32>) -> vector<1x2x6xi32>{ %c0 = arith.constant 0 : index %c0_i32 = arith.constant 0 : i32 - %v = vector.transfer_read %arg[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} : - memref<1x43x4x6xi32>, vector<1x2x6xi32> - return %v : vector<1x2x6xi32> + %res = vector.transfer_read %mem[%c0, %idx_1, %idx_2, %c0], %c0_i32 { + in_bounds = [true, true, true] + } : memref<1x43x4x6xi32>, vector<1x2x6xi32> + return %res : vector<1x2x6xi32> } // CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)> // CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_zero_indices( // CHECK-SAME: %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index, -// CHECK-SAME: %[[M_IN:.*]]: memref<1x43x4x6xi32> +// CHECK-SAME: %[[MEM:.*]]: memref<1x43x4x6xi32> // CHECK: %[[C_0:.*]] = arith.constant 0 : i32 // CHECK: %[[C_0_IDX:.*]] = arith.constant 0 : index -// CHECK: %[[COLLAPSED_IN:.*]] = memref.collapse_shape %[[M_IN]] {{\[}}[0], [1, 2, 3]] : memref<1x43x4x6xi32> into memref<1x1032xi32> +// CHECK: %[[COLLAPSED_IN:.*]] = memref.collapse_shape %[[MEM]] {{\[}}[0], [1, 2, 3]] : memref<1x43x4x6xi32> into memref<1x1032xi32> // CHECK: %[[COLLAPSED_IDX:.*]] = affine.apply #[[$ATTR_0]]()[%[[IDX_1]], %[[IDX_2]]] // CHECK: %[[READ:.*]] = vector.transfer_read %[[COLLAPSED_IN]][%[[C_0_IDX]], %[[COLLAPSED_IDX]]], %[[C_0]] {in_bounds = [true]} : memref<1x1032xi32>, vector<12xi32> @@ -111,15 +112,16 @@ func.func @transfer_read_dims_mismatch_non_zero_indices( // the output vector is to be read _is_ contiguous. Hence the flattening works fine. func.func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices( - %arg : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>, + %mem : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>, %idx_1 : index, %idx_2 : index) -> vector<2x2xf32> { %c0 = arith.constant 0 : index %cst_1 = arith.constant 0.000000e+00 : f32 - %8 = vector.transfer_read %arg[%c0, %idx_1, %idx_2, %c0], %cst_1 {in_bounds = [true, true]} : - memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>, vector<2x2xf32> - return %8 : vector<2x2xf32> + %res = vector.transfer_read %mem[%c0, %idx_1, %idx_2, %c0], %cst_1 { + in_bounds = [true, true] + } : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>, vector<2x2xf32> + return %res : vector<2x2xf32> } // CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2)> @@ -138,29 +140,30 @@ func.func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices( // or not. Indeed, those dynamic shapes are not candidates for flattening anyway. func.func @transfer_read_leading_dynamic_dims( - %arg : memref>, + %mem : memref>, %idx_1 : index, %idx_2 : index) -> vector<8x4xi8> { %c0_i8 = arith.constant 0 : i8 %c0 = arith.constant 0 : index - %result = vector.transfer_read %arg[%idx_1, %idx_2, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : - memref>, vector<8x4xi8> - return %result : vector<8x4xi8> + %res = vector.transfer_read %mem[%idx_1, %idx_2, %c0, %c0], %c0_i8 { + in_bounds = [true, true] + } : memref>, vector<8x4xi8> + return %res : vector<8x4xi8> } // CHECK-LABEL: func @transfer_read_leading_dynamic_dims -// CHECK-SAME: %[[ARG0:.+]]: memref, %[[ARG1:.+]]: index, %[[ARG2:.+]]: index +// CHECK-SAME: %[[MEM:.+]]: memref, %[[IDX_1:.+]]: index, %[[IDX_2:.+]]: index // CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8 // CHECK: %[[C0:.+]] = arith.constant 0 : index -// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG0]] {{\[}}[0], [1], [2, 3]{{\]}} +// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] {{\[}}[0], [1], [2, 3]{{\]}} // CHECK-SAME: : memref into memref // CHECK: %[[VEC1D:.+]] = vector.transfer_read %[[COLLAPSED]] -// CHECK-SAME: [%[[ARG1]], %[[ARG2]], %[[C0]]], %[[C0_I8]] +// CHECK-SAME: [%[[IDX_1]], %[[IDX_2]], %[[C0]]], %[[C0_I8]] // CHECK-SAME: {in_bounds = [true]} // CHECK-SAME: : memref, vector<32xi8> -// CHECK: %[[VEC2D:.+]] = vector.shape_cast %[[VEC1D]] : vector<32xi8> to vector<8x4xi8> -// CHECK: return %[[VEC2D]] : vector<8x4xi8> +// CHECK: %[[RES:.+]] = vector.shape_cast %[[VEC1D]] : vector<32xi8> to vector<8x4xi8> +// CHECK: return %[[RES]] : vector<8x4xi8> // CHECK-128B-LABEL: func @transfer_read_leading_dynamic_dims // CHECK-128B: memref.collapse_shape @@ -172,13 +175,14 @@ func.func @transfer_read_leading_dynamic_dims( func.func @negative_transfer_read_dynamic_dim_to_flatten( %idx_1: index, %idx_2: index, - %m_in: memref<1x?x4x6xi32>) -> vector<1x2x6xi32> { + %mem: memref<1x?x4x6xi32>) -> vector<1x2x6xi32> { %c0 = arith.constant 0 : index %c0_i32 = arith.constant 0 : i32 - %v = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} : - memref<1x?x4x6xi32>, vector<1x2x6xi32> - return %v : vector<1x2x6xi32> + %res = vector.transfer_read %mem[%c0, %idx_1, %idx_2, %c0], %c0_i32 { + in_bounds = [true, true, true] + } : memref<1x?x4x6xi32>, vector<1x2x6xi32> + return %res : vector<1x2x6xi32> } // CHECK-LABEL: func.func @negative_transfer_read_dynamic_dim_to_flatten @@ -194,13 +198,13 @@ func.func @negative_transfer_read_dynamic_dim_to_flatten( // memref. func.func @transfer_read_dims_mismatch_non_contiguous_slice( - %arg : memref<5x4x3x2xi8>) -> vector<2x1x2x2xi8> { + %mem : memref<5x4x3x2xi8>) -> vector<2x1x2x2xi8> { %c0 = arith.constant 0 : index %cst = arith.constant 0 : i8 - %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst : + %res = vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst : memref<5x4x3x2xi8>, vector<2x1x2x2xi8> - return %v : vector<2x1x2x2xi8> + return %res : vector<2x1x2x2xi8> } // CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous_slice( @@ -213,11 +217,11 @@ func.func @transfer_read_dims_mismatch_non_contiguous_slice( // ----- func.func @transfer_read_0d( - %arg : memref) -> vector { + %mem : memref) -> vector { %cst = arith.constant 0 : i8 - %0 = vector.transfer_read %arg[], %cst : memref, vector - return %0 : vector + %res = vector.transfer_read %mem[], %cst : memref, vector + return %res : vector } // CHECK-LABEL: func.func @transfer_read_0d @@ -233,13 +237,13 @@ func.func @transfer_read_0d( // Strides make the input memref non-contiguous, hence non-flattenable. func.func @transfer_read_non_contiguous_src( - %arg : memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> { + %mem : memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> { %c0 = arith.constant 0 : index %cst = arith.constant 0 : i8 - %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst : + %res = vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst : memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>, vector<5x4x3x2xi8> - return %v : vector<5x4x3x2xi8> + return %res : vector<5x4x3x2xi8> } // CHECK-LABEL: func.func @transfer_read_non_contiguous_src @@ -258,19 +262,19 @@ func.func @transfer_read_non_contiguous_src( ///---------------------------------------------------------------------------------------- func.func @transfer_write_dims_match_contiguous( - %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, + %mem : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, %vec : vector<5x4x3x2xi8>) { %c0 = arith.constant 0 : index - vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] : + vector.transfer_write %vec, %mem [%c0, %c0, %c0, %c0] : vector<5x4x3x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> return } // CHECK-LABEL: func @transfer_write_dims_match_contiguous( -// CHECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8 +// CHECK-SAME: %[[MEM:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8 // CHECK-SAME: %[[VEC:[0-9a-zA-Z]+]]: vector<5x4x3x2xi8> -// CHECK-DAG: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]{{.}} : memref<5x4x3x2xi8, {{.+}}> into memref<120xi8, {{.+}}> +// CHECK-DAG: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] {{.}}[0, 1, 2, 3]{{.}} : memref<5x4x3x2xi8, {{.+}}> into memref<120xi8, {{.+}}> // CHECK-DAG: %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<5x4x3x2xi8> to vector<120xi8> // CHECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]] @@ -280,19 +284,19 @@ func.func @transfer_write_dims_match_contiguous( // ----- func.func @transfer_write_dims_match_contiguous_empty_stride( - %arg : memref<5x4x3x2xi8>, + %mem : memref<5x4x3x2xi8>, %vec : vector<5x4x3x2xi8>) { %c0 = arith.constant 0 : index - vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] : + vector.transfer_write %vec, %mem [%c0, %c0, %c0, %c0] : vector<5x4x3x2xi8>, memref<5x4x3x2xi8> return } // CHECK-LABEL: func @transfer_write_dims_match_contiguous_empty_stride( -// CHECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8 +// CHECK-SAME: %[[MEM:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8 // CHECK-SAME: %[[VEC:[0-9a-zA-Z]+]]: vector<5x4x3x2xi8> -// CHECK-DAG: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]{{.}} : memref<5x4x3x2xi8> into memref<120xi8> +// CHECK-DAG: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] {{.}}[0, 1, 2, 3]{{.}} : memref<5x4x3x2xi8> into memref<120xi8> // CHECK-DAG: %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<5x4x3x2xi8> to vector<120xi8> // CHECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]] @@ -302,21 +306,21 @@ func.func @transfer_write_dims_match_contiguous_empty_stride( // ----- func.func @transfer_write_dims_mismatch_contiguous( - %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, + %mem : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, %vec : vector<1x1x2x2xi8>) { %c0 = arith.constant 0 : index - vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] : + vector.transfer_write %vec, %mem [%c0, %c0, %c0, %c0] : vector<1x1x2x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> return } // CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous -// CHECK-SAME: %[[VAL_0:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, -// CHECK-SAME: %[[VAL_1:.*]]: vector<1x1x2x2xi8>) { +// CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, +// CHECK-SAME: %[[VEC:.*]]: vector<1x1x2x2xi8>) { // CHECK: %[[VAL_2:.*]] = arith.constant 0 : index -// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, strided<[1], offset: ?>> -// CHECK: %[[VAL_4:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x1x2x2xi8> to vector<4xi8> +// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[MEM]] {{\[\[}}0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, strided<[1], offset: ?>> +// CHECK: %[[VAL_4:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x2x2xi8> to vector<4xi8> // CHECK: vector.transfer_write %[[VAL_4]], %[[VAL_3]]{{\[}}%[[VAL_2]]] {in_bounds = [true]} : vector<4xi8>, memref<120xi8, strided<[1], offset: ?>> // CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous( @@ -327,12 +331,12 @@ func.func @transfer_write_dims_mismatch_contiguous( func.func @transfer_write_dims_mismatch_non_zero_indices( %idx_1: index, %idx_2: index, - %arg: memref<1x43x4x6xi32>, + %mem: memref<1x43x4x6xi32>, %vec: vector<1x2x6xi32>) { %c0 = arith.constant 0 : index %c0_i32 = arith.constant 0 : i32 - vector.transfer_write %vec, %arg[%c0, %idx_1, %idx_2, %c0] {in_bounds = [true, true, true]} : + vector.transfer_write %vec, %mem[%c0, %idx_1, %idx_2, %c0] {in_bounds = [true, true, true]} : vector<1x2x6xi32>, memref<1x43x4x6xi32> return } @@ -341,11 +345,11 @@ func.func @transfer_write_dims_mismatch_non_zero_indices( // CHECK-LABEL: func.func @transfer_write_dims_mismatch_non_zero_indices( // CHECK-SAME: %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index, -// CHECK-SAME: %[[ARG:.*]]: memref<1x43x4x6xi32>, +// CHECK-SAME: %[[MEM:.*]]: memref<1x43x4x6xi32>, // CHECK-SAME: %[[VEC:.*]]: vector<1x2x6xi32>) { // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[IDX:.*]] = affine.apply #[[$ATTR_0]](){{\[}}%[[IDX_1]], %[[IDX_2]]] -// CHECK-DAG: %[[CS:.*]] = memref.collapse_shape %[[ARG]] {{\[\[}}0], [1, 2, 3]] : memref<1x43x4x6xi32> into memref<1x1032xi32> +// CHECK-DAG: %[[CS:.*]] = memref.collapse_shape %[[MEM]] {{\[\[}}0], [1, 2, 3]] : memref<1x43x4x6xi32> into memref<1x1032xi32> // CHECK: %[[SC:.*]] = vector.shape_cast %[[VEC]] : vector<1x2x6xi32> to vector<12xi32> // CHECK: vector.transfer_write %[[SC]], %[[CS]]{{\[}}%[[C0]], %[[IDX]]] {in_bounds = [true]} : vector<12xi32>, memref<1x1032xi32> @@ -359,13 +363,13 @@ func.func @transfer_write_dims_mismatch_non_zero_indices( // flattening works fine. func.func @transfer_write_dims_mismatch_non_contiguous_non_zero_indices( - %value : vector<2x2xf32>, - %subview : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>, + %vec : vector<2x2xf32>, + %mem : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>, %idx_1 : index, %idx_2 : index) { %c0 = arith.constant 0 : index - vector.transfer_write %value, %subview[%c0, %idx_1, %idx_2, %c0] {in_bounds = [true, true]} : vector<2x2xf32>, memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> + vector.transfer_write %vec, %mem[%c0, %idx_1, %idx_2, %c0] {in_bounds = [true, true]} : vector<2x2xf32>, memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> return } @@ -385,22 +389,22 @@ func.func @transfer_write_dims_mismatch_non_contiguous_non_zero_indices( func.func @transfer_write_leading_dynamic_dims( %vec : vector<8x4xi8>, - %arg : memref>, + %mem : memref>, %idx_1 : index, %idx_2 : index) { %c0 = arith.constant 0 : index - vector.transfer_write %vec, %arg[%idx_1, %idx_2, %c0, %c0] {in_bounds = [true, true]} : + vector.transfer_write %vec, %mem[%idx_1, %idx_2, %c0, %c0] {in_bounds = [true, true]} : vector<8x4xi8>, memref> return } // CHECK-LABEL: func @transfer_write_leading_dynamic_dims -// CHECK-SAME: %[[ARG0:.+]]: vector<8x4xi8>, %[[ARG1:.+]]: memref, %[[ARG2:.+]]: index, %[[ARG3:.+]]: index +// CHECK-SAME: %[[VEC:.+]]: vector<8x4xi8>, %[[MEM:.+]]: memref, %[[ARG2:.+]]: index, %[[ARG3:.+]]: index // CHECK: %[[C0:.+]] = arith.constant 0 : index -// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG1]] {{\[}}[0], [1], [2, 3]{{\]}} +// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] {{\[}}[0], [1], [2, 3]{{\]}} // CHECK-SAME: : memref into memref -// CHECK: %[[VEC1D:.+]] = vector.shape_cast %[[ARG0]] : vector<8x4xi8> to vector<32xi8> +// CHECK: %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<8x4xi8> to vector<32xi8> // CHECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]] // CHECK-SAME: [%[[ARG2]], %[[ARG3]], %[[C0]]] // CHECK-SAME: {in_bounds = [true]} @@ -417,11 +421,11 @@ func.func @negative_transfer_write_dynamic_to_flatten( %idx_1: index, %idx_2: index, %vec : vector<1x2x6xi32>, - %arg: memref<1x?x4x6xi32>) { + %mem: memref<1x?x4x6xi32>) { %c0 = arith.constant 0 : index %c0_i32 = arith.constant 0 : i32 - vector.transfer_write %vec, %arg[%c0, %idx_1, %idx_2, %c0] {in_bounds = [true, true, true]} : + vector.transfer_write %vec, %mem[%c0, %idx_1, %idx_2, %c0] {in_bounds = [true, true, true]} : vector<1x2x6xi32>, memref<1x?x4x6xi32> return } @@ -439,12 +443,12 @@ func.func @negative_transfer_write_dynamic_to_flatten( // memref. func.func @transfer_write_dims_mismatch_non_contiguous_slice( - %arg : memref<5x4x3x2xi8>, + %mem : memref<5x4x3x2xi8>, %vec : vector<2x1x2x2xi8>) { %c0 = arith.constant 0 : index %cst = arith.constant 0 : i8 - vector.transfer_write %vec, %arg[%c0, %c0, %c0, %c0] : + vector.transfer_write %vec, %mem[%c0, %c0, %c0, %c0] : vector<2x1x2x2xi8>, memref<5x4x3x2xi8> return } @@ -459,10 +463,10 @@ func.func @transfer_write_dims_mismatch_non_contiguous_slice( // ----- func.func @transfer_write_0d( - %arg : memref, + %mem : memref, %vec : vector) { - vector.transfer_write %vec, %arg[] : vector, memref + vector.transfer_write %vec, %mem[] : vector, memref return } @@ -479,11 +483,11 @@ func.func @transfer_write_0d( // The strides make the input memref non-contiguous, hence non-flattenable. func.func @transfer_write_non_contiguous_src( - %arg : memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>, + %mem : memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>, %vec : vector<5x4x3x2xi8>) { %c0 = arith.constant 0 : index - vector.transfer_write %vec, %arg[%c0, %c0, %c0, %c0] : + vector.transfer_write %vec, %mem[%c0, %c0, %c0, %c0] : vector<5x4x3x2xi8>, memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>> return } @@ -503,9 +507,9 @@ func.func @transfer_write_non_contiguous_src( /// TODO: Move to a dedicated file - there's no "flattening" in the following tests ///---------------------------------------------------------------------------------------- -func.func @fold_unit_dim_add_basic(%arg0 : vector<1x8xi32>) -> vector<1x8xi32> { - %add = arith.addi %arg0, %arg0 : vector<1x8xi32> - return %add : vector<1x8xi32> +func.func @fold_unit_dim_add_basic(%vec : vector<1x8xi32>) -> vector<1x8xi32> { + %res = arith.addi %vec, %vec : vector<1x8xi32> + return %res : vector<1x8xi32> } // CHECK-LABEL: func.func @fold_unit_dim_add_basic( // CHECK-SAME: %[[VAL_0:.*]]: vector<1x8xi32>) -> vector<1x8xi32> { @@ -520,9 +524,9 @@ func.func @fold_unit_dim_add_basic(%arg0 : vector<1x8xi32>) -> vector<1x8xi32> { // ----- -func.func @fold_unit_dim_add_leading_and_trailing(%arg0 : vector<1x8x1xi32>) -> vector<1x8x1xi32> { - %add = arith.addi %arg0, %arg0 : vector<1x8x1xi32> - return %add : vector<1x8x1xi32> +func.func @fold_unit_dim_add_leading_and_trailing(%vec : vector<1x8x1xi32>) -> vector<1x8x1xi32> { + %res = arith.addi %vec, %vec : vector<1x8x1xi32> + return %res : vector<1x8x1xi32> } // CHECK-LABEL: func.func @fold_unit_dim_add_leading_and_trailing( // CHECK-SAME: %[[VAL_0:.*]]: vector<1x8x1xi32>) -> vector<1x8x1xi32> { @@ -537,10 +541,10 @@ func.func @fold_unit_dim_add_leading_and_trailing(%arg0 : vector<1x8x1xi32>) -> // ----- -func.func @fold_unit_dim_add(%arg0 : vector<8x1xi32>, - %arg1 : vector<1x8xi32>) -> vector<8xi32> { - %sc_arg0 = vector.shape_cast %arg0 : vector<8x1xi32> to vector<1x8xi32> - %add = arith.addi %sc_arg0, %arg1 : vector<1x8xi32> +func.func @fold_unit_dim_add(%vec_0 : vector<8x1xi32>, + %vec_1 : vector<1x8xi32>) -> vector<8xi32> { + %sc_vec_0 = vector.shape_cast %vec_0 : vector<8x1xi32> to vector<1x8xi32> + %add = arith.addi %sc_vec_0, %vec_1 : vector<1x8xi32> %res = vector.shape_cast %add : vector<1x8xi32> to vector<8xi32> return %res : vector<8xi32> } @@ -558,10 +562,10 @@ func.func @fold_unit_dim_add(%arg0 : vector<8x1xi32>, // ----- -func.func @fold_unit_dim_mulf(%arg0 : vector<8x[2]x1xf32>, - %arg1 : vector<1x8x[2]xf32>) -> vector<8x[2]xf32> { - %sc_arg0 = vector.shape_cast %arg0 : vector<8x[2]x1xf32> to vector<1x8x[2]xf32> - %add = arith.mulf %sc_arg0, %arg1 : vector<1x8x[2]xf32> +func.func @fold_unit_dim_mulf(%vec_0 : vector<8x[2]x1xf32>, + %vec_1 : vector<1x8x[2]xf32>) -> vector<8x[2]xf32> { + %sc_vec_0 = vector.shape_cast %vec_0 : vector<8x[2]x1xf32> to vector<1x8x[2]xf32> + %add = arith.mulf %sc_vec_0, %vec_1 : vector<1x8x[2]xf32> %res = vector.shape_cast %add : vector<1x8x[2]xf32> to vector<8x[2]xf32> return %res : vector<8x[2]xf32> } @@ -579,9 +583,9 @@ func.func @fold_unit_dim_mulf(%arg0 : vector<8x[2]x1xf32>, // ----- -func.func @fold_unit_dim_sitofp(%arg0 : vector<8x[2]x1xi8>) -> vector<8x[2]xf32> { - %sc_arg0 = vector.shape_cast %arg0 : vector<8x[2]x1xi8> to vector<1x8x[2]xi8> - %add = arith.sitofp %sc_arg0 : vector<1x8x[2]xi8> to vector<1x8x[2]xf32> +func.func @fold_unit_dim_sitofp(%vec : vector<8x[2]x1xi8>) -> vector<8x[2]xf32> { + %sc_vec_0 = vector.shape_cast %vec : vector<8x[2]x1xi8> to vector<1x8x[2]xi8> + %add = arith.sitofp %sc_vec_0 : vector<1x8x[2]xi8> to vector<1x8x[2]xf32> %res = vector.shape_cast %add : vector<1x8x[2]xf32> to vector<8x[2]xf32> return %res : vector<8x[2]xf32> } @@ -599,14 +603,14 @@ func.func @fold_unit_dim_sitofp(%arg0 : vector<8x[2]x1xi8>) -> vector<8x[2]xf32> // All shape casts are folded away -func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>, - %arg1 : vector<8xi32>, - %arg2 : vector<8xi32>) -> vector<8xi32> { - %sc_arg0 = vector.shape_cast %arg0 : vector<8xi32> to vector<1x8xi32> - %sc_arg1 = vector.shape_cast %arg1 : vector<8xi32> to vector<1x8xi32> - %sc_arg2 = vector.shape_cast %arg2 : vector<8xi32> to vector<1x8xi32> - %mul = arith.muli %sc_arg0, %sc_arg1 : vector<1x8xi32> - %add = arith.addi %mul, %sc_arg2 : vector<1x8xi32> +func.func @fold_unit_dims_entirely(%vec_0 : vector<8xi32>, + %vec_1 : vector<8xi32>, + %vec_2 : vector<8xi32>) -> vector<8xi32> { + %sc_vec_0 = vector.shape_cast %vec_0 : vector<8xi32> to vector<1x8xi32> + %sc_vec_1 = vector.shape_cast %vec_1 : vector<8xi32> to vector<1x8xi32> + %sc_vec_2 = vector.shape_cast %vec_2 : vector<8xi32> to vector<1x8xi32> + %mul = arith.muli %sc_vec_0, %sc_vec_1 : vector<1x8xi32> + %add = arith.addi %mul, %sc_vec_2 : vector<1x8xi32> %res = vector.shape_cast %add : vector<1x8xi32> to vector<8xi32> return %res : vector<8xi32> } @@ -623,10 +627,10 @@ func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>, // ----- -func.func @fold_inner_unit_dim(%arg0 : vector<8x1x3xf128>, - %arg1 : vector<1x8x3xf128>) -> vector<8x3xf128> { - %sc_arg1 = vector.shape_cast %arg1 : vector<1x8x3xf128> to vector<8x1x3xf128> - %mul = arith.mulf %arg0, %sc_arg1 : vector<8x1x3xf128> +func.func @fold_inner_unit_dim(%vec_0 : vector<8x1x3xf128>, + %vec_1 : vector<1x8x3xf128>) -> vector<8x3xf128> { + %sc_vec_1 = vector.shape_cast %vec_1 : vector<1x8x3xf128> to vector<8x1x3xf128> + %mul = arith.mulf %vec_0, %sc_vec_1 : vector<8x1x3xf128> %res = vector.shape_cast %mul : vector<8x1x3xf128> to vector<8x3xf128> return %res : vector<8x3xf128> } @@ -641,10 +645,10 @@ func.func @fold_inner_unit_dim(%arg0 : vector<8x1x3xf128>, // ----- -func.func @fold_inner_unit_dim_scalable(%arg0 : vector<8x1x[1]x3xf128>, - %arg1 : vector<1x8x[1]x3xf128>) -> vector<8x[1]x3xf128> { - %sc_arg1 = vector.shape_cast %arg1 : vector<1x8x[1]x3xf128> to vector<8x1x[1]x3xf128> - %mul = arith.mulf %arg0, %sc_arg1 : vector<8x1x[1]x3xf128> +func.func @fold_inner_unit_dim_scalable(%vec_0 : vector<8x1x[1]x3xf128>, + %vec_1 : vector<1x8x[1]x3xf128>) -> vector<8x[1]x3xf128> { + %sc_vec_1 = vector.shape_cast %vec_1 : vector<1x8x[1]x3xf128> to vector<8x1x[1]x3xf128> + %mul = arith.mulf %vec_0, %sc_vec_1 : vector<8x1x[1]x3xf128> %res = vector.shape_cast %mul : vector<8x1x[1]x3xf128> to vector<8x[1]x3xf128> return %res : vector<8x[1]x3xf128> } @@ -659,8 +663,8 @@ func.func @fold_inner_unit_dim_scalable(%arg0 : vector<8x1x[1]x3xf128>, // ----- -func.func @fold_all_unit_dims(%arg0: vector<1x1xf32>) -> vector<1xf32> { - %0 = arith.mulf %arg0, %arg0 : vector<1x1xf32> +func.func @fold_all_unit_dims(%vec: vector<1x1xf32>) -> vector<1xf32> { + %0 = arith.mulf %vec, %vec : vector<1x1xf32> %res = vector.shape_cast %0 : vector<1x1xf32> to vector<1xf32> return %res : vector<1xf32> } @@ -675,12 +679,12 @@ func.func @fold_all_unit_dims(%arg0: vector<1x1xf32>) -> vector<1xf32> { // ----- func.func @negative_out_of_bound_transfer_read( - %arg : memref>) -> vector<5x4x3x2xi8> { + %mem : memref>) -> vector<5x4x3x2xi8> { %c0 = arith.constant 0 : index %cst = arith.constant 0 : i8 - %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst {in_bounds = [false, true, true, true]} : + %res = vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst {in_bounds = [false, true, true, true]} : memref>, vector<5x4x3x2xi8> - return %v : vector<5x4x3x2xi8> + return %res : vector<5x4x3x2xi8> } // CHECK: func.func @negative_out_of_bound_transfer_read // CHECK-NOT: memref.collapse_shape @@ -688,9 +692,9 @@ func.func @negative_out_of_bound_transfer_read( // ----- func.func @negative_out_of_bound_transfer_write( - %arg : memref>, %vec : vector<1x1x3x2xi8>) { + %mem : memref>, %vec : vector<1x1x3x2xi8>) { %c0 = arith.constant 0 : index - vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] {in_bounds = [false, true, true, true]} : + vector.transfer_write %vec, %mem [%c0, %c0, %c0, %c0] {in_bounds = [false, true, true, true]} : vector<1x1x3x2xi8>, memref> return }