Skip to content

Commit 967e3a9

Browse files
committed
Revert "[mlir][Transforms] Add 1:N matchAndRewrite overload (llvm#116470)"
This reverts commit 9df63b2.
1 parent e9c68c6 commit 967e3a9

File tree

12 files changed

+455
-494
lines changed

12 files changed

+455
-494
lines changed

mlir/include/mlir/Conversion/LLVMCommon/Pattern.h

Lines changed: 4 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -143,8 +143,6 @@ template <typename SourceOp>
143143
class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
144144
public:
145145
using OpAdaptor = typename SourceOp::Adaptor;
146-
using OneToNOpAdaptor =
147-
typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
148146

149147
explicit ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter,
150148
PatternBenefit benefit = 1)
@@ -155,29 +153,17 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
155153
/// Wrappers around the RewritePattern methods that pass the derived op type.
156154
void rewrite(Operation *op, ArrayRef<Value> operands,
157155
ConversionPatternRewriter &rewriter) const final {
158-
auto sourceOp = cast<SourceOp>(op);
159-
rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
160-
}
161-
void rewrite(Operation *op, ArrayRef<ValueRange> operands,
162-
ConversionPatternRewriter &rewriter) const final {
163-
auto sourceOp = cast<SourceOp>(op);
164-
rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter);
156+
rewrite(cast<SourceOp>(op), OpAdaptor(operands, cast<SourceOp>(op)),
157+
rewriter);
165158
}
166159
LogicalResult match(Operation *op) const final {
167160
return match(cast<SourceOp>(op));
168161
}
169162
LogicalResult
170163
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
171164
ConversionPatternRewriter &rewriter) const final {
172-
auto sourceOp = cast<SourceOp>(op);
173-
return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
174-
}
175-
LogicalResult
176-
matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
177-
ConversionPatternRewriter &rewriter) const final {
178-
auto sourceOp = cast<SourceOp>(op);
179-
return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
180-
rewriter);
165+
return matchAndRewrite(cast<SourceOp>(op),
166+
OpAdaptor(operands, cast<SourceOp>(op)), rewriter);
181167
}
182168

183169
/// Rewrite and Match methods that operate on the SourceOp type. These must be
@@ -189,12 +175,6 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
189175
ConversionPatternRewriter &rewriter) const {
190176
llvm_unreachable("must override rewrite or matchAndRewrite");
191177
}
192-
virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
193-
ConversionPatternRewriter &rewriter) const {
194-
SmallVector<Value> oneToOneOperands =
195-
getOneToOneAdaptorOperands(adaptor.getOperands());
196-
rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
197-
}
198178
virtual LogicalResult
199179
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
200180
ConversionPatternRewriter &rewriter) const {
@@ -203,13 +183,6 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
203183
rewrite(op, adaptor, rewriter);
204184
return success();
205185
}
206-
virtual LogicalResult
207-
matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
208-
ConversionPatternRewriter &rewriter) const {
209-
SmallVector<Value> oneToOneOperands =
210-
getOneToOneAdaptorOperands(adaptor.getOperands());
211-
return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
212-
}
213186

214187
private:
215188
using ConvertToLLVMPattern::match;

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 0 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -538,15 +538,8 @@ class ConversionPattern : public RewritePattern {
538538
ConversionPatternRewriter &rewriter) const {
539539
llvm_unreachable("unimplemented rewrite");
540540
}
541-
virtual void rewrite(Operation *op, ArrayRef<ValueRange> operands,
542-
ConversionPatternRewriter &rewriter) const {
543-
rewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
544-
}
545541

