@@ -390,34 +390,32 @@ func.func @do_not_fuse_alias(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
390
390
391
391
// -----
392
392
393
- func.func @fuse_reductions (%A: memref <2 x2 xf32 >, %B: memref <2 x2 xf32 >) -> (f32 , f32 ) {
393
+ func.func @fuse_reductions_two (%A: memref <2 x2 xf32 >, %B: memref <2 x2 xf32 >) -> (f32 , f32 ) {
394
394
%c2 = arith.constant 2 : index
395
395
%c0 = arith.constant 0 : index
396
396
%c1 = arith.constant 1 : index
397
397
%init1 = arith.constant 1.0 : f32
398
398
%init2 = arith.constant 2.0 : f32
399
399
%res1 = scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) init (%init1 ) -> f32 {
400
400
%A_elem = memref.load %A [%i , %j ] : memref <2 x2 xf32 >
401
- scf.reduce (%A_elem ) : f32 {
401
+ scf.reduce (%A_elem : f32 ) {
402
402
^bb0 (%lhs: f32 , %rhs: f32 ):
403
403
%1 = arith.addf %lhs , %rhs : f32
404
404
scf.reduce.return %1 : f32
405
405
}
406
- scf.yield
407
406
}
408
407
%res2 = scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) init (%init2 ) -> f32 {
409
408
%B_elem = memref.load %B [%i , %j ] : memref <2 x2 xf32 >
410
- scf.reduce (%B_elem ) : f32 {
409
+ scf.reduce (%B_elem : f32 ) {
411
410
^bb0 (%lhs: f32 , %rhs: f32 ):
412
411
%1 = arith.mulf %lhs , %rhs : f32
413
412
scf.reduce.return %1 : f32
414
413
}
415
- scf.yield
416
414
}
417
415
return %res1 , %res2 : f32 , f32
418
416
}
419
417
420
- // CHECK-LABEL: func @fuse_reductions
418
+ // CHECK-LABEL: func @fuse_reductions_two
421
419
// CHECK-SAME: (%[[A:.*]]: memref<2x2xf32>, %[[B:.*]]: memref<2x2xf32>) -> (f32, f32)
422
420
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
423
421
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
@@ -428,44 +426,105 @@ func.func @fuse_reductions(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -> (f32, f3
428
426
// CHECK-SAME: to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]])
429
427
// CHECK-SAME: init (%[[INIT1]], %[[INIT2]]) -> (f32, f32)
430
428
// CHECK: %[[VAL_A:.*]] = memref.load %[[A]][%[[I]], %[[J]]]
431
- // CHECK: scf.reduce(%[[VAL_A]]) : f32 {
429
+ // CHECK: %[[VAL_B:.*]] = memref.load %[[B]][%[[I]], %[[J]]]
430
+ // CHECK: scf.reduce(%[[VAL_A]], %[[VAL_B]] : f32, f32) {
432
431
// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
433
432
// CHECK: %[[R:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32
434
433
// CHECK: scf.reduce.return %[[R]] : f32
435
434
// CHECK: }
436
- // CHECK: %[[VAL_B:.*]] = memref.load %[[B]][%[[I]], %[[J]]]
437
- // CHECK: scf.reduce(%[[VAL_B]]) : f32 {
438
435
// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
439
436
// CHECK: %[[R:.*]] = arith.mulf %[[LHS]], %[[RHS]] : f32
440
437
// CHECK: scf.reduce.return %[[R]] : f32
441
438
// CHECK: }
442
- // CHECK: scf.yield
443
439
// CHECK: return %[[RES]]#0, %[[RES]]#1 : f32, f32
444
440
445
441
// -----
446
442
443
+ func.func @fuse_reductions_three (%A: memref <2 x2 xf32 >, %B: memref <2 x2 xf32 >, %C: memref <2 x2 xf32 >) -> (f32 , f32 , f32 ) {
444
+ %c2 = arith.constant 2 : index
445
+ %c0 = arith.constant 0 : index
446
+ %c1 = arith.constant 1 : index
447
+ %init1 = arith.constant 1.0 : f32
448
+ %init2 = arith.constant 2.0 : f32
449
+ %init3 = arith.constant 3.0 : f32
450
+ %res1 = scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) init (%init1 ) -> f32 {
451
+ %A_elem = memref.load %A [%i , %j ] : memref <2 x2 xf32 >
452
+ scf.reduce (%A_elem : f32 ) {
453
+ ^bb0 (%lhs: f32 , %rhs: f32 ):
454
+ %1 = arith.addf %lhs , %rhs : f32
455
+ scf.reduce.return %1 : f32
456
+ }
457
+ }
458
+ %res2 = scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) init (%init2 ) -> f32 {
459
+ %B_elem = memref.load %B [%i , %j ] : memref <2 x2 xf32 >
460
+ scf.reduce (%B_elem : f32 ) {
461
+ ^bb0 (%lhs: f32 , %rhs: f32 ):
462
+ %1 = arith.mulf %lhs , %rhs : f32
463
+ scf.reduce.return %1 : f32
464
+ }
465
+ }
466
+ %res3 = scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) init (%init3 ) -> f32 {
467
+ %A_elem = memref.load %C [%i , %j ] : memref <2 x2 xf32 >
468
+ scf.reduce (%A_elem : f32 ) {
469
+ ^bb0 (%lhs: f32 , %rhs: f32 ):
470
+ %1 = arith.addf %lhs , %rhs : f32
471
+ scf.reduce.return %1 : f32
472
+ }
473
+ }
474
+ return %res1 , %res2 , %res3 : f32 , f32 , f32
475
+ }
476
+
477
+ // CHECK-LABEL: func @fuse_reductions_three
478
+ // CHECK-SAME: (%[[A:.*]]: memref<2x2xf32>, %[[B:.*]]: memref<2x2xf32>, %[[C:.*]]: memref<2x2xf32>) -> (f32, f32, f32)
479
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
480
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
481
+ // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
482
+ // CHECK-DAG: %[[INIT1:.*]] = arith.constant 1.000000e+00 : f32
483
+ // CHECK-DAG: %[[INIT2:.*]] = arith.constant 2.000000e+00 : f32
484
+ // CHECK-DAG: %[[INIT3:.*]] = arith.constant 3.000000e+00 : f32
485
+ // CHECK: %[[RES:.*]]:3 = scf.parallel (%[[I:.*]], %[[J:.*]]) = (%[[C0]], %[[C0]])
486
+ // CHECK-SAME: to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]])
487
+ // CHECK-SAME: init (%[[INIT1]], %[[INIT2]], %[[INIT3]]) -> (f32, f32, f32)
488
+ // CHECK: %[[VAL_A:.*]] = memref.load %[[A]][%[[I]], %[[J]]]
489
+ // CHECK: %[[VAL_B:.*]] = memref.load %[[B]][%[[I]], %[[J]]]
490
+ // CHECK: %[[VAL_C:.*]] = memref.load %[[C]][%[[I]], %[[J]]]
491
+ // CHECK: scf.reduce(%[[VAL_A]], %[[VAL_B]], %[[VAL_C]] : f32, f32, f32) {
492
+ // CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
493
+ // CHECK: %[[R:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32
494
+ // CHECK: scf.reduce.return %[[R]] : f32
495
+ // CHECK: }
496
+ // CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
497
+ // CHECK: %[[R:.*]] = arith.mulf %[[LHS]], %[[RHS]] : f32
498
+ // CHECK: scf.reduce.return %[[R]] : f32
499
+ // CHECK: }
500
+ // CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
501
+ // CHECK: %[[R:.*]] = arith.addf %[[LHS]], %[[RHS]] : f32
502
+ // CHECK: scf.reduce.return %[[R]] : f32
503
+ // CHECK: }
504
+ // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : f32, f32, f32
505
+
506
+ // -----
507
+
447
508
func.func @reductions_use_res (%A: memref <2 x2 xf32 >, %B: memref <2 x2 xf32 >) -> (f32 , f32 ) {
448
509
%c2 = arith.constant 2 : index
449
510
%c0 = arith.constant 0 : index
450
511
%c1 = arith.constant 1 : index
451
512
%init1 = arith.constant 1.0 : f32
452
513
%res1 = scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) init (%init1 ) -> f32 {
453
514
%A_elem = memref.load %A [%i , %j ] : memref <2 x2 xf32 >
454
- scf.reduce (%A_elem ) : f32 {
515
+ scf.reduce (%A_elem : f32 ) {
455
516
^bb0 (%lhs: f32 , %rhs: f32 ):
456
517
%1 = arith.addf %lhs , %rhs : f32
457
518
scf.reduce.return %1 : f32
458
519
}
459
- scf.yield
460
520
}
461
521
%res2 = scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) init (%res1 ) -> f32 {
462
522
%B_elem = memref.load %B [%i , %j ] : memref <2 x2 xf32 >
463
- scf.reduce (%B_elem ) : f32 {
523
+ scf.reduce (%B_elem : f32 ) {
464
524
^bb0 (%lhs: f32 , %rhs: f32 ):
465
525
%1 = arith.mulf %lhs , %rhs : f32
466
526
scf.reduce.return %1 : f32
467
527
}
468
- scf.yield
469
528
}
470
529
return %res1 , %res2 : f32 , f32
471
530
}
@@ -485,22 +544,20 @@ func.func @reductions_use_res_inside(%A: memref<2x2xf32>, %B: memref<2x2xf32>) -
485
544
%init2 = arith.constant 2.0 : f32
486
545
%res1 = scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) init (%init1 ) -> f32 {
487
546
%A_elem = memref.load %A [%i , %j ] : memref <2 x2 xf32 >
488
- scf.reduce (%A_elem ) : f32 {
547
+ scf.reduce (%A_elem : f32 ) {
489
548
^bb0 (%lhs: f32 , %rhs: f32 ):
490
549
%1 = arith.addf %lhs , %rhs : f32
491
550
scf.reduce.return %1 : f32
492
551
}
493
- scf.yield
494
552
}
495
553
%res2 = scf.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) init (%init2 ) -> f32 {
496
554
%B_elem = memref.load %B [%i , %j ] : memref <2 x2 xf32 >
497
555
%sum = arith.addf %B_elem , %res1 : f32
498
- scf.reduce (%sum ) : f32 {
556
+ scf.reduce (%sum : f32 ) {
499
557
^bb0 (%lhs: f32 , %rhs: f32 ):
500
558
%1 = arith.mulf %lhs , %rhs : f32
501
559
scf.reduce.return %1 : f32
502
560
}
503
- scf.yield
504
561
}
505
562
return %res1 , %res2 : f32 , f32
506
563
}
0 commit comments