diff --git a/library/std/src/sync/mod.rs b/library/std/src/sync/mod.rs index e67b4f6f22f5a..a68045da8aa32 100644 --- a/library/std/src/sync/mod.rs +++ b/library/std/src/sync/mod.rs @@ -224,6 +224,8 @@ pub use self::poison::{MappedMutexGuard, MappedRwLockReadGuard, MappedRwLockWrit #[unstable(feature = "mpmc_channel", issue = "126840")] pub mod mpmc; pub mod mpsc; +#[unstable(feature = "oneshot_channel", issue = "143674")] +pub mod oneshot; #[unstable(feature = "sync_poison_mod", issue = "134646")] pub mod poison; diff --git a/library/std/src/sync/oneshot.rs b/library/std/src/sync/oneshot.rs new file mode 100644 index 0000000000000..9a0662bf9039b --- /dev/null +++ b/library/std/src/sync/oneshot.rs @@ -0,0 +1,465 @@ +//! A single-producer, single-consumer (oneshot) channel. +//! +//! This is an experimental module, so the API will likely change. + +use crate::sync::mpmc; +use crate::sync::mpsc::{RecvError, SendError}; +use crate::time::{Duration, Instant}; +use crate::{error, fmt}; + +/// Creates a new oneshot channel, returning the sender/receiver halves. +/// +/// # Examples +/// +/// ``` +/// # #![feature(oneshot_channel)] +/// # use std::sync::oneshot; +/// # use std::thread; +/// # +/// let (sender, receiver) = oneshot::channel(); +/// +/// // Spawn off an expensive computation. +/// thread::spawn(move || { +/// # fn expensive_computation() -> i32 { 42 } +/// sender.send(expensive_computation()).unwrap(); +/// // `sender` is consumed by `send`, so we cannot use it anymore. +/// }); +/// +/// # fn do_other_work() -> i32 { 42 } +/// do_other_work(); +/// +/// // Let's see what that answer was... +/// println!("{:?}", receiver.recv().unwrap()); +/// // `receiver` is consumed by `recv`, so we cannot use it anymore. +/// ``` +#[must_use] +#[unstable(feature = "oneshot_channel", issue = "143674")] +pub fn channel() -> (Sender, Receiver) { + // Using a `sync_channel` with capacity 1 means that the internal implementation will use the + // `Array`-flavored channel implementation. + let (sender, receiver) = mpmc::sync_channel(1); + (Sender { inner: sender }, Receiver { inner: receiver }) +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Sender +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// The sending half of a oneshot channel. +/// +/// # Examples +/// +/// ``` +/// # #![feature(oneshot_channel)] +/// # use std::sync::oneshot; +/// # use std::thread; +/// # +/// let (sender, receiver) = oneshot::channel(); +/// +/// thread::spawn(move || { +/// sender.send("Hello from thread!").unwrap(); +/// }); +/// +/// assert_eq!(receiver.recv().unwrap(), "Hello from thread!"); +/// ``` +/// +/// `Sender` cannot be sent between threads if it is sending non-`Send` types. +/// +/// ```compile_fail +/// # #![feature(oneshot_channel)] +/// # use std::sync::oneshot; +/// # use std::thread; +/// # use std::ptr; +/// # +/// let (sender, receiver) = oneshot::channel(); +/// +/// struct NotSend(*mut ()); +/// thread::spawn(move || { +/// sender.send(NotSend(ptr::null_mut())); +/// }); +/// +/// let reply = receiver.try_recv().unwrap(); +/// ``` +#[unstable(feature = "oneshot_channel", issue = "143674")] +pub struct Sender { + /// The `oneshot` channel is simply a wrapper around a `mpmc` channel. + inner: mpmc::Sender, +} + +/// # Safety +/// +/// Since the only methods in which synchronization must occur take full ownership of the +/// [`Sender`], it is perfectly safe to share a `&Sender` between threads (as it is effectively +/// useless without ownership). +#[unstable(feature = "oneshot_channel", issue = "143674")] +unsafe impl Sync for Sender {} + +impl Sender { + /// Attempts to send a value through this channel. This can only fail if the corresponding + /// [`Receiver`] has been dropped. + /// + /// This method is non-blocking (wait-free). + /// + /// # Examples + /// + /// ``` + /// # #![feature(oneshot_channel)] + /// # use std::sync::oneshot; + /// # use std::thread; + /// # + /// let (tx, rx) = oneshot::channel(); + /// + /// thread::spawn(move || { + /// // Perform some computation. + /// let result = 2 + 2; + /// tx.send(result).unwrap(); + /// }); + /// + /// assert_eq!(rx.recv().unwrap(), 4); + /// ``` + #[unstable(feature = "oneshot_channel", issue = "143674")] + pub fn send(self, t: T) -> Result<(), SendError> { + self.inner.send(t) + } +} + +#[unstable(feature = "oneshot_channel", issue = "143674")] +impl fmt::Debug for Sender { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.pad("Sender { .. }") + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Receiver +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// The receiving half of a oneshot channel. +/// +/// # Examples +/// +/// ``` +/// # #![feature(oneshot_channel)] +/// # use std::sync::oneshot; +/// # use std::thread; +/// # use std::time::Duration; +/// # +/// let (sender, receiver) = oneshot::channel(); +/// +/// thread::spawn(move || { +/// thread::sleep(Duration::from_millis(100)); +/// sender.send("Hello after delay!").unwrap(); +/// }); +/// +/// println!("Waiting for message..."); +/// println!("{}", receiver.recv().unwrap()); +/// ``` +/// +/// `Receiver` cannot be sent between threads if it is receiving non-`Send` types. +/// +/// ```compile_fail +/// # #![feature(oneshot_channel)] +/// # use std::sync::oneshot; +/// # use std::thread; +/// # use std::ptr; +/// # +/// let (sender, receiver) = oneshot::channel(); +/// +/// struct NotSend(*mut ()); +/// sender.send(NotSend(ptr::null_mut())); +/// +/// thread::spawn(move || { +/// let reply = receiver.try_recv().unwrap(); +/// }); +/// ``` +#[unstable(feature = "oneshot_channel", issue = "143674")] +pub struct Receiver { + /// The `oneshot` channel is simply a wrapper around a `mpmc` channel. + inner: mpmc::Receiver, +} + +/// # Safety +/// +/// Since the only methods in which synchronization must occur take full ownership of the +/// [`Receiver`], it is perfectly safe to share a `&Receiver` between threads (as it is unable to +/// receive any values without ownership). +#[unstable(feature = "oneshot_channel", issue = "143674")] +unsafe impl Sync for Receiver {} + +impl Receiver { + /// Receives the value from the sending end, blocking the calling thread until it gets it. + /// + /// Can only fail if the corresponding [`Sender`] has been dropped. + /// + /// # Examples + /// + /// ``` + /// # #![feature(oneshot_channel)] + /// # use std::sync::oneshot; + /// # use std::thread; + /// # use std::time::Duration; + /// # + /// let (tx, rx) = oneshot::channel(); + /// + /// thread::spawn(move || { + /// thread::sleep(Duration::from_millis(500)); + /// tx.send("Done!").unwrap(); + /// }); + /// + /// // This will block until the message arrives. + /// println!("{}", rx.recv().unwrap()); + /// ``` + #[unstable(feature = "oneshot_channel", issue = "143674")] + pub fn recv(self) -> Result { + self.inner.recv() + } + + // Fallible methods. + + /// Attempts to return a pending value on this receiver without blocking. + /// + /// # Examples + /// + /// ``` + /// # #![feature(oneshot_channel)] + /// # use std::sync::oneshot; + /// # use std::thread; + /// # use std::time::Duration; + /// # + /// let (sender, mut receiver) = oneshot::channel(); + /// + /// thread::spawn(move || { + /// thread::sleep(Duration::from_millis(100)); + /// sender.send(42).unwrap(); + /// }); + /// + /// // Keep trying until we get the message, doing other work in the process. + /// loop { + /// match receiver.try_recv() { + /// Ok(value) => { + /// assert_eq!(value, 42); + /// break; + /// } + /// Err(oneshot::TryRecvError::Empty(rx)) => { + /// // Retake ownership of the receiver. + /// receiver = rx; + /// # fn do_other_work() { thread::sleep(Duration::from_millis(25)); } + /// do_other_work(); + /// } + /// Err(oneshot::TryRecvError::Disconnected) => panic!("Sender disconnected"), + /// } + /// } + /// ``` + #[unstable(feature = "oneshot_channel", issue = "143674")] + pub fn try_recv(self) -> Result> { + self.inner.try_recv().map_err(|err| match err { + mpmc::TryRecvError::Empty => TryRecvError::Empty(self), + mpmc::TryRecvError::Disconnected => TryRecvError::Disconnected, + }) + } + + /// Attempts to wait for a value on this receiver, returning an error if the corresponding + /// [`Sender`] half of this channel has been dropped, or if it waits more than `timeout`. + /// + /// # Examples + /// + /// ``` + /// # #![feature(oneshot_channel)] + /// # use std::sync::oneshot; + /// # use std::thread; + /// # use std::time::Duration; + /// # + /// let (sender, receiver) = oneshot::channel(); + /// + /// thread::spawn(move || { + /// thread::sleep(Duration::from_millis(500)); + /// sender.send("Success!").unwrap(); + /// }); + /// + /// // Wait up to 1 second for the message + /// match receiver.recv_timeout(Duration::from_secs(1)) { + /// Ok(msg) => println!("Received: {}", msg), + /// Err(oneshot::RecvTimeoutError::Timeout(_)) => println!("Timed out!"), + /// Err(oneshot::RecvTimeoutError::Disconnected) => println!("Sender dropped!"), + /// } + /// ``` + #[unstable(feature = "oneshot_channel", issue = "143674")] + pub fn recv_timeout(self, timeout: Duration) -> Result> { + self.inner.recv_timeout(timeout).map_err(|err| match err { + mpmc::RecvTimeoutError::Timeout => RecvTimeoutError::Timeout(self), + mpmc::RecvTimeoutError::Disconnected => RecvTimeoutError::Disconnected, + }) + } + + /// Attempts to wait for a value on this receiver, returning an error if the corresponding + /// [`Sender`] half of this channel has been dropped, or if `deadline` is reached. + /// + /// # Examples + /// + /// ``` + /// # #![feature(oneshot_channel)] + /// # use std::sync::oneshot; + /// # use std::thread; + /// # use std::time::{Duration, Instant}; + /// # + /// let (sender, receiver) = oneshot::channel(); + /// + /// thread::spawn(move || { + /// thread::sleep(Duration::from_millis(100)); + /// sender.send("Just in time!").unwrap(); + /// }); + /// + /// let deadline = Instant::now() + Duration::from_millis(500); + /// match receiver.recv_deadline(deadline) { + /// Ok(msg) => println!("Received: {}", msg), + /// Err(oneshot::RecvTimeoutError::Timeout(_)) => println!("Missed deadline!"), + /// Err(oneshot::RecvTimeoutError::Disconnected) => println!("Sender dropped!"), + /// } + /// ``` + #[unstable(feature = "oneshot_channel", issue = "143674")] + pub fn recv_deadline(self, deadline: Instant) -> Result> { + self.inner.recv_deadline(deadline).map_err(|err| match err { + mpmc::RecvTimeoutError::Timeout => RecvTimeoutError::Timeout(self), + mpmc::RecvTimeoutError::Disconnected => RecvTimeoutError::Disconnected, + }) + } +} + +#[unstable(feature = "oneshot_channel", issue = "143674")] +impl fmt::Debug for Receiver { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.pad("Receiver { .. }") + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Receiver Errors +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// An error returned from the [`try_recv`](Receiver::try_recv) method. +/// +/// See the documentation for [`try_recv`] for more information on how to use this error. +/// +/// [`try_recv`]: Receiver::try_recv +#[unstable(feature = "oneshot_channel", issue = "143674")] +pub enum TryRecvError { + /// The [`Sender`] has not sent a message yet, but it might in the future (as it has not yet + /// disconnected). This variant contains the [`Receiver`] that [`try_recv`](Receiver::try_recv) + /// took ownership over. + Empty(Receiver), + /// The corresponding [`Sender`] half of this channel has become disconnected, and there will + /// never be any more data sent over the channel. + Disconnected, +} + +/// An error returned from the [`recv_timeout`](Receiver::recv_timeout) or +/// [`recv_deadline`](Receiver::recv_deadline) methods. +/// +/// # Examples +/// +/// Usage of this error is similar to [`TryRecvError`]. +/// +/// ``` +/// # #![feature(oneshot_channel)] +/// # use std::sync::oneshot::{self, RecvTimeoutError}; +/// # use std::thread; +/// # use std::time::Duration; +/// # +/// let (sender, receiver) = oneshot::channel(); +/// +/// thread::spawn(move || { +/// // Simulate a long computation that takes longer than our timeout. +/// thread::sleep(Duration::from_millis(500)); +/// sender.send("Too late!".to_string()).unwrap(); +/// }); +/// +/// // Try to receive the message with a short timeout. +/// match receiver.recv_timeout(Duration::from_millis(100)) { +/// Ok(msg) => println!("Received: {}", msg), +/// Err(RecvTimeoutError::Timeout(rx)) => { +/// println!("Timed out waiting for message!"); +/// // You can reuse the receiver if needed. +/// drop(rx); +/// } +/// Err(RecvTimeoutError::Disconnected) => println!("Sender dropped!"), +/// } +/// ``` +#[unstable(feature = "oneshot_channel", issue = "143674")] +pub enum RecvTimeoutError { + /// The [`Sender`] has not sent a message yet, but it might in the future (as it has not yet + /// disconnected). This variant contains the [`Receiver`] that either + /// [`recv_timeout`](Receiver::recv_timeout) or [`recv_deadline`](Receiver::recv_deadline) took + /// ownership over. + Timeout(Receiver), + /// The corresponding [`Sender`] half of this channel has become disconnected, and there will + /// never be any more data sent over the channel. + Disconnected, +} + +#[unstable(feature = "oneshot_channel", issue = "143674")] +impl fmt::Debug for TryRecvError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + "TryRecvError(..)".fmt(f) + } +} + +#[unstable(feature = "oneshot_channel", issue = "143674")] +impl fmt::Display for TryRecvError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + TryRecvError::Empty(..) => "receiving on an empty channel".fmt(f), + TryRecvError::Disconnected => "receiving on a closed channel".fmt(f), + } + } +} + +#[unstable(feature = "oneshot_channel", issue = "143674")] +impl error::Error for TryRecvError {} + +#[unstable(feature = "oneshot_channel", issue = "143674")] +impl From for TryRecvError { + /// Converts a `RecvError` into a `TryRecvError`. + /// + /// This conversion always returns `TryRecvError::Disconnected`. + /// + /// No data is allocated on the heap. + fn from(err: RecvError) -> TryRecvError { + match err { + RecvError => TryRecvError::Disconnected, + } + } +} + +#[unstable(feature = "oneshot_channel", issue = "143674")] +impl fmt::Debug for RecvTimeoutError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + "RecvTimeoutError(..)".fmt(f) + } +} + +#[unstable(feature = "oneshot_channel", issue = "143674")] +impl fmt::Display for RecvTimeoutError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + RecvTimeoutError::Timeout(..) => "timed out waiting on channel".fmt(f), + RecvTimeoutError::Disconnected => "receiving on a closed channel".fmt(f), + } + } +} + +#[unstable(feature = "oneshot_channel", issue = "143674")] +impl error::Error for RecvTimeoutError {} + +#[unstable(feature = "oneshot_channel", issue = "143674")] +impl From for RecvTimeoutError { + /// Converts a `RecvError` into a `RecvTimeoutError`. + /// + /// This conversion always returns `RecvTimeoutError::Disconnected`. + /// + /// No data is allocated on the heap. + fn from(err: RecvError) -> RecvTimeoutError { + match err { + RecvError => RecvTimeoutError::Disconnected, + } + } +} diff --git a/library/std/tests/sync/lib.rs b/library/std/tests/sync/lib.rs index 51190f0894fb7..a84aa58cc7a13 100644 --- a/library/std/tests/sync/lib.rs +++ b/library/std/tests/sync/lib.rs @@ -1,6 +1,7 @@ #![feature(lazy_get)] #![feature(mapped_lock_guards)] #![feature(mpmc_channel)] +#![feature(oneshot_channel)] #![feature(once_cell_try)] #![feature(lock_value_accessors)] #![feature(reentrant_lock)] @@ -23,6 +24,8 @@ mod mutex; mod once; mod once_lock; #[cfg(not(any(target_os = "emscripten", target_os = "wasi")))] +mod oneshot; +#[cfg(not(any(target_os = "emscripten", target_os = "wasi")))] mod reentrant_lock; #[cfg(not(any(target_os = "emscripten", target_os = "wasi")))] mod rwlock; diff --git a/library/std/tests/sync/oneshot.rs b/library/std/tests/sync/oneshot.rs new file mode 100644 index 0000000000000..6a87c72b9cb5b --- /dev/null +++ b/library/std/tests/sync/oneshot.rs @@ -0,0 +1,342 @@ +//! Inspired by tests from + +use std::sync::mpsc::RecvError; +use std::sync::oneshot; +use std::sync::oneshot::{RecvTimeoutError, TryRecvError}; +use std::time::{Duration, Instant}; +use std::{mem, thread}; + +#[test] +fn send_before_try_recv() { + let (sender, receiver) = oneshot::channel(); + + assert!(sender.send(19i128).is_ok()); + + match receiver.try_recv() { + Ok(19) => {} + _ => panic!("expected Ok(19)"), + } +} + +#[test] +fn send_before_recv() { + let (sender, receiver) = oneshot::channel::<()>(); + + assert!(sender.send(()).is_ok()); + assert_eq!(receiver.recv(), Ok(())); + + let (sender, receiver) = oneshot::channel::(); + + assert!(sender.send(42).is_ok()); + assert_eq!(receiver.recv(), Ok(42)); + + let (sender, receiver) = oneshot::channel::<[u8; 4096]>(); + + assert!(sender.send([0b10101010; 4096]).is_ok()); + assert!(receiver.recv().unwrap()[..] == [0b10101010; 4096][..]); +} + +#[test] +fn sender_drop() { + { + let (sender, receiver) = oneshot::channel::(); + + mem::drop(sender); + + match receiver.recv() { + Err(RecvError) => {} + _ => panic!("expected recv error"), + } + } + + { + let (sender, receiver) = oneshot::channel::(); + + mem::drop(sender); + + match receiver.try_recv() { + Err(TryRecvError::Disconnected) => {} + _ => panic!("expected disconnected error"), + } + } + { + let (sender, receiver) = oneshot::channel::(); + + mem::drop(sender); + + match receiver.recv_timeout(Duration::from_secs(1)) { + Err(RecvTimeoutError::Disconnected) => {} + _ => panic!("expected disconnected error"), + } + } +} + +#[test] +fn send_never_deadline() { + let (sender, receiver) = oneshot::channel::(); + + mem::drop(sender); + + match receiver.recv_deadline(Instant::now()) { + Err(RecvTimeoutError::Disconnected) => {} + _ => panic!("expected disconnected error"), + } +} + +#[test] +fn send_before_recv_timeout() { + let (sender, receiver) = oneshot::channel(); + + assert!(sender.send(22i128).is_ok()); + + let start = Instant::now(); + + let timeout = Duration::from_secs(1); + match receiver.recv_timeout(timeout) { + Ok(22) => {} + _ => panic!("expected Ok(22)"), + } + + assert!(start.elapsed() < timeout); +} + +#[test] +fn send_error() { + let (sender, receiver) = oneshot::channel(); + + mem::drop(receiver); + + let send_error = sender.send(32u128).unwrap_err(); + assert_eq!(send_error.0, 32); +} + +#[test] +fn recv_before_send() { + let (sender, receiver) = oneshot::channel(); + + let t1 = thread::spawn(move || { + thread::sleep(Duration::from_millis(10)); + sender.send(9u128).unwrap(); + }); + let t2 = thread::spawn(move || { + assert_eq!(receiver.recv(), Ok(9)); + }); + + t1.join().unwrap(); + t2.join().unwrap(); +} + +#[test] +fn recv_timeout_before_send() { + let (sender, receiver) = oneshot::channel(); + + let t = thread::spawn(move || { + thread::sleep(Duration::from_millis(100)); + sender.send(99u128).unwrap(); + }); + + match receiver.recv_timeout(Duration::from_secs(1)) { + Ok(99) => {} + _ => panic!("expected Ok(99)"), + } + + t.join().unwrap(); +} + +#[test] +fn recv_then_drop_sender() { + let (sender, receiver) = oneshot::channel::(); + + let t1 = thread::spawn(move || match receiver.recv() { + Err(RecvError) => {} + _ => panic!("expected recv error"), + }); + + let t2 = thread::spawn(move || { + thread::sleep(Duration::from_millis(10)); + mem::drop(sender); + }); + + t1.join().unwrap(); + t2.join().unwrap(); +} + +#[test] +fn drop_sender_then_recv() { + let (sender, receiver) = oneshot::channel::(); + + let t1 = thread::spawn(move || { + thread::sleep(Duration::from_millis(10)); + mem::drop(sender); + }); + + let t2 = thread::spawn(move || match receiver.recv() { + Err(RecvError) => {} + _ => panic!("expected disconnected error"), + }); + + t1.join().unwrap(); + t2.join().unwrap(); +} + +#[test] +fn try_recv_empty() { + let (sender, receiver) = oneshot::channel::(); + match receiver.try_recv() { + Err(TryRecvError::Empty(_)) => {} + _ => panic!("expected empty error"), + } + mem::drop(sender); +} + +#[test] +fn try_recv_then_drop_receiver() { + let (sender, receiver) = oneshot::channel::(); + + let t1 = thread::spawn(move || { + thread::sleep(Duration::from_millis(100)); + let _ = sender.send(42); + }); + + let t2 = thread::spawn(move || match receiver.try_recv() { + Ok(_) => {} + Err(TryRecvError::Empty(r)) => { + mem::drop(r); + } + Err(TryRecvError::Disconnected) => {} + }); + + t2.join().unwrap(); + t1.join().unwrap(); +} + +#[test] +fn recv_no_time() { + let (_sender, receiver) = oneshot::channel::(); + + let start = Instant::now(); + match receiver.recv_deadline(start) { + Err(RecvTimeoutError::Timeout(_)) => {} + _ => panic!("expected timeout error"), + } + + let (_sender, receiver) = oneshot::channel::(); + match receiver.recv_timeout(Duration::from_millis(0)) { + Err(RecvTimeoutError::Timeout(_)) => {} + _ => panic!("expected timeout error"), + } +} + +#[test] +fn recv_deadline_passed() { + let (_sender, receiver) = oneshot::channel::(); + + let start = Instant::now(); + let timeout = Duration::from_millis(100); + + match receiver.recv_deadline(start + timeout) { + Err(RecvTimeoutError::Timeout(_)) => {} + _ => panic!("expected timeout error"), + } + + assert!(start.elapsed() >= timeout); + assert!(start.elapsed() < timeout * 3); +} + +#[test] +fn recv_time_passed() { + let (_sender, receiver) = oneshot::channel::(); + + let start = Instant::now(); + let timeout = Duration::from_millis(100); + match receiver.recv_timeout(timeout) { + Err(RecvTimeoutError::Timeout(_)) => {} + _ => panic!("expected timeout error"), + } + assert!(start.elapsed() >= timeout); + assert!(start.elapsed() < timeout * 3); +} + +#[test] +fn non_send_type_can_be_used_on_same_thread() { + use std::ptr; + + #[derive(Debug, Eq, PartialEq)] + struct NotSend(*mut ()); + + let (sender, receiver) = oneshot::channel(); + sender.send(NotSend(ptr::null_mut())).unwrap(); + let reply = receiver.try_recv().unwrap(); + assert_eq!(reply, NotSend(ptr::null_mut())); +} + +/// Helper for testing drop behavior (taken directly from the `oneshot` crate). +struct DropCounter { + count: std::rc::Rc>, +} + +impl DropCounter { + fn new() -> (DropTracker, DropCounter) { + let count = std::rc::Rc::new(std::cell::RefCell::new(0)); + (DropTracker { count: count.clone() }, DropCounter { count }) + } + + fn count(&self) -> usize { + *self.count.borrow() + } +} + +struct DropTracker { + count: std::rc::Rc>, +} + +impl Drop for DropTracker { + fn drop(&mut self) { + *self.count.borrow_mut() += 1; + } +} + +#[test] +fn message_in_channel_dropped_on_receiver_drop() { + let (sender, receiver) = oneshot::channel(); + + let (message, counter) = DropCounter::new(); + assert_eq!(counter.count(), 0); + + sender.send(message).unwrap(); + assert_eq!(counter.count(), 0); + + mem::drop(receiver); + assert_eq!(counter.count(), 1); +} + +#[test] +fn send_error_drops_message_correctly() { + let (sender, receiver) = oneshot::channel(); + mem::drop(receiver); + + let (message, counter) = DropCounter::new(); + + let send_error = sender.send(message).unwrap_err(); + assert_eq!(counter.count(), 0); + + mem::drop(send_error); + assert_eq!(counter.count(), 1); +} + +#[test] +fn send_error_drops_message_correctly_on_extract() { + let (sender, receiver) = oneshot::channel(); + mem::drop(receiver); + + let (message, counter) = DropCounter::new(); + + let send_error = sender.send(message).unwrap_err(); + assert_eq!(counter.count(), 0); + + let message = send_error.0; // Access the inner value directly + assert_eq!(counter.count(), 0); + + mem::drop(message); + assert_eq!(counter.count(), 1); +}