5252#include " llvm/Support/Debug.h"
5353
5454#include < algorithm>
55+ #include < regex>
5556#include < set>
5657
5758using namespace llvm ;
@@ -724,6 +725,13 @@ void OCLToSPIRVBase::visitCallBarrier(CallInst *CI) {
724725
725726void OCLToSPIRVBase::visitCallConvert (CallInst *CI, StringRef MangledName,
726727 StringRef DemangledName) {
728+ // OpenCL Explicit Conversions (6.4.3) formed as below for scalars:
729+ // destType convert_destType<_sat><_roundingMode>(sourceType)
730+ // and for vector type:
731+ // destTypeN convert_destTypeN<_sat><_roundingMode>(sourceTypeN)
732+ // If the demangled name is not matching the suggested pattern and does not
733+ // meet allowed destination type restrictions - this is not an OpenCL builtin,
734+ // return from the function and translate such CallInst as a function call.
727735 if (eraseUselessConvert (CI, MangledName, DemangledName))
728736 return ;
729737 Op OC = OpNop;
@@ -734,16 +742,56 @@ void OCLToSPIRVBase::visitCallConvert(CallInst *CI, StringRef MangledName,
734742 if (auto *VecTy = dyn_cast<VectorType>(SrcTy))
735743 SrcTy = VecTy->getElementType ();
736744 auto IsTargetInt = isa<IntegerType>(TargetTy);
745+ auto TargetSigned = DemangledName[8 ] != ' u' ;
737746
738747 std::string TargetTyName (
739748 DemangledName.substr (strlen (kOCLBuiltinName ::ConvertPrefix)));
740749 auto FirstUnderscoreLoc = TargetTyName.find (' _' );
741750 if (FirstUnderscoreLoc != std::string::npos)
742751 TargetTyName = TargetTyName.substr (0 , FirstUnderscoreLoc);
752+
753+ // Validate target type name
754+ std::regex Expr (" ([a-z]+)([0-9]*)$" );
755+ std::smatch DestTyMatch;
756+ if (!std::regex_match (TargetTyName, DestTyMatch, Expr))
757+ return ;
758+
759+ // The first sub_match is the whole string; the next
760+ // sub_match is the first parenthesized expression.
761+ std::string DestTy = DestTyMatch[1 ].str ();
762+
763+ // check it's valid type name
764+ static std::unordered_set<std::string> ValidTypes = {
765+ " float" , " double" , " half" , " char" , " uchar" , " short" ,
766+ " ushort" , " int" , " uint" , " long" , " ulong" };
767+
768+ if (ValidTypes.find (DestTy) == ValidTypes.end ())
769+ return ;
770+
771+ // check that it's allowed vector size
772+ std::string VecSize = DestTyMatch[2 ].str ();
773+ if (!VecSize.empty ()) {
774+ int Size = stoi (VecSize);
775+ switch (Size) {
776+ case 2 :
777+ case 3 :
778+ case 4 :
779+ case 8 :
780+ case 16 :
781+ break ;
782+ default :
783+ return ;
784+ }
785+ }
786+ DemangledName = DemangledName.drop_front (
787+ strlen (kOCLBuiltinName ::ConvertPrefix) + TargetTyName.size ());
743788 TargetTyName = std::string (" _R" ) + TargetTyName;
744789
790+ if (!DemangledName.empty () && !DemangledName.starts_with (" _sat" ) &&
791+ !DemangledName.starts_with (" _rt" ))
792+ return ;
793+
745794 std::string Sat = DemangledName.find (" _sat" ) != StringRef::npos ? " _sat" : " " ;
746- auto TargetSigned = DemangledName[8 ] != ' u' ;
747795 if (isa<IntegerType>(SrcTy)) {
748796 bool Signed = isLastFuncParamSigned (MangledName);
749797 if (IsTargetInt) {
0 commit comments