@@ -54,6 +54,28 @@ static SmallVector<Value> sliceTransferIndices(ArrayRef<int64_t> elementOffsets,
54
54
return slicedIndices;
55
55
}
56
56
57
+ // Compute the new indices by adding `offsets` to `originalIndices`.
58
+ // If m < n (m = offsets.size(), n = originalIndices.size()),
59
+ // then only the trailing m values in `originalIndices` are updated.
60
+ static SmallVector<Value> sliceLoadStoreIndices (PatternRewriter &rewriter,
61
+ Location loc,
62
+ OperandRange originalIndices,
63
+ ArrayRef<int64_t > offsets) {
64
+ assert (offsets.size () <= originalIndices.size () &&
65
+ " Offsets should not exceed the number of original indices" );
66
+ SmallVector<Value> indices (originalIndices);
67
+
68
+ auto start = indices.size () - offsets.size ();
69
+ for (auto [i, offset] : llvm::enumerate (offsets)) {
70
+ if (offset != 0 ) {
71
+ indices[start + i] = rewriter.create <arith::AddIOp>(
72
+ loc, originalIndices[start + i],
73
+ rewriter.create <arith::ConstantIndexOp>(loc, offset));
74
+ }
75
+ }
76
+ return indices;
77
+ }
78
+
57
79
// Clones `op` into a new operations that takes `operands` and returns
58
80
// `resultTypes`.
59
81
static Operation *cloneOpWithOperandsAndTypes (OpBuilder &builder, Location loc,
@@ -631,6 +653,90 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
631
653
vector::UnrollVectorOptions options;
632
654
};
633
655
656
+ struct UnrollLoadPattern : public OpRewritePattern <vector::LoadOp> {
657
+ UnrollLoadPattern (MLIRContext *context,
658
+ const vector::UnrollVectorOptions &options,
659
+ PatternBenefit benefit = 1 )
660
+ : OpRewritePattern<vector::LoadOp>(context, benefit), options(options) {}
661
+
662
+ LogicalResult matchAndRewrite (vector::LoadOp loadOp,
663
+ PatternRewriter &rewriter) const override {
664
+ VectorType vecType = loadOp.getVectorType ();
665
+
666
+ auto targetShape = getTargetShape (options, loadOp);
667
+ if (!targetShape)
668
+ return failure ();
669
+
670
+ Location loc = loadOp.getLoc ();
671
+ ArrayRef<int64_t > originalShape = vecType.getShape ();
672
+ SmallVector<int64_t > strides (targetShape->size (), 1 );
673
+
674
+ Value result = rewriter.create <arith::ConstantOp>(
675
+ loc, vecType, rewriter.getZeroAttr (vecType));
676
+
677
+ SmallVector<int64_t > loopOrder =
678
+ getUnrollOrder (originalShape.size (), loadOp, options);
679
+
680
+ auto targetVecType =
681
+ VectorType::get (*targetShape, vecType.getElementType ());
682
+
683
+ for (SmallVector<int64_t > offsets :
684
+ StaticTileOffsetRange (originalShape, *targetShape, loopOrder)) {
685
+ SmallVector<Value> indices =
686
+ sliceLoadStoreIndices (rewriter, loc, loadOp.getIndices (), offsets);
687
+ Value slicedLoad = rewriter.create <vector::LoadOp>(
688
+ loc, targetVecType, loadOp.getBase (), indices);
689
+ result = rewriter.createOrFold <vector::InsertStridedSliceOp>(
690
+ loc, slicedLoad, result, offsets, strides);
691
+ }
692
+ rewriter.replaceOp (loadOp, result);
693
+ return success ();
694
+ }
695
+
696
+ private:
697
+ vector::UnrollVectorOptions options;
698
+ };
699
+
700
+ struct UnrollStorePattern : public OpRewritePattern <vector::StoreOp> {
701
+ UnrollStorePattern (MLIRContext *context,
702
+ const vector::UnrollVectorOptions &options,
703
+ PatternBenefit benefit = 1 )
704
+ : OpRewritePattern<vector::StoreOp>(context, benefit), options(options) {}
705
+
706
+ LogicalResult matchAndRewrite (vector::StoreOp storeOp,
707
+ PatternRewriter &rewriter) const override {
708
+ VectorType vecType = storeOp.getVectorType ();
709
+
710
+ auto targetShape = getTargetShape (options, storeOp);
711
+ if (!targetShape)
712
+ return failure ();
713
+
714
+ Location loc = storeOp.getLoc ();
715
+ ArrayRef<int64_t > originalShape = vecType.getShape ();
716
+ SmallVector<int64_t > strides (targetShape->size (), 1 );
717
+
718
+ Value base = storeOp.getBase ();
719
+ Value vector = storeOp.getValueToStore ();
720
+
721
+ SmallVector<int64_t > loopOrder =
722
+ getUnrollOrder (originalShape.size (), storeOp, options);
723
+
724
+ for (SmallVector<int64_t > offsets :
725
+ StaticTileOffsetRange (originalShape, *targetShape, loopOrder)) {
726
+ SmallVector<Value> indices =
727
+ sliceLoadStoreIndices (rewriter, loc, storeOp.getIndices (), offsets);
728
+ Value slice = rewriter.createOrFold <vector::ExtractStridedSliceOp>(
729
+ loc, vector, offsets, *targetShape, strides);
730
+ rewriter.create <vector::StoreOp>(loc, slice, base, indices);
731
+ }
732
+ rewriter.eraseOp (storeOp);
733
+ return success ();
734
+ }
735
+
736
+ private:
737
+ vector::UnrollVectorOptions options;
738
+ };
739
+
634
740
struct UnrollBroadcastPattern : public OpRewritePattern <vector::BroadcastOp> {
635
741
UnrollBroadcastPattern (MLIRContext *context,
636
742
const vector::UnrollVectorOptions &options,
@@ -699,10 +805,10 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
699
805
void mlir::vector::populateVectorUnrollPatterns (
700
806
RewritePatternSet &patterns, const UnrollVectorOptions &options,
701
807
PatternBenefit benefit) {
702
- patterns
703
- . add <UnrollTransferReadPattern, UnrollTransferWritePattern ,
704
- UnrollContractionPattern, UnrollElementwisePattern ,
705
- UnrollReductionPattern, UnrollMultiReductionPattern ,
706
- UnrollTransposePattern, UnrollGatherPattern , UnrollBroadcastPattern>(
707
- patterns.getContext (), options, benefit);
808
+ patterns. add <UnrollTransferReadPattern, UnrollTransferWritePattern,
809
+ UnrollContractionPattern, UnrollElementwisePattern ,
810
+ UnrollReductionPattern, UnrollMultiReductionPattern ,
811
+ UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern ,
812
+ UnrollStorePattern , UnrollBroadcastPattern>(
813
+ patterns.getContext (), options, benefit);
708
814
}
0 commit comments