546542
/// Hook for derived classes to implement combined matching and rewriting.
547-
/// This overload supports only 1:1 replacements. The 1:N overload is called
548-
/// by the driver. By default, it calls this 1:1 overload or reports a fatal
549-
/// error if 1:N replacements were found.
550543
virtual LogicalResult
551544
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
552545
ConversionPatternRewriter &rewriter) const {
@@ -556,14 +549,6 @@ class ConversionPattern : public RewritePattern {
556549
return success();
557550
}
558551

559-
/// Hook for derived classes to implement combined matching and rewriting.
560-
/// This overload supports 1:N replacements.
561-
virtual LogicalResult
562-
matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
563-
ConversionPatternRewriter &rewriter) const {
564-
return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
565-
}
566-
567552
/// Attempt to match and rewrite the IR root at the specified operation.
568553
LogicalResult matchAndRewrite(Operation *op,
569554
PatternRewriter &rewriter) const final;
@@ -590,15 +575,6 @@ class ConversionPattern : public RewritePattern {
590575
: RewritePattern(std::forward<Args>(args)...),
591576
typeConverter(&typeConverter) {}
592577

593-
/// Given an array of value ranges, which are the inputs to a 1:N adaptor,
594-
/// try to extract the single value of each range to construct a the inputs
595-
/// for a 1:1 adaptor.
596-
///
597-
/// This function produces a fatal error if at least one range has 0 or
598-
/// more than 1 value: "pattern 'name' does not support 1:N conversion"
599-
SmallVector<Value>
600-
getOneToOneAdaptorOperands(ArrayRef<ValueRange> operands) const;
601-
602578
protected:
603579
/// An optional type converter for use by this pattern.
604580
const TypeConverter *typeConverter = nullptr;
@@ -614,8 +590,6 @@ template <typename SourceOp>
614590
class OpConversionPattern : public ConversionPattern {
615591
public:
616592
using OpAdaptor = typename SourceOp::Adaptor;
617-
using OneToNOpAdaptor =
618-
typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
619593

620594
OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
621595
: ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
@@ -634,24 +608,12 @@ class OpConversionPattern : public ConversionPattern {
634608
auto sourceOp = cast<SourceOp>(op);
635609
rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
636610
}
637-
void rewrite(Operation *op, ArrayRef<ValueRange> operands,
638-
ConversionPatternRewriter &rewriter) const final {
639-
auto sourceOp = cast<SourceOp>(op);
640-
rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter);
641-
}
642611
LogicalResult
643612
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
644613
ConversionPatternRewriter &rewriter) const final {
645614
auto sourceOp = cast<SourceOp>(op);
646615
return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
647616
}
648-
LogicalResult
649-
matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
650-
ConversionPatternRewriter &rewriter) const final {
651-
auto sourceOp = cast<SourceOp>(op);
652-
return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
653-
rewriter);
654-
}
655617

656618
/// Rewrite and Match methods that operate on the SourceOp type. These must be
657619
/// overridden by the derived pattern class.
@@ -662,12 +624,6 @@ class OpConversionPattern : public ConversionPattern {
662624
ConversionPatternRewriter &rewriter) const {
663625
llvm_unreachable("must override matchAndRewrite or a rewrite method");
664626
}
665-
virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
666-
ConversionPatternRewriter &rewriter) const {
667-
SmallVector<Value> oneToOneOperands =
668-
getOneToOneAdaptorOperands(adaptor.getOperands());
669-
rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
670-
}
671627
virtual LogicalResult
672628
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
673629
ConversionPatternRewriter &rewriter) const {
@@ -676,13 +632,6 @@ class OpConversionPattern : public ConversionPattern {
676632
rewrite(op, adaptor, rewriter);
677633
return success();
678634
}
679-
virtual LogicalResult
680-
matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
681-
ConversionPatternRewriter &rewriter) const {
682-
SmallVector<Value> oneToOneOperands =
683-
getOneToOneAdaptorOperands(adaptor.getOperands());
684-
return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
685-
}
686635

687636
private:
688637
using ConversionPattern::matchAndRewrite;
@@ -708,31 +657,18 @@ class OpInterfaceConversionPattern : public ConversionPattern {
708657
ConversionPatternRewriter &rewriter) const final {
709658
rewrite(cast<SourceOp>(op), operands, rewriter);
710659
}
711-
void rewrite(Operation *op, ArrayRef<ValueRange> operands,
712-
ConversionPatternRewriter &rewriter) const final {
713-
rewrite(cast<SourceOp>(op), operands, rewriter);
714-
}
715660
LogicalResult
716661
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
717662
ConversionPatternRewriter &rewriter) const final {
718663
return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
719664
}
720-
LogicalResult
721-
matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
722-
ConversionPatternRewriter &rewriter) const final {
723-
return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
724-
}
725665

726666
/// Rewrite and Match methods that operate on the SourceOp type. These must be
727667
/// overridden by the derived pattern class.
728668
virtual void rewrite(SourceOp op, ArrayRef<Value> operands,
729669
ConversionPatternRewriter &rewriter) const {
730670
llvm_unreachable("must override matchAndRewrite or a rewrite method");
731671
}
732-
virtual void rewrite(SourceOp op, ArrayRef<ValueRange> operands,
733-
ConversionPatternRewriter &rewriter) const {
734-
rewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
735-
}
736672
virtual LogicalResult
737673
matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
738674
ConversionPatternRewriter &rewriter) const {
@@ -741,11 +677,6 @@ class OpInterfaceConversionPattern : public ConversionPattern {
741677
rewrite(op, operands, rewriter);
742678
return success();
743679
}
744-
virtual LogicalResult
745-
matchAndRewrite(SourceOp op, ArrayRef<ValueRange> operands,
746-
ConversionPatternRewriter &rewriter) const {
747-
return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
748-
}
749680

750681
private:
751682
using ConversionPattern::matchAndRewrite;

0 commit comments

Comments
 (0)