Skip to content

Commit b42e92e

Browse files
AlexeySachkovbader
authored andcommitted
Implementation of SPV_INTEL_function_pointers extension
The extension specification is published here https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/spirv/SPV_INTEL_function_pointers.asciidoc Overview: This extensions brings two "levels" of function pointers support added under corresponding capabilities: Two new instructions are added under FunctionPointersINTEL capability: - OpFunctionPointerINTEL to support "address of" operator for functions, - OpFunctionPointerCallINTEL to do indirect function calls. ReferencedIndirectlyINTEL decoration is added under IndirectReferencesINTEL capability. This decoration can be attached to functions which are not referenced directly in the module. These function must not be optimized out based on call graph/reachability analysis. Signed-off-by: Alexey Sachkov <[email protected]> Signed-off-by: Alexey Sotkin <[email protected]>
1 parent 9cde62e commit b42e92e

20 files changed

+773
-12
lines changed

llvm-spirv/lib/SPIRV/SPIRVReader.cpp

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -383,11 +383,12 @@ Type *SPIRVToLLVM::transType(SPIRVType *T, bool IsClassMember) {
383383
case OpTypeArray:
384384
return mapType(T, ArrayType::get(transType(T->getArrayElementType()),
385385
T->getArrayLength()));
386-
case OpTypePointer:
386+
case OpTypePointer: {
387387
return mapType(
388388
T, PointerType::get(
389389
transType(T->getPointerElementType(), IsClassMember),
390390
SPIRSPIRVAddrSpaceMap::rmap(T->getPointerStorageClass())));
391+
}
391392
case OpTypeVector:
392393
return mapType(T, VectorType::get(transType(T->getVectorComponentType()),
393394
T->getVectorComponentCount()));
@@ -500,8 +501,19 @@ std::string SPIRVToLLVM::transTypeToOCLTypeName(SPIRVType *T, bool IsSigned) {
500501
break;
501502
case OpTypeArray:
502503
return "array";
503-
case OpTypePointer:
504-
return transTypeToOCLTypeName(T->getPointerElementType()) + "*";
504+
case OpTypePointer: {
505+
SPIRVType *ET = T->getPointerElementType();
506+
if (isa<OpTypeFunction>(ET)) {
507+
SPIRVTypeFunction *TF = static_cast<SPIRVTypeFunction*>(ET);
508+
std::string name = transTypeToOCLTypeName(TF->getReturnType());
509+
name += " (*)(";
510+
for(unsigned I = 0, E = TF->getNumParameters(); I < E; ++I)
511+
name += transTypeToOCLTypeName(TF->getParameterType(I)) + ',';
512+
name.back() = ')'; // replace the last comma with a closing brace.
513+
return name;
514+
}
515+
return transTypeToOCLTypeName(ET) + "*";
516+
}
505517
case OpTypeVector:
506518
return transTypeToOCLTypeName(T->getVectorComponentType()) +
507519
T->getVectorComponentCount();
@@ -1687,6 +1699,27 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
16871699
return mapValue(BV, Call);
16881700
}
16891701

1702+
case OpFunctionPointerCallINTEL: {
1703+
SPIRVFunctionPointerCallINTEL *BC =
1704+
static_cast<SPIRVFunctionPointerCallINTEL *>(BV);
1705+
auto Call = CallInst::Create(transValue(BC->getCalledValue(), F, BB),
1706+
transValue(BC->getArgumentValues(), F, BB),
1707+
BC->getName(), BB);
1708+
// Assuming we are calling a regular device function
1709+
Call->setCallingConv(CallingConv::SPIR_FUNC);
1710+
// Don't set attributes, because at translation time we don't know which
1711+
// function exactly we are calling.
1712+
return mapValue(BV, Call);
1713+
}
1714+
1715+
case OpFunctionPointerINTEL: {
1716+
SPIRVFunctionPointerINTEL *BC =
1717+
static_cast<SPIRVFunctionPointerINTEL *>(BV);
1718+
SPIRVFunction* F = BC->getFunction();
1719+
BV->setName(F->getName());
1720+
return mapValue(BV, transFunction(F));
1721+
}
1722+
16901723
case OpExtInst: {
16911724
auto *ExtInst = static_cast<SPIRVExtInst *>(BV);
16921725
switch (ExtInst->getExtSetKind()) {
@@ -1869,6 +1902,10 @@ Function *SPIRVToLLVM::transFunction(SPIRVFunction *BF) {
18691902
Function *F = cast<Function>(
18701903
mapValue(BF, Function::Create(FT, Linkage, BF->getName(), M)));
18711904
mapFunction(BF, F);
1905+
1906+
if (BF->hasDecorate(DecorationReferencedIndirectlyINTEL))
1907+
F->addFnAttr("referenced-indirectly");
1908+
18721909
if (!F->isIntrinsic()) {
18731910
F->setCallingConv(IsKernel ? CallingConv::SPIR_KERNEL
18741911
: CallingConv::SPIR_FUNC);

llvm-spirv/lib/SPIRV/SPIRVRegularizeLLVM.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,8 @@ void SPIRVRegularizeLLVM::lowerFuncPtr(Module *M) {
182182
auto AI = F.arg_begin();
183183
if (hasFunctionPointerArg(&F, AI)) {
184184
auto OC = getSPIRVFuncOC(F.getName());
185-
assert(OC != OpNop && "Invalid function pointer usage");
186-
Work.push_back(std::make_pair(&F, OC));
185+
if (OC != OpNop) // not a user-defined function
186+
Work.push_back(std::make_pair(&F, OC));
187187
}
188188
}
189189
for (auto &I : Work)

llvm-spirv/lib/SPIRV/SPIRVWriter.cpp

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,19 @@ SPIRVType *LLVMToSPIRV::transType(Type *T) {
261261
// (non-pointer) image or pipe type.
262262
if (T->isPointerTy()) {
263263
auto ET = T->getPointerElementType();
264-
assert(!ET->isFunctionTy() && "Function pointer type is not allowed");
264+
bool IsFuncPtrAllowed = true;
265+
// TODO: uncomment the line below once
266+
// https://github.com/KhronosGroup/SPIRV-LLVM-Translator/pull/244 is merged.
267+
// BM->isAllowedToUseExtension(ExtensionID::SPV_INTEL_function_pointers);
268+
if (ET->isFunctionTy() && !IsFuncPtrAllowed) {
269+
std::string S;
270+
raw_string_ostream RSOS(S);
271+
T->print(RSOS);
272+
BM->getErrorLog().checkError(false, SPIRVEC_FunctionPointersDisallowed,
273+
RSOS.str());
274+
BM->setInvalid();
275+
return nullptr;
276+
}
265277
auto ST = dyn_cast<StructType>(ET);
266278
auto AddrSpc = T->getPointerAddressSpace();
267279
if (ST && !ST->isSized()) {
@@ -505,6 +517,11 @@ SPIRVFunction *LLVMToSPIRV::transFunctionDecl(Function *F) {
505517
BF->addDecorate(DecorationFuncParamAttr, FunctionParameterAttributeZext);
506518
if (Attrs.hasAttribute(AttributeList::ReturnIndex, Attribute::SExt))
507519
BF->addDecorate(DecorationFuncParamAttr, FunctionParameterAttributeSext);
520+
if (Attrs.hasFnAttribute("referenced-indirectly")) {
521+
assert(!oclIsKernel(F) &&
522+
"kernel function was marked as referenced-indirectly");
523+
BF->addDecorate(DecorationReferencedIndirectlyINTEL);
524+
}
508525
SPIRVDBG(dbgs() << "[transFunction] " << *F << " => ";
509526
spvdbgs() << *BF << '\n';)
510527
return BF;
@@ -780,9 +797,19 @@ SPIRVValue *LLVMToSPIRV::transValueWithoutDecoration(Value *V,
780797
MemoryAccess[0] |= MemoryAccessNontemporalMask;
781798
if (MemoryAccess.front() == 0)
782799
MemoryAccess.clear();
800+
801+
SPIRVValue *BSV = nullptr;
802+
if(Function *F = dyn_cast<Function>(ST->getValueOperand())) {
803+
// store of function pointer
804+
BSV = BM->addFunctionPointerINTELInst(
805+
transType(F->getType()),
806+
static_cast<SPIRVFunction *>(transValue(F, BB)), BB);
807+
} else {
808+
BSV = transValue(ST->getValueOperand(), BB);
809+
}
810+
783811
return mapValue(V, BM->addStoreInst(transValue(ST->getPointerOperand(), BB),
784-
transValue(ST->getValueOperand(), BB),
785-
MemoryAccess, BB));
812+
BSV, MemoryAccess, BB));
786813
}
787814

788815
if (LoadInst *LD = dyn_cast<LoadInst>(V)) {
@@ -911,8 +938,17 @@ SPIRVValue *LLVMToSPIRV::transValueWithoutDecoration(Value *V,
911938

912939
if (auto Phi = dyn_cast<PHINode>(V)) {
913940
std::vector<SPIRVValue *> IncomingPairs;
941+
914942
for (size_t I = 0, E = Phi->getNumIncomingValues(); I != E; ++I) {
915-
IncomingPairs.push_back(transValue(Phi->getIncomingValue(I), BB));
943+
SPIRVValue *BV = nullptr;
944+
if (Function *F = dyn_cast<Function>(Phi->getIncomingValue(I))) {
945+
BV = BM->addFunctionPointerINTELInst(
946+
transType(F->getType()),
947+
static_cast<SPIRVFunction *>(transValue(F, BB)), BB);
948+
} else {
949+
BV = transValue(Phi->getIncomingValue(I), BB);
950+
}
951+
IncomingPairs.push_back(BV);
916952
IncomingPairs.push_back(transValue(Phi->getIncomingBlock(I), nullptr));
917953
}
918954
return mapValue(
@@ -1425,6 +1461,12 @@ SPIRVValue *LLVMToSPIRV::transIntrinsicInst(IntrinsicInst *II,
14251461
}
14261462

14271463
SPIRVValue *LLVMToSPIRV::transCallInst(CallInst *CI, SPIRVBasicBlock *BB) {
1464+
if (CI->isIndirectCall())
1465+
return transIndirectCallInst(CI, BB);
1466+
return transDirectCallInst(CI, BB);
1467+
}
1468+
1469+
SPIRVValue *LLVMToSPIRV::transDirectCallInst(CallInst *CI, SPIRVBasicBlock *BB) {
14281470
SPIRVExtInstSetKind ExtSetKind = SPIRVEIS_Count;
14291471
SPIRVWord ExtOp = SPIRVWORD_MAX;
14301472
llvm::Function *F = CI->getCalledFunction();
@@ -1456,6 +1498,14 @@ SPIRVValue *LLVMToSPIRV::transCallInst(CallInst *CI, SPIRVBasicBlock *BB) {
14561498
BB);
14571499
}
14581500

1501+
SPIRVValue *LLVMToSPIRV::transIndirectCallInst(CallInst *CI,
1502+
SPIRVBasicBlock *BB) {
1503+
return BM->addIndirectCallInst(
1504+
transValue(CI->getCalledValue(), BB), transType(CI->getType()),
1505+
transArguments(CI, BB, SPIRVEntry::createUnique(OpFunctionCall).get()),
1506+
BB);
1507+
}
1508+
14591509
bool LLVMToSPIRV::transAddressingMode() {
14601510
Triple TargetTriple(M->getTargetTriple());
14611511

llvm-spirv/lib/SPIRV/SPIRVWriter.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ class LLVMToSPIRV : public ModulePass {
9595
bool transBuiltinSet();
9696
SPIRVValue *transIntrinsicInst(IntrinsicInst *Intrinsic, SPIRVBasicBlock *BB);
9797
SPIRVValue *transCallInst(CallInst *Call, SPIRVBasicBlock *BB);
98+
SPIRVValue *transDirectCallInst(CallInst *Call, SPIRVBasicBlock *BB);
99+
SPIRVValue *transIndirectCallInst(CallInst *Call, SPIRVBasicBlock *BB);
98100
bool transDecoration(Value *V, SPIRVValue *BV);
99101
SPIRVWord transFunctionControlMask(Function *);
100102
SPIRVFunction *transFunctionDecl(Function *F);

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVDecorate.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,8 @@ class SPIRVDecorate : public SPIRVDecorateGeneric {
152152
case DecorationSimpleDualPortINTEL:
153153
case DecorationMergeINTEL:
154154
return getSet(SPV_INTEL_fpga_memory_attributes);
155+
case DecorationReferencedIndirectlyINTEL:
156+
return getSet(SPV_INTEL_function_pointers);
155157
default:
156158
return SPIRVExtSet();
157159
}

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVEnum.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,8 @@ enum SPIRVExtensionKind {
119119
SPV_INTEL_device_side_avc_motion_estimation,
120120
SPV_INTEL_fpga_reg,
121121
SPV_INTEL_fpga_memory_attributes,
122-
SPV_INTEL_unstructured_loop_controls
122+
SPV_INTEL_unstructured_loop_controls,
123+
SPV_INTEL_function_pointers
123124
};
124125

125126
typedef std::set<SPIRVExtensionKind> SPIRVExtSet;
@@ -133,6 +134,7 @@ template <> inline void SPIRVMap<SPIRVExtensionKind, std::string>::init() {
133134
add(SPV_KHR_no_integer_wrap_decoration, "SPV_KHR_no_integer_wrap_decoration");
134135
add(SPV_INTEL_fpga_reg, "SPV_INTEL_fpga_reg");
135136
add(SPV_INTEL_fpga_memory_attributes, "SPV_INTEL_fpga_memory_attributes");
137+
add(SPV_INTEL_function_pointers, "SPV_INTEL_function_pointers");
136138
}
137139

138140
template <> inline void SPIRVMap<SPIRVExtInstSetKind, std::string>::init() {
@@ -376,6 +378,7 @@ template <> inline void SPIRVMap<Decoration, SPIRVCapVec>::init() {
376378
ADD_VEC_INIT(DecorationSimpleDualPortINTEL,
377379
{CapabilityFPGAMemoryAttributesINTEL});
378380
ADD_VEC_INIT(DecorationMergeINTEL, {CapabilityFPGAMemoryAttributesINTEL});
381+
ADD_VEC_INIT(DecorationReferencedIndirectlyINTEL, {CapabilityIndirectReferencesINTEL});
379382
}
380383

381384
template <> inline void SPIRVMap<BuiltIn, SPIRVCapVec>::init() {

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVErrorEnum.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@ _SPIRV_OP(InvalidFunctionCall, "Unexpected llvm intrinsic:")
1010
_SPIRV_OP(InvalidArraySize, "Array size must be at least 1:")
1111
_SPIRV_OP(InvalidModule, "Invalid SPIR-V module:")
1212
_SPIRV_OP(UnimplementedOpCode, "Unimplemented opcode")
13+
_SPIRV_OP(FunctionPointersDisallowed, "Can't translate the function pointer:\n")

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVInstruction.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,31 @@ void SPIRVFunctionCall::validate() const {
119119
SPIRVFunctionCallGeneric::validate();
120120
}
121121

122+
SPIRVFunctionPointerCallINTEL::SPIRVFunctionPointerCallINTEL(
123+
SPIRVId TheId, SPIRVValue *TheCalledValue, SPIRVType *TheReturnType,
124+
const std::vector<SPIRVWord> &TheArgs, SPIRVBasicBlock *BB)
125+
: SPIRVFunctionCallGeneric(TheReturnType, TheId, TheArgs, BB),
126+
CalledValueId(TheCalledValue->getId()) {
127+
validate();
128+
}
129+
130+
void SPIRVFunctionPointerCallINTEL::validate() const {
131+
SPIRVFunctionCallGeneric::validate();
132+
}
133+
134+
SPIRVFunctionPointerINTEL::SPIRVFunctionPointerINTEL(SPIRVId TheId,
135+
SPIRVType *TheType,
136+
SPIRVFunction *TheFunction,
137+
SPIRVBasicBlock *BB)
138+
: SPIRVInstruction(FixedWordCount, OC, TheType, TheId, BB),
139+
TheFunction(TheFunction->getId()) {
140+
validate();
141+
}
142+
143+
void SPIRVFunctionPointerINTEL::validate() const {
144+
SPIRVInstruction::validate();
145+
}
146+
122147
// ToDo: Each instruction should implement this function
123148
std::vector<SPIRVValue *> SPIRVInstruction::getOperands() {
124149
std::vector<SPIRVValue *> Empty;

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVInstruction.h

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1451,6 +1451,50 @@ class SPIRVFunctionCall : public SPIRVFunctionCallGeneric<OpFunctionCall, 4> {
14511451
SPIRVId FunctionId;
14521452
};
14531453

1454+
class SPIRVFunctionPointerCallINTEL
1455+
: public SPIRVFunctionCallGeneric<OpFunctionPointerCallINTEL, 4> {
1456+
public:
1457+
SPIRVFunctionPointerCallINTEL(SPIRVId TheId, SPIRVValue *TheCalledValue,
1458+
SPIRVType *TheReturnType,
1459+
const std::vector<SPIRVWord> &TheArgs,
1460+
SPIRVBasicBlock *BB);
1461+
SPIRVFunctionPointerCallINTEL() : CalledValueId(SPIRVID_INVALID) {}
1462+
SPIRVValue *getCalledValue() const { return get<SPIRVValue>(CalledValueId); }
1463+
_SPIRV_DEF_ENCDEC4(Type, Id, CalledValueId, Args)
1464+
void validate() const override;
1465+
bool isOperandLiteral(unsigned Index) const override { return false; }
1466+
SPIRVExtSet getRequiredExtensions() const override {
1467+
return getSet(SPV_INTEL_function_pointers);
1468+
}
1469+
SPIRVCapVec getRequiredCapability() const override {
1470+
return getVec(CapabilityFunctionPointersINTEL);
1471+
}
1472+
1473+
protected:
1474+
SPIRVId CalledValueId;
1475+
};
1476+
1477+
class SPIRVFunctionPointerINTEL : public SPIRVInstruction {
1478+
const static Op OC = OpFunctionPointerINTEL;
1479+
const static SPIRVWord FixedWordCount = 4;
1480+
public:
1481+
SPIRVFunctionPointerINTEL(SPIRVId TheId, SPIRVType *TheType,
1482+
SPIRVFunction *TheFunction, SPIRVBasicBlock *BB);
1483+
SPIRVFunctionPointerINTEL() : SPIRVInstruction(OC), TheFunction(SPIRVID_INVALID) {}
1484+
SPIRVFunction *getFunction() const { return get<SPIRVFunction>(TheFunction); }
1485+
_SPIRV_DEF_ENCDEC3(Type, Id, TheFunction)
1486+
void validate() const override;
1487+
bool isOperandLiteral(unsigned Index) const override { return false; }
1488+
SPIRVExtSet getRequiredExtensions() const override {
1489+
return getSet(SPV_INTEL_function_pointers);
1490+
}
1491+
SPIRVCapVec getRequiredCapability() const override {
1492+
return getVec(CapabilityFunctionPointersINTEL);
1493+
}
1494+
protected:
1495+
SPIRVId TheFunction;
1496+
};
1497+
14541498
class SPIRVExtInst : public SPIRVFunctionCallGeneric<OpExtInst, 5> {
14551499
public:
14561500
SPIRVExtInst(SPIRVType *TheType, SPIRVId TheId, SPIRVId TheBuiltinSet,

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVIsValidEnum.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,7 @@ inline bool isValid(spv::Decoration V) {
405405
case DecorationMaxPrivateCopiesINTEL:
406406
case DecorationSinglepumpINTEL:
407407
case DecorationDoublepumpINTEL:
408+
case DecorationReferencedIndirectlyINTEL:
408409
return true;
409410
default:
410411
return false;

0 commit comments

Comments
 (0)