Skip to content

Commit 2f0d12a

Browse files
committed
[InstCombine] Transform high latency, dependent FSQRT/FDIV into FMUL
The proposed patch, in general, tries to transform the below code sequence: x = 1.0 / sqrt (a); r1 = x * x; // same as 1.0 / a r2 = a * x; // same as sqrt (a) TO (If x, r1 and r2 are all used further in the code) tmp1 = 1.0 / a tmp2 = sqrt (a) tmp3 = tmp1 * tmp2 x = tmp3 r1 = tmp1 r2 = tmp2 The transform tries to make high latency sqrt and div operations independent and also saves on one multiplication. The patch was tested with SPEC17 suite with cpu=neoverse-v2. The performance uplift achieved was: 544.nab_r ~4% No other regressions were observed. Also, no compile time differences were observed with the patch. Closes #54652
1 parent e05c1b4 commit 2f0d12a

File tree

2 files changed

+637
-3
lines changed

2 files changed

+637
-3
lines changed

llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp

Lines changed: 174 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -626,6 +626,127 @@ Instruction *InstCombinerImpl::foldPowiReassoc(BinaryOperator &I) {
626626
return nullptr;
627627
}
628628

629+
bool isFSqrtDivToFMulLegal(Instruction *X, SmallSetVector<Instruction *, 2> &R1,
630+
SmallSetVector<Instruction *, 2> &R2) {
631+
632+
BasicBlock *BBx = X->getParent();
633+
BasicBlock *BBr1 = R1[0]->getParent();
634+
BasicBlock *BBr2 = R2[0]->getParent();
635+
636+
auto IsStrictFP = [](Instruction *I) {
637+
IntrinsicInst *II = dyn_cast<IntrinsicInst>(I);
638+
return II && II->isStrictFP();
639+
};
640+
641+
// Check the constaints on instruction X.
642+
auto XConstraintsSatisfied = [X, &IsStrictFP]() {
643+
if (IsStrictFP(X))
644+
return false;
645+
// X must atleast have 4 uses.
646+
// 3 uses as part of
647+
// r1 = x * x
648+
// r2 = a * x
649+
// Now, post-transform, r1/r2 will no longer have usage of 'x' and if the
650+
// changes to 'x' need to persist, we must have one more usage of 'x'
651+
if (!X->hasNUsesOrMore(4))
652+
return false;
653+
// Check if reciprocalFP is enabled.
654+
bool RecipFPMath = dyn_cast<FPMathOperator>(X)->hasAllowReciprocal();
655+
return RecipFPMath;
656+
};
657+
if (!XConstraintsSatisfied())
658+
return false;
659+
660+
// Check the constraints on instructions in R1.
661+
auto R1ConstraintsSatisfied = [BBr1, &IsStrictFP](Instruction *I) {
662+
if (IsStrictFP(I))
663+
return false;
664+
// When you have multiple instructions residing in R1 and R2 respectively,
665+
// it's difficult to generate combinations of (R1,R2) and then check if we
666+
// have the required pattern. So, for now, just be conservative.
667+
if (I->getParent() != BBr1)
668+
return false;
669+
if (!I->hasNUsesOrMore(1))
670+
return false;
671+
// The optimization tries to convert
672+
// R1 = div * div where, div = 1/sqrt(a)
673+
// to
674+
// R1 = 1/a
675+
// Now, this simplication does not work because sqrt(a)=NaN when a<0
676+
if (!I->hasNoNaNs())
677+
return false;
678+
// sqrt(-0.0) = -0.0, and doing this simplication would change the sign of
679+
// the result.
680+
return I->hasNoSignedZeros();
681+
};
682+
if (!std::all_of(R1.begin(), R1.end(), R1ConstraintsSatisfied))
683+
return false;
684+
685+
// Check the constraints on instructions in R2.
686+
auto R2ConstraintsSatisfied = [BBr2, &IsStrictFP](Instruction *I) {
687+
if (IsStrictFP(I))
688+
return false;
689+
// When you have multiple instructions residing in R1 and R2 respectively,
690+
// it's difficult to generate combination of (R1,R2) and then check if we
691+
// have the required pattern. So, for now, just be conservative.
692+
if (I->getParent() != BBr2)
693+
return false;
694+
if (!I->hasNUsesOrMore(1))
695+
return false;
696+
// This simplication changes
697+
// R2 = a * 1/sqrt(a)
698+
// to
699+
// R2 = sqrt(a)
700+
// Now, sqrt(-0.0) = -0.0 and doing this simplication would produce -0.0
701+
// instead of NaN.
702+
return I->hasNoSignedZeros();
703+
};
704+
if (!std::all_of(R2.begin(), R2.end(), R2ConstraintsSatisfied))
705+
return false;
706+
707+
// Check the constraints on X, R1 and R2 combined.
708+
// fdiv instruction and one of the multiplications must reside in the same
709+
// block. If not, the optimized code may execute more ops than before and
710+
// this may hamper the performance.
711+
return (BBx == BBr1 || BBx == BBr2);
712+
}
713+
714+
void getFSqrtDivOptPattern(Value *Div, SmallSetVector<Instruction *, 2> &R1,
715+
SmallSetVector<Instruction *, 2> &R2) {
716+
Value *A;
717+
if (match(Div, m_FDiv(m_FPOne(), m_Sqrt(m_Value(A)))) ||
718+
match(Div, m_FDiv(m_SpecificFP(-1.0), m_Sqrt(m_Value(A))))) {
719+
for (auto U : Div->users()) {
720+
Instruction *I = dyn_cast<Instruction>(U);
721+
if (!(I && I->getOpcode() == Instruction::FMul))
722+
continue;
723+
724+
if (match(I, m_FMul(m_Specific(Div), m_Specific(Div)))) {
725+
R1.insert(I);
726+
continue;
727+
}
728+
729+
Value *X;
730+
if (match(I, m_FMul(m_Specific(Div), m_Value(X))) && X == A) {
731+
R2.insert(I);
732+
continue;
733+
}
734+
735+
if (match(I, m_FMul(m_Value(X), m_Specific(Div))) && X == A) {
736+
R2.insert(I);
737+
continue;
738+
}
739+
}
740+
}
741+
}
742+
743+
bool delayFMulSqrtTransform(Value *Div) {
744+
SmallSetVector<Instruction *, 2> R1, R2;
745+
getFSqrtDivOptPattern(Div, R1, R2);
746+
return (!(R1.empty() || R2.empty()) &&
747+
isFSqrtDivToFMulLegal((Instruction *)Div, R1, R2));
748+
}
749+
629750
Instruction *InstCombinerImpl::foldFMulReassoc(BinaryOperator &I) {
630751
Value *Op0 = I.getOperand(0);
631752
Value *Op1 = I.getOperand(1);
@@ -705,19 +826,20 @@ Instruction *InstCombinerImpl::foldFMulReassoc(BinaryOperator &I) {
705826
// has the necessary (reassoc) fast-math-flags.
706827
if (I.hasNoSignedZeros() &&
707828
match(Op0, (m_FDiv(m_SpecificFP(1.0), m_Value(Y)))) &&
708-
match(Y, m_Sqrt(m_Value(X))) && Op1 == X)
829+
match(Y, m_Sqrt(m_Value(X))) && Op1 == X && !delayFMulSqrtTransform(Op0))
709830
return BinaryOperator::CreateFDivFMF(X, Y, &I);
710831
if (I.hasNoSignedZeros() &&
711832
match(Op1, (m_FDiv(m_SpecificFP(1.0), m_Value(Y)))) &&
712-
match(Y, m_Sqrt(m_Value(X))) && Op0 == X)
833+
match(Y, m_Sqrt(m_Value(X))) && Op0 == X && !delayFMulSqrtTransform(Op1))
713834
return BinaryOperator::CreateFDivFMF(X, Y, &I);
714835

