Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1386,6 +1386,11 @@ static bool generateSampleImageInst(const StringRef DemangledCall,
ReturnType = ReturnType.substr(0, ReturnType.find('('));
}
SPIRVType *Type = GR->getOrCreateSPIRVTypeByName(ReturnType, MIRBuilder);
if (!Type) {
std::string DiagMsg =
"Unable to recognize SPIRV type name: " + ReturnType;
report_fatal_error(DiagMsg.c_str());
}
MRI->setRegClass(Call->Arguments[0], &SPIRV::IDRegClass);
MRI->setRegClass(Call->Arguments[1], &SPIRV::IDRegClass);
MRI->setRegClass(Call->Arguments[3], &SPIRV::IDRegClass);
Expand Down
32 changes: 16 additions & 16 deletions llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,22 +157,22 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
isSpecialOpaqueType(OriginalArgType))
return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual);

MDString *MDKernelArgType = getOCLKernelArgType(F, ArgIdx);
if (!MDKernelArgType || (!MDKernelArgType->getString().ends_with("*") &&
!MDKernelArgType->getString().ends_with("_t")))
return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual);

if (MDKernelArgType->getString().ends_with("*"))
return GR->getOrCreateSPIRVTypeByName(
MDKernelArgType->getString(), MIRBuilder,
addressSpaceToStorageClass(OriginalArgType->getPointerAddressSpace()));

if (MDKernelArgType->getString().ends_with("_t"))
return GR->getOrCreateSPIRVTypeByName(
"opencl." + MDKernelArgType->getString().str(), MIRBuilder,
SPIRV::StorageClass::Function, ArgAccessQual);

llvm_unreachable("Unable to recognize argument type name.");
SPIRVType *ResArgType = nullptr;
if (MDString *MDKernelArgType = getOCLKernelArgType(F, ArgIdx)) {
StringRef MDTypeStr = MDKernelArgType->getString();
if (MDTypeStr.ends_with("*"))
ResArgType = GR->getOrCreateSPIRVTypeByName(
MDTypeStr, MIRBuilder,
addressSpaceToStorageClass(
OriginalArgType->getPointerAddressSpace()));
else if (MDTypeStr.ends_with("_t"))
ResArgType = GR->getOrCreateSPIRVTypeByName(
"opencl." + MDTypeStr.str(), MIRBuilder,
SPIRV::StorageClass::Function, ArgAccessQual);
}
return ResArgType ? ResArgType
: GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder,
ArgAccessQual);
}

static bool isEntryPoint(const Function &F) {
Expand Down
12 changes: 8 additions & 4 deletions llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -443,8 +443,9 @@ Register SPIRVGlobalRegistry::buildConstantSampler(
SPIRVType *SampTy;
if (SpvType)
SampTy = getOrCreateSPIRVType(getTypeForSPIRVType(SpvType), MIRBuilder);
else
SampTy = getOrCreateSPIRVTypeByName("opencl.sampler_t", MIRBuilder);
else if ((SampTy = getOrCreateSPIRVTypeByName("opencl.sampler_t",
MIRBuilder)) == nullptr)
report_fatal_error("Unable to recognize SPIRV type name: opencl.sampler_t");

auto Sampler =
ResReg.isValid()
Expand Down Expand Up @@ -941,6 +942,7 @@ SPIRVGlobalRegistry::checkSpecialInstr(const SPIRV::SpecialTypeDescriptor &TD,
return nullptr;
}

// Returns nullptr if unable to recognize SPIRV type name
// TODO: maybe use tablegen to implement this.
SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVTypeByName(
StringRef TypeStr, MachineIRBuilder &MIRBuilder,
Expand Down Expand Up @@ -992,8 +994,10 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVTypeByName(
} else if (TypeStr.starts_with("double")) {
Ty = Type::getDoubleTy(Ctx);
TypeStr = TypeStr.substr(strlen("double"));
} else
llvm_unreachable("Unable to recognize SPIRV type name.");
} else {
// Unable to recognize SPIRV type name
return nullptr;
}

auto SpirvTy = getOrCreateSPIRVType(Ty, MIRBuilder, AQ);

Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ class SPIRVGlobalRegistry {

// Either generate a new OpTypeXXX instruction or return an existing one
// corresponding to the given string containing the name of the builtin type.
// Return nullptr if unable to recognize SPIRV type name from `TypeStr`.
SPIRVType *getOrCreateSPIRVTypeByName(
StringRef TypeStr, MachineIRBuilder &MIRBuilder,
SPIRV::StorageClass::StorageClass SC = SPIRV::StorageClass::Function,
Expand Down
34 changes: 34 additions & 0 deletions llvm/test/CodeGen/SPIRV/pointers/custom-kernel-arg-type.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}

; CHECK: %[[TyInt:.*]] = OpTypeInt 8 0
; CHECK: %[[TyPtr:.*]] = OpTypePointer {{[a-zA-Z]+}} %[[TyInt]]
; CHECK: OpFunctionParameter %[[TyPtr]]
; CHECK: OpFunctionParameter %[[TyPtr]]

%struct.my_kernel_data = type { i32, i32, i32, i32, i32 }
%struct.my_struct = type { i32, i32 }

define spir_kernel void @test(ptr addrspace(1) %in, ptr addrspace(1) %outData) !kernel_arg_type !5 {
entry:
ret void
}

!llvm.module.flags = !{!0}
!opencl.enable.FP_CONTRACT = !{}
!opencl.ocl.version = !{!1}
!opencl.spir.version = !{!2}
!opencl.used.extensions = !{!3}
!opencl.used.optional.core.features = !{!3}
!opencl.compiler.options = !{!3}
!llvm.ident = !{!4}
!opencl.kernels = !{!6}

!0 = !{i32 1, !"wchar_size", i32 4}
!1 = !{i32 1, i32 0}
!2 = !{i32 1, i32 2}
!3 = !{}
!4 = !{!"clang version 6.0.0"}
!5 = !{!"my_kernel_data*", !"struct my_struct*"}
!6 = !{ptr @test}