Skip to content

[mlir][Transforms] Add 1:N matchAndRewrite overload #116470

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 31 additions & 4 deletions mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ template <typename SourceOp>
class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
public:
using OpAdaptor = typename SourceOp::Adaptor;
using OneToNOpAdaptor =
typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;

explicit ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
Expand All @@ -153,17 +155,29 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
/// Wrappers around the RewritePattern methods that pass the derived op type.
void rewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
rewrite(cast<SourceOp>(op), OpAdaptor(operands, cast<SourceOp>(op)),
rewriter);
auto sourceOp = cast<SourceOp>(op);
rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
}
void rewrite(Operation *op, ArrayRef<ValueRange> operands,
ConversionPatternRewriter &rewriter) const final {
auto sourceOp = cast<SourceOp>(op);
rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter);
}
LogicalResult match(Operation *op) const final {
return match(cast<SourceOp>(op));
}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
return matchAndRewrite(cast<SourceOp>(op),
OpAdaptor(operands, cast<SourceOp>(op)), rewriter);
auto sourceOp = cast<SourceOp>(op);
return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
ConversionPatternRewriter &rewriter) const final {
auto sourceOp = cast<SourceOp>(op);
return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
rewriter);
}

/// Rewrite and Match methods that operate on the SourceOp type. These must be
Expand All @@ -175,6 +189,12 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
ConversionPatternRewriter &rewriter) const {
llvm_unreachable("must override rewrite or matchAndRewrite");
}
virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
SmallVector<Value> oneToOneOperands =
getOneToOneAdaptorOperands(adaptor.getOperands());
rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
}
virtual LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Expand All @@ -183,6 +203,13 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
rewrite(op, adaptor, rewriter);
return success();
}
virtual LogicalResult
matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
SmallVector<Value> oneToOneOperands =
getOneToOneAdaptorOperands(adaptor.getOperands());
return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
}

private:
using ConvertToLLVMPattern::match;
Expand Down
69 changes: 69 additions & 0 deletions mlir/include/mlir/Transforms/DialectConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -538,8 +538,15 @@ class ConversionPattern : public RewritePattern {
ConversionPatternRewriter &rewriter) const {
llvm_unreachable("unimplemented rewrite");
}
virtual void rewrite(Operation *op, ArrayRef<ValueRange> operands,
ConversionPatternRewriter &rewriter) const {
rewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
}

/// Hook for derived classes to implement combined matching and rewriting.
/// This overload supports only 1:1 replacements. The 1:N overload is called
/// by the driver. By default, it calls this 1:1 overload or reports a fatal
/// error if 1:N replacements were found.
virtual LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
Expand All @@ -549,6 +556,14 @@ class ConversionPattern : public RewritePattern {
return success();
}

/// Hook for derived classes to implement combined matching and rewriting.
/// This overload supports 1:N replacements.
virtual LogicalResult
matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
ConversionPatternRewriter &rewriter) const {
return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
}

/// Attempt to match and rewrite the IR root at the specified operation.
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const final;
Expand All @@ -575,6 +590,15 @@ class ConversionPattern : public RewritePattern {
: RewritePattern(std::forward<Args>(args)...),
typeConverter(&typeConverter) {}

/// Given an array of value ranges, which are the inputs to a 1:N adaptor,
/// try to extract the single value of each range to construct a the inputs
/// for a 1:1 adaptor.
///
/// This function produces a fatal error if at least one range has 0 or
/// more than 1 value: "pattern 'name' does not support 1:N conversion"
SmallVector<Value>
getOneToOneAdaptorOperands(ArrayRef<ValueRange> operands) const;

protected:
/// An optional type converter for use by this pattern.
const TypeConverter *typeConverter = nullptr;
Expand All @@ -590,6 +614,8 @@ template <typename SourceOp>
class OpConversionPattern : public ConversionPattern {
public:
using OpAdaptor = typename SourceOp::Adaptor;
using OneToNOpAdaptor =
typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;

OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
: ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
Expand All @@ -608,12 +634,24 @@ class OpConversionPattern : public ConversionPattern {
auto sourceOp = cast<SourceOp>(op);
rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
}
void rewrite(Operation *op, ArrayRef<ValueRange> operands,
ConversionPatternRewriter &rewriter) const final {
auto sourceOp = cast<SourceOp>(op);
rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter);
}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto sourceOp = cast<SourceOp>(op);
return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
ConversionPatternRewriter &rewriter) const final {
auto sourceOp = cast<SourceOp>(op);
return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
rewriter);
}

/// Rewrite and Match methods that operate on the SourceOp type. These must be
/// overridden by the derived pattern class.
Expand All @@ -624,6 +662,12 @@ class OpConversionPattern : public ConversionPattern {
ConversionPatternRewriter &rewriter) const {
llvm_unreachable("must override matchAndRewrite or a rewrite method");
}
virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
SmallVector<Value> oneToOneOperands =
getOneToOneAdaptorOperands(adaptor.getOperands());
rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
}
virtual LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Expand All @@ -632,6 +676,13 @@ class OpConversionPattern : public ConversionPattern {
rewrite(op, adaptor, rewriter);
return success();
}
virtual LogicalResult
matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
SmallVector<Value> oneToOneOperands =
getOneToOneAdaptorOperands(adaptor.getOperands());
return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
}

