@@ -238,8 +238,8 @@ bool Compiler<Emitter>::VisitCastExpr(const CastExpr *CE) {
238238 case CK_DerivedToBaseMemberPointer: {
239239 assert (classifyPrim (CE->getType ()) == PT_MemberPtr);
240240 assert (classifyPrim (SubExpr->getType ()) == PT_MemberPtr);
241- const auto *FromMP = SubExpr->getType ()->getAs <MemberPointerType>();
242- const auto *ToMP = CE->getType ()->getAs <MemberPointerType>();
241+ const auto *FromMP = SubExpr->getType ()->castAs <MemberPointerType>();
242+ const auto *ToMP = CE->getType ()->castAs <MemberPointerType>();
243243
244244 unsigned DerivedOffset =
245245 Ctx.collectBaseOffset (ToMP->getMostRecentCXXRecordDecl (),
@@ -254,8 +254,8 @@ bool Compiler<Emitter>::VisitCastExpr(const CastExpr *CE) {
254254 case CK_BaseToDerivedMemberPointer: {
255255 assert (classifyPrim (CE) == PT_MemberPtr);
256256 assert (classifyPrim (SubExpr) == PT_MemberPtr);
257- const auto *FromMP = SubExpr->getType ()->getAs <MemberPointerType>();
258- const auto *ToMP = CE->getType ()->getAs <MemberPointerType>();
257+ const auto *FromMP = SubExpr->getType ()->castAs <MemberPointerType>();
258+ const auto *ToMP = CE->getType ()->castAs <MemberPointerType>();
259259
260260 unsigned DerivedOffset =
261261 Ctx.collectBaseOffset (FromMP->getMostRecentCXXRecordDecl (),
@@ -320,37 +320,34 @@ bool Compiler<Emitter>::VisitCastExpr(const CastExpr *CE) {
320320 }
321321
322322 case CK_IntegralToFloating: {
323- std::optional<PrimType> FromT = classify (SubExpr->getType ());
324- if (!FromT)
325- return false ;
326-
327323 if (!this ->visit (SubExpr))
328324 return false ;
329-
330325 const auto *TargetSemantics = &Ctx.getFloatSemantics (CE->getType ());
331- return this ->emitCastIntegralFloating (*FromT, TargetSemantics,
332- getFPOptions (CE), CE);
326+ return this ->emitCastIntegralFloating (
327+ classifyPrim (SubExpr), TargetSemantics, getFPOptions (CE), CE);
333328 }
334329
335- case CK_FloatingToBoolean:
336- case CK_FloatingToIntegral: {
337-
338- std::optional<PrimType> ToT = classify (CE->getType ());
339-
340- if (!ToT)
330+ case CK_FloatingToBoolean: {
331+ assert (classifyPrim (CE) == PT_Bool);
332+ if (const auto *FL = dyn_cast<FloatingLiteral>(SubExpr))
333+ return this ->emitConstBool (FL->getValue ().isNonZero (), CE);
334+ if (!this ->visit (SubExpr))
341335 return false ;
336+ return this ->emitCastFloatingIntegralBool (getFPOptions (CE), CE);
337+ }
342338
339+ case CK_FloatingToIntegral: {
343340 if (!this ->visit (SubExpr))
344341 return false ;
345-
342+ PrimType ToT = classifyPrim (CE);
346343 if (ToT == PT_IntAP)
347344 return this ->emitCastFloatingIntegralAP (Ctx.getBitWidth (CE->getType ()),
348345 getFPOptions (CE), CE);
349346 if (ToT == PT_IntAPS)
350347 return this ->emitCastFloatingIntegralAPS (Ctx.getBitWidth (CE->getType ()),
351348 getFPOptions (CE), CE);
352349
353- return this ->emitCastFloatingIntegral (* ToT, getFPOptions (CE), CE);
350+ return this ->emitCastFloatingIntegral (ToT, getFPOptions (CE), CE);
354351 }
355352
356353 case CK_NullToPointer:
@@ -395,9 +392,7 @@ bool Compiler<Emitter>::VisitCastExpr(const CastExpr *CE) {
395392 case CK_ArrayToPointerDecay: {
396393 if (!this ->visit (SubExpr))
397394 return false ;
398- if (!this ->emitArrayDecay (CE))
399- return false ;
400- return true ;
395+ return this ->emitArrayDecay (CE);
401396 }
402397
403398 case CK_IntegralToPointer: {
@@ -480,47 +475,67 @@ bool Compiler<Emitter>::VisitCastExpr(const CastExpr *CE) {
480475 return this ->emitBuiltinBitCast (CE);
481476
482477 case CK_IntegralToBoolean:
483- case CK_FixedPointToBoolean:
478+ case CK_FixedPointToBoolean: {
479+ // HLSL uses this to cast to one-element vectors.
480+ if (!CE->getType ()->isBooleanType ())
481+ return false ;
482+
483+ assert (classifyPrim (CE) == PT_Bool);
484+ if (const auto *IL = dyn_cast<IntegerLiteral>(SubExpr))
485+ return this ->emitConstBool (!IL->getValue ().isZero (), CE);
486+
487+ if (!this ->visit (SubExpr))
488+ return false ;
489+ PrimType FromT = classifyPrim (SubExpr->getType ());
490+ if (FromT == PT_Bool)
491+ return true ;
492+ return this ->emitCast (FromT, PT_Bool, CE);
493+ }
494+
484495 case CK_BooleanToSignedIntegral:
485496 case CK_IntegralCast: {
486497 std::optional<PrimType> FromT = classify (SubExpr->getType ());
487- std::optional<PrimType> ToT = classify (CE->getType ());
488-
489- if (!FromT || !ToT)
498+ if (!FromT)
490499 return false ;
500+ PrimType ToT = classifyPrim (CE->getType ());
491501
492- if (!this ->visit (SubExpr))
493- return false ;
502+ // Try to emit a casted known constant value directly.
503+ if (const auto *IL = dyn_cast<IntegerLiteral>(SubExpr)) {
504+ if (ToT != PT_IntAP && ToT != PT_IntAPS &&
505+ !CE->getType ()->isEnumeralType ())
506+ return this ->emitConst (IL->getValue (), CE);
507+ if (!this ->emitConst (IL->getValue (), SubExpr))
508+ return false ;
509+ } else {
510+ if (!this ->visit (SubExpr))
511+ return false ;
512+ }
494513
495514 // Possibly diagnose casts to enum types if the target type does not
496515 // have a fixed size.
497516 if (Ctx.getLangOpts ().CPlusPlus && CE->getType ()->isEnumeralType ()) {
498- if (const auto *ET = CE->getType ().getCanonicalType ()->getAs <EnumType>();
499- ET && !ET->getDecl ()->isFixed ()) {
517+ if (const auto *ET = CE->getType ().getCanonicalType ()->castAs <EnumType>();
518+ !ET->getDecl ()->isFixed ()) {
500519 if (!this ->emitCheckEnumValue (*FromT, ET->getDecl (), CE))
501520 return false ;
502521 }
503522 }
504523
505- auto maybeNegate = [&]() -> bool {
506- if (CE->getCastKind () == CK_BooleanToSignedIntegral)
507- return this ->emitNeg (*ToT, CE);
508- return true ;
509- };
510-
511- if (ToT == PT_IntAP)
512- return this ->emitCastAP (*FromT, Ctx.getBitWidth (CE->getType ()), CE) &&
513- maybeNegate ();
514- if (ToT == PT_IntAPS)
515- return this ->emitCastAPS (*FromT, Ctx.getBitWidth (CE->getType ()), CE) &&
516- maybeNegate ();
517-
518- if (FromT == ToT)
519- return true ;
520- if (!this ->emitCast (*FromT, *ToT, CE))
521- return false ;
522-
523- return maybeNegate ();
524+ if (ToT == PT_IntAP) {
525+ if (!this ->emitCastAP (*FromT, Ctx.getBitWidth (CE->getType ()), CE))
526+ return false ;
527+ } else if (ToT == PT_IntAPS) {
528+ if (!this ->emitCastAPS (*FromT, Ctx.getBitWidth (CE->getType ()), CE))
529+ return false ;
530+ } else {
531+ if (FromT == ToT)
532+ return true ;
533+ if (!this ->emitCast (*FromT, ToT, CE))
534+ return false ;
535+ }
536+ if (CE->getCastKind () == CK_BooleanToSignedIntegral)
537+ return this ->emitNeg (ToT, CE);
538+ return true ;
524539 }
525540
526541 case CK_PointerToBoolean:
0 commit comments