diff --git a/include/swift/AST/ASTBridging.h b/include/swift/AST/ASTBridging.h index 9aef3597e0aad..abdd3eccd818e 100644 --- a/include/swift/AST/ASTBridging.h +++ b/include/swift/AST/ASTBridging.h @@ -2432,13 +2432,14 @@ BridgedFallthroughStmt_createParsed(swift::SourceLoc loc, BridgedDeclContext cDC); SWIFT_NAME("BridgedForEachStmt.createParsed(_:labelInfo:forLoc:tryLoc:awaitLoc:" - "unsafeLoc:pattern:inLoc:sequence:whereLoc:whereExpr:body:)") + "unsafeLoc:pattern:inLoc:sequence:whereLoc:whereExpr:body:declContext:)") BridgedForEachStmt BridgedForEachStmt_createParsed( BridgedASTContext cContext, BridgedLabeledStmtInfo cLabelInfo, swift::SourceLoc forLoc, swift::SourceLoc tryLoc, swift::SourceLoc awaitLoc, swift::SourceLoc unsafeLoc, BridgedPattern cPat, swift::SourceLoc inLoc, BridgedExpr cSequence, swift::SourceLoc whereLoc, - BridgedNullableExpr cWhereExpr, BridgedBraceStmt cBody); + BridgedNullableExpr cWhereExpr, BridgedBraceStmt cBody, + BridgedDeclContext cDeclContext); SWIFT_NAME("BridgedGuardStmt.createParsed(_:guardLoc:conds:body:)") BridgedGuardStmt BridgedGuardStmt_createParsed(BridgedASTContext cContext, diff --git a/include/swift/AST/Expr.h b/include/swift/AST/Expr.h index 12a1946d35bc8..633189993f110 100644 --- a/include/swift/AST/Expr.h +++ b/include/swift/AST/Expr.h @@ -6724,6 +6724,27 @@ class MacroExpansionExpr final : public Expr, } }; +/// OpaqueExpr - created to serve as an indirection to a ForEachStmt's sequence +/// expr and where clause to avoid visiting it twice in the ASTWalker after +/// having desugared the loop. This will only be processed in SILGen to emit +/// the underlying expression. +class OpaqueExpr final : public Expr { + Expr *OriginalExpr; + +public: + OpaqueExpr(Expr* originalExpr) + : Expr(ExprKind::Opaque, /*implicit*/ true, Type()), + OriginalExpr(originalExpr) {} + + Expr *getOriginalExpr() const { return OriginalExpr; } + SourceLoc getStartLoc() const { return OriginalExpr->getStartLoc(); } + SourceLoc getEndLoc() const { return OriginalExpr->getEndLoc(); } + + static bool classof(const Expr *E) { + return E->getKind() == ExprKind::Opaque; + } +}; + inline bool Expr::isInfixOperator() const { return isa(this) || isa(this) || isa(this) || isa(this); diff --git a/include/swift/AST/ExprNodes.def b/include/swift/AST/ExprNodes.def index a193b8d699ec0..87ea039ae97f4 100644 --- a/include/swift/AST/ExprNodes.def +++ b/include/swift/AST/ExprNodes.def @@ -218,8 +218,9 @@ EXPR(Tap, Expr) UNCHECKED_EXPR(TypeJoin, Expr) EXPR(MacroExpansion, Expr) EXPR(TypeValue, Expr) +EXPR(Opaque, Expr) // Don't forget to update the LAST_EXPR below when adding a new Expr here. -LAST_EXPR(TypeValue) +LAST_EXPR(Opaque) #undef EXPR_RANGE #undef LITERAL_EXPR diff --git a/include/swift/AST/Pattern.h b/include/swift/AST/Pattern.h index ab6e94ff01b20..a80f13e4ba5d8 100644 --- a/include/swift/AST/Pattern.h +++ b/include/swift/AST/Pattern.h @@ -48,6 +48,7 @@ enum : unsigned { NumPatternKindBits = countBitsUsed(static_cast(PatternKind::Last_Pattern)) }; enum class DescriptivePatternKind : uint8_t { + Opaque, Paren, Tuple, Named, @@ -255,6 +256,24 @@ class alignas(8) Pattern : public ASTAllocated { Pattern *walk(ASTWalker &&walker) { return walk(walker); } }; +class OpaquePattern : public Pattern { + Pattern* SubPattern = nullptr; + +public: + OpaquePattern(Pattern *p) + : Pattern(PatternKind::Opaque), SubPattern(p) {} + + SourceLoc getLoc() const { return SubPattern->getLoc(); } + SourceRange getSourceRange() const { return SubPattern->getSourceRange(); } + + Pattern* getSubPattern() const { return SubPattern; } + void setSubPattern(Pattern *p) { SubPattern = p; } + + static bool classof(const Pattern *P) { + return P->getKind() == PatternKind::Opaque; + } +}; + /// A pattern consisting solely of grouping parentheses around a /// different pattern. class ParenPattern : public Pattern { @@ -870,6 +889,8 @@ inline Pattern *Pattern::getSemanticsProvidingPattern() { return tp->getSubPattern()->getSemanticsProvidingPattern(); if (auto *vp = dyn_cast(this)) return vp->getSubPattern()->getSemanticsProvidingPattern(); + if (auto *op = dyn_cast(this)) + return op->getSubPattern()->getSemanticsProvidingPattern(); return this; } diff --git a/include/swift/AST/PatternNodes.def b/include/swift/AST/PatternNodes.def index 29cd6070657db..8e880aa946089 100644 --- a/include/swift/AST/PatternNodes.def +++ b/include/swift/AST/PatternNodes.def @@ -33,6 +33,7 @@ #endif // Metavars: x (variable binding), pat (pattern), e (expression) +PATTERN(Opaque, Pattern) // Wrapper around pat PATTERN(Paren, Pattern) // (pat) PATTERN(Tuple, Pattern) // (pat1, ..., patN), N >= 1 PATTERN(Named, Pattern) // let pat, var pat diff --git a/include/swift/AST/Stmt.h b/include/swift/AST/Stmt.h index ac40b5d206fa0..24f00e6536d8c 100644 --- a/include/swift/AST/Stmt.h +++ b/include/swift/AST/Stmt.h @@ -1007,21 +1007,22 @@ class ForEachStmt : public LabeledStmt { SourceLoc WhereLoc; Expr *WhereExpr = nullptr; BraceStmt *Body; + DeclContext *DC = nullptr; // Set by Sema: ProtocolConformanceRef sequenceConformance = ProtocolConformanceRef(); Type sequenceType; - PatternBindingDecl *iteratorVar = nullptr; - Expr *nextCall = nullptr; - OpaqueValueExpr *elementExpr = nullptr; + BraceStmt *desugaredStmt = nullptr; Expr *convertElementExpr = nullptr; + LabeledStmt *continueTarget = nullptr; + LabeledStmt *breakTarget = nullptr; public: ForEachStmt(LabeledStmtInfo LabelInfo, SourceLoc ForLoc, SourceLoc TryLoc, SourceLoc AwaitLoc, SourceLoc UnsafeLoc, Pattern *Pat, SourceLoc InLoc, Expr *Sequence, SourceLoc WhereLoc, Expr *WhereExpr, BraceStmt *Body, - std::optional implicit = std::nullopt) + DeclContext *DC, std::optional implicit = std::nullopt) : LabeledStmt(StmtKind::ForEach, getDefaultImplicitFlag(implicit, ForLoc), LabelInfo), ForLoc(ForLoc), TryLoc(TryLoc), AwaitLoc(AwaitLoc), UnsafeLoc(UnsafeLoc), @@ -1030,15 +1031,6 @@ class ForEachStmt : public LabeledStmt { setPattern(Pat); } - void setIteratorVar(PatternBindingDecl *var) { iteratorVar = var; } - PatternBindingDecl *getIteratorVar() const { return iteratorVar; } - - void setNextCall(Expr *next) { nextCall = next; } - Expr *getNextCall() const { return nextCall; } - - void setElementExpr(OpaqueValueExpr *expr) { elementExpr = expr; } - OpaqueValueExpr *getElementExpr() const { return elementExpr; } - void setConvertElementExpr(Expr *expr) { convertElementExpr = expr; } Expr *getConvertElementExpr() const { return convertElementExpr; } @@ -1080,20 +1072,29 @@ class ForEachStmt : public LabeledStmt { Expr *getParsedSequence() const { return Sequence; } void setParsedSequence(Expr *S) { Sequence = S; } - /// Type-checked version of the sequence or nullptr if this statement - /// yet to be type-checked. - Expr *getTypeCheckedSequence() const; - /// getBody - Retrieve the body of the loop. BraceStmt *getBody() const { return Body; } void setBody(BraceStmt *B) { Body = B; } SourceLoc getStartLoc() const { return getLabelLocOrKeywordLoc(ForLoc); } SourceLoc getEndLoc() const { return Body->getEndLoc(); } + + DeclContext *getDeclContext() const { return DC; } + void setDeclContext(DeclContext *newDC) { DC = newDC; } static bool classof(const Stmt *S) { return S->getKind() == StmtKind::ForEach; } + + BraceStmt* desugar(); + BraceStmt* getDesugaredStmt() const { return desugaredStmt; } + void setDesugaredStmt(BraceStmt* newStmt) { desugaredStmt = newStmt; } + + void setContinueTarget(LabeledStmt *target) { continueTarget = target; } + LabeledStmt* getContinueTarget() { return continueTarget; } + + void setBreakTarget(LabeledStmt *target) { breakTarget = target; } + LabeledStmt* getBreakTarget() { return breakTarget; } }; /// A pattern and an optional guard expression used in a 'case' statement. @@ -1545,6 +1546,31 @@ class DoCatchStmt final } }; +/// OpaqueStmt - created to serve as an indirection to a ForEachStmt's body +/// to avoid visiting it twice in the ASTWalker after having desugared the loop. +/// This ensures we only visit the body once, and this OpaqueStmt will only be +/// visited to emit the underlying statement in SILGen. +class OpaqueStmt final : public Stmt { + SourceLoc StartLoc; + SourceLoc EndLoc; + BraceStmt *Body; // FIXME: should I just use Stmt * so that this is more versatile? + // If not, should the class be renamed to be more specific? + public: + OpaqueStmt(BraceStmt* body, SourceLoc startLoc, SourceLoc endLoc) + : Stmt(StmtKind::Opaque, true /*always implicit*/), + StartLoc(startLoc), EndLoc(endLoc), Body(body) {} + + SourceLoc getLoc() const { return StartLoc; } + SourceLoc getStartLoc() const { return StartLoc; } + SourceLoc getEndLoc() const { return EndLoc; } + + BraceStmt* getUnderlyingStmt() { return Body; } + + static bool classof(const Stmt *S) { + return S->getKind() == StmtKind::Opaque; + } +}; + /// BreakStmt - The "break" and "break label" statement. class BreakStmt : public Stmt { SourceLoc Loc; diff --git a/include/swift/AST/StmtNodes.def b/include/swift/AST/StmtNodes.def index b35149ad7f437..a3da49f053814 100644 --- a/include/swift/AST/StmtNodes.def +++ b/include/swift/AST/StmtNodes.def @@ -61,6 +61,7 @@ ABSTRACT_STMT(Labeled, Stmt) LABELED_STMT(ForEach, LabeledStmt) LABELED_STMT(Switch, LabeledStmt) STMT_RANGE(Labeled, If, Switch) +STMT(Opaque, Stmt) STMT(Case, Stmt) STMT(Break, Stmt) STMT(Continue, Stmt) diff --git a/include/swift/AST/TypeCheckRequests.h b/include/swift/AST/TypeCheckRequests.h index 6125c326c1295..146fd21031cfa 100644 --- a/include/swift/AST/TypeCheckRequests.h +++ b/include/swift/AST/TypeCheckRequests.h @@ -5591,6 +5591,25 @@ class IsCustomAvailabilityDomainPermanentlyEnabled } }; +class DesugarForEachStmtRequest + : public SimpleRequest { +public: + using SimpleRequest::SimpleRequest; + +private: + friend SimpleRequest; + + // Evaluation. + BraceStmt *evaluate(Evaluator &evaluator, ForEachStmt *FES) const; + +public: + bool isCached() const { return true; } + std::optional getCachedResult() const; + void cacheResult(BraceStmt *stmt) const; +}; + #define SWIFT_TYPEID_ZONE TypeChecker #define SWIFT_TYPEID_HEADER "swift/AST/TypeCheckerTypeIDZone.def" #include "swift/Basic/DefineTypeIDZone.h" diff --git a/include/swift/AST/TypeCheckerTypeIDZone.def b/include/swift/AST/TypeCheckerTypeIDZone.def index 4ebd381402475..f68bdafff3ebc 100644 --- a/include/swift/AST/TypeCheckerTypeIDZone.def +++ b/include/swift/AST/TypeCheckerTypeIDZone.def @@ -674,3 +674,7 @@ SWIFT_REQUEST(TypeChecker, IsCustomAvailabilityDomainPermanentlyEnabled, SWIFT_REQUEST(TypeChecker, EmitPerformanceHints, evaluator::SideEffect(SourceFile *), Cached, NoLocationInfo) + +SWIFT_REQUEST(TypeChecker, DesugarForEachStmtRequest, + Stmt*(const ForEachStmt*), + Cached, NoLocationInfo) diff --git a/include/swift/Sema/Constraint.h b/include/swift/Sema/Constraint.h index 0d49cd2376fbc..6f0d376a5a30a 100644 --- a/include/swift/Sema/Constraint.h +++ b/include/swift/Sema/Constraint.h @@ -213,6 +213,8 @@ enum class ConstraintKind : char { MaterializePackExpansion, /// The first type is a l-value type whose object type is the second type. LValueObject, + /// The first type of the sequence. The second type is the element type. + ForEachElement, }; /// Classification of the different kinds of constraints. @@ -712,6 +714,7 @@ class Constraint final : public llvm::ilist_node, case ConstraintKind::ApplicableFunction: case ConstraintKind::DynamicCallableApplicableFunction: case ConstraintKind::BindOverload: + case ConstraintKind::ForEachElement: case ConstraintKind::OptionalObject: case ConstraintKind::OneWayEqual: case ConstraintKind::FallbackType: diff --git a/include/swift/Sema/ConstraintLocator.h b/include/swift/Sema/ConstraintLocator.h index fdea2f5e50966..6d1b0079c8e32 100644 --- a/include/swift/Sema/ConstraintLocator.h +++ b/include/swift/Sema/ConstraintLocator.h @@ -83,6 +83,8 @@ enum ContextualTypePurpose : uint8_t { CTP_ExprPattern, ///< `~=` operator application associated with expression /// pattern. + + CTP_ForEachElement, ///< Element expression associated with `for-in` loop. }; namespace constraints { diff --git a/include/swift/Sema/ConstraintSystem.h b/include/swift/Sema/ConstraintSystem.h index d2dc0bb82618a..4d1804fadc353 100644 --- a/include/swift/Sema/ConstraintSystem.h +++ b/include/swift/Sema/ConstraintSystem.h @@ -5022,6 +5022,12 @@ class ConstraintSystem { DeclContext *useDC, FunctionRefInfo functionRefInfo, TypeMatchOptions flags, ConstraintLocatorBuilder locator); + /// Attempt to simplify the ForEachElement constraint. + SolutionKind simplifyForEachElementConstraint( + Type first, Type second, + TypeMatchOptions flags, + ConstraintLocatorBuilder locator); + /// Attempt to simplify the optional object constraint. SolutionKind simplifyOptionalObjectConstraint( Type first, Type second, diff --git a/include/swift/Sema/SyntacticElementTarget.h b/include/swift/Sema/SyntacticElementTarget.h index 47d084bf2bb98..2f5559404b3bd 100644 --- a/include/swift/Sema/SyntacticElementTarget.h +++ b/include/swift/Sema/SyntacticElementTarget.h @@ -36,17 +36,8 @@ struct SequenceIterationInfo { /// The type of the sequence. Type sequenceType; - /// The type of an element in the sequence. - Type elementType; - /// The type of the pattern that matches the elements. Type initType; - - /// Implicit `$iterator = .makeIterator()` - PatternBindingDecl *makeIteratorVar; - - /// Implicit `$iterator.next()` call. - Expr *nextCall; }; /// Describes information about a for-in loop over a pack that needs to be @@ -605,6 +596,7 @@ class SyntacticElementTarget { case CTP_Initialization: case CTP_ForEachSequence: case CTP_ExprPattern: + case CTP_ForEachElement: break; default: assert(false && "Unexpected contextual type purpose"); diff --git a/lib/AST/ASTDumper.cpp b/lib/AST/ASTDumper.cpp index bef1826bc5c5c..4a560d1ea6500 100644 --- a/lib/AST/ASTDumper.cpp +++ b/lib/AST/ASTDumper.cpp @@ -2001,6 +2001,11 @@ namespace { printField(P->getValue(), Label::always("value")); printFoot(); } + void visitOpaquePattern(OpaquePattern *P, Label label){ + printCommon(P, "pattern_opaque", label); + printRec(P->getSubPattern(), Label::optional("sub_pattern")); + printFoot(); + } }; @@ -3252,6 +3257,10 @@ class PrintStmt : public StmtVisitor, printFlag(S->TrailingSemiLoc.isValid(), "trailing_semi"); } + void visitOpaqueStmt(OpaqueStmt *S, Label label){ + visitBraceStmt(S->getUnderlyingStmt(), label); + } + void visitBraceStmt(BraceStmt *S, Label label) { printCommon(S, "brace_stmt", label); printList(S->getElements(), [&](auto &Elt, Label label) { @@ -3332,20 +3341,15 @@ class PrintStmt : public StmtVisitor, printRec(S->getWhere(), Label::always("where")); } printRec(S->getParsedSequence(), Label::optional("parsed_sequence")); - if (S->getIteratorVar()) { - printRec(S->getIteratorVar(), Label::optional("iterator_var")); - } - if (S->getNextCall()) { - printRec(S->getNextCall(), Label::optional("next_call")); - } if (S->getConvertElementExpr()) { printRec(S->getConvertElementExpr(), Label::optional("convert_element_expr")); } - if (S->getElementExpr()) { - printRec(S->getElementExpr(), Label::optional("element_expr")); - } + printRec(S->getBody(), Label::optional("body")); + + printRec(S->getDesugaredStmt(), Label::optional("desugared_loop")); + printFoot(); } void visitBreakStmt(BreakStmt *S, Label label) { @@ -4237,6 +4241,10 @@ class PrintExpr : public ExprVisitor, printFoot(); } + void visitOpaqueExpr(OpaqueExpr *E, Label label){ + visit(E->getOriginalExpr(), label); + } + void visitPropertyWrapperValuePlaceholderExpr( PropertyWrapperValuePlaceholderExpr *E, Label label) { printCommon(E, "property_wrapper_value_placeholder_expr", label); diff --git a/lib/AST/ASTMangler.cpp b/lib/AST/ASTMangler.cpp index 9f55cb3a46ebe..ad1183b9e7486 100644 --- a/lib/AST/ASTMangler.cpp +++ b/lib/AST/ASTMangler.cpp @@ -2664,6 +2664,9 @@ namespace { VarDecl *visitAnyPattern(AnyPattern *P) { return nullptr; } + VarDecl *visitOpaquePattern(OpaquePattern *P) { + return visit(P->getSubPattern()); + } // Refutable patterns shouldn't ever come up. #define REFUTABLE_PATTERN(ID, BASE) \ diff --git a/lib/AST/ASTPrinter.cpp b/lib/AST/ASTPrinter.cpp index a135b5c1f8468..1db2c98c7d40a 100644 --- a/lib/AST/ASTPrinter.cpp +++ b/lib/AST/ASTPrinter.cpp @@ -1537,6 +1537,9 @@ static PrintNameContext getTypeMemberPrintNameContext(const Decl *d) { void PrintAST::printPattern(const Pattern *pattern) { switch (pattern->getKind()) { + case PatternKind::Opaque: + printPattern(cast(pattern)->getSubPattern()); + break; case PatternKind::Any: Printer << "_"; break; @@ -5633,6 +5636,16 @@ void PrintAST::visitTypeValueExpr(TypeValueExpr *expr) { expr->getType()->print(Printer, Options); } +void PrintAST::visitOpaqueExpr(OpaqueExpr *expr) { + // FIXME: unsure about this, maybe do nothing? + visit(expr->getOriginalExpr()); +} + +void PrintAST::visitOpaqueStmt(OpaqueStmt *stmt) { + // FIXME: unsure about this, maybe do nothing? + printBraceStmt(stmt->getUnderlyingStmt()); +} + void PrintAST::visitBraceStmt(BraceStmt *stmt) { printBraceStmt(stmt); } @@ -5810,7 +5823,7 @@ void PrintAST::visitForEachStmt(ForEachStmt *stmt) { printPattern(stmt->getPattern()); Printer << " " << tok::kw_in << " "; // FIXME: print container - if (auto *seq = stmt->getTypeCheckedSequence()) { + if (auto *seq = stmt->getParsedSequence()) { // Look through the call to '.makeIterator()' if (auto *CE = dyn_cast(seq)) { diff --git a/lib/AST/ASTScopeCreation.cpp b/lib/AST/ASTScopeCreation.cpp index b947c44dcc8d8..d7b63c679cab5 100644 --- a/lib/AST/ASTScopeCreation.cpp +++ b/lib/AST/ASTScopeCreation.cpp @@ -414,6 +414,7 @@ class NodeAdder VISIT_AND_IGNORE(ContinueStmt) VISIT_AND_IGNORE(FallthroughStmt) VISIT_AND_IGNORE(FailStmt) + VISIT_AND_IGNORE(OpaqueStmt) #undef VISIT_AND_IGNORE diff --git a/lib/AST/ASTVerifier.cpp b/lib/AST/ASTVerifier.cpp index 1f4d3eced7d6f..3223c399c761e 100644 --- a/lib/AST/ASTVerifier.cpp +++ b/lib/AST/ASTVerifier.cpp @@ -802,11 +802,6 @@ class Verifier : public ASTWalker { ForEachPatternSequences.insert(expansion); } - if (!S->getElementExpr()) - return true; - - assert(!OpaqueValues.count(S->getElementExpr())); - OpaqueValues[S->getElementExpr()] = 0; return true; } @@ -819,12 +814,6 @@ class Verifier : public ASTWalker { // Clean up for real. cleanup(expansion); } - - if (!S->getElementExpr()) - return; - - assert(OpaqueValues.count(S->getElementExpr())); - OpaqueValues.erase(S->getElementExpr()); } bool shouldVerify(InterpolatedStringLiteralExpr *expr) { diff --git a/lib/AST/ASTWalker.cpp b/lib/AST/ASTWalker.cpp index e39f3c4c3c7ed..70453e006d29f 100644 --- a/lib/AST/ASTWalker.cpp +++ b/lib/AST/ASTWalker.cpp @@ -644,6 +644,8 @@ class Traversal : public ASTVisitorgetOpaqueValuePlaceholder()) { @@ -1896,6 +1898,11 @@ Stmt *Traversal::visitPoundAssertStmt(PoundAssertStmt *S) { return S; } +Stmt* Traversal::visitOpaqueStmt(OpaqueStmt* OS){ + // We do not want to visit it. + return OS; +} + Stmt *Traversal::visitBraceStmt(BraceStmt *BS) { for (auto &Elem : BS->getElements()) { if (auto *SubExpr = Elem.dyn_cast()) { @@ -2066,28 +2073,11 @@ Stmt *Traversal::visitForEachStmt(ForEachStmt *S) { return nullptr; } - // The iterator decl is built directly on top of the sequence - // expression, so don't visit both. - // - // If for-in is already type-checked, the type-checked version - // of the sequence is going to be visited as part of `iteratorVar`. - if (auto IteratorVar = S->getIteratorVar()) { - if (doIt(IteratorVar)) - return nullptr; - - if (auto NextCall = S->getNextCall()) { - if ((NextCall = doIt(NextCall))) - S->setNextCall(NextCall); - else - return nullptr; - } - } else { - if (Expr *Sequence = S->getParsedSequence()) { + if (Expr *Sequence = S->getParsedSequence()) { if ((Sequence = doIt(Sequence))) S->setParsedSequence(Sequence); else return nullptr; - } } if (Expr *Where = S->getWhere()) { @@ -2111,6 +2101,13 @@ Stmt *Traversal::visitForEachStmt(ForEachStmt *S) { return nullptr; } + if (Stmt *Desugared = S->getDesugaredStmt()) { + if ((Desugared = doIt(Desugared))) + S->setDesugaredStmt(cast(Desugared)); + else + return nullptr; + } + return S; } @@ -2242,6 +2239,14 @@ Pattern *Traversal::visitExprPattern(ExprPattern *P) { return nullptr; } +Pattern *Traversal::visitOpaquePattern(OpaquePattern *P) { + if (Pattern *newSub = doIt(P->getSubPattern())) { + P->setSubPattern(newSub); + return P; + } + return nullptr; +} + Pattern *Traversal::visitBindingPattern(BindingPattern *P) { if (Pattern *newSub = doIt(P->getSubPattern())) { P->setSubPattern(newSub); diff --git a/lib/AST/Bridging/StmtBridging.cpp b/lib/AST/Bridging/StmtBridging.cpp index e8c11bb98d0cf..b285b3d120c12 100644 --- a/lib/AST/Bridging/StmtBridging.cpp +++ b/lib/AST/Bridging/StmtBridging.cpp @@ -191,11 +191,11 @@ BridgedForEachStmt BridgedForEachStmt_createParsed( SourceLoc forLoc, SourceLoc tryLoc, SourceLoc awaitLoc, SourceLoc unsafeLoc, BridgedPattern cPat, SourceLoc inLoc, BridgedExpr cSequence, SourceLoc whereLoc, BridgedNullableExpr cWhereExpr, - BridgedBraceStmt cBody) { + BridgedBraceStmt cBody, BridgedDeclContext cDeclContext) { return new (cContext.unbridged()) ForEachStmt(cLabelInfo.unbridged(), forLoc, tryLoc, awaitLoc, unsafeLoc, cPat.unbridged(), inLoc, cSequence.unbridged(), whereLoc, - cWhereExpr.unbridged(), cBody.unbridged()); + cWhereExpr.unbridged(), cBody.unbridged(), cDeclContext.unbridged()); } BridgedGuardStmt BridgedGuardStmt_createParsed(BridgedASTContext cContext, diff --git a/lib/AST/Expr.cpp b/lib/AST/Expr.cpp index b1e96674cbf59..da233cd0e4e97 100644 --- a/lib/AST/Expr.cpp +++ b/lib/AST/Expr.cpp @@ -466,6 +466,7 @@ ConcreteDeclRef Expr::getReferencedDecl(bool stopAtParenExpr) const { NO_REFERENCE(TypeJoin); SIMPLE_REFERENCE(MacroExpansion, getMacroRef); NO_REFERENCE(TypeValue); + NO_REFERENCE(Opaque); #undef SIMPLE_REFERENCE #undef NO_REFERENCE @@ -840,6 +841,7 @@ bool Expr::canAppendPostfixExpression(bool appendingPostfixOperator) const { case ExprKind::MacroExpansion: case ExprKind::CurrentContextIsolation: + case ExprKind::Opaque: /* FIXME: unsure about this */ return true; } @@ -1044,6 +1046,7 @@ bool Expr::isValidParentOfTypeExpr(Expr *typeExpr) const { case ExprKind::ActorIsolationErasure: case ExprKind::ExtractFunctionIsolation: case ExprKind::UnsafeCast: + case ExprKind::Opaque: return false; } diff --git a/lib/AST/Pattern.cpp b/lib/AST/Pattern.cpp index 89619f45c9895..8df1b5a53db49 100644 --- a/lib/AST/Pattern.cpp +++ b/lib/AST/Pattern.cpp @@ -50,6 +50,7 @@ DescriptivePatternKind Pattern::getDescriptiveKind() const { TRIVIAL_PATTERN_KIND(OptionalSome); TRIVIAL_PATTERN_KIND(Bool); TRIVIAL_PATTERN_KIND(Expr); + TRIVIAL_PATTERN_KIND(Opaque); case PatternKind::Binding: switch (cast(this)->getIntroducer()) { @@ -91,6 +92,7 @@ StringRef Pattern::getDescriptivePatternKindName(DescriptivePatternKind K) { ENTRY(Expr, "expression pattern"); ENTRY(Var, "'var' binding pattern"); ENTRY(Let, "'let' binding pattern"); + ENTRY(Opaque, "opaque pattern"); } #undef ENTRY llvm_unreachable("bad DescriptivePatternKind"); @@ -262,6 +264,7 @@ void Pattern::forEachVariable(llvm::function_ref fn) const { case PatternKind::Paren: case PatternKind::Typed: case PatternKind::Binding: + case PatternKind::Opaque: return getSemanticsProvidingPattern()->forEachVariable(fn); case PatternKind::Tuple: @@ -311,6 +314,8 @@ void Pattern::forEachNode(llvm::function_ref f) { return cast(this)->getSubPattern()->forEachNode(f); case PatternKind::Binding: return cast(this)->getSubPattern()->forEachNode(f); + case PatternKind::Opaque: + return cast(this)->getSubPattern()->forEachNode(f); case PatternKind::Tuple: for (auto elt : cast(this)->getElements()) @@ -789,6 +794,7 @@ Pattern::getOwnership( USE_SUBPATTERN(Paren) USE_SUBPATTERN(Typed) USE_SUBPATTERN(Binding) + USE_SUBPATTERN(Opaque) #undef USE_SUBPATTERN void visitTuplePattern(TuplePattern *p) { for (auto &element : p->getElements()) { diff --git a/lib/AST/Stmt.cpp b/lib/AST/Stmt.cpp index 835f8c736eb44..706f03fcb1292 100644 --- a/lib/AST/Stmt.cpp +++ b/lib/AST/Stmt.cpp @@ -89,6 +89,8 @@ StringRef Stmt::getDescriptiveKindName(StmtKind K) { return "discard"; case StmtKind::PoundAssert: return "#assert"; + case StmtKind::Opaque: + return "opaque"; } llvm_unreachable("Unhandled case in switch!"); } @@ -453,13 +455,6 @@ void ForEachStmt::setPattern(Pattern *p) { Pat->markOwnedByStatement(this); } -Expr *ForEachStmt::getTypeCheckedSequence() const { - if (auto *expansion = dyn_cast(getParsedSequence())) - return expansion; - - return iteratorVar ? iteratorVar->getInit(/*index=*/0) : nullptr; -} - DoCatchStmt *DoCatchStmt::create(DeclContext *dc, LabeledStmtInfo labelInfo, SourceLoc doLoc, SourceLoc throwsLoc, TypeLoc thrownType, Stmt *body, @@ -486,6 +481,13 @@ bool DoCatchStmt::isSyntacticallyExhaustive() const { return false; } +BraceStmt *ForEachStmt::desugar() { + auto &ctx = this->getDeclContext()->getASTContext(); + return evaluateOrDefault(ctx.evaluator, + DesugarForEachStmtRequest{this}, + nullptr); +} + Type DoCatchStmt::getExplicitCaughtType() const { ASTContext &ctx = DC->getASTContext(); return CatchNode(const_cast(this)).getExplicitCaughtType(ctx); diff --git a/lib/AST/TypeCheckRequests.cpp b/lib/AST/TypeCheckRequests.cpp index fbb490e0932ed..90367c754f694 100644 --- a/lib/AST/TypeCheckRequests.cpp +++ b/lib/AST/TypeCheckRequests.cpp @@ -2890,3 +2890,20 @@ void IsCustomAvailabilityDomainPermanentlyEnabled::cacheResult( domain->flags.isPermanentlyEnabledComputed = true; domain->flags.isPermanentlyEnabled = isPermanentlyEnabled; } + +//----------------------------------------------------------------------------// +// DesugarForEachStmtRequest computation. +//----------------------------------------------------------------------------// +std::optional DesugarForEachStmtRequest::getCachedResult() const { + auto *fes = std::get<0>(getStorage()); + auto* desugaredStmt = fes->getDesugaredStmt(); + if (!desugaredStmt){ + return std::nullopt; + } + return desugaredStmt; +} + +void DesugarForEachStmtRequest::cacheResult(BraceStmt *stmt) const { + auto *fes = std::get<0>(getStorage()); + fes->setDesugaredStmt(stmt); +} diff --git a/lib/ASTGen/Sources/ASTGen/Stmts.swift b/lib/ASTGen/Sources/ASTGen/Stmts.swift index db3bbc465f17c..9d4f4eaab9041 100644 --- a/lib/ASTGen/Sources/ASTGen/Stmts.swift +++ b/lib/ASTGen/Sources/ASTGen/Stmts.swift @@ -389,7 +389,8 @@ extension ASTGenVisitor { sequence: self.generate(expr: node.sequence), whereLoc: self.generateSourceLoc(node.whereClause?.whereKeyword), whereExpr: self.generate(expr: node.whereClause?.condition), - body: self.generate(codeBlock: node.body) + body: self.generate(codeBlock: node.body), + declContext: self.declContext ) } diff --git a/lib/Parse/ParseStmt.cpp b/lib/Parse/ParseStmt.cpp index 921a865c297e7..df9a665e37e4a 100644 --- a/lib/Parse/ParseStmt.cpp +++ b/lib/Parse/ParseStmt.cpp @@ -2506,7 +2506,7 @@ ParserResult Parser::parseStmtForEach(LabeledStmtInfo LabelInfo) { new (Context) ForEachStmt(LabelInfo, ForLoc, TryLoc, AwaitLoc, UnsafeLoc, pattern.get(), InLoc, Container.get(), WhereLoc, Where.getPtrOrNull(), - Body.get())); + Body.get(), CurDeclContext)); } /// diff --git a/lib/SILGen/ASTVisitor.h b/lib/SILGen/ASTVisitor.h index 919bf08904ce7..28f8b127dc5dc 100644 --- a/lib/SILGen/ASTVisitor.h +++ b/lib/SILGen/ASTVisitor.h @@ -56,6 +56,10 @@ class ASTVisitor : public swift::ASTVisitorgetSubPattern()); } + void visitOpaquePattern(OpaquePattern *P) { + return visit(P->getSubPattern()); + } void visitTuplePattern(TuplePattern *P) { path.push_back(0); for (unsigned i : indices(P->getElements())) { diff --git a/lib/SILGen/SILGenConstructor.cpp b/lib/SILGen/SILGenConstructor.cpp index 2c096c46c986f..583b56d9b85d6 100644 --- a/lib/SILGen/SILGenConstructor.cpp +++ b/lib/SILGen/SILGenConstructor.cpp @@ -1457,6 +1457,9 @@ emitMemberInit(SILGenFunction &SGF, VarDecl *selfDecl, Pattern *pattern) { return emitMemberInit(SGF, selfDecl, cast(pattern)->getSubPattern()); + case PatternKind::Opaque: + return emitMemberInit(SGF, selfDecl, + cast(pattern)->getSubPattern()); #define PATTERN(Name, Parent) #define REFUTABLE_PATTERN(Name, Parent) case PatternKind::Name: #include "swift/AST/PatternNodes.def" diff --git a/lib/SILGen/SILGenDecl.cpp b/lib/SILGen/SILGenDecl.cpp index 0efd5c0e56c69..02b4baf634b02 100644 --- a/lib/SILGen/SILGenDecl.cpp +++ b/lib/SILGen/SILGenDecl.cpp @@ -1475,6 +1475,9 @@ struct InitializationForPattern InitializationPtr visitBindingPattern(BindingPattern *P) { return visit(P->getSubPattern()); } + InitializationPtr visitOpaquePattern(OpaquePattern *P) { + return visit(P->getSubPattern()); + } // AnyPatterns (i.e, _) don't require any storage. Any value bound here will // just be dropped. diff --git a/lib/SILGen/SILGenExpr.cpp b/lib/SILGen/SILGenExpr.cpp index 6aad3a9b994de..45cd06ea4fae8 100644 --- a/lib/SILGen/SILGenExpr.cpp +++ b/lib/SILGen/SILGenExpr.cpp @@ -612,6 +612,7 @@ namespace { RValue visitMacroExpansionExpr(MacroExpansionExpr *E, SGFContext C); RValue visitCurrentContextIsolationExpr(CurrentContextIsolationExpr *E, SGFContext C); RValue visitTypeValueExpr(TypeValueExpr *E, SGFContext C); + RValue visitOpaqueExpr(OpaqueExpr *E, SGFContext C); }; } // end anonymous namespace @@ -6080,6 +6081,7 @@ namespace { USE_SUBPATTERN(Paren) USE_SUBPATTERN(Typed) USE_SUBPATTERN(Binding) + USE_SUBPATTERN(Opaque) #undef USE_SUBPATTERN #define PATTERN(Kind, Parent) @@ -6738,6 +6740,10 @@ RValue RValueEmitter::visitMakeTemporarilyEscapableExpr( return rvalue; } +RValue RValueEmitter::visitOpaqueExpr(OpaqueExpr *E, SGFContext C) { + return visit(E->getOriginalExpr()); +} + RValue RValueEmitter::visitOpaqueValueExpr(OpaqueValueExpr *E, SGFContext C) { auto found = SGF.OpaqueValues.find(E); assert(found != SGF.OpaqueValues.end()); diff --git a/lib/SILGen/SILGenGlobalVariable.cpp b/lib/SILGen/SILGenGlobalVariable.cpp index d1af47fcd48e7..7f8033bbd781c 100644 --- a/lib/SILGen/SILGenGlobalVariable.cpp +++ b/lib/SILGen/SILGenGlobalVariable.cpp @@ -187,6 +187,9 @@ struct GenGlobalAccessors : public PatternVisitor void visitBindingPattern(BindingPattern *P) { return visit(P->getSubPattern()); } + void visitOpaquePattern(OpaquePattern *P) { + return visit(P->getSubPattern()); + } void visitTuplePattern(TuplePattern *P) { for (auto &elt : P->getElements()) visit(elt.getPattern()); diff --git a/lib/SILGen/SILGenPattern.cpp b/lib/SILGen/SILGenPattern.cpp index 4b4a98ec11279..65f14524f0ca3 100644 --- a/lib/SILGen/SILGenPattern.cpp +++ b/lib/SILGen/SILGenPattern.cpp @@ -56,6 +56,9 @@ static void dumpPattern(const Pattern *p, llvm::raw_ostream &os) { } p = p->getSemanticsProvidingPattern(); switch (p->getKind()) { + case PatternKind::Opaque: + dumpPattern(cast(p)->getSubPattern(), os); + return; case PatternKind::Any: os << '_'; return; @@ -110,6 +113,7 @@ static bool isDirectlyRefutablePattern(const Pattern *p) { if (!p) return false; switch (p->getKind()) { + case PatternKind::Any: case PatternKind::Named: case PatternKind::Expr: @@ -130,6 +134,7 @@ static bool isDirectlyRefutablePattern(const Pattern *p) { case PatternKind::Paren: case PatternKind::Typed: case PatternKind::Binding: + case PatternKind::Opaque: return isDirectlyRefutablePattern(p->getSemanticsProvidingPattern()); } llvm_unreachable("bad pattern"); @@ -191,6 +196,7 @@ static unsigned getNumSpecializationsRecursive(const Pattern *p, unsigned n) { case PatternKind::Paren: case PatternKind::Typed: case PatternKind::Binding: + case PatternKind::Opaque: return getNumSpecializationsRecursive(p->getSemanticsProvidingPattern(), n); } llvm_unreachable("bad pattern"); @@ -232,6 +238,7 @@ static bool isWildcardPattern(const Pattern *p) { case PatternKind::Paren: case PatternKind::Typed: case PatternKind::Binding: + case PatternKind::Opaque: return isWildcardPattern(p->getSemanticsProvidingPattern()); } @@ -291,6 +298,9 @@ static Pattern *getSimilarSpecializingPattern(Pattern *p, Pattern *first) { } return nullptr; } + + case PatternKind::Opaque: + return getSimilarSpecializingPattern(cast(p)->getSubPattern(), first); case PatternKind::Paren: case PatternKind::Binding: @@ -1236,6 +1246,8 @@ bindRefutablePatterns(const ClauseRow &row, ArgArray args, case PatternKind::Named: break; + case PatternKind::Opaque: + case PatternKind::Expr: { ExprPattern *exprPattern = cast(pattern); DebugLocOverrideRAII LocOverride{SGF.B, @@ -1585,6 +1597,8 @@ void PatternMatchEmission::emitSpecializedDispatch(ClauseMatrix &clauses, case PatternKind::Paren: case PatternKind::Typed: case PatternKind::Binding: + // FIXME: unsure about this but don't know what else to do + case PatternKind::Opaque: llvm_unreachable("non-semantic pattern kind!"); case PatternKind::Tuple: @@ -2752,6 +2766,9 @@ void PatternMatchEmission::emitDestructiveCaseBlocks() { void visitTypedPattern(TypedPattern *P, ManagedValue mv) { return visit(P->getSubPattern(), mv); } + void visitOpaquePattern(OpaquePattern *P, ManagedValue mv) { + return visit(P->getSubPattern(), mv); + } }; // Now we can start destructively binding the value. diff --git a/lib/SILGen/SILGenStmt.cpp b/lib/SILGen/SILGenStmt.cpp index 78ab99c3aa731..8d1fd0e2f21c6 100644 --- a/lib/SILGen/SILGenStmt.cpp +++ b/lib/SILGen/SILGenStmt.cpp @@ -1404,10 +1404,14 @@ void StmtEmitter::visitRepeatWhileStmt(RepeatWhileStmt *S) { SGF.BreakContinueDestStack.pop_back(); } +void StmtEmitter::visitOpaqueStmt(OpaqueStmt *S) { + visitBraceStmt(S->getUnderlyingStmt()); +} + void StmtEmitter::visitForEachStmt(ForEachStmt *S) { - if (auto *expansion = - dyn_cast(S->getTypeCheckedSequence())) { + if (auto *expansion = + dyn_cast(S->getParsedSequence())) { auto formalPackType = dyn_cast( PackType::get(SGF.getASTContext(), expansion->getType()) ->getCanonicalType()); @@ -1442,170 +1446,9 @@ void StmtEmitter::visitForEachStmt(ForEachStmt *S) { return; } - // Emit the 'iterator' variable that we'll be using for iteration. - LexicalScope OuterForScope(SGF, CleanupLocation(S)); - SGF.emitPatternBinding(S->getIteratorVar(), - /*index=*/0, /*debuginfo*/ true); - - // If we ever reach an unreachable point, stop emitting statements. - // This will need revision if we ever add goto. - if (!SGF.B.hasValidInsertionPoint()) return; - - // If generator's optional result is address-only, create a stack allocation - // to hold the results. This will be initialized on every entry into the loop - // header and consumed by the loop body. On loop exit, the terminating value - // will be in the buffer. - CanType optTy = S->getNextCall()->getType()->getCanonicalType(); - auto &optTL = SGF.getTypeLowering(optTy); - - SILValue addrOnlyBuf; - bool nextResultTyIsAddressOnly = - optTL.isAddressOnly() && SGF.silConv.useLoweredAddresses(); - - if (nextResultTyIsAddressOnly) - addrOnlyBuf = SGF.emitTemporaryAllocation(S, optTL.getLoweredType()); - - // Create a new basic block and jump into it. - JumpDest loopDest = createJumpDest(S->getBody()); - SGF.B.emitBlock(loopDest.getBlock(), S); - - // Set the destinations for 'break' and 'continue'. - JumpDest endDest = createJumpDest(S->getBody()); - SGF.BreakContinueDestStack.push_back({ S, endDest, loopDest }); - - bool hasElementConversion = S->getElementExpr(); - auto buildElementRValue = [&](SGFContext ctx) { - RValue result; - result = SGF.emitRValue(S->getNextCall(), - hasElementConversion ? SGFContext() : ctx); - return result; - }; - - ManagedValue nextBufOrElement; - // Then emit the loop destination block. - // - // Advance the generator. Use a scope to ensure that any temporary stack - // allocations in the subexpression are immediately released. - if (nextResultTyIsAddressOnly) { - // Create the initialization outside of the innerForScope so that the - // innerForScope doesn't clean it up. - auto nextInit = SGF.useBufferAsTemporary(addrOnlyBuf, optTL); - { - ArgumentScope innerForScope(SGF, SILLocation(S)); - SILLocation loc = SILLocation(S); - RValue result = buildElementRValue(SGFContext(nextInit.get())); - if (!result.isInContext()) { - ArgumentSource(SILLocation(S->getTypeCheckedSequence()), - std::move(result).ensurePlusOne(SGF, loc)) - .forwardInto(SGF, nextInit.get()); - } - innerForScope.pop(); - } - nextBufOrElement = nextInit->getManagedAddress(); - } else { - ArgumentScope innerForScope(SGF, SILLocation(S)); - nextBufOrElement = innerForScope.popPreservingValue( - buildElementRValue(SGFContext()) - .getAsSingleValue(SGF, SILLocation(S))); - } - - SILBasicBlock *failExitingBlock = createBasicBlock(); - SwitchEnumBuilder switchEnumBuilder(SGF.B, S, nextBufOrElement); - - auto convertElementRValue = [&](ManagedValue inputValue, SGFContext ctx) -> ManagedValue { - SILGenFunction::OpaqueValueRAII pushOpaqueValue(SGF, S->getElementExpr(), - inputValue); - return SGF.emitRValue(S->getConvertElementExpr(), ctx) - .getAsSingleValue(SGF, SILLocation(S)); - }; - - switchEnumBuilder.addOptionalSomeCase( - createBasicBlock(), loopDest.getBlock(), - [&](ManagedValue inputValue, SwitchCaseFullExpr &&scope) { - SGF.emitProfilerIncrement(S->getBody()); - - // Emit the loop body. - // The declared variable(s) for the current element are destroyed - // at the end of each loop iteration. - { - Scope innerForScope(SGF.Cleanups, CleanupLocation(S->getBody())); - // Emit the initialization for the pattern. If any of the bound - // patterns - // fail (because this is a 'for case' pattern with a refutable - // pattern, - // the code should jump to the continue block. - InitializationPtr initLoopVars = - SGF.emitPatternBindingInitialization(S->getPattern(), loopDest); - - // If we had a loadable "next" generator value, we know it is present. - // Get the value out of the optional, and wrap it up with a cleanup so - // that any exits out of this scope properly clean it up. - // - // *NOTE* If we do not have an address only value, then inputValue is - // *already properly unwrapped. - SGFContext loopVarCtx{initLoopVars.get()}; - if (nextResultTyIsAddressOnly) { - inputValue = SGF.emitUncheckedGetOptionalValueFrom( - S, inputValue, optTL, - hasElementConversion ? SGFContext() : loopVarCtx); - } - - CanType optConvertedTy = optTy; - if (hasElementConversion) { - inputValue = convertElementRValue(inputValue, loopVarCtx); - optConvertedTy = - OptionalType::get(S->getConvertElementExpr()->getType()) - ->getCanonicalType(); - } - if (!inputValue.isInContext()) - RValue(SGF, S, optConvertedTy.getOptionalObjectType(), inputValue) - .forwardInto(SGF, S->getBody(), initLoopVars.get()); - - // Now that the pattern has been initialized, check any where - // condition. - // If it fails, loop around as if 'continue' happened. - if (auto *Where = S->getWhere()) { - auto cond = SGF.emitCondition(Where, /*invert*/ true); - // If self is null, branch to the epilog. - cond.enterTrue(SGF); - SGF.Cleanups.emitBranchAndCleanups(loopDest, Where, {}); - cond.exitTrue(SGF); - cond.complete(SGF); - } - - visit(S->getBody()); - } - - // If we emitted an unreachable in the body, we will not have a valid - // insertion point. Just return early. - if (!SGF.B.hasValidInsertionPoint()) { - scope.unreachableExit(); - return; - } - - // Otherwise, associate the loop body's closing brace with this branch. - RegularLocation L(S->getBody()); - L.pointToEnd(); - scope.exitAndBranch(L); - }, - SGF.loadProfilerCount(S->getBody())); - - // We add loop fail block, just to be defensive about intermediate - // transformations performing cleanups at scope.exit(). We still jump to the - // contBlock. - switchEnumBuilder.addOptionalNoneCase( - createBasicBlock(), failExitingBlock, - [&](ManagedValue inputValue, SwitchCaseFullExpr &&scope) { - assert(!inputValue && "None should not be passed an argument!"); - scope.exitAndBranch(S); - }, - SGF.loadProfilerCount(S)); - - std::move(switchEnumBuilder).emit(); - - SGF.B.emitBlock(failExitingBlock); - emitOrDeleteBlock(SGF, endDest, S); - SGF.BreakContinueDestStack.pop_back(); + auto* braceStmt = S->getDesugaredStmt(); + if (braceStmt) + visitBraceStmt(braceStmt); } void StmtEmitter::visitBreakStmt(BreakStmt *S) { @@ -1616,6 +1459,9 @@ void StmtEmitter::visitBreakStmt(BreakStmt *S) { void SILGenFunction::emitBreakOutOf(SILLocation loc, Stmt *target) { // Find the target JumpDest based on the target that sema filled into the // stmt. + if (auto *forEachStmt = dyn_cast(target)) + if (auto *breakTarget = forEachStmt->getBreakTarget()) + target = breakTarget; for (auto &elt : BreakContinueDestStack) { if (target == elt.Target) { Cleanups.emitBranchAndCleanups(elt.BreakDest, loc); @@ -1628,10 +1474,15 @@ void SILGenFunction::emitBreakOutOf(SILLocation loc, Stmt *target) { void StmtEmitter::visitContinueStmt(ContinueStmt *S) { assert(S->getTarget() && "Sema didn't fill in continue target?"); + auto* target = S->getTarget(); + if (auto *forEachStmt = dyn_cast(target)) + if (auto *continueTarget = forEachStmt->getContinueTarget()) + target = continueTarget; + // Find the target JumpDest based on the target that sema filled into the // stmt. for (auto &elt : SGF.BreakContinueDestStack) { - if (S->getTarget() == elt.Target) { + if (target == elt.Target) { SGF.Cleanups.emitBranchAndCleanups(elt.ContinueDest, S); return; } diff --git a/lib/SILOptimizer/Mandatory/MoveOnlyDiagnostics.cpp b/lib/SILOptimizer/Mandatory/MoveOnlyDiagnostics.cpp index 514391ada4e0c..82b2a99b71e15 100644 --- a/lib/SILOptimizer/Mandatory/MoveOnlyDiagnostics.cpp +++ b/lib/SILOptimizer/Mandatory/MoveOnlyDiagnostics.cpp @@ -187,6 +187,7 @@ void DiagnosticEmitter::emitMissingConsumeInDiscardingContext( case StmtKind::Case: case StmtKind::Fallthrough: case StmtKind::Discard: + case StmtKind::Opaque: return false; }; } diff --git a/lib/Sema/BuilderTransform.cpp b/lib/Sema/BuilderTransform.cpp index 74a74fb2b591e..23d398d11d4e2 100644 --- a/lib/Sema/BuilderTransform.cpp +++ b/lib/Sema/BuilderTransform.cpp @@ -735,7 +735,7 @@ class ResultBuilderTransform forEachStmt->getParsedSequence(), forEachStmt->getWhereLoc(), forEachStmt->getWhere(), cloneBraceWith(forEachStmt->getBody(), newBody), - forEachStmt->isImplicit()); + forEachStmt->getDeclContext(), forEachStmt->isImplicit()); // For a body of new `do` statement that holds updated `for-in` loop // and epilog that consists of a call to `buildArray` that forms the @@ -771,6 +771,7 @@ class ResultBuilderTransform UNSUPPORTED_STMT(Fail) UNSUPPORTED_STMT(PoundAssert) UNSUPPORTED_STMT(Case) + UNSUPPORTED_STMT(Opaque) #undef UNSUPPORTED_STMT diff --git a/lib/Sema/CSApply.cpp b/lib/Sema/CSApply.cpp index c28f85d68ba38..5e43ee6bee1eb 100644 --- a/lib/Sema/CSApply.cpp +++ b/lib/Sema/CSApply.cpp @@ -2758,6 +2758,11 @@ namespace { return expr; } + Expr *visitOpaqueExpr(OpaqueExpr *expr) { + // Do nothing with opaque expressions. + return expr; + } + Expr *visitCodeCompletionExpr(CodeCompletionExpr *expr) { // Do nothing with code completion expressions. auto toType = simplifyType(cs.getType(expr)); @@ -9336,107 +9341,18 @@ applySolutionToForEachStmtPreamble(ForEachStmt *stmt, auto &ctx = cs.getASTContext(); auto *parsedSequence = stmt->getParsedSequence(); - bool isAsync = stmt->getAwaitLoc().isValid(); // Simplify the various types. info.sequenceType = solution.simplifyType(info.sequenceType); - info.elementType = solution.simplifyType(info.elementType); info.initType = solution.simplifyType(info.initType); - // First, let's apply the solution to the expression. - auto *makeIteratorVar = info.makeIteratorVar; - - auto makeIteratorTarget = *cs.getTargetFor({makeIteratorVar, /*index=*/0}); + auto sequenceTarget = *cs.getTargetFor(parsedSequence); - auto rewrittenTarget = rewriter.rewriteTarget(makeIteratorTarget); + auto rewrittenTarget = rewriter.rewriteTarget(sequenceTarget); if (!rewrittenTarget) return std::nullopt; - // Set type-checked initializer and mark it as such. - { - makeIteratorVar->setInit(/*index=*/0, rewrittenTarget->getAsExpr()); - makeIteratorVar->setInitializerChecked(/*index=*/0); - } - - stmt->setIteratorVar(makeIteratorVar); - - // Now, `$iterator.next()` call. - { - auto nextTarget = *cs.getTargetFor(info.nextCall); - - auto rewrittenTarget = rewriter.rewriteTarget(nextTarget); - if (!rewrittenTarget) - return std::nullopt; - - Expr *nextCall = rewrittenTarget->getAsExpr(); - // Wrap a call to `next()` into `try await` since `AsyncIteratorProtocol` - // witness could be `async throws`. - if (isAsync) { - // Cannot use `forEachChildExpr` here because we need to - // to wrap a call in `try` and then stop immediately after. - struct TryInjector : ASTWalker { - ASTContext &C; - const Solution &S; - - bool ShouldStop = false; - - TryInjector(ASTContext &ctx, const Solution &solution) - : C(ctx), S(solution) {} - - MacroWalking getMacroWalkingBehavior() const override { - return MacroWalking::Expansion; - } - - PreWalkResult walkToExprPre(Expr *E) override { - if (ShouldStop) - return Action::Stop(); - - if (auto *call = dyn_cast(E)) { - // There is a single call expression in `nextCall`. - ShouldStop = true; - - auto nextRefType = - S.getResolvedType(call->getFn())->castTo(); - - // If the inferred witness is throwing, we need to wrap the call - // into `try` expression. - if (nextRefType->isThrowing()) { - auto *tryExpr = TryExpr::createImplicit( - C, /*tryLoc=*/call->getStartLoc(), call, call->getType()); - // Cannot stop here because we need to make sure that - // the new expression gets injected into AST. - return Action::SkipNode(tryExpr); - } - } - - return Action::Continue(E); - } - }; - - nextCall->walk(TryInjector(ctx, solution)); - } - - stmt->setNextCall(nextCall); - } - - // Convert that std::optional value to the type of the pattern. - auto optPatternType = OptionalType::get(info.initType); - Type nextResultType = OptionalType::get(info.elementType); - if (!optPatternType->isEqual(nextResultType)) { - OpaqueValueExpr *elementExpr = new (ctx) OpaqueValueExpr( - stmt->getInLoc(), nextResultType->getOptionalObjectType(), - /*isPlaceholder=*/false); - cs.cacheExprTypes(elementExpr); - - auto *loc = cs.getConstraintLocator(parsedSequence, - ConstraintLocator::SequenceElementType); - auto *convertExpr = solution.coerceToType(elementExpr, info.initType, loc); - if (!convertExpr) - return std::nullopt; - - stmt->setElementExpr(elementExpr); - stmt->setConvertElementExpr(convertExpr); - } + stmt->setParsedSequence(rewrittenTarget->getAsExpr()); // Get the conformance of the sequence type to the Sequence protocol. auto sequenceProto = TypeChecker::getProtocol( @@ -9590,6 +9506,7 @@ ExprWalker::rewriteTarget(SyntacticElementTarget target) { case CTP_Condition: case CTP_WrappedProperty: case CTP_SingleValueStmtBranch: + case CTP_ForEachElement: result.setExpr(rewrittenExpr); break; } diff --git a/lib/Sema/CSBindings.cpp b/lib/Sema/CSBindings.cpp index 2946181cf41f1..5443fbf56b99d 100644 --- a/lib/Sema/CSBindings.cpp +++ b/lib/Sema/CSBindings.cpp @@ -2059,6 +2059,7 @@ void PotentialBindings::infer(ConstraintSystem &CS, case ConstraintKind::Conversion: case ConstraintKind::ArgumentConversion: case ConstraintKind::OperatorArgumentConversion: + case ConstraintKind::ForEachElement: case ConstraintKind::OptionalObject: case ConstraintKind::UnresolvedMemberChainBase: case ConstraintKind::LValueObject: { diff --git a/lib/Sema/CSDiagnostics.cpp b/lib/Sema/CSDiagnostics.cpp index 80abf67a9db35..aaf5b9891a021 100644 --- a/lib/Sema/CSDiagnostics.cpp +++ b/lib/Sema/CSDiagnostics.cpp @@ -868,6 +868,7 @@ GenericArgumentsMismatchFailure::getDiagnosticFor( case CTP_EnumCaseRawValue: case CTP_ExprPattern: case CTP_SingleValueStmtBranch: + case CTP_ForEachElement: break; } return std::nullopt; @@ -2963,6 +2964,7 @@ getContextualNilDiagnostic(ContextualTypePurpose CTP) { case CTP_WrappedProperty: case CTP_ExprPattern: case CTP_SingleValueStmtBranch: + case CTP_ForEachElement: return std::nullopt; case CTP_EnumCaseRawValue: @@ -3748,6 +3750,7 @@ ContextualFailure::getDiagnosticFor(ContextualTypePurpose context, case CTP_Unused: case CTP_YieldByReference: case CTP_ExprPattern: + case CTP_ForEachElement: break; } return std::nullopt; diff --git a/lib/Sema/CSGen.cpp b/lib/Sema/CSGen.cpp index ed18eb6b96617..83ea5e63aa6e7 100644 --- a/lib/Sema/CSGen.cpp +++ b/lib/Sema/CSGen.cpp @@ -2718,6 +2718,13 @@ namespace { }; switch (pattern->getKind()) { + case PatternKind::Opaque: { + auto *opaque = cast(pattern); + auto *ogPattern = opaque->getSubPattern(); + auto underlyingType = ogPattern->getType(); + return setType(underlyingType); + + } case PatternKind::Paren: { auto *paren = cast(pattern); @@ -4243,6 +4250,10 @@ namespace { return resultType; } + virtual Type visitOpaqueExpr(OpaqueExpr *E) { + return E->getOriginalExpr()->getType(); + } + static bool isTriggerFallbackDiagnosticBuiltin(UnresolvedDotExpr *UDE, ASTContext &Context) { auto *DRE = dyn_cast(UDE->getBase()); @@ -4656,18 +4667,9 @@ static std::optional generateForEachStmtConstraints(ConstraintSystem &cs, DeclContext *dc, ForEachStmt *stmt, Pattern *typeCheckedPattern, bool shouldBindPatternVarsOneWay) { - ASTContext &ctx = cs.getASTContext(); bool isAsync = stmt->getAwaitLoc().isValid(); auto *sequenceExpr = stmt->getParsedSequence(); - // If we have an unsafe expression for the sequence, lift it out of the - // sequence expression. We'll put it back after we've introduced the - // various calls. - UnsafeExpr *unsafeExpr = dyn_cast(sequenceExpr); - if (unsafeExpr) { - sequenceExpr = unsafeExpr->getSubExpr(); - } - auto contextualLocator = cs.getConstraintLocator( sequenceExpr, LocatorPathElt::ContextualType(CTP_ForEachSequence)); auto elementLocator = cs.getConstraintLocator( @@ -4682,164 +4684,29 @@ generateForEachStmtConstraints(ConstraintSystem &cs, DeclContext *dc, if (!sequenceProto) return std::nullopt; - std::string name; - { - if (auto np = dyn_cast_or_null(stmt->getPattern())) - name = "$"+np->getBoundName().str().str(); - name += "$generator"; - } - - auto *makeIteratorVar = new (ctx) - VarDecl(/*isStatic=*/false, VarDecl::Introducer::Var, - sequenceExpr->getStartLoc(), ctx.getIdentifier(name), dc); - makeIteratorVar->setImplicit(); - - // FIXME: Apply `nonisolated(unsafe)` to async iterators. - // - // Async iterators are not `Sendable`; they're only meant to be used from - // the isolation domain that creates them. But the `next()` method runs on - // the generic executor, so calling it from an actor-isolated context passes - // non-`Sendable` state across the isolation boundary. `next()` should - // inherit the isolation of the caller, but for now, use the opt out. - if (isAsync) { - auto *nonisolated = - NonisolatedAttr::createImplicit(ctx, NonIsolatedModifier::Unsafe); - makeIteratorVar->addAttribute(nonisolated); - } - - // First, let's form a call from sequence to `.makeIterator()` and save - // that in a special variable which is going to be used by SILGen. - { - FuncDecl *makeIterator = isAsync ? ctx.getAsyncSequenceMakeAsyncIterator() - : ctx.getSequenceMakeIterator(); - - auto *makeIteratorRef = new (ctx) UnresolvedDotExpr( - sequenceExpr, SourceLoc(), DeclNameRef(makeIterator->getName()), - DeclNameLoc(stmt->getForLoc()), /*implicit=*/true); - makeIteratorRef->setFunctionRefInfo(FunctionRefInfo::singleBaseNameApply()); - - Expr *makeIteratorCall = - CallExpr::createImplicitEmpty(ctx, makeIteratorRef); - - // Swap in the 'unsafe' expression. - if (unsafeExpr) { - unsafeExpr->setSubExpr(makeIteratorCall); - makeIteratorCall = unsafeExpr; - } - - Pattern *pattern = NamedPattern::createImplicit(ctx, makeIteratorVar); - auto *PB = PatternBindingDecl::createImplicit( - ctx, StaticSpellingKind::None, pattern, makeIteratorCall, dc); - auto makeIteratorTarget = SyntacticElementTarget::forInitialization( - makeIteratorCall, /*patternType=*/Type(), PB, /*index=*/0, - /*shouldBindPatternsOneWay=*/false); + ContextualTypeInfo contextInfo(sequenceProto->getDeclaredInterfaceType(), + CTP_ForEachSequence); + cs.setContextualInfo(sequenceExpr, contextInfo); - ContextualTypeInfo contextInfo(sequenceProto->getDeclaredInterfaceType(), - CTP_ForEachSequence); - cs.setContextualInfo(sequenceExpr, contextInfo); + auto seqExprTarget = SyntacticElementTarget(sequenceExpr, + dc, contextInfo, false); - if (cs.generateConstraints(makeIteratorTarget)) - return std::nullopt; - - sequenceIterationInfo.makeIteratorVar = PB; - - // Type of sequence expression has to conform to Sequence protocol. - // - // Note that the following emulates having `$generator` separately - // type-checked by introducing a `TVO_PrefersSubtypeBinding` type - // variable that would make sure that result of `.makeIterator` would - // get ranked standalone. - { - auto *externalIteratorType = cs.createTypeVariable( - cs.getConstraintLocator(sequenceExpr), TVO_PrefersSubtypeBinding); - - cs.addConstraint(ConstraintKind::Equal, externalIteratorType, - cs.getType(sequenceExpr), - externalIteratorType->getImpl().getLocator()); - - cs.addConstraint(ConstraintKind::ConformsTo, externalIteratorType, - sequenceProto->getDeclaredInterfaceType(), - contextualLocator); - - sequenceIterationInfo.sequenceType = cs.getType(sequenceExpr); - } - - cs.setTargetFor({PB, /*index=*/0}, makeIteratorTarget); - } - - // Now, result type of `.makeIterator()` is used to form a call to - // `.next()`. `next()` is called on each iteration of the loop. - { - FuncDecl *nextFn = - TypeChecker::getForEachIteratorNextFunction(dc, stmt->getForLoc(), isAsync); - Identifier nextId = nextFn ? nextFn->getName().getBaseIdentifier() - : ctx.Id_next; - TinyPtrVector labels; - if (nextFn && nextFn->getParameters()->size() == 1) - labels.push_back(ctx.Id_isolation); - auto *makeIteratorVarRef = - new (ctx) DeclRefExpr(makeIteratorVar, DeclNameLoc(stmt->getForLoc()), - /*Implicit=*/true); - auto *nextRef = new (ctx) - UnresolvedDotExpr(makeIteratorVarRef, SourceLoc(), - DeclNameRef(DeclName(ctx, nextId, labels)), - DeclNameLoc(stmt->getForLoc()), /*implicit=*/true); - nextRef->setFunctionRefInfo(FunctionRefInfo::singleBaseNameApply()); - - ArgumentList *nextArgs; - if (nextFn && nextFn->getParameters()->size() == 1) { - auto isolationArg = - new (ctx) CurrentContextIsolationExpr(stmt->getForLoc(), Type()); - nextArgs = ArgumentList::createImplicit( - ctx, {Argument(SourceLoc(), ctx.Id_isolation, isolationArg)}); - } else { - nextArgs = ArgumentList::createImplicit(ctx, {}); - } - Expr *nextCall = CallExpr::createImplicit(ctx, nextRef, nextArgs); - - // `next` is always async but witness might not be throwing - if (isAsync) { - nextCall = - AwaitExpr::createImplicit(ctx, nextCall->getLoc(), nextCall); - } - - // Wrap the 'next' call in 'unsafe', if the for..in loop has that - // effect or if the loop is async (in which case the iterator variable - // is nonisolated(unsafe). - if (stmt->getUnsafeLoc().isValid() || - (isAsync && - ctx.LangOpts.StrictConcurrencyLevel == StrictConcurrency::Complete)) { - SourceLoc loc = stmt->getUnsafeLoc(); - bool implicit = stmt->getUnsafeLoc().isInvalid(); - if (loc.isInvalid()) - loc = stmt->getForLoc(); - nextCall = new (ctx) UnsafeExpr(loc, nextCall, Type(), implicit); - } - - // The iterator type must conform to IteratorProtocol. - { - ProtocolDecl *iteratorProto = TypeChecker::getProtocol( - cs.getASTContext(), stmt->getForLoc(), - isAsync ? KnownProtocolKind::AsyncIteratorProtocol - : KnownProtocolKind::IteratorProtocol); - if (!iteratorProto) - return std::nullopt; - - ContextualTypeInfo contextInfo(iteratorProto->getDeclaredInterfaceType(), - CTP_ForEachSequence); - cs.setContextualInfo(nextRef->getBase(), contextInfo); - } - - SyntacticElementTarget nextTarget(nextCall, dc, CTP_Unused, - /*contextualType=*/Type(), - /*isDiscarded=*/false); - if (cs.generateConstraints(nextTarget, FreeTypeVariableBinding::Disallow)) - return std::nullopt; + if (cs.generateConstraints(seqExprTarget)) + return std::nullopt; + cs.setTargetFor(sequenceExpr, seqExprTarget); + auto seqType = cs.getType(sequenceExpr); + // Type of sequence expression has to conform to Sequence protocol. + // + // Note that the following emulates having `$generator` separately + // type-checked by introducing a `TVO_PrefersSubtypeBinding` type + // variable that would make sure that result of `.makeIterator` would + // get ranked standalone. + cs.addConstraint(ConstraintKind::ConformsTo, seqType, + sequenceProto->getDeclaredInterfaceType(), + contextualLocator); - sequenceIterationInfo.nextCall = nextTarget.getAsExpr(); - cs.setTargetFor(sequenceIterationInfo.nextCall, nextTarget); - } + sequenceIterationInfo.sequenceType = seqType; // Generate constraints for the pattern. Type initType = @@ -4848,24 +4715,11 @@ generateForEachStmtConstraints(ConstraintSystem &cs, DeclContext *dc, if (!initType) return std::nullopt; - // Add a conversion constraint between the element type of the sequence - // and the type of the element pattern. - auto *elementTypeLoc = cs.getConstraintLocator( - elementLocator, ConstraintLocator::OptionalInjection); - auto elementType = cs.createTypeVariable(elementTypeLoc, - /*flags=*/0); - { - auto nextType = cs.getType(sequenceIterationInfo.nextCall); - cs.addConstraint(ConstraintKind::OptionalObject, nextType, elementType, - elementTypeLoc); - cs.addConstraint(ConstraintKind::Conversion, elementType, initType, - elementLocator); - } - - // Populate all of the information for a for-each loop. - sequenceIterationInfo.elementType = elementType; sequenceIterationInfo.initType = initType; + cs.addConstraint(ConstraintKind::ForEachElement, seqType, initType, + cs.getConstraintLocator(stmt)); + return sequenceIterationInfo; } diff --git a/lib/Sema/CSSimplify.cpp b/lib/Sema/CSSimplify.cpp index 3701f7f6c93a2..c58a8371c7e40 100644 --- a/lib/Sema/CSSimplify.cpp +++ b/lib/Sema/CSSimplify.cpp @@ -2186,6 +2186,7 @@ ConstraintSystem::matchTupleTypes(TupleType *tuple1, TupleType *tuple2, case ConstraintKind::KeyPath: case ConstraintKind::KeyPathApplication: case ConstraintKind::LiteralConformsTo: + case ConstraintKind::ForEachElement: case ConstraintKind::OptionalObject: case ConstraintKind::UnresolvedValueMember: case ConstraintKind::ValueMember: @@ -2550,6 +2551,7 @@ static bool matchFunctionRepresentations(FunctionType::ExtInfo einfo1, case ConstraintKind::KeyPath: case ConstraintKind::KeyPathApplication: case ConstraintKind::LiteralConformsTo: + case ConstraintKind::ForEachElement: case ConstraintKind::OptionalObject: case ConstraintKind::UnresolvedValueMember: case ConstraintKind::ValueMember: @@ -3283,6 +3285,7 @@ ConstraintSystem::matchFunctionTypes(FunctionType *func1, FunctionType *func2, case ConstraintKind::KeyPath: case ConstraintKind::KeyPathApplication: case ConstraintKind::LiteralConformsTo: + case ConstraintKind::ForEachElement: case ConstraintKind::OptionalObject: case ConstraintKind::UnresolvedValueMember: case ConstraintKind::ValueMember: @@ -7361,6 +7364,7 @@ ConstraintSystem::matchTypes(Type type1, Type type2, ConstraintKind kind, case ConstraintKind::KeyPath: case ConstraintKind::KeyPathApplication: case ConstraintKind::LiteralConformsTo: + case ConstraintKind::ForEachElement: case ConstraintKind::OptionalObject: case ConstraintKind::UnresolvedValueMember: case ConstraintKind::ValueMember: @@ -9822,6 +9826,65 @@ ConstraintSystem::simplifyCheckedCastConstraint( llvm_unreachable("Unhandled CheckedCastKind in switch."); } +ConstraintSystem::SolutionKind +ConstraintSystem::simplifyForEachElementConstraint( + Type first, Type second, + TypeMatchOptions flags, + ConstraintLocatorBuilder locator) { + Type seqTy = getFixedTypeRecursive(first, flags, /*wantRValue=*/true); + + if (seqTy->isTypeVariableOrMember()) { + if (flags.contains(TMF_GenerateConstraints)) { + addUnsolvedConstraint( + Constraint::create(*this, ConstraintKind::ForEachElement, first, + second, getConstraintLocator(locator))); + return SolutionKind::Solved; + } + + return SolutionKind::Unsolved; + } + + auto *externalSequenceType = createTypeVariable( + getConstraintLocator(locator), TVO_PrefersSubtypeBinding); + + addConstraint(ConstraintKind::Bind, externalSequenceType, seqTy, + locator); + + bool isAsync = false; + auto anchor = locator.getAnchor(); + if (auto *stmt = getAsStmt(anchor)) { + isAsync = stmt->getAwaitLoc().isValid(); + } + + auto sequenceProto = isAsync + ? Context.getProtocol(KnownProtocolKind::AsyncSequence) + : Context.getProtocol(KnownProtocolKind::Sequence); + + if (!sequenceProto) { + return SolutionKind::Error; + } + + auto *elementAssocType = sequenceProto->getAssociatedType(Context.Id_Element); + auto elementType = DependentMemberType::get(externalSequenceType, elementAssocType); + + Type resultElementType = elementType; + if (seqTy->isExistentialType()) + { + resultElementType = typeEraseOpenedExistentialReference( + elementType, seqTy, externalSequenceType, + TypePosition::Covariant); + if (!resultElementType) { + recordPotentialHole(externalSequenceType); + resultElementType = PlaceholderType::get(Context, externalSequenceType); + increaseScore(SK_Hole, locator); + } + } + + addConstraint(ConstraintKind::Conversion, resultElementType, second, + locator); + return SolutionKind::Solved; +} + ConstraintSystem::SolutionKind ConstraintSystem::simplifyOptionalObjectConstraint( Type first, Type second, @@ -16304,6 +16367,9 @@ ConstraintSystem::addConstraintImpl(ConstraintKind kind, Type first, case ConstraintKind::CheckedCast: return simplifyCheckedCastConstraint(first, second, subflags, locator); + case ConstraintKind::ForEachElement: + return simplifyForEachElementConstraint(first, second, subflags, locator); + case ConstraintKind::OptionalObject: return simplifyOptionalObjectConstraint(first, second, subflags, locator); @@ -16706,6 +16772,7 @@ void ConstraintSystem::addContextualConversionConstraint( case CTP_WrappedProperty: case CTP_ExprPattern: case CTP_SingleValueStmtBranch: + case CTP_ForEachElement: break; } @@ -16941,6 +17008,11 @@ ConstraintSystem::simplifyConstraint(const Constraint &constraint) { return result; } + case ConstraintKind::ForEachElement: + return simplifyForEachElementConstraint( + constraint.getFirstType(), constraint.getSecondType(), + /*flags*/ std::nullopt, constraint.getLocator()); + case ConstraintKind::OptionalObject: return simplifyOptionalObjectConstraint( constraint.getFirstType(), constraint.getSecondType(), diff --git a/lib/Sema/CSSyntacticElement.cpp b/lib/Sema/CSSyntacticElement.cpp index 503cd35d783a9..78fe5f51b1def 100644 --- a/lib/Sema/CSSyntacticElement.cpp +++ b/lib/Sema/CSSyntacticElement.cpp @@ -866,6 +866,7 @@ class SyntacticElementConstraintGenerator } // These statements don't require any type-checking. + void visitOpaqueStmt(OpaqueStmt *opaqueStmt) {} void visitBreakStmt(BreakStmt *breakStmt) {} void visitContinueStmt(ContinueStmt *continueStmt) {} void visitDeferStmt(DeferStmt *deferStmt) {} @@ -1816,6 +1817,10 @@ class SyntacticElementSolutionApplication rewriter.addLocalDeclToTypeCheck(decl); } + ASTNode visitOpaqueStmt(OpaqueStmt *opaqueStmt) { + return opaqueStmt; + } + ASTNode visitBreakStmt(BreakStmt *breakStmt) { // Force the target to be computed in case it produces diagnostics. (void)breakStmt->getTarget(); diff --git a/lib/Sema/Constraint.cpp b/lib/Sema/Constraint.cpp index a7675254f31d3..b029cc108728e 100644 --- a/lib/Sema/Constraint.cpp +++ b/lib/Sema/Constraint.cpp @@ -84,6 +84,7 @@ Constraint::Constraint(ConstraintKind Kind, Type First, Type Second, case ConstraintKind::DynamicTypeOf: case ConstraintKind::EscapableFunctionOf: case ConstraintKind::OpenedExistentialOf: + case ConstraintKind::ForEachElement: case ConstraintKind::OptionalObject: case ConstraintKind::OneWayEqual: case ConstraintKind::UnresolvedMemberChainBase: @@ -167,6 +168,7 @@ Constraint::Constraint(ConstraintKind Kind, Type First, Type Second, Type Third, case ConstraintKind::DynamicTypeOf: case ConstraintKind::EscapableFunctionOf: case ConstraintKind::OpenedExistentialOf: + case ConstraintKind::ForEachElement: case ConstraintKind::OptionalObject: case ConstraintKind::ApplicableFunction: case ConstraintKind::DynamicCallableApplicableFunction: @@ -463,6 +465,8 @@ void Constraint::print(llvm::raw_ostream &Out, SourceManager *sm, Out << getThirdType()->getString(PO); skipSecond = true; break; + case ConstraintKind::ForEachElement: + Out << " for each element "; break; case ConstraintKind::OptionalObject: Out << " optional with object type "; break; case ConstraintKind::BindOverload: { @@ -708,6 +712,7 @@ gatherReferencedTypeVars(Constraint *constraint, case ConstraintKind::DynamicTypeOf: case ConstraintKind::EscapableFunctionOf: case ConstraintKind::OpenedExistentialOf: + case ConstraintKind::ForEachElement: case ConstraintKind::OptionalObject: case ConstraintKind::Defaultable: case ConstraintKind::SubclassOf: @@ -1157,4 +1162,4 @@ void Constraint::setPreparedOverload(PreparedOverload *preparedOverload) { preparedOverload->wasForDiagnostics())); Overload.Prepared = preparedOverload; -} \ No newline at end of file +} diff --git a/lib/Sema/MiscDiagnostics.cpp b/lib/Sema/MiscDiagnostics.cpp index 81ce8ab549a86..a40278c6d4480 100644 --- a/lib/Sema/MiscDiagnostics.cpp +++ b/lib/Sema/MiscDiagnostics.cpp @@ -4198,6 +4198,8 @@ VarDeclUsageChecker::~VarDeclUsageChecker() { // Don't try to suggest 'var' -> 'let' conversion // in case of 'for' loop because it's an implicitly // immutable context. + // FIXME: need to find out if the stmt is part of a foreach's + // desugared suggestLet = !isa(stmt); } diff --git a/lib/Sema/SyntacticElementTarget.cpp b/lib/Sema/SyntacticElementTarget.cpp index e978dddbf7743..08b37ea476f3f 100644 --- a/lib/Sema/SyntacticElementTarget.cpp +++ b/lib/Sema/SyntacticElementTarget.cpp @@ -275,6 +275,7 @@ bool SyntacticElementTarget::contextualTypeIsOnlyAHint() const { case CTP_WrappedProperty: case CTP_ExprPattern: case CTP_SingleValueStmtBranch: + case CTP_ForEachElement: return false; } llvm_unreachable("invalid contextual type"); diff --git a/lib/Sema/TypeCheckDeclPrimary.cpp b/lib/Sema/TypeCheckDeclPrimary.cpp index 9e02c74db944f..5f9d9ad18469f 100644 --- a/lib/Sema/TypeCheckDeclPrimary.cpp +++ b/lib/Sema/TypeCheckDeclPrimary.cpp @@ -1476,6 +1476,10 @@ buildDefaultInitializerString(DeclContext *dc, Pattern *pattern) { case PatternKind::Binding: return buildDefaultInitializerString( dc, cast(pattern)->getSubPattern()); + + case PatternKind::Opaque: + return buildDefaultInitializerString( + dc, cast(pattern)->getSubPattern()); } llvm_unreachable("Unhandled PatternKind in switch."); diff --git a/lib/Sema/TypeCheckEffects.cpp b/lib/Sema/TypeCheckEffects.cpp index 24bad3378c75f..f097e66c1f469 100644 --- a/lib/Sema/TypeCheckEffects.cpp +++ b/lib/Sema/TypeCheckEffects.cpp @@ -1988,15 +1988,19 @@ class ApplyClassifier { classifier.AsyncKind, /*FIXME:*/PotentialEffectReason::forApply()); } - case EffectKind::Unsafe: - llvm_unreachable("Unimplemented"); + case EffectKind::Unsafe: { + FunctionUnsafeClassifier classifier(*this); + stmt->walk(classifier); + return classifier.classification; + } } + llvm_unreachable("Bad effect"); } /// Check to see if the given for-each statement to determine if it /// throws or is async. Classification classifyForEach(ForEachStmt *stmt) { - if (!stmt->getNextCall()) + if (!stmt->getDesugaredStmt()) return Classification::forInvalidCode(); // If there is an 'await', the for-each loop is always async. @@ -2011,10 +2015,10 @@ class ApplyClassifier { } // Merge the thrown result from the next/nextElement call. - result.merge(classifyExpr(stmt->getNextCall(), EffectKind::Throws)); + result.merge(classifyStmt(stmt->getDesugaredStmt(), EffectKind::Throws)); // Merge unsafe effect from the next/nextElement call. - result.merge(classifyExpr(stmt->getNextCall(), EffectKind::Unsafe)); + result.merge(classifyStmt(stmt->getDesugaredStmt(), EffectKind::Unsafe)); return result; } @@ -3629,10 +3633,6 @@ class CheckEffectsCoverage : public EffectsHandlingWalker llvm::DenseMap> uncoveredAsync; llvm::DenseMap parentMap; - /// The next/nextElement call expressions within for-in statements we've - /// seen. - llvm::SmallDenseSet forEachNextCallExprs; - /// Expressions that are assumed to be safe because they are being /// passed directly into an explicitly `@safe` function. llvm::DenseSet assumedSafeArguments; @@ -4384,11 +4384,6 @@ class CheckEffectsCoverage : public EffectsHandlingWalker /*stopAtAutoClosure=*/false, EffectKind::Unsafe); - // We don't diagnose uncovered unsafe uses within the next/nextElement - // call, because they're handled already by the for-in loop checking. - if (forEachNextCallExprs.contains(anchor)) - break; - // Figure out a location to use if the unsafe use didn't have one. SourceLoc replacementLoc; if (anchor) @@ -4585,18 +4580,18 @@ class CheckEffectsCoverage : public EffectsHandlingWalker ShouldRecurse_t checkForEach(ForEachStmt *S) { // Reparent the type-checked sequence on the parsed sequence, so we can // find an anchor. - if (auto typeCheckedExpr = S->getTypeCheckedSequence()) { + if (auto typeCheckedExpr = S->getParsedSequence()) { parentMap = typeCheckedExpr->getParentMap(); - - if (auto parsedSequence = S->getParsedSequence()) { - parentMap[typeCheckedExpr] = parsedSequence; - } } - // Note the nextCall expression. - if (auto nextCall = S->getNextCall()) { - forEachNextCallExprs.insert(nextCall); - } + // Walk everything + S->getParsedSequence()->walk(*this); + S->getBody()->walk(*this); + if (S->getWhere()) + S->getWhere()->walk(*this); + + if (S->getDesugaredStmt()) + S->getDesugaredStmt()->walk(*this); auto classification = getApplyClassifier().classifyForEach(S); @@ -4638,7 +4633,13 @@ class CheckEffectsCoverage : public EffectsHandlingWalker } } - return ShouldRecurse; + if (S->getUnsafeLoc().isValid() && !classification.hasUnsafe()){ + Ctx.Diags.diagnose(S->getUnsafeLoc(), + diag::no_unsafe_in_unsafe_for) + .fixItRemove(S->getUnsafeLoc()); + } + + return ShouldNotRecurse; } ShouldRecurse_t checkDefer(DeferStmt *S) { @@ -4704,11 +4705,6 @@ class CheckEffectsCoverage : public EffectsHandlingWalker return; } - Ctx.Diags.diagnose(E->getUnsafeLoc(), - forEachNextCallExprs.contains(E) - ? diag::no_unsafe_in_unsafe_for - : diag::no_unsafe_in_unsafe) - .fixItRemove(E->getUnsafeLoc()); } void noteLabeledConditionalStmt(LabeledConditionalStmt *stmt) { diff --git a/lib/Sema/TypeCheckPattern.cpp b/lib/Sema/TypeCheckPattern.cpp index b08c2bc5a6677..c88f3e273eadb 100644 --- a/lib/Sema/TypeCheckPattern.cpp +++ b/lib/Sema/TypeCheckPattern.cpp @@ -310,6 +310,10 @@ class ResolvePattern : public ASTVisitor(P)) SP = PP->getSubPattern(); + else if (auto* BP = dyn_cast(P)) + SP = BP->getSubPattern(); else - SP = cast(P)->getSubPattern(); + SP = cast(P)->getSubPattern(); + Type subType = TypeChecker::typeCheckPattern( pattern.forSubPattern(SP, /*retainTopLevel=*/true)); if (subType->hasError()) @@ -1150,10 +1158,26 @@ Pattern *TypeChecker::coercePatternToType( PP->setType(sub->getType()); return P; } + + case PatternKind::Opaque: { + auto VP = cast(P); + auto sub = VP->getSubPattern(); + + sub = coercePatternToType( + pattern.forSubPattern(sub, /*retainTopLevel=*/false), type, subOptions, + tryRewritePattern); + if (!sub) + return nullptr; + VP->setSubPattern(sub); + if (sub->hasType()) + VP->setType(sub->getType()); + return P; + } + case PatternKind::Binding: { auto VP = cast(P); + auto sub = VP->getSubPattern(); - Pattern *sub = VP->getSubPattern(); sub = coercePatternToType( pattern.forSubPattern(sub, /*retainTopLevel=*/false), type, subOptions, tryRewritePattern); diff --git a/lib/Sema/TypeCheckStmt.cpp b/lib/Sema/TypeCheckStmt.cpp index 9584ec27111e7..52e3ac91b16ff 100644 --- a/lib/Sema/TypeCheckStmt.cpp +++ b/lib/Sema/TypeCheckStmt.cpp @@ -137,6 +137,11 @@ namespace { CS->setDeclContext(ParentDC); if (auto *FS = dyn_cast(S)) FS->setDeclContext(ParentDC); + if (auto *FES = dyn_cast(S)) + { + FES->setDeclContext(ParentDC); + FES->desugar(); + } return Action::Continue(S); } @@ -1520,6 +1525,10 @@ class StmtChecker : public StmtVisitor { return S; } + Stmt *visitOpaqueStmt(OpaqueStmt *S) { + return S; + } + Stmt *visitBreakStmt(BreakStmt *S) { // Force the target to be computed in case it produces diagnostics. (void)S->getTarget(); @@ -3430,3 +3439,180 @@ FuncDecl *TypeChecker::getForEachIteratorNextFunction( // Fall back to AsyncIteratorProtocol.next(). return ctx.getAsyncIteratorNext(); } + +static BraceStmt *desugarForEachStmt(ForEachStmt* stmt){ + auto *parsedSequence = stmt->getParsedSequence(); + + if (isa(parsedSequence)) + return nullptr; + + if (parsedSequence->getType()->hasError() + || stmt->getPattern()->getType()->hasError()) + return nullptr; + + auto *dc = stmt->getDeclContext(); + auto &ctx = dc->getASTContext(); + bool isAsync = stmt->getAwaitLoc().isValid(); + + // If we have an unsafe expression for the sequence, lift it out of the + // sequence expression. We'll put it back after we've introduced the + // various calls. + UnsafeExpr *unsafeExpr = dyn_cast(parsedSequence); + if (unsafeExpr) { + parsedSequence = unsafeExpr->getSubExpr(); + } + + auto opaqueSeqExpr = new (ctx) OpaqueExpr(parsedSequence); + + std::string name; + { + if (auto np = dyn_cast_or_null(stmt->getPattern())) + name = "$"+np->getBoundName().str().str(); + name += "$generator"; + } + + auto *makeIteratorVar = new (ctx) + VarDecl(/*isStatic=*/false, VarDecl::Introducer::Var, + opaqueSeqExpr->getStartLoc(), + ctx.getIdentifier(name), dc); + makeIteratorVar->setImplicit(); + + // Async iterators are not `Sendable`; they're only meant to be used from + // the isolation domain that creates them. But the `next()` method runs on + // the generic executor, so calling it from an actor-isolated context passes + // non-`Sendable` state across the isolation boundary. `next()` should + // inherit the isolation of the caller, but for now, use the opt out. + if (isAsync) { + auto *nonisolated = + NonisolatedAttr::createImplicit(ctx, NonIsolatedModifier::Unsafe); + makeIteratorVar->addAttribute(nonisolated); + } + + // First, let's form a call from sequence to `.makeIterator()` and save + // that in a special variable which is going to be used by SILGen. + FuncDecl *makeIterator = isAsync ? ctx.getAsyncSequenceMakeAsyncIterator() + : ctx.getSequenceMakeIterator(); + + auto *makeIteratorRef = new (ctx) UnresolvedDotExpr( + opaqueSeqExpr, SourceLoc(), DeclNameRef(makeIterator->getName()), + DeclNameLoc(stmt->getForLoc()), /*implicit=*/true); + makeIteratorRef->setFunctionRefInfo(FunctionRefInfo::singleBaseNameApply()); + + Expr *makeIteratorCall = + CallExpr::createImplicitEmpty(ctx, makeIteratorRef); + + // Swap in the 'unsafe' expression. + if (unsafeExpr) { + unsafeExpr = UnsafeExpr::createImplicit(ctx, unsafeExpr->getUnsafeLoc(), makeIteratorCall); + makeIteratorCall = unsafeExpr; + } + + Pattern *pattern = NamedPattern::createImplicit(ctx, makeIteratorVar); + auto *PB = PatternBindingDecl::createImplicit( + ctx, StaticSpellingKind::None, pattern, makeIteratorCall, dc); + + // The result type of `.makeIterator()` is used to form a call to + // `.next()`. `next()` is called on each iteration of the loop. + FuncDecl *nextFn = + TypeChecker::getForEachIteratorNextFunction(dc, stmt->getForLoc(), isAsync); + Identifier nextId = nextFn ? nextFn->getName().getBaseIdentifier() + : ctx.Id_next; + TinyPtrVector labels; + if (nextFn && nextFn->getParameters()->size() == 1) + labels.push_back(ctx.Id_isolation); + auto *makeIteratorVarRef = + new (ctx) DeclRefExpr(makeIteratorVar, DeclNameLoc(stmt->getForLoc()), + /*Implicit=*/true); + auto *nextRef = new (ctx) + UnresolvedDotExpr(makeIteratorVarRef, SourceLoc(), + DeclNameRef(DeclName(ctx, nextId, labels)), + DeclNameLoc(stmt->getForLoc()), /*implicit=*/true); + nextRef->setFunctionRefInfo(FunctionRefInfo::singleBaseNameApply()); + + ArgumentList *nextArgs; + if (nextFn && nextFn->getParameters()->size() == 1) { + auto isolationArg = + new (ctx) CurrentContextIsolationExpr(stmt->getForLoc(), Type()); + nextArgs = ArgumentList::createImplicit( + ctx, {Argument(SourceLoc(), ctx.Id_isolation, isolationArg)}); + } else { + nextArgs = ArgumentList::createImplicit(ctx, {}); + } + Expr *nextCall = CallExpr::createImplicit(ctx, nextRef, nextArgs); + + // `next` is always async but witness might not be throwing + if (isAsync) { + nextCall = + AwaitExpr::createImplicit(ctx, nextCall->getLoc(), nextCall); + } + + if (stmt->getTryLoc().isValid()){ + nextCall = TryExpr::createImplicit(ctx, nextCall->getLoc(), nextCall); + } + + // Wrap the 'next' call in 'unsafe', if the for..in loop has that + // effect or if the loop is async (in which case the iterator variable + // is nonisolated(unsafe). + if (stmt->getUnsafeLoc().isValid() || + (isAsync && + ctx.LangOpts.StrictConcurrencyLevel == StrictConcurrency::Complete)) { + SourceLoc loc = stmt->getUnsafeLoc(); + bool implicit = stmt->getUnsafeLoc().isInvalid(); + if (loc.isInvalid()) + loc = stmt->getForLoc(); + nextCall = new (ctx) UnsafeExpr(loc, nextCall, Type(), implicit); + } + + auto elementPattern = stmt->getPattern(); + auto optPatternType = OptionalType::get(elementPattern->getType()); + + SmallVector cond; + auto *opaquePattern = new (ctx) OpaquePattern(elementPattern); + opaquePattern->setType(optPatternType); + + auto *somePattern = OptionalSomePattern::createImplicit(ctx, opaquePattern); + + auto PBI = ConditionalPatternBindingInfo::create(ctx, SourceLoc(), somePattern, nextCall); + auto conditionElement = StmtConditionElement(PBI); + cond.push_back(conditionElement); + + /* for ... in ... where cond { body } + * becomes: + * while ... { if cond then body else continue } + */ + auto* whereClause = stmt->getWhere(); + auto* forBody = stmt->getBody(); + + Stmt* whileBody = new (ctx) OpaqueStmt(forBody, SourceLoc(), SourceLoc()); + + if (whereClause) + { + SmallVector thenClause{whileBody}; + + whereClause = new (ctx) OpaqueExpr(whereClause); + + whileBody = new (ctx) IfStmt(SourceLoc(), whereClause, + BraceStmt::create(ctx, SourceLoc(), thenClause, SourceLoc()), SourceLoc(), + nullptr, /*implicit*/ true, ctx); + } + + auto* whileStmt = new (ctx) WhileStmt(stmt->getLabelInfo(), SourceLoc(), ctx.AllocateCopy(cond), whileBody, true); + stmt->setBreakTarget(whileStmt); + stmt->setContinueTarget(whileStmt); + + SmallVector stmts; + stmts.push_back(PB); + stmts.push_back(whileStmt); + + auto *braceStmt = BraceStmt::create(ctx, stmt->getStartLoc(), stmts, stmt->getEndLoc()); + + StmtChecker checker(stmt->getDeclContext()); + if (!checker.typeCheckStmt(braceStmt)) + return nullptr; + + return braceStmt; +} + +BraceStmt* DesugarForEachStmtRequest::evaluate(Evaluator &evaluator, ForEachStmt *stmt) const { + return desugarForEachStmt(stmt); +} diff --git a/lib/Sema/TypeCheckSwitchStmt.cpp b/lib/Sema/TypeCheckSwitchStmt.cpp index 48edef19f8363..11c33b2a579fe 100644 --- a/lib/Sema/TypeCheckSwitchStmt.cpp +++ b/lib/Sema/TypeCheckSwitchStmt.cpp @@ -1482,6 +1482,10 @@ namespace { auto *PP = cast(item); return projectPattern(PP->getSubPattern()); } + case PatternKind::Opaque: { + auto *opaque = cast(item); + return projectPattern(opaque->getSubPattern()); + } case PatternKind::OptionalSome: { auto *OSP = cast(item); const Identifier name = OSP->getElementDecl()->getBaseIdentifier(); diff --git a/lib/Serialization/Serialization.cpp b/lib/Serialization/Serialization.cpp index 22af1f5a88a4e..cbba152fa334d 100644 --- a/lib/Serialization/Serialization.cpp +++ b/lib/Serialization/Serialization.cpp @@ -4034,11 +4034,13 @@ class Serializer::DeclSerializer : public DeclVisitor { writePattern(typed->getSubPattern()); break; } + case PatternKind::Is: case PatternKind::EnumElement: case PatternKind::OptionalSome: case PatternKind::Bool: case PatternKind::Expr: + case PatternKind::Opaque: llvm_unreachable("Refutable patterns cannot be serialized"); case PatternKind::Binding: { diff --git a/test/SILGen/foreach.swift b/test/SILGen/foreach.swift index 8a16840653184..82a3940e46b4a 100644 --- a/test/SILGen/foreach.swift +++ b/test/SILGen/foreach.swift @@ -122,6 +122,9 @@ func trivialStructBreak(_ xx: [Int]) { // CHECK: [[IND_VAR:%.*]] = load [trivial] [[GET_ELT_STACK]] // CHECK: switch_enum [[IND_VAR]] : $Optional, case #Optional.some!enumelt: [[SOME_BB:bb[0-9]+]], case #Optional.none!enumelt: [[NONE_BB:bb[0-9]+]] // +// CHECK: [[NONE_BB]]: +// CHECK: br [[CONT_BLOCK:bb[0-9]+]] +// // CHECK: [[SOME_BB]]([[VAR:%.*]] : $Int): // CHECK: cond_br {{%.*}}, [[LOOP_BREAK_END_BLOCK:bb[0-9]+]], [[CONTINUE_CHECK_BLOCK:bb[0-9]+]] // @@ -143,9 +146,6 @@ func trivialStructBreak(_ xx: [Int]) { // CHECK: apply [[LOOP_BODY_FUNC]]() // CHECK: br [[LOOP_DEST]] // -// CHECK: [[NONE_BB]]: -// CHECK: br [[CONT_BLOCK]] -// // CHECK: [[CONT_BLOCK]] // CHECK: destroy_value [[ITERATOR_BOX]] : ${ var IndexingIterator> } // CHECK: [[FUNC_END_FUNC:%.*]] = function_ref @funcEnd : $@convention(thin) () -> () @@ -213,19 +213,23 @@ func existentialBreak(_ xx: [P]) { // CHECK: store [[ARRAY_COPY:%.*]] to [init] [[BORROWED_ARRAY_STACK]] // CHECK: [[MAKE_ITERATOR_FUNC:%.*]] = function_ref @$sSlss16IndexingIteratorVyxG0B0RtzrlE04makeB0ACyF : $@convention(method) <τ_0_0 where τ_0_0 : Collection, τ_0_0.Iterator == IndexingIterator<τ_0_0>> (@in τ_0_0) -> @out IndexingIterator<τ_0_0> // CHECK: apply [[MAKE_ITERATOR_FUNC]]>([[PROJECT_ITERATOR_BOX]], [[BORROWED_ARRAY_STACK]]) -// CHECK: [[ELT_STACK:%.*]] = alloc_stack $Optional // CHECK: br [[LOOP_DEST:bb[0-9]+]] // // CHECK: [[LOOP_DEST]]: +// CHECK: [[T0:%.*]] = alloc_stack [lexical] [var_decl] $any P, let, name "x" +// CHECK: [[ELT_STACK:%.*]] = alloc_stack $Optional // CHECK: [[WRITE:%.*]] = begin_access [modify] [unknown] [[PROJECT_ITERATOR_BOX]] : $*IndexingIterator> // CHECK: [[FUNC_REF:%.*]] = function_ref @$ss16IndexingIteratorV4next7ElementQzSgyF : $@convention(method) <τ_0_0 where τ_0_0 : Collection> (@inout IndexingIterator<τ_0_0>) -> @out Optional<τ_0_0.Element> // CHECK: apply [[FUNC_REF]]>([[ELT_STACK]], [[WRITE]]) // CHECK: switch_enum_addr [[ELT_STACK]] : $*Optional, case #Optional.some!enumelt: [[SOME_BB:bb[0-9]+]], case #Optional.none!enumelt: [[NONE_BB:bb[0-9]+]] // +// CHECK: [[NONE_BB]]: +// CHECK: br [[CONT_BLOCK:bb[0-9]+]] +// // CHECK: [[SOME_BB]]: -// CHECK: [[T0:%.*]] = alloc_stack [lexical] [var_decl] $any P, let, name "x" // CHECK: [[ELT_STACK_TAKE:%.*]] = unchecked_take_enum_data_addr [[ELT_STACK]] : $*Optional, #Optional.some!enumelt // CHECK: copy_addr [take] [[ELT_STACK_TAKE]] to [init] [[T0]] +// CHECK: dealloc_stack [[ELT_STACK]] // CHECK: cond_br {{%.*}}, [[LOOP_BREAK_END_BLOCK:bb[0-9]+]], [[CONTINUE_CHECK_BLOCK:bb[0-9]+]] // // CHECK: [[LOOP_BREAK_END_BLOCK]]: @@ -252,11 +256,7 @@ func existentialBreak(_ xx: [P]) { // CHECK: dealloc_stack [[T0]] // CHECK: br [[LOOP_DEST]] // -// CHECK: [[NONE_BB]]: -// CHECK: br [[CONT_BLOCK]] -// // CHECK: [[CONT_BLOCK]] -// CHECK: dealloc_stack [[ELT_STACK]] // CHECK: destroy_value [[ITERATOR_BOX]] : ${ var IndexingIterator> } // CHECK: [[FUNC_END_FUNC:%.*]] = function_ref @funcEnd : $@convention(thin) () -> () // CHECK: apply [[FUNC_END_FUNC]]() @@ -374,19 +374,23 @@ func genericStructBreak(_ xx: [GenericStruct]) { // CHECK: store [[ARRAY_COPY:%.*]] to [init] [[BORROWED_ARRAY_STACK]] // CHECK: [[MAKE_ITERATOR_FUNC:%.*]] = function_ref @$sSlss16IndexingIteratorVyxG0B0RtzrlE04makeB0ACyF : $@convention(method) <τ_0_0 where τ_0_0 : Collection, τ_0_0.Iterator == IndexingIterator<τ_0_0>> (@in τ_0_0) -> @out IndexingIterator<τ_0_0> // CHECK: apply [[MAKE_ITERATOR_FUNC]]>>([[PROJECT_ITERATOR_BOX]], [[BORROWED_ARRAY_STACK]]) -// CHECK: [[ELT_STACK:%.*]] = alloc_stack $Optional> // CHECK: br [[LOOP_DEST:bb[0-9]+]] // // CHECK: [[LOOP_DEST]]: +// CHECK: [[T0:%.*]] = alloc_stack [lexical] [var_decl] $GenericStruct, let, name "x" +// CHECK: [[ELT_STACK:%.*]] = alloc_stack $Optional> // CHECK: [[WRITE:%.*]] = begin_access [modify] [unknown] [[PROJECT_ITERATOR_BOX]] : $*IndexingIterator>> // CHECK: [[FUNC_REF:%.*]] = function_ref @$ss16IndexingIteratorV4next7ElementQzSgyF : $@convention(method) <τ_0_0 where τ_0_0 : Collection> (@inout IndexingIterator<τ_0_0>) -> @out Optional<τ_0_0.Element> // CHECK: apply [[FUNC_REF]]>>([[ELT_STACK]], [[WRITE]]) // CHECK: switch_enum_addr [[ELT_STACK]] : $*Optional>, case #Optional.some!enumelt: [[SOME_BB:bb[0-9]+]], case #Optional.none!enumelt: [[NONE_BB:bb[0-9]+]] // +// CHECK: [[NONE_BB]]: +// CHECK: br [[CONT_BLOCK:bb[0-9]+]] +// // CHECK: [[SOME_BB]]: -// CHECK: [[T0:%.*]] = alloc_stack [lexical] [var_decl] $GenericStruct, let, name "x" // CHECK: [[ELT_STACK_TAKE:%.*]] = unchecked_take_enum_data_addr [[ELT_STACK]] : $*Optional>, #Optional.some!enumelt // CHECK: copy_addr [take] [[ELT_STACK_TAKE]] to [init] [[T0]] +// CHECK: dealloc_stack [[ELT_STACK]] // CHECK: cond_br {{%.*}}, [[LOOP_BREAK_END_BLOCK:bb[0-9]+]], [[CONTINUE_CHECK_BLOCK:bb[0-9]+]] // // CHECK: [[LOOP_BREAK_END_BLOCK]]: @@ -413,11 +417,7 @@ func genericStructBreak(_ xx: [GenericStruct]) { // CHECK: dealloc_stack [[T0]] // CHECK: br [[LOOP_DEST]] // -// CHECK: [[NONE_BB]]: -// CHECK: br [[CONT_BLOCK]] -// // CHECK: [[CONT_BLOCK]] -// CHECK: dealloc_stack [[ELT_STACK]] // CHECK: destroy_value [[ITERATOR_BOX]] // CHECK: [[FUNC_END_FUNC:%.*]] = function_ref @funcEnd : $@convention(thin) () -> () // CHECK: apply [[FUNC_END_FUNC]]() @@ -481,19 +481,23 @@ func genericCollectionBreak(_ xx: T) { // CHECK: [[PROJECT_ITERATOR_BOX:%.*]] = project_box [[ITERATOR_LIFETIME]] // CHECK: [[MAKE_ITERATOR_FUNC:%.*]] = witness_method $T, #Sequence.makeIterator : // CHECK: apply [[MAKE_ITERATOR_FUNC]]([[PROJECT_ITERATOR_BOX]], [[COLLECTION_COPY:%.*]]) -// CHECK: [[ELT_STACK:%.*]] = alloc_stack $Optional // CHECK: br [[LOOP_DEST:bb[0-9]+]] // // CHECK: [[LOOP_DEST]]: +// CHECK: [[T0:%.*]] = alloc_stack [lexical] [var_decl] $T.Element, let, name "x" +// CHECK: [[ELT_STACK:%.*]] = alloc_stack $Optional // CHECK: [[WRITE:%.*]] = begin_access [modify] [unknown] [[PROJECT_ITERATOR_BOX]] : $*T.Iterator // CHECK: [[GET_NEXT_FUNC:%.*]] = witness_method $T.Iterator, #IteratorProtocol.next : (inout Self) -> () -> Self.Element? : $@convention(witness_method: IteratorProtocol) <τ_0_0 where τ_0_0 : IteratorProtocol> (@inout τ_0_0) -> @out Optional<τ_0_0.Element> // CHECK: apply [[GET_NEXT_FUNC]]([[ELT_STACK]], [[WRITE]]) // CHECK: switch_enum_addr [[ELT_STACK]] : $*Optional, case #Optional.some!enumelt: [[SOME_BB:bb[0-9]+]], case #Optional.none!enumelt: [[NONE_BB:bb[0-9]+]] // +// CHECK: [[NONE_BB]]: +// CHECK: br [[CONT_BLOCK:bb[0-9]+]] +// // CHECK: [[SOME_BB]]: -// CHECK: [[T0:%.*]] = alloc_stack [lexical] [var_decl] $T.Element, let, name "x" // CHECK: [[ELT_STACK_TAKE:%.*]] = unchecked_take_enum_data_addr [[ELT_STACK]] : $*Optional, #Optional.some!enumelt // CHECK: copy_addr [take] [[ELT_STACK_TAKE]] to [init] [[T0]] +// CHECK: dealloc_stack [[ELT_STACK]] // CHECK: cond_br {{%.*}}, [[LOOP_BREAK_END_BLOCK:bb[0-9]+]], [[CONTINUE_CHECK_BLOCK:bb[0-9]+]] // // CHECK: [[LOOP_BREAK_END_BLOCK]]: @@ -520,11 +524,7 @@ func genericCollectionBreak(_ xx: T) { // CHECK: dealloc_stack [[T0]] // CHECK: br [[LOOP_DEST]] // -// CHECK: [[NONE_BB]]: -// CHECK: br [[CONT_BLOCK]] -// // CHECK: [[CONT_BLOCK]] -// CHECK: dealloc_stack [[ELT_STACK]] // CHECK: end_borrow [[ITERATOR_LIFETIME]] // CHECK: destroy_value [[ITERATOR_BOX]] // CHECK: [[FUNC_END_FUNC:%.*]] = function_ref @funcEnd : $@convention(thin) () -> () @@ -578,9 +578,6 @@ func tupleElements(_ xx: [(C, C)]) { // CHECK: destroy_value [[A]] for (_, _) in xx {} // CHECK: bb{{.*}}([[PAYLOAD:%.*]] : @owned $(C, C)): - // CHECK: ([[A:%.*]], [[B:%.*]]) = destructure_tuple [[PAYLOAD]] - // CHECK: destroy_value [[B]] - // CHECK: destroy_value [[A]] for _ in xx {} } @@ -621,9 +618,9 @@ func genericFuncWithConversion(list : [T]) { // CHECK-LABEL: sil hidden [ossa] @$s7foreach32injectForEachElementIntoOptionalyySaySiGF // CHECK: [[NEXT_RESULT:%.*]] = load [trivial] {{.*}} : $*Optional -// CHECK: switch_enum [[NEXT_RESULT]] : $Optional, case #Optional.some!enumelt: [[BB_SOME:bb.*]], case -// CHECK: [[BB_SOME]]([[X_PRE_BINDING:%.*]] : $Int): -// CHECK: [[X_BINDING:%.*]] = enum $Optional, #Optional.some!enumelt, [[X_PRE_BINDING]] : $Int +// CHECK: [[X_BINDING:%.*]] = enum $Optional>, #Optional.some!enumelt, [[NEXT_RESULT]] : $Optional +// CHECK: switch_enum [[X_BINDING]] : $Optional>, case #Optional.some!enumelt: [[BB_SOME:bb.*]], case +// CHECK: [[BB_SOME]]([[X_BINDING:%.*]] : $Optional): // CHECK: [[MVX_BINDING:%.*]] = move_value [var_decl] [[X_BINDING]] : $Optional // CHECK: debug_value [[MVX_BINDING]] : $Optional, let, name "x" func injectForEachElementIntoOptional(_ xs: [Int]) { @@ -631,14 +628,14 @@ func injectForEachElementIntoOptional(_ xs: [Int]) { } // CHECK-LABEL: sil hidden [ossa] @$s7foreach32injectForEachElementIntoOptionalyySayxGlF -// CHECK: copy_addr [take] [[NEXT_RESULT:%.*]] to [init] [[NEXT_RESULT_COPY:%.*]] : $*Optional -// CHECK: switch_enum_addr [[NEXT_RESULT_COPY]] : $*Optional, case #Optional.some!enumelt: [[BB_SOME:bb.*]], case -// CHECK: [[BB_SOME]]: // CHECK: [[X_BINDING:%.*]] = alloc_stack [lexical] [var_decl] $Optional, let, name "x" -// CHECK: [[ADDR:%.*]] = unchecked_take_enum_data_addr [[NEXT_RESULT_COPY]] : $*Optional, #Optional.some!enumelt -// CHECK: [[X_ADDR:%.*]] = init_enum_data_addr [[X_BINDING]] : $*Optional, #Optional.some!enumelt -// CHECK: copy_addr [take] [[ADDR]] to [init] [[X_ADDR]] : $*T -// CHECK: inject_enum_addr [[X_BINDING]] : $*Optional, #Optional.some!enumelt +// CHECK: [[ADDR:%.*]] = alloc_stack $Optional> +// CHECK: [[X_ADDR:%.*]] = init_enum_data_addr [[ADDR]] : $*Optional>, #Optional.some!enumelt +// CHECK: inject_enum_addr [[ADDR]] : $*Optional>, #Optional.some!enumelt +// CHECK: switch_enum_addr [[ADDR]] : $*Optional>, case #Optional.some!enumelt: [[BB_SOME:bb.*]], case +// CHECK: [[BB_SOME]]: +// CHECK: [[RES:%.*]] = unchecked_take_enum_data_addr [[ADDR]] : $*Optional>, #Optional.some!enumelt +// CHECK: copy_addr [take] [[RES:%.*]] to [init] [[RES_COPY:%.*]] : $*Optional func injectForEachElementIntoOptional(_ xs: [T]) { for x : T? in xs {} } diff --git a/test/SILGen/sil_locations.swift b/test/SILGen/sil_locations.swift index 03889318273b0..abfab46c533f2 100644 --- a/test/SILGen/sil_locations.swift +++ b/test/SILGen/sil_locations.swift @@ -331,15 +331,15 @@ func testStringForEachStmt() { } // CHECK-LABEL: sil hidden [ossa] @$s13sil_locations21testStringForEachStmtyyF - // CHECK: br {{.*}} line:[[@LINE-8]]:3 + // CHECK: br {{.*}} // CHECK: switch_enum {{.*}} line:[[@LINE-9]]:3 // CHECK: cond_br {{.*}} line:[[@LINE-8]]:10 // Break branch: // CHECK: br {{.*}} line:[[@LINE-9]]:7 // Looping back branch: - // CHECK: br {{.*}} line:[[@LINE-9]]:3 + // CHECK: br {{.*}} // Condition is false branch: - // CHECK: br {{.*}} line:[[@LINE-16]]:3 + // CHECK: br {{.*}}