diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index 63dcda78d0f2b..a376afa5ddab1 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -810,12 +810,12 @@ static Value calculateGatherOffset(RewriterBase &rewriter, enum VectorMemoryAccessKind { ScalarBroadcast, Contiguous, Gather }; -/// Checks whether /p val can be used for calculating a loop invariant index. -static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val) { +/// Checks whether `val` can be used for calculating a loop invariant index. +static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val, + VectorType resType) { - auto targetShape = linalgOp.getStaticLoopRanges(); - assert(llvm::count_if(targetShape, - [](int64_t dimSize) { return dimSize > 1; }) == 1 && + assert(((llvm::count_if(resType.getShape(), + [](int64_t dimSize) { return dimSize > 1; }) == 1)) && "n-D vectors are not yet supported"); // Blocks outside _this_ linalg.generic are effectively loop invariant. @@ -849,7 +849,7 @@ static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val) { bool result = true; for (auto op : ancestor->getOperands()) - result &= isLoopInvariantIdx(linalgOp, op); + result &= isLoopInvariantIdx(linalgOp, op, resType); return result; } @@ -871,10 +871,9 @@ static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val) { /// where is the trailing dim of the iteration space. \p foundIndexOp is /// updated to `true` when such an op is found. static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, - bool &foundIndexOp) { + bool &foundIndexOp, VectorType resType) { - auto targetShape = linalgOp.getStaticLoopRanges(); - assert(((llvm::count_if(targetShape, + assert(((llvm::count_if(resType.getShape(), [](int64_t dimSize) { return dimSize > 1; }) == 1)) && "n-D vectors are not yet supported"); @@ -910,44 +909,38 @@ static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, bool result = false; for (auto op : ancestor->getOperands()) - result |= isContiguousLoadIdx(linalgOp, op, foundIndexOp); + result |= isContiguousLoadIdx(linalgOp, op, foundIndexOp, resType); return result; } /// Infer the memory access pattern for the input ExtractOp /// -/// Based on the operation shapes and indices (usually based on the iteration -/// space of the parent `linalgOp` operation), decides whether the input -/// ExtractOp is a contiguous load (including a broadcast of a scalar) or a -/// gather load. +/// Based on the ExtratOp result shape and the access indices, decides whether +/// this Op corresponds to a contiguous load (including a broadcast of a scalar) +/// or a gather load. When analysing the ExtractOp indices (to identify +/// contiguous laods), this method looks for "loop" invariant indices (e.g. +/// block arguments) and indices that change linearly (e.g. via `linalg.index` +/// Op). /// /// Note that it is always safe to use gather load operations for contiguous /// loads (albeit slow), but not vice-versa. When in doubt, bail out and assume /// that `extractOp` is a gather load. static VectorMemoryAccessKind getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, - LinalgOp &linalgOp) { + LinalgOp &linalgOp, VectorType resType) { - auto targetShape = linalgOp.getStaticLoopRanges(); auto inputShape = cast(extractOp.getTensor().getType()); - // 0.1 Is this a 0-D vector? If yes then this is a scalar broadcast. + // 0. Is this a 0-D vector? If yes then this is a scalar broadcast. if (inputShape.getShape().empty()) return VectorMemoryAccessKind::ScalarBroadcast; - // 0.2 In the case of dynamic shapes just bail-out and assume that it's a - // gather load. - // TODO: Relax this condition. - if (linalgOp.hasDynamicShape()) - return VectorMemoryAccessKind::Gather; - // True for vectors that are effectively 1D, e.g. `vector<1x4x1xi32>`, false // otherwise. - bool isOutput1DVector = (llvm::count_if(targetShape, [](int64_t dimSize) { - return dimSize > 1; - }) == 1); - + bool isOutput1DVector = + (llvm::count_if(resType.getShape(), + [](int64_t dimSize) { return dimSize > 1; }) == 1); // 1. Assume that it's a gather load when reading non-1D vector. if (!isOutput1DVector) return VectorMemoryAccessKind::Gather; @@ -965,7 +958,7 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, if (inputShape.getShape()[i] == 1) continue; - leadingIdxsLoopInvariant &= isLoopInvariantIdx(linalgOp, indexVal); + leadingIdxsLoopInvariant &= isLoopInvariantIdx(linalgOp, indexVal, resType); } if (!leadingIdxsLoopInvariant) { @@ -982,7 +975,7 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, // 3a. Scalar broadcast load // If the trailing index is loop invariant then this is a scalar load. if (leadingIdxsLoopInvariant && - isLoopInvariantIdx(linalgOp, extractOpTrailingIdx)) { + isLoopInvariantIdx(linalgOp, extractOpTrailingIdx, resType)) { LDBG("Found scalar broadcast load: " << extractOp); return VectorMemoryAccessKind::ScalarBroadcast; @@ -993,8 +986,8 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, // This effectively means that it must be based on the trailing loop index. // This is what the following bool captures. bool foundIndexOp = false; - bool isContiguousLoad = - isContiguousLoadIdx(linalgOp, extractOpTrailingIdx, foundIndexOp); + bool isContiguousLoad = isContiguousLoadIdx(linalgOp, extractOpTrailingIdx, + foundIndexOp, resType); isContiguousLoad &= foundIndexOp; if (isContiguousLoad) { @@ -1035,7 +1028,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, rewriter.create(loc, 0)); VectorMemoryAccessKind memAccessKind = - getTensorExtractMemoryAccessPattern(extractOp, linalgOp); + getTensorExtractMemoryAccessPattern(extractOp, linalgOp, resultType); // 1. Handle gather access if (memAccessKind == VectorMemoryAccessKind::Gather) { diff --git a/mlir/test/Dialect/Linalg/vectorization-scalable.mlir b/mlir/test/Dialect/Linalg/vectorization-scalable.mlir index 4ee3088cc3778..c3a30e3ee209e 100644 --- a/mlir/test/Dialect/Linalg/vectorization-scalable.mlir +++ b/mlir/test/Dialect/Linalg/vectorization-scalable.mlir @@ -162,17 +162,14 @@ func.func @vectorize_linalg_index(%arg0: tensor<3x3x?xf32>, %arg1: tensor<1x1x?x // CHECK-LABEL: @vectorize_linalg_index // CHECK-SAME: %[[SRC:.*]]: tensor<3x3x?xf32>, %[[DST:.*]]: tensor<1x1x?xf32> -// CHECK-DAG: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<1x1x[4]xf32> -// CHECK-DAG: %[[MASK:.*]] = arith.constant dense : vector<1x1x[4]xi1> // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index // CHECK: %[[DST_DIM2:.*]] = tensor.dim %[[DST]], %[[C2]] : tensor<1x1x?xf32> -// CHECK: %[[DST_MASK:.*]] = vector.create_mask %[[C1]], %[[C1]], %[[DST_DIM2]] : vector<1x1x[4]xi1> +// CHECK: %[[MASK:.*]] = vector.create_mask %[[C1]], %[[C1]], %[[DST_DIM2]] : vector<1x1x[4]xi1> // CHECK: %[[INDEX_VEC:.*]] = vector.step : vector<[4]xindex> -// CHECK: %[[INDEX_VEC_BCAST:.*]] = vector.broadcast %[[INDEX_VEC]] : vector<[4]xindex> to vector<1x1x[4]xindex> -// CHECK: %[[GATHER:.*]] = vector.mask %[[DST_MASK]] { vector.gather %[[SRC]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {{\[}}%[[INDEX_VEC_BCAST]]], %[[MASK]], %[[PASSTHRU]] : tensor<3x3x?xf32>, vector<1x1x[4]xindex>, vector<1x1x[4]xi1>, vector<1x1x[4]xf32> into vector<1x1x[4]xf32> } : vector<1x1x[4]xi1> -> vector<1x1x[4]xf32> -// CHECK: %[[OUT:.*]] = vector.mask %[[DST_MASK]] { vector.transfer_write %[[GATHER]], %[[DST]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x[4]xf32>, tensor<1x1x?xf32> } : vector<1x1x[4]xi1> -> tensor<1x1x?xf32> +// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC]][%c0, %c0, %2], %cst {in_bounds = [true, true, true]} : tensor<3x3x?xf32>, vector<1x1x[4]xf32> } : vector<1x1x[4]xi1> -> vector<1x1x[4]xf32> +// CHECK: %[[OUT:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[DST]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x[4]xf32>, tensor<1x1x?xf32> } : vector<1x1x[4]xi1> -> tensor<1x1x?xf32> // CHECK: return %[[OUT]] : tensor<1x1x?xf32> module attributes {transform.with_named_sequence} { diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir index 964565620fd01..31a754d934368 100644 --- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir +++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir @@ -120,52 +120,54 @@ module attributes {transform.with_named_sequence} { // ----- -func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous(%6: tensor, %arg0: index, %extracted_slice : tensor) -> tensor { +func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous( + %src: tensor, + %output : tensor, + %idx: index) -> tensor { + %c79 = arith.constant 79 : index %1 = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"] - } outs(%extracted_slice : tensor) { + } outs(%output : tensor) { ^bb0(%out: f32): %2 = linalg.index 1 : index - %3 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%2, %arg0) - %extracted = tensor.extract %6[%c79, %3] : tensor + %3 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%2, %idx) + %extracted = tensor.extract %src[%c79, %3] : tensor linalg.yield %extracted : f32 } -> tensor return %1 : tensor } // CHECK-LABEL: func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous( -// CHECK-SAME: %[[VAL_0:.*]]: tensor, -// CHECK-SAME: %[[VAL_1:.*]]: index, -// CHECK-SAME: %[[VAL_2:.*]]: tensor) -> tensor { -// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 79 : index -// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : index -// CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_2]], %[[VAL_4]] : tensor -// CHECK: %[[VAL_6:.*]] = arith.constant 1 : index -// CHECK: %[[VAL_7:.*]] = tensor.dim %[[VAL_2]], %[[VAL_6]] : tensor -// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[VAL_10:.*]] = vector.create_mask %[[VAL_5]], %[[VAL_7]] : vector<1x4xi1> -// CHECK: %[[VAL_11:.*]] = vector.mask %[[VAL_10]] { vector.transfer_read %[[VAL_2]]{{\[}}%[[VAL_8]], %[[VAL_8]]], %[[VAL_9]] {in_bounds = [true, true]} : tensor, vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32> -// CHECK: %[[VAL_12:.*]] = vector.step : vector<4xindex> -// CHECK: %[[VAL_13:.*]] = vector.broadcast %[[VAL_1]] : index to vector<4xindex> -// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_12]], %[[VAL_13]] : vector<4xindex> -// CHECK-DAG: %[[VAL_15:.*]] = arith.constant dense : vector<1x4xi1> -// CHECK-DAG: %[[VAL_16:.*]] = arith.constant dense<0.000000e+00> : vector<1x4xf32> -// CHECK-DAG: %[[VAL_17:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[VAL_18:.*]] = arith.constant dense<79> : vector<1x4xindex> -// CHECK-DAG: %[[VAL_19:.*]] = arith.constant 1 : index -// CHECK: %[[VAL_20:.*]] = tensor.dim %[[VAL_0]], %[[VAL_19]] : tensor -// CHECK: %[[VAL_21:.*]] = vector.broadcast %[[VAL_20]] : index to vector<1x4xindex> -// CHECK: %[[VAL_22:.*]] = arith.muli %[[VAL_18]], %[[VAL_21]] : vector<1x4xindex> -// CHECK: %[[VAL_23:.*]] = vector.broadcast %[[VAL_14]] : vector<4xindex> to vector<1x4xindex> -// CHECK: %[[VAL_24:.*]] = arith.addi %[[VAL_23]], %[[VAL_22]] : vector<1x4xindex> -// CHECK: %[[VAL_25:.*]] = vector.mask %[[VAL_10]] { vector.gather %[[VAL_0]]{{\[}}%[[VAL_17]], %[[VAL_17]]] {{\[}}%[[VAL_24]]], %[[VAL_15]], %[[VAL_16]] : tensor, vector<1x4xindex>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32> -// CHECK: %[[VAL_26:.*]] = arith.constant 0 : index -// CHECK: %[[VAL_27:.*]] = vector.mask %[[VAL_10]] { vector.transfer_write %[[VAL_25]], %[[VAL_2]]{{\[}}%[[VAL_26]], %[[VAL_26]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor } : vector<1x4xi1> -> tensor -// CHECK: return %[[VAL_27]] : tensor -// CHECK: } +// CHECK-SAME: %[[SRC:[a-zA-Z0-9]*]]: tensor, +// CHECK-SAME: %[[OUTPUT:[a-zA-Z0-9]*]]: tensor, +// CHECK-SAME: %[[IDX:.*]]: index) + +/// Create the mask +// CHECK: %[[C79:.*]] = arith.constant 79 : index +// CHECK: %[[DIM_0_IDX:.*]] = arith.constant 0 : index +// CHECK: %[[DIM_0:.*]] = tensor.dim %[[OUTPUT]], %[[DIM_0_IDX]] : tensor +// CHECK: %[[DIM_1_IDX:.*]] = arith.constant 1 : index +// CHECK: %[[DIM_1:.*]] = tensor.dim %[[OUTPUT]], %[[DIM_1_IDX]] : tensor +// CHECK: %[[MASK:.*]] = vector.create_mask %[[DIM_0]], %[[DIM_1]] : vector<1x4xi1> + +/// TODO: This transfer_read is redundant - remove +// CHECK: vector.mask %[[MASK]] { vector.transfer_read %[[OUTPUT]]{{.*}} {in_bounds = [true, true]} : tensor, vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32> + +/// Caluclate the index vector +// CHECK: %[[STEP:.*]] = vector.step : vector<4xindex> +// CHECK: %[[IDX_BC:.*]] = vector.broadcast %[[IDX]] : index to vector<4xindex> +// CHECK: %[[IDX_VEC:.*]] = arith.addi %[[STEP]], %[[IDX_BC]] : vector<4xindex> +// CHECK: %[[C0:.*]] = arith.constant 0 : i32 +// CHECK: %[[SC:.*]] = vector.shape_cast %[[IDX_VEC]] : vector<4xindex> to vector<4xindex> + +/// Extract the starting point from the index vector +// CHECK: %[[IDX_START:.*]] = vector.extractelement %[[SC]]{{\[}}%[[C0]] : i32] : vector<4xindex> + +// Final read and write +// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC]]{{\[}}%[[C79]], %[[IDX_START]]], {{.*}} {in_bounds = [true, true]} : tensor, vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32> +// CHECK: %[[VAL_24:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[OUTPUT]]{{.*}} {in_bounds = [true, true]} : vector<1x4xf32>, tensor } : vector<1x4xi1> -> tensor module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { @@ -177,6 +179,65 @@ module attributes {transform.with_named_sequence} { // ----- +func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous_scalable( + %src: tensor, + %output : tensor, + %idx: index) -> tensor { + + %c79 = arith.constant 79 : index + %1 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"] + } outs(%output : tensor) { + ^bb0(%out: f32): + %2 = linalg.index 1 : index + %3 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%2, %idx) + %extracted = tensor.extract %src[%c79, %3] : tensor + linalg.yield %extracted : f32 + } -> tensor + return %1 : tensor +} + +// CHECK-LABEL: func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous_scalable( +// CHECK-SAME: %[[SRC:[a-zA-Z0-9]*]]: tensor, +// CHECK-SAME: %[[OUTPUT:[a-zA-Z0-9]*]]: tensor, +// CHECK-SAME: %[[IDX:.*]]: index) + +/// Create the mask +// CHECK: %[[C79:.*]] = arith.constant 79 : index +// CHECK: %[[DIM_0_IDX:.*]] = arith.constant 0 : index +// CHECK: %[[DIM_0:.*]] = tensor.dim %[[OUTPUT]], %[[DIM_0_IDX]] : tensor +// CHECK: %[[DIM_1_IDX:.*]] = arith.constant 1 : index +// CHECK: %[[DIM_1:.*]] = tensor.dim %[[OUTPUT]], %[[DIM_1_IDX]] : tensor +// CHECK: %[[MASK:.*]] = vector.create_mask %[[DIM_0]], %[[DIM_1]] : vector<1x[4]xi1> + +/// TODO: This transfer_read is redundant - remove +// CHECK: vector.mask %[[MASK]] { vector.transfer_read %[[OUTPUT]]{{.*}} {in_bounds = [true, true]} : tensor, vector<1x[4]xf32> } : vector<1x[4]xi1> -> vector<1x[4]xf32> + +/// Caluclate the index vector +// CHECK: %[[STEP:.*]] = vector.step : vector<[4]xindex> +// CHECK: %[[IDX_BC:.*]] = vector.broadcast %[[IDX]] : index to vector<[4]xindex> +// CHECK: %[[IDX_VEC:.*]] = arith.addi %[[STEP]], %[[IDX_BC]] : vector<[4]xindex> +// CHECK: %[[C0:.*]] = arith.constant 0 : i32 +// CHECK: %[[SC:.*]] = vector.shape_cast %[[IDX_VEC]] : vector<[4]xindex> to vector<[4]xindex> + +/// Extract the starting point from the index vector +// CHECK: %[[IDX_START:.*]] = vector.extractelement %[[SC]]{{\[}}%[[C0]] : i32] : vector<[4]xindex> + +// Final read and write +// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC]]{{\[}}%[[C79]], %[[IDX_START]]], {{.*}} {in_bounds = [true, true]} : tensor, vector<1x[4]xf32> } : vector<1x[4]xi1> -> vector<1x[4]xf32> +// CHECK: %[[VAL_24:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[OUTPUT]]{{.*}} {in_bounds = [true, true]} : vector<1x[4]xf32>, tensor } : vector<1x[4]xi1> -> tensor + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %0 vector_sizes [1, [4]] {vectorize_nd_extract} : !transform.any_op + transform.yield + } +} + +// ----- + func.func @masked_vectorize_nd_tensor_extract_with_affine_apply_gather(%6: tensor<80x16xf32>, %arg0: index, %extracted_slice : tensor<1x3xf32>) -> tensor<1x3xf32> { %c16 = arith.constant 16 : index %1 = linalg.generic {