@@ -5163,69 +5163,167 @@ bool Sema::CheckPPCMMAType(QualType Type, SourceLocation TypeLoc) {
5163
5163
return false;
5164
5164
}
5165
5165
5166
- // Note: returning true in this case results in CheckBuiltinFunctionCall
5167
- // returning an ExprError
5168
- bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
5169
- switch (BuiltinID) {
5170
- case Builtin::BI__builtin_hlsl_dot: {
5171
- if (checkArgCount(*this, TheCall, 2)) {
5166
+ bool PromoteBoolsToInt(Sema *S, CallExpr *TheCall) {
5167
+ unsigned NumArgs = TheCall->getNumArgs();
5168
+
5169
+ for (unsigned i = 0; i < NumArgs; ++i) {
5170
+ ExprResult A = TheCall->getArg(i);
5171
+ if (!A.get()->getType()->isBooleanType())
5172
+ return false;
5173
+ }
5174
+ // if we got here all args are bool
5175
+ for (unsigned i = 0; i < NumArgs; ++i) {
5176
+ ExprResult A = TheCall->getArg(i);
5177
+ ExprResult ResA = S->PerformImplicitConversion(A.get(), S->Context.IntTy,
5178
+ Sema::AA_Converting);
5179
+ if (ResA.isInvalid())
5172
5180
return true;
5173
- }
5174
- Expr *Arg0 = TheCall->getArg(0);
5175
- QualType ArgTy0 = Arg0->getType();
5181
+ TheCall->setArg(0, ResA.get());
5182
+ }
5183
+ return false;
5184
+ }
5176
5185
5177
- Expr *Arg1 = TheCall->getArg(1);
5178
- QualType ArgTy1 = Arg1->getType();
5186
+ int overloadOrder(Sema *S, QualType ArgTyA) {
5187
+ auto kind = ArgTyA->getAs<BuiltinType>()->getKind();
5188
+ switch (kind) {
5189
+ case BuiltinType::Short:
5190
+ case BuiltinType::UShort:
5191
+ return 1;
5192
+ case BuiltinType::Int:
5193
+ case BuiltinType::UInt:
5194
+ return 2;
5195
+ case BuiltinType::Long:
5196
+ case BuiltinType::ULong:
5197
+ return 3;
5198
+ case BuiltinType::LongLong:
5199
+ case BuiltinType::ULongLong:
5200
+ return 4;
5201
+ case BuiltinType::Float16:
5202
+ case BuiltinType::Half:
5203
+ return 5;
5204
+ case BuiltinType::Float:
5205
+ return 6;
5206
+ default:
5207
+ break;
5208
+ }
5209
+ return 0;
5210
+ }
5179
5211
5180
- auto *VecTy0 = ArgTy0->getAs<VectorType>();
5181
- auto *VecTy1 = ArgTy1->getAs<VectorType>();
5182
- SourceLocation BuiltinLoc = TheCall->getBeginLoc();
5212
+ QualType getVecLargestBitness(Sema *S, QualType ArgTyA, QualType ArgTyB) {
5213
+ auto *VecTyA = ArgTyA->getAs<VectorType>();
5214
+ auto *VecTyB = ArgTyB->getAs<VectorType>();
5215
+ QualType VecTyAElem = VecTyA->getElementType();
5216
+ QualType VecTyBElem = VecTyB->getElementType();
5217
+ int vecAElemWidth = overloadOrder(S, VecTyAElem);
5218
+ int vecBElemWidth = overloadOrder(S, VecTyBElem);
5219
+ return vecAElemWidth > vecBElemWidth ? ArgTyA : ArgTyB;
5220
+ }
5183
5221
5184
- // if arg0 is bool then call Diag with err_builtin_invalid_arg_type
5185
- if (checkMathBuiltinElementType(*this, Arg0->getBeginLoc(), ArgTy0, 1)) {
5186
- return true;
5187
- }
5222
+ void PromoteVectorArgTruncation(Sema *S, CallExpr *TheCall) {
5223
+ assert(TheCall->getNumArgs() > 1);
5224
+ ExprResult A = TheCall->getArg(0);
5225
+ ExprResult B = TheCall->getArg(1);
5226
+ QualType ArgTyA = A.get()->getType();
5227
+ QualType ArgTyB = B.get()->getType();
5188
5228
5189
- // if arg1 is bool then call Diag with err_builtin_invalid_arg_type
5190
- if (checkMathBuiltinElementType(*this, Arg1->getBeginLoc(), ArgTy1, 2)) {
5191
- return true;
5229
+ auto *VecTyA = ArgTyA->getAs<VectorType>();
5230
+ auto *VecTyB = ArgTyB->getAs<VectorType>();
5231
+ if (VecTyA == nullptr && VecTyB == nullptr)
5232
+ return;
5233
+ if (VecTyA == nullptr || VecTyB == nullptr)
5234
+ return;
5235
+ if (VecTyA->getNumElements() == VecTyB->getNumElements())
5236
+ return;
5237
+
5238
+ Expr *LargerArg = B.get();
5239
+ Expr *SmallerArg = A.get();
5240
+ int largerIndex = 1;
5241
+ if (VecTyA->getNumElements() > VecTyB->getNumElements()) {
5242
+ LargerArg = A.get();
5243
+ SmallerArg = B.get();
5244
+ largerIndex = 0;
5245
+ }
5246
+ S->Diag(TheCall->getExprLoc(), diag::warn_hlsl_impcast_vector_truncation)
5247
+ << LargerArg->getType() << SmallerArg->getType()
5248
+ << LargerArg->getSourceRange() << SmallerArg->getSourceRange();
5249
+ ExprResult ResLargerArg = S->ImpCastExprToType(
5250
+ LargerArg, SmallerArg->getType(), CK_HLSLVectorTruncation);
5251
+ TheCall->setArg(largerIndex, ResLargerArg.get());
5252
+ return;
5253
+ }
5254
+
5255
+ bool PromoteVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
5256
+ assert(TheCall->getNumArgs() > 1);
5257
+ ExprResult A = TheCall->getArg(0);
5258
+ ExprResult B = TheCall->getArg(1);
5259
+ QualType ArgTyA = A.get()->getType();
5260
+ QualType ArgTyB = B.get()->getType();
5261
+
5262
+ auto *VecTyA = ArgTyA->getAs<VectorType>();
5263
+ auto *VecTyB = ArgTyB->getAs<VectorType>();
5264
+ if (VecTyA == nullptr && VecTyB == nullptr)
5265
+ return false;
5266
+ if (VecTyA && VecTyB) {
5267
+ if (VecTyA->getElementType() == VecTyB->getElementType()) {
5268
+ TheCall->setType(VecTyA->getElementType());
5269
+ return false;
5270
+ }
5271
+ SourceLocation BuiltinLoc = TheCall->getBeginLoc();
5272
+ QualType CastType = getVecLargestBitness(S, ArgTyA, ArgTyB);
5273
+ if (CastType == ArgTyA) {
5274
+ ExprResult ResB = S->SemaConvertVectorExpr(
5275
+ B.get(), S->Context.CreateTypeSourceInfo(ArgTyA), BuiltinLoc,
5276
+ B.get()->getBeginLoc());
5277
+ TheCall->setArg(1, ResB.get());
5278
+ TheCall->setType(VecTyA->getElementType());
5279
+ return false;
5192
5280
}
5193
5281
5194
- if (VecTy0 == nullptr && VecTy1 == nullptr) {
5195
- if (ArgTy0 != ArgTy1) {
5196
- return true;
5197
- } else {
5198
- return false;
5199
- }
5282
+ if (CastType == ArgTyB) {
5283
+ ExprResult ResA = S->SemaConvertVectorExpr(
5284
+ A.get(), S->Context.CreateTypeSourceInfo(ArgTyB), BuiltinLoc,
5285
+ A.get()->getBeginLoc());
5286
+ TheCall->setArg(0, ResA.get());
5287
+ TheCall->setType(VecTyB->getElementType());
5288
+ return false;
5200
5289
}
5290
+ return false;
5291
+ }
5201
5292
5202
- if ((VecTy0 == nullptr && VecTy1 != nullptr) ||
5203
- (VecTy0 != nullptr && VecTy1 == nullptr)) {
5293
+ if (VecTyB) {
5294
+ // Convert to the vector result type
5295
+ ExprResult ResA = A;
5296
+ if (VecTyB->getElementType() != ArgTyA)
5297
+ ResA = S->ImpCastExprToType(ResA.get(), VecTyB->getElementType(),
5298
+ CK_FloatingCast);
5299
+ ResA = S->ImpCastExprToType(ResA.get(), ArgTyB, CK_VectorSplat);
5300
+ TheCall->setArg(0, ResA.get());
5301
+ }
5302
+ if (VecTyA) {
5303
+ ExprResult ResB = B;
5304
+ if (VecTyA->getElementType() != ArgTyB)
5305
+ ResB = S->ImpCastExprToType(ResB.get(), VecTyA->getElementType(),
5306
+ CK_FloatingCast);
5307
+ ResB = S->ImpCastExprToType(ResB.get(), ArgTyA, CK_VectorSplat);
5308
+ TheCall->setArg(1, ResB.get());
5309
+ }
5310
+ return false;
5311
+ }
5204
5312
5205
- Diag(BuiltinLoc, diag::err_vec_builtin_non_vector)
5206
- << TheCall->getDirectCallee()
5207
- << SourceRange(TheCall->getArg(0)->getBeginLoc(),
5208
- TheCall->getArg(1)->getEndLoc());
5313
+ // Note: returning true in this case results in CheckBuiltinFunctionCall
5314
+ // returning an ExprError
5315
+ bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
5316
+ switch (BuiltinID) {
5317
+ case Builtin::BI__builtin_hlsl_dot: {
5318
+ if (checkArgCount(*this, TheCall, 2))
5209
5319
return true;
5210
- }
5211
-
5212
- if (VecTy0->getElementType() != VecTy1->getElementType()) {
5213
- // Note: This case should never happen. If type promotion occurs
5214
- // then element types won't be different. This diag error is here
5215
- // b\c EmitHLSLBuiltinExpr asserts on this case.
5216
- Diag(BuiltinLoc, diag::err_vec_builtin_incompatible_vector)
5217
- << TheCall->getDirectCallee()
5218
- << SourceRange(TheCall->getArg(0)->getBeginLoc(),
5219
- TheCall->getArg(1)->getEndLoc());
5320
+ if (PromoteBoolsToInt(this, TheCall))
5220
5321
return true;
5221
- }
5222
- if (VecTy0->getNumElements() != VecTy1->getNumElements()) {
5223
- Diag(BuiltinLoc, diag::err_vec_builtin_incompatible_size)
5224
- << TheCall->getDirectCallee()
5225
- << SourceRange(TheCall->getArg(0)->getBeginLoc(),
5226
- TheCall->getArg(1)->getEndLoc());
5322
+ if (PromoteVectorElementCallArgs(this, TheCall))
5323
+ return true;
5324
+ PromoteVectorArgTruncation(this, TheCall);
5325
+ if (SemaBuiltinVectorToScalarMath(TheCall))
5227
5326
return true;
5228
- }
5229
5327
break;
5230
5328
}
5231
5329
}
@@ -19669,15 +19767,37 @@ bool Sema::PrepareBuiltinElementwiseMathOneArgCall(CallExpr *TheCall) {
19669
19767
}
19670
19768
19671
19769
bool Sema::SemaBuiltinElementwiseMath(CallExpr *TheCall) {
19770
+ QualType Res;
19771
+ bool result = SemaBuiltinVectorMath(TheCall, Res);
19772
+ if (result)
19773
+ return true;
19774
+ TheCall->setType(Res);
19775
+ return false;
19776
+ }
19777
+
19778
+ bool Sema::SemaBuiltinVectorToScalarMath(CallExpr *TheCall) {
19779
+ QualType Res;
19780
+ bool result = SemaBuiltinVectorMath(TheCall, Res);
19781
+ if (result)
19782
+ return true;
19783
+
19784
+ if (auto *VecTy0 = Res->getAs<VectorType>()) {
19785
+ TheCall->setType(VecTy0->getElementType());
19786
+ } else {
19787
+ TheCall->setType(Res);
19788
+ }
19789
+ return false;
19790
+ }
19791
+
19792
+ bool Sema::SemaBuiltinVectorMath(CallExpr *TheCall, QualType &Res) {
19672
19793
if (checkArgCount(*this, TheCall, 2))
19673
19794
return true;
19674
19795
19675
19796
ExprResult A = TheCall->getArg(0);
19676
19797
ExprResult B = TheCall->getArg(1);
19677
19798
// Do standard promotions between the two arguments, returning their common
19678
19799
// type.
19679
- QualType Res =
19680
- UsualArithmeticConversions(A, B, TheCall->getExprLoc(), ACK_Comparison);
19800
+ Res = UsualArithmeticConversions(A, B, TheCall->getExprLoc(), ACK_Comparison);
19681
19801
if (A.isInvalid() || B.isInvalid())
19682
19802
return true;
19683
19803
@@ -19694,7 +19814,6 @@ bool Sema::SemaBuiltinElementwiseMath(CallExpr *TheCall) {
19694
19814
19695
19815
TheCall->setArg(0, A.get());
19696
19816
TheCall->setArg(1, B.get());
19697
- TheCall->setType(Res);
19698
19817
return false;
19699
19818
}
19700
19819
0 commit comments