Skip to content

[DAGCombine] Remove oneuse restrictions for RISCV in folding (shl (add_nsw x, c1)), c2) and folding (shl(sext(add x, c1)), c2) in some scenarios #101294

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions llvm/include/llvm/CodeGen/TargetLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -4305,6 +4305,12 @@ class TargetLowering : public TargetLoweringBase {
/// @param Level the current DAGCombine legalization level.
virtual bool isDesirableToCommuteWithShift(const SDNode *N,
CombineLevel Level) const {
SDValue ShiftLHS = N->getOperand(0);
if (!ShiftLHS->hasOneUse())
return false;
if (ShiftLHS.getOpcode() == ISD::SIGN_EXTEND &&
!ShiftLHS.getOperand(0)->hasOneUse())
return false;
return true;
}

Expand Down
5 changes: 2 additions & 3 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10230,7 +10230,7 @@ SDValue DAGCombiner::visitSHL(SDNode *N) {
// Variant of version done on multiply, except mul by a power of 2 is turned
// into a shift.
if ((N0.getOpcode() == ISD::ADD || N0.getOpcode() == ISD::OR) &&
N0->hasOneUse() && TLI.isDesirableToCommuteWithShift(N, Level)) {
TLI.isDesirableToCommuteWithShift(N, Level)) {
SDValue N01 = N0.getOperand(1);
if (SDValue Shl1 =
DAG.FoldConstantArithmetic(ISD::SHL, SDLoc(N1), VT, {N01, N1})) {
Expand All @@ -10249,8 +10249,7 @@ SDValue DAGCombiner::visitSHL(SDNode *N) {
// TODO: Should we limit this with isLegalAddImmediate?
if (N0.getOpcode() == ISD::SIGN_EXTEND &&
N0.getOperand(0).getOpcode() == ISD::ADD &&
N0.getOperand(0)->getFlags().hasNoSignedWrap() && N0->hasOneUse() &&
N0.getOperand(0)->hasOneUse() &&
N0.getOperand(0)->getFlags().hasNoSignedWrap() &&
TLI.isDesirableToCommuteWithShift(N, Level)) {
SDValue Add = N0.getOperand(0);
SDLoc DL(N0);
Expand Down
7 changes: 7 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17931,6 +17931,13 @@ AArch64TargetLowering::isDesirableToCommuteWithShift(const SDNode *N,
SDValue ShiftLHS = N->getOperand(0);
EVT VT = N->getValueType(0);

if (!ShiftLHS->hasOneUse())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please don't make unnecessary change

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

return false;

if (ShiftLHS.getOpcode() == ISD::SIGN_EXTEND &&
!ShiftLHS.getOperand(0)->hasOneUse())
return false;

// If ShiftLHS is unsigned bit extraction: ((x >> C) & mask), then do not
// combine it with shift 'N' to let it be lowered to UBFX except:
// ((x >> C) & mask) << C.
Expand Down
9 changes: 9 additions & 0 deletions llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1072,6 +1072,15 @@ bool AMDGPUTargetLowering::isDesirableToCommuteWithShift(
assert((N->getOpcode() == ISD::SHL || N->getOpcode() == ISD::SRA ||
N->getOpcode() == ISD::SRL) &&
"Expected shift op");

SDValue ShiftLHS = N->getOperand(0);
if (!ShiftLHS->hasOneUse())
return false;

if (ShiftLHS.getOpcode() == ISD::SIGN_EXTEND &&
!ShiftLHS.getOperand(0)->hasOneUse())
return false;

// Always commute pre-type legalization and right shifts.
// We're looking for shl(or(x,y),z) patterns.
if (Level < CombineLevel::AfterLegalizeTypes ||
Expand Down
8 changes: 8 additions & 0 deletions llvm/lib/Target/ARM/ARMISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13880,6 +13880,14 @@ ARMTargetLowering::isDesirableToCommuteWithShift(const SDNode *N,
N->getOpcode() == ISD::SRL) &&
"Expected shift op");

SDValue ShiftLHS = N->getOperand(0);
if (!ShiftLHS->hasOneUse())
return false;

if (ShiftLHS.getOpcode() == ISD::SIGN_EXTEND &&
!ShiftLHS.getOperand(0)->hasOneUse())
return false;

if (Level == BeforeLegalizeTypes)
return true;

Expand Down
47 changes: 47 additions & 0 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18151,8 +18151,46 @@ bool RISCVTargetLowering::isDesirableToCommuteWithShift(
// (shl (or x, c1), c2) -> (or (shl x, c2), c1 << c2)
SDValue N0 = N->getOperand(0);
EVT Ty = N0.getValueType();

// LD/ST will optimize constant Offset extraction, so when AddNode is used by
// LD/ST, it can still complete the folding optimization operation performed
// above.
auto isUsedByLdSt = [&]() {
bool CanOptAlways = false;
if (N0->getOpcode() == ISD::ADD && !N0->hasOneUse()) {
for (SDNode *Use : N0->uses()) {
// This use is the one we're on right now. Skip it
if (Use == N || Use->getOpcode() == ISD::SELECT)
continue;
if (!isa<StoreSDNode>(Use) && !isa<LoadSDNode>(Use)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to check that c1 << c2 fits in 12 bits so it will fold into the load/store?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There may be no need for check. The address calculation will generate LUI + addi to ensure that the LD/ST offset falls within the 12-bit range.

CanOptAlways = false;
break;
}
CanOptAlways = true;
}
}

if (N0->getOpcode() == ISD::SIGN_EXTEND &&
!N0->getOperand(0)->hasOneUse()) {
for (SDNode *Use : N0->getOperand(0)->uses()) {
// This use is the one we're on right now. Skip it
if (Use == N0.getNode() || Use->getOpcode() == ISD::SELECT)
continue;
if (!isa<StoreSDNode>(Use) && !isa<LoadSDNode>(Use)) {
CanOptAlways = false;
break;
}
CanOptAlways = true;
}
}
return CanOptAlways;
};

if (Ty.isScalarInteger() &&
(N0.getOpcode() == ISD::ADD || N0.getOpcode() == ISD::OR)) {
if (N0.getOpcode() == ISD::ADD && !N0->hasOneUse())
return isUsedByLdSt();

auto *C1 = dyn_cast<ConstantSDNode>(N0->getOperand(1));
auto *C2 = dyn_cast<ConstantSDNode>(N->getOperand(1));
if (C1 && C2) {
Expand Down Expand Up @@ -18187,6 +18225,15 @@ bool RISCVTargetLowering::isDesirableToCommuteWithShift(
return false;
}
}

if (!N0->hasOneUse())
return false;

if (N0->getOpcode() == ISD::SIGN_EXTEND &&
N0->getOperand(0)->getOpcode() == ISD::ADD &&
!N0->getOperand(0)->hasOneUse())
return isUsedByLdSt();

return true;
}

Expand Down
182 changes: 182 additions & 0 deletions llvm/test/CodeGen/RISCV/add_sext_shl_constant.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 3
; RUN: llc -mtriple=riscv64 < %s | FileCheck -check-prefix=RV64 %s

define void @add_sext_shl_moreOneUse_add(ptr %array1, i32 %a, i32 %b) {
; RV64-LABEL: add_sext_shl_moreOneUse_add:
; RV64: # %bb.0: # %entry
; RV64-NEXT: addi a3, a1, 5
; RV64-NEXT: sext.w a1, a1
; RV64-NEXT: slli a1, a1, 2
; RV64-NEXT: add a0, a1, a0
; RV64-NEXT: sw a2, 20(a0)
; RV64-NEXT: sw a2, 24(a0)
; RV64-NEXT: sw a3, 140(a0)
; RV64-NEXT: ret
entry:
%add = add nsw i32 %a, 5
%idxprom = sext i32 %add to i64
%arrayidx = getelementptr inbounds i32, ptr %array1, i64 %idxprom
store i32 %b, ptr %arrayidx
%add3 = add nsw i32 %a, 6
%idxprom4 = sext i32 %add3 to i64
%arrayidx5 = getelementptr inbounds i32, ptr %array1, i64 %idxprom4
store i32 %b, ptr %arrayidx5
%add6 = add nsw i32 %a, 35
%idxprom7 = sext i32 %add6 to i64
%arrayidx8 = getelementptr inbounds i32, ptr %array1, i64 %idxprom7
store i32 %add, ptr %arrayidx8
ret void
}

define void @add_sext_shl_moreOneUse_addexceedsign12(ptr %array1, i32 %a, i32 %b) {
; RV64-LABEL: add_sext_shl_moreOneUse_addexceedsign12:
; RV64: # %bb.0: # %entry
; RV64-NEXT: addi a3, a1, 2047
; RV64-NEXT: lui a4, 2
; RV64-NEXT: sext.w a1, a1
; RV64-NEXT: addi a3, a3, 1
; RV64-NEXT: slli a1, a1, 2
; RV64-NEXT: add a0, a0, a4
; RV64-NEXT: add a0, a0, a1
; RV64-NEXT: sw a2, 0(a0)
; RV64-NEXT: sw a3, 4(a0)
; RV64-NEXT: sw a2, 120(a0)
; RV64-NEXT: ret
entry:
%add = add nsw i32 %a, 2048
%idxprom = sext i32 %add to i64
%arrayidx = getelementptr inbounds i32, ptr %array1, i64 %idxprom
store i32 %b, ptr %arrayidx
%0 = sext i32 %a to i64
%1 = getelementptr i32, ptr %array1, i64 %0
%arrayidx3 = getelementptr i8, ptr %1, i64 8196
store i32 %add, ptr %arrayidx3
%arrayidx6 = getelementptr i8, ptr %1, i64 8312
store i32 %b, ptr %arrayidx6
ret void
}

define void @add_sext_shl_moreOneUse_sext(ptr %array1, i32 %a, i32 %b) {
; RV64-LABEL: add_sext_shl_moreOneUse_sext:
; RV64: # %bb.0: # %entry
; RV64-NEXT: sext.w a1, a1
; RV64-NEXT: addi a3, a1, 5
; RV64-NEXT: slli a1, a1, 2
; RV64-NEXT: add a0, a1, a0
; RV64-NEXT: sw a2, 20(a0)
; RV64-NEXT: sw a2, 24(a0)
; RV64-NEXT: sd a3, 140(a0)
; RV64-NEXT: ret
entry:
%add = add nsw i32 %a, 5
%idxprom = sext i32 %add to i64
%arrayidx = getelementptr inbounds i32, ptr %array1, i64 %idxprom
store i32 %b, ptr %arrayidx
%add3 = add nsw i32 %a, 6
%idxprom4 = sext i32 %add3 to i64
%arrayidx5 = getelementptr inbounds i32, ptr %array1, i64 %idxprom4
store i32 %b, ptr %arrayidx5
%add6 = add nsw i32 %a, 35
%idxprom7 = sext i32 %add6 to i64
%arrayidx8 = getelementptr inbounds i32, ptr %array1, i64 %idxprom7
store i64 %idxprom, ptr %arrayidx8
ret void
}

; test of jumpping, find add's operand has one more use can simplified
define void @add_sext_shl_moreOneUse_add_inSelect(ptr %array1, i32 signext %a, i32 %b, i32 signext %x) {
; RV64-LABEL: add_sext_shl_moreOneUse_add_inSelect:
; RV64: # %bb.0: # %entry
; RV64-NEXT: addi a4, a1, 5
; RV64-NEXT: mv a5, a4
; RV64-NEXT: bgtz a3, .LBB3_2
; RV64-NEXT: # %bb.1: # %entry
; RV64-NEXT: mv a5, a2
; RV64-NEXT: .LBB3_2: # %entry
; RV64-NEXT: slli a1, a1, 2
; RV64-NEXT: add a0, a1, a0
; RV64-NEXT: sw a5, 20(a0)
; RV64-NEXT: sw a5, 24(a0)
; RV64-NEXT: sw a4, 140(a0)
; RV64-NEXT: ret
entry:
%add = add nsw i32 %a, 5
%cmp = icmp sgt i32 %x, 0
%idxprom = sext i32 %add to i64
%arrayidx = getelementptr inbounds i32, ptr %array1, i64 %idxprom
%add.b = select i1 %cmp, i32 %add, i32 %b
store i32 %add.b, ptr %arrayidx
%add5 = add nsw i32 %a, 6
%idxprom6 = sext i32 %add5 to i64
%arrayidx7 = getelementptr inbounds i32, ptr %array1, i64 %idxprom6
store i32 %add.b, ptr %arrayidx7
%add8 = add nsw i32 %a, 35
%idxprom9 = sext i32 %add8 to i64
%arrayidx10 = getelementptr inbounds i32, ptr %array1, i64 %idxprom9
store i32 %add, ptr %arrayidx10
ret void
}

define void @add_sext_shl_moreOneUse_add_inSelect_addexceedsign12(ptr %array1, i32 signext %a, i32 %b, i32 signext %x) {
; RV64-LABEL: add_sext_shl_moreOneUse_add_inSelect_addexceedsign12:
; RV64: # %bb.0: # %entry
; RV64-NEXT: addi a4, a1, 2047
; RV64-NEXT: lui a5, 2
; RV64-NEXT: slli a6, a1, 2
; RV64-NEXT: addi a1, a4, 1
; RV64-NEXT: add a0, a0, a6
; RV64-NEXT: add a0, a0, a5
; RV64-NEXT: mv a4, a1
; RV64-NEXT: bgtz a3, .LBB4_2
; RV64-NEXT: # %bb.1: # %entry
; RV64-NEXT: mv a4, a2
; RV64-NEXT: .LBB4_2: # %entry
; RV64-NEXT: sw a4, 0(a0)
; RV64-NEXT: sw a4, 4(a0)
; RV64-NEXT: sw a1, 120(a0)
; RV64-NEXT: ret
entry:
%add = add nsw i32 %a, 2048
%cmp = icmp sgt i32 %x, 0
%idxprom = sext i32 %add to i64
%arrayidx = getelementptr inbounds i32, ptr %array1, i64 %idxprom
%add.b = select i1 %cmp, i32 %add, i32 %b
store i32 %add.b, ptr %arrayidx
%0 = sext i32 %a to i64
%1 = getelementptr i32, ptr %array1, i64 %0
%arrayidx7 = getelementptr i8, ptr %1, i64 8196
store i32 %add.b, ptr %arrayidx7
%arrayidx10 = getelementptr i8, ptr %1, i64 8312
store i32 %add, ptr %arrayidx10
ret void
}

define void @add_shl_moreOneUse_inSelect(ptr %array1, i64 %a, i64 %b, i64 %x) {
; RV64-LABEL: add_shl_moreOneUse_inSelect:
; RV64: # %bb.0: # %entry
; RV64-NEXT: addi a4, a1, 5
; RV64-NEXT: mv a5, a4
; RV64-NEXT: bgtz a3, .LBB5_2
; RV64-NEXT: # %bb.1: # %entry
; RV64-NEXT: mv a5, a2
; RV64-NEXT: .LBB5_2: # %entry
; RV64-NEXT: slli a1, a1, 3
; RV64-NEXT: add a0, a1, a0
; RV64-NEXT: sd a5, 40(a0)
; RV64-NEXT: sd a5, 48(a0)
; RV64-NEXT: sd a4, 280(a0)
; RV64-NEXT: ret
entry:
%add = add nsw i64 %a, 5
%cmp = icmp sgt i64 %x, 0
%spec.select = select i1 %cmp, i64 %add, i64 %b
%0 = getelementptr inbounds i64, ptr %array1, i64 %add
store i64 %spec.select, ptr %0
%add3 = add nsw i64 %a, 6
%arrayidx4 = getelementptr inbounds i64, ptr %array1, i64 %add3
store i64 %spec.select, ptr %arrayidx4
%add5 = add nsw i64 %a, 35
%arrayidx6 = getelementptr inbounds i64, ptr %array1, i64 %add5
store i64 %add, ptr %arrayidx6
ret void
}
Loading
Loading