Skip to content

Commit 62ec3db

Browse files
committed
[CSOptimizer] Rework inferTypeOfArithmeticOperatorChain
- Expand the inference to include prefix and postfix unary operators - Recognize previously resolved declaration and member references in argument positions and record their types. - Expand reconciliation logic from Double<->Int to include other floating-point types and `CGFloat`.
1 parent c9e3ebe commit 62ec3db

File tree

1 file changed

+56
-28
lines changed

1 file changed

+56
-28
lines changed

lib/Sema/CSOptimizer.cpp

Lines changed: 56 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -451,27 +451,28 @@ inferTypeFromInitializerResultType(ConstraintSystem &cs,
451451
}
452452

453453
/// If the given expression represents a chain of operators that have
454-
/// only literals as arguments, attempt to deduce a potential type of the
455-
/// chain. For example if chain has only integral literals it's going to
456-
/// be `Int`, if there are some floating-point literals mixed in - it's going
457-
/// to be `Double`.
458-
static Type inferTypeOfArithmeticOperatorChain(DeclContext *dc, ASTNode node) {
459-
auto binaryOp = getAsExpr<BinaryExpr>(node);
460-
if (!binaryOp)
461-
return Type();
462-
454+
/// only declaration/member references and/or literals as arguments,
455+
/// attempt to deduce a potential type of the chain. For example if
456+
/// chain has only integral literals it's going to be `Int`, if there
457+
/// are some floating-point literals mixed in - it's going to be `Double`.
458+
static Type inferTypeOfArithmeticOperatorChain(ConstraintSystem &cs,
459+
ASTNode node) {
463460
class OperatorChainAnalyzer : public ASTWalker {
464461
ASTContext &C;
465462
DeclContext *DC;
463+
ConstraintSystem &CS;
466464

467-
llvm::SmallPtrSet<Type, 2> literals;
465+
llvm::SmallPtrSet<llvm::PointerIntPair<Type, 1>, 2> candidates;
468466

469467
bool unsupported = false;
470468

471469
PreWalkResult<Expr *> walkToExprPre(Expr *expr) override {
472470
if (isa<BinaryExpr>(expr))
473471
return Action::Continue(expr);
474472

473+
if (isa<PrefixUnaryExpr>(expr) || isa<PostfixUnaryExpr>(expr))
474+
return Action::Continue(expr);
475+
475476
if (isa<ParenExpr>(expr))
476477
return Action::Continue(expr);
477478

@@ -487,40 +488,67 @@ static Type inferTypeOfArithmeticOperatorChain(DeclContext *dc, ASTNode node) {
487488
if (auto *LE = dyn_cast<LiteralExpr>(expr)) {
488489
if (auto *P = TypeChecker::getLiteralProtocol(C, LE)) {
489490
if (auto defaultTy = TypeChecker::getDefaultType(P, DC)) {
490-
if (defaultTy->isInt()) {
491-
// Don't add `Int` if `Double` is already in the list.
492-
if (literals.contains(C.getDoubleType()))
493-
return Action::Continue(expr);
494-
} else if (defaultTy->isDouble()) {
495-
// A single use of a floating-point literal flips the
496-
// type of the entire chain to `Double`.
497-
(void)literals.erase(C.getIntType());
498-
}
499-
500-
literals.insert(defaultTy);
491+
addCandidateType(defaultTy, /*literal=*/true);
501492
// String interpolation expressions have `TapExpr`
502493
// as their children, no reason to walk them.
503494
return Action::SkipChildren(expr);
504495
}
505496
}
506497
}
507498

499+
if (auto *UDE = dyn_cast<UnresolvedDotExpr>(expr)) {
500+
auto memberTy = CS.getType(UDE);
501+
if (!memberTy->hasTypeVariable()) {
502+
addCandidateType(memberTy, /*literal=*/false);
503+
return Action::SkipChildren(expr);
504+
}
505+
}
506+
507+
if (auto *DRE = dyn_cast<DeclRefExpr>(expr)) {
508+
auto declTy = CS.getType(DRE);
509+
if (!declTy->hasTypeVariable()) {
510+
addCandidateType(declTy, /*literal=*/false);
511+
return Action::SkipChildren(expr);
512+
}
513+
}
514+
508515
unsupported = true;
509516
return Action::Stop();
510517
}
511518

519+
void addCandidateType(Type type, bool literal) {
520+
if (literal) {
521+
if (type->isInt()) {
522+
// Floating-point types always subsume Int in operator chains.
523+
if (llvm::any_of(candidates, [](const auto &candidate) {
524+
auto ty = candidate.getPointer();
525+
return isFloatType(ty) || ty->isCGFloat();
526+
}))
527+
return;
528+
} else if (isFloatType(type) || type->isCGFloat()) {
529+
// A single use of a floating-point literal flips the
530+
// type of the entire chain to it.
531+
(void)candidates.erase({C.getIntType(), /*literal=*/true});
532+
}
533+
}
534+
535+
candidates.insert({type, literal});
536+
}
537+
512538
public:
513-
OperatorChainAnalyzer(DeclContext *DC) : C(DC->getASTContext()), DC(DC) {}
539+
OperatorChainAnalyzer(ConstraintSystem &CS)
540+
: C(CS.getASTContext()), DC(CS.DC), CS(CS) {}
514541

515542
Type chainType() const {
516543
if (unsupported)
517544
return Type();
518-
return literals.size() != 1 ? Type() : *literals.begin();
545+
return candidates.size() != 1 ? Type()
546+
: (*candidates.begin()).getPointer();
519547
}
520548
};
521549

522-
OperatorChainAnalyzer analyzer(dc);
523-
binaryOp->walk(analyzer);
550+
OperatorChainAnalyzer analyzer(cs);
551+
node.walk(analyzer);
524552

525553
return analyzer.chainType();
526554
}
@@ -695,7 +723,7 @@ static std::optional<DisjunctionInfo> preserveFavoringOfUnlabeledUnaryArgument(
695723
// For chains like `1 + 2 * 3` it's easy to deduce the type because
696724
// we know what literal types are preferred.
697725
if (isa<BinaryExpr>(argument)) {
698-
auto chainTy = inferTypeOfArithmeticOperatorChain(cs.DC, argument);
726+
auto chainTy = inferTypeOfArithmeticOperatorChain(cs, argument);
699727
if (!chainTy)
700728
return DisjunctionInfo::none();
701729

@@ -1008,7 +1036,7 @@ static void determineBestChoicesInContext(
10081036
auto *resultLoc = typeVar->getImpl().getLocator();
10091037

10101038
if (auto type = inferTypeOfArithmeticOperatorChain(
1011-
cs.DC, resultLoc->getAnchor())) {
1039+
cs, resultLoc->getAnchor())) {
10121040
types.push_back({type, /*fromLiteral=*/true});
10131041
}
10141042

@@ -1824,7 +1852,7 @@ ConstraintSystem::selectDisjunction() {
18241852

18251853
// Not all of the non-operator disjunctions are supported by the
18261854
// ranking algorithm, so to prevent eager selection of operators
1827-
// when anything concrete is known about them, let's reset the score
1855+
// when nothing concrete is known about them, let's reset the score
18281856
// and compare purely based on number of choices.
18291857
if (isFirstOperator != isSecondOperator) {
18301858
if (isFirstOperator && isFirstSpeculative)

0 commit comments

Comments
 (0)