From 4bbcc2a7fbf1a363719d224dd80dfbcbbee6f26e Mon Sep 17 00:00:00 2001 From: Peter Krull Date: Sun, 3 Mar 2024 15:35:52 +0100 Subject: Add OnceLock sync primitive --- embassy-sync/src/lib.rs | 1 + embassy-sync/src/once_lock.rs | 237 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 238 insertions(+) create mode 100644 embassy-sync/src/once_lock.rs (limited to 'embassy-sync') diff --git a/embassy-sync/src/lib.rs b/embassy-sync/src/lib.rs index d88c76db5..61b173e80 100644 --- a/embassy-sync/src/lib.rs +++ b/embassy-sync/src/lib.rs @@ -13,6 +13,7 @@ mod ring_buffer; pub mod blocking_mutex; pub mod channel; pub mod mutex; +pub mod once_lock; pub mod pipe; pub mod priority_channel; pub mod pubsub; diff --git a/embassy-sync/src/once_lock.rs b/embassy-sync/src/once_lock.rs new file mode 100644 index 000000000..f83577a6d --- /dev/null +++ b/embassy-sync/src/once_lock.rs @@ -0,0 +1,237 @@ +//! Syncronization primitive for initializing a value once, allowing others to await a reference to the value. + +use core::cell::Cell; +use core::future::poll_fn; +use core::mem::MaybeUninit; +use core::sync::atomic::{AtomicBool, Ordering}; +use core::task::Poll; + +/// The `OnceLock` is a synchronization primitive that allows for +/// initializing a value once, and allowing others to `.await` a +/// reference to the value. This is useful for lazy initialization of +/// a static value. +/// +/// **Note**: this implementation uses a busy loop to poll the value, +/// which is not as efficient as registering a dedicated `Waker`. +/// However, the if the usecase for is to initialize a static variable +/// relatively early in the program life cycle, it should be fine. +/// +/// # Example +/// ``` +/// use futures_executor::block_on; +/// use embassy_sync::once_lock::OnceLock; +/// +/// // Define a static value that will be lazily initialized +/// static VALUE: OnceLock = OnceLock::new(); +/// +/// let f = async { +/// +/// // Initialize the value +/// let reference = VALUE.get_or_init(|| 20); +/// assert_eq!(reference, &20); +/// +/// // Wait for the value to be initialized +/// // and get a static reference it +/// assert_eq!(VALUE.get().await, &20); +/// +/// }; +/// block_on(f) +/// ``` +pub struct OnceLock { + init: AtomicBool, + data: Cell>, +} + +unsafe impl Sync for OnceLock {} + +impl OnceLock { + /// Create a new uninitialized `OnceLock`. + pub const fn new() -> Self { + Self { + init: AtomicBool::new(false), + data: Cell::new(MaybeUninit::zeroed()), + } + } + + /// Get a reference to the underlying value, waiting for it to be set. + /// If the value is already set, this will return immediately. + pub async fn get(&self) -> &T { + + poll_fn(|cx| match self.try_get() { + Some(data) => Poll::Ready(data), + None => { + cx.waker().wake_by_ref(); + Poll::Pending + } + }) + .await + } + + /// Try to get a reference to the underlying value if it exists. + pub fn try_get(&self) -> Option<&T> { + if self.init.load(Ordering::Relaxed) { + Some(unsafe { self.get_ref_unchecked() }) + } else { + None + } + } + + /// Set the underlying value. If the value is already set, this will return an error with the given value. + pub fn init(&self, value: T) -> Result<(), T> { + // Critical section is required to ensure that the value is + // not simultaniously initialized elsewhere at the same time. + critical_section::with(|_| { + // If the value is not set, set it and return Ok. + if !self.init.load(Ordering::Relaxed) { + self.data.set(MaybeUninit::new(value)); + self.init.store(true, Ordering::Relaxed); + Ok(()) + + // Otherwise return an error with the given value. + } else { + Err(value) + } + }) + } + + /// Get a reference to the underlying value, initializing it if it does not exist. + pub fn get_or_init(&self, f: F) -> &T + where + F: FnOnce() -> T, + { + // Critical section is required to ensure that the value is + // not simultaniously initialized elsewhere at the same time. + critical_section::with(|_| { + // If the value is not set, set it. + if !self.init.load(Ordering::Relaxed) { + self.data.set(MaybeUninit::new(f())); + self.init.store(true, Ordering::Relaxed); + } + }); + + // Return a reference to the value. + unsafe { self.get_ref_unchecked() } + } + + /// Consume the `OnceLock`, returning the underlying value if it was initialized. + pub fn into_inner(self) -> Option { + if self.init.load(Ordering::Relaxed) { + Some(unsafe { self.data.into_inner().assume_init() }) + } else { + None + } + } + + /// Take the underlying value if it was initialized, uninitializing the `OnceLock` in the process. + pub fn take(&mut self) -> Option { + // If the value is set, uninitialize the lock and return the value. + critical_section::with(|_| { + if self.init.load(Ordering::Relaxed) { + let val = unsafe { self.data.replace(MaybeUninit::zeroed()).assume_init() }; + self.init.store(false, Ordering::Relaxed); + Some(val) + + // Otherwise return None. + } else { + None + } + }) + } + + /// Check if the value has been set. + pub fn is_set(&self) -> bool { + self.init.load(Ordering::Relaxed) + } + + /// Get a reference to the underlying value. + /// # Safety + /// Must only be used if a value has been set. + unsafe fn get_ref_unchecked(&self) -> &T { + (*self.data.as_ptr()).assume_init_ref() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn once_lock() { + let lock = OnceLock::new(); + assert_eq!(lock.try_get(), None); + assert_eq!(lock.is_set(), false); + + let v = 42; + assert_eq!(lock.init(v), Ok(())); + assert_eq!(lock.is_set(), true); + assert_eq!(lock.try_get(), Some(&v)); + assert_eq!(lock.try_get(), Some(&v)); + + let v = 43; + assert_eq!(lock.init(v), Err(v)); + assert_eq!(lock.is_set(), true); + assert_eq!(lock.try_get(), Some(&42)); + } + + #[test] + fn once_lock_get_or_init() { + let lock = OnceLock::new(); + assert_eq!(lock.try_get(), None); + assert_eq!(lock.is_set(), false); + + let v = lock.get_or_init(|| 42); + assert_eq!(v, &42); + assert_eq!(lock.is_set(), true); + assert_eq!(lock.try_get(), Some(&42)); + + let v = lock.get_or_init(|| 43); + assert_eq!(v, &42); + assert_eq!(lock.is_set(), true); + assert_eq!(lock.try_get(), Some(&42)); + } + + #[test] + fn once_lock_static() { + static LOCK: OnceLock = OnceLock::new(); + + let v: &'static i32 = LOCK.get_or_init(|| 42); + assert_eq!(v, &42); + + let v: &'static i32 = LOCK.get_or_init(|| 43); + assert_eq!(v, &42); + } + + #[futures_test::test] + async fn once_lock_async() { + static LOCK: OnceLock = OnceLock::new(); + + assert!(LOCK.init(42).is_ok()); + + let v: &'static i32 = LOCK.get().await; + assert_eq!(v, &42); + } + + #[test] + fn once_lock_into_inner() { + let lock: OnceLock = OnceLock::new(); + + let v = lock.get_or_init(|| 42); + assert_eq!(v, &42); + + assert_eq!(lock.into_inner(), Some(42)); + } + + #[test] + fn once_lock_take_init() { + let mut lock: OnceLock = OnceLock::new(); + + assert_eq!(lock.get_or_init(|| 42), &42); + assert_eq!(lock.is_set(), true); + + assert_eq!(lock.take(), Some(42)); + assert_eq!(lock.is_set(), false); + + assert_eq!(lock.get_or_init(|| 43), &43); + assert_eq!(lock.is_set(), true); + } +} -- cgit From 245e7d3bc20d66fa77f6763ce5e5bec0ea6ddeff Mon Sep 17 00:00:00 2001 From: Peter Krull Date: Sun, 3 Mar 2024 15:43:01 +0100 Subject: This one is for ci/rustfmt --- embassy-sync/src/once_lock.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) (limited to 'embassy-sync') diff --git a/embassy-sync/src/once_lock.rs b/embassy-sync/src/once_lock.rs index f83577a6d..31cc99711 100644 --- a/embassy-sync/src/once_lock.rs +++ b/embassy-sync/src/once_lock.rs @@ -33,7 +33,7 @@ use core::task::Poll; /// // Wait for the value to be initialized /// // and get a static reference it /// assert_eq!(VALUE.get().await, &20); -/// +/// /// }; /// block_on(f) /// ``` @@ -56,7 +56,6 @@ impl OnceLock { /// Get a reference to the underlying value, waiting for it to be set. /// If the value is already set, this will return immediately. pub async fn get(&self) -> &T { - poll_fn(|cx| match self.try_get() { Some(data) => Poll::Ready(data), None => { -- cgit From 8bb1fe1f653bd4bef6ad0760f33a07bcfe5d91fb Mon Sep 17 00:00:00 2001 From: Ulf Lilleengen Date: Tue, 12 Mar 2024 15:27:52 +0100 Subject: Add conversion into dyn variants for channel futures --- embassy-sync/src/channel.rs | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) (limited to 'embassy-sync') diff --git a/embassy-sync/src/channel.rs b/embassy-sync/src/channel.rs index 01db0d09a..79d3e0378 100644 --- a/embassy-sync/src/channel.rs +++ b/embassy-sync/src/channel.rs @@ -263,6 +263,15 @@ impl<'ch, T> Future for DynamicReceiveFuture<'ch, T> { } } +impl<'ch, M: RawMutex, T, const N: usize> From> for DynamicReceiveFuture<'ch, T> +{ + fn from(value: ReceiveFuture<'ch, M, T, N>) -> Self { + Self { + channel: value.channel, + } + } +} + /// Future returned by [`Channel::send`] and [`Sender::send`]. #[must_use = "futures do nothing unless you `.await` or poll them"] pub struct SendFuture<'ch, M, T, const N: usize> @@ -321,6 +330,16 @@ impl<'ch, T> Future for DynamicSendFuture<'ch, T> { impl<'ch, T> Unpin for DynamicSendFuture<'ch, T> {} +impl<'ch, M: RawMutex, T, const N: usize> From> for DynamicSendFuture<'ch, T> +{ + fn from(value: SendFuture<'ch, M, T, N>) -> Self { + Self { + channel: value.channel, + message: value.message, + } + } +} + pub(crate) trait DynamicChannel { fn try_send_with_context(&self, message: T, cx: Option<&mut Context<'_>>) -> Result<(), TrySendError>; -- cgit From c0d91600ea7b6847e2295a0fee819291c951132f Mon Sep 17 00:00:00 2001 From: Ulf Lilleengen Date: Tue, 12 Mar 2024 15:37:53 +0100 Subject: rustfmt --- embassy-sync/src/channel.rs | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) (limited to 'embassy-sync') diff --git a/embassy-sync/src/channel.rs b/embassy-sync/src/channel.rs index 79d3e0378..48f4dafd6 100644 --- a/embassy-sync/src/channel.rs +++ b/embassy-sync/src/channel.rs @@ -263,12 +263,9 @@ impl<'ch, T> Future for DynamicReceiveFuture<'ch, T> { } } -impl<'ch, M: RawMutex, T, const N: usize> From> for DynamicReceiveFuture<'ch, T> -{ +impl<'ch, M: RawMutex, T, const N: usize> From> for DynamicReceiveFuture<'ch, T> { fn from(value: ReceiveFuture<'ch, M, T, N>) -> Self { - Self { - channel: value.channel, - } + Self { channel: value.channel } } } @@ -330,8 +327,7 @@ impl<'ch, T> Future for DynamicSendFuture<'ch, T> { impl<'ch, T> Unpin for DynamicSendFuture<'ch, T> {} -impl<'ch, M: RawMutex, T, const N: usize> From> for DynamicSendFuture<'ch, T> -{ +impl<'ch, M: RawMutex, T, const N: usize> From> for DynamicSendFuture<'ch, T> { fn from(value: SendFuture<'ch, M, T, N>) -> Self { Self { channel: value.channel, -- cgit From d06dbf332bc2557d83045c93f71a86163663d481 Mon Sep 17 00:00:00 2001 From: Noah Bliss Date: Wed, 20 Mar 2024 03:19:01 +0000 Subject: Doc update: signaled does not clear signal signaled does not clear signal (doc update) --- embassy-sync/src/signal.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'embassy-sync') diff --git a/embassy-sync/src/signal.rs b/embassy-sync/src/signal.rs index d75750ce7..520f1a896 100644 --- a/embassy-sync/src/signal.rs +++ b/embassy-sync/src/signal.rs @@ -125,7 +125,7 @@ where }) } - /// non-blocking method to check whether this signal has been signaled. + /// non-blocking method to check whether this signal has been signaled. This does not clear the signal. pub fn signaled(&self) -> bool { self.state.lock(|cell| { let state = cell.replace(State::None); -- cgit From 3d842dac85a4ea21519f56d4ec6342b528805b8a Mon Sep 17 00:00:00 2001 From: Dario Nieuwenhuis Date: Wed, 20 Mar 2024 14:53:19 +0100 Subject: fmt: disable "unused" warnings. --- embassy-sync/src/fmt.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) (limited to 'embassy-sync') diff --git a/embassy-sync/src/fmt.rs b/embassy-sync/src/fmt.rs index 78e583c1c..2ac42c557 100644 --- a/embassy-sync/src/fmt.rs +++ b/embassy-sync/src/fmt.rs @@ -1,5 +1,5 @@ #![macro_use] -#![allow(unused_macros)] +#![allow(unused)] use core::fmt::{Debug, Display, LowerHex}; @@ -229,7 +229,6 @@ impl Try for Result { } } -#[allow(unused)] pub(crate) struct Bytes<'a>(pub &'a [u8]); impl<'a> Debug for Bytes<'a> { -- cgit From a85e9c6e68fbe23aa8f1fc6a1c1e7ca935c1506a Mon Sep 17 00:00:00 2001 From: Maarten de Vries Date: Fri, 22 Mar 2024 12:29:01 +0100 Subject: Forward the "std" feature to the critical-section crate in embassy-sync. Otherwise, using embassy-sync in unit tests will result in linker errors when using the CriticalSectionRawMutex. --- embassy-sync/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'embassy-sync') diff --git a/embassy-sync/Cargo.toml b/embassy-sync/Cargo.toml index 85673026c..aaf6fab1d 100644 --- a/embassy-sync/Cargo.toml +++ b/embassy-sync/Cargo.toml @@ -20,7 +20,7 @@ src_base_git = "https://github.com/embassy-rs/embassy/blob/$COMMIT/embassy-sync/ target = "thumbv7em-none-eabi" [features] -std = [] +std = ["critical-section/std"] turbowakers = [] [dependencies] -- cgit From a38cbbdc5974fe31f70bffb4fd1995391e4ff331 Mon Sep 17 00:00:00 2001 From: Alexandru Radovici Date: Sat, 30 Mar 2024 22:36:30 +0200 Subject: fix typo --- embassy-sync/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'embassy-sync') diff --git a/embassy-sync/README.md b/embassy-sync/README.md index c2e13799e..2c1c0cf68 100644 --- a/embassy-sync/README.md +++ b/embassy-sync/README.md @@ -5,7 +5,7 @@ An [Embassy](https://embassy.dev) project. Synchronization primitives and data structures with async support: - [`Channel`](channel::Channel) - A Multiple Producer Multiple Consumer (MPMC) channel. Each message is only received by a single consumer. -- [`PriorityChannel`](channel::priority::PriorityChannel) - A Multiple Producer Multiple Consumer (MPMC) channel. Each message is only received by a single consumer. Higher priority items are sifted to the front of the channel. +- [`PriorityChannel`](channel::priority::PriorityChannel) - A Multiple Producer Multiple Consumer (MPMC) channel. Each message is only received by a single consumer. Higher priority items are shifted to the front of the channel. - [`PubSubChannel`](pubsub::PubSubChannel) - A broadcast channel (publish-subscribe) channel. Each message is received by all consumers. - [`Signal`](signal::Signal) - Signalling latest value to a single consumer. - [`Mutex`](mutex::Mutex) - Mutex for synchronizing state between asynchronous tasks. -- cgit From f8a6007e1cb3fe9b13db4a56872228155ed8f1cb Mon Sep 17 00:00:00 2001 From: Alex Moon Date: Sat, 30 Mar 2024 22:19:55 -0400 Subject: Semaphore synchronization primitive This provides both a "greedy" and "fair" async semaphore implementation. --- embassy-sync/src/lib.rs | 1 + embassy-sync/src/semaphore.rs | 704 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 705 insertions(+) create mode 100644 embassy-sync/src/semaphore.rs (limited to 'embassy-sync') diff --git a/embassy-sync/src/lib.rs b/embassy-sync/src/lib.rs index 61b173e80..1873483f9 100644 --- a/embassy-sync/src/lib.rs +++ b/embassy-sync/src/lib.rs @@ -17,6 +17,7 @@ pub mod once_lock; pub mod pipe; pub mod priority_channel; pub mod pubsub; +pub mod semaphore; pub mod signal; pub mod waitqueue; pub mod zerocopy_channel; diff --git a/embassy-sync/src/semaphore.rs b/embassy-sync/src/semaphore.rs new file mode 100644 index 000000000..52c468b4a --- /dev/null +++ b/embassy-sync/src/semaphore.rs @@ -0,0 +1,704 @@ +//! A synchronization primitive for controlling access to a pool of resources. +use core::cell::{Cell, RefCell}; +use core::convert::Infallible; +use core::future::poll_fn; +use core::mem::MaybeUninit; +use core::task::{Poll, Waker}; + +use heapless::Deque; + +use crate::blocking_mutex::raw::RawMutex; +use crate::blocking_mutex::Mutex; +use crate::waitqueue::WakerRegistration; + +/// An asynchronous semaphore. +/// +/// A semaphore tracks a number of permits, typically representing a pool of shared resources. +/// Users can acquire permits to synchronize access to those resources. The semaphore does not +/// contain the resources themselves, only the count of available permits. +pub trait Semaphore: Sized { + /// The error returned when the semaphore is unable to acquire the requested permits. + type Error; + + /// Asynchronously acquire one or more permits from the semaphore. + async fn acquire(&self, permits: usize) -> Result, Self::Error>; + + /// Try to immediately acquire one or more permits from the semaphore. + fn try_acquire(&self, permits: usize) -> Option>; + + /// Asynchronously acquire all permits controlled by the semaphore. + /// + /// This method will wait until at least `min` permits are available, then acquire all available permits + /// from the semaphore. Note that other tasks may have already acquired some permits which could be released + /// back to the semaphore at any time. The number of permits actually acquired may be determined by calling + /// [`SemaphoreReleaser::permits`]. + async fn acquire_all(&self, min: usize) -> Result, Self::Error>; + + /// Try to immediately acquire all available permits from the semaphore, if at least `min` permits are available. + fn try_acquire_all(&self, min: usize) -> Option>; + + /// Release `permits` back to the semaphore, making them available to be acquired. + fn release(&self, permits: usize); + + /// Reset the number of available permints in the semaphore to `permits`. + fn set(&self, permits: usize); +} + +/// A representation of a number of acquired permits. +/// +/// The acquired permits will be released back to the [`Semaphore`] when this is dropped. +pub struct SemaphoreReleaser<'a, S: Semaphore> { + semaphore: &'a S, + permits: usize, +} + +impl<'a, S: Semaphore> Drop for SemaphoreReleaser<'a, S> { + fn drop(&mut self) { + self.semaphore.release(self.permits); + } +} + +impl<'a, S: Semaphore> SemaphoreReleaser<'a, S> { + /// The number of acquired permits. + pub fn permits(&self) -> usize { + self.permits + } + + /// Prevent the acquired permits from being released on drop. + /// + /// Returns the number of acquired permits. + pub fn disarm(self) -> usize { + let permits = self.permits; + core::mem::forget(self); + permits + } +} + +/// A greedy [`Semaphore`] implementation. +/// +/// Tasks can acquire permits as soon as they become available, even if another task +/// is waiting on a larger number of permits. +pub struct GreedySemaphore { + state: Mutex>, +} + +impl Default for GreedySemaphore { + fn default() -> Self { + Self::new(0) + } +} + +impl GreedySemaphore { + /// Create a new `Semaphore`. + pub const fn new(permits: usize) -> Self { + Self { + state: Mutex::new(Cell::new(SemaphoreState { + permits, + waker: WakerRegistration::new(), + })), + } + } + + #[cfg(test)] + fn permits(&self) -> usize { + self.state.lock(|cell| { + let state = cell.replace(SemaphoreState::EMPTY); + let permits = state.permits; + cell.replace(state); + permits + }) + } + + fn poll_acquire( + &self, + permits: usize, + acquire_all: bool, + waker: Option<&Waker>, + ) -> Poll, Infallible>> { + self.state.lock(|cell| { + let mut state = cell.replace(SemaphoreState::EMPTY); + if let Some(permits) = state.take(permits, acquire_all) { + cell.set(state); + Poll::Ready(Ok(SemaphoreReleaser { + semaphore: self, + permits, + })) + } else { + if let Some(waker) = waker { + state.register(waker); + } + cell.set(state); + Poll::Pending + } + }) + } +} + +impl Semaphore for GreedySemaphore { + type Error = Infallible; + + async fn acquire(&self, permits: usize) -> Result, Self::Error> { + poll_fn(|cx| self.poll_acquire(permits, false, Some(cx.waker()))).await + } + + fn try_acquire(&self, permits: usize) -> Option> { + match self.poll_acquire(permits, false, None) { + Poll::Ready(Ok(n)) => Some(n), + _ => None, + } + } + + async fn acquire_all(&self, min: usize) -> Result, Self::Error> { + poll_fn(|cx| self.poll_acquire(min, true, Some(cx.waker()))).await + } + + fn try_acquire_all(&self, min: usize) -> Option> { + match self.poll_acquire(min, true, None) { + Poll::Ready(Ok(n)) => Some(n), + _ => None, + } + } + + fn release(&self, permits: usize) { + if permits > 0 { + self.state.lock(|cell| { + let mut state = cell.replace(SemaphoreState::EMPTY); + state.permits += permits; + state.wake(); + cell.set(state); + }); + } + } + + fn set(&self, permits: usize) { + self.state.lock(|cell| { + let mut state = cell.replace(SemaphoreState::EMPTY); + if permits > state.permits { + state.wake(); + } + state.permits = permits; + cell.set(state); + }); + } +} + +struct SemaphoreState { + permits: usize, + waker: WakerRegistration, +} + +impl SemaphoreState { + const EMPTY: SemaphoreState = SemaphoreState { + permits: 0, + waker: WakerRegistration::new(), + }; + + fn register(&mut self, w: &Waker) { + self.waker.register(w); + } + + fn take(&mut self, mut permits: usize, acquire_all: bool) -> Option { + if self.permits < permits { + None + } else { + if acquire_all { + permits = self.permits; + } + self.permits -= permits; + Some(permits) + } + } + + fn wake(&mut self) { + self.waker.wake(); + } +} + +/// A fair [`Semaphore`] implementation. +/// +/// Tasks are allowed to acquire permits in FIFO order. A task waiting to acquire +/// a large number of permits will prevent other tasks from acquiring any permits +/// until its request is satisfied. +/// +/// Up to `N` tasks may attempt to acquire permits concurrently. If additional +/// tasks attempt to acquire a permit, a [`WaitQueueFull`] error will be returned. +pub struct FairSemaphore +where + M: RawMutex, +{ + state: Mutex>>, +} + +impl Default for FairSemaphore +where + M: RawMutex, +{ + fn default() -> Self { + Self::new(0) + } +} + +impl FairSemaphore +where + M: RawMutex, +{ + /// Create a new `FairSemaphore`. + pub const fn new(permits: usize) -> Self { + Self { + state: Mutex::new(RefCell::new(FairSemaphoreState::new(permits))), + } + } + + #[cfg(test)] + fn permits(&self) -> usize { + self.state.lock(|cell| cell.borrow().permits) + } + + fn poll_acquire( + &self, + permits: usize, + acquire_all: bool, + cx: Option<(&Cell>, &Waker)>, + ) -> Poll, WaitQueueFull>> { + let ticket = cx.as_ref().map(|(cell, _)| cell.get()).unwrap_or(None); + self.state.lock(|cell| { + let mut state = cell.borrow_mut(); + if let Some(permits) = state.take(ticket, permits, acquire_all) { + Poll::Ready(Ok(SemaphoreReleaser { + semaphore: self, + permits, + })) + } else if let Some((cell, waker)) = cx { + match state.register(ticket, waker) { + Ok(ticket) => { + cell.set(Some(ticket)); + Poll::Pending + } + Err(err) => Poll::Ready(Err(err)), + } + } else { + Poll::Pending + } + }) + } +} + +/// An error indicating the [`FairSemaphore`]'s wait queue is full. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct WaitQueueFull; + +impl Semaphore for FairSemaphore { + type Error = WaitQueueFull; + + async fn acquire(&self, permits: usize) -> Result, Self::Error> { + let ticket = Cell::new(None); + let _guard = OnDrop::new(|| self.state.lock(|cell| cell.borrow_mut().cancel(ticket.get()))); + poll_fn(|cx| self.poll_acquire(permits, false, Some((&ticket, cx.waker())))).await + } + + fn try_acquire(&self, permits: usize) -> Option> { + match self.poll_acquire(permits, false, None) { + Poll::Ready(Ok(x)) => Some(x), + _ => None, + } + } + + async fn acquire_all(&self, min: usize) -> Result, Self::Error> { + let ticket = Cell::new(None); + let _guard = OnDrop::new(|| self.state.lock(|cell| cell.borrow_mut().cancel(ticket.get()))); + poll_fn(|cx| self.poll_acquire(min, true, Some((&ticket, cx.waker())))).await + } + + fn try_acquire_all(&self, min: usize) -> Option> { + match self.poll_acquire(min, true, None) { + Poll::Ready(Ok(x)) => Some(x), + _ => None, + } + } + + fn release(&self, permits: usize) { + if permits > 0 { + self.state.lock(|cell| { + let mut state = cell.borrow_mut(); + state.permits += permits; + state.wake(); + }); + } + } + + fn set(&self, permits: usize) { + self.state.lock(|cell| { + let mut state = cell.borrow_mut(); + if permits > state.permits { + state.wake(); + } + state.permits = permits; + }); + } +} + +struct FairSemaphoreState { + permits: usize, + next_ticket: usize, + wakers: Deque, N>, +} + +impl FairSemaphoreState { + /// Create a new empty instance + const fn new(permits: usize) -> Self { + Self { + permits, + next_ticket: 0, + wakers: Deque::new(), + } + } + + /// Register a waker. If the queue is full the function returns an error + fn register(&mut self, ticket: Option, w: &Waker) -> Result { + self.pop_canceled(); + + match ticket { + None => { + let ticket = self.next_ticket.wrapping_add(self.wakers.len()); + self.wakers.push_back(Some(w.clone())).or(Err(WaitQueueFull))?; + Ok(ticket) + } + Some(ticket) => { + self.set_waker(ticket, Some(w.clone())); + Ok(ticket) + } + } + } + + fn cancel(&mut self, ticket: Option) { + if let Some(ticket) = ticket { + self.set_waker(ticket, None); + } + } + + fn set_waker(&mut self, ticket: usize, waker: Option) { + let i = ticket.wrapping_sub(self.next_ticket); + if i < self.wakers.len() { + let (a, b) = self.wakers.as_mut_slices(); + let x = if i < a.len() { &mut a[i] } else { &mut b[i - a.len()] }; + *x = waker; + } + } + + fn take(&mut self, ticket: Option, mut permits: usize, acquire_all: bool) -> Option { + self.pop_canceled(); + + if permits > self.permits { + return None; + } + + match ticket { + Some(n) if n != self.next_ticket => return None, + None if !self.wakers.is_empty() => return None, + _ => (), + } + + if acquire_all { + permits = self.permits; + } + self.permits -= permits; + + if ticket.is_some() { + self.pop(); + } + + Some(permits) + } + + fn pop_canceled(&mut self) { + while let Some(None) = self.wakers.front() { + self.pop(); + } + } + + /// Panics if `self.wakers` is empty + fn pop(&mut self) { + self.wakers.pop_front().unwrap(); + self.next_ticket = self.next_ticket.wrapping_add(1); + } + + fn wake(&mut self) { + self.pop_canceled(); + + if let Some(Some(waker)) = self.wakers.front() { + waker.wake_by_ref(); + } + } +} + +/// A type to delay the drop handler invocation. +#[must_use = "to delay the drop handler invocation to the end of the scope"] +struct OnDrop { + f: MaybeUninit, +} + +impl OnDrop { + /// Create a new instance. + pub fn new(f: F) -> Self { + Self { f: MaybeUninit::new(f) } + } +} + +impl Drop for OnDrop { + fn drop(&mut self) { + unsafe { self.f.as_ptr().read()() } + } +} + +#[cfg(test)] +mod tests { + mod greedy { + use core::pin::pin; + + use futures_util::poll; + + use super::super::*; + use crate::blocking_mutex::raw::NoopRawMutex; + + #[test] + fn try_acquire() { + let semaphore = GreedySemaphore::::new(3); + + let a = semaphore.try_acquire(1).unwrap(); + assert_eq!(a.permits(), 1); + assert_eq!(semaphore.permits(), 2); + + core::mem::drop(a); + assert_eq!(semaphore.permits(), 3); + } + + #[test] + fn disarm() { + let semaphore = GreedySemaphore::::new(3); + + let a = semaphore.try_acquire(1).unwrap(); + assert_eq!(a.disarm(), 1); + assert_eq!(semaphore.permits(), 2); + } + + #[futures_test::test] + async fn acquire() { + let semaphore = GreedySemaphore::::new(3); + + let a = semaphore.acquire(1).await.unwrap(); + assert_eq!(a.permits(), 1); + assert_eq!(semaphore.permits(), 2); + + core::mem::drop(a); + assert_eq!(semaphore.permits(), 3); + } + + #[test] + fn try_acquire_all() { + let semaphore = GreedySemaphore::::new(3); + + let a = semaphore.try_acquire_all(1).unwrap(); + assert_eq!(a.permits(), 3); + assert_eq!(semaphore.permits(), 0); + } + + #[futures_test::test] + async fn acquire_all() { + let semaphore = GreedySemaphore::::new(3); + + let a = semaphore.acquire_all(1).await.unwrap(); + assert_eq!(a.permits(), 3); + assert_eq!(semaphore.permits(), 0); + } + + #[test] + fn release() { + let semaphore = GreedySemaphore::::new(3); + assert_eq!(semaphore.permits(), 3); + semaphore.release(2); + assert_eq!(semaphore.permits(), 5); + } + + #[test] + fn set() { + let semaphore = GreedySemaphore::::new(3); + assert_eq!(semaphore.permits(), 3); + semaphore.set(2); + assert_eq!(semaphore.permits(), 2); + } + + #[test] + fn contested() { + let semaphore = GreedySemaphore::::new(3); + + let a = semaphore.try_acquire(1).unwrap(); + let b = semaphore.try_acquire(3); + assert!(b.is_none()); + + core::mem::drop(a); + + let b = semaphore.try_acquire(3); + assert!(b.is_some()); + } + + #[futures_test::test] + async fn greedy() { + let semaphore = GreedySemaphore::::new(3); + + let a = semaphore.try_acquire(1).unwrap(); + + let b_fut = semaphore.acquire(3); + let mut b_fut = pin!(b_fut); + let b = poll!(b_fut.as_mut()); + assert!(b.is_pending()); + + // Succeed even through `b` is waiting + let c = semaphore.try_acquire(1); + assert!(c.is_some()); + + let b = poll!(b_fut.as_mut()); + assert!(b.is_pending()); + + core::mem::drop(a); + + let b = poll!(b_fut.as_mut()); + assert!(b.is_pending()); + + core::mem::drop(c); + + let b = poll!(b_fut.as_mut()); + assert!(b.is_ready()); + } + } + + mod fair { + use core::pin::pin; + + use futures_util::poll; + + use super::super::*; + use crate::blocking_mutex::raw::NoopRawMutex; + + #[test] + fn try_acquire() { + let semaphore = FairSemaphore::::new(3); + + let a = semaphore.try_acquire(1).unwrap(); + assert_eq!(a.permits(), 1); + assert_eq!(semaphore.permits(), 2); + + core::mem::drop(a); + assert_eq!(semaphore.permits(), 3); + } + + #[test] + fn disarm() { + let semaphore = FairSemaphore::::new(3); + + let a = semaphore.try_acquire(1).unwrap(); + assert_eq!(a.disarm(), 1); + assert_eq!(semaphore.permits(), 2); + } + + #[futures_test::test] + async fn acquire() { + let semaphore = FairSemaphore::::new(3); + + let a = semaphore.acquire(1).await.unwrap(); + assert_eq!(a.permits(), 1); + assert_eq!(semaphore.permits(), 2); + + core::mem::drop(a); + assert_eq!(semaphore.permits(), 3); + } + + #[test] + fn try_acquire_all() { + let semaphore = FairSemaphore::::new(3); + + let a = semaphore.try_acquire_all(1).unwrap(); + assert_eq!(a.permits(), 3); + assert_eq!(semaphore.permits(), 0); + } + + #[futures_test::test] + async fn acquire_all() { + let semaphore = FairSemaphore::::new(3); + + let a = semaphore.acquire_all(1).await.unwrap(); + assert_eq!(a.permits(), 3); + assert_eq!(semaphore.permits(), 0); + } + + #[test] + fn release() { + let semaphore = FairSemaphore::::new(3); + assert_eq!(semaphore.permits(), 3); + semaphore.release(2); + assert_eq!(semaphore.permits(), 5); + } + + #[test] + fn set() { + let semaphore = FairSemaphore::::new(3); + assert_eq!(semaphore.permits(), 3); + semaphore.set(2); + assert_eq!(semaphore.permits(), 2); + } + + #[test] + fn contested() { + let semaphore = FairSemaphore::::new(3); + + let a = semaphore.try_acquire(1).unwrap(); + let b = semaphore.try_acquire(3); + assert!(b.is_none()); + + core::mem::drop(a); + + let b = semaphore.try_acquire(3); + assert!(b.is_some()); + } + + #[futures_test::test] + async fn fairness() { + let semaphore = FairSemaphore::::new(3); + + let a = semaphore.try_acquire(1); + assert!(a.is_some()); + + let b_fut = semaphore.acquire(3); + let mut b_fut = pin!(b_fut); + let b = poll!(b_fut.as_mut()); // Poll `b_fut` once so it is registered + assert!(b.is_pending()); + + let c = semaphore.try_acquire(1); + assert!(c.is_none()); + + let c_fut = semaphore.acquire(1); + let mut c_fut = pin!(c_fut); + let c = poll!(c_fut.as_mut()); // Poll `c_fut` once so it is registered + assert!(c.is_pending()); // `c` is blocked behind `b` + + let d = semaphore.acquire(1).await; + assert!(matches!(d, Err(WaitQueueFull))); + + core::mem::drop(a); + + let c = poll!(c_fut.as_mut()); + assert!(c.is_pending()); // `c` is still blocked behind `b` + + let b = poll!(b_fut.as_mut()); + assert!(b.is_ready()); + + let c = poll!(c_fut.as_mut()); + assert!(c.is_pending()); // `c` is still blocked behind `b` + + core::mem::drop(b); + + let c = poll!(c_fut.as_mut()); + assert!(c.is_ready()); + } + } +} -- cgit