diff --git a/include/swift/Basic/LangOptions.h b/include/swift/Basic/LangOptions.h index a4680b9ab4dd8..5796147aa2346 100644 --- a/include/swift/Basic/LangOptions.h +++ b/include/swift/Basic/LangOptions.h @@ -679,7 +679,7 @@ namespace swift { unsigned SolverShrinkUnsolvedThreshold = 10; /// Disable the shrink phase of the expression type checker. - bool SolverDisableShrink = false; + bool SolverDisableShrink = true; /// Enable experimental operator designated types feature. bool EnableOperatorDesignatedTypes = false; diff --git a/include/swift/Sema/ConstraintSystem.h b/include/swift/Sema/ConstraintSystem.h index 352cd4b950500..2391932512291 100644 --- a/include/swift/Sema/ConstraintSystem.h +++ b/include/swift/Sema/ConstraintSystem.h @@ -2251,6 +2251,7 @@ class ConstraintSystem { friend class SplitterStep; friend class ComponentStep; friend class TypeVariableStep; + friend class DisjunctionStep; friend class ConjunctionStep; friend class ConjunctionElement; friend class RequirementFailure; @@ -2402,6 +2403,10 @@ class ConstraintSystem { /// the current constraint system. llvm::MapVector DisjunctionChoices; + /// The stack of all disjunctions selected during current path in order. + /// This stack is managed by the \c DisjunctionStep. + llvm::SmallVector SelectedDisjunctions; + /// A map from applied disjunction constraints to the corresponding /// argument function type. llvm::SmallMapVector @@ -5023,6 +5028,9 @@ class ConstraintSystem { /// /// \returns The selected disjunction. Constraint *selectDisjunction(); + /// Select the best possible disjunction for solver to attempt + /// based on the given list. + Constraint *selectBestDisjunction(ArrayRef disjunctions); /// Pick a conjunction from the InactiveConstraints list. /// @@ -5622,6 +5630,10 @@ class DisjunctionChoice { bool isSymmetricOperator() const; bool isUnaryOperator() const; + bool isGenericUnaryOperator() const { + return isGenericOperator() && isUnaryOperator(); + } + void print(llvm::raw_ostream &Out, SourceManager *SM) const { Out << "disjunction choice "; Choice->print(Out, SM); @@ -5839,8 +5851,6 @@ class DisjunctionChoiceProducer : public BindingProducer { unsigned Index = 0; - bool needsGenericOperatorOrdering = true; - public: using Element = DisjunctionChoice; @@ -5858,10 +5868,6 @@ class DisjunctionChoiceProducer : public BindingProducer { partitionDisjunction(Ordering, PartitionBeginning); } - void setNeedsGenericOperatorOrdering(bool flag) { - needsGenericOperatorOrdering = flag; - } - Optional operator()() override { if (isExhausted()) return None; @@ -5874,18 +5880,6 @@ class DisjunctionChoiceProducer : public BindingProducer { ++Index; - auto choice = DisjunctionChoice(CS, currIndex, Choices[Ordering[currIndex]], - IsExplicitConversion, isBeginningOfPartition); - // Partition the generic operators before producing the first generic - // operator disjunction choice. - if (needsGenericOperatorOrdering && choice.isGenericOperator()) { - unsigned nextPartitionIndex = (PartitionIndex < PartitionBeginning.size() ? - PartitionBeginning[PartitionIndex] : Ordering.size()); - partitionGenericOperators(Ordering.begin() + currIndex, - Ordering.begin() + nextPartitionIndex); - needsGenericOperatorOrdering = false; - } - return DisjunctionChoice(CS, currIndex, Choices[Ordering[currIndex]], IsExplicitConversion, isBeginningOfPartition); } @@ -5900,12 +5894,6 @@ class DisjunctionChoiceProducer : public BindingProducer { // have to visit all of the options. void partitionDisjunction(SmallVectorImpl &Ordering, SmallVectorImpl &PartitionBeginning); - - /// Partition the choices in the range \c first to \c last into groups and - /// order the groups in the best order to attempt based on the argument - /// function type that the operator is applied to. - void partitionGenericOperators(SmallVectorImpl::iterator first, - SmallVectorImpl::iterator last); }; class ConjunctionElementProducer : public BindingProducer { diff --git a/lib/Sema/CSGen.cpp b/lib/Sema/CSGen.cpp index fec7d7acd7cca..c04f2303b5b59 100644 --- a/lib/Sema/CSGen.cpp +++ b/lib/Sema/CSGen.cpp @@ -73,314 +73,6 @@ static bool mergeRepresentativeEquivalenceClasses(ConstraintSystem &CS, } namespace { - - /// Internal struct for tracking information about types within a series - /// of "linked" expressions. (Such as a chain of binary operator invocations.) - struct LinkedTypeInfo { - bool hasLiteral = false; - - llvm::SmallSet collectedTypes; - llvm::SmallVector binaryExprs; - }; - - /// Walks an expression sub-tree, and collects information about expressions - /// whose types are mutually dependent upon one another. - class LinkedExprCollector : public ASTWalker { - - llvm::SmallVectorImpl &LinkedExprs; - ConstraintSystem &CS; - - public: - LinkedExprCollector(llvm::SmallVectorImpl &linkedExprs, - ConstraintSystem &cs) - : LinkedExprs(linkedExprs), CS(cs) {} - - std::pair walkToExprPre(Expr *expr) override { - - if (CS.shouldReusePrecheckedType() && - !CS.getType(expr)->hasTypeVariable()) { - return { false, expr }; - } - - if (isa(expr)) - return {false, expr}; - - // Store top-level binary exprs for further analysis. - if (isa(expr) || - - // Literal exprs are contextually typed, so store them off as well. - isa(expr) || - - // We'd like to look at the elements of arrays and dictionaries. - isa(expr) || - isa(expr) || - - // assignment expression can involve anonymous closure parameters - // as source and destination, so it's beneficial for diagnostics if - // we look at the assignment. - isa(expr)) { - LinkedExprs.push_back(expr); - return {false, expr}; - } - - return { true, expr }; - } - - Expr *walkToExprPost(Expr *expr) override { - return expr; - } - - /// Ignore statements. - std::pair walkToStmtPre(Stmt *stmt) override { - return { false, stmt }; - } - - /// Ignore declarations. - bool walkToDeclPre(Decl *decl) override { return false; } - - /// Ignore patterns. - std::pair walkToPatternPre(Pattern *pat) override { - return { false, pat }; - } - - /// Ignore types. - bool walkToTypeReprPre(TypeRepr *T) override { return false; } - }; - - /// Given a collection of "linked" expressions, analyzes them for - /// commonalities regarding their types. This will help us compute a - /// "best common type" from the expression types. - class LinkedExprAnalyzer : public ASTWalker { - - LinkedTypeInfo <I; - ConstraintSystem &CS; - - public: - - LinkedExprAnalyzer(LinkedTypeInfo <i, ConstraintSystem &cs) : - LTI(lti), CS(cs) {} - - std::pair walkToExprPre(Expr *expr) override { - - if (CS.shouldReusePrecheckedType() && - !CS.getType(expr)->hasTypeVariable()) { - return { false, expr }; - } - - if (isa(expr)) { - LTI.hasLiteral = true; - return { false, expr }; - } - - if (isa(expr)) { - return { true, expr }; - } - - if (auto UDE = dyn_cast(expr)) { - - if (CS.hasType(UDE)) - LTI.collectedTypes.insert(CS.getType(UDE).getPointer()); - - // Don't recurse into the base expression. - return { false, expr }; - } - - - if (isa(expr)) { - return {false, expr}; - } - - if (auto FVE = dyn_cast(expr)) { - LTI.collectedTypes.insert(CS.getType(FVE).getPointer()); - return { false, expr }; - } - - if (auto DRE = dyn_cast(expr)) { - if (auto varDecl = dyn_cast(DRE->getDecl())) { - if (CS.hasType(DRE)) { - LTI.collectedTypes.insert(CS.getType(DRE).getPointer()); - } - return { false, expr }; - } - } - - // In the case of a function application, we would have already captured - // the return type during constraint generation, so there's no use in - // looking any further. - if (isa(expr) && - !(isa(expr) || isa(expr) || - isa(expr))) { - return { false, expr }; - } - - if (auto *binaryExpr = dyn_cast(expr)) { - LTI.binaryExprs.push_back(binaryExpr); - } - - if (auto favoredType = CS.getFavoredType(expr)) { - LTI.collectedTypes.insert(favoredType); - - return { false, expr }; - } - - // Optimize branches of a conditional expression separately. - if (auto IE = dyn_cast(expr)) { - CS.optimizeConstraints(IE->getCondExpr()); - CS.optimizeConstraints(IE->getThenExpr()); - CS.optimizeConstraints(IE->getElseExpr()); - return { false, expr }; - } - - // For exprs of a tuple, avoid favoring. (We need to allow for cases like - // (Int, Int32).) - if (isa(expr)) { - return { false, expr }; - } - - // Coercion exprs have a rigid type, so there's no use in gathering info - // about them. - if (auto *coercion = dyn_cast(expr)) { - // Let's not collect information about types initialized by - // coercions just like we don't for regular initializer calls, - // because that might lead to overly eager type variable merging. - if (!coercion->isLiteralInit()) - LTI.collectedTypes.insert(CS.getType(expr).getPointer()); - return { false, expr }; - } - - // Don't walk into subscript expressions - to do so would risk factoring - // the index expression into edge contraction. (We don't want to do this - // if the index expression is a literal type that differs from the return - // type of the subscript operation.) - if (isa(expr) || isa(expr)) { - return { false, expr }; - } - - // Don't walk into unresolved member expressions - we avoid merging type - // variables inside UnresolvedMemberExpr and those outside, since they - // should be allowed to behave independently in CS. - if (isa(expr)) { - return {false, expr }; - } - - return { true, expr }; - } - - /// Ignore statements. - std::pair walkToStmtPre(Stmt *stmt) override { - return { false, stmt }; - } - - /// Ignore declarations. - bool walkToDeclPre(Decl *decl) override { return false; } - - /// Ignore patterns. - std::pair walkToPatternPre(Pattern *pat) override { - return { false, pat }; - } - - /// Ignore types. - bool walkToTypeReprPre(TypeRepr *T) override { return false; } - }; - - /// For a given expression, given information that is global to the - /// expression, attempt to derive a favored type for it. - void computeFavoredTypeForExpr(Expr *expr, ConstraintSystem &CS) { - LinkedTypeInfo lti; - - expr->walk(LinkedExprAnalyzer(lti, CS)); - - // Check whether we can proceed with favoring. - if (llvm::any_of(lti.binaryExprs, [](const BinaryExpr *op) { - auto *ODRE = dyn_cast(op->getFn()); - if (!ODRE) - return false; - - // Attempting to favor based on operand types is wrong for - // nil-coalescing operator. - auto identifier = ODRE->getDecls().front()->getBaseIdentifier(); - return identifier.isNilCoalescingOperator(); - })) { - return; - } - - if (lti.collectedTypes.size() == 1) { - // TODO: Compute the BCT. - - // It's only useful to favor the type instead of - // binding it directly to arguments/result types, - // which means in case it has been miscalculated - // solver can still make progress. - auto favoredTy = (*lti.collectedTypes.begin())->getWithoutSpecifierType(); - CS.setFavoredType(expr, favoredTy.getPointer()); - - // If we have a chain of identical binop expressions with homogeneous - // argument types, we can directly simplify the associated constraint - // graph. - auto simplifyBinOpExprTyVars = [&]() { - // Don't attempt to do linking if there are - // literals intermingled with other inferred types. - if (lti.hasLiteral) - return; - - for (auto binExp1 : lti.binaryExprs) { - for (auto binExp2 : lti.binaryExprs) { - if (binExp1 == binExp2) - continue; - - auto fnTy1 = CS.getType(binExp1)->getAs(); - auto fnTy2 = CS.getType(binExp2)->getAs(); - - if (!(fnTy1 && fnTy2)) - return; - - auto ODR1 = dyn_cast(binExp1->getFn()); - auto ODR2 = dyn_cast(binExp2->getFn()); - - if (!(ODR1 && ODR2)) - return; - - // TODO: We currently limit this optimization to known arithmetic - // operators, but we should be able to broaden this out to - // logical operators as well. - if (!isArithmeticOperatorDecl(ODR1->getDecls()[0])) - return; - - if (ODR1->getDecls()[0]->getBaseName() != - ODR2->getDecls()[0]->getBaseName()) - return; - - // All things equal, we can merge the tyvars for the function - // types. - auto rep1 = CS.getRepresentative(fnTy1); - auto rep2 = CS.getRepresentative(fnTy2); - - if (rep1 != rep2) { - CS.mergeEquivalenceClasses(rep1, rep2, - /*updateWorkList*/ false); - } - - auto odTy1 = CS.getType(ODR1)->getAs(); - auto odTy2 = CS.getType(ODR2)->getAs(); - - if (odTy1 && odTy2) { - auto odRep1 = CS.getRepresentative(odTy1); - auto odRep2 = CS.getRepresentative(odTy2); - - // Since we'll be choosing the same overload, we can merge - // the overload tyvar as well. - if (odRep1 != odRep2) - CS.mergeEquivalenceClasses(odRep1, odRep2, - /*updateWorkList*/ false); - } - } - } - }; - - simplifyBinOpExprTyVars(); - } - } - /// Determine whether the given parameter type and argument should be /// "favored" because they match exactly. bool isFavoredParamAndArg(ConstraintSystem &CS, Type paramTy, Type argTy, @@ -4230,18 +3922,7 @@ ConstraintSystem::applyPropertyWrapperToParameter( void ConstraintSystem::optimizeConstraints(Expr *e) { if (getASTContext().TypeCheckerOpts.DisableConstraintSolverPerformanceHacks) return; - - SmallVector linkedExprs; - - // Collect any linked expressions. - LinkedExprCollector collector(linkedExprs, *this); - e->walk(collector); - - // Favor types, as appropriate. - for (auto linkedExpr : linkedExprs) { - computeFavoredTypeForExpr(linkedExpr, *this); - } - + // Optimize the constraints. ConstraintOptimizer optimizer(*this); e->walk(optimizer); diff --git a/lib/Sema/CSSimplify.cpp b/lib/Sema/CSSimplify.cpp index 58b7d97beab2c..bb99176dec4f3 100644 --- a/lib/Sema/CSSimplify.cpp +++ b/lib/Sema/CSSimplify.cpp @@ -10405,6 +10405,24 @@ bool ConstraintSystem::simplifyAppliedOverloadsImpl( } } + // Disabled overloads need special handling depending mode. + if (constraint->isDisabled()) { + // In diagnostic mode, invalidate previous common result type if + // current overload choice has a fix to make sure that we produce + // the best diagnostics possible. + if (shouldAttemptFixes()) { + if (constraint->getFix()) + commonResultType = ErrorType::get(getASTContext()); + return true; + } + + // In performance mode, let's skip the disabled overload choice + // and continue - this would make sure that common result type + // could be found if one exists among the overloads the solver + // is actually going to attempt. + return false; + } + // Determine the type that this choice will have. Type choiceType = getEffectiveOverloadType( constraint->getLocator(), choice, /*allowMembers=*/true, @@ -10414,6 +10432,34 @@ bool ConstraintSystem::simplifyAppliedOverloadsImpl( return true; } + // This is the situation where a property has the same name + // as a method e.g. + // + // protocol P { + // var test: String { get } + // } + // + // extension P { + // var test: String { get { return "" } } + // } + // + // struct S : P { + // func test() -> Int { 42 } + // } + // + // var s = S() + // s.test() <- disjunction would have two choices here, one + // for the property from `P` and one for the method of `S`. + // + // In cases like this, let's exclude property overload from common + // result determination because it cannot be applied. + // + // Note that such overloads cannot be disabled, because they still + // have to be checked in diagnostic mode and there is (currently) + // no way to re-enable them for diagnostics. + if (!choiceType->lookThroughAllOptionalTypes()->is()) + return true; + // If types lined up exactly, let's favor this overload choice. if (Type(argFnType)->isEqual(choiceType)) constraint->setFavored(); diff --git a/lib/Sema/CSSolver.cpp b/lib/Sema/CSSolver.cpp index 7e547d4e030a7..cefc775226ab9 100644 --- a/lib/Sema/CSSolver.cpp +++ b/lib/Sema/CSSolver.cpp @@ -1662,8 +1662,9 @@ ConstraintSystem::filterDisjunction( // right-hand side of a conversion constraint, since having a concrete // type that we're converting to can make it possible to split the // constraint system into multiple ones. -static Constraint *selectBestBindingDisjunction( - ConstraintSystem &cs, SmallVectorImpl &disjunctions) { +static Constraint * +selectBestBindingDisjunction(ConstraintSystem &cs, + ArrayRef disjunctions) { if (disjunctions.empty()) return nullptr; @@ -1932,118 +1933,6 @@ static void existingOperatorBindingsForDisjunction(ConstraintSystem &CS, } } -void DisjunctionChoiceProducer::partitionGenericOperators( - SmallVectorImpl::iterator first, - SmallVectorImpl::iterator last) { - auto *argFnType = CS.getAppliedDisjunctionArgumentFunction(Disjunction); - if (!isOperatorDisjunction(Disjunction) || !argFnType) - return; - - auto operatorName = Choices[0]->getOverloadChoice().getName(); - if (!operatorName.getBaseIdentifier().isArithmeticOperator()) - return; - - SmallVector concreteOverloads; - SmallVector numericOverloads; - SmallVector sequenceOverloads; - SmallVector simdOverloads; - SmallVector otherGenericOverloads; - - auto refinesOrConformsTo = [&](NominalTypeDecl *nominal, KnownProtocolKind kind) -> bool { - if (!nominal) - return false; - - auto *protocol = - TypeChecker::getProtocol(CS.getASTContext(), SourceLoc(), kind); - - if (auto *refined = dyn_cast(nominal)) - return refined->inheritsFrom(protocol); - - return (bool)TypeChecker::conformsToProtocol(nominal->getDeclaredType(), protocol, - CS.DC->getParentModule()); - }; - - // Gather Numeric and Sequence overloads into separate buckets. - for (auto iter = first; iter != last; ++iter) { - unsigned index = *iter; - auto *decl = Choices[index]->getOverloadChoice().getDecl(); - auto *nominal = decl->getDeclContext()->getSelfNominalTypeDecl(); - - if (isSIMDOperator(decl)) { - simdOverloads.push_back(index); - } else if (!decl->getInterfaceType()->is()) { - concreteOverloads.push_back(index); - } else if (refinesOrConformsTo(nominal, KnownProtocolKind::AdditiveArithmetic)) { - numericOverloads.push_back(index); - } else if (refinesOrConformsTo(nominal, KnownProtocolKind::Sequence)) { - sequenceOverloads.push_back(index); - } else { - otherGenericOverloads.push_back(index); - } - } - - auto sortPartition = [&](SmallVectorImpl &partition) { - llvm::sort(partition, [&](unsigned lhs, unsigned rhs) -> bool { - auto *declA = - dyn_cast(Choices[lhs]->getOverloadChoice().getDecl()); - auto *declB = - dyn_cast(Choices[rhs]->getOverloadChoice().getDecl()); - - return TypeChecker::isDeclRefinementOf(declA, declB); - }); - }; - - // Sort sequence overloads so that refinements are attempted first. - // If the solver finds a solution with an overload, it can then skip - // subsequent choices that the successful choice is a refinement of. - sortPartition(sequenceOverloads); - - // Attempt concrete overloads first. - first = std::copy(concreteOverloads.begin(), concreteOverloads.end(), first); - - // Check if any of the known argument types conform to one of the standard - // arithmetic protocols. If so, the sovler should attempt the corresponding - // overload choices first. - for (auto arg : argFnType->getParams()) { - auto argType = arg.getPlainType(); - argType = CS.getFixedTypeRecursive(argType, /*wantRValue=*/true); - - if (argType->isTypeVariableOrMember()) - continue; - - if (TypeChecker::conformsToKnownProtocol( - argType, KnownProtocolKind::AdditiveArithmetic, - CS.DC->getParentModule())) { - first = - std::copy(numericOverloads.begin(), numericOverloads.end(), first); - numericOverloads.clear(); - break; - } - - if (TypeChecker::conformsToKnownProtocol( - argType, KnownProtocolKind::Sequence, - CS.DC->getParentModule())) { - first = - std::copy(sequenceOverloads.begin(), sequenceOverloads.end(), first); - sequenceOverloads.clear(); - break; - } - - if (TypeChecker::conformsToKnownProtocol( - argType, KnownProtocolKind::SIMD, - CS.DC->getParentModule())) { - first = std::copy(simdOverloads.begin(), simdOverloads.end(), first); - simdOverloads.clear(); - break; - } - } - - first = std::copy(otherGenericOverloads.begin(), otherGenericOverloads.end(), first); - first = std::copy(numericOverloads.begin(), numericOverloads.end(), first); - first = std::copy(sequenceOverloads.begin(), sequenceOverloads.end(), first); - first = std::copy(simdOverloads.begin(), simdOverloads.end(), first); -} - void DisjunctionChoiceProducer::partitionDisjunction( SmallVectorImpl &Ordering, SmallVectorImpl &PartitionBeginning) { @@ -2080,17 +1969,18 @@ void DisjunctionChoiceProducer::partitionDisjunction( // First collect some things that we'll generally put near the beginning or // end of the partitioning. SmallVector favored; - SmallVector everythingElse; + // start - Operator section + SmallVector concreteOperators; + SmallVector partiallySpecializedOperators; + SmallVector numericOperators; + SmallVector sequenceOperators; SmallVector simdOperators; + SmallVector genericOperators; + // end - operator section + SmallVector everythingElse; SmallVector disabled; SmallVector unavailable; - // Add existing operator bindings to the main partition first. This often - // helps the solver find a solution fast. - existingOperatorBindingsForDisjunction(CS, Choices, everythingElse); - for (auto index : everythingElse) - taken.insert(Choices[index]); - // First collect disabled and favored constraints. forEachChoice(Choices, [&](unsigned index, Constraint *constraint) -> bool { if (constraint->isDisabled()) { @@ -2125,17 +2015,96 @@ void DisjunctionChoiceProducer::partitionDisjunction( }); } - // Partition SIMD operators. - if (isOperatorDisjunction(Disjunction) && - !Choices[0]->getOverloadChoice().getName().getBaseIdentifier().isArithmeticOperator()) { - forEachChoice(Choices, [&](unsigned index, Constraint *constraint) -> bool { - if (isSIMDOperator(constraint->getOverloadChoice().getDecl())) { - simdOperators.push_back(index); - return true; + bool isArithmeticOperator = false; + + if (isOperatorDisjunction(Disjunction)) { + auto operatorName = Choices[0]->getOverloadChoice().getName(); + isArithmeticOperator = + operatorName.getBaseIdentifier().isArithmeticOperator(); + } + + if (isArithmeticOperator) { + auto refinesOrConformsTo = [&](NominalTypeDecl *nominal, + KnownProtocolKind kind) -> bool { + if (!nominal) + return false; + + auto *protocol = + TypeChecker::getProtocol(CS.getASTContext(), SourceLoc(), kind); + + if (auto *refined = dyn_cast(nominal)) + return refined->inheritsFrom(protocol); + + return (bool)TypeChecker::conformsToProtocol( + nominal->getDeclaredType(), protocol, CS.DC->getParentModule()); + }; + + auto isPartiallySpecialized = [&](ValueDecl *choice) -> bool { + auto choiceType = choice->getInterfaceType(); + + auto *fnType = choiceType->getAs(); + if (!(fnType && fnType->is())) + return false; + + if (choice->getDeclContext()->getSelfNominalTypeDecl()) + fnType = fnType->getResult()->castTo(); + + // Type has to be either bound generic e.g. `S`, + // or unbound generic e.g. `Array`, or concrete. + auto isAcceptableType = [&](Type type) { + if (auto *UGT = type->getAs()) + return isa(UGT->getDecl()); + + return type->is() || + !(type->hasTypeParameter() || type->hasDependentMember()); + }; + + if (llvm::all_of(fnType->getParams(), + [&](const AnyFunctionType::Param ¶m) { + return isAcceptableType(param.getPlainType()); + })) { + return isAcceptableType(fnType->getResult()); } return false; + }; + + forEachChoice(Choices, [&](unsigned index, Constraint *choice) -> bool { + auto *decl = choice->getOverloadChoice().getDecl(); + auto *nominal = decl->getDeclContext()->getSelfNominalTypeDecl(); + + if (isSIMDOperator(decl)) { + simdOperators.push_back(index); + } else if (!decl->getInterfaceType()->is()) { + concreteOperators.push_back(index); + } else if (isPartiallySpecialized(decl)) { + partiallySpecializedOperators.push_back(index); + } else if (refinesOrConformsTo(nominal, + KnownProtocolKind::AdditiveArithmetic)) { + numericOperators.push_back(index); + } else if (refinesOrConformsTo(nominal, KnownProtocolKind::Sequence)) { + sequenceOperators.push_back(index); + } else { + genericOperators.push_back(index); + } + return true; }); + + auto sortPartition = [&](SmallVectorImpl &partition) { + llvm::sort(partition, [&](unsigned lhs, unsigned rhs) -> bool { + auto *declA = + dyn_cast(Choices[lhs]->getOverloadChoice().getDecl()); + auto *declB = + dyn_cast(Choices[rhs]->getOverloadChoice().getDecl()); + + return TypeChecker::isDeclRefinementOf(declA, declB); + }); + }; + + // Sort sequence overloads so that refinements are attempted first. + // If the solver finds a solution with an overload, it can then skip + // subsequent choices that the successful choice is a refinement of. + sortPartition(sequenceOperators); } // Gather the remaining options. @@ -2155,8 +2124,54 @@ void DisjunctionChoiceProducer::partitionDisjunction( }; appendPartition(favored); + + if (isArithmeticOperator) { + appendPartition(concreteOperators); + appendPartition(partiallySpecializedOperators); + + if (auto *argFnType = CS.getAppliedDisjunctionArgumentFunction(Disjunction)) { + // Check if any of the known argument types conform to one of the standard + // arithmetic protocols. If so, the solver should attempt the + // corresponding overload choices first. + for (auto arg : argFnType->getParams()) { + auto argType = arg.getPlainType(); + argType = CS.getFixedTypeRecursive(argType, /*wantRValue=*/true); + + if (argType->isTypeVariableOrMember()) + continue; + + if (TypeChecker::conformsToKnownProtocol( + argType, KnownProtocolKind::AdditiveArithmetic, + CS.DC->getParentModule())) { + appendPartition(numericOperators); + numericOperators.clear(); + break; + } + + if (TypeChecker::conformsToKnownProtocol(argType, + KnownProtocolKind::Sequence, + CS.DC->getParentModule())) { + appendPartition(sequenceOperators); + sequenceOperators.clear(); + break; + } + + if (TypeChecker::conformsToKnownProtocol( + argType, KnownProtocolKind::SIMD, CS.DC->getParentModule())) { + appendPartition(simdOperators); + simdOperators.clear(); + break; + } + } + } + + appendPartition(genericOperators); + appendPartition(numericOperators); + appendPartition(sequenceOperators); + appendPartition(simdOperators); + } + appendPartition(everythingElse); - appendPartition(simdOperators); appendPartition(unavailable); appendPartition(disabled); @@ -2170,6 +2185,175 @@ Constraint *ConstraintSystem::selectDisjunction() { if (disjunctions.empty()) return nullptr; + // If there are only a few disjunctions available, + // let's just use selection alogirthm. + if (disjunctions.size() <= 2) + return selectBestDisjunction(disjunctions); + + if (SelectedDisjunctions.empty()) + return selectBestDisjunction(disjunctions); + + auto *lastDisjunction = SelectedDisjunctions.back()->getLocator(); + + // First, let's built a dictionary of all disjunctions accessible + // via their anchoring expressions. + llvm::SmallDenseMap anchoredDisjunctions; + for (auto *disjunction : disjunctions) { + if (auto anchor = simplifyLocatorToAnchor(disjunction->getLocator())) + anchoredDisjunctions.insert({anchor, disjunction}); + } + + auto lookupDisjunctionInCache = + [&anchoredDisjunctions](Expr *expr) -> Constraint * { + auto disjunction = anchoredDisjunctions.find(expr); + return disjunction != anchoredDisjunctions.end() ? disjunction->second + : nullptr; + }; + + auto findDisjunction = [&](Expr *expr) -> Constraint * { + if (!expr || !(isa(expr) || isa(expr))) + return nullptr; + + // For applications i.e. calls, let's match their function first. + if (auto *apply = dyn_cast(expr)) { + if (auto disjunction = lookupDisjunctionInCache(apply->getFn())) + return disjunction; + } + + return lookupDisjunctionInCache(expr); + }; + + auto findClosestDisjunction = [&](Expr *expr) -> Constraint * { + Constraint *selectedDisjunction = nullptr; + expr->forEachChildExpr([&](Expr *expr) -> Expr * { + if (auto *disjunction = findDisjunction(expr)) { + selectedDisjunction = disjunction; + return nullptr; + } + return expr; + }); + return selectedDisjunction; + }; + + if (auto *expr = getAsExpr(lastDisjunction->getAnchor())) { + // If this disjunction is derived from an overload set expression, + // let's look one level higher since its immediate parent is an + // operator application. + if (isa(expr)) + expr = getParentExpr(expr); + + bool isMemberRef = isa(expr); + + // Implicit `.init` calls need some special handling. + if (lastDisjunction->isLastElement()) { + if (auto *call = dyn_cast(expr)) { + expr = call->getFn(); + isMemberRef = true; + } + } + + if (isMemberRef) { + auto *parent = getParentExpr(expr); + // If this is a member application e.g. `.test(...)`, + // then let's see whether one of the arguments is a + // closure and if so, select the "best" disjunction + // from its body to be attempted next. This helps to + // type-check operator chains in a freshly resolved + // closure before moving to the next member in that + // chain of expressions for example: + // + // arr.map { $0 + 1 * 3 ... }.filter { ... }.reduce(0, +) + // + // Attempting to solve the body of the `.map` right after + // it has been selected helps to split up the constraint + // system. + if (auto *call = dyn_cast_or_null(parent)) { + if (auto *arguments = call->getArgs()) { + for (const auto &argument : *arguments) { + auto *argExpr = argument.getExpr()->getSemanticsProvidingExpr(); + auto *closure = dyn_cast(argExpr); + // Even if the body of this closure participates in type-check + // it would be handled one statement at a time via a conjunction + // constraint. + if (!(closure && closure->hasSingleExpressionBody())) + continue; + + // Note that it's important that we select the best possible + // disjunction from the body of the closure first, it helps + // to prune the solver space. + SmallVector innerDisjunctions; + + for (auto *disjunction : disjunctions) { + auto *choice = disjunction->getNestedConstraints()[0]; + if (choice->getKind() == ConstraintKind::BindOverload && + choice->getOverloadUseDC() == closure) + innerDisjunctions.push_back(disjunction); + } + + if (!innerDisjunctions.empty()) + return selectBestDisjunction(innerDisjunctions); + } + } + } + } + + // First, let's see whether there is a direct parent in scope, since + // parent is the one which is going to use the result type of the + // last disjunction. + if (auto *parent = getParentExpr(expr)) { + if (isMemberRef && isa(parent)) + parent = getParentExpr(parent); + + if (auto disjunction = findDisjunction(parent)) + return disjunction; + + // If parent is a tuple, let's collect disjunctions associated + // with its elements and run selection algorithm on them. + if (auto *tuple = dyn_cast_or_null(parent)) { + auto *elementExpr = expr; + + // If current element has any unsolved disjunctions, let's + // attempt the closest to keep solving the local element. + if (auto disjunction = findClosestDisjunction(elementExpr)) + return disjunction; + + SmallVector tupleDisjunctions; + // Find all of the disjunctions that are nested inside of + // the current tuple for selection. + for (auto *disjunction : disjunctions) { + auto anchor = disjunction->getLocator()->getAnchor(); + if (auto *expr = getAsExpr(anchor)) { + while ((expr = getParentExpr(expr))) { + if (expr == tuple) { + tupleDisjunctions.push_back(disjunction); + break; + } + } + } + } + + // Let's use a pool of all disjunctions associated with + // this tuple. Picking the best one, regardless of the + // element would stir solving towards solving everything + // in a particular element. + if (!tupleDisjunctions.empty()) + return selectBestDisjunction(tupleDisjunctions); + } + } + + // If parent is not available (e.g. because it's already bound), + // let's look into the arguments, and find the closest unbound one. + if (auto *closestDisjunction = findClosestDisjunction(expr)) + return closestDisjunction; + } + + return selectBestDisjunction(disjunctions); +} + +Constraint * +ConstraintSystem::selectBestDisjunction(ArrayRef disjunctions) { + assert(!disjunctions.empty()); + if (auto *disjunction = selectBestBindingDisjunction(*this, disjunctions)) return disjunction; @@ -2182,8 +2366,15 @@ Constraint *ConstraintSystem::selectDisjunction() { unsigned firstFavored = first->countFavoredNestedConstraints(); unsigned secondFavored = second->countFavoredNestedConstraints(); - if (!isOperatorDisjunction(first) || !isOperatorDisjunction(second)) + if (!isOperatorDisjunction(first) || !isOperatorDisjunction(second)) { + // If one of the sides has favored overloads, let's prefer it + // since it's a strong enough signal that there is something + // known about the arguments associated with the call. + if (firstFavored == 0 || secondFavored == 0) + return firstFavored > secondFavored; + return firstActive < secondActive; + } if (firstFavored == secondFavored) { // Look for additional choices that are "favored" diff --git a/lib/Sema/CSStep.cpp b/lib/Sema/CSStep.cpp index d534fbec95144..760015b8db0e7 100644 --- a/lib/Sema/CSStep.cpp +++ b/lib/Sema/CSStep.cpp @@ -546,12 +546,6 @@ StepResult DisjunctionStep::resume(bool prevFailed) { if (!choice.isGenericOperator() && choice.isSymmetricOperator()) { if (!BestNonGenericScore || score < BestNonGenericScore) { BestNonGenericScore = score; - if (shouldSkipGenericOperators()) { - // The disjunction choice producer shouldn't do the work - // to partition the generic operator choices if generic - // operators are going to be skipped. - Producer.setNeedsGenericOperatorOrdering(false); - } } } @@ -703,16 +697,30 @@ bool DisjunctionStep::shouldStopAt(const DisjunctionChoice &choice) const { bool hasUnavailableOverloads = delta.Data[SK_Unavailable] > 0; bool hasFixes = delta.Data[SK_Fix] > 0; bool hasAsyncMismatch = delta.Data[SK_AsyncInSyncMismatch] > 0; - auto isBeginningOfPartition = choice.isBeginningOfPartition(); + bool isBeginningOfPartition = choice.isBeginningOfPartition(); // Attempt to short-circuit evaluation of this disjunction only // if the disjunction choice we are comparing to did not involve: // 1. selecting unavailable overloads // 2. result in fixes being applied to reach a solution // 3. selecting an overload that results in an async/sync mismatch - return !hasUnavailableOverloads && !hasFixes && !hasAsyncMismatch && - (isBeginningOfPartition || - shortCircuitDisjunctionAt(choice, lastChoice)); + if (hasUnavailableOverloads || hasFixes || hasAsyncMismatch) + return false; + + // Similar to \c shouldSkip - don't stop at the beginning of generic partition + // for unary operators when there was a solution with concrete operator choice + // that required an implicit CGFloat<->Double conversion, because not all of + // such operators have concrete `CGFloat` overloads. + if (isBeginningOfPartition && choice.isGenericUnaryOperator()) { + if (BestNonGenericScore) { + auto &score = BestNonGenericScore->Data; + if (score[SK_ImplicitValueConversion] > 0) + return false; + } + } + + return isBeginningOfPartition || + shortCircuitDisjunctionAt(choice, lastChoice); } bool swift::isSIMDOperator(ValueDecl *value) { diff --git a/lib/Sema/CSStep.h b/lib/Sema/CSStep.h index a08b1a16b5c03..c0a37d5ac8b0f 100644 --- a/lib/Sema/CSStep.h +++ b/lib/Sema/CSStep.h @@ -653,6 +653,7 @@ class DisjunctionStep final : public BindingStep { assert(Disjunction->getKind() == ConstraintKind::Disjunction); pruneOverloadSet(Disjunction); ++cs.solverState->NumDisjunctions; + cs.SelectedDisjunctions.push_back(Disjunction); } ~DisjunctionStep() override { @@ -663,6 +664,8 @@ class DisjunctionStep final : public BindingStep { // Re-enable previously disabled overload choices. for (auto *choice : DisabledChoices) choice->setEnabled(); + + CS.SelectedDisjunctions.pop_back(); } StepResult resume(bool prevFailed) override; diff --git a/test/IDE/complete_ambiguous.swift b/test/IDE/complete_ambiguous.swift index 628e333dcf1c4..62f1177c48fe3 100644 --- a/test/IDE/complete_ambiguous.swift +++ b/test/IDE/complete_ambiguous.swift @@ -448,8 +448,8 @@ struct Struct123: Equatable { } func testBestSolutionFilter() { let a = Struct123(); - let b = [Struct123]().first(where: { $0 == a && 1 + 90 * 5 / 8 == 45 * -10 })?.structMem != .#^BEST_SOLUTION_FILTER?xfail=rdar73282163^# - let c = min(10.3, 10 / 10.4) < 6 / 7 ? true : Optional(a)?.structMem != .#^BEST_SOLUTION_FILTER2?check=BEST_SOLUTION_FILTER;xfail=rdar73282163^# + let b = [Struct123]().first(where: { $0 == a && 1 + 90 * 5 / 8 == 45 * -10 })?.structMem != .#^BEST_SOLUTION_FILTER^# + min(10.3, 10 / 10.4) < 6 / 7 ? true : Optional(a)?.structMem != .#^BEST_SOLUTION_FILTER2?check=BEST_SOLUTION_FILTER^# } // BEST_SOLUTION_FILTER: Begin completions diff --git a/validation-test/Sema/type_checker_perf/fast/mixed_double_and_float_operators.swift b/validation-test/Sema/type_checker_perf/fast/mixed_double_and_float_operators.swift new file mode 100644 index 0000000000000..666033c8eb7ab --- /dev/null +++ b/validation-test/Sema/type_checker_perf/fast/mixed_double_and_float_operators.swift @@ -0,0 +1,10 @@ +// RUN: %target-typecheck-verify-swift -solver-expression-time-threshold=1 -solver-disable-shrink +// REQUIRES: OS=macosx,tools-release,no_asan + +import Foundation + +var r: Float = 0 +var x: Double = 0 +var y: Double = 0 + +let _ = (1.0 - 1.0 / (1.0 + exp(-5.0 * (r - 0.05)/0.01))) * Float(x) + Float(y) diff --git a/validation-test/Sema/type_checker_perf/slow/mixed_string_array_addition.swift b/validation-test/Sema/type_checker_perf/fast/mixed_string_array_addition.swift similarity index 76% rename from validation-test/Sema/type_checker_perf/slow/mixed_string_array_addition.swift rename to validation-test/Sema/type_checker_perf/fast/mixed_string_array_addition.swift index aa0c3979b5ef2..75c7d1050eb4d 100644 --- a/validation-test/Sema/type_checker_perf/slow/mixed_string_array_addition.swift +++ b/validation-test/Sema/type_checker_perf/fast/mixed_string_array_addition.swift @@ -3,7 +3,6 @@ func method(_ arg: String, body: () -> [String]) {} func test(str: String, properties: [String]) { - // expected-error@+1 {{the compiler is unable to type-check this expression in reasonable time}} method(str + "" + str + "") { properties.map { param in "" + param + "" + param + "" + param + "" diff --git a/validation-test/Sema/type_checker_perf/fast/property_and_methods_with_same_name.swift.gyb b/validation-test/Sema/type_checker_perf/fast/property_and_methods_with_same_name.swift.gyb new file mode 100644 index 0000000000000..e102c55553f24 --- /dev/null +++ b/validation-test/Sema/type_checker_perf/fast/property_and_methods_with_same_name.swift.gyb @@ -0,0 +1,30 @@ +// RUN: %scale-test --begin 1 --end 10 --step 1 --select NumLeafScopes %s +// REQUIRES: asserts,no_asan + +func test(_: [A]) {} + +class A {} + +protocol P { + var arr: Int { get } +} + +extension P { + var arr: Int { get { return 42 } } +} + +class S : P { + func arr() -> [A] { [] } + func arr(_: Int = 42) -> [A] { [] } +} + +// There is a clash between `arr` property and `arr` methods +// returning `[A]` which shouldn't prevent "common result" +// determination. +func run_test(s: S) { + test(s.arr() + + %for i in range(0, N): + s.arr() + + %end + s.arr()) +} diff --git a/validation-test/Sema/type_checker_perf/slow/rdar23682605.swift b/validation-test/Sema/type_checker_perf/fast/rdar23682605.swift similarity index 91% rename from validation-test/Sema/type_checker_perf/slow/rdar23682605.swift rename to validation-test/Sema/type_checker_perf/fast/rdar23682605.swift index b7bc757fe1989..4be53b4d2ef83 100644 --- a/validation-test/Sema/type_checker_perf/slow/rdar23682605.swift +++ b/validation-test/Sema/type_checker_perf/fast/rdar23682605.swift @@ -14,7 +14,6 @@ func memoize( body: @escaping ((T)->U, T)->U ) -> (T)->U { } let fibonacci = memoize { - // expected-error@-1 {{reasonable time}} fibonacci, n in n < 2 ? Double(n) : fibonacci(n - 1) + fibonacci(n - 2) } diff --git a/validation-test/Sema/type_checker_perf/slow/rdar46713933_literal_arg.swift b/validation-test/Sema/type_checker_perf/fast/rdar46713933_literal_arg.swift similarity index 85% rename from validation-test/Sema/type_checker_perf/slow/rdar46713933_literal_arg.swift rename to validation-test/Sema/type_checker_perf/fast/rdar46713933_literal_arg.swift index a0628335b9c36..5256a92a787c7 100644 --- a/validation-test/Sema/type_checker_perf/slow/rdar46713933_literal_arg.swift +++ b/validation-test/Sema/type_checker_perf/fast/rdar46713933_literal_arg.swift @@ -8,5 +8,4 @@ func wrap(_ key: String, _ value: T) -> T { retur func wrapped() -> Int { return wrap("1", 1) + wrap("1", 1) + wrap("1", 1) + wrap("1", 1) + wrap("1", 1) + wrap("1", 1) - // expected-error@-1 {{the compiler is unable to type-check this expression in reasonable time}} } diff --git a/validation-test/Sema/type_checker_perf/fast/sr10130.swift b/validation-test/Sema/type_checker_perf/fast/sr10130.swift new file mode 100644 index 0000000000000..91e48d96eac91 --- /dev/null +++ b/validation-test/Sema/type_checker_perf/fast/sr10130.swift @@ -0,0 +1,16 @@ +// RUN: %target-typecheck-verify-swift -solver-expression-time-threshold=1 +// REQUIRES: tools-release,no_asan + +import Foundation + +let itemsPerRow = 10 +let size: CGFloat = 20 +let margin: CGFloat = 10 + +let _ = (0..<100).map { (row: CGFloat($0 / itemsPerRow), col: CGFloat($0 % itemsPerRow)) } + .map { + CGRect(x: $0.col * (size + margin) + margin, + y: $0.row * (size + margin) + margin, + width: size, + height: size) + }