From 8d47c2d577e0066a85c3b1f681e7d1ad7b3f5dff Mon Sep 17 00:00:00 2001 From: Alex MacLean Date: Wed, 12 Feb 2025 00:23:36 -0800 Subject: [PATCH 1/2] [NVPTX] Cleanup/Refactoring in NVPTX AsmPrinter and RegisterInfo (NFC) (#126800) --- llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp | 516 ++++++++------------ llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h | 22 +- llvm/lib/Target/NVPTX/NVPTXRegisterInfo.cpp | 23 +- llvm/lib/Target/NVPTX/NVPTXRegisterInfo.h | 6 +- 4 files changed, 238 insertions(+), 329 deletions(-) diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp index 75d930d9f7b6f..5b60151c14cc4 100644 --- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp @@ -27,6 +27,7 @@ #include "cl_common_defines.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/APInt.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/SmallString.h" @@ -47,6 +48,7 @@ #include "llvm/CodeGen/TargetRegisterInfo.h" #include "llvm/CodeGen/ValueTypes.h" #include "llvm/CodeGenTypes/MachineValueType.h" +#include "llvm/IR/Argument.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constant.h" @@ -93,20 +95,19 @@ using namespace llvm; #define DEPOTNAME "__local_depot" -/// DiscoverDependentGlobals - Return a set of GlobalVariables on which \p V +/// discoverDependentGlobals - Return a set of GlobalVariables on which \p V /// depends. static void -DiscoverDependentGlobals(const Value *V, +discoverDependentGlobals(const Value *V, DenseSet &Globals) { - if (const GlobalVariable *GV = dyn_cast(V)) + if (const GlobalVariable *GV = dyn_cast(V)) { Globals.insert(GV); - else { - if (const User *U = dyn_cast(V)) { - for (unsigned i = 0, e = U->getNumOperands(); i != e; ++i) { - DiscoverDependentGlobals(U->getOperand(i), Globals); - } - } + return; } + + if (const User *U = dyn_cast(V)) + for (const auto &O : U->operands()) + discoverDependentGlobals(O, Globals); } /// VisitGlobalVariableForEmission - Add \p GV to the list of GlobalVariable @@ -127,8 +128,8 @@ VisitGlobalVariableForEmission(const GlobalVariable *GV, // Make sure we visit all dependents first DenseSet Others; - for (unsigned i = 0, e = GV->getNumOperands(); i != e; ++i) - DiscoverDependentGlobals(GV->getOperand(i), Others); + for (const auto &O : GV->operands()) + discoverDependentGlobals(O, Others); for (const GlobalVariable *GV : Others) VisitGlobalVariableForEmission(GV, Order, Visited, Visiting); @@ -623,9 +624,8 @@ static bool usedInGlobalVarDef(const Constant *C) { if (!C) return false; - if (const GlobalVariable *GV = dyn_cast(C)) { + if (const GlobalVariable *GV = dyn_cast(C)) return GV->getName() != "llvm.used"; - } for (const User *U : C->users()) if (const Constant *C = dyn_cast(U)) @@ -635,25 +635,23 @@ static bool usedInGlobalVarDef(const Constant *C) { return false; } -static bool usedInOneFunc(const User *U, Function const *&oneFunc) { - if (const GlobalVariable *othergv = dyn_cast(U)) { - if (othergv->getName() == "llvm.used") +static bool usedInOneFunc(const User *U, Function const *&OneFunc) { + if (const GlobalVariable *OtherGV = dyn_cast(U)) + if (OtherGV->getName() == "llvm.used") return true; - } - if (const Instruction *instr = dyn_cast(U)) { - if (instr->getParent() && instr->getParent()->getParent()) { - const Function *curFunc = instr->getParent()->getParent(); - if (oneFunc && (curFunc != oneFunc)) + if (const Instruction *I = dyn_cast(U)) { + if (const Function *CurFunc = I->getFunction()) { + if (OneFunc && (CurFunc != OneFunc)) return false; - oneFunc = curFunc; + OneFunc = CurFunc; return true; - } else - return false; + } + return false; } for (const User *UU : U->users()) - if (!usedInOneFunc(UU, oneFunc)) + if (!usedInOneFunc(UU, OneFunc)) return false; return true; @@ -666,16 +664,15 @@ static bool usedInOneFunc(const User *U, Function const *&oneFunc) { * 2. Does it have local linkage? * 3. Is the global variable referenced only in one function? */ -static bool canDemoteGlobalVar(const GlobalVariable *gv, Function const *&f) { - if (!gv->hasLocalLinkage()) +static bool canDemoteGlobalVar(const GlobalVariable *GV, Function const *&f) { + if (!GV->hasLocalLinkage()) return false; - PointerType *Pty = gv->getType(); - if (Pty->getAddressSpace() != ADDRESS_SPACE_SHARED) + if (GV->getAddressSpace() != ADDRESS_SPACE_SHARED) return false; const Function *oneFunc = nullptr; - bool flag = usedInOneFunc(gv, oneFunc); + bool flag = usedInOneFunc(GV, oneFunc); if (!flag) return false; if (!oneFunc) @@ -685,27 +682,22 @@ static bool canDemoteGlobalVar(const GlobalVariable *gv, Function const *&f) { } static bool useFuncSeen(const Constant *C, - DenseMap &seenMap) { + const SmallPtrSetImpl &SeenSet) { for (const User *U : C->users()) { if (const Constant *cu = dyn_cast(U)) { - if (useFuncSeen(cu, seenMap)) + if (useFuncSeen(cu, SeenSet)) return true; } else if (const Instruction *I = dyn_cast(U)) { - const BasicBlock *bb = I->getParent(); - if (!bb) - continue; - const Function *caller = bb->getParent(); - if (!caller) - continue; - if (seenMap.contains(caller)) - return true; + if (const Function *Caller = I->getFunction()) + if (SeenSet.contains(Caller)) + return true; } } return false; } void NVPTXAsmPrinter::emitDeclarations(const Module &M, raw_ostream &O) { - DenseMap seenMap; + SmallPtrSet SeenSet; for (const Function &F : M) { if (F.getAttributes().hasFnAttr("nvptx-libcall-callee")) { emitDeclaration(&F, O); @@ -731,7 +723,7 @@ void NVPTXAsmPrinter::emitDeclarations(const Module &M, raw_ostream &O) { } // Emit a declaration of this function if the function that // uses this constant expr has already been seen. - if (useFuncSeen(C, seenMap)) { + if (useFuncSeen(C, SeenSet)) { emitDeclaration(&F, O); break; } @@ -739,23 +731,19 @@ void NVPTXAsmPrinter::emitDeclarations(const Module &M, raw_ostream &O) { if (!isa(U)) continue; - const Instruction *instr = cast(U); - const BasicBlock *bb = instr->getParent(); - if (!bb) - continue; - const Function *caller = bb->getParent(); - if (!caller) + const Function *Caller = cast(U)->getFunction(); + if (!Caller) continue; // If a caller has already been seen, then the caller is // appearing in the module before the callee. so print out // a declaration for the callee. - if (seenMap.contains(caller)) { + if (SeenSet.contains(Caller)) { emitDeclaration(&F, O); break; } } - seenMap[&F] = true; + SeenSet.insert(&F); } for (const GlobalAlias &GA : M.aliases()) emitAliasDeclaration(&GA, O); @@ -818,7 +806,7 @@ void NVPTXAsmPrinter::emitGlobals(const Module &M) { // Print out module-level global variables in proper order for (const GlobalVariable *GV : Globals) - printModuleLevelGV(GV, OS2, /*processDemoted=*/false, STI); + printModuleLevelGV(GV, OS2, /*ProcessDemoted=*/false, STI); OS2 << '\n'; @@ -839,16 +827,14 @@ void NVPTXAsmPrinter::emitGlobalAlias(const Module &M, const GlobalAlias &GA) { void NVPTXAsmPrinter::emitHeader(Module &M, raw_ostream &O, const NVPTXSubtarget &STI) { - O << "//\n"; - O << "// Generated by LLVM NVPTX Back-End\n"; - O << "//\n"; - O << "\n"; + const unsigned PTXVersion = STI.getPTXVersion(); - unsigned PTXVersion = STI.getPTXVersion(); - O << ".version " << (PTXVersion / 10) << "." << (PTXVersion % 10) << "\n"; - - O << ".target "; - O << STI.getTargetName(); + O << "//\n" + "// Generated by LLVM NVPTX Back-End\n" + "//\n" + "\n" + << ".version " << (PTXVersion / 10) << "." << (PTXVersion % 10) << "\n" + << ".target " << STI.getTargetName(); const NVPTXTargetMachine &NTM = static_cast(TM); if (NTM.getDrvInterface() == NVPTX::NVCL) @@ -871,16 +857,9 @@ void NVPTXAsmPrinter::emitHeader(Module &M, raw_ostream &O, if (HasFullDebugInfo) O << ", debug"; - O << "\n"; - - O << ".address_size "; - if (NTM.is64Bit()) - O << "64"; - else - O << "32"; - O << "\n"; - - O << "\n"; + O << "\n" + << ".address_size " << (NTM.is64Bit() ? "64" : "32") << "\n" + << "\n"; } bool NVPTXAsmPrinter::doFinalization(Module &M) { @@ -928,41 +907,28 @@ void NVPTXAsmPrinter::emitLinkageDirective(const GlobalValue *V, raw_ostream &O) { if (static_cast(TM).getDrvInterface() == NVPTX::CUDA) { if (V->hasExternalLinkage()) { - if (isa(V)) { - const GlobalVariable *GVar = cast(V); - if (GVar) { - if (GVar->hasInitializer()) - O << ".visible "; - else - O << ".extern "; - } - } else if (V->isDeclaration()) + if (const auto *GVar = dyn_cast(V)) + O << (GVar->hasInitializer() ? ".visible " : ".extern "); + else if (V->isDeclaration()) O << ".extern "; else O << ".visible "; } else if (V->hasAppendingLinkage()) { - std::string msg; - msg.append("Error: "); - msg.append("Symbol "); - if (V->hasName()) - msg.append(std::string(V->getName())); - msg.append("has unsupported appending linkage type"); - llvm_unreachable(msg.c_str()); - } else if (!V->hasInternalLinkage() && - !V->hasPrivateLinkage()) { + report_fatal_error("Symbol '" + llvm::Twine(V->getNameOrAsOperand()) + + "' has unsupported appending linkage type"); + } else if (!V->hasInternalLinkage() && !V->hasPrivateLinkage()) { O << ".weak "; } } } void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar, - raw_ostream &O, bool processDemoted, + raw_ostream &O, bool ProcessDemoted, const NVPTXSubtarget &STI) { // Skip meta data - if (GVar->hasSection()) { + if (GVar->hasSection()) if (GVar->getSection() == "llvm.metadata") return; - } // Skip LLVM intrinsic global variables if (GVar->getName().starts_with("llvm.") || @@ -1069,20 +1035,20 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar, } if (GVar->hasPrivateLinkage()) { - if (strncmp(GVar->getName().data(), "unrollpragma", 12) == 0) + if (GVar->getName().starts_with("unrollpragma")) return; // FIXME - need better way (e.g. Metadata) to avoid generating this global - if (strncmp(GVar->getName().data(), "filename", 8) == 0) + if (GVar->getName().starts_with("filename")) return; if (GVar->use_empty()) return; } - const Function *demotedFunc = nullptr; - if (!processDemoted && canDemoteGlobalVar(GVar, demotedFunc)) { + const Function *DemotedFunc = nullptr; + if (!ProcessDemoted && canDemoteGlobalVar(GVar, DemotedFunc)) { O << "// " << GVar->getName() << " has been demoted\n"; - localDecls[demotedFunc].push_back(GVar); + localDecls[DemotedFunc].push_back(GVar); return; } @@ -1090,17 +1056,14 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar, emitPTXAddressSpace(GVar->getAddressSpace(), O); if (isManaged(*GVar)) { - if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30) { + if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30) report_fatal_error( ".attribute(.managed) requires PTX version >= 4.0 and sm_30"); - } O << " .attribute(.managed)"; } - if (MaybeAlign A = GVar->getAlign()) - O << " .align " << A->value(); - else - O << " .align " << (int)DL.getPrefTypeAlign(ETy).value(); + O << " .align " + << GVar->getAlign().value_or(DL.getPrefTypeAlign(ETy)).value(); if (ETy->isFloatingPointTy() || ETy->isPointerTy() || (ETy->isIntegerTy() && ETy->getScalarSizeInBits() <= 64)) { @@ -1137,8 +1100,6 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar, } } } else { - uint64_t ElementSize = 0; - // Although PTX has direct support for struct type and array type and // LLVM IR is very similar to PTX, the LLVM CodeGen does not support for // targets that support these high level field accesses. Structs, arrays @@ -1147,8 +1108,8 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar, case Type::IntegerTyID: // Integers larger than 64 bits case Type::StructTyID: case Type::ArrayTyID: - case Type::FixedVectorTyID: - ElementSize = DL.getTypeStoreSize(ETy); + case Type::FixedVectorTyID: { + const uint64_t ElementSize = DL.getTypeStoreSize(ETy); // Ptx allows variable initilization only for constant and // global state spaces. if (((GVar->getAddressSpace() == ADDRESS_SPACE_GLOBAL) || @@ -1159,7 +1120,7 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar, AggBuffer aggBuffer(ElementSize, *this); bufferAggregateConstant(Initializer, &aggBuffer); if (aggBuffer.numSymbols()) { - unsigned int ptrSize = MAI->getCodePointerSize(); + const unsigned int ptrSize = MAI->getCodePointerSize(); if (ElementSize % ptrSize || !aggBuffer.allSymbolsAligned(ptrSize)) { // Print in bytes and use the mask() operator for pointers. @@ -1190,22 +1151,17 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar, } else { O << " .b8 "; getSymbol(GVar)->print(O, MAI); - if (ElementSize) { - O << "["; - O << ElementSize; - O << "]"; - } + if (ElementSize) + O << "[" << ElementSize << "]"; } } else { O << " .b8 "; getSymbol(GVar)->print(O, MAI); - if (ElementSize) { - O << "["; - O << ElementSize; - O << "]"; - } + if (ElementSize) + O << "[" << ElementSize << "]"; } break; + } default: llvm_unreachable("type not supported yet"); } @@ -1229,7 +1185,7 @@ void NVPTXAsmPrinter::AggBuffer::printSymbol(unsigned nSym, raw_ostream &os) { Name->print(os, AP.MAI); } } else if (const ConstantExpr *CExpr = dyn_cast(v0)) { - const MCExpr *Expr = AP.lowerConstantForGV(cast(CExpr), false); + const MCExpr *Expr = AP.lowerConstantForGV(CExpr, false); AP.printMCExpr(*Expr, os); } else llvm_unreachable("symbol type unknown"); @@ -1298,18 +1254,18 @@ void NVPTXAsmPrinter::AggBuffer::printWords(raw_ostream &os) { } } -void NVPTXAsmPrinter::emitDemotedVars(const Function *f, raw_ostream &O) { - auto It = localDecls.find(f); +void NVPTXAsmPrinter::emitDemotedVars(const Function *F, raw_ostream &O) { + auto It = localDecls.find(F); if (It == localDecls.end()) return; - std::vector &gvars = It->second; + ArrayRef GVars = It->second; const NVPTXTargetMachine &NTM = static_cast(TM); const NVPTXSubtarget &STI = *static_cast(NTM.getSubtargetImpl()); - for (const GlobalVariable *GV : gvars) { + for (const GlobalVariable *GV : GVars) { O << "\t// demoted variable\n\t"; printModuleLevelGV(GV, O, /*processDemoted=*/true, STI); } @@ -1344,13 +1300,11 @@ NVPTXAsmPrinter::getPTXFundamentalTypeStr(Type *Ty, bool useB4PTR) const { unsigned NumBits = cast(Ty)->getBitWidth(); if (NumBits == 1) return "pred"; - else if (NumBits <= 64) { + if (NumBits <= 64) { std::string name = "u"; return name + utostr(NumBits); - } else { - llvm_unreachable("Integer too large"); - break; } + llvm_unreachable("Integer too large"); break; } case Type::BFloatTyID: @@ -1393,16 +1347,14 @@ void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar, O << "."; emitPTXAddressSpace(GVar->getType()->getAddressSpace(), O); if (isManaged(*GVar)) { - if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30) { + if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30) report_fatal_error( ".attribute(.managed) requires PTX version >= 4.0 and sm_30"); - } + O << " .attribute(.managed)"; } - if (MaybeAlign A = GVar->getAlign()) - O << " .align " << A->value(); - else - O << " .align " << (int)DL.getPrefTypeAlign(ETy).value(); + O << " .align " + << GVar->getAlign().value_or(DL.getPrefTypeAlign(ETy)).value(); // Special case for i128 if (ETy->isIntegerTy(128)) { @@ -1413,9 +1365,7 @@ void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar, } if (ETy->isFloatingPointTy() || ETy->isIntOrPtrTy()) { - O << " ."; - O << getPTXFundamentalTypeStr(ETy); - O << " "; + O << " ." << getPTXFundamentalTypeStr(ETy) << " "; getSymbol(GVar)->print(O, MAI); return; } @@ -1446,16 +1396,13 @@ void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar, void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) { const DataLayout &DL = getDataLayout(); - const AttributeList &PAL = F->getAttributes(); const NVPTXSubtarget &STI = TM.getSubtarget(*F); const auto *TLI = cast(STI.getTargetLowering()); const NVPTXMachineFunctionInfo *MFI = MF ? MF->getInfo() : nullptr; - Function::const_arg_iterator I, E; - unsigned paramIndex = 0; - bool first = true; - bool isKernelFunc = isKernelFunction(*F); + bool IsFirst = true; + const bool IsKernelFunc = isKernelFunction(*F); if (F->arg_empty() && !F->isVarArg()) { O << "()"; @@ -1464,161 +1411,143 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) { O << "(\n"; - for (I = F->arg_begin(), E = F->arg_end(); I != E; ++I, paramIndex++) { - Type *Ty = I->getType(); + for (const Argument &Arg : F->args()) { + Type *Ty = Arg.getType(); + const std::string ParamSym = TLI->getParamName(F, Arg.getArgNo()); - if (!first) + if (!IsFirst) O << ",\n"; - first = false; + IsFirst = false; // Handle image/sampler parameters - if (isKernelFunc) { - if (isSampler(*I) || isImage(*I)) { - std::string ParamSym; - raw_string_ostream ParamStr(ParamSym); - ParamStr << F->getName() << "_param_" << paramIndex; - ParamStr.flush(); - bool EmitImagePtr = !MFI || !MFI->checkImageHandleSymbol(ParamSym); - if (isImage(*I)) { - if (isImageWriteOnly(*I) || isImageReadWrite(*I)) { - if (EmitImagePtr) - O << "\t.param .u64 .ptr .surfref "; - else - O << "\t.param .surfref "; - O << TLI->getParamName(F, paramIndex); - } - else { // Default image is read_only - if (EmitImagePtr) - O << "\t.param .u64 .ptr .texref "; - else - O << "\t.param .texref "; - O << TLI->getParamName(F, paramIndex); - } - } else { - if (EmitImagePtr) - O << "\t.param .u64 .ptr .samplerref "; - else - O << "\t.param .samplerref "; - O << TLI->getParamName(F, paramIndex); - } + if (IsKernelFunc) { + const bool IsSampler = isSampler(Arg); + const bool IsTexture = !IsSampler && isImageReadOnly(Arg); + const bool IsSurface = !IsSampler && !IsTexture && + (isImageReadWrite(Arg) || isImageWriteOnly(Arg)); + if (IsSampler || IsTexture || IsSurface) { + const bool EmitImgPtr = !MFI || !MFI->checkImageHandleSymbol(ParamSym); + O << "\t.param "; + if (EmitImgPtr) + O << ".u64 .ptr "; + + if (IsSampler) + O << ".samplerref "; + else if (IsTexture) + O << ".texref "; + else // IsSurface + O << ".samplerref "; + O << ParamSym; continue; } } - auto getOptimalAlignForParam = [TLI, &DL, &PAL, F, - paramIndex](Type *Ty) -> Align { + auto GetOptimalAlignForParam = [TLI, &DL, F, &Arg](Type *Ty) -> Align { if (MaybeAlign StackAlign = - getAlign(*F, paramIndex + AttributeList::FirstArgIndex)) + getAlign(*F, Arg.getArgNo() + AttributeList::FirstArgIndex)) return StackAlign.value(); Align TypeAlign = TLI->getFunctionParamOptimizedAlign(F, Ty, DL); - MaybeAlign ParamAlign = PAL.getParamAlignment(paramIndex); + MaybeAlign ParamAlign = + Arg.hasByValAttr() ? Arg.getParamAlign() : MaybeAlign(); return std::max(TypeAlign, ParamAlign.valueOrOne()); }; - if (!PAL.hasParamAttr(paramIndex, Attribute::ByVal)) { - if (ShouldPassAsArray(Ty)) { - // Just print .param .align .b8 .param[size]; - // = optimal alignment for the element type; always multiple of - // PAL.getParamAlignment - // size = typeallocsize of element type - Align OptimalAlign = getOptimalAlignForParam(Ty); + if (Arg.hasByValAttr()) { + // param has byVal attribute. + Type *ETy = Arg.getParamByValType(); + assert(ETy && "Param should have byval type"); + + // Print .param .align .b8 .param[size]; + // = optimal alignment for the element type; always multiple of + // PAL.getParamAlignment + // size = typeallocsize of element type + const Align OptimalAlign = + IsKernelFunc ? GetOptimalAlignForParam(ETy) + : TLI->getFunctionByValParamAlign( + F, ETy, Arg.getParamAlign().valueOrOne(), DL); + + O << "\t.param .align " << OptimalAlign.value() << " .b8 " << ParamSym + << "[" << DL.getTypeAllocSize(ETy) << "]"; + continue; + } - O << "\t.param .align " << OptimalAlign.value() << " .b8 "; - O << TLI->getParamName(F, paramIndex); - O << "[" << DL.getTypeAllocSize(Ty) << "]"; + if (ShouldPassAsArray(Ty)) { + // Just print .param .align .b8 .param[size]; + // = optimal alignment for the element type; always multiple of + // PAL.getParamAlignment + // size = typeallocsize of element type + Align OptimalAlign = GetOptimalAlignForParam(Ty); - continue; - } - // Just a scalar - auto *PTy = dyn_cast(Ty); - unsigned PTySizeInBits = 0; - if (PTy) { - PTySizeInBits = - TLI->getPointerTy(DL, PTy->getAddressSpace()).getSizeInBits(); - assert(PTySizeInBits && "Invalid pointer size"); - } + O << "\t.param .align " << OptimalAlign.value() << " .b8 " << ParamSym + << "[" << DL.getTypeAllocSize(Ty) << "]"; - if (isKernelFunc) { - if (PTy) { - O << "\t.param .u" << PTySizeInBits << " .ptr"; - - switch (PTy->getAddressSpace()) { - default: - break; - case ADDRESS_SPACE_GLOBAL: - O << " .global"; - break; - case ADDRESS_SPACE_SHARED: - O << " .shared"; - break; - case ADDRESS_SPACE_CONST: - O << " .const"; - break; - case ADDRESS_SPACE_LOCAL: - O << " .local"; - break; - } + continue; + } + // Just a scalar + auto *PTy = dyn_cast(Ty); + unsigned PTySizeInBits = 0; + if (PTy) { + PTySizeInBits = + TLI->getPointerTy(DL, PTy->getAddressSpace()).getSizeInBits(); + assert(PTySizeInBits && "Invalid pointer size"); + } - O << " .align " << I->getParamAlign().valueOrOne().value(); - O << " " << TLI->getParamName(F, paramIndex); - continue; + if (IsKernelFunc) { + if (PTy) { + O << "\t.param .u" << PTySizeInBits << " .ptr"; + + switch (PTy->getAddressSpace()) { + default: + break; + case ADDRESS_SPACE_GLOBAL: + O << " .global"; + break; + case ADDRESS_SPACE_SHARED: + O << " .shared"; + break; + case ADDRESS_SPACE_CONST: + O << " .const"; + break; + case ADDRESS_SPACE_LOCAL: + O << " .local"; + break; } - // non-pointer scalar to kernel func - O << "\t.param ."; - // Special case: predicate operands become .u8 types - if (Ty->isIntegerTy(1)) - O << "u8"; - else - O << getPTXFundamentalTypeStr(Ty); - O << " "; - O << TLI->getParamName(F, paramIndex); + O << " .align " << Arg.getParamAlign().valueOrOne().value() << " " + << ParamSym; continue; } - // Non-kernel function, just print .param .b for ABI - // and .reg .b for non-ABI - unsigned sz = 0; - if (isa(Ty)) { - sz = cast(Ty)->getBitWidth(); - sz = promoteScalarArgumentSize(sz); - } else if (PTy) { - assert(PTySizeInBits && "Invalid pointer size"); - sz = PTySizeInBits; - } else - sz = Ty->getPrimitiveSizeInBits(); - O << "\t.param .b" << sz << " "; - O << TLI->getParamName(F, paramIndex); + + // non-pointer scalar to kernel func + O << "\t.param ."; + // Special case: predicate operands become .u8 types + if (Ty->isIntegerTy(1)) + O << "u8"; + else + O << getPTXFundamentalTypeStr(Ty); + O << " " << ParamSym; continue; } - - // param has byVal attribute. - Type *ETy = PAL.getParamByValType(paramIndex); - assert(ETy && "Param should have byval type"); - - // Print .param .align .b8 .param[size]; - // = optimal alignment for the element type; always multiple of - // PAL.getParamAlignment - // size = typeallocsize of element type - Align OptimalAlign = - isKernelFunc - ? getOptimalAlignForParam(ETy) - : TLI->getFunctionByValParamAlign( - F, ETy, PAL.getParamAlignment(paramIndex).valueOrOne(), DL); - - unsigned sz = DL.getTypeAllocSize(ETy); - O << "\t.param .align " << OptimalAlign.value() << " .b8 "; - O << TLI->getParamName(F, paramIndex); - O << "[" << sz << "]"; + // Non-kernel function, just print .param .b for ABI + // and .reg .b for non-ABI + unsigned Size; + if (auto *ITy = dyn_cast(Ty)) { + Size = promoteScalarArgumentSize(ITy->getBitWidth()); + } else if (PTy) { + assert(PTySizeInBits && "Invalid pointer size"); + Size = PTySizeInBits; + } else + Size = Ty->getPrimitiveSizeInBits(); + O << "\t.param .b" << Size << " " << ParamSym; } if (F->isVarArg()) { - if (!first) + if (!IsFirst) O << ",\n"; - O << "\t.param .align " << STI.getMaxRequiredAlignment(); - O << " .b8 "; - O << TLI->getParamName(F, /* vararg */ -1) << "[]"; + O << "\t.param .align " << STI.getMaxRequiredAlignment() << " .b8 " + << TLI->getParamName(F, /* vararg */ -1) << "[]"; } O << "\n)"; @@ -1641,11 +1570,11 @@ void NVPTXAsmPrinter::setAndEmitFunctionVirtualRegisters( O << "\t.local .align " << MFI.getMaxAlign().value() << " .b8 \t" << DEPOTNAME << getFunctionNumber() << "[" << NumBytes << "];\n"; if (static_cast(MF.getTarget()).is64Bit()) { - O << "\t.reg .b64 \t%SP;\n"; - O << "\t.reg .b64 \t%SPL;\n"; + O << "\t.reg .b64 \t%SP;\n" + << "\t.reg .b64 \t%SPL;\n"; } else { - O << "\t.reg .b32 \t%SP;\n"; - O << "\t.reg .b32 \t%SPL;\n"; + O << "\t.reg .b32 \t%SP;\n" + << "\t.reg .b32 \t%SPL;\n"; } } @@ -1662,29 +1591,16 @@ void NVPTXAsmPrinter::setAndEmitFunctionVirtualRegisters( regmap.insert(std::make_pair(vr, n + 1)); } - // Emit register declarations - // @TODO: Extract out the real register usage - // O << "\t.reg .pred %p<" << NVPTXNumRegisters << ">;\n"; - // O << "\t.reg .s16 %rc<" << NVPTXNumRegisters << ">;\n"; - // O << "\t.reg .s16 %rs<" << NVPTXNumRegisters << ">;\n"; - // O << "\t.reg .s32 %r<" << NVPTXNumRegisters << ">;\n"; - // O << "\t.reg .s64 %rd<" << NVPTXNumRegisters << ">;\n"; - // O << "\t.reg .f32 %f<" << NVPTXNumRegisters << ">;\n"; - // O << "\t.reg .f64 %fd<" << NVPTXNumRegisters << ">;\n"; - // Emit declaration of the virtual registers or 'physical' registers for // each register class - for (unsigned i=0; i< TRI->getNumRegClasses(); i++) { - const TargetRegisterClass *RC = TRI->getRegClass(i); - DenseMap ®map = VRegMapping[RC]; - std::string rcname = getNVPTXRegClassName(RC); - std::string rcStr = getNVPTXRegClassStr(RC); - int n = regmap.size(); + for (const TargetRegisterClass *RC : TRI->regclasses()) { + const unsigned N = VRegMapping[RC].size(); // Only declare those registers that may be used. - if (n) { - O << "\t.reg " << rcname << " \t" << rcStr << "<" << (n+1) - << ">;\n"; + if (N) { + const StringRef RCName = getNVPTXRegClassName(RC); + const StringRef RCStr = getNVPTXRegClassStr(RC); + O << "\t.reg " << RCName << " \t" << RCStr << "<" << (N + 1) << ">;\n"; } } @@ -1711,7 +1627,8 @@ void NVPTXAsmPrinter::encodeDebugInfoRegisterNumbers( } } -void NVPTXAsmPrinter::printFPConstant(const ConstantFP *Fp, raw_ostream &O) { +void NVPTXAsmPrinter::printFPConstant(const ConstantFP *Fp, + raw_ostream &O) const { APFloat APF = APFloat(Fp->getValueAPF()); // make a copy bool ignored; unsigned int numHex; @@ -1746,10 +1663,7 @@ void NVPTXAsmPrinter::printScalarConstant(const Constant *CPV, raw_ostream &O) { return; } if (const GlobalValue *GVar = dyn_cast(CPV)) { - bool IsNonGenericPointer = false; - if (GVar->getType()->getAddressSpace() != 0) { - IsNonGenericPointer = true; - } + const bool IsNonGenericPointer = GVar->getAddressSpace() != 0; if (EmitGeneric && !isa(CPV) && !IsNonGenericPointer) { O << "generic("; getSymbol(GVar)->print(O, MAI); @@ -1798,7 +1712,7 @@ void NVPTXAsmPrinter::bufferLEByte(const Constant *CPV, int Bytes, switch (CPV->getType()->getTypeID()) { case Type::IntegerTyID: - if (const auto CI = dyn_cast(CPV)) { + if (const auto *CI = dyn_cast(CPV)) { AddIntToBuffer(CI->getValue()); break; } @@ -1912,7 +1826,8 @@ void NVPTXAsmPrinter::bufferAggregateConstant(const Constant *CPV, /// expressions that are representable in PTX and create /// NVPTXGenericMCSymbolRefExpr nodes for addrspacecast instructions. const MCExpr * -NVPTXAsmPrinter::lowerConstantForGV(const Constant *CV, bool ProcessingGeneric) { +NVPTXAsmPrinter::lowerConstantForGV(const Constant *CV, + bool ProcessingGeneric) const { MCContext &Ctx = OutContext; if (CV->isNullValue() || isa(CV)) @@ -1922,13 +1837,10 @@ NVPTXAsmPrinter::lowerConstantForGV(const Constant *CV, bool ProcessingGeneric) return MCConstantExpr::create(CI->getZExtValue(), Ctx); if (const GlobalValue *GV = dyn_cast(CV)) { - const MCSymbolRefExpr *Expr = - MCSymbolRefExpr::create(getSymbol(GV), Ctx); - if (ProcessingGeneric) { + const MCSymbolRefExpr *Expr = MCSymbolRefExpr::create(getSymbol(GV), Ctx); + if (ProcessingGeneric) return NVPTXGenericMCSymbolRefExpr::create(Expr, Ctx); - } else { - return Expr; - } + return Expr; } const ConstantExpr *CE = dyn_cast(CV); @@ -2041,7 +1953,7 @@ NVPTXAsmPrinter::lowerConstantForGV(const Constant *CV, bool ProcessingGeneric) } // Copy of MCExpr::print customized for NVPTX -void NVPTXAsmPrinter::printMCExpr(const MCExpr &Expr, raw_ostream &OS) { +void NVPTXAsmPrinter::printMCExpr(const MCExpr &Expr, raw_ostream &OS) const { switch (Expr.getKind()) { case MCExpr::Target: return cast(&Expr)->printImpl(OS, MAI); diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h index f58b4bdc40474..f7c3fda332eff 100644 --- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h +++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h @@ -101,15 +101,13 @@ class LLVM_LIBRARY_VISIBILITY NVPTXAsmPrinter : public AsmPrinter { // SymbolsBeforeStripping[i]. SmallVector SymbolsBeforeStripping; unsigned curpos; - NVPTXAsmPrinter &AP; - bool EmitGeneric; + const NVPTXAsmPrinter &AP; + const bool EmitGeneric; public: - AggBuffer(unsigned size, NVPTXAsmPrinter &AP) - : size(size), buffer(size), AP(AP) { - curpos = 0; - EmitGeneric = AP.EmitGeneric; - } + AggBuffer(unsigned size, const NVPTXAsmPrinter &AP) + : size(size), buffer(size), curpos(0), AP(AP), + EmitGeneric(AP.EmitGeneric) {} // Copy Num bytes from Ptr. // if Bytes > Num, zero fill up to Bytes. @@ -155,7 +153,6 @@ class LLVM_LIBRARY_VISIBILITY NVPTXAsmPrinter : public AsmPrinter { StringRef getPassName() const override { return "NVPTX Assembly Printer"; } const Function *F; - std::string CurrentFnName; void emitStartOfAsmFile(Module &M) override; void emitBasicBlockStart(const MachineBasicBlock &MBB) override; @@ -190,8 +187,9 @@ class LLVM_LIBRARY_VISIBILITY NVPTXAsmPrinter : public AsmPrinter { bool PrintAsmMemoryOperand(const MachineInstr *MI, unsigned OpNo, const char *ExtraCode, raw_ostream &) override; - const MCExpr *lowerConstantForGV(const Constant *CV, bool ProcessingGeneric); - void printMCExpr(const MCExpr &Expr, raw_ostream &OS); + const MCExpr *lowerConstantForGV(const Constant *CV, + bool ProcessingGeneric) const; + void printMCExpr(const MCExpr &Expr, raw_ostream &OS) const; protected: bool doInitialization(Module &M) override; @@ -217,7 +215,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXAsmPrinter : public AsmPrinter { void emitPTXAddressSpace(unsigned int AddressSpace, raw_ostream &O) const; std::string getPTXFundamentalTypeStr(Type *Ty, bool = true) const; void printScalarConstant(const Constant *CPV, raw_ostream &O); - void printFPConstant(const ConstantFP *Fp, raw_ostream &O); + void printFPConstant(const ConstantFP *Fp, raw_ostream &O) const; void bufferLEByte(const Constant *CPV, int Bytes, AggBuffer *aggBuffer); void bufferAggregateConstant(const Constant *CV, AggBuffer *aggBuffer); @@ -245,7 +243,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXAsmPrinter : public AsmPrinter { // Since the address value should always be generic in CUDA C and always // be specific in OpenCL, we use this simple control here. // - bool EmitGeneric; + const bool EmitGeneric; public: NVPTXAsmPrinter(TargetMachine &TM, std::unique_ptr Streamer) diff --git a/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.cpp index d1b136429d3a4..229c438edf723 100644 --- a/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.cpp @@ -24,7 +24,7 @@ using namespace llvm; #define DEBUG_TYPE "nvptx-reg-info" namespace llvm { -std::string getNVPTXRegClassName(TargetRegisterClass const *RC) { +StringRef getNVPTXRegClassName(TargetRegisterClass const *RC) { if (RC == &NVPTX::Float32RegsRegClass) return ".f32"; if (RC == &NVPTX::Float64RegsRegClass) @@ -62,7 +62,7 @@ std::string getNVPTXRegClassName(TargetRegisterClass const *RC) { return "INTERNAL"; } -std::string getNVPTXRegClassStr(TargetRegisterClass const *RC) { +StringRef getNVPTXRegClassStr(TargetRegisterClass const *RC) { if (RC == &NVPTX::Float32RegsRegClass) return "%f"; if (RC == &NVPTX::Float64RegsRegClass) @@ -81,7 +81,7 @@ std::string getNVPTXRegClassStr(TargetRegisterClass const *RC) { return "!Special!"; return "INTERNAL"; } -} +} // namespace llvm NVPTXRegisterInfo::NVPTXRegisterInfo() : NVPTXGenRegisterInfo(0), StrPool(StrAlloc) {} @@ -144,11 +144,10 @@ void NVPTXRegisterInfo::clearDebugRegisterMap() const { debugRegisterMap.clear(); } -static uint64_t encodeRegisterForDwarf(std::string registerName) { - if (registerName.length() > 8) { +static uint64_t encodeRegisterForDwarf(StringRef RegisterName) { + if (RegisterName.size() > 8) // The name is more than 8 characters long, and so won't fit into 64 bits. return 0; - } // Encode the name string into a DWARF register number using cuda-gdb's // encoding. See cuda_check_dwarf2_reg_ptx_virtual_register in cuda-tdep.c, @@ -157,14 +156,14 @@ static uint64_t encodeRegisterForDwarf(std::string registerName) { // number, which is stored in ULEB128, but in practice must be no more than 8 // bytes (excluding null terminator, which is not included). uint64_t result = 0; - for (unsigned char c : registerName) + for (unsigned char c : RegisterName) result = (result << 8) | c; return result; } void NVPTXRegisterInfo::addToDebugRegisterMap( - uint64_t preEncodedVirtualRegister, std::string registerName) const { - uint64_t mapped = encodeRegisterForDwarf(registerName); + uint64_t preEncodedVirtualRegister, StringRef RegisterName) const { + uint64_t mapped = encodeRegisterForDwarf(RegisterName); if (mapped == 0) return; debugRegisterMap.insert({preEncodedVirtualRegister, mapped}); @@ -172,13 +171,13 @@ void NVPTXRegisterInfo::addToDebugRegisterMap( int64_t NVPTXRegisterInfo::getDwarfRegNum(MCRegister RegNum, bool isEH) const { if (RegNum.isPhysical()) { - std::string name = NVPTXInstPrinter::getRegisterName(RegNum.id()); + StringRef Name = NVPTXInstPrinter::getRegisterName(RegNum.id()); // In NVPTXFrameLowering.cpp, we do arrange for %Depot to be accessible from // %SP. Using the %Depot register doesn't provide any debug info in // cuda-gdb, but switching it to %SP does. if (RegNum.id() == NVPTX::VRDepot) - name = "%SP"; - return encodeRegisterForDwarf(name); + Name = "%SP"; + return encodeRegisterForDwarf(Name); } uint64_t lookup = debugRegisterMap.lookup(RegNum.id()); if (lookup) diff --git a/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.h b/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.h index d2f6d257d6b07..cfec7377fd634 100644 --- a/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.h +++ b/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.h @@ -69,13 +69,13 @@ class NVPTXRegisterInfo : public NVPTXGenRegisterInfo { // here, because the proper encoding for debug registers is available only // temporarily during ASM emission. void addToDebugRegisterMap(uint64_t preEncodedVirtualRegister, - std::string registerName) const; + StringRef RegisterName) const; void clearDebugRegisterMap() const; int64_t getDwarfRegNum(MCRegister RegNum, bool isEH) const override; }; -std::string getNVPTXRegClassName(const TargetRegisterClass *RC); -std::string getNVPTXRegClassStr(const TargetRegisterClass *RC); +StringRef getNVPTXRegClassName(const TargetRegisterClass *RC); +StringRef getNVPTXRegClassStr(const TargetRegisterClass *RC); } // end namespace llvm From 32aaedf851b3e29acd77331dd3ba7e01be238114 Mon Sep 17 00:00:00 2001 From: Alex Maclean Date: Thu, 13 Feb 2025 16:28:11 +0000 Subject: [PATCH 2/2] fixup for release builds --- llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp index 5b60151c14cc4..0538b33530470 100644 --- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp @@ -914,7 +914,7 @@ void NVPTXAsmPrinter::emitLinkageDirective(const GlobalValue *V, else O << ".visible "; } else if (V->hasAppendingLinkage()) { - report_fatal_error("Symbol '" + llvm::Twine(V->getNameOrAsOperand()) + + report_fatal_error("Symbol '" + (V->hasName() ? V->getName() : "") + "' has unsupported appending linkage type"); } else if (!V->hasInternalLinkage() && !V->hasPrivateLinkage()) { O << ".weak ";