Skip to content

Commit c17adab

Browse files
committed
[AArch64][SVE] Instcombine ptrue(all) to splat(i1)
SVE Operations such as predicated loads become canonicalized to LLVM masked loads, and doing the same for ptrue(all) creates further optimization opportunities from generic LLVM IR passes.
1 parent d54c28b commit c17adab

11 files changed

+353
-530
lines changed

clang/test/CodeGen/AArch64/sve-acle-__ARM_FEATURE_SVE_VECTOR_OPERATORS.c

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,19 +52,17 @@ vec2048 x2048 = {0, 1, 2, 3, 3 , 2 , 1, 0, 0, 1, 2, 3, 3 , 2 , 1, 0,
5252
typedef int8_t vec_int8 __attribute__((vector_size(N / 8)));
5353
// CHECK128-LABEL: define{{.*}} <16 x i8> @f2(<16 x i8> noundef %x)
5454
// CHECK128-NEXT: entry:
55-
// CHECK128-NEXT: [[TMP0:%.*]] = tail call <vscale x 16 x i1> @llvm.aarch64.sve.ptrue.nxv16i1(i32 31)
5655
// CHECK128-NEXT: [[CASTSCALABLESVE:%.*]] = tail call <vscale x 16 x i8> @llvm.vector.insert.nxv16i8.v16i8(<vscale x 16 x i8> poison, <16 x i8> [[X:%.*]], i64 0)
57-
// CHECK128-NEXT: [[TMP1:%.*]] = tail call <vscale x 16 x i8> @llvm.aarch64.sve.asrd.nxv16i8(<vscale x 16 x i1> [[TMP0]], <vscale x 16 x i8> [[CASTSCALABLESVE]], i32 1)
56+
// CHECK128-NEXT: [[TMP1:%.*]] = tail call <vscale x 16 x i8> @llvm.aarch64.sve.asrd.nxv16i8(<vscale x 16 x i1> splat (i1 true), <vscale x 16 x i8> [[CASTSCALABLESVE]], i32 1)
5857
// CHECK128-NEXT: [[CASTFIXEDSVE:%.*]] = tail call <16 x i8> @llvm.vector.extract.v16i8.nxv16i8(<vscale x 16 x i8> [[TMP1]], i64 0)
5958
// CHECK128-NEXT: ret <16 x i8> [[CASTFIXEDSVE]]
6059

6160
// CHECK-LABEL: define{{.*}} void @f2(
6261
// CHECK-SAME: ptr dead_on_unwind noalias writable writeonly sret(<[[#div(VBITS,8)]] x i8>) align 16 captures(none) initializes((0, [[#div(VBITS,8)]])) %agg.result, ptr noundef readonly captures(none) %0)
6362
// CHECK-NEXT: entry:
6463
// CHECK-NEXT: [[X:%.*]] = load <[[#div(VBITS,8)]] x i8>, ptr [[TMP0:%.*]], align 16, [[TBAA6:!tbaa !.*]]
65-
// CHECK-NEXT: [[TMP1:%.*]] = tail call <vscale x 16 x i1> @llvm.aarch64.sve.ptrue.nxv16i1(i32 31)
6664
// CHECK-NEXT: [[CASTSCALABLESVE:%.*]] = tail call <vscale x 16 x i8> @llvm.vector.insert.nxv16i8.v[[#div(VBITS,8)]]i8(<vscale x 16 x i8> poison, <[[#div(VBITS,8)]] x i8> [[X]], i64 0)
67-
// CHECK-NEXT: [[TMP2:%.*]] = tail call <vscale x 16 x i8> @llvm.aarch64.sve.asrd.nxv16i8(<vscale x 16 x i1> [[TMP1]], <vscale x 16 x i8> [[CASTSCALABLESVE]], i32 1)
65+
// CHECK-NEXT: [[TMP2:%.*]] = tail call <vscale x 16 x i8> @llvm.aarch64.sve.asrd.nxv16i8(<vscale x 16 x i1> splat (i1 true), <vscale x 16 x i8> [[CASTSCALABLESVE]], i32 1)
6866
// CHECK-NEXT: [[CASTFIXEDSVE:%.*]] = tail call <[[#div(VBITS,8)]] x i8> @llvm.vector.extract.v[[#div(VBITS,8)]]i8.nxv16i8(<vscale x 16 x i8> [[TMP2]], i64 0)
6967
// CHECK-NEXT: store <[[#div(VBITS,8)]] x i8> [[CASTFIXEDSVE]], ptr [[AGG_RESULT:%.*]], align 16, [[TBAA6]]
7068
// CHECK-NEXT: ret void

clang/test/CodeGen/AArch64/sve-intrinsics/acle_sve_rdffr.c

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,12 @@
77

88
// CHECK-LABEL: @test_svrdffr(
99
// CHECK-NEXT: entry:
10-
// CHECK-NEXT: [[TMP0:%.*]] = tail call <vscale x 16 x i1> @llvm.aarch64.sve.ptrue.nxv16i1(i32 31)
11-
// CHECK-NEXT: [[TMP1:%.*]] = tail call <vscale x 16 x i1> @llvm.aarch64.sve.rdffr.z(<vscale x 16 x i1> [[TMP0]])
10+
// CHECK-NEXT: [[TMP1:%.*]] = tail call <vscale x 16 x i1> @llvm.aarch64.sve.rdffr.z(<vscale x 16 x i1> splat (i1 true))
1211
// CHECK-NEXT: ret <vscale x 16 x i1> [[TMP1]]
1312
//
1413
// CPP-CHECK-LABEL: @_Z12test_svrdffrv(
1514
// CPP-CHECK-NEXT: entry:
16-
// CPP-CHECK-NEXT: [[TMP0:%.*]] = tail call <vscale x 16 x i1> @llvm.aarch64.sve.ptrue.nxv16i1(i32 31)
17-
// CPP-CHECK-NEXT: [[TMP1:%.*]] = tail call <vscale x 16 x i1> @llvm.aarch64.sve.rdffr.z(<vscale x 16 x i1> [[TMP0]])
15+
// CPP-CHECK-NEXT: [[TMP1:%.*]] = tail call <vscale x 16 x i1> @llvm.aarch64.sve.rdffr.z(<vscale x 16 x i1> splat (i1 true))
1816
// CPP-CHECK-NEXT: ret <vscale x 16 x i1> [[TMP1]]
1917
//
2018
svbool_t test_svrdffr()

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1492,9 +1492,17 @@ static bool isAllActivePredicate(Value *Pred) {
14921492
if (cast<ScalableVectorType>(Pred->getType())->getMinNumElements() <=
14931493
cast<ScalableVectorType>(UncastedPred->getType())->getMinNumElements())
14941494
Pred = UncastedPred;
1495+
if (match(Pred, m_Intrinsic<Intrinsic::aarch64_sve_ptrue>(
1496+
m_ConstantInt<AArch64SVEPredPattern::all>())))
1497+
return true;
1498+
1499+
if (Value *Splat = getSplatValue(Pred)) {
1500+
auto ConstIdx = dyn_cast<ConstantInt>(Splat);
1501+
if (ConstIdx->getZExtValue() == 1)
1502+
return true;
1503+
}
14951504

1496-
return match(Pred, m_Intrinsic<Intrinsic::aarch64_sve_ptrue>(
1497-
m_ConstantInt<AArch64SVEPredPattern::all>()));
1505+
return false;
14981506
}
14991507

15001508
// Use SVE intrinsic info to eliminate redundant operands and/or canonicalise
@@ -1701,14 +1709,7 @@ static std::optional<Instruction *> instCombineSVECmpNE(InstCombiner &IC,
17011709
IntrinsicInst &II) {
17021710
LLVMContext &Ctx = II.getContext();
17031711

1704-
// Check that the predicate is all active
1705-
auto *Pg = dyn_cast<IntrinsicInst>(II.getArgOperand(0));
1706-
if (!Pg || Pg->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue)
1707-
return std::nullopt;
1708-
1709-
const auto PTruePattern =
1710-
cast<ConstantInt>(Pg->getOperand(0))->getZExtValue();
1711-
if (PTruePattern != AArch64SVEPredPattern::all)
1712+
if (!isAllActivePredicate(II.getArgOperand(0)))
17121713
return std::nullopt;
17131714

17141715
// Check that we have a compare of zero..
@@ -2118,8 +2119,7 @@ instCombineSVEVectorBinOp(InstCombiner &IC, IntrinsicInst &II) {
21182119
auto *OpPredicate = II.getOperand(0);
21192120
auto BinOpCode = intrinsicIDToBinOpCode(II.getIntrinsicID());
21202121
if (BinOpCode == Instruction::BinaryOpsEnd ||
2121-
!match(OpPredicate, m_Intrinsic<Intrinsic::aarch64_sve_ptrue>(
2122-
m_ConstantInt<AArch64SVEPredPattern::all>())))
2122+
!isAllActivePredicate(OpPredicate))
21232123
return std::nullopt;
21242124
auto BinOp = IC.Builder.CreateBinOpFMF(
21252125
BinOpCode, II.getOperand(1), II.getOperand(2), II.getFastMathFlags());
@@ -2641,6 +2641,20 @@ static std::optional<Instruction *> instCombineDMB(InstCombiner &IC,
26412641
return std::nullopt;
26422642
}
26432643

2644+
static std::optional<Instruction *> instCombinePTrue(InstCombiner &IC,
2645+
IntrinsicInst &II) {
2646+
IRBuilder<> Builder(&II);
2647+
auto Type = cast<VectorType>(II.getType());
2648+
ConstantInt *Pattern;
2649+
if (match(II.getOperand(0), m_ConstantInt(Pattern)) &&
2650+
Pattern->getZExtValue() == AArch64SVEPredPattern::all) {
2651+
Value *One = ConstantInt::get(Builder.getInt1Ty(), APInt(1, 1));
2652+
Value *SplatOne = Builder.CreateVectorSplat(Type->getElementCount(), One);
2653+
return IC.replaceInstUsesWith(II, SplatOne);
2654+
}
2655+
return std::nullopt;
2656+
}
2657+
26442658
std::optional<Instruction *>
26452659
AArch64TTIImpl::instCombineIntrinsic(InstCombiner &IC,
26462660
IntrinsicInst &II) const {
@@ -2744,6 +2758,8 @@ AArch64TTIImpl::instCombineIntrinsic(InstCombiner &IC,
27442758
return instCombineSVEDupqLane(IC, II);
27452759
case Intrinsic::aarch64_sve_insr:
27462760
return instCombineSVEInsr(IC, II);
2761+
case Intrinsic::aarch64_sve_ptrue:
2762+
return instCombinePTrue(IC, II);
27472763
}
27482764

27492765
return std::nullopt;

llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-abs-srshl.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,7 @@ define <vscale x 8 x i16> @srshl_abs_positive_merge(<vscale x 8 x i16> %a, <vsca
4242

4343
define <vscale x 8 x i16> @srshl_abs_all_active_pred(<vscale x 8 x i16> %a, <vscale x 8 x i16> %b, <vscale x 8 x i1> %pg2) #0 {
4444
; CHECK-LABEL: @srshl_abs_all_active_pred(
45-
; CHECK-NEXT: [[PG:%.*]] = tail call <vscale x 8 x i1> @llvm.aarch64.sve.ptrue.nxv8i1(i32 31)
46-
; CHECK-NEXT: [[ABS:%.*]] = tail call <vscale x 8 x i16> @llvm.aarch64.sve.abs.nxv8i16(<vscale x 8 x i16> [[B:%.*]], <vscale x 8 x i1> [[PG]], <vscale x 8 x i16> [[A:%.*]])
45+
; CHECK-NEXT: [[ABS:%.*]] = tail call <vscale x 8 x i16> @llvm.aarch64.sve.abs.nxv8i16(<vscale x 8 x i16> [[B:%.*]], <vscale x 8 x i1> splat (i1 true), <vscale x 8 x i16> [[A:%.*]])
4746
; CHECK-NEXT: [[TMP1:%.*]] = call <vscale x 8 x i16> @llvm.aarch64.sve.lsl.nxv8i16(<vscale x 8 x i1> [[PG2:%.*]], <vscale x 8 x i16> [[ABS]], <vscale x 8 x i16> splat (i16 2))
4847
; CHECK-NEXT: ret <vscale x 8 x i16> [[TMP1]]
4948
;

0 commit comments

Comments
 (0)