Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 130 additions & 59 deletions mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,78 @@ getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
return {makeComposedFoldedAffineMin(builder, collapseShape.getLoc(), minMap,
groupStrides)};
}

template <typename ReassociativeReshapeLikeOp,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you move the comment explaining what this does?

SmallVector<OpFoldResult> (*getReshapedSizes)(
ReassociativeReshapeLikeOp, OpBuilder &,
ArrayRef<OpFoldResult> /*origSizes*/, unsigned /*groupId*/),
SmallVector<OpFoldResult> (*getReshapedStrides)(
ReassociativeReshapeLikeOp, OpBuilder &,
ArrayRef<OpFoldResult> /*origSizes*/,
ArrayRef<OpFoldResult> /*origStrides*/, unsigned /*groupId*/)>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A fly-by nit: is there an observable benefit to using these template arguments as opposed to passing in function_ref callbacks?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably not. I mostly followed the style of the caller

static FailureOr<StridedMetadata>
resolveReshapeStridedMetadata(RewriterBase &rewriter,
ReassociativeReshapeLikeOp reshape) {
// Build a plain extract_strided_metadata(memref) from
// extract_strided_metadata(reassociative_reshape_like(memref)).
Location origLoc = reshape.getLoc();
Value source = reshape.getSrc();
auto sourceType = cast<MemRefType>(source.getType());
unsigned sourceRank = sourceType.getRank();

auto newExtractStridedMetadata =
rewriter.create<memref::ExtractStridedMetadataOp>(origLoc, source);

// Collect statically known information.
auto [strides, offset] = getStridesAndOffset(sourceType);
MemRefType reshapeType = reshape.getResultType();
unsigned reshapeRank = reshapeType.getRank();

OpFoldResult offsetOfr =
ShapedType::isDynamic(offset)
? getAsOpFoldResult(newExtractStridedMetadata.getOffset())
: rewriter.getIndexAttr(offset);

// Get the special case of 0-D out of the way.
if (sourceRank == 0) {
SmallVector<OpFoldResult> ones(reshapeRank, rewriter.getIndexAttr(1));
return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), offsetOfr,
/*sizes=*/ones, /*strides=*/ones};
}

SmallVector<OpFoldResult> finalSizes;
finalSizes.reserve(reshapeRank);
SmallVector<OpFoldResult> finalStrides;
finalStrides.reserve(reshapeRank);

// Compute the reshaped strides and sizes from the base strides and sizes.
SmallVector<OpFoldResult> origSizes =
getAsOpFoldResult(newExtractStridedMetadata.getSizes());
SmallVector<OpFoldResult> origStrides =
getAsOpFoldResult(newExtractStridedMetadata.getStrides());
unsigned idx = 0, endIdx = reshape.getReassociationIndices().size();
for (; idx != endIdx; ++idx) {
SmallVector<OpFoldResult> reshapedSizes =
getReshapedSizes(reshape, rewriter, origSizes, /*groupId=*/idx);
SmallVector<OpFoldResult> reshapedStrides = getReshapedStrides(
reshape, rewriter, origSizes, origStrides, /*groupId=*/idx);

unsigned groupSize = reshapedSizes.size();
for (unsigned i = 0; i < groupSize; ++i) {
finalSizes.push_back(reshapedSizes[i]);
finalStrides.push_back(reshapedStrides[i]);
}
}
assert(((isa<memref::ExpandShapeOp>(reshape) && idx == sourceRank) ||
(isa<memref::CollapseShapeOp>(reshape) && idx == reshapeRank)) &&
"We should have visited all the input dimensions");
assert(finalSizes.size() == reshapeRank &&
"We should have populated all the values");

return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), offsetOfr,
finalSizes, finalStrides};
}

