Skip to content

Conversation

petar-avramovic
Copy link
Collaborator

Add IDs for bit width that cover multiple LLTs: B32 B64 etc.
"Predicate" wrapper class for bool predicate functions used to
write pretty rules. Predicates can be combined using &&, || and !.
Lowering for splitting and widening loads.
Write rules for loads to not change existing mir tests from old
regbankselect.

Lower G_ instructions that can't be inst-selected with register bank
assignment from RBSelect based on uniformity analysis.
- Lower instruction to perform it on assigned register bank
- Put uniform value in vgpr because SALU instruction is not available
- Execute divergent instruction in SALU - "waterfall loop"

Given LLTs on all operands after legalizer, some register bank
assignments require lowering while other do not.
Note: cases where all register bank assignments would require lowering
are lowered in legalizer.

RBLegalize goals:
- Define Rules: when and how to perform lowering
- Goal of defining Rules it to provide high level table-like brief
  overview of how to lower generic instructions based on available
  target features and uniformity info (uniform vs divergent).
- Fast search of Rules, depends on how complicated Rule.Predicate is
- For some opcodes there would be too many Rules that are essentially
  all the same just for different combinations of types and banks.
  Write custom function that handles all cases.
- Rules are made from enum IDs that correspond to each operand.
  Names of IDs are meant to give brief description what lowering does
  for each operand or the whole instruction.
- RBLegalizeHelper implements lowering algorithms and handles all IDs

Since this is the first patch that actually enables -new-reg-bank-select
here is the summary of regression tests that were added earlier:
- if instruction is uniform always select SALU instruction if available
- eliminate back to back vgpr to sgpr to vgpr copies of uniform values
- fast rules: small differences for standard and vector instruction
- enabling Rule based on target feature - salu_float
- how to specify lowering algorithm - vgpr S64 AND to S32
- on G_TRUNC in reg, it is up to user to deal with truncated bits
  G_TRUNC in reg is treated as no-op.
- dealing with truncated high bits - ABS S16 to S32
- sgpr S1 phi lowering
- new opcodes for vcc-to-scc and scc-to-vcc copies
- lowering for vgprS1-to-vcc copy (formally this is vgpr-to-vcc G_TRUNC)
- S1 zext and sext lowering to select
- uniform and divergent S1 AND(OR and XOR) lowering - inst-selected into
  SALU instruction
- divergent phi with uniform inputs
- divergent instruction with temporal divergent use, source instruction
  is defined as uniform(RBSelect) - missing temporal divergence lowering
- uniform phi, because of undef incoming, is assigned to vgpr. Will be
  fixed in RBSelect via another fix in machine uniformity analysis.
Add IDs for bit width that cover multiple LLTs: B32 B64 etc.
"Predicate" wrapper class for bool predicate functions used to
write pretty rules. Predicates can be combined using &&, || and !.
Lowering for splitting and widening loads.
Write rules for loads to not change existing mir tests from old
regbankselect.
Copy link
Collaborator Author

petar-avramovic commented Oct 18, 2024

Warning

This pull request is not mergeable via GitHub because a downstack PR is open. Once all requirements are satisfied, merge this PR as a stack on Graphite.
Learn more

This stack of pull requests is managed by Graphite. Learn more about stacking.

@llvmbot
Copy link
Member

llvmbot commented Oct 18, 2024

@llvm/pr-subscribers-llvm-globalisel

@llvm/pr-subscribers-backend-amdgpu

Author: Petar Avramovic (petar-avramovic)

Changes

Add IDs for bit width that cover multiple LLTs: B32 B64 etc.
"Predicate" wrapper class for bool predicate functions used to
write pretty rules. Predicates can be combined using &&, || and !.
Lowering for splitting and widening loads.
Write rules for loads to not change existing mir tests from old
regbankselect.


Patch is 81.23 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/112865.diff

