diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index c683857363720..bf91912245262 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -18237,41 +18237,21 @@ bool RISCVTargetLowering::isDesirableToCommuteWithShift( // LD/ST will optimize constant Offset extraction, so when AddNode is used by // LD/ST, it can still complete the folding optimization operation performed // above. - auto isUsedByLdSt = [&]() { - bool CanOptAlways = false; - if (N0->getOpcode() == ISD::ADD && !N0->hasOneUse()) { - for (SDNode *Use : N0->uses()) { - // This use is the one we're on right now. Skip it - if (Use == N || Use->getOpcode() == ISD::SELECT) - continue; - if (!isa(Use) && !isa(Use)) { - CanOptAlways = false; - break; - } - CanOptAlways = true; - } - } - - if (N0->getOpcode() == ISD::SIGN_EXTEND && - !N0->getOperand(0)->hasOneUse()) { - for (SDNode *Use : N0->getOperand(0)->uses()) { - // This use is the one we're on right now. Skip it - if (Use == N0.getNode() || Use->getOpcode() == ISD::SELECT) - continue; - if (!isa(Use) && !isa(Use)) { - CanOptAlways = false; - break; - } - CanOptAlways = true; - } + auto isUsedByLdSt = [](const SDNode *X, const SDNode *User) { + for (SDNode *Use : X->uses()) { + // This use is the one we're on right now. Skip it + if (Use == User || Use->getOpcode() == ISD::SELECT) + continue; + if (!isa(Use) && !isa(Use)) + return false; } - return CanOptAlways; + return true; }; if (Ty.isScalarInteger() && (N0.getOpcode() == ISD::ADD || N0.getOpcode() == ISD::OR)) { if (N0.getOpcode() == ISD::ADD && !N0->hasOneUse()) - return isUsedByLdSt(); + return isUsedByLdSt(N0.getNode(), N); auto *C1 = dyn_cast(N0->getOperand(1)); auto *C2 = dyn_cast(N->getOperand(1)); @@ -18314,7 +18294,7 @@ bool RISCVTargetLowering::isDesirableToCommuteWithShift( if (N0->getOpcode() == ISD::SIGN_EXTEND && N0->getOperand(0)->getOpcode() == ISD::ADD && !N0->getOperand(0)->hasOneUse()) - return isUsedByLdSt(); + return isUsedByLdSt(N0->getOperand(0).getNode(), N0.getNode()); return true; }