Skip to content

[InstCombine] Transform high latency, dependent FSQRT/FDIV into FMUL #87474

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 176 additions & 0 deletions llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include "InstCombineInternal.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Analysis/InstructionSimplify.h"
#include "llvm/Analysis/ValueTracking.h"
Expand Down Expand Up @@ -666,6 +667,94 @@ Instruction *InstCombinerImpl::foldPowiReassoc(BinaryOperator &I) {
return nullptr;
}

// If we have the following pattern,
// X = 1.0/sqrt(a)
// R1 = X * X
// R2 = a/sqrt(a)
// then this method collects all the instructions that match R1 and R2.
static bool getFSqrtDivOptPattern(Instruction *Div,
SmallPtrSetImpl<Instruction *> &R1,
SmallPtrSetImpl<Instruction *> &R2) {
Value *A;
if (match(Div, m_FDiv(m_FPOne(), m_Sqrt(m_Value(A)))) ||
match(Div, m_FDiv(m_SpecificFP(-1.0), m_Sqrt(m_Value(A))))) {
for (User *U : Div->users()) {
Instruction *I = cast<Instruction>(U);
if (match(I, m_FMul(m_Specific(Div), m_Specific(Div))))
R1.insert(I);
}

CallInst *CI = cast<CallInst>(Div->getOperand(1));
for (User *U : CI->users()) {
Instruction *I = cast<Instruction>(U);
if (match(I, m_FDiv(m_Specific(A), m_Sqrt(m_Specific(A)))))
R2.insert(I);
}
}
return !R1.empty() && !R2.empty();
}

// Check legality for transforming
// x = 1.0/sqrt(a)
// r1 = x * x;
// r2 = a/sqrt(a);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excuse me if this was covered in one of the many resolved conversations, but it's not clear to me why you're transforming a/sqrt(a) as part of this pattern. Is it because you need to hoist R2 into the same block as X?

I see that in InstCombinerImpl::foldFMulReassoc() we are transforming a number of patterns into x/sqrt(x) with a comment that the backend is expected to transform that into sqrt(x) if the necessary fast-math flags are present. I'm not sure why that's being left to the backend, but I don't see any reason to perform the transformation here. If you want InstCombine to do that, it could just as easily be an independent transformation.

It may be that this is a case where we decided we needed the "unsafe-fp-math" function attribute because none of the individual fast-math flags clearly allows it. @jcranmer-intel has been working on clarifying the semantics of these flags and might have more to say on this. I think it's definitely a transformation we want to allow, but I would argue that it requires more than just reassoc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but it's not clear to me why you're transforming a/sqrt(a) as part of this pattern

we would like to express x = r1 * r2 where r1 and r2 are in suitable form. If r2 is not in required form, there is no point in doing this transformation(i.e. we wont be saving on 1 division in the backend).
We cant wait for the backend to transform r2 here.

I think it's definitely a transformation we want to allow, but I would argue that it requires more than just reassoc.

as far as I remember, I had this discussion with @jcranmer-intel on this PR itself. This is just considered algebraic-rewrite and hence, the reassoc flag. But if there any other flags required, coming post acceptance of his proposal, I am not sure if we should wait.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The semantics of the 'reassoc' flag don't explicitly allow algebraic transformations other than reassociation. It's often treated that way, but that's not what the language reference says. This is one of the things @jcranmer-intel is hoping to correct long-term. For now, we don't have another flag that clearly allows this, so I suppose we'll need to rely on 'reassoc' here.

//
// TO
//
// r1 = 1/a
// r2 = sqrt(a)
// x = r1 * r2
// This transform works only when 'a' is known positive.
static bool isFSqrtDivToFMulLegal(Instruction *X,
SmallPtrSetImpl<Instruction *> &R1,
SmallPtrSetImpl<Instruction *> &R2) {
// Check if the required pattern for the transformation exists.
if (!getFSqrtDivOptPattern(X, R1, R2))
return false;

BasicBlock *BBx = X->getParent();
BasicBlock *BBr1 = (*R1.begin())->getParent();
BasicBlock *BBr2 = (*R2.begin())->getParent();

CallInst *FSqrt = cast<CallInst>(X->getOperand(1));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like the use of cast<> here. I understand that you've verified in getFSqrtDivOptPattern() that this will be a call, but the cast here creates a tight coupling between the functions that isn't enforced by the function semantics. That is, someone could call this function without having called the other. I think it makes more sense to combine them or to call getFSqrtDivOptPattern from here and not require it to be called separately by users of this function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Firstly, any attempt to call the func directly where the cast<> fails should be immediately visible in the debug build.

Second, This part is keep different just to increase readability and seperate functionally different things. The scenario you have mentioned is bound to happen everywhere. For the same reason, I have mentioned in the description of the function that how should x/r1/r2 look like. If you are not satisfied with the cast, maybe I can add an assert.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cast contains an assert, so there's no need to add one. I understand your point about readability. I think it's better to have code that avoids failures than to have code that fails in obvious ways if misused, but I'm willing to leave that up to you.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok sure. Will make the necessary change.

if (!FSqrt->hasAllowReassoc() || !FSqrt->hasNoNaNs() ||
!FSqrt->hasNoSignedZeros() || !FSqrt->hasNoInfs())
Comment on lines +720 to +721
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this require nnan? For a nan input a, r1, r2 and x are trivially nan in the input and output. For a < 0, in the input, x = nan, r1 = nan, r2 = nan. In the output, r1 = non-nan, but this is OK if the single use is the multiply to x. This would also be OK if the multiply had nnan instead

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The optimization is valid only for positive normals. Now, we cant put restrictions on values of a so I had to put constraints on call instruction since this is used for all x/r1/r2. Also, x/r1/r2 can have multiple uses and hence, their values before/after transform need to be matched

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The nnan restriction severely limits the usefulness of this transformation, so I think it's worth trying to find alternative ways for the transformation to trigger. I understand that you want it to trigger for cases where x, r1, and r2 have an arbitrary number of users, but if the nnan condition isn't met, you could still perform the transformation if r1 is the only user or x. In addition, you could check isKnownNonNegative() for a.

On the other hand, the ninf requirement is similarly restrictive, so maybe it's just necessary to accept that the limitations on this transformation. Since you mentioned CPU2017 in your description, I would mention that ninf can't be used with all CPU2017 benchmarks. In particular, it breaks povray. That's not to say this transformation isn't general enough. I'm just highlighting the limitation.

Copy link
Contributor Author

@sushgokh sushgokh Jan 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This transformation is valid only when a is positive normal. So, we need to have these checks enforced on X or a.

So, yes, this is limitation when the transformation is applicable.

return false;

// We change x = 1/sqrt(a) to x = sqrt(a) * 1/a . This change isn't allowed
// by recip fp as it is strictly meant to transform ops of type a/b to
// a * 1/b. So, this can be considered as algebraic rewrite and reassoc flag
// has been used(rather abused)in the past for algebraic rewrites.
if (!X->hasAllowReassoc() || !X->hasAllowReciprocal() || !X->hasNoInfs())
return false;

// Check the constraints on X, R1 and R2 combined.
// fdiv instruction and one of the multiplications must reside in the same
// block. If not, the optimized code may execute more ops than before and
// this may hamper the performance.
if (BBx != BBr1 && BBx != BBr2)
return false;

// Check the constraints on instructions in R1.
if (any_of(R1, [BBr1](Instruction *I) {
// When you have multiple instructions residing in R1 and R2
// respectively, it's difficult to generate combinations of (R1,R2) and
// then check if we have the required pattern. So, for now, just be
// conservative.
return (I->getParent() != BBr1 || !I->hasAllowReassoc());
}))
return false;

// Check the constraints on instructions in R2.
return all_of(R2, [BBr2](Instruction *I) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason to use any_of above and all_of here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really. Initially I had all_of at all the places and someone suggested to have any_of. I can make it consistent at all the places i.e use all_of. My understanding is it wont matter

// When you have multiple instructions residing in R1 and R2
// respectively, it's difficult to generate combination of (R1,R2) and
// then check if we have the required pattern. So, for now, just be
// conservative.
return (I->getParent() == BBr2 && I->hasAllowReassoc());
});
}

Instruction *InstCombinerImpl::foldFMulReassoc(BinaryOperator &I) {
Value *Op0 = I.getOperand(0);
Value *Op1 = I.getOperand(1);
Expand Down Expand Up @@ -1917,6 +2006,75 @@ static Instruction *foldFDivSqrtDivisor(BinaryOperator &I,
return BinaryOperator::CreateFMulFMF(Op0, NewSqrt, &I);
}

// Change
// X = 1/sqrt(a)
// R1 = X * X
// R2 = a * X
//
// TO
//
// FDiv = 1/a
// FSqrt = sqrt(a)
// FMul = FDiv * FSqrt
// Replace Uses Of R1 With FDiv
// Replace Uses Of R2 With FSqrt
// Replace Uses Of X With FMul
static Instruction *
convertFSqrtDivIntoFMul(CallInst *CI, Instruction *X,
const SmallPtrSetImpl<Instruction *> &R1,
const SmallPtrSetImpl<Instruction *> &R2,
InstCombiner::BuilderTy &B, InstCombinerImpl *IC) {

B.SetInsertPoint(X);

// Have an instruction that is representative of all of instructions in R1 and
// get the most common fpmath metadata and fast-math flags on it.
Value *SqrtOp = CI->getArgOperand(0);
auto *FDiv = cast<Instruction>(
B.CreateFDiv(ConstantFP::get(X->getType(), 1.0), SqrtOp));
auto *R1FPMathMDNode = (*R1.begin())->getMetadata(LLVMContext::MD_fpmath);
FastMathFlags R1FMF = (*R1.begin())->getFastMathFlags(); // Common FMF
for (Instruction *I : R1) {
R1FPMathMDNode = MDNode::getMostGenericFPMath(
R1FPMathMDNode, I->getMetadata(LLVMContext::MD_fpmath));
R1FMF &= I->getFastMathFlags();
IC->replaceInstUsesWith(*I, FDiv);
IC->eraseInstFromFunction(*I);
}
FDiv->setMetadata(LLVMContext::MD_fpmath, R1FPMathMDNode);
FDiv->copyFastMathFlags(R1FMF);

// Have a single sqrt call instruction that is representative of all of
// instructions in R2 and get the most common fpmath metadata and fast-math
// flags on it.
auto *FSqrt = cast<CallInst>(CI->clone());
FSqrt->insertBefore(CI);
auto *R2FPMathMDNode = (*R2.begin())->getMetadata(LLVMContext::MD_fpmath);
FastMathFlags R2FMF = (*R2.begin())->getFastMathFlags(); // Common FMF
for (Instruction *I : R2) {
R2FPMathMDNode = MDNode::getMostGenericFPMath(
R2FPMathMDNode, I->getMetadata(LLVMContext::MD_fpmath));
R2FMF &= I->getFastMathFlags();
IC->replaceInstUsesWith(*I, FSqrt);
IC->eraseInstFromFunction(*I);
}
FSqrt->setMetadata(LLVMContext::MD_fpmath, R2FPMathMDNode);
FSqrt->copyFastMathFlags(R2FMF);

Instruction *FMul;
// If X = -1/sqrt(a) initially,then FMul = -(FDiv * FSqrt)
if (match(X, m_FDiv(m_SpecificFP(-1.0), m_Specific(CI)))) {
Value *Mul = B.CreateFMul(FDiv, FSqrt);
FMul = cast<Instruction>(B.CreateFNeg(Mul));
} else
FMul = cast<Instruction>(B.CreateFMul(FDiv, FSqrt));
FMul->copyMetadata(*X);
FMul->copyFastMathFlags(FastMathFlags::intersectRewrite(R1FMF, R2FMF) |
FastMathFlags::unionValue(R1FMF, R2FMF));
IC->replaceInstUsesWith(*X, FMul);
return IC->eraseInstFromFunction(*X);
}

Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) {
Module *M = I.getModule();

Expand All @@ -1941,6 +2099,24 @@ Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) {
return R;

Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);

// Convert
// x = 1.0/sqrt(a)
// r1 = x * x;
// r2 = a/sqrt(a);
//
// TO
//
// r1 = 1/a
// r2 = sqrt(a)
// x = r1 * r2
SmallPtrSet<Instruction *, 2> R1, R2;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had to change from SmallVector to SmallPtrSet since invoking the users() API somehow got me duplicate users

if (isFSqrtDivToFMulLegal(&I, R1, R2)) {
CallInst *CI = cast<CallInst>(I.getOperand(1));
if (Instruction *D = convertFSqrtDivIntoFMul(CI, &I, R1, R2, Builder, this))
return D;
}

if (isa<Constant>(Op0))
if (SelectInst *SI = dyn_cast<SelectInst>(Op1))
if (Instruction *R = FoldOpIntoSelect(I, SI))
Expand Down
Loading
Loading