Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions flang/include/flang/Optimizer/CodeGen/FIROpPatterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,7 @@ class FIROpConversion : public ConvertFIRToLLVMPattern {
virtual llvm::LogicalResult
matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
llvm::SmallVector<mlir::Value> oneToOneOperands =
getOneToOneAdaptorOperands(adaptor.getOperands());
return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
return dispatchTo1To1(*this, op, adaptor, rewriter);
}

private:
Expand Down
6 changes: 2 additions & 4 deletions mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -233,9 +233,7 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
virtual LogicalResult
matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
SmallVector<Value> oneToOneOperands =
getOneToOneAdaptorOperands(adaptor.getOperands());
return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
return dispatchTo1To1(*this, op, adaptor, rewriter);
}

private:
Expand Down Expand Up @@ -276,7 +274,7 @@ class ConvertOpInterfaceToLLVMPattern : public ConvertToLLVMPattern {
virtual LogicalResult
matchAndRewrite(SourceOp op, ArrayRef<ValueRange> operands,
ConversionPatternRewriter &rewriter) const {
return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
return dispatchTo1To1(*this, op, operands, rewriter);
}

private:
Expand Down
62 changes: 52 additions & 10 deletions mlir/include/mlir/Transforms/DialectConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -521,8 +521,8 @@ class ConversionPattern : public RewritePattern {

/// Hook for derived classes to implement combined matching and rewriting.
/// This overload supports only 1:1 replacements. The 1:N overload is called
/// by the driver. By default, it calls this 1:1 overload or reports a fatal
/// error if 1:N replacements were found.
/// by the driver. By default, it calls this 1:1 overload or fails to match
/// if 1:N replacements were found.
virtual LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
Expand All @@ -534,7 +534,7 @@ class ConversionPattern : public RewritePattern {
virtual LogicalResult
matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
ConversionPatternRewriter &rewriter) const {
return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
return dispatchTo1To1(*this, op, operands, rewriter);
}

/// Attempt to match and rewrite the IR root at the specified operation.
Expand Down Expand Up @@ -567,11 +567,26 @@ class ConversionPattern : public RewritePattern {
/// try to extract the single value of each range to construct a the inputs
/// for a 1:1 adaptor.
///
/// This function produces a fatal error if at least one range has 0 or
/// more than 1 value: "pattern 'name' does not support 1:N conversion"
SmallVector<Value>
/// Returns failure if at least one range has 0 or more than 1 value.
FailureOr<SmallVector<Value>>
getOneToOneAdaptorOperands(ArrayRef<ValueRange> operands) const;

/// Overloaded method used to dispatch to the 1:1 'matchAndRewrite' method
/// if possible and emit diagnostic with a failure return value otherwise.
/// 'self' should be '*this' of the derived-pattern and is used to dispatch
/// to the correct 'matchAndRewrite' method in the derived pattern.
template <typename SelfPattern, typename SourceOp>
static LogicalResult dispatchTo1To1(const SelfPattern &self, SourceOp op,
ArrayRef<ValueRange> operands,
ConversionPatternRewriter &rewriter);

/// Same as above, but accepts an adaptor as operand.
template <typename SelfPattern, typename SourceOp>
static LogicalResult dispatchTo1To1(
const SelfPattern &self, SourceOp op,
typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>> adaptor,
ConversionPatternRewriter &rewriter);

protected:
/// An optional type converter for use by this pattern.
const TypeConverter *typeConverter = nullptr;
Expand Down Expand Up @@ -620,9 +635,7 @@ class OpConversionPattern : public ConversionPattern {
virtual LogicalResult
matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
SmallVector<Value> oneToOneOperands =
getOneToOneAdaptorOperands(adaptor.getOperands());
return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
return dispatchTo1To1(*this, op, adaptor, rewriter);
}

private:
Expand Down Expand Up @@ -666,7 +679,7 @@ class OpInterfaceConversionPattern : public ConversionPattern {
virtual LogicalResult
matchAndRewrite(SourceOp op, ArrayRef<ValueRange> operands,
ConversionPatternRewriter &rewriter) const {
return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
return dispatchTo1To1(*this, op, operands, rewriter);
}

private:
Expand Down Expand Up @@ -865,6 +878,35 @@ class ConversionPatternRewriter final : public PatternRewriter {
std::unique_ptr<detail::ConversionPatternRewriterImpl> impl;
};

template <typename SelfPattern, typename SourceOp>
LogicalResult
ConversionPattern::dispatchTo1To1(const SelfPattern &self, SourceOp op,
ArrayRef<ValueRange> operands,
ConversionPatternRewriter &rewriter) {
FailureOr<SmallVector<Value>> oneToOneOperands =
self.getOneToOneAdaptorOperands(operands);
if (failed(oneToOneOperands))
return rewriter.notifyMatchFailure(op,
"pattern '" + self.getDebugName() +
"' does not support 1:N conversion");
return self.matchAndRewrite(op, *oneToOneOperands, rewriter);
}

template <typename SelfPattern, typename SourceOp>
LogicalResult ConversionPattern::dispatchTo1To1(
const SelfPattern &self, SourceOp op,
typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>> adaptor,
ConversionPatternRewriter &rewriter) {
FailureOr<SmallVector<Value>> oneToOneOperands =
self.getOneToOneAdaptorOperands(adaptor.getOperands());
if (failed(oneToOneOperands))
return rewriter.notifyMatchFailure(op,
"pattern '" + self.getDebugName() +
"' does not support 1:N conversion");
return self.matchAndRewrite(
op, typename SourceOp::Adaptor(*oneToOneOperands, adaptor), rewriter);
}

//===----------------------------------------------------------------------===//
// ConversionTarget
//===----------------------------------------------------------------------===//
Expand Down
8 changes: 4 additions & 4 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2244,17 +2244,17 @@ detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
// ConversionPattern
//===----------------------------------------------------------------------===//

SmallVector<Value> ConversionPattern::getOneToOneAdaptorOperands(
FailureOr<SmallVector<Value>> ConversionPattern::getOneToOneAdaptorOperands(
ArrayRef<ValueRange> operands) const {
SmallVector<Value> oneToOneOperands;
oneToOneOperands.reserve(operands.size());
for (ValueRange operand : operands) {
if (operand.size() != 1)
llvm::report_fatal_error("pattern '" + getDebugName() +
"' does not support 1:N conversion");
return failure();

oneToOneOperands.push_back(operand.front());
}
return oneToOneOperands;
return std::move(oneToOneOperands);
}

LogicalResult
Expand Down
21 changes: 21 additions & 0 deletions mlir/test/Transforms/test-legalizer.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -439,3 +439,24 @@ func.func @test_lookup_without_converter() {
// expected-remark@+1 {{op 'func.return' is not legalizable}}
return
}

// -----
// expected-remark@-1 {{applyPartialConversion failed}}

func.func @test_skip_1to1_pattern(%arg0: f32) {
// expected-error@+1 {{failed to legalize operation 'test.type_consumer'}}
"test.type_consumer"(%arg0) : (f32) -> ()
return
}

// -----

// Demonstrate that the pattern generally works, but only for 1:1 type
// conversions.

// CHECK-LABEL: @test_working_1to1_pattern(
func.func @test_working_1to1_pattern(%arg0: f16) {
// CHECK-NEXT: "test.return"() : () -> ()
"test.type_consumer"(%arg0) : (f16) -> ()
"test.return"() : () -> ()
}
21 changes: 19 additions & 2 deletions mlir/test/lib/Dialect/Test/TestPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1386,6 +1386,23 @@ class TestMultiple1ToNReplacement : public ConversionPattern {
}
};

/// Pattern that erases 'test.type_consumers' iff the input operand is the
/// result of a 1:1 type conversion.
/// Used to test correct skipping of 1:1 patterns in the 1:N case.
class TestTypeConsumerOpPattern
: public OpConversionPattern<TestTypeConsumerOp> {
public:
TestTypeConsumerOpPattern(MLIRContext *ctx, const TypeConverter &converter)
: OpConversionPattern<TestTypeConsumerOp>(converter, ctx) {}

LogicalResult
matchAndRewrite(TestTypeConsumerOp op, OpAdaptor operands,
ConversionPatternRewriter &rewriter) const final {
rewriter.eraseOp(op);
return success();
}
};

/// Test unambiguous overload resolution of replaceOpWithMultiple. This
/// function is just to trigger compiler errors. It is never executed.
[[maybe_unused]] void testReplaceOpWithMultipleOverloads(
Expand Down Expand Up @@ -1497,8 +1514,8 @@ struct TestLegalizePatternDriver
TestRepetitive1ToNConsumer>(&getContext());
patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
TestPassthroughInvalidOp, TestMultiple1ToNReplacement,
TestBlockArgReplace, TestReplaceWithValidConsumer>(
&getContext(), converter);
TestBlockArgReplace, TestReplaceWithValidConsumer,
TestTypeConsumerOpPattern>(&getContext(), converter);
patterns.add<TestConvertBlockArgs>(converter, &getContext());
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
converter);
Expand Down