Skip to content

Commit 76a8541

Browse files
replace with multiple
Apply suggestions from code review Co-authored-by: Markus Böck <[email protected]> address comments [WIP] 1:N conversion pattern update test cases Update mlir/lib/Transforms/Utils/DialectConversion.cpp Co-authored-by: Markus Böck <[email protected]> Update mlir/lib/Transforms/Utils/DialectConversion.cpp Co-authored-by: Markus Böck <[email protected]> address comments rollback unresolved materializations properly
1 parent c894d3a commit 76a8541

File tree

12 files changed

+518
-337
lines changed

12 files changed

+518
-337
lines changed

mlir/include/mlir/Conversion/LLVMCommon/Pattern.h

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,8 @@ template <typename SourceOp>
143143
class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
144144
public:
145145
using OpAdaptor = typename SourceOp::Adaptor;
146+
using OneToNOpAdaptor =
147+
typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
146148

147149
explicit ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter,
148150
PatternBenefit benefit = 1)
@@ -153,17 +155,29 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
153155
/// Wrappers around the RewritePattern methods that pass the derived op type.
154156
void rewrite(Operation *op, ArrayRef<Value> operands,
155157
ConversionPatternRewriter &rewriter) const final {
156-
rewrite(cast<SourceOp>(op), OpAdaptor(operands, cast<SourceOp>(op)),
157-
rewriter);
158+
auto sourceOp = cast<SourceOp>(op);
159+
rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
160+
}
161+
void rewrite(Operation *op, ArrayRef<ValueRange> operands,
162+
ConversionPatternRewriter &rewriter) const final {
163+
auto sourceOp = cast<SourceOp>(op);
164+
rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter);
158165
}
159166
LogicalResult match(Operation *op) const final {
160167
return match(cast<SourceOp>(op));
161168
}
162169
LogicalResult
163170
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
164171
ConversionPatternRewriter &rewriter) const final {
165-
return matchAndRewrite(cast<SourceOp>(op),
166-
OpAdaptor(operands, cast<SourceOp>(op)), rewriter);
172+
auto sourceOp = cast<SourceOp>(op);
173+
return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
174+
}
175+
LogicalResult
176+
matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
177+
ConversionPatternRewriter &rewriter) const final {
178+
auto sourceOp = cast<SourceOp>(op);
179+
return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
180+
rewriter);
167181
}
168182

169183
/// Rewrite and Match methods that operate on the SourceOp type. These must be
@@ -175,6 +189,12 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
175189
ConversionPatternRewriter &rewriter) const {
176190
llvm_unreachable("must override rewrite or matchAndRewrite");
177191
}
192+
virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
193+
ConversionPatternRewriter &rewriter) const {
194+
SmallVector<Value> oneToOneOperands =
195+
getOneToOneAdaptorOperands(adaptor.getOperands());
196+
rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
197+
}
178198
virtual LogicalResult
179199
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
180200
ConversionPatternRewriter &rewriter) const {
@@ -183,6 +203,13 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
183203
rewrite(op, adaptor, rewriter);
184204
return success();
185205
}
206+
virtual LogicalResult
207+
matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
208+
ConversionPatternRewriter &rewriter) const {
209+
SmallVector<Value> oneToOneOperands =
210+
getOneToOneAdaptorOperands(adaptor.getOperands());
211+
return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
212+
}
186213

187214
private:
188215
using ConvertToLLVMPattern::match;

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,8 +537,15 @@ class ConversionPattern : public RewritePattern {
537537
ConversionPatternRewriter &rewriter) const {
538538
llvm_unreachable("unimplemented rewrite");
539539
}
540+
virtual void rewrite(Operation *op, ArrayRef<ValueRange> operands,
541+
ConversionPatternRewriter &rewriter) const {
542+
rewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
543+
}
540544

