@@ -626,6 +626,150 @@ Instruction *InstCombinerImpl::foldPowiReassoc(BinaryOperator &I) {
626
626
return nullptr ;
627
627
}
628
628
629
+ bool isFSqrtDivToFMulLegal (Instruction *X, SmallSetVector<Instruction *, 2 > &R1,
630
+ SmallSetVector<Instruction *, 2 > &R2) {
631
+
632
+ BasicBlock *BBx = X->getParent ();
633
+ BasicBlock *BBr1 = R1[0 ]->getParent ();
634
+ BasicBlock *BBr2 = R2[0 ]->getParent ();
635
+
636
+ auto IsStrictFP = [](Instruction *I) {
637
+ IntrinsicInst *II = dyn_cast<IntrinsicInst>(I);
638
+ if (II && II->isStrictFP ())
639
+ return true ;
640
+ return false ;
641
+ };
642
+
643
+ // check if X and instructions in R1/R2 satisfy basic block constraints
644
+ auto BBConstraintsSatisfied = [BBx, BBr1, BBr2]() {
645
+ // div instruction and one of the multiplications must reside in the same
646
+ // block. If not, the optimized code may execute more ops than before and
647
+ // this may hamper the performance
648
+ if (!(BBx == BBr1 || BBx == BBr2))
649
+ return false ;
650
+ return true ;
651
+ };
652
+
653
+ // Check the constaints on instruction X
654
+ auto XConstraintsSatisfied = [X, &IsStrictFP]() {
655
+ // X must have 3 uses in R1/R2 inclusive and 1 more use if the replacement
656
+ // for X should not get dead code eliminated. If X has less than 4 uses, the
657
+ // changes to R1 and R2 are anyway done as part of other transforms.R1 and
658
+ // R2 must either be global or must have single local use.
659
+ if (!X->hasNUsesOrMore (4 ))
660
+ return false ;
661
+ // check if reciprocalFP is enabled
662
+ bool RecipFPMath = dyn_cast<FPMathOperator>(X)->hasAllowReciprocal ();
663
+ if (!RecipFPMath)
664
+ return false ;
665
+ if (IsStrictFP (X))
666
+ return false ;
667
+ return true ;
668
+ };
669
+ if (!XConstraintsSatisfied ())
670
+ return false ;
671
+
672
+ // check the constraints on instructions in R1
673
+ auto R1ConstraintsSatisfied = [BBr1, &IsStrictFP](Instruction *I) {
674
+ // when you have multiple instructions residing in R1 and R2 respectively,
675
+ // its difficult to generate combination of (R1,R2) and then check if we
676
+ // have the required pattern. So, for now, just be conservative.
677
+ if (I->getParent () != BBr1)
678
+ return false ;
679
+ if (!I->hasNUsesOrMore (1 ))
680
+ return false ;
681
+ if (IsStrictFP (I))
682
+ return false ;
683
+ // The optimization tries to convert
684
+ // R1 = div * div where, div = 1/sqrt(a)
685
+ // to
686
+ // R1 = 1/a
687
+ // Now, this simplication does not work because sqrt(a)=NaN when a<0
688
+ if (!I->hasNoNaNs ())
689
+ return false ;
690
+ // sqrt(-0.0) = -0.0, and doing this simplication would change the sign of
691
+ // the result.
692
+ if (!I->hasNoSignedZeros ())
693
+ return false ;
694
+ return true ;
695
+ };
696
+ if (!std::all_of (R1.begin (), R1.end (), R1ConstraintsSatisfied))
697
+ return false ;
698
+
699
+ // check the constraints on instructions in R2
700
+ auto R2ConstraintsSatisfied = [BBr2, &IsStrictFP](Instruction *I) {
701
+ // when you have multiple instructions residing in R1 and R2 respectively,
702
+ // its
703
+ // difficult to generate combination of (R1,R2) and then check if we have
704
+ // the required pattern. So, for now, just be conservative.
705
+ if (I->getParent () != BBr2)
706
+ return false ;
707
+ if (!I->hasNUsesOrMore (1 ))
708
+ return false ;
709
+ if (IsStrictFP (I))
710
+ return false ;
711
+ // This simplication changes
712
+ // R2 = a * 1/sqrt(a)
713
+ // to
714
+ // R2 = sqrt(a)
715
+ // Now, sqrt(-0.0) = -0.0 and doing this simplication would produce -0.0
716
+ // instead of NaN.
717
+ if (!I->hasNoSignedZeros ())
718
+ return false ;
719
+ return true ;
720
+ };
721
+ if (!std::all_of (R2.begin (), R2.end (), R2ConstraintsSatisfied))
722
+ return false ;
723
+
724
+ auto XR1R2ConstraintsSatisfied = [=]() {
725
+ const Function *F = X->getFunction ();
726
+ bool UnsafeFPMath = F->getFnAttribute (" unsafe-fp-math" ).getValueAsBool ();
727
+ if (!UnsafeFPMath)
728
+ return false ;
729
+ if (!BBConstraintsSatisfied ())
730
+ return false ;
731
+ return true ;
732
+ };
733
+ if (!XR1R2ConstraintsSatisfied ())
734
+ return false ;
735
+
736
+ return true ;
737
+ }
738
+
739
+ void getFSqrtDivOptPattern (Value *Div, SmallSetVector<Instruction *, 2 > &R1,
740
+ SmallSetVector<Instruction *, 2 > &R2) {
741
+ Value *A;
742
+ if (match (Div, m_FDiv (m_FPOne (), m_Sqrt (m_Value (A))))) {
743
+ for (auto U : Div->users ()) {
744
+
745
+ if (match (U, m_FMul (m_Specific (Div), m_Specific (Div)))) {
746
+ R1.insert (static_cast <Instruction *>(U));
747
+ continue ;
748
+ }
749
+
750
+ Value *X;
751
+ if (match (U, m_FMul (m_Specific (Div), m_Value (X))) && X == A) {
752
+ R2.insert (static_cast <Instruction *>(U));
753
+ continue ;
754
+ }
755
+
756
+ if (match (U, m_FMul (m_Value (X), m_Specific (Div))) && X == A) {
757
+ R2.insert (static_cast <Instruction *>(U));
758
+ continue ;
759
+ }
760
+ }
761
+ }
762
+ }
763
+
764
+ bool delayFMulSqrtTransform (Value *Div) {
765
+ SmallSetVector<Instruction *, 2 > R1, R2;
766
+ getFSqrtDivOptPattern (Div, R1, R2);
767
+ if (R1.size () && R2.size () &&
768
+ isFSqrtDivToFMulLegal (static_cast <Instruction *>(Div), R1, R2))
769
+ return true ;
770
+ return false ;
771
+ }
772
+
629
773
Instruction *InstCombinerImpl::foldFMulReassoc (BinaryOperator &I) {
630
774
Value *Op0 = I.getOperand (0 );
631
775
Value *Op1 = I.getOperand (1 );
@@ -705,19 +849,20 @@ Instruction *InstCombinerImpl::foldFMulReassoc(BinaryOperator &I) {
705
849
// has the necessary (reassoc) fast-math-flags.
706
850
if (I.hasNoSignedZeros () &&
707
851
match (Op0, (m_FDiv (m_SpecificFP (1.0 ), m_Value (Y)))) &&
708
- match (Y, m_Sqrt (m_Value (X))) && Op1 == X)
852
+ match (Y, m_Sqrt (m_Value (X))) && Op1 == X && ! delayFMulSqrtTransform (Op0) )
709
853
return BinaryOperator::CreateFDivFMF (X, Y, &I);
710
854
if (I.hasNoSignedZeros () &&
711
855
match (Op1, (m_FDiv (m_SpecificFP (1.0 ), m_Value (Y)))) &&
712
- match (Y, m_Sqrt (m_Value (X))) && Op0 == X)
856
+ match (Y, m_Sqrt (m_Value (X))) && Op0 == X && ! delayFMulSqrtTransform (Op1) )
713
857
return BinaryOperator::CreateFDivFMF (X, Y, &I);
714
858
715
859
// Like the similar transform in instsimplify, this requires 'nsz' because
716
860
// sqrt(-0.0) = -0.0, and -0.0 * -0.0 does not simplify to -0.0.
717
861
if (I.hasNoNaNs () && I.hasNoSignedZeros () && Op0 == Op1 && Op0->hasNUses (2 )) {
718
862
// Peek through fdiv to find squaring of square root:
719
863
// (X / sqrt(Y)) * (X / sqrt(Y)) --> (X * X) / Y
720
- if (match (Op0, m_FDiv (m_Value (X), m_Sqrt (m_Value (Y))))) {
864
+ if (match (Op0, m_FDiv (m_Value (X), m_Sqrt (m_Value (Y)))) &&
865
+ !delayFMulSqrtTransform (Op0)) {
721
866
Value *XX = Builder.CreateFMulFMF (X, X, &I);
722
867
return BinaryOperator::CreateFDivFMF (XX, Y, &I);
723
868
}
@@ -1796,6 +1941,30 @@ static Instruction *foldFDivSqrtDivisor(BinaryOperator &I,
1796
1941
return BinaryOperator::CreateFMulFMF (Op0, NewSqrt, &I);
1797
1942
}
1798
1943
1944
+ Value *convertFSqrtDivIntoFMul (CallInst *CI, Instruction *X,
1945
+ SmallSetVector<Instruction *, 2 > &R1,
1946
+ SmallSetVector<Instruction *, 2 > &R2,
1947
+ Value *SqrtOp, InstCombiner::BuilderTy &B) {
1948
+
1949
+ // 1. synthesize tmp1 = 1/a and replace uses of r1
1950
+ B.SetInsertPoint (X);
1951
+ Value *Tmp1 =
1952
+ B.CreateFDivFMF (ConstantFP::get (R1[0 ]->getType (), 1.0 ), SqrtOp, R1[0 ]);
1953
+ for (auto *I : R1)
1954
+ I->replaceAllUsesWith (Tmp1);
1955
+
1956
+ // 2. No need of synthesizing Tmp2 again. In this scenario, tmp2 = CI. Replace
1957
+ // uses of r2 with tmp2
1958
+ for (auto *I : R2)
1959
+ I->replaceAllUsesWith (CI);
1960
+
1961
+ // 3. synthesize tmp3 = tmp1 * tmp2 . Replace uses of 'x' with tmp3
1962
+ B.SetInsertPoint (X->getNextNode ());
1963
+ Value *Tmp3 = B.CreateFMulFMF (Tmp1, CI, X);
1964
+
1965
+ return Tmp3;
1966
+ }
1967
+
1799
1968
Instruction *InstCombinerImpl::visitFDiv (BinaryOperator &I) {
1800
1969
Module *M = I.getModule ();
1801
1970
@@ -1820,6 +1989,30 @@ Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) {
1820
1989
return R;
1821
1990
1822
1991
Value *Op0 = I.getOperand (0 ), *Op1 = I.getOperand (1 );
1992
+
1993
+ // Convert
1994
+ // x = 1.0/sqrt(a)
1995
+ // r1 = x * x;
1996
+ // r2 = a * x;
1997
+ //
1998
+ // TO
1999
+ //
2000
+ // tmp1 = 1.0 / a
2001
+ // tmp2 = sqrt (a)
2002
+ // tmp3 = tmp1 * tmp2
2003
+ // x = tmp3
2004
+ // r1 = tmp1
2005
+ // r2 = tmp2
2006
+ SmallSetVector<Instruction *, 2 > R1, R2;
2007
+ getFSqrtDivOptPattern (&I, R1, R2);
2008
+ if (R1.size () && R2.size () &&
2009
+ isFSqrtDivToFMulLegal (static_cast <Instruction *>(&I), R1, R2)) {
2010
+ CallInst *CI = static_cast <CallInst *>((&I)->getOperand (1 ));
2011
+ Value *SqrtOp = CI->getArgOperand (0 );
2012
+ if (Value *D = convertFSqrtDivIntoFMul (CI, &I, R1, R2, SqrtOp, Builder))
2013
+ return replaceInstUsesWith (I, D);
2014
+ }
2015
+
1823
2016
if (isa<Constant>(Op0))
1824
2017
if (SelectInst *SI = dyn_cast<SelectInst>(Op1))
1825
2018
if (Instruction *R = FoldOpIntoSelect (I, SI))
0 commit comments