Skip to content

Commit 18dd8fb

Browse files
committed
[InstCombine] Transform high latency, dependent FSQRT/FDIV into FMUL
The proposed patch, in general, tries to transform the below code sequence: x = 1.0 / sqrt (a); r1 = x * x; // same as 1.0 / a r2 = a * x; // same as sqrt (a) TO (If x, r1 and r2 are all used further in the code) tmp1 = 1.0 / a tmp2 = sqrt (a) tmp3 = tmp1 * tmp2 x = tmp3 r1 = tmp1 r2 = tmp2 The transform tries to make high latency sqrt and div operations independent and also saves on one multiplication. The patch was tested with SPEC17 suite with cpu=neoverse-v2. The performance uplift achieved was: 544.nab_r ~4% No other regressions were observed. Also, no compile time differences were observed with the patch. Closes #54652
1 parent e05c1b4 commit 18dd8fb

File tree

2 files changed

+602
-3
lines changed

2 files changed

+602
-3
lines changed

llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp

Lines changed: 196 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -626,6 +626,150 @@ Instruction *InstCombinerImpl::foldPowiReassoc(BinaryOperator &I) {
626626
return nullptr;
627627
}
628628

629+
bool isFSqrtDivToFMulLegal(Instruction *X, SmallSetVector<Instruction *, 2> &R1,
630+
SmallSetVector<Instruction *, 2> &R2) {
631+
632+
BasicBlock *BBx = X->getParent();
633+
BasicBlock *BBr1 = R1[0]->getParent();
634+
BasicBlock *BBr2 = R2[0]->getParent();
635+
636+
auto IsStrictFP = [](Instruction *I) {
637+
IntrinsicInst *II = dyn_cast<IntrinsicInst>(I);
638+
if (II && II->isStrictFP())
639+
return true;
640+
return false;
641+
};
642+
643+
// check if X and instructions in R1/R2 satisfy basic block constraints
644+
auto BBConstraintsSatisfied = [BBx, BBr1, BBr2]() {
645+
// div instruction and one of the multiplications must reside in the same
646+
// block. If not, the optimized code may execute more ops than before and
647+
// this may hamper the performance
648+
if (!(BBx == BBr1 || BBx == BBr2))
649+
return false;
650+
return true;
651+
};
652+
653+
// Check the constaints on instruction X
654+
auto XConstraintsSatisfied = [X, &IsStrictFP]() {
655+
// X must have 3 uses in R1/R2 inclusive and 1 more use if the replacement
656+
// for X should not get dead code eliminated. If X has less than 4 uses, the
657+
// changes to R1 and R2 are anyway done as part of other transforms.R1 and
658+
// R2 must either be global or must have single local use.
659+
if (!X->hasNUsesOrMore(4))
660+
return false;
661+
// check if reciprocalFP is enabled
662+
bool RecipFPMath = dyn_cast<FPMathOperator>(X)->hasAllowReciprocal();
663+
if (!RecipFPMath)
664+
return false;
665+
if (IsStrictFP(X))
666+
return false;
667+
return true;
668+
};
669+
if (!XConstraintsSatisfied())
670+
return false;
671+
672+
// check the constraints on instructions in R1
673+
auto R1ConstraintsSatisfied = [BBr1, &IsStrictFP](Instruction *I) {
674+
// when you have multiple instructions residing in R1 and R2 respectively,
675+
// its difficult to generate combination of (R1,R2) and then check if we
676+
// have the required pattern. So, for now, just be conservative.
677+
if (I->getParent() != BBr1)
678+
return false;
679+
if (!I->hasNUsesOrMore(1))
680+
return false;
681+
if (IsStrictFP(I))
682+
return false;
683+
// The optimization tries to convert
684+
// R1 = div * div where, div = 1/sqrt(a)
685+
// to
686+
// R1 = 1/a
687+
// Now, this simplication does not work because sqrt(a)=NaN when a<0
688+
if (!I->hasNoNaNs())
689+
return false;
690+
// sqrt(-0.0) = -0.0, and doing this simplication would change the sign of
691+
// the result.
692+
if (!I->hasNoSignedZeros())
693+
return false;
694+
return true;
695+
};
696+
if (!std::all_of(R1.begin(), R1.end(), R1ConstraintsSatisfied))
697+
return false;
698+
699+
// check the constraints on instructions in R2
700+
auto R2ConstraintsSatisfied = [BBr2, &IsStrictFP](Instruction *I) {
701+
// when you have multiple instructions residing in R1 and R2 respectively,
702+
// its
703+
// difficult to generate combination of (R1,R2) and then check if we have
704+
// the required pattern. So, for now, just be conservative.
705+
if (I->getParent() != BBr2)
706+
return false;
707+
if (!I->hasNUsesOrMore(1))
708+
return false;
709+
if (IsStrictFP(I))
710+
return false;
711+
// This simplication changes
712+
// R2 = a * 1/sqrt(a)
713+
// to
714+
// R2 = sqrt(a)
715+
// Now, sqrt(-0.0) = -0.0 and doing this simplication would produce -0.0
716+
// instead of NaN.
717+
if (!I->hasNoSignedZeros())
718+
return false;
719+
return true;
720+
};
721+
if (!std::all_of(R2.begin(), R2.end(), R2ConstraintsSatisfied))
722+
return false;
723+
724+
auto XR1R2ConstraintsSatisfied = [=]() {
725+
const Function *F = X->getFunction();
726+
bool UnsafeFPMath = F->getFnAttribute("unsafe-fp-math").getValueAsBool();
727+
if (!UnsafeFPMath)
728+
return false;
729+
if (!BBConstraintsSatisfied())
730+
return false;
731+
return true;
732+
};
733+
if (!XR1R2ConstraintsSatisfied())
734+
return false;
735+
736+
return true;
737+
}
738+
739+
void getFSqrtDivOptPattern(Value *Div, SmallSetVector<Instruction *, 2> &R1,
740+
SmallSetVector<Instruction *, 2> &R2) {
741+
Value *A;
742+
if (match(Div, m_FDiv(m_FPOne(), m_Sqrt(m_Value(A))))) {
743+
for (auto U : Div->users()) {
744+
745+
if (match(U, m_FMul(m_Specific(Div), m_Specific(Div)))) {
746+
R1.insert(static_cast<Instruction *>(U));
747+
continue;
748+
}
749+
750+
Value *X;
751+
if (match(U, m_FMul(m_Specific(Div), m_Value(X))) && X == A) {
752+
R2.insert(static_cast<Instruction *>(U));
753+
continue;
754+
}
755+
756+
if (match(U, m_FMul(m_Value(X), m_Specific(Div))) && X == A) {
757+
R2.insert(static_cast<Instruction *>(U));
758+
continue;
759+
}
760+
}
761+
}
762+
}
763+
764+
bool delayFMulSqrtTransform(Value *Div) {
765+
SmallSetVector<Instruction *, 2> R1, R2;
766+
getFSqrtDivOptPattern(Div, R1, R2);
767+
if (R1.size() && R2.size() &&
768+
isFSqrtDivToFMulLegal(static_cast<Instruction *>(Div), R1, R2))
769+
return true;
770+
return false;
771+
}
772+
629773
Instruction *InstCombinerImpl::foldFMulReassoc(BinaryOperator &I) {
630774
Value *Op0 = I.getOperand(0);
631775
Value *Op1 = I.getOperand(1);
@@ -705,19 +849,20 @@ Instruction *InstCombinerImpl::foldFMulReassoc(BinaryOperator &I) {
705849
// has the necessary (reassoc) fast-math-flags.
706850
if (I.hasNoSignedZeros() &&
707851
match(Op0, (m_FDiv(m_SpecificFP(1.0), m_Value(Y)))) &&
708-
match(Y, m_Sqrt(m_Value(X))) && Op1 == X)
852+
match(Y, m_Sqrt(m_Value(X))) && Op1 == X && !delayFMulSqrtTransform(Op0))
709853
return BinaryOperator::CreateFDivFMF(X, Y, &I);
710854
if (I.hasNoSignedZeros() &&
711855
match(Op1, (m_FDiv(m_SpecificFP(1.0), m_Value(Y)))) &&
712-
match(Y, m_Sqrt(m_Value(X))) && Op0 == X)
856+
match(Y, m_Sqrt(m_Value(X))) && Op0 == X && !delayFMulSqrtTransform(Op1))
713857
return BinaryOperator::CreateFDivFMF(X, Y, &I);
714858

