From 2a78d8bb9f9c5ce23088223bbf9330c061c824a1 Mon Sep 17 00:00:00 2001 From: Han-Kuan Chen Date: Wed, 8 Jan 2025 05:36:42 -0800 Subject: [PATCH 1/6] rename IBase to MainOp --- llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index c4582df89213d..1a062a61761e0 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -916,6 +916,7 @@ static InstructionsState getSameOpcode(ArrayRef VL, return InstructionsState::invalid(); Value *V = *It; + Instruction *MainOp = cast(V); unsigned InstCnt = std::count_if(It, VL.end(), IsaPred); if ((VL.size() > 2 && !isa(V) && InstCnt < VL.size() / 2) || (VL.size() == 2 && InstCnt < 2)) @@ -955,10 +956,9 @@ static InstructionsState getSameOpcode(ArrayRef VL, }(); // Check for one alternate opcode from another BinaryOperator. // TODO - generalize to support all operators (types, calls etc.). - auto *IBase = cast(V); Intrinsic::ID BaseID = 0; SmallVector BaseMappings; - if (auto *CallBase = dyn_cast(IBase)) { + if (auto *CallBase = dyn_cast(MainOp)) { BaseID = getVectorIntrinsicIDForCall(CallBase, &TLI); BaseMappings = VFDatabase(*CallBase).getMappings(*CallBase); if (!isTriviallyVectorizable(BaseID) && BaseMappings.empty()) @@ -986,7 +986,7 @@ static InstructionsState getSameOpcode(ArrayRef VL, continue; } } else if (IsCastOp && isa(I)) { - Value *Op0 = IBase->getOperand(0); + Value *Op0 = MainOp->getOperand(0); Type *Ty0 = Op0->getType(); Value *Op1 = I->getOperand(0); Type *Ty1 = Op1->getType(); @@ -1045,17 +1045,17 @@ static InstructionsState getSameOpcode(ArrayRef VL, "CastInst."); if (auto *Gep = dyn_cast(I)) { if (Gep->getNumOperands() != 2 || - Gep->getOperand(0)->getType() != IBase->getOperand(0)->getType()) + Gep->getOperand(0)->getType() != MainOp->getOperand(0)->getType()) return InstructionsState::invalid(); } else if (auto *EI = dyn_cast(I)) { if (!isVectorLikeInstWithConstOps(EI)) return InstructionsState::invalid(); } else if (auto *LI = dyn_cast(I)) { - auto *BaseLI = cast(IBase); + auto *BaseLI = cast(MainOp); if (!LI->isSimple() || !BaseLI->isSimple()) return InstructionsState::invalid(); } else if (auto *Call = dyn_cast(I)) { - auto *CallBase = cast(IBase); + auto *CallBase = cast(MainOp); if (Call->getCalledFunction() != CallBase->getCalledFunction()) return InstructionsState::invalid(); if (Call->hasOperandBundles() && From 13628159e420214125edd0c51790a47c6b6d19e3 Mon Sep 17 00:00:00 2001 From: Han-Kuan Chen Date: Wed, 8 Jan 2025 05:37:56 -0800 Subject: [PATCH 2/6] rename V to MainOp --- .../Transforms/Vectorize/SLPVectorizer.cpp | 22 +++++++++---------- 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index 1a062a61761e0..ddd155af16057 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -915,19 +915,18 @@ static InstructionsState getSameOpcode(ArrayRef VL, if (It == VL.end()) return InstructionsState::invalid(); - Value *V = *It; - Instruction *MainOp = cast(V); + Instruction *MainOp = cast(*It); unsigned InstCnt = std::count_if(It, VL.end(), IsaPred); - if ((VL.size() > 2 && !isa(V) && InstCnt < VL.size() / 2) || + if ((VL.size() > 2 && !isa(MainOp) && InstCnt < VL.size() / 2) || (VL.size() == 2 && InstCnt < 2)) return InstructionsState::invalid(); - bool IsCastOp = isa(V); - bool IsBinOp = isa(V); - bool IsCmpOp = isa(V); - CmpInst::Predicate BasePred = - IsCmpOp ? cast(V)->getPredicate() : CmpInst::BAD_ICMP_PREDICATE; - unsigned Opcode = cast(V)->getOpcode(); + bool IsCastOp = isa(MainOp); + bool IsBinOp = isa(MainOp); + bool IsCmpOp = isa(MainOp); + CmpInst::Predicate BasePred = IsCmpOp ? cast(MainOp)->getPredicate() + : CmpInst::BAD_ICMP_PREDICATE; + unsigned Opcode = MainOp->getOpcode(); unsigned AltOpcode = Opcode; unsigned AltIndex = std::distance(VL.begin(), It); @@ -1003,7 +1002,7 @@ static InstructionsState getSameOpcode(ArrayRef VL, } } } else if (auto *Inst = dyn_cast(VL[Cnt]); Inst && IsCmpOp) { - auto *BaseInst = cast(V); + auto *BaseInst = cast(MainOp); Type *Ty0 = BaseInst->getOperand(0)->getType(); Type *Ty1 = Inst->getOperand(0)->getType(); if (Ty0 == Ty1) { @@ -1085,8 +1084,7 @@ static InstructionsState getSameOpcode(ArrayRef VL, return InstructionsState::invalid(); } - return InstructionsState(cast(V), - cast(VL[AltIndex])); + return InstructionsState(MainOp, cast(VL[AltIndex])); } /// \returns true if all of the values in \p VL have the same type or false From de4aee23bbb97f53ea1b8fcc8846f7beaf78f645 Mon Sep 17 00:00:00 2001 From: Han-Kuan Chen Date: Wed, 8 Jan 2025 05:50:53 -0800 Subject: [PATCH 3/6] move IsCmpOp out of lambda --- llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index ddd155af16057..041b11868e987 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -930,9 +930,7 @@ static InstructionsState getSameOpcode(ArrayRef VL, unsigned AltOpcode = Opcode; unsigned AltIndex = std::distance(VL.begin(), It); - bool SwappedPredsCompatible = [&]() { - if (!IsCmpOp) - return false; + bool SwappedPredsCompatible = IsCmpOp && [&]() { SetVector UniquePreds, UniqueNonSwappedPreds; UniquePreds.insert(BasePred); UniqueNonSwappedPreds.insert(BasePred); From caacc2e01b00e7d7714e960892696225fb9e7bbb Mon Sep 17 00:00:00 2001 From: Han-Kuan Chen Date: Thu, 9 Jan 2025 01:33:00 -0800 Subject: [PATCH 4/6] replace AltIndex with AltOp --- llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index 041b11868e987..b9999055bf011 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -926,9 +926,9 @@ static InstructionsState getSameOpcode(ArrayRef VL, bool IsCmpOp = isa(MainOp); CmpInst::Predicate BasePred = IsCmpOp ? cast(MainOp)->getPredicate() : CmpInst::BAD_ICMP_PREDICATE; + Instruction *AltOp = MainOp; unsigned Opcode = MainOp->getOpcode(); unsigned AltOpcode = Opcode; - unsigned AltIndex = std::distance(VL.begin(), It); bool SwappedPredsCompatible = IsCmpOp && [&]() { SetVector UniquePreds, UniqueNonSwappedPreds; @@ -979,7 +979,7 @@ static InstructionsState getSameOpcode(ArrayRef VL, if (Opcode == AltOpcode && isValidForAlternation(InstOpcode) && isValidForAlternation(Opcode)) { AltOpcode = InstOpcode; - AltIndex = Cnt; + AltOp = I; continue; } } else if (IsCastOp && isa(I)) { @@ -995,7 +995,7 @@ static InstructionsState getSameOpcode(ArrayRef VL, isValidForAlternation(InstOpcode) && "Cast isn't safe for alternation, logic needs to be updated!"); AltOpcode = InstOpcode; - AltIndex = Cnt; + AltOp = I; continue; } } @@ -1020,15 +1020,15 @@ static InstructionsState getSameOpcode(ArrayRef VL, if (isCmpSameOrSwapped(BaseInst, Inst, TLI)) continue; - auto *AltInst = cast(VL[AltIndex]); - if (AltIndex) { + auto *AltInst = cast(AltOp); + if (MainOp != AltOp) { if (isCmpSameOrSwapped(AltInst, Inst, TLI)) continue; } else if (BasePred != CurrentPred) { assert( isValidForAlternation(InstOpcode) && "CmpInst isn't safe for alternation, logic needs to be updated!"); - AltIndex = Cnt; + AltOp = I; continue; } CmpInst::Predicate AltPred = AltInst->getPredicate(); @@ -1082,7 +1082,7 @@ static InstructionsState getSameOpcode(ArrayRef VL, return InstructionsState::invalid(); } - return InstructionsState(MainOp, cast(VL[AltIndex])); + return InstructionsState(MainOp, AltOp); } /// \returns true if all of the values in \p VL have the same type or false From 37511d3bdae86dfa310d814dd014a02639506091 Mon Sep 17 00:00:00 2001 From: Han-Kuan Chen Date: Wed, 8 Jan 2025 06:10:38 -0800 Subject: [PATCH 5/6] skip the beginning PoisonValue and MainOp --- llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index b9999055bf011..97780cf8dbbbf 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -962,8 +962,9 @@ static InstructionsState getSameOpcode(ArrayRef VL, return InstructionsState::invalid(); } bool AnyPoison = InstCnt != VL.size(); - for (int Cnt = 0, E = VL.size(); Cnt < E; Cnt++) { - auto *I = dyn_cast(VL[Cnt]); + // Skip MainOp. + while (++It != VL.end()) { + auto *I = dyn_cast(*It); if (!I) continue; @@ -999,7 +1000,7 @@ static InstructionsState getSameOpcode(ArrayRef VL, continue; } } - } else if (auto *Inst = dyn_cast(VL[Cnt]); Inst && IsCmpOp) { + } else if (auto *Inst = dyn_cast(I); Inst && IsCmpOp) { auto *BaseInst = cast(MainOp); Type *Ty0 = BaseInst->getOperand(0)->getType(); Type *Ty1 = Inst->getOperand(0)->getType(); @@ -1014,7 +1015,7 @@ static InstructionsState getSameOpcode(ArrayRef VL, CmpInst::Predicate SwappedCurrentPred = CmpInst::getSwappedPredicate(CurrentPred); - if ((E == 2 || SwappedPredsCompatible) && + if ((VL.size() == 2 || SwappedPredsCompatible) && (BasePred == CurrentPred || BasePred == SwappedCurrentPred)) continue; From a076a51658978fd251fbbdf730c966732dd1a92f Mon Sep 17 00:00:00 2001 From: Han-Kuan Chen Date: Thu, 9 Jan 2025 08:02:19 -0800 Subject: [PATCH 6/6] apply comment --- llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index 97780cf8dbbbf..5241a599e0b25 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -963,8 +963,8 @@ static InstructionsState getSameOpcode(ArrayRef VL, } bool AnyPoison = InstCnt != VL.size(); // Skip MainOp. - while (++It != VL.end()) { - auto *I = dyn_cast(*It); + for (Value *V : iterator_range(It + 1, VL.end())) { + auto *I = dyn_cast(V); if (!I) continue;