diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h index 87b6914f8a0ee..40550d96a5b3d 100644 --- a/llvm/include/llvm/CodeGen/SelectionDAG.h +++ b/llvm/include/llvm/CodeGen/SelectionDAG.h @@ -1873,7 +1873,7 @@ class SelectionDAG { /// chain to the token factor. This ensures that the new memory node will have /// the same relative memory dependency position as the old load. Returns the /// new merged load chain. - SDValue makeEquivalentMemoryOrdering(LoadSDNode *OldLoad, SDValue NewMemOp); + SDValue makeEquivalentMemoryOrdering(MemSDNode *OldLoad, SDValue NewMemOp); /// Topological-sort the AllNodes list and a /// assign a unique node id for each node in the DAG based on their diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index e6a7d092b7b79..1c1445f9f44b7 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -12236,7 +12236,7 @@ SDValue SelectionDAG::makeEquivalentMemoryOrdering(SDValue OldChain, return TokenFactor; } -SDValue SelectionDAG::makeEquivalentMemoryOrdering(LoadSDNode *OldLoad, +SDValue SelectionDAG::makeEquivalentMemoryOrdering(MemSDNode *OldLoad, SDValue NewMemOp) { assert(isa(NewMemOp.getNode()) && "Expected a memop node"); SDValue OldChain = SDValue(OldLoad, 1); diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 1fc50a36fda72..5ce8d83feb0dd 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -7193,15 +7193,19 @@ static SDValue LowerAsSplatVectorLoad(SDValue SrcOp, MVT VT, const SDLoc &dl, } // Recurse to find a LoadSDNode source and the accumulated ByteOffest. -static bool findEltLoadSrc(SDValue Elt, LoadSDNode *&Ld, int64_t &ByteOffset) { - if (ISD::isNON_EXTLoad(Elt.getNode())) { - auto *BaseLd = cast(Elt); - if (!BaseLd->isSimple()) - return false; +static bool findEltLoadSrc(SDValue Elt, MemSDNode *&Ld, int64_t &ByteOffset) { + if (auto *BaseLd = dyn_cast(Elt)) { Ld = BaseLd; ByteOffset = 0; return true; - } + } else if (auto *BaseLd = dyn_cast(Elt)) + if (ISD::isNON_EXTLoad(Elt.getNode())) { + if (!BaseLd->isSimple()) + return false; + Ld = BaseLd; + ByteOffset = 0; + return true; + } switch (Elt.getOpcode()) { case ISD::BITCAST: @@ -7254,7 +7258,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef Elts, APInt ZeroMask = APInt::getZero(NumElems); APInt UndefMask = APInt::getZero(NumElems); - SmallVector Loads(NumElems, nullptr); + SmallVector Loads(NumElems, nullptr); SmallVector ByteOffsets(NumElems, 0); // For each element in the initializer, see if we've found a load, zero or an @@ -7304,7 +7308,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef Elts, EVT EltBaseVT = EltBase.getValueType(); assert(EltBaseVT.getSizeInBits() == EltBaseVT.getStoreSizeInBits() && "Register/Memory size mismatch"); - LoadSDNode *LDBase = Loads[FirstLoadedElt]; + MemSDNode *LDBase = Loads[FirstLoadedElt]; assert(LDBase && "Did not find base load for merging consecutive loads"); unsigned BaseSizeInBits = EltBaseVT.getStoreSizeInBits(); unsigned BaseSizeInBytes = BaseSizeInBits / 8; @@ -7318,15 +7322,18 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef Elts, // Check to see if the element's load is consecutive to the base load // or offset from a previous (already checked) load. - auto CheckConsecutiveLoad = [&](LoadSDNode *Base, int EltIdx) { - LoadSDNode *Ld = Loads[EltIdx]; + auto CheckConsecutiveLoad = [&](MemSDNode *Base, int EltIdx) { + MemSDNode *Ld = Loads[EltIdx]; int64_t ByteOffset = ByteOffsets[EltIdx]; if (ByteOffset && (ByteOffset % BaseSizeInBytes) == 0) { int64_t BaseIdx = EltIdx - (ByteOffset / BaseSizeInBytes); return (0 <= BaseIdx && BaseIdx < (int)NumElems && LoadMask[BaseIdx] && Loads[BaseIdx] == Ld && ByteOffsets[BaseIdx] == 0); } - return DAG.areNonVolatileConsecutiveLoads(Ld, Base, BaseSizeInBytes, + auto *L = dyn_cast(Ld); + auto *B = dyn_cast(Base); + return L && B && + DAG.areNonVolatileConsecutiveLoads(L, B, BaseSizeInBytes, EltIdx - FirstLoadedElt); }; @@ -7347,7 +7354,7 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef Elts, } } - auto CreateLoad = [&DAG, &DL, &Loads](EVT VT, LoadSDNode *LDBase) { + auto CreateLoad = [&DAG, &DL, &Loads](EVT VT, MemSDNode *LDBase) { auto MMOFlags = LDBase->getMemOperand()->getFlags(); assert(LDBase->isSimple() && "Cannot merge volatile or atomic loads."); @@ -60539,6 +60546,35 @@ static SDValue combineINTRINSIC_VOID(SDNode *N, SelectionDAG &DAG, return SDValue(); } +static SDValue combineVZEXT_LOAD(SDNode *N, SelectionDAG &DAG, + TargetLowering::DAGCombinerInfo &DCI) { + // Find the TokenFactor to locate the associated AtomicLoad. + SDNode *ALD = nullptr; + for (auto &TF : N->uses()) + if (TF.getUser()->getOpcode() == ISD::TokenFactor) { + SDValue L = TF.getUser()->getOperand(0); + SDValue R = TF.getUser()->getOperand(1); + if (L.getNode() == N) + ALD = R.getNode(); + else if (R.getNode() == N) + ALD = L.getNode(); + } + + if (!ALD) + return SDValue(); + if (!isa(ALD)) + return SDValue(); + + // Replace the VZEXT_LOAD with the AtomicLoad. + SDLoc dl(N); + SDValue SV = + DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, + N->getValueType(0).changeTypeToInteger(), SDValue(ALD, 0)); + SDValue BC = DAG.getNode(ISD::BITCAST, dl, N->getValueType(0), SV); + BC = DCI.CombineTo(N, BC, SDValue(ALD, 1)); + return BC; +} + SDValue X86TargetLowering::PerformDAGCombine(SDNode *N, DAGCombinerInfo &DCI) const { SelectionDAG &DAG = DCI.DAG; @@ -60735,6 +60771,7 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N, case ISD::INTRINSIC_VOID: return combineINTRINSIC_VOID(N, DAG, DCI); case ISD::FP_TO_SINT_SAT: case ISD::FP_TO_UINT_SAT: return combineFP_TO_xINT_SAT(N, DAG, Subtarget); + case X86ISD::VZEXT_LOAD: return combineVZEXT_LOAD(N, DAG, DCI); // clang-format on }