diff --git a/llvm/lib/Target/RISCV/RISCVFoldMasks.cpp b/llvm/lib/Target/RISCV/RISCVFoldMasks.cpp index fddbaa97d0638..2089f5dda6fe5 100644 --- a/llvm/lib/Target/RISCV/RISCVFoldMasks.cpp +++ b/llvm/lib/Target/RISCV/RISCVFoldMasks.cpp @@ -47,10 +47,13 @@ class RISCVFoldMasks : public MachineFunctionPass { StringRef getPassName() const override { return "RISC-V Fold Masks"; } private: - bool convertToUnmasked(MachineInstr &MI, MachineInstr *MaskDef) const; - bool convertVMergeToVMv(MachineInstr &MI, MachineInstr *MaskDef) const; + bool convertToUnmasked(MachineInstr &MI) const; + bool convertVMergeToVMv(MachineInstr &MI) const; - bool isAllOnesMask(MachineInstr *MaskDef) const; + bool isAllOnesMask(const MachineInstr *MaskDef) const; + + /// Maps uses of V0 to the corresponding def of V0. + DenseMap V0Defs; }; } // namespace @@ -59,10 +62,9 @@ char RISCVFoldMasks::ID = 0; INITIALIZE_PASS(RISCVFoldMasks, DEBUG_TYPE, "RISC-V Fold Masks", false, false) -bool RISCVFoldMasks::isAllOnesMask(MachineInstr *MaskDef) const { - if (!MaskDef) - return false; - assert(MaskDef->isCopy() && MaskDef->getOperand(0).getReg() == RISCV::V0); +bool RISCVFoldMasks::isAllOnesMask(const MachineInstr *MaskDef) const { + assert(MaskDef && MaskDef->isCopy() && + MaskDef->getOperand(0).getReg() == RISCV::V0); Register SrcReg = TRI->lookThruCopyLike(MaskDef->getOperand(1).getReg(), MRI); if (!SrcReg.isVirtual()) return false; @@ -89,8 +91,7 @@ bool RISCVFoldMasks::isAllOnesMask(MachineInstr *MaskDef) const { // Transform (VMERGE_VVM_ false, false, true, allones, vl, sew) to // (VMV_V_V_ false, true, vl, sew). It may decrease uses of VMSET. -bool RISCVFoldMasks::convertVMergeToVMv(MachineInstr &MI, - MachineInstr *V0Def) const { +bool RISCVFoldMasks::convertVMergeToVMv(MachineInstr &MI) const { #define CASE_VMERGE_TO_VMV(lmul) \ case RISCV::PseudoVMERGE_VVM_##lmul: \ NewOpc = RISCV::PseudoVMV_V_V_##lmul; \ @@ -116,7 +117,7 @@ bool RISCVFoldMasks::convertVMergeToVMv(MachineInstr &MI, return false; assert(MI.getOperand(4).isReg() && MI.getOperand(4).getReg() == RISCV::V0); - if (!isAllOnesMask(V0Def)) + if (!isAllOnesMask(V0Defs.lookup(&MI))) return false; MI.setDesc(TII->get(NewOpc)); @@ -133,14 +134,13 @@ bool RISCVFoldMasks::convertVMergeToVMv(MachineInstr &MI, return true; } -bool RISCVFoldMasks::convertToUnmasked(MachineInstr &MI, - MachineInstr *MaskDef) const { +bool RISCVFoldMasks::convertToUnmasked(MachineInstr &MI) const { const RISCV::RISCVMaskedPseudoInfo *I = RISCV::getMaskedPseudoInfo(MI.getOpcode()); if (!I) return false; - if (!isAllOnesMask(MaskDef)) + if (!isAllOnesMask(V0Defs.lookup(&MI))) return false; // There are two classes of pseudos in the table - compares and @@ -198,20 +198,26 @@ bool RISCVFoldMasks::runOnMachineFunction(MachineFunction &MF) { // $v0:vr = COPY %mask:vr // %x:vr = Pseudo_MASK %a:vr, %b:br, $v0:vr // - // Because $v0 isn't in SSA, keep track of it so we can check the mask operand - // on each pseudo. - MachineInstr *CurrentV0Def; - for (MachineBasicBlock &MBB : MF) { - CurrentV0Def = nullptr; - for (MachineInstr &MI : MBB) { - Changed |= convertToUnmasked(MI, CurrentV0Def); - Changed |= convertVMergeToVMv(MI, CurrentV0Def); + // Because $v0 isn't in SSA, keep track of its definition at each use so we + // can check mask operands. + for (const MachineBasicBlock &MBB : MF) { + const MachineInstr *CurrentV0Def = nullptr; + for (const MachineInstr &MI : MBB) { + if (MI.readsRegister(RISCV::V0, TRI)) + V0Defs[&MI] = CurrentV0Def; if (MI.definesRegister(RISCV::V0, TRI)) CurrentV0Def = &MI; } } + for (MachineBasicBlock &MBB : MF) { + for (MachineInstr &MI : MBB) { + Changed |= convertToUnmasked(MI); + Changed |= convertVMergeToVMv(MI); + } + } + return Changed; }