Skip to content

Commit 48611b5

Browse files
committed
Finish observation when all group tasks are finished or observed sequence is completed
Given observed sequence is finite, When sequence completes, Then observation finishes Given observed sequence is infinite, When all tasks complete, Then observation finishes
1 parent 2484e0b commit 48611b5

File tree

4 files changed

+103
-9
lines changed

4 files changed

+103
-9
lines changed

Package.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ let package = Package(
2121
.target(
2222
name: "AsyncSwiftly",
2323
dependencies: [
24+
"AsyncTrigger",
25+
"AsyncMaterializedSequence",
2426
.product(name: "AsyncAlgorithms", package: "swift-async-algorithms"),
2527
],
2628
),
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
//
2+
// AsyncTakeUntilSequence.swift
3+
// async-swiftly
4+
//
5+
// Created by Erik Basargin on 05/07/2025.
6+
//
7+
8+
import AsyncAlgorithms
9+
import AsyncMaterializedSequence
10+
11+
extension AsyncSequence {
12+
13+
func takeUntil<TriggerSequence: AsyncSequence>(
14+
_ trigger: TriggerSequence
15+
) -> AsyncTakeUntilSequence<Self, TriggerSequence> where Self: Sendable, Self.Element: Sendable, TriggerSequence: Sendable, TriggerSequence.Element: Sendable {
16+
.init(self, trigger)
17+
}
18+
}
19+
20+
struct AsyncTakeUntilSequence<
21+
Base1: AsyncSequence,
22+
Base2: AsyncSequence
23+
>: AsyncSequence, Sendable where Base1: Sendable, Base1.Element: Sendable, Base2: Sendable, Base2.Element: Sendable {
24+
25+
typealias Base = AsyncCombineLatest2Sequence<
26+
AsyncMaterializedSequence<Base1>,
27+
AsyncMerge2Sequence<AsyncThrowingMapSequence<AsyncSyncSequence<[Int]>, Bool>, AsyncThrowingMapSequence<Base2, Bool>>
28+
>
29+
30+
let base: Base
31+
32+
init(_ base1: Base1, _ base2: Base2) {
33+
let startWith = [1].async.map { _ throws in false }
34+
let triggerBase = base2.map { _ throws in true }
35+
let trigger = merge(startWith, triggerBase)
36+
base = combineLatest(base1.materialize(), trigger)
37+
}
38+
39+
func makeAsyncIterator() -> Iterator {
40+
Iterator(base: base.makeAsyncIterator())
41+
}
42+
43+
struct Iterator: AsyncIteratorProtocol {
44+
45+
var base: Base.AsyncIterator
46+
47+
mutating func next() async throws -> Base1.Element? {
48+
guard let value = try await base.next() else {
49+
return nil
50+
}
51+
52+
switch value {
53+
case (.value(let element), false):
54+
return element
55+
case (.completed, false), (_, true):
56+
return nil
57+
}
58+
}
59+
}
60+
}

Sources/AsyncSwiftly/TestingTaskGroup.swift

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
// Created by Erik Basargin on 25/06/2025.
66
//
77

8+
import AsyncTrigger
9+
import AsyncAlgorithms
810
import os
911
import Foundation
1012

@@ -71,6 +73,7 @@ public struct TestingTaskGroup<ObservationElement: Sendable>: ~Copyable {
7173
let clock: Clock
7274
var group: ThrowingDiscardingTaskGroup<any Error>
7375
let events = EventsStorage<ObservationElement>()
76+
let finishObservationTrigger = AsyncTrigger()
7477

7578
public init(group: ThrowingDiscardingTaskGroup<any Error>) {
7679
self.queue = WorkQueue()
@@ -80,11 +83,11 @@ public struct TestingTaskGroup<ObservationElement: Sendable>: ~Copyable {
8083

8184
public mutating func addObserver<Failure: Error>(
8285
at rawStep: Int,
83-
observer: @escaping @Sendable () -> some AsyncSequence<ObservationElement, Failure>
86+
observer: @escaping @Sendable () -> some AsyncSequence<ObservationElement, Failure> & Sendable
8487
) {
85-
addTask(at: rawStep) { [clock, events] in
88+
addTask(at: rawStep) { [clock, events, finishObservationTrigger] in
8689
do {
87-
for try await element in observer() {
90+
for try await element in observer().takeUntil(finishObservationTrigger) {
8891
let tick = clock.now.when.rawValue
8992
events.append(.value(tick, element), for: rawStep)
9093
}
@@ -98,8 +101,7 @@ public struct TestingTaskGroup<ObservationElement: Sendable>: ~Copyable {
98101
}
99102

100103
public mutating func addTask(at rawStep: Int, operation: sending @escaping @isolated(any) () async -> Void) {
101-
let duration = Clock.Step.step(rawStep)
102-
let instant = Clock.Instant(when: duration)
104+
let instant = Clock.Instant(when: .step(rawStep))
103105
let executor = OperationExecutor(instant: instant, queue: queue)
104106

105107
group.addTask { [queue] in
@@ -113,9 +115,11 @@ public struct TestingTaskGroup<ObservationElement: Sendable>: ~Copyable {
113115
work()
114116
}
115117

118+
finishObservationTrigger.fire()
119+
116120
if queue.isAnyWorkLeft {
117121
// At this point if work remains unfinished, we've got some tasks that cannot be resolved in provided range of time.
118-
// Wait until all work finishes
122+
// Wait until all work finishes or times out.
119123

120124
await queue.waitForAll()
121125
}
@@ -282,7 +286,9 @@ extension TestingTaskGroup.WorkQueue: AsyncSequence {
282286
}
283287

284288
repeat {
285-
await Task.yield()
289+
for _ in 0..<100 {
290+
await Task.yield()
291+
}
286292

287293
switch nextAction() {
288294
case let .awaitNextWork(queue):

Tests/AsyncSwiftlyTests/TestingTaskGroupTests.swift

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,8 @@ struct TestingTaskGroupTests {
118118
}
119119
}
120120

121-
@Test("Given system under test produces values, When stream is observed, Then values are received in correct moment")
122-
func valuesAreObserved() async throws {
121+
@Test("Given observed sequence is finite, When sequence completes, Then observation finishes")
122+
func observationFinishes() async throws {
123123
let (stream, continuation) = AsyncStream.makeStream(of: Int.self)
124124

125125
let result = try await withTestingTaskGroup { group in
@@ -132,6 +132,9 @@ struct TestingTaskGroupTests {
132132
group.addTask(at: 2) {
133133
continuation.yield(2)
134134
}
135+
group.addTask(at: 3) {
136+
continuation.yield(3)
137+
}
135138
}
136139

137140
#expect(result[0] == [
@@ -140,4 +143,27 @@ struct TestingTaskGroupTests {
140143
.finished(2),
141144
])
142145
}
146+
147+
@Test("Given observed sequence is infinite, When all tasks complete, Then observation finishes")
148+
func infiniteObservationFinishes() async throws {
149+
let (stream, continuation) = AsyncStream.makeStream(of: Int.self)
150+
151+
let result = try await withTestingTaskGroup { group in
152+
group.addObserver(at: 0) {
153+
stream
154+
}
155+
group.addTask(at: 1) {
156+
continuation.yield(1)
157+
}
158+
group.addTask(at: 2) {
159+
continuation.yield(2)
160+
}
161+
}
162+
163+
#expect(result[0] == [
164+
.value(1, 1),
165+
.value(2, 2),
166+
.finished(3),
167+
])
168+
}
143169
}

0 commit comments

Comments
 (0)