715836
// Like the similar transform in instsimplify, this requires 'nsz' because
716837
// sqrt(-0.0) = -0.0, and -0.0 * -0.0 does not simplify to -0.0.
717838
if (I.hasNoNaNs() && I.hasNoSignedZeros() && Op0 == Op1 && Op0->hasNUses(2)) {
718839
// Peek through fdiv to find squaring of square root:
719840
// (X / sqrt(Y)) * (X / sqrt(Y)) --> (X * X) / Y
720-
if (match(Op0, m_FDiv(m_Value(X), m_Sqrt(m_Value(Y))))) {
841+
if (match(Op0, m_FDiv(m_Value(X), m_Sqrt(m_Value(Y)))) &&
842+
!delayFMulSqrtTransform(Op0)) {
721843
Value *XX = Builder.CreateFMulFMF(X, X, &I);
722844
return BinaryOperator::CreateFDivFMF(XX, Y, &I);
723845
}
@@ -1796,6 +1918,35 @@ static Instruction *foldFDivSqrtDivisor(BinaryOperator &I,
17961918
return BinaryOperator::CreateFMulFMF(Op0, NewSqrt, &I);
17971919
}
17981920

1921+
Value *convertFSqrtDivIntoFMul(CallInst *CI, Instruction *X,
1922+
SmallSetVector<Instruction *, 2> &R1,
1923+
SmallSetVector<Instruction *, 2> &R2,
1924+
Value *SqrtOp, InstCombiner::BuilderTy &B) {
1925+
1926+
// 1. synthesize tmp1 = 1/a and replace uses of r1
1927+
B.SetInsertPoint(X);
1928+
Value *Tmp1 =
1929+
B.CreateFDivFMF(ConstantFP::get(R1[0]->getType(), 1.0), SqrtOp, R1[0]);
1930+
for (auto *I : R1)
1931+
I->replaceAllUsesWith(Tmp1);
1932+
1933+
// 2. No need of synthesizing Tmp2 again. In this scenario, tmp2 = CI. Replace
1934+
// uses of r2 with tmp2
1935+
for (auto *I : R2)
1936+
I->replaceAllUsesWith(CI);
1937+
1938+
// 3. synthesize tmp3 = tmp1 * tmp2 . Replace uses of 'x' with tmp3
1939+
Value *Tmp3;
1940+
// If x = -1/sqrt(a) initially,then Tmp3 = -(Tmp1*tmp2)
1941+
if (match(X, m_FDiv(m_SpecificFP(-1.0), m_Specific(CI)))) {
1942+
Value *Mul = B.CreateFMul(Tmp1, CI);
1943+
Tmp3 = B.CreateFNegFMF(Mul, X);
1944+
} else
1945+
Tmp3 = B.CreateFMulFMF(Tmp1, CI, X);
1946+
1947+
return Tmp3;
1948+
}
1949+
17991950
Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) {
18001951
Module *M = I.getModule();
18011952

@@ -1820,6 +1971,26 @@ Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) {
18201971
return R;
18211972

18221973
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
1974+
1975+
// Convert
1976+
// x = 1.0/sqrt(a)
1977+
// r1 = x * x;
1978+
// r2 = a * x;
1979+
//
1980+
// TO
1981+
//
1982+
// r1 = 1/a
1983+
// r2 = sqrt(a)
1984+
// x = r1 * r2
1985+
SmallSetVector<Instruction *, 2> R1, R2;
1986+
getFSqrtDivOptPattern(&I, R1, R2);
1987+
if (!(R1.empty() || R2.empty()) && isFSqrtDivToFMulLegal(&I, R1, R2)) {
1988+
CallInst *CI = (CallInst *)((&I)->getOperand(1));
1989+
Value *SqrtOp = CI->getArgOperand(0);
1990+
if (Value *D = convertFSqrtDivIntoFMul(CI, &I, R1, R2, SqrtOp, Builder))
1991+
return replaceInstUsesWith(I, D);
1992+
}
1993+
18231994
if (isa<Constant>(Op0))
18241995
if (SelectInst *SI = dyn_cast<SelectInst>(Op1))
18251996
if (Instruction *R = FoldOpIntoSelect(I, SI))

0 commit comments

Comments
 (0)