diff --git a/stdlib/public/Concurrency/AsyncStream.swift b/stdlib/public/Concurrency/AsyncStream.swift index ca6eb979695c1..bbec11db6cb17 100644 --- a/stdlib/public/Concurrency/AsyncStream.swift +++ b/stdlib/public/Concurrency/AsyncStream.swift @@ -336,18 +336,20 @@ public struct AsyncStream { unfolding produce: @escaping @Sendable () async -> Element?, onCancel: (@Sendable () -> Void)? = nil ) { - let storage: _AsyncStreamCriticalStorage Element?>> - = .create(produce) + let storage = _AsyncStreamCriticalStorage<_UnfoldingState?>.create( + _UnfoldingState(produce: produce, onCancel: onCancel) + ) + context = _Context { return await withTaskCancellationHandler { - guard let result = await storage.value?() else { + guard let result = await storage.value?.produce() else { storage.value = nil return nil } return result } onCancel: { - storage.value = nil - onCancel?() + let state = storage.withLock { $0.take() } + state?.onCancel?() } } } diff --git a/stdlib/public/Concurrency/AsyncStreamBuffer.swift b/stdlib/public/Concurrency/AsyncStreamBuffer.swift index bc73a4d7ad078..17519afca562c 100644 --- a/stdlib/public/Concurrency/AsyncStreamBuffer.swift +++ b/stdlib/public/Concurrency/AsyncStreamBuffer.swift @@ -276,6 +276,14 @@ extension AsyncStream { return storage } } + + // MARK: - Unfolding + + /// State for the `AsyncStream.init(unfolding:)` variant. + internal struct _UnfoldingState { + var produce: @Sendable () async -> Element? + var onCancel: (@Sendable () -> Void)? + } } @available(SwiftStdlib 5.1, *) @@ -540,6 +548,7 @@ extension AsyncThrowingStream { // this is used to store closures; which are two words final class _AsyncStreamCriticalStorage: @unchecked Sendable { + // FIXME: stop paying the cost of exclusivity checks when accessing this var _value: Contents private init(_doNotCallMe: ()) { fatalError("_AsyncStreamCriticalStorage must be initialized by create") @@ -559,21 +568,25 @@ final class _AsyncStreamCriticalStorage: @unchecked Sendable { var value: Contents { get { - lock() - let contents = _value - unlock() - return contents + self.withLock { $0 } } set { - lock() - withExtendedLifetime(_value) { - _value = newValue - unlock() + let oldValue = self.withLock { + let old = $0 + $0 = newValue + return old } + extendLifetime(oldValue) } } + func withLock(_ body: (inout Contents) -> Result) -> Result { + lock() + defer { unlock() } + return body(&self._value) + } + static func create(_ initial: Contents) -> _AsyncStreamCriticalStorage { let minimumCapacity = _lockWordCount() let storage = unsafe Builtin.allocWithTailElems_1( diff --git a/test/Concurrency/Runtime/async_stream.swift b/test/Concurrency/Runtime/async_stream.swift index 553d4d736bbd8..5faa62714bc5d 100644 --- a/test/Concurrency/Runtime/async_stream.swift +++ b/test/Concurrency/Runtime/async_stream.swift @@ -487,6 +487,57 @@ class NotSendable {} _ = await consumer2.value } + // MARK: - Unfolding + + tests.test("unfolding stream calls onCancel at most once") { @MainActor in + nonisolated(unsafe) var cancelCallbackCount = 0 + + let (innerControlStream, _) = AsyncStream.makeStream() + let (outerControlStream, outerContinuation) = AsyncStream.makeStream() + + let task = Task { @MainActor in + let stream = AsyncStream { @MainActor in + outerContinuation.yield("started") + do { + var iter = innerControlStream.makeAsyncIterator() + let next = await iter.next() + assert(next == nil) // should only return from Task cancellation + } + return 42 + } onCancel: { + cancelCallbackCount += 1 + } + + for await value in stream { + _ = value + } + } + + // wait for unfolding closure to start + do { + var iter = outerControlStream.makeAsyncIterator() + let next = await iter.next() + assert(next == "started") + } + + // ensure iterator is suspended + await MainActor.run {} + + // cancel task + task.cancel() + + // cancel callback should be invoked + expectEqual(cancelCallbackCount, 1) + + // ensure task completes + _ = await task.value + + // check that the cancel callback wasn't invoked again + expectEqual(cancelCallbackCount, 1) + } + + // MARK: - + await runAllTestsAsync() } }