diff --git a/lib/Sema/CSGen.cpp b/lib/Sema/CSGen.cpp index 69946ab363743..88fe8a2b42ed3 100644 --- a/lib/Sema/CSGen.cpp +++ b/lib/Sema/CSGen.cpp @@ -3821,6 +3821,7 @@ bool ConstraintSystem::generateConstraints(StmtCondition condition, return true; } + Type boolTy = boolDecl->getDeclaredType(); for (const auto &condElement : condition) { switch (condElement.getKind()) { case StmtConditionElement::CK_Availability: @@ -3829,6 +3830,8 @@ bool ConstraintSystem::generateConstraints(StmtCondition condition, case StmtConditionElement::CK_Boolean: { Expr *condExpr = condElement.getBoolean(); + setContextualType(condExpr, TypeLoc::withoutLoc(boolTy), CTP_Condition); + condExpr = generateConstraints(condExpr, dc); if (!condExpr) { return true; @@ -3836,8 +3839,9 @@ bool ConstraintSystem::generateConstraints(StmtCondition condition, addConstraint(ConstraintKind::Conversion, getType(condExpr), - boolDecl->getDeclaredType(), - getConstraintLocator(condExpr)); + boolTy, + getConstraintLocator(condExpr, + LocatorPathElt::ContextualType())); continue; } diff --git a/lib/Sema/CSSolver.cpp b/lib/Sema/CSSolver.cpp index 263fe0217946f..94b00f9e3a435 100644 --- a/lib/Sema/CSSolver.cpp +++ b/lib/Sema/CSSolver.cpp @@ -168,6 +168,10 @@ Solution ConstraintSystem::finalize() { solution.addedNodeTypes.insert(nodeType); } + // Remember contextual types. + solution.contextualTypes.assign( + contextualTypes.begin(), contextualTypes.end()); + for (auto &e : CheckedConformances) solution.Conformances.push_back({e.first, e.second}); @@ -232,6 +236,14 @@ void ConstraintSystem::applySolution(const Solution &solution) { setType(nodeType.first, nodeType.second); } + // Add the contextual types. + for (const auto &contextualType : solution.contextualTypes) { + if (!getContextualTypeInfo(contextualType.first)) { + setContextualType(contextualType.first, contextualType.second.typeLoc, + contextualType.second.purpose); + } + } + // Register the conformances checked along the way to arrive to solution. for (auto &conformance : solution.Conformances) CheckedConformances.push_back(conformance); diff --git a/lib/Sema/ConstraintSystem.h b/lib/Sema/ConstraintSystem.h index 020e19ccdd8b6..1cd0c4724e7a4 100644 --- a/lib/Sema/ConstraintSystem.h +++ b/lib/Sema/ConstraintSystem.h @@ -787,6 +787,15 @@ using OpenedType = std::pair; using OpenedTypeMap = llvm::DenseMap; +/// Describes contextual type information about a particular expression +/// within a constraint system. +struct ContextualTypeInfo { + TypeLoc typeLoc; + ContextualTypePurpose purpose; + + Type getType() const { return typeLoc.getType(); } +}; + /// A complete solution to a constraint system. /// /// A solution to a constraint system consists of type variable bindings to @@ -849,6 +858,9 @@ class Solution { /// The node -> type mappings introduced by this solution. llvm::MapVector addedNodeTypes; + /// Contextual types introduced by this solution. + std::vector> contextualTypes; + std::vector> Conformances; @@ -1300,13 +1312,6 @@ class ConstraintSystem { llvm::DenseMap, TypeBase *> KeyPathComponentTypes; - struct ContextualTypeInfo { - TypeLoc typeLoc; - ContextualTypePurpose purpose; - - Type getType() const { return typeLoc.getType(); } - }; - /// Contextual type information for expressions that are part of this /// constraint system. llvm::MapVector contextualTypes; @@ -2174,35 +2179,36 @@ class ConstraintSystem { return E; } - void setContextualType(Expr *expr, TypeLoc T, ContextualTypePurpose purpose) { + void setContextualType( + const Expr *expr, TypeLoc T, ContextualTypePurpose purpose) { assert(expr != nullptr && "Expected non-null expression!"); assert(contextualTypes.count(expr) == 0 && "Already set this contextual type"); contextualTypes[expr] = { T, purpose }; } - Optional getContextualTypeInfo(Expr *expr) const { + Optional getContextualTypeInfo(const Expr *expr) const { auto known = contextualTypes.find(expr); if (known == contextualTypes.end()) return None; return known->second; } - Type getContextualType(Expr *expr) const { + Type getContextualType(const Expr *expr) const { auto result = getContextualTypeInfo(expr); if (result) return result->typeLoc.getType(); return Type(); } - TypeLoc getContextualTypeLoc(Expr *expr) const { + TypeLoc getContextualTypeLoc(const Expr *expr) const { auto result = getContextualTypeInfo(expr); if (result) return result->typeLoc; return TypeLoc(); } - ContextualTypePurpose getContextualTypePurpose(Expr *expr) const { + ContextualTypePurpose getContextualTypePurpose(const Expr *expr) const { auto result = getContextualTypeInfo(expr); if (result) return result->purpose; diff --git a/test/Constraints/function_builder_diags.swift b/test/Constraints/function_builder_diags.swift index 70e15160d4865..8769496a8bb7b 100644 --- a/test/Constraints/function_builder_diags.swift +++ b/test/Constraints/function_builder_diags.swift @@ -257,7 +257,7 @@ struct MyTuplifiedStruct { } } -// Check that we're performing syntactic use diagnostics/ +// Check that we're performing syntactic use diagnostics. func acceptMetatype(_: T.Type) -> Bool { true } func syntacticUses(_: T) { @@ -269,3 +269,18 @@ func syntacticUses(_: T) { } } } + +// Check custom diagnostics within "if" conditions. +struct HasProperty { + var property: Bool = false +} + +func checkConditions(cond: Bool) { + var x = HasProperty() + + tuplify(cond) { value in + if x.property = value { // expected-error{{use of '=' in a boolean context, did you mean '=='?}} + "matched it" + } + } +}