541545
/// Hook for derived classes to implement combined matching and rewriting.
546+
/// This overload supports only 1:1 replacements. The 1:N overload is called
547+
/// by the driver. By default, it calls this 1:1 overload or reports a fatal
548+
/// error if 1:N replacements were found.
542549
virtual LogicalResult
543550
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
544551
ConversionPatternRewriter &rewriter) const {
@@ -548,6 +555,14 @@ class ConversionPattern : public RewritePattern {
548555
return success();
549556
}
550557

558+
/// Hook for derived classes to implement combined matching and rewriting.
559+
/// This overload supports 1:N replacements.
560+
virtual LogicalResult
561+
matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
562+
ConversionPatternRewriter &rewriter) const {
563+
return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
564+
}
565+
551566
/// Attempt to match and rewrite the IR root at the specified operation.
552567
LogicalResult matchAndRewrite(Operation *op,
553568
PatternRewriter &rewriter) const final;
@@ -574,6 +589,15 @@ class ConversionPattern : public RewritePattern {
574589
: RewritePattern(std::forward<Args>(args)...),
575590
typeConverter(&typeConverter) {}
576591

592+
/// Given an array of value ranges, which are the inputs to a 1:N adaptor,
593+
/// try to extract the single value of each range to construct a the inputs
594+
/// for a 1:1 adaptor.
595+
///
596+
/// This function produces a fatal error if at least one range has 0 or
597+
/// more than 1 value: "pattern 'name' does not support 1:N conversion"
598+
SmallVector<Value>
599+
getOneToOneAdaptorOperands(ArrayRef<ValueRange> operands) const;
600+
577601
protected:
578602
/// An optional type converter for use by this pattern.
579603
const TypeConverter *typeConverter = nullptr;
@@ -589,6 +613,8 @@ template <typename SourceOp>
589613
class OpConversionPattern : public ConversionPattern {
590614
public:
591615
using OpAdaptor = typename SourceOp::Adaptor;
616+
using OneToNOpAdaptor =
617+
typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
592618

593619
OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
594620
: ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
@@ -607,12 +633,24 @@ class OpConversionPattern : public ConversionPattern {
607633
auto sourceOp = cast<SourceOp>(op);
608634
rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
609635
}
636+
void rewrite(Operation *op, ArrayRef<ValueRange> operands,
637+
ConversionPatternRewriter &rewriter) const final {
638+
auto sourceOp = cast<SourceOp>(op);
639+
rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter);
640+
}
610641
LogicalResult
611642
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
612643
ConversionPatternRewriter &rewriter) const final {
613644
auto sourceOp = cast<SourceOp>(op);
614645
return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
615646
}
647+
LogicalResult
648+
matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
649+
ConversionPatternRewriter &rewriter) const final {
650+
auto sourceOp = cast<SourceOp>(op);
651+
return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
652+
rewriter);
653+
}
616654

617655
/// Rewrite and Match methods that operate on the SourceOp type. These must be
618656
/// overridden by the derived pattern class.
@@ -623,6 +661,12 @@ class OpConversionPattern : public ConversionPattern {
623661
ConversionPatternRewriter &rewriter) const {
624662
llvm_unreachable("must override matchAndRewrite or a rewrite method");
625663
}
664+
virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
665+
ConversionPatternRewriter &rewriter) const {
666+
SmallVector<Value> oneToOneOperands =
667+
getOneToOneAdaptorOperands(adaptor.getOperands());
668+
rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
669+
}
626670
virtual LogicalResult
627671
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
628672
ConversionPatternRewriter &rewriter) const {
@@ -631,6 +675,13 @@ class OpConversionPattern : public ConversionPattern {
631675
rewrite(op, adaptor, rewriter);
632676
return success();
633677
}
678+
virtual LogicalResult
679+
matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
680+
ConversionPatternRewriter &rewriter) const {
681+
SmallVector<Value> oneToOneOperands =
682+
getOneToOneAdaptorOperands(adaptor.getOperands());
683+
return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
684+
}
634685

635686
private:
636687
using ConversionPattern::matchAndRewrite;
@@ -656,18 +707,31 @@ class OpInterfaceConversionPattern : public ConversionPattern {
656707
ConversionPatternRewriter &rewriter) const final {
657708
rewrite(cast<SourceOp>(op), operands, rewriter);
658709
}
710+
void rewrite(Operation *op, ArrayRef<ValueRange> operands,
711+
ConversionPatternRewriter &rewriter) const final {
712+
rewrite(cast<SourceOp>(op), operands, rewriter);
713+
}
659714
LogicalResult
660715
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
661716
ConversionPatternRewriter &rewriter) const final {
662717
return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
663718
}
719+
LogicalResult
720+
matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
721+
ConversionPatternRewriter &rewriter) const final {
722+
return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
723+
}
664724

