Skip to content

UNFIXED RELAND "[LoopVectorizer] Add support for chaining partial reductions (#120272)" #124199

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

vitalybuka
Copy link
Collaborator

Reverts #124198

@llvmbot
Copy link
Member

llvmbot commented Jan 23, 2025

@llvm/pr-subscribers-vectorizers

@llvm/pr-subscribers-llvm-transforms

Author: Vitaly Buka (vitalybuka)

Changes

Reverts llvm/llvm-project#124198


Patch is 79.11 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/124199.diff

4 Files Affected:

  • (modified) llvm/lib/Transforms/Vectorize/LoopVectorize.cpp (+42-22)
  • (modified) llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h (+2-2)
  • (modified) llvm/lib/Transforms/Vectorize/VPlan.h (+4-1)
  • (added) llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-chained.ll (+1025)
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 7167e2179af535..dec7a87ba9c50b 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -8682,12 +8682,12 @@ VPReplicateRecipe *VPRecipeBuilder::handleReplication(Instruction *I,
 /// are valid so recipes can be formed later.
 void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
   // Find all possible partial reductions.
-  SmallVector<std::pair<PartialReductionChain, unsigned>, 1>
+  SmallVector<std::pair<PartialReductionChain, unsigned>>
       PartialReductionChains;
-  for (const auto &[Phi, RdxDesc] : Legal->getReductionVars())
-    if (std::optional<std::pair<PartialReductionChain, unsigned>> Pair =
-            getScaledReduction(Phi, RdxDesc, Range))
-      PartialReductionChains.push_back(*Pair);
+  for (const auto &[Phi, RdxDesc] : Legal->getReductionVars()) {
+    if (auto SR = getScaledReduction(Phi, RdxDesc.getLoopExitInstr(), Range))
+      PartialReductionChains.append(*SR);
+  }
 
   // A partial reduction is invalid if any of its extends are used by
   // something that isn't another partial reduction. This is because the
@@ -8715,26 +8715,44 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
   }
 }
 
