Skip to content

Commit 9f3c10f

Browse files
authored
Fix SPIR-V friendly LLVM IR for conversion functions (#899)
In most cases convert functions are translated to set of LLVM IR instructions but in case of saturated conversion or if conversion has rounding mode, the translation should go through SPIR-V friendly LLVM IR.
1 parent d1213fe commit 9f3c10f

File tree

9 files changed

+228
-43
lines changed

9 files changed

+228
-43
lines changed

lib/SPIRV/SPIRVInternal.h

+11
Original file line numberDiff line numberDiff line change
@@ -944,6 +944,17 @@ std::string mangleBuiltin(StringRef UniqName, ArrayRef<Type *> ArgTypes,
944944
std::string getSPIRVFriendlyIRFunctionName(OCLExtOpKind ExtOpId,
945945
ArrayRef<Type *> ArgTys);
946946

947+
/// Mangle a function in SPIR-V friendly IR manner
948+
/// \param UniqName full unmangled name of the SPIR-V built-in function that
949+
/// contains possible postfixes that depend not on opcode but on decorations or
950+
/// return type, for example __spirv_UConvert_Rint_sat.
951+
/// \param OC opcode of corresponding built-in instruction. Used to gather info
952+
/// for unsigned/constant arguments.
953+
/// \param Types of arguments of SPIR-V built-in function
954+
/// \return IA64 mangled name.
955+
std::string getSPIRVFriendlyIRFunctionName(const std::string &UniqName,
956+
spv::Op OC, ArrayRef<Type *> ArgTys);
957+
947958
/// Remove cast from a value.
948959
Value *removeCast(Value *V);
949960

lib/SPIRV/SPIRVReader.cpp

+35-16
Original file line numberDiff line numberDiff line change
@@ -2385,13 +2385,6 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
23852385
BV->getName(), BB));
23862386
}
23872387

2388-
case OpImageQuerySize:
2389-
case OpImageQuerySizeLod: {
2390-
return mapValue(
2391-
BV, transSPIRVBuiltinFromInst(static_cast<SPIRVInstruction *>(BV), BB,
2392-
/*AddRetTypePostfix=*/true));
2393-
}
2394-
23952388
case OpBitReverse: {
23962389
auto *BR = static_cast<SPIRVUnary *>(BV);
23972390
auto Ty = transType(BV->getType());
@@ -2670,7 +2663,7 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
26702663
auto BI = static_cast<SPIRVInstruction *>(BV);
26712664
Value *Inst = nullptr;
26722665
if (BI->hasFPRoundingMode() || BI->isSaturatedConversion())
2673-
Inst = transOCLBuiltinFromInst(BI, BB);
2666+
Inst = transSPIRVBuiltinFromInst(BI, BB);
26742667
else
26752668
Inst = transConvertInst(BV, F, BB);
26762669
return mapValue(BV, Inst);
@@ -3252,10 +3245,16 @@ Instruction *SPIRVToLLVM::transBuiltinFromInst(const std::string &FuncName,
32523245
HasFuncPtrArg = true;
32533246
}
32543247
}
3255-
if (!HasFuncPtrArg)
3256-
mangleOpenClBuiltin(FuncName, ArgTys, MangledName);
3257-
else
3248+
if (!HasFuncPtrArg) {
3249+
if (BM->getDesiredBIsRepresentation() != BIsRepresentation::SPIRVFriendlyIR)
3250+
mangleOpenClBuiltin(FuncName, ArgTys, MangledName);
3251+
else
3252+
MangledName =
3253+
getSPIRVFriendlyIRFunctionName(FuncName, BI->getOpCode(), ArgTys);
3254+
3255+
} else {
32583256
MangledName = decorateSPIRVFunction(FuncName);
3257+
}
32593258
Function *Func = M->getFunction(MangledName);
32603259
FunctionType *FT = FunctionType::get(RetTy, ArgTys, false);
32613260
// ToDo: Some intermediate functions have duplicate names with
@@ -3399,22 +3398,42 @@ std::string getSPIRVFuncSuffix(SPIRVInstruction *BI) {
33993398
break;
34003399
}
34013400
}
3401+
if (BI->hasDecorate(DecorationSaturatedConversion)) {
3402+
Suffix += kSPIRVPostfix::Divider;
3403+
Suffix += kSPIRVPostfix::Sat;
3404+
}
3405+
SPIRVFPRoundingModeKind Kind;
3406+
if (BI->hasFPRoundingMode(&Kind)) {
3407+
Suffix += kSPIRVPostfix::Divider;
3408+
Suffix += SPIRSPIRVFPRoundingModeMap::rmap(Kind);
3409+
}
34023410
return Suffix;
34033411
}
34043412

34053413
Instruction *SPIRVToLLVM::transSPIRVBuiltinFromInst(SPIRVInstruction *BI,
3406-
BasicBlock *BB,
3407-
bool AddRetTypePostfix) {
3414+
BasicBlock *BB) {
34083415
assert(BB && "Invalid BB");
3416+
const auto OC = BI->getOpCode();
3417+
bool AddRetTypePostfix = false;
3418+
if (OC == OpImageQuerySizeLod || OC == OpImageQuerySize)
3419+
AddRetTypePostfix = true;
3420+
3421+
bool IsRetSigned = false;
3422+
if (isCvtOpCode(OC)) {
3423+
AddRetTypePostfix = true;
3424+
if (OC == OpConvertUToF || OC == OpSatConvertUToS)
3425+
IsRetSigned = true;
3426+
}
3427+
34093428
if (AddRetTypePostfix) {
34103429
const Type *RetTy =
34113430
BI->hasType() ? transType(BI->getType()) : Type::getVoidTy(*Context);
3412-
return transBuiltinFromInst(getSPIRVFuncName(BI->getOpCode(), RetTy) +
3431+
return transBuiltinFromInst(getSPIRVFuncName(OC, RetTy, IsRetSigned) +
34133432
getSPIRVFuncSuffix(BI),
34143433
BI, BB);
34153434
}
3416-
return transBuiltinFromInst(
3417-
getSPIRVFuncName(BI->getOpCode(), getSPIRVFuncSuffix(BI)), BI, BB);
3435+
return transBuiltinFromInst(getSPIRVFuncName(OC, getSPIRVFuncSuffix(BI)), BI,
3436+
BB);
34183437
}
34193438

34203439
bool SPIRVToLLVM::translate() {

lib/SPIRV/SPIRVReader.h

+1-2
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,7 @@ class SPIRVToLLVM {
122122
Instruction *transBuiltinFromInst(const std::string &FuncName,
123123
SPIRVInstruction *BI, BasicBlock *BB);
124124
Instruction *transOCLBuiltinFromInst(SPIRVInstruction *BI, BasicBlock *BB);
125-
Instruction *transSPIRVBuiltinFromInst(SPIRVInstruction *BI, BasicBlock *BB,
126-
bool AddRetTypePostfix = false);
125+
Instruction *transSPIRVBuiltinFromInst(SPIRVInstruction *BI, BasicBlock *BB);
127126
void transOCLVectorLoadStore(std::string &UnmangledName,
128127
std::vector<SPIRVWord> &BArgs);
129128

lib/SPIRV/SPIRVToOCL.cpp

+31
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,10 @@ void SPIRVToOCL::visitCallInst(CallInst &CI) {
8989
visitCallSPIRVImageMediaBlockBuiltin(&CI, OC);
9090
return;
9191
}
92+
if (isCvtOpCode(OC)) {
93+
visitCallSPIRVCvtBuiltin(&CI, OC, DemangledName);
94+
return;
95+
}
9296
if (OCLSPIRVBuiltinMap::rfind(OC))
9397
visitCallSPIRVBuiltin(&CI, OC);
9498
}
@@ -498,6 +502,33 @@ void SPIRVToOCL::visitCallSPIRVImageMediaBlockBuiltin(CallInst *CI, Op OC) {
498502
&Attrs);
499503
}
500504

505+
void SPIRVToOCL::visitCallSPIRVCvtBuiltin(CallInst *CI, Op OC,
506+
StringRef DemangledName) {
507+
AttributeList Attrs = CI->getCalledFunction()->getAttributes();
508+
mutateCallInstOCL(
509+
M, CI,
510+
[=](CallInst *Call, std::vector<Value *> &Args) {
511+
std::string CastBuiltInName;
512+
if (isCvtFromUnsignedOpCode(OC))
513+
CastBuiltInName = "u";
514+
CastBuiltInName += kOCLBuiltinName::ConvertPrefix;
515+
Type *DstTy = Call->getType();
516+
CastBuiltInName +=
517+
mapLLVMTypeToOCLType(DstTy, !isCvtToUnsignedOpCode(OC));
518+
if (DemangledName.find("_sat") != StringRef::npos || isSatCvtOpCode(OC))
519+
CastBuiltInName += "_sat";
520+
Value *Src = Call->getOperand(0);
521+
assert(Src && "Invalid SPIRV convert builtin call");
522+
Type *SrcTy = Src->getType();
523+
auto Loc = DemangledName.find("_rt");
524+
if (Loc != StringRef::npos &&
525+
!(isa<IntegerType>(SrcTy) && isa<IntegerType>(DstTy)))
526+
CastBuiltInName += DemangledName.substr(Loc, 4).str();
527+
return CastBuiltInName;
528+
},
529+
&Attrs);
530+
}
531+
501532
void SPIRVToOCL::visitCallSPIRVBuiltin(CallInst *CI, Op OC) {
502533
AttributeList Attrs = CI->getCalledFunction()->getAttributes();
503534
mutateCallInstOCL(

lib/SPIRV/SPIRVToOCL.h

+6
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,12 @@ class SPIRVToOCL : public ModulePass, public InstVisitor<SPIRVToOCL> {
9898
/// intel_sub_group_media_block_write
9999
void visitCallSPIRVImageMediaBlockBuiltin(CallInst *CI, Op OC);
100100

101+
/// Transform __spirv_*Convert_R{ReturnType}{_sat}{_rtp|_rtn|_rtz|_rte} to
102+
/// convert_{ReturnType}_{sat}{_rtp|_rtn|_rtz|_rte}
103+
/// example: <2 x i8> __spirv_SatConvertUToS(<2 x i32>) =>
104+
/// convert_uchar2_sat(int2)
105+
void visitCallSPIRVCvtBuiltin(CallInst *CI, Op OC, StringRef DemangledName);
106+
101107
/// Transform __spirv_* builtins to OCL 2.0 builtins.
102108
/// No change with arguments.
103109
void visitCallSPIRVBuiltin(CallInst *CI, Op OC);

lib/SPIRV/SPIRVUtil.cpp

+54-1
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ std::string getSPIRVFuncName(Op OC, StringRef PostFix) {
379379

380380
std::string getSPIRVFuncName(Op OC, const Type *PRetTy, bool IsSigned) {
381381
return prefixSPIRVName(getName(OC) + kSPIRVPostfix::Divider +
382-
getPostfixForReturnType(PRetTy, false));
382+
getPostfixForReturnType(PRetTy, IsSigned));
383383
}
384384

385385
std::string getSPIRVExtFuncName(SPIRVExtInstSetKind Set, unsigned ExtOp,
@@ -1597,6 +1597,52 @@ bool checkTypeForSPIRVExtendedInstLowering(IntrinsicInst *II, SPIRVModule *BM) {
15971597
} // namespace SPIRV
15981598

15991599
namespace {
1600+
class SPIRVFriendlyIRMangleInfo : public BuiltinFuncMangleInfo {
1601+
public:
1602+
SPIRVFriendlyIRMangleInfo(spv::Op OC, ArrayRef<Type *> ArgTys)
1603+
: OC(OC), ArgTys(ArgTys) {}
1604+
1605+
void init(StringRef UniqUnmangledName) override {
1606+
UnmangledName = UniqUnmangledName.str();
1607+
switch (OC) {
1608+
case OpConvertUToF:
1609+
LLVM_FALLTHROUGH;
1610+
case OpUConvert:
1611+
LLVM_FALLTHROUGH;
1612+
case OpSatConvertUToS:
1613+
// Treat all arguments as unsigned
1614+
addUnsignedArg(-1);
1615+
break;
1616+
case OpSubgroupShuffleINTEL:
1617+
LLVM_FALLTHROUGH;
1618+
case OpSubgroupShuffleXorINTEL:
1619+
addUnsignedArg(1);
1620+
break;
1621+
case OpSubgroupShuffleDownINTEL:
1622+
LLVM_FALLTHROUGH;
1623+
case OpSubgroupShuffleUpINTEL:
1624+
addUnsignedArg(2);
1625+
break;
1626+
case OpSubgroupBlockWriteINTEL:
1627+
addUnsignedArg(0);
1628+
addUnsignedArg(1);
1629+
break;
1630+
case OpSubgroupImageBlockWriteINTEL:
1631+
addUnsignedArg(2);
1632+
break;
1633+
case OpSubgroupBlockReadINTEL:
1634+
setArgAttr(0, SPIR::ATTR_CONST);
1635+
addUnsignedArg(0);
1636+
break;
1637+
default:;
1638+
// No special handling is needed
1639+
}
1640+
}
1641+
1642+
private:
1643+
spv::Op OC;
1644+
ArrayRef<Type *> ArgTys;
1645+
};
16001646
class OpenCLStdToSPIRVFriendlyIRMangleInfo : public BuiltinFuncMangleInfo {
16011647
public:
16021648
OpenCLStdToSPIRVFriendlyIRMangleInfo(OCLExtOpKind ExtOpId,
@@ -1660,4 +1706,11 @@ std::string getSPIRVFriendlyIRFunctionName(OCLExtOpKind ExtOpId,
16601706
return mangleBuiltin(MangleInfo.getUnmangledName(), ArgTys, &MangleInfo);
16611707
}
16621708

1709+
std::string getSPIRVFriendlyIRFunctionName(const std::string &UniqName,
1710+
spv::Op OC,
1711+
ArrayRef<Type *> ArgTys) {
1712+
SPIRVFriendlyIRMangleInfo MangleInfo(OC, ArgTys);
1713+
return mangleBuiltin(UniqName, ArgTys, &MangleInfo);
1714+
}
1715+
16631716
} // namespace SPIRV

lib/SPIRV/libSPIRV/SPIRVOpCode.h

+4
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,10 @@ inline bool isCvtFromUnsignedOpCode(Op OpCode) {
115115
OpCode == OpSatConvertUToS;
116116
}
117117

118+
inline bool isSatCvtOpCode(Op OpCode) {
119+
return OpCode == OpSatConvertUToS || OpCode == OpSatConvertSToU;
120+
}
121+
118122
inline bool isOpaqueGenericTypeOpCode(Op OpCode) {
119123
return ((unsigned)OpCode >= OpTypeEvent && (unsigned)OpCode <= OpTypeQueue) ||
120124
OpCode == OpTypeSampler;

test/transcoding/SatConvert.cl

-24
This file was deleted.
+86
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
// RUN: %clang_cc1 -triple spir-unknown-unknown -O1 -cl-std=CL2.0 -fdeclare-opencl-builtins -finclude-default-header -emit-llvm-bc %s -o %t.bc
2+
// RUN: llvm-spirv %t.bc -spirv-text -o - | FileCheck %s --check-prefix=CHECK-SPIRV
3+
// RUN: llvm-spirv %t.bc -o %t.spv
4+
// RUN: spirv-val %t.spv
5+
// RUN: llvm-spirv -r %t.spv -o %t.rev.bc
6+
// RUN: llvm-dis < %t.rev.bc | FileCheck %s --check-prefix=CHECK-LLVM
7+
// RUN: llvm-spirv -r --spirv-target-env=SPV-IR %t.spv -o %t.rev.bc
8+
// RUN: llvm-dis < %t.rev.bc | FileCheck %s --check-prefix=CHECK-SPV-IR
9+
10+
// CHECK-SPIRV: SatConvertSToU
11+
12+
// CHECK-LLVM-LABEL: @testSToU
13+
// CHECK-LLVM: call spir_func <2 x i8> @_Z18convert_uchar2_satDv2_i
14+
15+
// CHECK-SPV-IR-LABEL: @testSToU
16+
// CHECK-SPV-IR: call spir_func <2 x i8> @_Z30__spirv_SatConvertSToU_Ruchar2Dv2_i
17+
18+
kernel void testSToU(global int2 *a, global uchar2 *res) {
19+
res[0] = convert_uchar2_sat(*a);
20+
}
21+
22+
// CHECK-SPIRV: SatConvertUToS
23+
24+
// CHECK-LLVM-LABEL: @testUToS
25+
// CHECK-LLVM: call spir_func <2 x i8> @_Z17convert_char2_satDv2_j
26+
27+
// CHECK-SPV-IR-LABEL: @testUToS
28+
// CHECK-SPV-IR: call spir_func <2 x i8> @_Z29__spirv_SatConvertUToS_Rchar2Dv2_j
29+
kernel void testUToS(global uint2 *a, global char2 *res) {
30+
res[0] = convert_char2_sat(*a);
31+
}
32+
33+
// CHECK-SPIRV: ConvertUToF
34+
35+
// CHECK-LLVM-LABEL: @testUToF
36+
// CHECK-LLVM: call spir_func <2 x float> @_Z18convert_float2_rtzDv2_j
37+
38+
// CHECK-SPV-IR-LABEL: @testUToF
39+
// CHECK-SPV-IR: call spir_func <2 x float> @_Z31__spirv_ConvertUToF_Rfloat2_rtzDv2_j
40+
kernel void testUToF(global uint2 *a, global float2 *res) {
41+
res[0] = convert_float2_rtz(*a);
42+
}
43+
44+
// CHECK-SPIRV: ConvertFToU
45+
46+
// CHECK-LLVM-LABEL: @testFToUSat
47+
// CHECK-LLVM: call spir_func <2 x i32> @_Z21convert_uint2_sat_rtnDv2_f
48+
49+
// CHECK-SPV-IR-LABEL: @testFToUSat
50+
// CHECK-SPV-IR: call spir_func <2 x i32> @_Z34__spirv_ConvertFToU_Ruint2_sat_rtnDv2_f
51+
kernel void testFToUSat(global float2 *a, global uint2 *res) {
52+
res[0] = convert_uint2_sat_rtn(*a);
53+
}
54+
55+
// CHECK-SPIRV: UConvert
56+
57+
// CHECK-LLVM-LABEL: @testUToUSat
58+
// CHECK-LLVM: call spir_func i32 @_Z16convert_uint_sath
59+
60+
// CHECK-SPV-IR-LABEL: @testUToUSat
61+
// CHECK-SPV-IR: call spir_func i32 @_Z26__spirv_UConvert_Ruint_sath
62+
kernel void testUToUSat(global uchar *a, global uint *res) {
63+
res[0] = convert_uint_sat(*a);
64+
}
65+
66+
// CHECK-SPIRV: UConvert
67+
68+
// CHECK-LLVM-LABEL: @testUToUSat1
69+
// CHECK-LLVM: call spir_func i8 @_Z17convert_uchar_satj
70+
71+
// CHECK-SPV-IR-LABEL: @testUToUSat1
72+
// CHECK-SPV-IR: call spir_func i8 @_Z27__spirv_UConvert_Ruchar_satj
73+
kernel void testUToUSat1(global uint *a, global uchar *res) {
74+
res[0] = convert_uchar_sat(*a);
75+
}
76+
77+
// CHECK-SPIRV: ConvertFToU
78+
79+
// CHECK-LLVM-LABEL: @testFToU
80+
// CHECK-LLVM: call spir_func <3 x i32> @_Z17convert_uint3_rtpDv3_f
81+
82+
// CHECK-SPV-IR-LABEL: @testFToU
83+
// CHECK-SPV-IR: call spir_func <3 x i32> @_Z30__spirv_ConvertFToU_Ruint3_rtpDv3_f
84+
kernel void testFToU(global float3 *a, global uint3 *res) {
85+
res[0] = convert_uint3_rtp(*a);
86+
}

0 commit comments

Comments
 (0)