Skip to content

Commit 3524651

Browse files
committed
Add tests for call directly to builtin
Add more robustness to SemaChecking
1 parent 595ce41 commit 3524651

File tree

6 files changed

+376
-83
lines changed

6 files changed

+376
-83
lines changed

clang/include/clang/Basic/Builtins.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4526,7 +4526,7 @@ def HLSLCreateHandle : LangBuiltin<"HLSL_LANG"> {
45264526

45274527
def HLSLDotProduct : LangBuiltin<"HLSL_LANG"> {
45284528
let Spellings = ["__builtin_hlsl_dot"];
4529-
let Attributes = [NoThrow, Const, CustomTypeChecking];
4529+
let Attributes = [NoThrow, Const];
45304530
let Prototype = "void(...)";
45314531
}
45324532

clang/include/clang/Sema/Sema.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14121,6 +14121,8 @@ class Sema final {
1412114121

1412214122
bool CheckPPCMMAType(QualType Type, SourceLocation TypeLoc);
1412314123

14124+
bool SemaBuiltinVectorMath(CallExpr *TheCall, QualType &Res);
14125+
bool SemaBuiltinVectorToScalarMath(CallExpr *TheCall);
1412414126
bool SemaBuiltinElementwiseMath(CallExpr *TheCall);
1412514127
bool SemaBuiltinElementwiseTernaryMath(CallExpr *TheCall);
1412614128
bool PrepareBuiltinElementwiseMathOneArgCall(CallExpr *TheCall);

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17914,26 +17914,28 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
1791417914
llvm::Type *T1 = Op1->getType();
1791517915
if (!T0->isVectorTy() && !T1->isVectorTy()) {
1791617916
if (T0->isFloatingPointTy()) {
17917-
return Builder.CreateFMul(Op0, Op1, "dx.dot");
17917+
return Builder.CreateFMul(Op0, Op1, "dx.dot");
1791817918
}
1791917919

1792017920
if (T0->isIntegerTy()) {
17921-
return Builder.CreateMul(Op0, Op1, "dx.dot");
17921+
return Builder.CreateMul(Op0, Op1, "dx.dot");
1792217922
}
17923+
// Bools should have been promoted
1792317924
assert(
1792417925
false &&
1792517926
"Dot product on a scalar is only supported on integers and floats.");
1792617927
}
17928+
// A VectorSplat should have happened
1792717929
assert(T0->isVectorTy() && T1->isVectorTy() &&
1792817930
"Dot product of vector and scalar is not supported.");
1792917931

17930-
// NOTE: this assert will need to be revisited after overload resoltion
17931-
// PR merges.
17932+
// A vector sext or sitofp should have happened
1793217933
assert(T0->getScalarType() == T1->getScalarType() &&
1793317934
"Dot product of vectors need the same element types.");
1793417935

1793517936
auto *VecTy0 = E->getArg(0)->getType()->getAs<VectorType>();
1793617937
auto *VecTy1 = E->getArg(1)->getType()->getAs<VectorType>();
17938+
// A HLSLVectorTruncation should have happend
1793717939
assert(VecTy0->getNumElements() == VecTy1->getNumElements() &&
1793817940
"Dot product requires vectors to be of the same size.");
1793917941

clang/lib/Sema/SemaChecking.cpp

Lines changed: 172 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -5163,69 +5163,167 @@ bool Sema::CheckPPCMMAType(QualType Type, SourceLocation TypeLoc) {
51635163
return false;
51645164
}
51655165

5166-
// Note: returning true in this case results in CheckBuiltinFunctionCall
5167-
// returning an ExprError
5168-
bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
5169-
switch (BuiltinID) {
5170-
case Builtin::BI__builtin_hlsl_dot: {
5171-
if (checkArgCount(*this, TheCall, 2)) {
5166+
bool PromoteBoolsToInt(Sema *S, CallExpr *TheCall) {
5167+
unsigned NumArgs = TheCall->getNumArgs();
5168+
5169+
for (unsigned i = 0; i < NumArgs; ++i) {
5170+
ExprResult A = TheCall->getArg(i);
5171+
if (!A.get()->getType()->isBooleanType())
5172+
return false;
5173+
}
5174+
// if we got here all args are bool
5175+
for (unsigned i = 0; i < NumArgs; ++i) {
5176+
ExprResult A = TheCall->getArg(i);
5177+
ExprResult ResA = S->PerformImplicitConversion(A.get(), S->Context.IntTy,
5178+
Sema::AA_Converting);
5179+
if (ResA.isInvalid())
51725180
return true;
5173-
}
5174-
Expr *Arg0 = TheCall->getArg(0);
5175-
QualType ArgTy0 = Arg0->getType();
5181+
TheCall->setArg(0, ResA.get());
5182+
}
5183+
return false;
5184+
}
51765185

5177-
Expr *Arg1 = TheCall->getArg(1);
5178-
QualType ArgTy1 = Arg1->getType();
5186+
int overloadOrder(Sema *S, QualType ArgTyA) {
5187+
auto kind = ArgTyA->getAs<BuiltinType>()->getKind();
5188+
switch (kind) {
5189+
case BuiltinType::Short:
5190+
case BuiltinType::UShort:
5191+
return 1;
5192+
case BuiltinType::Int:
5193+
case BuiltinType::UInt:
5194+
return 2;
5195+
case BuiltinType::Long:
5196+
case BuiltinType::ULong:
5197+
return 3;
5198+
case BuiltinType::LongLong:
5199+
case BuiltinType::ULongLong:
5200+
return 4;
5201+
case BuiltinType::Float16:
5202+
case BuiltinType::Half:
5203+
return 5;
5204+
case BuiltinType::Float:
5205+
return 6;
5206+
default:
5207+
break;
5208+
}
5209+
return 0;
5210+
}
51795211

5180-
auto *VecTy0 = ArgTy0->getAs<VectorType>();
5181-
auto *VecTy1 = ArgTy1->getAs<VectorType>();
5182-
SourceLocation BuiltinLoc = TheCall->getBeginLoc();
5212+
QualType getVecLargestBitness(Sema *S, QualType ArgTyA, QualType ArgTyB) {
5213+
auto *VecTyA = ArgTyA->getAs<VectorType>();
5214+
auto *VecTyB = ArgTyB->getAs<VectorType>();
5215+
QualType VecTyAElem = VecTyA->getElementType();
5216+
QualType VecTyBElem = VecTyB->getElementType();
5217+
int vecAElemWidth = overloadOrder(S, VecTyAElem);
5218+
int vecBElemWidth = overloadOrder(S, VecTyBElem);
5219+
return vecAElemWidth > vecBElemWidth ? ArgTyA : ArgTyB;
5220+
}
51835221

5184-
// if arg0 is bool then call Diag with err_builtin_invalid_arg_type
5185-
if (checkMathBuiltinElementType(*this, Arg0->getBeginLoc(), ArgTy0, 1)) {
5186-
return true;
5187-
}
5222+
void PromoteVectorArgTruncation(Sema *S, CallExpr *TheCall) {
5223+
assert(TheCall->getNumArgs() > 1);
5224+
ExprResult A = TheCall->getArg(0);
5225+
ExprResult B = TheCall->getArg(1);
5226+
QualType ArgTyA = A.get()->getType();
5227+
QualType ArgTyB = B.get()->getType();
51885228

5189-
// if arg1 is bool then call Diag with err_builtin_invalid_arg_type
5190-
if (checkMathBuiltinElementType(*this, Arg1->getBeginLoc(), ArgTy1, 2)) {
5191-
return true;
5229+
auto *VecTyA = ArgTyA->getAs<VectorType>();
5230+
auto *VecTyB = ArgTyB->getAs<VectorType>();
5231+
if (VecTyA == nullptr && VecTyB == nullptr)
5232+
return;
5233+
if (VecTyA == nullptr || VecTyB == nullptr)
5234+
return;
5235+
if (VecTyA->getNumElements() == VecTyB->getNumElements())
5236+
return;
5237+
5238+
Expr *LargerArg = B.get();
5239+
Expr *SmallerArg = A.get();
5240+
int largerIndex = 1;
5241+
if (VecTyA->getNumElements() > VecTyB->getNumElements()) {
5242+
LargerArg = A.get();
5243+
SmallerArg = B.get();
5244+
largerIndex = 0;
5245+
}
5246+
S->Diag(TheCall->getExprLoc(), diag::warn_hlsl_impcast_vector_truncation)
5247+
<< LargerArg->getType() << SmallerArg->getType()
5248+
<< LargerArg->getSourceRange() << SmallerArg->getSourceRange();
5249+
ExprResult ResLargerArg = S->ImpCastExprToType(
5250+
LargerArg, SmallerArg->getType(), CK_HLSLVectorTruncation);
5251+
TheCall->setArg(largerIndex, ResLargerArg.get());
5252+
return;
5253+
}
5254+
5255+
bool PromoteVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
5256+
assert(TheCall->getNumArgs() > 1);
5257+
ExprResult A = TheCall->getArg(0);
5258+
ExprResult B = TheCall->getArg(1);
5259+
QualType ArgTyA = A.get()->getType();
5260+
QualType ArgTyB = B.get()->getType();
5261+
5262+
auto *VecTyA = ArgTyA->getAs<VectorType>();
5263+
auto *VecTyB = ArgTyB->getAs<VectorType>();
5264+
if (VecTyA == nullptr && VecTyB == nullptr)
5265+
return false;
5266+
if (VecTyA && VecTyB) {
5267+
if (VecTyA->getElementType() == VecTyB->getElementType()) {
5268+
TheCall->setType(VecTyA->getElementType());
5269+
return false;
5270+
}
5271+
SourceLocation BuiltinLoc = TheCall->getBeginLoc();
5272+
QualType CastType = getVecLargestBitness(S, ArgTyA, ArgTyB);
5273+
if (CastType == ArgTyA) {
5274+
ExprResult ResB = S->SemaConvertVectorExpr(
5275+
B.get(), S->Context.CreateTypeSourceInfo(ArgTyA), BuiltinLoc,
5276+
B.get()->getBeginLoc());
5277+
TheCall->setArg(1, ResB.get());
5278+
TheCall->setType(VecTyA->getElementType());
5279+
return false;
51925280
}
51935281

5194-
if (VecTy0 == nullptr && VecTy1 == nullptr) {
5195-
if (ArgTy0 != ArgTy1) {
5196-
return true;
5197-
} else {
5198-
return false;
5199-
}
5282+
if (CastType == ArgTyB) {
5283+
ExprResult ResA = S->SemaConvertVectorExpr(
5284+
A.get(), S->Context.CreateTypeSourceInfo(ArgTyB), BuiltinLoc,
5285+
A.get()->getBeginLoc());
5286+
TheCall->setArg(0, ResA.get());
5287+
TheCall->setType(VecTyB->getElementType());
5288+
return false;
52005289
}
5290+
return false;
5291+
}
52015292

5202-
if ((VecTy0 == nullptr && VecTy1 != nullptr) ||
5203-
(VecTy0 != nullptr && VecTy1 == nullptr)) {
5293+
if (VecTyB) {
5294+
// Convert to the vector result type
5295+
ExprResult ResA = A;
5296+
if (VecTyB->getElementType() != ArgTyA)
5297+
ResA = S->ImpCastExprToType(ResA.get(), VecTyB->getElementType(),
5298+
CK_FloatingCast);
5299+
ResA = S->ImpCastExprToType(ResA.get(), ArgTyB, CK_VectorSplat);
5300+
TheCall->setArg(0, ResA.get());
5301+
}
5302+
if (VecTyA) {
5303+
ExprResult ResB = B;
5304+
if (VecTyA->getElementType() != ArgTyB)
5305+
ResB = S->ImpCastExprToType(ResB.get(), VecTyA->getElementType(),
5306+
CK_FloatingCast);
5307+
ResB = S->ImpCastExprToType(ResB.get(), ArgTyA, CK_VectorSplat);
5308+
TheCall->setArg(1, ResB.get());
5309+
}
5310+
return false;
5311+
}
52045312

5205-
Diag(BuiltinLoc, diag::err_vec_builtin_non_vector)
5206-
<< TheCall->getDirectCallee()
5207-
<< SourceRange(TheCall->getArg(0)->getBeginLoc(),
5208-
TheCall->getArg(1)->getEndLoc());
5313+
// Note: returning true in this case results in CheckBuiltinFunctionCall
5314+
// returning an ExprError
5315+
bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
5316+
switch (BuiltinID) {
5317+
case Builtin::BI__builtin_hlsl_dot: {
5318+
if (checkArgCount(*this, TheCall, 2))
52095319
return true;
5210-
}
5211-
5212-
if (VecTy0->getElementType() != VecTy1->getElementType()) {
5213-
// Note: This case should never happen. If type promotion occurs
5214-
// then element types won't be different. This diag error is here
5215-
// b\c EmitHLSLBuiltinExpr asserts on this case.
5216-
Diag(BuiltinLoc, diag::err_vec_builtin_incompatible_vector)
5217-
<< TheCall->getDirectCallee()
5218-
<< SourceRange(TheCall->getArg(0)->getBeginLoc(),
5219-
TheCall->getArg(1)->getEndLoc());
5320+
if (PromoteBoolsToInt(this, TheCall))
52205321
return true;
5221-
}
5222-
if (VecTy0->getNumElements() != VecTy1->getNumElements()) {
5223-
Diag(BuiltinLoc, diag::err_vec_builtin_incompatible_size)
5224-
<< TheCall->getDirectCallee()
5225-
<< SourceRange(TheCall->getArg(0)->getBeginLoc(),
5226-
TheCall->getArg(1)->getEndLoc());
5322+
if (PromoteVectorElementCallArgs(this, TheCall))
5323+
return true;
5324+
PromoteVectorArgTruncation(this, TheCall);
5325+
if (SemaBuiltinVectorToScalarMath(TheCall))
52275326
return true;
5228-
}
52295327
break;
52305328
}
52315329
}
@@ -19669,15 +19767,37 @@ bool Sema::PrepareBuiltinElementwiseMathOneArgCall(CallExpr *TheCall) {
1966919767
}
1967019768

1967119769
bool Sema::SemaBuiltinElementwiseMath(CallExpr *TheCall) {
19770+
QualType Res;
19771+
bool result = SemaBuiltinVectorMath(TheCall, Res);
19772+
if (result)
19773+
return true;
19774+
TheCall->setType(Res);
19775+
return false;
19776+
}
19777+
19778+
bool Sema::SemaBuiltinVectorToScalarMath(CallExpr *TheCall) {
19779+
QualType Res;
19780+
bool result = SemaBuiltinVectorMath(TheCall, Res);
19781+
if (result)
19782+
return true;
19783+
19784+
if (auto *VecTy0 = Res->getAs<VectorType>()) {
19785+
TheCall->setType(VecTy0->getElementType());
19786+
} else {
19787+
TheCall->setType(Res);
19788+
}
19789+
return false;
19790+
}
19791+
19792+
bool Sema::SemaBuiltinVectorMath(CallExpr *TheCall, QualType &Res) {
1967219793
if (checkArgCount(*this, TheCall, 2))
1967319794
return true;
1967419795

1967519796
ExprResult A = TheCall->getArg(0);
1967619797
ExprResult B = TheCall->getArg(1);
1967719798
// Do standard promotions between the two arguments, returning their common
1967819799
// type.
19679-
QualType Res =
19680-
UsualArithmeticConversions(A, B, TheCall->getExprLoc(), ACK_Comparison);
19800+
Res = UsualArithmeticConversions(A, B, TheCall->getExprLoc(), ACK_Comparison);
1968119801
if (A.isInvalid() || B.isInvalid())
1968219802
return true;
1968319803

@@ -19694,7 +19814,6 @@ bool Sema::SemaBuiltinElementwiseMath(CallExpr *TheCall) {
1969419814

1969519815
TheCall->setArg(0, A.get());
1969619816
TheCall->setArg(1, B.get());
19697-
TheCall->setType(Res);
1969819817
return false;
1969919818
}
1970019819

0 commit comments

Comments
 (0)