-std::optional<std::pair<PartialReductionChain, unsigned>>
-VPRecipeBuilder::getScaledReduction(PHINode *PHI,
-                                    const RecurrenceDescriptor &Rdx,
+std::optional<SmallVector<std::pair<PartialReductionChain, unsigned>>>
+VPRecipeBuilder::getScaledReduction(Instruction *PHI, Instruction *RdxExitInstr,
                                     VFRange &Range) {
+
+  if (!CM.TheLoop->contains(RdxExitInstr))
+    return std::nullopt;
+
   // TODO: Allow scaling reductions when predicating. The select at
   // the end of the loop chooses between the phi value and most recent
   // reduction result, both of which have different VFs to the active lane
   // mask when scaling.
-  if (CM.blockNeedsPredicationForAnyReason(Rdx.getLoopExitInstr()->getParent()))
+  if (CM.blockNeedsPredicationForAnyReason(RdxExitInstr->getParent()))
     return std::nullopt;
 
-  auto *Update = dyn_cast<BinaryOperator>(Rdx.getLoopExitInstr());
+  auto *Update = dyn_cast<BinaryOperator>(RdxExitInstr);
   if (!Update)
     return std::nullopt;
 
   Value *Op = Update->getOperand(0);
   Value *PhiOp = Update->getOperand(1);
-  if (Op == PHI) {
-    Op = Update->getOperand(1);
-    PhiOp = Update->getOperand(0);
+  if (Op == PHI)
+    std::swap(Op, PhiOp);
+
+  SmallVector<std::pair<PartialReductionChain, unsigned>> Chains;
+
+  // Try and get a scaled reduction from the first non-phi operand.
+  // If one is found, we use the discovered reduction instruction in
+  // place of the accumulator for costing.
+  if (auto *OpInst = dyn_cast<Instruction>(Op)) {
+    if (auto SR0 = getScaledReduction(PHI, OpInst, Range)) {
+      Chains.append(*SR0);
+      PHI = SR0->rbegin()->first.Reduction;
+
+      Op = Update->getOperand(0);
+      PhiOp = Update->getOperand(1);
+      if (Op == PHI)
+        std::swap(Op, PhiOp);
+    }
   }
   if (PhiOp != PHI)
     return std::nullopt;
@@ -8757,7 +8775,7 @@ VPRecipeBuilder::getScaledReduction(PHINode *PHI,
   TTI::PartialReductionExtendKind OpBExtend =
       TargetTransformInfo::getPartialReductionExtendKind(ExtB);
 
-  PartialReductionChain Chain(Rdx.getLoopExitInstr(), ExtA, ExtB, BinOp);
+  PartialReductionChain Chain(RdxExitInstr, ExtA, ExtB, BinOp);
 
   unsigned TargetScaleFactor =
       PHI->getType()->getPrimitiveSizeInBits().getKnownScalarFactor(
@@ -8772,9 +8790,9 @@ VPRecipeBuilder::getScaledReduction(PHINode *PHI,
             return Cost.isValid();
           },
           Range))
-    return std::make_pair(Chain, TargetScaleFactor);
+    Chains.push_back(std::make_pair(Chain, TargetScaleFactor));
 
-  return std::nullopt;
+  return Chains;
 }
 
 VPRecipeBase *
@@ -8869,12 +8887,14 @@ VPRecipeBuilder::tryToCreatePartialReduction(Instruction *Reduction,
          "Unexpected number of operands for partial reduction");
 
   VPValue *BinOp = Operands[0];
-  VPValue *Phi = Operands[1];
-  if (isa<VPReductionPHIRecipe>(BinOp->getDefiningRecipe()))
-    std::swap(BinOp, Phi);
-
-  return new VPPartialReductionRecipe(Reduction->getOpcode(), BinOp, Phi,
-                                      Reduction);
+  VPValue *Accumulator = Operands[1];
+  VPRecipeBase *BinOpRecipe = BinOp->getDefiningRecipe();
+  if (isa<VPReductionPHIRecipe>(BinOpRecipe) ||
+      isa<VPPartialReductionRecipe>(BinOpRecipe))
+    std::swap(BinOp, Accumulator);
+
+  return new VPPartialReductionRecipe(Reduction->getOpcode(), BinOp,
+                                      Accumulator, Reduction);
 }
 
 void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF,
diff --git a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
index 44745bfd46f891..9b1f40d0560bc2 100644
--- a/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
+++ b/llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h
@@ -142,8 +142,8 @@ class VPRecipeBuilder {
   /// Returns null if no scaled reduction was found, otherwise a pair with a
   /// struct containing reduction information and the scaling factor between the
   /// number of elements in the input and output.
-  std::optional<std::pair<PartialReductionChain, unsigned>>
-  getScaledReduction(PHINode *PHI, const RecurrenceDescriptor &Rdx,
+  std::optional<SmallVector<std::pair<PartialReductionChain, unsigned>>>
+  getScaledReduction(Instruction *PHI, Instruction *RdxExitInstr,
                      VFRange &Range);
 
 public:
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 09da5741d913ab..b52ee3c2428f3f 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2455,7 +2455,10 @@ class VPPartialReductionRecipe : public VPSingleDefRecipe {
       : VPSingleDefRecipe(VPDef::VPPartialReductionSC,
                           ArrayRef<VPValue *>({Op0, Op1}), ReductionInst),
         Opcode(Opcode) {
-    assert(isa<VPReductionPHIRecipe>(getOperand(1)->getDefiningRecipe()) &&
+    [[maybe_unused]] auto *AccumulatorRecipe =
+        getOperand(1)->getDefiningRecipe();
+    assert((isa<VPReductionPHIRecipe>(AccumulatorRecipe) ||
+            isa<VPPartialReductionRecipe>(AccumulatorRecipe)) &&
            "Unexpected operand order for partial reduction recipe");
   }
   ~VPPartialReductionRecipe() override = default;
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-chained.ll b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-chained.ll
new file mode 100644
index 00000000000000..bedf8b6b3a9b56
--- /dev/null
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-chained.ll
@@ -0,0 +1,1025 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt --mattr=+neon,+dotprod -passes=loop-vectorize -force-vector-interleave=1 -enable-epilogue-vectorization=false -S < %s | FileCheck %s --check-prefixes=CHECK-NEON
+; RUN: opt --mattr=+sve -passes=loop-vectorize -force-vector-interleave=1 -enable-epilogue-vectorization=false -S < %s | FileCheck %s --check-prefixes=CHECK-SVE
+; RUN: opt --mattr=+sve -vectorizer-maximize-bandwidth -passes=loop-vectorize -force-vector-interleave=1 -enable-epilogue-vectorization=false -S < %s | FileCheck %s --check-prefixes=CHECK-SVE-MAXBW
+
+target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
+target triple = "aarch64-none-unknown-elf"
+
+; Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(argmem: read) vscale_range(1,16)
+define i32 @chained_partial_reduce_add_sub(ptr %a, ptr %b, ptr %c, i32 %N) #0 {
+; CHECK-NEON-LABEL: define i32 @chained_partial_reduce_add_sub(
+; CHECK-NEON-SAME: ptr [[A:%.*]], ptr [[B:%.*]], ptr [[C:%.*]], i32 [[N:%.*]]) #[[ATTR0:[0-9]+]] {
+; CHECK-NEON-NEXT:  entry:
+; CHECK-NEON-NEXT:    [[CMP28_NOT:%.*]] = icmp ult i32 [[N]], 2
+; CHECK-NEON-NEXT:    [[DIV27:%.*]] = lshr i32 [[N]], 1
+; CHECK-NEON-NEXT:    [[WIDE_TRIP_COUNT:%.*]] = zext nneg i32 [[DIV27]] to i64
+; CHECK-NEON-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[WIDE_TRIP_COUNT]], 16
+; CHECK-NEON-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
+; CHECK-NEON:       vector.ph:
+; CHECK-NEON-NEXT:    [[N_MOD_VF:%.*]] = urem i64 [[WIDE_TRIP_COUNT]], 16
+; CHECK-NEON-NEXT:    [[N_VEC:%.*]] = sub i64 [[WIDE_TRIP_COUNT]], [[N_MOD_VF]]
+; CHECK-NEON-NEXT:    br label [[VECTOR_BODY:%.*]]
+; CHECK-NEON:       vector.body:
+; CHECK-NEON-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEON-NEXT:    [[VEC_PHI:%.*]] = phi <16 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP13:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEON-NEXT:    [[TMP0:%.*]] = add i64 [[INDEX]], 0
+; CHECK-NEON-NEXT:    [[TMP1:%.*]] = getelementptr inbounds nuw i8, ptr [[A]], i64 [[TMP0]]
+; CHECK-NEON-NEXT:    [[TMP2:%.*]] = getelementptr inbounds nuw i8, ptr [[B]], i64 [[TMP0]]
+; CHECK-NEON-NEXT:    [[TMP3:%.*]] = getelementptr inbounds nuw i8, ptr [[C]], i64 [[TMP0]]
+; CHECK-NEON-NEXT:    [[TMP4:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP1]], i32 0
+; CHECK-NEON-NEXT:    [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[TMP4]], align 1
+; CHECK-NEON-NEXT:    [[TMP5:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP2]], i32 0
+; CHECK-NEON-NEXT:    [[WIDE_LOAD1:%.*]] = load <16 x i8>, ptr [[TMP5]], align 1
+; CHECK-NEON-NEXT:    [[TMP6:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP3]], i32 0
+; CHECK-NEON-NEXT:    [[WIDE_LOAD2:%.*]] = load <16 x i8>, ptr [[TMP6]], align 1
+; CHECK-NEON-NEXT:    [[TMP7:%.*]] = sext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
+; CHECK-NEON-NEXT:    [[TMP8:%.*]] = sext <16 x i8> [[WIDE_LOAD1]] to <16 x i32>
+; CHECK-NEON-NEXT:    [[TMP9:%.*]] = sext <16 x i8> [[WIDE_LOAD2]] to <16 x i32>
+; CHECK-NEON-NEXT:    [[TMP10:%.*]] = mul nsw <16 x i32> [[TMP7]], [[TMP8]]
+; CHECK-NEON-NEXT:    [[TMP11:%.*]] = add <16 x i32> [[VEC_PHI]], [[TMP10]]
+; CHECK-NEON-NEXT:    [[TMP12:%.*]] = mul nsw <16 x i32> [[TMP7]], [[TMP9]]
+; CHECK-NEON-NEXT:    [[TMP13]] = sub <16 x i32> [[TMP11]], [[TMP12]]
+; CHECK-NEON-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
+; CHECK-NEON-NEXT:    [[TMP14:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
+; CHECK-NEON-NEXT:    br i1 [[TMP14]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
+; CHECK-NEON:       middle.block:
+; CHECK-NEON-NEXT:    [[TMP15:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP13]])
+; CHECK-NEON-NEXT:    [[CMP_N:%.*]] = icmp eq i64 [[WIDE_TRIP_COUNT]], [[N_VEC]]
+; CHECK-NEON-NEXT:    br i1 [[CMP_N]], label [[FOR_COND_CLEANUP:%.*]], label [[SCALAR_PH]]
+;
+; CHECK-SVE-LABEL: define i32 @chained_partial_reduce_add_sub(
+; CHECK-SVE-SAME: ptr [[A:%.*]], ptr [[B:%.*]], ptr [[C:%.*]], i32 [[N:%.*]]) #[[ATTR0:[0-9]+]] {
+; CHECK-SVE-NEXT:  entry:
+; CHECK-SVE-NEXT:    [[CMP28_NOT:%.*]] = icmp ult i32 [[N]], 2
+; CHECK-SVE-NEXT:    [[DIV27:%.*]] = lshr i32 [[N]], 1
+; CHECK-SVE-NEXT:    [[WIDE_TRIP_COUNT:%.*]] = zext nneg i32 [[DIV27]] to i64
+; CHECK-SVE-NEXT:    [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-SVE-NEXT:    [[TMP1:%.*]] = mul i64 [[TMP0]], 4
+; CHECK-SVE-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[WIDE_TRIP_COUNT]], [[TMP1]]
+; CHECK-SVE-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
+; CHECK-SVE:       vector.ph:
+; CHECK-SVE-NEXT:    [[TMP2:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-SVE-NEXT:    [[TMP3:%.*]] = mul i64 [[TMP2]], 4
+; CHECK-SVE-NEXT:    [[N_MOD_VF:%.*]] = urem i64 [[WIDE_TRIP_COUNT]], [[TMP3]]
+; CHECK-SVE-NEXT:    [[N_VEC:%.*]] = sub i64 [[WIDE_TRIP_COUNT]], [[N_MOD_VF]]
+; CHECK-SVE-NEXT:    [[TMP4:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-SVE-NEXT:    [[TMP5:%.*]] = mul i64 [[TMP4]], 4
+; CHECK-SVE-NEXT:    br label [[VECTOR_BODY:%.*]]
+; CHECK-SVE:       vector.body:
+; CHECK-SVE-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-SVE-NEXT:    [[VEC_PHI:%.*]] = phi <vscale x 4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP19:%.*]], [[VECTOR_BODY]] ]
+; CHECK-SVE-NEXT:    [[TMP6:%.*]] = add i64 [[INDEX]], 0
+; CHECK-SVE-NEXT:    [[TMP7:%.*]] = getelementptr inbounds nuw i8, ptr [[A]], i64 [[TMP6]]
+; CHECK-SVE-NEXT:    [[TMP8:%.*]] = getelementptr inbounds nuw i8, ptr [[B]], i64 [[TMP6]]
+; CHECK-SVE-NEXT:    [[TMP9:%.*]] = getelementptr inbounds nuw i8, ptr [[C]], i64 [[TMP6]]
+; CHECK-SVE-NEXT:    [[TMP10:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP7]], i32 0
+; CHECK-SVE-NEXT:    [[WIDE_LOAD:%.*]] = load <vscale x 4 x i8>, ptr [[TMP10]], align 1
+; CHECK-SVE-NEXT:    [[TMP11:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP8]], i32 0
+; CHECK-SVE-NEXT:    [[WIDE_LOAD1:%.*]] = load <vscale x 4 x i8>, ptr [[TMP11]], align 1
+; CHECK-SVE-NEXT:    [[TMP12:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP9]], i32 0
+; CHECK-SVE-NEXT:    [[WIDE_LOAD2:%.*]] = load <vscale x 4 x i8>, ptr [[TMP12]], align 1
+; CHECK-SVE-NEXT:    [[TMP13:%.*]] = sext <vscale x 4 x i8> [[WIDE_LOAD]] to <vscale x 4 x i32>
+; CHECK-SVE-NEXT:    [[TMP14:%.*]] = sext <vscale x 4 x i8> [[WIDE_LOAD1]] to <vscale x 4 x i32>
+; CHECK-SVE-NEXT:    [[TMP15:%.*]] = sext <vscale x 4 x i8> [[WIDE_LOAD2]] to <vscale x 4 x i32>
+; CHECK-SVE-NEXT:    [[TMP16:%.*]] = mul nsw <vscale x 4 x i32> [[TMP13]], [[TMP14]]
+; CHECK-SVE-NEXT:    [[TMP17:%.*]] = add <vscale x 4 x i32> [[VEC_PHI]], [[TMP16]]
+; CHECK-SVE-NEXT:    [[TMP18:%.*]] = mul nsw <vscale x 4 x i32> [[TMP13]], [[TMP15]]
+; CHECK-SVE-NEXT:    [[TMP19]] = sub <vscale x 4 x i32> [[TMP17]], [[TMP18]]
+; CHECK-SVE-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP5]]
+; CHECK-SVE-NEXT:    [[TMP20:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
+; CHECK-SVE-NEXT:    br i1 [[TMP20]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
+; CHECK-SVE:       middle.block:
+; CHECK-SVE-NEXT:    [[TMP21:%.*]] = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> [[TMP19]])
+; CHECK-SVE-NEXT:    [[CMP_N:%.*]] = icmp eq i64 [[WIDE_TRIP_COUNT]], [[N_VEC]]
+; CHECK-SVE-NEXT:    br i1 [[CMP_N]], label [[FOR_COND_CLEANUP:%.*]], label [[SCALAR_PH]]
+;
+; CHECK-SVE-MAXBW-LABEL: define i32 @chained_partial_reduce_add_sub(
+; CHECK-SVE-MAXBW-SAME: ptr [[A:%.*]], ptr [[B:%.*]], ptr [[C:%.*]], i32 [[N:%.*]]) #[[ATTR0:[0-9]+]] {
+; CHECK-SVE-MAXBW-NEXT:  entry:
+; CHECK-SVE-MAXBW-NEXT:    [[CMP28_NOT:%.*]] = icmp ult i32 [[N]], 2
+; CHECK-SVE-MAXBW-NEXT:    [[DIV27:%.*]] = lshr i32 [[N]], 1
+; CHECK-SVE-MAXBW-NEXT:    [[WIDE_TRIP_COUNT:%.*]] = zext nneg i32 [[DIV27]] to i64
+; CHECK-SVE-MAXBW-NEXT:    [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-SVE-MAXBW-NEXT:    [[TMP1:%.*]] = mul i64 [[TMP0]], 8
+; CHECK-SVE-MAXBW-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[WIDE_TRIP_COUNT]], [[TMP1]]
+; CHECK-SVE-MAXBW-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
+; CHECK-SVE-MAXBW:       vector.ph:
+; CHECK-SVE-MAXBW-NEXT:    [[TMP2:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-SVE-MAXBW-NEXT:    [[TMP3:%.*]] = mul i64 [[TMP2]], 8
+; CHECK-SVE-MAXBW-NEXT:    [[N_MOD_VF:%.*]] = urem i64 [[WIDE_TRIP_COUNT]], [[TMP3]]
+; CHECK-SVE-MAXBW-NEXT:    [[N_VEC:%.*]] = sub i64 [[WIDE_TRIP_COUNT]], [[N_MOD_VF]]
+; CHECK-SVE-MAXBW-NEXT:    [[TMP4:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-SVE-MAXBW-NEXT:    [[TMP5:%.*]] = mul i64 [[TMP4]], 8
+; CHECK-SVE-MAXBW-NEXT:    br label [[VECTOR_BODY:%.*]]
+; CHECK-SVE-MAXBW:       vector.body:
+; CHECK-SVE-MAXBW-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-SVE-MAXBW-NEXT:    [[VEC_PHI:%.*]] = phi <vscale x 8 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP19:%.*]], [[VECTOR_BODY]] ]
+; CHECK-SVE-MAXBW-NEXT:    [[TMP6:%.*]] = add i64 [[INDEX]], 0
+; CHECK-SVE-MAXBW-NEXT:    [[TMP7:%.*]] = getelementptr inbounds nuw i8, ptr [[A]], i64 [[TMP6]]
+; CHECK-SVE-MAXBW-NEXT:    [[TMP8:%.*]] = getelementptr inbounds nuw i8, ptr [[B]], i64 [[TMP6]]
+; CHECK-SVE-MAXBW-NEXT:    [[TMP9:%.*]] = getelementptr inbounds nuw i8, ptr [[C]], i64 [[TMP6]]
+; CHECK-SVE-MAXBW-NEXT:    [[TMP10:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP7]], i32 0
+; CHECK-SVE-MAXBW-NEXT:    [[WIDE_LOAD:%.*]] = load <vscale x 8 x i8>, ptr [[TMP10]], align 1
+; CHECK-SVE-MAXBW-NEXT:    [[TMP11:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP8]], i32 0
+; CHECK-SVE-MAXBW-NEXT:    [[WIDE_LOAD1:%.*]] = load <vscale x 8 x i8>, ptr [[TMP11]], align 1
+; CHECK-SVE-MAXBW-NEXT:    [[TMP12:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP9]], i32 0
+; CHECK-SVE-MAXBW-NEXT:    [[WIDE_LOAD2:%.*]] = load <vscale x 8 x i8>, ptr [[TMP12]], align 1
+; CHECK-SVE-MAXBW-NEXT:    [[TMP13:%.*]] = sext <vscale x 8 x i8> [[WIDE_LOAD]] to <vscale x 8 x i32>
+; CHECK-SVE-MAXBW-NEXT:    [[TMP14:%.*]] = sext <vscale x 8 x i8> [[WIDE_LOAD1]] to <vscale x 8 x i32>
+; CHECK-SVE-MAXBW-NEXT:    [[TMP15:%.*]] = sext <vscale x 8 x i8> [[WIDE_LOAD2]] to <vscale x 8 x i32>
+; CHECK-SVE-MAXBW-NEXT:    [[TMP16:%.*]] = mul nsw <vscale x 8 x i32> [[TMP13]], [[TMP14]]
+; CHECK-SVE-MAXBW-NEXT:    [[TMP17:%.*]] = add <vscale x 8 x i32> [[VEC_PHI]], [[TMP16]]
+; CHECK-SVE-MAXBW-NEXT:    [[TMP18:%.*]] = mul nsw <vscale x 8 x i32> [[TMP13]], [[TMP15]]
+; CHECK-SVE-MAXBW-NEXT:    [[TMP19]] = sub <vscale x 8 x i32> [[TMP17]], [[TMP18]]
+; CHECK-SVE-MAXBW-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP5]]
+; CHECK-SVE-MAXBW-NEXT:    [[TMP20:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
+; CHECK-SVE-MAXBW-NEXT:    br i1 [[TMP20]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
+; CHECK-SVE-MAXBW:       middle.block:
+; CHECK-SVE-MAXBW-NEXT:    [[TMP21:%.*]] = call i32 @llvm.vector.reduce.add.nxv8i32(<vscale x 8 x i32> [[TMP19]])
+; CHECK-SVE-MAXBW-NEXT:    [[CMP_N:%.*]] = icmp eq i64 [[WIDE_TRIP_COUNT]], [[N_VEC]]
+; CHECK-SVE-MAXBW-NEXT:    br i1 [[CMP_N]], label [[FOR_COND_CLEANUP:%.*]], label [[SCALAR_PH]]
+;
+entry:
+  %cmp28.not = icmp ult i32 %N, 2
+  %div27 = lshr i32 %N, 1
+  %wide.trip.count = zext nneg i32 %div27 to i64
+  br label %for.body
+
+for.cond.cleanup:                                 ; preds = %for.cond.cleanup.loopexit, %entry
+  %res.0.lcssa = phi i32 [ %sub, %for.body ]
+  ret i32 %res.0.lcssa
+
+for.body:                                         ; preds = %for.body.preheader, %for.body
+  %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ]
+  %res = phi i32 [ 0, %entry ], [ %sub, %for.body ]
+  %a.ptr = getelementptr inbounds nuw i8, ptr %a, i64 %indvars.iv
+  %b.ptr = getelementptr inbounds nuw i8, ptr %b, i64 %indvars.iv
+  %c.ptr = getelementptr inbounds nuw i8, ptr %c, i64 %indvars.iv
+  %a.val = load i8, ptr %a.ptr, align 1
+  %b.val = load i8, ptr %b.ptr, align 1
+  %c.val = load i8, ptr %c.ptr, align 1
+  %a.ext = sext i8 %a.val to i32
+  %b.ext = sext i8 %b.val to i32
+  %c.ext = sext i8 %c.val to i32
+  %mul.ab = mul nsw i32 %a.ext, %b.ext
+  %add = add nsw i32 %res, %mul.ab
+  %mul.ac = mul nsw i32 %a.ext, %c.ext
+  %sub = sub i32 %add, %mul.ac
+  %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1
+  %exitcond.not = icmp eq i64 %indvars.iv.next, %wide.trip.count
+  br i1 %exitcond.not, label %for.cond.cleanup, label %for.body, !loop !1
+}
+
+define i32 @chained_partial_reduce_add_add(ptr %a, ptr %b, ptr %c, i32 %N) #0 {
+; CHECK-NEON-LABEL: define i32 @chained_partial_reduce_add_add(
+; CHECK-NEON-SAME: ptr [[A:%.*]], ptr [[B:%.*]], ptr [[C:%.*]], i32 [[N:%.*]]) #[[ATTR0]] {
+; CHECK-NEON-NEXT:  entry:
+; CHECK-NEON-NEXT:    [[CMP28_NOT:%.*]] = icmp ult i32 [[N]], 2
+; CHECK-NEON-NEXT:    [[DIV27:%.*]] = lshr i32 [[N]], 1
+; CHECK-NEON-NEXT:    [[WIDE_TRIP_COUNT:%.*]] = zext nneg i32 [[DIV27]] to i64
+; CHECK-NEON-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[WIDE_TRIP_COUNT]], 16
+; CHECK-NEON-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
+; CHECK-NEON:       vector.ph:
+; CHECK-NEON-NEXT:    [[N_MOD_VF:%.*]] = urem i64 [[WIDE_TRIP_COUNT]], 16
+; CHECK-NEON-NEXT:    [[N_VEC:%.*]] = sub i64 [[WIDE_TRIP_COUNT]], [[N_MOD_VF]]
+; CHECK-NEON-NEXT:    br label [...
[truncated]

@vitalybuka
Copy link
Collaborator Author

Just and additional heads up about revert.

@vitalybuka vitalybuka marked this pull request as draft January 23, 2025 22:07
@vitalybuka vitalybuka closed this Jan 27, 2025
Copy link
Collaborator Author

relanded NickGuy-Arm@7bd2019

@vitalybuka vitalybuka deleted the revert-124198-users/vitalybuka/spr/revert-loopvectorizer-add-support-for-chaining-partial-reductions-120272 branch January 27, 2025 00:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants