aboutsummaryrefslogtreecommitdiff
path: root/embassy-sync/src/semaphore.rs
diff options
context:
space:
mode:
Diffstat (limited to 'embassy-sync/src/semaphore.rs')
-rw-r--r--embassy-sync/src/semaphore.rs136
1 files changed, 102 insertions, 34 deletions
diff --git a/embassy-sync/src/semaphore.rs b/embassy-sync/src/semaphore.rs
index 52c468b4a..d30eee30b 100644
--- a/embassy-sync/src/semaphore.rs
+++ b/embassy-sync/src/semaphore.rs
@@ -1,8 +1,7 @@
1//! A synchronization primitive for controlling access to a pool of resources. 1//! A synchronization primitive for controlling access to a pool of resources.
2use core::cell::{Cell, RefCell}; 2use core::cell::{Cell, RefCell};
3use core::convert::Infallible; 3use core::convert::Infallible;
4use core::future::poll_fn; 4use core::future::{poll_fn, Future};
5use core::mem::MaybeUninit;
6use core::task::{Poll, Waker}; 5use core::task::{Poll, Waker};
7 6
8use heapless::Deque; 7use heapless::Deque;
@@ -258,9 +257,9 @@ where
258 &self, 257 &self,
259 permits: usize, 258 permits: usize,
260 acquire_all: bool, 259 acquire_all: bool,
261 cx: Option<(&Cell<Option<usize>>, &Waker)>, 260 cx: Option<(&mut Option<usize>, &Waker)>,
262 ) -> Poll<Result<SemaphoreReleaser<'_, Self>, WaitQueueFull>> { 261 ) -> Poll<Result<SemaphoreReleaser<'_, Self>, WaitQueueFull>> {
263 let ticket = cx.as_ref().map(|(cell, _)| cell.get()).unwrap_or(None); 262 let ticket = cx.as_ref().map(|(x, _)| **x).unwrap_or(None);
264 self.state.lock(|cell| { 263 self.state.lock(|cell| {
265 let mut state = cell.borrow_mut(); 264 let mut state = cell.borrow_mut();
266 if let Some(permits) = state.take(ticket, permits, acquire_all) { 265 if let Some(permits) = state.take(ticket, permits, acquire_all) {
@@ -268,10 +267,10 @@ where
268 semaphore: self, 267 semaphore: self,
269 permits, 268 permits,
270 })) 269 }))
271 } else if let Some((cell, waker)) = cx { 270 } else if let Some((ticket_ref, waker)) = cx {
272 match state.register(ticket, waker) { 271 match state.register(ticket, waker) {
273 Ok(ticket) => { 272 Ok(ticket) => {
274 cell.set(Some(ticket)); 273 *ticket_ref = Some(ticket);
275 Poll::Pending 274 Poll::Pending
276 } 275 }
277 Err(err) => Poll::Ready(Err(err)), 276 Err(err) => Poll::Ready(Err(err)),
@@ -291,10 +290,12 @@ pub struct WaitQueueFull;
291impl<M: RawMutex, const N: usize> Semaphore for FairSemaphore<M, N> { 290impl<M: RawMutex, const N: usize> Semaphore for FairSemaphore<M, N> {
292 type Error = WaitQueueFull; 291 type Error = WaitQueueFull;
293 292
294 async fn acquire(&self, permits: usize) -> Result<SemaphoreReleaser<'_, Self>, Self::Error> { 293 fn acquire(&self, permits: usize) -> impl Future<Output = Result<SemaphoreReleaser<'_, Self>, Self::Error>> {
295 let ticket = Cell::new(None); 294 FairAcquire {
296 let _guard = OnDrop::new(|| self.state.lock(|cell| cell.borrow_mut().cancel(ticket.get()))); 295 sema: self,
297 poll_fn(|cx| self.poll_acquire(permits, false, Some((&ticket, cx.waker())))).await 296 permits,
297 ticket: None,
298 }
298 } 299 }
299 300
300 fn try_acquire(&self, permits: usize) -> Option<SemaphoreReleaser<'_, Self>> { 301 fn try_acquire(&self, permits: usize) -> Option<SemaphoreReleaser<'_, Self>> {
@@ -304,10 +305,12 @@ impl<M: RawMutex, const N: usize> Semaphore for FairSemaphore<M, N> {
304 } 305 }
305 } 306 }
306 307
307 async fn acquire_all(&self, min: usize) -> Result<SemaphoreReleaser<'_, Self>, Self::Error> { 308 fn acquire_all(&self, min: usize) -> impl Future<Output = Result<SemaphoreReleaser<'_, Self>, Self::Error>> {
308 let ticket = Cell::new(None); 309 FairAcquireAll {
309 let _guard = OnDrop::new(|| self.state.lock(|cell| cell.borrow_mut().cancel(ticket.get()))); 310 sema: self,
310 poll_fn(|cx| self.poll_acquire(min, true, Some((&ticket, cx.waker())))).await 311 min,
312 ticket: None,
313 }
311 } 314 }
312 315
313 fn try_acquire_all(&self, min: usize) -> Option<SemaphoreReleaser<'_, Self>> { 316 fn try_acquire_all(&self, min: usize) -> Option<SemaphoreReleaser<'_, Self>> {
@@ -338,6 +341,52 @@ impl<M: RawMutex, const N: usize> Semaphore for FairSemaphore<M, N> {
338 } 341 }
339} 342}
340 343
344struct FairAcquire<'a, M: RawMutex, const N: usize> {
345 sema: &'a FairSemaphore<M, N>,
346 permits: usize,
347 ticket: Option<usize>,
348}
349
350impl<'a, M: RawMutex, const N: usize> Drop for FairAcquire<'a, M, N> {
351 fn drop(&mut self) {
352 self.sema
353 .state
354 .lock(|cell| cell.borrow_mut().cancel(self.ticket.take()));
355 }
356}
357
358impl<'a, M: RawMutex, const N: usize> core::future::Future for FairAcquire<'a, M, N> {
359 type Output = Result<SemaphoreReleaser<'a, FairSemaphore<M, N>>, WaitQueueFull>;
360
361 fn poll(mut self: core::pin::Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> Poll<Self::Output> {
362 self.sema
363 .poll_acquire(self.permits, false, Some((&mut self.ticket, cx.waker())))
364 }
365}
366
367struct FairAcquireAll<'a, M: RawMutex, const N: usize> {
368 sema: &'a FairSemaphore<M, N>,
369 min: usize,
370 ticket: Option<usize>,
371}
372
373impl<'a, M: RawMutex, const N: usize> Drop for FairAcquireAll<'a, M, N> {
374 fn drop(&mut self) {
375 self.sema
376 .state
377 .lock(|cell| cell.borrow_mut().cancel(self.ticket.take()));
378 }
379}
380
381impl<'a, M: RawMutex, const N: usize> core::future::Future for FairAcquireAll<'a, M, N> {
382 type Output = Result<SemaphoreReleaser<'a, FairSemaphore<M, N>>, WaitQueueFull>;
383
384 fn poll(mut self: core::pin::Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> Poll<Self::Output> {
385 self.sema
386 .poll_acquire(self.min, true, Some((&mut self.ticket, cx.waker())))
387 }
388}
389
341struct FairSemaphoreState<const N: usize> { 390struct FairSemaphoreState<const N: usize> {
342 permits: usize, 391 permits: usize,
343 next_ticket: usize, 392 next_ticket: usize,
@@ -406,6 +455,9 @@ impl<const N: usize> FairSemaphoreState<N> {
406 455
407 if ticket.is_some() { 456 if ticket.is_some() {
408 self.pop(); 457 self.pop();
458 if self.permits > 0 {
459 self.wake();
460 }
409 } 461 }
410 462
411 Some(permits) 463 Some(permits)
@@ -432,25 +484,6 @@ impl<const N: usize> FairSemaphoreState<N> {
432 } 484 }
433} 485}
434 486
435/// A type to delay the drop handler invocation.
436#[must_use = "to delay the drop handler invocation to the end of the scope"]
437struct OnDrop<F: FnOnce()> {
438 f: MaybeUninit<F>,
439}
440
441impl<F: FnOnce()> OnDrop<F> {
442 /// Create a new instance.
443 pub fn new(f: F) -> Self {
444 Self { f: MaybeUninit::new(f) }
445 }
446}
447
448impl<F: FnOnce()> Drop for OnDrop<F> {
449 fn drop(&mut self) {
450 unsafe { self.f.as_ptr().read()() }
451 }
452}
453
454#[cfg(test)] 487#[cfg(test)]
455mod tests { 488mod tests {
456 mod greedy { 489 mod greedy {
@@ -574,11 +607,16 @@ mod tests {
574 607
575 mod fair { 608 mod fair {
576 use core::pin::pin; 609 use core::pin::pin;
610 use core::time::Duration;
577 611
612 use futures_executor::ThreadPool;
613 use futures_timer::Delay;
578 use futures_util::poll; 614 use futures_util::poll;
615 use futures_util::task::SpawnExt;
616 use static_cell::StaticCell;
579 617
580 use super::super::*; 618 use super::super::*;
581 use crate::blocking_mutex::raw::NoopRawMutex; 619 use crate::blocking_mutex::raw::{CriticalSectionRawMutex, NoopRawMutex};
582 620
583 #[test] 621 #[test]
584 fn try_acquire() { 622 fn try_acquire() {
@@ -700,5 +738,35 @@ mod tests {
700 let c = poll!(c_fut.as_mut()); 738 let c = poll!(c_fut.as_mut());
701 assert!(c.is_ready()); 739 assert!(c.is_ready());
702 } 740 }
741
742 #[futures_test::test]
743 async fn wakers() {
744 let executor = ThreadPool::new().unwrap();
745
746 static SEMAPHORE: StaticCell<FairSemaphore<CriticalSectionRawMutex, 2>> = StaticCell::new();
747 let semaphore = &*SEMAPHORE.init(FairSemaphore::new(3));
748
749 let a = semaphore.try_acquire(2);
750 assert!(a.is_some());
751
752 let b_task = executor
753 .spawn_with_handle(async move { semaphore.acquire(2).await })
754 .unwrap();
755 while semaphore.state.lock(|x| x.borrow().wakers.is_empty()) {
756 Delay::new(Duration::from_millis(50)).await;
757 }
758
759 let c_task = executor
760 .spawn_with_handle(async move { semaphore.acquire(1).await })
761 .unwrap();
762
763 core::mem::drop(a);
764
765 let b = b_task.await.unwrap();
766 assert_eq!(b.permits(), 2);
767
768 let c = c_task.await.unwrap();
769 assert_eq!(c.permits(), 1);
770 }
703 } 771 }
704} 772}