Skip to content

Commit a14c464

Browse files
authored
Add a passthrough message source and sequence (#1252)
Motivation: `AsyncThrowingStream` provides an implementation of `AsyncSequence` which allows the holder to provide new values to that sequence from within a closure provided to the initializer. This API doesn't fit our needs: we must be able to provide the values 'from the outside' rather than during initialization. Modifications: - Add a `PassthroughMessageSequence`, an implementation of `AsyncSequence` which consumes messages from a `PassthroughMessagesSource`. - The source may have values provided to it via `yield(_:)` and terminated with `finish()` or `finish(throwing:)`. - Add tests and a few `AsyncSequence` helpers. Result: We have an `AsyncSequence` implementation which can have values provided to it.
1 parent 54ea824 commit a14c464

7 files changed

+484
-40
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
/*
2+
* Copyright 2021, gRPC Authors All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#if compiler(>=5.5)
17+
18+
/// An ``AsyncSequence`` adapter for a ``PassthroughMessageSource``.`
19+
@available(macOS 12, iOS 15, tvOS 15, watchOS 8, *)
20+
@usableFromInline
21+
internal struct PassthroughMessageSequence<Element, Failure: Error>: AsyncSequence {
22+
@usableFromInline
23+
internal typealias Element = Element
24+
25+
@usableFromInline
26+
internal typealias AsyncIterator = Iterator
27+
28+
/// The source of messages in the sequence.
29+
@usableFromInline
30+
internal let _source: PassthroughMessageSource<Element, Failure>
31+
32+
@usableFromInline
33+
internal func makeAsyncIterator() -> Iterator {
34+
return Iterator(storage: self._source)
35+
}
36+
37+
@usableFromInline
38+
internal init(consuming source: PassthroughMessageSource<Element, Failure>) {
39+
self._source = source
40+
}
41+
42+
@usableFromInline
43+
internal struct Iterator: AsyncIteratorProtocol {
44+
@usableFromInline
45+
internal let _storage: PassthroughMessageSource<Element, Failure>
46+
47+
fileprivate init(storage: PassthroughMessageSource<Element, Failure>) {
48+
self._storage = storage
49+
}
50+
51+
@inlinable
52+
internal func next() async throws -> Element? {
53+
return try await self._storage.consumeNextElement()
54+
}
55+
}
56+
}
57+
58+
#endif // compiler(>=5.5)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
/*
2+
* Copyright 2021, gRPC Authors All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#if compiler(>=5.5)
17+
import NIOConcurrencyHelpers
18+
import NIOCore
19+
20+
/// The source of messages for a ``PassthroughMessageSequence``.`
21+
///
22+
/// Values may be provided to the source with calls to ``yield(_:)`` which returns whether the value
23+
/// was accepted (and how many values are yet to be consumed) -- or dropped.
24+
///
25+
/// The backing storage has an unbounded capacity and callers should use the number of unconsumed
26+
/// values returned from ``yield(_:)`` as an indication of when to stop providing values.
27+
///
28+
/// The source must be finished exactly once by calling ``finish()`` or ``finish(throwing:)`` to
29+
/// indicate that the sequence should end with an error.
30+
@available(macOS 12, iOS 15, tvOS 15, watchOS 8, *)
31+
@usableFromInline
32+
internal final class PassthroughMessageSource<Element, Failure: Error> {
33+
@usableFromInline
34+
internal typealias _ContinuationResult = Result<Element?, Error>
35+
36+
/// All state in this class must be accessed via the lock.
37+
///
38+
/// - Important: We use a `class` with a lock rather than an `actor` as we must guarantee that
39+
/// calls to ``yield(_:)`` are not reordered.
40+
@usableFromInline
41+
internal let _lock: Lock
42+
43+
/// A queue of elements which may be consumed as soon as there is demand.
44+
@usableFromInline
45+
internal var _continuationResults: CircularBuffer<_ContinuationResult>
46+
47+
/// A continuation which will be resumed in the future. The continuation must be `nil`
48+
/// if ``continuationResults`` is not empty.
49+
@usableFromInline
50+
internal var _continuation: Optional<CheckedContinuation<Element?, Error>>
51+
52+
/// True if a terminal continuation result (`.success(nil)` or `.failure()`) has been seen.
53+
/// No more values may be enqueued to `continuationResults` if this is `true`.
54+
@usableFromInline
55+
internal var _isTerminated: Bool
56+
57+
@usableFromInline
58+
internal init(initialBufferCapacity: Int = 16) {
59+
self._lock = Lock()
60+
self._continuationResults = CircularBuffer(initialCapacity: initialBufferCapacity)
61+
self._continuation = nil
62+
self._isTerminated = false
63+
}
64+
65+
// MARK: - Append / Yield
66+
67+
@usableFromInline
68+
internal enum YieldResult: Hashable {
69+
/// The value was accepted. The `queueDepth` indicates how many elements are waiting to be
70+
/// consumed.
71+
///
72+
/// If `queueDepth` is zero then the value was consumed immediately.
73+
case accepted(queueDepth: Int)
74+
75+
/// The value was dropped because the source has already been finished.
76+
case dropped
77+
}
78+
79+
@inlinable
80+
internal func yield(_ element: Element) -> YieldResult {
81+
let continuationResult: _ContinuationResult = .success(element)
82+
return self._yield(continuationResult, isTerminator: false)
83+
}
84+
85+
@inlinable
86+
internal func finish(throwing error: Failure? = nil) -> YieldResult {
87+
let continuationResult: _ContinuationResult = error.map { .failure($0) } ?? .success(nil)
88+
return self._yield(continuationResult, isTerminator: true)
89+
}
90+
91+
@usableFromInline
92+
internal enum _YieldResult {
93+
/// The sequence has already been terminated; drop the element.
94+
case alreadyTerminated
95+
/// The element was added to the queue to be consumed later.
96+
case queued(Int)
97+
/// Demand for an element already existed: complete the continuation with the result being
98+
/// yielded.
99+
case resume(CheckedContinuation<Element?, Error>)
100+
}
101+
102+
@inlinable
103+
internal func _yield(
104+
_ continuationResult: _ContinuationResult, isTerminator: Bool
105+
) -> YieldResult {
106+
let result: _YieldResult = self._lock.withLock {
107+
if self._isTerminated {
108+
return .alreadyTerminated
109+
} else if let continuation = self._continuation {
110+
self._continuation = nil
111+
return .resume(continuation)
112+
} else {
113+
self._isTerminated = isTerminator
114+
self._continuationResults.append(continuationResult)
115+
return .queued(self._continuationResults.count)
116+
}
117+
}
118+
119+
let yieldResult: YieldResult
120+
switch result {
121+
case let .queued(size):
122+
yieldResult = .accepted(queueDepth: size)
123+
case let .resume(continuation):
124+
// If we resume a continuation then the queue must be empty
125+
yieldResult = .accepted(queueDepth: 0)
126+
continuation.resume(with: continuationResult)
127+
case .alreadyTerminated:
128+
yieldResult = .dropped
129+
}
130+
131+
return yieldResult
132+
}
133+
134+
// MARK: - Next
135+
136+
@inlinable
137+
internal func consumeNextElement() async throws -> Element? {
138+
return try await withCheckedThrowingContinuation {
139+
self._consumeNextElement(continuation: $0)
140+
}
141+
}
142+
143+
@inlinable
144+
internal func _consumeNextElement(continuation: CheckedContinuation<Element?, Error>) {
145+
let continuationResult: _ContinuationResult? = self._lock.withLock {
146+
if let nextResult = self._continuationResults.popFirst() {
147+
return nextResult
148+
} else {
149+
// Nothing buffered and not terminated yet: save the continuation for later.
150+
assert(self._continuation == nil)
151+
self._continuation = continuation
152+
return nil
153+
}
154+
}
155+
156+
if let continuationResult = continuationResult {
157+
continuation.resume(with: continuationResult)
158+
}
159+
}
160+
}
161+
162+
#endif // compiler(>=5.5)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/*
2+
* Copyright 2021, gRPC Authors All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#if compiler(>=5.5)
17+
18+
@available(macOS 12, iOS 15, tvOS 15, watchOS 8, *)
19+
extension AsyncSequence {
20+
internal func collect() async throws -> [Element] {
21+
return try await self.reduce(into: []) { accumulated, next in
22+
accumulated.append(next)
23+
}
24+
}
25+
26+
internal func count() async throws -> Int {
27+
return try await self.reduce(0) { count, _ in count + 1 }
28+
}
29+
}
30+
31+
#endif // compiler(>=5.5)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
/*
2+
* Copyright 2021, gRPC Authors All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#if compiler(>=5.5)
17+
@testable import GRPC
18+
import XCTest
19+
20+
@available(macOS 12, iOS 15, tvOS 15, watchOS 8, *)
21+
class PassthroughMessageSourceTests: GRPCTestCase {
22+
func testBasicUsage() {
23+
XCTAsyncTest {
24+
let source = PassthroughMessageSource<String, Never>()
25+
let sequence = PassthroughMessageSequence(consuming: source)
26+
27+
XCTAssertEqual(source.yield("foo"), .accepted(queueDepth: 1))
28+
XCTAssertEqual(source.yield("bar"), .accepted(queueDepth: 2))
29+
XCTAssertEqual(source.yield("baz"), .accepted(queueDepth: 3))
30+
31+
let firstTwo = try await sequence.prefix(2).collect()
32+
XCTAssertEqual(firstTwo, ["foo", "bar"])
33+
34+
XCTAssertEqual(source.yield("bar"), .accepted(queueDepth: 2))
35+
XCTAssertEqual(source.yield("foo"), .accepted(queueDepth: 3))
36+
37+
XCTAssertEqual(source.finish(), .accepted(queueDepth: 4))
38+
39+
let theRest = try await sequence.collect()
40+
XCTAssertEqual(theRest, ["baz", "bar", "foo"])
41+
}
42+
}
43+
44+
func testFinishWithError() {
45+
XCTAsyncTest {
46+
let source = PassthroughMessageSource<String, TestError>()
47+
48+
XCTAssertEqual(source.yield("one"), .accepted(queueDepth: 1))
49+
XCTAssertEqual(source.yield("two"), .accepted(queueDepth: 2))
50+
XCTAssertEqual(source.yield("three"), .accepted(queueDepth: 3))
51+
XCTAssertEqual(source.finish(throwing: TestError()), .accepted(queueDepth: 4))
52+
53+
// We should still be able to get the elements before the error.
54+
let sequence = PassthroughMessageSequence(consuming: source)
55+
let elements = try await sequence.prefix(3).collect()
56+
XCTAssertEqual(elements, ["one", "two", "three"])
57+
58+
do {
59+
for try await element in sequence {
60+
XCTFail("Unexpected value '\(element)'")
61+
}
62+
XCTFail("AsyncSequence did not throw")
63+
} catch {
64+
XCTAssert(error is TestError)
65+
}
66+
}
67+
}
68+
69+
func testYieldAfterFinish() {
70+
XCTAsyncTest {
71+
let source = PassthroughMessageSource<String, Never>()
72+
XCTAssertEqual(source.finish(), .accepted(queueDepth: 1))
73+
XCTAssertEqual(source.yield("foo"), .dropped)
74+
75+
let sequence = PassthroughMessageSequence(consuming: source)
76+
let elements = try await sequence.count()
77+
XCTAssertEqual(elements, 0)
78+
}
79+
}
80+
81+
func testMultipleFinishes() {
82+
XCTAsyncTest {
83+
let source = PassthroughMessageSource<String, TestError>()
84+
XCTAssertEqual(source.finish(), .accepted(queueDepth: 1))
85+
XCTAssertEqual(source.finish(), .dropped)
86+
XCTAssertEqual(source.finish(throwing: TestError()), .dropped)
87+
88+
let sequence = PassthroughMessageSequence(consuming: source)
89+
let elements = try await sequence.count()
90+
XCTAssertEqual(elements, 0)
91+
}
92+
}
93+
94+
func testConsumeBeforeYield() {
95+
XCTAsyncTest {
96+
let source = PassthroughMessageSource<String, Never>()
97+
let sequence = PassthroughMessageSequence(consuming: source)
98+
99+
await withThrowingTaskGroup(of: Void.self) { group in
100+
group.addTask(priority: .high) {
101+
let iterator = sequence.makeAsyncIterator()
102+
if let next = try await iterator.next() {
103+
XCTAssertEqual(next, "one")
104+
} else {
105+
XCTFail("No value produced")
106+
}
107+
}
108+
109+
group.addTask(priority: .low) {
110+
let result = source.yield("one")
111+
// We can't guarantee that this task will run after the other so we *may* have a queue
112+
// depth of one.
113+
XCTAssert(result == .accepted(queueDepth: 0) || result == .accepted(queueDepth: 1))
114+
}
115+
}
116+
}
117+
}
118+
119+
func testConsumeBeforeFinish() {
120+
XCTAsyncTest {
121+
let source = PassthroughMessageSource<String, TestError>()
122+
let sequence = PassthroughMessageSequence(consuming: source)
123+
124+
await withThrowingTaskGroup(of: Void.self) { group in
125+
group.addTask(priority: .high) {
126+
let iterator = sequence.makeAsyncIterator()
127+
await XCTAssertThrowsError(_ = try await iterator.next()) { error in
128+
XCTAssert(error is TestError)
129+
}
130+
}
131+
132+
group.addTask(priority: .low) {
133+
let result = source.finish(throwing: TestError())
134+
// We can't guarantee that this task will run after the other so we *may* have a queue
135+
// depth of one.
136+
XCTAssert(result == .accepted(queueDepth: 0) || result == .accepted(queueDepth: 1))
137+
}
138+
}
139+
}
140+
}
141+
}
142+
143+
fileprivate struct TestError: Error {}
144+
145+
#endif // compiler(>=5.5)

0 commit comments

Comments
 (0)