@@ -451,27 +451,28 @@ inferTypeFromInitializerResultType(ConstraintSystem &cs,
451
451
}
452
452
453
453
// / If the given expression represents a chain of operators that have
454
- // / only literals as arguments, attempt to deduce a potential type of the
455
- // / chain. For example if chain has only integral literals it's going to
456
- // / be `Int`, if there are some floating-point literals mixed in - it's going
457
- // / to be `Double`.
458
- static Type inferTypeOfArithmeticOperatorChain (DeclContext *dc, ASTNode node) {
459
- auto binaryOp = getAsExpr<BinaryExpr>(node);
460
- if (!binaryOp)
461
- return Type ();
462
-
454
+ // / only declaration/member references and/or literals as arguments,
455
+ // / attempt to deduce a potential type of the chain. For example if
456
+ // / chain has only integral literals it's going to be `Int`, if there
457
+ // / are some floating-point literals mixed in - it's going to be `Double`.
458
+ static Type inferTypeOfArithmeticOperatorChain (ConstraintSystem &cs,
459
+ ASTNode node) {
463
460
class OperatorChainAnalyzer : public ASTWalker {
464
461
ASTContext &C;
465
462
DeclContext *DC;
463
+ ConstraintSystem &CS;
466
464
467
- llvm::SmallPtrSet<Type, 2 > literals ;
465
+ llvm::SmallPtrSet<llvm::PointerIntPair< Type, 1 >, 2 > candidates ;
468
466
469
467
bool unsupported = false ;
470
468
471
469
PreWalkResult<Expr *> walkToExprPre (Expr *expr) override {
472
470
if (isa<BinaryExpr>(expr))
473
471
return Action::Continue (expr);
474
472
473
+ if (isa<PrefixUnaryExpr>(expr) || isa<PostfixUnaryExpr>(expr))
474
+ return Action::Continue (expr);
475
+
475
476
if (isa<ParenExpr>(expr))
476
477
return Action::Continue (expr);
477
478
@@ -487,40 +488,67 @@ static Type inferTypeOfArithmeticOperatorChain(DeclContext *dc, ASTNode node) {
487
488
if (auto *LE = dyn_cast<LiteralExpr>(expr)) {
488
489
if (auto *P = TypeChecker::getLiteralProtocol (C, LE)) {
489
490
if (auto defaultTy = TypeChecker::getDefaultType (P, DC)) {
490
- if (defaultTy->isInt ()) {
491
- // Don't add `Int` if `Double` is already in the list.
492
- if (literals.contains (C.getDoubleType ()))
493
- return Action::Continue (expr);
494
- } else if (defaultTy->isDouble ()) {
495
- // A single use of a floating-point literal flips the
496
- // type of the entire chain to `Double`.
497
- (void )literals.erase (C.getIntType ());
498
- }
499
-
500
- literals.insert (defaultTy);
491
+ addCandidateType (defaultTy, /* literal=*/ true );
501
492
// String interpolation expressions have `TapExpr`
502
493
// as their children, no reason to walk them.
503
494
return Action::SkipChildren (expr);
504
495
}
505
496
}
506
497
}
507
498
499
+ if (auto *UDE = dyn_cast<UnresolvedDotExpr>(expr)) {
500
+ auto memberTy = CS.getType (UDE);
501
+ if (!memberTy->hasTypeVariable ()) {
502
+ addCandidateType (memberTy, /* literal=*/ false );
503
+ return Action::SkipChildren (expr);
504
+ }
505
+ }
506
+
507
+ if (auto *DRE = dyn_cast<DeclRefExpr>(expr)) {
508
+ auto declTy = CS.getType (DRE);
509
+ if (!declTy->hasTypeVariable ()) {
510
+ addCandidateType (declTy, /* literal=*/ false );
511
+ return Action::SkipChildren (expr);
512
+ }
513
+ }
514
+
508
515
unsupported = true ;
509
516
return Action::Stop ();
510
517
}
511
518
519
+ void addCandidateType (Type type, bool literal) {
520
+ if (literal) {
521
+ if (type->isInt ()) {
522
+ // Floating-point types always subsume Int in operator chains.
523
+ if (llvm::any_of (candidates, [](const auto &candidate) {
524
+ auto ty = candidate.getPointer ();
525
+ return isFloatType (ty) || ty->isCGFloat ();
526
+ }))
527
+ return ;
528
+ } else if (isFloatType (type) || type->isCGFloat ()) {
529
+ // A single use of a floating-point literal flips the
530
+ // type of the entire chain to it.
531
+ (void )candidates.erase ({C.getIntType (), /* literal=*/ true });
532
+ }
533
+ }
534
+
535
+ candidates.insert ({type, literal});
536
+ }
537
+
512
538
public:
513
- OperatorChainAnalyzer (DeclContext *DC) : C(DC->getASTContext ()), DC(DC) {}
539
+ OperatorChainAnalyzer (ConstraintSystem &CS)
540
+ : C(CS.getASTContext()), DC(CS.DC), CS(CS) {}
514
541
515
542
Type chainType () const {
516
543
if (unsupported)
517
544
return Type ();
518
- return literals.size () != 1 ? Type () : *literals.begin ();
545
+ return candidates.size () != 1 ? Type ()
546
+ : (*candidates.begin ()).getPointer ();
519
547
}
520
548
};
521
549
522
- OperatorChainAnalyzer analyzer (dc );
523
- binaryOp-> walk (analyzer);
550
+ OperatorChainAnalyzer analyzer (cs );
551
+ node. walk (analyzer);
524
552
525
553
return analyzer.chainType ();
526
554
}
@@ -695,7 +723,7 @@ static std::optional<DisjunctionInfo> preserveFavoringOfUnlabeledUnaryArgument(
695
723
// For chains like `1 + 2 * 3` it's easy to deduce the type because
696
724
// we know what literal types are preferred.
697
725
if (isa<BinaryExpr>(argument)) {
698
- auto chainTy = inferTypeOfArithmeticOperatorChain (cs. DC , argument);
726
+ auto chainTy = inferTypeOfArithmeticOperatorChain (cs, argument);
699
727
if (!chainTy)
700
728
return DisjunctionInfo::none ();
701
729
@@ -1008,7 +1036,7 @@ static void determineBestChoicesInContext(
1008
1036
auto *resultLoc = typeVar->getImpl ().getLocator ();
1009
1037
1010
1038
if (auto type = inferTypeOfArithmeticOperatorChain (
1011
- cs. DC , resultLoc->getAnchor ())) {
1039
+ cs, resultLoc->getAnchor ())) {
1012
1040
types.push_back ({type, /* fromLiteral=*/ true });
1013
1041
}
1014
1042
@@ -1824,7 +1852,7 @@ ConstraintSystem::selectDisjunction() {
1824
1852
1825
1853
// Not all of the non-operator disjunctions are supported by the
1826
1854
// ranking algorithm, so to prevent eager selection of operators
1827
- // when anything concrete is known about them, let's reset the score
1855
+ // when nothing concrete is known about them, let's reset the score
1828
1856
// and compare purely based on number of choices.
1829
1857
if (isFirstOperator != isSecondOperator) {
1830
1858
if (isFirstOperator && isFirstSpeculative)
0 commit comments