Skip to content

Commit 3098200

Browse files
[ISel] Propagate disjoint flag in ShrinkDemandedOp (#114560)
When trying to evaluate an expression in a narrower type, the DAGCombine should propagate the disjoint flag, as it's equally valid on the narrower expression. This helps improve better use of addressing modes for some Arm SME instructions, for example.
1 parent 9cfe302 commit 3098200

File tree

3 files changed

+23
-2
lines changed

3 files changed

+23
-2
lines changed

llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -604,10 +604,15 @@ bool TargetLowering::ShrinkDemandedOp(SDValue Op, unsigned BitWidth,
604604
EVT SmallVT = EVT::getIntegerVT(*DAG.getContext(), SmallVTBits);
605605
if (isTruncateFree(VT, SmallVT) && isZExtFree(SmallVT, VT)) {
606606
// We found a type with free casts.
607+
608+
// If the operation has the 'disjoint' flag, then the
609+
// operands on the new node are also disjoint.
610+
SDNodeFlags Flags(Op->getFlags().hasDisjoint() ? SDNodeFlags::Disjoint
611+
: SDNodeFlags::None);
607612
SDValue X = DAG.getNode(
608613
Op.getOpcode(), dl, SmallVT,
609614
DAG.getNode(ISD::TRUNCATE, dl, SmallVT, Op.getOperand(0)),
610-
DAG.getNode(ISD::TRUNCATE, dl, SmallVT, Op.getOperand(1)));
615+
DAG.getNode(ISD::TRUNCATE, dl, SmallVT, Op.getOperand(1)), Flags);
611616
assert(DemandedSize <= SmallVTBits && "Narrowed below demanded bits?");
612617
SDValue Z = DAG.getNode(ISD::ANY_EXTEND, dl, VT, X);
613618
return TLO.CombineTo(Op, Z);

llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7420,7 +7420,7 @@ bool AArch64DAGToDAGISel::SelectSMETileSlice(SDValue N, unsigned MaxSize,
74207420
SDValue &Base, SDValue &Offset,
74217421
unsigned Scale) {
74227422
// Try to untangle an ADD node into a 'reg + offset'
7423-
if (N.getOpcode() == ISD::ADD)
7423+
if (CurDAG->isBaseWithConstantOffset(N))
74247424
if (auto C = dyn_cast<ConstantSDNode>(N.getOperand(1))) {
74257425
int64_t ImmOff = C->getSExtValue();
74267426
if ((ImmOff > 0 && ImmOff <= MaxSize && (ImmOff % Scale == 0))) {

llvm/test/CodeGen/AArch64/sme-intrinsics-mova-insert.ll

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,22 @@ exit:
470470
ret void
471471
}
472472

473+
define void @test_add_with_disjoint_or(i64 %idx, <vscale x 4 x i1> %pg) {
474+
; CHECK-LABEL: test_add_with_disjoint_or:
475+
; CHECK: // %bb.0:
476+
; CHECK-NEXT: mov z0.s, #0 // =0x0
477+
; CHECK-NEXT: mov x12, x0
478+
; CHECK-NEXT: mov za0h.s[w12, 0], p0/m, z0.s
479+
; CHECK-NEXT: mov za0h.s[w12, 1], p0/m, z0.s
480+
; CHECK-NEXT: ret
481+
%idx.trunc = trunc i64 %idx to i32
482+
call void @llvm.aarch64.sme.write.horiz.nxv4i32(i32 0, i32 %idx.trunc, <vscale x 4 x i1> %pg, <vscale x 4 x i32> zeroinitializer)
483+
%idx2 = or disjoint i64 %idx, 1
484+
%idx2.trunc = trunc i64 %idx2 to i32
485+
call void @llvm.aarch64.sme.write.horiz.nxv4i32(i32 0, i32 %idx2.trunc, <vscale x 4 x i1> %pg, <vscale x 4 x i32> zeroinitializer)
486+
ret void
487+
}
488+
473489
declare void @llvm.aarch64.sme.write.horiz.nxv16i8(i32, i32, <vscale x 16 x i1>, <vscale x 16 x i8>)
474490
declare void @llvm.aarch64.sme.write.horiz.nxv8i16(i32, i32, <vscale x 8 x i1>, <vscale x 8 x i16>)
475491
declare void @llvm.aarch64.sme.write.horiz.nxv8f16(i32, i32, <vscale x 8 x i1>, <vscale x 8 x half>)

0 commit comments

Comments
 (0)