Skip to content

[SYCL][NVPTX] Emit 'grid_constant' annotations for by-val kernel params #14332

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jul 10, 2024
64 changes: 64 additions & 0 deletions clang/lib/CodeGen/Targets/NVPTX.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "ABIInfoImpl.h"
#include "TargetInfo.h"
#include "clang/Basic/Cuda.h"
#include "llvm/IR/IntrinsicsNVPTX.h"

using namespace clang;
Expand Down Expand Up @@ -80,6 +81,9 @@ class NVPTXTargetCodeGenInfo : public TargetCodeGenInfo {
static void addNVVMMetadata(llvm::GlobalValue *GV, StringRef Name,
int Operand);

static void addNVVMMetadata(llvm::GlobalValue *GV, StringRef Name,
const std::vector<int> &Operands);

private:
static void emitBuiltinSurfTexDeviceCopy(CodeGenFunction &CGF, LValue Dst,
LValue Src) {
Expand Down Expand Up @@ -218,6 +222,28 @@ RValue NVPTXABIInfo::EmitVAArg(CodeGenFunction &CGF, Address VAListAddr,
llvm_unreachable("NVPTX does not support varargs");
}

// Get current CudaArch and ignore any unknown values
// Copied from CGOpenMPRuntimeGPU
static CudaArch getCudaArch(CodeGenModule &CGM) {
if (!CGM.getTarget().hasFeature("ptx"))
return CudaArch::UNKNOWN;
for (const auto &Feature : CGM.getTarget().getTargetOpts().FeatureMap) {
if (Feature.getValue()) {
CudaArch Arch = StringToCudaArch(Feature.getKey());
if (Arch != CudaArch::UNKNOWN)
return Arch;
}
}
return CudaArch::UNKNOWN;
}

static bool supportsGridConstant(CudaArch Arch) {
assert((Arch == CudaArch::UNKNOWN || IsNVIDIAGpuArch(Arch)) &&
"Unexpected architecture");
static_assert(CudaArch::UNKNOWN < CudaArch::SM_70);
return Arch >= CudaArch::SM_70;
}

void NVPTXTargetCodeGenInfo::setTargetAttributes(
const Decl *D, llvm::GlobalValue *GV, CodeGen::CodeGenModule &M) const {
if (GV->isDeclaration())
Expand Down Expand Up @@ -248,6 +274,22 @@ void NVPTXTargetCodeGenInfo::setTargetAttributes(
addNVVMMetadata(F, "kernel", 1);
// And kernel functions are not subject to inlining
F->addFnAttr(llvm::Attribute::NoInline);

if (M.getLangOpts().SYCLIsDevice &&
supportsGridConstant(getCudaArch(M))) {
// Add grid_constant annotations to all relevant kernel-function
// parameters. We can guarantee that in SYCL, all by-val kernel
// parameters are "grid_constant".
std::vector<int> GridConstantParamIdxs;
for (auto [Idx, Arg] : llvm::enumerate(F->args())) {
if (Arg.getType()->isPointerTy() && Arg.hasByValAttr()) {
// Note - the parameter indices are numbered from 1.
GridConstantParamIdxs.push_back(Idx + 1);
}
}
if (!GridConstantParamIdxs.empty())
addNVVMMetadata(F, "grid_constant", GridConstantParamIdxs);
}
}
bool HasMaxWorkGroupSize = false;
bool HasMinWorkGroupPerCU = false;
Expand Down Expand Up @@ -329,6 +371,28 @@ void NVPTXTargetCodeGenInfo::addNVVMMetadata(llvm::GlobalValue *GV,
MD->addOperand(llvm::MDNode::get(Ctx, MDVals));
}

void NVPTXTargetCodeGenInfo::addNVVMMetadata(llvm::GlobalValue *GV,
StringRef Name,
const std::vector<int> &Operands) {
llvm::Module *M = GV->getParent();
llvm::LLVMContext &Ctx = M->getContext();

// Get "nvvm.annotations" metadata node
llvm::NamedMDNode *MD = M->getOrInsertNamedMetadata("nvvm.annotations");

llvm::SmallVector<llvm::Metadata *, 8> MDOps;
for (int Op : Operands) {
MDOps.push_back(llvm::ConstantAsMetadata::get(
llvm::ConstantInt::get(llvm::Type::getInt32Ty(Ctx), Op)));
}
auto *OpList = llvm::MDNode::get(Ctx, MDOps);

llvm::Metadata *MDVals[] = {llvm::ConstantAsMetadata::get(GV),
llvm::MDString::get(Ctx, Name), OpList};
// Append metadata to nvvm.annotations
MD->addOperand(llvm::MDNode::get(Ctx, MDVals));
}

bool NVPTXTargetCodeGenInfo::shouldEmitStaticExternCAliases() const {
return false;
}
Expand Down
34 changes: 34 additions & 0 deletions clang/test/CodeGenSYCL/nvvm-annotations.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// RUN: %clang_cc1 -fno-sycl-force-inline-kernel-lambda -fsycl-is-device -internal-isystem %S/Inputs -triple nvptx-nvidia-cuda -target-cpu sm_70 -disable-llvm-passes -sycl-std=2020 -emit-llvm -o - %s | FileCheck %s --check-prefixes=CHECK,GRIDCONST
// RUN: %clang_cc1 -fno-sycl-force-inline-kernel-lambda -fsycl-is-device -internal-isystem %S/Inputs -triple nvptx64-nvidia-cuda -target-cpu sm_70 -disable-llvm-passes -sycl-std=2020 -emit-llvm -o - %s | FileCheck %s --check-prefixes=CHECK,GRIDCONST

// RUN: %clang_cc1 -fno-sycl-force-inline-kernel-lambda -fsycl-is-device -internal-isystem %S/Inputs -triple nvptx-nvidia-cuda -target-cpu sm_60 -disable-llvm-passes -sycl-std=2020 -emit-llvm -o - %s | FileCheck %s --check-prefixes=CHECK,NOGRIDCONST
// RUN: %clang_cc1 -fno-sycl-force-inline-kernel-lambda -fsycl-is-device -internal-isystem %S/Inputs -triple nvptx64-nvidia-cuda -target-cpu sm_60 -disable-llvm-passes -sycl-std=2020 -emit-llvm -o - %s | FileCheck %s --check-prefixes=CHECK,NOGRIDCONST

// Tests that certain SYCL kernel parameters are annotated with "grid_constant" for supported microarchitectures.

#include "sycl.hpp"

using namespace sycl;

int main() {
queue q;

struct S {
int a;
} s;

q.submit([&](handler &h) {
// CHECK: define{{.*}} void @[[FUNC1:.*kernel_grid_const_params]](ptr noundef byval(%struct.S) align 4 %_arg_s)
h.single_task<class kernel_grid_const_params>([=]() { (void) s;});
});

return 0;
}

// Don't emit grid_constant annotations for older architectures.
// NOGRIDCONST-NOT: "grid_constant"

// This isn't stable in general, as it depends on the order of the captured
// parameters, but in this case there's only one parameter so we know it's 1.
// GRIDCONST-DAG: = !{ptr @[[FUNC1]], !"grid_constant", [[MD:\![0-9]+]]}
// GRIDCONST-DAG: [[MD]] = !{i32 1}
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,8 @@ class DeadArgumentEliminationPass
bool removeDeadArgumentsFromCallers(Function &F);
void propagateVirtMustcallLiveness(const Module &M);

void UpdateNVPTXMetadata(Module &M, Function *F, Function *NF);
void UpdateNVPTXMetadata(Module &M, Function *F, Function *NF,
const SmallVectorImpl<bool> &ArgAlive);
llvm::DenseSet<Function *> NVPTXKernelSet;

bool IsNVPTXKernel(const Function *F) { return NVPTXKernelSet.contains(F); };
Expand Down
74 changes: 70 additions & 4 deletions llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1171,7 +1171,7 @@ bool DeadArgumentEliminationPass::removeDeadStuffFromFunction(Function *F) {
NF->addMetadata(KindID, *Node);

if (IsNVPTXKernel(F))
UpdateNVPTXMetadata(*(F->getParent()), F, NF);
UpdateNVPTXMetadata(*(F->getParent()), F, NF, ArgAlive);

// If either the return value(s) or argument(s) are removed, then probably the
// function does not follow standard calling conventions anymore. Hence, add
Expand Down Expand Up @@ -1249,8 +1249,9 @@ PreservedAnalyses DeadArgumentEliminationPass::run(Module &M,
return PreservedAnalyses::none();
}

void DeadArgumentEliminationPass::UpdateNVPTXMetadata(Module &M, Function *F,
Function *NF) {
void DeadArgumentEliminationPass::UpdateNVPTXMetadata(
Module &M, Function *F, Function *NF,
const SmallVectorImpl<bool> &ArgAlive) {

auto *NvvmMetadata = M.getNamedMetadata("nvvm.annotations");
if (!NvvmMetadata)
Expand All @@ -1260,13 +1261,78 @@ void DeadArgumentEliminationPass::UpdateNVPTXMetadata(Module &M, Function *F,
const auto &FuncOperand = MetadataNode->getOperand(0);
if (!FuncOperand)
continue;
auto FuncConstant = dyn_cast<ConstantAsMetadata>(FuncOperand);
auto *FuncConstant = dyn_cast<ConstantAsMetadata>(FuncOperand);
if (!FuncConstant)
continue;
auto *Func = dyn_cast<Function>(FuncConstant->getValue());
if (Func != F)
continue;
// Update the metadata with the new function
MetadataNode->replaceOperandWith(0, llvm::ConstantAsMetadata::get(NF));

// Carefully update any and all grid_constant annotations, since those are
// denoted parameter indices, which may have changed
for (unsigned I = 1; I < MetadataNode->getNumOperands() - 1; I += 2) {
if (auto *Type = dyn_cast<MDString>(MetadataNode->getOperand(I));
Type && Type->getString() == "grid_constant") {
LLVMContext &Ctx = NF->getContext();
LLVM_DEBUG(dbgs() << "DeadArgumentEliminationPass - updating nvvm "
"grid_constant annotations for fn: "
<< NF->getName() << "\n");
// The 'value' operand is a list of integers denoting parameter indices
const auto *OldGridConstParamIdxs =
dyn_cast<MDNode>(MetadataNode->getOperand(I + 1));
assert(OldGridConstParamIdxs &&
"Unexpected NVVM annotation format: expected MDNode operand");
// For each parameter that's identified as a grid_constant, count how
// many arguments before that position are dead, and shift the number
// down by that amount.
// Note that there's no guaranteed order to the parameter indices, so
// there's fewer 'smart' things like counting up incrementally as we go.
SmallVector<Metadata *, 8> NewGridConstParamOps;
for (const auto &Op : OldGridConstParamIdxs->operands()) {
auto *ParamIdx = mdconst::dyn_extract<ConstantInt>(Op);
// If the operand's not a constant, or its constant value is not
// within the range of the old function's parameter list (note - it
// counts from 1), it's not well-defined. Just strip it out for
// safety.
if (!ParamIdx || ParamIdx->isZero() ||
ParamIdx->getZExtValue() > F->arg_size())
continue;

size_t OldParamIdx = ParamIdx->getZExtValue() - 1;
// If the parameter is no longer alive, it's definitely not a
// grid_constant. Strip it out.
if (!ArgAlive[OldParamIdx])
continue;

unsigned ShiftDownAmt = 0;
for (unsigned i = 0; i < std::min(F->arg_size(), OldParamIdx); i++) {
if (!ArgAlive[i])
ShiftDownAmt++;
}
NewGridConstParamOps.push_back(
ConstantAsMetadata::get(ConstantInt::get(
Type::getInt32Ty(Ctx), OldParamIdx - ShiftDownAmt + 1)));
}

// Update the metadata with the new grid_constant information
MDNode *NewGridConstParamIdxs = MDNode::get(Ctx, NewGridConstParamOps);

LLVM_DEBUG(dbgs() << " * updating old annotation {";
auto PrintList =
[](const MDNode *MD) {
for (const auto &O : MD->operands())
if (const auto *ParamNo =
mdconst::dyn_extract<ConstantInt>(O))
dbgs() << ParamNo->getZExtValue() << ",";
};
PrintList(OldGridConstParamIdxs);
dbgs() << "} to new annotation {";
PrintList(NewGridConstParamIdxs); dbgs() << "}\n";);

MetadataNode->replaceOperandWith(I + 1, NewGridConstParamIdxs);
}
}
}
}
113 changes: 113 additions & 0 deletions llvm/test/Transforms/DeadArgElim/nvvm-annotations.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --check-globals all --version 5
; RUN: opt < %s -passes=deadargelim -S | FileCheck %s

define internal void @test1(i32 %v, ptr byval(i32) %DEADARG1, ptr %p) {
; CHECK-LABEL: define internal void @test1(
; CHECK-SAME: i32 [[V:%.*]], ptr [[P:%.*]]) {
; CHECK-NEXT: store i32 [[V]], ptr [[P]], align 4
; CHECK-NEXT: ret void
;
store i32 %v, ptr %p
ret void
}

define internal void @test2(ptr byval(i32) %DEADARG1, ptr byval(i32) %p) {
; CHECK-LABEL: define internal void @test2(
; CHECK-SAME: ptr byval(i32) [[P:%.*]]) {
; CHECK-NEXT: store i32 0, ptr [[P]], align 4
; CHECK-NEXT: ret void
;
store i32 0, ptr %p
ret void
}

define internal void @test3(ptr byval(i32) %DEADARG1, i32 %v, ptr byval(i32) %p) {
; CHECK-LABEL: define internal void @test3(
; CHECK-SAME: i32 [[V:%.*]], ptr byval(i32) [[P:%.*]]) {
; CHECK-NEXT: store i32 [[V]], ptr [[P]], align 4
; CHECK-NEXT: ret void
;
store i32 %v, ptr %p
ret void
}

define internal void @test4(ptr byval(i32) %p, i32 %v, ptr byval(i32) %DEADARG) {
; CHECK-LABEL: define internal void @test4(
; CHECK-SAME: ptr byval(i32) [[P:%.*]], i32 [[V:%.*]]) {
; CHECK-NEXT: store i32 [[V]], ptr [[P]], align 4
; CHECK-NEXT: ret void
;
store i32 %v, ptr %p
ret void
}

define internal void @test5(ptr byval(i32) %p, i32 %x, ptr byval(i32) %DEADARG1, ptr byval(i32) %DEADARG2, i32 %y, ptr byval(i32) %q) {
; CHECK-LABEL: define internal void @test5(
; CHECK-SAME: ptr byval(i32) [[P:%.*]], i32 [[X:%.*]], i32 [[Y:%.*]], ptr byval(i32) [[Q:%.*]]) {
; CHECK-NEXT: [[T:%.*]] = add i32 [[X]], [[Y]]
; CHECK-NEXT: store i32 [[T]], ptr [[P]], align 4
; CHECK-NEXT: store i32 [[T]], ptr [[Q]], align 4
; CHECK-NEXT: ret void
;
%t = add i32 %x, %y
store i32 %t, ptr %p
store i32 %t, ptr %q
ret void
}

!nvvm.annotations = !{
!0, !1,
!3, !4, !6,
!8, !9, !11,
!13, !14, !16,
!18, !19
}

; Note - also test various permutations of the parameter lists, as they are not
; specified to be in any particular order (e.g., consecutive).
!0 = !{ptr @test1, !"kernel", i32 1}
!1 = !{ptr @test1, !"grid_constant", !2}
!2 = !{i32 2}

!3 = !{ptr @test2, !"kernel", i32 1}
!4 = !{ptr @test2, !"grid_constant", !5}
!5 = !{i32 1, i32 2}
!6 = !{ptr @test2, !"grid_constant", !7}
!7 = !{i32 2, i32 1}

!8 = !{ptr @test3, !"kernel", i32 1}
!9 = !{ptr @test3, !"grid_constant", !10}
!10 = !{i32 1, i32 3}
!11 = !{ptr @test3, !"grid_constant", !12}
!12 = !{i32 3, i32 1}

!13 = !{ptr @test4, !"kernel", i32 1}
!14 = !{ptr @test4, !"grid_constant", !15}
!15 = !{i32 1, i32 3}
!16 = !{ptr @test4, !"grid_constant", !17}
!17 = !{i32 3, i32 1}

!18 = !{ptr @test5, !"kernel", i32 1}
!19 = !{ptr @test5, !"grid_constant", !20, !"grid_constant", !21, !"grid_constant", !22}
!20 = !{i32 1, i32 3, i32 4, i32 6}
!21 = !{i32 3, i32 1, i32 4, i32 6}
!22 = !{i32 3, i32 1, i32 6, i32 4}
;.
; CHECK: [[META0:![0-9]+]] = !{ptr @test1, !"kernel", i32 1}
; CHECK: [[META1:![0-9]+]] = !{ptr @test1, !"grid_constant", [[META2:![0-9]+]]}
; CHECK: [[META2]] = !{}
; CHECK: [[META3:![0-9]+]] = !{ptr @test2, !"kernel", i32 1}
; CHECK: [[META4:![0-9]+]] = !{ptr @test2, !"grid_constant", [[META5:![0-9]+]]}
; CHECK: [[META5]] = !{i32 1}
; CHECK: [[META6:![0-9]+]] = distinct !{ptr @test2, !"grid_constant", [[META5]]}
; CHECK: [[META7:![0-9]+]] = !{ptr @test3, !"kernel", i32 1}
; CHECK: [[META8:![0-9]+]] = !{ptr @test3, !"grid_constant", [[META9:![0-9]+]]}
; CHECK: [[META9]] = !{i32 2}
; CHECK: [[META10:![0-9]+]] = distinct !{ptr @test3, !"grid_constant", [[META9]]}
; CHECK: [[META11:![0-9]+]] = !{ptr @test4, !"kernel", i32 1}
; CHECK: [[META12:![0-9]+]] = !{ptr @test4, !"grid_constant", [[META5]]}
; CHECK: [[META13:![0-9]+]] = distinct !{ptr @test4, !"grid_constant", [[META5]]}
; CHECK: [[META14:![0-9]+]] = !{ptr @test5, !"kernel", i32 1}
; CHECK: [[META15:![0-9]+]] = !{ptr @test5, !"grid_constant", [[META16:![0-9]+]], !"grid_constant", [[META16]], !"grid_constant", [[META16]]}
; CHECK: [[META16]] = !{i32 1, i32 4}
;.
Loading