aboutsummaryrefslogtreecommitdiff
path: root/embassy-sync
diff options
context:
space:
mode:
Diffstat (limited to 'embassy-sync')
-rw-r--r--embassy-sync/src/multi_signal.rs340
1 files changed, 292 insertions, 48 deletions
diff --git a/embassy-sync/src/multi_signal.rs b/embassy-sync/src/multi_signal.rs
index 5f724c76b..1481dc8f8 100644
--- a/embassy-sync/src/multi_signal.rs
+++ b/embassy-sync/src/multi_signal.rs
@@ -97,7 +97,7 @@ impl<'a, M: RawMutex, T: Clone, const N: usize> MultiSignal<M, T, N> {
97 } 97 }
98 98
99 /// Get a [`Receiver`] for the `MultiSignal`. 99 /// Get a [`Receiver`] for the `MultiSignal`.
100 pub fn receiver(&'a self) -> Result<Receiver<'a, M, T, N>, Error> { 100 pub fn receiver<'s>(&'a self) -> Result<Receiver<'a, M, T, N>, Error> {
101 self.mutex.lock(|state| { 101 self.mutex.lock(|state| {
102 let mut s = state.borrow_mut(); 102 let mut s = state.borrow_mut();
103 if s.receiver_count < N { 103 if s.receiver_count < N {
@@ -142,60 +142,36 @@ impl<'a, M: RawMutex, T: Clone, const N: usize> MultiSignal<M, T, N> {
142 fn get_id(&self) -> u64 { 142 fn get_id(&self) -> u64 {
143 self.mutex.lock(|state| state.borrow().current_id) 143 self.mutex.lock(|state| state.borrow().current_id)
144 } 144 }
145
146 /// Poll the `MultiSignal` with an optional context.
147 fn get_with_context(&self, rcv: &mut Rcv<'a, M, T, N>, cx: Option<&mut Context>) -> Poll<T> {
148 self.mutex.lock(|state| {
149 let mut s = state.borrow_mut();
150 match (s.current_id > rcv.at_id, rcv.predicate) {
151 (true, None) => {
152 rcv.at_id = s.current_id;
153 Poll::Ready(s.data.clone())
154 }
155 (true, Some(f)) if f(&s.data) => {
156 rcv.at_id = s.current_id;
157 Poll::Ready(s.data.clone())
158 }
159 _ => {
160 if let Some(cx) = cx {
161 s.wakers.register(cx.waker());
162 }
163 Poll::Pending
164 }
165 }
166 })
167 }
168} 145}
169 146
170/// A receiver is able to `.await` a changed `MultiSignal` value. 147/// A receiver is able to `.await` a changed `MultiSignal` value.
171pub struct Rcv<'a, M: RawMutex, T: Clone, const N: usize> { 148pub struct Rcv<'a, M: RawMutex, T: Clone, const N: usize> {
172 multi_sig: &'a MultiSignal<M, T, N>, 149 multi_sig: &'a MultiSignal<M, T, N>,
173 predicate: Option<fn(&T) -> bool>,
174 at_id: u64, 150 at_id: u64,
175} 151}
176 152
177// f: Option<impl FnMut(&T) -> bool> 153impl<'s, 'a, M: RawMutex, T: Clone, const N: usize> Rcv<'a, M, T, N> {
178impl<'a, M: RawMutex, T: Clone, const N: usize> Rcv<'a, M, T, N> {
179 /// Create a new `Receiver` with a reference the given `MultiSignal`. 154 /// Create a new `Receiver` with a reference the given `MultiSignal`.
180 fn new(multi_sig: &'a MultiSignal<M, T, N>) -> Self { 155 fn new(multi_sig: &'a MultiSignal<M, T, N>) -> Self {
181 Self { 156 Self { multi_sig, at_id: 0 }
182 multi_sig,
183 predicate: None,
184 at_id: 0,
185 }
186 } 157 }
187 158
188 /// Wait for a change to the value of the corresponding `MultiSignal`. 159 /// Wait for a change to the value of the corresponding `MultiSignal`.
189 pub fn changed<'s>(&'s mut self) -> ReceiverFuture<'s, 'a, M, T, N> { 160 pub async fn changed(&mut self) -> T {
190 self.predicate = None; 161 ReceiverWaitFuture { subscriber: self }.await
191 ReceiverFuture { subscriber: self }
192 } 162 }
193 163
194 /// Wait for a change to the value of the corresponding `MultiSignal` which matches the predicate `f`. 164 /// Wait for a change to the value of the corresponding `MultiSignal` which matches the predicate `f`.
195 // TODO: How do we make this work with a FnMut closure? 165 // TODO: How do we make this work with a FnMut closure?
196 pub fn changed_and<'s>(&'s mut self, f: fn(&T) -> bool) -> ReceiverFuture<'s, 'a, M, T, N> { 166 pub async fn changed_and<F>(&mut self, f: F) -> T
197 self.predicate = Some(f); 167 where
198 ReceiverFuture { subscriber: self } 168 F: FnMut(&T) -> bool,
169 {
170 ReceiverPredFuture {
171 subscriber: self,
172 predicate: f,
173 }
174 .await
199 } 175 }
200 176
201 /// Try to get a changed value of the corresponding `MultiSignal`. 177 /// Try to get a changed value of the corresponding `MultiSignal`.
@@ -213,7 +189,10 @@ impl<'a, M: RawMutex, T: Clone, const N: usize> Rcv<'a, M, T, N> {
213 } 189 }
214 190
215 /// Try to get a changed value of the corresponding `MultiSignal` which matches the predicate `f`. 191 /// Try to get a changed value of the corresponding `MultiSignal` which matches the predicate `f`.
216 pub fn try_changed_and(&mut self, mut f: impl FnMut(&T) -> bool) -> Option<T> { 192 pub fn try_changed_and<F>(&mut self, mut f: F) -> Option<T>
193 where
194 F: FnMut(&T) -> bool,
195 {
217 self.multi_sig.mutex.lock(|state| { 196 self.multi_sig.mutex.lock(|state| {
218 let s = state.borrow(); 197 let s = state.borrow();
219 match s.current_id > self.at_id && f(&s.data) { 198 match s.current_id > self.at_id && f(&s.data) {
@@ -232,7 +211,10 @@ impl<'a, M: RawMutex, T: Clone, const N: usize> Rcv<'a, M, T, N> {
232 } 211 }
233 212
234 /// Peek the current value of the corresponding `MultiSignal` and check if it satisfies the predicate `f`. 213 /// Peek the current value of the corresponding `MultiSignal` and check if it satisfies the predicate `f`.
235 pub fn peek_and(&self, f: impl FnMut(&T) -> bool) -> Option<T> { 214 pub fn peek_and<F>(&self, f: F) -> Option<T>
215 where
216 F: FnMut(&T) -> bool,
217 {
236 self.multi_sig.peek_and(f) 218 self.multi_sig.peek_and(f)
237 } 219 }
238 220
@@ -247,7 +229,7 @@ impl<'a, M: RawMutex, T: Clone, const N: usize> Rcv<'a, M, T, N> {
247/// A `Receiver` is able to `.await` a change to the corresponding [`MultiSignal`] value. 229/// A `Receiver` is able to `.await` a change to the corresponding [`MultiSignal`] value.
248pub struct Receiver<'a, M: RawMutex, T: Clone, const N: usize>(Rcv<'a, M, T, N>); 230pub struct Receiver<'a, M: RawMutex, T: Clone, const N: usize>(Rcv<'a, M, T, N>);
249 231
250impl<'a, M: RawMutex, T: Clone, const N: usize> Deref for Receiver<'a, M, T, N> { 232impl<'s, 'a, M: RawMutex, T: Clone, const N: usize> Deref for Receiver<'a, M, T, N> {
251 type Target = Rcv<'a, M, T, N>; 233 type Target = Rcv<'a, M, T, N>;
252 234
253 fn deref(&self) -> &Self::Target { 235 fn deref(&self) -> &Self::Target {
@@ -255,7 +237,7 @@ impl<'a, M: RawMutex, T: Clone, const N: usize> Deref for Receiver<'a, M, T, N>
255 } 237 }
256} 238}
257 239
258impl<'a, M: RawMutex, T: Clone, const N: usize> DerefMut for Receiver<'a, M, T, N> { 240impl<'s, 'a, M: RawMutex, T: Clone, const N: usize> DerefMut for Receiver<'a, M, T, N> {
259 fn deref_mut(&mut self) -> &mut Self::Target { 241 fn deref_mut(&mut self) -> &mut Self::Target {
260 &mut self.0 242 &mut self.0
261 } 243 }
@@ -263,18 +245,280 @@ impl<'a, M: RawMutex, T: Clone, const N: usize> DerefMut for Receiver<'a, M, T,
263 245
264/// Future for the `Receiver` wait action 246/// Future for the `Receiver` wait action
265#[must_use = "futures do nothing unless you `.await` or poll them"] 247#[must_use = "futures do nothing unless you `.await` or poll them"]
266pub struct ReceiverFuture<'s, 'a, M: RawMutex, T: Clone, const N: usize> { 248pub struct ReceiverWaitFuture<'s, 'a, M: RawMutex, T: Clone, const N: usize> {
267 subscriber: &'s mut Rcv<'a, M, T, N>, 249 subscriber: &'s mut Rcv<'a, M, T, N>,
268} 250}
269 251
270impl<'s, 'a, M: RawMutex, T: Clone, const N: usize> Future for ReceiverFuture<'s, 'a, M, T, N> { 252impl<'s, 'a, M: RawMutex, T: Clone, const N: usize> Unpin for ReceiverWaitFuture<'s, 'a, M, T, N> {}
253impl<'s, 'a, M: RawMutex, T: Clone, const N: usize> Future for ReceiverWaitFuture<'s, 'a, M, T, N> {
271 type Output = T; 254 type Output = T;
272 255
273 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { 256 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
274 self.subscriber 257 self.get_with_context(Some(cx))
275 .multi_sig 258 }
276 .get_with_context(&mut self.subscriber, Some(cx)) 259}
260
261impl<'s, 'a, M: RawMutex, T: Clone, const N: usize> ReceiverWaitFuture<'s, 'a, M, T, N> {
262 /// Poll the `MultiSignal` with an optional context.
263 fn get_with_context(&mut self, cx: Option<&mut Context>) -> Poll<T> {
264 self.subscriber.multi_sig.mutex.lock(|state| {
265 let mut s = state.borrow_mut();
266 match s.current_id > self.subscriber.at_id {
267 true => {
268 self.subscriber.at_id = s.current_id;
269 Poll::Ready(s.data.clone())
270 }
271 _ => {
272 if let Some(cx) = cx {
273 s.wakers.register(cx.waker());
274 }
275 Poll::Pending
276 }
277 }
278 })
279 }
280}
281
282/// Future for the `Receiver` wait action, with the ability to filter the value with a predicate.
283#[must_use = "futures do nothing unless you `.await` or poll them"]
284pub struct ReceiverPredFuture<'s, 'a, M: RawMutex, T: Clone, F: FnMut(&'a T) -> bool, const N: usize> {
285 subscriber: &'s mut Rcv<'a, M, T, N>,
286 predicate: F,
287}
288
289impl<'s, 'a, M: RawMutex, T: Clone, F: FnMut(&T) -> bool, const N: usize> Unpin for ReceiverPredFuture<'s, 'a, M, T, F, N> {}
290impl<'s, 'a, M: RawMutex, T: Clone, F: FnMut(&T) -> bool, const N: usize> Future for ReceiverPredFuture<'s, 'a, M, T, F, N>{
291 type Output = T;
292
293 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
294 self.get_with_context_pred(Some(cx))
295 }
296}
297
298impl<'s, 'a, M: RawMutex, T: Clone, F: FnMut(&T) -> bool, const N: usize> ReceiverPredFuture<'s, 'a, M, T, F, N> {
299 /// Poll the `MultiSignal` with an optional context.
300 fn get_with_context_pred(&mut self, cx: Option<&mut Context>) -> Poll<T> {
301 self.subscriber.multi_sig.mutex.lock(|state| {
302 let mut s = state.borrow_mut();
303 match s.current_id > self.subscriber.at_id {
304 true if (self.predicate)(&s.data) => {
305 self.subscriber.at_id = s.current_id;
306 Poll::Ready(s.data.clone())
307 }
308 _ => {
309 if let Some(cx) = cx {
310 s.wakers.register(cx.waker());
311 }
312 Poll::Pending
313 }
314 }
315 })
277 } 316 }
278} 317}
279 318
280impl<'s, 'a, M: RawMutex, T: Clone, const N: usize> Unpin for ReceiverFuture<'s, 'a, M, T, N> {} 319#[cfg(test)]
320mod tests {
321 use super::*;
322 use crate::blocking_mutex::raw::CriticalSectionRawMutex;
323 use futures_executor::block_on;
324
325 #[test]
326 fn multiple_writes() {
327 let f = async {
328 static SOME_SIGNAL: MultiSignal<CriticalSectionRawMutex, u8, 2> = MultiSignal::new(0);
329
330 // Obtain Receivers
331 let mut rcv0 = SOME_SIGNAL.receiver().unwrap();
332 let mut rcv1 = SOME_SIGNAL.receiver().unwrap();
333
334 SOME_SIGNAL.write(10);
335
336 // Receive the new value
337 assert_eq!(rcv0.changed().await, 10);
338 assert_eq!(rcv1.changed().await, 10);
339
340 // No update
341 assert_eq!(rcv0.try_changed(), None);
342 assert_eq!(rcv1.try_changed(), None);
343
344 SOME_SIGNAL.write(20);
345
346 assert_eq!(rcv0.changed().await, 20);
347 assert_eq!(rcv1.changed().await, 20);
348 };
349 block_on(f);
350 }
351
352 #[test]
353 fn max_receivers() {
354 let f = async {
355 static SOME_SIGNAL: MultiSignal<CriticalSectionRawMutex, u8, 2> = MultiSignal::new(0);
356
357 // Obtain Receivers
358 let _ = SOME_SIGNAL.receiver().unwrap();
359 let _ = SOME_SIGNAL.receiver().unwrap();
360 assert!(SOME_SIGNAL.receiver().is_err());
361 };
362 block_on(f);
363 }
364
365 // Really weird edge case, but it's possible to have a receiver that never gets a value.
366 #[test]
367 fn receive_initial() {
368 let f = async {
369 static SOME_SIGNAL: MultiSignal<CriticalSectionRawMutex, u8, 2> = MultiSignal::new(0);
370
371 // Obtain Receivers
372 let mut rcv0 = SOME_SIGNAL.receiver().unwrap();
373 let mut rcv1 = SOME_SIGNAL.receiver().unwrap();
374
375 assert_eq!(rcv0.try_changed(), Some(0));
376 assert_eq!(rcv1.try_changed(), Some(0));
377
378 assert_eq!(rcv0.try_changed(), None);
379 assert_eq!(rcv1.try_changed(), None);
380 };
381 block_on(f);
382 }
383
384 #[test]
385 fn count_ids() {
386 let f = async {
387 static SOME_SIGNAL: MultiSignal<CriticalSectionRawMutex, u8, 2> = MultiSignal::new(0);
388
389 // Obtain Receivers
390 let mut rcv0 = SOME_SIGNAL.receiver().unwrap();
391 let mut rcv1 = SOME_SIGNAL.receiver().unwrap();
392
393 SOME_SIGNAL.write(10);
394
395 assert_eq!(rcv0.changed().await, 10);
396 assert_eq!(rcv1.changed().await, 10);
397
398 assert_eq!(rcv0.try_changed(), None);
399 assert_eq!(rcv1.try_changed(), None);
400
401 SOME_SIGNAL.write(20);
402 SOME_SIGNAL.write(20);
403 SOME_SIGNAL.write(20);
404
405 assert_eq!(rcv0.changed().await, 20);
406 assert_eq!(rcv1.changed().await, 20);
407
408 assert_eq!(rcv0.try_changed(), None);
409 assert_eq!(rcv1.try_changed(), None);
410
411 assert_eq!(SOME_SIGNAL.get_id(), 5);
412 };
413 block_on(f);
414 }
415
416 #[test]
417 fn peek_still_await() {
418 let f = async {
419 static SOME_SIGNAL: MultiSignal<CriticalSectionRawMutex, u8, 2> = MultiSignal::new(0);
420
421 // Obtain Receivers
422 let mut rcv0 = SOME_SIGNAL.receiver().unwrap();
423 let mut rcv1 = SOME_SIGNAL.receiver().unwrap();
424
425 SOME_SIGNAL.write(10);
426
427 assert_eq!(rcv0.peek(), 10);
428 assert_eq!(rcv1.peek(), 10);
429
430 assert_eq!(rcv0.changed().await, 10);
431 assert_eq!(rcv1.changed().await, 10);
432 };
433 block_on(f);
434 }
435
436 #[test]
437 fn predicate() {
438 let f = async {
439 static SOME_SIGNAL: MultiSignal<CriticalSectionRawMutex, u8, 2> = MultiSignal::new(0);
440
441 // Obtain Receivers
442 let mut rcv0 = SOME_SIGNAL.receiver().unwrap();
443 let mut rcv1 = SOME_SIGNAL.receiver().unwrap();
444
445 SOME_SIGNAL.write(20);
446
447 assert_eq!(rcv0.changed_and(|x| x > &10).await, 20);
448 assert_eq!(rcv1.try_changed_and(|x| x > &30), None);
449 };
450 block_on(f);
451 }
452
453 #[test]
454 fn mutable_predicate() {
455 let f = async {
456 static SOME_SIGNAL: MultiSignal<CriticalSectionRawMutex, u8, 2> = MultiSignal::new(0);
457
458 // Obtain Receivers
459 let mut rcv = SOME_SIGNAL.receiver().unwrap();
460
461 SOME_SIGNAL.write(10);
462
463 let mut largest = 0;
464 let mut predicate = |x: &u8| {
465 if *x > largest {
466 largest = *x;
467 }
468 true
469 };
470
471 assert_eq!(rcv.changed_and(&mut predicate).await, 10);
472
473 SOME_SIGNAL.write(20);
474
475 assert_eq!(rcv.changed_and(&mut predicate).await, 20);
476
477 SOME_SIGNAL.write(5);
478
479 assert_eq!(rcv.changed_and(&mut predicate).await, 5);
480
481 assert_eq!(largest, 20)
482 };
483 block_on(f);
484 }
485
486 #[test]
487 fn peek_and() {
488 let f = async {
489 static SOME_SIGNAL: MultiSignal<CriticalSectionRawMutex, u8, 2> = MultiSignal::new(0);
490
491 // Obtain Receivers
492 let mut rcv0 = SOME_SIGNAL.receiver().unwrap();
493 let mut rcv1 = SOME_SIGNAL.receiver().unwrap();
494
495 SOME_SIGNAL.write(20);
496
497 assert_eq!(rcv0.peek_and(|x| x > &10), Some(20));
498 assert_eq!(rcv1.peek_and(|x| x > &30), None);
499
500 assert_eq!(rcv0.changed().await, 20);
501 assert_eq!(rcv1.changed().await, 20);
502 };
503 block_on(f);
504 }
505
506 #[test]
507 fn peek_with_static() {
508 let f = async {
509 static SOME_SIGNAL: MultiSignal<CriticalSectionRawMutex, u8, 2> = MultiSignal::new(0);
510
511 // Obtain Receivers
512 let rcv0 = SOME_SIGNAL.receiver().unwrap();
513 let rcv1 = SOME_SIGNAL.receiver().unwrap();
514
515 SOME_SIGNAL.write(20);
516
517 assert_eq!(rcv0.peek(), 20);
518 assert_eq!(rcv1.peek(), 20);
519 assert_eq!(SOME_SIGNAL.peek(), 20);
520 assert_eq!(SOME_SIGNAL.peek_and(|x| x > &30), None);
521 };
522 block_on(f);
523 }
524}