diff --git a/Sources/AsyncAlgorithms/AsyncAlgorithms.docc/Guides/Channel.md b/Sources/AsyncAlgorithms/AsyncAlgorithms.docc/Guides/Channel.md index c907cc09..974903b5 100644 --- a/Sources/AsyncAlgorithms/AsyncAlgorithms.docc/Guides/Channel.md +++ b/Sources/AsyncAlgorithms/AsyncAlgorithms.docc/Guides/Channel.md @@ -14,7 +14,7 @@ ## Proposed Solution -To achieve a system that supports back pressure and allows for the communication of more than one value from one task to another we are introducing a new type, the _channel_. The channel will be a reference-type asynchronous sequence with an asynchronous sending capability that awaits the consumption of iteration. Each value sent by the channel, or finish transmitted, will await the consumption of that value or event by iteration. That awaiting behavior will allow for the affordance of back pressure applied from the consumption site to be transmitted to the production site. This means that the rate of production cannot exceed the rate of consumption, and that the rate of consumption cannot exceed the rate of production. +To achieve a system that supports back pressure and allows for the communication of more than one value from one task to another we are introducing a new type, the _channel_. The channel will be a reference-type asynchronous sequence with an asynchronous sending capability that awaits the consumption of iteration. Each value sent by the channel will await the consumption of that value by iteration. That awaiting behavior will allow for the affordance of back pressure applied from the consumption site to be transmitted to the production site. This means that the rate of production cannot exceed the rate of consumption, and that the rate of consumption cannot exceed the rate of production. Sending a terminal event to the channel will instantly resume all pending operations for every producers and consumers. ## Detailed Design @@ -31,7 +31,7 @@ public final class AsyncChannel: AsyncSequence, Sendable { public init(element elementType: Element.Type = Element.self) public func send(_ element: Element) async - public func finish() async + public func finish() public func makeAsyncIterator() -> Iterator } @@ -45,13 +45,13 @@ public final class AsyncThrowingChannel: Asyn public func send(_ element: Element) async public func fail(_ error: Error) async where Failure == Error - public func finish() async + public func finish() public func makeAsyncIterator() -> Iterator } ``` -Channels are intended to be used as communication types between tasks. Particularly when one task produces values and another task consumes said values. The back pressure applied by `send(_:)`, `fail(_:)` and `finish()` via the suspension/resume ensure that the production of values does not exceed the consumption of values from iteration. Each of these methods suspend after enqueuing the event and are resumed when the next call to `next()` on the `Iterator` is made. +Channels are intended to be used as communication types between tasks. Particularly when one task produces values and another task consumes said values. On the one hand, the back pressure applied by `send(_:)` and `fail(_:)` via the suspension/resume ensure that the production of values does not exceed the consumption of values from iteration. Each of these methods suspend after enqueuing the event and are resumed when the next call to `next()` on the `Iterator` is made. On the other hand, the call to `finish()` immediately resumes all the pending operations for every producers and consumers. Thus, every suspended `send(_:)` operations instantly resume, so as every suspended `next()` operations by producing a nil value, indicating the termination of the iterations. Further calls to `send(_:)` will immediately resume. ```swift let channel = AsyncChannel() @@ -59,7 +59,7 @@ Task { while let resultOfLongCalculation = doLongCalculations() { await channel.send(resultOfLongCalculation) } - await channel.finish() + channel.finish() } for await calculationResult in channel { diff --git a/Sources/AsyncAlgorithms/AsyncChannel.swift b/Sources/AsyncAlgorithms/AsyncChannel.swift index a9922a90..0f30a6f6 100644 --- a/Sources/AsyncAlgorithms/AsyncChannel.swift +++ b/Sources/AsyncAlgorithms/AsyncChannel.swift @@ -131,7 +131,12 @@ public final class AsyncChannel: AsyncSequence, Sendable { func next(_ generation: Int) async -> Element? { return await withUnsafeContinuation { continuation in var cancelled = false + var terminal = false state.withCriticalRegion { state -> UnsafeResumption?, Never>? in + if state.terminal { + terminal = true + return nil + } switch state.emission { case .idle: state.emission = .awaiting([Awaiting(generation: generation, continuation: continuation)]) @@ -157,13 +162,13 @@ public final class AsyncChannel: AsyncSequence, Sendable { return nil } }?.resume() - if cancelled { + if cancelled || terminal { continuation.resume(returning: nil) } } } - func cancelSend() { + func finishAll() { let (sends, nexts) = state.withCriticalRegion { state -> ([UnsafeContinuation?, Never>], Set) in if state.terminal { return ([], []) @@ -188,23 +193,15 @@ public final class AsyncChannel: AsyncSequence, Sendable { } } - func _send(_ result: Result) async { + func _send(_ element: Element) async { await withTaskCancellationHandler { - cancelSend() + finishAll() } operation: { let continuation: UnsafeContinuation? = await withUnsafeContinuation { continuation in state.withCriticalRegion { state -> UnsafeResumption?, Never>? in if state.terminal { return UnsafeResumption(continuation: continuation, success: nil) } - switch result { - case .success(let value): - if value == nil { - state.terminal = true - } - case .failure: - state.terminal = true - } switch state.emission { case .idle: state.emission = .pending([continuation]) @@ -224,20 +221,19 @@ public final class AsyncChannel: AsyncSequence, Sendable { } }?.resume() } - continuation?.resume(with: result) + continuation?.resume(returning: element) } } /// Send an element to an awaiting iteration. This function will resume when the next call to `next()` is made. /// If the channel is already finished then this returns immediately public func send(_ element: Element) async { - await _send(.success(element)) + await _send(element) } - /// Send a finish to an awaiting iteration. This function will resume when the next call to `next()` is made. - /// If the channel is already finished then this returns immediately - public func finish() async { - await _send(.success(nil)) + /// Send a finish to all awaiting iterations. + public func finish() { + finishAll() } /// Create an `Iterator` for iteration of an `AsyncChannel` diff --git a/Sources/AsyncAlgorithms/AsyncThrowingChannel.swift b/Sources/AsyncAlgorithms/AsyncThrowingChannel.swift index 205729e3..58c9e8c9 100644 --- a/Sources/AsyncAlgorithms/AsyncThrowingChannel.swift +++ b/Sources/AsyncAlgorithms/AsyncThrowingChannel.swift @@ -129,7 +129,12 @@ public final class AsyncThrowingChannel: Asyn func next(_ generation: Int) async throws -> Element? { return try await withUnsafeThrowingContinuation { continuation in var cancelled = false + var terminal = false state.withCriticalRegion { state -> UnsafeResumption?, Never>? in + if state.terminal { + terminal = true + return nil + } switch state.emission { case .idle: state.emission = .awaiting([Awaiting(generation: generation, continuation: continuation)]) @@ -155,13 +160,13 @@ public final class AsyncThrowingChannel: Asyn return nil } }?.resume() - if cancelled { + if cancelled || terminal { continuation.resume(returning: nil) } } } - func cancelSend() { + func finishAll() { let (sends, nexts) = state.withCriticalRegion { state -> ([UnsafeContinuation?, Never>], Set) in if state.terminal { return ([], []) @@ -186,23 +191,20 @@ public final class AsyncThrowingChannel: Asyn } } - func _send(_ result: Result) async { + func _send(_ result: Result) async { await withTaskCancellationHandler { - cancelSend() + finishAll() } operation: { let continuation: UnsafeContinuation? = await withUnsafeContinuation { continuation in state.withCriticalRegion { state -> UnsafeResumption?, Never>? in if state.terminal { return UnsafeResumption(continuation: continuation, success: nil) } - switch result { - case .success(let value): - if value == nil { - state.terminal = true - } - case .failure: + + if case .failure = result { state.terminal = true } + switch state.emission { case .idle: state.emission = .pending([continuation]) @@ -222,7 +224,7 @@ public final class AsyncThrowingChannel: Asyn } }?.resume() } - continuation?.resume(with: result) + continuation?.resume(with: result.map { $0 as Element? }) } } @@ -238,10 +240,9 @@ public final class AsyncThrowingChannel: Asyn await _send(.failure(error)) } - /// Send a finish to an awaiting iteration. This function will resume when the next call to `next()` is made. - /// If the channel is already finished then this returns immediately - public func finish() async { - await _send(.success(nil)) + /// Send a finish to all awaiting iterations. + public func finish() { + finishAll() } public func makeAsyncIterator() -> Iterator { diff --git a/Tests/AsyncAlgorithmsTests/TestChannel.swift b/Tests/AsyncAlgorithmsTests/TestChannel.swift index 50cb7626..891ec434 100644 --- a/Tests/AsyncAlgorithmsTests/TestChannel.swift +++ b/Tests/AsyncAlgorithmsTests/TestChannel.swift @@ -13,13 +13,16 @@ import AsyncAlgorithms final class TestChannel: XCTestCase { - func test_channel() async { + func test_asyncChannel_delivers_values_when_two_producers_and_two_consumers() async { + let (sentFromProducer1, sentFromProducer2) = ("test1", "test2") + let expected = Set([sentFromProducer1, sentFromProducer2]) + let channel = AsyncChannel() Task { - await channel.send("test1") + await channel.send(sentFromProducer1) } Task { - await channel.send("test2") + await channel.send(sentFromProducer2) } let t: Task = Task { @@ -28,13 +31,17 @@ final class TestChannel: XCTestCase { return value } var iterator = channel.makeAsyncIterator() - let value = await iterator.next() - let other = await t.value - - XCTAssertEqual(Set([value, other]), Set(["test1", "test2"])) + + let (collectedFromConsumer1, collectedFromConsumer2) = (await t.value, await iterator.next()) + let collected = Set([collectedFromConsumer1, collectedFromConsumer2]) + + XCTAssertEqual(collected, expected) } - func test_throwing_channel() async throws { + func test_asyncThrowingChannel_delivers_values_when_two_producers_and_two_consumers() async throws { + let (sentFromProducer1, sentFromProducer2) = ("test1", "test2") + let expected = Set([sentFromProducer1, sentFromProducer2]) + let channel = AsyncThrowingChannel() Task { await channel.send("test1") @@ -49,13 +56,14 @@ final class TestChannel: XCTestCase { return value } var iterator = channel.makeAsyncIterator() - let value = try await iterator.next() - let other = try await t.value + + let (collectedFromConsumer1, collectedFromConsumer2) = (try await t.value, try await iterator.next()) + let collected = Set([collectedFromConsumer1, collectedFromConsumer2]) - XCTAssertEqual(Set([value, other]), Set(["test1", "test2"])) + XCTAssertEqual(collected, expected) } - func test_throwing() async { + func test_asyncThrowingChannel_throws_when_fail_is_called() async { let channel = AsyncThrowingChannel() Task { await channel.fail(Failure()) @@ -63,37 +71,170 @@ final class TestChannel: XCTestCase { var iterator = channel.makeAsyncIterator() do { let _ = try await iterator.next() - XCTFail() + XCTFail("The AsyncThrowingChannel should have thrown") } catch { XCTAssertEqual(error as? Failure, Failure()) } } - - func test_send_finish() async { + + func test_asyncChannel_ends_alls_iterators_and_discards_additional_sent_values_when_finish_is_called() async { let channel = AsyncChannel() let complete = ManagedCriticalState(false) let finished = expectation(description: "finished") + Task { - await channel.finish() + channel.finish() complete.withCriticalRegion { $0 = true } finished.fulfill() } - XCTAssertFalse(complete.withCriticalRegion { $0 }) - let value = ManagedCriticalState(nil) + + let valueFromConsumer1 = ManagedCriticalState(nil) + let valueFromConsumer2 = ManagedCriticalState(nil) + let received = expectation(description: "received") + received.expectedFulfillmentCount = 2 + let pastEnd = expectation(description: "pastEnd") + pastEnd.expectedFulfillmentCount = 2 + Task { var iterator = channel.makeAsyncIterator() let ending = await iterator.next() - value.withCriticalRegion { $0 = ending } + valueFromConsumer1.withCriticalRegion { $0 = ending } received.fulfill() let item = await iterator.next() XCTAssertNil(item) pastEnd.fulfill() } + + Task { + var iterator = channel.makeAsyncIterator() + let ending = await iterator.next() + valueFromConsumer2.withCriticalRegion { $0 = ending } + received.fulfill() + let item = await iterator.next() + XCTAssertNil(item) + pastEnd.fulfill() + } + wait(for: [finished, received], timeout: 1.0) + XCTAssertTrue(complete.withCriticalRegion { $0 }) - XCTAssertEqual(value.withCriticalRegion { $0 }, nil) + XCTAssertEqual(valueFromConsumer1.withCriticalRegion { $0 }, nil) + XCTAssertEqual(valueFromConsumer2.withCriticalRegion { $0 }, nil) + + wait(for: [pastEnd], timeout: 1.0) + let additionalSend = expectation(description: "additional send") + Task { + await channel.send("test") + additionalSend.fulfill() + } + wait(for: [additionalSend], timeout: 1.0) + } + + func test_asyncChannel_ends_alls_iterators_and_discards_additional_sent_values_when_finish_is_called2() async throws { + let channel = AsyncChannel() + let complete = ManagedCriticalState(false) + let finished = expectation(description: "finished") + + let valueFromConsumer1 = ManagedCriticalState(nil) + let valueFromConsumer2 = ManagedCriticalState(nil) + + let received = expectation(description: "received") + received.expectedFulfillmentCount = 2 + + let pastEnd = expectation(description: "pastEnd") + pastEnd.expectedFulfillmentCount = 2 + + Task(priority: .high) { + var iterator = channel.makeAsyncIterator() + let ending = await iterator.next() + valueFromConsumer1.withCriticalRegion { $0 = ending } + received.fulfill() + let item = await iterator.next() + XCTAssertNil(item) + pastEnd.fulfill() + } + + Task(priority: .high) { + var iterator = channel.makeAsyncIterator() + let ending = await iterator.next() + valueFromConsumer2.withCriticalRegion { $0 = ending } + received.fulfill() + let item = await iterator.next() + XCTAssertNil(item) + pastEnd.fulfill() + } + + try await Task.sleep(nanoseconds: 1_000_000_000) + + Task(priority: .low) { + channel.finish() + complete.withCriticalRegion { $0 = true } + finished.fulfill() + } + + wait(for: [finished, received], timeout: 1.0) + + XCTAssertTrue(complete.withCriticalRegion { $0 }) + XCTAssertEqual(valueFromConsumer1.withCriticalRegion { $0 }, nil) + XCTAssertEqual(valueFromConsumer2.withCriticalRegion { $0 }, nil) + + wait(for: [pastEnd], timeout: 1.0) + let additionalSend = expectation(description: "additional send") + Task { + await channel.send("test") + additionalSend.fulfill() + } + wait(for: [additionalSend], timeout: 1.0) + } + + func test_asyncThrowingChannel_ends_alls_iterators_and_discards_additional_sent_values_when_finish_is_called() async { + let channel = AsyncThrowingChannel() + let complete = ManagedCriticalState(false) + let finished = expectation(description: "finished") + + Task { + channel.finish() + complete.withCriticalRegion { $0 = true } + finished.fulfill() + } + + let valueFromConsumer1 = ManagedCriticalState(nil) + let valueFromConsumer2 = ManagedCriticalState(nil) + + let received = expectation(description: "received") + received.expectedFulfillmentCount = 2 + + let pastEnd = expectation(description: "pastEnd") + pastEnd.expectedFulfillmentCount = 2 + + Task { + var iterator = channel.makeAsyncIterator() + let ending = try await iterator.next() + valueFromConsumer1.withCriticalRegion { $0 = ending } + received.fulfill() + let item = try await iterator.next() + XCTAssertNil(item) + pastEnd.fulfill() + } + + Task { + var iterator = channel.makeAsyncIterator() + let ending = try await iterator.next() + valueFromConsumer2.withCriticalRegion { $0 = ending } + received.fulfill() + let item = try await iterator.next() + XCTAssertNil(item) + pastEnd.fulfill() + } + + wait(for: [finished, received], timeout: 1.0) + + XCTAssertTrue(complete.withCriticalRegion { $0 }) + XCTAssertEqual(valueFromConsumer1.withCriticalRegion { $0 }, nil) + XCTAssertEqual(valueFromConsumer2.withCriticalRegion { $0 }, nil) + wait(for: [pastEnd], timeout: 1.0) let additionalSend = expectation(description: "additional send") Task { @@ -103,7 +244,7 @@ final class TestChannel: XCTestCase { wait(for: [additionalSend], timeout: 1.0) } - func test_cancellation() async { + func test_asyncChannel_ends_iterator_when_task_is_cancelled() async { let channel = AsyncChannel() let ready = expectation(description: "ready") let task: Task = Task { @@ -116,8 +257,22 @@ final class TestChannel: XCTestCase { let value = await task.value XCTAssertNil(value) } + + func test_asyncThrowingChannel_ends_iterator_when_task_is_cancelled() async throws { + let channel = AsyncThrowingChannel() + let ready = expectation(description: "ready") + let task: Task = Task { + var iterator = channel.makeAsyncIterator() + ready.fulfill() + return try await iterator.next() + } + wait(for: [ready], timeout: 1.0) + task.cancel() + let value = try await task.value + XCTAssertNil(value) + } - func test_sendCancellation() async { + func test_asyncChannel_resumes_send_when_task_is_cancelled() async { let channel = AsyncChannel() let notYetDone = expectation(description: "not yet done") notYetDone.isInverted = true @@ -132,7 +287,7 @@ final class TestChannel: XCTestCase { wait(for: [done], timeout: 1.0) } - func test_sendCancellation_throwing() async { + func test_asyncThrowingChannel_resumes_send_when_task_is_cancelled() async { let channel = AsyncThrowingChannel() let notYetDone = expectation(description: "not yet done") notYetDone.isInverted = true @@ -146,18 +301,4 @@ final class TestChannel: XCTestCase { task.cancel() wait(for: [done], timeout: 1.0) } - - func test_cancellation_throwing() async throws { - let channel = AsyncThrowingChannel() - let ready = expectation(description: "ready") - let task: Task = Task { - var iterator = channel.makeAsyncIterator() - ready.fulfill() - return try await iterator.next() - } - wait(for: [ready], timeout: 1.0) - task.cancel() - let value = try await task.value - XCTAssertNil(value) - } }