@@ -5166,69 +5166,167 @@ bool Sema::CheckPPCMMAType(QualType Type, SourceLocation TypeLoc) {
5166
5166
return false;
5167
5167
}
5168
5168
5169
- // Note: returning true in this case results in CheckBuiltinFunctionCall
5170
- // returning an ExprError
5171
- bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
5172
- switch (BuiltinID) {
5173
- case Builtin::BI__builtin_hlsl_dot: {
5174
- if (checkArgCount(*this, TheCall, 2)) {
5169
+ bool PromoteBoolsToInt(Sema *S, CallExpr *TheCall) {
5170
+ unsigned NumArgs = TheCall->getNumArgs();
5171
+
5172
+ for (unsigned i = 0; i < NumArgs; ++i) {
5173
+ ExprResult A = TheCall->getArg(i);
5174
+ if (!A.get()->getType()->isBooleanType())
5175
+ return false;
5176
+ }
5177
+ // if we got here all args are bool
5178
+ for (unsigned i = 0; i < NumArgs; ++i) {
5179
+ ExprResult A = TheCall->getArg(i);
5180
+ ExprResult ResA = S->PerformImplicitConversion(A.get(), S->Context.IntTy,
5181
+ Sema::AA_Converting);
5182
+ if (ResA.isInvalid())
5175
5183
return true;
5176
- }
5177
- Expr *Arg0 = TheCall->getArg(0);
5178
- QualType ArgTy0 = Arg0->getType();
5184
+ TheCall->setArg(0, ResA.get());
5185
+ }
5186
+ return false;
5187
+ }
5179
5188
5180
- Expr *Arg1 = TheCall->getArg(1);
5181
- QualType ArgTy1 = Arg1->getType();
5189
+ int overloadOrder(Sema *S, QualType ArgTyA) {
5190
+ auto kind = ArgTyA->getAs<BuiltinType>()->getKind();
5191
+ switch (kind) {
5192
+ case BuiltinType::Short:
5193
+ case BuiltinType::UShort:
5194
+ return 1;
5195
+ case BuiltinType::Int:
5196
+ case BuiltinType::UInt:
5197
+ return 2;
5198
+ case BuiltinType::Long:
5199
+ case BuiltinType::ULong:
5200
+ return 3;
5201
+ case BuiltinType::LongLong:
5202
+ case BuiltinType::ULongLong:
5203
+ return 4;
5204
+ case BuiltinType::Float16:
5205
+ case BuiltinType::Half:
5206
+ return 5;
5207
+ case BuiltinType::Float:
5208
+ return 6;
5209
+ default:
5210
+ break;
5211
+ }
5212
+ return 0;
5213
+ }
5182
5214
5183
- auto *VecTy0 = ArgTy0->getAs<VectorType>();
5184
- auto *VecTy1 = ArgTy1->getAs<VectorType>();
5185
- SourceLocation BuiltinLoc = TheCall->getBeginLoc();
5215
+ QualType getVecLargestBitness(Sema *S, QualType ArgTyA, QualType ArgTyB) {
5216
+ auto *VecTyA = ArgTyA->getAs<VectorType>();
5217
+ auto *VecTyB = ArgTyB->getAs<VectorType>();
5218
+ QualType VecTyAElem = VecTyA->getElementType();
5219
+ QualType VecTyBElem = VecTyB->getElementType();
5220
+ int vecAElemWidth = overloadOrder(S, VecTyAElem);
5221
+ int vecBElemWidth = overloadOrder(S, VecTyBElem);
5222
+ return vecAElemWidth > vecBElemWidth ? ArgTyA : ArgTyB;
5223
+ }
5186
5224
5187
- // if arg0 is bool then call Diag with err_builtin_invalid_arg_type
5188
- if (checkMathBuiltinElementType(*this, Arg0->getBeginLoc(), ArgTy0, 1)) {
5189
- return true;
5190
- }
5225
+ void PromoteVectorArgTruncation(Sema *S, CallExpr *TheCall) {
5226
+ assert(TheCall->getNumArgs() > 1);
5227
+ ExprResult A = TheCall->getArg(0);
5228
+ ExprResult B = TheCall->getArg(1);
5229
+ QualType ArgTyA = A.get()->getType();
5230
+ QualType ArgTyB = B.get()->getType();
5191
5231
5192
- // if arg1 is bool then call Diag with err_builtin_invalid_arg_type
5193
- if (checkMathBuiltinElementType(*this, Arg1->getBeginLoc(), ArgTy1, 2)) {
5194
- return true;
5232
+ auto *VecTyA = ArgTyA->getAs<VectorType>();
5233
+ auto *VecTyB = ArgTyB->getAs<VectorType>();
5234
+ if (VecTyA == nullptr && VecTyB == nullptr)
5235
+ return;
5236
+ if (VecTyA == nullptr || VecTyB == nullptr)
5237
+ return;
5238
+ if (VecTyA->getNumElements() == VecTyB->getNumElements())
5239
+ return;
5240
+
5241
+ Expr *LargerArg = B.get();
5242
+ Expr *SmallerArg = A.get();
5243
+ int largerIndex = 1;
5244
+ if (VecTyA->getNumElements() > VecTyB->getNumElements()) {
5245
+ LargerArg = A.get();
5246
+ SmallerArg = B.get();
5247
+ largerIndex = 0;
5248
+ }
5249
+ S->Diag(TheCall->getExprLoc(), diag::warn_hlsl_impcast_vector_truncation)
5250
+ << LargerArg->getType() << SmallerArg->getType()
5251
+ << LargerArg->getSourceRange() << SmallerArg->getSourceRange();
5252
+ ExprResult ResLargerArg = S->ImpCastExprToType(
5253
+ LargerArg, SmallerArg->getType(), CK_HLSLVectorTruncation);
5254
+ TheCall->setArg(largerIndex, ResLargerArg.get());
5255
+ return;
5256
+ }
5257
+
5258
+ bool PromoteVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
5259
+ assert(TheCall->getNumArgs() > 1);
5260
+ ExprResult A = TheCall->getArg(0);
5261
+ ExprResult B = TheCall->getArg(1);
5262
+ QualType ArgTyA = A.get()->getType();
5263
+ QualType ArgTyB = B.get()->getType();
5264
+
5265
+ auto *VecTyA = ArgTyA->getAs<VectorType>();
5266
+ auto *VecTyB = ArgTyB->getAs<VectorType>();
5267
+ if (VecTyA == nullptr && VecTyB == nullptr)
5268
+ return false;
5269
+ if (VecTyA && VecTyB) {
5270
+ if (VecTyA->getElementType() == VecTyB->getElementType()) {
5271
+ TheCall->setType(VecTyA->getElementType());
5272
+ return false;
5273
+ }
5274
+ SourceLocation BuiltinLoc = TheCall->getBeginLoc();
5275
+ QualType CastType = getVecLargestBitness(S, ArgTyA, ArgTyB);
5276
+ if (CastType == ArgTyA) {
5277
+ ExprResult ResB = S->SemaConvertVectorExpr(
5278
+ B.get(), S->Context.CreateTypeSourceInfo(ArgTyA), BuiltinLoc,
5279
+ B.get()->getBeginLoc());
5280
+ TheCall->setArg(1, ResB.get());
5281
+ TheCall->setType(VecTyA->getElementType());
5282
+ return false;
5195
5283
}
5196
5284
5197
- if (VecTy0 == nullptr && VecTy1 == nullptr) {
5198
- if (ArgTy0 != ArgTy1) {
5199
- return true;
5200
- } else {
5201
- return false;
5202
- }
5285
+ if (CastType == ArgTyB) {
5286
+ ExprResult ResA = S->SemaConvertVectorExpr(
5287
+ A.get(), S->Context.CreateTypeSourceInfo(ArgTyB), BuiltinLoc,
5288
+ A.get()->getBeginLoc());
5289
+ TheCall->setArg(0, ResA.get());
5290
+ TheCall->setType(VecTyB->getElementType());
5291
+ return false;
5203
5292
}
5293
+ return false;
5294
+ }
5204
5295
5205
- if ((VecTy0 == nullptr && VecTy1 != nullptr) ||
5206
- (VecTy0 != nullptr && VecTy1 == nullptr)) {
5296
+ if (VecTyB) {
5297
+ // Convert to the vector result type
5298
+ ExprResult ResA = A;
5299
+ if (VecTyB->getElementType() != ArgTyA)
5300
+ ResA = S->ImpCastExprToType(ResA.get(), VecTyB->getElementType(),
5301
+ CK_FloatingCast);
5302
+ ResA = S->ImpCastExprToType(ResA.get(), ArgTyB, CK_VectorSplat);
5303
+ TheCall->setArg(0, ResA.get());
5304
+ }
5305
+ if (VecTyA) {
5306
+ ExprResult ResB = B;
5307
+ if (VecTyA->getElementType() != ArgTyB)
5308
+ ResB = S->ImpCastExprToType(ResB.get(), VecTyA->getElementType(),
5309
+ CK_FloatingCast);
5310
+ ResB = S->ImpCastExprToType(ResB.get(), ArgTyA, CK_VectorSplat);
5311
+ TheCall->setArg(1, ResB.get());
5312
+ }
5313
+ return false;
5314
+ }
5207
5315
5208
- Diag(BuiltinLoc, diag::err_vec_builtin_non_vector)
5209
- << TheCall->getDirectCallee()
5210
- << SourceRange(TheCall->getArg(0)->getBeginLoc(),
5211
- TheCall->getArg(1)->getEndLoc());
5316
+ // Note: returning true in this case results in CheckBuiltinFunctionCall
5317
+ // returning an ExprError
5318
+ bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
5319
+ switch (BuiltinID) {
5320
+ case Builtin::BI__builtin_hlsl_dot: {
5321
+ if (checkArgCount(*this, TheCall, 2))
5212
5322
return true;
5213
- }
5214
-
5215
- if (VecTy0->getElementType() != VecTy1->getElementType()) {
5216
- // Note: This case should never happen. If type promotion occurs
5217
- // then element types won't be different. This diag error is here
5218
- // b\c EmitHLSLBuiltinExpr asserts on this case.
5219
- Diag(BuiltinLoc, diag::err_vec_builtin_incompatible_vector)
5220
- << TheCall->getDirectCallee()
5221
- << SourceRange(TheCall->getArg(0)->getBeginLoc(),
5222
- TheCall->getArg(1)->getEndLoc());
5323
+ if (PromoteBoolsToInt(this, TheCall))
5223
5324
return true;
5224
- }
5225
- if (VecTy0->getNumElements() != VecTy1->getNumElements()) {
5226
- Diag(BuiltinLoc, diag::err_vec_builtin_incompatible_size)
5227
- << TheCall->getDirectCallee()
5228
- << SourceRange(TheCall->getArg(0)->getBeginLoc(),
5229
- TheCall->getArg(1)->getEndLoc());
5325
+ if (PromoteVectorElementCallArgs(this, TheCall))
5326
+ return true;
5327
+ PromoteVectorArgTruncation(this, TheCall);
5328
+ if (SemaBuiltinVectorToScalarMath(TheCall))
5230
5329
return true;
5231
- }
5232
5330
break;
5233
5331
}
5234
5332
}
@@ -19676,15 +19774,37 @@ bool Sema::PrepareBuiltinElementwiseMathOneArgCall(CallExpr *TheCall) {
19676
19774
}
19677
19775
19678
19776
bool Sema::SemaBuiltinElementwiseMath(CallExpr *TheCall) {
19777
+ QualType Res;
19778
+ bool result = SemaBuiltinVectorMath(TheCall, Res);
19779
+ if (result)
19780
+ return true;
19781
+ TheCall->setType(Res);
19782
+ return false;
19783
+ }
19784
+
19785
+ bool Sema::SemaBuiltinVectorToScalarMath(CallExpr *TheCall) {
19786
+ QualType Res;
19787
+ bool result = SemaBuiltinVectorMath(TheCall, Res);
19788
+ if (result)
19789
+ return true;
19790
+
19791
+ if (auto *VecTy0 = Res->getAs<VectorType>()) {
19792
+ TheCall->setType(VecTy0->getElementType());
19793
+ } else {
19794
+ TheCall->setType(Res);
19795
+ }
19796
+ return false;
19797
+ }
19798
+
19799
+ bool Sema::SemaBuiltinVectorMath(CallExpr *TheCall, QualType &Res) {
19679
19800
if (checkArgCount(*this, TheCall, 2))
19680
19801
return true;
19681
19802
19682
19803
ExprResult A = TheCall->getArg(0);
19683
19804
ExprResult B = TheCall->getArg(1);
19684
19805
// Do standard promotions between the two arguments, returning their common
19685
19806
// type.
19686
- QualType Res =
19687
- UsualArithmeticConversions(A, B, TheCall->getExprLoc(), ACK_Comparison);
19807
+ Res = UsualArithmeticConversions(A, B, TheCall->getExprLoc(), ACK_Comparison);
19688
19808
if (A.isInvalid() || B.isInvalid())
19689
19809
return true;
19690
19810
@@ -19701,7 +19821,6 @@ bool Sema::SemaBuiltinElementwiseMath(CallExpr *TheCall) {
19701
19821
19702
19822
TheCall->setArg(0, A.get());
19703
19823
TheCall->setArg(1, B.get());
19704
- TheCall->setType(Res);
19705
19824
return false;
19706
19825
}
19707
19826
0 commit comments