/// Replace `baseBuffer, offset, sizes, strides =
/// extract_strided_metadata(reshapeLike(memref))`
/// With
Expand Down Expand Up @@ -580,68 +652,66 @@ struct ReshapeFolder : public OpRewritePattern<ReassociativeReshapeLikeOp> {

LogicalResult matchAndRewrite(ReassociativeReshapeLikeOp reshape,
PatternRewriter &rewriter) const override {
// Build a plain extract_strided_metadata(memref) from
// extract_strided_metadata(reassociative_reshape_like(memref)).
Location origLoc = reshape.getLoc();
Value source = reshape.getSrc();
auto sourceType = cast<MemRefType>(source.getType());
unsigned sourceRank = sourceType.getRank();

auto newExtractStridedMetadata =
rewriter.create<memref::ExtractStridedMetadataOp>(origLoc, source);

// Collect statically known information.
auto [strides, offset] = getStridesAndOffset(sourceType);
MemRefType reshapeType = reshape.getResultType();
unsigned reshapeRank = reshapeType.getRank();

OpFoldResult offsetOfr =
ShapedType::isDynamic(offset)
? getAsOpFoldResult(newExtractStridedMetadata.getOffset())
: rewriter.getIndexAttr(offset);

// Get the special case of 0-D out of the way.
if (sourceRank == 0) {
SmallVector<OpFoldResult> ones(reshapeRank, rewriter.getIndexAttr(1));
auto memrefDesc = rewriter.create<memref::ReinterpretCastOp>(
origLoc, reshapeType, newExtractStridedMetadata.getBaseBuffer(),
offsetOfr, /*sizes=*/ones, /*strides=*/ones);
rewriter.replaceOp(reshape, memrefDesc.getResult());
return success();
FailureOr<StridedMetadata> stridedMetadata =
resolveReshapeStridedMetadata<ReassociativeReshapeLikeOp,
getReshapedSizes, getReshapedStrides>(
rewriter, reshape);
if (failed(stridedMetadata)) {
return rewriter.notifyMatchFailure(reshape,
"failed to resolve reshape metadata");
}

SmallVector<OpFoldResult> finalSizes;
finalSizes.reserve(reshapeRank);
SmallVector<OpFoldResult> finalStrides;
finalStrides.reserve(reshapeRank);

// Compute the reshaped strides and sizes from the base strides and sizes.
SmallVector<OpFoldResult> origSizes =
getAsOpFoldResult(newExtractStridedMetadata.getSizes());
SmallVector<OpFoldResult> origStrides =
getAsOpFoldResult(newExtractStridedMetadata.getStrides());
unsigned idx = 0, endIdx = reshape.getReassociationIndices().size();
for (; idx != endIdx; ++idx) {
SmallVector<OpFoldResult> reshapedSizes =
getReshapedSizes(reshape, rewriter, origSizes, /*groupId=*/idx);
SmallVector<OpFoldResult> reshapedStrides = getReshapedStrides(
reshape, rewriter, origSizes, origStrides, /*groupId=*/idx);

unsigned groupSize = reshapedSizes.size();
for (unsigned i = 0; i < groupSize; ++i) {
finalSizes.push_back(reshapedSizes[i]);
finalStrides.push_back(reshapedStrides[i]);
}
rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
reshape, reshape.getType(), stridedMetadata->basePtr,
stridedMetadata->offset, stridedMetadata->sizes,
stridedMetadata->strides);
return success();
}
};

/// Pattern to replace `extract_strided_metadata(collapse_shape)`
/// With
///
/// \verbatim
/// baseBuffer, baseOffset, baseSizes, baseStrides =
/// extract_strided_metadata(memref)
/// strides#i = baseStrides#i * subSizes#i
/// offset = baseOffset + sum(subOffset#i * baseStrides#i)
/// sizes = subSizes
/// \verbatim
///
/// with `baseBuffer`, `offset`, `sizes` and `strides` being
/// the replacements for the original `extract_strided_metadata`.
struct ExtractStridedMetadataOpCollapseShapeFolder
: OpRewritePattern<memref::ExtractStridedMetadataOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
PatternRewriter &rewriter) const override {
auto collapseShapeOp =
op.getSource().getDefiningOp<memref::CollapseShapeOp>();
if (!collapseShapeOp)
return failure();

FailureOr<StridedMetadata> stridedMetadata =
resolveReshapeStridedMetadata<memref::CollapseShapeOp, getCollapsedSize,
getCollapsedStride>(rewriter,
collapseShapeOp);
if (failed(stridedMetadata)) {
return rewriter.notifyMatchFailure(
op, "failed to resolve metadata in terms of source collapse_shape op");
}
assert(((isa<memref::ExpandShapeOp>(reshape) && idx == sourceRank) ||
(isa<memref::CollapseShapeOp>(reshape) && idx == reshapeRank)) &&
"We should have visited all the input dimensions");
assert(finalSizes.size() == reshapeRank &&
"We should have populated all the values");
auto memrefDesc = rewriter.create<memref::ReinterpretCastOp>(
origLoc, reshapeType, newExtractStridedMetadata.getBaseBuffer(),
offsetOfr, finalSizes, finalStrides);
rewriter.replaceOp(reshape, memrefDesc.getResult());

Location loc = collapseShapeOp.getLoc();
SmallVector<Value> results;
results.push_back(stridedMetadata->basePtr);
results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc,
stridedMetadata->offset));
results.append(
getValueOrCreateConstantIndexOp(rewriter, loc, stridedMetadata->sizes));
results.append(getValueOrCreateConstantIndexOp(rewriter, loc,
stridedMetadata->strides));
rewriter.replaceOp(op, results);
return success();
}
};
Expand Down Expand Up @@ -1030,6 +1100,7 @@ void memref::populateResolveExtractStridedMetadataPatterns(
RewritePatternSet &patterns) {
patterns.add<ExtractStridedMetadataOpAllocFolder<memref::AllocOp>,
ExtractStridedMetadataOpAllocFolder<memref::AllocaOp>,
ExtractStridedMetadataOpCollapseShapeFolder,
ExtractStridedMetadataOpGetGlobalFolder,
ExtractStridedMetadataOpSubviewFolder,
RewriteExtractAlignedPointerAsIndexOfViewLikeOp,
Expand Down