diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp index c131fde517f80..4c93d3841bf87 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -568,6 +568,7 @@ namespace { /// memref.collapse_shape on the source so that the resulting /// vector.transfer_read has a 1D source. Requires the source shape to be /// already reduced i.e. without unit dims. +/// /// If `targetVectorBitwidth` is provided, the flattening will only happen if /// the trailing dimension of the vector read is smaller than the provided /// bitwidth. @@ -617,7 +618,7 @@ class FlattenContiguousRowMajorTransferReadPattern Value collapsedSource = collapseInnerDims(rewriter, loc, source, firstDimToCollapse); MemRefType collapsedSourceType = - dyn_cast(collapsedSource.getType()); + cast(collapsedSource.getType()); int64_t collapsedRank = collapsedSourceType.getRank(); assert(collapsedRank == firstDimToCollapse + 1); @@ -658,6 +659,10 @@ class FlattenContiguousRowMajorTransferReadPattern /// memref.collapse_shape on the source so that the resulting /// vector.transfer_write has a 1D source. Requires the source shape to be /// already reduced i.e. without unit dims. +/// +/// If `targetVectorBitwidth` is provided, the flattening will only happen if +/// the trailing dimension of the vector read is smaller than the provided +/// bitwidth. class FlattenContiguousRowMajorTransferWritePattern : public OpRewritePattern { public: @@ -674,9 +679,12 @@ class FlattenContiguousRowMajorTransferWritePattern VectorType vectorType = cast(vector.getType()); Value source = transferWriteOp.getSource(); MemRefType sourceType = dyn_cast(source.getType()); + + // 0. Check pre-conditions // Contiguity check is valid on tensors only. if (!sourceType) return failure(); + // If this is already 0D/1D, there's nothing to do. if (vectorType.getRank() <= 1) // Already 0D/1D, nothing to do. return failure(); @@ -688,7 +696,6 @@ class FlattenContiguousRowMajorTransferWritePattern return failure(); if (!vector::isContiguousSlice(sourceType, vectorType)) return failure(); - int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank(); // TODO: generalize this pattern, relax the requirements here. if (transferWriteOp.hasOutOfBoundsDim()) return failure(); @@ -697,10 +704,9 @@ class FlattenContiguousRowMajorTransferWritePattern if (transferWriteOp.getMask()) return failure(); - SmallVector collapsedIndices = - getCollapsedIndices(rewriter, loc, sourceType.getShape(), - transferWriteOp.getIndices(), firstDimToCollapse); + int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank(); + // 1. Collapse the source memref Value collapsedSource = collapseInnerDims(rewriter, loc, source, firstDimToCollapse); MemRefType collapsedSourceType = @@ -708,11 +714,20 @@ class FlattenContiguousRowMajorTransferWritePattern int64_t collapsedRank = collapsedSourceType.getRank(); assert(collapsedRank == firstDimToCollapse + 1); + // 2. Generate input args for a new vector.transfer_read that will read + // from the collapsed memref. + // 2.1. New dim exprs + affine map SmallVector dimExprs{ getAffineDimExpr(firstDimToCollapse, rewriter.getContext())}; auto collapsedMap = AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext()); + // 2.2 New indices + SmallVector collapsedIndices = + getCollapsedIndices(rewriter, loc, sourceType.getShape(), + transferWriteOp.getIndices(), firstDimToCollapse); + + // 3. Create new vector.transfer_write that writes to the collapsed memref VectorType flatVectorType = VectorType::get({vectorType.getNumElements()}, vectorType.getElementType()); Value flatVector = @@ -721,6 +736,9 @@ class FlattenContiguousRowMajorTransferWritePattern rewriter.create( loc, flatVector, collapsedSource, collapsedIndices, collapsedMap); flatWrite.setInBoundsAttr(rewriter.getBoolArrayAttr({true})); + + // 4. Replace the old transfer_write with the new one writing the + // collapsed shape rewriter.eraseOp(transferWriteOp); return success(); } diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir index d7365d25d21b4..e83306a5089ff 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir @@ -1,17 +1,23 @@ // RUN: mlir-opt %s -test-vector-transfer-flatten-patterns -split-input-file | FileCheck %s // RUN: mlir-opt %s -test-vector-transfer-flatten-patterns=target-vector-bitwidth=128 -split-input-file | FileCheck %s --check-prefix=CHECK-128B +///---------------------------------------------------------------------------------------- +/// vector.transfer_read +/// [Pattern: FlattenContiguousRowMajorTransferReadPattern] +///---------------------------------------------------------------------------------------- + func.func @transfer_read_dims_match_contiguous( - %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> { - %c0 = arith.constant 0 : index - %cst = arith.constant 0 : i8 - %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst : - memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<5x4x3x2xi8> - return %v : vector<5x4x3x2xi8> + %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> { + + %c0 = arith.constant 0 : index + %cst = arith.constant 0 : i8 + %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst : + memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<5x4x3x2xi8> + return %v : vector<5x4x3x2xi8> } // CHECK-LABEL: func @transfer_read_dims_match_contiguous -// CHECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8 +// CHECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8 // CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3] // CHECK: %[[READ1D:.+]] = vector.transfer_read %[[COLLAPSED]] // CHECK: %[[VEC2D:.+]] = vector.shape_cast %[[READ1D]] : vector<120xi8> to vector<5x4x3x2xi8> @@ -24,11 +30,12 @@ func.func @transfer_read_dims_match_contiguous( func.func @transfer_read_dims_match_contiguous_empty_stride( %arg : memref<5x4x3x2xi8>) -> vector<5x4x3x2xi8> { - %c0 = arith.constant 0 : index - %cst = arith.constant 0 : i8 - %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst : - memref<5x4x3x2xi8>, vector<5x4x3x2xi8> - return %v : vector<5x4x3x2xi8> + + %c0 = arith.constant 0 : index + %cst = arith.constant 0 : i8 + %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst : + memref<5x4x3x2xi8>, vector<5x4x3x2xi8> + return %v : vector<5x4x3x2xi8> } // CHECK-LABEL: func @transfer_read_dims_match_contiguous_empty_stride( @@ -47,16 +54,17 @@ func.func @transfer_read_dims_match_contiguous_empty_stride( // contiguous subset of the memref, so "flattenable". func.func @transfer_read_dims_mismatch_contiguous( - %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> { - %c0 = arith.constant 0 : index - %cst = arith.constant 0 : i8 - %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst : - memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<1x1x2x2xi8> - return %v : vector<1x1x2x2xi8> + %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> { + + %c0 = arith.constant 0 : index + %cst = arith.constant 0 : i8 + %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst : + memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<1x1x2x2xi8> + return %v : vector<1x1x2x2xi8> } // CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous( -// CHECK-SAME: %[[VAL_0:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> { +// CHECK-SAME: %[[VAL_0:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> { // CHECK: %[[VAL_1:.*]] = arith.constant 0 : i8 // CHECK: %[[VAL_2:.*]] = arith.constant 0 : index // CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, strided<[1], offset: ?>> @@ -70,135 +78,160 @@ func.func @transfer_read_dims_mismatch_contiguous( // ----- func.func @transfer_read_dims_mismatch_non_zero_indices( - %idx_1: index, - %idx_2: index, - %m_in: memref<1x43x4x6xi32>, - %m_out: memref<1x2x6xi32>) { + %idx_1: index, + %idx_2: index, + %arg: memref<1x43x4x6xi32>) -> vector<1x2x6xi32>{ + %c0 = arith.constant 0 : index %c0_i32 = arith.constant 0 : i32 - %2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} : + %v = vector.transfer_read %arg[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} : memref<1x43x4x6xi32>, vector<1x2x6xi32> - vector.transfer_write %2, %m_out[%c0, %c0, %c0] {in_bounds = [true, true, true]} : - vector<1x2x6xi32>, memref<1x2x6xi32> - return + return %v : vector<1x2x6xi32> } // CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)> // CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_zero_indices( // CHECK-SAME: %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index, -// CHECK-SAME: %[[M_IN:.*]]: memref<1x43x4x6xi32>, -// CHECK-SAME: %[[M_OUT:.*]]: memref<1x2x6xi32>) { +// CHECK-SAME: %[[M_IN:.*]]: memref<1x43x4x6xi32> // CHECK: %[[C_0:.*]] = arith.constant 0 : i32 // CHECK: %[[C_0_IDX:.*]] = arith.constant 0 : index // CHECK: %[[COLLAPSED_IN:.*]] = memref.collapse_shape %[[M_IN]] {{\[}}[0], [1, 2, 3]] : memref<1x43x4x6xi32> into memref<1x1032xi32> // CHECK: %[[COLLAPSED_IDX:.*]] = affine.apply #[[$ATTR_0]]()[%[[IDX_1]], %[[IDX_2]]] // CHECK: %[[READ:.*]] = vector.transfer_read %[[COLLAPSED_IN]][%[[C_0_IDX]], %[[COLLAPSED_IDX]]], %[[C_0]] {in_bounds = [true]} : memref<1x1032xi32>, vector<12xi32> -// CHECK: %[[COLLAPSED_OUT:.*]] = memref.collapse_shape %[[M_OUT]] {{\[}}[0, 1, 2]] : memref<1x2x6xi32> into memref<12xi32> -// CHECK: vector.transfer_write %[[READ]], %[[COLLAPSED_OUT]][%[[C_0_IDX]]] {in_bounds = [true]} : vector<12xi32>, memref<12xi32> // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_zero_indices( // CHECK-128B-NOT: memref.collapse_shape // ----- +// Overall, the source memref is non-contiguous. However, the slice from which +// the output vector is to be read _is_ contiguous. Hence the flattening works fine. + func.func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices( - %subview : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>, - %idx0 : index, %idx1 : index) -> vector<2x2xf32> { + %arg : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>, + %idx0 : index, + %idx1 : index) -> vector<2x2xf32> { + %c0 = arith.constant 0 : index %cst_1 = arith.constant 0.000000e+00 : f32 - %8 = vector.transfer_read %subview[%c0, %idx0, %idx1, %c0], %cst_1 {in_bounds = [true, true]} : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>, vector<2x2xf32> + %8 = vector.transfer_read %arg[%c0, %idx0, %idx1, %c0], %cst_1 {in_bounds = [true, true]} : + memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>, vector<2x2xf32> return %8 : vector<2x2xf32> } -// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2)> +// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2)> + // CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices( -// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> into memref<1x3x6xf32, strided<[40, 10, 1], offset: ?>> -// CHECK: %[[APPLY:.*]] = affine.apply #[[$MAP]]() +// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> into memref<1x3x6xf32, strided<[40, 10, 1], offset: ?>> +// CHECK: %[[APPLY:.*]] = affine.apply #[[$MAP]]() // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices( // CHECK-128B: memref.collapse_shape // ----- +func.func @transfer_read_dims_mismatch_non_contiguous( + %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<2x1x2x2xi8> { + + %c0 = arith.constant 0 : index + %cst = arith.constant 0 : i8 + %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst : + memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<2x1x2x2xi8> + return %v : vector<2x1x2x2xi8> +} + +// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous +// CHECK-NOT: memref.collapse_shape +// CHECK-NOT: vector.shape_cast + +// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous( +// CHECK-128B-NOT: memref.collapse_shape + +// ----- + // The input memref has a dynamic trailing shape and hence is not flattened. // TODO: This case could be supported via memref.dim func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes( - %idx_1: index, - %idx_2: index, - %m_in: memref<1x?x4x6xi32>, - %m_out: memref<1x2x6xi32>) { + %idx_1: index, + %idx_2: index, + %m_in: memref<1x?x4x6xi32>) -> vector<1x2x6xi32> { + %c0 = arith.constant 0 : index %c0_i32 = arith.constant 0 : i32 - %2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} : + %v = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} : memref<1x?x4x6xi32>, vector<1x2x6xi32> - vector.transfer_write %2, %m_out[%c0, %c0, %c0] {in_bounds = [true, true, true]} : - vector<1x2x6xi32>, memref<1x2x6xi32> - return + return %v : vector<1x2x6xi32> } -// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes( -// CHECK-SAME: %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index, -// CHECK-SAME: %[[M_IN:.*]]: memref<1x?x4x6xi32>, -// CHECK-SAME: %[[M_OUT:.*]]: memref<1x2x6xi32>) { -// CHECK: %[[READ:.*]] = vector.transfer_read %[[M_IN]]{{.*}} : memref<1x?x4x6xi32>, vector<1x2x6xi32> -// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[M_OUT]]{{.*}} : memref<1x2x6xi32> into memref<12xi32> -// CHECK: %[[SC:.*]] = vector.shape_cast %[[READ]] : vector<1x2x6xi32> to vector<12xi32> -// CHECK: vector.transfer_write %[[SC]], %[[COLLAPSED]]{{.*}} : vector<12xi32>, memref<12xi32> +// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes( +// CHECK-NOT: memref.collapse_shape +// CHECK-NOT: vector.shape_cast // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes( // CHECK-128B-NOT: memref.collapse_shape // ----- -func.func @transfer_read_dims_mismatch_non_contiguous( - %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<2x1x2x2xi8> { - %c0 = arith.constant 0 : index - %cst = arith.constant 0 : i8 - %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst : - memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<2x1x2x2xi8> - return %v : vector<2x1x2x2xi8> +// The vector to be read represents a _non-contiguous_ slice of the input +// memref. + +func.func @transfer_read_dims_mismatch_non_contiguous_slice( + %arg : memref<5x4x3x2xi8>) -> vector<2x1x2x2xi8> { + + %c0 = arith.constant 0 : index + %cst = arith.constant 0 : i8 + %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst : + memref<5x4x3x2xi8>, vector<2x1x2x2xi8> + return %v : vector<2x1x2x2xi8> } -// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous +// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous_slice( // CHECK-NOT: memref.collapse_shape // CHECK-NOT: vector.shape_cast -// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous( +// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous_slice( // CHECK-128B-NOT: memref.collapse_shape // ----- -func.func @transfer_read_dims_mismatch_non_contiguous_empty_stride( - %arg : memref<5x4x3x2xi8>) -> vector<2x1x2x2xi8> { - %c0 = arith.constant 0 : index - %cst = arith.constant 0 : i8 - %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst : - memref<5x4x3x2xi8>, vector<2x1x2x2xi8> - return %v : vector<2x1x2x2xi8> +func.func @transfer_read_0d( + %arg : memref) -> vector { + + %cst = arith.constant 0 : i8 + %0 = vector.transfer_read %arg[], %cst : memref, vector + return %0 : vector } -// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous_empty_stride( +// CHECK-LABEL: func.func @transfer_read_0d // CHECK-NOT: memref.collapse_shape // CHECK-NOT: vector.shape_cast -// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous_empty_stride( +// CHECK-128B-LABEL: func @transfer_read_0d( // CHECK-128B-NOT: memref.collapse_shape +// CHECK-128B-NOT: vector.shape_cast // ----- +///---------------------------------------------------------------------------------------- +/// vector.transfer_write +/// [Pattern: FlattenContiguousRowMajorTransferWritePattern] +///---------------------------------------------------------------------------------------- + func.func @transfer_write_dims_match_contiguous( - %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, %vec : vector<5x4x3x2xi8>) { - %c0 = arith.constant 0 : index - vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] : - vector<5x4x3x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> - return + %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, + %vec : vector<5x4x3x2xi8>) { + + %c0 = arith.constant 0 : index + vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] : + vector<5x4x3x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> + return } // CHECK-LABEL: func @transfer_write_dims_match_contiguous( -// CHECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8 -// CHECK-SAME: %[[VEC:[0-9a-zA-Z]+]]: vector<5x4x3x2xi8> +// CHECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8 +// CHECK-SAME: %[[VEC:[0-9a-zA-Z]+]]: vector<5x4x3x2xi8> // CHECK-DAG: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]{{.}} : memref<5x4x3x2xi8, {{.+}}> into memref<120xi8, {{.+}}> // CHECK-DAG: %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<5x4x3x2xi8> to vector<120xi8> // CHECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]] @@ -208,42 +241,101 @@ func.func @transfer_write_dims_match_contiguous( // ----- +func.func @transfer_write_dims_match_contiguous_empty_stride( + %arg : memref<5x4x3x2xi8>, + %vec : vector<5x4x3x2xi8>) { + + %c0 = arith.constant 0 : index + vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] : + vector<5x4x3x2xi8>, memref<5x4x3x2xi8> + return +} + +// CHECK-LABEL: func @transfer_write_dims_match_contiguous_empty_stride( +// CHECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8 +// CHECK-SAME: %[[VEC:[0-9a-zA-Z]+]]: vector<5x4x3x2xi8> +// CHECK-DAG: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]{{.}} : memref<5x4x3x2xi8> into memref<120xi8> +// CHECK-DAG: %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<5x4x3x2xi8> to vector<120xi8> +// CHECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]] + +// CHECK-128B-LABEL: func @transfer_write_dims_match_contiguous_empty_stride( +// CHECK-128B: memref.collapse_shape + +// ----- + func.func @transfer_write_dims_mismatch_contiguous( - %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, %vec : vector<1x1x2x2xi8>) { - %c0 = arith.constant 0 : index - vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] : - vector<1x1x2x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> - return + %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, + %vec : vector<1x1x2x2xi8>) { + + %c0 = arith.constant 0 : index + vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] : + vector<1x1x2x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> + return } // CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous -// CHECK-SAME: %[[VAL_0:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, -// CHECK-SAME: %[[VAL_1:.*]]: vector<1x1x2x2xi8>) { +// CHECK-SAME: %[[VAL_0:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, +// CHECK-SAME: %[[VAL_1:.*]]: vector<1x1x2x2xi8>) { // CHECK: %[[VAL_2:.*]] = arith.constant 0 : index // CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, strided<[1], offset: ?>> // CHECK: %[[VAL_4:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x1x2x2xi8> to vector<4xi8> // CHECK: vector.transfer_write %[[VAL_4]], %[[VAL_3]]{{\[}}%[[VAL_2]]] {in_bounds = [true]} : vector<4xi8>, memref<120xi8, strided<[1], offset: ?>> -// CHECK: return -// CHECK: } // CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous( // CHECK-128B: memref.collapse_shape // ----- +func.func @transfer_write_dims_mismatch_non_zero_indices( + %idx_1: index, + %idx_2: index, + %arg: memref<1x43x4x6xi32>, + %vec: vector<1x2x6xi32>) { + + %c0 = arith.constant 0 : index + %c0_i32 = arith.constant 0 : i32 + vector.transfer_write %vec, %arg[%c0, %idx_1, %idx_2, %c0] {in_bounds = [true, true, true]} : + vector<1x2x6xi32>, memref<1x43x4x6xi32> + return +} + +// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)> + +// CHECK-LABEL: func.func @transfer_write_dims_mismatch_non_zero_indices( +// CHECK-SAME: %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index, +// CHECK-SAME: %[[ARG:.*]]: memref<1x43x4x6xi32>, +// CHECK-SAME: %[[VEC:.*]]: vector<1x2x6xi32>) { +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[IDX:.*]] = affine.apply #[[$ATTR_0]](){{\[}}%[[IDX_1]], %[[IDX_2]]] +// CHECK-DAG: %[[CS:.*]] = memref.collapse_shape %[[ARG]] {{\[\[}}0], [1, 2, 3]] : memref<1x43x4x6xi32> into memref<1x1032xi32> +// CHECK: %[[SC:.*]] = vector.shape_cast %[[VEC]] : vector<1x2x6xi32> to vector<12xi32> +// CHECK: vector.transfer_write %[[SC]], %[[CS]]{{\[}}%[[C0]], %[[IDX]]] {in_bounds = [true]} : vector<12xi32>, memref<1x1032xi32> + +// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_non_zero_indices( +// CHECK-128B-NOT: memref.collapse_shape + +// ----- + +// Overall, the destination memref is non-contiguous. However, the slice to +// which the input vector is to be written _is_ contiguous. Hence the +// flattening works fine. + func.func @transfer_write_dims_mismatch_non_contiguous_non_zero_indices( %value : vector<2x2xf32>, %subview : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>, - %idx0 : index, %idx1 : index) { + %idx0 : index, + %idx1 : index) { + %c0 = arith.constant 0 : index vector.transfer_write %value, %subview[%c0, %idx0, %idx1, %c0] {in_bounds = [true, true]} : vector<2x2xf32>, memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> return } -// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2)> +// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2)> + // CHECK-LABEL: func.func @transfer_write_dims_mismatch_non_contiguous_non_zero_indices( -// CHECK: %[[APPLY:.*]] = affine.apply #[[$MAP]]() -// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> into memref<1x3x6xf32, strided<[40, 10, 1], offset: ?>> +// CHECK-DAG: %[[APPLY:.*]] = affine.apply #[[$MAP]]() +// CHECK-DAG: %[[COLLAPSE:.+]] = memref.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> into memref<1x3x6xf32, strided<[40, 10, 1], offset: ?>> // CHECK-128B-LABEL: func @transfer_write_dims_mismatch_non_contiguous_non_zero_indices( // CHECK-128B: memref.collapse_shape @@ -251,11 +343,13 @@ func.func @transfer_write_dims_mismatch_non_contiguous_non_zero_indices( // ----- func.func @transfer_write_dims_mismatch_non_contiguous( - %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, %vec : vector<2x1x2x2xi8>) { - %c0 = arith.constant 0 : index - vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] : - vector<2x1x2x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> - return + %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, + %vec : vector<2x1x2x2xi8>) { + + %c0 = arith.constant 0 : index + vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] : + vector<2x1x2x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> + return } // CHECK-LABEL: func.func @transfer_write_dims_mismatch_non_contiguous @@ -267,37 +361,76 @@ func.func @transfer_write_dims_mismatch_non_contiguous( // ----- -func.func @transfer_write_0d(%arg : memref, %vec : vector) { - vector.transfer_write %vec, %arg[] : vector, memref - return +// The input memref has a dynamic trailing shape and hence is not flattened. +// TODO: This case could be supported via memref.dim + +func.func @transfer_write_dims_mismatch_non_zero_indices_dynamic_shapes( + %idx_1: index, + %idx_2: index, + %vec : vector<1x2x6xi32>, + %arg: memref<1x?x4x6xi32>) { + + %c0 = arith.constant 0 : index + %c0_i32 = arith.constant 0 : i32 + vector.transfer_write %vec, %arg[%c0, %idx_1, %idx_2, %c0] {in_bounds = [true, true, true]} : + vector<1x2x6xi32>, memref<1x?x4x6xi32> + return } -// CHECK-LABEL: func.func @transfer_write_0d +// CHECK-LABEL: func.func @transfer_write_dims_mismatch_non_zero_indices_dynamic_shapes( // CHECK-NOT: memref.collapse_shape // CHECK-NOT: vector.shape_cast -// CHECK-128B-LABEL: func @transfer_write_0d( +// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_non_zero_indices_dynamic_shapes( // CHECK-128B-NOT: memref.collapse_shape -// CHECK-128B-NOT: vector.shape_cast // ----- -func.func @transfer_read_0d(%arg : memref) -> vector { - %cst = arith.constant 0 : i8 - %0 = vector.transfer_read %arg[], %cst : memref, vector - return %0 : vector +// The vector to be written represents a _non-contiguous_ slice of the output +// memref. + +func.func @transfer_write_dims_mismatch_non_contiguous_slice( + %arg : memref<5x4x3x2xi8>, + %vec : vector<2x1x2x2xi8>) { + + %c0 = arith.constant 0 : index + %cst = arith.constant 0 : i8 + vector.transfer_write %vec, %arg[%c0, %c0, %c0, %c0] : + vector<2x1x2x2xi8>, memref<5x4x3x2xi8> + return } -// CHECK-LABEL: func.func @transfer_read_0d +// CHECK-LABEL: func.func @transfer_write_dims_mismatch_non_contiguous_slice( // CHECK-NOT: memref.collapse_shape // CHECK-NOT: vector.shape_cast -// CHECK-128B-LABEL: func @transfer_read_0d( +// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_non_contiguous_slice( +// CHECK-128B-NOT: memref.collapse_shape + +// ----- + +func.func @transfer_write_0d( + %arg : memref, + %vec : vector) { + + vector.transfer_write %vec, %arg[] : vector, memref + return +} + +// CHECK-LABEL: func.func @transfer_write_0d +// CHECK-NOT: memref.collapse_shape +// CHECK-NOT: vector.shape_cast + +// CHECK-128B-LABEL: func @transfer_write_0d( // CHECK-128B-NOT: memref.collapse_shape // CHECK-128B-NOT: vector.shape_cast // ----- +///---------------------------------------------------------------------------------------- +/// TODO: Categorize + re-format +///---------------------------------------------------------------------------------------- + func.func @transfer_read_flattenable_with_dynamic_dims_and_indices(%arg0 : memref>, %arg1 : index, %arg2 : index) -> vector<8x4xi8> { %c0_i8 = arith.constant 0 : i8 %c0 = arith.constant 0 : index