diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 7485fc48f4132..d805a76754c71 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -39580,7 +39580,7 @@ static bool matchBinaryPermuteShuffle( static SDValue combineX86ShuffleChainWithExtract( ArrayRef Inputs, SDValue Root, ArrayRef BaseMask, int Depth, - bool HasVariableMask, bool AllowVariableCrossLaneMask, + ArrayRef SrcNodes, bool AllowVariableCrossLaneMask, bool AllowVariablePerLaneMask, SelectionDAG &DAG, const X86Subtarget &Subtarget); @@ -39595,7 +39595,7 @@ static SDValue combineX86ShuffleChainWithExtract( /// instruction but should only be used to replace chains over a certain depth. static SDValue combineX86ShuffleChain(ArrayRef Inputs, SDValue Root, ArrayRef BaseMask, int Depth, - bool HasVariableMask, + ArrayRef SrcNodes, bool AllowVariableCrossLaneMask, bool AllowVariablePerLaneMask, SelectionDAG &DAG, @@ -40064,6 +40064,10 @@ static SDValue combineX86ShuffleChain(ArrayRef Inputs, SDValue Root, if (Depth < 1) return SDValue(); + bool HasVariableMask = llvm::any_of(SrcNodes, [](const SDNode *N) { + return isTargetShuffleVariableMask(N->getOpcode()); + }); + // Depth threshold above which we can efficiently use variable mask shuffles. int VariableCrossLaneShuffleDepth = Subtarget.hasFastVariableCrossLaneShuffle() ? 1 : 2; @@ -40134,9 +40138,8 @@ static SDValue combineX86ShuffleChain(ArrayRef Inputs, SDValue Root, // If that failed and either input is extracted then try to combine as a // shuffle with the larger type. if (SDValue WideShuffle = combineX86ShuffleChainWithExtract( - Inputs, Root, BaseMask, Depth, HasVariableMask, - AllowVariableCrossLaneMask, AllowVariablePerLaneMask, DAG, - Subtarget)) + Inputs, Root, BaseMask, Depth, SrcNodes, AllowVariableCrossLaneMask, + AllowVariablePerLaneMask, DAG, Subtarget)) return WideShuffle; // If we have a dual input lane-crossing shuffle then lower to VPERMV3, @@ -40307,8 +40310,8 @@ static SDValue combineX86ShuffleChain(ArrayRef Inputs, SDValue Root, // If that failed and either input is extracted then try to combine as a // shuffle with the larger type. if (SDValue WideShuffle = combineX86ShuffleChainWithExtract( - Inputs, Root, BaseMask, Depth, HasVariableMask, - AllowVariableCrossLaneMask, AllowVariablePerLaneMask, DAG, Subtarget)) + Inputs, Root, BaseMask, Depth, SrcNodes, AllowVariableCrossLaneMask, + AllowVariablePerLaneMask, DAG, Subtarget)) return WideShuffle; // If we have a dual input shuffle then lower to VPERMV3, @@ -40346,7 +40349,7 @@ static SDValue combineX86ShuffleChain(ArrayRef Inputs, SDValue Root, // extract_subvector(shuffle(x,y,m2),0) static SDValue combineX86ShuffleChainWithExtract( ArrayRef Inputs, SDValue Root, ArrayRef BaseMask, int Depth, - bool HasVariableMask, bool AllowVariableCrossLaneMask, + ArrayRef SrcNodes, bool AllowVariableCrossLaneMask, bool AllowVariablePerLaneMask, SelectionDAG &DAG, const X86Subtarget &Subtarget) { unsigned NumMaskElts = BaseMask.size(); @@ -40475,7 +40478,7 @@ static SDValue combineX86ShuffleChainWithExtract( if (SDValue WideShuffle = combineX86ShuffleChain(WideInputs, WideRoot, WideMask, Depth, - HasVariableMask, AllowVariableCrossLaneMask, + SrcNodes, AllowVariableCrossLaneMask, AllowVariablePerLaneMask, DAG, Subtarget)) { WideShuffle = extractSubVector(WideShuffle, 0, DAG, SDLoc(Root), RootSizeInBits); @@ -40698,7 +40701,7 @@ static SDValue canonicalizeShuffleMaskWithHorizOp( // TODO: Extend this to merge multiple constant Ops and update the mask. static SDValue combineX86ShufflesConstants(MVT VT, ArrayRef Ops, ArrayRef Mask, - bool HasVariableMask, + ArrayRef SrcNodes, SelectionDAG &DAG, const SDLoc &DL, const X86Subtarget &Subtarget) { unsigned SizeInBits = VT.getSizeInBits(); @@ -40720,6 +40723,9 @@ static SDValue combineX86ShufflesConstants(MVT VT, ArrayRef Ops, // only used once or the combined shuffle has included a variable mask // shuffle, this is to avoid constant pool bloat. bool IsOptimizingSize = DAG.shouldOptForSize(); + bool HasVariableMask = llvm::any_of(SrcNodes, [](const SDNode *N) { + return isTargetShuffleVariableMask(N->getOpcode()); + }); if (IsOptimizingSize && !HasVariableMask && llvm::none_of(Ops, [](SDValue SrcOp) { return SrcOp->hasOneUse(); })) return SDValue(); @@ -40821,7 +40827,7 @@ namespace llvm { static SDValue combineX86ShufflesRecursively( ArrayRef SrcOps, int SrcOpIndex, SDValue Root, ArrayRef RootMask, ArrayRef SrcNodes, unsigned Depth, - unsigned MaxDepth, bool HasVariableMask, bool AllowVariableCrossLaneMask, + unsigned MaxDepth, bool AllowVariableCrossLaneMask, bool AllowVariablePerLaneMask, SelectionDAG &DAG, const X86Subtarget &Subtarget) { assert(!RootMask.empty() && @@ -40877,7 +40883,6 @@ static SDValue combineX86ShufflesRecursively( SmallVector OpMask; SmallVector OpInputs; APInt OpUndef, OpZero; - bool IsOpVariableMask = isTargetShuffleVariableMask(Op.getOpcode()); if (getTargetShuffleInputs(Op, OpDemandedElts, OpInputs, OpMask, OpUndef, OpZero, DAG, Depth, false)) { // Shuffle inputs must not be larger than the shuffle result. @@ -41092,7 +41097,6 @@ static SDValue combineX86ShufflesRecursively( return getOnesVector(RootVT, DAG, DL); assert(!Ops.empty() && "Shuffle with no inputs detected"); - HasVariableMask |= IsOpVariableMask; // Update the list of shuffle nodes that have been combined so far. SmallVector CombinedNodes(SrcNodes); @@ -41121,15 +41125,14 @@ static SDValue combineX86ShufflesRecursively( } if (SDValue Res = combineX86ShufflesRecursively( Ops, i, Root, ResolvedMask, CombinedNodes, Depth + 1, MaxDepth, - HasVariableMask, AllowCrossLaneVar, AllowPerLaneVar, DAG, - Subtarget)) + AllowCrossLaneVar, AllowPerLaneVar, DAG, Subtarget)) return Res; } } // Attempt to constant fold all of the constant source ops. if (SDValue Cst = combineX86ShufflesConstants( - RootVT, Ops, Mask, HasVariableMask, DAG, DL, Subtarget)) + RootVT, Ops, Mask, CombinedNodes, DAG, DL, Subtarget)) return Cst; // If constant fold failed and we only have constants - then we have @@ -41231,7 +41234,7 @@ static SDValue combineX86ShufflesRecursively( // Try to combine into a single shuffle instruction. if (SDValue Shuffle = combineX86ShuffleChain( - Ops, Root, Mask, Depth, HasVariableMask, AllowVariableCrossLaneMask, + Ops, Root, Mask, Depth, CombinedNodes, AllowVariableCrossLaneMask, AllowVariablePerLaneMask, DAG, Subtarget)) return Shuffle; @@ -41250,7 +41253,7 @@ static SDValue combineX86ShufflesRecursively( // If that failed and any input is extracted then try to combine as a // shuffle with the larger type. return combineX86ShuffleChainWithExtract( - Ops, Root, Mask, Depth, HasVariableMask, AllowVariableCrossLaneMask, + Ops, Root, Mask, Depth, CombinedNodes, AllowVariableCrossLaneMask, AllowVariablePerLaneMask, DAG, Subtarget); } @@ -41259,7 +41262,6 @@ static SDValue combineX86ShufflesRecursively(SDValue Op, SelectionDAG &DAG, const X86Subtarget &Subtarget) { return combineX86ShufflesRecursively( {Op}, 0, Op, {0}, {}, /*Depth*/ 0, X86::MaxShuffleCombineDepth, - /*HasVarMask*/ false, /*AllowCrossLaneVarMask*/ true, /*AllowPerLaneVarMask*/ true, DAG, Subtarget); } @@ -41897,7 +41899,7 @@ static SDValue combineTargetShuffle(SDValue N, const SDLoc &DL, if (SDValue Res = combineX86ShufflesRecursively( {BC}, 0, BC, DemandedMask, {}, /*Depth*/ 0, X86::MaxShuffleCombineDepth, - /*HasVarMask*/ false, /*AllowCrossLaneVarMask*/ true, + /*AllowCrossLaneVarMask*/ true, /*AllowPerLaneVarMask*/ true, DAG, Subtarget)) return DAG.getNode(X86ISD::VBROADCAST, DL, VT, DAG.getBitcast(SrcVT, Res)); @@ -42236,7 +42238,7 @@ static SDValue combineTargetShuffle(SDValue N, const SDLoc &DL, llvm::narrowShuffleMaskElts(EltBits / 8, Mask, ByteMask); if (SDValue NewMask = combineX86ShufflesConstants( ShufVT, {MaskLHS, MaskRHS}, ByteMask, - /*HasVariableMask=*/true, DAG, DL, Subtarget)) { + {LHS.getNode(), RHS.getNode()}, DAG, DL, Subtarget)) { SDValue NewLHS = DAG.getNode(X86ISD::PSHUFB, DL, ShufVT, LHS.getOperand(0), NewMask); SDValue NewRHS = DAG.getNode(X86ISD::PSHUFB, DL, ShufVT, @@ -43871,7 +43873,6 @@ bool X86TargetLowering::SimplifyDemandedVectorEltsForTargetNode( SDValue NewShuffle = combineX86ShufflesRecursively( {Op}, 0, Op, DemandedMask, {}, 0, X86::MaxShuffleCombineDepth - Depth, - /*HasVarMask*/ false, /*AllowCrossLaneVarMask*/ true, /*AllowPerLaneVarMask*/ true, TLO.DAG, Subtarget); if (NewShuffle) @@ -51430,7 +51431,7 @@ static SDValue combineAnd(SDNode *N, SelectionDAG &DAG, if (SDValue Shuffle = combineX86ShufflesRecursively( {SrcVec}, 0, SrcVec, ShuffleMask, {}, /*Depth*/ 1, X86::MaxShuffleCombineDepth, - /*HasVarMask*/ false, /*AllowVarCrossLaneMask*/ true, + /*AllowVarCrossLaneMask*/ true, /*AllowVarPerLaneMask*/ true, DAG, Subtarget)) return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, VT, Shuffle, N0.getOperand(1));