aboutsummaryrefslogtreecommitdiff
path: root/embassy-sync/src
diff options
context:
space:
mode:
Diffstat (limited to 'embassy-sync/src')
-rw-r--r--embassy-sync/src/lib.rs1
-rw-r--r--embassy-sync/src/semaphore.rs704
2 files changed, 705 insertions, 0 deletions
diff --git a/embassy-sync/src/lib.rs b/embassy-sync/src/lib.rs
index 61b173e80..1873483f9 100644
--- a/embassy-sync/src/lib.rs
+++ b/embassy-sync/src/lib.rs
@@ -17,6 +17,7 @@ pub mod once_lock;
17pub mod pipe; 17pub mod pipe;
18pub mod priority_channel; 18pub mod priority_channel;
19pub mod pubsub; 19pub mod pubsub;
20pub mod semaphore;
20pub mod signal; 21pub mod signal;
21pub mod waitqueue; 22pub mod waitqueue;
22pub mod zerocopy_channel; 23pub mod zerocopy_channel;
diff --git a/embassy-sync/src/semaphore.rs b/embassy-sync/src/semaphore.rs
new file mode 100644
index 000000000..52c468b4a
--- /dev/null
+++ b/embassy-sync/src/semaphore.rs
@@ -0,0 +1,704 @@
1//! A synchronization primitive for controlling access to a pool of resources.
2use core::cell::{Cell, RefCell};
3use core::convert::Infallible;
4use core::future::poll_fn;
5use core::mem::MaybeUninit;
6use core::task::{Poll, Waker};
7
8use heapless::Deque;
9
10use crate::blocking_mutex::raw::RawMutex;
11use crate::blocking_mutex::Mutex;
12use crate::waitqueue::WakerRegistration;
13
14/// An asynchronous semaphore.
15///
16/// A semaphore tracks a number of permits, typically representing a pool of shared resources.
17/// Users can acquire permits to synchronize access to those resources. The semaphore does not
18/// contain the resources themselves, only the count of available permits.
19pub trait Semaphore: Sized {
20 /// The error returned when the semaphore is unable to acquire the requested permits.
21 type Error;
22
23 /// Asynchronously acquire one or more permits from the semaphore.
24 async fn acquire(&self, permits: usize) -> Result<SemaphoreReleaser<'_, Self>, Self::Error>;
25
26 /// Try to immediately acquire one or more permits from the semaphore.
27 fn try_acquire(&self, permits: usize) -> Option<SemaphoreReleaser<'_, Self>>;
28
29 /// Asynchronously acquire all permits controlled by the semaphore.
30 ///
31 /// This method will wait until at least `min` permits are available, then acquire all available permits
32 /// from the semaphore. Note that other tasks may have already acquired some permits which could be released
33 /// back to the semaphore at any time. The number of permits actually acquired may be determined by calling
34 /// [`SemaphoreReleaser::permits`].
35 async fn acquire_all(&self, min: usize) -> Result<SemaphoreReleaser<'_, Self>, Self::Error>;
36
37 /// Try to immediately acquire all available permits from the semaphore, if at least `min` permits are available.
38 fn try_acquire_all(&self, min: usize) -> Option<SemaphoreReleaser<'_, Self>>;
39
40 /// Release `permits` back to the semaphore, making them available to be acquired.
41 fn release(&self, permits: usize);
42
43 /// Reset the number of available permints in the semaphore to `permits`.
44 fn set(&self, permits: usize);
45}
46
47/// A representation of a number of acquired permits.
48///
49/// The acquired permits will be released back to the [`Semaphore`] when this is dropped.
50pub struct SemaphoreReleaser<'a, S: Semaphore> {
51 semaphore: &'a S,
52 permits: usize,
53}
54
55impl<'a, S: Semaphore> Drop for SemaphoreReleaser<'a, S> {
56 fn drop(&mut self) {
57 self.semaphore.release(self.permits);
58 }
59}
60
61impl<'a, S: Semaphore> SemaphoreReleaser<'a, S> {
62 /// The number of acquired permits.
63 pub fn permits(&self) -> usize {
64 self.permits
65 }
66
67 /// Prevent the acquired permits from being released on drop.
68 ///
69 /// Returns the number of acquired permits.
70 pub fn disarm(self) -> usize {
71 let permits = self.permits;
72 core::mem::forget(self);
73 permits
74 }
75}
76
77/// A greedy [`Semaphore`] implementation.
78///
79/// Tasks can acquire permits as soon as they become available, even if another task
80/// is waiting on a larger number of permits.
81pub struct GreedySemaphore<M: RawMutex> {
82 state: Mutex<M, Cell<SemaphoreState>>,
83}
84
85impl<M: RawMutex> Default for GreedySemaphore<M> {
86 fn default() -> Self {
87 Self::new(0)
88 }
89}
90
91impl<M: RawMutex> GreedySemaphore<M> {
92 /// Create a new `Semaphore`.
93 pub const fn new(permits: usize) -> Self {
94 Self {
95 state: Mutex::new(Cell::new(SemaphoreState {
96 permits,
97 waker: WakerRegistration::new(),
98 })),
99 }
100 }
101
102 #[cfg(test)]
103 fn permits(&self) -> usize {
104 self.state.lock(|cell| {
105 let state = cell.replace(SemaphoreState::EMPTY);
106 let permits = state.permits;
107 cell.replace(state);
108 permits
109 })
110 }
111
112 fn poll_acquire(
113 &self,
114 permits: usize,
115 acquire_all: bool,
116 waker: Option<&Waker>,
117 ) -> Poll<Result<SemaphoreReleaser<'_, Self>, Infallible>> {
118 self.state.lock(|cell| {
119 let mut state = cell.replace(SemaphoreState::EMPTY);
120 if let Some(permits) = state.take(permits, acquire_all) {
121 cell.set(state);
122 Poll::Ready(Ok(SemaphoreReleaser {
123 semaphore: self,
124 permits,
125 }))
126 } else {
127 if let Some(waker) = waker {
128 state.register(waker);
129 }
130 cell.set(state);
131 Poll::Pending
132 }
133 })
134 }
135}
136
137impl<M: RawMutex> Semaphore for GreedySemaphore<M> {
138 type Error = Infallible;
139
140 async fn acquire(&self, permits: usize) -> Result<SemaphoreReleaser<'_, Self>, Self::Error> {
141 poll_fn(|cx| self.poll_acquire(permits, false, Some(cx.waker()))).await
142 }
143
144 fn try_acquire(&self, permits: usize) -> Option<SemaphoreReleaser<'_, Self>> {
145 match self.poll_acquire(permits, false, None) {
146 Poll::Ready(Ok(n)) => Some(n),
147 _ => None,
148 }
149 }
150
151 async fn acquire_all(&self, min: usize) -> Result<SemaphoreReleaser<'_, Self>, Self::Error> {
152 poll_fn(|cx| self.poll_acquire(min, true, Some(cx.waker()))).await
153 }
154
155 fn try_acquire_all(&self, min: usize) -> Option<SemaphoreReleaser<'_, Self>> {
156 match self.poll_acquire(min, true, None) {
157 Poll::Ready(Ok(n)) => Some(n),
158 _ => None,
159 }
160 }
161
162 fn release(&self, permits: usize) {
163 if permits > 0 {
164 self.state.lock(|cell| {
165 let mut state = cell.replace(SemaphoreState::EMPTY);
166 state.permits += permits;
167 state.wake();
168 cell.set(state);
169 });
170 }
171 }
172
173 fn set(&self, permits: usize) {
174 self.state.lock(|cell| {
175 let mut state = cell.replace(SemaphoreState::EMPTY);
176 if permits > state.permits {
177 state.wake();
178 }
179 state.permits = permits;
180 cell.set(state);
181 });
182 }
183}
184
185struct SemaphoreState {
186 permits: usize,
187 waker: WakerRegistration,
188}
189
190impl SemaphoreState {
191 const EMPTY: SemaphoreState = SemaphoreState {
192 permits: 0,
193 waker: WakerRegistration::new(),
194 };
195
196 fn register(&mut self, w: &Waker) {
197 self.waker.register(w);
198 }
199
200 fn take(&mut self, mut permits: usize, acquire_all: bool) -> Option<usize> {
201 if self.permits < permits {
202 None
203 } else {
204 if acquire_all {
205 permits = self.permits;
206 }
207 self.permits -= permits;
208 Some(permits)
209 }
210 }
211
212 fn wake(&mut self) {
213 self.waker.wake();
214 }
215}
216
217/// A fair [`Semaphore`] implementation.
218///
219/// Tasks are allowed to acquire permits in FIFO order. A task waiting to acquire
220/// a large number of permits will prevent other tasks from acquiring any permits
221/// until its request is satisfied.
222///
223/// Up to `N` tasks may attempt to acquire permits concurrently. If additional
224/// tasks attempt to acquire a permit, a [`WaitQueueFull`] error will be returned.
225pub struct FairSemaphore<M, const N: usize>
226where
227 M: RawMutex,
228{
229 state: Mutex<M, RefCell<FairSemaphoreState<N>>>,
230}
231
232impl<M, const N: usize> Default for FairSemaphore<M, N>
233where
234 M: RawMutex,
235{
236 fn default() -> Self {
237 Self::new(0)
238 }
239}
240
241impl<M, const N: usize> FairSemaphore<M, N>
242where
243 M: RawMutex,
244{
245 /// Create a new `FairSemaphore`.
246 pub const fn new(permits: usize) -> Self {
247 Self {
248 state: Mutex::new(RefCell::new(FairSemaphoreState::new(permits))),
249 }
250 }
251
252 #[cfg(test)]
253 fn permits(&self) -> usize {
254 self.state.lock(|cell| cell.borrow().permits)
255 }
256
257 fn poll_acquire(
258 &self,
259 permits: usize,
260 acquire_all: bool,
261 cx: Option<(&Cell<Option<usize>>, &Waker)>,
262 ) -> Poll<Result<SemaphoreReleaser<'_, Self>, WaitQueueFull>> {
263 let ticket = cx.as_ref().map(|(cell, _)| cell.get()).unwrap_or(None);
264 self.state.lock(|cell| {
265 let mut state = cell.borrow_mut();
266 if let Some(permits) = state.take(ticket, permits, acquire_all) {
267 Poll::Ready(Ok(SemaphoreReleaser {
268 semaphore: self,
269 permits,
270 }))
271 } else if let Some((cell, waker)) = cx {
272 match state.register(ticket, waker) {
273 Ok(ticket) => {
274 cell.set(Some(ticket));
275 Poll::Pending
276 }
277 Err(err) => Poll::Ready(Err(err)),
278 }
279 } else {
280 Poll::Pending
281 }
282 })
283 }
284}
285
286/// An error indicating the [`FairSemaphore`]'s wait queue is full.
287#[derive(Debug, Clone, Copy, PartialEq, Eq)]
288#[cfg_attr(feature = "defmt", derive(defmt::Format))]
289pub struct WaitQueueFull;
290
291impl<M: RawMutex, const N: usize> Semaphore for FairSemaphore<M, N> {
292 type Error = WaitQueueFull;
293
294 async fn acquire(&self, permits: usize) -> Result<SemaphoreReleaser<'_, Self>, Self::Error> {
295 let ticket = Cell::new(None);
296 let _guard = OnDrop::new(|| self.state.lock(|cell| cell.borrow_mut().cancel(ticket.get())));
297 poll_fn(|cx| self.poll_acquire(permits, false, Some((&ticket, cx.waker())))).await
298 }
299
300 fn try_acquire(&self, permits: usize) -> Option<SemaphoreReleaser<'_, Self>> {
301 match self.poll_acquire(permits, false, None) {
302 Poll::Ready(Ok(x)) => Some(x),
303 _ => None,
304 }
305 }
306
307 async fn acquire_all(&self, min: usize) -> Result<SemaphoreReleaser<'_, Self>, Self::Error> {
308 let ticket = Cell::new(None);
309 let _guard = OnDrop::new(|| self.state.lock(|cell| cell.borrow_mut().cancel(ticket.get())));
310 poll_fn(|cx| self.poll_acquire(min, true, Some((&ticket, cx.waker())))).await
311 }
312
313 fn try_acquire_all(&self, min: usize) -> Option<SemaphoreReleaser<'_, Self>> {
314 match self.poll_acquire(min, true, None) {
315 Poll::Ready(Ok(x)) => Some(x),
316 _ => None,
317 }
318 }
319
320 fn release(&self, permits: usize) {
321 if permits > 0 {
322 self.state.lock(|cell| {
323 let mut state = cell.borrow_mut();
324 state.permits += permits;
325 state.wake();
326 });
327 }
328 }
329
330 fn set(&self, permits: usize) {
331 self.state.lock(|cell| {
332 let mut state = cell.borrow_mut();
333 if permits > state.permits {
334 state.wake();
335 }
336 state.permits = permits;
337 });
338 }
339}
340
341struct FairSemaphoreState<const N: usize> {
342 permits: usize,
343 next_ticket: usize,
344 wakers: Deque<Option<Waker>, N>,
345}
346
347impl<const N: usize> FairSemaphoreState<N> {
348 /// Create a new empty instance
349 const fn new(permits: usize) -> Self {
350 Self {
351 permits,
352 next_ticket: 0,
353 wakers: Deque::new(),
354 }
355 }
356
357 /// Register a waker. If the queue is full the function returns an error
358 fn register(&mut self, ticket: Option<usize>, w: &Waker) -> Result<usize, WaitQueueFull> {
359 self.pop_canceled();
360
361 match ticket {
362 None => {
363 let ticket = self.next_ticket.wrapping_add(self.wakers.len());
364 self.wakers.push_back(Some(w.clone())).or(Err(WaitQueueFull))?;
365 Ok(ticket)
366 }
367 Some(ticket) => {
368 self.set_waker(ticket, Some(w.clone()));
369 Ok(ticket)
370 }
371 }
372 }
373
374 fn cancel(&mut self, ticket: Option<usize>) {
375 if let Some(ticket) = ticket {
376 self.set_waker(ticket, None);
377 }
378 }
379
380 fn set_waker(&mut self, ticket: usize, waker: Option<Waker>) {
381 let i = ticket.wrapping_sub(self.next_ticket);
382 if i < self.wakers.len() {
383 let (a, b) = self.wakers.as_mut_slices();
384 let x = if i < a.len() { &mut a[i] } else { &mut b[i - a.len()] };
385 *x = waker;
386 }
387 }
388
389 fn take(&mut self, ticket: Option<usize>, mut permits: usize, acquire_all: bool) -> Option<usize> {
390 self.pop_canceled();
391
392 if permits > self.permits {
393 return None;
394 }
395
396 match ticket {
397 Some(n) if n != self.next_ticket => return None,
398 None if !self.wakers.is_empty() => return None,
399 _ => (),
400 }
401
402 if acquire_all {
403 permits = self.permits;
404 }
405 self.permits -= permits;
406
407 if ticket.is_some() {
408 self.pop();
409 }
410
411 Some(permits)
412 }
413
414 fn pop_canceled(&mut self) {
415 while let Some(None) = self.wakers.front() {
416 self.pop();
417 }
418 }
419
420 /// Panics if `self.wakers` is empty
421 fn pop(&mut self) {
422 self.wakers.pop_front().unwrap();
423 self.next_ticket = self.next_ticket.wrapping_add(1);
424 }
425
426 fn wake(&mut self) {
427 self.pop_canceled();
428
429 if let Some(Some(waker)) = self.wakers.front() {
430 waker.wake_by_ref();
431 }
432 }
433}
434
435/// A type to delay the drop handler invocation.
436#[must_use = "to delay the drop handler invocation to the end of the scope"]
437struct OnDrop<F: FnOnce()> {
438 f: MaybeUninit<F>,
439}
440
441impl<F: FnOnce()> OnDrop<F> {
442 /// Create a new instance.
443 pub fn new(f: F) -> Self {
444 Self { f: MaybeUninit::new(f) }
445 }
446}
447
448impl<F: FnOnce()> Drop for OnDrop<F> {
449 fn drop(&mut self) {
450 unsafe { self.f.as_ptr().read()() }
451 }
452}
453
454#[cfg(test)]
455mod tests {
456 mod greedy {
457 use core::pin::pin;
458
459 use futures_util::poll;
460
461 use super::super::*;
462 use crate::blocking_mutex::raw::NoopRawMutex;
463
464 #[test]
465 fn try_acquire() {
466 let semaphore = GreedySemaphore::<NoopRawMutex>::new(3);
467
468 let a = semaphore.try_acquire(1).unwrap();
469 assert_eq!(a.permits(), 1);
470 assert_eq!(semaphore.permits(), 2);
471
472 core::mem::drop(a);
473 assert_eq!(semaphore.permits(), 3);
474 }
475
476 #[test]
477 fn disarm() {
478 let semaphore = GreedySemaphore::<NoopRawMutex>::new(3);
479
480 let a = semaphore.try_acquire(1).unwrap();
481 assert_eq!(a.disarm(), 1);
482 assert_eq!(semaphore.permits(), 2);
483 }
484
485 #[futures_test::test]
486 async fn acquire() {
487 let semaphore = GreedySemaphore::<NoopRawMutex>::new(3);
488
489 let a = semaphore.acquire(1).await.unwrap();
490 assert_eq!(a.permits(), 1);
491 assert_eq!(semaphore.permits(), 2);
492
493 core::mem::drop(a);
494 assert_eq!(semaphore.permits(), 3);
495 }
496
497 #[test]
498 fn try_acquire_all() {
499 let semaphore = GreedySemaphore::<NoopRawMutex>::new(3);
500
501 let a = semaphore.try_acquire_all(1).unwrap();
502 assert_eq!(a.permits(), 3);
503 assert_eq!(semaphore.permits(), 0);
504 }
505
506 #[futures_test::test]
507 async fn acquire_all() {
508 let semaphore = GreedySemaphore::<NoopRawMutex>::new(3);
509
510 let a = semaphore.acquire_all(1).await.unwrap();
511 assert_eq!(a.permits(), 3);
512 assert_eq!(semaphore.permits(), 0);
513 }
514
515 #[test]
516 fn release() {
517 let semaphore = GreedySemaphore::<NoopRawMutex>::new(3);
518 assert_eq!(semaphore.permits(), 3);
519 semaphore.release(2);
520 assert_eq!(semaphore.permits(), 5);
521 }
522
523 #[test]
524 fn set() {
525 let semaphore = GreedySemaphore::<NoopRawMutex>::new(3);
526 assert_eq!(semaphore.permits(), 3);
527 semaphore.set(2);
528 assert_eq!(semaphore.permits(), 2);
529 }
530
531 #[test]
532 fn contested() {
533 let semaphore = GreedySemaphore::<NoopRawMutex>::new(3);
534
535 let a = semaphore.try_acquire(1).unwrap();
536 let b = semaphore.try_acquire(3);
537 assert!(b.is_none());
538
539 core::mem::drop(a);
540
541 let b = semaphore.try_acquire(3);
542 assert!(b.is_some());
543 }
544
545 #[futures_test::test]
546 async fn greedy() {
547 let semaphore = GreedySemaphore::<NoopRawMutex>::new(3);
548
549 let a = semaphore.try_acquire(1).unwrap();
550
551 let b_fut = semaphore.acquire(3);
552 let mut b_fut = pin!(b_fut);
553 let b = poll!(b_fut.as_mut());
554 assert!(b.is_pending());
555
556 // Succeed even through `b` is waiting
557 let c = semaphore.try_acquire(1);
558 assert!(c.is_some());
559
560 let b = poll!(b_fut.as_mut());
561 assert!(b.is_pending());
562
563 core::mem::drop(a);
564
565 let b = poll!(b_fut.as_mut());
566 assert!(b.is_pending());
567
568 core::mem::drop(c);
569
570 let b = poll!(b_fut.as_mut());
571 assert!(b.is_ready());
572 }
573 }
574
575 mod fair {
576 use core::pin::pin;
577
578 use futures_util::poll;
579
580 use super::super::*;
581 use crate::blocking_mutex::raw::NoopRawMutex;
582
583 #[test]
584 fn try_acquire() {
585 let semaphore = FairSemaphore::<NoopRawMutex, 2>::new(3);
586
587 let a = semaphore.try_acquire(1).unwrap();
588 assert_eq!(a.permits(), 1);
589 assert_eq!(semaphore.permits(), 2);
590
591 core::mem::drop(a);
592 assert_eq!(semaphore.permits(), 3);
593 }
594
595 #[test]
596 fn disarm() {
597 let semaphore = FairSemaphore::<NoopRawMutex, 2>::new(3);
598
599 let a = semaphore.try_acquire(1).unwrap();
600 assert_eq!(a.disarm(), 1);
601 assert_eq!(semaphore.permits(), 2);
602 }
603
604 #[futures_test::test]
605 async fn acquire() {
606 let semaphore = FairSemaphore::<NoopRawMutex, 2>::new(3);
607
608 let a = semaphore.acquire(1).await.unwrap();
609 assert_eq!(a.permits(), 1);
610 assert_eq!(semaphore.permits(), 2);
611
612 core::mem::drop(a);
613 assert_eq!(semaphore.permits(), 3);
614 }
615
616 #[test]
617 fn try_acquire_all() {
618 let semaphore = FairSemaphore::<NoopRawMutex, 2>::new(3);
619
620 let a = semaphore.try_acquire_all(1).unwrap();
621 assert_eq!(a.permits(), 3);
622 assert_eq!(semaphore.permits(), 0);
623 }
624
625 #[futures_test::test]
626 async fn acquire_all() {
627 let semaphore = FairSemaphore::<NoopRawMutex, 2>::new(3);
628
629 let a = semaphore.acquire_all(1).await.unwrap();
630 assert_eq!(a.permits(), 3);
631 assert_eq!(semaphore.permits(), 0);
632 }
633
634 #[test]
635 fn release() {
636 let semaphore = FairSemaphore::<NoopRawMutex, 2>::new(3);
637 assert_eq!(semaphore.permits(), 3);
638 semaphore.release(2);
639 assert_eq!(semaphore.permits(), 5);
640 }
641
642 #[test]
643 fn set() {
644 let semaphore = FairSemaphore::<NoopRawMutex, 2>::new(3);
645 assert_eq!(semaphore.permits(), 3);
646 semaphore.set(2);
647 assert_eq!(semaphore.permits(), 2);
648 }
649
650 #[test]
651 fn contested() {
652 let semaphore = FairSemaphore::<NoopRawMutex, 2>::new(3);
653
654 let a = semaphore.try_acquire(1).unwrap();
655 let b = semaphore.try_acquire(3);
656 assert!(b.is_none());
657
658 core::mem::drop(a);
659
660 let b = semaphore.try_acquire(3);
661 assert!(b.is_some());
662 }
663
664 #[futures_test::test]
665 async fn fairness() {
666 let semaphore = FairSemaphore::<NoopRawMutex, 2>::new(3);
667
668 let a = semaphore.try_acquire(1);
669 assert!(a.is_some());
670
671 let b_fut = semaphore.acquire(3);
672 let mut b_fut = pin!(b_fut);
673 let b = poll!(b_fut.as_mut()); // Poll `b_fut` once so it is registered
674 assert!(b.is_pending());
675
676 let c = semaphore.try_acquire(1);
677 assert!(c.is_none());
678
679 let c_fut = semaphore.acquire(1);
680 let mut c_fut = pin!(c_fut);
681 let c = poll!(c_fut.as_mut()); // Poll `c_fut` once so it is registered
682 assert!(c.is_pending()); // `c` is blocked behind `b`
683
684 let d = semaphore.acquire(1).await;
685 assert!(matches!(d, Err(WaitQueueFull)));
686
687 core::mem::drop(a);
688
689 let c = poll!(c_fut.as_mut());
690 assert!(c.is_pending()); // `c` is still blocked behind `b`
691
692 let b = poll!(b_fut.as_mut());
693 assert!(b.is_ready());
694
695 let c = poll!(c_fut.as_mut());
696 assert!(c.is_pending()); // `c` is still blocked behind `b`
697
698 core::mem::drop(b);
699
700 let c = poll!(c_fut.as_mut());
701 assert!(c.is_ready());
702 }
703 }
704}