diff --git a/include/swift/AST/ASTContext.h b/include/swift/AST/ASTContext.h index 46125e348f464..bc9254895dd0d 100644 --- a/include/swift/AST/ASTContext.h +++ b/include/swift/AST/ASTContext.h @@ -522,6 +522,9 @@ class ASTContext final { /// Get Sequence.makeIterator(). FuncDecl *getSequenceMakeIterator() const; + /// Get AsyncSequence.makeAsyncIterator(). + FuncDecl *getAsyncSequenceMakeAsyncIterator() const; + /// Check whether the standard library provides all the correct /// intrinsic support for Optional. /// diff --git a/include/swift/AST/ASTTypeIDZone.def b/include/swift/AST/ASTTypeIDZone.def index f6c5d333af3fb..c076ecf16d8da 100644 --- a/include/swift/AST/ASTTypeIDZone.def +++ b/include/swift/AST/ASTTypeIDZone.def @@ -31,6 +31,8 @@ SWIFT_TYPEID(PropertyWrapperTypeInfo) SWIFT_TYPEID(Requirement) SWIFT_TYPEID(ResilienceExpansion) SWIFT_TYPEID(FragileFunctionKind) +SWIFT_TYPEID(FunctionRethrowingKind) +SWIFT_TYPEID(ProtocolRethrowsRequirementList) SWIFT_TYPEID(TangentPropertyInfo) SWIFT_TYPEID(SymbolSourceMap) SWIFT_TYPEID(Type) diff --git a/include/swift/AST/ASTTypeIDs.h b/include/swift/AST/ASTTypeIDs.h index 6d217514e20b2..f647bad61928c 100644 --- a/include/swift/AST/ASTTypeIDs.h +++ b/include/swift/AST/ASTTypeIDs.h @@ -64,6 +64,8 @@ class ProtocolDecl; class Requirement; enum class ResilienceExpansion : unsigned; struct FragileFunctionKind; +enum class FunctionRethrowingKind : uint8_t; +class ProtocolRethrowsRequirementList; class SourceFile; class SymbolSourceMap; struct TangentPropertyInfo; diff --git a/include/swift/AST/Attr.def b/include/swift/AST/Attr.def index 0166ef3545cc8..a56ef80cb13ec 100644 --- a/include/swift/AST/Attr.def +++ b/include/swift/AST/Attr.def @@ -359,6 +359,10 @@ SIMPLE_DECL_ATTR(rethrows, Rethrows, RejectByParser | ABIStableToAdd | ABIStableToRemove | APIStableToAdd | APIStableToRemove, 57) +SIMPLE_DECL_ATTR(rethrows, AtRethrows, + OnProtocol | + ABIStableToAdd | ABIStableToRemove | APIStableToAdd | APIStableToRemove, + 58) DECL_ATTR(_swift_native_objc_runtime_base, SwiftNativeObjCRuntimeBase, OnClass | UserInaccessible | diff --git a/include/swift/AST/Decl.h b/include/swift/AST/Decl.h index 54ce64eb5de44..418aabfa2281f 100644 --- a/include/swift/AST/Decl.h +++ b/include/swift/AST/Decl.h @@ -3870,6 +3870,66 @@ enum class KnownDerivableProtocolKind : uint8_t { Actor, }; +class ProtocolRethrowsRequirementList { +public: + typedef std::pair Entry; + +private: + ArrayRef entries; + +public: + ProtocolRethrowsRequirementList(ArrayRef entries) : entries(entries) {} + ProtocolRethrowsRequirementList() : entries() {} + + typedef const Entry *const_iterator; + typedef const_iterator iterator; + + const_iterator begin() const { return entries.begin(); } + const_iterator end() const { return entries.end(); } + + size_t size() const { return entries.size(); } + + void print(raw_ostream &OS) const; + + SWIFT_DEBUG_DUMP; + + friend bool operator==(const ProtocolRethrowsRequirementList &lhs, + const ProtocolRethrowsRequirementList &rhs) { + if (lhs.size() != rhs.size()) { + return false; + } + auto lhsIter = lhs.begin(); + auto rhsIter = rhs.begin(); + while (lhsIter != lhs.end() && rhsIter != rhs.end()) { + if (lhsIter->first->isEqual(rhsIter->first)) { + return false; + } + if (lhsIter->second != rhsIter->second) { + return false; + } + } + return true; + } + + friend bool operator!=(const ProtocolRethrowsRequirementList &lhs, + const ProtocolRethrowsRequirementList &rhs) { + return !(lhs == rhs); + } + + friend llvm::hash_code hash_value( + const ProtocolRethrowsRequirementList &list) { + return llvm::hash_combine(list.size()); // it is good enought for + // llvm::hash_code hash; + // for (auto entry : list) { + // hash = llvm::hash_combine(hash, entry.first->getCanonicalType()); + // hash = llvm::hash_combine(hash, entry.second); + // } + // return hash; + } +}; + +void simple_display(raw_ostream &out, const ProtocolRethrowsRequirementList reqs); + /// ProtocolDecl - A declaration of a protocol, for example: /// /// protocol Drawable { @@ -4051,6 +4111,9 @@ class ProtocolDecl final : public NominalTypeDecl { /// contain 'Self' in 'parameter' or 'other' position. bool existentialTypeSupported() const; + ProtocolRethrowsRequirementList getRethrowingRequirements() const; + bool isRethrowingProtocol() const; + private: void computeKnownProtocolKind() const; @@ -5460,6 +5523,23 @@ class ImportAsMemberStatus { } }; +enum class FunctionRethrowingKind : uint8_t { + /// The function is not throwing + None, + + /// The function rethrows by closure + ByClosure, + + /// The function rethrows by conformance + ByConformance, + + /// The function throws + Throws, + + /// The function throwing determinate is invalid + Invalid +}; + /// Base class for function-like declarations. class AbstractFunctionDecl : public GenericContext, public ValueDecl { friend class NeedsNewVTableEntryRequest; @@ -5663,6 +5743,8 @@ class AbstractFunctionDecl : public GenericContext, public ValueDecl { /// Returns true if the function body throws. bool hasThrows() const { return Bits.AbstractFunctionDecl.Throws; } + FunctionRethrowingKind getRethrowingKind() const; + // FIXME: Hack that provides names with keyword arguments for accessors. DeclName getEffectiveFullName() const; diff --git a/include/swift/AST/DiagnosticsSema.def b/include/swift/AST/DiagnosticsSema.def index 5dc8ccb4d52e6..3e0f5ac5824a5 100644 --- a/include/swift/AST/DiagnosticsSema.def +++ b/include/swift/AST/DiagnosticsSema.def @@ -2981,6 +2981,8 @@ ERROR(override_rethrows_with_non_rethrows,none, "be 'rethrows'", (bool)) ERROR(rethrows_without_throwing_parameter,none, "'rethrows' function must take a throwing function argument", ()) +ERROR(rethrows_attr_on_non_protocol,none, + "@rethrows may only be used on 'protocol' declarations", ()) ERROR(autoclosure_function_type,none, "@autoclosure attribute only applies to function types", @@ -4052,6 +4054,8 @@ NOTE(because_rethrows_argument_throws,none, NOTE(because_rethrows_default_argument_throws,none, "call is to 'rethrows' function, but a defaulted argument function" " can throw", ()) +NOTE(because_rethrows_default_conformance_throws,none, + "call is to 'rethrows' function, but a conformance has a throwing witness", ()) ERROR(throwing_call_in_nonthrowing_autoclosure,none, "call can throw, but it is executed in a non-throwing " diff --git a/include/swift/AST/KnownIdentifiers.def b/include/swift/AST/KnownIdentifiers.def index 4f77e9a98b409..904b47c915214 100644 --- a/include/swift/AST/KnownIdentifiers.def +++ b/include/swift/AST/KnownIdentifiers.def @@ -93,7 +93,9 @@ IDENTIFIER(KeyedEncodingContainer) IDENTIFIER(keyedBy) IDENTIFIER(keyPath) IDENTIFIER(makeIterator) +IDENTIFIER(makeAsyncIterator) IDENTIFIER(Iterator) +IDENTIFIER(AsyncIterator) IDENTIFIER(load) IDENTIFIER(main) IDENTIFIER_WITH_NAME(MainEntryPoint, "$main") diff --git a/include/swift/AST/KnownProtocols.def b/include/swift/AST/KnownProtocols.def index 1d4a8484b726c..af489b2d4a963 100644 --- a/include/swift/AST/KnownProtocols.def +++ b/include/swift/AST/KnownProtocols.def @@ -88,6 +88,9 @@ PROTOCOL(StringInterpolationProtocol) PROTOCOL(AdditiveArithmetic) PROTOCOL(Differentiable) +PROTOCOL(AsyncSequence) +PROTOCOL(AsyncIteratorProtocol) + PROTOCOL(FloatingPoint) EXPRESSIBLE_BY_LITERAL_PROTOCOL(ExpressibleByArrayLiteral, "Array", false) diff --git a/include/swift/AST/ProtocolConformanceRef.h b/include/swift/AST/ProtocolConformanceRef.h index e155092f58db6..77d67b0267278 100644 --- a/include/swift/AST/ProtocolConformanceRef.h +++ b/include/swift/AST/ProtocolConformanceRef.h @@ -170,8 +170,13 @@ class ProtocolConformanceRef { /// Get any additional requirements that are required for this conformance to /// be satisfied. ArrayRef getConditionalRequirements() const; + + bool classifyAsThrows() const; }; +void simple_display(llvm::raw_ostream &out, ProtocolConformanceRef conformanceRef); +SourceLoc extractNearestSourceLoc(const ProtocolConformanceRef conformanceRef); + } // end namespace swift #endif // LLVM_SWIFT_AST_PROTOCOLCONFORMANCEREF_H diff --git a/include/swift/AST/Stmt.h b/include/swift/AST/Stmt.h index 63907e2820ab4..029dadb177989 100644 --- a/include/swift/AST/Stmt.h +++ b/include/swift/AST/Stmt.h @@ -726,6 +726,8 @@ class RepeatWhileStmt : public LabeledStmt { /// \endcode class ForEachStmt : public LabeledStmt { SourceLoc ForLoc; + SourceLoc TryLoc; + SourceLoc AwaitLoc; Pattern *Pat; SourceLoc InLoc; Expr *Sequence; @@ -741,12 +743,12 @@ class ForEachStmt : public LabeledStmt { Expr *convertElementExpr = nullptr; public: - ForEachStmt(LabeledStmtInfo LabelInfo, SourceLoc ForLoc, Pattern *Pat, - SourceLoc InLoc, Expr *Sequence, SourceLoc WhereLoc, + ForEachStmt(LabeledStmtInfo LabelInfo, SourceLoc ForLoc, SourceLoc TryLoc, SourceLoc AwaitLoc, + Pattern *Pat, SourceLoc InLoc, Expr *Sequence, SourceLoc WhereLoc, Expr *WhereExpr, BraceStmt *Body, Optional implicit = None) : LabeledStmt(StmtKind::ForEach, getDefaultImplicitFlag(implicit, ForLoc), LabelInfo), - ForLoc(ForLoc), Pat(nullptr), InLoc(InLoc), Sequence(Sequence), + ForLoc(ForLoc), TryLoc(TryLoc), AwaitLoc(AwaitLoc), Pat(nullptr), InLoc(InLoc), Sequence(Sequence), WhereLoc(WhereLoc), WhereExpr(WhereExpr), Body(Body) { setPattern(Pat); } @@ -778,6 +780,9 @@ class ForEachStmt : public LabeledStmt { /// getWhereLoc - Retrieve the location of the 'where' keyword. SourceLoc getWhereLoc() const { return WhereLoc; } + + SourceLoc getAwaitLoc() const { return AwaitLoc; } + SourceLoc getTryLoc() const { return TryLoc; } /// getPattern - Retrieve the pattern describing the iteration variables. /// These variables will only be visible within the body of the loop. diff --git a/include/swift/AST/TypeCheckRequests.h b/include/swift/AST/TypeCheckRequests.h index 2e0f6c0696817..41bf7038006fc 100644 --- a/include/swift/AST/TypeCheckRequests.h +++ b/include/swift/AST/TypeCheckRequests.h @@ -311,6 +311,44 @@ class ExistentialTypeSupportedRequest : void cacheResult(bool value) const; }; +class ProtocolRethrowsRequirementsRequest : + public SimpleRequest { +public: + using SimpleRequest::SimpleRequest; + +private: + friend SimpleRequest; + + // Evaluation. + ProtocolRethrowsRequirementList + evaluate(Evaluator &evaluator, ProtocolDecl *decl) const; + +public: + // Caching. + bool isCached() const { return true; } +}; + +class ProtocolConformanceRefClassifyAsThrowsRequest : + public SimpleRequest { +public: + using SimpleRequest::SimpleRequest; + +private: + friend SimpleRequest; + + // Evaluation. + bool + evaluate(Evaluator &evaluator, ProtocolConformanceRef conformanceRef) const; + +public: + // Caching. + bool isCached() const { return true; } +}; + /// Determine whether the given declaration is 'final'. class IsFinalRequest : public SimpleRequest { +public: + using SimpleRequest::SimpleRequest; + +private: + friend SimpleRequest; + + // Evaluation. + FunctionRethrowingKind evaluate(Evaluator &evaluator, AbstractFunctionDecl *decl) const; + +public: + // Caching. + bool isCached() const { return true; } +}; + +void simple_display(llvm::raw_ostream &out, FunctionRethrowingKind value); + /// Request the custom attribute which attaches a result builder to the /// given declaration. class AttachedResultBuilderRequest : diff --git a/include/swift/AST/TypeCheckerTypeIDZone.def b/include/swift/AST/TypeCheckerTypeIDZone.def index 4c24a181a1888..ae3cee5930803 100644 --- a/include/swift/AST/TypeCheckerTypeIDZone.def +++ b/include/swift/AST/TypeCheckerTypeIDZone.def @@ -206,6 +206,8 @@ SWIFT_REQUEST(TypeChecker, RequiresOpaqueModifyCoroutineRequest, bool(AbstractStorageDecl *), SeparatelyCached, NoLocationInfo) SWIFT_REQUEST(TypeChecker, FragileFunctionKindRequest, FragileFunctionKind(DeclContext *), Cached, NoLocationInfo) +SWIFT_REQUEST(TypeChecker, FunctionRethrowingKindRequest, + FunctionRethrowingKind(AbstractFunctionDecl *), Cached, NoLocationInfo) SWIFT_REQUEST(TypeChecker, SelfAccessKindRequest, SelfAccessKind(FuncDecl *), SeparatelyCached, NoLocationInfo) SWIFT_REQUEST(TypeChecker, StorageImplInfoRequest, @@ -260,6 +262,12 @@ SWIFT_REQUEST(TypeChecker, ResolveImplicitMemberRequest, SWIFT_REQUEST(TypeChecker, ResolveTypeEraserTypeRequest, Type(ProtocolDecl *, TypeEraserAttr *), SeparatelyCached, NoLocationInfo) +SWIFT_REQUEST(TypeChecker, ProtocolRethrowsRequirementsRequest, + ProtocolRethrowsRequirementList(ProtocolDecl *), + Cached, NoLocationInfo) +SWIFT_REQUEST(TypeChecker, ProtocolConformanceRefClassifyAsThrowsRequest, + bool(ProtocolConformanceRef), + Cached, NoLocationInfo) SWIFT_REQUEST(TypeChecker, ResolveTypeRequest, Type (const TypeResolution *, TypeRepr *, GenericParamList *), Uncached, NoLocationInfo) diff --git a/lib/AST/ASTContext.cpp b/lib/AST/ASTContext.cpp index d07f1748654b2..4e7601e0e3bc1 100644 --- a/lib/AST/ASTContext.cpp +++ b/lib/AST/ASTContext.cpp @@ -206,6 +206,9 @@ struct ASTContext::Implementation { /// The declaration of 'Sequence.makeIterator()'. FuncDecl *MakeIterator = nullptr; + /// The declaration of 'AsyncSequence.makeAsyncIterator()'. + FuncDecl *MakeAsyncIterator = nullptr; + /// The declaration of Swift.Optional.Some. EnumElementDecl *OptionalSomeDecl = nullptr; @@ -772,6 +775,31 @@ FuncDecl *ASTContext::getSequenceMakeIterator() const { return nullptr; } +FuncDecl *ASTContext::getAsyncSequenceMakeAsyncIterator() const { + if (getImpl().MakeAsyncIterator) { + return getImpl().MakeAsyncIterator; + } + + auto proto = getProtocol(KnownProtocolKind::AsyncSequence); + if (!proto) + return nullptr; + + for (auto result : proto->lookupDirect(Id_makeAsyncIterator)) { + if (result->getDeclContext() != proto) + continue; + + if (auto func = dyn_cast(result)) { + if (func->getParameters()->size() != 0) + continue; + + getImpl().MakeAsyncIterator = func; + return func; + } + } + + return nullptr; +} + #define KNOWN_STDLIB_TYPE_DECL(NAME, DECL_CLASS, NUM_GENERIC_PARAMS) \ DECL_CLASS *ASTContext::get##NAME##Decl() const { \ if (getImpl().NAME##Decl) \ @@ -943,6 +971,8 @@ ProtocolDecl *ASTContext::getProtocol(KnownProtocolKind kind) const { M = getLoadedModule(Id_Differentiation); break; case KnownProtocolKind::Actor: + case KnownProtocolKind::AsyncSequence: + case KnownProtocolKind::AsyncIteratorProtocol: M = getLoadedModule(Id_Concurrency); break; default: diff --git a/lib/AST/ASTVerifier.cpp b/lib/AST/ASTVerifier.cpp index 200a485eec0e4..c9c8f74ae5253 100644 --- a/lib/AST/ASTVerifier.cpp +++ b/lib/AST/ASTVerifier.cpp @@ -1803,11 +1803,25 @@ class Verifier : public ASTWalker { Out << "\n"; abort(); } else if (E->throws() && !FT->isThrowing()) { - Out << "apply expression is marked as throwing, but function operand" - "does not have a throwing function type\n"; - E->dump(Out); - Out << "\n"; - abort(); + FunctionRethrowingKind rethrowingKind = FunctionRethrowingKind::Invalid; + if (auto DRE = dyn_cast(E->getFn())) { + if (auto fnDecl = dyn_cast(DRE->getDecl())) { + rethrowingKind = fnDecl->getRethrowingKind(); + } + } else if (auto OCDRE = dyn_cast(E->getFn())) { + if (auto fnDecl = dyn_cast(OCDRE->getDecl())) { + rethrowingKind = fnDecl->getRethrowingKind(); + } + } + + if (rethrowingKind != FunctionRethrowingKind::ByConformance && + rethrowingKind != FunctionRethrowingKind::Throws) { + Out << "apply expression is marked as throwing, but function operand" + "does not have a throwing function type\n"; + E->dump(Out); + Out << "\n"; + abort(); + } } if (E->isSuper() != E->getArg()->isSuperExpr()) { diff --git a/lib/AST/Decl.cpp b/lib/AST/Decl.cpp index 448702258dab5..1b95a4af2433f 100644 --- a/lib/AST/Decl.cpp +++ b/lib/AST/Decl.cpp @@ -4929,6 +4929,26 @@ bool ProtocolDecl::existentialTypeSupported() const { ExistentialTypeSupportedRequest{const_cast(this)}, true); } +void swift::simple_display(llvm::raw_ostream &out, const ProtocolRethrowsRequirementList list) { + for (auto entry : list) { + simple_display(out, entry.first); + simple_display(out, entry.second); + } +} + + +ProtocolRethrowsRequirementList +ProtocolDecl::getRethrowingRequirements() const { + return evaluateOrDefault(getASTContext().evaluator, + ProtocolRethrowsRequirementsRequest{const_cast(this)}, + ProtocolRethrowsRequirementList()); +} + +bool +ProtocolDecl::isRethrowingProtocol() const { + return getRethrowingRequirements().size() > 0; +} + StringRef ProtocolDecl::getObjCRuntimeName( llvm::SmallVectorImpl &buffer) const { // If there is an 'objc' attribute with a name, use that name. @@ -6769,6 +6789,12 @@ bool AbstractFunctionDecl::canBeAsyncHandler() const { false); } +FunctionRethrowingKind AbstractFunctionDecl::getRethrowingKind() const { + return evaluateOrDefault(getASTContext().evaluator, + FunctionRethrowingKindRequest{const_cast(this)}, + FunctionRethrowingKind::Invalid); +} + BraceStmt *AbstractFunctionDecl::getBody(bool canSynthesize) const { if ((getBodyKind() == BodyKind::Synthesize || getBodyKind() == BodyKind::Unparsed) && diff --git a/lib/AST/ProtocolConformance.cpp b/lib/AST/ProtocolConformance.cpp index 62b4791e29c88..e80598602fb93 100644 --- a/lib/AST/ProtocolConformance.cpp +++ b/lib/AST/ProtocolConformance.cpp @@ -26,6 +26,7 @@ #include "swift/AST/TypeCheckRequests.h" #include "swift/AST/TypeWalker.h" #include "swift/AST/Types.h" +#include "swift/AST/TypeCheckRequests.h" #include "swift/Basic/Statistic.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/Statistic.h" @@ -185,6 +186,79 @@ ProtocolConformanceRef::getWitnessByName(Type type, DeclName name) const { return getConcrete()->getWitnessDeclRef(requirement); } + +static bool classifyRequirement(ModuleDecl *module, + ProtocolConformance *reqConformance, + ValueDecl *requiredFn) { + auto declRef = reqConformance->getWitnessDeclRef(requiredFn); + auto witnessDecl = cast(declRef.getDecl()); + switch (witnessDecl->getRethrowingKind()) { + case FunctionRethrowingKind::ByConformance: { + auto substitutions = reqConformance->getSubstitutions(module); + for (auto conformanceRef : substitutions.getConformances()) { + if (conformanceRef.classifyAsThrows()) { + return true; + } + } + break; + } + case FunctionRethrowingKind::None: + break; + case FunctionRethrowingKind::Throws: + return true; + default: + return true; + } + return false; +} + +// classify the type requirements of a given prottocol type with a function +// requirement as throws or not. This will detect if the signature of the +// function is throwing or not depending on associated types. +static bool classifyTypeRequirement(ModuleDecl *module, Type protoType, + ValueDecl *requiredFn, + ProtocolConformance *conformance, + ProtocolDecl *requiredProtocol) { + auto reqProtocol = cast(requiredFn->getDeclContext()); + ProtocolConformance *reqConformance; + + if(protoType->isEqual(reqProtocol->getSelfInterfaceType()) && + requiredProtocol == reqProtocol) { + reqConformance = conformance; + } else { + auto reqConformanceRef = + conformance->getAssociatedConformance(protoType, reqProtocol); + if (!reqConformanceRef.isConcrete()) { + return true; + } + reqConformance = reqConformanceRef.getConcrete(); + } + + return classifyRequirement(module, reqConformance, requiredFn); +} + +bool +ProtocolConformanceRefClassifyAsThrowsRequest::evaluate( + Evaluator &evaluator, ProtocolConformanceRef conformanceRef) const { + auto conformance = conformanceRef.getConcrete(); + auto requiredProtocol = conformanceRef.getRequirement(); + auto module = requiredProtocol->getModuleContext(); + for (auto req : requiredProtocol->getRethrowingRequirements()) { + if (classifyTypeRequirement(module, req.first, req.second, + conformance, requiredProtocol)) { + return true; + } + } + return false; +} + +bool ProtocolConformanceRef::classifyAsThrows() const { + if (!isConcrete()) { return true; } + return evaluateOrDefault(getRequirement()->getASTContext().evaluator, + ProtocolConformanceRefClassifyAsThrowsRequest{ *this }, + true); +} + void *ProtocolConformance::operator new(size_t bytes, ASTContext &context, AllocationArena arena, unsigned alignment) { @@ -1529,3 +1603,20 @@ void swift::simple_display(llvm::raw_ostream &out, const ProtocolConformance *conf) { conf->printName(out); } + +void swift::simple_display(llvm::raw_ostream &out, ProtocolConformanceRef conformanceRef) { + if (conformanceRef.isAbstract()) { + simple_display(out, conformanceRef.getAbstract()); + } else if (conformanceRef.isConcrete()) { + simple_display(out, conformanceRef.getConcrete()); + } +} + +SourceLoc swift::extractNearestSourceLoc(const ProtocolConformanceRef conformanceRef) { + if (conformanceRef.isAbstract()) { + return extractNearestSourceLoc(conformanceRef.getAbstract()); + } else if (conformanceRef.isConcrete()) { + return extractNearestSourceLoc(conformanceRef.getConcrete()->getProtocol()); + } + return SourceLoc(); +} diff --git a/lib/AST/SubstitutionMap.cpp b/lib/AST/SubstitutionMap.cpp index b9352b4314852..1d5fb19b20480 100644 --- a/lib/AST/SubstitutionMap.cpp +++ b/lib/AST/SubstitutionMap.cpp @@ -30,6 +30,7 @@ #include "swift/AST/LazyResolver.h" #include "swift/AST/Module.h" #include "swift/AST/ProtocolConformance.h" +#include "swift/AST/TypeCheckRequests.h" #include "swift/AST/Types.h" #include "swift/Basic/Defer.h" #include "llvm/Support/Debug.h" diff --git a/lib/AST/TypeCheckRequests.cpp b/lib/AST/TypeCheckRequests.cpp index 5cd9930133df2..a9d2414efbdc2 100644 --- a/lib/AST/TypeCheckRequests.cpp +++ b/lib/AST/TypeCheckRequests.cpp @@ -278,6 +278,31 @@ void ExistentialTypeSupportedRequest::cacheResult(bool value) const { decl->setCachedExistentialTypeSupported(value); } +//----------------------------------------------------------------------------// +// getRethrowingKind computation. +//----------------------------------------------------------------------------// + +void swift::simple_display(llvm::raw_ostream &out, + FunctionRethrowingKind kind) { + switch (kind) { + case FunctionRethrowingKind::None: + out << "non-throwing"; + break; + case FunctionRethrowingKind::ByClosure: + out << "by closure"; + break; + case FunctionRethrowingKind::ByConformance: + out << "by conformance"; + break; + case FunctionRethrowingKind::Throws: + out << "throws"; + break; + case FunctionRethrowingKind::Invalid: + out << "invalid"; + break; + } +} + //----------------------------------------------------------------------------// // isFinal computation. //----------------------------------------------------------------------------// diff --git a/lib/IRGen/GenMeta.cpp b/lib/IRGen/GenMeta.cpp index 2fcdf3c72cc08..4fc6a50a8a4cc 100644 --- a/lib/IRGen/GenMeta.cpp +++ b/lib/IRGen/GenMeta.cpp @@ -5080,7 +5080,9 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) { // The other known protocols aren't special at runtime. case KnownProtocolKind::Sequence: + case KnownProtocolKind::AsyncSequence: case KnownProtocolKind::IteratorProtocol: + case KnownProtocolKind::AsyncIteratorProtocol: case KnownProtocolKind::RawRepresentable: case KnownProtocolKind::Equatable: case KnownProtocolKind::Hashable: diff --git a/lib/IRGen/IRGenSIL.cpp b/lib/IRGen/IRGenSIL.cpp index 31bef947ca330..d229224b6d004 100644 --- a/lib/IRGen/IRGenSIL.cpp +++ b/lib/IRGen/IRGenSIL.cpp @@ -4360,7 +4360,7 @@ void IRGenSILFunction::visitDebugValueAddrInst(DebugValueAddrInst *i) { break; } assert(llvm::isa(Storage) && - "arg expected to be load from inside %swift.context"); + "arg expected to be load from inside swift.context"); #endif Indirection = CoroIndirectValue; } diff --git a/lib/Parse/ParseDecl.cpp b/lib/Parse/ParseDecl.cpp index cba6397192e29..78798edf4c7a0 100644 --- a/lib/Parse/ParseDecl.cpp +++ b/lib/Parse/ParseDecl.cpp @@ -2691,7 +2691,8 @@ ParserStatus Parser::parseDeclAttribute( // If this not an identifier, the attribute is malformed. if (Tok.isNot(tok::identifier) && Tok.isNot(tok::kw_in) && - Tok.isNot(tok::kw_inout)) { + Tok.isNot(tok::kw_inout) && + Tok.isNot(tok::kw_rethrows)) { if (Tok.is(tok::code_complete)) { if (CodeCompletion) { @@ -2712,7 +2713,7 @@ ParserStatus Parser::parseDeclAttribute( // If the attribute follows the new representation, switch // over to the alternate parsing path. DeclAttrKind DK = DeclAttribute::getAttrKindFromString(Tok.getText()); - + if (DK == DAK_Rethrows) { DK = DAK_AtRethrows; } auto checkInvalidAttrName = [&](StringRef invalidName, StringRef correctName, DeclAttrKind kind, @@ -3637,7 +3638,8 @@ static void skipAttribute(Parser &P) { // Parse the attribute name, which can be qualified, have // generic arguments, and so on. do { - if (!P.consumeIf(tok::identifier) && !P.consumeIf(tok::code_complete)) + if (!(P.consumeIf(tok::identifier) || P.consumeIf(tok::kw_rethrows)) && + !P.consumeIf(tok::code_complete)) return; if (P.startsWithLess(P.Tok)) { @@ -3656,8 +3658,16 @@ static void skipAttribute(Parser &P) { } bool Parser::isStartOfSwiftDecl() { - // If this is obviously not the start of a decl, then we're done. - if (!isKeywordPossibleDeclStart(Tok)) return false; + if (Tok.is(tok::at_sign) && peekToken().is(tok::kw_rethrows)) { + // @rethrows does not follow the general rule of @ so + // it is needed to short circuit this else there will be an infinite + // loop on invalid attributes of just rethrows + } else if (!isKeywordPossibleDeclStart(Tok)) { + // If this is obviously not the start of a decl, then we're done. + return false; + } + + // When 'init' appears inside another 'init', it's likely the user wants to // invoke an initializer but forgets to prefix it with 'self.' or 'super.' diff --git a/lib/Parse/ParseStmt.cpp b/lib/Parse/ParseStmt.cpp index b0b16fdab34b7..36414a83fedfa 100644 --- a/lib/Parse/ParseStmt.cpp +++ b/lib/Parse/ParseStmt.cpp @@ -2124,6 +2124,19 @@ ParserResult Parser::parseStmtForEach(LabeledStmtInfo LabelInfo) { // lookahead to resolve what is going on. bool IsCStyleFor = isStmtForCStyle(*this); auto StartOfControl = Tok.getLoc(); + SourceLoc AwaitLoc; + SourceLoc TryLoc; + + if (shouldParseExperimentalConcurrency() && + Tok.isContextualKeyword("await")) { + AwaitLoc = consumeToken(); + } if (shouldParseExperimentalConcurrency() && + Tok.is(tok::kw_try)) { + TryLoc = consumeToken(); + if (Tok.isContextualKeyword("await")) { + AwaitLoc = consumeToken(); + } + } // Parse the pattern. This is either 'case ' or just a // normal pattern. @@ -2218,7 +2231,7 @@ ParserResult Parser::parseStmtForEach(LabeledStmtInfo LabelInfo) { return makeParserResult( Status, - new (Context) ForEachStmt(LabelInfo, ForLoc, pattern.get(), InLoc, + new (Context) ForEachStmt(LabelInfo, ForLoc, TryLoc, AwaitLoc, pattern.get(), InLoc, Container.get(), WhereLoc, Where.getPtrOrNull(), Body.get())); } diff --git a/lib/SILGen/SILGenApply.cpp b/lib/SILGen/SILGenApply.cpp index 12477f287b1a8..552b969ff76e8 100644 --- a/lib/SILGen/SILGenApply.cpp +++ b/lib/SILGen/SILGenApply.cpp @@ -4969,7 +4969,21 @@ RValue SILGenFunction::emitApplyMethod(SILLocation loc, ConcreteDeclRef declRef, .asForeign(requiresForeignEntryPoint(declRef.getDecl())); auto declRefConstant = getConstantInfo(getTypeExpansionContext(), callRef); auto subs = declRef.getSubstitutions(); - + bool throws = false; + bool markedAsRethrows = call->getAttrs().hasAttribute(); + FunctionRethrowingKind rethrowingKind = call->getRethrowingKind(); + if (rethrowingKind == FunctionRethrowingKind::ByConformance) { + for (auto conformanceRef : subs.getConformances()) { + if (conformanceRef.classifyAsThrows()) { + throws = true; + break; + } + } + } else if (markedAsRethrows && + rethrowingKind == FunctionRethrowingKind::Throws) { + throws = true; + } + // Scope any further writeback just within this operation. FormalEvaluationScope writebackScope(*this); @@ -5002,7 +5016,7 @@ RValue SILGenFunction::emitApplyMethod(SILLocation loc, ConcreteDeclRef declRef, // Form the call emission. CallEmission emission(*this, std::move(*callee), std::move(writebackScope)); emission.addSelfParam(loc, std::move(self), substFormalType.getParams()[0]); - emission.addCallSite(loc, std::move(args), /*throws*/ false); + emission.addCallSite(loc, std::move(args), throws); return emission.apply(C); } diff --git a/lib/SILGen/SILGenStmt.cpp b/lib/SILGen/SILGenStmt.cpp index 93a8cf34880ff..36e2088c70814 100644 --- a/lib/SILGen/SILGenStmt.cpp +++ b/lib/SILGen/SILGenStmt.cpp @@ -181,6 +181,8 @@ namespace { #define STMT(ID, BASE) void visit##ID##Stmt(ID##Stmt *S); #include "swift/AST/StmtNodes.def" + void visitAsyncForEachStmt(ForEachStmt *S); + ASTContext &getASTContext() { return SGF.getASTContext(); } SILBasicBlock *createBasicBlock() { return SGF.createBasicBlock(); } @@ -932,10 +934,241 @@ void StmtEmitter::visitRepeatWhileStmt(RepeatWhileStmt *S) { SGF.BreakContinueDestStack.pop_back(); } +void StmtEmitter::visitAsyncForEachStmt(ForEachStmt *S) { + + // Dig out information about the sequence conformance. + auto sequenceConformance = S->getSequenceConformance(); + Type sequenceType = S->getSequence()->getType(); + + auto asyncSequenceProto = + SGF.getASTContext().getProtocol(KnownProtocolKind::AsyncSequence); + auto sequenceSubs = SubstitutionMap::getProtocolSubstitutions( + asyncSequenceProto, sequenceType, sequenceConformance); + + // Emit the 'generator' variable that we'll be using for iteration. + LexicalScope OuterForScope(SGF, CleanupLocation(S)); + { + auto initialization = + SGF.emitInitializationForVarDecl(S->getIteratorVar(), false); + SILLocation loc = SILLocation(S->getSequence()); + + // Compute the reference to the AsyncSequence's makeAsyncSequence(). + FuncDecl *makeGeneratorReq = + SGF.getASTContext().getAsyncSequenceMakeAsyncIterator(); + ConcreteDeclRef makeGeneratorRef(makeGeneratorReq, sequenceSubs); + + // Call makeAsyncSequence(). + RValue result = SGF.emitApplyMethod( + loc, makeGeneratorRef, ArgumentSource(S->getSequence()), + PreparedArguments(ArrayRef({})), + SGFContext(initialization.get())); + if (!result.isInContext()) { + ArgumentSource(SILLocation(S->getSequence()), + std::move(result).ensurePlusOne(SGF, loc)) + .forwardInto(SGF, initialization.get()); + } + } + + // If we ever reach an unreachable point, stop emitting statements. + // This will need revision if we ever add goto. + if (!SGF.B.hasValidInsertionPoint()) return; + + // If generator's optional result is address-only, create a stack allocation + // to hold the results. This will be initialized on every entry into the loop + // header and consumed by the loop body. On loop exit, the terminating value + // will be in the buffer. + CanType optTy; + if (S->getConvertElementExpr()) { + optTy = S->getConvertElementExpr()->getType()->getCanonicalType(); + } else { + optTy = OptionalType::get(S->getSequenceConformance().getTypeWitnessByName( + S->getSequence()->getType(), + SGF.getASTContext().Id_Element)) + ->getCanonicalType(); + } + auto &optTL = SGF.getTypeLowering(optTy); + SILValue addrOnlyBuf; + ManagedValue nextBufOrValue; + + if (optTL.isAddressOnly() && SGF.silConv.useLoweredAddresses()) + addrOnlyBuf = SGF.emitTemporaryAllocation(S, optTL.getLoweredType()); + + // Create a new basic block and jump into it. + JumpDest loopDest = createJumpDest(S->getBody()); + SGF.B.emitBlock(loopDest.getBlock(), S); + + // Compute the reference to the the generator's next() && cancel(). + auto generatorProto = + SGF.getASTContext().getProtocol(KnownProtocolKind::AsyncIteratorProtocol); + ValueDecl *generatorNextReq = generatorProto->getSingleRequirement( + DeclName(SGF.getASTContext(), SGF.getASTContext().Id_next, + ArrayRef())); + auto generatorAssocType = + asyncSequenceProto->getAssociatedType(SGF.getASTContext().Id_AsyncIterator); + auto generatorMemberRef = DependentMemberType::get( + asyncSequenceProto->getSelfInterfaceType(), generatorAssocType); + auto generatorType = sequenceConformance.getAssociatedType( + sequenceType, generatorMemberRef); + auto generatorConformance = sequenceConformance.getAssociatedConformance( + sequenceType, generatorMemberRef, generatorProto); + auto generatorSubs = SubstitutionMap::getProtocolSubstitutions( + generatorProto, generatorType, generatorConformance); + ConcreteDeclRef generatorNextRef(generatorNextReq, generatorSubs); + + // Set the destinations for 'break' and 'continue'. + JumpDest endDest = createJumpDest(S->getBody()); + SGF.BreakContinueDestStack.push_back({ S, endDest, loopDest }); + + + auto buildArgumentSource = [&]() { + if (cast(generatorNextRef.getDecl())->getSelfAccessKind() == + SelfAccessKind::Mutating) { + LValue lv = + SGF.emitLValue(S->getIteratorVarRef(), SGFAccessKind::ReadWrite); + return ArgumentSource(S, std::move(lv)); + } + LValue lv = + SGF.emitLValue(S->getIteratorVarRef(), SGFAccessKind::OwnedObjectRead); + return ArgumentSource( + S, SGF.emitLoadOfLValue(S->getIteratorVarRef(), std::move(lv), + SGFContext().withFollowingSideEffects())); + }; + + auto buildElementRValue = [&](SILLocation loc, SGFContext ctx) { + RValue result; + result = SGF.emitApplyMethod( + loc, generatorNextRef, buildArgumentSource(), + PreparedArguments(ArrayRef({})), + S->getElementExpr() ? SGFContext() : ctx); + if (S->getElementExpr()) { + SILGenFunction::OpaqueValueRAII pushOpaqueValue( + SGF, S->getElementExpr(), + std::move(result).getAsSingleValue(SGF, loc)); + result = SGF.emitRValue(S->getConvertElementExpr(), ctx); + } + return result; + }; + + // Then emit the loop destination block. + // + // Advance the generator. Use a scope to ensure that any temporary stack + // allocations in the subexpression are immediately released. + if (optTL.isAddressOnly() && SGF.silConv.useLoweredAddresses()) { + // Create the initialization outside of the innerForScope so that the + // innerForScope doesn't clean it up. + auto nextInit = SGF.useBufferAsTemporary(addrOnlyBuf, optTL); + { + ArgumentScope innerForScope(SGF, SILLocation(S)); + SILLocation loc = SILLocation(S); + RValue result = buildElementRValue(loc, SGFContext(nextInit.get())); + if (!result.isInContext()) { + ArgumentSource(SILLocation(S->getSequence()), + std::move(result).ensurePlusOne(SGF, loc)) + .forwardInto(SGF, nextInit.get()); + } + innerForScope.pop(); + } + nextBufOrValue = nextInit->getManagedAddress(); + } else { + ArgumentScope innerForScope(SGF, SILLocation(S)); + nextBufOrValue = innerForScope.popPreservingValue( + buildElementRValue(SILLocation(S), SGFContext()) + .getAsSingleValue(SGF, SILLocation(S))); + } + + SILBasicBlock *failExitingBlock = createBasicBlock(); + SwitchEnumBuilder switchEnumBuilder(SGF.B, S, nextBufOrValue); + + switchEnumBuilder.addOptionalSomeCase( + createBasicBlock(), loopDest.getBlock(), + [&](ManagedValue inputValue, SwitchCaseFullExpr &&scope) { + SGF.emitProfilerIncrement(S->getBody()); + + // Emit the loop body. + // The declared variable(s) for the current element are destroyed + // at the end of each loop iteration. + { + Scope innerForScope(SGF.Cleanups, CleanupLocation(S->getBody())); + // Emit the initialization for the pattern. If any of the bound + // patterns + // fail (because this is a 'for case' pattern with a refutable + // pattern, + // the code should jump to the continue block. + InitializationPtr initLoopVars = + SGF.emitPatternBindingInitialization(S->getPattern(), loopDest); + + // If we had a loadable "next" generator value, we know it is present. + // Get the value out of the optional, and wrap it up with a cleanup so + // that any exits out of this scope properly clean it up. + // + // *NOTE* If we do not have an address only value, then inputValue is + // *already properly unwrapped. + if (optTL.isAddressOnly() && SGF.silConv.useLoweredAddresses()) { + inputValue = SGF.emitUncheckedGetOptionalValueFrom( + S, inputValue, optTL, SGFContext(initLoopVars.get())); + } + + if (!inputValue.isInContext()) + RValue(SGF, S, optTy.getOptionalObjectType(), inputValue) + .forwardInto(SGF, S, initLoopVars.get()); + + // Now that the pattern has been initialized, check any where + // condition. + // If it fails, loop around as if 'continue' happened. + if (auto *Where = S->getWhere()) { + auto cond = SGF.emitCondition(Where, /*invert*/ true); + // If self is null, branch to the epilog. + cond.enterTrue(SGF); + SGF.Cleanups.emitBranchAndCleanups(loopDest, Where, {}); + cond.exitTrue(SGF); + cond.complete(SGF); + } + + visit(S->getBody()); + } + + // If we emitted an unreachable in the body, we will not have a valid + // insertion point. Just return early. + if (!SGF.B.hasValidInsertionPoint()) { + scope.unreachableExit(); + return; + } + + // Otherwise, associate the loop body's closing brace with this branch. + RegularLocation L(S->getBody()); + L.pointToEnd(); + scope.exitAndBranch(L); + }, + SGF.loadProfilerCount(S->getBody())); + + // We add loop fail block, just to be defensive about intermediate + // transformations performing cleanups at scope.exit(). We still jump to the + // contBlock. + switchEnumBuilder.addOptionalNoneCase( + createBasicBlock(), failExitingBlock, + [&](ManagedValue inputValue, SwitchCaseFullExpr &&scope) { + assert(!inputValue && "None should not be passed an argument!"); + scope.exitAndBranch(S); + }, + SGF.loadProfilerCount(S)); + + std::move(switchEnumBuilder).emit(); + + SGF.B.emitBlock(failExitingBlock); + emitOrDeleteBlock(SGF, endDest, S); + SGF.BreakContinueDestStack.pop_back(); +} + void StmtEmitter::visitForEachStmt(ForEachStmt *S) { + if (S->getAwaitLoc().isValid()) { + visitAsyncForEachStmt(S); + return; + } + // Dig out information about the sequence conformance. auto sequenceConformance = S->getSequenceConformance(); Type sequenceType = S->getSequence()->getType(); + auto sequenceProto = SGF.getASTContext().getProtocol(KnownProtocolKind::Sequence); auto sequenceSubs = SubstitutionMap::getProtocolSubstitutions( diff --git a/lib/Sema/BuilderTransform.cpp b/lib/Sema/BuilderTransform.cpp index f833a705e452d..f461ff9105178 100644 --- a/lib/Sema/BuilderTransform.cpp +++ b/lib/Sema/BuilderTransform.cpp @@ -774,7 +774,8 @@ class BuilderClosureVisitor // take care of this. auto sequenceProto = TypeChecker::getProtocol( dc->getASTContext(), forEachStmt->getForLoc(), - KnownProtocolKind::Sequence); + forEachStmt->getAwaitLoc().isValid() ? + KnownProtocolKind::AsyncSequence : KnownProtocolKind::Sequence); if (!sequenceProto) { if (!unhandledNode) unhandledNode = forEachStmt; diff --git a/lib/Sema/CSApply.cpp b/lib/Sema/CSApply.cpp index 4581ecf8e39b7..655205f7fecb7 100644 --- a/lib/Sema/CSApply.cpp +++ b/lib/Sema/CSApply.cpp @@ -7966,7 +7966,9 @@ static Optional applySolutionToForEachStmt( // Get the conformance of the sequence type to the Sequence protocol. auto stmt = forEachStmtInfo.stmt; auto sequenceProto = TypeChecker::getProtocol( - cs.getASTContext(), stmt->getForLoc(), KnownProtocolKind::Sequence); + cs.getASTContext(), stmt->getForLoc(), + stmt->getAwaitLoc().isValid() ? + KnownProtocolKind::AsyncSequence : KnownProtocolKind::Sequence); auto sequenceConformance = TypeChecker::conformsToProtocol( forEachStmtInfo.sequenceType, sequenceProto, cs.DC); assert(!sequenceConformance.isInvalid() && diff --git a/lib/Sema/CSGen.cpp b/lib/Sema/CSGen.cpp index f4aeb8cd3849a..95de559e37a82 100644 --- a/lib/Sema/CSGen.cpp +++ b/lib/Sema/CSGen.cpp @@ -3655,6 +3655,7 @@ generateForEachStmtConstraints( ConstraintSystem &cs, SolutionApplicationTarget target, Expr *sequence) { auto forEachStmtInfo = target.getForEachStmtInfo(); ForEachStmt *stmt = forEachStmtInfo.stmt; + bool isAsync = stmt->getAwaitLoc().isValid(); auto locator = cs.getConstraintLocator(sequence); auto contextualLocator = @@ -3662,7 +3663,9 @@ generateForEachStmtConstraints( // The expression type must conform to the Sequence protocol. auto sequenceProto = TypeChecker::getProtocol( - cs.getASTContext(), stmt->getForLoc(), KnownProtocolKind::Sequence); + cs.getASTContext(), stmt->getForLoc(), + isAsync ? + KnownProtocolKind::AsyncSequence : KnownProtocolKind::Sequence); if (!sequenceProto) { return None; } @@ -3708,18 +3711,22 @@ generateForEachStmtConstraints( // Determine the iterator type. auto iteratorAssocType = - sequenceProto->getAssociatedType(cs.getASTContext().Id_Iterator); + sequenceProto->getAssociatedType(isAsync ? + cs.getASTContext().Id_AsyncIterator : cs.getASTContext().Id_Iterator); Type iteratorType = DependentMemberType::get(sequenceType, iteratorAssocType); // The iterator type must conform to IteratorProtocol. ProtocolDecl *iteratorProto = TypeChecker::getProtocol( cs.getASTContext(), stmt->getForLoc(), - KnownProtocolKind::IteratorProtocol); + isAsync ? + KnownProtocolKind::AsyncIteratorProtocol : KnownProtocolKind::IteratorProtocol); if (!iteratorProto) return None; // Reference the makeIterator witness. - FuncDecl *makeIterator = ctx.getSequenceMakeIterator(); + FuncDecl *makeIterator = isAsync ? + ctx.getAsyncSequenceMakeAsyncIterator() : ctx.getSequenceMakeIterator(); + Type makeIteratorType = cs.createTypeVariable(locator, TVO_CanBindToNoEscape); cs.addValueWitnessConstraint( diff --git a/lib/Sema/TypeCheckAttr.cpp b/lib/Sema/TypeCheckAttr.cpp index 2b87eed617852..2116896825f91 100644 --- a/lib/Sema/TypeCheckAttr.cpp +++ b/lib/Sema/TypeCheckAttr.cpp @@ -215,6 +215,7 @@ class AttributeChecker : public AttributeVisitor { void visitNSCopyingAttr(NSCopyingAttr *attr); void visitRequiredAttr(RequiredAttr *attr); void visitRethrowsAttr(RethrowsAttr *attr); + void visitAtRethrowsAttr(AtRethrowsAttr *attr); void checkApplicationMainAttribute(DeclAttribute *attr, Identifier Id_ApplicationDelegate, @@ -2134,44 +2135,27 @@ void AttributeChecker::visitRequiredAttr(RequiredAttr *attr) { } } -static bool hasThrowingFunctionParameter(CanType type) { - // Only consider throwing function types. - if (auto fnType = dyn_cast(type)) { - return fnType->getExtInfo().isThrowing(); - } - - // Look through tuples. - if (auto tuple = dyn_cast(type)) { - for (auto eltType : tuple.getElementTypes()) { - if (hasThrowingFunctionParameter(eltType)) - return true; - } - return false; - } - - // Suppress diagnostics in the presence of errors. - if (type->hasError()) { - return true; - } - - return false; -} - void AttributeChecker::visitRethrowsAttr(RethrowsAttr *attr) { // 'rethrows' only applies to functions that take throwing functions // as parameters. - auto fn = cast(D); - for (auto param : *fn->getParameters()) { - if (hasThrowingFunctionParameter(param->getType() - ->lookThroughAllOptionalTypes() - ->getCanonicalType())) - return; + auto fn = dyn_cast(D); + if (fn && fn->getRethrowingKind() != FunctionRethrowingKind::Invalid) { + return; } diagnose(attr->getLocation(), diag::rethrows_without_throwing_parameter); attr->setInvalid(); } +void AttributeChecker::visitAtRethrowsAttr(AtRethrowsAttr *attr) { + if (isa(D)) { + return; + } + + diagnose(attr->getLocation(), diag::rethrows_attr_on_non_protocol); + attr->setInvalid(); +} + /// Collect all used generic parameter types from a given type. static void collectUsedGenericParameters( Type Ty, SmallPtrSetImpl &ConstrainedGenericParams) { diff --git a/lib/Sema/TypeCheckConcurrency.cpp b/lib/Sema/TypeCheckConcurrency.cpp index 8b15e80643a41..cc5181a1f361f 100644 --- a/lib/Sema/TypeCheckConcurrency.cpp +++ b/lib/Sema/TypeCheckConcurrency.cpp @@ -958,12 +958,15 @@ namespace { /// \returns true if we diagnosed the entity, \c false otherwise. bool diagnoseInOutArg(const ApplyExpr *call, const InOutExpr *arg, bool isPartialApply) { + // check that the call is actually async if (!isAsyncCall(call)) return false; Expr *subArg = arg->getSubExpr(); ValueDecl *valueDecl = nullptr; + if (auto binding = dyn_cast(subArg)) + subArg = binding->getSubExpr(); if (LookupExpr *baseArg = dyn_cast(subArg)) { while (LookupExpr *nextLayer = dyn_cast(baseArg->getBase())) baseArg = nextLayer; diff --git a/lib/Sema/TypeCheckConstraints.cpp b/lib/Sema/TypeCheckConstraints.cpp index 302040a68ad11..7356e604193d5 100644 --- a/lib/Sema/TypeCheckConstraints.cpp +++ b/lib/Sema/TypeCheckConstraints.cpp @@ -562,7 +562,9 @@ bool TypeChecker::typeCheckForEachBinding(DeclContext *dc, ForEachStmt *stmt) { }; auto sequenceProto = TypeChecker::getProtocol( - dc->getASTContext(), stmt->getForLoc(), KnownProtocolKind::Sequence); + dc->getASTContext(), stmt->getForLoc(), + stmt->getAwaitLoc().isValid() ? + KnownProtocolKind::AsyncSequence : KnownProtocolKind::Sequence); if (!sequenceProto) return failed(); @@ -587,6 +589,45 @@ bool TypeChecker::typeCheckForEachBinding(DeclContext *dc, ForEachStmt *stmt) { if (!typeCheckExpression(target)) return failed(); + // check to see if the sequence expr is throwing (and async), if so require + // the stmt to have a try loc + if (stmt->getAwaitLoc().isValid()) { + auto Ty = sequence->getType(); + if (Ty.isNull()) { + auto DRE = dyn_cast(sequence); + if (DRE) { + Ty = DRE->getDecl()->getInterfaceType(); + } + if (Ty.isNull()) { + return failed(); + } + } + auto context = Ty->getNominalOrBoundGenericNominal(); + if (!context) { + // if no nominal type can be determined then we must consider this to be + // a potential throwing source and concequently this must have a valid try + // location to account for that potential ambiguity. + if (stmt->getTryLoc().isInvalid()) { + auto &diags = dc->getASTContext().Diags; + diags.diagnose(stmt->getAwaitLoc(), diag::throwing_call_unhandled); + return failed(); + } else { + return false; + } + + } + auto module = dc->getParentModule(); + auto conformanceRef = module->lookupConformance(Ty, sequenceProto); + + if (conformanceRef.classifyAsThrows() && + stmt->getTryLoc().isInvalid()) { + auto &diags = dc->getASTContext().Diags; + diags.diagnose(stmt->getAwaitLoc(), diag::throwing_call_unhandled); + + return failed(); + } + } + return false; } diff --git a/lib/Sema/TypeCheckDecl.cpp b/lib/Sema/TypeCheckDecl.cpp index f5b253baa1d89..c281691d82b19 100644 --- a/lib/Sema/TypeCheckDecl.cpp +++ b/lib/Sema/TypeCheckDecl.cpp @@ -696,6 +696,165 @@ ExistentialTypeSupportedRequest::evaluate(Evaluator &evaluator, return true; } +static bool hasThrowingFunctionClosureParameter(CanType type) { + // Only consider throwing function types. + if (auto fnType = dyn_cast(type)) { + return fnType->getExtInfo().isThrowing(); + } + + // Look through tuples. + if (auto tuple = dyn_cast(type)) { + for (auto eltType : tuple.getElementTypes()) { + auto elt = eltType->lookThroughAllOptionalTypes()->getCanonicalType(); + if (hasThrowingFunctionClosureParameter(elt)) + return true; + } + return false; + } + + // Suppress diagnostics in the presence of errors. + if (type->hasError()) { + return true; + } + + return false; +} + +static FunctionRethrowingKind +getTypeThrowingKind(Type interfaceTy, GenericSignature genericSig) { + if (interfaceTy->isTypeParameter()) { + for (auto proto : genericSig->getRequiredProtocols(interfaceTy)) { + if (proto->isRethrowingProtocol()) { + return FunctionRethrowingKind::ByConformance; + } + } + } else if (auto NTD = interfaceTy->getNominalOrBoundGenericNominal()) { + if (auto genericSig = NTD->getGenericSignature()) { + for (auto req : genericSig->getRequirements()) { + if (req.getKind() == RequirementKind::Conformance) { + if (req.getSecondType()->castTo() + ->getDecl() + ->isRethrowingProtocol()) { + return FunctionRethrowingKind::ByConformance; + } + } + } + } + } + return FunctionRethrowingKind::Invalid; +} + +static FunctionRethrowingKind +getParameterThrowingKind(AbstractFunctionDecl *decl, + GenericSignature genericSig) { + FunctionRethrowingKind kind = FunctionRethrowingKind::Invalid; + // check all parameters to determine if any are closures that throw + bool foundThrowingClosure = false; + for (auto param : *decl->getParameters()) { + auto interfaceTy = param->getInterfaceType(); + if (hasThrowingFunctionClosureParameter(interfaceTy + ->lookThroughAllOptionalTypes() + ->getCanonicalType())) { + foundThrowingClosure = true; + } + + if (kind == FunctionRethrowingKind::Invalid) { + kind = getTypeThrowingKind(interfaceTy, genericSig); + } + } + if (kind == FunctionRethrowingKind::Invalid && + foundThrowingClosure) { + return FunctionRethrowingKind::ByClosure; + } + return kind; +} + +ProtocolRethrowsRequirementList +ProtocolRethrowsRequirementsRequest::evaluate(Evaluator &evaluator, + ProtocolDecl *decl) const { + SmallVector, 2> found; + llvm::DenseSet checkedProtocols; + + ASTContext &ctx = decl->getASTContext(); + + // only allow rethrowing requirements to be determined from marked protocols + if (!decl->getAttrs().hasAttribute()) { + return ProtocolRethrowsRequirementList(ctx.AllocateCopy(found)); + } + + // check if immediate members of protocol are 'rethrows' + for (auto member : decl->getMembers()) { + auto fnDecl = dyn_cast(member); + // it must be a function + // it must have a rethrows attribute + // it must not have any parameters that are closures that cause rethrowing + if (!fnDecl || + !fnDecl->hasThrows()) { + continue; + } + + GenericSignature genericSig = fnDecl->getGenericSignature(); + auto kind = getParameterThrowingKind(fnDecl, genericSig); + // skip closure based rethrowing cases + if (kind == FunctionRethrowingKind::ByClosure) { + continue; + } + // we now have a protocol member that has a rethrows and no closure + // parameters contributing to it's rethrowing-ness + found.push_back( + std::pair(decl->getSelfInterfaceType(), fnDecl)); + } + checkedProtocols.insert(decl); + + // check associated conformances of associated types or inheritance + for (auto requirement : decl->getRequirementSignature()) { + if (requirement.getKind() != RequirementKind::Conformance) { + continue; + } + auto protoTy = requirement.getSecondType()->castTo(); + auto proto = protoTy->getDecl(); + if (checkedProtocols.count(proto) != 0) { + continue; + } + checkedProtocols.insert(proto); + for (auto entry : proto->getRethrowingRequirements()) { + found.emplace_back(requirement.getFirstType(), entry.second); + } + } + + return ProtocolRethrowsRequirementList(ctx.AllocateCopy(found)); +} + +FunctionRethrowingKind +FunctionRethrowingKindRequest::evaluate(Evaluator &evaluator, + AbstractFunctionDecl *decl) const { + if (decl->hasThrows()) { + auto proto = dyn_cast(decl->getDeclContext()); + bool fromRethrow = proto != nullptr ? proto->isRethrowingProtocol() : false; + bool markedRethrows = decl->getAttrs().hasAttribute(); + if (fromRethrow && !markedRethrows) { + return FunctionRethrowingKind::ByConformance; + } + if (markedRethrows) { + GenericSignature genericSig = decl->getGenericSignature(); + FunctionRethrowingKind kind = getParameterThrowingKind(decl, genericSig); + // since we have checked all arguments, if we still havent found anything + // check the self parameter + if (kind == FunctionRethrowingKind::Invalid && + decl->hasImplicitSelfDecl()) { + auto selfParam = decl->getImplicitSelfDecl(); + if (selfParam) { + auto interfaceTy = selfParam->getInterfaceType(); + kind = getTypeThrowingKind(interfaceTy, genericSig); + } + } + return kind; + } + return FunctionRethrowingKind::Throws; + } + return FunctionRethrowingKind::None; +} + bool IsFinalRequest::evaluate(Evaluator &evaluator, ValueDecl *decl) const { if (isa(decl)) diff --git a/lib/Sema/TypeCheckDeclOverride.cpp b/lib/Sema/TypeCheckDeclOverride.cpp index af8d60e4755b1..bb952521f2935 100644 --- a/lib/Sema/TypeCheckDeclOverride.cpp +++ b/lib/Sema/TypeCheckDeclOverride.cpp @@ -1515,6 +1515,8 @@ namespace { UNINTERESTING_ATTR(ActorIndependent) UNINTERESTING_ATTR(GlobalActor) UNINTERESTING_ATTR(Async) + + UNINTERESTING_ATTR(AtRethrows) #undef UNINTERESTING_ATTR void visitAvailableAttr(AvailableAttr *attr) { diff --git a/lib/Sema/TypeCheckEffects.cpp b/lib/Sema/TypeCheckEffects.cpp index 81d405bb1465f..1c8fb97dd736a 100644 --- a/lib/Sema/TypeCheckEffects.cpp +++ b/lib/Sema/TypeCheckEffects.cpp @@ -22,6 +22,7 @@ #include "swift/AST/Initializer.h" #include "swift/AST/Pattern.h" #include "swift/AST/PrettyStackTrace.h" +#include "swift/AST/ProtocolConformance.h" using namespace swift; @@ -44,33 +45,44 @@ class AbstractFunction { unsigned TheKind : 2; unsigned IsRethrows : 1; unsigned ParamCount : 2; + FunctionRethrowingKind rethrowingKind; + ConcreteDeclRef declRef; public: - explicit AbstractFunction(Kind kind, Expr *fn) + explicit AbstractFunction(Kind kind, Expr *fn, ConcreteDeclRef declRef) : TheKind(kind), IsRethrows(false), - ParamCount(1) { + ParamCount(1), + rethrowingKind(FunctionRethrowingKind::Invalid), + declRef(declRef) { TheExpr = fn; } - explicit AbstractFunction(AbstractFunctionDecl *fn) + explicit AbstractFunction(AbstractFunctionDecl *fn, ConcreteDeclRef declRef) : TheKind(Kind::Function), IsRethrows(fn->getAttrs().hasAttribute()), - ParamCount(fn->getNumCurryLevels()) { + ParamCount(fn->getNumCurryLevels()), + rethrowingKind(fn->getRethrowingKind()), + declRef(declRef) { TheFunction = fn; } - explicit AbstractFunction(AbstractClosureExpr *closure) + explicit AbstractFunction(AbstractClosureExpr *closure, + ConcreteDeclRef declRef) : TheKind(Kind::Closure), IsRethrows(false), - ParamCount(1) { + ParamCount(1), + rethrowingKind(FunctionRethrowingKind::Invalid), + declRef(declRef) { TheClosure = closure; } - explicit AbstractFunction(ParamDecl *parameter) + explicit AbstractFunction(ParamDecl *parameter, ConcreteDeclRef declRef) : TheKind(Kind::Parameter), IsRethrows(false), - ParamCount(1) { + ParamCount(1), + rethrowingKind(FunctionRethrowingKind::Invalid), + declRef(declRef) { TheParameter = parameter; } @@ -79,6 +91,8 @@ class AbstractFunction { /// Whether the function is marked 'rethrows'. bool isBodyRethrows() const { return IsRethrows; } + FunctionRethrowingKind getRethrowingKind() const { return rethrowingKind; } + unsigned getNumArgumentsForFullApply() const { return ParamCount; } @@ -116,18 +130,29 @@ class AbstractFunction { return TheExpr; } + ConcreteDeclRef getDeclRef() { + return declRef; + } + static AbstractFunction decomposeApply(ApplyExpr *apply, SmallVectorImpl &args) { Expr *fn; + ConcreteDeclRef declRef; do { args.push_back(apply->getArg()); - fn = apply->getFn()->getValueProvidingExpr(); + auto applyFn = apply->getFn(); + if (!declRef) { + if (auto DRE = dyn_cast(applyFn)) { + declRef = DRE->getDeclRef(); + } + } + fn = applyFn->getValueProvidingExpr(); } while ((apply = dyn_cast(fn))); - return decomposeFunction(fn); + return decomposeFunction(fn, declRef); } - static AbstractFunction decomposeFunction(Expr *fn) { + static AbstractFunction decomposeFunction(Expr *fn, ConcreteDeclRef declRef = ConcreteDeclRef()) { assert(fn->getValueProvidingExpr() == fn); while (true) { @@ -158,25 +183,25 @@ class AbstractFunction { // Constructor delegation. if (auto otherCtorDeclRef = dyn_cast(fn)) { - return AbstractFunction(otherCtorDeclRef->getDecl()); + return AbstractFunction(otherCtorDeclRef->getDecl(), declRef); } // Normal function references. - if (auto declRef = dyn_cast(fn)) { - ValueDecl *decl = declRef->getDecl(); + if (auto DRE = dyn_cast(fn)) { + ValueDecl *decl = DRE->getDecl(); if (auto fn = dyn_cast(decl)) { - return AbstractFunction(fn); + return AbstractFunction(fn, declRef); } else if (auto param = dyn_cast(decl)) { - return AbstractFunction(param); + return AbstractFunction(param, declRef); } // Closures. } else if (auto closure = dyn_cast(fn)) { - return AbstractFunction(closure); + return AbstractFunction(closure, declRef); } // Everything else is opaque. - return AbstractFunction(Kind::Opaque, fn); + return AbstractFunction(Kind::Opaque, fn, declRef); } }; @@ -244,6 +269,8 @@ class EffectsHandlingWalker : public ASTWalker { recurse = asImpl().checkDoCatch(doCatch); } else if (auto thr = dyn_cast(S)) { recurse = asImpl().checkThrow(thr); + } else if (auto forEach = dyn_cast(S)) { + recurse = asImpl().checkForEach(forEach); } return {bool(recurse), S}; } @@ -257,6 +284,10 @@ class EffectsHandlingWalker : public ASTWalker { } return ShouldNotRecurse; } + + ShouldRecurse_t checkForEach(ForEachStmt *S) { + return ShouldRecurse; + } }; /// A potential reason why something might throw. @@ -279,6 +310,10 @@ class PotentialThrowReason { /// The function is 'rethrows', and it was passed a default /// argument that was not rethrowing-only in this context. CallRethrowsWithDefaultThrowingArgument, + + /// The the function is 'rethrows', and it is a member that + /// is a conformance to a rethrowing protocol. + CallRethrowsWithConformance, }; static StringRef kindToString(Kind k) { @@ -290,6 +325,8 @@ class PotentialThrowReason { return "CallRethrowsWithExplicitThrowingArgument"; case Kind::CallRethrowsWithDefaultThrowingArgument: return "CallRethrowsWithDefaultThrowingArgument"; + case Kind::CallRethrowsWithConformance: + return "CallRethrowsWithConformance"; } } @@ -307,6 +344,11 @@ class PotentialThrowReason { static PotentialThrowReason forDefaultArgument() { return PotentialThrowReason(Kind::CallRethrowsWithDefaultThrowingArgument); } + static PotentialThrowReason forRethrowsConformance(Expr *E) { + PotentialThrowReason result(Kind::CallRethrowsWithConformance); + result.TheExpression = E; + return result; + } static PotentialThrowReason forThrowingApply() { return PotentialThrowReason(Kind::CallThrows); } @@ -323,7 +365,8 @@ class PotentialThrowReason { bool isThrow() const { return getKind() == Kind::Throw; } bool isRethrowsCall() const { return (getKind() == Kind::CallRethrowsWithExplicitThrowingArgument || - getKind() == Kind::CallRethrowsWithDefaultThrowingArgument); + getKind() == Kind::CallRethrowsWithDefaultThrowingArgument || + getKind() == Kind::CallRethrowsWithConformance); } /// If this was built with forRethrowsArgument, return the expression. @@ -464,11 +507,6 @@ class ApplyClassifier { if (!fnType) return Classification::forInvalidCode(); bool isAsync = fnType->isAsync() || E->implicitlyAsync(); - - // If the function doesn't throw at all, we're done here. - if (!fnType->isThrowing()) - return isAsync ? Classification::forAsync() : Classification(); - // Decompose the application. SmallVector args; auto fnRef = AbstractFunction::decomposeApply(E, args); @@ -479,6 +517,34 @@ class ApplyClassifier { return Classification::forInvalidCode(); } + if (fnRef.getRethrowingKind() == FunctionRethrowingKind::ByConformance) { + auto substitutions = fnRef.getDeclRef().getSubstitutions(); + bool classifiedAsThrows = false; + for (auto conformanceRef : substitutions.getConformances()) { + if (conformanceRef.classifyAsThrows()) { + classifiedAsThrows = true; + break; + } + } + + if (classifiedAsThrows) { + return Classification::forRethrowingOnly( + PotentialThrowReason::forRethrowsConformance(E), isAsync); + } + } else if (fnRef.isBodyRethrows() && + fnRef.getRethrowingKind() == FunctionRethrowingKind::Throws) { + return Classification::forThrow(PotentialThrowReason::forThrowingApply(), + isAsync); + } else if (fnRef.isBodyRethrows() && + fnRef.getRethrowingKind() == FunctionRethrowingKind::None) { + return isAsync ? Classification::forAsync() : Classification(); + } + + // If the function doesn't throw at all, we're done here. + if (!fnType->isThrowing()) { + return isAsync ? Classification::forAsync() : Classification(); + } + // If we're applying more arguments than the natural argument // count, then this is a call to the opaque value returned from // the function. @@ -966,7 +1032,7 @@ class Context { if (!fn) return false; - return fn->getAttrs().hasAttribute(); + return fn->getRethrowingKind() == FunctionRethrowingKind::ByClosure; } /// Whether this is an autoclosure. @@ -1142,6 +1208,9 @@ class Context { case PotentialThrowReason::Kind::CallRethrowsWithDefaultThrowingArgument: Diags.diagnose(loc, diag::because_rethrows_default_argument_throws); return; + case PotentialThrowReason::Kind::CallRethrowsWithConformance: + Diags.diagnose(loc, diag::because_rethrows_default_conformance_throws); + return; } llvm_unreachable("bad reason kind"); } @@ -2019,6 +2088,16 @@ class CheckEffectsCoverage : public EffectsHandlingWalker scope.preserveCoverageFromOptionalOrForcedTryOperand(); return ShouldNotRecurse; } + + ShouldRecurse_t checkForEach(ForEachStmt *S) { + if (S->getTryLoc().isValid() && + !Flags.has(ContextFlags::IsTryCovered)) { + checkThrowAsyncSite(S, /*requiresTry*/ false, + Classification::forThrow(PotentialThrowReason::forThrow(), + /*async*/false)); + } + return ShouldRecurse; + } }; // Find nested functions and perform effects checking on them. diff --git a/lib/Sema/TypeCheckProtocol.cpp b/lib/Sema/TypeCheckProtocol.cpp index a11cea294dcf2..39252f8b0d532 100644 --- a/lib/Sema/TypeCheckProtocol.cpp +++ b/lib/Sema/TypeCheckProtocol.cpp @@ -558,12 +558,25 @@ swift::matchWitness( return RequirementMatch(witness, MatchKind::MutatingConflict); // If the requirement is rethrows, the witness must either be - // rethrows or be non-throwing. + // rethrows or be non-throwing if the requirement is not by conformance + // else the witness can be by conformance, throwing or non throwing if (reqAttrs.hasAttribute() && - !witnessAttrs.hasAttribute() && - cast(witness)->hasThrows()) - return RequirementMatch(witness, MatchKind::RethrowsConflict); - + !witnessAttrs.hasAttribute()) { + auto reqRethrowingKind = funcReq->getRethrowingKind(); + auto witnessRethrowingKind = funcWitness->getRethrowingKind(); + if (reqRethrowingKind == FunctionRethrowingKind::ByConformance) { + switch (witnessRethrowingKind) { + case FunctionRethrowingKind::ByConformance: + case FunctionRethrowingKind::Throws: + case FunctionRethrowingKind::None: + break; + default: + return RequirementMatch(witness, MatchKind::RethrowsConflict); + } + } else if (cast(witness)->hasThrows()) { + return RequirementMatch(witness, MatchKind::RethrowsConflict); + } + } // We want to decompose the parameters to handle them separately. decomposeFunctionType = true; } else if (auto *witnessASD = dyn_cast(witness)) { diff --git a/stdlib/public/Concurrency/AsyncIteratorProtocol.swift b/stdlib/public/Concurrency/AsyncIteratorProtocol.swift new file mode 100644 index 0000000000000..7613300cc96cf --- /dev/null +++ b/stdlib/public/Concurrency/AsyncIteratorProtocol.swift @@ -0,0 +1,19 @@ +////===----------------------------------------------------------------------===// +//// +//// This source file is part of the Swift.org open source project +//// +//// Copyright (c) 2020 Apple Inc. and the Swift project authors +//// Licensed under Apache License v2.0 with Runtime Library Exception +//// +//// See https://swift.org/LICENSE.txt for license information +//// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +//// +////===----------------------------------------------------------------------===// + +import Swift + +@rethrows +public protocol AsyncIteratorProtocol { + associatedtype Element + mutating func next() async throws -> Element? +} diff --git a/stdlib/public/Concurrency/AsyncSequence.swift b/stdlib/public/Concurrency/AsyncSequence.swift new file mode 100644 index 0000000000000..070d622f9f8b8 --- /dev/null +++ b/stdlib/public/Concurrency/AsyncSequence.swift @@ -0,0 +1,20 @@ +////===----------------------------------------------------------------------===// +//// +//// This source file is part of the Swift.org open source project +//// +//// Copyright (c) 2020 Apple Inc. and the Swift project authors +//// Licensed under Apache License v2.0 with Runtime Library Exception +//// +//// See https://swift.org/LICENSE.txt for license information +//// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +//// +////===----------------------------------------------------------------------===// + +import Swift + +@rethrows +public protocol AsyncSequence { + associatedtype AsyncIterator: AsyncIteratorProtocol where AsyncIterator.Element == Element + associatedtype Element + func makeAsyncIterator() -> AsyncIterator +} diff --git a/stdlib/public/Concurrency/CMakeLists.txt b/stdlib/public/Concurrency/CMakeLists.txt index 3ed283311ad41..e826277c60e4e 100644 --- a/stdlib/public/Concurrency/CMakeLists.txt +++ b/stdlib/public/Concurrency/CMakeLists.txt @@ -38,6 +38,8 @@ add_swift_target_library(swift_Concurrency ${SWIFT_STDLIB_LIBRARY_BUILD_TYPES} I Actor.swift CheckedContinuation.swift GlobalExecutor.cpp + AsyncIteratorProtocol.swift + AsyncSequence.swift PartialAsyncTask.swift Task.cpp Task.swift diff --git a/test/IDE/complete_decl_attribute.swift b/test/IDE/complete_decl_attribute.swift index b5125597727c9..a2585d76ecfb9 100644 --- a/test/IDE/complete_decl_attribute.swift +++ b/test/IDE/complete_decl_attribute.swift @@ -239,6 +239,7 @@ struct _S { // ON_MEMBER_LAST-DAG: Keyword/None: inlinable[#Declaration Attribute#]; name=inlinable // ON_MEMBER_LAST-DAG: Keyword/None: objcMembers[#Declaration Attribute#]; name=objcMembers // ON_MEMBER_LAST-DAG: Keyword/None: NSApplicationMain[#Declaration Attribute#]; name=NSApplicationMain +// ON_MEMBER_LAST-DAG: Keyword/None: rethrows[#Declaration Attribute#]; name=rethrows // ON_MEMBER_LAST-DAG: Keyword/None: warn_unqualified_access[#Declaration Attribute#]; name=warn_unqualified_access // ON_MEMBER_LAST-DAG: Keyword/None: usableFromInline[#Declaration Attribute#]; name=usableFromInline // ON_MEMBER_LAST-DAG: Keyword/None: discardableResult[#Declaration Attribute#]; name=discardableResult @@ -286,6 +287,7 @@ func dummy2() {} // KEYWORD_LAST-NEXT: Keyword/None: inlinable[#Declaration Attribute#]; name=inlinable{{$}} // KEYWORD_LAST-NEXT: Keyword/None: objcMembers[#Declaration Attribute#]; name=objcMembers{{$}} // KEYWORD_LAST-NEXT: Keyword/None: NSApplicationMain[#Declaration Attribute#]; name=NSApplicationMain{{$}} +// KEYWORD_LAST-NEXT: Keyword/None: rethrows[#Declaration Attribute#]; name=rethrows{{$}} // KEYWORD_LAST-NEXT: Keyword/None: warn_unqualified_access[#Declaration Attribute#]; name=warn_unqualified_access // KEYWORD_LAST-NEXT: Keyword/None: usableFromInline[#Declaration Attribute#]; name=usableFromInline{{$}} // KEYWORD_LAST-NEXT: Keyword/None: discardableResult[#Declaration Attribute#]; name=discardableResult diff --git a/test/Parse/foreach_async.swift b/test/Parse/foreach_async.swift new file mode 100644 index 0000000000000..52df6387400a1 --- /dev/null +++ b/test/Parse/foreach_async.swift @@ -0,0 +1,51 @@ +// RUN: %target-typecheck-verify-swift -enable-experimental-concurrency + +struct AsyncRange: AsyncSequence, AsyncIteratorProtocol where Bound.Stride : SignedInteger { + var range: Range.Iterator + typealias Element = Bound + mutating func next() async -> Element? { return range.next() } + func cancel() { } + + func makeAsyncIterator() -> Self { return self } +} + +struct AsyncIntRange : AsyncSequence, AsyncIteratorProtocol { + typealias Element = (Int, Int) + func next() async -> (Int, Int)? {} + + func cancel() { } + + typealias AsyncIterator = AsyncIntRange + func makeAsyncIterator() -> AsyncIntRange { return self } +} + +func for_each(r: AsyncRange, iir: AsyncIntRange) async { // expected-note {{'r' declared here}} + var sum = 0 + + // Simple foreach loop, using the variable in the body + for await i in r { + sum = sum + i + } + // Check scoping of variable introduced with foreach loop + i = 0 // expected-error{{cannot find 'i' in scope; did you mean 'r'?}} + + // For-each loops with two variables and varying degrees of typedness + for await (i, j) in iir { + sum = sum + i + j + } + for await (i, j) in iir { + sum = sum + i + j + } + for await (i, j) : (Int, Int) in iir { + sum = sum + i + j + } + + // Parse errors + // FIXME: Bad diagnostics; should be just 'expected 'in' after for-each patter'. + for await i r { // expected-error {{found an unexpected second identifier in constant declaration}} + } // expected-note @-1 {{join the identifiers together}} + // expected-note @-2 {{join the identifiers together with camel-case}} + // expected-error @-3 {{expected 'in' after for-each pattern}} + // expected-error @-4 {{expected Sequence expression for for-each loop}} + for await i in r sum = sum + i; // expected-error{{expected '{' to start the body of for-each loop}} +} diff --git a/test/SILGen/foreach_async.swift b/test/SILGen/foreach_async.swift new file mode 100644 index 0000000000000..b10b1dc481631 --- /dev/null +++ b/test/SILGen/foreach_async.swift @@ -0,0 +1,223 @@ +// RUN: %target-swift-emit-silgen %s -module-name foreach_async -swift-version 5 -enable-experimental-concurrency | %FileCheck %s +// REQUIRES: concurrency + +////////////////// +// Declarations // +////////////////// + +class C {} + +@_silgen_name("loopBodyEnd") +func loopBodyEnd() -> () + +@_silgen_name("condition") +func condition() -> Bool + +@_silgen_name("loopContinueEnd") +func loopContinueEnd() -> () + +@_silgen_name("loopBreakEnd") +func loopBreakEnd() -> () + +@_silgen_name("funcEnd") +func funcEnd() -> () + +struct TrivialStruct { + var value: Int32 +} + +struct NonTrivialStruct { + var value: C +} + +struct GenericStruct { + var value: T + var value2: C +} + +protocol P {} +protocol ClassP : AnyObject {} + +protocol GenericCollection : Collection { + +} + +struct AsyncLazySequence: AsyncSequence { + typealias Element = S.Element + typealias AsyncIterator = Iterator + + struct Iterator: AsyncIteratorProtocol { + typealias Element = S.Element + + var iterator: S.Iterator? + + mutating func next() async -> S.Element? { + return iterator?.next() + } + } + + var sequence: S + + func makeAsyncIterator() -> Iterator { + return Iterator(iterator: sequence.makeIterator()) + } +} + +/////////// +// Tests // +/////////// + +//===----------------------------------------------------------------------===// +// Trivial Struct +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: sil hidden [ossa] @$s13foreach_async13trivialStructyyAA17AsyncLazySequenceVySaySiGGYF : $@convention(thin) @async (@guaranteed AsyncLazySequence>) -> () { +// CHECK: bb0([[SOURCE:%.*]] : @guaranteed $AsyncLazySequence>): +// CHECK: [[ITERATOR_BOX:%.*]] = alloc_box ${ var AsyncLazySequence>.Iterator }, var, name "$x$generator" +// CHECK: [[PROJECT_ITERATOR_BOX:%.*]] = project_box [[ITERATOR_BOX]] +// CHECK: br [[LOOP_DEST:bb[0-9]+]] + +// CHECK: [[LOOP_DEST]]: +// CHECK: [[NEXT_RESULT:%.*]] = alloc_stack $Optional +// CHECK: [[MUTATION:%.*]] = begin_access +// CHECK: [[WITNESS_METHOD:%.*]] = witness_method $AsyncLazySequence>.Iterator, #AsyncIteratorProtocol.next : (inout Self) -> () async throws -> Self.Element? : $@convention(witness_method: AsyncIteratorProtocol) @async <τ_0_0 where τ_0_0 : AsyncIteratorProtocol> (@inout τ_0_0) -> (@out Optional<τ_0_0.Element>, @error Error) +// CHECK: try_apply [[WITNESS_METHOD]].Iterator>([[NEXT_RESULT]], [[MUTATION]]) : $@convention(witness_method: AsyncIteratorProtocol) @async <τ_0_0 where τ_0_0 : AsyncIteratorProtocol> (@inout τ_0_0) -> (@out Optional<τ_0_0.Element>, @error Error), normal [[NORMAL_BB:bb[0-9]+]], error [[ERROR_BB:bb[0-9]+]] + +// CHECK: [[NORMAL_BB]]([[VAR:%.*]] : $()): +// CHECK: end_access [[MUTATION]] +// CHECK: switch_enum [[IND_VAR:%.*]] : $Optional, case #Optional.some!enumelt: [[SOME_BB:bb[0-9]+]], case #Optional.none!enumelt: [[NONE_BB:bb[0-9]+]] + +// CHECK: [[SOME_BB]]([[VAR:%.*]] : $Int): +// CHECK: loopBodyEnd +// CHECK: br [[LOOP_DEST]] + +// CHECK: [[NONE_BB]]: +// CHECK: funcEnd +// CHECK return + +// CHECK: [[ERROR_BB]]([[VAR:%.*]] : @owned $Error): +// CHECK: unreachable +// CHECK: } // end sil function '$s13foreach_async13trivialStructyyAA17AsyncLazySequenceVySaySiGGYF' +func trivialStruct(_ xx: AsyncLazySequence<[Int]>) async { + for await x in xx { + loopBodyEnd() + } + funcEnd() +} + +// CHECK-LABEL: sil hidden [ossa] @$s13foreach_async21trivialStructContinueyyAA17AsyncLazySequenceVySaySiGGYF : $@convention(thin) @async (@guaranteed AsyncLazySequence>) -> () { +// CHECK: bb0([[SOURCE:%.*]] : @guaranteed $AsyncLazySequence>): +// CHECK: [[ITERATOR_BOX:%.*]] = alloc_box ${ var AsyncLazySequence>.Iterator }, var, name "$x$generator" +// CHECK: [[PROJECT_ITERATOR_BOX:%.*]] = project_box [[ITERATOR_BOX]] +// CHECK: br [[LOOP_DEST:bb[0-9]+]] + +// CHECK: [[LOOP_DEST]]: +// CHECK: [[NEXT_RESULT:%.*]] = alloc_stack $Optional +// CHECK: [[MUTATION:%.*]] = begin_access +// CHECK: [[WITNESS_METHOD:%.*]] = witness_method $AsyncLazySequence>.Iterator, #AsyncIteratorProtocol.next : (inout Self) -> () async throws -> Self.Element? : $@convention(witness_method: AsyncIteratorProtocol) @async <τ_0_0 where τ_0_0 : AsyncIteratorProtocol> (@inout τ_0_0) -> (@out Optional<τ_0_0.Element>, @error Error) +// CHECK: try_apply [[WITNESS_METHOD]].Iterator>([[NEXT_RESULT]], [[MUTATION]]) : $@convention(witness_method: AsyncIteratorProtocol) @async <τ_0_0 where τ_0_0 : AsyncIteratorProtocol> (@inout τ_0_0) -> (@out Optional<τ_0_0.Element>, @error Error), normal [[NORMAL_BB:bb[0-9]+]], error [[ERROR_BB:bb[0-9]+]] + +// CHECK: [[NORMAL_BB]]([[VAR:%.*]] : $()): +// CHECK: end_access [[MUTATION]] +// CHECK: switch_enum [[IND_VAR:%.*]] : $Optional, case #Optional.some!enumelt: [[SOME_BB:bb[0-9]+]], case #Optional.none!enumelt: [[NONE_BB:bb[0-9]+]] + +// CHECK: [[SOME_BB]]([[VAR:%.*]] : $Int): +// CHECK: condition +// CHECK: cond_br [[VAR:%.*]], [[COND_TRUE:bb[0-9]+]], [[COND_FALSE:bb[0-9]+]] + +// CHECK: [[COND_TRUE]]: +// CHECK: loopContinueEnd +// CHECK: br [[LOOP_DEST]] + +// CHECK: [[COND_FALSE]]: +// CHECK: loopBodyEnd +// CHECK: br [[LOOP_DEST]] + +// CHECK: [[NONE_BB]]: +// CHECK: funcEnd +// CHECK return + +// CHECK: [[ERROR_BB]]([[VAR:%.*]] : @owned $Error): +// CHECK: unreachable +// CHECK: } // end sil function '$s13foreach_async21trivialStructContinueyyAA17AsyncLazySequenceVySaySiGGYF' + +func trivialStructContinue(_ xx: AsyncLazySequence<[Int]>) async { + for await x in xx { + if (condition()) { + loopContinueEnd() + continue + } + loopBodyEnd() + } + + funcEnd() +} + +// TODO: Write this test +func trivialStructBreak(_ xx: AsyncLazySequence<[Int]>) async { + for await x in xx { + if (condition()) { + loopBreakEnd() + break + } + loopBodyEnd() + } + + funcEnd() +} + +// CHECK-LABEL: sil hidden [ossa] @$s13foreach_async26trivialStructContinueBreakyyAA17AsyncLazySequenceVySaySiGGYF : $@convention(thin) @async (@guaranteed AsyncLazySequence>) -> () +// CHECK: bb0([[SOURCE:%.*]] : @guaranteed $AsyncLazySequence>): +// CHECK: [[ITERATOR_BOX:%.*]] = alloc_box ${ var AsyncLazySequence>.Iterator }, var, name "$x$generator" +// CHECK: [[PROJECT_ITERATOR_BOX:%.*]] = project_box [[ITERATOR_BOX]] +// CHECK: br [[LOOP_DEST:bb[0-9]+]] + +// CHECK: [[LOOP_DEST]]: +// CHECK: [[NEXT_RESULT:%.*]] = alloc_stack $Optional +// CHECK: [[MUTATION:%.*]] = begin_access +// CHECK: [[WITNESS_METHOD:%.*]] = witness_method $AsyncLazySequence>.Iterator, #AsyncIteratorProtocol.next : (inout Self) -> () async throws -> Self.Element? : $@convention(witness_method: AsyncIteratorProtocol) @async <τ_0_0 where τ_0_0 : AsyncIteratorProtocol> (@inout τ_0_0) -> (@out Optional<τ_0_0.Element>, @error Error) +// CHECK: try_apply [[WITNESS_METHOD]].Iterator>([[NEXT_RESULT]], [[MUTATION]]) : $@convention(witness_method: AsyncIteratorProtocol) @async <τ_0_0 where τ_0_0 : AsyncIteratorProtocol> (@inout τ_0_0) -> (@out Optional<τ_0_0.Element>, @error Error), normal [[NORMAL_BB:bb[0-9]+]], error [[ERROR_BB:bb[0-9]+]] + +// CHECK: [[NORMAL_BB]]([[VAR:%.*]] : $()): +// CHECK: end_access [[MUTATION]] +// CHECK: switch_enum [[IND_VAR:%.*]] : $Optional, case #Optional.some!enumelt: [[SOME_BB:bb[0-9]+]], case #Optional.none!enumelt: [[NONE_BB:bb[0-9]+]] + +// CHECK: [[SOME_BB]]([[VAR:%.*]] : $Int): +// CHECK: condition +// CHECK: cond_br [[VAR:%.*]], [[COND_TRUE:bb[0-9]+]], [[COND_FALSE:bb[0-9]+]] + +// CHECK: [[COND_TRUE]]: +// CHECK: loopBreakEnd +// CHECK: br [[LOOP_EXIT:bb[0-9]+]] + +// CHECK: [[COND_FALSE]]: +// CHECK: condition +// CHECK: cond_br [[VAR:%.*]], [[COND_TRUE2:bb[0-9]+]], [[COND_FALSE2:bb[0-9]+]] + +// CHECK: [[COND_TRUE2]]: +// CHECK: loopContinueEnd +// CHECK: br [[LOOP_DEST]] + +// CHECK: [[COND_FALSE2]]: +// CHECK: br [[LOOP_DEST]] + +// CHECK: [[LOOP_EXIT]]: +// CHECK: return + +// CHECK: } // end sil function '$s13foreach_async26trivialStructContinueBreakyyAA17AsyncLazySequenceVySaySiGGYF' +func trivialStructContinueBreak(_ xx: AsyncLazySequence<[Int]>) async { + for await x in xx { + if (condition()) { + loopBreakEnd() + break + } + + if (condition()) { + loopContinueEnd() + continue + } + loopBodyEnd() + } + + funcEnd() +} diff --git a/test/attr/attr_rethrows_protocol.swift b/test/attr/attr_rethrows_protocol.swift new file mode 100644 index 0000000000000..4bf4891b7b35e --- /dev/null +++ b/test/attr/attr_rethrows_protocol.swift @@ -0,0 +1,78 @@ +// RUN: %target-typecheck-verify-swift + +@rethrows +protocol RethrowingProtocol { + func source() throws +} + +struct Rethrows: RethrowingProtocol { + var other: Source + func source() rethrows { } +} + +struct Throws: RethrowingProtocol { + func source() throws { } +} + +struct ThrowsWithSource: RethrowingProtocol { + var other: Source + func source() throws { } +} + +struct NonThrows: RethrowingProtocol { + func source() { } +} + +struct NonThrowsWithSource: RethrowingProtocol { + var other: Source + func source() { } +} + +protocol InvalidRethrowingProtocol { + func source() throws +} + +struct InvalidRethrows : InvalidRethrowingProtocol { + func source() rethrows { } + // expected-error@-1{{'rethrows' function must take a throwing function argument}} +} + +func freeFloatingRethrowing(_ r: R) rethrows { } + +func freeFloatingRethrowingFromExistential(_ r: RethrowingProtocol) rethrows { } + +func invalidFreeFloatingRethrows() rethrows { + // expected-error@-1{{'rethrows' function must take a throwing function argument}} +} + +let rethrowingFromThrows = Rethrows(other: Throws()) +try rethrowingFromThrows.source() + +@rethrows +protocol HasAssociatedRethrowerWithEnclosedRethrow { + associatedtype Rethrower: RethrowingProtocol + + func source() throws +} + +@rethrows +protocol HasAssociatedRethrower { + associatedtype Rethrower: RethrowingProtocol + + func makeRethrower() -> Rethrower +} + +func freeFloatingRethrowing(_ r: R) rethrows { } + +@rethrows +protocol InheritsRethrowing: RethrowingProtocol {} + +func freeFloatingInheritedRethrowingFunction(_ r: I) rethrows { } +func freeFloatingInheritedRethrowingFunctionFromExistential(_ r: InheritsRethrowing) rethrows { } + +func closureAndRethrowing(_ r: R, _ closure: () throws -> Void) rethrows { } + +closureAndRethrowing(NonThrows()) { } +try closureAndRethrowing(NonThrows()) { } // expected-warning{{no calls to throwing functions occur within 'try' expression}} +try closureAndRethrowing(Throws()) { } +try closureAndRethrowing(NonThrows()) { () throws -> Void in } diff --git a/utils/gyb_syntax_support/StmtNodes.py b/utils/gyb_syntax_support/StmtNodes.py index f662f3975a5bf..974490bf15e09 100644 --- a/utils/gyb_syntax_support/StmtNodes.py +++ b/utils/gyb_syntax_support/StmtNodes.py @@ -74,7 +74,8 @@ Child('GuardResult', kind='Expr'), ]), - # for-in-stmt -> label? ':'? 'for' 'case'? pattern 'in' expr 'where'? + # for-in-stmt -> label? ':'? + # 'for' 'try'? 'await'? 'case'? pattern 'in' expr 'where'? # expr code-block ';'? Node('ForInStmt', kind='Stmt', traits=['WithCodeBlock', 'Labeled'], @@ -84,6 +85,11 @@ Child('LabelColon', kind='ColonToken', is_optional=True), Child('ForKeyword', kind='ForToken'), + Child('TryKeyword', kind='TryToken', + is_optional=True), + Child('AwaitKeyword', kind='IdentifierToken', + classification='Keyword', + text_choices=['await'], is_optional=True), Child('CaseKeyword', kind='CaseToken', is_optional=True), Child('Pattern', kind='Pattern'),