Skip to content

Add an async throwing source #1252

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions Sources/GRPC/AsyncAwaitSupport/PassthroughMessageSequence.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Copyright 2021, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#if compiler(>=5.5)

/// An ``AsyncSequence`` adapter for a ``PassthroughMessageSource``.`
@available(macOS 12, iOS 15, tvOS 15, watchOS 8, *)
@usableFromInline
internal struct PassthroughMessageSequence<Element, Failure: Error>: AsyncSequence {
@usableFromInline
internal typealias Element = Element

@usableFromInline
internal typealias AsyncIterator = Iterator

/// The source of messages in the sequence.
@usableFromInline
internal let _source: PassthroughMessageSource<Element, Failure>

@usableFromInline
internal func makeAsyncIterator() -> Iterator {
return Iterator(storage: self._source)
}

@usableFromInline
internal init(consuming source: PassthroughMessageSource<Element, Failure>) {
self._source = source
}

@usableFromInline
internal struct Iterator: AsyncIteratorProtocol {
@usableFromInline
internal let _storage: PassthroughMessageSource<Element, Failure>

fileprivate init(storage: PassthroughMessageSource<Element, Failure>) {
self._storage = storage
}

@inlinable
internal func next() async throws -> Element? {
return try await self._storage.consumeNextElement()
}
}
}

