@@ -550,6 +550,89 @@ getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
550
550
return {makeComposedFoldedAffineMin (builder, collapseShape.getLoc (), minMap,
551
551
groupStrides)};
552
552
}
553
+
554
+ // / From `reshape_like(memref, subSizes, subStrides))` compute
555
+ // /
556
+ // / \verbatim
557
+ // / baseBuffer, baseOffset, baseSizes, baseStrides =
558
+ // / extract_strided_metadata(memref)
559
+ // / strides#i = baseStrides#i * subStrides#i
560
+ // / sizes = subSizes
561
+ // / \endverbatim
562
+ // /
563
+ // / and return {baseBuffer, baseOffset, sizes, strides}
564
+ template <typename ReassociativeReshapeLikeOp>
565
+ static FailureOr<StridedMetadata> resolveReshapeStridedMetadata (
566
+ RewriterBase &rewriter, ReassociativeReshapeLikeOp reshape,
567
+ function_ref<SmallVector<OpFoldResult>(
568
+ ReassociativeReshapeLikeOp, OpBuilder &,
569
+ ArrayRef<OpFoldResult> /* origSizes*/ , unsigned /* groupId*/ )>
570
+ getReshapedSizes,
571
+ function_ref<SmallVector<OpFoldResult>(
572
+ ReassociativeReshapeLikeOp, OpBuilder &,
573
+ ArrayRef<OpFoldResult> /* origSizes*/ ,
574
+ ArrayRef<OpFoldResult> /* origStrides*/ , unsigned /* groupId*/ )>
575
+ getReshapedStrides) {
576
+ // Build a plain extract_strided_metadata(memref) from
577
+ // extract_strided_metadata(reassociative_reshape_like(memref)).
578
+ Location origLoc = reshape.getLoc ();
579
+ Value source = reshape.getSrc ();
580
+ auto sourceType = cast<MemRefType>(source.getType ());
581
+ unsigned sourceRank = sourceType.getRank ();
582
+
583
+ auto newExtractStridedMetadata =
584
+ rewriter.create <memref::ExtractStridedMetadataOp>(origLoc, source);
585
+
586
+ // Collect statically known information.
587
+ auto [strides, offset] = getStridesAndOffset (sourceType);
588
+ MemRefType reshapeType = reshape.getResultType ();
589
+ unsigned reshapeRank = reshapeType.getRank ();
590
+
591
+ OpFoldResult offsetOfr =
592
+ ShapedType::isDynamic (offset)
593
+ ? getAsOpFoldResult (newExtractStridedMetadata.getOffset ())
594
+ : rewriter.getIndexAttr (offset);
595
+
596
+ // Get the special case of 0-D out of the way.
597
+ if (sourceRank == 0 ) {
598
+ SmallVector<OpFoldResult> ones (reshapeRank, rewriter.getIndexAttr (1 ));
599
+ return StridedMetadata{newExtractStridedMetadata.getBaseBuffer (), offsetOfr,
600
+ /* sizes=*/ ones, /* strides=*/ ones};
601
+ }
602
+
603
+ SmallVector<OpFoldResult> finalSizes;
604
+ finalSizes.reserve (reshapeRank);
605
+ SmallVector<OpFoldResult> finalStrides;
606
+ finalStrides.reserve (reshapeRank);
607
+
608
+ // Compute the reshaped strides and sizes from the base strides and sizes.
609
+ SmallVector<OpFoldResult> origSizes =
610
+ getAsOpFoldResult (newExtractStridedMetadata.getSizes ());
611
+ SmallVector<OpFoldResult> origStrides =
612
+ getAsOpFoldResult (newExtractStridedMetadata.getStrides ());
613
+ unsigned idx = 0 , endIdx = reshape.getReassociationIndices ().size ();
614
+ for (; idx != endIdx; ++idx) {
615
+ SmallVector<OpFoldResult> reshapedSizes =
616
+ getReshapedSizes (reshape, rewriter, origSizes, /* groupId=*/ idx);
617
+ SmallVector<OpFoldResult> reshapedStrides = getReshapedStrides (
618
+ reshape, rewriter, origSizes, origStrides, /* groupId=*/ idx);
619
+
620
+ unsigned groupSize = reshapedSizes.size ();
621
+ for (unsigned i = 0 ; i < groupSize; ++i) {
622
+ finalSizes.push_back (reshapedSizes[i]);
623
+ finalStrides.push_back (reshapedStrides[i]);
624
+ }
625
+ }
626
+ assert (((isa<memref::ExpandShapeOp>(reshape) && idx == sourceRank) ||
627
+ (isa<memref::CollapseShapeOp>(reshape) && idx == reshapeRank)) &&
628
+ " We should have visited all the input dimensions" );
629
+ assert (finalSizes.size () == reshapeRank &&
630
+ " We should have populated all the values" );
631
+
632
+ return StridedMetadata{newExtractStridedMetadata.getBaseBuffer (), offsetOfr,
633
+ finalSizes, finalStrides};
634
+ }
635
+
553
636
// / Replace `baseBuffer, offset, sizes, strides =
554
637
// / extract_strided_metadata(reshapeLike(memref))`
555
638
// / With
@@ -580,68 +663,65 @@ struct ReshapeFolder : public OpRewritePattern<ReassociativeReshapeLikeOp> {
580
663
581
664
LogicalResult matchAndRewrite (ReassociativeReshapeLikeOp reshape,
582
665
PatternRewriter &rewriter) const override {
583
- // Build a plain extract_strided_metadata(memref) from
584
- // extract_strided_metadata(reassociative_reshape_like(memref)).
585
- Location origLoc = reshape.getLoc ();
586
- Value source = reshape.getSrc ();
587
- auto sourceType = cast<MemRefType>(source.getType ());
588
- unsigned sourceRank = sourceType.getRank ();
589
-
590
- auto newExtractStridedMetadata =
591
- rewriter.create <memref::ExtractStridedMetadataOp>(origLoc, source);
592
-
593
- // Collect statically known information.
594
- auto [strides, offset] = getStridesAndOffset (sourceType);
595
- MemRefType reshapeType = reshape.getResultType ();
596
- unsigned reshapeRank = reshapeType.getRank ();
597
-
598
- OpFoldResult offsetOfr =
599
- ShapedType::isDynamic (offset)
600
- ? getAsOpFoldResult (newExtractStridedMetadata.getOffset ())
601
- : rewriter.getIndexAttr (offset);
602
-
603
- // Get the special case of 0-D out of the way.
604
- if (sourceRank == 0 ) {
605
- SmallVector<OpFoldResult> ones (reshapeRank, rewriter.getIndexAttr (1 ));
606
- auto memrefDesc = rewriter.create <memref::ReinterpretCastOp>(
607
- origLoc, reshapeType, newExtractStridedMetadata.getBaseBuffer (),
608
- offsetOfr, /* sizes=*/ ones, /* strides=*/ ones);
609
- rewriter.replaceOp (reshape, memrefDesc.getResult ());
610
- return success ();
666
+ FailureOr<StridedMetadata> stridedMetadata =
667
+ resolveReshapeStridedMetadata<ReassociativeReshapeLikeOp>(
668
+ rewriter, reshape, getReshapedSizes, getReshapedStrides);
669
+ if (failed (stridedMetadata)) {
670
+ return rewriter.notifyMatchFailure (reshape,
671
+ " failed to resolve reshape metadata" );
611
672
}
612
673
613
- SmallVector<OpFoldResult> finalSizes;
614
- finalSizes.reserve (reshapeRank);
615
- SmallVector<OpFoldResult> finalStrides;
616
- finalStrides.reserve (reshapeRank);
617
-
618
- // Compute the reshaped strides and sizes from the base strides and sizes.
619
- SmallVector<OpFoldResult> origSizes =
620
- getAsOpFoldResult (newExtractStridedMetadata.getSizes ());
621
- SmallVector<OpFoldResult> origStrides =
622
- getAsOpFoldResult (newExtractStridedMetadata.getStrides ());
623
- unsigned idx = 0 , endIdx = reshape.getReassociationIndices ().size ();
624
- for (; idx != endIdx; ++idx) {
625
- SmallVector<OpFoldResult> reshapedSizes =
626
- getReshapedSizes (reshape, rewriter, origSizes, /* groupId=*/ idx);
627
- SmallVector<OpFoldResult> reshapedStrides = getReshapedStrides (
628
- reshape, rewriter, origSizes, origStrides, /* groupId=*/ idx);
629
-
630
- unsigned groupSize = reshapedSizes.size ();
631
- for (unsigned i = 0 ; i < groupSize; ++i) {
632
- finalSizes.push_back (reshapedSizes[i]);
633
- finalStrides.push_back (reshapedStrides[i]);
634
- }
674
+ rewriter.replaceOpWithNewOp <memref::ReinterpretCastOp>(
675
+ reshape, reshape.getType (), stridedMetadata->basePtr ,
676
+ stridedMetadata->offset , stridedMetadata->sizes ,
677
+ stridedMetadata->strides );
678
+ return success ();
679
+ }
680
+ };
681
+
682
+ // / Pattern to replace `extract_strided_metadata(collapse_shape)`
683
+ // / With
684
+ // /
685
+ // / \verbatim
686
+ // / baseBuffer, baseOffset, baseSizes, baseStrides =
687
+ // / extract_strided_metadata(memref)
688
+ // / strides#i = baseStrides#i * subSizes#i
689
+ // / offset = baseOffset + sum(subOffset#i * baseStrides#i)
690
+ // / sizes = subSizes
691
+ // / \verbatim
692
+ // /
693
+ // / with `baseBuffer`, `offset`, `sizes` and `strides` being
694
+ // / the replacements for the original `extract_strided_metadata`.
695
+ struct ExtractStridedMetadataOpCollapseShapeFolder
696
+ : OpRewritePattern<memref::ExtractStridedMetadataOp> {
697
+ using OpRewritePattern::OpRewritePattern;
698
+
699
+ LogicalResult matchAndRewrite (memref::ExtractStridedMetadataOp op,
700
+ PatternRewriter &rewriter) const override {
701
+ auto collapseShapeOp =
702
+ op.getSource ().getDefiningOp <memref::CollapseShapeOp>();
703
+ if (!collapseShapeOp)
704
+ return failure ();
705
+
706
+ FailureOr<StridedMetadata> stridedMetadata =
707
+ resolveReshapeStridedMetadata<memref::CollapseShapeOp>(
708
+ rewriter, collapseShapeOp, getCollapsedSize, getCollapsedStride);
709
+ if (failed (stridedMetadata)) {
710
+ return rewriter.notifyMatchFailure (
711
+ op,
712
+ " failed to resolve metadata in terms of source collapse_shape op" );
635
713
}
636
- assert (((isa<memref::ExpandShapeOp>(reshape) && idx == sourceRank) ||
637
- (isa<memref::CollapseShapeOp>(reshape) && idx == reshapeRank)) &&
638
- " We should have visited all the input dimensions" );
639
- assert (finalSizes.size () == reshapeRank &&
640
- " We should have populated all the values" );
641
- auto memrefDesc = rewriter.create <memref::ReinterpretCastOp>(
642
- origLoc, reshapeType, newExtractStridedMetadata.getBaseBuffer (),
643
- offsetOfr, finalSizes, finalStrides);
644
- rewriter.replaceOp (reshape, memrefDesc.getResult ());
714
+
715
+ Location loc = collapseShapeOp.getLoc ();
716
+ SmallVector<Value> results;
717
+ results.push_back (stridedMetadata->basePtr );
718
+ results.push_back (getValueOrCreateConstantIndexOp (rewriter, loc,
719
+ stridedMetadata->offset ));
720
+ results.append (
721
+ getValueOrCreateConstantIndexOp (rewriter, loc, stridedMetadata->sizes ));
722
+ results.append (getValueOrCreateConstantIndexOp (rewriter, loc,
723
+ stridedMetadata->strides ));
724
+ rewriter.replaceOp (op, results);
645
725
return success ();
646
726
}
647
727
};
@@ -1018,9 +1098,11 @@ void memref::populateExpandStridedMetadataPatterns(
1018
1098
getCollapsedStride>,
1019
1099
ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
1020
1100
ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
1101
+ ExtractStridedMetadataOpCollapseShapeFolder,
1021
1102
ExtractStridedMetadataOpGetGlobalFolder,
1022
1103
RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
1023
1104
ExtractStridedMetadataOpReinterpretCastFolder,
1105
+ ExtractStridedMetadataOpSubviewFolder,
1024
1106
ExtractStridedMetadataOpCastFolder,
1025
1107
ExtractStridedMetadataOpExtractStridedMetadataFolder>(
1026
1108
patterns.getContext ());
@@ -1030,6 +1112,7 @@ void memref::populateResolveExtractStridedMetadataPatterns(
1030
1112
RewritePatternSet &patterns) {
1031
1113
patterns.add <ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
1032
1114
ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
1115
+ ExtractStridedMetadataOpCollapseShapeFolder,
1033
1116
ExtractStridedMetadataOpGetGlobalFolder,
1034
1117
ExtractStridedMetadataOpSubviewFolder,
1035
1118
RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
0 commit comments