715859
// Like the similar transform in instsimplify, this requires 'nsz' because
716860
// sqrt(-0.0) = -0.0, and -0.0 * -0.0 does not simplify to -0.0.
717861
if (I.hasNoNaNs() && I.hasNoSignedZeros() && Op0 == Op1 && Op0->hasNUses(2)) {
718862
// Peek through fdiv to find squaring of square root:
719863
// (X / sqrt(Y)) * (X / sqrt(Y)) --> (X * X) / Y
720-
if (match(Op0, m_FDiv(m_Value(X), m_Sqrt(m_Value(Y))))) {
864+
if (match(Op0, m_FDiv(m_Value(X), m_Sqrt(m_Value(Y)))) &&
865+
!delayFMulSqrtTransform(Op0)) {
721866
Value *XX = Builder.CreateFMulFMF(X, X, &I);
722867
return BinaryOperator::CreateFDivFMF(XX, Y, &I);
723868
}
@@ -1796,6 +1941,30 @@ static Instruction *foldFDivSqrtDivisor(BinaryOperator &I,
17961941
return BinaryOperator::CreateFMulFMF(Op0, NewSqrt, &I);
17971942
}
17981943

1944+
Value *convertFSqrtDivIntoFMul(CallInst *CI, Instruction *X,
1945+
SmallSetVector<Instruction *, 2> &R1,
1946+
SmallSetVector<Instruction *, 2> &R2,
1947+
Value *SqrtOp, InstCombiner::BuilderTy &B) {
1948+
1949+
// 1. synthesize tmp1 = 1/a and replace uses of r1
1950+
B.SetInsertPoint(X);
1951+
Value *Tmp1 =
1952+
B.CreateFDivFMF(ConstantFP::get(R1[0]->getType(), 1.0), SqrtOp, R1[0]);
1953+
for (auto *I : R1)
1954+
I->replaceAllUsesWith(Tmp1);
1955+
1956+
// 2. No need of synthesizing Tmp2 again. In this scenario, tmp2 = CI. Replace
1957+
// uses of r2 with tmp2
1958+
for (auto *I : R2)
1959+
I->replaceAllUsesWith(CI);
1960+
1961+
// 3. synthesize tmp3 = tmp1 * tmp2 . Replace uses of 'x' with tmp3
1962+
B.SetInsertPoint(X->getNextNode());
1963+
Value *Tmp3 = B.CreateFMulFMF(Tmp1, CI, X);
1964+
1965+
return Tmp3;
1966+
}
1967+
17991968
Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) {
18001969
Module *M = I.getModule();
18011970

@@ -1820,6 +1989,30 @@ Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) {
18201989
return R;
18211990

18221991
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
1992+
1993+
// Convert
1994+
// x = 1.0/sqrt(a)
1995+
// r1 = x * x;
1996+
// r2 = a * x;
1997+
//
1998+
// TO
1999+
//
2000+
// tmp1 = 1.0 / a
2001+
// tmp2 = sqrt (a)
2002+
// tmp3 = tmp1 * tmp2
2003+
// x = tmp3
2004+
// r1 = tmp1
2005+
// r2 = tmp2
2006+
SmallSetVector<Instruction *, 2> R1, R2;
2007+
getFSqrtDivOptPattern(&I, R1, R2);
2008+
if (R1.size() && R2.size() &&
2009+
isFSqrtDivToFMulLegal(static_cast<Instruction *>(&I), R1, R2)) {
2010+
CallInst *CI = static_cast<CallInst *>((&I)->getOperand(1));
2011+
Value *SqrtOp = CI->getArgOperand(0);
2012+
if (Value *D = convertFSqrtDivIntoFMul(CI, &I, R1, R2, SqrtOp, Builder))
2013+
return replaceInstUsesWith(I, D);
2014+
}
2015+
18232016
if (isa<Constant>(Op0))
18242017
if (SelectInst *SI = dyn_cast<SelectInst>(Op1))
18252018
if (Instruction *R = FoldOpIntoSelect(I, SI))

0 commit comments

Comments
 (0)