99// ===----------------------------------------------------------------------===//
1010
1111#include " clang/Sema/SemaHLSL.h"
12- #include " clang/AST/ASTContext.h"
1312#include " clang/AST/Decl.h"
1413#include " clang/AST/DeclBase.h"
1514#include " clang/AST/DeclCXX.h"
2827#include " llvm/ADT/SmallVector.h"
2928#include " llvm/ADT/StringExtras.h"
3029#include " llvm/ADT/StringRef.h"
31- #include " llvm/IR/DerivedTypes.h"
32- #include " llvm/IR/Type.h"
3330#include " llvm/Support/Casting.h"
3431#include " llvm/Support/DXILABI.h"
3532#include " llvm/Support/ErrorHandling.h"
@@ -1468,6 +1465,25 @@ bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
14681465 return true ;
14691466}
14701467
1468+ bool CheckArgTypeWithoutImplicits (
1469+ Sema *S, Expr *Arg, QualType ExpectedType,
1470+ llvm::function_ref<bool (clang::QualType PassedType)> Check) {
1471+
1472+ QualType ArgTy = Arg->IgnoreImpCasts ()->getType ();
1473+
1474+ clang::QualType BaseType =
1475+ ArgTy->isVectorType ()
1476+ ? ArgTy->getAs <clang::VectorType>()->getElementType ()
1477+ : ArgTy;
1478+
1479+ if (Check (BaseType)) {
1480+ S->Diag (Arg->getBeginLoc (), diag::err_typecheck_convert_incompatible)
1481+ << ArgTy << ExpectedType << 1 << 0 << 0 ;
1482+ return true ;
1483+ }
1484+ return false ;
1485+ }
1486+
14711487bool CheckArgsTypesAreCorrect (
14721488 Sema *S, CallExpr *TheCall, QualType ExpectedType,
14731489 llvm::function_ref<bool (clang::QualType PassedType)> Check) {
@@ -1494,6 +1510,14 @@ bool CheckAllArgsHaveFloatRepresentation(Sema *S, CallExpr *TheCall) {
14941510 checkAllFloatTypes);
14951511}
14961512
1513+ bool CheckArgIsFloatOrIntWithoutImplicits (Sema *S, Expr *Arg) {
1514+ auto checkFloat = [](clang::QualType PassedType) -> bool {
1515+ return !PassedType->isFloat32Type () && !PassedType->isIntegerType ();
1516+ };
1517+
1518+ return CheckArgTypeWithoutImplicits (S, Arg, S->Context .FloatTy , checkFloat);
1519+ }
1520+
14971521bool CheckFloatOrHalfRepresentations (Sema *S, CallExpr *TheCall) {
14981522 auto checkFloatorHalf = [](clang::QualType PassedType) -> bool {
14991523 clang::QualType BaseType =
@@ -1652,16 +1676,9 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
16521676 if (SemaRef.checkArgCount (TheCall, 1 ))
16531677 return true ;
16541678
1655- ExprResult A = TheCall->getArg (0 );
1656- QualType ArgTyA = A.get ()->getType ();
1657-
1658- if (ArgTyA->isVectorType ()){
1659- auto VecTy = TheCall->getArg (0 )->getType ()->getAs <VectorType>();
1660- auto ReturnType = this ->getASTContext ().getVectorType (TheCall->getCallReturnType (this ->getASTContext ()), VecTy->getNumElements (),
1661- VectorKind::Generic);
1662-
1663- TheCall->setType (ReturnType);
1664- }
1679+ Expr *Arg = TheCall->getArg (0 );
1680+ if (CheckArgIsFloatOrIntWithoutImplicits (&SemaRef, Arg))
1681+ return true ;
16651682
16661683 break ;
16671684 }
0 commit comments