Skip to content

Commit 9c1ce31

Browse files
authored
[mlir][vector] Add unroll patterns for vector.load and vector.store (#143420)
This PR adds unroll patterns for vector.load and vector.store. This PR is follow up of #137558
1 parent b6445ac commit 9c1ce31

File tree

5 files changed

+170
-9
lines changed

5 files changed

+170
-9
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1736,7 +1736,9 @@ def Vector_TransferWriteOp :
17361736
let hasVerifier = 1;
17371737
}
17381738

1739-
def Vector_LoadOp : Vector_Op<"load"> {
1739+
def Vector_LoadOp : Vector_Op<"load", [
1740+
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
1741+
]> {
17401742
let summary = "reads an n-D slice of memory into an n-D vector";
17411743
let description = [{
17421744
The 'vector.load' operation reads an n-D slice of memory into an n-D
@@ -1822,7 +1824,9 @@ def Vector_LoadOp : Vector_Op<"load"> {
18221824
"$base `[` $indices `]` attr-dict `:` type($base) `,` type($result)";
18231825
}
18241826

1825-
def Vector_StoreOp : Vector_Op<"store"> {
1827+
def Vector_StoreOp : Vector_Op<"store", [
1828+
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
1829+
]> {
18261830
let summary = "writes an n-D vector to an n-D slice of memory";
18271831
let description = [{
18281832
The 'vector.store' operation writes an n-D vector to an n-D slice of memory.

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5371,6 +5371,10 @@ OpFoldResult LoadOp::fold(FoldAdaptor) {
53715371
return OpFoldResult();
53725372
}
53735373

5374+
std::optional<SmallVector<int64_t, 4>> LoadOp::getShapeForUnroll() {
5375+
return llvm::to_vector<4>(getVectorType().getShape());
5376+
}
5377+
53745378
//===----------------------------------------------------------------------===//
53755379
// StoreOp
53765380
//===----------------------------------------------------------------------===//
@@ -5406,6 +5410,10 @@ LogicalResult StoreOp::fold(FoldAdaptor adaptor,
54065410
return memref::foldMemRefCast(*this);
54075411
}
54085412

5413+
std::optional<SmallVector<int64_t, 4>> StoreOp::getShapeForUnroll() {
5414+
return llvm::to_vector<4>(getVectorType().getShape());
5415+
}
5416+
54095417
//===----------------------------------------------------------------------===//
54105418
// MaskedLoadOp
54115419
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp

Lines changed: 112 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,28 @@ static SmallVector<Value> sliceTransferIndices(ArrayRef<int64_t> elementOffsets,
5454
return slicedIndices;
5555
}
5656

57+
// Compute the new indices by adding `offsets` to `originalIndices`.
58+
// If m < n (m = offsets.size(), n = originalIndices.size()),
59+
// then only the trailing m values in `originalIndices` are updated.
60+
static SmallVector<Value> sliceLoadStoreIndices(PatternRewriter &rewriter,
61+
Location loc,
62+
OperandRange originalIndices,
63+
ArrayRef<int64_t> offsets) {
64+
assert(offsets.size() <= originalIndices.size() &&
65+
"Offsets should not exceed the number of original indices");
66+
SmallVector<Value> indices(originalIndices);
67+
68+
auto start = indices.size() - offsets.size();
69+
for (auto [i, offset] : llvm::enumerate(offsets)) {
70+
if (offset != 0) {
71+
indices[start + i] = rewriter.create<arith::AddIOp>(
72+
loc, originalIndices[start + i],
73+
rewriter.create<arith::ConstantIndexOp>(loc, offset));
74+
}
75+
}
76+
return indices;
77+
}
78+
5779
// Clones `op` into a new operations that takes `operands` and returns
5880
// `resultTypes`.
5981
static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc,
@@ -631,6 +653,90 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
631653
vector::UnrollVectorOptions options;
632654
};
633655

656+
struct UnrollLoadPattern : public OpRewritePattern<vector::LoadOp> {
657+
UnrollLoadPattern(MLIRContext *context,
658+
const vector::UnrollVectorOptions &options,
659+
PatternBenefit benefit = 1)
660+
: OpRewritePattern<vector::LoadOp>(context, benefit), options(options) {}
661+
662+
LogicalResult matchAndRewrite(vector::LoadOp loadOp,
663+
PatternRewriter &rewriter) const override {
664+
VectorType vecType = loadOp.getVectorType();
665+
666+
auto targetShape = getTargetShape(options, loadOp);
667+
if (!targetShape)
668+
return failure();
669+
670+
Location loc = loadOp.getLoc();
671+
ArrayRef<int64_t> originalShape = vecType.getShape();
672+
SmallVector<int64_t> strides(targetShape->size(), 1);
673+
674+
Value result = rewriter.create<arith::ConstantOp>(
675+
loc, vecType, rewriter.getZeroAttr(vecType));
676+
677+
SmallVector<int64_t> loopOrder =
678+
getUnrollOrder(originalShape.size(), loadOp, options);
679+
680+
auto targetVecType =
681+
VectorType::get(*targetShape, vecType.getElementType());
682+
683+
for (SmallVector<int64_t> offsets :
684+
StaticTileOffsetRange(originalShape, *targetShape, loopOrder)) {
685+
SmallVector<Value> indices =
686+
sliceLoadStoreIndices(rewriter, loc, loadOp.getIndices(), offsets);
687+
Value slicedLoad = rewriter.create<vector::LoadOp>(
688+
loc, targetVecType, loadOp.getBase(), indices);
689+
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
690+
loc, slicedLoad, result, offsets, strides);
691+
}
692+
rewriter.replaceOp(loadOp, result);
693+
return success();
694+
}
695+
696+
private:
697+
vector::UnrollVectorOptions options;
698+
};
699+
700+
struct UnrollStorePattern : public OpRewritePattern<vector::StoreOp> {
701+
UnrollStorePattern(MLIRContext *context,
702+
const vector::UnrollVectorOptions &options,
703+
PatternBenefit benefit = 1)
704+
: OpRewritePattern<vector::StoreOp>(context, benefit), options(options) {}
705+
706+
LogicalResult matchAndRewrite(vector::StoreOp storeOp,
707+
PatternRewriter &rewriter) const override {
708+
VectorType vecType = storeOp.getVectorType();
709+
710+
auto targetShape = getTargetShape(options, storeOp);
711+
if (!targetShape)
712+
return failure();
713+
714+
Location loc = storeOp.getLoc();
715+
ArrayRef<int64_t> originalShape = vecType.getShape();
716+
SmallVector<int64_t> strides(targetShape->size(), 1);
717+
718+
Value base = storeOp.getBase();
719+
Value vector = storeOp.getValueToStore();
720+
721+
SmallVector<int64_t> loopOrder =
722+
getUnrollOrder(originalShape.size(), storeOp, options);
723+
724+
for (SmallVector<int64_t> offsets :
725+
StaticTileOffsetRange(originalShape, *targetShape, loopOrder)) {
726+
SmallVector<Value> indices =
727+
sliceLoadStoreIndices(rewriter, loc, storeOp.getIndices(), offsets);
728+
Value slice = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
729+
loc, vector, offsets, *targetShape, strides);
730+
rewriter.create<vector::StoreOp>(loc, slice, base, indices);
731+
}
732+
rewriter.eraseOp(storeOp);
733+
return success();
734+
}
735+
736+
private:
737+
vector::UnrollVectorOptions options;
738+
};
739+
634740
struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
635741
UnrollBroadcastPattern(MLIRContext *context,
636742
const vector::UnrollVectorOptions &options,
@@ -699,10 +805,10 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
699805
void mlir::vector::populateVectorUnrollPatterns(
700806
RewritePatternSet &patterns, const UnrollVectorOptions &options,
701807
PatternBenefit benefit) {
702-
patterns
703-
.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
704-
UnrollContractionPattern, UnrollElementwisePattern,
705-
UnrollReductionPattern, UnrollMultiReductionPattern,
706-
UnrollTransposePattern, UnrollGatherPattern, UnrollBroadcastPattern>(
707-
patterns.getContext(), options, benefit);
808+
patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
809+
UnrollContractionPattern, UnrollElementwisePattern,
810+
UnrollReductionPattern, UnrollMultiReductionPattern,
811+
UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
812+
UnrollStorePattern, UnrollBroadcastPattern>(
813+
patterns.getContext(), options, benefit);
708814
}

mlir/test/Dialect/Vector/vector-unroll-options.mlir

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,3 +378,45 @@ func.func @vector_broadcast_with_tailing_unit_dim(%v: vector<4x1xf32>) -> vector
378378
// CHECK: [[b3:%.+]] = vector.broadcast [[s3]] : vector<2x1xf32> to vector<2x2xf32>
379379
// CHECK: [[r3:%.+]] = vector.insert_strided_slice [[b3]], [[r2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
380380
// CHECK: return [[r3]] : vector<4x4xf32>
381+
382+
383+
func.func @vector_load_2D(%mem: memref<4x4xf16>) -> vector<4x4xf16> {
384+
%c0 = arith.constant 0 : index
385+
%0 = vector.load %mem[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16>
386+
return %0 : vector<4x4xf16>
387+
}
388+
389+
// CHECK-LABEL: func.func @vector_load_2D(
390+
// CHECK-SAME: %[[ARG:.*]]: memref<4x4xf16>) -> vector<4x4xf16> {
391+
// CHECK: %[[C2:.*]] = arith.constant 2 : index
392+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
393+
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4x4xf16>
394+
// CHECK: %[[V0:.*]] = vector.load %[[ARG]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<2x2xf16>
395+
// CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[V0]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf16> into vector<4x4xf16>
396+
// CHECK: %[[V2:.*]] = vector.load %[[ARG]][%[[C0]], %[[C2]]] : memref<4x4xf16>, vector<2x2xf16>
397+
// CHECK: %[[V3:.*]] = vector.insert_strided_slice %[[V2]], %[[V1]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf16> into vector<4x4xf16>
398+
// CHECK: %[[V4:.*]] = vector.load %[[ARG]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<2x2xf16>
399+
// CHECK: %[[V5:.*]] = vector.insert_strided_slice %[[V4]], %[[V3]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf16> into vector<4x4xf16>
400+
// CHECK: %[[V6:.*]] = vector.load %[[ARG]][%[[C2]], %[[C2]]] : memref<4x4xf16>, vector<2x2xf16>
401+
// CHECK: %[[V7:.*]] = vector.insert_strided_slice %[[V6]], %[[V5]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf16> into vector<4x4xf16>
402+
// CHECK: return %[[V7]] : vector<4x4xf16>
403+
404+
405+
func.func @vector_store_2D(%mem: memref<4x4xf16>, %v: vector<4x4xf16>) {
406+
%c0 = arith.constant 0 : index
407+
vector.store %v, %mem[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16>
408+
return
409+
}
410+
411+
// CHECK-LABEL: func.func @vector_store_2D(
412+
// CHECK-SAME: %[[ARG0:.*]]: memref<4x4xf16>, %[[ARG1:.*]]: vector<4x4xf16>) {
413+
// CHECK: %[[C2:.*]] = arith.constant 2 : index
414+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
415+
// CHECK: %[[V0:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf16> to vector<2x2xf16>
416+
// CHECK: vector.store %[[V0]], %[[ARG0]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<2x2xf16>
417+
// CHECK: %[[V1:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf16> to vector<2x2xf16>
418+
// CHECK: vector.store %[[V1]], %[[ARG0]][%[[C0]], %[[C2]]] : memref<4x4xf16>, vector<2x2xf16>
419+
// CHECK: %[[V2:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf16> to vector<2x2xf16>
420+
// CHECK: vector.store %[[V2]], %[[ARG0]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<2x2xf16>
421+
// CHECK: %[[V3:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf16> to vector<2x2xf16>
422+
// CHECK: vector.store %[[V3]], %[[ARG0]][%[[C2]], %[[C2]]] : memref<4x4xf16>, vector<2x2xf16>

mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,8 @@ struct TestVectorUnrollingPatterns
163163
.setFilterConstraint([](Operation *op) {
164164
return success(
165165
isa<arith::AddFOp, vector::FMAOp, vector::MultiDimReductionOp,
166-
vector::BroadcastOp>(op));
166+
vector::BroadcastOp, vector::LoadOp, vector::StoreOp>(
167+
op));
167168
}));
168169
populateVectorUnrollPatterns(
169170
patterns, UnrollVectorOptions()

0 commit comments

Comments
 (0)