@@ -538,8 +538,15 @@ class ConversionPattern : public RewritePattern {
538
538
ConversionPatternRewriter &rewriter) const {
539
539
llvm_unreachable (" unimplemented rewrite" );
540
540
}
541
+ virtual void rewrite (Operation *op, ArrayRef<ValueRange> operands,
542
+ ConversionPatternRewriter &rewriter) const {
543
+ rewrite (op, getOneToOneAdaptorOperands (operands), rewriter);
544
+ }
541
545
542
546
// / Hook for derived classes to implement combined matching and rewriting.
547
+ // / This overload supports only 1:1 replacements. The 1:N overload is called
548
+ // / by the driver. By default, it calls this 1:1 overload or reports a fatal
549
+ // / error if 1:N replacements were found.
543
550
virtual LogicalResult
544
551
matchAndRewrite (Operation *op, ArrayRef<Value> operands,
545
552
ConversionPatternRewriter &rewriter) const {
@@ -549,6 +556,14 @@ class ConversionPattern : public RewritePattern {
549
556
return success ();
550
557
}
551
558
559
+ // / Hook for derived classes to implement combined matching and rewriting.
560
+ // / This overload supports 1:N replacements.
561
+ virtual LogicalResult
562
+ matchAndRewrite (Operation *op, ArrayRef<ValueRange> operands,
563
+ ConversionPatternRewriter &rewriter) const {
564
+ return matchAndRewrite (op, getOneToOneAdaptorOperands (operands), rewriter);
565
+ }
566
+
552
567
// / Attempt to match and rewrite the IR root at the specified operation.
553
568
LogicalResult matchAndRewrite (Operation *op,
554
569
PatternRewriter &rewriter) const final ;
@@ -575,6 +590,15 @@ class ConversionPattern : public RewritePattern {
575
590
: RewritePattern(std::forward<Args>(args)...),
576
591
typeConverter (&typeConverter) {}
577
592
593
+ // / Given an array of value ranges, which are the inputs to a 1:N adaptor,
594
+ // / try to extract the single value of each range to construct a the inputs
595
+ // / for a 1:1 adaptor.
596
+ // /
597
+ // / This function produces a fatal error if at least one range has 0 or
598
+ // / more than 1 value: "pattern 'name' does not support 1:N conversion"
599
+ SmallVector<Value>
600
+ getOneToOneAdaptorOperands (ArrayRef<ValueRange> operands) const ;
601
+
578
602
protected:
579
603
// / An optional type converter for use by this pattern.
580
604
const TypeConverter *typeConverter = nullptr ;
@@ -590,6 +614,8 @@ template <typename SourceOp>
590
614
class OpConversionPattern : public ConversionPattern {
591
615
public:
592
616
using OpAdaptor = typename SourceOp::Adaptor;
617
+ using OneToNOpAdaptor =
618
+ typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
593
619
594
620
OpConversionPattern (MLIRContext *context, PatternBenefit benefit = 1 )
595
621
: ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
@@ -608,12 +634,24 @@ class OpConversionPattern : public ConversionPattern {
608
634
auto sourceOp = cast<SourceOp>(op);
609
635
rewrite (sourceOp, OpAdaptor (operands, sourceOp), rewriter);
610
636
}
637
+ void rewrite (Operation *op, ArrayRef<ValueRange> operands,
638
+ ConversionPatternRewriter &rewriter) const final {
639
+ auto sourceOp = cast<SourceOp>(op);
640
+ rewrite (sourceOp, OneToNOpAdaptor (operands, sourceOp), rewriter);
641
+ }
611
642
LogicalResult
612
643
matchAndRewrite (Operation *op, ArrayRef<Value> operands,
613
644
ConversionPatternRewriter &rewriter) const final {
614
645
auto sourceOp = cast<SourceOp>(op);
615
646
return matchAndRewrite (sourceOp, OpAdaptor (operands, sourceOp), rewriter);
616
647
}
648
+ LogicalResult
649
+ matchAndRewrite (Operation *op, ArrayRef<ValueRange> operands,
650
+ ConversionPatternRewriter &rewriter) const final {
651
+ auto sourceOp = cast<SourceOp>(op);
652
+ return matchAndRewrite (sourceOp, OneToNOpAdaptor (operands, sourceOp),
653
+ rewriter);
654
+ }
617
655
618
656
// / Rewrite and Match methods that operate on the SourceOp type. These must be
619
657
// / overridden by the derived pattern class.
@@ -624,6 +662,12 @@ class OpConversionPattern : public ConversionPattern {
624
662
ConversionPatternRewriter &rewriter) const {
625
663
llvm_unreachable (" must override matchAndRewrite or a rewrite method" );
626
664
}
665
+ virtual void rewrite (SourceOp op, OneToNOpAdaptor adaptor,
666
+ ConversionPatternRewriter &rewriter) const {
667
+ SmallVector<Value> oneToOneOperands =
668
+ getOneToOneAdaptorOperands (adaptor.getOperands ());
669
+ rewrite (op, OpAdaptor (oneToOneOperands, adaptor), rewriter);
670
+ }
627
671
virtual LogicalResult
628
672
matchAndRewrite (SourceOp op, OpAdaptor adaptor,
629
673
ConversionPatternRewriter &rewriter) const {
@@ -632,6 +676,13 @@ class OpConversionPattern : public ConversionPattern {
632
676
rewrite (op, adaptor, rewriter);
633
677
return success ();
634
678
}
679
+ virtual LogicalResult
680
+ matchAndRewrite (SourceOp op, OneToNOpAdaptor adaptor,
681
+ ConversionPatternRewriter &rewriter) const {
682
+ SmallVector<Value> oneToOneOperands =
683
+ getOneToOneAdaptorOperands (adaptor.getOperands ());
684
+ return matchAndRewrite (op, OpAdaptor (oneToOneOperands, adaptor), rewriter);
685
+ }
635
686
636
687
private:
637
688
using ConversionPattern::matchAndRewrite;
@@ -657,18 +708,31 @@ class OpInterfaceConversionPattern : public ConversionPattern {
657
708
ConversionPatternRewriter &rewriter) const final {
658
709
rewrite (cast<SourceOp>(op), operands, rewriter);
659
710
}
711
+ void rewrite (Operation *op, ArrayRef<ValueRange> operands,
712
+ ConversionPatternRewriter &rewriter) const final {
713
+ rewrite (cast<SourceOp>(op), operands, rewriter);
714
+ }
660
715
LogicalResult
661
716
matchAndRewrite (Operation *op, ArrayRef<Value> operands,
662
717
ConversionPatternRewriter &rewriter) const final {
663
718
return matchAndRewrite (cast<SourceOp>(op), operands, rewriter);
664
719
}
720
+ LogicalResult
721
+ matchAndRewrite (Operation *op, ArrayRef<ValueRange> operands,
722
+ ConversionPatternRewriter &rewriter) const final {
723
+ return matchAndRewrite (cast<SourceOp>(op), operands, rewriter);
724
+ }
665
725
666
726
// / Rewrite and Match methods that operate on the SourceOp type. These must be
667
727
// / overridden by the derived pattern class.
668
728
virtual void rewrite (SourceOp op, ArrayRef<Value> operands,
669
729
ConversionPatternRewriter &rewriter) const {
670
730
llvm_unreachable (" must override matchAndRewrite or a rewrite method" );
671
731
}
732
+ virtual void rewrite (SourceOp op, ArrayRef<ValueRange> operands,
733
+ ConversionPatternRewriter &rewriter) const {
734
+ rewrite (op, getOneToOneAdaptorOperands (operands), rewriter);
735
+ }
672
736
virtual LogicalResult
673
737
matchAndRewrite (SourceOp op, ArrayRef<Value> operands,
674
738
ConversionPatternRewriter &rewriter) const {
@@ -677,6 +741,11 @@ class OpInterfaceConversionPattern : public ConversionPattern {
677
741
rewrite (op, operands, rewriter);
678
742
return success ();
679
743
}
744
+ virtual LogicalResult
745
+ matchAndRewrite (SourceOp op, ArrayRef<ValueRange> operands,
746
+ ConversionPatternRewriter &rewriter) const {
747
+ return matchAndRewrite (op, getOneToOneAdaptorOperands (operands), rewriter);
748
+ }
680
749
681
750
private:
682
751
using ConversionPattern::matchAndRewrite;
0 commit comments