Skip to content

Commit 8a5279c

Browse files
committed
merge: optimize task creation and remove sendable constraint
1 parent 68c8dc2 commit 8a5279c

11 files changed

+636
-461
lines changed

Package.swift

+7-2
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,14 @@ let package = Package(
1616
.library(name: "_CAsyncSequenceValidationSupport", type: .static, targets: ["AsyncSequenceValidation"]),
1717
.library(name: "AsyncAlgorithms_XCTest", targets: ["AsyncAlgorithms_XCTest"]),
1818
],
19-
dependencies: [],
19+
dependencies: [
20+
.package(url: "https://github.com/apple/swift-collections.git", .upToNextMajor(from: "1.0.3"))
21+
],
2022
targets: [
21-
.target(name: "AsyncAlgorithms"),
23+
.target(
24+
name: "AsyncAlgorithms",
25+
dependencies: [.product(name: "Collections", package: "swift-collections")]
26+
),
2227
.target(
2328
name: "AsyncSequenceValidation",
2429
dependencies: ["_CAsyncSequenceValidationSupport", "AsyncAlgorithms"]),

Sources/AsyncAlgorithms/AsyncChannel.swift

+3-3
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ public final class AsyncChannel<Element: Sendable>: AsyncSequence, Sendable {
2828
init(_ channel: AsyncChannel<Element>) {
2929
self.channel = channel
3030
}
31-
31+
3232
/// Await the next sent element or finish.
3333
public mutating func next() async -> Element? {
3434
guard active else {
@@ -116,7 +116,7 @@ public final class AsyncChannel<Element: Sendable>: AsyncSequence, Sendable {
116116

117117
/// Create a new `AsyncChannel` given an element type.
118118
public init(element elementType: Element.Type = Element.self) { }
119-
119+
120120
func establish() -> Int {
121121
state.withCriticalRegion { state in
122122
defer { state.generation &+= 1 }
@@ -152,7 +152,7 @@ public final class AsyncChannel<Element: Sendable>: AsyncSequence, Sendable {
152152
}
153153
return UnsafeResumption(continuation: send, success: continuation)
154154
case .awaiting(var nexts):
155-
if nexts.update(with: Awaiting(generation: generation, continuation: continuation)) != nil {
155+
if nexts.update(with: Awaiting(generation: generation, continuation: continuation)) != nil {
156156
nexts.remove(Awaiting(placeholder: generation))
157157
cancelled = true
158158
}

Sources/AsyncAlgorithms/AsyncChunksOfCountOrSignalSequence.swift

+38-22
Original file line numberDiff line numberDiff line change
@@ -60,34 +60,50 @@ public struct AsyncChunksOfCountOrSignalSequence<Base: AsyncSequence, Collected:
6060

6161
public typealias Element = Collected
6262

63+
enum Either {
64+
case first(Base.Element)
65+
case second(Signal.Element)
66+
}
67+
6368
/// The iterator for a `AsyncChunksOfCountOrSignalSequence` instance.
6469
public struct Iterator: AsyncIteratorProtocol, Sendable {
6570
let count: Int?
66-
var state: Merge2StateMachine<Base, Signal>
67-
init(base: Base.AsyncIterator, count: Int?, signal: Signal.AsyncIterator) {
71+
let state: MergeStateMachine<Either>
72+
init(base: Base, count: Int?, signal: Signal) {
6873
self.count = count
69-
self.state = Merge2StateMachine(base, terminatesOnNil: true, signal)
74+
let eitherBase = base.map { Either.first($0) }
75+
let eitherSignal = signal.map { Either.second($0) }
76+
self.state = MergeStateMachine(eitherBase, terminatesOnNil: true, eitherSignal)
7077
}
71-
78+
7279
public mutating func next() async rethrows -> Collected? {
73-
var result : Collected?
74-
while let next = try await state.next() {
75-
switch next {
76-
case .first(let element):
77-
if result == nil {
78-
result = Collected()
79-
}
80-
result!.append(element)
81-
if result?.count == count {
82-
return result
83-
}
84-
case .second(_):
85-
if result != nil {
86-
return result
87-
}
88-
}
80+
var collected: Collected?
81+
82+
loop: while true {
83+
let next = await state.next()
84+
85+
switch next {
86+
case .termination:
87+
break loop
88+
case .element(let result):
89+
let element = try result._rethrowGet()
90+
switch element {
91+
case .first(let element):
92+
if collected == nil {
93+
collected = Collected()
94+
}
95+
collected!.append(element)
96+
if collected?.count == count {
97+
return collected
98+
}
99+
case .second(_):
100+
if collected != nil {
101+
return collected
102+
}
103+
}
89104
}
90-
return result
105+
}
106+
return collected
91107
}
92108
}
93109

@@ -105,6 +121,6 @@ public struct AsyncChunksOfCountOrSignalSequence<Base: AsyncSequence, Collected:
105121
}
106122

107123
public func makeAsyncIterator() -> Iterator {
108-
return Iterator(base: base.makeAsyncIterator(), count: count, signal: signal.makeAsyncIterator())
124+
return Iterator(base: base, count: count, signal: signal)
109125
}
110126
}

Sources/AsyncAlgorithms/AsyncMerge2Sequence.swift

+39-180
Original file line numberDiff line numberDiff line change
@@ -10,200 +10,59 @@
1010
//===----------------------------------------------------------------------===//
1111

1212
/// Creates an asynchronous sequence of elements from two underlying asynchronous sequences
13-
public func merge<Base1: AsyncSequence, Base2: AsyncSequence>(_ base1: Base1, _ base2: Base2) -> AsyncMerge2Sequence<Base1, Base2>
14-
where
15-
Base1.Element == Base2.Element,
16-
Base1: Sendable, Base2: Sendable,
17-
Base1.Element: Sendable,
18-
Base1.AsyncIterator: Sendable, Base2.AsyncIterator: Sendable {
19-
return AsyncMerge2Sequence(base1, base2)
20-
}
21-
22-
struct Merge2StateMachine<Base1: AsyncSequence, Base2: AsyncSequence>: Sendable where Base1.AsyncIterator: Sendable, Base2.AsyncIterator: Sendable, Base1.Element: Sendable, Base2.Element: Sendable {
23-
typealias Element1 = Base1.Element
24-
typealias Element2 = Base2.Element
25-
26-
let iter1TerminatesOnNil : Bool
27-
let iter2terminatesOnNil : Bool
28-
29-
enum Partial: @unchecked Sendable {
30-
case first(Result<Element1?, Error>, Base1.AsyncIterator)
31-
case second(Result<Element2?, Error>, Base2.AsyncIterator)
32-
}
33-
34-
enum Either {
35-
case first(Base1.Element)
36-
case second(Base2.Element)
37-
}
38-
39-
var state: (PartialIteration<Base1.AsyncIterator, Partial>, PartialIteration<Base2.AsyncIterator, Partial>)
40-
41-
init(_ iterator1: Base1.AsyncIterator, terminatesOnNil iter1TerminatesOnNil: Bool = false, _ iterator2: Base2.AsyncIterator, terminatesOnNil iter2terminatesOnNil: Bool = false) {
42-
self.iter1TerminatesOnNil = iter1TerminatesOnNil
43-
self.iter2terminatesOnNil = iter2terminatesOnNil
44-
state = (.idle(iterator1), .idle(iterator2))
45-
}
46-
47-
mutating func apply(_ task1: Task<Merge2StateMachine<Base1, Base2>.Partial, Never>?, _ task2: Task<Merge2StateMachine<Base1, Base2>.Partial, Never>?) async rethrows -> Either? {
48-
switch await Task.select([task1, task2].compactMap ({ $0 })).value {
49-
case .first(let result, let iterator):
50-
do {
51-
guard let value = try state.0.resolve(result, iterator) else {
52-
if iter1TerminatesOnNil {
53-
state.1.cancel()
54-
return nil
55-
}
56-
return try await next()
57-
}
58-
return .first(value)
59-
} catch {
60-
state.1.cancel()
61-
throw error
62-
}
63-
case .second(let result, let iterator):
64-
do {
65-
guard let value = try state.1.resolve(result, iterator) else {
66-
if iter2terminatesOnNil {
67-
state.0.cancel()
68-
return nil
69-
}
70-
return try await next()
71-
}
72-
return .second(value)
73-
} catch {
74-
state.0.cancel()
75-
throw error
76-
}
77-
}
78-
}
79-
80-
func first(_ iterator1: Base1.AsyncIterator) -> Task<Partial, Never> {
81-
Task {
82-
var iter = iterator1
83-
do {
84-
let value = try await iter.next()
85-
return .first(.success(value), iter)
86-
} catch {
87-
return .first(.failure(error), iter)
88-
}
89-
}
90-
}
91-
92-
func second(_ iterator2: Base2.AsyncIterator) -> Task<Partial, Never> {
93-
Task {
94-
var iter = iterator2
95-
do {
96-
let value = try await iter.next()
97-
return .second(.success(value), iter)
98-
} catch {
99-
return .second(.failure(error), iter)
100-
}
101-
}
102-
}
103-
104-
/// Advances to the next element and returns it or `nil` if no next element exists.
105-
mutating func next() async rethrows -> Either? {
106-
if Task.isCancelled {
107-
state.0.cancel()
108-
state.1.cancel()
109-
return nil
110-
}
111-
switch state {
112-
case (.idle(let iterator1), .idle(let iterator2)):
113-
let task1 = first(iterator1)
114-
let task2 = second(iterator2)
115-
state = (.pending(task1), .pending(task2))
116-
return try await apply(task1, task2)
117-
case (.idle(let iterator1), .pending(let task2)):
118-
let task1 = first(iterator1)
119-
state = (.pending(task1), .pending(task2))
120-
return try await apply(task1, task2)
121-
case (.pending(let task1), .idle(let iterator2)):
122-
let task2 = second(iterator2)
123-
state = (.pending(task1), .pending(task2))
124-
return try await apply(task1, task2)
125-
case (.idle(var iterator1), .terminal):
126-
do {
127-
if let value = try await iterator1.next() {
128-
state = (.idle(iterator1), .terminal)
129-
return .first(value)
130-
} else {
131-
state = (.terminal, .terminal)
132-
return nil
133-
}
134-
} catch {
135-
state = (.terminal, .terminal)
136-
throw error
137-
}
138-
case (.terminal, .idle(var iterator2)):
139-
do {
140-
if let value = try await iterator2.next() {
141-
state = (.terminal, .idle(iterator2))
142-
return .second(value)
143-
} else {
144-
state = (.terminal, .terminal)
145-
return nil
146-
}
147-
} catch {
148-
state = (.terminal, .terminal)
149-
throw error
150-
}
151-
case (.terminal, .pending(let task2)):
152-
return try await apply(nil, task2)
153-
case (.pending(let task1), .pending(let task2)):
154-
return try await apply(task1, task2)
155-
case (.pending(let task1), .terminal):
156-
return try await apply(task1, nil)
157-
case (.terminal, .terminal):
158-
return nil
159-
}
160-
}
161-
}
162-
163-
extension Merge2StateMachine.Either where Base1.Element == Base2.Element {
164-
var value : Base1.Element {
165-
switch self {
166-
case .first(let val):
167-
return val
168-
case .second(let val):
169-
return val
170-
}
171-
}
13+
public func merge<Base1: AsyncSequence, Base2: AsyncSequence>(
14+
_ base1: Base1,
15+
_ base2: Base2
16+
) -> AsyncMerge2Sequence<Base1, Base2>{
17+
AsyncMerge2Sequence(base1, base2)
17218
}
17319

17420
/// An asynchronous sequence of elements from two underlying asynchronous sequences
17521
///
17622
/// In a `AsyncMerge2Sequence` instance, the *i*th element is the *i*th element
17723
/// resolved in sequential order out of the two underlying asynchronous sequences.
17824
/// Use the `merge(_:_:)` function to create an `AsyncMerge2Sequence`.
179-
public struct AsyncMerge2Sequence<Base1: AsyncSequence, Base2: AsyncSequence>: AsyncSequence, Sendable
180-
where
181-
Base1.Element == Base2.Element,
182-
Base1: Sendable, Base2: Sendable,
183-
Base1.Element: Sendable,
184-
Base1.AsyncIterator: Sendable, Base2.AsyncIterator: Sendable {
25+
public struct AsyncMerge2Sequence<Base1: AsyncSequence, Base2: AsyncSequence>: AsyncSequence
26+
where Base1.Element == Base2.Element {
18527
public typealias Element = Base1.Element
186-
/// An iterator for `AsyncMerge2Sequence`
187-
public struct Iterator: AsyncIteratorProtocol, Sendable {
188-
var state: Merge2StateMachine<Base1, Base2>
189-
init(_ base1: Base1.AsyncIterator, _ base2: Base2.AsyncIterator) {
190-
state = Merge2StateMachine(base1, base2)
191-
}
28+
public typealias AsyncIterator = Iterator
19229

193-
public mutating func next() async rethrows -> Element? {
194-
return try await state.next()?.value
195-
}
196-
}
197-
19830
let base1: Base1
19931
let base2: Base2
200-
201-
init(_ base1: Base1, _ base2: Base2) {
32+
33+
public init(_ base1: Base1, _ base2: Base2) {
20234
self.base1 = base1
20335
self.base2 = base2
20436
}
205-
37+
20638
public func makeAsyncIterator() -> Iterator {
207-
return Iterator(base1.makeAsyncIterator(), base2.makeAsyncIterator())
39+
Iterator(
40+
base1: self.base1,
41+
base2: self.base2
42+
)
43+
}
44+
45+
public struct Iterator: AsyncIteratorProtocol {
46+
let mergeStateMachine: MergeStateMachine<Element>
47+
48+
init(base1: Base1, base2: Base2) {
49+
self.mergeStateMachine = MergeStateMachine(
50+
base1,
51+
base2
52+
)
53+
}
54+
55+
public mutating func next() async rethrows -> Element? {
56+
let mergedElement = await self.mergeStateMachine.next()
57+
switch mergedElement {
58+
case .element(let result):
59+
return try result._rethrowGet()
60+
case .termination:
61+
return nil
62+
}
63+
}
20864
}
20965
}
66+
67+
extension AsyncMerge2Sequence: Sendable where Base1: Sendable, Base2: Sendable {}
68+
extension AsyncMerge2Sequence.Iterator: Sendable where Base1: Sendable, Base2: Sendable {}

0 commit comments

Comments
 (0)