Skip to content

Commit e0df6e0

Browse files
authored
Merge pull request #72234 from DougGregor/attr-implements-assoc-failure-type
[Associated type inference] Support `@_implements` on type witnesses and use it for async sequence `Failure`
2 parents 26226f6 + 5bdd4e5 commit e0df6e0

File tree

12 files changed

+132
-69
lines changed

12 files changed

+132
-69
lines changed

include/swift/Basic/Features.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ SUPPRESSIBLE_LANGUAGE_FEATURE(Extern, 0, "@_extern")
169169
LANGUAGE_FEATURE(ExpressionMacroDefaultArguments, 422, "Expression macro as caller-side default argument")
170170
LANGUAGE_FEATURE(BuiltinStoreRaw, 0, "Builtin.storeRaw")
171171
LANGUAGE_FEATURE(BuiltinCreateTask, 0, "Builtin.createTask and Builtin.createDiscardingTask")
172+
SUPPRESSIBLE_LANGUAGE_FEATURE(AssociatedTypeImplements, 0, "@_implements on associated types")
172173

173174
// Swift 6
174175
UPCOMING_FEATURE(ConciseMagicFile, 274, 6)

lib/AST/ASTPrinter.cpp

Lines changed: 8 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3083,6 +3083,14 @@ static void suppressingFeatureIsolatedAny(PrintOptions &options,
30833083
action();
30843084
}
30853085

