aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDario Nieuwenhuis <[email protected]>2021-09-11 02:49:16 +0200
committerDario Nieuwenhuis <[email protected]>2021-09-11 02:49:16 +0200
commitb78f4695c4c7dac91ecf4c2aa420cf8bf4d8be52 (patch)
tree5405e73587d02f2b7772deb5e7bbba6f76b1d578
parent67fa6b06fafc8635d2063e687904d30864f45a05 (diff)
embassy/channel: use heapless::Deque.
-rw-r--r--embassy/Cargo.toml1
-rw-r--r--embassy/src/channel/mpsc.rs71
2 files changed, 22 insertions, 50 deletions
diff --git a/embassy/Cargo.toml b/embassy/Cargo.toml
index 0a8ab4434..ae06bc198 100644
--- a/embassy/Cargo.toml
+++ b/embassy/Cargo.toml
@@ -42,6 +42,7 @@ embassy-traits = { version = "0.1.0", path = "../embassy-traits"}
42atomic-polyfill = "0.1.3" 42atomic-polyfill = "0.1.3"
43critical-section = "0.2.1" 43critical-section = "0.2.1"
44embedded-hal = "0.2.6" 44embedded-hal = "0.2.6"
45heapless = "0.7.5"
45 46
46[dev-dependencies] 47[dev-dependencies]
47embassy = { path = ".", features = ["executor-agnostic"] } 48embassy = { path = ".", features = ["executor-agnostic"] }
diff --git a/embassy/src/channel/mpsc.rs b/embassy/src/channel/mpsc.rs
index b20d48a95..c77452441 100644
--- a/embassy/src/channel/mpsc.rs
+++ b/embassy/src/channel/mpsc.rs
@@ -40,14 +40,13 @@
40use core::cell::UnsafeCell; 40use core::cell::UnsafeCell;
41use core::fmt; 41use core::fmt;
42use core::marker::PhantomData; 42use core::marker::PhantomData;
43use core::mem::MaybeUninit;
44use core::pin::Pin; 43use core::pin::Pin;
45use core::ptr;
46use core::task::Context; 44use core::task::Context;
47use core::task::Poll; 45use core::task::Poll;
48use core::task::Waker; 46use core::task::Waker;
49 47
50use futures::Future; 48use futures::Future;
49use heapless::Deque;
51 50
52use crate::blocking_mutex::{CriticalSectionMutex, Mutex, NoopMutex, ThreadModeMutex}; 51use crate::blocking_mutex::{CriticalSectionMutex, Mutex, NoopMutex, ThreadModeMutex};
53use crate::waitqueue::WakerRegistration; 52use crate::waitqueue::WakerRegistration;
@@ -446,10 +445,7 @@ impl<T> defmt::Format for TrySendError<T> {
446} 445}
447 446
448struct ChannelState<T, const N: usize> { 447struct ChannelState<T, const N: usize> {
449 buf: [MaybeUninit<UnsafeCell<T>>; N], 448 queue: Deque<T, N>,
450 read_pos: usize,
451 write_pos: usize,
452 full: bool,
453 closed: bool, 449 closed: bool,
454 receiver_registered: bool, 450 receiver_registered: bool,
455 senders_registered: u32, 451 senders_registered: u32,
@@ -458,14 +454,9 @@ struct ChannelState<T, const N: usize> {
458} 454}
459 455
460impl<T, const N: usize> ChannelState<T, N> { 456impl<T, const N: usize> ChannelState<T, N> {
461 const INIT: MaybeUninit<UnsafeCell<T>> = MaybeUninit::uninit();
462
463 const fn new() -> Self { 457 const fn new() -> Self {
464 ChannelState { 458 ChannelState {
465 buf: [Self::INIT; N], 459 queue: Deque::new(),
466 read_pos: 0,
467 write_pos: 0,
468 full: false,
469 closed: false, 460 closed: false,
470 receiver_registered: false, 461 receiver_registered: false,
471 senders_registered: 0, 462 senders_registered: 0,
@@ -479,17 +470,16 @@ impl<T, const N: usize> ChannelState<T, N> {
479 } 470 }
480 471
481 fn try_recv_with_context(&mut self, cx: Option<&mut Context<'_>>) -> Result<T, TryRecvError> { 472 fn try_recv_with_context(&mut self, cx: Option<&mut Context<'_>>) -> Result<T, TryRecvError> {
482 if self.read_pos != self.write_pos || self.full { 473 if self.queue.is_full() {
483 if self.full { 474 self.senders_waker.wake();
484 self.full = false; 475 }
485 self.senders_waker.wake(); 476
486 } 477 if let Some(message) = self.queue.pop_front() {
487 let message = unsafe { (self.buf[self.read_pos]).assume_init_mut().get().read() };
488 self.read_pos = (self.read_pos + 1) % self.buf.len();
489 Ok(message) 478 Ok(message)
490 } else if !self.closed { 479 } else if !self.closed {
491 cx.into_iter() 480 if let Some(cx) = cx {
492 .for_each(|cx| self.set_receiver_waker(&cx.waker())); 481 self.set_receiver_waker(cx.waker());
482 }
493 Err(TryRecvError::Empty) 483 Err(TryRecvError::Empty)
494 } else { 484 } else {
495 Err(TryRecvError::Closed) 485 Err(TryRecvError::Closed)
@@ -505,22 +495,21 @@ impl<T, const N: usize> ChannelState<T, N> {
505 message: T, 495 message: T,
506 cx: Option<&mut Context<'_>>, 496 cx: Option<&mut Context<'_>>,
507 ) -> Result<(), TrySendError<T>> { 497 ) -> Result<(), TrySendError<T>> {
508 if !self.closed { 498 if self.closed {
509 if !self.full { 499 return Err(TrySendError::Closed(message));
510 self.buf[self.write_pos] = MaybeUninit::new(message.into()); 500 }
511 self.write_pos = (self.write_pos + 1) % self.buf.len(); 501
512 if self.write_pos == self.read_pos { 502 match self.queue.push_back(message) {
513 self.full = true; 503 Ok(()) => {
514 }
515 self.receiver_waker.wake(); 504 self.receiver_waker.wake();
505
516 Ok(()) 506 Ok(())
517 } else { 507 }
508 Err(message) => {
518 cx.into_iter() 509 cx.into_iter()
519 .for_each(|cx| self.set_senders_waker(&cx.waker())); 510 .for_each(|cx| self.set_senders_waker(&cx.waker()));
520 Err(TrySendError::Full(message)) 511 Err(TrySendError::Full(message))
521 } 512 }
522 } else {
523 Err(TrySendError::Closed(message))
524 } 513 }
525 } 514 }
526 515
@@ -585,16 +574,6 @@ impl<T, const N: usize> ChannelState<T, N> {
585 } 574 }
586} 575}
587 576
588impl<T, const N: usize> Drop for ChannelState<T, N> {
589 fn drop(&mut self) {
590 while self.read_pos != self.write_pos || self.full {
591 self.full = false;
592 unsafe { ptr::drop_in_place(self.buf[self.read_pos].as_mut_ptr()) };
593 self.read_pos = (self.read_pos + 1) % N;
594 }
595 }
596}
597
598/// A a bounded mpsc channel for communicating between asynchronous tasks 577/// A a bounded mpsc channel for communicating between asynchronous tasks
599/// with backpressure. 578/// with backpressure.
600/// 579///
@@ -676,15 +655,7 @@ mod tests {
676 use super::*; 655 use super::*;
677 656
678 fn capacity<T, const N: usize>(c: &ChannelState<T, N>) -> usize { 657 fn capacity<T, const N: usize>(c: &ChannelState<T, N>) -> usize {
679 if !c.full { 658 c.queue.capacity() - c.queue.len()
680 if c.write_pos > c.read_pos {
681 (c.buf.len() - c.write_pos) + c.read_pos
682 } else {
683 (c.buf.len() - c.read_pos) + c.write_pos
684 }
685 } else {
686 0
687 }
688 } 659 }
689 660
690 #[test] 661 #[test]