diff --git a/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp b/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp index 54ca8ccd8d9e9..66d26bf5b11e2 100644 --- a/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp +++ b/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp @@ -1292,53 +1292,60 @@ std::optional RISCVVLOptimizer::checkUsers(MachineInstr &MI) { return CommonVL; } -bool RISCVVLOptimizer::tryReduceVL(MachineInstr &OrigMI) { - SetVector Worklist; - Worklist.insert(&OrigMI); +bool RISCVVLOptimizer::tryReduceVL(MachineInstr &MI) { + LLVM_DEBUG(dbgs() << "Trying to reduce VL for " << MI << "\n"); - bool MadeChange = false; - while (!Worklist.empty()) { - MachineInstr &MI = *Worklist.pop_back_val(); - LLVM_DEBUG(dbgs() << "Trying to reduce VL for " << MI << "\n"); + if (!isVectorRegClass(MI.getOperand(0).getReg(), MRI)) + return false; - if (!isVectorRegClass(MI.getOperand(0).getReg(), MRI)) - continue; + auto CommonVL = checkUsers(MI); + if (!CommonVL) + return false; - auto CommonVL = checkUsers(MI); - if (!CommonVL) - continue; + assert((CommonVL->isImm() || CommonVL->getReg().isVirtual()) && + "Expected VL to be an Imm or virtual Reg"); - assert((CommonVL->isImm() || CommonVL->getReg().isVirtual()) && - "Expected VL to be an Imm or virtual Reg"); + unsigned VLOpNum = RISCVII::getVLOpNum(MI.getDesc()); + MachineOperand &VLOp = MI.getOperand(VLOpNum); - unsigned VLOpNum = RISCVII::getVLOpNum(MI.getDesc()); - MachineOperand &VLOp = MI.getOperand(VLOpNum); + if (!RISCV::isVLKnownLE(*CommonVL, VLOp)) { + LLVM_DEBUG(dbgs() << " Abort due to CommonVL not <= VLOp.\n"); + return false; + } - if (!RISCV::isVLKnownLE(*CommonVL, VLOp)) { - LLVM_DEBUG(dbgs() << " Abort due to CommonVL not <= VLOp.\n"); - continue; - } + if (CommonVL->isImm()) { + LLVM_DEBUG(dbgs() << " Reduce VL from " << VLOp << " to " + << CommonVL->getImm() << " for " << MI << "\n"); + VLOp.ChangeToImmediate(CommonVL->getImm()); + return true; + } + const MachineInstr *VLMI = MRI->getVRegDef(CommonVL->getReg()); + if (!MDT->dominates(VLMI, &MI)) + return false; + LLVM_DEBUG( + dbgs() << " Reduce VL from " << VLOp << " to " + << printReg(CommonVL->getReg(), MRI->getTargetRegisterInfo()) + << " for " << MI << "\n"); - if (CommonVL->isImm()) { - LLVM_DEBUG(dbgs() << " Reduce VL from " << VLOp << " to " - << CommonVL->getImm() << " for " << MI << "\n"); - VLOp.ChangeToImmediate(CommonVL->getImm()); - } else { - const MachineInstr *VLMI = MRI->getVRegDef(CommonVL->getReg()); - if (!MDT->dominates(VLMI, &MI)) - continue; - LLVM_DEBUG( - dbgs() << " Reduce VL from " << VLOp << " to " - << printReg(CommonVL->getReg(), MRI->getTargetRegisterInfo()) - << " for " << MI << "\n"); + // All our checks passed. We can reduce VL. + VLOp.ChangeToRegister(CommonVL->getReg(), false); + return true; +} - // All our checks passed. We can reduce VL. - VLOp.ChangeToRegister(CommonVL->getReg(), false); - } +bool RISCVVLOptimizer::runOnMachineFunction(MachineFunction &MF) { + if (skipFunction(MF.getFunction())) + return false; + + MRI = &MF.getRegInfo(); + MDT = &getAnalysis().getDomTree(); - MadeChange = true; + const RISCVSubtarget &ST = MF.getSubtarget(); + if (!ST.hasVInstructions()) + return false; - // Now add all inputs to this instruction to the worklist. + SetVector Worklist; + auto PushOperands = [this, &Worklist](MachineInstr &MI, + bool IgnoreSameBlock) { for (auto &Op : MI.operands()) { if (!Op.isReg() || !Op.isUse() || !Op.getReg().isVirtual()) continue; @@ -1351,34 +1358,40 @@ bool RISCVVLOptimizer::tryReduceVL(MachineInstr &OrigMI) { if (!isCandidate(*DefMI)) continue; + if (IgnoreSameBlock && DefMI->getParent() == MI.getParent()) + continue; + Worklist.insert(DefMI); } - } - - return MadeChange; -} - -bool RISCVVLOptimizer::runOnMachineFunction(MachineFunction &MF) { - if (skipFunction(MF.getFunction())) - return false; - - MRI = &MF.getRegInfo(); - MDT = &getAnalysis().getDomTree(); - - const RISCVSubtarget &ST = MF.getSubtarget(); - if (!ST.hasVInstructions()) - return false; + }; + // Do a first pass eagerly rewriting in roughly reverse instruction + // order, populate the worklist with any instructions we might need to + // revisit. We avoid adding definitions to the worklist if they're + // in the same block - we're about to visit them anyways. bool MadeChange = false; for (MachineBasicBlock &MBB : MF) { - // Visit instructions in reverse order. + // Avoid unreachable blocks as they have degenerate dominance + if (!MDT->isReachableFromEntry(&MBB)) + continue; + for (auto &MI : make_range(MBB.rbegin(), MBB.rend())) { if (!isCandidate(MI)) continue; - - MadeChange |= tryReduceVL(MI); + if (!tryReduceVL(MI)) + continue; + MadeChange = true; + PushOperands(MI, /*IgnoreSameBlock*/ true); } } + while (!Worklist.empty()) { + assert(MadeChange); + MachineInstr &MI = *Worklist.pop_back_val(); + if (!tryReduceVL(MI)) + continue; + PushOperands(MI, /*IgnoreSameBlock*/ false); + } + return MadeChange; }