Skip to content

Commit aae7571

Browse files
committed
[mlir][linalg] Upgrade vectorisation of tensor.extract
This PR removes the assumption that reading from a dynamic tensor is always a gather load: ```mlir %extracted = tensor.extract %src[%c79, %3] : tensor<?x?xf32> ``` That assumption was originally introduced to simplify the implementation and to reduce the number of cases to consider. Now that the vectorisation of `tensor.extract` has been around for > 1 year and has been quite stable, we can safely relax it. This is a relatively small change - rather than using the parent linalg Op to infer the target output shape (not possible with dynamic shapes), the vectorizer will use the (previously constructed) output vector shape instead. As expected, the following test required updating (`vector.gather` -> `vector.transfer_read`): * @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous Similar test for scalable vectors is also added.
1 parent 2ba3fe7 commit aae7571

File tree

3 files changed

+123
-73
lines changed

3 files changed

+123
-73
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 25 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -808,14 +808,13 @@ static Value calculateGatherOffset(RewriterBase &rewriter,
808808

809809
enum VectorMemoryAccessKind { ScalarBroadcast, Contiguous, Gather };
810810

811-
/// Checks whether /p val can be used for calculating a loop invariant index.
812-
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val) {
811+
/// Checks whether `val` can be used for calculating a loop invariant index.
812+
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val, VectorType resType) {
813813

814-
auto targetShape = linalgOp.getStaticLoopRanges();
815-
assert(((llvm::count_if(targetShape,
814+
assert(((llvm::count_if(resType.getShape(),
816815
[](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
817816
"n-D vectors are not yet supported");
818-
assert(targetShape.back() != 1 &&
817+
assert(resType.getShape().back() != 1 &&
819818
"1-D vectors with the trailing dim eqaual 1 are not yet supported");
820819

821820
// Blocks outside _this_ linalg.generic are effectively loop invariant.
@@ -849,7 +848,7 @@ static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val) {
849848

850849
bool result = true;
851850
for (auto op : ancestor->getOperands())
852-
result &= isLoopInvariantIdx(linalgOp, op);
851+
result &= isLoopInvariantIdx(linalgOp, op, resType);
853852

854853
return result;
855854
}
@@ -871,13 +870,12 @@ static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val) {
871870
/// where <dim> is the trailing dim of the iteration space. \p foundIndexOp is
872871
/// updated to `true` when such an op is found.
873872
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
874-
bool &foundIndexOp) {
873+
bool &foundIndexOp, VectorType resType) {
875874

876-
auto targetShape = linalgOp.getStaticLoopRanges();
877-
assert(((llvm::count_if(targetShape,
875+
assert(((llvm::count_if(resType.getShape(),
878876
[](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
879877
"n-D vectors are not yet supported");
880-
assert(targetShape.back() != 1 &&
878+
assert(resType.getShape().back() != 1 &&
881879
"1-D vectors with the trailing dim 1 are not yet supported");
882880

883881
// Blocks outside _this_ linalg.generic are effectively loop invariant.
@@ -912,46 +910,40 @@ static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
912910

913911
bool result = false;
914912
for (auto op : ancestor->getOperands())
915-
result |= isContiguousLoadIdx(linalgOp, op, foundIndexOp);
913+
result |= isContiguousLoadIdx(linalgOp, op, foundIndexOp, resType);
916914

917915
return result;
918916
}
919917

920918
/// Infer the memory access pattern for the input ExtractOp
921919
///
922-
/// Based on the operation shapes and indices (usually based on the iteration
923-
/// space of the parent `linalgOp` operation), decides whether the input
924-
/// ExtractOp is a contiguous load (including a broadcast of a scalar) or a
925-
/// gather load.
920+
/// Based on the ExtratOp result shape and the access indices, decides whether
921+
/// this Op corresponds to a contiguous load (including a broadcast of a scalar)
922+
/// or a gather load. When analysing the ExtractOp indices (to identify
923+
/// contiguous laods), this method looks for "loop" invariant indices (e.g.
924+
/// block arguments) and indices that change linearly (e.g. via `linalg.index`
925+
/// Op).
926926
///
927927
/// Note that it is always safe to use gather load operations for contiguous
928928
/// loads (albeit slow), but not vice-versa. When in doubt, bail out and assume
929929
/// that `extractOp` is a gather load.
930930
static VectorMemoryAccessKind
931931
getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
932-
LinalgOp &linalgOp) {
932+
LinalgOp &linalgOp, VectorType resType) {
933933

934-
auto targetShape = linalgOp.getStaticLoopRanges();
935934
auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
936935

937-
// 0.1 Is this a 0-D vector? If yes then this is a scalar broadcast.
936+
// 0. Is this a 0-D vector? If yes then this is a scalar broadcast.
938937
if (inputShape.getShape().empty())
939938
return VectorMemoryAccessKind::ScalarBroadcast;
940939

941-
// 0.2 In the case of dynamic shapes just bail-out and assume that it's a
942-
// gather load.
943-
// TODO: Relax this condition.
944-
if (linalgOp.hasDynamicShape())
945-
return VectorMemoryAccessKind::Gather;
946-
947940
// 1. Assume that it's a gather load when reading _into_:
948-
// * an n-D "vector", like `tensor<1x2x4xi32` or `tensor<2x1x4xi32>`, or
949-
// * a 1-D "vector" with the trailing dim equal 1, e.g. `tensor<1x4x1xi32`.
941+
// * an n-D "vector", like `vector<1x2x4xi32` or `vector<2x1x4xi32>`, or
942+
// * a 1-D "vector" with the trailing dim equal 1, e.g. `vector<1x4x1xi32>`.
950943
// TODO: Relax these conditions.
951-
// FIXME: This condition assumes non-dynamic sizes.
952-
if ((llvm::count_if(targetShape,
944+
if ((llvm::count_if(resType.getShape(),
953945
[](int64_t dimSize) { return dimSize > 1; }) != 1) ||
954-
targetShape.back() == 1)
946+
resType.getShape().back() == 1)
955947
return VectorMemoryAccessKind::Gather;
956948

957949
// 2. Assume that it's a gather load when reading _from_ a tensor for which
@@ -972,7 +964,7 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
972964
if (inputShape.getShape()[i] == 1)
973965
continue;
974966

975-
leadingIdxsLoopInvariant &= isLoopInvariantIdx(linalgOp, indexVal);
967+
leadingIdxsLoopInvariant &= isLoopInvariantIdx(linalgOp, indexVal, resType);
976968
}
977969

978970
if (!leadingIdxsLoopInvariant) {
@@ -989,7 +981,7 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
989981
// 4a. Scalar broadcast load
990982
// If the trailing index is loop invariant then this is a scalar load.
991983
if (leadingIdxsLoopInvariant &&
992-
isLoopInvariantIdx(linalgOp, extractOpTrailingIdx)) {
984+
isLoopInvariantIdx(linalgOp, extractOpTrailingIdx, resType)) {
993985
LDBG("Found scalar broadcast load: " << extractOp);
994986

995987
return VectorMemoryAccessKind::ScalarBroadcast;
@@ -1001,7 +993,7 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
1001993
// This is what the following bool captures.
1002994
bool foundIndexOp = false;
1003995
bool isContiguousLoad =
1004-
isContiguousLoadIdx(linalgOp, extractOpTrailingIdx, foundIndexOp);
996+
isContiguousLoadIdx(linalgOp, extractOpTrailingIdx, foundIndexOp, resType);
1005997
isContiguousLoad &= foundIndexOp;
1006998

1007999
if (isContiguousLoad) {
@@ -1042,7 +1034,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
10421034
rewriter.create<arith::ConstantIndexOp>(loc, 0));
10431035

10441036
VectorMemoryAccessKind memAccessKind =
1045-
getTensorExtractMemoryAccessPattern(extractOp, linalgOp);
1037+
getTensorExtractMemoryAccessPattern(extractOp, linalgOp, resultType);
10461038

10471039
// 1. Handle gather access
10481040
if (memAccessKind == VectorMemoryAccessKind::Gather) {

mlir/test/Dialect/Linalg/vectorization-scalable.mlir

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -162,17 +162,14 @@ func.func @vectorize_linalg_index(%arg0: tensor<3x3x?xf32>, %arg1: tensor<1x1x?x
162162

163163
// CHECK-LABEL: @vectorize_linalg_index
164164
// CHECK-SAME: %[[SRC:.*]]: tensor<3x3x?xf32>, %[[DST:.*]]: tensor<1x1x?xf32>
165-
// CHECK-DAG: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<1x1x[4]xf32>
166-
// CHECK-DAG: %[[MASK:.*]] = arith.constant dense<true> : vector<1x1x[4]xi1>
167165
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
168166
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
169167
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
170168
// CHECK: %[[DST_DIM2:.*]] = tensor.dim %[[DST]], %[[C2]] : tensor<1x1x?xf32>
171-
// CHECK: %[[DST_MASK:.*]] = vector.create_mask %[[C1]], %[[C1]], %[[DST_DIM2]] : vector<1x1x[4]xi1>
169+
// CHECK: %[[MASK:.*]] = vector.create_mask %[[C1]], %[[C1]], %[[DST_DIM2]] : vector<1x1x[4]xi1>
172170
// CHECK: %[[INDEX_VEC:.*]] = vector.step : vector<[4]xindex>
173-
// CHECK: %[[INDEX_VEC_BCAST:.*]] = vector.broadcast %[[INDEX_VEC]] : vector<[4]xindex> to vector<1x1x[4]xindex>
174-
// CHECK: %[[GATHER:.*]] = vector.mask %[[DST_MASK]] { vector.gather %[[SRC]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {{\[}}%[[INDEX_VEC_BCAST]]], %[[MASK]], %[[PASSTHRU]] : tensor<3x3x?xf32>, vector<1x1x[4]xindex>, vector<1x1x[4]xi1>, vector<1x1x[4]xf32> into vector<1x1x[4]xf32> } : vector<1x1x[4]xi1> -> vector<1x1x[4]xf32>
175-
// CHECK: %[[OUT:.*]] = vector.mask %[[DST_MASK]] { vector.transfer_write %[[GATHER]], %[[DST]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x[4]xf32>, tensor<1x1x?xf32> } : vector<1x1x[4]xi1> -> tensor<1x1x?xf32>
171+
// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC]][%c0, %c0, %2], %cst {in_bounds = [true, true, true]} : tensor<3x3x?xf32>, vector<1x1x[4]xf32> } : vector<1x1x[4]xi1> -> vector<1x1x[4]xf32>
172+
// CHECK: %[[OUT:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[DST]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x[4]xf32>, tensor<1x1x?xf32> } : vector<1x1x[4]xi1> -> tensor<1x1x?xf32>
176173
// CHECK: return %[[OUT]] : tensor<1x1x?xf32>
177174

178175
module attributes {transform.with_named_sequence} {

mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir

Lines changed: 95 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -120,52 +120,54 @@ module attributes {transform.with_named_sequence} {
120120

121121
// -----
122122

123-
func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous(%6: tensor<?x?xf32>, %arg0: index, %extracted_slice : tensor<?x?xf32>) -> tensor<?x?xf32> {
123+
func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous(
124+
%src: tensor<?x?xf32>,
125+
%output : tensor<?x?xf32>,
126+
%idx: index) -> tensor<?x?xf32> {
127+
124128
%c79 = arith.constant 79 : index
125129
%1 = linalg.generic {
126130
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>],
127131
iterator_types = ["parallel", "parallel"]
128-
} outs(%extracted_slice : tensor<?x?xf32>) {
132+
} outs(%output : tensor<?x?xf32>) {
129133
^bb0(%out: f32):
130134
%2 = linalg.index 1 : index
131-
%3 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%2, %arg0)
132-
%extracted = tensor.extract %6[%c79, %3] : tensor<?x?xf32>
135+
%3 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%2, %idx)
136+
%extracted = tensor.extract %src[%c79, %3] : tensor<?x?xf32>
133137
linalg.yield %extracted : f32
134138
} -> tensor<?x?xf32>
135139
return %1 : tensor<?x?xf32>
136140
}
137141

138142
// CHECK-LABEL: func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous(
139-
// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?xf32>,
140-
// CHECK-SAME: %[[VAL_1:.*]]: index,
141-
// CHECK-SAME: %[[VAL_2:.*]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
142-
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 79 : index
143-
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : index
144-
// CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_2]], %[[VAL_4]] : tensor<?x?xf32>
145-
// CHECK: %[[VAL_6:.*]] = arith.constant 1 : index
146-
// CHECK: %[[VAL_7:.*]] = tensor.dim %[[VAL_2]], %[[VAL_6]] : tensor<?x?xf32>
147-
// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 0 : index
148-
// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32
149-
// CHECK: %[[VAL_10:.*]] = vector.create_mask %[[VAL_5]], %[[VAL_7]] : vector<1x4xi1>
150-
// CHECK: %[[VAL_11:.*]] = vector.mask %[[VAL_10]] { vector.transfer_read %[[VAL_2]]{{\[}}%[[VAL_8]], %[[VAL_8]]], %[[VAL_9]] {in_bounds = [true, true]} : tensor<?x?xf32>, vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
151-
// CHECK: %[[VAL_12:.*]] = vector.step : vector<4xindex>
152-
// CHECK: %[[VAL_13:.*]] = vector.broadcast %[[VAL_1]] : index to vector<4xindex>
153-
// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_12]], %[[VAL_13]] : vector<4xindex>
154-
// CHECK-DAG: %[[VAL_15:.*]] = arith.constant dense<true> : vector<1x4xi1>
155-
// CHECK-DAG: %[[VAL_16:.*]] = arith.constant dense<0.000000e+00> : vector<1x4xf32>
156-
// CHECK-DAG: %[[VAL_17:.*]] = arith.constant 0 : index
157-
// CHECK-DAG: %[[VAL_18:.*]] = arith.constant dense<79> : vector<1x4xindex>
158-
// CHECK-DAG: %[[VAL_19:.*]] = arith.constant 1 : index
159-
// CHECK: %[[VAL_20:.*]] = tensor.dim %[[VAL_0]], %[[VAL_19]] : tensor<?x?xf32>
160-
// CHECK: %[[VAL_21:.*]] = vector.broadcast %[[VAL_20]] : index to vector<1x4xindex>
161-
// CHECK: %[[VAL_22:.*]] = arith.muli %[[VAL_18]], %[[VAL_21]] : vector<1x4xindex>
162-
// CHECK: %[[VAL_23:.*]] = vector.broadcast %[[VAL_14]] : vector<4xindex> to vector<1x4xindex>
163-
// CHECK: %[[VAL_24:.*]] = arith.addi %[[VAL_23]], %[[VAL_22]] : vector<1x4xindex>
164-
// CHECK: %[[VAL_25:.*]] = vector.mask %[[VAL_10]] { vector.gather %[[VAL_0]]{{\[}}%[[VAL_17]], %[[VAL_17]]] {{\[}}%[[VAL_24]]], %[[VAL_15]], %[[VAL_16]] : tensor<?x?xf32>, vector<1x4xindex>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
165-
// CHECK: %[[VAL_26:.*]] = arith.constant 0 : index
166-
// CHECK: %[[VAL_27:.*]] = vector.mask %[[VAL_10]] { vector.transfer_write %[[VAL_25]], %[[VAL_2]]{{\[}}%[[VAL_26]], %[[VAL_26]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<?x?xf32> } : vector<1x4xi1> -> tensor<?x?xf32>
167-
// CHECK: return %[[VAL_27]] : tensor<?x?xf32>
168-
// CHECK: }
143+
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
144+
// CHECK-SAME: %[[OUTPUT:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
145+
// CHECK-SAME: %[[IDX:.*]]: index)
146+
147+
/// Create the mask
148+
// CHECK: %[[C79:.*]] = arith.constant 79 : index
149+
// CHECK: %[[DIM_0_IDX:.*]] = arith.constant 0 : index
150+
// CHECK: %[[DIM_0:.*]] = tensor.dim %[[OUTPUT]], %[[DIM_0_IDX]] : tensor<?x?xf32>
151+
// CHECK: %[[DIM_1_IDX:.*]] = arith.constant 1 : index
152+
// CHECK: %[[DIM_1:.*]] = tensor.dim %[[OUTPUT]], %[[DIM_1_IDX]] : tensor<?x?xf32>
153+
// CHECK: %[[MASK:.*]] = vector.create_mask %[[DIM_0]], %[[DIM_1]] : vector<1x4xi1>
154+
155+
/// TODO: This transfer_read is redundant - remove
156+
// CHECK: vector.mask %[[MASK]] { vector.transfer_read %[[OUTPUT]]{{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
157+
158+
/// Caluclate the index vector
159+
// CHECK: %[[STEP:.*]] = vector.step : vector<4xindex>
160+
// CHECK: %[[IDX_BC:.*]] = vector.broadcast %[[IDX]] : index to vector<4xindex>
161+
// CHECK: %[[IDX_VEC:.*]] = arith.addi %[[STEP]], %[[IDX_BC]] : vector<4xindex>
162+
// CHECK: %[[C0:.*]] = arith.constant 0 : i32
163+
// CHECK: %[[SC:.*]] = vector.shape_cast %[[IDX_VEC]] : vector<4xindex> to vector<4xindex>
164+
165+
/// Extract the starting point from the index vector
166+
// CHECK: %[[IDX_START:.*]] = vector.extractelement %[[SC]]{{\[}}%[[C0]] : i32] : vector<4xindex>
167+
168+
// Final read and write
169+
// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC]]{{\[}}%[[C79]], %[[IDX_START]]], {{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
170+
// CHECK: %[[VAL_24:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[OUTPUT]]{{.*}} {in_bounds = [true, true]} : vector<1x4xf32>, tensor<?x?xf32> } : vector<1x4xi1> -> tensor<?x?xf32>
169171

170172
module attributes {transform.with_named_sequence} {
171173
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
@@ -177,6 +179,65 @@ module attributes {transform.with_named_sequence} {
177179

178180
// -----
179181

182+
func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous_scalable(
183+
%src: tensor<?x?xf32>,
184+
%output : tensor<?x?xf32>,
185+
%idx: index) -> tensor<?x?xf32> {
186+
187+
%c79 = arith.constant 79 : index
188+
%1 = linalg.generic {
189+
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>],
190+
iterator_types = ["parallel", "parallel"]
191+
} outs(%output : tensor<?x?xf32>) {
192+
^bb0(%out: f32):
193+
%2 = linalg.index 1 : index
194+
%3 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%2, %idx)
195+
%extracted = tensor.extract %src[%c79, %3] : tensor<?x?xf32>
196+
linalg.yield %extracted : f32
197+
} -> tensor<?x?xf32>
198+
return %1 : tensor<?x?xf32>
199+
}
200+
201+
// CHECK-LABEL: func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous_scalable(
202+
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
203+
// CHECK-SAME: %[[OUTPUT:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
204+
// CHECK-SAME: %[[IDX:.*]]: index)
205+
206+
/// Create the mask
207+
// CHECK: %[[C79:.*]] = arith.constant 79 : index
208+
// CHECK: %[[DIM_0_IDX:.*]] = arith.constant 0 : index
209+
// CHECK: %[[DIM_0:.*]] = tensor.dim %[[OUTPUT]], %[[DIM_0_IDX]] : tensor<?x?xf32>
210+
// CHECK: %[[DIM_1_IDX:.*]] = arith.constant 1 : index
211+
// CHECK: %[[DIM_1:.*]] = tensor.dim %[[OUTPUT]], %[[DIM_1_IDX]] : tensor<?x?xf32>
212+
// CHECK: %[[MASK:.*]] = vector.create_mask %[[DIM_0]], %[[DIM_1]] : vector<1x[4]xi1>
213+
214+
/// TODO: This transfer_read is redundant - remove
215+
// CHECK: vector.mask %[[MASK]] { vector.transfer_read %[[OUTPUT]]{{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<1x[4]xf32> } : vector<1x[4]xi1> -> vector<1x[4]xf32>
216+
217+
/// Caluclate the index vector
218+
// CHECK: %[[STEP:.*]] = vector.step : vector<[4]xindex>
219+
// CHECK: %[[IDX_BC:.*]] = vector.broadcast %[[IDX]] : index to vector<[4]xindex>
220+
// CHECK: %[[IDX_VEC:.*]] = arith.addi %[[STEP]], %[[IDX_BC]] : vector<[4]xindex>
221+
// CHECK: %[[C0:.*]] = arith.constant 0 : i32
222+
// CHECK: %[[SC:.*]] = vector.shape_cast %[[IDX_VEC]] : vector<[4]xindex> to vector<[4]xindex>
223+
224+
/// Extract the starting point from the index vector
225+
// CHECK: %[[IDX_START:.*]] = vector.extractelement %[[SC]]{{\[}}%[[C0]] : i32] : vector<[4]xindex>
226+
227+
// Final read and write
228+
// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC]]{{\[}}%[[C79]], %[[IDX_START]]], {{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<1x[4]xf32> } : vector<1x[4]xi1> -> vector<1x[4]xf32>
229+
// CHECK: %[[VAL_24:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[OUTPUT]]{{.*}} {in_bounds = [true, true]} : vector<1x[4]xf32>, tensor<?x?xf32> } : vector<1x[4]xi1> -> tensor<?x?xf32>
230+
231+
module attributes {transform.with_named_sequence} {
232+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
233+
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
234+
transform.structured.vectorize %0 vector_sizes [1, [4]] {vectorize_nd_extract} : !transform.any_op
235+
transform.yield
236+
}
237+
}
238+
239+
// -----
240+
180241
func.func @masked_vectorize_nd_tensor_extract_with_affine_apply_gather(%6: tensor<80x16xf32>, %arg0: index, %extracted_slice : tensor<1x3xf32>) -> tensor<1x3xf32> {
181242
%c16 = arith.constant 16 : index
182243
%1 = linalg.generic {

0 commit comments

Comments
 (0)