Skip to content

Commit 922c2d5

Browse files
akshayrdeodharfrasercrmck
authored andcommitted
[NVPTX] Basic support for "grid_constant" (#96125)
- Adds a helper function for checking whether an argument is a [grid_constant](https://docs.nvidia.com/cuda/nvvm-ir-spec/index.html#supported-properties). - Adds support for cvta.param using changes from llvm/llvm-project#95289 - Supports escaped grid_constant pointers conservatively, by casting all uses to the generic address space with cvta.param.
1 parent 58e60a5 commit 922c2d5

File tree

6 files changed

+297
-86
lines changed

6 files changed

+297
-86
lines changed

llvm/include/llvm/IR/IntrinsicsNVVM.td

+6
Original file line numberDiff line numberDiff line change
@@ -1698,6 +1698,12 @@ def int_nvvm_ptr_gen_to_param: Intrinsic<[llvm_anyptr_ty],
16981698
[IntrNoMem, IntrSpeculatable, IntrNoCallback],
16991699
"llvm.nvvm.ptr.gen.to.param">;
17001700

1701+
// sm70+, PTX7.7+
1702+
def int_nvvm_ptr_param_to_gen: DefaultAttrsIntrinsic<[llvm_anyptr_ty],
1703+
[llvm_anyptr_ty],
1704+
[IntrNoMem, IntrSpeculatable, IntrNoCallback],
1705+
"llvm.nvvm.ptr.param.to.gen">;
1706+
17011707
// Move intrinsics, used in nvvm internally
17021708

17031709
def int_nvvm_move_i16 : Intrinsic<[llvm_i16_ty], [llvm_i16_ty], [IntrNoMem],

llvm/lib/Target/NVPTX/NVPTXIntrinsics.td

+1
Original file line numberDiff line numberDiff line change
@@ -2684,6 +2684,7 @@ defm cvta_local : NG_TO_G<"local", int_nvvm_ptr_local_to_gen, useShortPtrLocal>
26842684
defm cvta_shared : NG_TO_G<"shared", int_nvvm_ptr_shared_to_gen, useShortPtrShared>;
26852685
defm cvta_global : NG_TO_G<"global", int_nvvm_ptr_global_to_gen, False>;
26862686
defm cvta_const : NG_TO_G<"const", int_nvvm_ptr_constant_to_gen, useShortPtrConst>;
2687+
defm cvta_param : NG_TO_G<"param", int_nvvm_ptr_param_to_gen, False>;
26872688

26882689
defm cvta_to_local : G_TO_NG<"local", int_nvvm_ptr_gen_to_local, useShortPtrLocal>;
26892690
defm cvta_to_shared : G_TO_NG<"shared", int_nvvm_ptr_gen_to_shared, useShortPtrShared>;

llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp

+56-21
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,9 @@
9595
#include "llvm/Analysis/ValueTracking.h"
9696
#include "llvm/CodeGen/TargetPassConfig.h"
9797
#include "llvm/IR/Function.h"
98+
#include "llvm/IR/IRBuilder.h"
9899
#include "llvm/IR/Instructions.h"
100+
#include "llvm/IR/IntrinsicsNVPTX.h"
99101
#include "llvm/IR/Module.h"
100102
#include "llvm/IR/Type.h"
101103
#include "llvm/InitializePasses.h"
@@ -336,8 +338,9 @@ void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
336338
while (!ValuesToCheck.empty()) {
337339
Value *V = ValuesToCheck.pop_back_val();
338340
if (!IsALoadChainInstr(V)) {
339-
LLVM_DEBUG(dbgs() << "Need a copy of " << *Arg << " because of " << *V
340-
<< "\n");
341+
LLVM_DEBUG(dbgs() << "Need a "
342+
<< (isParamGridConstant(*Arg) ? "cast " : "copy ")
343+
<< "of " << *Arg << " because of " << *V << "\n");
341344
(void)Arg;
342345
return false;
343346
}
@@ -366,27 +369,59 @@ void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
366369
return;
367370
}
368371

369-
// Otherwise we have to create a temporary copy.
370372
const DataLayout &DL = Func->getParent()->getDataLayout();
371373
unsigned AS = DL.getAllocaAddrSpace();
372-
AllocaInst *AllocA = new AllocaInst(StructType, AS, Arg->getName(), FirstInst);
373-
// Set the alignment to alignment of the byval parameter. This is because,
374-
// later load/stores assume that alignment, and we are going to replace
375-
// the use of the byval parameter with this alloca instruction.
376-
AllocA->setAlignment(Func->getParamAlign(Arg->getArgNo())
377-
.value_or(DL.getPrefTypeAlign(StructType)));
378-
Arg->replaceAllUsesWith(AllocA);
379-
380-
Value *ArgInParam = new AddrSpaceCastInst(
381-
Arg, PointerType::get(StructType, ADDRESS_SPACE_PARAM), Arg->getName(),
382-
FirstInst);
383-
// Be sure to propagate alignment to this load; LLVM doesn't know that NVPTX
384-
// addrspacecast preserves alignment. Since params are constant, this load is
385-
// definitely not volatile.
386-
LoadInst *LI =
387-
new LoadInst(StructType, ArgInParam, Arg->getName(),
388-
/*isVolatile=*/false, AllocA->getAlign(), FirstInst);
389-
new StoreInst(LI, AllocA, FirstInst);
374+
if (isParamGridConstant(*Arg)) {
375+
// Writes to a grid constant are undefined behaviour. We do not need a
376+
// temporary copy. When a pointer might have escaped, conservatively replace
377+
// all of its uses (which might include a device function call) with a cast
378+
// to the generic address space.
379+
// TODO: only cast byval grid constant parameters at use points that need
380+
// generic address (e.g., merging parameter pointers with other address
381+
// space, or escaping to call-sites, inline-asm, memory), and use the
382+
// parameter address space for normal loads.
383+
IRBuilder<> IRB(&Func->getEntryBlock().front());
384+
385+
// Cast argument to param address space
386+
auto *CastToParam =
387+
cast<AddrSpaceCastInst>(IRB.CreateAddrSpaceCast(
388+
Arg, IRB.getPtrTy(ADDRESS_SPACE_PARAM), Arg->getName() + ".param"));
389+
390+
// Cast param address to generic address space. We do not use an
391+
// addrspacecast to generic here, because, LLVM considers `Arg` to be in the
392+
// generic address space, and a `generic -> param` cast followed by a `param
393+
// -> generic` cast will be folded away. The `param -> generic` intrinsic
394+
// will be correctly lowered to `cvta.param`.
395+
Value *CvtToGenCall = IRB.CreateIntrinsic(
396+
IRB.getPtrTy(ADDRESS_SPACE_GENERIC), Intrinsic::nvvm_ptr_param_to_gen,
397+
CastToParam, nullptr, CastToParam->getName() + ".gen");
398+
399+
Arg->replaceAllUsesWith(CvtToGenCall);
400+
401+
// Do not replace Arg in the cast to param space
402+
CastToParam->setOperand(0, Arg);
403+
} else {
404+
// Otherwise we have to create a temporary copy.
405+
AllocaInst *AllocA =
406+
new AllocaInst(StructType, AS, Arg->getName(), FirstInst);
407+
// Set the alignment to alignment of the byval parameter. This is because,
408+
// later load/stores assume that alignment, and we are going to replace
409+
// the use of the byval parameter with this alloca instruction.
410+
AllocA->setAlignment(Func->getParamAlign(Arg->getArgNo())
411+
.value_or(DL.getPrefTypeAlign(StructType)));
412+
Arg->replaceAllUsesWith(AllocA);
413+
414+
Value *ArgInParam = new AddrSpaceCastInst(
415+
Arg, PointerType::get(Arg->getContext(), ADDRESS_SPACE_PARAM),
416+
Arg->getName(), FirstInst);
417+
// Be sure to propagate alignment to this load; LLVM doesn't know that NVPTX
418+
// addrspacecast preserves alignment. Since params are constant, this load
419+
// is definitely not volatile.
420+
LoadInst *LI =
421+
new LoadInst(StructType, ArgInParam, Arg->getName(),
422+
/*isVolatile=*/false, AllocA->getAlign(), FirstInst);
423+
new StoreInst(LI, AllocA, FirstInst);
424+
}
390425
}
391426

392427
void NVPTXLowerArgs::markPointerAsGlobal(Value *Ptr) {

llvm/lib/Target/NVPTX/NVPTXUtilities.cpp

+78-65
Original file line numberDiff line numberDiff line change
@@ -52,29 +52,46 @@ void clearAnnotationCache(const Module *Mod) {
5252
AC.Cache.erase(Mod);
5353
}
5454

55-
static void cacheAnnotationFromMD(const MDNode *md, key_val_pair_t &retval) {
55+
static void readIntVecFromMDNode(const MDNode *MetadataNode,
56+
std::vector<unsigned> &Vec) {
57+
for (unsigned i = 0, e = MetadataNode->getNumOperands(); i != e; ++i) {
58+
ConstantInt *Val =
59+
mdconst::extract<ConstantInt>(MetadataNode->getOperand(i));
60+
Vec.push_back(Val->getZExtValue());
61+
}
62+
}
63+
64+
static void cacheAnnotationFromMD(const MDNode *MetadataNode,
65+
key_val_pair_t &retval) {
5666
auto &AC = getAnnotationCache();
5767
std::lock_guard<sys::Mutex> Guard(AC.Lock);
58-
assert(md && "Invalid mdnode for annotation");
59-
assert((md->getNumOperands() % 2) == 1 && "Invalid number of operands");
68+
assert(MetadataNode && "Invalid mdnode for annotation");
69+
assert((MetadataNode->getNumOperands() % 2) == 1 &&
70+
"Invalid number of operands");
6071
// start index = 1, to skip the global variable key
6172
// increment = 2, to skip the value for each property-value pairs
62-
for (unsigned i = 1, e = md->getNumOperands(); i != e; i += 2) {
73+
for (unsigned i = 1, e = MetadataNode->getNumOperands(); i != e; i += 2) {
6374
// property
64-
const MDString *prop = dyn_cast<MDString>(md->getOperand(i));
75+
const MDString *prop = dyn_cast<MDString>(MetadataNode->getOperand(i));
6576
assert(prop && "Annotation property not a string");
77+
std::string Key = prop->getString().str();
6678

6779
// value
68-
ConstantInt *Val = mdconst::dyn_extract<ConstantInt>(md->getOperand(i + 1));
69-
assert(Val && "Value operand not a constant int");
70-
71-
std::string keyname = prop->getString().str();
72-
if (retval.find(keyname) != retval.end())
73-
retval[keyname].push_back(Val->getZExtValue());
74-
else {
75-
std::vector<unsigned> tmp;
76-
tmp.push_back(Val->getZExtValue());
77-
retval[keyname] = tmp;
80+
if (ConstantInt *Val = mdconst::dyn_extract<ConstantInt>(
81+
MetadataNode->getOperand(i + 1))) {
82+
retval[Key].push_back(Val->getZExtValue());
83+
} else if (MDNode *VecMd =
84+
dyn_cast<MDNode>(MetadataNode->getOperand(i + 1))) {
85+
// note: only "grid_constant" annotations support vector MDNodes.
86+
// assert: there can only exist one unique key value pair of
87+
// the form (string key, MDNode node). Operands of such a node
88+
// shall always be unsigned ints.
89+
if (retval.find(Key) == retval.end()) {
90+
readIntVecFromMDNode(VecMd, retval[Key]);
91+
continue;
92+
}
93+
} else {
94+
llvm_unreachable("Value operand not a constant int or an mdnode");
7895
}
7996
}
8097
}
@@ -145,9 +162,9 @@ bool findAllNVVMAnnotation(const GlobalValue *gv, const std::string &prop,
145162

146163
bool isTexture(const Value &val) {
147164
if (const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
148-
unsigned annot;
149-
if (findOneNVVMAnnotation(gv, "texture", annot)) {
150-
assert((annot == 1) && "Unexpected annotation on a texture symbol");
165+
unsigned Annot;
166+
if (findOneNVVMAnnotation(gv, "texture", Annot)) {
167+
assert((Annot == 1) && "Unexpected annotation on a texture symbol");
151168
return true;
152169
}
153170
}
@@ -156,70 +173,67 @@ bool isTexture(const Value &val) {
156173

157174
bool isSurface(const Value &val) {
158175
if (const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
159-
unsigned annot;
160-
if (findOneNVVMAnnotation(gv, "surface", annot)) {
161-
assert((annot == 1) && "Unexpected annotation on a surface symbol");
176+
unsigned Annot;
177+
if (findOneNVVMAnnotation(gv, "surface", Annot)) {
178+
assert((Annot == 1) && "Unexpected annotation on a surface symbol");
162179
return true;
163180
}
164181
}
165182
return false;
166183
}
167184

168-
bool isSampler(const Value &val) {
169-
const char *AnnotationName = "sampler";
170-
171-
if (const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
172-
unsigned annot;
173-
if (findOneNVVMAnnotation(gv, AnnotationName, annot)) {
174-
assert((annot == 1) && "Unexpected annotation on a sampler symbol");
175-
return true;
176-
}
177-
}
178-
if (const Argument *arg = dyn_cast<Argument>(&val)) {
179-
const Function *func = arg->getParent();
180-
std::vector<unsigned> annot;
181-
if (findAllNVVMAnnotation(func, AnnotationName, annot)) {
182-
if (is_contained(annot, arg->getArgNo()))
185+
static bool argHasNVVMAnnotation(const Value &Val,
186+
const std::string &Annotation,
187+
const bool StartArgIndexAtOne = false) {
188+
if (const Argument *Arg = dyn_cast<Argument>(&Val)) {
189+
const Function *Func = Arg->getParent();
190+
std::vector<unsigned> Annot;
191+
if (findAllNVVMAnnotation(Func, Annotation, Annot)) {
192+
const unsigned BaseOffset = StartArgIndexAtOne ? 1 : 0;
193+
if (is_contained(Annot, BaseOffset + Arg->getArgNo())) {
183194
return true;
195+
}
184196
}
185197
}
186198
return false;
187199
}
188200

189-
bool isImageReadOnly(const Value &val) {
190-
if (const Argument *arg = dyn_cast<Argument>(&val)) {
191-
const Function *func = arg->getParent();
192-
std::vector<unsigned> annot;
193-
if (findAllNVVMAnnotation(func, "rdoimage", annot)) {
194-
if (is_contained(annot, arg->getArgNo()))
195-
return true;
201+
bool isParamGridConstant(const Value &V) {
202+
if (const Argument *Arg = dyn_cast<Argument>(&V)) {
203+
// "grid_constant" counts argument indices starting from 1
204+
if (Arg->hasByValAttr() &&
205+
argHasNVVMAnnotation(*Arg, "grid_constant", /*StartArgIndexAtOne*/true)) {
206+
assert(isKernelFunction(*Arg->getParent()) &&
207+
"only kernel arguments can be grid_constant");
208+
return true;
196209
}
197210
}
198211
return false;
199212
}
200213

201-
bool isImageWriteOnly(const Value &val) {
202-
if (const Argument *arg = dyn_cast<Argument>(&val)) {
203-
const Function *func = arg->getParent();
204-
std::vector<unsigned> annot;
205-
if (findAllNVVMAnnotation(func, "wroimage", annot)) {
206-
if (is_contained(annot, arg->getArgNo()))
207-
return true;
214+
bool isSampler(const Value &val) {
215+
const char *AnnotationName = "sampler";
216+
217+
if (const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
218+
unsigned Annot;
219+
if (findOneNVVMAnnotation(gv, AnnotationName, Annot)) {
220+
assert((Annot == 1) && "Unexpected annotation on a sampler symbol");
221+
return true;
208222
}
209223
}
210-
return false;
224+
return argHasNVVMAnnotation(val, AnnotationName);
225+
}
226+
227+
bool isImageReadOnly(const Value &val) {
228+
return argHasNVVMAnnotation(val, "rdoimage");
229+
}
230+
231+
bool isImageWriteOnly(const Value &val) {
232+
return argHasNVVMAnnotation(val, "wroimage");
211233
}
212234

213235
bool isImageReadWrite(const Value &val) {
214-
if (const Argument *arg = dyn_cast<Argument>(&val)) {
215-
const Function *func = arg->getParent();
216-
std::vector<unsigned> annot;
217-
if (findAllNVVMAnnotation(func, "rdwrimage", annot)) {
218-
if (is_contained(annot, arg->getArgNo()))
219-
return true;
220-
}
221-
}
222-
return false;
236+
return argHasNVVMAnnotation(val, "rdwrimage");
223237
}
224238

225239
bool isImage(const Value &val) {
@@ -228,9 +242,9 @@ bool isImage(const Value &val) {
228242

229243
bool isManaged(const Value &val) {
230244
if(const GlobalValue *gv = dyn_cast<GlobalValue>(&val)) {
231-
unsigned annot;
232-
if (findOneNVVMAnnotation(gv, "managed", annot)) {
233-
assert((annot == 1) && "Unexpected annotation on a managed symbol");
245+
unsigned Annot;
246+
if (findOneNVVMAnnotation(gv, "managed", Annot)) {
247+
assert((Annot == 1) && "Unexpected annotation on a managed symbol");
234248
return true;
235249
}
236250
}
@@ -290,8 +304,7 @@ bool getMaxNReg(const Function &F, unsigned &x) {
290304

291305
bool isKernelFunction(const Function &F) {
292306
unsigned x = 0;
293-
bool retval = findOneNVVMAnnotation(&F, "kernel", x);
294-
if (!retval) {
307+
if (!findOneNVVMAnnotation(&F, "kernel", x)) {
295308
// There is no NVVM metadata, check the calling convention
296309
return F.getCallingConv() == CallingConv::PTX_Kernel;
297310
}

llvm/lib/Target/NVPTX/NVPTXUtilities.h

+1
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ bool getMaxClusterRank(const Function &, unsigned &);
6060
bool getMinCTASm(const Function &, unsigned &);
6161
bool getMaxNReg(const Function &, unsigned &);
6262
bool isKernelFunction(const Function &);
63+
bool isParamGridConstant(const Value &);
6364

6465
MaybeAlign getAlign(const Function &, unsigned);
6566
MaybeAlign getAlign(const CallInst &, unsigned);

0 commit comments

Comments
 (0)