diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index 3d0d6abf702d7..63dcda78d0f2b 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -814,11 +814,9 @@ enum VectorMemoryAccessKind { ScalarBroadcast, Contiguous, Gather }; static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val) { auto targetShape = linalgOp.getStaticLoopRanges(); - assert(((llvm::count_if(targetShape, - [](int64_t dimSize) { return dimSize > 1; }) == 1)) && + assert(llvm::count_if(targetShape, + [](int64_t dimSize) { return dimSize > 1; }) == 1 && "n-D vectors are not yet supported"); - assert(targetShape.back() != 1 && - "1-D vectors with the trailing dim eqaual 1 are not yet supported"); // Blocks outside _this_ linalg.generic are effectively loop invariant. // However, analysing block arguments for _this_ linalg.generic Op is a bit @@ -879,8 +877,6 @@ static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, assert(((llvm::count_if(targetShape, [](int64_t dimSize) { return dimSize > 1; }) == 1)) && "n-D vectors are not yet supported"); - assert(targetShape.back() != 1 && - "1-D vectors with the trailing dim 1 are not yet supported"); // Blocks outside _this_ linalg.generic are effectively loop invariant. // However, analysing block arguments for _this_ linalg.generic Op is a bit @@ -946,27 +942,22 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, if (linalgOp.hasDynamicShape()) return VectorMemoryAccessKind::Gather; - // 1. Assume that it's a gather load when reading _into_: - // * an n-D "vector", like `tensor<1x2x4xi32` or `tensor<2x1x4xi32>`, or - // * a 1-D "vector" with the trailing dim equal 1, e.g. `tensor<1x4x1xi32`. - // TODO: Relax these conditions. - // FIXME: This condition assumes non-dynamic sizes. - if ((llvm::count_if(targetShape, - [](int64_t dimSize) { return dimSize > 1; }) != 1) || - targetShape.back() == 1) - return VectorMemoryAccessKind::Gather; + // True for vectors that are effectively 1D, e.g. `vector<1x4x1xi32>`, false + // otherwise. + bool isOutput1DVector = (llvm::count_if(targetShape, [](int64_t dimSize) { + return dimSize > 1; + }) == 1); - // 2. Assume that it's a gather load when reading _from_ a tensor for which - // the trailing dimension is 1, e.g. `tensor<1x4x1xi32>`. - // TODO: Relax this condition. - if (inputShape.getShape().back() == 1) + // 1. Assume that it's a gather load when reading non-1D vector. + if (!isOutput1DVector) return VectorMemoryAccessKind::Gather; bool leadingIdxsLoopInvariant = true; - // 3. Analyze the leading indices of `extractOp`. + // 2. Analyze the leading indices of `extractOp`. // Look at the way each index is calculated and decide whether it is suitable - // for a contiguous load, i.e. whether it's loop invariant. + // for a contiguous load, i.e. whether it's loop invariant. If not, it's a + // gather load. auto indices = extractOp.getIndices(); auto leadIndices = indices.drop_back(1); @@ -982,13 +973,13 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, return VectorMemoryAccessKind::Gather; } - // 4. Analyze the trailing index for `extractOp`. + // 3. Analyze the trailing index for `extractOp`. // At this point we know that the leading indices are loop invariant. This // means that is potentially a scalar or a contiguous load. We can decide // based on the trailing idx. auto extractOpTrailingIdx = indices.back(); - // 4a. Scalar broadcast load + // 3a. Scalar broadcast load // If the trailing index is loop invariant then this is a scalar load. if (leadingIdxsLoopInvariant && isLoopInvariantIdx(linalgOp, extractOpTrailingIdx)) { @@ -997,7 +988,7 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, return VectorMemoryAccessKind::ScalarBroadcast; } - // 4b. Contiguous loads + // 3b. Contiguous loads // The trailing `extractOp` index should increment with every loop iteration. // This effectively means that it must be based on the trailing loop index. // This is what the following bool captures. @@ -1011,7 +1002,7 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, return VectorMemoryAccessKind::Contiguous; } - // 5. Fallback case - gather load. + // 4. Fallback case - gather load. LDBG("Found gather load: " << extractOp); return VectorMemoryAccessKind::Gather; } diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir index 85e1c56dd45a0..bdaa20c3bf971 100644 --- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir +++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir @@ -37,6 +37,7 @@ module attributes {transform.with_named_sequence} { } // ----- + #map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> func.func @vectorize_nd_tensor_extract_constant_idx(%arg0: tensor<3x3xf32>, %arg2: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> { %c0 = arith.constant 1 : index @@ -74,20 +75,24 @@ module attributes {transform.with_named_sequence} { // ----- -#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -func.func @vectorize_nd_tensor_extract_transfer_read_basic(%arg0: tensor<3x3x3xf32>, %arg2: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> { - %1 = linalg.generic { - indexing_maps = [#map1], +#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +func.func @vectorize_nd_tensor_extract_transfer_read_basic( + %arg0: tensor<3x3x3xf32>, + %arg1: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> { + + %res = linalg.generic { + indexing_maps = [#map], iterator_types = ["parallel", "parallel", "parallel"] - } outs(%arg2 : tensor<1x1x3xf32>) { - ^bb0(%arg4: f32): - %2 = linalg.index 0 : index - %3 = linalg.index 1 : index - %4 = linalg.index 2 : index - %5 = tensor.extract %arg0[%2, %3, %4] : tensor<3x3x3xf32> - linalg.yield %5 : f32 + } outs(%arg1 : tensor<1x1x3xf32>) { + ^bb0(%out: f32): + %1 = linalg.index 0 : index + %2 = linalg.index 1 : index + %3 = linalg.index 2 : index + %4 = tensor.extract %arg0[%1, %2, %3] : tensor<3x3x3xf32> + linalg.yield %4 : f32 } -> tensor<1x1x3xf32> - return %1 : tensor<1x1x3xf32> + + return %res : tensor<1x1x3xf32> } // CHECK-LABEL: func.func @vectorize_nd_tensor_extract_transfer_read_basic @@ -104,6 +109,38 @@ func.func @vectorize_nd_tensor_extract_transfer_read_basic(%arg0: tensor<3x3x3xf // CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[IDX1]], %[[IDX2]], %[[C0:.*]]], %[[CST_0]] {in_bounds = [true, true, true]} : tensor<3x3x3xf32>, vector<1x1x3xf32> // CHECK: vector.transfer_write %[[READ]], %[[ARG1]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x3xf32>, tensor<1x1x3xf32> +// Same as example above, but reading into a column tensor. Note that after the +// vectorizatoin, the `TransferOpReduceRank` will replace +// `vector.transfer_read` with `tensor.extract -> scalar`. + +// TODO: Currently this fails to vectorise when the indices are non-constant. + +func.func @vectorize_nd_tensor_extract_transfer_read_basic_column( + %input: tensor<3x3x3xf32>, + %output: tensor<3x1x1xf32>) -> tensor<3x1x1xf32> { + + %c0 = arith.constant 0 : index + %res = linalg.generic { + indexing_maps = [#map], + iterator_types = ["parallel", "parallel", "parallel"] + } outs(%output : tensor<3x1x1xf32>) { + ^bb0(%out: f32): + %5 = tensor.extract %input[%c0, %c0, %c0] : tensor<3x3x3xf32> + linalg.yield %5 : f32 + } -> tensor<3x1x1xf32> + + return %res : tensor<3x1x1xf32> +} + +// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_transfer_read_basic_column( +// CHECK-SAME: %[[INPUT:.*]]: tensor<3x3x3xf32>, +// CHECK-SAME: %[[OUTPUT:.*]]: tensor<3x1x1xf32>) +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[EXTRACT:.*]] = tensor.extract %[[INPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] : tensor<3x3x3xf32> +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[EXTRACT]] : f32 to vector<3x1x1xf32> +// CHECK: %[[RES:.*]] = vector.transfer_write %[[BCAST]], %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<3x1x1xf32>, tensor<3x1x1xf32> +// CHECK: return %[[RES]] : tensor<3x1x1xf32> + module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op @@ -595,3 +632,59 @@ module attributes {transform.with_named_sequence} { transform.yield } } + + +// ----- + +func.func @vectorize_scalar_broadcast_column_tensor(%in: tensor<1x1x4xi32>) -> tensor<1x1x4xi32> { + %c4 = arith.constant 4 : index + %c0 = arith.constant 0 : index + %cst = arith.constant dense<[[0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14]]> : tensor<15x1xi32> + + %out = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} outs(%in : tensor<1x1x4xi32>) { + ^bb0(%out: i32): + %8 = linalg.index 0 : index + %idx_0 = linalg.index 0 : index + %extracted = tensor.extract %cst[%idx_0, %c0] : tensor<15x1xi32> + linalg.yield %extracted : i32 + } -> tensor<1x1x4xi32> + + return %out:tensor<1x1x4xi32> +} + +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1) -> (0, 0, 0)> +// CHECK-LABEL: func.func @vectorize_scalar_broadcast_column_tensor( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x4xi32>) -> tensor<1x1x4xi32> { +// CHECK: %[[VAL_1:.*]] = arith.constant 4 : index +// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_3:.*]] = arith.constant dense<{{\[\[}}0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14]]> : tensor<15x1xi32> +// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_5:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_6:.*]] = arith.constant 4 : index +// CHECK: %[[VAL_7:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_8:.*]] = arith.constant 0 : i32 +// CHECK: %[[VAL_9:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_7]], %[[VAL_7]], %[[VAL_7]]], %[[VAL_8]] : tensor<1x1x4xi32>, vector<1x1x4xi32> +// CHECK: %[[VAL_10:.*]] = vector.step : vector<1xindex> +// CHECK: %[[VAL_11:.*]] = vector.broadcast %[[VAL_10]] : vector<1xindex> to vector<4x1x1xindex> +// CHECK: %[[VAL_12:.*]] = vector.transpose %[[VAL_11]], [2, 1, 0] : vector<4x1x1xindex> to vector<1x1x4xindex> +// CHECK: %[[VAL_13:.*]] = vector.step : vector<1xindex> +// CHECK: %[[VAL_14:.*]] = vector.broadcast %[[VAL_13]] : vector<1xindex> to vector<4x1x1xindex> +// CHECK: %[[VAL_15:.*]] = vector.transpose %[[VAL_14]], [2, 1, 0] : vector<4x1x1xindex> to vector<1x1x4xindex> +// CHECK: %[[VAL_16:.*]] = arith.constant dense : vector<1x1x4xi1> +// CHECK: %[[VAL_17:.*]] = arith.constant dense<0> : vector<1x1x4xi32> +// CHECK: %[[VAL_18:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_19:.*]] = arith.constant 0 : i32 +// CHECK: %[[VAL_20:.*]] = vector.shape_cast %[[VAL_15]] : vector<1x1x4xindex> to vector<4xindex> +// CHECK: %[[VAL_21:.*]] = vector.extractelement %[[VAL_20]]{{\[}}%[[VAL_19]] : i32] : vector<4xindex> +// CHECK: %[[VAL_22:.*]] = arith.constant 0 : i32 +// CHECK: %[[VAL_23:.*]] = vector.transfer_read %[[VAL_3]]{{\[}}%[[VAL_21]], %[[VAL_2]]], %[[VAL_22]] {in_bounds = [true, true, true], permutation_map = #[[$ATTR_1]]} : tensor<15x1xi32>, vector<1x1x4xi32> +// CHECK: %[[VAL_24:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_25:.*]] = vector.transfer_write %[[VAL_23]], %[[VAL_0]]{{\[}}%[[VAL_24]], %[[VAL_24]], %[[VAL_24]]] : vector<1x1x4xi32>, tensor<1x1x4xi32> + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %0 vector_sizes [1, 1, 4]{ vectorize_nd_extract } : !transform.any_op + transform.yield + } +}