Skip to content

Commit 450ac01

Browse files
authored
[mlir][MemRef] Add ExtractStridedMetadataOpCollapseShapeFolder (#89954)
This PR adds a new pattern to the set of patterns used to resolve the offset, sizes and stride of a memref. Similar to `ExtractStridedMetadataOpSubviewFolder`, the new pattern resolves strided_metadata(collapse_shape) directly, without introduce a reshape_cast op.
1 parent d74e42a commit 450ac01

File tree

2 files changed

+165
-60
lines changed

2 files changed

+165
-60
lines changed

mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp

Lines changed: 142 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,89 @@ getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
550550
return {makeComposedFoldedAffineMin(builder, collapseShape.getLoc(), minMap,
551551
groupStrides)};
552552
}
553+
554+
/// From `reshape_like(memref, subSizes, subStrides))` compute
555+
///
556+
/// \verbatim
557+
/// baseBuffer, baseOffset, baseSizes, baseStrides =
558+
/// extract_strided_metadata(memref)
559+
/// strides#i = baseStrides#i * subStrides#i
560+
/// sizes = subSizes
561+
/// \endverbatim
562+
///
563+
/// and return {baseBuffer, baseOffset, sizes, strides}
564+
template <typename ReassociativeReshapeLikeOp>
565+
static FailureOr<StridedMetadata> resolveReshapeStridedMetadata(
566+
RewriterBase &rewriter, ReassociativeReshapeLikeOp reshape,
567+
function_ref<SmallVector<OpFoldResult>(
568+
ReassociativeReshapeLikeOp, OpBuilder &,
569+
ArrayRef<OpFoldResult> /*origSizes*/, unsigned /*groupId*/)>
570+
getReshapedSizes,
571+
function_ref<SmallVector<OpFoldResult>(
572+
ReassociativeReshapeLikeOp, OpBuilder &,
573+
ArrayRef<OpFoldResult> /*origSizes*/,
574+
ArrayRef<OpFoldResult> /*origStrides*/, unsigned /*groupId*/)>
575+
getReshapedStrides) {
576+
// Build a plain extract_strided_metadata(memref) from
577+
// extract_strided_metadata(reassociative_reshape_like(memref)).
578+
Location origLoc = reshape.getLoc();
579+
Value source = reshape.getSrc();
580+
auto sourceType = cast<MemRefType>(source.getType());
581+
unsigned sourceRank = sourceType.getRank();
582+
583+
auto newExtractStridedMetadata =
584+
rewriter.create<memref::ExtractStridedMetadataOp>(origLoc, source);
585+
586+
// Collect statically known information.
587+
auto [strides, offset] = getStridesAndOffset(sourceType);
588+
MemRefType reshapeType = reshape.getResultType();
589+
unsigned reshapeRank = reshapeType.getRank();
590+
591+
OpFoldResult offsetOfr =
592+
ShapedType::isDynamic(offset)
593+
? getAsOpFoldResult(newExtractStridedMetadata.getOffset())
594+
: rewriter.getIndexAttr(offset);
595+
596+
// Get the special case of 0-D out of the way.
597+
if (sourceRank == 0) {
598+
SmallVector<OpFoldResult> ones(reshapeRank, rewriter.getIndexAttr(1));
599+
return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), offsetOfr,
600+
/*sizes=*/ones, /*strides=*/ones};
601+
}
602+
603+
SmallVector<OpFoldResult> finalSizes;
604+
finalSizes.reserve(reshapeRank);
605+
SmallVector<OpFoldResult> finalStrides;
606+
finalStrides.reserve(reshapeRank);
607+
608+
// Compute the reshaped strides and sizes from the base strides and sizes.
609+
SmallVector<OpFoldResult> origSizes =
610+
getAsOpFoldResult(newExtractStridedMetadata.getSizes());
611+
SmallVector<OpFoldResult> origStrides =
612+
getAsOpFoldResult(newExtractStridedMetadata.getStrides());
613+
unsigned idx = 0, endIdx = reshape.getReassociationIndices().size();
614+
for (; idx != endIdx; ++idx) {
615+
SmallVector<OpFoldResult> reshapedSizes =
616+
getReshapedSizes(reshape, rewriter, origSizes, /*groupId=*/idx);
617+
SmallVector<OpFoldResult> reshapedStrides = getReshapedStrides(
618+
reshape, rewriter, origSizes, origStrides, /*groupId=*/idx);
619+
620+
unsigned groupSize = reshapedSizes.size();
621+
for (unsigned i = 0; i < groupSize; ++i) {
622+
finalSizes.push_back(reshapedSizes[i]);
623+
finalStrides.push_back(reshapedStrides[i]);
624+
}
625+
}
626+
assert(((isa<memref::ExpandShapeOp>(reshape) && idx == sourceRank) ||
627+
(isa<memref::CollapseShapeOp>(reshape) && idx == reshapeRank)) &&
628+
"We should have visited all the input dimensions");
629+
assert(finalSizes.size() == reshapeRank &&
630+
"We should have populated all the values");
631+
632+
return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), offsetOfr,
633+
finalSizes, finalStrides};
634+
}
635+
553636
/// Replace `baseBuffer, offset, sizes, strides =
554637
/// extract_strided_metadata(reshapeLike(memref))`
555638
/// With
@@ -580,68 +663,65 @@ struct ReshapeFolder : public OpRewritePattern<ReassociativeReshapeLikeOp> {
580663

581664
LogicalResult matchAndRewrite(ReassociativeReshapeLikeOp reshape,
582665
PatternRewriter &rewriter) const override {
583-
// Build a plain extract_strided_metadata(memref) from
584-
// extract_strided_metadata(reassociative_reshape_like(memref)).
585-
Location origLoc = reshape.getLoc();
586-
Value source = reshape.getSrc();
587-
auto sourceType = cast<MemRefType>(source.getType());
588-
unsigned sourceRank = sourceType.getRank();
589-
590-
auto newExtractStridedMetadata =
591-
rewriter.create<memref::ExtractStridedMetadataOp>(origLoc, source);
592-
593-
// Collect statically known information.
594-
auto [strides, offset] = getStridesAndOffset(sourceType);
595-
MemRefType reshapeType = reshape.getResultType();
596-
unsigned reshapeRank = reshapeType.getRank();
597-
598-
OpFoldResult offsetOfr =
599-
ShapedType::isDynamic(offset)
600-
? getAsOpFoldResult(newExtractStridedMetadata.getOffset())
601-
: rewriter.getIndexAttr(offset);
602-
603-
// Get the special case of 0-D out of the way.
604-
if (sourceRank == 0) {
605-
SmallVector<OpFoldResult> ones(reshapeRank, rewriter.getIndexAttr(1));
606-
auto memrefDesc = rewriter.create<memref::ReinterpretCastOp>(
607-
origLoc, reshapeType, newExtractStridedMetadata.getBaseBuffer(),
608-
offsetOfr, /*sizes=*/ones, /*strides=*/ones);
609-
rewriter.replaceOp(reshape, memrefDesc.getResult());
610-
return success();
666+
FailureOr<StridedMetadata> stridedMetadata =
667+
resolveReshapeStridedMetadata<ReassociativeReshapeLikeOp>(
668+
rewriter, reshape, getReshapedSizes, getReshapedStrides);
669+
if (failed(stridedMetadata)) {
670+
return rewriter.notifyMatchFailure(reshape,
671+
"failed to resolve reshape metadata");
611672
}
612673

613-
SmallVector<OpFoldResult> finalSizes;
614-
finalSizes.reserve(reshapeRank);
615-
SmallVector<OpFoldResult> finalStrides;
616-
finalStrides.reserve(reshapeRank);
617-
618-
// Compute the reshaped strides and sizes from the base strides and sizes.
619-
SmallVector<OpFoldResult> origSizes =
620-
getAsOpFoldResult(newExtractStridedMetadata.getSizes());
621-
SmallVector<OpFoldResult> origStrides =
622-
getAsOpFoldResult(newExtractStridedMetadata.getStrides());
623-
unsigned idx = 0, endIdx = reshape.getReassociationIndices().size();
624-
for (; idx != endIdx; ++idx) {
625-
SmallVector<OpFoldResult> reshapedSizes =
626-
getReshapedSizes(reshape, rewriter, origSizes, /*groupId=*/idx);
627-
SmallVector<OpFoldResult> reshapedStrides = getReshapedStrides(
628-
reshape, rewriter, origSizes, origStrides, /*groupId=*/idx);
629-
630-
unsigned groupSize = reshapedSizes.size();
631-
for (unsigned i = 0; i < groupSize; ++i) {
632-
finalSizes.push_back(reshapedSizes[i]);
633-
finalStrides.push_back(reshapedStrides[i]);
634-
}
674+
rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
675+
reshape, reshape.getType(), stridedMetadata->basePtr,
676+
stridedMetadata->offset, stridedMetadata->sizes,
677+
stridedMetadata->strides);
678+
return success();
679+
}
680+
};
681+
682+
/// Pattern to replace `extract_strided_metadata(collapse_shape)`
683+
/// With
684+
///
685+
/// \verbatim
686+
/// baseBuffer, baseOffset, baseSizes, baseStrides =
687+
/// extract_strided_metadata(memref)
688+
/// strides#i = baseStrides#i * subSizes#i
689+
/// offset = baseOffset + sum(subOffset#i * baseStrides#i)
690+
/// sizes = subSizes
691+
/// \verbatim
692+
///
693+
/// with `baseBuffer`, `offset`, `sizes` and `strides` being
694+
/// the replacements for the original `extract_strided_metadata`.
695+
struct ExtractStridedMetadataOpCollapseShapeFolder
696+
: OpRewritePattern<memref::ExtractStridedMetadataOp> {
697+
using OpRewritePattern::OpRewritePattern;
698+
699+
LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
700+
PatternRewriter &rewriter) const override {
701+
auto collapseShapeOp =
702+
op.getSource().getDefiningOp<memref::CollapseShapeOp>();
703+
if (!collapseShapeOp)
704+
return failure();
705+
706+
FailureOr<StridedMetadata> stridedMetadata =
707+
resolveReshapeStridedMetadata<memref::CollapseShapeOp>(
708+
rewriter, collapseShapeOp, getCollapsedSize, getCollapsedStride);
709+
if (failed(stridedMetadata)) {
710+
return rewriter.notifyMatchFailure(
711+
op,
712+
"failed to resolve metadata in terms of source collapse_shape op");
635713
}
636-
assert(((isa<memref::ExpandShapeOp>(reshape) && idx == sourceRank) ||
637-
(isa<memref::CollapseShapeOp>(reshape) && idx == reshapeRank)) &&
638-
"We should have visited all the input dimensions");
639-
assert(finalSizes.size() == reshapeRank &&
640-
"We should have populated all the values");
641-
auto memrefDesc = rewriter.create<memref::ReinterpretCastOp>(
642-
origLoc, reshapeType, newExtractStridedMetadata.getBaseBuffer(),
643-
offsetOfr, finalSizes, finalStrides);
644-
rewriter.replaceOp(reshape, memrefDesc.getResult());
714+
715+
Location loc = collapseShapeOp.getLoc();
716+
SmallVector<Value> results;
717+
results.push_back(stridedMetadata->basePtr);
718+
results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc,
719+
stridedMetadata->offset));
720+
results.append(
721+
getValueOrCreateConstantIndexOp(rewriter, loc, stridedMetadata->sizes));
722+
results.append(getValueOrCreateConstantIndexOp(rewriter, loc,
723+
stridedMetadata->strides));
724+
rewriter.replaceOp(op, results);
645725
return success();
646726
}
647727
};
@@ -1018,9 +1098,11 @@ void memref::populateExpandStridedMetadataPatterns(
10181098
getCollapsedStride>,
10191099
ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
10201100
ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
1101+
ExtractStridedMetadataOpCollapseShapeFolder,
10211102
ExtractStridedMetadataOpGetGlobalFolder,
10221103
RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
10231104
ExtractStridedMetadataOpReinterpretCastFolder,
1105+
ExtractStridedMetadataOpSubviewFolder,
10241106
ExtractStridedMetadataOpCastFolder,
10251107
ExtractStridedMetadataOpExtractStridedMetadataFolder>(
10261108
patterns.getContext());
@@ -1030,6 +1112,7 @@ void memref::populateResolveExtractStridedMetadataPatterns(
10301112
RewritePatternSet &patterns) {
10311113
patterns.add<ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
10321114
ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
1115+
ExtractStridedMetadataOpCollapseShapeFolder,
10331116
ExtractStridedMetadataOpGetGlobalFolder,
10341117
ExtractStridedMetadataOpSubviewFolder,
10351118
RewriteExtractAlignedPointerAsIndexOfViewLikeOp,

mlir/test/Dialect/MemRef/expand-strided-metadata.mlir

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1513,4 +1513,26 @@ func.func @zero_sized_memred(%arg0: f32) -> (memref<f16, 3>, index,index,index)
15131513
%sizes, %strides :
15141514
memref<f16,3>, index,
15151515
index, index
1516-
}
1516+
}
1517+
1518+
// -----
1519+
1520+
func.func @extract_strided_metadata_of_collapse_shape(%base: memref<5x4xf32>)
1521+
-> (memref<f32>, index, index, index) {
1522+
1523+
%collapse = memref.collapse_shape %base[[0, 1]] :
1524+
memref<5x4xf32> into memref<20xf32>
1525+
1526+
%base_buffer, %offset, %size, %stride = memref.extract_strided_metadata %collapse :
1527+
memref<20xf32> -> memref<f32>, index, index, index
1528+
1529+
return %base_buffer, %offset, %size, %stride :
1530+
memref<f32>, index, index, index
1531+
}
1532+
1533+
// CHECK-LABEL: func @extract_strided_metadata_of_collapse_shape
1534+
// CHECK-DAG: %[[OFFSET:.*]] = arith.constant 0 : index
1535+
// CHECK-DAG: %[[SIZE:.*]] = arith.constant 20 : index
1536+
// CHECK-DAG: %[[STEP:.*]] = arith.constant 1 : index
1537+
// CHECK: %[[BASE:.*]], %{{.*}}, %{{.*}}, %{{.*}} = memref.extract_strided_metadata
1538+
// CHECK: return %[[BASE]], %[[OFFSET]], %[[SIZE]], %[[STEP]] : memref<f32>, index, index, index

0 commit comments

Comments
 (0)