diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp index c85bd27d256b2..e4593e7db90e8 100644 --- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp @@ -1386,6 +1386,11 @@ static bool generateSampleImageInst(const StringRef DemangledCall, ReturnType = ReturnType.substr(0, ReturnType.find('(')); } SPIRVType *Type = GR->getOrCreateSPIRVTypeByName(ReturnType, MIRBuilder); + if (!Type) { + std::string DiagMsg = + "Unable to recognize SPIRV type name: " + ReturnType; + report_fatal_error(DiagMsg.c_str()); + } MRI->setRegClass(Call->Arguments[0], &SPIRV::IDRegClass); MRI->setRegClass(Call->Arguments[1], &SPIRV::IDRegClass); MRI->setRegClass(Call->Arguments[3], &SPIRV::IDRegClass); diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp index 62c08bab46eee..97b25147ffb34 100644 --- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp @@ -157,22 +157,22 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx, isSpecialOpaqueType(OriginalArgType)) return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual); - MDString *MDKernelArgType = getOCLKernelArgType(F, ArgIdx); - if (!MDKernelArgType || (!MDKernelArgType->getString().ends_with("*") && - !MDKernelArgType->getString().ends_with("_t"))) - return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual); - - if (MDKernelArgType->getString().ends_with("*")) - return GR->getOrCreateSPIRVTypeByName( - MDKernelArgType->getString(), MIRBuilder, - addressSpaceToStorageClass(OriginalArgType->getPointerAddressSpace())); - - if (MDKernelArgType->getString().ends_with("_t")) - return GR->getOrCreateSPIRVTypeByName( - "opencl." + MDKernelArgType->getString().str(), MIRBuilder, - SPIRV::StorageClass::Function, ArgAccessQual); - - llvm_unreachable("Unable to recognize argument type name."); + SPIRVType *ResArgType = nullptr; + if (MDString *MDKernelArgType = getOCLKernelArgType(F, ArgIdx)) { + StringRef MDTypeStr = MDKernelArgType->getString(); + if (MDTypeStr.ends_with("*")) + ResArgType = GR->getOrCreateSPIRVTypeByName( + MDTypeStr, MIRBuilder, + addressSpaceToStorageClass( + OriginalArgType->getPointerAddressSpace())); + else if (MDTypeStr.ends_with("_t")) + ResArgType = GR->getOrCreateSPIRVTypeByName( + "opencl." + MDTypeStr.str(), MIRBuilder, + SPIRV::StorageClass::Function, ArgAccessQual); + } + return ResArgType ? ResArgType + : GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, + ArgAccessQual); } static bool isEntryPoint(const Function &F) { diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp index 6c009b9e8ddef..f2c27467c34b4 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp @@ -443,8 +443,9 @@ Register SPIRVGlobalRegistry::buildConstantSampler( SPIRVType *SampTy; if (SpvType) SampTy = getOrCreateSPIRVType(getTypeForSPIRVType(SpvType), MIRBuilder); - else - SampTy = getOrCreateSPIRVTypeByName("opencl.sampler_t", MIRBuilder); + else if ((SampTy = getOrCreateSPIRVTypeByName("opencl.sampler_t", + MIRBuilder)) == nullptr) + report_fatal_error("Unable to recognize SPIRV type name: opencl.sampler_t"); auto Sampler = ResReg.isValid() @@ -941,6 +942,7 @@ SPIRVGlobalRegistry::checkSpecialInstr(const SPIRV::SpecialTypeDescriptor &TD, return nullptr; } +// Returns nullptr if unable to recognize SPIRV type name // TODO: maybe use tablegen to implement this. SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVTypeByName( StringRef TypeStr, MachineIRBuilder &MIRBuilder, @@ -992,8 +994,10 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVTypeByName( } else if (TypeStr.starts_with("double")) { Ty = Type::getDoubleTy(Ctx); TypeStr = TypeStr.substr(strlen("double")); - } else - llvm_unreachable("Unable to recognize SPIRV type name."); + } else { + // Unable to recognize SPIRV type name + return nullptr; + } auto SpirvTy = getOrCreateSPIRVType(Ty, MIRBuilder, AQ); diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h index 60967bfb68a87..f3280928c25df 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h @@ -138,6 +138,7 @@ class SPIRVGlobalRegistry { // Either generate a new OpTypeXXX instruction or return an existing one // corresponding to the given string containing the name of the builtin type. + // Return nullptr if unable to recognize SPIRV type name from `TypeStr`. SPIRVType *getOrCreateSPIRVTypeByName( StringRef TypeStr, MachineIRBuilder &MIRBuilder, SPIRV::StorageClass::StorageClass SC = SPIRV::StorageClass::Function, diff --git a/llvm/test/CodeGen/SPIRV/pointers/custom-kernel-arg-type.ll b/llvm/test/CodeGen/SPIRV/pointers/custom-kernel-arg-type.ll new file mode 100644 index 0000000000000..4593fad783c60 --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/pointers/custom-kernel-arg-type.ll @@ -0,0 +1,34 @@ +; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %} + +; CHECK: %[[TyInt:.*]] = OpTypeInt 8 0 +; CHECK: %[[TyPtr:.*]] = OpTypePointer {{[a-zA-Z]+}} %[[TyInt]] +; CHECK: OpFunctionParameter %[[TyPtr]] +; CHECK: OpFunctionParameter %[[TyPtr]] + +%struct.my_kernel_data = type { i32, i32, i32, i32, i32 } +%struct.my_struct = type { i32, i32 } + +define spir_kernel void @test(ptr addrspace(1) %in, ptr addrspace(1) %outData) !kernel_arg_type !5 { +entry: + ret void +} + +!llvm.module.flags = !{!0} +!opencl.enable.FP_CONTRACT = !{} +!opencl.ocl.version = !{!1} +!opencl.spir.version = !{!2} +!opencl.used.extensions = !{!3} +!opencl.used.optional.core.features = !{!3} +!opencl.compiler.options = !{!3} +!llvm.ident = !{!4} +!opencl.kernels = !{!6} + +!0 = !{i32 1, !"wchar_size", i32 4} +!1 = !{i32 1, i32 0} +!2 = !{i32 1, i32 2} +!3 = !{} +!4 = !{!"clang version 6.0.0"} +!5 = !{!"my_kernel_data*", !"struct my_struct*"} +!6 = !{ptr @test} +