Skip to content

Reland "[NVPTX] Cleanup/Refactoring in NVPTX AsmPrinter and RegisterInfo (NFC)" #127089

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

AlexMaclean
Copy link
Member

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Feb 13, 2025

@llvm/pr-subscribers-backend-nvptx

Author: Alex MacLean (AlexMaclean)

Changes

Patch is 39.13 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/127089.diff

4 Files Affected:

  • (modified) llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp (+214-302)
  • (modified) llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h (+10-12)
  • (modified) llvm/lib/Target/NVPTX/NVPTXRegisterInfo.cpp (+11-12)
  • (modified) llvm/lib/Target/NVPTX/NVPTXRegisterInfo.h (+3-3)
diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
index 75d930d9f7b6f..0538b33530470 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -27,6 +27,7 @@
 #include "cl_common_defines.h"
 #include "llvm/ADT/APFloat.h"
 #include "llvm/ADT/APInt.h"
+#include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/SmallString.h"
@@ -47,6 +48,7 @@
 #include "llvm/CodeGen/TargetRegisterInfo.h"
 #include "llvm/CodeGen/ValueTypes.h"
 #include "llvm/CodeGenTypes/MachineValueType.h"
+#include "llvm/IR/Argument.h"
 #include "llvm/IR/Attributes.h"
 #include "llvm/IR/BasicBlock.h"
 #include "llvm/IR/Constant.h"
@@ -93,20 +95,19 @@ using namespace llvm;
 
 #define DEPOTNAME "__local_depot"
 
-/// DiscoverDependentGlobals - Return a set of GlobalVariables on which \p V
+/// discoverDependentGlobals - Return a set of GlobalVariables on which \p V
 /// depends.
 static void
-DiscoverDependentGlobals(const Value *V,
+discoverDependentGlobals(const Value *V,
                          DenseSet<const GlobalVariable *> &Globals) {
-  if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(V))
+  if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(V)) {
     Globals.insert(GV);
-  else {
-    if (const User *U = dyn_cast<User>(V)) {
-      for (unsigned i = 0, e = U->getNumOperands(); i != e; ++i) {
-        DiscoverDependentGlobals(U->getOperand(i), Globals);
-      }
-    }
+    return;
   }
+
+  if (const User *U = dyn_cast<User>(V))
+    for (const auto &O : U->operands())
+      discoverDependentGlobals(O, Globals);
 }
 
 /// VisitGlobalVariableForEmission - Add \p GV to the list of GlobalVariable
@@ -127,8 +128,8 @@ VisitGlobalVariableForEmission(const GlobalVariable *GV,
 
   // Make sure we visit all dependents first
   DenseSet<const GlobalVariable *> Others;
-  for (unsigned i = 0, e = GV->getNumOperands(); i != e; ++i)
-    DiscoverDependentGlobals(GV->getOperand(i), Others);
+  for (const auto &O : GV->operands())
+    discoverDependentGlobals(O, Others);
 
   for (const GlobalVariable *GV : Others)
     VisitGlobalVariableForEmission(GV, Order, Visited, Visiting);
@@ -623,9 +624,8 @@ static bool usedInGlobalVarDef(const Constant *C) {
   if (!C)
     return false;
 
-  if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(C)) {
+  if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(C))
     return GV->getName() != "llvm.used";
-  }
 
   for (const User *U : C->users())
     if (const Constant *C = dyn_cast<Constant>(U))
@@ -635,25 +635,23 @@ static bool usedInGlobalVarDef(const Constant *C) {
   return false;
 }
 
