Skip to content

[SPIR-V] Cast ptr kernel args to i8* when used as Store's value operand #78603

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions llvm/lib/Target/SPIRV/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ add_llvm_target(SPIRVCodeGen
SPIRVISelLowering.cpp
SPIRVLegalizerInfo.cpp
SPIRVMCInstLower.cpp
SPIRVMetadata.cpp
SPIRVModuleAnalysis.cpp
SPIRVPreLegalizer.cpp
SPIRVPrepareFunctions.cpp
Expand Down
63 changes: 5 additions & 58 deletions llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "SPIRVBuiltins.h"
#include "SPIRVGlobalRegistry.h"
#include "SPIRVISelLowering.h"
#include "SPIRVMetadata.h"
#include "SPIRVRegisterInfo.h"
#include "SPIRVSubtarget.h"
#include "SPIRVUtils.h"
Expand Down Expand Up @@ -117,64 +118,12 @@ static FunctionType *getOriginalFunctionType(const Function &F) {
return FunctionType::get(RetTy, ArgTypes, F.isVarArg());
}

static MDString *getKernelArgAttribute(const Function &KernelFunction,
unsigned ArgIdx,
const StringRef AttributeName) {
assert(KernelFunction.getCallingConv() == CallingConv::SPIR_KERNEL &&
"Kernel attributes are attached/belong only to kernel functions");

// Lookup the argument attribute in metadata attached to the kernel function.
MDNode *Node = KernelFunction.getMetadata(AttributeName);
if (Node && ArgIdx < Node->getNumOperands())
return cast<MDString>(Node->getOperand(ArgIdx));

// Sometimes metadata containing kernel attributes is not attached to the
// function, but can be found in the named module-level metadata instead.
// For example:
// !opencl.kernels = !{!0}
// !0 = !{void ()* @someKernelFunction, !1, ...}
// !1 = !{!"kernel_arg_addr_space", ...}
// In this case the actual index of searched argument attribute is ArgIdx + 1,
// since the first metadata node operand is occupied by attribute name
// ("kernel_arg_addr_space" in the example above).
unsigned MDArgIdx = ArgIdx + 1;
NamedMDNode *OpenCLKernelsMD =
KernelFunction.getParent()->getNamedMetadata("opencl.kernels");
if (!OpenCLKernelsMD || OpenCLKernelsMD->getNumOperands() == 0)
return nullptr;

// KernelToMDNodeList contains kernel function declarations followed by
// corresponding MDNodes for each attribute. Search only MDNodes "belonging"
// to the currently lowered kernel function.
MDNode *KernelToMDNodeList = OpenCLKernelsMD->getOperand(0);
bool FoundLoweredKernelFunction = false;
for (const MDOperand &Operand : KernelToMDNodeList->operands()) {
ValueAsMetadata *MaybeValue = dyn_cast<ValueAsMetadata>(Operand);
if (MaybeValue && dyn_cast<Function>(MaybeValue->getValue())->getName() ==
KernelFunction.getName()) {
FoundLoweredKernelFunction = true;
continue;
}
if (MaybeValue && FoundLoweredKernelFunction)
return nullptr;

MDNode *MaybeNode = dyn_cast<MDNode>(Operand);
if (FoundLoweredKernelFunction && MaybeNode &&
cast<MDString>(MaybeNode->getOperand(0))->getString() ==
AttributeName &&
MDArgIdx < MaybeNode->getNumOperands())
return cast<MDString>(MaybeNode->getOperand(MDArgIdx));
}
return nullptr;
}

static SPIRV::AccessQualifier::AccessQualifier
getArgAccessQual(const Function &F, unsigned ArgIdx) {
if (F.getCallingConv() != CallingConv::SPIR_KERNEL)
return SPIRV::AccessQualifier::ReadWrite;

MDString *ArgAttribute =
getKernelArgAttribute(F, ArgIdx, "kernel_arg_access_qual");
MDString *ArgAttribute = getOCLKernelArgAccessQual(F, ArgIdx);
if (!ArgAttribute)
return SPIRV::AccessQualifier::ReadWrite;

Expand All @@ -186,9 +135,8 @@ getArgAccessQual(const Function &F, unsigned ArgIdx) {
}

static std::vector<SPIRV::Decoration::Decoration>
getKernelArgTypeQual(const Function &KernelFunction, unsigned ArgIdx) {
MDString *ArgAttribute =
getKernelArgAttribute(KernelFunction, ArgIdx, "kernel_arg_type_qual");
getKernelArgTypeQual(const Function &F, unsigned ArgIdx) {
MDString *ArgAttribute = getOCLKernelArgTypeQual(F, ArgIdx);
if (ArgAttribute && ArgAttribute->getString().compare("volatile") == 0)
return {SPIRV::Decoration::Volatile};
return {};
Expand All @@ -209,8 +157,7 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
isSpecialOpaqueType(OriginalArgType))
return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual);

MDString *MDKernelArgType =
getKernelArgAttribute(F, ArgIdx, "kernel_arg_type");
MDString *MDKernelArgType = getOCLKernelArgType(F, ArgIdx);
if (!MDKernelArgType || (!MDKernelArgType->getString().ends_with("*") &&
!MDKernelArgType->getString().ends_with("_t")))
return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual);
Expand Down
34 changes: 28 additions & 6 deletions llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
//===----------------------------------------------------------------------===//

#include "SPIRV.h"
#include "SPIRVMetadata.h"
#include "SPIRVTargetMachine.h"
#include "SPIRVUtils.h"
#include "llvm/IR/IRBuilder.h"
Expand Down Expand Up @@ -282,7 +283,26 @@ void SPIRVEmitIntrinsics::insertPtrCastInstr(Instruction *I) {
Value *Pointer;
Type *ExpectedElementType;
unsigned OperandToReplace;
if (StoreInst *SI = dyn_cast<StoreInst>(I)) {
bool AllowCastingToChar = false;

StoreInst *SI = dyn_cast<StoreInst>(I);
if (SI && F->getCallingConv() == CallingConv::SPIR_KERNEL &&
SI->getValueOperand()->getType()->isPointerTy() &&
isa<Argument>(SI->getValueOperand())) {
Argument *Arg = cast<Argument>(SI->getValueOperand());
MDString *ArgType = getOCLKernelArgType(*Arg->getParent(), Arg->getArgNo());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can Arg be nullptr?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should be able to safely use cast<> here instead of dyn_cast since you have the isa check.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @sudonatalie! Fixed!

if (!ArgType || ArgType->getString().starts_with("uchar*"))
return;

// Handle special case when StoreInst's value operand is a kernel argument
// of a pointer type. Since these arguments could have either a basic
// element type (e.g. float*) or OpenCL builtin type (sampler_t), bitcast
// the StoreInst's value operand to default pointer element type (i8).
Pointer = Arg;
ExpectedElementType = IntegerType::getInt8Ty(F->getContext());
OperandToReplace = 0;
AllowCastingToChar = true;
} else if (SI) {
Pointer = SI->getPointerOperand();
ExpectedElementType = SI->getValueOperand()->getType();
OperandToReplace = 1;
Expand Down Expand Up @@ -364,13 +384,15 @@ void SPIRVEmitIntrinsics::insertPtrCastInstr(Instruction *I) {

// Do not emit spv_ptrcast if it would cast to the default pointer element
// type (i8) of the same address space.
if (ExpectedElementType->isIntegerTy(8))
if (ExpectedElementType->isIntegerTy(8) && !AllowCastingToChar)
return;

// If this would be the first spv_ptrcast and there is no spv_assign_ptr_type
// for this pointer before, do not emit spv_ptrcast but emit
// spv_assign_ptr_type instead.
if (FirstPtrCastOrAssignPtrType && isa<Instruction>(Pointer)) {
// If this would be the first spv_ptrcast, the pointer's defining instruction
// requires spv_assign_ptr_type and does not already have one, do not emit
// spv_ptrcast and emit spv_assign_ptr_type instead.
Instruction *PointerDefInst = dyn_cast<Instruction>(Pointer);
if (FirstPtrCastOrAssignPtrType && PointerDefInst &&
requireAssignPtrType(PointerDefInst)) {
buildIntrWithMD(Intrinsic::spv_assign_ptr_type, {Pointer->getType()},
ExpectedElementTypeConst, Pointer,
{IRB->getInt32(AddressSpace)});
Expand Down
92 changes: 92 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVMetadata.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
//===--- SPIRVMetadata.cpp ---- IR Metadata Parsing Funcs -------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains functions needed for parsing LLVM IR metadata relevant
// to the SPIR-V target.
//
//===----------------------------------------------------------------------===//

#include "SPIRVMetadata.h"

using namespace llvm;

static MDString *getOCLKernelArgAttribute(const Function &F, unsigned ArgIdx,
const StringRef AttributeName) {
assert(
F.getCallingConv() == CallingConv::SPIR_KERNEL &&
"Kernel attributes are attached/belong only to OpenCL kernel functions");

// Lookup the argument attribute in metadata attached to the kernel function.
MDNode *Node = F.getMetadata(AttributeName);
if (Node && ArgIdx < Node->getNumOperands())
return cast<MDString>(Node->getOperand(ArgIdx));

// Sometimes metadata containing kernel attributes is not attached to the
// function, but can be found in the named module-level metadata instead.
// For example:
// !opencl.kernels = !{!0}
// !0 = !{void ()* @someKernelFunction, !1, ...}
// !1 = !{!"kernel_arg_addr_space", ...}
// In this case the actual index of searched argument attribute is ArgIdx + 1,
// since the first metadata node operand is occupied by attribute name
// ("kernel_arg_addr_space" in the example above).
unsigned MDArgIdx = ArgIdx + 1;
NamedMDNode *OpenCLKernelsMD =
F.getParent()->getNamedMetadata("opencl.kernels");
if (!OpenCLKernelsMD || OpenCLKernelsMD->getNumOperands() == 0)
return nullptr;

// KernelToMDNodeList contains kernel function declarations followed by
// corresponding MDNodes for each attribute. Search only MDNodes "belonging"
// to the currently lowered kernel function.
MDNode *KernelToMDNodeList = OpenCLKernelsMD->getOperand(0);
bool FoundLoweredKernelFunction = false;
for (const MDOperand &Operand : KernelToMDNodeList->operands()) {
ValueAsMetadata *MaybeValue = dyn_cast<ValueAsMetadata>(Operand);
if (MaybeValue &&
dyn_cast<Function>(MaybeValue->getValue())->getName() == F.getName()) {
FoundLoweredKernelFunction = true;
continue;
}
if (MaybeValue && FoundLoweredKernelFunction)
return nullptr;

MDNode *MaybeNode = dyn_cast<MDNode>(Operand);
if (FoundLoweredKernelFunction && MaybeNode &&
cast<MDString>(MaybeNode->getOperand(0))->getString() ==
AttributeName &&
MDArgIdx < MaybeNode->getNumOperands())
return cast<MDString>(MaybeNode->getOperand(MDArgIdx));
}
return nullptr;
}

namespace llvm {

MDString *getOCLKernelArgAccessQual(const Function &F, unsigned ArgIdx) {
assert(
F.getCallingConv() == CallingConv::SPIR_KERNEL &&
"Kernel attributes are attached/belong only to OpenCL kernel functions");
return getOCLKernelArgAttribute(F, ArgIdx, "kernel_arg_access_qual");
}

MDString *getOCLKernelArgTypeQual(const Function &F, unsigned ArgIdx) {
assert(
F.getCallingConv() == CallingConv::SPIR_KERNEL &&
"Kernel attributes are attached/belong only to OpenCL kernel functions");
return getOCLKernelArgAttribute(F, ArgIdx, "kernel_arg_type_qual");
}

MDString *getOCLKernelArgType(const Function &F, unsigned ArgIdx) {
assert(
F.getCallingConv() == CallingConv::SPIR_KERNEL &&
"Kernel attributes are attached/belong only to OpenCL kernel functions");
return getOCLKernelArgAttribute(F, ArgIdx, "kernel_arg_type");
}

} // namespace llvm
31 changes: 31 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVMetadata.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
//===--- SPIRVMetadata.h ---- IR Metadata Parsing Funcs ---------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains functions needed for parsing LLVM IR metadata relevant
// to the SPIR-V target.
//
//===----------------------------------------------------------------------===//

#ifndef LLVM_LIB_TARGET_SPIRV_SPIRVMETADATA_H
#define LLVM_LIB_TARGET_SPIRV_SPIRVMETADATA_H

#include "llvm/IR/Metadata.h"
#include "llvm/IR/Module.h"

namespace llvm {

//===----------------------------------------------------------------------===//
// OpenCL Metadata
//

MDString *getOCLKernelArgAccessQual(const Function &F, unsigned ArgIdx);
MDString *getOCLKernelArgTypeQual(const Function &F, unsigned ArgIdx);
MDString *getOCLKernelArgType(const Function &F, unsigned ArgIdx);

} // namespace llvm
#endif // LLVM_LIB_TARGET_SPIRV_METADATA_H
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}

; CHECK-DAG: %[[#CHAR:]] = OpTypeInt 8
; CHECK-DAG: %[[#GLOBAL_PTR_CHAR:]] = OpTypePointer CrossWorkgroup %[[#CHAR]]

define spir_kernel void @foo(ptr addrspace(1) %arg) {
ret void
}

; CHECK: %[[#]] = OpFunctionParameter %[[#GLOBAL_PTR_CHAR]]
14 changes: 14 additions & 0 deletions llvm/test/CodeGen/SPIRV/pointers/kernel-argument-ptr-no-bitcast.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}

; CHECK-DAG: %[[#CHAR:]] = OpTypeInt 8
; CHECK-DAG: %[[#GLOBAL_PTR_CHAR:]] = OpTypePointer CrossWorkgroup %[[#CHAR]]

define spir_kernel void @foo(i8 %a, ptr addrspace(1) %p) {
store i8 %a, ptr addrspace(1) %p
ret void
}

; CHECK: %[[#A:]] = OpFunctionParameter %[[#CHAR]]
; CHECK: %[[#P:]] = OpFunctionParameter %[[#GLOBAL_PTR_CHAR]]
; CHECK: OpStore %[[#P]] %[[#A]]
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}

; CHECK-DAG: %[[#CHAR:]] = OpTypeInt 8
; CHECK-DAG: %[[#GLOBAL_PTR_CHAR:]] = OpTypePointer CrossWorkgroup %[[#CHAR]]

define spir_kernel void @foo(ptr addrspace(1) %arg) !kernel_arg_addr_space !1 !kernel_arg_access_qual !2 !kernel_arg_type !3 !kernel_arg_base_type !3 !kernel_arg_type_qual !4 {
%var = alloca ptr addrspace(1), align 8
; CHECK: %[[#]] = OpFunctionParameter %[[#GLOBAL_PTR_CHAR]]
; CHECK-NOT: %[[#]] = OpBitcast %[[#]] %[[#]]
store ptr addrspace(1) %arg, ptr %var, align 8
ret void
}

!1 = !{i32 1}
!2 = !{!"none"}
!3 = !{!"char*"}
!4 = !{!""}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}

define spir_kernel void @foo(ptr addrspace(1) %arg) !kernel_arg_addr_space !1 !kernel_arg_access_qual !2 !kernel_arg_type !3 !kernel_arg_base_type !3 !kernel_arg_type_qual !4 {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have a test that covers the case you mentioned where the arguments can be found in the module-level metadata?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we do have a test in llvm/test/CodeGen/SPIRV/opencl/metadata/kernel_arg_type_module_metadata.ll

%var = alloca ptr addrspace(1), align 8
; CHECK: %[[#VAR:]] = OpVariable %[[#]] Function
store ptr addrspace(1) %arg, ptr %var, align 8
; The test itends to verify that OpStore uses OpVariable result directly (without a bitcast).
; Other type checking is done by spirv-val.
; CHECK: OpStore %[[#VAR]] %[[#]] Aligned 8
%lod = load ptr addrspace(1), ptr %var, align 8
%idx = getelementptr inbounds i64, ptr addrspace(1) %lod, i64 0
ret void
}

!1 = !{i32 1}
!2 = !{!"none"}
!3 = !{!"ulong*"}
!4 = !{!""}