3086+
static void suppressingFeatureAssociatedTypeImplements(PrintOptions &options,
3087+
llvm::function_ref<void()> action) {
3088+
unsigned originalExcludeAttrCount = options.ExcludeAttrList.size();
3089+
options.ExcludeAttrList.push_back(DeclAttrKind::Implements);
3090+
action();
3091+
options.ExcludeAttrList.resize(originalExcludeAttrCount);
3092+
}
3093+
30863094
/// Suppress the printing of a particular feature.
30873095
static void suppressingFeature(PrintOptions &options, Feature feature,
30883096
llvm::function_ref<void()> action) {
@@ -5512,32 +5520,6 @@ void Decl::printInherited(ASTPrinter &Printer, const PrintOptions &Opts) const {
55125520
printer.printInherited(this);
55135521
}
55145522

5515-
/// Determine whether this typealias is an inferred typealias "Failure" that
5516-
/// would conflict with another entity named failure in the same type.
5517-
static bool isConflictingFailureTypeWitness(
5518-
const TypeAliasDecl *typealias) {
5519-
if (!typealias->isImplicit())
5520-
return false;
5521-
5522-
ASTContext &ctx = typealias->getASTContext();
5523-
if (typealias->getName() != ctx.Id_Failure)
5524-
return false;
5525-
5526-
auto nominal = typealias->getDeclContext()->getSelfNominalTypeDecl();
5527-
if (!nominal)
5528-
return false;
5529-
5530-
// Look for another entity with the same name.
5531-
auto lookupResults = nominal->lookupDirect(
5532-
typealias->getName(), typealias->getLoc());
5533-
for (auto found : lookupResults) {
5534-
if (found != typealias)
5535-
return true;
5536-
}
5537-
5538-
return false;
5539-
}
5540-
55415523
bool Decl::shouldPrintInContext(const PrintOptions &PO) const {
55425524
// Skip getters/setters. They are part of the variable or subscript.
55435525
if (isa<AccessorDecl>(this))
@@ -5577,14 +5559,6 @@ bool Decl::shouldPrintInContext(const PrintOptions &PO) const {
55775559
return PO.PrintIfConfig;
55785560
}
55795561

5580-
// Prior to Swift 6, we shouldn't print the inferred associated type
5581-
// witness for AsyncSequence.Failure. It is always determined from the
5582-
// AsyncIteratorProtocol witness.
5583-
if (auto typealias = dyn_cast<TypeAliasDecl>(this)) {
5584-
if (isConflictingFailureTypeWitness(typealias))
5585-
return false;
5586-
}
5587-
55885562
// Print everything else.
55895563
return true;
55905564
}

lib/AST/FeatureSet.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,10 @@ static bool usesFeatureExtern(Decl *decl) {
342342
return decl->getAttrs().hasAttribute<ExternAttr>();
343343
}
344344

345+
static bool usesFeatureAssociatedTypeImplements(Decl *decl) {
346+
return isa<TypeDecl>(decl) && decl->getAttrs().hasAttribute<ImplementsAttr>();
347+
}
348+
345349
static bool usesFeatureExpressionMacroDefaultArguments(Decl *decl) {
346350
if (auto func = dyn_cast<AbstractFunctionDecl>(decl)) {
347351
for (auto param : *func->getParameters()) {

lib/Sema/AssociatedTypeInference.cpp

Lines changed: 58 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,17 @@ static bool containsConcreteDependentMemberType(Type ty) {
231231
});
232232
}
233233

234+
/// Determine whether this is the AsyncIteratorProtocol.Failure or
235+
/// AsyncSequence.Failure associated type.
236+
static bool isAsyncIteratorOrSequenceFailure(AssociatedTypeDecl *assocType) {
237+
auto proto = assocType->getProtocol();
238+
if (!proto->isSpecificProtocol(KnownProtocolKind::AsyncIteratorProtocol) &&
239+
!proto->isSpecificProtocol(KnownProtocolKind::AsyncSequence))
240+
return false;
241+
242+
return assocType->getName() == assocType->getASTContext().Id_Failure;
243+
}
244+
234245
static void recordTypeWitness(NormalProtocolConformance *conformance,
235246
AssociatedTypeDecl *assocType,
236247
Type type,
@@ -254,14 +265,40 @@ static void recordTypeWitness(NormalProtocolConformance *conformance,
254265

255266
// If there was no type declaration, synthesize one.
256267
if (typeDecl == nullptr) {
268+
Identifier name;
269+
bool needsImplementsAttr;
270+
if (isAsyncIteratorOrSequenceFailure(assocType)) {
271+
// Use __<protocol>_<assocType> as the name, to keep it out of the
272+
// way of other names.
273+
llvm::SmallString<32> nameBuffer;
274+
nameBuffer += "__";
275+
nameBuffer += assocType->getProtocol()->getName().str();
276+
nameBuffer += "_";
277+
nameBuffer += assocType->getName().str();
278+
279+
name = ctx.getIdentifier(nameBuffer);
280+
needsImplementsAttr = true;
281+
} else {
282+
// Declare a typealias with the same name as the associated type.
283+
name = assocType->getName();
284+
needsImplementsAttr = false;
285+
}
286+
257287
auto aliasDecl = new (ctx) TypeAliasDecl(
258-
SourceLoc(), SourceLoc(), assocType->getName(), SourceLoc(),
288+
SourceLoc(), SourceLoc(), name, SourceLoc(),
259289
/*genericparams*/ nullptr, dc);
260290
aliasDecl->setUnderlyingType(type);
261291

262292
aliasDecl->setImplicit();
263293
aliasDecl->setSynthesized();
264294

295+
// If needed, add an @_implements(Protocol, Name) attribute.
296+
if (needsImplementsAttr) {
297+
auto attr = ImplementsAttr::create(
298+
dc, assocType->getProtocol(), assocType->getName());
299+
aliasDecl->getAttrs().add(attr);
300+
}
301+
265302
// Inject the typealias into the nominal decl that conforms to the protocol.
266303
auto nominal = dc->getSelfNominalTypeDecl();
267304
auto requiredAccessScope = evaluateOrDefault(
@@ -395,17 +432,6 @@ static bool isAsyncSequenceFailure(AssociatedTypeDecl *assocType) {
395432
return assocType->getName() == assocType->getASTContext().Id_Failure;
396433
}
397434

398-
/// Determine whether this is the AsyncIteratorProtocol.Failure or
399-
/// AsyncSequence.Failure associated type.
400-
static bool isAsyncIteratorOrSequenceFailure(AssociatedTypeDecl *assocType) {
401-
auto proto = assocType->getProtocol();
402-
if (!proto->isSpecificProtocol(KnownProtocolKind::AsyncIteratorProtocol) &&
403-
!proto->isSpecificProtocol(KnownProtocolKind::AsyncSequence))
404-
return false;
405-
406-
return assocType->getName() == assocType->getASTContext().Id_Failure;
407-
}
408-
409435
/// Attempt to resolve a type witness via member name lookup.
410436
static ResolveWitnessResult resolveTypeWitnessViaLookup(
411437
NormalProtocolConformance *conformance,
@@ -429,7 +455,8 @@ static ResolveWitnessResult resolveTypeWitnessViaLookup(
429455
abort();
430456
}
431457

432-
NLOptions subOptions = (NL_QualifiedDefault | NL_OnlyTypes | NL_ProtocolMembers);
458+
NLOptions subOptions = (NL_QualifiedDefault | NL_OnlyTypes |
459+
NL_ProtocolMembers | NL_IncludeAttributeImplements);
433460

434461
// Look for a member type with the same name as the associated type.
435462
SmallVector<ValueDecl *, 4> candidates;
@@ -455,6 +482,16 @@ static ResolveWitnessResult resolveTypeWitnessViaLookup(
455482
if (isa<AssociatedTypeDecl>(typeDecl))
456483
continue;
457484

485+
// If the name doesn't match and there's no appropriate @_implements
486+
// attribute, skip this candidate.
487+
//
488+
// Also skip candidates in protocol extensions, because they tend to cause
489+
// request cycles. We'll look at those during associated type inference.
490+
if (assocType->getName() != typeDecl->getName() &&
491+
!(witnessHasImplementsAttrForRequiredName(typeDecl, assocType) &&
492+
!typeDecl->getDeclContext()->getSelfProtocolDecl()))
493+
continue;
494+
458495
auto *genericDecl = cast<GenericTypeDecl>(typeDecl);
459496

460497
// If the declaration has generic parameters, it cannot witness an
@@ -2008,7 +2045,8 @@ AssociatedTypeInference::inferTypeWitnessesViaAssociatedType(
20082045

20092046
NLOptions subOptions = (NL_QualifiedDefault |
20102047
NL_OnlyTypes |
2011-
NL_ProtocolMembers);
2048+
NL_ProtocolMembers |
2049+
NL_IncludeAttributeImplements);
20122050

20132051
// Look for types with the given default name that have appropriate
20142052
// @_implements attributes.
@@ -2030,6 +2068,12 @@ AssociatedTypeInference::inferTypeWitnessesViaAssociatedType(
20302068
if (!defaultProto)
20312069
continue;
20322070

2071+
// If the name doesn't match and there's no appropriate @_implements
2072+
// attribute, skip this candidate.
2073+
if (defaultName.getBaseName() != typeDecl->getName() &&
2074+
!witnessHasImplementsAttrForRequiredName(typeDecl, assocType))
2075+
continue;
2076+
20332077
// Determine the witness type.
20342078
Type witnessType = getWitnessTypeForMatching(conformance, typeDecl);
20352079
if (!witnessType) continue;

lib/Sema/TypeCheckAttr.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3885,7 +3885,7 @@ void AttributeChecker::visitImplementsAttr(ImplementsAttr *attr) {
38853885
// conforms to the specified protocol.
38863886
NominalTypeDecl *NTD = DC->getSelfNominalTypeDecl();
38873887
if (auto *OtherPD = dyn_cast<ProtocolDecl>(NTD)) {
3888-
if (!OtherPD->inheritsFrom(PD) &&
3888+
if (!(OtherPD == PD || OtherPD->inheritsFrom(PD)) &&
38893889
!(OtherPD->isSpecificProtocol(KnownProtocolKind::DistributedActor) ||
38903890
PD->isSpecificProtocol(KnownProtocolKind::Actor))) {
38913891
diagnose(attr->getLocation(),

lib/Sema/TypeCheckProtocol.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1281,18 +1281,18 @@ swift::matchWitness(WitnessChecker::RequirementEnvironmentCache &reqEnvCache,
12811281
return matchWitness(dc, req, witness, setup, matchTypes, finalize);
12821282
}
12831283

1284-
static bool
1285-
witnessHasImplementsAttrForRequiredName(ValueDecl *witness,
1286-
ValueDecl *requirement) {
1284+
bool
1285+
swift::witnessHasImplementsAttrForRequiredName(ValueDecl *witness,
1286+
ValueDecl *requirement) {
12871287
if (auto A = witness->getAttrs().getAttribute<ImplementsAttr>()) {
12881288
return A->getMemberName() == requirement->getName();
12891289
}
12901290
return false;
12911291
}
12921292

1293-
static bool
1294-
witnessHasImplementsAttrForExactRequirement(ValueDecl *witness,
1295-
ValueDecl *requirement) {
1293+
bool
1294+
swift::witnessHasImplementsAttrForExactRequirement(ValueDecl *witness,
1295+
ValueDecl *requirement) {
12961296
assert(requirement->isProtocolRequirement());
12971297
auto *PD = cast<ProtocolDecl>(requirement->getDeclContext());
12981298
if (auto A = witness->getAttrs().getAttribute<ImplementsAttr>()) {

lib/Sema/TypeCheckProtocol.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,16 @@ AssociatedTypeDecl *findDefaultedAssociatedType(
223223
DeclContext *dc, NominalTypeDecl *adoptee,
224224
AssociatedTypeDecl *assocType);
225225

226+
/// Determine whether this witness has an `@_implements` attribute whose
227+
/// name matches that of the given requirement.
228+
bool witnessHasImplementsAttrForRequiredName(ValueDecl *witness,
229+
ValueDecl *requirement);
230+
231+
/// Determine whether this witness has an `@_implements` attribute whose name
232+
/// and protocol match that of the requirement exactly.
233+
bool witnessHasImplementsAttrForExactRequirement(ValueDecl *witness,
234+
ValueDecl *requirement);
235+
226236
}
227237

228238
#endif // SWIFT_SEMA_PROTOCOL_H

stdlib/public/Distributed/DistributedActor.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -448,4 +448,4 @@ public protocol _DistributedActorStub where Self: DistributedActor {}
448448
@available(SwiftStdlib 6.0, *)
449449
public func _distributedStubFatalError(function: String = #function) -> Never {
450450
fatalError("Unexpected invocation of distributed method '\(function)' stub!")
451-
}
451+
}

stdlib/public/core/Sequence.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ public protocol Sequence<Element> {
453453
// Provides a default associated type witness for Iterator when the
454454
// Self type is both a Sequence and an Iterator.
455455
extension Sequence where Self: IteratorProtocol {
456-
// @_implements(Sequence, Iterator)
456+
@_implements(Sequence, Iterator)
457457
public typealias _Default_Iterator = Self
458458
}
459459

test/Concurrency/async_iterator_inference.swift

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,11 @@ case fail
7676

7777
@available(SwiftStdlib 5.1, *)
7878
func testAssocTypeInference(sf: S.Failure, tsf: TS.Failure, gtsf1: GenericTS<MyError>.Failure, adapter: SequenceAdapter<SpecificTS<MyError>>.Failure, ntas: NormalThrowingAsyncSequence<String, MyError>.Failure) {
79-
let _: Int = sf // expected-error{{cannot convert value of type 'S.Failure' (aka 'Never') to specified type 'Int'}}
80-
let _: Int = tsf // expected-error{{cannot convert value of type 'TS.Failure' (aka 'any Error') to specified type 'Int'}}
81-
let _: Int = gtsf1 // expected-error{{cannot convert value of type 'GenericTS<MyError>.Failure' (aka 'any Error') to specified type 'Int'}}
82-
let _: Int = adapter // expected-error{{cannot convert value of type 'SequenceAdapter<SpecificTS<MyError>>.Failure' (aka 'MyError') to specified type 'Int'}}
83-
let _: Int = ntas // expected-error{{cannot convert value of type 'NormalThrowingAsyncSequence<String, MyError>.Failure' (aka 'any Error') to specified type 'Int'}}
79+
let _: Int = sf // expected-error{{cannot convert value of type 'S.__AsyncSequence_Failure' (aka 'Never') to specified type 'Int'}}
80+
let _: Int = tsf // expected-error{{cannot convert value of type 'TS.__AsyncSequence_Failure' (aka 'any Error') to specified type 'Int'}}
81+
let _: Int = gtsf1 // expected-error{{cannot convert value of type 'GenericTS<MyError>.__AsyncSequence_Failure' (aka 'any Error') to specified type 'Int'}}
82+
let _: Int = adapter // expected-error{{cannot convert value of type 'SequenceAdapter<SpecificTS<MyError>>.__AsyncSequence_Failure' (aka 'MyError') to specified type 'Int'}}
83+
let _: Int = ntas // expected-error{{cannot convert value of type 'NormalThrowingAsyncSequence<String, MyError>.__AsyncSequence_Failure' (aka 'any Error') to specified type 'Int'}}
8484
}
8585

8686

test/ModuleInterface/async_sequence_conformance.swift

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,21 @@
44

55
// REQUIRES: concurrency, OS=macosx
66

7-
// CHECK: @available(
8-
// CHECK-NEXT: public struct SequenceAdapte
7+
// CHECK: public struct SequenceAdapte
98
@available(SwiftStdlib 5.1, *)
109
public struct SequenceAdapter<Base: AsyncSequence>: AsyncSequence {
1110
// CHECK-LABEL: public struct AsyncIterator
1211
// CHECK: @available{{.*}}macOS 10.15
1312
// CHECK-NEXT: public typealias Element = Base.Element
13+
14+
// CHECK: #if compiler(>=5.3) && $AssociatedTypeImplements
1415
// CHECK: @available(
15-
// CHECK-NEXT: public typealias Failure = Base.Failure
16+
// CHECK: @_implements(_Concurrency.AsyncIteratorProtocol, Failure)
17+
// CHECK-SAME: public typealias __AsyncIteratorProtocol_Failure = Base.Failure
18+
// CHECK-NEXT: #else
19+
// CHECK-NOT: @_implements
20+
// CHECK: public typealias __AsyncIteratorProtocol_Failure = Base.Failure
21+
// CHECK-NEXT: #endif
1622
public typealias Element = Base.Element
1723

1824
public struct AsyncIterator: AsyncIteratorProtocol {
@@ -23,11 +29,11 @@ public struct SequenceAdapter<Base: AsyncSequence>: AsyncSequence {
2329
public func makeAsyncIterator() -> AsyncIterator { AsyncIterator() }
2430

2531
// CHECK: @available(
26-
// CHECK-NEXT: public typealias Failure = Base.Failure
32+
// CHECK: @_implements(_Concurrency.AsyncSequence, Failure)
33+
// CHECK-SAME: public typealias __AsyncSequence_Failure = Base.Failure
2734
}
2835

29-
// CHECK: @available(
30-
// CHECK-NEXT: public struct OtherSequenceAdapte
36+
// CHECK: public struct OtherSequenceAdapte
3137
@available(SwiftStdlib 5.1, *)
3238
public struct OtherSequenceAdapter<Base: AsyncSequence>: AsyncSequence {
3339
// CHECK: public typealias Element = Base.Element
@@ -37,7 +43,8 @@ public struct OtherSequenceAdapter<Base: AsyncSequence>: AsyncSequence {
3743
// CHECK-LABEL: public struct AsyncIterator
3844
// CHECK: @available{{.*}}macOS 10.15
3945
// CHECK: @available(
40-
// CHECK-NEXT: public typealias Failure = Base.Failure
46+
// CHECK: @_implements(_Concurrency.AsyncIteratorProtocol, Failure)
47+
// CHECK-SAME: public typealias __AsyncIteratorProtocol_Failure = Base.Failure
4148
public typealias Element = Base.Element
4249

4350
public struct Failure: Error { }
@@ -52,3 +59,16 @@ public struct OtherSequenceAdapter<Base: AsyncSequence>: AsyncSequence {
5259

5360
// CHECK-NOT: public typealias Failure
5461
}
62+
63+
// CHECK: public struct MineOwnIterator
64+
@available(SwiftStdlib 5.1, *)
65+
public struct MineOwnIterator<Element>: AsyncSequence, AsyncIteratorProtocol {
66+
public mutating func next() async -> Element? { nil }
67+
public func makeAsyncIterator() -> Self { self }
68+
69+
// CHECK: @_implements(_Concurrency.AsyncIteratorProtocol, Failure)
70+
// CHECK-SAME: public typealias __AsyncIteratorProtocol_Failure = Swift.Never
71+
72+
// CHECK: @_implements(_Concurrency.AsyncSequence, Failure)
73+
// CHECK-SAME: public typealias __AsyncSequence_Failure = Swift.Never
74+
}

test/attr/attr_implements.swift

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,3 +97,13 @@ func falseWhenSpecificType(_ x: SpecificType) -> Bool { return x == x }
9797

9898
assert(trueWhenJustEquatable(SpecificType()))
9999
assert(!falseWhenSpecificType(SpecificType()))
100+
101+
// @_implements on associated types
102+
protocol PWithAssoc {
103+
associatedtype A
104+
}
105+
106+
struct XWithAssoc: PWithAssoc {
107+
@_implements(PWithAssoc, A)
108+
typealias __P_A = Int
109+
}

0 commit comments

Comments
 (0)