Skip to content

[Merge] optimize tasks creation and remove sendable constraint #193

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,14 @@ let package = Package(
.library(name: "_CAsyncSequenceValidationSupport", type: .static, targets: ["AsyncSequenceValidation"]),
.library(name: "AsyncAlgorithms_XCTest", targets: ["AsyncAlgorithms_XCTest"]),
],
dependencies: [],
dependencies: [
.package(url: "https://github.com/apple/swift-collections.git", .upToNextMajor(from: "1.0.3"))
],
targets: [
.target(name: "AsyncAlgorithms"),
.target(
name: "AsyncAlgorithms",
dependencies: [.product(name: "Collections", package: "swift-collections")]
),
.target(
name: "AsyncSequenceValidation",
dependencies: ["_CAsyncSequenceValidationSupport", "AsyncAlgorithms"]),
Expand Down
6 changes: 3 additions & 3 deletions Sources/AsyncAlgorithms/AsyncChannel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public final class AsyncChannel<Element: Sendable>: AsyncSequence, Sendable {
init(_ channel: AsyncChannel<Element>) {
self.channel = channel
}

/// Await the next sent element or finish.
public mutating func next() async -> Element? {
guard active else {
Expand Down Expand Up @@ -116,7 +116,7 @@ public final class AsyncChannel<Element: Sendable>: AsyncSequence, Sendable {

/// Create a new `AsyncChannel` given an element type.
public init(element elementType: Element.Type = Element.self) { }

func establish() -> Int {
state.withCriticalRegion { state in
defer { state.generation &+= 1 }
Expand Down Expand Up @@ -152,7 +152,7 @@ public final class AsyncChannel<Element: Sendable>: AsyncSequence, Sendable {
}
return UnsafeResumption(continuation: send, success: continuation)
case .awaiting(var nexts):
if nexts.update(with: Awaiting(generation: generation, continuation: continuation)) != nil {
if nexts.update(with: Awaiting(generation: generation, continuation: continuation)) != nil {
nexts.remove(Awaiting(placeholder: generation))
cancelled = true
}
Expand Down
60 changes: 38 additions & 22 deletions Sources/AsyncAlgorithms/AsyncChunksOfCountOrSignalSequence.swift
Original file line number Diff line number Diff line change
Expand Up @@ -60,34 +60,50 @@ public struct AsyncChunksOfCountOrSignalSequence<Base: AsyncSequence, Collected:

public typealias Element = Collected

enum Either {
case first(Base.Element)
case second(Signal.Element)
}

/// The iterator for a `AsyncChunksOfCountOrSignalSequence` instance.
public struct Iterator: AsyncIteratorProtocol, Sendable {
let count: Int?
var state: Merge2StateMachine<Base, Signal>
init(base: Base.AsyncIterator, count: Int?, signal: Signal.AsyncIterator) {
let state: MergeStateMachine<Either>
init(base: Base, count: Int?, signal: Signal) {
self.count = count
self.state = Merge2StateMachine(base, terminatesOnNil: true, signal)
let eitherBase = base.map { Either.first($0) }
let eitherSignal = signal.map { Either.second($0) }
self.state = MergeStateMachine(eitherBase, terminatesOnNil: true, eitherSignal)
}

public mutating func next() async rethrows -> Collected? {
var result : Collected?
while let next = try await state.next() {
switch next {
case .first(let element):
if result == nil {
result = Collected()
}
result!.append(element)
if result?.count == count {
return result
}
case .second(_):
if result != nil {
return result
}
}
var collected: Collected?

loop: while true {
let next = await state.next()

switch next {
case .termination:
break loop
case .element(let result):
let element = try result._rethrowGet()
switch element {
case .first(let element):
if collected == nil {
collected = Collected()
}
collected!.append(element)
if collected?.count == count {
return collected
}
case .second(_):
if collected != nil {
return collected
}
}
}
return result
}
return collected
}
}

Expand All @@ -105,6 +121,6 @@ public struct AsyncChunksOfCountOrSignalSequence<Base: AsyncSequence, Collected:
}

public func makeAsyncIterator() -> Iterator {
return Iterator(base: base.makeAsyncIterator(), count: count, signal: signal.makeAsyncIterator())
return Iterator(base: base, count: count, signal: signal)
}
}
219 changes: 39 additions & 180 deletions Sources/AsyncAlgorithms/AsyncMerge2Sequence.swift
Original file line number Diff line number Diff line change
Expand Up @@ -10,200 +10,59 @@
//===----------------------------------------------------------------------===//

/// Creates an asynchronous sequence of elements from two underlying asynchronous sequences
public func merge<Base1: AsyncSequence, Base2: AsyncSequence>(_ base1: Base1, _ base2: Base2) -> AsyncMerge2Sequence<Base1, Base2>
where
Base1.Element == Base2.Element,
Base1: Sendable, Base2: Sendable,
Base1.Element: Sendable,
Base1.AsyncIterator: Sendable, Base2.AsyncIterator: Sendable {
return AsyncMerge2Sequence(base1, base2)
}

struct Merge2StateMachine<Base1: AsyncSequence, Base2: AsyncSequence>: Sendable where Base1.AsyncIterator: Sendable, Base2.AsyncIterator: Sendable, Base1.Element: Sendable, Base2.Element: Sendable {
typealias Element1 = Base1.Element
typealias Element2 = Base2.Element

let iter1TerminatesOnNil : Bool
let iter2terminatesOnNil : Bool

enum Partial: @unchecked Sendable {
case first(Result<Element1?, Error>, Base1.AsyncIterator)
case second(Result<Element2?, Error>, Base2.AsyncIterator)
}

enum Either {
case first(Base1.Element)
case second(Base2.Element)
}

var state: (PartialIteration<Base1.AsyncIterator, Partial>, PartialIteration<Base2.AsyncIterator, Partial>)

init(_ iterator1: Base1.AsyncIterator, terminatesOnNil iter1TerminatesOnNil: Bool = false, _ iterator2: Base2.AsyncIterator, terminatesOnNil iter2terminatesOnNil: Bool = false) {
self.iter1TerminatesOnNil = iter1TerminatesOnNil
self.iter2terminatesOnNil = iter2terminatesOnNil
state = (.idle(iterator1), .idle(iterator2))
}

mutating func apply(_ task1: Task<Merge2StateMachine<Base1, Base2>.Partial, Never>?, _ task2: Task<Merge2StateMachine<Base1, Base2>.Partial, Never>?) async rethrows -> Either? {
switch await Task.select([task1, task2].compactMap ({ $0 })).value {
case .first(let result, let iterator):
do {
guard let value = try state.0.resolve(result, iterator) else {
if iter1TerminatesOnNil {
state.1.cancel()
return nil
}
return try await next()
}
return .first(value)
} catch {
state.1.cancel()
throw error
}
case .second(let result, let iterator):
do {
guard let value = try state.1.resolve(result, iterator) else {
if iter2terminatesOnNil {
state.0.cancel()
return nil
}
return try await next()
}
return .second(value)
} catch {
state.0.cancel()
throw error
}
}
}

func first(_ iterator1: Base1.AsyncIterator) -> Task<Partial, Never> {
Task {
var iter = iterator1
do {
let value = try await iter.next()
return .first(.success(value), iter)
} catch {
return .first(.failure(error), iter)
}
}
}

func second(_ iterator2: Base2.AsyncIterator) -> Task<Partial, Never> {
Task {
var iter = iterator2
do {
let value = try await iter.next()
return .second(.success(value), iter)
} catch {
return .second(.failure(error), iter)
}
}
}

/// Advances to the next element and returns it or `nil` if no next element exists.
mutating func next() async rethrows -> Either? {
if Task.isCancelled {
state.0.cancel()
state.1.cancel()
return nil
}
switch state {
case (.idle(let iterator1), .idle(let iterator2)):
let task1 = first(iterator1)
let task2 = second(iterator2)
state = (.pending(task1), .pending(task2))
return try await apply(task1, task2)
case (.idle(let iterator1), .pending(let task2)):
let task1 = first(iterator1)
state = (.pending(task1), .pending(task2))
return try await apply(task1, task2)
case (.pending(let task1), .idle(let iterator2)):
let task2 = second(iterator2)
state = (.pending(task1), .pending(task2))
return try await apply(task1, task2)
case (.idle(var iterator1), .terminal):
do {
if let value = try await iterator1.next() {
state = (.idle(iterator1), .terminal)
return .first(value)
} else {
state = (.terminal, .terminal)
return nil
}
} catch {
state = (.terminal, .terminal)
throw error
}
case (.terminal, .idle(var iterator2)):
do {
if let value = try await iterator2.next() {
state = (.terminal, .idle(iterator2))
return .second(value)
} else {
state = (.terminal, .terminal)
return nil
}
} catch {
state = (.terminal, .terminal)
throw error
}
case (.terminal, .pending(let task2)):
return try await apply(nil, task2)
case (.pending(let task1), .pending(let task2)):
return try await apply(task1, task2)
case (.pending(let task1), .terminal):
return try await apply(task1, nil)
case (.terminal, .terminal):
return nil
}
}
}

extension Merge2StateMachine.Either where Base1.Element == Base2.Element {
var value : Base1.Element {
switch self {
case .first(let val):
return val
case .second(let val):
return val
}
}
public func merge<Base1: AsyncSequence, Base2: AsyncSequence>(
_ base1: Base1,
_ base2: Base2
) -> AsyncMerge2Sequence<Base1, Base2>{
AsyncMerge2Sequence(base1, base2)
}

/// An asynchronous sequence of elements from two underlying asynchronous sequences
///
/// In a `AsyncMerge2Sequence` instance, the *i*th element is the *i*th element
/// resolved in sequential order out of the two underlying asynchronous sequences.
/// Use the `merge(_:_:)` function to create an `AsyncMerge2Sequence`.
public struct AsyncMerge2Sequence<Base1: AsyncSequence, Base2: AsyncSequence>: AsyncSequence, Sendable
where
Base1.Element == Base2.Element,
Base1: Sendable, Base2: Sendable,
Base1.Element: Sendable,
Base1.AsyncIterator: Sendable, Base2.AsyncIterator: Sendable {
public struct AsyncMerge2Sequence<Base1: AsyncSequence, Base2: AsyncSequence>: AsyncSequence
where Base1.Element == Base2.Element {
public typealias Element = Base1.Element
/// An iterator for `AsyncMerge2Sequence`
public struct Iterator: AsyncIteratorProtocol, Sendable {
var state: Merge2StateMachine<Base1, Base2>
init(_ base1: Base1.AsyncIterator, _ base2: Base2.AsyncIterator) {
state = Merge2StateMachine(base1, base2)
}
public typealias AsyncIterator = Iterator

public mutating func next() async rethrows -> Element? {
return try await state.next()?.value
}
}

let base1: Base1
let base2: Base2
init(_ base1: Base1, _ base2: Base2) {

public init(_ base1: Base1, _ base2: Base2) {
self.base1 = base1
self.base2 = base2
}

public func makeAsyncIterator() -> Iterator {
return Iterator(base1.makeAsyncIterator(), base2.makeAsyncIterator())
Iterator(
base1: self.base1,
base2: self.base2
)
}

public struct Iterator: AsyncIteratorProtocol {
let mergeStateMachine: MergeStateMachine<Element>

init(base1: Base1, base2: Base2) {
self.mergeStateMachine = MergeStateMachine(
base1,
base2
)
}

public mutating func next() async rethrows -> Element? {
let mergedElement = await self.mergeStateMachine.next()
switch mergedElement {
case .element(let result):
return try result._rethrowGet()
case .termination:
return nil
}
}
}
}

extension AsyncMerge2Sequence: Sendable where Base1: Sendable, Base2: Sendable {}
extension AsyncMerge2Sequence.Iterator: Sendable where Base1: Sendable, Base2: Sendable {}
Loading