@@ -537,8 +537,15 @@ class ConversionPattern : public RewritePattern {
537
537
ConversionPatternRewriter &rewriter) const {
538
538
llvm_unreachable (" unimplemented rewrite" );
539
539
}
540
+ virtual void rewrite (Operation *op, ArrayRef<ValueRange> operands,
541
+ ConversionPatternRewriter &rewriter) const {
542
+ rewrite (op, getOneToOneAdaptorOperands (operands), rewriter);
543
+ }
540
544
541
545
// / Hook for derived classes to implement combined matching and rewriting.
546
+ // / This overload supports only 1:1 replacements. The 1:N overload is called
547
+ // / by the driver. By default, it calls this 1:1 overload or reports a fatal
548
+ // / error if 1:N replacements were found.
542
549
virtual LogicalResult
543
550
matchAndRewrite (Operation *op, ArrayRef<Value> operands,
544
551
ConversionPatternRewriter &rewriter) const {
@@ -548,6 +555,14 @@ class ConversionPattern : public RewritePattern {
548
555
return success ();
549
556
}
550
557
558
+ // / Hook for derived classes to implement combined matching and rewriting.
559
+ // / This overload supports 1:N replacements.
560
+ virtual LogicalResult
561
+ matchAndRewrite (Operation *op, ArrayRef<ValueRange> operands,
562
+ ConversionPatternRewriter &rewriter) const {
563
+ return matchAndRewrite (op, getOneToOneAdaptorOperands (operands), rewriter);
564
+ }
565
+
551
566
// / Attempt to match and rewrite the IR root at the specified operation.
552
567
LogicalResult matchAndRewrite (Operation *op,
553
568
PatternRewriter &rewriter) const final ;
@@ -574,6 +589,15 @@ class ConversionPattern : public RewritePattern {
574
589
: RewritePattern(std::forward<Args>(args)...),
575
590
typeConverter (&typeConverter) {}
576
591
592
+ // / Given an array of value ranges, which are the inputs to a 1:N adaptor,
593
+ // / try to extract the single value of each range to construct a the inputs
594
+ // / for a 1:1 adaptor.
595
+ // /
596
+ // / This function produces a fatal error if at least one range has 0 or
597
+ // / more than 1 value: "pattern 'name' does not support 1:N conversion"
598
+ SmallVector<Value>
599
+ getOneToOneAdaptorOperands (ArrayRef<ValueRange> operands) const ;
600
+
577
601
protected:
578
602
// / An optional type converter for use by this pattern.
579
603
const TypeConverter *typeConverter = nullptr ;
@@ -589,6 +613,8 @@ template <typename SourceOp>
589
613
class OpConversionPattern : public ConversionPattern {
590
614
public:
591
615
using OpAdaptor = typename SourceOp::Adaptor;
616
+ using OneToNOpAdaptor =
617
+ typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
592
618
593
619
OpConversionPattern (MLIRContext *context, PatternBenefit benefit = 1 )
594
620
: ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
@@ -607,12 +633,24 @@ class OpConversionPattern : public ConversionPattern {
607
633
auto sourceOp = cast<SourceOp>(op);
608
634
rewrite (sourceOp, OpAdaptor (operands, sourceOp), rewriter);
609
635
}
636
+ void rewrite (Operation *op, ArrayRef<ValueRange> operands,
637
+ ConversionPatternRewriter &rewriter) const final {
638
+ auto sourceOp = cast<SourceOp>(op);
639
+ rewrite (sourceOp, OneToNOpAdaptor (operands, sourceOp), rewriter);
640
+ }
610
641
LogicalResult
611
642
matchAndRewrite (Operation *op, ArrayRef<Value> operands,
612
643
ConversionPatternRewriter &rewriter) const final {
613
644
auto sourceOp = cast<SourceOp>(op);
614
645
return matchAndRewrite (sourceOp, OpAdaptor (operands, sourceOp), rewriter);
615
646
}
647
+ LogicalResult
648
+ matchAndRewrite (Operation *op, ArrayRef<ValueRange> operands,
649
+ ConversionPatternRewriter &rewriter) const final {
650
+ auto sourceOp = cast<SourceOp>(op);
651
+ return matchAndRewrite (sourceOp, OneToNOpAdaptor (operands, sourceOp),
652
+ rewriter);
653
+ }
616
654
617
655
// / Rewrite and Match methods that operate on the SourceOp type. These must be
618
656
// / overridden by the derived pattern class.
@@ -623,6 +661,12 @@ class OpConversionPattern : public ConversionPattern {
623
661
ConversionPatternRewriter &rewriter) const {
624
662
llvm_unreachable (" must override matchAndRewrite or a rewrite method" );
625
663
}
664
+ virtual void rewrite (SourceOp op, OneToNOpAdaptor adaptor,
665
+ ConversionPatternRewriter &rewriter) const {
666
+ SmallVector<Value> oneToOneOperands =
667
+ getOneToOneAdaptorOperands (adaptor.getOperands ());
668
+ rewrite (op, OpAdaptor (oneToOneOperands, adaptor), rewriter);
669
+ }
626
670
virtual LogicalResult
627
671
matchAndRewrite (SourceOp op, OpAdaptor adaptor,
628
672
ConversionPatternRewriter &rewriter) const {
@@ -631,6 +675,13 @@ class OpConversionPattern : public ConversionPattern {
631
675
rewrite (op, adaptor, rewriter);
632
676
return success ();
633
677
}
678
+ virtual LogicalResult
679
+ matchAndRewrite (SourceOp op, OneToNOpAdaptor adaptor,
680
+ ConversionPatternRewriter &rewriter) const {
681
+ SmallVector<Value> oneToOneOperands =
682
+ getOneToOneAdaptorOperands (adaptor.getOperands ());
683
+ return matchAndRewrite (op, OpAdaptor (oneToOneOperands, adaptor), rewriter);
684
+ }
634
685
635
686
private:
636
687
using ConversionPattern::matchAndRewrite;
@@ -656,18 +707,31 @@ class OpInterfaceConversionPattern : public ConversionPattern {
656
707
ConversionPatternRewriter &rewriter) const final {
657
708
rewrite (cast<SourceOp>(op), operands, rewriter);
658
709
}
710
+ void rewrite (Operation *op, ArrayRef<ValueRange> operands,
711
+ ConversionPatternRewriter &rewriter) const final {
712
+ rewrite (cast<SourceOp>(op), operands, rewriter);
713
+ }
659
714
LogicalResult
660
715
matchAndRewrite (Operation *op, ArrayRef<Value> operands,
661
716
ConversionPatternRewriter &rewriter) const final {
662
717
return matchAndRewrite (cast<SourceOp>(op), operands, rewriter);
663
718
}
719
+ LogicalResult
720
+ matchAndRewrite (Operation *op, ArrayRef<ValueRange> operands,
721
+ ConversionPatternRewriter &rewriter) const final {
722
+ return matchAndRewrite (cast<SourceOp>(op), operands, rewriter);
723
+ }
664
724
665
725
// / Rewrite and Match methods that operate on the SourceOp type. These must be
666
726
// / overridden by the derived pattern class.
667
727
virtual void rewrite (SourceOp op, ArrayRef<Value> operands,
668
728
ConversionPatternRewriter &rewriter) const {
669
729
llvm_unreachable (" must override matchAndRewrite or a rewrite method" );
670
730
}
731
+ virtual void rewrite (SourceOp op, ArrayRef<ValueRange> operands,
732
+ ConversionPatternRewriter &rewriter) const {
733
+ rewrite (op, getOneToOneAdaptorOperands (operands), rewriter);
734
+ }
671
735
virtual LogicalResult
672
736
matchAndRewrite (SourceOp op, ArrayRef<Value> operands,
673
737
ConversionPatternRewriter &rewriter) const {
@@ -676,6 +740,11 @@ class OpInterfaceConversionPattern : public ConversionPattern {
676
740
rewrite (op, operands, rewriter);
677
741
return success ();
678
742
}
743
+ virtual LogicalResult
744
+ matchAndRewrite (SourceOp op, ArrayRef<ValueRange> operands,
745
+ ConversionPatternRewriter &rewriter) const {
746
+ return matchAndRewrite (op, getOneToOneAdaptorOperands (operands), rewriter);
747
+ }
679
748
680
749
private:
681
750
using ConversionPattern::matchAndRewrite;
0 commit comments