@@ -563,11 +563,14 @@ static void createChoosePivot(OpBuilder &builder, ModuleOp module,
563
563
// p = (lo+hi)/2 // pivot index
564
564
// i = lo
565
565
// j = hi-1
566
- // while (i < j ) do {
566
+ // while (true ) do {
567
567
// while (xs[i] < xs[p]) i ++;
568
568
// i_eq = (xs[i] == xs[p]);
569
569
// while (xs[j] > xs[p]) j --;
570
570
// j_eq = (xs[j] == xs[p]);
571
+ //
572
+ // if (i >= j) return j + 1;
573
+ //
571
574
// if (i < j) {
572
575
// swap(xs[i], xs[j])
573
576
// if (i == p) {
@@ -581,8 +584,7 @@ static void createChoosePivot(OpBuilder &builder, ModuleOp module,
581
584
// }
582
585
// }
583
586
// }
584
- // return p
585
- // }
587
+ // }
586
588
static void createPartitionFunc (OpBuilder &builder, ModuleOp module,
587
589
func::FuncOp func, uint64_t nx, uint64_t ny,
588
590
bool isCoo, uint32_t nTrailingP = 0 ) {
@@ -605,22 +607,22 @@ static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
605
607
Value i = lo;
606
608
Value j = builder.create <arith::SubIOp>(loc, hi, c1);
607
609
createChoosePivot (builder, module, func, nx, ny, isCoo, i, j, p, args);
608
- SmallVector<Value, 3 > operands{i, j, p}; // Exactly three values.
609
- SmallVector<Type, 3 > types{i.getType (), j.getType (), p.getType ()};
610
+ Value trueVal = constantI1 (builder, loc, true ); // The value for while (true)
611
+ SmallVector<Value, 4 > operands{i, j, p, trueVal}; // Exactly four values.
612
+ SmallVector<Type, 4 > types{i.getType (), j.getType (), p.getType (),
613
+ trueVal.getType ()};
610
614
scf::WhileOp whileOp = builder.create <scf::WhileOp>(loc, types, operands);
611
615
612
616
// The before-region of the WhileOp.
613
- Block *before =
614
- builder. createBlock (&whileOp. getBefore (), {}, types, { loc, loc, loc});
617
+ Block *before = builder. createBlock (&whileOp. getBefore (), {}, types,
618
+ {loc, loc, loc, loc});
615
619
builder.setInsertionPointToEnd (before);
616
- Value cond = builder.create <arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
617
- before->getArgument (0 ),
618
- before->getArgument (1 ));
619
- builder.create <scf::ConditionOp>(loc, cond, before->getArguments ());
620
+ builder.create <scf::ConditionOp>(loc, before->getArgument (3 ),
621
+ before->getArguments ());
620
622
621
623
// The after-region of the WhileOp.
622
624
Block *after =
623
- builder.createBlock (&whileOp.getAfter (), {}, types, {loc, loc, loc});
625
+ builder.createBlock (&whileOp.getAfter (), {}, types, {loc, loc, loc, loc });
624
626
builder.setInsertionPointToEnd (after);
625
627
i = after->getArgument (0 );
626
628
j = after->getArgument (1 );
@@ -637,7 +639,8 @@ static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
637
639
j = jresult;
638
640
639
641
// If i < j:
640
- cond = builder.create <arith::CmpIOp>(loc, arith::CmpIPredicate::ult, i, j);
642
+ Value cond =
643
+ builder.create <arith::CmpIOp>(loc, arith::CmpIPredicate::ult, i, j);
641
644
scf::IfOp ifOp = builder.create <scf::IfOp>(loc, types, cond, /* else=*/ true );
642
645
builder.setInsertionPointToStart (&ifOp.getThenRegion ().front ());
643
646
SmallVector<Value> swapOperands{i, j};
@@ -675,11 +678,15 @@ static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
675
678
builder.setInsertionPointAfter (ifOp2);
676
679
builder.create <scf::YieldOp>(
677
680
loc,
678
- ValueRange{ifOp2.getResult (0 ), ifOp2.getResult (1 ), ifOpI.getResult (0 )});
681
+ ValueRange{ifOp2.getResult (0 ), ifOp2.getResult (1 ), ifOpI.getResult (0 ),
682
+ /* cont=*/ constantI1 (builder, loc, true )});
679
683
680
- // False branch for if i < j:
684
+ // False branch for if i < j (i.e., i >= j) :
681
685
builder.setInsertionPointToStart (&ifOp.getElseRegion ().front ());
682
- builder.create <scf::YieldOp>(loc, ValueRange{i, j, p});
686
+ p = builder.create <arith::AddIOp>(loc, j,
687
+ constantOne (builder, loc, j.getType ()));
688
+ builder.create <scf::YieldOp>(
689
+ loc, ValueRange{i, j, p, /* cont=*/ constantI1 (builder, loc, false )});
683
690
684
691
// Return for the whileOp.
685
692
builder.setInsertionPointAfter (ifOp);
@@ -927,6 +934,8 @@ createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func,
927
934
Location loc = func.getLoc ();
928
935
Value lo = args[loIdx];
929
936
Value hi = args[hiIdx];
937
+ SmallVector<Type, 2 > types (2 , lo.getType ()); // Only two types.
938
+
930
939
FlatSymbolRefAttr partitionFunc = getMangledSortHelperFunc (
931
940
builder, func, {IndexType::get (context)}, kPartitionFuncNamePrefix , nx,
932
941
ny, isCoo, args.drop_back (nTrailingP), createPartitionFunc);
@@ -935,14 +944,25 @@ createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func,
935
944
TypeRange{IndexType::get (context)},
936
945
args.drop_back (nTrailingP))
937
946
.getResult (0 );
938
- Value pP1 =
939
- builder.create <arith::AddIOp>(loc, p, constantIndex (builder, loc, 1 ));
947
+
940
948
Value lenLow = builder.create <arith::SubIOp>(loc, p, lo);
941
949
Value lenHigh = builder.create <arith::SubIOp>(loc, hi, p);
950
+ // Partition already sorts array with len <= 2
951
+ Value c2 = constantIndex (builder, loc, 2 );
952
+ Value len = builder.create <arith::SubIOp>(loc, hi, lo);
953
+ Value lenGtTwo =
954
+ builder.create <arith::CmpIOp>(loc, arith::CmpIPredicate::ugt, len, c2);
955
+ scf::IfOp ifLenGtTwo =
956
+ builder.create <scf::IfOp>(loc, types, lenGtTwo, /* else=*/ true );
957
+ builder.setInsertionPointToStart (&ifLenGtTwo.getElseRegion ().front ());
958
+ // Returns an empty range to mark the entire region is fully sorted.
959
+ builder.create <scf::YieldOp>(loc, ValueRange{lo, lo});
960
+
961
+ // Else len > 2, need recursion.
962
+ builder.setInsertionPointToStart (&ifLenGtTwo.getThenRegion ().front ());
942
963
Value cond = builder.create <arith::CmpIOp>(loc, arith::CmpIPredicate::ule,
943
964
lenLow, lenHigh);
944
965
945
- SmallVector<Type, 2 > types (2 , lo.getType ()); // Only two types.
946
966
scf::IfOp ifOp = builder.create <scf::IfOp>(loc, types, cond, /* else=*/ true );
947
967
948
968
Value c0 = constantIndex (builder, loc, 0 );
@@ -961,14 +981,17 @@ createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func,
961
981
// the bigger partition to be processed by the enclosed while-loop.
962
982
builder.setInsertionPointToStart (&ifOp.getThenRegion ().front ());
963
983
mayRecursion (lo, p, lenLow);
964
- builder.create <scf::YieldOp>(loc, ValueRange{pP1 , hi});
984
+ builder.create <scf::YieldOp>(loc, ValueRange{p , hi});
965
985
966
986
builder.setInsertionPointToStart (&ifOp.getElseRegion ().front ());
967
- mayRecursion (pP1 , hi, lenHigh);
987
+ mayRecursion (p , hi, lenHigh);
968
988
builder.create <scf::YieldOp>(loc, ValueRange{lo, p});
969
989
970
990
builder.setInsertionPointAfter (ifOp);
971
- return std::make_pair (ifOp.getResult (0 ), ifOp.getResult (1 ));
991
+ builder.create <scf::YieldOp>(loc, ifOp.getResults ());
992
+
993
+ builder.setInsertionPointAfter (ifLenGtTwo);
994
+ return std::make_pair (ifLenGtTwo.getResult (0 ), ifLenGtTwo.getResult (1 ));
972
995
}
973
996
974
997
// / Creates a function to perform insertion sort on the values in the range of
0 commit comments