Skip to content

Commit 7d3519d

Browse files
replace with multiple
Apply suggestions from code review Co-authored-by: Markus Böck <[email protected]> address comments [WIP] 1:N conversion pattern update test cases Update mlir/lib/Transforms/Utils/DialectConversion.cpp Co-authored-by: Markus Böck <[email protected]> Update mlir/lib/Transforms/Utils/DialectConversion.cpp Co-authored-by: Markus Böck <[email protected]> address comments rollback unresolved materializations properly
1 parent b22cc5a commit 7d3519d

File tree

12 files changed

+497
-325
lines changed

12 files changed

+497
-325
lines changed

mlir/include/mlir/Conversion/LLVMCommon/Pattern.h

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,8 @@ template <typename SourceOp>
143143
class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
144144
public:
145145
using OpAdaptor = typename SourceOp::Adaptor;
146+
using OneToNOpAdaptor =
147+
typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
146148

147149
explicit ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter,
148150
PatternBenefit benefit = 1)
@@ -153,17 +155,29 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
153155
/// Wrappers around the RewritePattern methods that pass the derived op type.
154156
void rewrite(Operation *op, ArrayRef<Value> operands,
155157
ConversionPatternRewriter &rewriter) const final {
156-
rewrite(cast<SourceOp>(op), OpAdaptor(operands, cast<SourceOp>(op)),
157-
rewriter);
158+
auto sourceOp = cast<SourceOp>(op);
159+
rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
160+
}
161+
void rewrite(Operation *op, ArrayRef<ValueRange> operands,
162+
ConversionPatternRewriter &rewriter) const final {
163+
auto sourceOp = cast<SourceOp>(op);
164+
rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter);
158165
}
159166
LogicalResult match(Operation *op) const final {
160167
return match(cast<SourceOp>(op));
161168
}
162169
LogicalResult
163170
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
164171
ConversionPatternRewriter &rewriter) const final {
165-
return matchAndRewrite(cast<SourceOp>(op),
166-
OpAdaptor(operands, cast<SourceOp>(op)), rewriter);
172+
auto sourceOp = cast<SourceOp>(op);
173+
return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
174+
}
175+
LogicalResult
176+
matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
177+
ConversionPatternRewriter &rewriter) const final {
178+
auto sourceOp = cast<SourceOp>(op);
179+
return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
180+
rewriter);
167181
}
168182

169183
/// Rewrite and Match methods that operate on the SourceOp type. These must be
@@ -175,6 +189,12 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
175189
ConversionPatternRewriter &rewriter) const {
176190
llvm_unreachable("must override rewrite or matchAndRewrite");
177191
}
192+
virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
193+
ConversionPatternRewriter &rewriter) const {
194+
SmallVector<Value> oneToOneOperands =
195+
getOneToOneAdaptorOperands(adaptor.getOperands());
196+
rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
197+
}
178198
virtual LogicalResult
179199
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
180200
ConversionPatternRewriter &rewriter) const {
@@ -183,6 +203,13 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
183203
rewrite(op, adaptor, rewriter);
184204
return success();
185205
}
206+
virtual LogicalResult
207+
matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
208+
ConversionPatternRewriter &rewriter) const {
209+
SmallVector<Value> oneToOneOperands =
210+
getOneToOneAdaptorOperands(adaptor.getOperands());
211+
return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
212+
}
186213

187214
private:
188215
using ConvertToLLVMPattern::match;

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -538,8 +538,15 @@ class ConversionPattern : public RewritePattern {
538538
ConversionPatternRewriter &rewriter) const {
539539
llvm_unreachable("unimplemented rewrite");
540540
}
541+
virtual void rewrite(Operation *op, ArrayRef<ValueRange> operands,
542+
ConversionPatternRewriter &rewriter) const {
543+
rewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
544+
}
541545