6 Files Affected:

  • (modified) llvm/lib/Target/AMDGPU/AMDGPURBLegalizeHelper.cpp (+297-5)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPURBLegalizeHelper.h (+4-3)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPURBLegalizeRules.cpp (+300-7)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPURBLegalizeRules.h (+63-2)
  • (modified) llvm/test/CodeGen/AMDGPU/GlobalISel/regbankselect-load.mir (+271-49)
  • (modified) llvm/test/CodeGen/AMDGPU/GlobalISel/regbankselect-zextload.mir (+7-2)
diff --git a/llvm/lib/Target/AMDGPU/AMDGPURBLegalizeHelper.cpp b/llvm/lib/Target/AMDGPU/AMDGPURBLegalizeHelper.cpp
index a0f6ecedab7a83..f58f0a315096d2 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURBLegalizeHelper.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPURBLegalizeHelper.cpp
@@ -37,6 +37,97 @@ bool RegBankLegalizeHelper::findRuleAndApplyMapping(MachineInstr &MI) {
   return true;
 }
 
+void RegBankLegalizeHelper::splitLoad(MachineInstr &MI,
+                                      ArrayRef<LLT> LLTBreakdown, LLT MergeTy) {
+  MachineFunction &MF = B.getMF();
+  assert(MI.getNumMemOperands() == 1);
+  MachineMemOperand &BaseMMO = **MI.memoperands_begin();
+  Register Dst = MI.getOperand(0).getReg();
+  const RegisterBank *DstRB = MRI.getRegBankOrNull(Dst);
+  Register BasePtrReg = MI.getOperand(1).getReg();
+  LLT PtrTy = MRI.getType(BasePtrReg);
+  const RegisterBank *PtrRB = MRI.getRegBankOrNull(BasePtrReg);
+  LLT OffsetTy = LLT::scalar(PtrTy.getSizeInBits());
+  SmallVector<Register, 4> LoadPartRegs;
+
+  unsigned ByteOffset = 0;
+  for (LLT PartTy : LLTBreakdown) {
+    Register BasePtrPlusOffsetReg;
+    if (ByteOffset == 0) {
+      BasePtrPlusOffsetReg = BasePtrReg;
+    } else {
+      BasePtrPlusOffsetReg = MRI.createVirtualRegister({PtrRB, PtrTy});
+      Register OffsetReg = MRI.createVirtualRegister({PtrRB, OffsetTy});
+      B.buildConstant(OffsetReg, ByteOffset);
+      B.buildPtrAdd(BasePtrPlusOffsetReg, BasePtrReg, OffsetReg);
+    }
+    MachineMemOperand *BasePtrPlusOffsetMMO =
+        MF.getMachineMemOperand(&BaseMMO, ByteOffset, PartTy);
+    Register PartLoad = MRI.createVirtualRegister({DstRB, PartTy});
+    B.buildLoad(PartLoad, BasePtrPlusOffsetReg, *BasePtrPlusOffsetMMO);
+    LoadPartRegs.push_back(PartLoad);
+    ByteOffset += PartTy.getSizeInBytes();
+  }
+
+  if (!MergeTy.isValid()) {
+    // Loads are of same size, concat or merge them together.
+    B.buildMergeLikeInstr(Dst, LoadPartRegs);
+  } else {
+    // Load(s) are not all of same size, need to unmerge them to smaller pieces
+    // of MergeTy type, then merge them all together in Dst.
+    SmallVector<Register, 4> MergeTyParts;
+    for (Register Reg : LoadPartRegs) {
+      if (MRI.getType(Reg) == MergeTy) {
+        MergeTyParts.push_back(Reg);
+      } else {
+        auto Unmerge = B.buildUnmerge(MergeTy, Reg);
+        for (unsigned i = 0; i < Unmerge->getNumOperands() - 1; ++i) {
+          Register UnmergeReg = Unmerge->getOperand(i).getReg();
+          MRI.setRegBank(UnmergeReg, *DstRB);
+          MergeTyParts.push_back(UnmergeReg);
+        }
+      }
+    }
+    B.buildMergeLikeInstr(Dst, MergeTyParts);
+  }
+  MI.eraseFromParent();
+}
+
+void RegBankLegalizeHelper::widenLoad(MachineInstr &MI, LLT WideTy,
+                                      LLT MergeTy) {
+  MachineFunction &MF = B.getMF();
+  assert(MI.getNumMemOperands() == 1);
+  MachineMemOperand &BaseMMO = **MI.memoperands_begin();
+  Register Dst = MI.getOperand(0).getReg();
+  const RegisterBank *DstRB = MRI.getRegBankOrNull(Dst);
+  Register BasePtrReg = MI.getOperand(1).getReg();
+
+  Register BasePtrPlusOffsetReg;
+  BasePtrPlusOffsetReg = BasePtrReg;
+
+  MachineMemOperand *BasePtrPlusOffsetMMO =
+      MF.getMachineMemOperand(&BaseMMO, 0, WideTy);
+  Register WideLoad = MRI.createVirtualRegister({DstRB, WideTy});
+  B.buildLoad(WideLoad, BasePtrPlusOffsetReg, *BasePtrPlusOffsetMMO);
+
+  if (WideTy.isScalar()) {
+    B.buildTrunc(Dst, WideLoad);
+  } else {
+    SmallVector<Register, 4> MergeTyParts;
+    unsigned NumEltsMerge =
+        MRI.getType(Dst).getSizeInBits() / MergeTy.getSizeInBits();
+    auto Unmerge = B.buildUnmerge(MergeTy, WideLoad);
+    for (unsigned i = 0; i < Unmerge->getNumOperands() - 1; ++i) {
+      Register UnmergeReg = Unmerge->getOperand(i).getReg();
+      MRI.setRegBank(UnmergeReg, *DstRB);
+      if (i < NumEltsMerge)
+        MergeTyParts.push_back(UnmergeReg);
+    }
+    B.buildMergeLikeInstr(Dst, MergeTyParts);
+  }
+  MI.eraseFromParent();
+}
+
 void RegBankLegalizeHelper::lower(MachineInstr &MI,
                                   const RegBankLLTMapping &Mapping,
                                   SmallSet<Register, 4> &WaterfallSGPRs) {
@@ -119,6 +210,53 @@ void RegBankLegalizeHelper::lower(MachineInstr &MI,
     MI.eraseFromParent();
     break;
   }
+  case SplitLoad: {
+    LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
+    LLT V8S16 = LLT::fixed_vector(8, S16);
+    LLT V4S32 = LLT::fixed_vector(4, S32);
+    LLT V2S64 = LLT::fixed_vector(2, S64);
+
+    if (DstTy == LLT::fixed_vector(8, S32))
+      splitLoad(MI, {V4S32, V4S32});
+    else if (DstTy == LLT::fixed_vector(16, S32))
+      splitLoad(MI, {V4S32, V4S32, V4S32, V4S32});
+    else if (DstTy == LLT::fixed_vector(4, S64))
+      splitLoad(MI, {V2S64, V2S64});
+    else if (DstTy == LLT::fixed_vector(8, S64))
+      splitLoad(MI, {V2S64, V2S64, V2S64, V2S64});
+    else if (DstTy == LLT::fixed_vector(16, S16))
+      splitLoad(MI, {V8S16, V8S16});
+    else if (DstTy == LLT::fixed_vector(32, S16))
+      splitLoad(MI, {V8S16, V8S16, V8S16, V8S16});
+    else if (DstTy == LLT::scalar(256))
+      splitLoad(MI, {LLT::scalar(128), LLT::scalar(128)});
+    else if (DstTy == LLT::scalar(96))
+      splitLoad(MI, {S64, S32}, S32);
+    else if (DstTy == LLT::fixed_vector(3, S32))
+      splitLoad(MI, {LLT::fixed_vector(2, S32), S32}, S32);
+    else if (DstTy == LLT::fixed_vector(6, S16))
+      splitLoad(MI, {LLT::fixed_vector(4, S16), LLT::fixed_vector(2, S16)},
+                LLT::fixed_vector(2, S16));
+    else {
+      MI.dump();
+      llvm_unreachable("SplitLoad type not supported\n");
+    }
+    break;
+  }
+  case WidenLoad: {
+    LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
+    if (DstTy == LLT::scalar(96))
+      widenLoad(MI, LLT::scalar(128));
+    else if (DstTy == LLT::fixed_vector(3, S32))
+      widenLoad(MI, LLT::fixed_vector(4, S32), S32);
+    else if (DstTy == LLT::fixed_vector(6, S16))
+      widenLoad(MI, LLT::fixed_vector(8, S16), LLT::fixed_vector(2, S16));
+    else {
+      MI.dump();
+      llvm_unreachable("WidenLoad type not supported\n");
+    }
+    break;
+  }
   }
 
   // TODO: executeInWaterfallLoop(... WaterfallSGPRs)
