Skip to content
This repository was archived by the owner on Apr 23, 2020. It is now read-only.

Commit 7840500

Browse files
author
Elena Demikhovsky
committed
AVX512F: FMA intrinsic + FNEG - sequence optimization
The previous commit (r280368 - https://reviews.llvm.org/D23313) does not cover AVX-512F, KNL set. FNEG(x) operation is lowered to (bitcast (vpxor (bitcast x), (bitcast constfp(0x80000000))). It happens because FP XOR is not supported for 512-bit data types on KNL and we use integer XOR instead. I added pattern match for integer XOR. Differential Revision: https://reviews.llvm.org/D24221 git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@280785 91177308-0d34-0410-b5e6-96231b3b80d8
1 parent 1da68f6 commit 7840500

File tree

2 files changed

+117
-110
lines changed

2 files changed

+117
-110
lines changed

lib/Target/X86/X86ISelLowering.cpp

+102-90
Original file line numberDiff line numberDiff line change
@@ -29233,28 +29233,6 @@ static SDValue foldVectorXorShiftIntoCmp(SDNode *N, SelectionDAG &DAG,
2923329233
return DAG.getNode(X86ISD::PCMPGT, SDLoc(N), VT, Shift.getOperand(0), Ones);
2923429234
}
2923529235

29236-
static SDValue combineXor(SDNode *N, SelectionDAG &DAG,
29237-
TargetLowering::DAGCombinerInfo &DCI,
29238-
const X86Subtarget &Subtarget) {
29239-
if (SDValue Cmp = foldVectorXorShiftIntoCmp(N, DAG, Subtarget))
29240-
return Cmp;
29241-
29242-
if (DCI.isBeforeLegalizeOps())
29243-
return SDValue();
29244-
29245-
if (SDValue RV = foldXorTruncShiftIntoCmp(N, DAG))
29246-
return RV;
29247-
29248-
if (Subtarget.hasCMov())
29249-
if (SDValue RV = combineIntegerAbs(N, DAG))
29250-
return RV;
29251-
29252-
if (SDValue FPLogic = convertIntLogicToFPLogic(N, DAG, Subtarget))
29253-
return FPLogic;
29254-
29255-
return SDValue();
29256-
}
29257-
2925829236
/// This function detects the AVG pattern between vectors of unsigned i8/i16,
2925929237
/// which is c = (a + b + 1) / 2, and replace this operation with the efficient
2926029238
/// X86ISD::AVG instruction.
@@ -30363,12 +30341,68 @@ static SDValue combineTruncate(SDNode *N, SelectionDAG &DAG,
3036330341
return combineVectorTruncation(N, DAG, Subtarget);
3036430342
}
3036530343

30344+
/// Returns the negated value if the node \p N flips sign of FP value.
30345+
///
30346+
/// FP-negation node may have different forms: FNEG(x) or FXOR (x, 0x80000000).
30347+
/// AVX512F does not have FXOR, so FNEG is lowered as
30348+
/// (bitcast (xor (bitcast x), (bitcast ConstantFP(0x80000000)))).
30349+
/// In this case we go though all bitcasts.
30350+
static SDValue isFNEG(SDNode *N) {
30351+
if (N->getOpcode() == ISD::FNEG)
30352+
return N->getOperand(0);
30353+
30354+
SDValue Op = peekThroughBitcasts(SDValue(N, 0));
30355+
if (Op.getOpcode() != X86ISD::FXOR && Op.getOpcode() != ISD::XOR)
30356+
return SDValue();
30357+
30358+
SDValue Op1 = peekThroughBitcasts(Op.getOperand(1));
30359+
if (!Op1.getValueType().isFloatingPoint())
30360+
return SDValue();
30361+
30362+
SDValue Op0 = peekThroughBitcasts(Op.getOperand(0));
30363+
30364+
unsigned EltBits = Op1.getValueType().getScalarSizeInBits();
30365+
auto isSignBitValue = [&](const ConstantFP *C) {
30366+
return C->getValueAPF().bitcastToAPInt() == APInt::getSignBit(EltBits);
30367+
};
30368+
30369+
// There is more than one way to represent the same constant on
30370+
// the different X86 targets. The type of the node may also depend on size.
30371+
// - load scalar value and broadcast
30372+
// - BUILD_VECTOR node
30373+
// - load from a constant pool.
30374+
// We check all variants here.
30375+
if (Op1.getOpcode() == X86ISD::VBROADCAST) {
30376+
if (auto *C = getTargetConstantFromNode(Op1.getOperand(0)))
30377+
if (isSignBitValue(cast<ConstantFP>(C)))
30378+
return Op0;
30379+
30380+
} else if (BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(Op1)) {
30381+
if (ConstantFPSDNode *CN = BV->getConstantFPSplatNode())
30382+
if (isSignBitValue(CN->getConstantFPValue()))
30383+
return Op0;
30384+
30385+
} else if (auto *C = getTargetConstantFromNode(Op1)) {
30386+
if (C->getType()->isVectorTy()) {
30387+
if (auto *SplatV = C->getSplatValue())
30388+
if (isSignBitValue(cast<ConstantFP>(SplatV)))
30389+
return Op0;
30390+
} else if (auto *FPConst = dyn_cast<ConstantFP>(C))
30391+
if (isSignBitValue(FPConst))
30392+
return Op0;
30393+
}
30394+
return SDValue();
30395+
}
30396+
3036630397
/// Do target-specific dag combines on floating point negations.
3036730398
static SDValue combineFneg(SDNode *N, SelectionDAG &DAG,
3036830399
const X86Subtarget &Subtarget) {
30369-
EVT VT = N->getValueType(0);
30400+
EVT OrigVT = N->getValueType(0);
30401+
SDValue Arg = isFNEG(N);
30402+
assert(Arg.getNode() && "N is expected to be an FNEG node");
30403+
30404+
EVT VT = Arg.getValueType();
3037030405
EVT SVT = VT.getScalarType();
30371-
SDValue Arg = N->getOperand(0);
3037230406
SDLoc DL(N);
3037330407

3037430408
// Let legalize expand this if it isn't a legal type yet.
@@ -30381,40 +30415,30 @@ static SDValue combineFneg(SDNode *N, SelectionDAG &DAG,
3038130415
if (Arg.getOpcode() == ISD::FMUL && (SVT == MVT::f32 || SVT == MVT::f64) &&
3038230416
Arg->getFlags()->hasNoSignedZeros() && Subtarget.hasAnyFMA()) {
3038330417
SDValue Zero = DAG.getConstantFP(0.0, DL, VT);
30384-
return DAG.getNode(X86ISD::FNMSUB, DL, VT, Arg.getOperand(0),
30385-
Arg.getOperand(1), Zero);
30418+
SDValue NewNode = DAG.getNode(X86ISD::FNMSUB, DL, VT, Arg.getOperand(0),
30419+
Arg.getOperand(1), Zero);
30420+
return DAG.getBitcast(OrigVT, NewNode);
3038630421
}
3038730422

3038830423
// If we're negating a FMA node, then we can adjust the
3038930424
// instruction to include the extra negation.
30425+
unsigned NewOpcode = 0;
3039030426
if (Arg.hasOneUse()) {
3039130427
switch (Arg.getOpcode()) {
30392-
case X86ISD::FMADD:
30393-
return DAG.getNode(X86ISD::FNMSUB, DL, VT, Arg.getOperand(0),
30394-
Arg.getOperand(1), Arg.getOperand(2));
30395-
case X86ISD::FMSUB:
30396-
return DAG.getNode(X86ISD::FNMADD, DL, VT, Arg.getOperand(0),
30397-
Arg.getOperand(1), Arg.getOperand(2));
30398-
case X86ISD::FNMADD:
30399-
return DAG.getNode(X86ISD::FMSUB, DL, VT, Arg.getOperand(0),
30400-
Arg.getOperand(1), Arg.getOperand(2));
30401-
case X86ISD::FNMSUB:
30402-
return DAG.getNode(X86ISD::FMADD, DL, VT, Arg.getOperand(0),
30403-
Arg.getOperand(1), Arg.getOperand(2));
30404-
case X86ISD::FMADD_RND:
30405-
return DAG.getNode(X86ISD::FNMSUB_RND, DL, VT, Arg.getOperand(0),
30406-
Arg.getOperand(1), Arg.getOperand(2), Arg.getOperand(3));
30407-
case X86ISD::FMSUB_RND:
30408-
return DAG.getNode(X86ISD::FNMADD_RND, DL, VT, Arg.getOperand(0),
30409-
Arg.getOperand(1), Arg.getOperand(2), Arg.getOperand(3));
30410-
case X86ISD::FNMADD_RND:
30411-
return DAG.getNode(X86ISD::FMSUB_RND, DL, VT, Arg.getOperand(0),
30412-
Arg.getOperand(1), Arg.getOperand(2), Arg.getOperand(3));
30413-
case X86ISD::FNMSUB_RND:
30414-
return DAG.getNode(X86ISD::FMADD_RND, DL, VT, Arg.getOperand(0),
30415-
Arg.getOperand(1), Arg.getOperand(2), Arg.getOperand(3));
30428+
case X86ISD::FMADD: NewOpcode = X86ISD::FNMSUB; break;
30429+
case X86ISD::FMSUB: NewOpcode = X86ISD::FNMADD; break;
30430+
case X86ISD::FNMADD: NewOpcode = X86ISD::FMSUB; break;
30431+
case X86ISD::FNMSUB: NewOpcode = X86ISD::FMADD; break;
30432+
case X86ISD::FMADD_RND: NewOpcode = X86ISD::FNMSUB_RND; break;
30433+
case X86ISD::FMSUB_RND: NewOpcode = X86ISD::FNMADD_RND; break;
30434+
case X86ISD::FNMADD_RND: NewOpcode = X86ISD::FMSUB_RND; break;
30435+
case X86ISD::FNMSUB_RND: NewOpcode = X86ISD::FMADD_RND; break;
3041630436
}
3041730437
}
30438+
if (NewOpcode)
30439+
return DAG.getBitcast(OrigVT, DAG.getNode(NewOpcode, DL, VT,
30440+
Arg.getNode()->ops()));
30441+
3041830442
return SDValue();
3041930443
}
3042030444

@@ -30442,42 +30466,28 @@ static SDValue lowerX86FPLogicOp(SDNode *N, SelectionDAG &DAG,
3044230466
return SDValue();
3044330467
}
3044430468

30445-
/// Returns true if the node \p N is FNEG(x) or FXOR (x, 0x80000000).
30446-
bool isFNEG(const SDNode *N) {
30447-
if (N->getOpcode() == ISD::FNEG)
30448-
return true;
30469+
static SDValue combineXor(SDNode *N, SelectionDAG &DAG,
30470+
TargetLowering::DAGCombinerInfo &DCI,
30471+
const X86Subtarget &Subtarget) {
30472+
if (SDValue Cmp = foldVectorXorShiftIntoCmp(N, DAG, Subtarget))
30473+
return Cmp;
3044930474

30450-
if (N->getOpcode() == X86ISD::FXOR) {
30451-
unsigned EltBits = N->getSimpleValueType(0).getScalarSizeInBits();
30452-
SDValue Op1 = N->getOperand(1);
30475+
if (DCI.isBeforeLegalizeOps())
30476+
return SDValue();
3045330477

30454-
auto isSignBitValue = [&](const ConstantFP *C) {
30455-
return C->getValueAPF().bitcastToAPInt() == APInt::getSignBit(EltBits);
30456-
};
30478+
if (SDValue RV = foldXorTruncShiftIntoCmp(N, DAG))
30479+
return RV;
3045730480

30458-
// There is more than one way to represent the same constant on
30459-
// the different X86 targets. The type of the node may also depend on size.
30460-
// - load scalar value and broadcast
30461-
// - BUILD_VECTOR node
30462-
// - load from a constant pool.
30463-
// We check all variants here.
30464-
if (Op1.getOpcode() == X86ISD::VBROADCAST) {
30465-
if (auto *C = getTargetConstantFromNode(Op1.getOperand(0)))
30466-
return isSignBitValue(cast<ConstantFP>(C));
30481+
if (Subtarget.hasCMov())
30482+
if (SDValue RV = combineIntegerAbs(N, DAG))
30483+
return RV;
3046730484

30468-
} else if (BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(Op1)) {
30469-
if (ConstantFPSDNode *CN = BV->getConstantFPSplatNode())
30470-
return isSignBitValue(CN->getConstantFPValue());
30485+
if (SDValue FPLogic = convertIntLogicToFPLogic(N, DAG, Subtarget))
30486+
return FPLogic;
3047130487

30472-
} else if (auto *C = getTargetConstantFromNode(Op1)) {
30473-
if (C->getType()->isVectorTy()) {
30474-
if (auto *SplatV = C->getSplatValue())
30475-
return isSignBitValue(cast<ConstantFP>(SplatV));
30476-
} else if (auto *FPConst = dyn_cast<ConstantFP>(C))
30477-
return isSignBitValue(FPConst);
30478-
}
30479-
}
30480-
return false;
30488+
if (isFNEG(N))
30489+
return combineFneg(N, DAG, Subtarget);
30490+
return SDValue();
3048130491
}
3048230492

3048330493
/// Do target-specific dag combines on X86ISD::FOR and X86ISD::FXOR nodes.
@@ -30907,18 +30917,20 @@ static SDValue combineFMA(SDNode *N, SelectionDAG &DAG,
3090730917
SDValue B = N->getOperand(1);
3090830918
SDValue C = N->getOperand(2);
3090930919

30910-
bool NegA = isFNEG(A.getNode());
30911-
bool NegB = isFNEG(B.getNode());
30912-
bool NegC = isFNEG(C.getNode());
30920+
auto invertIfNegative = [](SDValue &V) {
30921+
if (SDValue NegVal = isFNEG(V.getNode())) {
30922+
V = NegVal;
30923+
return true;
30924+
}
30925+
return false;
30926+
};
30927+
30928+
bool NegA = invertIfNegative(A);
30929+
bool NegB = invertIfNegative(B);
30930+
bool NegC = invertIfNegative(C);
3091330931

3091430932
// Negative multiplication when NegA xor NegB
3091530933
bool NegMul = (NegA != NegB);
30916-
if (NegA)
30917-
A = A.getOperand(0);
30918-
if (NegB)
30919-
B = B.getOperand(0);
30920-
if (NegC)
30921-
C = C.getOperand(0);
3092230934

3092330935
unsigned NewOpcode;
3092430936
if (!NegMul)

test/CodeGen/X86/fma-fneg-combine.ll

+15-20
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
2-
; RUN: llc < %s -mtriple=x86_64-unknown-linux-gnu -mattr=+avx512bw -mattr=+avx512vl -mattr=+avx512dq | FileCheck %s
2+
; RUN: llc < %s -mtriple=x86_64-unknown-linux-gnu -mattr=+avx512bw -mattr=+avx512vl -mattr=+avx512dq | FileCheck %s --check-prefix=CHECK --check-prefix=SKX
3+
; RUN: llc < %s -mtriple=x86_64-unknown-linux-gnu -mattr=+avx512f -mattr=+fma | FileCheck %s --check-prefix=CHECK --check-prefix=KNL
34

45
; This test checks combinations of FNEG and FMA intrinsics on AVX-512 target
56
; PR28892
@@ -88,11 +89,18 @@ entry:
8889
}
8990

9091
define <8 x float> @test8(<8 x float> %a, <8 x float> %b, <8 x float> %c) {
91-
; CHECK-LABEL: test8:
92-
; CHECK: # BB#0: # %entry
93-
; CHECK-NEXT: vxorps {{.*}}(%rip){1to8}, %ymm2, %ymm2
94-
; CHECK-NEXT: vfmsub213ps %ymm2, %ymm1, %ymm0
95-
; CHECK-NEXT: retq
92+
; SKX-LABEL: test8:
93+
; SKX: # BB#0: # %entry
94+
; SKX-NEXT: vxorps {{.*}}(%rip){1to8}, %ymm2, %ymm2
95+
; SKX-NEXT: vfmsub213ps %ymm2, %ymm1, %ymm0
96+
; SKX-NEXT: retq
97+
;
98+
; KNL-LABEL: test8:
99+
; KNL: # BB#0: # %entry
100+
; KNL-NEXT: vbroadcastss {{.*}}(%rip), %ymm3
101+
; KNL-NEXT: vxorps %ymm3, %ymm2, %ymm2
102+
; KNL-NEXT: vfmsub213ps %ymm2, %ymm1, %ymm0
103+
; KNL-NEXT: retq
96104
entry:
97105
%sub.c = fsub <8 x float> <float -0.000000e+00, float -0.000000e+00, float -0.000000e+00, float -0.000000e+00, float -0.000000e+00, float -0.000000e+00, float -0.000000e+00, float -0.000000e+00>, %c
98106
%0 = tail call <8 x float> @llvm.x86.fma.vfmsub.ps.256(<8 x float> %a, <8 x float> %b, <8 x float> %sub.c) #2
@@ -115,22 +123,9 @@ entry:
115123

116124
declare <8 x double> @llvm.x86.avx512.mask.vfmadd.pd.512(<8 x double> %a, <8 x double> %b, <8 x double> %c, i8, i32)
117125

118-
define <4 x double> @test10(<4 x double> %a, <4 x double> %b, <4 x double> %c) {
126+
define <2 x double> @test10(<2 x double> %a, <2 x double> %b, <2 x double> %c) {
119127
; CHECK-LABEL: test10:
120128
; CHECK: # BB#0: # %entry
121-
; CHECK-NEXT: vfnmsub213pd %ymm2, %ymm1, %ymm0
122-
; CHECK-NEXT: retq
123-
entry:
124-
%0 = tail call <4 x double> @llvm.x86.avx512.mask.vfmadd.pd.256(<4 x double> %a, <4 x double> %b, <4 x double> %c, i8 -1) #2
125-
%sub.i = fsub <4 x double> <double -0.000000e+00, double -0.000000e+00, double -0.000000e+00, double -0.000000e+00>, %0
126-
ret <4 x double> %sub.i
127-
}
128-
129-
declare <4 x double> @llvm.x86.avx512.mask.vfmadd.pd.256(<4 x double> %a, <4 x double> %b, <4 x double> %c, i8)
130-
131-
define <2 x double> @test11(<2 x double> %a, <2 x double> %b, <2 x double> %c) {
132-
; CHECK-LABEL: test11:
133-
; CHECK: # BB#0: # %entry
134129
; CHECK-NEXT: vfnmsub213sd %xmm2, %xmm0, %xmm1
135130
; CHECK-NEXT: vmovaps %xmm1, %xmm0
136131
; CHECK-NEXT: retq

0 commit comments

Comments
 (0)