Skip to content

Commit 84b2e22

Browse files
committed
Add tests for call directly to builtin
Add more robustness to SemaChecking
1 parent fad7e6e commit 84b2e22

File tree

6 files changed

+376
-83
lines changed

6 files changed

+376
-83
lines changed

clang/include/clang/Basic/Builtins.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4526,7 +4526,7 @@ def HLSLCreateHandle : LangBuiltin<"HLSL_LANG"> {
45264526

45274527
def HLSLDotProduct : LangBuiltin<"HLSL_LANG"> {
45284528
let Spellings = ["__builtin_hlsl_dot"];
4529-
let Attributes = [NoThrow, Const, CustomTypeChecking];
4529+
let Attributes = [NoThrow, Const];
45304530
let Prototype = "void(...)";
45314531
}
45324532

clang/include/clang/Sema/Sema.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14123,6 +14123,8 @@ class Sema final {
1412314123

1412414124
bool CheckPPCMMAType(QualType Type, SourceLocation TypeLoc);
1412514125

14126+
bool SemaBuiltinVectorMath(CallExpr *TheCall, QualType &Res);
14127+
bool SemaBuiltinVectorToScalarMath(CallExpr *TheCall);
1412614128
bool SemaBuiltinElementwiseMath(CallExpr *TheCall);
1412714129
bool SemaBuiltinElementwiseTernaryMath(CallExpr *TheCall);
1412814130
bool PrepareBuiltinElementwiseMathOneArgCall(CallExpr *TheCall);

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17977,26 +17977,28 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
1797717977
llvm::Type *T1 = Op1->getType();
1797817978
if (!T0->isVectorTy() && !T1->isVectorTy()) {
1797917979
if (T0->isFloatingPointTy()) {
17980-
return Builder.CreateFMul(Op0, Op1, "dx.dot");
17980+
return Builder.CreateFMul(Op0, Op1, "dx.dot");
1798117981
}
1798217982

1798317983
if (T0->isIntegerTy()) {
17984-
return Builder.CreateMul(Op0, Op1, "dx.dot");
17984+
return Builder.CreateMul(Op0, Op1, "dx.dot");
1798517985
}
17986+
// Bools should have been promoted
1798617987
assert(
1798717988
false &&
1798817989
"Dot product on a scalar is only supported on integers and floats.");
1798917990
}
17991+
// A VectorSplat should have happened
1799017992
assert(T0->isVectorTy() && T1->isVectorTy() &&
1799117993
"Dot product of vector and scalar is not supported.");
1799217994

17993-
// NOTE: this assert will need to be revisited after overload resoltion
17994-
// PR merges.
17995+
// A vector sext or sitofp should have happened
1799517996
assert(T0->getScalarType() == T1->getScalarType() &&
1799617997
"Dot product of vectors need the same element types.");
1799717998

1799817999
auto *VecTy0 = E->getArg(0)->getType()->getAs<VectorType>();
1799918000
auto *VecTy1 = E->getArg(1)->getType()->getAs<VectorType>();
18001+
// A HLSLVectorTruncation should have happend
1800018002
assert(VecTy0->getNumElements() == VecTy1->getNumElements() &&
1800118003
"Dot product requires vectors to be of the same size.");
1800218004

clang/lib/Sema/SemaChecking.cpp

Lines changed: 172 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -5166,69 +5166,167 @@ bool Sema::CheckPPCMMAType(QualType Type, SourceLocation TypeLoc) {
51665166
return false;
51675167
}
51685168

5169-
// Note: returning true in this case results in CheckBuiltinFunctionCall
5170-
// returning an ExprError
5171-
bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
5172-
switch (BuiltinID) {
5173-
case Builtin::BI__builtin_hlsl_dot: {
5174-
if (checkArgCount(*this, TheCall, 2)) {
5169+
bool PromoteBoolsToInt(Sema *S, CallExpr *TheCall) {
5170+
unsigned NumArgs = TheCall->getNumArgs();
5171+
5172+
for (unsigned i = 0; i < NumArgs; ++i) {
5173+
ExprResult A = TheCall->getArg(i);
5174+
if (!A.get()->getType()->isBooleanType())
5175+
return false;
5176+
}
5177+
// if we got here all args are bool
5178+
for (unsigned i = 0; i < NumArgs; ++i) {
5179+
ExprResult A = TheCall->getArg(i);
5180+
ExprResult ResA = S->PerformImplicitConversion(A.get(), S->Context.IntTy,
5181+
Sema::AA_Converting);
5182+
if (ResA.isInvalid())
51755183
return true;
5176-
}
5177-
Expr *Arg0 = TheCall->getArg(0);
5178-
QualType ArgTy0 = Arg0->getType();
5184+
TheCall->setArg(0, ResA.get());
5185+
}
5186+
return false;
5187+
}
51795188

5180-
Expr *Arg1 = TheCall->getArg(1);
5181-
QualType ArgTy1 = Arg1->getType();
5189+
int overloadOrder(Sema *S, QualType ArgTyA) {
5190+
auto kind = ArgTyA->getAs<BuiltinType>()->getKind();
5191+
switch (kind) {
5192+
case BuiltinType::Short:
5193+
case BuiltinType::UShort:
5194+
return 1;
5195+
case BuiltinType::Int:
5196+
case BuiltinType::UInt:
5197+
return 2;
5198+
case BuiltinType::Long:
5199+
case BuiltinType::ULong:
5200+
return 3;
5201+
case BuiltinType::LongLong:
5202+
case BuiltinType::ULongLong:
5203+
return 4;
5204+
case BuiltinType::Float16:
5205+
case BuiltinType::Half:
5206+
return 5;
5207+
case BuiltinType::Float:
5208+
return 6;
5209+
default:
5210+
break;
5211+
}
5212+
return 0;
5213+
}
51825214

5183-
auto *VecTy0 = ArgTy0->getAs<VectorType>();
5184-
auto *VecTy1 = ArgTy1->getAs<VectorType>();
5185-
SourceLocation BuiltinLoc = TheCall->getBeginLoc();
5215+
QualType getVecLargestBitness(Sema *S, QualType ArgTyA, QualType ArgTyB) {
5216+
auto *VecTyA = ArgTyA->getAs<VectorType>();
5217+
auto *VecTyB = ArgTyB->getAs<VectorType>();
5218+
QualType VecTyAElem = VecTyA->getElementType();
5219+
QualType VecTyBElem = VecTyB->getElementType();
5220+
int vecAElemWidth = overloadOrder(S, VecTyAElem);
5221+
int vecBElemWidth = overloadOrder(S, VecTyBElem);
5222+
return vecAElemWidth > vecBElemWidth ? ArgTyA : ArgTyB;
5223+
}
51865224

5187-
// if arg0 is bool then call Diag with err_builtin_invalid_arg_type
5188-
if (checkMathBuiltinElementType(*this, Arg0->getBeginLoc(), ArgTy0, 1)) {
5189-
return true;
5190-
}
5225+
void PromoteVectorArgTruncation(Sema *S, CallExpr *TheCall) {
5226+
assert(TheCall->getNumArgs() > 1);
5227+
ExprResult A = TheCall->getArg(0);
5228+
ExprResult B = TheCall->getArg(1);
5229+
QualType ArgTyA = A.get()->getType();
5230+
QualType ArgTyB = B.get()->getType();
51915231

5192-
// if arg1 is bool then call Diag with err_builtin_invalid_arg_type
5193-
if (checkMathBuiltinElementType(*this, Arg1->getBeginLoc(), ArgTy1, 2)) {
5194-
return true;
5232+
auto *VecTyA = ArgTyA->getAs<VectorType>();
5233+
auto *VecTyB = ArgTyB->getAs<VectorType>();
5234+
if (VecTyA == nullptr && VecTyB == nullptr)
5235+
return;
5236+
if (VecTyA == nullptr || VecTyB == nullptr)
5237+
return;
5238+
if (VecTyA->getNumElements() == VecTyB->getNumElements())
5239+
return;
5240+
5241+
Expr *LargerArg = B.get();
5242+
Expr *SmallerArg = A.get();
5243+
int largerIndex = 1;
5244+
if (VecTyA->getNumElements() > VecTyB->getNumElements()) {
5245+
LargerArg = A.get();
5246+
SmallerArg = B.get();
5247+
largerIndex = 0;
5248+
}
5249+
S->Diag(TheCall->getExprLoc(), diag::warn_hlsl_impcast_vector_truncation)
5250+
<< LargerArg->getType() << SmallerArg->getType()
5251+
<< LargerArg->getSourceRange() << SmallerArg->getSourceRange();
5252+
ExprResult ResLargerArg = S->ImpCastExprToType(
5253+
LargerArg, SmallerArg->getType(), CK_HLSLVectorTruncation);
5254+
TheCall->setArg(largerIndex, ResLargerArg.get());
5255+
return;
5256+
}
5257+
5258+
bool PromoteVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
5259+
assert(TheCall->getNumArgs() > 1);
5260+
ExprResult A = TheCall->getArg(0);
5261+
ExprResult B = TheCall->getArg(1);
5262+
QualType ArgTyA = A.get()->getType();
5263+
QualType ArgTyB = B.get()->getType();
5264+
5265+
auto *VecTyA = ArgTyA->getAs<VectorType>();
5266+
auto *VecTyB = ArgTyB->getAs<VectorType>();
5267+
if (VecTyA == nullptr && VecTyB == nullptr)
5268+
return false;
5269+
if (VecTyA && VecTyB) {
5270+
if (VecTyA->getElementType() == VecTyB->getElementType()) {
5271+
TheCall->setType(VecTyA->getElementType());
5272+
return false;
5273+
}
5274+
SourceLocation BuiltinLoc = TheCall->getBeginLoc();
5275+
QualType CastType = getVecLargestBitness(S, ArgTyA, ArgTyB);
5276+
if (CastType == ArgTyA) {
5277+
ExprResult ResB = S->SemaConvertVectorExpr(
5278+
B.get(), S->Context.CreateTypeSourceInfo(ArgTyA), BuiltinLoc,
5279+
B.get()->getBeginLoc());
5280+
TheCall->setArg(1, ResB.get());
5281+
TheCall->setType(VecTyA->getElementType());
5282+
return false;
51955283
}
51965284

5197-
if (VecTy0 == nullptr && VecTy1 == nullptr) {
5198-
if (ArgTy0 != ArgTy1) {
5199-
return true;
5200-
} else {
5201-
return false;
5202-
}
5285+
if (CastType == ArgTyB) {
5286+
ExprResult ResA = S->SemaConvertVectorExpr(
5287+
A.get(), S->Context.CreateTypeSourceInfo(ArgTyB), BuiltinLoc,
5288+
A.get()->getBeginLoc());
5289+
TheCall->setArg(0, ResA.get());
5290+
TheCall->setType(VecTyB->getElementType());
5291+
return false;
52035292
}
5293+
return false;
5294+
}
52045295

5205-
if ((VecTy0 == nullptr && VecTy1 != nullptr) ||
5206-
(VecTy0 != nullptr && VecTy1 == nullptr)) {
5296+
if (VecTyB) {
5297+
// Convert to the vector result type
5298+
ExprResult ResA = A;
5299+
if (VecTyB->getElementType() != ArgTyA)
5300+
ResA = S->ImpCastExprToType(ResA.get(), VecTyB->getElementType(),
5301+
CK_FloatingCast);
5302+
ResA = S->ImpCastExprToType(ResA.get(), ArgTyB, CK_VectorSplat);
5303+
TheCall->setArg(0, ResA.get());
5304+
}
5305+
if (VecTyA) {
5306+
ExprResult ResB = B;
5307+
if (VecTyA->getElementType() != ArgTyB)
5308+
ResB = S->ImpCastExprToType(ResB.get(), VecTyA->getElementType(),
5309+
CK_FloatingCast);
5310+
ResB = S->ImpCastExprToType(ResB.get(), ArgTyA, CK_VectorSplat);
5311+
TheCall->setArg(1, ResB.get());
5312+
}
5313+
return false;
5314+
}
52075315

5208-
Diag(BuiltinLoc, diag::err_vec_builtin_non_vector)
5209-
<< TheCall->getDirectCallee()
5210-
<< SourceRange(TheCall->getArg(0)->getBeginLoc(),
5211-
TheCall->getArg(1)->getEndLoc());
5316+
// Note: returning true in this case results in CheckBuiltinFunctionCall
5317+
// returning an ExprError
5318+
bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
5319+
switch (BuiltinID) {
5320+
case Builtin::BI__builtin_hlsl_dot: {
5321+
if (checkArgCount(*this, TheCall, 2))
52125322
return true;
5213-
}
5214-
5215-
if (VecTy0->getElementType() != VecTy1->getElementType()) {
5216-
// Note: This case should never happen. If type promotion occurs
5217-
// then element types won't be different. This diag error is here
5218-
// b\c EmitHLSLBuiltinExpr asserts on this case.
5219-
Diag(BuiltinLoc, diag::err_vec_builtin_incompatible_vector)
5220-
<< TheCall->getDirectCallee()
5221-
<< SourceRange(TheCall->getArg(0)->getBeginLoc(),
5222-
TheCall->getArg(1)->getEndLoc());
5323+
if (PromoteBoolsToInt(this, TheCall))
52235324
return true;
5224-
}
5225-
if (VecTy0->getNumElements() != VecTy1->getNumElements()) {
5226-
Diag(BuiltinLoc, diag::err_vec_builtin_incompatible_size)
5227-
<< TheCall->getDirectCallee()
5228-
<< SourceRange(TheCall->getArg(0)->getBeginLoc(),
5229-
TheCall->getArg(1)->getEndLoc());
5325+
if (PromoteVectorElementCallArgs(this, TheCall))
5326+
return true;
5327+
PromoteVectorArgTruncation(this, TheCall);
5328+
if (SemaBuiltinVectorToScalarMath(TheCall))
52305329
return true;
5231-
}
52325330
break;
52335331
}
52345332
}
@@ -19676,15 +19774,37 @@ bool Sema::PrepareBuiltinElementwiseMathOneArgCall(CallExpr *TheCall) {
1967619774
}
1967719775

1967819776
bool Sema::SemaBuiltinElementwiseMath(CallExpr *TheCall) {
19777+
QualType Res;
19778+
bool result = SemaBuiltinVectorMath(TheCall, Res);
19779+
if (result)
19780+
return true;
19781+
TheCall->setType(Res);
19782+
return false;
19783+
}
19784+
19785+
bool Sema::SemaBuiltinVectorToScalarMath(CallExpr *TheCall) {
19786+
QualType Res;
19787+
bool result = SemaBuiltinVectorMath(TheCall, Res);
19788+
if (result)
19789+
return true;
19790+
19791+
if (auto *VecTy0 = Res->getAs<VectorType>()) {
19792+
TheCall->setType(VecTy0->getElementType());
19793+
} else {
19794+
TheCall->setType(Res);
19795+
}
19796+
return false;
19797+
}
19798+
19799+
bool Sema::SemaBuiltinVectorMath(CallExpr *TheCall, QualType &Res) {
1967919800
if (checkArgCount(*this, TheCall, 2))
1968019801
return true;
1968119802

1968219803
ExprResult A = TheCall->getArg(0);
1968319804
ExprResult B = TheCall->getArg(1);
1968419805
// Do standard promotions between the two arguments, returning their common
1968519806
// type.
19686-
QualType Res =
19687-
UsualArithmeticConversions(A, B, TheCall->getExprLoc(), ACK_Comparison);
19807+
Res = UsualArithmeticConversions(A, B, TheCall->getExprLoc(), ACK_Comparison);
1968819808
if (A.isInvalid() || B.isInvalid())
1968919809
return true;
1969019810

@@ -19701,7 +19821,6 @@ bool Sema::SemaBuiltinElementwiseMath(CallExpr *TheCall) {
1970119821

1970219822
TheCall->setArg(0, A.get());
1970319823
TheCall->setArg(1, B.get());
19704-
TheCall->setType(Res);
1970519824
return false;
1970619825
}
1970719826

0 commit comments

Comments
 (0)