diff --git a/llvm/include/llvm/CodeGen/GlobalISel/IRTranslator.h b/llvm/include/llvm/CodeGen/GlobalISel/IRTranslator.h index bffc03ed0187e..1b094d9d9fe77 100644 --- a/llvm/include/llvm/CodeGen/GlobalISel/IRTranslator.h +++ b/llvm/include/llvm/CodeGen/GlobalISel/IRTranslator.h @@ -366,6 +366,10 @@ class IRTranslator : public MachineFunctionPass { BranchProbability BranchProbToNext, Register Reg, SwitchCG::BitTestCase &B, MachineBasicBlock *SwitchBB); + void splitWorkItem(SwitchCG::SwitchWorkList &WorkList, + const SwitchCG::SwitchWorkListItem &W, Value *Cond, + MachineBasicBlock *SwitchMBB, MachineIRBuilder &MIB); + bool lowerJumpTableWorkItem( SwitchCG::SwitchWorkListItem W, MachineBasicBlock *SwitchMBB, MachineBasicBlock *CurMBB, MachineBasicBlock *DefaultMBB, diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp index 9c11113902a24..6708f2baa5ed5 100644 --- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp +++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp @@ -751,16 +751,91 @@ bool IRTranslator::translateSwitch(const User &U, MachineIRBuilder &MIB) { auto DefaultProb = getEdgeProbability(SwitchMBB, DefaultMBB); WorkList.push_back({SwitchMBB, First, Last, nullptr, nullptr, DefaultProb}); - // FIXME: At the moment we don't do any splitting optimizations here like - // SelectionDAG does, so this worklist only has one entry. while (!WorkList.empty()) { SwitchWorkListItem W = WorkList.pop_back_val(); + + unsigned NumClusters = W.LastCluster - W.FirstCluster + 1; + // For optimized builds, lower large range as a balanced binary tree. + if (NumClusters > 3 && + MF->getTarget().getOptLevel() != CodeGenOptLevel::None && + !DefaultMBB->getParent()->getFunction().hasMinSize()) { + splitWorkItem(WorkList, W, SI.getCondition(), SwitchMBB, MIB); + continue; + } + if (!lowerSwitchWorkItem(W, SI.getCondition(), SwitchMBB, DefaultMBB, MIB)) return false; } return true; } +void IRTranslator::splitWorkItem(SwitchCG::SwitchWorkList &WorkList, + const SwitchCG::SwitchWorkListItem &W, + Value *Cond, MachineBasicBlock *SwitchMBB, + MachineIRBuilder &MIB) { + using namespace SwitchCG; + assert(W.FirstCluster->Low->getValue().slt(W.LastCluster->Low->getValue()) && + "Clusters not sorted?"); + assert(W.LastCluster - W.FirstCluster + 1 >= 2 && "Too small to split!"); + + auto [LastLeft, FirstRight, LeftProb, RightProb] = + SL->computeSplitWorkItemInfo(W); + + // Use the first element on the right as pivot since we will make less-than + // comparisons against it. + CaseClusterIt PivotCluster = FirstRight; + assert(PivotCluster > W.FirstCluster); + assert(PivotCluster <= W.LastCluster); + + CaseClusterIt FirstLeft = W.FirstCluster; + CaseClusterIt LastRight = W.LastCluster; + + const ConstantInt *Pivot = PivotCluster->Low; + + // New blocks will be inserted immediately after the current one. + MachineFunction::iterator BBI(W.MBB); + ++BBI; + + // We will branch to the LHS if Value < Pivot. If LHS is a single cluster, + // we can branch to its destination directly if it's squeezed exactly in + // between the known lower bound and Pivot - 1. + MachineBasicBlock *LeftMBB; + if (FirstLeft == LastLeft && FirstLeft->Kind == CC_Range && + FirstLeft->Low == W.GE && + (FirstLeft->High->getValue() + 1LL) == Pivot->getValue()) { + LeftMBB = FirstLeft->MBB; + } else { + LeftMBB = FuncInfo.MF->CreateMachineBasicBlock(W.MBB->getBasicBlock()); + FuncInfo.MF->insert(BBI, LeftMBB); + WorkList.push_back( + {LeftMBB, FirstLeft, LastLeft, W.GE, Pivot, W.DefaultProb / 2}); + } + + // Similarly, we will branch to the RHS if Value >= Pivot. If RHS is a + // single cluster, RHS.Low == Pivot, and we can branch to its destination + // directly if RHS.High equals the current upper bound. + MachineBasicBlock *RightMBB; + if (FirstRight == LastRight && FirstRight->Kind == CC_Range && W.LT && + (FirstRight->High->getValue() + 1ULL) == W.LT->getValue()) { + RightMBB = FirstRight->MBB; + } else { + RightMBB = FuncInfo.MF->CreateMachineBasicBlock(W.MBB->getBasicBlock()); + FuncInfo.MF->insert(BBI, RightMBB); + WorkList.push_back( + {RightMBB, FirstRight, LastRight, Pivot, W.LT, W.DefaultProb / 2}); + } + + // Create the CaseBlock record that will be used to lower the branch. + CaseBlock CB(ICmpInst::Predicate::ICMP_SLT, false, Cond, Pivot, nullptr, + LeftMBB, RightMBB, W.MBB, MIB.getDebugLoc(), LeftProb, + RightProb); + + if (W.MBB == SwitchMBB) + emitSwitchCase(CB, SwitchMBB, MIB); + else + SL->SwitchCases.push_back(CB); +} + void IRTranslator::emitJumpTable(SwitchCG::JumpTable &JT, MachineBasicBlock *MBB) { // Emit the code for the jump table diff --git a/llvm/test/CodeGen/AArch64/GlobalISel/irtranslator-switch-split.ll b/llvm/test/CodeGen/AArch64/GlobalISel/irtranslator-switch-split.ll index 54c8eb913d5d4..55cf48ed2245f 100644 --- a/llvm/test/CodeGen/AArch64/GlobalISel/irtranslator-switch-split.ll +++ b/llvm/test/CodeGen/AArch64/GlobalISel/irtranslator-switch-split.ll @@ -17,31 +17,33 @@ define i32 @scanfile(i32 %call148) { ; CHECK-NEXT: .cfi_offset w30, -8 ; CHECK-NEXT: .cfi_offset w29, -16 ; CHECK-NEXT: mov w8, w0 +; CHECK-NEXT: cmp w0, #1 ; CHECK-NEXT: mov w0, wzr -; CHECK-NEXT: cbz w8, LBB0_7 +; CHECK-NEXT: b.ge LBB0_3 ; CHECK-NEXT: ; %bb.1: ; %entry -; CHECK-NEXT: cmp w8, #1 -; CHECK-NEXT: b.eq LBB0_7 -; CHECK-NEXT: ; %bb.2: ; %entry +; CHECK-NEXT: cbnz w8, LBB0_7 +; CHECK-NEXT: LBB0_2: ; %common.ret1 +; CHECK-NEXT: ldp x29, x30, [sp], #16 ; 16-byte Folded Reload +; CHECK-NEXT: ret +; CHECK-NEXT: LBB0_3: ; %entry +; CHECK-NEXT: b.eq LBB0_2 +; CHECK-NEXT: ; %bb.4: ; %entry ; CHECK-NEXT: cmp w8, #2 -; CHECK-NEXT: b.eq LBB0_4 -; CHECK-NEXT: ; %bb.3: ; %entry +; CHECK-NEXT: b.eq LBB0_6 +; CHECK-NEXT: ; %bb.5: ; %entry ; CHECK-NEXT: cmp w8, #3 -; CHECK-NEXT: b.ne LBB0_5 -; CHECK-NEXT: LBB0_4: ; %sw.bb300 +; CHECK-NEXT: b.ne LBB0_2 +; CHECK-NEXT: LBB0_6: ; %sw.bb300 ; CHECK-NEXT: bl _logg ; CHECK-NEXT: ldp x29, x30, [sp], #16 ; 16-byte Folded Reload ; CHECK-NEXT: ret -; CHECK-NEXT: LBB0_5: ; %entry +; CHECK-NEXT: LBB0_7: ; %entry ; CHECK-NEXT: cmn w8, #2 -; CHECK-NEXT: b.eq LBB0_8 -; CHECK-NEXT: ; %bb.6: ; %entry +; CHECK-NEXT: b.eq LBB0_9 +; CHECK-NEXT: ; %bb.8: ; %entry ; CHECK-NEXT: cmn w8, #1 -; CHECK-NEXT: b.eq LBB0_8 -; CHECK-NEXT: LBB0_7: ; %common.ret1 -; CHECK-NEXT: ldp x29, x30, [sp], #16 ; 16-byte Folded Reload -; CHECK-NEXT: ret -; CHECK-NEXT: LBB0_8: ; %sw.bb150 +; CHECK-NEXT: b.ne LBB0_2 +; CHECK-NEXT: LBB0_9: ; %sw.bb150 ; CHECK-NEXT: bl _logg ; CHECK-NEXT: brk #0x1 entry: