Skip to content

[mlir][linalg] Vectorisation of tensor.extract - dynamic shapes #100582

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 25 additions & 32 deletions mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -810,12 +810,12 @@ static Value calculateGatherOffset(RewriterBase &rewriter,

enum VectorMemoryAccessKind { ScalarBroadcast, Contiguous, Gather };

/// Checks whether /p val can be used for calculating a loop invariant index.
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val) {
/// Checks whether `val` can be used for calculating a loop invariant index.
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val,
VectorType resType) {

auto targetShape = linalgOp.getStaticLoopRanges();
assert(llvm::count_if(targetShape,
[](int64_t dimSize) { return dimSize > 1; }) == 1 &&
assert(((llvm::count_if(resType.getShape(),
[](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
"n-D vectors are not yet supported");

// Blocks outside _this_ linalg.generic are effectively loop invariant.
Expand Down Expand Up @@ -849,7 +849,7 @@ static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val) {

bool result = true;
for (auto op : ancestor->getOperands())
result &= isLoopInvariantIdx(linalgOp, op);
result &= isLoopInvariantIdx(linalgOp, op, resType);

return result;
}
Expand All @@ -871,10 +871,9 @@ static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val) {
/// where <dim> is the trailing dim of the iteration space. \p foundIndexOp is
/// updated to `true` when such an op is found.
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
bool &foundIndexOp) {
bool &foundIndexOp, VectorType resType) {

auto targetShape = linalgOp.getStaticLoopRanges();
assert(((llvm::count_if(targetShape,
assert(((llvm::count_if(resType.getShape(),
[](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
"n-D vectors are not yet supported");

Expand Down Expand Up @@ -910,44 +909,38 @@ static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,

bool result = false;
for (auto op : ancestor->getOperands())
result |= isContiguousLoadIdx(linalgOp, op, foundIndexOp);
result |= isContiguousLoadIdx(linalgOp, op, foundIndexOp, resType);

return result;
}

/// Infer the memory access pattern for the input ExtractOp
///
/// Based on the operation shapes and indices (usually based on the iteration
/// space of the parent `linalgOp` operation), decides whether the input
/// ExtractOp is a contiguous load (including a broadcast of a scalar) or a
/// gather load.
/// Based on the ExtratOp result shape and the access indices, decides whether
/// this Op corresponds to a contiguous load (including a broadcast of a scalar)
/// or a gather load. When analysing the ExtractOp indices (to identify
/// contiguous laods), this method looks for "loop" invariant indices (e.g.
/// block arguments) and indices that change linearly (e.g. via `linalg.index`
/// Op).
///
/// Note that it is always safe to use gather load operations for contiguous
/// loads (albeit slow), but not vice-versa. When in doubt, bail out and assume
/// that `extractOp` is a gather load.
static VectorMemoryAccessKind
getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
LinalgOp &linalgOp) {
LinalgOp &linalgOp, VectorType resType) {

auto targetShape = linalgOp.getStaticLoopRanges();
auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());

// 0.1 Is this a 0-D vector? If yes then this is a scalar broadcast.
// 0. Is this a 0-D vector? If yes then this is a scalar broadcast.
if (inputShape.getShape().empty())
return VectorMemoryAccessKind::ScalarBroadcast;

// 0.2 In the case of dynamic shapes just bail-out and assume that it's a
// gather load.
// TODO: Relax this condition.
if (linalgOp.hasDynamicShape())
return VectorMemoryAccessKind::Gather;

// True for vectors that are effectively 1D, e.g. `vector<1x4x1xi32>`, false
// otherwise.
bool isOutput1DVector = (llvm::count_if(targetShape, [](int64_t dimSize) {
return dimSize > 1;
}) == 1);

bool isOutput1DVector =
(llvm::count_if(resType.getShape(),
[](int64_t dimSize) { return dimSize > 1; }) == 1);
// 1. Assume that it's a gather load when reading non-1D vector.
if (!isOutput1DVector)
return VectorMemoryAccessKind::Gather;
Expand All @@ -965,7 +958,7 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
if (inputShape.getShape()[i] == 1)
continue;

leadingIdxsLoopInvariant &= isLoopInvariantIdx(linalgOp, indexVal);
leadingIdxsLoopInvariant &= isLoopInvariantIdx(linalgOp, indexVal, resType);
}

if (!leadingIdxsLoopInvariant) {
Expand All @@ -982,7 +975,7 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
// 3a. Scalar broadcast load
// If the trailing index is loop invariant then this is a scalar load.
if (leadingIdxsLoopInvariant &&
isLoopInvariantIdx(linalgOp, extractOpTrailingIdx)) {
isLoopInvariantIdx(linalgOp, extractOpTrailingIdx, resType)) {
LDBG("Found scalar broadcast load: " << extractOp);

return VectorMemoryAccessKind::ScalarBroadcast;
Expand All @@ -993,8 +986,8 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
// This effectively means that it must be based on the trailing loop index.
// This is what the following bool captures.
bool foundIndexOp = false;
bool isContiguousLoad =
isContiguousLoadIdx(linalgOp, extractOpTrailingIdx, foundIndexOp);
bool isContiguousLoad = isContiguousLoadIdx(linalgOp, extractOpTrailingIdx,
foundIndexOp, resType);
isContiguousLoad &= foundIndexOp;

if (isContiguousLoad) {
Expand Down Expand Up @@ -1035,7 +1028,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
rewriter.create<arith::ConstantIndexOp>(loc, 0));

VectorMemoryAccessKind memAccessKind =
getTensorExtractMemoryAccessPattern(extractOp, linalgOp);
getTensorExtractMemoryAccessPattern(extractOp, linalgOp, resultType);

// 1. Handle gather access
if (memAccessKind == VectorMemoryAccessKind::Gather) {
Expand Down
9 changes: 3 additions & 6 deletions mlir/test/Dialect/Linalg/vectorization-scalable.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -162,17 +162,14 @@ func.func @vectorize_linalg_index(%arg0: tensor<3x3x?xf32>, %arg1: tensor<1x1x?x

// CHECK-LABEL: @vectorize_linalg_index
// CHECK-SAME: %[[SRC:.*]]: tensor<3x3x?xf32>, %[[DST:.*]]: tensor<1x1x?xf32>
// CHECK-DAG: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<1x1x[4]xf32>
// CHECK-DAG: %[[MASK:.*]] = arith.constant dense<true> : vector<1x1x[4]xi1>
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[DST_DIM2:.*]] = tensor.dim %[[DST]], %[[C2]] : tensor<1x1x?xf32>
// CHECK: %[[DST_MASK:.*]] = vector.create_mask %[[C1]], %[[C1]], %[[DST_DIM2]] : vector<1x1x[4]xi1>
// CHECK: %[[MASK:.*]] = vector.create_mask %[[C1]], %[[C1]], %[[DST_DIM2]] : vector<1x1x[4]xi1>
// CHECK: %[[INDEX_VEC:.*]] = vector.step : vector<[4]xindex>
// CHECK: %[[INDEX_VEC_BCAST:.*]] = vector.broadcast %[[INDEX_VEC]] : vector<[4]xindex> to vector<1x1x[4]xindex>
// CHECK: %[[GATHER:.*]] = vector.mask %[[DST_MASK]] { vector.gather %[[SRC]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {{\[}}%[[INDEX_VEC_BCAST]]], %[[MASK]], %[[PASSTHRU]] : tensor<3x3x?xf32>, vector<1x1x[4]xindex>, vector<1x1x[4]xi1>, vector<1x1x[4]xf32> into vector<1x1x[4]xf32> } : vector<1x1x[4]xi1> -> vector<1x1x[4]xf32>
// CHECK: %[[OUT:.*]] = vector.mask %[[DST_MASK]] { vector.transfer_write %[[GATHER]], %[[DST]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x[4]xf32>, tensor<1x1x?xf32> } : vector<1x1x[4]xi1> -> tensor<1x1x?xf32>
// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC]][%c0, %c0, %2], %cst {in_bounds = [true, true, true]} : tensor<3x3x?xf32>, vector<1x1x[4]xf32> } : vector<1x1x[4]xi1> -> vector<1x1x[4]xf32>
// CHECK: %[[OUT:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[DST]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x[4]xf32>, tensor<1x1x?xf32> } : vector<1x1x[4]xi1> -> tensor<1x1x?xf32>
// CHECK: return %[[OUT]] : tensor<1x1x?xf32>

module attributes {transform.with_named_sequence} {
Expand Down
129 changes: 95 additions & 34 deletions mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -120,52 +120,54 @@ module attributes {transform.with_named_sequence} {

// -----

func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous(%6: tensor<?x?xf32>, %arg0: index, %extracted_slice : tensor<?x?xf32>) -> tensor<?x?xf32> {
func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous(
%src: tensor<?x?xf32>,
%output : tensor<?x?xf32>,
%idx: index) -> tensor<?x?xf32> {

%c79 = arith.constant 79 : index
%1 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]
} outs(%extracted_slice : tensor<?x?xf32>) {
} outs(%output : tensor<?x?xf32>) {
^bb0(%out: f32):
%2 = linalg.index 1 : index
%3 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%2, %arg0)
%extracted = tensor.extract %6[%c79, %3] : tensor<?x?xf32>
%3 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%2, %idx)
%extracted = tensor.extract %src[%c79, %3] : tensor<?x?xf32>
linalg.yield %extracted : f32
} -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}

// CHECK-LABEL: func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: index,
// CHECK-SAME: %[[VAL_2:.*]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 79 : index
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_2]], %[[VAL_4]] : tensor<?x?xf32>
// CHECK: %[[VAL_6:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_7:.*]] = tensor.dim %[[VAL_2]], %[[VAL_6]] : tensor<?x?xf32>
// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[VAL_10:.*]] = vector.create_mask %[[VAL_5]], %[[VAL_7]] : vector<1x4xi1>
// CHECK: %[[VAL_11:.*]] = vector.mask %[[VAL_10]] { vector.transfer_read %[[VAL_2]]{{\[}}%[[VAL_8]], %[[VAL_8]]], %[[VAL_9]] {in_bounds = [true, true]} : tensor<?x?xf32>, vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
// CHECK: %[[VAL_12:.*]] = vector.step : vector<4xindex>
// CHECK: %[[VAL_13:.*]] = vector.broadcast %[[VAL_1]] : index to vector<4xindex>
// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_12]], %[[VAL_13]] : vector<4xindex>
// CHECK-DAG: %[[VAL_15:.*]] = arith.constant dense<true> : vector<1x4xi1>
// CHECK-DAG: %[[VAL_16:.*]] = arith.constant dense<0.000000e+00> : vector<1x4xf32>
// CHECK-DAG: %[[VAL_17:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[VAL_18:.*]] = arith.constant dense<79> : vector<1x4xindex>
// CHECK-DAG: %[[VAL_19:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_20:.*]] = tensor.dim %[[VAL_0]], %[[VAL_19]] : tensor<?x?xf32>
// CHECK: %[[VAL_21:.*]] = vector.broadcast %[[VAL_20]] : index to vector<1x4xindex>
// CHECK: %[[VAL_22:.*]] = arith.muli %[[VAL_18]], %[[VAL_21]] : vector<1x4xindex>
// CHECK: %[[VAL_23:.*]] = vector.broadcast %[[VAL_14]] : vector<4xindex> to vector<1x4xindex>
// CHECK: %[[VAL_24:.*]] = arith.addi %[[VAL_23]], %[[VAL_22]] : vector<1x4xindex>
// CHECK: %[[VAL_25:.*]] = vector.mask %[[VAL_10]] { vector.gather %[[VAL_0]]{{\[}}%[[VAL_17]], %[[VAL_17]]] {{\[}}%[[VAL_24]]], %[[VAL_15]], %[[VAL_16]] : tensor<?x?xf32>, vector<1x4xindex>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
// CHECK: %[[VAL_26:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_27:.*]] = vector.mask %[[VAL_10]] { vector.transfer_write %[[VAL_25]], %[[VAL_2]]{{\[}}%[[VAL_26]], %[[VAL_26]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<?x?xf32> } : vector<1x4xi1> -> tensor<?x?xf32>
// CHECK: return %[[VAL_27]] : tensor<?x?xf32>
// CHECK: }
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
// CHECK-SAME: %[[OUTPUT:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
// CHECK-SAME: %[[IDX:.*]]: index)

/// Create the mask
// CHECK: %[[C79:.*]] = arith.constant 79 : index
// CHECK: %[[DIM_0_IDX:.*]] = arith.constant 0 : index
// CHECK: %[[DIM_0:.*]] = tensor.dim %[[OUTPUT]], %[[DIM_0_IDX]] : tensor<?x?xf32>
// CHECK: %[[DIM_1_IDX:.*]] = arith.constant 1 : index
// CHECK: %[[DIM_1:.*]] = tensor.dim %[[OUTPUT]], %[[DIM_1_IDX]] : tensor<?x?xf32>
// CHECK: %[[MASK:.*]] = vector.create_mask %[[DIM_0]], %[[DIM_1]] : vector<1x4xi1>

/// TODO: This transfer_read is redundant - remove
// CHECK: vector.mask %[[MASK]] { vector.transfer_read %[[OUTPUT]]{{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>

/// Caluclate the index vector
// CHECK: %[[STEP:.*]] = vector.step : vector<4xindex>
// CHECK: %[[IDX_BC:.*]] = vector.broadcast %[[IDX]] : index to vector<4xindex>
// CHECK: %[[IDX_VEC:.*]] = arith.addi %[[STEP]], %[[IDX_BC]] : vector<4xindex>
// CHECK: %[[C0:.*]] = arith.constant 0 : i32
// CHECK: %[[SC:.*]] = vector.shape_cast %[[IDX_VEC]] : vector<4xindex> to vector<4xindex>

/// Extract the starting point from the index vector
// CHECK: %[[IDX_START:.*]] = vector.extractelement %[[SC]]{{\[}}%[[C0]] : i32] : vector<4xindex>

// Final read and write
// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC]]{{\[}}%[[C79]], %[[IDX_START]]], {{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
// CHECK: %[[VAL_24:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[OUTPUT]]{{.*}} {in_bounds = [true, true]} : vector<1x4xf32>, tensor<?x?xf32> } : vector<1x4xi1> -> tensor<?x?xf32>

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
Expand All @@ -177,6 +179,65 @@ module attributes {transform.with_named_sequence} {

// -----

func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous_scalable(
%src: tensor<?x?xf32>,
%output : tensor<?x?xf32>,
%idx: index) -> tensor<?x?xf32> {

%c79 = arith.constant 79 : index
%1 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]
} outs(%output : tensor<?x?xf32>) {
^bb0(%out: f32):
%2 = linalg.index 1 : index
%3 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%2, %idx)
%extracted = tensor.extract %src[%c79, %3] : tensor<?x?xf32>
linalg.yield %extracted : f32
} -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}

// CHECK-LABEL: func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous_scalable(
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
// CHECK-SAME: %[[OUTPUT:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
// CHECK-SAME: %[[IDX:.*]]: index)

/// Create the mask
// CHECK: %[[C79:.*]] = arith.constant 79 : index
// CHECK: %[[DIM_0_IDX:.*]] = arith.constant 0 : index
// CHECK: %[[DIM_0:.*]] = tensor.dim %[[OUTPUT]], %[[DIM_0_IDX]] : tensor<?x?xf32>
// CHECK: %[[DIM_1_IDX:.*]] = arith.constant 1 : index
// CHECK: %[[DIM_1:.*]] = tensor.dim %[[OUTPUT]], %[[DIM_1_IDX]] : tensor<?x?xf32>
// CHECK: %[[MASK:.*]] = vector.create_mask %[[DIM_0]], %[[DIM_1]] : vector<1x[4]xi1>

/// TODO: This transfer_read is redundant - remove
// CHECK: vector.mask %[[MASK]] { vector.transfer_read %[[OUTPUT]]{{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<1x[4]xf32> } : vector<1x[4]xi1> -> vector<1x[4]xf32>

/// Caluclate the index vector
// CHECK: %[[STEP:.*]] = vector.step : vector<[4]xindex>
// CHECK: %[[IDX_BC:.*]] = vector.broadcast %[[IDX]] : index to vector<[4]xindex>
// CHECK: %[[IDX_VEC:.*]] = arith.addi %[[STEP]], %[[IDX_BC]] : vector<[4]xindex>
// CHECK: %[[C0:.*]] = arith.constant 0 : i32
// CHECK: %[[SC:.*]] = vector.shape_cast %[[IDX_VEC]] : vector<[4]xindex> to vector<[4]xindex>

/// Extract the starting point from the index vector
// CHECK: %[[IDX_START:.*]] = vector.extractelement %[[SC]]{{\[}}%[[C0]] : i32] : vector<[4]xindex>

// Final read and write
// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC]]{{\[}}%[[C79]], %[[IDX_START]]], {{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<1x[4]xf32> } : vector<1x[4]xi1> -> vector<1x[4]xf32>
// CHECK: %[[VAL_24:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[OUTPUT]]{{.*}} {in_bounds = [true, true]} : vector<1x[4]xf32>, tensor<?x?xf32> } : vector<1x[4]xi1> -> tensor<?x?xf32>

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.structured.vectorize %0 vector_sizes [1, [4]] {vectorize_nd_extract} : !transform.any_op
transform.yield
}
}

// -----

func.func @masked_vectorize_nd_tensor_extract_with_affine_apply_gather(%6: tensor<80x16xf32>, %arg0: index, %extracted_slice : tensor<1x3xf32>) -> tensor<1x3xf32> {
%c16 = arith.constant 16 : index
%1 = linalg.generic {
Expand Down
Loading