-static bool usedInOneFunc(const User *U, Function const *&oneFunc) {
-  if (const GlobalVariable *othergv = dyn_cast<GlobalVariable>(U)) {
-    if (othergv->getName() == "llvm.used")
+static bool usedInOneFunc(const User *U, Function const *&OneFunc) {
+  if (const GlobalVariable *OtherGV = dyn_cast<GlobalVariable>(U))
+    if (OtherGV->getName() == "llvm.used")
       return true;
-  }
 
-  if (const Instruction *instr = dyn_cast<Instruction>(U)) {
-    if (instr->getParent() && instr->getParent()->getParent()) {
-      const Function *curFunc = instr->getParent()->getParent();
-      if (oneFunc && (curFunc != oneFunc))
+  if (const Instruction *I = dyn_cast<Instruction>(U)) {
+    if (const Function *CurFunc = I->getFunction()) {
+      if (OneFunc && (CurFunc != OneFunc))
         return false;
-      oneFunc = curFunc;
+      OneFunc = CurFunc;
       return true;
-    } else
-      return false;
+    }
+    return false;
   }
 
   for (const User *UU : U->users())
-    if (!usedInOneFunc(UU, oneFunc))
+    if (!usedInOneFunc(UU, OneFunc))
       return false;
 
   return true;
@@ -666,16 +664,15 @@ static bool usedInOneFunc(const User *U, Function const *&oneFunc) {
  * 2. Does it have local linkage?
  * 3. Is the global variable referenced only in one function?
  */
-static bool canDemoteGlobalVar(const GlobalVariable *gv, Function const *&f) {
-  if (!gv->hasLocalLinkage())
+static bool canDemoteGlobalVar(const GlobalVariable *GV, Function const *&f) {
+  if (!GV->hasLocalLinkage())
     return false;
-  PointerType *Pty = gv->getType();
-  if (Pty->getAddressSpace() != ADDRESS_SPACE_SHARED)
+  if (GV->getAddressSpace() != ADDRESS_SPACE_SHARED)
     return false;
 
   const Function *oneFunc = nullptr;
 
-  bool flag = usedInOneFunc(gv, oneFunc);
+  bool flag = usedInOneFunc(GV, oneFunc);
   if (!flag)
     return false;
   if (!oneFunc)
@@ -685,27 +682,22 @@ static bool canDemoteGlobalVar(const GlobalVariable *gv, Function const *&f) {
 }
 
 static bool useFuncSeen(const Constant *C,
-                        DenseMap<const Function *, bool> &seenMap) {
+                        const SmallPtrSetImpl<const Function *> &SeenSet) {
   for (const User *U : C->users()) {
     if (const Constant *cu = dyn_cast<Constant>(U)) {
-      if (useFuncSeen(cu, seenMap))
+      if (useFuncSeen(cu, SeenSet))
         return true;
     } else if (const Instruction *I = dyn_cast<Instruction>(U)) {
-      const BasicBlock *bb = I->getParent();
-      if (!bb)
-        continue;
-      const Function *caller = bb->getParent();
-      if (!caller)
-        continue;
-      if (seenMap.contains(caller))
-        return true;
+      if (const Function *Caller = I->getFunction())
+        if (SeenSet.contains(Caller))
+          return true;
     }
   }
   return false;
 }
 
 void NVPTXAsmPrinter::emitDeclarations(const Module &M, raw_ostream &O) {
-  DenseMap<const Function *, bool> seenMap;
+  SmallPtrSet<const Function *, 32> SeenSet;
   for (const Function &F : M) {
     if (F.getAttributes().hasFnAttr("nvptx-libcall-callee")) {
       emitDeclaration(&F, O);
@@ -731,7 +723,7 @@ void NVPTXAsmPrinter::emitDeclarations(const Module &M, raw_ostream &O) {
         }
         // Emit a declaration of this function if the function that
         // uses this constant expr has already been seen.
-        if (useFuncSeen(C, seenMap)) {
+        if (useFuncSeen(C, SeenSet)) {
           emitDeclaration(&F, O);
           break;
         }
@@ -739,23 +731,19 @@ void NVPTXAsmPrinter::emitDeclarations(const Module &M, raw_ostream &O) {
 
       if (!isa<Instruction>(U))
         continue;
-      const Instruction *instr = cast<Instruction>(U);
-      const BasicBlock *bb = instr->getParent();
-      if (!bb)
-        continue;
-      const Function *caller = bb->getParent();
-      if (!caller)
+      const Function *Caller = cast<Instruction>(U)->getFunction();
+      if (!Caller)
         continue;
 
       // If a caller has already been seen, then the caller is
       // appearing in the module before the callee. so print out
       // a declaration for the callee.
-      if (seenMap.contains(caller)) {
+      if (SeenSet.contains(Caller)) {
         emitDeclaration(&F, O);
         break;
       }
     }
-    seenMap[&F] = true;
+    SeenSet.insert(&F);
   }
   for (const GlobalAlias &GA : M.aliases())
     emitAliasDeclaration(&GA, O);
@@ -818,7 +806,7 @@ void NVPTXAsmPrinter::emitGlobals(const Module &M) {
 
   // Print out module-level global variables in proper order
   for (const GlobalVariable *GV : Globals)
-    printModuleLevelGV(GV, OS2, /*processDemoted=*/false, STI);
+    printModuleLevelGV(GV, OS2, /*ProcessDemoted=*/false, STI);
 
   OS2 << '\n';
 
@@ -839,16 +827,14 @@ void NVPTXAsmPrinter::emitGlobalAlias(const Module &M, const GlobalAlias &GA) {
 
 void NVPTXAsmPrinter::emitHeader(Module &M, raw_ostream &O,
                                  const NVPTXSubtarget &STI) {
-  O << "//\n";
-  O << "// Generated by LLVM NVPTX Back-End\n";
-  O << "//\n";
-  O << "\n";
+  const unsigned PTXVersion = STI.getPTXVersion();
 
-  unsigned PTXVersion = STI.getPTXVersion();
-  O << ".version " << (PTXVersion / 10) << "." << (PTXVersion % 10) << "\n";
-
-  O << ".target ";
-  O << STI.getTargetName();
+  O << "//\n"
+       "// Generated by LLVM NVPTX Back-End\n"
+       "//\n"
+       "\n"
+    << ".version " << (PTXVersion / 10) << "." << (PTXVersion % 10) << "\n"
+    << ".target " << STI.getTargetName();
 
   const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
   if (NTM.getDrvInterface() == NVPTX::NVCL)
@@ -871,16 +857,9 @@ void NVPTXAsmPrinter::emitHeader(Module &M, raw_ostream &O,
   if (HasFullDebugInfo)
     O << ", debug";
 
-  O << "\n";
-
-  O << ".address_size ";
-  if (NTM.is64Bit())
-    O << "64";
-  else
-    O << "32";
-  O << "\n";
-
-  O << "\n";
+  O << "\n"
+    << ".address_size " << (NTM.is64Bit() ? "64" : "32") << "\n"
+    << "\n";
 }
 
 bool NVPTXAsmPrinter::doFinalization(Module &M) {
@@ -928,41 +907,28 @@ void NVPTXAsmPrinter::emitLinkageDirective(const GlobalValue *V,
                                            raw_ostream &O) {
   if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() == NVPTX::CUDA) {
     if (V->hasExternalLinkage()) {
-      if (isa<GlobalVariable>(V)) {
-        const GlobalVariable *GVar = cast<GlobalVariable>(V);
-        if (GVar) {
-          if (GVar->hasInitializer())
-            O << ".visible ";
-          else
-            O << ".extern ";
-        }
-      } else if (V->isDeclaration())
+      if (const auto *GVar = dyn_cast<GlobalVariable>(V))
+        O << (GVar->hasInitializer() ? ".visible " : ".extern ");
+      else if (V->isDeclaration())
         O << ".extern ";
       else
         O << ".visible ";
     } else if (V->hasAppendingLinkage()) {
-      std::string msg;
-      msg.append("Error: ");
-      msg.append("Symbol ");
-      if (V->hasName())
-        msg.append(std::string(V->getName()));
-      msg.append("has unsupported appending linkage type");
-      llvm_unreachable(msg.c_str());
-    } else if (!V->hasInternalLinkage() &&
-               !V->hasPrivateLinkage()) {
+      report_fatal_error("Symbol '" + (V->hasName() ? V->getName() : "") +
+                         "' has unsupported appending linkage type");
+    } else if (!V->hasInternalLinkage() && !V->hasPrivateLinkage()) {
       O << ".weak ";
     }
   }
 }
 
 void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
-                                         raw_ostream &O, bool processDemoted,
+                                         raw_ostream &O, bool ProcessDemoted,
                                          const NVPTXSubtarget &STI) {
   // Skip meta data
-  if (GVar->hasSection()) {
+  if (GVar->hasSection())
     if (GVar->getSection() == "llvm.metadata")
       return;
-  }
 
   // Skip LLVM intrinsic global variables
   if (GVar->getName().starts_with("llvm.") ||
@@ -1069,20 +1035,20 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
   }
 
   if (GVar->hasPrivateLinkage()) {
-    if (strncmp(GVar->getName().data(), "unrollpragma", 12) == 0)
+    if (GVar->getName().starts_with("unrollpragma"))
       return;
 
     // FIXME - need better way (e.g. Metadata) to avoid generating this global
-    if (strncmp(GVar->getName().data(), "filename", 8) == 0)
+    if (GVar->getName().starts_with("filename"))
       return;
     if (GVar->use_empty())
       return;
   }
 
-  const Function *demotedFunc = nullptr;
-  if (!processDemoted && canDemoteGlobalVar(GVar, demotedFunc)) {
+  const Function *DemotedFunc = nullptr;
+  if (!ProcessDemoted && canDemoteGlobalVar(GVar, DemotedFunc)) {
     O << "// " << GVar->getName() << " has been demoted\n";
-    localDecls[demotedFunc].push_back(GVar);
+    localDecls[DemotedFunc].push_back(GVar);
     return;
   }
 
@@ -1090,17 +1056,14 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
   emitPTXAddressSpace(GVar->getAddressSpace(), O);
 
   if (isManaged(*GVar)) {
-    if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30) {
+    if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30)
       report_fatal_error(
           ".attribute(.managed) requires PTX version >= 4.0 and sm_30");
-    }
     O << " .attribute(.managed)";
   }
 
-  if (MaybeAlign A = GVar->getAlign())
-    O << " .align " << A->value();
-  else
-    O << " .align " << (int)DL.getPrefTypeAlign(ETy).value();
+  O << " .align "
+    << GVar->getAlign().value_or(DL.getPrefTypeAlign(ETy)).value();
 
   if (ETy->isFloatingPointTy() || ETy->isPointerTy() ||
       (ETy->isIntegerTy() && ETy->getScalarSizeInBits() <= 64)) {
@@ -1137,8 +1100,6 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
       }
     }
   } else {
-    uint64_t ElementSize = 0;
-
     // Although PTX has direct support for struct type and array type and
     // LLVM IR is very similar to PTX, the LLVM CodeGen does not support for
     // targets that support these high level field accesses. Structs, arrays
@@ -1147,8 +1108,8 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
     case Type::IntegerTyID: // Integers larger than 64 bits
     case Type::StructTyID:
     case Type::ArrayTyID:
-    case Type::FixedVectorTyID:
-      ElementSize = DL.getTypeStoreSize(ETy);
+    case Type::FixedVectorTyID: {
+      const uint64_t ElementSize = DL.getTypeStoreSize(ETy);
       // Ptx allows variable initilization only for constant and
       // global state spaces.
       if (((GVar->getAddressSpace() == ADDRESS_SPACE_GLOBAL) ||
@@ -1159,7 +1120,7 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
           AggBuffer aggBuffer(ElementSize, *this);
           bufferAggregateConstant(Initializer, &aggBuffer);
           if (aggBuffer.numSymbols()) {
-            unsigned int ptrSize = MAI->getCodePointerSize();
+            const unsigned int ptrSize = MAI->getCodePointerSize();
             if (ElementSize % ptrSize ||
                 !aggBuffer.allSymbolsAligned(ptrSize)) {
               // Print in bytes and use the mask() operator for pointers.
@@ -1190,22 +1151,17 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
         } else {
           O << " .b8 ";
           getSymbol(GVar)->print(O, MAI);
-          if (ElementSize) {
-            O << "[";
-            O << ElementSize;
-            O << "]";
-          }
+          if (ElementSize)
+            O << "[" << ElementSize << "]";
         }
       } else {
         O << " .b8 ";
         getSymbol(GVar)->print(O, MAI);
-        if (ElementSize) {
-          O << "[";
-          O << ElementSize;
-          O << "]";
-        }
+        if (ElementSize)
+          O << "[" << ElementSize << "]";
       }
       break;
+    }
     default:
       llvm_unreachable("type not supported yet");
     }
@@ -1229,7 +1185,7 @@ void NVPTXAsmPrinter::AggBuffer::printSymbol(unsigned nSym, raw_ostream &os) {
       Name->print(os, AP.MAI);
     }
   } else if (const ConstantExpr *CExpr = dyn_cast<ConstantExpr>(v0)) {
-    const MCExpr *Expr = AP.lowerConstantForGV(cast<Constant>(CExpr), false);
+    const MCExpr *Expr = AP.lowerConstantForGV(CExpr, false);
     AP.printMCExpr(*Expr, os);
   } else
     llvm_unreachable("symbol type unknown");
@@ -1298,18 +1254,18 @@ void NVPTXAsmPrinter::AggBuffer::printWords(raw_ostream &os) {
   }
 }
 
-void NVPTXAsmPrinter::emitDemotedVars(const Function *f, raw_ostream &O) {
-  auto It = localDecls.find(f);
+void NVPTXAsmPrinter::emitDemotedVars(const Function *F, raw_ostream &O) {
+  auto It = localDecls.find(F);
   if (It == localDecls.end())
     return;
 
-  std::vector<const GlobalVariable *> &gvars = It->second;
+  ArrayRef<const GlobalVariable *> GVars = It->second;
 
   const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
   const NVPTXSubtarget &STI =
       *static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
 
-  for (const GlobalVariable *GV : gvars) {
+  for (const GlobalVariable *GV : GVars) {
     O << "\t// demoted variable\n\t";
     printModuleLevelGV(GV, O, /*processDemoted=*/true, STI);
   }
@@ -1344,13 +1300,11 @@ NVPTXAsmPrinter::getPTXFundamentalTypeStr(Type *Ty, bool useB4PTR) const {
     unsigned NumBits = cast<IntegerType>(Ty)->getBitWidth();
     if (NumBits == 1)
       return "pred";
-    else if (NumBits <= 64) {
+    if (NumBits <= 64) {
       std::string name = "u";
       return name + utostr(NumBits);
-    } else {
-      llvm_unreachable("Integer too large");
-      break;
     }
+    llvm_unreachable("Integer too large");
     break;
   }
   case Type::BFloatTyID:
@@ -1393,16 +1347,14 @@ void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar,
   O << ".";
   emitPTXAddressSpace(GVar->getType()->getAddressSpace(), O);
   if (isManaged(*GVar)) {
-    if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30) {
+    if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30)
       report_fatal_error(
           ".attribute(.managed) requires PTX version >= 4.0 and sm_30");
-    }
+
     O << " .attribute(.managed)";
   }
-  if (MaybeAlign A = GVar->getAlign())
-    O << " .align " << A->value();
-  else
-    O << " .align " << (int)DL.getPrefTypeAlign(ETy).value();
+  O << " .align "
+    << GVar->getAlign().value_or(DL.getPrefTypeAlign(ETy)).value();
 
   // Special case for i128
   if (ETy->isIntegerTy(128)) {
@@ -1413,9 +1365,7 @@ void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar,
   }
 
   if (ETy->isFloatingPointTy() || ETy->isIntOrPtrTy()) {
-    O << " .";
-    O << getPTXFundamentalTypeStr(ETy);
-    O << " ";
+    O << " ." << getPTXFundamentalTypeStr(ETy) << " ";
     getSymbol(GVar)->print(O, MAI);
     return;
   }
@@ -1446,16 +1396,13 @@ void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar,
 
 void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
   const DataLayout &DL = getDataLayout();
-  const AttributeList &PAL = F->getAttributes();
   const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(*F);
   const auto *TLI = cast<NVPTXTargetLowering>(STI.getTargetLowering());
   const NVPTXMachineFunctionInfo *MFI =
       MF ? MF->getInfo<NVPTXMachineFunctionInfo>() : nullptr;
 
-  Function::const_arg_iterator I, E;
-  unsigned paramIndex = 0;
-  bool first = true;
-  bool isKernelFunc = isKernelFunction(*F);
+  bool IsFirst = true;
+  const bool IsKernelFunc = isKernelFunction(*F);
 
   if (F->arg_empty() && !F->isVarArg()) {
     O << "()";
@@ -1464,161 +1411,143 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
 
   O << "(\n";
 
-  for (I = F->arg_begin(), E = F->arg_end(); I != E; ++I, paramIndex++) {
-    Type *Ty = I->getType();
+  for (const Argument &Arg : F->args()) {
+    Type *Ty = Arg.getType();
+    const std::string ParamSym = TLI->getParamName(F, Arg.getArgNo());
 
-    if (!first)
+    if (!IsFirst)
       O << ",\n";
 
-    first = false;
+    IsFirst = false;
 
     // Handle image/sampler parameters
-    if (isKernelFunc) {
-      if (isSampler(*I) || isImage(*I)) {
-        std::string ParamSym;
-        raw_string_ostream ParamStr(ParamSym);
-        ParamStr << F->getName() << "_param_" << paramIndex;
-        ParamStr.flush();
-        bool EmitImagePtr = !MFI || !MFI->checkImageHandleSymbol(ParamSym);
-        if (isImage(*I)) {
-          if (isImageWriteOnly(*I) || isImageReadWrite(*I)) {
-            if (EmitImagePtr)
-              O << "\t.param .u64 .ptr .surfref ";
-            else
-              O << "\t.param .surfref ";
-            O << TLI->getParamName(F, paramIndex);
-          }
-          else { // Default image is read_only
-            if (EmitImagePtr)
-              O << "\t.param .u64 .ptr .texref ";
-            else
-              O << "\t.param .texref ";
-            O << TLI->getParamName(F, paramIndex);
-          }
-        } else {
-          if (EmitImagePtr)
-            O << "\t.param .u64 .ptr .samplerref ";
-          else
-            O << "\t.param .samplerref ";
-          O << TLI->getParamName(F, paramIndex);
-        }
+    if (IsKernelFunc) {
+      const bool IsSampler = isSampler(Arg);
+      const bool IsTexture = !IsSampler && isImageReadOnly(Arg);
+      const bool IsSurface = !IsSampler && !IsTexture &&
+                             (isImageReadWrite(Arg) || isImageWriteOnly(Arg));
+      if (IsSampler || IsTexture || IsSurface) {
+        const bool EmitImgPtr = !MFI || !MFI->checkImageHandleSymbol(ParamSym);
+        O << "\t.param ";
+        if (EmitImgPtr)
+          O << ".u64 .ptr ";
+
+        if (IsSampler)...
[truncated]

Copy link
Member

@Artem-B Artem-B left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. The patch is good to go.

My rant on getNameOrAsOperand is for that function only and does not need to be addressed here.

@@ -914,7 +914,7 @@ void NVPTXAsmPrinter::emitLinkageDirective(const GlobalValue *V,
else
O << ".visible ";
} else if (V->hasAppendingLinkage()) {
report_fatal_error("Symbol '" + llvm::Twine(V->getNameOrAsOperand()) +
report_fatal_error("Symbol '" + (V->hasName() ? V->getName() : "") +
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO, it would make sense to make getNameOrAsOperand() available in general, but that's unrelated to your patch.

I think the initial commit that introduced getNameOrAsOperand was somewhat misguided to put that function under NDEBUG. IMO debug/nodebug should not change availability of class members. It's OK to change member function implementation, if necessary. E.g how it's done here:

void assertModuleIsMaterialized() const {
#ifndef NDEBUG
assertModuleIsMaterializedImpl();
#endif
}

It just adds unnecessary headache without buying us anything.

That said, that getNameOrAsOperand is a terrible name. Your code as is makes more sense. Let's keep it.

@AlexMaclean AlexMaclean merged commit ecdfa36 into llvm:main Feb 13, 2025
8 of 10 checks passed
metaflow added a commit that referenced this pull request Feb 14, 2025
…egisterInfo (NFC)" (#127089)"

This reverts commit ecdfa36.

That introduced a breaking change in printer.

E.g. llvm/test/CodeGen/NVPTX/surf-read.ll started to output `.param .samplerref foo_param_0,` instead of `.param .surfref foo_param_0,`.

Looks like we don't have ptxas llvm build bot to catch this sort of NFC.
@metaflow
Copy link
Contributor

metaflow commented Feb 14, 2025

Sorry, I have reverted this as this not a complete NFC.

It broke llvm/test/CodeGen/NVPTX/surf-read.ll and llvm/test/CodeGen/NVPTX/surf-write.ll tests. Diff of llc output is

12c12
<       .param .surfref foo_param_0,
---
>       .param .samplerref foo_param_0,

and ptxas complains

ptxas /tmp/tmpxft_0007f20a_00000000-0_stdin, line 23; error   : Argument 1 of instruction 'suld.b': .surfref or .u64 register expected

@metaflow
Copy link
Contributor

it also shows that we don't seem to have any ptxas build bots

joaosaffran pushed a commit to joaosaffran/llvm-project that referenced this pull request Feb 14, 2025
joaosaffran pushed a commit to joaosaffran/llvm-project that referenced this pull request Feb 14, 2025
…egisterInfo (NFC)" (llvm#127089)"

This reverts commit ecdfa36.

That introduced a breaking change in printer.

E.g. llvm/test/CodeGen/NVPTX/surf-read.ll started to output `.param .samplerref foo_param_0,` instead of `.param .surfref foo_param_0,`.

Looks like we don't have ptxas llvm build bot to catch this sort of NFC.
AlexMaclean added a commit to AlexMaclean/llvm-project that referenced this pull request Feb 14, 2025
sivan-shani pushed a commit to sivan-shani/llvm-project that referenced this pull request Feb 24, 2025
sivan-shani pushed a commit to sivan-shani/llvm-project that referenced this pull request Feb 24, 2025
…egisterInfo (NFC)" (llvm#127089)"

This reverts commit ecdfa36.

That introduced a breaking change in printer.

E.g. llvm/test/CodeGen/NVPTX/surf-read.ll started to output `.param .samplerref foo_param_0,` instead of `.param .surfref foo_param_0,`.

Looks like we don't have ptxas llvm build bot to catch this sort of NFC.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants