Skip to content

Update foldFMulReassoc to respect absent fast-math flags #88589

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions llvm/include/llvm/IR/InstrTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "llvm/IR/Attributes.h"
#include "llvm/IR/CallingConv.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/FMF.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Instruction.h"
#include "llvm/IR/LLVMContext.h"
Expand Down Expand Up @@ -311,6 +312,32 @@ class BinaryOperator : public Instruction {
return BO;
}

static BinaryOperator *CreateWithFMF(BinaryOps Opc, Value *V1, Value *V2,
FastMathFlags FMF,
const Twine &Name = "",
Instruction *InsertBefore = nullptr) {
BinaryOperator *BO = Create(Opc, V1, V2, Name, InsertBefore);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you added a new operand/overload to Create, can you avoid the new include?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so. The include is needed because I'm passing FastMathFlags by value. It gets passed as a single integer, so it doesn't really make sense to pass by pointer.

BO->setFastMathFlags(FMF);
return BO;
}

static BinaryOperator *CreateFAddFMF(Value *V1, Value *V2, FastMathFlags FMF,
const Twine &Name = "") {
return CreateWithFMF(Instruction::FAdd, V1, V2, FMF, Name);
}
static BinaryOperator *CreateFSubFMF(Value *V1, Value *V2, FastMathFlags FMF,
const Twine &Name = "") {
return CreateWithFMF(Instruction::FSub, V1, V2, FMF, Name);
}
static BinaryOperator *CreateFMulFMF(Value *V1, Value *V2, FastMathFlags FMF,
const Twine &Name = "") {
return CreateWithFMF(Instruction::FMul, V1, V2, FMF, Name);
}
static BinaryOperator *CreateFDivFMF(Value *V1, Value *V2, FastMathFlags FMF,
const Twine &Name = "") {
return CreateWithFMF(Instruction::FDiv, V1, V2, FMF, Name);
}

static BinaryOperator *CreateFAddFMF(Value *V1, Value *V2,
Instruction *FMFSource,
const Twine &Name = "") {
Expand Down
38 changes: 26 additions & 12 deletions llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -631,31 +631,38 @@ Instruction *InstCombinerImpl::foldFMulReassoc(BinaryOperator &I) {
Value *Op1 = I.getOperand(1);
Value *X, *Y;
Constant *C;
BinaryOperator *Op0BinOp;

// Reassociate constant RHS with another constant to form constant
// expression.
if (match(Op1, m_Constant(C)) && C->isFiniteNonZeroFP()) {
if (match(Op1, m_Constant(C)) && C->isFiniteNonZeroFP() &&
match(Op0, m_AllowReassoc(m_BinOp(Op0BinOp)))) {
// Everything in this scope folds I with Op0, intersecting their FMF.
FastMathFlags FMF = I.getFastMathFlags() & Op0BinOp->getFastMathFlags();
IRBuilder<>::FastMathFlagGuard FMFGuard(Builder);
Builder.setFastMathFlags(FMF);
Constant *C1;
if (match(Op0, m_OneUse(m_FDiv(m_Constant(C1), m_Value(X))))) {
// (C1 / X) * C --> (C * C1) / X
Constant *CC1 =
ConstantFoldBinaryOpOperands(Instruction::FMul, C, C1, DL);
if (CC1 && CC1->isNormalFP())
return BinaryOperator::CreateFDivFMF(CC1, X, &I);
return BinaryOperator::CreateFDivFMF(CC1, X, FMF);
}
if (match(Op0, m_FDiv(m_Value(X), m_Constant(C1)))) {
// FIXME: This seems like it should also be checking for arcp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose the question here is if the transformation looks like (A / B) * C => (A * recip(B)) * C => A * (recip(B) * C) => A * (C / B) (with the creation/deletion of the internal recip operation requiring arcp), or if it's just a straightforward application of the associativity law. We're not entirely consistent with regards to whether or not reassociation of division requires both reassoc and arcp, although requiring arcp does seem to be the plurality opinion.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My thinking is that division is not associative, and you can only apply the associative property if you replace division with multiplication by the reciprocal. I agree that we're not consistent about this, but I don't see how skipping the arcp check can be justified.

FWIW, gcc thinks this requires reciprocal math: https://godbolt.org/z/T7xTTshbf

// (X / C1) * C --> X * (C / C1)
Constant *CDivC1 =
ConstantFoldBinaryOpOperands(Instruction::FDiv, C, C1, DL);
if (CDivC1 && CDivC1->isNormalFP())
return BinaryOperator::CreateFMulFMF(X, CDivC1, &I);
return BinaryOperator::CreateFMulFMF(X, CDivC1, FMF);

// If the constant was a denormal, try reassociating differently.
// (X / C1) * C --> X / (C1 / C)
Constant *C1DivC =
ConstantFoldBinaryOpOperands(Instruction::FDiv, C1, C, DL);
if (C1DivC && Op0->hasOneUse() && C1DivC->isNormalFP())
return BinaryOperator::CreateFDivFMF(X, C1DivC, &I);
return BinaryOperator::CreateFDivFMF(X, C1DivC, FMF);
}

// We do not need to match 'fadd C, X' and 'fsub X, C' because they are
Expand All @@ -665,26 +672,33 @@ Instruction *InstCombinerImpl::foldFMulReassoc(BinaryOperator &I) {
// (X + C1) * C --> (X * C) + (C * C1)
if (Constant *CC1 =
ConstantFoldBinaryOpOperands(Instruction::FMul, C, C1, DL)) {
Value *XC = Builder.CreateFMulFMF(X, C, &I);
return BinaryOperator::CreateFAddFMF(XC, CC1, &I);
Value *XC = Builder.CreateFMul(X, C);
return BinaryOperator::CreateFAddFMF(XC, CC1, FMF);
}
}
if (match(Op0, m_OneUse(m_FSub(m_Constant(C1), m_Value(X))))) {
// (C1 - X) * C --> (C * C1) - (X * C)
if (Constant *CC1 =
ConstantFoldBinaryOpOperands(Instruction::FMul, C, C1, DL)) {
Value *XC = Builder.CreateFMulFMF(X, C, &I);
return BinaryOperator::CreateFSubFMF(CC1, XC, &I);
Value *XC = Builder.CreateFMul(X, C);
return BinaryOperator::CreateFSubFMF(CC1, XC, FMF);
}
}
}

Value *Z;
if (match(&I,
m_c_FMul(m_OneUse(m_FDiv(m_Value(X), m_Value(Y))), m_Value(Z)))) {
// Sink division: (X / Y) * Z --> (X * Z) / Y
Value *NewFMul = Builder.CreateFMulFMF(X, Z, &I);
return BinaryOperator::CreateFDivFMF(NewFMul, Y, &I);
m_c_FMul(m_AllowReassoc(m_OneUse(m_FDiv(m_Value(X), m_Value(Y)))),
m_Value(Z)))) {
BinaryOperator *DivOp = cast<BinaryOperator>(((Z == Op0) ? Op1 : Op0));
FastMathFlags FMF = I.getFastMathFlags() & DivOp->getFastMathFlags();
if (FMF.allowReassoc()) {
// Sink division: (X / Y) * Z --> (X * Z) / Y
IRBuilder<>::FastMathFlagGuard FMFGuard(Builder);
Builder.setFastMathFlags(FMF);
auto *NewFMul = Builder.CreateFMul(X, Z);
return BinaryOperator::CreateFDivFMF(NewFMul, Y, FMF);
}
}

// sqrt(X) * sqrt(Y) -> sqrt(X * Y)
Expand Down
4 changes: 2 additions & 2 deletions llvm/test/Transforms/InstCombine/fast-math.ll
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ define float @fdiv1(float %x) {
; CHECK-NEXT: [[DIV1:%.*]] = fmul fast float [[X:%.*]], 0x3FD7303B60000000
; CHECK-NEXT: ret float [[DIV1]]
;
%div = fdiv float %x, 0x3FF3333340000000
%div = fdiv fast float %x, 0x3FF3333340000000
%div1 = fdiv fast float %div, 0x4002666660000000
ret float %div1
; 0x3FF3333340000000 = 1.2f
Expand Down Expand Up @@ -603,7 +603,7 @@ define float @fdiv3(float %x) {
; CHECK-NEXT: [[DIV1:%.*]] = fdiv fast float [[TMP1]], 0x47EFFFFFE0000000
; CHECK-NEXT: ret float [[DIV1]]
;
%div = fdiv float %x, 0x47EFFFFFE0000000
%div = fdiv fast float %x, 0x47EFFFFFE0000000
%div1 = fdiv fast float %div, 0x4002666660000000
ret float %div1
}
Expand Down
30 changes: 15 additions & 15 deletions llvm/test/Transforms/InstCombine/fmul-pow.ll
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ define double @pow_ab_recip_a_reassoc(double %a, double %b) {
; CHECK-NEXT: [[M:%.*]] = call reassoc double @llvm.pow.f64(double [[A:%.*]], double [[TMP1]])
; CHECK-NEXT: ret double [[M]]
;
%r = fdiv double 1.0, %a
%p = call double @llvm.pow.f64(double %a, double %b)
%r = fdiv reassoc double 1.0, %a
%p = call reassoc double @llvm.pow.f64(double %a, double %b)
%m = fmul reassoc double %r, %p
ret double %m
}
Expand All @@ -99,8 +99,8 @@ define double @pow_ab_recip_a_reassoc_commute(double %a, double %b) {
; CHECK-NEXT: [[M:%.*]] = call reassoc double @llvm.pow.f64(double [[A:%.*]], double [[TMP1]])
; CHECK-NEXT: ret double [[M]]
;
%r = fdiv double 1.0, %a
%p = call double @llvm.pow.f64(double %a, double %b)
%r = fdiv reassoc double 1.0, %a
%p = call reassoc double @llvm.pow.f64(double %a, double %b)
%m = fmul reassoc double %p, %r
ret double %m
}
Expand All @@ -109,14 +109,14 @@ define double @pow_ab_recip_a_reassoc_commute(double %a, double %b) {

define double @pow_ab_recip_a_reassoc_use1(double %a, double %b) {
; CHECK-LABEL: @pow_ab_recip_a_reassoc_use1(
; CHECK-NEXT: [[R:%.*]] = fdiv double 1.000000e+00, [[A:%.*]]
; CHECK-NEXT: [[P:%.*]] = call double @llvm.pow.f64(double [[A]], double [[B:%.*]])
; CHECK-NEXT: [[R:%.*]] = fdiv reassoc double 1.000000e+00, [[A:%.*]]
; CHECK-NEXT: [[P:%.*]] = call reassoc double @llvm.pow.f64(double [[A]], double [[B:%.*]])
; CHECK-NEXT: [[M:%.*]] = fmul reassoc double [[R]], [[P]]
; CHECK-NEXT: call void @use(double [[R]])
; CHECK-NEXT: ret double [[M]]
;
%r = fdiv double 1.0, %a
%p = call double @llvm.pow.f64(double %a, double %b)
%r = fdiv reassoc double 1.0, %a
%p = call reassoc double @llvm.pow.f64(double %a, double %b)
%m = fmul reassoc double %r, %p
call void @use(double %r)
ret double %m
Expand All @@ -126,13 +126,13 @@ define double @pow_ab_recip_a_reassoc_use1(double %a, double %b) {

define double @pow_ab_recip_a_reassoc_use2(double %a, double %b) {
; CHECK-LABEL: @pow_ab_recip_a_reassoc_use2(
; CHECK-NEXT: [[P:%.*]] = call double @llvm.pow.f64(double [[A:%.*]], double [[B:%.*]])
; CHECK-NEXT: [[P:%.*]] = call reassoc double @llvm.pow.f64(double [[A:%.*]], double [[B:%.*]])
; CHECK-NEXT: [[M:%.*]] = fdiv reassoc double [[P]], [[A]]
; CHECK-NEXT: call void @use(double [[P]])
; CHECK-NEXT: ret double [[M]]
;
%r = fdiv double 1.0, %a
%p = call double @llvm.pow.f64(double %a, double %b)
%r = fdiv reassoc double 1.0, %a
%p = call reassoc double @llvm.pow.f64(double %a, double %b)
%m = fmul reassoc double %r, %p
call void @use(double %p)
ret double %m
Expand All @@ -142,15 +142,15 @@ define double @pow_ab_recip_a_reassoc_use2(double %a, double %b) {

define double @pow_ab_recip_a_reassoc_use3(double %a, double %b) {
; CHECK-LABEL: @pow_ab_recip_a_reassoc_use3(
; CHECK-NEXT: [[R:%.*]] = fdiv double 1.000000e+00, [[A:%.*]]
; CHECK-NEXT: [[P:%.*]] = call double @llvm.pow.f64(double [[A]], double [[B:%.*]])
; CHECK-NEXT: [[R:%.*]] = fdiv reassoc double 1.000000e+00, [[A:%.*]]
; CHECK-NEXT: [[P:%.*]] = call reassoc double @llvm.pow.f64(double [[A]], double [[B:%.*]])
; CHECK-NEXT: [[M:%.*]] = fmul reassoc double [[R]], [[P]]
; CHECK-NEXT: call void @use(double [[R]])
; CHECK-NEXT: call void @use(double [[P]])
; CHECK-NEXT: ret double [[M]]
;
%r = fdiv double 1.0, %a
%p = call double @llvm.pow.f64(double %a, double %b)
%r = fdiv reassoc double 1.0, %a
%p = call reassoc double @llvm.pow.f64(double %a, double %b)
%m = fmul reassoc double %r, %p
call void @use(double %r)
call void @use(double %p)
Expand Down
Loading
Loading