private:
using ConversionPattern::matchAndRewrite;
Expand All @@ -657,18 +708,31 @@ class OpInterfaceConversionPattern : public ConversionPattern {
ConversionPatternRewriter &rewriter) const final {
rewrite(cast<SourceOp>(op), operands, rewriter);
}
void rewrite(Operation *op, ArrayRef<ValueRange> operands,
ConversionPatternRewriter &rewriter) const final {
rewrite(cast<SourceOp>(op), operands, rewriter);
}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
ConversionPatternRewriter &rewriter) const final {
return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
}

/// Rewrite and Match methods that operate on the SourceOp type. These must be
/// overridden by the derived pattern class.
virtual void rewrite(SourceOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
llvm_unreachable("must override matchAndRewrite or a rewrite method");
}
virtual void rewrite(SourceOp op, ArrayRef<ValueRange> operands,
ConversionPatternRewriter &rewriter) const {
rewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
}
virtual LogicalResult
matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
Expand All @@ -677,6 +741,11 @@ class OpInterfaceConversionPattern : public ConversionPattern {
rewrite(op, operands, rewriter);
return success();
}
virtual LogicalResult
matchAndRewrite(SourceOp op, ArrayRef<ValueRange> operands,
ConversionPatternRewriter &rewriter) const {
return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
}

private:
using ConversionPattern::matchAndRewrite;
Expand Down
56 changes: 6 additions & 50 deletions mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,40 +13,6 @@
using namespace mlir;
using namespace mlir::func;

//===----------------------------------------------------------------------===//
// Helper functions
//===----------------------------------------------------------------------===//

/// If the given value can be decomposed with the type converter, decompose it.
/// Otherwise, return the given value.
// TODO: Value decomposition should happen automatically through a 1:N adaptor.
// This function will disappear when the 1:1 and 1:N drivers are merged.
static SmallVector<Value> decomposeValue(OpBuilder &builder, Location loc,
Value value,
const TypeConverter *converter) {
// Try to convert the given value's type. If that fails, just return the
// given value.
SmallVector<Type> convertedTypes;
if (failed(converter->convertType(value.getType(), convertedTypes)))
return {value};
if (convertedTypes.empty())
return {};

// If the given value's type is already legal, just return the given value.
TypeRange convertedTypeRange(convertedTypes);
if (convertedTypeRange == TypeRange(value.getType()))
return {value};

// Try to materialize a target conversion. If the materialization did not
// produce values of the requested type, the materialization failed. Just
// return the given value in that case.
SmallVector<Value> result = converter->materializeTargetConversion(
builder, loc, convertedTypeRange, value);
if (result.empty())
return {value};
return result;
}

//===----------------------------------------------------------------------===//
// DecomposeCallGraphTypesForFuncArgs
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -102,16 +68,11 @@ struct DecomposeCallGraphTypesForReturnOp
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
matchAndRewrite(ReturnOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
SmallVector<Value, 2> newOperands;
for (Value operand : adaptor.getOperands()) {
// TODO: We can directly take the values from the adaptor once this is a
// 1:N conversion pattern.
llvm::append_range(newOperands,
decomposeValue(rewriter, operand.getLoc(), operand,
getTypeConverter()));
}
for (ValueRange operand : adaptor.getOperands())
llvm::append_range(newOperands, operand);
rewriter.replaceOpWithNewOp<ReturnOp>(op, newOperands);
return success();
}
Expand All @@ -128,18 +89,13 @@ struct DecomposeCallGraphTypesForCallOp : public OpConversionPattern<CallOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(CallOp op, OpAdaptor adaptor,
matchAndRewrite(CallOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {

// Create the operands list of the new `CallOp`.
SmallVector<Value, 2> newOperands;
for (Value operand : adaptor.getOperands()) {
// TODO: We can directly take the values from the adaptor once this is a
// 1:N conversion pattern.
llvm::append_range(newOperands,
decomposeValue(rewriter, operand.getLoc(), operand,
getTypeConverter()));
}
for (ValueRange operand : adaptor.getOperands())
llvm::append_range(newOperands, operand);

// Create the new result types for the new `CallOp` and track the number of
// replacement types for each original op result.
Expand Down
16 changes: 12 additions & 4 deletions mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@
using namespace mlir;
using namespace mlir::func;

/// Flatten the given value ranges into a single vector of values.
static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
SmallVector<Value> result;
for (const auto &vals : values)
llvm::append_range(result, vals);
return result;
}

namespace {
/// Converts the operand and result types of the CallOp, used together with the
/// FuncOpSignatureConversion.
Expand All @@ -21,7 +29,7 @@ struct CallOpSignatureConversion : public OpConversionPattern<CallOp> {

/// Hook for derived classes to implement combined matching and rewriting.
LogicalResult
matchAndRewrite(CallOp callOp, OpAdaptor adaptor,
matchAndRewrite(CallOp callOp, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Convert the original function results. Keep track of how many result
// types an original result type is converted into.
Expand All @@ -38,9 +46,9 @@ struct CallOpSignatureConversion : public OpConversionPattern<CallOp> {

// Substitute with the new result types from the corresponding FuncType
// conversion.
auto newCallOp =
rewriter.create<CallOp>(callOp.getLoc(), callOp.getCallee(),
convertedResults, adaptor.getOperands());
auto newCallOp = rewriter.create<CallOp>(
callOp.getLoc(), callOp.getCallee(), convertedResults,
flattenValues(adaptor.getOperands()));
SmallVector<ValueRange> replacements;
size_t offset = 0;
for (int i = 0, e = callOp->getNumResults(); i < e; ++i) {
Expand Down
Loading