@@ -142,13 +280,73 @@ LLT RegBankLegalizeHelper::getTyFromID(RegBankLLTMapingApplyID ID) {
   case Sgpr64:
   case Vgpr64:
     return LLT::scalar(64);
-
+  case SgprP1:
+  case VgprP1:
+    return LLT::pointer(1, 64);
+  case SgprP3:
+  case VgprP3:
+    return LLT::pointer(3, 32);
+  case SgprP4:
+  case VgprP4:
+    return LLT::pointer(4, 64);
+  case SgprP5:
+  case VgprP5:
+    return LLT::pointer(5, 32);
   case SgprV4S32:
   case VgprV4S32:
   case UniInVgprV4S32:
     return LLT::fixed_vector(4, 32);
-  case VgprP1:
-    return LLT::pointer(1, 64);
+  default:
+    return LLT();
+  }
+}
+
+LLT RegBankLegalizeHelper::getBTyFromID(RegBankLLTMapingApplyID ID, LLT Ty) {
+  switch (ID) {
+  case SgprB32:
+  case VgprB32:
+  case UniInVgprB32:
+    if (Ty == LLT::scalar(32) || Ty == LLT::fixed_vector(2, 16) ||
+        Ty == LLT::pointer(3, 32) || Ty == LLT::pointer(5, 32) ||
+        Ty == LLT::pointer(6, 32))
+      return Ty;
+    return LLT();
+  case SgprB64:
+  case VgprB64:
+  case UniInVgprB64:
+    if (Ty == LLT::scalar(64) || Ty == LLT::fixed_vector(2, 32) ||
+        Ty == LLT::fixed_vector(4, 16) || Ty == LLT::pointer(0, 64) ||
+        Ty == LLT::pointer(1, 64) || Ty == LLT::pointer(4, 64))
+      return Ty;
+    return LLT();
+  case SgprB96:
+  case VgprB96:
+  case UniInVgprB96:
+    if (Ty == LLT::scalar(96) || Ty == LLT::fixed_vector(3, 32) ||
+        Ty == LLT::fixed_vector(6, 16))
+      return Ty;
+    return LLT();
+  case SgprB128:
+  case VgprB128:
+  case UniInVgprB128:
+    if (Ty == LLT::scalar(128) || Ty == LLT::fixed_vector(4, 32) ||
+        Ty == LLT::fixed_vector(2, 64))
+      return Ty;
+    return LLT();
+  case SgprB256:
+  case VgprB256:
+  case UniInVgprB256:
+    if (Ty == LLT::scalar(256) || Ty == LLT::fixed_vector(8, 32) ||
+        Ty == LLT::fixed_vector(4, 64) || Ty == LLT::fixed_vector(16, 16))
+      return Ty;
+    return LLT();
+  case SgprB512:
+  case VgprB512:
+  case UniInVgprB512:
+    if (Ty == LLT::scalar(512) || Ty == LLT::fixed_vector(16, 32) ||
+        Ty == LLT::fixed_vector(8, 64))
+      return Ty;
+    return LLT();
   default:
     return LLT();
   }
@@ -163,10 +361,26 @@ RegBankLegalizeHelper::getRBFromID(RegBankLLTMapingApplyID ID) {
   case Sgpr16:
   case Sgpr32:
   case Sgpr64:
+  case SgprP1:
+  case SgprP3:
+  case SgprP4:
+  case SgprP5:
   case SgprV4S32:
+  case SgprB32:
+  case SgprB64:
+  case SgprB96:
+  case SgprB128:
+  case SgprB256:
+  case SgprB512:
   case UniInVcc:
   case UniInVgprS32:
   case UniInVgprV4S32:
+  case UniInVgprB32:
+  case UniInVgprB64:
+  case UniInVgprB96:
+  case UniInVgprB128:
+  case UniInVgprB256:
+  case UniInVgprB512:
   case Sgpr32Trunc:
   case Sgpr32AExt:
   case Sgpr32AExtBoolInReg:
@@ -176,7 +390,16 @@ RegBankLegalizeHelper::getRBFromID(RegBankLLTMapingApplyID ID) {
   case Vgpr32:
   case Vgpr64:
   case VgprP1:
+  case VgprP3:
+  case VgprP4:
+  case VgprP5:
   case VgprV4S32:
+  case VgprB32:
+  case VgprB64:
+  case VgprB96:
+  case VgprB128:
+  case VgprB256:
+  case VgprB512:
     return VgprRB;
 
   default:
@@ -202,17 +425,42 @@ void RegBankLegalizeHelper::applyMappingDst(
     case Sgpr16:
     case Sgpr32:
     case Sgpr64:
+    case SgprP1:
+    case SgprP3:
+    case SgprP4:
+    case SgprP5:
     case SgprV4S32:
     case Vgpr32:
     case Vgpr64:
     case VgprP1:
+    case VgprP3:
+    case VgprP4:
+    case VgprP5:
     case VgprV4S32: {
       assert(Ty == getTyFromID(MethodIDs[OpIdx]));
       assert(RB == getRBFromID(MethodIDs[OpIdx]));
       break;
     }
 
-    // uniform in vcc/vgpr: scalars and vectors
+    // sgpr and vgpr B-types
+    case SgprB32:
+    case SgprB64:
+    case SgprB96:
+    case SgprB128:
+    case SgprB256:
+    case SgprB512:
+    case VgprB32:
+    case VgprB64:
+    case VgprB96:
+    case VgprB128:
+    case VgprB256:
+    case VgprB512: {
+      assert(Ty == getBTyFromID(MethodIDs[OpIdx], Ty));
+      assert(RB == getRBFromID(MethodIDs[OpIdx]));
+      break;
+    }
+
+    // uniform in vcc/vgpr: scalars, vectors and B-types
     case UniInVcc: {
       assert(Ty == S1);
       assert(RB == SgprRB);
@@ -229,6 +477,17 @@ void RegBankLegalizeHelper::applyMappingDst(
       AMDGPU::buildReadAnyLaneDst(B, MI, RBI);
       break;
     }
+    case UniInVgprB32:
+    case UniInVgprB64:
+    case UniInVgprB96:
+    case UniInVgprB128:
+    case UniInVgprB256:
+    case UniInVgprB512: {
+      assert(Ty == getBTyFromID(MethodIDs[OpIdx], Ty));
+      assert(RB == SgprRB);
+      AMDGPU::buildReadAnyLaneDst(B, MI, RBI);
+      break;
+    }
 
     // sgpr trunc
     case Sgpr32Trunc: {
@@ -279,16 +538,34 @@ void RegBankLegalizeHelper::applyMappingSrc(
     case Sgpr16:
     case Sgpr32:
     case Sgpr64:
+    case SgprP1:
+    case SgprP3:
+    case SgprP4:
+    case SgprP5:
     case SgprV4S32: {
       assert(Ty == getTyFromID(MethodIDs[i]));
       assert(RB == getRBFromID(MethodIDs[i]));
       break;
     }
+    // sgpr B-types
+    case SgprB32:
+    case SgprB64:
+    case SgprB96:
+    case SgprB128:
+    case SgprB256:
+    case SgprB512: {
+      assert(Ty == getBTyFromID(MethodIDs[i], Ty));
+      assert(RB == getRBFromID(MethodIDs[i]));
+      break;
+    }
 
     // vgpr scalars, pointers and vectors
     case Vgpr32:
     case Vgpr64:
     case VgprP1:
+    case VgprP3:
+    case VgprP4:
+    case VgprP5:
     case VgprV4S32: {
       assert(Ty == getTyFromID(MethodIDs[i]));
       if (RB != VgprRB) {
@@ -298,6 +575,21 @@ void RegBankLegalizeHelper::applyMappingSrc(
       }
       break;
     }
+    // vgpr B-types
+    case VgprB32:
+    case VgprB64:
+    case VgprB96:
+    case VgprB128:
+    case VgprB256:
+    case VgprB512: {
+      assert(Ty == getBTyFromID(MethodIDs[i], Ty));
+      if (RB != VgprRB) {
+        auto CopyToVgpr =
+            B.buildCopy(createVgpr(getBTyFromID(MethodIDs[i], Ty)), Reg);
+        Op.setReg(CopyToVgpr.getReg(0));
+      }
+      break;
+    }
 
     // sgpr and vgpr scalars with extend
     case Sgpr32AExt: {
@@ -372,7 +664,7 @@ void RegBankLegalizeHelper::applyMappingPHI(MachineInstr &MI) {
   // We accept all types that can fit in some register class.
   // Uniform G_PHIs have all sgpr registers.
   // Divergent G_PHIs have vgpr dst but inputs can be sgpr or vgpr.
-  if (Ty == LLT::scalar(32)) {
+  if (Ty == LLT::scalar(32) || Ty == LLT::pointer(4, 64)) {
     return;
   }
 
diff --git a/llvm/lib/Target/AMDGPU/AMDGPURBLegalizeHelper.h b/llvm/lib/Target/AMDGPU/AMDGPURBLegalizeHelper.h
index 84042205ebf901..843d3d523a4a6a 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURBLegalizeHelper.h
+++ b/llvm/lib/Target/AMDGPU/AMDGPURBLegalizeHelper.h
@@ -92,6 +92,7 @@ class RegBankLegalizeHelper {
                               SmallSet<Register, 4> &SGPROperandRegs);
 
   LLT getTyFromID(RegBankLLTMapingApplyID ID);
+  LLT getBTyFromID(RegBankLLTMapingApplyID ID, LLT Ty);
 
   const RegisterBank *getRBFromID(RegBankLLTMapingApplyID ID);
 
@@ -104,9 +105,9 @@ class RegBankLegalizeHelper {
                   const SmallVectorImpl<RegBankLLTMapingApplyID> &MethodIDs,
                   SmallSet<Register, 4> &SGPRWaterfallOperandRegs);
 
-  unsigned setBufferOffsets(MachineIRBuilder &B, Register CombinedOffset,
-                            Register &VOffsetReg, Register &SOffsetReg,
-                            int64_t &InstOffsetVal, Align Alignment);
+  void splitLoad(MachineInstr &MI, ArrayRef<LLT> LLTBreakdown,
+                 LLT MergeTy = LLT());
+  void widenLoad(MachineInstr &MI, LLT WideTy, LLT MergeTy = LLT());
 
   void lower(MachineInstr &MI, const RegBankLLTMapping &Mapping,
              SmallSet<Register, 4> &SGPRWaterfallOperandRegs);
diff --git a/llvm/lib/Target/AMDGPU/AMDGPURBLegalizeRules.cpp b/llvm/lib/Target/AMDGPU/AMDGPURBLegalizeRules.cpp
index 1266f99c79c395..895a596cf84f40 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURBLegalizeRules.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPURBLegalizeRules.cpp
@@ -14,9 +14,11 @@
 //===----------------------------------------------------------------------===//
 
 #include "AMDGPURBLegalizeRules.h"
+#include "AMDGPUInstrInfo.h"
 #include "GCNSubtarget.h"
 #include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
 #include "llvm/IR/IntrinsicsAMDGPU.h"
+#include "llvm/Support/AMDGPUAddrSpace.h"
 
 using namespace llvm;
 using namespace AMDGPU;
@@ -47,6 +49,24 @@ bool matchUniformityAndLLT(Register Reg, UniformityLLTOpPredicateID UniID,
     return MRI.getType(Reg) == LLT::scalar(64);
   case P1:
     return MRI.getType(Reg) == LLT::pointer(1, 64);
+  case P3:
+    return MRI.getType(Reg) == LLT::pointer(3, 32);
+  case P4:
+    return MRI.getType(Reg) == LLT::pointer(4, 64);
+  case P5:
+    return MRI.getType(Reg) == LLT::pointer(5, 32);
+  case B32:
+    return MRI.getType(Reg).getSizeInBits() == 32;
+  case B64:
+    return MRI.getType(Reg).getSizeInBits() == 64;
+  case B96:
+    return MRI.getType(Reg).getSizeInBits() == 96;
+  case B128:
+    return MRI.getType(Reg).getSizeInBits() == 128;
+  case B256:
+    return MRI.getType(Reg).getSizeInBits() == 256;
+  case B512:
+    return MRI.getType(Reg).getSizeInBits() == 512;
 
   case UniS1:
     return MRI.getType(Reg) == LLT::scalar(1) && MUI.isUniform(Reg);
@@ -56,6 +76,26 @@ bool matchUniformityAndLLT(Register Reg, UniformityLLTOpPredicateID UniID,
     return MRI.getType(Reg) == LLT::scalar(32) && MUI.isUniform(Reg);
   case UniS64:
     return MRI.getType(Reg) == LLT::scalar(64) && MUI.isUniform(Reg);
+  case UniP1:
+    return MRI.getType(Reg) == LLT::pointer(1, 64) && MUI.isUniform(Reg);
+  case UniP3:
+    return MRI.getType(Reg) == LLT::pointer(3, 32) && MUI.isUniform(Reg);
+  case UniP4:
+    return MRI.getType(Reg) == LLT::pointer(4, 64) && MUI.isUniform(Reg);
+  case UniP5:
+    return MRI.getType(Reg) == LLT::pointer(5, 32) && MUI.isUniform(Reg);
+  case UniB32:
+    return MRI.getType(Reg).getSizeInBits() == 32 && MUI.isUniform(Reg);
+  case UniB64:
+    return MRI.getType(Reg).getSizeInBits() == 64 && MUI.isUniform(Reg);
+  case UniB96:
+    return MRI.getType(Reg).getSizeInBits() == 96 && MUI.isUniform(Reg);
+  case UniB128:
+    return MRI.getType(Reg).getSizeInBits() == 128 && MUI.isUniform(Reg);
+  case UniB256:
+    return MRI.getType(Reg).getSizeInBits() == 256 && MUI.isUniform(Reg);
+  case UniB512:
+    return MRI.getType(Reg).getSizeInBits() == 512 && MUI.isUniform(Reg);
 
   case DivS1:
     return MRI.getType(Reg) == LLT::scalar(1) && MUI.isDivergent(Reg);
@@ -65,6 +105,24 @@ bool matchUniformityAndLLT(Register Reg, UniformityLLTOpPredicateID UniID,
     return MRI.getType(Reg) == LLT::scalar(64) && MUI.isDivergent(Reg);
   case DivP1:
     return MRI.getType(Reg) == LLT::pointer(1, 64) && MUI.isDivergent(Reg);
+  case DivP3:
+    return MRI.getType(Reg) == LLT::pointer(3, 32) && MUI.isDivergent(Reg);
+  case DivP4:
+    return MRI.getType(Reg) == LLT::pointer(4, 64) && MUI.isDivergent(Reg);
+  case DivP5:
+    return MRI.getType(Reg) == LLT::pointer(5, 32) && MUI.isDivergent(Reg);
+  case DivB32:
+    return MRI.getType(Reg).getSizeInBits() == 32 && MUI.isDivergent(Reg);
+  case DivB64:
+    return MRI.getType(Reg).getSizeInBits() == 64 && MUI.isDivergent(Reg);
+  case DivB96:
+    return MRI.getType(Reg).getSizeInBits() == 96 && MUI.isDivergent(Reg);
+  case DivB128:
+    return MRI.getType(Reg).getSizeInBits() == 128 && MUI.isDivergent(Reg);
+  case DivB256:
+    return MRI.getType(Reg).getSizeInBits() == 256 && MUI.isDivergent(Reg);
+  case DivB512:
+    return MRI.getType(Reg).getSizeInBits() == 512 && MUI.isDivergent(Reg);
 
   case _:
     return true;
@@ -124,6 +182,22 @@ UniformityLLTOpPredicateID LLTToId(LLT Ty) {
   return _;
 }
 
+UniformityLLTOpPredicateID LLTToBId(LLT Ty) {
+  if (Ty == LLT::scalar(32) || Ty == LLT::fixed_vector(2, 16) ||
+      Ty == LLT::pointer(3, 32) || Ty == LLT::pointer(5, 32) ||
+      Ty == LLT::pointer(6, 32))
+    return B32;
+  if (Ty == LLT::scalar(64) || Ty == LLT::fixed_vector(2, 32) ||
+      Ty == LLT::fixed_vector(4, 16) || Ty == LLT::pointer(1, 64) ||
+      Ty == LLT::pointer(4, 64))
+    return B64;
+  if (Ty == LLT::fixed_vector(3, 32))
+    return B96;
+  if (Ty == LLT::fixed_vector(4, 32))
+    return B128;
+  return _;
+}
+
 const RegBankLLTMapping &
 SetOfRulesForOpcode::findMappingForMI(const MachineInstr &MI,
                                       const MachineRegisterInfo &MRI,
@@ -134,7 +208,12 @@ SetOfRulesForOpcode::findMappingForMI(const MachineInstr &MI,
   // returned which results in failure, does not search "Slow Rules".
   if (FastTypes != No) {
     Register Reg = MI.getOperand(0).getReg();
-    int Slot = getFastPredicateSlot(LLTToId(MRI.getType(Reg)));
+    int Slot;
+    if (FastTypes == StandardB)
+      Slot = getFastPredicateSlot(LLTToBId(MRI.getType(Reg)));
+    else
+      Slot = getFastPredicateSlot(LLTToId(MRI.getType(Reg)));
+
     if (Slot != -1) {
       if (MUI.isUniform(Reg))
         return Uni[Slot];
@@ -184,6 +263,19 @@ int SetOfRulesForOpcode::getFastPredicateSlot(
     default:
       return -1;
     }
+  case StandardB:
+    switch (Ty) {
+    case B32:
+      return 0;
+    case B64:
+      return 1;
+    case B96:
+      return 2;
+    case B128:
+      return 3;
+    default:
+      return -1;
+    }
   case Vector:
     switch (Ty) {
     case S32:
@@ -236,6 +328,127 @@ RegBankLegalizeRules::getRulesForOpc(MachineInstr &MI) const {
   return GRules.at(GRulesAlias.at(Opc));
 }
 
+// Syntactic sugar wrapper for predicate lambda that enables '&&', '||' and '!'.
+class Predicate {
+public:
+  struct Elt {
+    // Save formula composed of Pred, '&&', '||' and '!' as a jump table.
+    // Sink ! to Pred. For example !((A && !B) || C) -> (!A || B) && !C
+    // Sequences of && and || will be represented by jumps, for example:
+    // (A && B && ... X) or (A && B && ... X) || Y
+    //   A == true jump to B
+    //   A == false jump to end or Y, result is A(false) or Y
+    // (A || B || ... X) or (A || B || ... X) && Y
+    //   A == true jump to end or Y, result is B(true) or Y
+    //   A == false jump B
+    // Notice that when negating expression, we apply simply flip Neg on each
+    // Pred and swap TJumpOffset and FJumpOffset (&& becomes ||, || becomes &&)....
[truncated]

@petar-avramovic petar-avramovic force-pushed the users/petar-avramovic/new-rbs-rb-legalize branch from 3bd3785 to fbaf393 Compare October 18, 2024 11:11
@petar-avramovic petar-avramovic deleted the users/petar-avramovic/new-rbs-rb-load-rules_ branch October 18, 2024 11:16
@petar-avramovic
Copy link
Collaborator Author

Ignore this, see #112882

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants