99// ===----------------------------------------------------------------------===//
1010
1111#include " clang/Sema/SemaHLSL.h"
12- #include " clang/AST/ASTContext.h"
1312#include " clang/AST/Decl.h"
1413#include " clang/AST/DeclBase.h"
1514#include " clang/AST/DeclCXX.h"
2928#include " llvm/ADT/SmallVector.h"
3029#include " llvm/ADT/StringExtras.h"
3130#include " llvm/ADT/StringRef.h"
32- #include " llvm/IR/DerivedTypes.h"
33- #include " llvm/IR/Type.h"
3431#include " llvm/Support/Casting.h"
3532#include " llvm/Support/DXILABI.h"
3633#include " llvm/Support/ErrorHandling.h"
@@ -1473,6 +1470,25 @@ bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
14731470 return true ;
14741471}
14751472
1473+ bool CheckArgTypeWithoutImplicits (
1474+ Sema *S, Expr *Arg, QualType ExpectedType,
1475+ llvm::function_ref<bool (clang::QualType PassedType)> Check) {
1476+
1477+ QualType ArgTy = Arg->IgnoreImpCasts ()->getType ();
1478+
1479+ clang::QualType BaseType =
1480+ ArgTy->isVectorType ()
1481+ ? ArgTy->getAs <clang::VectorType>()->getElementType ()
1482+ : ArgTy;
1483+
1484+ if (Check (BaseType)) {
1485+ S->Diag (Arg->getBeginLoc (), diag::err_typecheck_convert_incompatible)
1486+ << ArgTy << ExpectedType << 1 << 0 << 0 ;
1487+ return true ;
1488+ }
1489+ return false ;
1490+ }
1491+
14761492bool CheckArgsTypesAreCorrect (
14771493 Sema *S, CallExpr *TheCall, QualType ExpectedType,
14781494 llvm::function_ref<bool (clang::QualType PassedType)> Check) {
@@ -1499,6 +1515,14 @@ bool CheckAllArgsHaveFloatRepresentation(Sema *S, CallExpr *TheCall) {
14991515 checkAllFloatTypes);
15001516}
15011517
1518+ bool CheckArgIsFloatOrIntWithoutImplicits (Sema *S, Expr *Arg) {
1519+ auto checkFloat = [](clang::QualType PassedType) -> bool {
1520+ return !PassedType->isFloat32Type () && !PassedType->isIntegerType ();
1521+ };
1522+
1523+ return CheckArgTypeWithoutImplicits (S, Arg, S->Context .FloatTy , checkFloat);
1524+ }
1525+
15021526bool CheckFloatOrHalfRepresentations (Sema *S, CallExpr *TheCall) {
15031527 auto checkFloatorHalf = [](clang::QualType PassedType) -> bool {
15041528 clang::QualType BaseType =
@@ -1760,16 +1784,9 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
17601784 if (SemaRef.checkArgCount (TheCall, 1 ))
17611785 return true ;
17621786
1763- ExprResult A = TheCall->getArg (0 );
1764- QualType ArgTyA = A.get ()->getType ();
1765-
1766- if (ArgTyA->isVectorType ()){
1767- auto VecTy = TheCall->getArg (0 )->getType ()->getAs <VectorType>();
1768- auto ReturnType = this ->getASTContext ().getVectorType (TheCall->getCallReturnType (this ->getASTContext ()), VecTy->getNumElements (),
1769- VectorKind::Generic);
1770-
1771- TheCall->setType (ReturnType);
1772- }
1787+ Expr *Arg = TheCall->getArg (0 );
1788+ if (CheckArgIsFloatOrIntWithoutImplicits (&SemaRef, Arg))
1789+ return true ;
17731790
17741791 break ;
17751792 }
0 commit comments