542546
/// Hook for derived classes to implement combined matching and rewriting.
547+
/// This overload supports only 1:1 replacements. The 1:N overload is called
548+
/// by the driver. By default, it calls this 1:1 overload or reports a fatal
549+
/// error if 1:N replacements were found.
543550
virtual LogicalResult
544551
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
545552
ConversionPatternRewriter &rewriter) const {
@@ -549,6 +556,14 @@ class ConversionPattern : public RewritePattern {
549556
return success();
550557
}
551558

559+
/// Hook for derived classes to implement combined matching and rewriting.
560+
/// This overload supports 1:N replacements.
561+
virtual LogicalResult
562+
matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
563+
ConversionPatternRewriter &rewriter) const {
564+
return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
565+
}
566+
552567
/// Attempt to match and rewrite the IR root at the specified operation.
553568
LogicalResult matchAndRewrite(Operation *op,
554569
PatternRewriter &rewriter) const final;
@@ -575,6 +590,15 @@ class ConversionPattern : public RewritePattern {
575590
: RewritePattern(std::forward<Args>(args)...),
576591
typeConverter(&typeConverter) {}
577592

593+
/// Given an array of value ranges, which are the inputs to a 1:N adaptor,
594+
/// try to extract the single value of each range to construct a the inputs
595+
/// for a 1:1 adaptor.
596+
///
597+
/// This function produces a fatal error if at least one range has 0 or
598+
/// more than 1 value: "pattern 'name' does not support 1:N conversion"
599+
SmallVector<Value>
600+
getOneToOneAdaptorOperands(ArrayRef<ValueRange> operands) const;
601+
578602
protected:
579603
/// An optional type converter for use by this pattern.
580604
const TypeConverter *typeConverter = nullptr;
@@ -590,6 +614,8 @@ template <typename SourceOp>
590614
class OpConversionPattern : public ConversionPattern {
591615
public:
592616
using OpAdaptor = typename SourceOp::Adaptor;
617+
using OneToNOpAdaptor =
618+
typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
593619

594620
OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
595621
: ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
@@ -608,12 +634,24 @@ class OpConversionPattern : public ConversionPattern {
608634
auto sourceOp = cast<SourceOp>(op);
609635
rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
610636
}
637+
void rewrite(Operation *op, ArrayRef<ValueRange> operands,
638+
ConversionPatternRewriter &rewriter) const final {
639+
auto sourceOp = cast<SourceOp>(op);
640+
rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter);
641+
}
611642
LogicalResult
612643
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
613644
ConversionPatternRewriter &rewriter) const final {
614645
auto sourceOp = cast<SourceOp>(op);
615646
return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
616647
}
648+
LogicalResult
649+
matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
650+
ConversionPatternRewriter &rewriter) const final {
651+
auto sourceOp = cast<SourceOp>(op);
652+
return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
653+
rewriter);
654+
}
617655

618656
/// Rewrite and Match methods that operate on the SourceOp type. These must be
619657
/// overridden by the derived pattern class.
@@ -624,6 +662,12 @@ class OpConversionPattern : public ConversionPattern {
624662
ConversionPatternRewriter &rewriter) const {
625663
llvm_unreachable("must override matchAndRewrite or a rewrite method");
626664
}
665+
virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
666+
ConversionPatternRewriter &rewriter) const {
667+
SmallVector<Value> oneToOneOperands =
668+
getOneToOneAdaptorOperands(adaptor.getOperands());
669+
rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
670+
}
627671
virtual LogicalResult
628672
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
629673
ConversionPatternRewriter &rewriter) const {
@@ -632,6 +676,13 @@ class OpConversionPattern : public ConversionPattern {
632676
rewrite(op, adaptor, rewriter);
633677
return success();
634678
}
679+
virtual LogicalResult
680+
matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
681+
ConversionPatternRewriter &rewriter) const {
682+
SmallVector<Value> oneToOneOperands =
683+
getOneToOneAdaptorOperands(adaptor.getOperands());
684+
return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
685+
}
635686

636687
private:
637688
using ConversionPattern::matchAndRewrite;
@@ -657,18 +708,31 @@ class OpInterfaceConversionPattern : public ConversionPattern {
657708
ConversionPatternRewriter &rewriter) const final {
658709
rewrite(cast<SourceOp>(op), operands, rewriter);
659710
}
711+
void rewrite(Operation *op, ArrayRef<ValueRange> operands,
712+
ConversionPatternRewriter &rewriter) const final {
713+
rewrite(cast<SourceOp>(op), operands, rewriter);
714+
}
660715
LogicalResult
661716
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
662717
ConversionPatternRewriter &rewriter) const final {
663718
return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
664719
}
720+
LogicalResult
721+
matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
722+
ConversionPatternRewriter &rewriter) const final {
723+
return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
724+
}
665725

