Skip to content

[AArch64][LoopIdiom] Generalize AArch64LoopIdiomTransform into LoopIdiomVectorize #94081

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 7, 2024

Conversation

mshockwave
Copy link
Member

@mshockwave mshockwave commented May 31, 2024

To facilitate sharing LoopIdiomTransform between AArch64 and RISC-V, this first patch moves AArch64LoopIdiomTransform from lib/Target/AArch64 to lib/Transforms/Vectorize and renames it to LoopIdiomVectorize. The next patch (#94082) will teach LoopIdiomVectorize how to generate VP intrinsics (in addition to the current masked vector style) in favor of RVV. The key component that dictates the vectorization style, of which SVE and RVV differ, is factored out in this patch as well.

@llvmbot
Copy link
Member

llvmbot commented May 31, 2024

@llvm/pr-subscribers-llvm-transforms

@llvm/pr-subscribers-backend-aarch64

Author: Min-Yih Hsu (mshockwave)

Changes

To facilitate sharing LoopIdiomTransform between AArch64 and RISC-V, this first patch moves AArch64LoopIdiomTransform from lib/Target/AArch64 to lib/Transforms/Vectorize. The next patch will teach LoopIdiomTransform how to generate VP intrinsics (in addition to the current masked vector style) in favor of RVV. The key component that dictates the vectorization style, of which SVE and RVV differ, is factored out in this patch as well.


Patch is 98.99 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/94081.diff

11 Files Affected:

  • (renamed) llvm/include/llvm/Transforms/Vectorize/LoopIdiomTransform.h (+5-9)
  • (modified) llvm/lib/Passes/PassBuilder.cpp (+1)
  • (modified) llvm/lib/Passes/PassRegistry.def (+1)
  • (modified) llvm/lib/Target/AArch64/AArch64.h (-1)
  • (removed) llvm/lib/Target/AArch64/AArch64PassRegistry.def (-20)
  • (modified) llvm/lib/Target/AArch64/AArch64TargetMachine.cpp (+2-6)
  • (modified) llvm/lib/Target/AArch64/AArch64TargetMachine.h (-1)
  • (modified) llvm/lib/Target/AArch64/CMakeLists.txt (+1-1)
  • (modified) llvm/lib/Transforms/Vectorize/CMakeLists.txt (+1)
  • (renamed) llvm/lib/Transforms/Vectorize/LoopIdiomTransform.cpp (+190-253)
  • (modified) llvm/test/Transforms/LoopIdiom/AArch64/byte-compare-index.ll (+201-204)
diff --git a/llvm/lib/Target/AArch64/AArch64LoopIdiomTransform.h b/llvm/include/llvm/Transforms/Vectorize/LoopIdiomTransform.h
similarity index 60%
rename from llvm/lib/Target/AArch64/AArch64LoopIdiomTransform.h
rename to llvm/include/llvm/Transforms/Vectorize/LoopIdiomTransform.h
index cc68425bb68b5..a97dcc7ae3a3f 100644
--- a/llvm/lib/Target/AArch64/AArch64LoopIdiomTransform.h
+++ b/llvm/include/llvm/Transforms/Vectorize/LoopIdiomTransform.h
@@ -1,4 +1,4 @@
-//===- AArch64LoopIdiomTransform.h --------------------------------------===//
+//===----------LoopIdiomTransform.h -----------------------------*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,20 +6,16 @@
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef LLVM_LIB_TARGET_AARCH64_AARCH64LOOPIDIOMTRANSFORM_H
-#define LLVM_LIB_TARGET_AARCH64_AARCH64LOOPIDIOMTRANSFORM_H
+#ifndef LLVM_LIB_TRANSFORMS_VECTORIZE_LOOPIDIOMTRANSFORM_H
+#define LLVM_LIB_TRANSFORMS_VECTORIZE_LOOPIDIOMTRANSFORM_H
 
 #include "llvm/IR/PassManager.h"
 #include "llvm/Transforms/Scalar/LoopPassManager.h"
 
 namespace llvm {
-
-struct AArch64LoopIdiomTransformPass
-    : PassInfoMixin<AArch64LoopIdiomTransformPass> {
+struct LoopIdiomTransformPass : PassInfoMixin<LoopIdiomTransformPass> {
   PreservedAnalyses run(Loop &L, LoopAnalysisManager &AM,
                         LoopStandardAnalysisResults &AR, LPMUpdater &U);
 };
-
 } // namespace llvm
-
-#endif // LLVM_LIB_TARGET_AARCH64_AARCH64LOOPIDIOMTRANSFORM_H
+#endif // LLVM_LIB_TRANSFORMS_VECTORIZE_LOOPIDIOMTRANSFORM_H
diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp
index 734ca4d5deec9..bf11146a05e5a 100644
--- a/llvm/lib/Passes/PassBuilder.cpp
+++ b/llvm/lib/Passes/PassBuilder.cpp
@@ -295,6 +295,7 @@
 #include "llvm/Transforms/Utils/UnifyFunctionExitNodes.h"
 #include "llvm/Transforms/Utils/UnifyLoopExits.h"
 #include "llvm/Transforms/Vectorize/LoadStoreVectorizer.h"
+#include "llvm/Transforms/Vectorize/LoopIdiomTransform.h"
 #include "llvm/Transforms/Vectorize/LoopVectorize.h"
 #include "llvm/Transforms/Vectorize/SLPVectorizer.h"
 #include "llvm/Transforms/Vectorize/VectorCombine.h"
diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def
index 50682ca4970f1..714058f91bfc6 100644
--- a/llvm/lib/Passes/PassRegistry.def
+++ b/llvm/lib/Passes/PassRegistry.def
@@ -621,6 +621,7 @@ LOOP_PASS("invalidate<all>", InvalidateAllAnalysesPass())
 LOOP_PASS("loop-bound-split", LoopBoundSplitPass())
 LOOP_PASS("loop-deletion", LoopDeletionPass())
 LOOP_PASS("loop-idiom", LoopIdiomRecognizePass())
+LOOP_PASS("loop-idiom-transform", LoopIdiomTransformPass())
 LOOP_PASS("loop-instsimplify", LoopInstSimplifyPass())
 LOOP_PASS("loop-predication", LoopPredicationPass())
 LOOP_PASS("loop-reduce", LoopStrengthReducePass())
diff --git a/llvm/lib/Target/AArch64/AArch64.h b/llvm/lib/Target/AArch64/AArch64.h
index b70fbe42fe5fc..19e0d1e2f5960 100644
--- a/llvm/lib/Target/AArch64/AArch64.h
+++ b/llvm/lib/Target/AArch64/AArch64.h
@@ -90,7 +90,6 @@ void initializeAArch64DeadRegisterDefinitionsPass(PassRegistry&);
 void initializeAArch64ExpandPseudoPass(PassRegistry &);
 void initializeAArch64GlobalsTaggingPass(PassRegistry &);
 void initializeAArch64LoadStoreOptPass(PassRegistry&);
-void initializeAArch64LoopIdiomTransformLegacyPassPass(PassRegistry &);
 void initializeAArch64LowerHomogeneousPrologEpilogPass(PassRegistry &);
 void initializeAArch64MIPeepholeOptPass(PassRegistry &);
 void initializeAArch64O0PreLegalizerCombinerPass(PassRegistry &);
diff --git a/llvm/lib/Target/AArch64/AArch64PassRegistry.def b/llvm/lib/Target/AArch64/AArch64PassRegistry.def
deleted file mode 100644
index ca944579f93a9..0000000000000
--- a/llvm/lib/Target/AArch64/AArch64PassRegistry.def
+++ /dev/null
@@ -1,20 +0,0 @@
-//===- AArch64PassRegistry.def - Registry of AArch64 passes -----*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file is used as the registry of passes that are part of the
-// AArch64 backend.
-//
-//===----------------------------------------------------------------------===//
-
-// NOTE: NO INCLUDE GUARD DESIRED!
-
-#ifndef LOOP_PASS
-#define LOOP_PASS(NAME, CREATE_PASS)
-#endif
-LOOP_PASS("aarch64-lit", AArch64LoopIdiomTransformPass())
-#undef LOOP_PASS
diff --git a/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp b/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp
index 945ab5cf1f303..a6e26501541f3 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp
@@ -11,7 +11,6 @@
 
 #include "AArch64TargetMachine.h"
 #include "AArch64.h"
-#include "AArch64LoopIdiomTransform.h"
 #include "AArch64MachineFunctionInfo.h"
 #include "AArch64MachineScheduler.h"
 #include "AArch64MacroFusion.h"
@@ -52,6 +51,7 @@
 #include "llvm/TargetParser/Triple.h"
 #include "llvm/Transforms/CFGuard.h"
 #include "llvm/Transforms/Scalar.h"
+#include "llvm/Transforms/Vectorize/LoopIdiomTransform.h"
 #include <memory>
 #include <optional>
 #include <string>
@@ -234,7 +234,6 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeAArch64Target() {
   initializeAArch64DeadRegisterDefinitionsPass(*PR);
   initializeAArch64ExpandPseudoPass(*PR);
   initializeAArch64LoadStoreOptPass(*PR);
-  initializeAArch64LoopIdiomTransformLegacyPassPass(*PR);
   initializeAArch64MIPeepholeOptPass(*PR);
   initializeAArch64SIMDInstrOptPass(*PR);
   initializeAArch64O0PreLegalizerCombinerPass(*PR);
@@ -553,12 +552,9 @@ class AArch64PassConfig : public TargetPassConfig {
 void AArch64TargetMachine::registerPassBuilderCallbacks(
     PassBuilder &PB, bool PopulateClassToPassNames) {
 
-#define GET_PASS_REGISTRY "AArch64PassRegistry.def"
-#include "llvm/Passes/TargetPassRegistry.inc"
-
   PB.registerLateLoopOptimizationsEPCallback(
       [=](LoopPassManager &LPM, OptimizationLevel Level) {
-        LPM.addPass(AArch64LoopIdiomTransformPass());
+        LPM.addPass(LoopIdiomTransformPass());
       });
 }
 
diff --git a/llvm/lib/Target/AArch64/AArch64TargetMachine.h b/llvm/lib/Target/AArch64/AArch64TargetMachine.h
index 8fb68b06f1378..e396d9204716a 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetMachine.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetMachine.h
@@ -14,7 +14,6 @@
 #define LLVM_LIB_TARGET_AARCH64_AARCH64TARGETMACHINE_H
 
 #include "AArch64InstrInfo.h"
-#include "AArch64LoopIdiomTransform.h"
 #include "AArch64Subtarget.h"
 #include "llvm/IR/DataLayout.h"
 #include "llvm/Target/TargetMachine.h"
diff --git a/llvm/lib/Target/AArch64/CMakeLists.txt b/llvm/lib/Target/AArch64/CMakeLists.txt
index 8e76f6c9279e7..639bc0707dff2 100644
--- a/llvm/lib/Target/AArch64/CMakeLists.txt
+++ b/llvm/lib/Target/AArch64/CMakeLists.txt
@@ -65,7 +65,6 @@ add_llvm_target(AArch64CodeGen
   AArch64ISelLowering.cpp
   AArch64InstrInfo.cpp
   AArch64LoadStoreOptimizer.cpp
-  AArch64LoopIdiomTransform.cpp
   AArch64LowerHomogeneousPrologEpilog.cpp
   AArch64MachineFunctionInfo.cpp
   AArch64MachineScheduler.cpp
@@ -112,6 +111,7 @@ add_llvm_target(AArch64CodeGen
   Target
   TargetParser
   TransformUtils
+  Vectorize
 
   ADD_TO_COMPONENT
   AArch64
diff --git a/llvm/lib/Transforms/Vectorize/CMakeLists.txt b/llvm/lib/Transforms/Vectorize/CMakeLists.txt
index 9674094024b9e..3ca5c404d020f 100644
--- a/llvm/lib/Transforms/Vectorize/CMakeLists.txt
+++ b/llvm/lib/Transforms/Vectorize/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_llvm_component_library(LLVMVectorize
   LoadStoreVectorizer.cpp
+  LoopIdiomTransform.cpp
   LoopVectorizationLegality.cpp
   LoopVectorize.cpp
   SLPVectorizer.cpp
diff --git a/llvm/lib/Target/AArch64/AArch64LoopIdiomTransform.cpp b/llvm/lib/Transforms/Vectorize/LoopIdiomTransform.cpp
similarity index 71%
rename from llvm/lib/Target/AArch64/AArch64LoopIdiomTransform.cpp
rename to llvm/lib/Transforms/Vectorize/LoopIdiomTransform.cpp
index a9bd8d877fb2e..5af1d6aa3b61e 100644
--- a/llvm/lib/Target/AArch64/AArch64LoopIdiomTransform.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopIdiomTransform.cpp
@@ -1,4 +1,4 @@
-//===- AArch64LoopIdiomTransform.cpp - Loop idiom recognition -------------===//
+//===-------- LoopIdiomTransform.cpp - Loop idiom recognition -------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -35,7 +35,8 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "AArch64LoopIdiomTransform.h"
+#include "llvm/Transforms/Vectorize/LoopIdiomTransform.h"
+#include "llvm/ADT/ScopeExit.h"
 #include "llvm/Analysis/DomTreeUpdater.h"
 #include "llvm/Analysis/LoopPass.h"
 #include "llvm/Analysis/TargetTransformInfo.h"
@@ -44,47 +45,46 @@
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/PatternMatch.h"
-#include "llvm/InitializePasses.h"
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
 
 using namespace llvm;
 using namespace PatternMatch;
 
-#define DEBUG_TYPE "aarch64-loop-idiom-transform"
+#define DEBUG_TYPE "loop-idiom-transform"
 
-static cl::opt<bool>
-    DisableAll("disable-aarch64-lit-all", cl::Hidden, cl::init(false),
-               cl::desc("Disable AArch64 Loop Idiom Transform Pass."));
-
-static cl::opt<bool> DisableByteCmp(
-    "disable-aarch64-lit-bytecmp", cl::Hidden, cl::init(false),
-    cl::desc("Proceed with AArch64 Loop Idiom Transform Pass, but do "
-             "not convert byte-compare loop(s)."));
-
-static cl::opt<bool> VerifyLoops(
-    "aarch64-lit-verify", cl::Hidden, cl::init(false),
-    cl::desc("Verify loops generated AArch64 Loop Idiom Transform Pass."));
+static cl::opt<bool> DisableAll("disable-loop-idiom-transform-all", cl::Hidden,
+                                cl::init(false),
+                                cl::desc("Disable Loop Idiom Transform Pass."));
 
-namespace llvm {
-
-void initializeAArch64LoopIdiomTransformLegacyPassPass(PassRegistry &);
-Pass *createAArch64LoopIdiomTransformPass();
+static cl::opt<bool>
+    DisableByteCmp("disable-loop-idiom-transform-bytecmp", cl::Hidden,
+                   cl::init(false),
+                   cl::desc("Proceed with Loop Idiom Transform Pass, but do "
+                            "not convert byte-compare loop(s)."));
 
-} // end namespace llvm
+static cl::opt<bool>
+    VerifyLoops("verify-loop-idiom-transform", cl::Hidden, cl::init(false),
+                cl::desc("Verify loops generated Loop Idiom Transform Pass."));
 
 namespace {
-
-class AArch64LoopIdiomTransform {
+class LoopIdiomTransform {
   Loop *CurLoop = nullptr;
   DominatorTree *DT;
   LoopInfo *LI;
   const TargetTransformInfo *TTI;
   const DataLayout *DL;
 
+  // Blocks that will be used for inserting vectorized code.
+  BasicBlock *EndBlock = nullptr;
+  BasicBlock *VectorLoopPreheaderBlock = nullptr;
+  BasicBlock *VectorLoopStartBlock = nullptr;
+  BasicBlock *VectorLoopMismatchBlock = nullptr;
+  BasicBlock *VectorLoopIncBlock = nullptr;
+
 public:
-  explicit AArch64LoopIdiomTransform(DominatorTree *DT, LoopInfo *LI,
-                                     const TargetTransformInfo *TTI,
-                                     const DataLayout *DL)
+  explicit LoopIdiomTransform(DominatorTree *DT, LoopInfo *LI,
+                              const TargetTransformInfo *TTI,
+                              const DataLayout *DL)
       : DT(DT), LI(LI), TTI(TTI), DL(DL) {}
 
   bool run(Loop *L);
@@ -98,83 +98,32 @@ class AArch64LoopIdiomTransform {
                       SmallVectorImpl<BasicBlock *> &ExitBlocks);
 
   bool recognizeByteCompare();
+
   Value *expandFindMismatch(IRBuilder<> &Builder, DomTreeUpdater &DTU,
                             GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
                             Instruction *Index, Value *Start, Value *MaxLen);
+
+  Value *createMaskedFindMismatch(IRBuilder<> &Builder, GetElementPtrInst *GEPA,
+                                  GetElementPtrInst *GEPB, Value *ExtStart,
+                                  Value *ExtEnd);
+
   void transformByteCompare(GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
                             PHINode *IndPhi, Value *MaxLen, Instruction *Index,
                             Value *Start, bool IncIdx, BasicBlock *FoundBB,
                             BasicBlock *EndBB);
   /// @}
 };
+} // anonymous namespace
 
-class AArch64LoopIdiomTransformLegacyPass : public LoopPass {
-public:
-  static char ID;
-
-  explicit AArch64LoopIdiomTransformLegacyPass() : LoopPass(ID) {
-    initializeAArch64LoopIdiomTransformLegacyPassPass(
-        *PassRegistry::getPassRegistry());
-  }
-
-  StringRef getPassName() const override {
-    return "Transform AArch64-specific loop idioms";
-  }
-
-  void getAnalysisUsage(AnalysisUsage &AU) const override {
-    AU.addRequired<LoopInfoWrapperPass>();
-    AU.addRequired<DominatorTreeWrapperPass>();
-    AU.addRequired<TargetTransformInfoWrapperPass>();
-  }
-
-  bool runOnLoop(Loop *L, LPPassManager &LPM) override;
-};
-
-bool AArch64LoopIdiomTransformLegacyPass::runOnLoop(Loop *L,
-                                                    LPPassManager &LPM) {
-
-  if (skipLoop(L))
-    return false;
-
-  auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
-  auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
-  auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(
-      *L->getHeader()->getParent());
-  return AArch64LoopIdiomTransform(
-             DT, LI, &TTI, &L->getHeader()->getModule()->getDataLayout())
-      .run(L);
-}
-
-} // end anonymous namespace
-
-char AArch64LoopIdiomTransformLegacyPass::ID = 0;
-
-INITIALIZE_PASS_BEGIN(
-    AArch64LoopIdiomTransformLegacyPass, "aarch64-lit",
-    "Transform specific loop idioms into optimized vector forms", false, false)
-INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
-INITIALIZE_PASS_DEPENDENCY(LCSSAWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
-INITIALIZE_PASS_END(
-    AArch64LoopIdiomTransformLegacyPass, "aarch64-lit",
-    "Transform specific loop idioms into optimized vector forms", false, false)
-
-Pass *llvm::createAArch64LoopIdiomTransformPass() {
-  return new AArch64LoopIdiomTransformLegacyPass();
-}
-
-PreservedAnalyses
-AArch64LoopIdiomTransformPass::run(Loop &L, LoopAnalysisManager &AM,
-                                   LoopStandardAnalysisResults &AR,
-                                   LPMUpdater &) {
+PreservedAnalyses LoopIdiomTransformPass::run(Loop &L, LoopAnalysisManager &AM,
+                                              LoopStandardAnalysisResults &AR,
+                                              LPMUpdater &) {
   if (DisableAll)
     return PreservedAnalyses::all();
 
   const auto *DL = &L.getHeader()->getModule()->getDataLayout();
 
-  AArch64LoopIdiomTransform LIT(&AR.DT, &AR.LI, &AR.TTI, DL);
+  LoopIdiomTransform LIT(&AR.DT, &AR.LI, &AR.TTI, DL);
   if (!LIT.run(&L))
     return PreservedAnalyses::all();
 
@@ -183,11 +132,11 @@ AArch64LoopIdiomTransformPass::run(Loop &L, LoopAnalysisManager &AM,
 
 //===----------------------------------------------------------------------===//
 //
-//          Implementation of AArch64LoopIdiomTransform
+//          Implementation of LoopIdiomTransform
 //
 //===----------------------------------------------------------------------===//
 
-bool AArch64LoopIdiomTransform::run(Loop *L) {
+bool LoopIdiomTransform::run(Loop *L) {
   CurLoop = L;
 
   Function &F = *L->getHeader()->getParent();
@@ -211,7 +160,7 @@ bool AArch64LoopIdiomTransform::run(Loop *L) {
   return recognizeByteCompare();
 }
 
-bool AArch64LoopIdiomTransform::recognizeByteCompare() {
+bool LoopIdiomTransform::recognizeByteCompare() {
   // Currently the transformation only works on scalable vector types, although
   // there is no fundamental reason why it cannot be made to work for fixed
   // width too.
@@ -224,7 +173,7 @@ bool AArch64LoopIdiomTransform::recognizeByteCompare() {
 
   BasicBlock *Header = CurLoop->getHeader();
 
-  // In AArch64LoopIdiomTransform::run we have already checked that the loop
+  // In LoopIdiomTransform::run we have already checked that the loop
   // has a preheader so we can assume it's in a canonical form.
   if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 2)
     return false;
@@ -242,8 +191,7 @@ bool AArch64LoopIdiomTransform::recognizeByteCompare() {
   //   %cmp.not = icmp eq i32 %inc, %n
   //   br i1 %cmp.not, label %while.end, label %while.body
   //
-  auto CondBBInsts = LoopBlocks[0]->instructionsWithoutDebug();
-  if (std::distance(CondBBInsts.begin(), CondBBInsts.end()) > 4)
+  if (LoopBlocks[0]->sizeWithoutDebug() > 4)
     return false;
 
   // The second block should contain 7 instructions, e.g.
@@ -257,8 +205,7 @@ bool AArch64LoopIdiomTransform::recognizeByteCompare() {
   //   %cmp.not.ld = icmp eq i8 %load.a, %load.b
   //   br i1 %cmp.not.ld, label %while.cond, label %while.end
   //
-  auto LoopBBInsts = LoopBlocks[1]->instructionsWithoutDebug();
-  if (std::distance(LoopBBInsts.begin(), LoopBBInsts.end()) > 7)
+  if (LoopBlocks[1]->sizeWithoutDebug() > 7)
     return false;
 
   // The incoming value to the PHI node from the loop should be an add of 1.
@@ -393,7 +340,109 @@ bool AArch64LoopIdiomTransform::recognizeByteCompare() {
   return true;
 }
 
-Value *AArch64LoopIdiomTransform::expandFindMismatch(
+Value *LoopIdiomTransform::createMaskedFindMismatch(IRBuilder<> &Builder,
+                                                    GetElementPtrInst *GEPA,
+                                                    GetElementPtrInst *GEPB,
+                                                    Value *ExtStart,
+                                                    Value *ExtEnd) {
+  Type *I64Type = Builder.getInt64Ty();
+  Type *ResType = Builder.getInt32Ty();
+  Type *LoadType = Builder.getInt8Ty();
+  Value *PtrA = GEPA->getPointerOperand();
+  Value *PtrB = GEPB->getPointerOperand();
+
+  // At this point we know two things must be true:
+  //  1. Start <= End
+  //  2. ExtMaxLen <= MinPageSize due to the page checks.
+  // Therefore, we know that we can use a 64-bit induction variable that
+  // starts from 0 -> ExtMaxLen and it will not overflow.
+  ScalableVectorType *PredVTy =
+      ScalableVectorType::get(Builder.getInt1Ty(), 16);
+
+  Value *InitialPred = Builder.CreateIntrinsic(
+      Intrinsic::get_active_lane_mask, {PredVTy, I64Type}, {ExtStart, ExtEnd});
+
+  Value *VecLen = Builder.CreateIntrinsic(Intrinsic::vscale, {I64Type}, {});
+  VecLen = Builder.CreateMul(VecLen, ConstantInt::get(I64Type, 16), "",
+                             /*HasNUW=*/true, /*HasNSW=*/true);
+
+  Value *PFalse = Builder.CreateVectorSplat(PredVTy->getElementCount(),
+                                            Builder.getInt1(false));
+
+  BranchInst *JumpToVectorLoop = BranchInst::Create(VectorLoopStartBlock);
+  Builder.Insert(JumpToVectorLoop);
+
+  // Set up the first vector loop block by creating the PHIs, doing the vector
+  // loads and comparing the vectors.
+  Builder.SetInsertPoint(VectorLoopStartBlock);
+  PHINode *LoopPred = Builder.CreatePHI(PredVTy, 2, "mismatch_vec_loop_pred");
+  LoopPred->addIncoming(InitialPred, VectorLoopPreheaderBlock);
+  PHINode *VectorIndexPhi = Builder.CreatePHI(I64Type, 2, "mismatch_vec_index");
+  VectorIndexPhi->addIncoming(ExtStart, VectorLoopPreheaderBlock);
+  Type *VectorLoadType = ScalableVectorType::get(Builder.getInt8Ty(), 16);
+  Value *Passthru = ConstantInt::getNullValue(VectorLoadType);
+
+  Value *VectorLhsGep = Builder.CreateGEP(LoadType, PtrA, VectorIndexPhi);
+  if (GEPA->isInBounds())
+    cast<GetElementPtrInst>(VectorLhsGep)->setIsInBounds(true);
+  Value *VectorLhsLoad = Builder.CreateMaskedLoad(VectorLoadType, VectorLhsGep,
+                                                  Align(1), LoopPred, Passthru);
+
+  Value *VectorRhsGep = Builder.CreateGEP(LoadType, PtrB, VectorIndexPhi);
+  if (GEPB->isInBounds())
+    cast<GetElemen...
[truncated]

Copy link

github-actions bot commented May 31, 2024

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff dbc3e26c25587e5460ae12caed84cb09197c4ed7 267750f58422e248c92f8d21870cd5086c6151f2 -- llvm/lib/Passes/PassBuilder.cpp llvm/lib/Target/AArch64/AArch64.h llvm/lib/Target/AArch64/AArch64TargetMachine.cpp llvm/lib/Target/AArch64/AArch64TargetMachine.h llvm/include/llvm/Transforms/Vectorize/LoopIdiomVectorize.h llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp
View the diff from clang-format here.
diff --git a/llvm/lib/Target/AArch64/AArch64.h b/llvm/lib/Target/AArch64/AArch64.h
index 6f2aeb83a4..8c6be3e312 100644
--- a/llvm/lib/Target/AArch64/AArch64.h
+++ b/llvm/lib/Target/AArch64/AArch64.h
@@ -89,7 +89,7 @@ void initializeAArch64DAGToDAGISelLegacyPass(PassRegistry &);
 void initializeAArch64DeadRegisterDefinitionsPass(PassRegistry&);
 void initializeAArch64ExpandPseudoPass(PassRegistry &);
 void initializeAArch64GlobalsTaggingPass(PassRegistry &);
-void initializeAArch64LoadStoreOptPass(PassRegistry&);
+void initializeAArch64LoadStoreOptPass(PassRegistry &);
 void initializeAArch64LowerHomogeneousPrologEpilogPass(PassRegistry &);
 void initializeAArch64MIPeepholeOptPass(PassRegistry &);
 void initializeAArch64O0PreLegalizerCombinerPass(PassRegistry &);

@nikic
Copy link
Contributor

nikic commented Jun 1, 2024

Is it possible to make the pass name more meaningful? Otherwise we end up with two passes called LoopIdiomRecognize and LoopIdiomTransform (loop-idiom and loop-idiom-transform), and it's not at all obvious how they differ. If I understand correctly, the relevant difference is that the latter is targeted at vector idioms?

(Alternatively, if this is no longer AArch64 specific, can it be merged into the general LoopIdiom pass? But that's probably a more involved change.)

@preames
Copy link
Collaborator

preames commented Jun 3, 2024

Can we please split this into two changes, one which does the move and rename, and a second which adds the new functionality? This diff is slightly confusing as a reader.

Naming wise, maybe LoopIdiomVectorizer? And update the comment to be really explicit about the goal being to eventually kill this off as the vectorizer grows more powerful?

@mshockwave
Copy link
Member Author

mshockwave commented Jun 3, 2024

Can we please split this into two changes, one which does the move and rename, and a second which adds the new functionality? This diff is slightly confusing as a reader.

I don't think I add any new functionalities in this patch, the support for RVV-style vectorization is in #94082 . Or are you referring to factoring out the part that will differ between SVE and RVV? That change is NFC.

Naming wise, maybe LoopIdiomVectorizer? And update the comment to be really explicit about the goal being to eventually kill this off as the vectorizer grows more powerful?

I second for LoopIdiomVectorizer.

@preames
Copy link
Collaborator

preames commented Jun 4, 2024

Can we please split this into two changes, one which does the move and rename, and a second which adds the new functionality? This diff is slightly confusing as a reader.

I don't think I add any new functionalities in this patch, the support for RVV-style vectorization is in #94082 . Or are you referring to factoring out the part that will differ between SVE and RVV? That change is NFC.

Apologies, it looks like I misread the patch description.

@mshockwave
Copy link
Member Author

(Alternatively, if this is no longer AArch64 specific, can it be merged into the general LoopIdiom pass? But that's probably a more involved change.)

Sorry I missed this question: I tried this but doesn't feel like what LoopIdiom is doing now (pattern matching and turn into libcalls) is somewhat related but not directed related to what this Pass does (pattern matching and turn into vector instructions). And yes, it's a more involved change. So I'm more inclined to not do the merge, at least for now.

@mshockwave mshockwave changed the title [AArch64][LoopIdiom] Generalize AArch64LoopIdiomTransform into LoopIdiomTransform [AArch64][LoopIdiom] Generalize AArch64LoopIdiomTransform into LoopIdiomVectorize Jun 4, 2024
@mshockwave
Copy link
Member Author

I've renamed the Pass to LoopIdiomVectorize and add a statement in the file header saying that our ultimately goal is to do this in LV instead.

@nikic
Copy link
Contributor

nikic commented Jun 5, 2024

I'm happy with the new naming, thanks! It looks like this PR needs a rebase due to some conflicts.

I'd generally recommend to split the move and rename off from the other code changes you do here. I assume that those are all NFC refactorings, but they look non-trivial enough that they probably shouldn't be part of the code move.

@@ -742,6 +682,11 @@ void AArch64LoopIdiomTransform::transformByteCompare(
DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
Builder.SetCurrentDebugLocation(PHBranch->getDebugLoc());

// Safeguard to check if we build the correct DomTree with DTU.
auto CheckDTU = llvm::make_scope_exit([&]() {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be a separate patch?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. These DTU tree changes are now reverted.

…iomTransform

To facilitate sharing LoopIdiomTransform between AArch64 and RISC-V,
this patch first moves AArch64LoopIdiomTransform from lib/Target/AArch64
to lib/Transforms/Vectorize. In addition, key component that is subject
to differ from RVV's vectorization style is factored out preemptively
in this patch.
@mshockwave mshockwave force-pushed the patch/aarch64-lit-generalize branch from 5552be5 to 267750f Compare June 5, 2024 23:50
@mshockwave
Copy link
Member Author

I have now reverted all the intrusive refactoring. I'm planning to create two additional PRs:

  1. Factoring out the part that differs between masked v.s. VP vectorization. Will be NFC.
  2. I found that some of the DominatorTreeUpdates are redundant. I will put such simplifications in a separate PR. But it's still effectively a NFC.


class AArch64LoopIdiomTransformLegacyPass : public LoopPass {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just want to note that the legacy Pass was not used anywhere. I think we simply forgot to remove it.

@david-arm david-arm requested a review from kmclaughlin-arm June 6, 2024 08:18
Copy link
Collaborator

@preames preames left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@mshockwave
Copy link
Member Author

I'll wait for another day or so to see if there are any other comments, as I saw a reviewer from ARM has being tagged.

I have now reverted all the intrusive refactoring. I'm planning to create two additional PRs:

  1. Factoring out the part that differs between masked v.s. VP vectorization. Will be NFC.

This is now in #94682

  1. I found that some of the DominatorTreeUpdates are redundant. I will put such simplifications in a separate PR. But it's still effectively a NFC.

This is now in #94681

@mshockwave mshockwave merged commit 37e309f into llvm:main Jun 7, 2024
6 of 7 checks passed
@mshockwave mshockwave deleted the patch/aarch64-lit-generalize branch June 7, 2024 21:06
keith added a commit to keith/llvm-project that referenced this pull request Jun 7, 2024
keith added a commit that referenced this pull request Jun 7, 2024
nekoshirro pushed a commit to nekoshirro/Alchemist-LLVM that referenced this pull request Jun 9, 2024
…iomVectorize (llvm#94081)

To facilitate sharing LoopIdiomTransform between AArch64 and RISC-V,
this first patch moves AArch64LoopIdiomTransform from lib/Target/AArch64
to lib/Transforms/Vectorize and renames it to LoopIdiomVectorize. The
following patch (llvm#94082) will teach LoopIdiomVectorize how to generate VP
intrinsics (in addition to the current masked vector style) in favor of
RVV.

Signed-off-by: Hafidz Muzakky <[email protected]>
nekoshirro pushed a commit to nekoshirro/Alchemist-LLVM that referenced this pull request Jun 9, 2024
@HerrCai0907 HerrCai0907 mentioned this pull request Jun 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants