@@ -24,6 +24,32 @@ func.func @fuse_empty_loops() {
24
24
25
25
// -----
26
26
27
+ func.func @fuse_ops_between (%A: f32 , %B: f32 ) -> f32 {
28
+ %c2 = arith.constant 2 : index
29
+ %c0 = arith.constant 0 : index
30
+ %c1 = arith.constant 1 : index
31
+ scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) {
32
+ scf.reduce
33
+ }
34
+ %res = arith.addf %A , %B : f32
35
+ scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) {
36
+ scf.reduce
37
+ }
38
+ return %res : f32
39
+ }
40
+ // CHECK-LABEL: func @fuse_ops_between
41
+ // CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
42
+ // CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index
43
+ // CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index
44
+ // CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : f32
45
+ // CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
46
+ // CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
47
+ // CHECK: scf.reduce
48
+ // CHECK: }
49
+ // CHECK-NOT: scf.parallel
50
+
51
+ // -----
52
+
27
53
func.func @fuse_two (%A: memref <2 x2 xf32 >, %B: memref <2 x2 xf32 >) {
28
54
%c2 = arith.constant 2 : index
29
55
%c0 = arith.constant 0 : index
@@ -89,7 +115,7 @@ func.func @fuse_three(%A: memref<2x2xf32>, %B: memref<2x2xf32>) {
89
115
memref.store %product_elem , %prod [%i , %j ] : memref <2 x2 xf32 >
90
116
scf.reduce
91
117
}
92
- scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) {
118
+ scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) {
93
119
%A_elem = memref.load %A [%i , %j ] : memref <2 x2 xf32 >
94
120
%res_elem = arith.addf %A_elem , %c2fp : f32
95
121
memref.store %res_elem , %B [%i , %j ] : memref <2 x2 xf32 >
@@ -575,3 +601,215 @@ func.func @do_not_fuse_affine_apply_to_non_ind_var(
575
601
// CHECK-NEXT: }
576
602
// CHECK-NEXT: memref.dealloc %[[ALLOC]] : memref<2x3xf32>
577
603
// CHECK-NEXT: return
604
+
605
+ // -----
606
+
607
+ func.func @fuse_reductions_two (%A: memref <2 x2 xf32 >, %B: memref <2 x2 xf32 >) -> (f32 , f32 ) {
608
+ %c2 = arith.constant 2 : index
609
+ %c0 = arith.constant 0 : index
610
+ %c1 = arith.constant 1 : index
611
+ %init1 = arith.constant 1.0 : f32
612
+ %init2 = arith.constant 2.0 : f32
613
+ %res1 = scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) init (%init1 ) -> f32 {
614
+ %A_elem = memref.load %A [%i , %j ] : memref <2 x2 xf32 >
615
+ scf.reduce (%A_elem : f32 ) {
616
+ ^bb0 (%lhs: f32 , %rhs: f32 ):
617
+ %1 = arith.addf %lhs , %rhs : f32
618
+ scf.reduce.return %1 : f32
619
+ }
620
+ }
621
+ %res2 = scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) init (%init2 ) -> f32 {
622
+ %B_elem = memref.load %B [%i , %j ] : memref <2 x2 xf32 >
623
+ scf.reduce (%B_elem : f32 ) {
624
+ ^bb0 (%lhs: f32 , %rhs: f32 ):
625
+ %1 = arith.mulf %lhs , %rhs : f32
626
+ scf.reduce.return %1 : f32
627
+ }
628
+ }
629
+ return %res1 , %res2 : f32 , f32
630
+ }
631
+
632
+ // CHECK-LABEL: func @fuse_reductions_two
633
+ // CHECK-SAME: (%[[A:.*]]: memref<2x2xf32>, %[[B:.*]]: memref<2x2xf32>) -> (f32, f32)
634
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
635
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
636
+ // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
637
+ // CHECK-DAG: %[[INIT1:.*]] = arith.constant 1.000000e+00 : f32
638
+ // CHECK-DAG: %[[INIT2:.*]] = arith.constant 2.000000e+00 : f32
639
+ // CHECK: %[[RES:.*]]:2 = scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]])
640
+ // CHECK-SAME: to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]])
641
+ // CHECK-SAME: init (%[[INIT1]], %[[INIT2]]) -> (f32, f32)
642
+ // CHECK: %[[VAL_A:.*]] = memref.load %[[A]][%[[I]], %[[J]]]
643
+ // CHECK: %[[VAL_B:.*]] = memref.load %[[B]][%[[I]], %[[J]]]
644
+ // CHECK: scf.reduce(%[[VAL_A]], %[[VAL_B]] : f32, f32) {
645
+ // CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
646
+ // CHECK: %[[R:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32
647
+ // CHECK: scf.reduce.return %[[R]] : f32
648
+ // CHECK: }
649
+ // CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
650
+ // CHECK: %[[R:.*]] = arith.mulf %[[LHS]], %[[RHS]] : f32
651
+ // CHECK: scf.reduce.return %[[R]] : f32
652
+ // CHECK: }
653
+ // CHECK: return %[[RES]]#0, %[[RES]]#1 : f32, f32
654
+
655
+ // -----
656
+
657
+ func.func @fuse_reductions_three (%A: memref <2 x2 xf32 >, %B: memref <2 x2 xf32 >, %C: memref <2 x2 xf32 >) -> (f32 , f32 , f32 ) {
658
+ %c2 = arith.constant 2 : index
659
+ %c0 = arith.constant 0 : index
660
+ %c1 = arith.constant 1 : index
661
+ %init1 = arith.constant 1.0 : f32
662
+ %init2 = arith.constant 2.0 : f32
663
+ %init3 = arith.constant 3.0 : f32
664
+ %res1 = scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) init (%init1 ) -> f32 {
665
+ %A_elem = memref.load %A [%i , %j ] : memref <2 x2 xf32 >
666
+ scf.reduce (%A_elem : f32 ) {
667
+ ^bb0 (%lhs: f32 , %rhs: f32 ):
668
+ %1 = arith.addf %lhs , %rhs : f32
669
+ scf.reduce.return %1 : f32
670
+ }
671
+ }
672
+ %res2 = scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) init (%init2 ) -> f32 {
673
+ %B_elem = memref.load %B [%i , %j ] : memref <2 x2 xf32 >
674
+ scf.reduce (%B_elem : f32 ) {
675
+ ^bb0 (%lhs: f32 , %rhs: f32 ):
676
+ %1 = arith.mulf %lhs , %rhs : f32
677
+ scf.reduce.return %1 : f32
678
+ }
679
+ }
680
+ %res3 = scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) init (%init3 ) -> f32 {
681
+ %A_elem = memref.load %C [%i , %j ] : memref <2 x2 xf32 >
682
+ scf.reduce (%A_elem : f32 ) {
683
+ ^bb0 (%lhs: f32 , %rhs: f32 ):
684
+ %1 = arith.addf %lhs , %rhs : f32
685
+ scf.reduce.return %1 : f32
686
+ }
687
+ }
688
+ return %res1 , %res2 , %res3 : f32 , f32 , f32
689
+ }
690
+
691
+ // CHECK-LABEL: func @fuse_reductions_three
692
+ // CHECK-SAME: (%[[A:.*]]: memref<2x2xf32>, %[[B:.*]]: memref<2x2xf32>, %[[C:.*]]: memref<2x2xf32>) -> (f32, f32, f32)
693
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
694
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
695
+ // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
696
+ // CHECK-DAG: %[[INIT1:.*]] = arith.constant 1.000000e+00 : f32
697
+ // CHECK-DAG: %[[INIT2:.*]] = arith.constant 2.000000e+00 : f32
698
+ // CHECK-DAG: %[[INIT3:.*]] = arith.constant 3.000000e+00 : f32
699
+ // CHECK: %[[RES:.*]]:3 = scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]])
700
+ // CHECK-SAME: to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]])
701
+ // CHECK-SAME: init (%[[INIT1]], %[[INIT2]], %[[INIT3]]) -> (f32, f32, f32)
702
+ // CHECK: %[[VAL_A:.*]] = memref.load %[[A]][%[[I]], %[[J]]]
703
+ // CHECK: %[[VAL_B:.*]] = memref.load %[[B]][%[[I]], %[[J]]]
704
+ // CHECK: %[[VAL_C:.*]] = memref.load %[[C]][%[[I]], %[[J]]]
705
+ // CHECK: scf.reduce(%[[VAL_A]], %[[VAL_B]], %[[VAL_C]] : f32, f32, f32) {
706
+ // CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
707
+ // CHECK: %[[R:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32
708
+ // CHECK: scf.reduce.return %[[R]] : f32
709
+ // CHECK: }
710
+ // CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
711
+ // CHECK: %[[R:.*]] = arith.mulf %[[LHS]], %[[RHS]] : f32
712
+ // CHECK: scf.reduce.return %[[R]] : f32
713
+ // CHECK: }
714
+ // CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
715
+ // CHECK: %[[R:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32
716
+ // CHECK: scf.reduce.return %[[R]] : f32
717
+ // CHECK: }
718
+ // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : f32, f32, f32
719
+
720
+ // -----
721
+
722
+ func.func @reductions_use_res (%A: memref <2 x2 xf32 >, %B: memref <2 x2 xf32 >) -> (f32 , f32 ) {
723
+ %c2 = arith.constant 2 : index
724
+ %c0 = arith.constant 0 : index
725
+ %c1 = arith.constant 1 : index
726
+ %init1 = arith.constant 1.0 : f32
727
+ %res1 = scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) init (%init1 ) -> f32 {
728
+ %A_elem = memref.load %A [%i , %j ] : memref <2 x2 xf32 >
729
+ scf.reduce (%A_elem : f32 ) {
730
+ ^bb0 (%lhs: f32 , %rhs: f32 ):
731
+ %1 = arith.addf %lhs , %rhs : f32
732
+ scf.reduce.return %1 : f32
733
+ }
734
+ }
735
+ %res2 = scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) init (%res1 ) -> f32 {
736
+ %B_elem = memref.load %B [%i , %j ] : memref <2 x2 xf32 >
737
+ scf.reduce (%B_elem : f32 ) {
738
+ ^bb0 (%lhs: f32 , %rhs: f32 ):
739
+ %1 = arith.mulf %lhs , %rhs : f32
740
+ scf.reduce.return %1 : f32
741
+ }
742
+ }
743
+ return %res1 , %res2 : f32 , f32
744
+ }
745
+
746
+ // %res1 is used as second scf.parallel arg, cannot fuse
747
+ // CHECK-LABEL: func @reductions_use_res
748
+ // CHECK: scf.parallel
749
+ // CHECK: scf.parallel
750
+
751
+ // -----
752
+
753
+ func.func @reductions_use_res_inside (%A: memref <2 x2 xf32 >, %B: memref <2 x2 xf32 >) -> (f32 , f32 ) {
754
+ %c2 = arith.constant 2 : index
755
+ %c0 = arith.constant 0 : index
756
+ %c1 = arith.constant 1 : index
757
+ %init1 = arith.constant 1.0 : f32
758
+ %init2 = arith.constant 2.0 : f32
759
+ %res1 = scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) init (%init1 ) -> f32 {
760
+ %A_elem = memref.load %A [%i , %j ] : memref <2 x2 xf32 >
761
+ scf.reduce (%A_elem : f32 ) {
762
+ ^bb0 (%lhs: f32 , %rhs: f32 ):
763
+ %1 = arith.addf %lhs , %rhs : f32
764
+ scf.reduce.return %1 : f32
765
+ }
766
+ }
767
+ %res2 = scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) init (%init2 ) -> f32 {
768
+ %B_elem = memref.load %B [%i , %j ] : memref <2 x2 xf32 >
769
+ %sum = arith.addf %B_elem , %res1 : f32
770
+ scf.reduce (%sum : f32 ) {
771
+ ^bb0 (%lhs: f32 , %rhs: f32 ):
772
+ %1 = arith.mulf %lhs , %rhs : f32
773
+ scf.reduce.return %1 : f32
774
+ }
775
+ }
776
+ return %res1 , %res2 : f32 , f32
777
+ }
778
+
779
+ // %res1 is used inside second scf.parallel, cannot fuse
780
+ // CHECK-LABEL: func @reductions_use_res_inside
781
+ // CHECK: scf.parallel
782
+ // CHECK: scf.parallel
783
+
784
+ // -----
785
+
786
+ func.func @reductions_use_res_between (%A: memref <2 x2 xf32 >, %B: memref <2 x2 xf32 >) -> (f32 , f32 , f32 ) {
787
+ %c2 = arith.constant 2 : index
788
+ %c0 = arith.constant 0 : index
789
+ %c1 = arith.constant 1 : index
790
+ %init1 = arith.constant 1.0 : f32
791
+ %init2 = arith.constant 2.0 : f32
792
+ %res1 = scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) init (%init1 ) -> f32 {
793
+ %A_elem = memref.load %A [%i , %j ] : memref <2 x2 xf32 >
794
+ scf.reduce (%A_elem : f32 ) {
795
+ ^bb0 (%lhs: f32 , %rhs: f32 ):
796
+ %1 = arith.addf %lhs , %rhs : f32
797
+ scf.reduce.return %1 : f32
798
+ }
799
+ }
800
+ %res3 = arith.addf %res1 , %init2 : f32
801
+ %res2 = scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) init (%init2 ) -> f32 {
802
+ %B_elem = memref.load %B [%i , %j ] : memref <2 x2 xf32 >
803
+ scf.reduce (%B_elem : f32 ) {
804
+ ^bb0 (%lhs: f32 , %rhs: f32 ):
805
+ %1 = arith.mulf %lhs , %rhs : f32
806
+ scf.reduce.return %1 : f32
807
+ }
808
+ }
809
+ return %res1 , %res2 , %res3 : f32 , f32 , f32
810
+ }
811
+
812
+ // instruction in between the loops uses the first loop result
813
+ // CHECK-LABEL: func @reductions_use_res_between
814
+ // CHECK: scf.parallel
815
+ // CHECK: scf.parallel
0 commit comments