666726
/// Rewrite and Match methods that operate on the SourceOp type. These must be
667727
/// overridden by the derived pattern class.
668728
virtual void rewrite(SourceOp op, ArrayRef<Value> operands,
669729
ConversionPatternRewriter &rewriter) const {
670730
llvm_unreachable("must override matchAndRewrite or a rewrite method");
671731
}
732+
virtual void rewrite(SourceOp op, ArrayRef<ValueRange> operands,
733+
ConversionPatternRewriter &rewriter) const {
734+
rewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
735+
}
672736
virtual LogicalResult
673737
matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
674738
ConversionPatternRewriter &rewriter) const {
@@ -677,6 +741,11 @@ class OpInterfaceConversionPattern : public ConversionPattern {
677741
rewrite(op, operands, rewriter);
678742
return success();
679743
}
744+
virtual LogicalResult
745+
matchAndRewrite(SourceOp op, ArrayRef<ValueRange> operands,
746+
ConversionPatternRewriter &rewriter) const {
747+
return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
748+
}
680749

681750
private:
682751
using ConversionPattern::matchAndRewrite;

mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp

Lines changed: 6 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -13,40 +13,6 @@
1313
using namespace mlir;
1414
using namespace mlir::func;
1515

16-
//===----------------------------------------------------------------------===//
17-
// Helper functions
18-
//===----------------------------------------------------------------------===//
19-
20-
/// If the given value can be decomposed with the type converter, decompose it.
21-
/// Otherwise, return the given value.
22-
// TODO: Value decomposition should happen automatically through a 1:N adaptor.
23-
// This function will disappear when the 1:1 and 1:N drivers are merged.
24-
static SmallVector<Value> decomposeValue(OpBuilder &builder, Location loc,
25-
Value value,
26-
const TypeConverter *converter) {
27-
// Try to convert the given value's type. If that fails, just return the
28-
// given value.
29-
SmallVector<Type> convertedTypes;
30-
if (failed(converter->convertType(value.getType(), convertedTypes)))
31-
return {value};
32-
if (convertedTypes.empty())
33-
return {};
34-
35-
// If the given value's type is already legal, just return the given value.
36-
TypeRange convertedTypeRange(convertedTypes);
37-
if (convertedTypeRange == TypeRange(value.getType()))
38-
return {value};
39-
40-
// Try to materialize a target conversion. If the materialization did not
41-
// produce values of the requested type, the materialization failed. Just
42-
// return the given value in that case.
43-
SmallVector<Value> result = converter->materializeTargetConversion(
44-
builder, loc, convertedTypeRange, value);
45-
if (result.empty())
46-
return {value};
47-
return result;
48-
}
49-
5016
//===----------------------------------------------------------------------===//
5117
// DecomposeCallGraphTypesForFuncArgs
5218
//===----------------------------------------------------------------------===//
@@ -102,16 +68,11 @@ struct DecomposeCallGraphTypesForReturnOp
10268
using OpConversionPattern::OpConversionPattern;
10369

10470
LogicalResult
105-
matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
71+
matchAndRewrite(ReturnOp op, OneToNOpAdaptor adaptor,
10672
ConversionPatternRewriter &rewriter) const final {
10773
SmallVector<Value, 2> newOperands;
108-
for (Value operand : adaptor.getOperands()) {
109-
// TODO: We can directly take the values from the adaptor once this is a
110-
// 1:N conversion pattern.
111-
llvm::append_range(newOperands,
112-
decomposeValue(rewriter, operand.getLoc(), operand,
113-
getTypeConverter()));
114-
}
74+
for (ValueRange operand : adaptor.getOperands())
75+
llvm::append_range(newOperands, operand);
11576
rewriter.replaceOpWithNewOp<ReturnOp>(op, newOperands);
11677
return success();
11778
}
@@ -128,18 +89,13 @@ struct DecomposeCallGraphTypesForCallOp : public OpConversionPattern<CallOp> {
12889
using OpConversionPattern::OpConversionPattern;
12990

13091
LogicalResult
131-
matchAndRewrite(CallOp op, OpAdaptor adaptor,
92+
matchAndRewrite(CallOp op, OneToNOpAdaptor adaptor,
13293
ConversionPatternRewriter &rewriter) const final {
13394

13495
// Create the operands list of the new `CallOp`.
13596
SmallVector<Value, 2> newOperands;
136-
for (Value operand : adaptor.getOperands()) {
137-
// TODO: We can directly take the values from the adaptor once this is a
138-
// 1:N conversion pattern.
139-
llvm::append_range(newOperands,
140-
decomposeValue(rewriter, operand.getLoc(), operand,
141-
getTypeConverter()));
142-
}
97+
for (ValueRange operand : adaptor.getOperands())
98+
llvm::append_range(newOperands, operand);
14399

144100
// Create the new result types for the new `CallOp` and track the number of
145101
// replacement types for each original op result.

mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,14 @@
1313
using namespace mlir;
1414
using namespace mlir::func;
1515

16+
/// Flatten the given value ranges into a single vector of values.
17+
static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
18+
SmallVector<Value> result;
19+
for (const auto &vals : values)
20+
llvm::append_range(result, vals);
21+
return result;
22+
}
23+
1624
namespace {
1725
/// Converts the operand and result types of the CallOp, used together with the
1826
/// FuncOpSignatureConversion.
@@ -21,7 +29,7 @@ struct CallOpSignatureConversion : public OpConversionPattern<CallOp> {
2129

2230
/// Hook for derived classes to implement combined matching and rewriting.
2331
LogicalResult
24-
matchAndRewrite(CallOp callOp, OpAdaptor adaptor,
32+
matchAndRewrite(CallOp callOp, OneToNOpAdaptor adaptor,
2533
ConversionPatternRewriter &rewriter) const override {
2634
// Convert the original function results. Keep track of how many result
2735
// types an original result type is converted into.
@@ -38,9 +46,9 @@ struct CallOpSignatureConversion : public OpConversionPattern<CallOp> {
3846

3947
// Substitute with the new result types from the corresponding FuncType
4048
// conversion.
41-
auto newCallOp =
42-
rewriter.create<CallOp>(callOp.getLoc(), callOp.getCallee(),
43-
convertedResults, adaptor.getOperands());
49+
auto newCallOp = rewriter.create<CallOp>(
50+
callOp.getLoc(), callOp.getCallee(), convertedResults,
51+
flattenValues(adaptor.getOperands()));
4452
SmallVector<ValueRange> replacements;
4553
size_t offset = 0;
4654
for (int i = 0, e = callOp->getNumResults(); i < e; ++i) {

0 commit comments

Comments
 (0)