Skip to content

[AsyncThrowingChannel] make the fail terminal event non async #164

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,14 @@ public final class AsyncThrowingChannel<Element: Sendable, Failure: Error>: Asyn
public init(element elementType: Element.Type = Element.self, failure failureType: Failure.Type = Failure.self)

public func send(_ element: Element) async
public func fail(_ error: Error) async where Failure == Error
public func fail(_ error: Error) where Failure == Error
public func finish()

public func makeAsyncIterator() -> Iterator
}
```

Channels are intended to be used as communication types between tasks. Particularly when one task produces values and another task consumes said values. On the one hand, the back pressure applied by `send(_:)` and `fail(_:)` via the suspension/resume ensure that the production of values does not exceed the consumption of values from iteration. Each of these methods suspend after enqueuing the event and are resumed when the next call to `next()` on the `Iterator` is made. On the other hand, the call to `finish()` immediately resumes all the pending operations for every producers and consumers. Thus, every suspended `send(_:)` operations instantly resume, so as every suspended `next()` operations by producing a nil value, indicating the termination of the iterations. Further calls to `send(_:)` will immediately resume.
Channels are intended to be used as communication types between tasks. Particularly when one task produces values and another task consumes said values. On the one hand, the back pressure applied by `send(_:)` via the suspension/resume ensures that the production of values does not exceed the consumption of values from iteration. This method suspends after enqueuing the event and is resumed when the next call to `next()` on the `Iterator` is made. On the other hand, the call to `finish()` or `fail(_:)` immediately resumes all the pending operations for every producers and consumers. Thus, every suspended `send(_:)` operations instantly resume, so as every suspended `next()` operations by producing a nil value, or by throwing an error, indicating the termination of the iterations. Further calls to `send(_:)` will immediately resume.

```swift
let channel = AsyncChannel<String>()
Expand Down
20 changes: 12 additions & 8 deletions Sources/AsyncAlgorithms/AsyncChannel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
///
/// The `AsyncChannel` class is intended to be used as a communication type between tasks,
/// particularly when one task produces values and another task consumes those values. The back
/// pressure applied by `send(_:)` and `finish()` via the suspension/resume ensures that
/// the production of values does not exceed the consumption of values from iteration. Each of these
/// methods suspends after enqueuing the event and is resumed when the next call to `next()`
/// on the `Iterator` is made.
/// pressure applied by `send(_:)` via the suspension/resume ensures that
/// the production of values does not exceed the consumption of values from iteration. This method
/// suspends after enqueuing the event and is resumed when the next call to `next()`
/// on the `Iterator` is made, or when `finish()` is called from another Task.
/// As `finish()` induces a terminal state, there is no need for a back pressure management.
/// This function does not suspend and will finish all the pending iterations.
public final class AsyncChannel<Element: Sendable>: AsyncSequence, Sendable {
/// The iterator for a `AsyncChannel` instance.
public struct Iterator: AsyncIteratorProtocol, Sendable {
Expand Down Expand Up @@ -168,7 +170,7 @@ public final class AsyncChannel<Element: Sendable>: AsyncSequence, Sendable {
}
}

func finishAll() {
func terminateAll() {
let (sends, nexts) = state.withCriticalRegion { state -> ([UnsafeContinuation<UnsafeContinuation<Element?, Never>?, Never>], Set<Awaiting>) in
if state.terminal {
return ([], [])
Expand All @@ -195,7 +197,7 @@ public final class AsyncChannel<Element: Sendable>: AsyncSequence, Sendable {

func _send(_ element: Element) async {
await withTaskCancellationHandler {
finishAll()
terminateAll()
} operation: {
let continuation: UnsafeContinuation<Element?, Never>? = await withUnsafeContinuation { continuation in
state.withCriticalRegion { state -> UnsafeResumption<UnsafeContinuation<Element?, Never>?, Never>? in
Expand Down Expand Up @@ -225,15 +227,17 @@ public final class AsyncChannel<Element: Sendable>: AsyncSequence, Sendable {
}
}

/// Send an element to an awaiting iteration. This function will resume when the next call to `next()` is made.
/// Send an element to an awaiting iteration. This function will resume when the next call to `next()` is made
/// or when a call to `finish()` is made from another Task.
/// If the channel is already finished then this returns immediately
public func send(_ element: Element) async {
await _send(element)
}

/// Send a finish to all awaiting iterations.
/// All subsequent calls to `next(_:)` will resume immediately.
public func finish() {
finishAll()
terminateAll()
}

/// Create an `Iterator` for iteration of an `AsyncChannel`
Expand Down
116 changes: 79 additions & 37 deletions Sources/AsyncAlgorithms/AsyncThrowingChannel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,13 @@

/// An error-throwing channel for sending elements from on task to another with back pressure.
///
/// The `AsyncThrowingChannel` class is intended to be used as a communication types between tasks, particularly when one task produces values and another task consumes those values. The back pressure applied by `send(_:)`, `fail(_:)` and `finish()` via suspension/resume ensures that the production of values does not exceed the consumption of values from iteration. Each of these methods suspends after enqueuing the event and is resumed when the next call to `next()` on the `Iterator` is made.
/// The `AsyncThrowingChannel` class is intended to be used as a communication types between tasks,
/// particularly when one task produces values and another task consumes those values. The back
/// pressure applied by `send(_:)` via suspension/resume ensures that the production of values does
/// not exceed the consumption of values from iteration. This method suspends after enqueuing the event
/// and is resumed when the next call to `next()` on the `Iterator` is made, or when `finish()`/`fail(_:)` is called
/// from another Task. As `finish()` and `fail(_:)` induce a terminal state, there is no need for a back pressure management.
/// Those functions do not suspend and will finish all the pending iterations.
public final class AsyncThrowingChannel<Element: Sendable, Failure: Error>: AsyncSequence, Sendable {
/// The iterator for an `AsyncThrowingChannel` instance.
public struct Iterator: AsyncIteratorProtocol, Sendable {
Expand Down Expand Up @@ -78,12 +84,23 @@ public final class AsyncThrowingChannel<Element: Sendable, Failure: Error>: Asyn
return lhs.generation == rhs.generation
}
}

enum Termination {
case finished
case failed(Error)
}

enum Emission {
case idle
case pending([UnsafeContinuation<UnsafeContinuation<Element?, Error>?, Never>])
case awaiting(Set<Awaiting>)

case terminated(Termination)

var isTerminated: Bool {
guard case .terminated = self else { return false }
return true
}

mutating func cancel(_ generation: Int) -> UnsafeContinuation<Element?, Error>? {
switch self {
case .awaiting(var awaiting):
Expand All @@ -106,9 +123,8 @@ public final class AsyncThrowingChannel<Element: Sendable, Failure: Error>: Asyn
struct State {
var emission: Emission = .idle
var generation = 0
var terminal = false
}

let state = ManagedCriticalState(State())

public init(_ elementType: Element.Type = Element.self) { }
Expand All @@ -129,12 +145,9 @@ public final class AsyncThrowingChannel<Element: Sendable, Failure: Error>: Asyn
func next(_ generation: Int) async throws -> Element? {
return try await withUnsafeThrowingContinuation { continuation in
var cancelled = false
var terminal = false
var potentialTermination: Termination?

state.withCriticalRegion { state -> UnsafeResumption<UnsafeContinuation<Element?, Error>?, Never>? in
if state.terminal {
terminal = true
return nil
}
switch state.emission {
case .idle:
state.emission = .awaiting([Awaiting(generation: generation, continuation: continuation)])
Expand All @@ -158,53 +171,78 @@ public final class AsyncThrowingChannel<Element: Sendable, Failure: Error>: Asyn
state.emission = .awaiting(nexts)
}
return nil
case .terminated(let termination):
potentialTermination = termination
state.emission = .terminated(.finished)
return nil
}
}?.resume()
if cancelled || terminal {

if cancelled {
continuation.resume(returning: nil)
return
}

switch potentialTermination {
case .none:
return
case .failed(let error):
continuation.resume(throwing: error)
return
case .finished:
continuation.resume(returning: nil)
return
}
}
}
func finishAll() {

func terminateAll(error: Failure? = nil) {
let (sends, nexts) = state.withCriticalRegion { state -> ([UnsafeContinuation<UnsafeContinuation<Element?, Error>?, Never>], Set<Awaiting>) in
if state.terminal {
return ([], [])

let nextState: Emission
if let error = error {
nextState = .terminated(.failed(error))
} else {
nextState = .terminated(.finished)
}
state.terminal = true

switch state.emission {
case .idle:
state.emission = nextState
return ([], [])
case .pending(let nexts):
state.emission = .idle
state.emission = nextState
return (nexts, [])
case .awaiting(let nexts):
state.emission = .idle
state.emission = nextState
return ([], nexts)
case .terminated:
return ([], [])
}
}

for send in sends {
send.resume(returning: nil)
}
for next in nexts {
next.continuation?.resume(returning: nil)

if let error = error {
for next in nexts {
next.continuation?.resume(throwing: error)
}
} else {
for next in nexts {
next.continuation?.resume(returning: nil)
}
}

}

func _send(_ result: Result<Element, Error>) async {
func _send(_ element: Element) async {
await withTaskCancellationHandler {
finishAll()
terminateAll()
} operation: {
let continuation: UnsafeContinuation<Element?, Error>? = await withUnsafeContinuation { continuation in
state.withCriticalRegion { state -> UnsafeResumption<UnsafeContinuation<Element?, Error>?, Never>? in
if state.terminal {
return UnsafeResumption(continuation: continuation, success: nil)
}

if case .failure = result {
state.terminal = true
}

switch state.emission {
case .idle:
state.emission = .pending([continuation])
Expand All @@ -221,28 +259,32 @@ public final class AsyncThrowingChannel<Element: Sendable, Failure: Error>: Asyn
state.emission = .awaiting(nexts)
}
return UnsafeResumption(continuation: continuation, success: next)
case .terminated:
return UnsafeResumption(continuation: continuation, success: nil)
}
}?.resume()
}
continuation?.resume(with: result.map { $0 as Element? })
continuation?.resume(returning: element)
}
}

/// Send an element to an awaiting iteration. This function will resume when the next call to `next()` is made.
/// Send an element to an awaiting iteration. This function will resume when the next call to `next()` is made
/// or when a call to `finish()`/`fail(_:)` is made from another Task.
/// If the channel is already finished then this returns immediately
public func send(_ element: Element) async {
await _send(.success(element))
await _send(element)
}

/// Send an error to an awaiting iteration. This function will resume when the next call to `next()` is made.
/// If the channel is already finished then this returns immediately
public func fail(_ error: Error) async where Failure == Error {
await _send(.failure(error))
/// Send an error to all awaiting iterations.
/// All subsequent calls to `next(_:)` will resume immediately.
public func fail(_ error: Error) where Failure == Error {
terminateAll(error: error)
}

/// Send a finish to all awaiting iterations.
/// All subsequent calls to `next(_:)` will resume immediately.
public func finish() {
finishAll()
terminateAll()
}

public func makeAsyncIterator() -> Iterator {
Expand Down
77 changes: 16 additions & 61 deletions Tests/AsyncAlgorithmsTests/TestChannel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -63,18 +63,30 @@ final class TestChannel: XCTestCase {
XCTAssertEqual(collected, expected)
}

func test_asyncThrowingChannel_throws_when_fail_is_called() async {
func test_asyncThrowingChannel_throws_and_discards_additional_sent_values_when_fail_is_called() async {
let sendImmediatelyResumes = expectation(description: "Send immediately resumes after fail")

let channel = AsyncThrowingChannel<String, Error>()
Task {
await channel.fail(Failure())
}
channel.fail(Failure())

var iterator = channel.makeAsyncIterator()
do {
let _ = try await iterator.next()
XCTFail("The AsyncThrowingChannel should have thrown")
} catch {
XCTAssertEqual(error as? Failure, Failure())
}

do {
let pastFailure = try await iterator.next()
XCTAssertNil(pastFailure)
} catch {
XCTFail("The AsyncThrowingChannel should not fail when failure has already been fired")
}

await channel.send("send")
sendImmediatelyResumes.fulfill()
wait(for: [sendImmediatelyResumes], timeout: 1.0)
}

func test_asyncChannel_ends_alls_iterators_and_discards_additional_sent_values_when_finish_is_called() async {
Expand Down Expand Up @@ -132,63 +144,6 @@ final class TestChannel: XCTestCase {
wait(for: [additionalSend], timeout: 1.0)
}

func test_asyncChannel_ends_alls_iterators_and_discards_additional_sent_values_when_finish_is_called2() async throws {
let channel = AsyncChannel<String>()
let complete = ManagedCriticalState(false)
let finished = expectation(description: "finished")

let valueFromConsumer1 = ManagedCriticalState<String?>(nil)
let valueFromConsumer2 = ManagedCriticalState<String?>(nil)

let received = expectation(description: "received")
received.expectedFulfillmentCount = 2

let pastEnd = expectation(description: "pastEnd")
pastEnd.expectedFulfillmentCount = 2

Task(priority: .high) {
var iterator = channel.makeAsyncIterator()
let ending = await iterator.next()
valueFromConsumer1.withCriticalRegion { $0 = ending }
received.fulfill()
let item = await iterator.next()
XCTAssertNil(item)
pastEnd.fulfill()
}

Task(priority: .high) {
var iterator = channel.makeAsyncIterator()
let ending = await iterator.next()
valueFromConsumer2.withCriticalRegion { $0 = ending }
received.fulfill()
let item = await iterator.next()
XCTAssertNil(item)
pastEnd.fulfill()
}

try await Task.sleep(nanoseconds: 1_000_000_000)

Task(priority: .low) {
channel.finish()
complete.withCriticalRegion { $0 = true }
finished.fulfill()
}

wait(for: [finished, received], timeout: 1.0)

XCTAssertTrue(complete.withCriticalRegion { $0 })
XCTAssertEqual(valueFromConsumer1.withCriticalRegion { $0 }, nil)
XCTAssertEqual(valueFromConsumer2.withCriticalRegion { $0 }, nil)

wait(for: [pastEnd], timeout: 1.0)
let additionalSend = expectation(description: "additional send")
Task {
await channel.send("test")
additionalSend.fulfill()
}
wait(for: [additionalSend], timeout: 1.0)
}

func test_asyncThrowingChannel_ends_alls_iterators_and_discards_additional_sent_values_when_finish_is_called() async {
let channel = AsyncThrowingChannel<String, Error>()
let complete = ManagedCriticalState(false)
Expand Down