665725
/// Rewrite and Match methods that operate on the SourceOp type. These must be
666726
/// overridden by the derived pattern class.
667727
virtual void rewrite(SourceOp op, ArrayRef<Value> operands,
668728
ConversionPatternRewriter &rewriter) const {
669729
llvm_unreachable("must override matchAndRewrite or a rewrite method");
670730
}
731+
virtual void rewrite(SourceOp op, ArrayRef<ValueRange> operands,
732+
ConversionPatternRewriter &rewriter) const {
733+
rewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
734+
}
671735
virtual LogicalResult
672736
matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
673737
ConversionPatternRewriter &rewriter) const {
@@ -676,6 +740,11 @@ class OpInterfaceConversionPattern : public ConversionPattern {
676740
rewrite(op, operands, rewriter);
677741
return success();
678742
}
743+
virtual LogicalResult
744+
matchAndRewrite(SourceOp op, ArrayRef<ValueRange> operands,
745+
ConversionPatternRewriter &rewriter) const {
746+
return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
747+
}
679748

680749
private:
681750
using ConversionPattern::matchAndRewrite;

mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp

Lines changed: 6 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -13,40 +13,6 @@
1313
using namespace mlir;
1414
using namespace mlir::func;
1515

16-
//===----------------------------------------------------------------------===//
17-
// Helper functions
18-
//===----------------------------------------------------------------------===//
19-
20-
/// If the given value can be decomposed with the type converter, decompose it.
21-
/// Otherwise, return the given value.
22-
// TODO: Value decomposition should happen automatically through a 1:N adaptor.
23-
// This function will disappear when the 1:1 and 1:N drivers are merged.
24-
static SmallVector<Value> decomposeValue(OpBuilder &builder, Location loc,
25-
Value value,
26-
const TypeConverter *converter) {
27-
// Try to convert the given value's type. If that fails, just return the
28-
// given value.
29-
SmallVector<Type> convertedTypes;
30-
if (failed(converter->convertType(value.getType(), convertedTypes)))
31-
return {value};
32-
if (convertedTypes.empty())
33-
return {};
34-
35-
// If the given value's type is already legal, just return the given value.
36-
TypeRange convertedTypeRange(convertedTypes);
37-
if (convertedTypeRange == TypeRange(value.getType()))
38-
return {value};
39-
40-
// Try to materialize a target conversion. If the materialization did not
41-
// produce values of the requested type, the materialization failed. Just
42-
// return the given value in that case.
43-
SmallVector<Value> result = converter->materializeTargetConversion(
44-
builder, loc, convertedTypeRange, value);
45-
if (result.empty())
46-
return {value};
47-
return result;
48-
}
49-
5016
//===----------------------------------------------------------------------===//
5117
// DecomposeCallGraphTypesForFuncArgs
5218
//===----------------------------------------------------------------------===//
@@ -102,16 +68,11 @@ struct DecomposeCallGraphTypesForReturnOp
10268
using OpConversionPattern::OpConversionPattern;
10369

10470
LogicalResult
105-
matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
71+
matchAndRewrite(ReturnOp op, OneToNOpAdaptor adaptor,
10672
ConversionPatternRewriter &rewriter) const final {
10773
SmallVector<Value, 2> newOperands;
108-
for (Value operand : adaptor.getOperands()) {
109-
// TODO: We can directly take the values from the adaptor once this is a
110-
// 1:N conversion pattern.
111-
llvm::append_range(newOperands,
112-
decomposeValue(rewriter, operand.getLoc(), operand,
113-
getTypeConverter()));
114-
}
74+
for (ValueRange operand : adaptor.getOperands())
75+
llvm::append_range(newOperands, operand);
11576
rewriter.replaceOpWithNewOp<ReturnOp>(op, newOperands);
11677
return success();
11778
}
@@ -128,18 +89,13 @@ struct DecomposeCallGraphTypesForCallOp : public OpConversionPattern<CallOp> {
12889
using OpConversionPattern::OpConversionPattern;
12990

13091
LogicalResult
131-
matchAndRewrite(CallOp op, OpAdaptor adaptor,
92+
matchAndRewrite(CallOp op, OneToNOpAdaptor adaptor,
13293
ConversionPatternRewriter &rewriter) const final {
13394

13495
// Create the operands list of the new `CallOp`.
13596
SmallVector<Value, 2> newOperands;
136-
for (Value operand : adaptor.getOperands()) {
137-
// TODO: We can directly take the values from the adaptor once this is a
138-
// 1:N conversion pattern.
139-
llvm::append_range(newOperands,
140-
decomposeValue(rewriter, operand.getLoc(), operand,
141-
getTypeConverter()));
142-
}
97+
for (ValueRange operand : adaptor.getOperands())
98+
llvm::append_range(newOperands, operand);
14399

144100
// Create the new result types for the new `CallOp` and track the number of
145101
// replacement types for each original op result.

mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,14 @@
1313
using namespace mlir;
1414
using namespace mlir::func;
1515

16+
/// Flatten the given value ranges into a single vector of values.
17+
static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
18+
SmallVector<Value> result;
19+
for (const auto &vals : values)
20+
llvm::append_range(result, vals);
21+
return result;
22+
}
23+
1624
namespace {
1725
/// Converts the operand and result types of the CallOp, used together with the
1826
/// FuncOpSignatureConversion.
@@ -21,7 +29,7 @@ struct CallOpSignatureConversion : public OpConversionPattern<CallOp> {
2129

2230
/// Hook for derived classes to implement combined matching and rewriting.
2331
LogicalResult
24-
matchAndRewrite(CallOp callOp, OpAdaptor adaptor,
32+
matchAndRewrite(CallOp callOp, OneToNOpAdaptor adaptor,
2533
ConversionPatternRewriter &rewriter) const override {
2634
// Convert the original function results. Keep track of how many result
2735
// types an original result type is converted into.
@@ -38,9 +46,9 @@ struct CallOpSignatureConversion : public OpConversionPattern<CallOp> {
3846

3947
// Substitute with the new result types from the corresponding FuncType
4048
// conversion.
41-
auto newCallOp =
42-
rewriter.create<CallOp>(callOp.getLoc(), callOp.getCallee(),
43-
convertedResults, adaptor.getOperands());
49+
auto newCallOp = rewriter.create<CallOp>(
50+
callOp.getLoc(), callOp.getCallee(), convertedResults,
51+
flattenValues(adaptor.getOperands()));
4452
SmallVector<ValueRange> replacements;
4553
size_t offset = 0;
4654
for (int i = 0, e = callOp->getNumResults(); i < e; ++i) {

0 commit comments

Comments
 (0)