diff --git a/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp b/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp index 2744c25d1bc75..52354281cdd7e 100644 --- a/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp @@ -17,6 +17,8 @@ #include "SPIRVSubtarget.h" #include "SPIRVTargetMachine.h" #include "SPIRVUtils.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/CodeGen/IntrinsicLowering.h" #include "llvm/IR/CFG.h" @@ -71,7 +73,7 @@ class SPIRVMergeRegionExitTargets : public FunctionPass { /// terminator will take. llvm::Value *createExitVariable( BasicBlock *BB, - const std::unordered_map &TargetToValue) { + const DenseMap &TargetToValue) { auto *T = BB->getTerminator(); if (isa(T)) return nullptr; @@ -103,7 +105,7 @@ class SPIRVMergeRegionExitTargets : public FunctionPass { /// Replaces |BB|'s branch targets present in |ToReplace| with |NewTarget|. void replaceBranchTargets(BasicBlock *BB, - const std::unordered_set ToReplace, + const SmallPtrSet &ToReplace, BasicBlock *NewTarget) { auto *T = BB->getTerminator(); if (isa(T)) @@ -133,7 +135,7 @@ class SPIRVMergeRegionExitTargets : public FunctionPass { bool runOnConvergenceRegionNoRecurse(LoopInfo &LI, const SPIRV::ConvergenceRegion *CR) { // Gather all the exit targets for this region. - std::unordered_set ExitTargets; + SmallPtrSet ExitTargets; for (BasicBlock *Exit : CR->Exits) { for (BasicBlock *Target : gatherSuccessors(Exit)) { if (CR->Blocks.count(Target) == 0) @@ -164,9 +166,10 @@ class SPIRVMergeRegionExitTargets : public FunctionPass { // Creating one constant per distinct exit target. This will be route to the // correct target. - std::unordered_map TargetToValue; + DenseMap TargetToValue; for (BasicBlock *Target : SortedExitTargets) - TargetToValue.emplace(Target, Builder.getInt32(TargetToValue.size())); + TargetToValue.insert( + std::make_pair(Target, Builder.getInt32(TargetToValue.size()))); // Creating one variable per exit node, set to the constant matching the // targeted external block. @@ -184,12 +187,12 @@ class SPIRVMergeRegionExitTargets : public FunctionPass { } // Creating the switch to jump to the correct exit target. - std::vector> CasesList( - TargetToValue.begin(), TargetToValue.end()); - llvm::SwitchInst *Sw = - Builder.CreateSwitch(node, CasesList[0].first, CasesList.size() - 1); - for (size_t i = 1; i < CasesList.size(); i++) - Sw->addCase(CasesList[i].second, CasesList[i].first); + llvm::SwitchInst *Sw = Builder.CreateSwitch(node, SortedExitTargets[0], + SortedExitTargets.size() - 1); + for (size_t i = 1; i < SortedExitTargets.size(); i++) { + BasicBlock *BB = SortedExitTargets[i]; + Sw->addCase(TargetToValue[BB], BB); + } // Fix exit branches to redirect to the new exit. for (auto Exit : CR->Exits) diff --git a/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-break.ll b/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-break.ll index b3fcdc978625f..e7b1b441405f6 100644 --- a/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-break.ll +++ b/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-break.ll @@ -66,7 +66,7 @@ while.end: ; CHECK: %[[#new_end]] = OpLabel ; CHECK: %[[#route:]] = OpPhi %[[#int_ty]] %[[#int_1]] %[[#while_cond]] %[[#int_0]] %[[#while_body]] -; CHECK: OpSwitch %[[#route]] %[[#while_end_loopexit]] 0 %[[#if_then]] +; CHECK: OpSwitch %[[#route]] %[[#if_then]] 1 %[[#while_end_loopexit]] } declare token @llvm.experimental.convergence.entry() #2 diff --git a/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-convergence-in-break.ll b/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-convergence-in-break.ll index a67c58fdd5749..593e3631c02b9 100644 --- a/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-convergence-in-break.ll +++ b/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-convergence-in-break.ll @@ -75,7 +75,7 @@ while.end: ; CHECK: %[[#new_end]] = OpLabel ; CHECK: %[[#route:]] = OpPhi %[[#int_ty]] %[[#int_0]] %[[#while_cond]] %[[#int_1]] %[[#tail]] -; CHECK: OpSwitch %[[#route]] %[[#while_end]] 0 %[[#while_end_loopexit]] +; CHECK: OpSwitch %[[#route]] %[[#while_end_loopexit]] 1 %[[#while_end]] } declare token @llvm.experimental.convergence.entry() #2 diff --git a/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-multiple-break.ll b/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-multiple-break.ll index 32a97553df05e..9806dd7955468 100644 --- a/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-multiple-break.ll +++ b/llvm/test/CodeGen/SPIRV/structurizer/merge-exit-multiple-break.ll @@ -85,7 +85,7 @@ while.end: ; CHECK: %[[#new_end]] = OpLabel ; CHECK: %[[#route:]] = OpPhi %[[#int_ty]] %[[#int_2]] %[[#while_cond]] %[[#int_0]] %[[#while_body]] %[[#int_1]] %[[#if_end]] -; CHECK: OpSwitch %[[#route]] %[[#while_end_loopexit]] 1 %[[#if_then2]] 0 %[[#if_then]] +; CHECK: OpSwitch %[[#route]] %[[#if_then]] 1 %[[#if_then2]] 2 %[[#while_end_loopexit]] } declare token @llvm.experimental.convergence.entry() #2