diff --git a/Package.swift b/Package.swift index c43747c0..fc920767 100644 --- a/Package.swift +++ b/Package.swift @@ -16,9 +16,14 @@ let package = Package( .library(name: "_CAsyncSequenceValidationSupport", type: .static, targets: ["AsyncSequenceValidation"]), .library(name: "AsyncAlgorithms_XCTest", targets: ["AsyncAlgorithms_XCTest"]), ], - dependencies: [], + dependencies: [ + .package(url: "https://github.com/apple/swift-collections.git", .upToNextMajor(from: "1.0.3")) + ], targets: [ - .target(name: "AsyncAlgorithms"), + .target( + name: "AsyncAlgorithms", + dependencies: [.product(name: "Collections", package: "swift-collections")] + ), .target( name: "AsyncSequenceValidation", dependencies: ["_CAsyncSequenceValidationSupport", "AsyncAlgorithms"]), diff --git a/Sources/AsyncAlgorithms/AsyncChannel.swift b/Sources/AsyncAlgorithms/AsyncChannel.swift index facdaadf..5d40b014 100644 --- a/Sources/AsyncAlgorithms/AsyncChannel.swift +++ b/Sources/AsyncAlgorithms/AsyncChannel.swift @@ -28,7 +28,7 @@ public final class AsyncChannel: AsyncSequence, Sendable { init(_ channel: AsyncChannel) { self.channel = channel } - + /// Await the next sent element or finish. public mutating func next() async -> Element? { guard active else { @@ -116,7 +116,7 @@ public final class AsyncChannel: AsyncSequence, Sendable { /// Create a new `AsyncChannel` given an element type. public init(element elementType: Element.Type = Element.self) { } - + func establish() -> Int { state.withCriticalRegion { state in defer { state.generation &+= 1 } @@ -152,7 +152,7 @@ public final class AsyncChannel: AsyncSequence, Sendable { } return UnsafeResumption(continuation: send, success: continuation) case .awaiting(var nexts): - if nexts.update(with: Awaiting(generation: generation, continuation: continuation)) != nil { + if nexts.update(with: Awaiting(generation: generation, continuation: continuation)) != nil { nexts.remove(Awaiting(placeholder: generation)) cancelled = true } diff --git a/Sources/AsyncAlgorithms/AsyncChunksOfCountOrSignalSequence.swift b/Sources/AsyncAlgorithms/AsyncChunksOfCountOrSignalSequence.swift index e5b167c7..e3167721 100644 --- a/Sources/AsyncAlgorithms/AsyncChunksOfCountOrSignalSequence.swift +++ b/Sources/AsyncAlgorithms/AsyncChunksOfCountOrSignalSequence.swift @@ -60,34 +60,50 @@ public struct AsyncChunksOfCountOrSignalSequence - init(base: Base.AsyncIterator, count: Int?, signal: Signal.AsyncIterator) { + let state: MergeStateMachine + init(base: Base, count: Int?, signal: Signal) { self.count = count - self.state = Merge2StateMachine(base, terminatesOnNil: true, signal) + let eitherBase = base.map { Either.first($0) } + let eitherSignal = signal.map { Either.second($0) } + self.state = MergeStateMachine(eitherBase, terminatesOnNil: true, eitherSignal) } - + public mutating func next() async rethrows -> Collected? { - var result : Collected? - while let next = try await state.next() { - switch next { - case .first(let element): - if result == nil { - result = Collected() - } - result!.append(element) - if result?.count == count { - return result - } - case .second(_): - if result != nil { - return result - } - } + var collected: Collected? + + loop: while true { + let next = await state.next() + + switch next { + case .termination: + break loop + case .element(let result): + let element = try result._rethrowGet() + switch element { + case .first(let element): + if collected == nil { + collected = Collected() + } + collected!.append(element) + if collected?.count == count { + return collected + } + case .second(_): + if collected != nil { + return collected + } + } } - return result + } + return collected } } @@ -105,6 +121,6 @@ public struct AsyncChunksOfCountOrSignalSequence Iterator { - return Iterator(base: base.makeAsyncIterator(), count: count, signal: signal.makeAsyncIterator()) + return Iterator(base: base, count: count, signal: signal) } } diff --git a/Sources/AsyncAlgorithms/AsyncMerge2Sequence.swift b/Sources/AsyncAlgorithms/AsyncMerge2Sequence.swift index eeaf0246..9175b66d 100644 --- a/Sources/AsyncAlgorithms/AsyncMerge2Sequence.swift +++ b/Sources/AsyncAlgorithms/AsyncMerge2Sequence.swift @@ -10,165 +10,11 @@ //===----------------------------------------------------------------------===// /// Creates an asynchronous sequence of elements from two underlying asynchronous sequences -public func merge(_ base1: Base1, _ base2: Base2) -> AsyncMerge2Sequence -where - Base1.Element == Base2.Element, - Base1: Sendable, Base2: Sendable, - Base1.Element: Sendable, - Base1.AsyncIterator: Sendable, Base2.AsyncIterator: Sendable { - return AsyncMerge2Sequence(base1, base2) -} - -struct Merge2StateMachine: Sendable where Base1.AsyncIterator: Sendable, Base2.AsyncIterator: Sendable, Base1.Element: Sendable, Base2.Element: Sendable { - typealias Element1 = Base1.Element - typealias Element2 = Base2.Element - - let iter1TerminatesOnNil : Bool - let iter2terminatesOnNil : Bool - - enum Partial: @unchecked Sendable { - case first(Result, Base1.AsyncIterator) - case second(Result, Base2.AsyncIterator) - } - - enum Either { - case first(Base1.Element) - case second(Base2.Element) - } - - var state: (PartialIteration, PartialIteration) - - init(_ iterator1: Base1.AsyncIterator, terminatesOnNil iter1TerminatesOnNil: Bool = false, _ iterator2: Base2.AsyncIterator, terminatesOnNil iter2terminatesOnNil: Bool = false) { - self.iter1TerminatesOnNil = iter1TerminatesOnNil - self.iter2terminatesOnNil = iter2terminatesOnNil - state = (.idle(iterator1), .idle(iterator2)) - } - - mutating func apply(_ task1: Task.Partial, Never>?, _ task2: Task.Partial, Never>?) async rethrows -> Either? { - switch await Task.select([task1, task2].compactMap ({ $0 })).value { - case .first(let result, let iterator): - do { - guard let value = try state.0.resolve(result, iterator) else { - if iter1TerminatesOnNil { - state.1.cancel() - return nil - } - return try await next() - } - return .first(value) - } catch { - state.1.cancel() - throw error - } - case .second(let result, let iterator): - do { - guard let value = try state.1.resolve(result, iterator) else { - if iter2terminatesOnNil { - state.0.cancel() - return nil - } - return try await next() - } - return .second(value) - } catch { - state.0.cancel() - throw error - } - } - } - - func first(_ iterator1: Base1.AsyncIterator) -> Task { - Task { - var iter = iterator1 - do { - let value = try await iter.next() - return .first(.success(value), iter) - } catch { - return .first(.failure(error), iter) - } - } - } - - func second(_ iterator2: Base2.AsyncIterator) -> Task { - Task { - var iter = iterator2 - do { - let value = try await iter.next() - return .second(.success(value), iter) - } catch { - return .second(.failure(error), iter) - } - } - } - - /// Advances to the next element and returns it or `nil` if no next element exists. - mutating func next() async rethrows -> Either? { - if Task.isCancelled { - state.0.cancel() - state.1.cancel() - return nil - } - switch state { - case (.idle(let iterator1), .idle(let iterator2)): - let task1 = first(iterator1) - let task2 = second(iterator2) - state = (.pending(task1), .pending(task2)) - return try await apply(task1, task2) - case (.idle(let iterator1), .pending(let task2)): - let task1 = first(iterator1) - state = (.pending(task1), .pending(task2)) - return try await apply(task1, task2) - case (.pending(let task1), .idle(let iterator2)): - let task2 = second(iterator2) - state = (.pending(task1), .pending(task2)) - return try await apply(task1, task2) - case (.idle(var iterator1), .terminal): - do { - if let value = try await iterator1.next() { - state = (.idle(iterator1), .terminal) - return .first(value) - } else { - state = (.terminal, .terminal) - return nil - } - } catch { - state = (.terminal, .terminal) - throw error - } - case (.terminal, .idle(var iterator2)): - do { - if let value = try await iterator2.next() { - state = (.terminal, .idle(iterator2)) - return .second(value) - } else { - state = (.terminal, .terminal) - return nil - } - } catch { - state = (.terminal, .terminal) - throw error - } - case (.terminal, .pending(let task2)): - return try await apply(nil, task2) - case (.pending(let task1), .pending(let task2)): - return try await apply(task1, task2) - case (.pending(let task1), .terminal): - return try await apply(task1, nil) - case (.terminal, .terminal): - return nil - } - } -} - -extension Merge2StateMachine.Either where Base1.Element == Base2.Element { - var value : Base1.Element { - switch self { - case .first(let val): - return val - case .second(let val): - return val - } - } +public func merge( + _ base1: Base1, + _ base2: Base2 +) -> AsyncMerge2Sequence{ + AsyncMerge2Sequence(base1, base2) } /// An asynchronous sequence of elements from two underlying asynchronous sequences @@ -176,34 +22,47 @@ extension Merge2StateMachine.Either where Base1.Element == Base2.Element { /// In a `AsyncMerge2Sequence` instance, the *i*th element is the *i*th element /// resolved in sequential order out of the two underlying asynchronous sequences. /// Use the `merge(_:_:)` function to create an `AsyncMerge2Sequence`. -public struct AsyncMerge2Sequence: AsyncSequence, Sendable -where - Base1.Element == Base2.Element, - Base1: Sendable, Base2: Sendable, - Base1.Element: Sendable, - Base1.AsyncIterator: Sendable, Base2.AsyncIterator: Sendable { +public struct AsyncMerge2Sequence: AsyncSequence +where Base1.Element == Base2.Element { public typealias Element = Base1.Element - /// An iterator for `AsyncMerge2Sequence` - public struct Iterator: AsyncIteratorProtocol, Sendable { - var state: Merge2StateMachine - init(_ base1: Base1.AsyncIterator, _ base2: Base2.AsyncIterator) { - state = Merge2StateMachine(base1, base2) - } + public typealias AsyncIterator = Iterator - public mutating func next() async rethrows -> Element? { - return try await state.next()?.value - } - } - let base1: Base1 let base2: Base2 - - init(_ base1: Base1, _ base2: Base2) { + + public init(_ base1: Base1, _ base2: Base2) { self.base1 = base1 self.base2 = base2 } - + public func makeAsyncIterator() -> Iterator { - return Iterator(base1.makeAsyncIterator(), base2.makeAsyncIterator()) + Iterator( + base1: self.base1, + base2: self.base2 + ) + } + + public struct Iterator: AsyncIteratorProtocol { + let mergeStateMachine: MergeStateMachine + + init(base1: Base1, base2: Base2) { + self.mergeStateMachine = MergeStateMachine( + base1, + base2 + ) + } + + public mutating func next() async rethrows -> Element? { + let mergedElement = await self.mergeStateMachine.next() + switch mergedElement { + case .element(let result): + return try result._rethrowGet() + case .termination: + return nil + } + } } } + +extension AsyncMerge2Sequence: Sendable where Base1: Sendable, Base2: Sendable {} +extension AsyncMerge2Sequence.Iterator: Sendable where Base1: Sendable, Base2: Sendable {} diff --git a/Sources/AsyncAlgorithms/AsyncMerge3Sequence.swift b/Sources/AsyncAlgorithms/AsyncMerge3Sequence.swift index efbbf9a8..ca6cd800 100644 --- a/Sources/AsyncAlgorithms/AsyncMerge3Sequence.swift +++ b/Sources/AsyncAlgorithms/AsyncMerge3Sequence.swift @@ -10,14 +10,12 @@ //===----------------------------------------------------------------------===// /// Creates an asynchronous sequence of elements from three underlying asynchronous sequences -public func merge(_ base1: Base1, _ base2: Base2, _ base3: Base3) -> AsyncMerge3Sequence -where - Base1.Element == Base2.Element, - Base2.Element == Base3.Element, - Base1: Sendable, Base2: Sendable, Base3: Sendable, - Base1.Element: Sendable, - Base1.AsyncIterator: Sendable, Base2.AsyncIterator: Sendable, Base3.AsyncIterator: Sendable { - return AsyncMerge3Sequence(base1, base2, base3) +public func merge( + _ base1: Base1, + _ base2: Base2, + _ base3: Base3 +) -> AsyncMerge3Sequence { + AsyncMerge3Sequence(base1, base2, base3) } /// An asynchronous sequence of elements from three underlying asynchronous sequences @@ -25,260 +23,51 @@ where /// In a `AsyncMerge3Sequence` instance, the *i*th element is the *i*th element /// resolved in sequential order out of the two underlying asynchronous sequences. /// Use the `merge(_:_:_:)` function to create an `AsyncMerge3Sequence`. -public struct AsyncMerge3Sequence: AsyncSequence, Sendable -where - Base1.Element == Base2.Element, - Base2.Element == Base3.Element, - Base1: Sendable, Base2: Sendable, Base3: Sendable, - Base1.Element: Sendable, - Base1.AsyncIterator: Sendable, Base2.AsyncIterator: Sendable, Base3.AsyncIterator: Sendable { +public struct AsyncMerge3Sequence: AsyncSequence +where Base1.Element == Base2.Element, Base3.Element == Base2.Element { public typealias Element = Base1.Element - /// An iterator for `AsyncMerge3Sequence` - public struct Iterator: AsyncIteratorProtocol, Sendable { - enum Partial: @unchecked Sendable { - case first(Result, Base1.AsyncIterator) - case second(Result, Base2.AsyncIterator) - case third(Result, Base3.AsyncIterator) - } - - var state: (PartialIteration, PartialIteration, PartialIteration) - - init(_ iterator1: Base1.AsyncIterator, _ iterator2: Base2.AsyncIterator, _ iterator3: Base3.AsyncIterator) { - state = (.idle(iterator1), .idle(iterator2), .idle(iterator3)) - } - - mutating func apply(_ task1: Task?, _ task2: Task?, _ task3: Task?) async rethrows -> Element? { - switch await Task.select([task1, task2, task3].compactMap { $0 }).value { - case .first(let result, let iterator): - do { - guard let value = try state.0.resolve(result, iterator) else { - return try await next() - } - return value - } catch { - state.1.cancel() - state.2.cancel() - throw error - } - case .second(let result, let iterator): - do { - guard let value = try state.1.resolve(result, iterator) else { - return try await next() - } - return value - } catch { - state.0.cancel() - state.2.cancel() - throw error - } - case .third(let result, let iterator): - do { - guard let value = try state.2.resolve(result, iterator) else { - return try await next() - } - return value - } catch { - state.0.cancel() - state.1.cancel() - throw error - } - } - } - - func first(_ iterator1: Base1.AsyncIterator) -> Task { - Task { - var iter = iterator1 - do { - let value = try await iter.next() - return .first(.success(value), iter) - } catch { - return .first(.failure(error), iter) - } - } - } - - func second(_ iterator2: Base2.AsyncIterator) -> Task { - Task { - var iter = iterator2 - do { - let value = try await iter.next() - return .second(.success(value), iter) - } catch { - return .second(.failure(error), iter) - } - } - } - - func third(_ iterator3: Base3.AsyncIterator) -> Task { - Task { - var iter = iterator3 - do { - let value = try await iter.next() - return .third(.success(value), iter) - } catch { - return .third(.failure(error), iter) - } - } - } - - public mutating func next() async rethrows -> Element? { - // state must have either all terminal or at least 1 idle iterator - // state may not have a saturation of pending tasks - switch state { - // three idle - case (.idle(let iterator1), .idle(let iterator2), .idle(let iterator3)): - let task1 = first(iterator1) - let task2 = second(iterator2) - let task3 = third(iterator3) - state = (.pending(task1), .pending(task2), .pending(task3)) - return try await apply(task1, task2, task3) - // two idle - case (.idle(let iterator1), .idle(let iterator2), .pending(let task3)): - let task1 = first(iterator1) - let task2 = second(iterator2) - state = (.pending(task1), .pending(task2), .pending(task3)) - return try await apply(task1, task2, task3) - case (.idle(let iterator1), .pending(let task2), .idle(let iterator3)): - let task1 = first(iterator1) - let task3 = third(iterator3) - state = (.pending(task1), .pending(task2), .pending(task3)) - return try await apply(task1, task2, task3) - case (.pending(let task1), .idle(let iterator2), .idle(let iterator3)): - let task2 = second(iterator2) - let task3 = third(iterator3) - state = (.pending(task1), .pending(task2), .pending(task3)) - return try await apply(task1, task2, task3) - - // 1 idle - case (.idle(let iterator1), .pending(let task2), .pending(let task3)): - let task1 = first(iterator1) - state = (.pending(task1), .pending(task2), .pending(task3)) - return try await apply(task1, task2, task3) - case (.pending(let task1), .idle(let iterator2), .pending(let task3)): - let task2 = second(iterator2) - state = (.pending(task1), .pending(task2), .pending(task3)) - return try await apply(task1, task2, task3) - case (.pending(let task1), .pending(let task2), .idle(let iterator3)): - let task3 = third(iterator3) - state = (.pending(task1), .pending(task2), .pending(task3)) - return try await apply(task1, task2, task3) - - // terminal degradations - // 1 terminal - case (.terminal, .idle(let iterator2), .idle(let iterator3)): - let task2 = second(iterator2) - let task3 = third(iterator3) - state = (.terminal, .pending(task2), .pending(task3)) - return try await apply(nil, task2, task3) - case (.terminal, .idle(let iterator2), .pending(let task3)): - let task2 = second(iterator2) - state = (.terminal, .pending(task2), .pending(task3)) - return try await apply(nil, task2, task3) - case (.terminal, .pending(let task2), .idle(let iterator3)): - let task3 = third(iterator3) - state = (.terminal, .pending(task2), .pending(task3)) - return try await apply(nil, task2, task3) - case (.idle(let iterator1), .terminal, .idle(let iterator3)): - let task1 = first(iterator1) - let task3 = third(iterator3) - state = (.pending(task1), .terminal, .pending(task3)) - return try await apply(task1, nil, task3) - case (.idle(let iterator1), .terminal, .pending(let task3)): - let task1 = first(iterator1) - state = (.pending(task1), .terminal, .pending(task3)) - return try await apply(task1, nil, task3) - case (.pending(let task1), .terminal, .idle(let iterator3)): - let task3 = third(iterator3) - state = (.pending(task1), .terminal, .pending(task3)) - return try await apply(task1, nil, task3) - case (.idle(let iterator1), .idle(let iterator2), .terminal): - let task1 = first(iterator1) - let task2 = second(iterator2) - state = (.pending(task1), .pending(task2), .terminal) - return try await apply(task1, task2, nil) - case (.idle(let iterator1), .pending(let task2), .terminal): - let task1 = first(iterator1) - state = (.pending(task1), .pending(task2), .terminal) - return try await apply(task1, task2, nil) - case (.pending(let task1), .idle(let iterator2), .terminal): - let task2 = second(iterator2) - state = (.pending(task1), .pending(task2), .terminal) - return try await apply(task1, task2, nil) - - // 2 terminal - // these can be permuted in place since they don't need to run two or more tasks at once - case (.terminal, .terminal, .idle(var iterator3)): - do { - if let value = try await iterator3.next() { - state = (.terminal, .terminal, .idle(iterator3)) - return value - } else { - state = (.terminal, .terminal, .terminal) - return nil - } - } catch { - state = (.terminal, .terminal, .terminal) - throw error - } - case (.terminal, .idle(var iterator2), .terminal): - do { - if let value = try await iterator2.next() { - state = (.terminal, .idle(iterator2), .terminal) - return value - } else { - state = (.terminal, .terminal, .terminal) - return nil - } - } catch { - state = (.terminal, .terminal, .terminal) - throw error - } - case (.idle(var iterator1), .terminal, .terminal): - do { - if let value = try await iterator1.next() { - state = (.idle(iterator1), .terminal, .terminal) - return value - } else { - state = (.terminal, .terminal, .terminal) - return nil - } - } catch { - state = (.terminal, .terminal, .terminal) - throw error - } - // 3 terminal - case (.terminal, .terminal, .terminal): - return nil - // partials - case (.pending(let task1), .pending(let task2), .pending(let task3)): - return try await apply(task1, task2, task3) - case (.pending(let task1), .pending(let task2), .terminal): - return try await apply(task1, task2, nil) - case (.pending(let task1), .terminal, .pending(let task3)): - return try await apply(task1, nil, task3) - case (.terminal, .pending(let task2), .pending(let task3)): - return try await apply(nil, task2, task3) - case (.pending(let task1), .terminal, .terminal): - return try await apply(task1, nil, nil) - case (.terminal, .pending(let task2), .terminal): - return try await apply(nil, task2, nil) - case (.terminal, .terminal, .pending(let task3)): - return try await apply(nil, nil, task3) - } - } - } - + public typealias AsyncIterator = Iterator + let base1: Base1 let base2: Base2 let base3: Base3 - init(_ base1: Base1, _ base2: Base2, _ base3: Base3) { + public init(_ base1: Base1, _ base2: Base2, _ base3: Base3) { self.base1 = base1 self.base2 = base2 self.base3 = base3 } public func makeAsyncIterator() -> Iterator { - return Iterator(base1.makeAsyncIterator(), base2.makeAsyncIterator(), base3.makeAsyncIterator()) + Iterator( + base1: self.base1, + base2: self.base2, + base3: self.base3 + ) + } + + public struct Iterator: AsyncIteratorProtocol { + let mergeStateMachine: MergeStateMachine + + init(base1: Base1, base2: Base2, base3: Base3) { + self.mergeStateMachine = MergeStateMachine( + base1, + base2, + base3 + ) + } + + public mutating func next() async rethrows -> Element? { + let mergedElement = await self.mergeStateMachine.next() + switch mergedElement { + case .element(let result): + return try result._rethrowGet() + case .termination: + return nil + } + } } } + +extension AsyncMerge3Sequence: Sendable where Base1: Sendable, Base2: Sendable, Base3: Sendable {} +extension AsyncMerge3Sequence.Iterator: Sendable where Base1: Sendable, Base2: Sendable, Base3: Sendable {} diff --git a/Sources/AsyncAlgorithms/AsyncMergeSequence.swift b/Sources/AsyncAlgorithms/AsyncMergeSequence.swift new file mode 100644 index 00000000..20afaa85 --- /dev/null +++ b/Sources/AsyncAlgorithms/AsyncMergeSequence.swift @@ -0,0 +1,62 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift Async Algorithms open source project +// +// Copyright (c) 2022 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// +//===----------------------------------------------------------------------===// + +/// Creates an asynchronous sequence of elements from many underlying asynchronous sequences +public func merge( + _ bases: Base... +) -> AsyncMergeSequence{ + AsyncMergeSequence(bases) +} + +/// An asynchronous sequence of elements from many underlying asynchronous sequences +/// +/// In a `AsyncMergeSequence` instance, the *i*th element is the *i*th element +/// resolved in sequential order out of the two underlying asynchronous sequences. +/// Use the `merge(...)` function to create an `AsyncMergeSequence`. +public struct AsyncMergeSequence: AsyncSequence { + public typealias Element = Base.Element + public typealias AsyncIterator = Iterator + + let bases: [Base] + + public init(_ bases: [Base]) { + self.bases = bases + } + + public func makeAsyncIterator() -> Iterator { + Iterator( + bases: self.bases + ) + } + + public struct Iterator: AsyncIteratorProtocol { + let mergeStateMachine: MergeStateMachine + + init(bases: [Base]) { + self.mergeStateMachine = MergeStateMachine( + bases + ) + } + + public mutating func next() async rethrows -> Element? { + let mergedElement = await self.mergeStateMachine.next() + switch mergedElement { + case .element(let result): + return try result._rethrowGet() + case .termination: + return nil + } + } + } +} + +extension AsyncMergeSequence: Sendable where Base: Sendable {} +extension AsyncMergeSequence.Iterator: Sendable where Base: Sendable {} diff --git a/Sources/AsyncAlgorithms/MergeStateMachine.swift b/Sources/AsyncAlgorithms/MergeStateMachine.swift new file mode 100644 index 00000000..ca339f56 --- /dev/null +++ b/Sources/AsyncAlgorithms/MergeStateMachine.swift @@ -0,0 +1,272 @@ +// +// MergeStateMachine.swift +// +// +// Created by Thibault Wittemberg on 08/09/2022. +// + +import DequeModule + +struct MergeStateMachine: Sendable { + enum MergedElement { + case element(Result) + case termination + } + + enum BufferState { + case idle + case queued(Deque) + case awaiting(UnsafeContinuation) + case closed + } + + struct State { + var buffer: BufferState + var basesToTerminate: Int + } + + struct OnNextDecision { + let continuation: UnsafeContinuation + let mergedElement: MergedElement + } + + let requestNextRegulatedElements: @Sendable () -> Void + let state: ManagedCriticalState + let task: Task + + init( + _ base1: Base1, + terminatesOnNil base1TerminatesOnNil: Bool = false, + _ base2: Base2, + terminatesOnNil base2TerminatesOnNil: Bool = false + ) where Base1.Element == Element, Base2.Element == Element { + self.state = ManagedCriticalState(State(buffer: .idle, basesToTerminate: 2)) + + let regulator1 = Regulator(base1, onNextRegulatedElement: { [state] in Self.onNextRegulatedElement($0, state: state) }) + let regulator2 = Regulator(base2, onNextRegulatedElement: { [state] in Self.onNextRegulatedElement($0, state: state) }) + + self.requestNextRegulatedElements = { + regulator1.requestNextRegulatedElement() + regulator2.requestNextRegulatedElement() + } + + self.task = Task { + await withTaskGroup(of: Void.self) { group in + group.addTask { + await regulator1.iterate(terminatesOnNil: base1TerminatesOnNil) + } + + group.addTask { + await regulator2.iterate(terminatesOnNil: base2TerminatesOnNil) + } + } + } + } + + init( + _ base1: Base1, + terminatesOnNil base1TerminatesOnNil: Bool = false, + _ base2: Base2, + terminatesOnNil base2TerminatesOnNil: Bool = false, + _ base3: Base3, + terminatesOnNil base3TerminatesOnNil: Bool = false + ) where Base1.Element == Element, Base2.Element == Element, Base3.Element == Base1.Element { + self.state = ManagedCriticalState(State(buffer: .idle, basesToTerminate: 3)) + + let regulator1 = Regulator(base1, onNextRegulatedElement: { [state] in Self.onNextRegulatedElement($0, state: state) }) + let regulator2 = Regulator(base2, onNextRegulatedElement: { [state] in Self.onNextRegulatedElement($0, state: state) }) + let regulator3 = Regulator(base3, onNextRegulatedElement: { [state] in Self.onNextRegulatedElement($0, state: state) }) + + self.requestNextRegulatedElements = { + regulator1.requestNextRegulatedElement() + regulator2.requestNextRegulatedElement() + regulator3.requestNextRegulatedElement() + } + + self.task = Task { + await withTaskGroup(of: Void.self) { group in + group.addTask { + await regulator1.iterate(terminatesOnNil: base1TerminatesOnNil) + } + + group.addTask { + await regulator2.iterate(terminatesOnNil: base2TerminatesOnNil) + } + + group.addTask { + await regulator3.iterate(terminatesOnNil: base3TerminatesOnNil) + } + } + } + } + + init( + _ bases: [Base] + ) where Base.Element == Element { + self.state = ManagedCriticalState(State(buffer: .idle, basesToTerminate: bases.count)) + + var regulators = [Regulator]() + + for base in bases { + let regulator = Regulator(base, onNextRegulatedElement: { [state] in Self.onNextRegulatedElement($0, state: state) }) + regulators.append(regulator) + } + + let immutableRegulators = regulators + self.requestNextRegulatedElements = { + for regulator in immutableRegulators { + regulator.requestNextRegulatedElement() + } + } + + self.task = Task { + await withTaskGroup(of: Void.self) { group in + for regulators in immutableRegulators { + group.addTask { + await regulators.iterate(terminatesOnNil: false) + } + } + } + } + } + + @Sendable static func onNextRegulatedElement(_ element: RegulatedElement, state: ManagedCriticalState) { + let decision = state.withCriticalRegion { state -> OnNextDecision? in + switch (state.buffer, element) { + // when buffer is close + case (.closed, _): + return nil + + // when buffer is empty and available + case (.idle, .termination(let forcedTermination)) where forcedTermination == true: + state.basesToTerminate = 0 + state.buffer = .closed + return nil + case (.idle, .termination): + state.basesToTerminate -= 1 + if state.basesToTerminate == 0 { + state.buffer = .closed + } else { + state.buffer = .idle + } + return nil + case (.idle, .element(let result)): + state.buffer = .queued([.element(result)]) + return nil + + // when buffer is queued + case (.queued(var elements), .termination(let forcedTermination)) where forcedTermination == true: + elements.append(.termination) + state.buffer = .queued(elements) + return nil + case (.queued(var elements), .termination): + state.basesToTerminate -= 1 + if state.basesToTerminate == 0 { + elements.append(.termination) + state.buffer = .queued(elements) + } + return nil + case (.queued(var elements), .element(let result)): + elements.append(.element(result)) + state.buffer = .queued(elements) + return nil + + // when buffer is awaiting for base values + case (.awaiting(let continuation), .termination(let forcedTermination)) where forcedTermination == true: + state.basesToTerminate = 0 + state.buffer = .closed + return OnNextDecision(continuation: continuation, mergedElement: .termination) + case (.awaiting(let continuation), .termination): + state.basesToTerminate -= 1 + if state.basesToTerminate == 0 { + state.buffer = .closed + return OnNextDecision(continuation: continuation, mergedElement: .termination) + } else { + state.buffer = .awaiting(continuation) + return nil + } + case (.awaiting(let continuation), .element(.success(let element))): + state.buffer = .idle + return OnNextDecision(continuation: continuation, mergedElement: .element(.success(element))) + case (.awaiting(let continuation), .element(.failure(let error))): + state.buffer = .closed + return OnNextDecision(continuation: continuation, mergedElement: .element(.failure(error))) + } + } + + if let decision = decision { + decision.continuation.resume(returning: decision.mergedElement) + } + } + + @Sendable func unsuspendAndClearOnCancel() { + let continuation = self.state.withCriticalRegion { state -> UnsafeContinuation? in + switch state.buffer { + case .awaiting(let continuation): + state.basesToTerminate = 0 + state.buffer = .closed + return continuation + default: + state.basesToTerminate = 0 + state.buffer = .closed + return nil + } + } + + continuation?.resume(returning: .termination) + self.task.cancel() + } + + func next() async -> MergedElement { + await withTaskCancellationHandler { + self.unsuspendAndClearOnCancel() + } operation: { + self.requestNextRegulatedElements() + + let mergedElement = await withUnsafeContinuation { (continuation: UnsafeContinuation) in + let decision = self.state.withCriticalRegion { state -> OnNextDecision? in + switch state.buffer { + case .closed: + return OnNextDecision(continuation: continuation, mergedElement: .termination) + case .idle: + state.buffer = .awaiting(continuation) + return nil + case .queued(var elements): + guard let mergedElement = elements.popFirst() else { + assertionFailure("The buffer cannot by empty, it should be idle in this case") + return OnNextDecision(continuation: continuation, mergedElement: .termination) + } + switch mergedElement { + case .termination: + state.buffer = .closed + return OnNextDecision(continuation: continuation, mergedElement: .termination) + case .element(.success(let element)): + if elements.isEmpty { + state.buffer = .idle + } else { + state.buffer = .queued(elements) + } + return OnNextDecision(continuation: continuation, mergedElement: .element(.success(element))) + case .element(.failure(let error)): + state.buffer = .closed + return OnNextDecision(continuation: continuation, mergedElement: .element(.failure(error))) + } + case .awaiting: + assertionFailure("The next function cannot be called concurrently") + return OnNextDecision(continuation: continuation, mergedElement: .termination) + } + } + + if let decision = decision { + decision.continuation.resume(returning: decision.mergedElement) + } + } + + if case .termination = mergedElement, case .element(.failure) = mergedElement { + self.task.cancel() + } + + return mergedElement + } + } +} diff --git a/Sources/AsyncAlgorithms/Regulator.swift b/Sources/AsyncAlgorithms/Regulator.swift new file mode 100644 index 00000000..43aabb8a --- /dev/null +++ b/Sources/AsyncAlgorithms/Regulator.swift @@ -0,0 +1,133 @@ +// +// Regulator.swift +// +// +// Created by Thibault Wittemberg on 08/09/2022. +// + +enum RegulatedElement { + case termination(forcedTermination: Bool) + case element(result: Result) +} + +struct Regulator { + enum State { + case idle + case suspended(UnsafeContinuation) + case active + case finished + } + + enum IterationDecision { + case suspend + case resume(continuation: UnsafeContinuation, shouldExit: Bool) + } + + let base: Base + let state: ManagedCriticalState + let onNextRegulatedElement: @Sendable (RegulatedElement) -> Void + + init( + _ base: Base, + onNextRegulatedElement: @Sendable @escaping (RegulatedElement) -> Void + ) { + self.base = base + self.state = ManagedCriticalState(.idle) + self.onNextRegulatedElement = onNextRegulatedElement + } + + func unsuspendAndExitOnCancel() { + let continuation = state.withCriticalRegion { state -> UnsafeContinuation? in + switch state { + case .suspended(let continuation): + state = .finished + return continuation + default: + state = .finished + return nil + } + } + + continuation?.resume(returning: true) + } + + func iterate(terminatesOnNil: Bool) async { + await withTaskCancellationHandler { + self.unsuspendAndExitOnCancel() + } operation: { + + var mutableBase = base.makeAsyncIterator() + + do { + baseLoop: while true { + let shouldExit = await withUnsafeContinuation { (continuation: UnsafeContinuation) in + let decision = self.state.withCriticalRegion { state -> IterationDecision in + + switch state { + case .idle: + state = .suspended(continuation) + return .suspend + case .suspended(let continuation): + assertionFailure("Inconsistent state, the base is already suspended") + return .resume(continuation: continuation, shouldExit: true) + case .active: + return .resume(continuation: continuation, shouldExit: false) + case .finished: + return .resume(continuation: continuation, shouldExit: true) + } + } + + switch decision { + case .suspend: + break + case .resume(let continuation, let shouldExit): + continuation.resume(returning: shouldExit) + } + } + + if shouldExit { + // end the loop ... no more values from this base + break baseLoop + } + + let element = try await mutableBase.next() + + let regulatedElement = self.state.withCriticalRegion { state -> RegulatedElement in + switch element { + case .none: + state = .finished + return .termination(forcedTermination: terminatesOnNil) + case .some(let element): + state = .idle + return .element(result: .success(element)) + } + } + + self.onNextRegulatedElement(regulatedElement) + } + } catch { + self.state.withCriticalRegion { state in + state = .finished + } + self.onNextRegulatedElement(.element(result: .failure(error))) + } + } + } + + @Sendable func requestNextRegulatedElement() { + let continuation = self.state.withCriticalRegion { state -> UnsafeContinuation? in + switch state { + case .suspended(let continuation): + state = .active + return continuation + case .idle: + state = .active + return nil + case .active, .finished: + return nil + } + } + + continuation?.resume(returning: false) + } +} diff --git a/Tests/AsyncAlgorithmsTests/Performance/TestThroughput.swift b/Tests/AsyncAlgorithmsTests/Performance/TestThroughput.swift index 2d490cc9..c7d91fd4 100644 --- a/Tests/AsyncAlgorithmsTests/Performance/TestThroughput.swift +++ b/Tests/AsyncAlgorithmsTests/Performance/TestThroughput.swift @@ -49,6 +49,11 @@ final class TestThroughput: XCTestCase { merge($0, $1, $2) } } + func test_merge4() async { + await measureSequenceThroughput(firstOutput: 1, secondOutput: 2, thirdOutput: 3, fourthOutput: 4) { + merge($0, $1, $2, $3) + } + } func test_removeDuplicates() async { await measureSequenceThroughput(source: (1...).async) { $0.removeDuplicates() diff --git a/Tests/AsyncAlgorithmsTests/Performance/ThroughputMeasurement.swift b/Tests/AsyncAlgorithmsTests/Performance/ThroughputMeasurement.swift index f9356aae..e6c3af05 100644 --- a/Tests/AsyncAlgorithmsTests/Performance/ThroughputMeasurement.swift +++ b/Tests/AsyncAlgorithmsTests/Performance/ThroughputMeasurement.swift @@ -104,7 +104,7 @@ extension XCTestCase { self.wait(for: [exp], timeout: sampleTime * 2) } } - + public func measureSequenceThroughput(firstOutput: @autoclosure () -> Output, secondOutput: @autoclosure () -> Output, thirdOutput: @autoclosure () -> Output, _ sequenceBuilder: (InfiniteAsyncSequence, InfiniteAsyncSequence, InfiniteAsyncSequence) -> S) async where S: Sendable { let metric = _ThroughputMetric() let sampleTime: Double = 0.1 @@ -129,7 +129,40 @@ extension XCTestCase { iterTask.cancel() self.wait(for: [exp], timeout: sampleTime * 2) } -} + } + + public func measureSequenceThroughput( + firstOutput: @autoclosure () -> Output, + secondOutput: @autoclosure () -> Output, + thirdOutput: @autoclosure () -> Output, + fourthOutput: @autoclosure () -> Output, + _ sequenceBuilder: (InfiniteAsyncSequence, InfiniteAsyncSequence, InfiniteAsyncSequence, InfiniteAsyncSequence + ) -> S) async where S: Sendable { + let metric = _ThroughputMetric() + let sampleTime: Double = 0.1 + + measure(metrics: [metric]) { + let firstInfSeq = InfiniteAsyncSequence(value: firstOutput()) + let secondInfSeq = InfiniteAsyncSequence(value: secondOutput()) + let thirdInfSeq = InfiniteAsyncSequence(value: thirdOutput()) + let fourthInfSeq = InfiniteAsyncSequence(value: thirdOutput()) + let seq = sequenceBuilder(firstInfSeq, secondInfSeq, thirdInfSeq, fourthInfSeq) + + let exp = self.expectation(description: "Finished") + let iterTask = Task { + var eventCount = 0 + for try await _ in seq { + eventCount += 1 + } + metric.eventCount = eventCount + exp.fulfill() + return eventCount + } + usleep(UInt32(sampleTime * Double(USEC_PER_SEC))) + iterTask.cancel() + self.wait(for: [exp], timeout: sampleTime * 2) + } + } public func measureSequenceThroughput( source: Source, _ sequenceBuilder: (Source) -> S) async where S: Sendable, Source: Sendable { let metric = _ThroughputMetric() diff --git a/Tests/AsyncAlgorithmsTests/TestMerge.swift b/Tests/AsyncAlgorithmsTests/TestMerge.swift index 3cf7c577..ee9c78f4 100644 --- a/Tests/AsyncAlgorithmsTests/TestMerge.swift +++ b/Tests/AsyncAlgorithmsTests/TestMerge.swift @@ -84,6 +84,7 @@ final class TestMerge2: XCTestCase { } catch { XCTAssertEqual(Failure(), error as? Failure) } + let pastEnd = try await iterator.next() XCTAssertNil(pastEnd)