@@ -39803,6 +39803,65 @@ static SDValue combineANDXORWithAllOnesIntoANDNP(SDNode *N, SelectionDAG &DAG) {
39803
39803
return DAG.getNode(X86ISD::ANDNP, SDLoc(N), VT, X, Y);
39804
39804
}
39805
39805
39806
+ // Try to widen AND, OR and XOR nodes to VT in order to remove casts around
39807
+ // logical operations, like in the example below.
39808
+ // or (and (truncate x, truncate y)),
39809
+ // (xor (truncate z, build_vector (constants)))
39810
+ // Given a target type \p VT, we generate
39811
+ // or (and x, y), (xor z, zext(build_vector (constants)))
39812
+ // given x, y and z are of type \p VT. We can do so, if operands are either
39813
+ // truncates from VT types, the second operand is a vector of constants or can
39814
+ // be recursively promoted.
39815
+ static SDValue PromoteMaskArithmetic(SDNode *N, EVT VT, SelectionDAG &DAG,
39816
+ unsigned Depth) {
39817
+ // Limit recursion to avoid excessive compile times.
39818
+ if (Depth >= SelectionDAG::MaxRecursionDepth)
39819
+ return SDValue();
39820
+
39821
+ if (N->getOpcode() != ISD::XOR && N->getOpcode() != ISD::AND &&
39822
+ N->getOpcode() != ISD::OR)
39823
+ return SDValue();
39824
+
39825
+ SDValue N0 = N->getOperand(0);
39826
+ SDValue N1 = N->getOperand(1);
39827
+ SDLoc DL(N);
39828
+
39829
+ const TargetLowering &TLI = DAG.getTargetLoweringInfo();
39830
+ if (!TLI.isOperationLegalOrPromote(N->getOpcode(), VT))
39831
+ return SDValue();
39832
+
39833
+ if (SDValue NN0 = PromoteMaskArithmetic(N0.getNode(), VT, DAG, Depth + 1))
39834
+ N0 = NN0;
39835
+ else {
39836
+ // The Left side has to be a trunc.
39837
+ if (N0.getOpcode() != ISD::TRUNCATE)
39838
+ return SDValue();
39839
+
39840
+ // The type of the truncated inputs.
39841
+ if (N0.getOperand(0).getValueType() != VT)
39842
+ return SDValue();
39843
+
39844
+ N0 = N0.getOperand(0);
39845
+ }
39846
+
39847
+ if (SDValue NN1 = PromoteMaskArithmetic(N1.getNode(), VT, DAG, Depth + 1))
39848
+ N1 = NN1;
39849
+ else {
39850
+ // The right side has to be a 'trunc' or a constant vector.
39851
+ bool RHSTrunc = N1.getOpcode() == ISD::TRUNCATE &&
39852
+ N1.getOperand(0).getValueType() == VT;
39853
+ if (!RHSTrunc && !ISD::isBuildVectorOfConstantSDNodes(N1.getNode()))
39854
+ return SDValue();
39855
+
39856
+ if (RHSTrunc)
39857
+ N1 = N1.getOperand(0);
39858
+ else
39859
+ N1 = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N1);
39860
+ }
39861
+
39862
+ return DAG.getNode(N->getOpcode(), DL, VT, N0, N1);
39863
+ }
39864
+
39806
39865
// On AVX/AVX2 the type v8i1 is legalized to v8i16, which is an XMM sized
39807
39866
// register. In most cases we actually compare or select YMM-sized registers
39808
39867
// and mixing the two types creates horrible code. This method optimizes
@@ -39814,53 +39873,19 @@ static SDValue PromoteMaskArithmetic(SDNode *N, SelectionDAG &DAG,
39814
39873
EVT VT = N->getValueType(0);
39815
39874
assert(VT.isVector() && "Expected vector type");
39816
39875
39876
+ SDLoc DL(N);
39817
39877
assert((N->getOpcode() == ISD::ANY_EXTEND ||
39818
39878
N->getOpcode() == ISD::ZERO_EXTEND ||
39819
39879
N->getOpcode() == ISD::SIGN_EXTEND) && "Invalid Node");
39820
39880
39821
39881
SDValue Narrow = N->getOperand(0);
39822
39882
EVT NarrowVT = Narrow.getValueType();
39823
39883
39824
- if (Narrow->getOpcode() != ISD::XOR &&
39825
- Narrow->getOpcode() != ISD::AND &&
39826
- Narrow->getOpcode() != ISD::OR)
39827
- return SDValue();
39828
-
39829
- SDValue N0 = Narrow->getOperand(0);
39830
- SDValue N1 = Narrow->getOperand(1);
39831
- SDLoc DL(Narrow);
39832
-
39833
- // The Left side has to be a trunc.
39834
- if (N0.getOpcode() != ISD::TRUNCATE)
39835
- return SDValue();
39836
-
39837
- // The type of the truncated inputs.
39838
- if (N0.getOperand(0).getValueType() != VT)
39839
- return SDValue();
39840
-
39841
- // The right side has to be a 'trunc' or a constant vector.
39842
- bool RHSTrunc = N1.getOpcode() == ISD::TRUNCATE &&
39843
- N1.getOperand(0).getValueType() == VT;
39844
- if (!RHSTrunc &&
39845
- !ISD::isBuildVectorOfConstantSDNodes(N1.getNode()))
39846
- return SDValue();
39847
-
39848
- const TargetLowering &TLI = DAG.getTargetLoweringInfo();
39849
-
39850
- if (!TLI.isOperationLegalOrPromote(Narrow->getOpcode(), VT))
39851
- return SDValue();
39852
-
39853
- // Set N0 and N1 to hold the inputs to the new wide operation.
39854
- N0 = N0.getOperand(0);
39855
- if (RHSTrunc)
39856
- N1 = N1.getOperand(0);
39857
- else
39858
- N1 = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N1);
39859
-
39860
39884
// Generate the wide operation.
39861
- SDValue Op = DAG.getNode(Narrow->getOpcode(), DL, VT, N0, N1);
39862
- unsigned Opcode = N->getOpcode();
39863
- switch (Opcode) {
39885
+ SDValue Op = PromoteMaskArithmetic(Narrow.getNode(), VT, DAG, 0);
39886
+ if (!Op)
39887
+ return SDValue();
39888
+ switch (N->getOpcode()) {
39864
39889
default: llvm_unreachable("Unexpected opcode");
39865
39890
case ISD::ANY_EXTEND:
39866
39891
return Op;
0 commit comments