@@ -49785,25 +49785,47 @@ static SDValue combineLoad(SDNode *N, SelectionDAG &DAG,
49785
49785
}
49786
49786
}
49787
49787
49788
- // If we also broadcast this as a subvector to a wider type, then just extract
49789
- // the lowest subvector.
49788
+ // If we also broadcast this to a wider type, then just extract the lowest
49789
+ // subvector.
49790
49790
if (Ext == ISD::NON_EXTLOAD && Subtarget.hasAVX() && Ld->isSimple() &&
49791
49791
(RegVT.is128BitVector() || RegVT.is256BitVector())) {
49792
49792
SDValue Ptr = Ld->getBasePtr();
49793
49793
SDValue Chain = Ld->getChain();
49794
- for (SDNode *User : Ptr->uses()) {
49795
- if (User != N && User->getOpcode() == X86ISD::SUBV_BROADCAST_LOAD &&
49796
- cast<MemIntrinsicSDNode>(User)->getBasePtr() == Ptr &&
49794
+ for (SDNode *User : Chain->uses()) {
49795
+ if (User != N &&
49796
+ (User->getOpcode() == X86ISD::SUBV_BROADCAST_LOAD ||
49797
+ User->getOpcode() == X86ISD::VBROADCAST_LOAD) &&
49797
49798
cast<MemIntrinsicSDNode>(User)->getChain() == Chain &&
49798
- cast<MemIntrinsicSDNode>(User)->getMemoryVT().getSizeInBits() ==
49799
- MemVT.getSizeInBits() &&
49800
49799
!User->hasAnyUseOfValue(1) &&
49801
49800
User->getValueSizeInBits(0).getFixedValue() >
49802
49801
RegVT.getFixedSizeInBits()) {
49803
- SDValue Extract = extractSubVector(SDValue(User, 0), 0, DAG, SDLoc(N),
49804
- RegVT.getSizeInBits());
49805
- Extract = DAG.getBitcast(RegVT, Extract);
49806
- return DCI.CombineTo(N, Extract, SDValue(User, 1));
49802
+ if (User->getOpcode() == X86ISD::SUBV_BROADCAST_LOAD &&
49803
+ cast<MemIntrinsicSDNode>(User)->getBasePtr() == Ptr &&
49804
+ cast<MemIntrinsicSDNode>(User)->getMemoryVT().getSizeInBits() ==
49805
+ MemVT.getSizeInBits()) {
49806
+ SDValue Extract = extractSubVector(SDValue(User, 0), 0, DAG, SDLoc(N),
49807
+ RegVT.getSizeInBits());
49808
+ Extract = DAG.getBitcast(RegVT, Extract);
49809
+ return DCI.CombineTo(N, Extract, SDValue(User, 1));
49810
+ }
49811
+ if (User->getOpcode() == X86ISD::VBROADCAST_LOAD &&
49812
+ getTargetConstantFromBasePtr(Ptr)) {
49813
+ // See if we are loading a constant that has also been broadcast.
49814
+ APInt Undefs, UserUndefs;
49815
+ SmallVector<APInt> Bits, UserBits;
49816
+ if (getTargetConstantBitsFromNode(SDValue(N, 0), 8, Undefs, Bits) &&
49817
+ getTargetConstantBitsFromNode(SDValue(User, 0), 8, UserUndefs,
49818
+ UserBits)) {
49819
+ UserUndefs = UserUndefs.trunc(Undefs.getBitWidth());
49820
+ UserBits.truncate(Bits.size());
49821
+ if (Bits == UserBits && UserUndefs.isSubsetOf(Undefs)) {
49822
+ SDValue Extract = extractSubVector(
49823
+ SDValue(User, 0), 0, DAG, SDLoc(N), RegVT.getSizeInBits());
49824
+ Extract = DAG.getBitcast(RegVT, Extract);
49825
+ return DCI.CombineTo(N, Extract, SDValue(User, 1));
49826
+ }
49827
+ }
49828
+ }
49807
49829
}
49808
49830
}
49809
49831
}
0 commit comments