#endif // compiler(>=5.5)
162 changes: 162 additions & 0 deletions Sources/GRPC/AsyncAwaitSupport/PassthroughMessageSource.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
/*
* Copyright 2021, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#if compiler(>=5.5)
import NIOConcurrencyHelpers
import NIOCore

/// The source of messages for a ``PassthroughMessageSequence``.`
///
/// Values may be provided to the source with calls to ``yield(_:)`` which returns whether the value
/// was accepted (and how many values are yet to be consumed) -- or dropped.
///
/// The backing storage has an unbounded capacity and callers should use the number of unconsumed
/// values returned from ``yield(_:)`` as an indication of when to stop providing values.
///
/// The source must be finished exactly once by calling ``finish()`` or ``finish(throwing:)`` to
/// indicate that the sequence should end with an error.
@available(macOS 12, iOS 15, tvOS 15, watchOS 8, *)
@usableFromInline
internal final class PassthroughMessageSource<Element, Failure: Error> {
@usableFromInline
internal typealias _ContinuationResult = Result<Element?, Error>

/// All state in this class must be accessed via the lock.
///
/// - Important: We use a `class` with a lock rather than an `actor` as we must guarantee that
/// calls to ``yield(_:)`` are not reordered.
@usableFromInline
internal let _lock: Lock

/// A queue of elements which may be consumed as soon as there is demand.
@usableFromInline
internal var _continuationResults: CircularBuffer<_ContinuationResult>

/// A continuation which will be resumed in the future. The continuation must be `nil`
/// if ``continuationResults`` is not empty.
@usableFromInline
internal var _continuation: Optional<CheckedContinuation<Element?, Error>>

/// True if a terminal continuation result (`.success(nil)` or `.failure()`) has been seen.
/// No more values may be enqueued to `continuationResults` if this is `true`.
@usableFromInline
internal var _isTerminated: Bool

@usableFromInline
internal init(initialBufferCapacity: Int = 16) {
self._lock = Lock()
self._continuationResults = CircularBuffer(initialCapacity: initialBufferCapacity)
self._continuation = nil
self._isTerminated = false
}

// MARK: - Append / Yield

@usableFromInline
internal enum YieldResult: Hashable {
/// The value was accepted. The `queueDepth` indicates how many elements are waiting to be
/// consumed.
///
/// If `queueDepth` is zero then the value was consumed immediately.
case accepted(queueDepth: Int)

/// The value was dropped because the source has already been finished.
case dropped
}

@inlinable
internal func yield(_ element: Element) -> YieldResult {
let continuationResult: _ContinuationResult = .success(element)
return self._yield(continuationResult, isTerminator: false)
}

@inlinable
internal func finish(throwing error: Failure? = nil) -> YieldResult {
let continuationResult: _ContinuationResult = error.map { .failure($0) } ?? .success(nil)
return self._yield(continuationResult, isTerminator: true)
}

@usableFromInline
internal enum _YieldResult {
/// The sequence has already been terminated; drop the element.
case alreadyTerminated
/// The element was added to the queue to be consumed later.
case queued(Int)
/// Demand for an element already existed: complete the continuation with the result being
/// yielded.
case resume(CheckedContinuation<Element?, Error>)
}

@inlinable
internal func _yield(
_ continuationResult: _ContinuationResult, isTerminator: Bool
) -> YieldResult {
let result: _YieldResult = self._lock.withLock {
if self._isTerminated {
return .alreadyTerminated
} else if let continuation = self._continuation {
self._continuation = nil
return .resume(continuation)
} else {
self._isTerminated = isTerminator
self._continuationResults.append(continuationResult)
return .queued(self._continuationResults.count)
}
}

let yieldResult: YieldResult
switch result {
case let .queued(size):
yieldResult = .accepted(queueDepth: size)
case let .resume(continuation):
// If we resume a continuation then the queue must be empty
yieldResult = .accepted(queueDepth: 0)
continuation.resume(with: continuationResult)
case .alreadyTerminated:
yieldResult = .dropped
}

return yieldResult
}

// MARK: - Next

@inlinable
internal func consumeNextElement() async throws -> Element? {
return try await withCheckedThrowingContinuation {
self._consumeNextElement(continuation: $0)
}
}

@inlinable
internal func _consumeNextElement(continuation: CheckedContinuation<Element?, Error>) {
let continuationResult: _ContinuationResult? = self._lock.withLock {
if let nextResult = self._continuationResults.popFirst() {
return nextResult
} else {
// Nothing buffered and not terminated yet: save the continuation for later.
assert(self._continuation == nil)
self._continuation = continuation
return nil
}
}

if let continuationResult = continuationResult {
continuation.resume(with: continuationResult)
}
}
}

#endif // compiler(>=5.5)
31 changes: 31 additions & 0 deletions Tests/GRPCTests/AsyncAwaitSupport/AsyncSequence+Helpers.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Copyright 2021, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#if compiler(>=5.5)

@available(macOS 12, iOS 15, tvOS 15, watchOS 8, *)
extension AsyncSequence {
internal func collect() async throws -> [Element] {
return try await self.reduce(into: []) { accumulated, next in
accumulated.append(next)
}
}

internal func count() async throws -> Int {
return try await self.reduce(0) { count, _ in count + 1 }
}
}

#endif // compiler(>=5.5)
145 changes: 145 additions & 0 deletions Tests/GRPCTests/AsyncAwaitSupport/PassthroughMessageSourceTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
/*
* Copyright 2021, gRPC Authors All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#if compiler(>=5.5)
@testable import GRPC
import XCTest

@available(macOS 12, iOS 15, tvOS 15, watchOS 8, *)
class PassthroughMessageSourceTests: GRPCTestCase {
func testBasicUsage() {
XCTAsyncTest {
let source = PassthroughMessageSource<String, Never>()
let sequence = PassthroughMessageSequence(consuming: source)

XCTAssertEqual(source.yield("foo"), .accepted(queueDepth: 1))
XCTAssertEqual(source.yield("bar"), .accepted(queueDepth: 2))
XCTAssertEqual(source.yield("baz"), .accepted(queueDepth: 3))

let firstTwo = try await sequence.prefix(2).collect()
XCTAssertEqual(firstTwo, ["foo", "bar"])

XCTAssertEqual(source.yield("bar"), .accepted(queueDepth: 2))
XCTAssertEqual(source.yield("foo"), .accepted(queueDepth: 3))

XCTAssertEqual(source.finish(), .accepted(queueDepth: 4))

let theRest = try await sequence.collect()
XCTAssertEqual(theRest, ["baz", "bar", "foo"])
}
}

func testFinishWithError() {
XCTAsyncTest {
let source = PassthroughMessageSource<String, TestError>()

XCTAssertEqual(source.yield("one"), .accepted(queueDepth: 1))
XCTAssertEqual(source.yield("two"), .accepted(queueDepth: 2))
XCTAssertEqual(source.yield("three"), .accepted(queueDepth: 3))
XCTAssertEqual(source.finish(throwing: TestError()), .accepted(queueDepth: 4))

// We should still be able to get the elements before the error.
let sequence = PassthroughMessageSequence(consuming: source)
let elements = try await sequence.prefix(3).collect()
XCTAssertEqual(elements, ["one", "two", "three"])

do {
for try await element in sequence {
XCTFail("Unexpected value '\(element)'")
}
XCTFail("AsyncSequence did not throw")
} catch {
XCTAssert(error is TestError)
}
}
}

func testYieldAfterFinish() {
XCTAsyncTest {
let source = PassthroughMessageSource<String, Never>()
XCTAssertEqual(source.finish(), .accepted(queueDepth: 1))
XCTAssertEqual(source.yield("foo"), .dropped)

let sequence = PassthroughMessageSequence(consuming: source)
let elements = try await sequence.count()
XCTAssertEqual(elements, 0)
}
}

func testMultipleFinishes() {
XCTAsyncTest {
let source = PassthroughMessageSource<String, TestError>()
XCTAssertEqual(source.finish(), .accepted(queueDepth: 1))
XCTAssertEqual(source.finish(), .dropped)
XCTAssertEqual(source.finish(throwing: TestError()), .dropped)

let sequence = PassthroughMessageSequence(consuming: source)
let elements = try await sequence.count()
XCTAssertEqual(elements, 0)
}
}

func testConsumeBeforeYield() {
XCTAsyncTest {
let source = PassthroughMessageSource<String, Never>()
let sequence = PassthroughMessageSequence(consuming: source)

await withThrowingTaskGroup(of: Void.self) { group in
group.addTask(priority: .high) {
let iterator = sequence.makeAsyncIterator()
if let next = try await iterator.next() {
XCTAssertEqual(next, "one")
} else {
XCTFail("No value produced")
}
}

group.addTask(priority: .low) {
let result = source.yield("one")
// We can't guarantee that this task will run after the other so we *may* have a queue
// depth of one.
XCTAssert(result == .accepted(queueDepth: 0) || result == .accepted(queueDepth: 1))
}
}
}
}

func testConsumeBeforeFinish() {
XCTAsyncTest {
let source = PassthroughMessageSource<String, TestError>()
let sequence = PassthroughMessageSequence(consuming: source)

await withThrowingTaskGroup(of: Void.self) { group in
group.addTask(priority: .high) {
let iterator = sequence.makeAsyncIterator()
await XCTAssertThrowsError(_ = try await iterator.next()) { error in
XCTAssert(error is TestError)
}
}

group.addTask(priority: .low) {
let result = source.finish(throwing: TestError())
// We can't guarantee that this task will run after the other so we *may* have a queue
// depth of one.
XCTAssert(result == .accepted(queueDepth: 0) || result == .accepted(queueDepth: 1))
}
}
}
}
}

fileprivate struct TestError: Error {}

#endif // compiler(>=5.5)
Loading