diff options
| author | Alex Moon <[email protected]> | 2024-04-03 19:13:57 -0400 |
|---|---|---|
| committer | Alex Moon <[email protected]> | 2024-04-03 19:13:57 -0400 |
| commit | c9acebf783c64784fe6b659a94b40fa080b6fbe8 (patch) | |
| tree | 1b57c9e0b2046615b08bce88d290506aa4b1767f /embassy-sync/src/semaphore.rs | |
| parent | 1fd260e4b1f875ba71c40a732a606bf487165750 (diff) | |
Fix `FairSemaphore` bugs
- `acquire` and `acquire_all` futures were `!Send`, even for `M: RawMutex + Send` due to the captured `Cell`.
- If multiple `acquire` tasks were queued, waking the first would not wake the second, even if there were permits remaining after the first `acquire` completed.
Diffstat (limited to 'embassy-sync/src/semaphore.rs')
| -rw-r--r-- | embassy-sync/src/semaphore.rs | 136 |
1 files changed, 102 insertions, 34 deletions
diff --git a/embassy-sync/src/semaphore.rs b/embassy-sync/src/semaphore.rs index 52c468b4a..d30eee30b 100644 --- a/embassy-sync/src/semaphore.rs +++ b/embassy-sync/src/semaphore.rs | |||
| @@ -1,8 +1,7 @@ | |||
| 1 | //! A synchronization primitive for controlling access to a pool of resources. | 1 | //! A synchronization primitive for controlling access to a pool of resources. |
| 2 | use core::cell::{Cell, RefCell}; | 2 | use core::cell::{Cell, RefCell}; |
| 3 | use core::convert::Infallible; | 3 | use core::convert::Infallible; |
| 4 | use core::future::poll_fn; | 4 | use core::future::{poll_fn, Future}; |
| 5 | use core::mem::MaybeUninit; | ||
| 6 | use core::task::{Poll, Waker}; | 5 | use core::task::{Poll, Waker}; |
| 7 | 6 | ||
| 8 | use heapless::Deque; | 7 | use heapless::Deque; |
| @@ -258,9 +257,9 @@ where | |||
| 258 | &self, | 257 | &self, |
| 259 | permits: usize, | 258 | permits: usize, |
| 260 | acquire_all: bool, | 259 | acquire_all: bool, |
| 261 | cx: Option<(&Cell<Option<usize>>, &Waker)>, | 260 | cx: Option<(&mut Option<usize>, &Waker)>, |
| 262 | ) -> Poll<Result<SemaphoreReleaser<'_, Self>, WaitQueueFull>> { | 261 | ) -> Poll<Result<SemaphoreReleaser<'_, Self>, WaitQueueFull>> { |
| 263 | let ticket = cx.as_ref().map(|(cell, _)| cell.get()).unwrap_or(None); | 262 | let ticket = cx.as_ref().map(|(x, _)| **x).unwrap_or(None); |
| 264 | self.state.lock(|cell| { | 263 | self.state.lock(|cell| { |
| 265 | let mut state = cell.borrow_mut(); | 264 | let mut state = cell.borrow_mut(); |
| 266 | if let Some(permits) = state.take(ticket, permits, acquire_all) { | 265 | if let Some(permits) = state.take(ticket, permits, acquire_all) { |
| @@ -268,10 +267,10 @@ where | |||
| 268 | semaphore: self, | 267 | semaphore: self, |
| 269 | permits, | 268 | permits, |
| 270 | })) | 269 | })) |
| 271 | } else if let Some((cell, waker)) = cx { | 270 | } else if let Some((ticket_ref, waker)) = cx { |
| 272 | match state.register(ticket, waker) { | 271 | match state.register(ticket, waker) { |
| 273 | Ok(ticket) => { | 272 | Ok(ticket) => { |
| 274 | cell.set(Some(ticket)); | 273 | *ticket_ref = Some(ticket); |
| 275 | Poll::Pending | 274 | Poll::Pending |
| 276 | } | 275 | } |
| 277 | Err(err) => Poll::Ready(Err(err)), | 276 | Err(err) => Poll::Ready(Err(err)), |
| @@ -291,10 +290,12 @@ pub struct WaitQueueFull; | |||
| 291 | impl<M: RawMutex, const N: usize> Semaphore for FairSemaphore<M, N> { | 290 | impl<M: RawMutex, const N: usize> Semaphore for FairSemaphore<M, N> { |
| 292 | type Error = WaitQueueFull; | 291 | type Error = WaitQueueFull; |
| 293 | 292 | ||
| 294 | async fn acquire(&self, permits: usize) -> Result<SemaphoreReleaser<'_, Self>, Self::Error> { | 293 | fn acquire(&self, permits: usize) -> impl Future<Output = Result<SemaphoreReleaser<'_, Self>, Self::Error>> { |
| 295 | let ticket = Cell::new(None); | 294 | FairAcquire { |
| 296 | let _guard = OnDrop::new(|| self.state.lock(|cell| cell.borrow_mut().cancel(ticket.get()))); | 295 | sema: self, |
| 297 | poll_fn(|cx| self.poll_acquire(permits, false, Some((&ticket, cx.waker())))).await | 296 | permits, |
| 297 | ticket: None, | ||
| 298 | } | ||
| 298 | } | 299 | } |
| 299 | 300 | ||
| 300 | fn try_acquire(&self, permits: usize) -> Option<SemaphoreReleaser<'_, Self>> { | 301 | fn try_acquire(&self, permits: usize) -> Option<SemaphoreReleaser<'_, Self>> { |
| @@ -304,10 +305,12 @@ impl<M: RawMutex, const N: usize> Semaphore for FairSemaphore<M, N> { | |||
| 304 | } | 305 | } |
| 305 | } | 306 | } |
| 306 | 307 | ||
| 307 | async fn acquire_all(&self, min: usize) -> Result<SemaphoreReleaser<'_, Self>, Self::Error> { | 308 | fn acquire_all(&self, min: usize) -> impl Future<Output = Result<SemaphoreReleaser<'_, Self>, Self::Error>> { |
| 308 | let ticket = Cell::new(None); | 309 | FairAcquireAll { |
| 309 | let _guard = OnDrop::new(|| self.state.lock(|cell| cell.borrow_mut().cancel(ticket.get()))); | 310 | sema: self, |
| 310 | poll_fn(|cx| self.poll_acquire(min, true, Some((&ticket, cx.waker())))).await | 311 | min, |
| 312 | ticket: None, | ||
| 313 | } | ||
| 311 | } | 314 | } |
| 312 | 315 | ||
| 313 | fn try_acquire_all(&self, min: usize) -> Option<SemaphoreReleaser<'_, Self>> { | 316 | fn try_acquire_all(&self, min: usize) -> Option<SemaphoreReleaser<'_, Self>> { |
| @@ -338,6 +341,52 @@ impl<M: RawMutex, const N: usize> Semaphore for FairSemaphore<M, N> { | |||
| 338 | } | 341 | } |
| 339 | } | 342 | } |
| 340 | 343 | ||
| 344 | struct FairAcquire<'a, M: RawMutex, const N: usize> { | ||
| 345 | sema: &'a FairSemaphore<M, N>, | ||
| 346 | permits: usize, | ||
| 347 | ticket: Option<usize>, | ||
| 348 | } | ||
| 349 | |||
| 350 | impl<'a, M: RawMutex, const N: usize> Drop for FairAcquire<'a, M, N> { | ||
| 351 | fn drop(&mut self) { | ||
| 352 | self.sema | ||
| 353 | .state | ||
| 354 | .lock(|cell| cell.borrow_mut().cancel(self.ticket.take())); | ||
| 355 | } | ||
| 356 | } | ||
| 357 | |||
| 358 | impl<'a, M: RawMutex, const N: usize> core::future::Future for FairAcquire<'a, M, N> { | ||
| 359 | type Output = Result<SemaphoreReleaser<'a, FairSemaphore<M, N>>, WaitQueueFull>; | ||
| 360 | |||
| 361 | fn poll(mut self: core::pin::Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> Poll<Self::Output> { | ||
| 362 | self.sema | ||
| 363 | .poll_acquire(self.permits, false, Some((&mut self.ticket, cx.waker()))) | ||
| 364 | } | ||
| 365 | } | ||
| 366 | |||
| 367 | struct FairAcquireAll<'a, M: RawMutex, const N: usize> { | ||
| 368 | sema: &'a FairSemaphore<M, N>, | ||
| 369 | min: usize, | ||
| 370 | ticket: Option<usize>, | ||
| 371 | } | ||
| 372 | |||
| 373 | impl<'a, M: RawMutex, const N: usize> Drop for FairAcquireAll<'a, M, N> { | ||
| 374 | fn drop(&mut self) { | ||
| 375 | self.sema | ||
| 376 | .state | ||
| 377 | .lock(|cell| cell.borrow_mut().cancel(self.ticket.take())); | ||
| 378 | } | ||
| 379 | } | ||
| 380 | |||
| 381 | impl<'a, M: RawMutex, const N: usize> core::future::Future for FairAcquireAll<'a, M, N> { | ||
| 382 | type Output = Result<SemaphoreReleaser<'a, FairSemaphore<M, N>>, WaitQueueFull>; | ||
| 383 | |||
| 384 | fn poll(mut self: core::pin::Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> Poll<Self::Output> { | ||
| 385 | self.sema | ||
| 386 | .poll_acquire(self.min, true, Some((&mut self.ticket, cx.waker()))) | ||
| 387 | } | ||
| 388 | } | ||
| 389 | |||
| 341 | struct FairSemaphoreState<const N: usize> { | 390 | struct FairSemaphoreState<const N: usize> { |
| 342 | permits: usize, | 391 | permits: usize, |
| 343 | next_ticket: usize, | 392 | next_ticket: usize, |
| @@ -406,6 +455,9 @@ impl<const N: usize> FairSemaphoreState<N> { | |||
| 406 | 455 | ||
| 407 | if ticket.is_some() { | 456 | if ticket.is_some() { |
| 408 | self.pop(); | 457 | self.pop(); |
| 458 | if self.permits > 0 { | ||
| 459 | self.wake(); | ||
| 460 | } | ||
| 409 | } | 461 | } |
| 410 | 462 | ||
| 411 | Some(permits) | 463 | Some(permits) |
| @@ -432,25 +484,6 @@ impl<const N: usize> FairSemaphoreState<N> { | |||
| 432 | } | 484 | } |
| 433 | } | 485 | } |
| 434 | 486 | ||
| 435 | /// A type to delay the drop handler invocation. | ||
| 436 | #[must_use = "to delay the drop handler invocation to the end of the scope"] | ||
| 437 | struct OnDrop<F: FnOnce()> { | ||
| 438 | f: MaybeUninit<F>, | ||
| 439 | } | ||
| 440 | |||
| 441 | impl<F: FnOnce()> OnDrop<F> { | ||
| 442 | /// Create a new instance. | ||
| 443 | pub fn new(f: F) -> Self { | ||
| 444 | Self { f: MaybeUninit::new(f) } | ||
| 445 | } | ||
| 446 | } | ||
| 447 | |||
| 448 | impl<F: FnOnce()> Drop for OnDrop<F> { | ||
| 449 | fn drop(&mut self) { | ||
| 450 | unsafe { self.f.as_ptr().read()() } | ||
| 451 | } | ||
| 452 | } | ||
| 453 | |||
| 454 | #[cfg(test)] | 487 | #[cfg(test)] |
| 455 | mod tests { | 488 | mod tests { |
| 456 | mod greedy { | 489 | mod greedy { |
| @@ -574,11 +607,16 @@ mod tests { | |||
| 574 | 607 | ||
| 575 | mod fair { | 608 | mod fair { |
| 576 | use core::pin::pin; | 609 | use core::pin::pin; |
| 610 | use core::time::Duration; | ||
| 577 | 611 | ||
| 612 | use futures_executor::ThreadPool; | ||
| 613 | use futures_timer::Delay; | ||
| 578 | use futures_util::poll; | 614 | use futures_util::poll; |
| 615 | use futures_util::task::SpawnExt; | ||
| 616 | use static_cell::StaticCell; | ||
| 579 | 617 | ||
| 580 | use super::super::*; | 618 | use super::super::*; |
| 581 | use crate::blocking_mutex::raw::NoopRawMutex; | 619 | use crate::blocking_mutex::raw::{CriticalSectionRawMutex, NoopRawMutex}; |
| 582 | 620 | ||
| 583 | #[test] | 621 | #[test] |
| 584 | fn try_acquire() { | 622 | fn try_acquire() { |
| @@ -700,5 +738,35 @@ mod tests { | |||
| 700 | let c = poll!(c_fut.as_mut()); | 738 | let c = poll!(c_fut.as_mut()); |
| 701 | assert!(c.is_ready()); | 739 | assert!(c.is_ready()); |
| 702 | } | 740 | } |
| 741 | |||
| 742 | #[futures_test::test] | ||
| 743 | async fn wakers() { | ||
| 744 | let executor = ThreadPool::new().unwrap(); | ||
| 745 | |||
| 746 | static SEMAPHORE: StaticCell<FairSemaphore<CriticalSectionRawMutex, 2>> = StaticCell::new(); | ||
| 747 | let semaphore = &*SEMAPHORE.init(FairSemaphore::new(3)); | ||
| 748 | |||
| 749 | let a = semaphore.try_acquire(2); | ||
| 750 | assert!(a.is_some()); | ||
| 751 | |||
| 752 | let b_task = executor | ||
| 753 | .spawn_with_handle(async move { semaphore.acquire(2).await }) | ||
| 754 | .unwrap(); | ||
| 755 | while semaphore.state.lock(|x| x.borrow().wakers.is_empty()) { | ||
| 756 | Delay::new(Duration::from_millis(50)).await; | ||
| 757 | } | ||
| 758 | |||
| 759 | let c_task = executor | ||
| 760 | .spawn_with_handle(async move { semaphore.acquire(1).await }) | ||
| 761 | .unwrap(); | ||
| 762 | |||
| 763 | core::mem::drop(a); | ||
| 764 | |||
| 765 | let b = b_task.await.unwrap(); | ||
| 766 | assert_eq!(b.permits(), 2); | ||
| 767 | |||
| 768 | let c = c_task.await.unwrap(); | ||
| 769 | assert_eq!(c.permits(), 1); | ||
| 770 | } | ||
| 703 | } | 771 | } |
| 704 | } | 772 | } |
