aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--embassy/src/util/mpsc.rs111
1 files changed, 61 insertions, 50 deletions
diff --git a/embassy/src/util/mpsc.rs b/embassy/src/util/mpsc.rs
index f350c6e53..246bd27e4 100644
--- a/embassy/src/util/mpsc.rs
+++ b/embassy/src/util/mpsc.rs
@@ -39,6 +39,7 @@
39 39
40use core::cell::UnsafeCell; 40use core::cell::UnsafeCell;
41use core::fmt; 41use core::fmt;
42use core::marker::PhantomData;
42use core::mem::MaybeUninit; 43use core::mem::MaybeUninit;
43use core::pin::Pin; 44use core::pin::Pin;
44use core::ptr; 45use core::ptr;
@@ -61,7 +62,7 @@ pub struct Sender<'ch, M, T, const N: usize>
61where 62where
62 M: Mutex<Data = ()>, 63 M: Mutex<Data = ()>,
63{ 64{
64 channel: &'ch Channel<M, T, N>, 65 channel_cell: &'ch UnsafeCell<ChannelCell<M, T, N>>,
65} 66}
66 67
67// Safe to pass the sender around 68// Safe to pass the sender around
@@ -77,7 +78,8 @@ pub struct Receiver<'ch, M, T, const N: usize>
77where 78where
78 M: Mutex<Data = ()>, 79 M: Mutex<Data = ()>,
79{ 80{
80 channel: &'ch Channel<M, T, N>, 81 channel_cell: &'ch UnsafeCell<ChannelCell<M, T, N>>,
82 _receiver_consumed: &'ch mut PhantomData<()>,
81} 83}
82 84
83// Safe to pass the receiver around 85// Safe to pass the receiver around
@@ -111,18 +113,23 @@ unsafe impl<'ch, M, T, const N: usize> Sync for Receiver<'ch, M, T, N> where
111/// 113///
112/// let (sender, receiver) = { 114/// let (sender, receiver) = {
113/// let mut channel = Channel::<WithThreadModeOnly, u32, 3>::with_thread_mode_only(); 115/// let mut channel = Channel::<WithThreadModeOnly, u32, 3>::with_thread_mode_only();
114/// mpsc::split(&channel) 116/// mpsc::split(&mut channel)
115/// }; 117/// };
116/// ``` 118/// ```
117pub fn split<M, T, const N: usize>( 119pub fn split<M, T, const N: usize>(
118 channel: &Channel<M, T, N>, 120 channel: &mut Channel<M, T, N>,
119) -> (Sender<M, T, N>, Receiver<M, T, N>) 121) -> (Sender<M, T, N>, Receiver<M, T, N>)
120where 122where
121 M: Mutex<Data = ()>, 123 M: Mutex<Data = ()>,
122{ 124{
123 let sender = Sender { channel: &channel }; 125 let sender = Sender {
124 let receiver = Receiver { channel: &channel }; 126 channel_cell: &channel.channel_cell,
125 channel.lock(|c| { 127 };
128 let receiver = Receiver {
129 channel_cell: &channel.channel_cell,
130 _receiver_consumed: &mut channel.receiver_consumed,
131 };
132 Channel::lock(&channel.channel_cell, |c| {
126 c.register_receiver(); 133 c.register_receiver();
127 c.register_sender(); 134 c.register_sender();
128 }); 135 });
@@ -154,12 +161,13 @@ where
154 } 161 }
155 162
156 fn recv_poll(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> { 163 fn recv_poll(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
157 self.channel 164 Channel::lock(self.channel_cell, |c| {
158 .lock(|c| match c.try_recv_with_context(Some(cx)) { 165 match c.try_recv_with_context(Some(cx)) {
159 Ok(v) => Poll::Ready(Some(v)), 166 Ok(v) => Poll::Ready(Some(v)),
160 Err(TryRecvError::Closed) => Poll::Ready(None), 167 Err(TryRecvError::Closed) => Poll::Ready(None),
161 Err(TryRecvError::Empty) => Poll::Pending, 168 Err(TryRecvError::Empty) => Poll::Pending,
162 }) 169 }
170 })
163 } 171 }
164 172
165 /// Attempts to immediately receive a message on this `Receiver` 173 /// Attempts to immediately receive a message on this `Receiver`
@@ -167,7 +175,7 @@ where
167 /// This method will either receive a message from the channel immediately or return an error 175 /// This method will either receive a message from the channel immediately or return an error
168 /// if the channel is empty. 176 /// if the channel is empty.
169 pub fn try_recv(&self) -> Result<T, TryRecvError> { 177 pub fn try_recv(&self) -> Result<T, TryRecvError> {
170 self.channel.lock(|c| c.try_recv()) 178 Channel::lock(self.channel_cell, |c| c.try_recv())
171 } 179 }
172 180
173 /// Closes the receiving half of a channel without dropping it. 181 /// Closes the receiving half of a channel without dropping it.
@@ -181,7 +189,7 @@ where
181 /// until those are released. 189 /// until those are released.
182 /// 190 ///
183 pub fn close(&mut self) { 191 pub fn close(&mut self) {
184 self.channel.lock(|c| c.close()) 192 Channel::lock(self.channel_cell, |c| c.close())
185 } 193 }
186} 194}
187 195
@@ -190,7 +198,7 @@ where
190 M: Mutex<Data = ()>, 198 M: Mutex<Data = ()>,
191{ 199{
192 fn drop(&mut self) { 200 fn drop(&mut self) {
193 self.channel.lock(|c| c.deregister_receiver()) 201 Channel::lock(self.channel_cell, |c| c.deregister_receiver())
194 } 202 }
195} 203}
196 204
@@ -245,7 +253,7 @@ where
245 /// [`channel`]: channel 253 /// [`channel`]: channel
246 /// [`close`]: Receiver::close 254 /// [`close`]: Receiver::close
247 pub fn try_send(&self, message: T) -> Result<(), TrySendError<T>> { 255 pub fn try_send(&self, message: T) -> Result<(), TrySendError<T>> {
248 self.channel.lock(|c| c.try_send(message)) 256 Channel::lock(self.channel_cell, |c| c.try_send(message))
249 } 257 }
250 258
251 /// Completes when the receiver has dropped. 259 /// Completes when the receiver has dropped.
@@ -266,7 +274,7 @@ where
266 /// [`Receiver`]: crate::sync::mpsc::Receiver 274 /// [`Receiver`]: crate::sync::mpsc::Receiver
267 /// [`Receiver::close`]: crate::sync::mpsc::Receiver::close 275 /// [`Receiver::close`]: crate::sync::mpsc::Receiver::close
268 pub fn is_closed(&self) -> bool { 276 pub fn is_closed(&self) -> bool {
269 self.channel.lock(|c| c.is_closed()) 277 Channel::lock(self.channel_cell, |c| c.is_closed())
270 } 278 }
271} 279}
272 280
@@ -286,11 +294,9 @@ where
286 294
287 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { 295 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
288 match self.message.take() { 296 match self.message.take() {
289 Some(m) => match self 297 Some(m) => match Channel::lock(self.sender.channel_cell, |c| {
290 .sender 298 c.try_send_with_context(m, Some(cx))
291 .channel 299 }) {
292 .lock(|c| c.try_send_with_context(m, Some(cx)))
293 {
294 Ok(..) => Poll::Ready(Ok(())), 300 Ok(..) => Poll::Ready(Ok(())),
295 Err(TrySendError::Closed(m)) => Poll::Ready(Err(SendError(m))), 301 Err(TrySendError::Closed(m)) => Poll::Ready(Err(SendError(m))),
296 Err(TrySendError::Full(m)) => { 302 Err(TrySendError::Full(m)) => {
@@ -319,11 +325,9 @@ where
319 type Output = (); 325 type Output = ();
320 326
321 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { 327 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
322 if self 328 if Channel::lock(self.sender.channel_cell, |c| {
323 .sender 329 c.is_closed_with_context(Some(cx))
324 .channel 330 }) {
325 .lock(|c| c.is_closed_with_context(Some(cx)))
326 {
327 Poll::Ready(()) 331 Poll::Ready(())
328 } else { 332 } else {
329 Poll::Pending 333 Poll::Pending
@@ -336,7 +340,7 @@ where
336 M: Mutex<Data = ()>, 340 M: Mutex<Data = ()>,
337{ 341{
338 fn drop(&mut self) { 342 fn drop(&mut self) {
339 self.channel.lock(|c| c.deregister_sender()) 343 Channel::lock(self.channel_cell, |c| c.deregister_sender())
340 } 344 }
341} 345}
342 346
@@ -346,9 +350,9 @@ where
346{ 350{
347 #[allow(clippy::clone_double_ref)] 351 #[allow(clippy::clone_double_ref)]
348 fn clone(&self) -> Self { 352 fn clone(&self) -> Self {
349 self.channel.lock(|c| c.register_sender()); 353 Channel::lock(self.channel_cell, |c| c.register_sender());
350 Sender { 354 Sender {
351 channel: self.channel.clone(), 355 channel_cell: self.channel_cell.clone(),
352 } 356 }
353 } 357 }
354} 358}
@@ -564,6 +568,7 @@ where
564 M: Mutex<Data = ()>, 568 M: Mutex<Data = ()>,
565{ 569{
566 channel_cell: UnsafeCell<ChannelCell<M, T, N>>, 570 channel_cell: UnsafeCell<ChannelCell<M, T, N>>,
571 receiver_consumed: PhantomData<()>,
567} 572}
568 573
569struct ChannelCell<M, T, const N: usize> 574struct ChannelCell<M, T, const N: usize>
@@ -588,7 +593,7 @@ impl<T, const N: usize> Channel<WithCriticalSections, T, N> {
588 /// // Declare a bounded channel of 3 u32s. 593 /// // Declare a bounded channel of 3 u32s.
589 /// let mut channel = Channel::<WithCriticalSections, u32, 3>::with_critical_sections(); 594 /// let mut channel = Channel::<WithCriticalSections, u32, 3>::with_critical_sections();
590 /// // once we have a channel, obtain its sender and receiver 595 /// // once we have a channel, obtain its sender and receiver
591 /// let (sender, receiver) = mpsc::split(&channel); 596 /// let (sender, receiver) = mpsc::split(&mut channel);
592 /// ``` 597 /// ```
593 pub const fn with_critical_sections() -> Self { 598 pub const fn with_critical_sections() -> Self {
594 let mutex = CriticalSectionMutex::new(()); 599 let mutex = CriticalSectionMutex::new(());
@@ -596,6 +601,7 @@ impl<T, const N: usize> Channel<WithCriticalSections, T, N> {
596 let channel_cell = ChannelCell { mutex, state }; 601 let channel_cell = ChannelCell { mutex, state };
597 Channel { 602 Channel {
598 channel_cell: UnsafeCell::new(channel_cell), 603 channel_cell: UnsafeCell::new(channel_cell),
604 receiver_consumed: PhantomData,
599 } 605 }
600 } 606 }
601} 607}
@@ -615,7 +621,7 @@ impl<T, const N: usize> Channel<WithThreadModeOnly, T, N> {
615 /// // Declare a bounded channel of 3 u32s. 621 /// // Declare a bounded channel of 3 u32s.
616 /// let mut channel = Channel::<WithThreadModeOnly, u32, 3>::with_thread_mode_only(); 622 /// let mut channel = Channel::<WithThreadModeOnly, u32, 3>::with_thread_mode_only();
617 /// // once we have a channel, obtain its sender and receiver 623 /// // once we have a channel, obtain its sender and receiver
618 /// let (sender, receiver) = mpsc::split(&channel); 624 /// let (sender, receiver) = mpsc::split(&mut channel);
619 /// ``` 625 /// ```
620 pub const fn with_thread_mode_only() -> Self { 626 pub const fn with_thread_mode_only() -> Self {
621 let mutex = ThreadModeMutex::new(()); 627 let mutex = ThreadModeMutex::new(());
@@ -623,6 +629,7 @@ impl<T, const N: usize> Channel<WithThreadModeOnly, T, N> {
623 let channel_cell = ChannelCell { mutex, state }; 629 let channel_cell = ChannelCell { mutex, state };
624 Channel { 630 Channel {
625 channel_cell: UnsafeCell::new(channel_cell), 631 channel_cell: UnsafeCell::new(channel_cell),
632 receiver_consumed: PhantomData,
626 } 633 }
627 } 634 }
628} 635}
@@ -639,7 +646,7 @@ impl<T, const N: usize> Channel<WithNoThreads, T, N> {
639 /// // Declare a bounded channel of 3 u32s. 646 /// // Declare a bounded channel of 3 u32s.
640 /// let mut channel = Channel::<WithNoThreads, u32, 3>::with_no_threads(); 647 /// let mut channel = Channel::<WithNoThreads, u32, 3>::with_no_threads();
641 /// // once we have a channel, obtain its sender and receiver 648 /// // once we have a channel, obtain its sender and receiver
642 /// let (sender, receiver) = mpsc::split(&channel); 649 /// let (sender, receiver) = mpsc::split(&mut channel);
643 /// ``` 650 /// ```
644 pub const fn with_no_threads() -> Self { 651 pub const fn with_no_threads() -> Self {
645 let mutex = NoopMutex::new(()); 652 let mutex = NoopMutex::new(());
@@ -647,6 +654,7 @@ impl<T, const N: usize> Channel<WithNoThreads, T, N> {
647 let channel_cell = ChannelCell { mutex, state }; 654 let channel_cell = ChannelCell { mutex, state };
648 Channel { 655 Channel {
649 channel_cell: UnsafeCell::new(channel_cell), 656 channel_cell: UnsafeCell::new(channel_cell),
657 receiver_consumed: PhantomData,
650 } 658 }
651 } 659 }
652} 660}
@@ -655,9 +663,12 @@ impl<M, T, const N: usize> Channel<M, T, N>
655where 663where
656 M: Mutex<Data = ()>, 664 M: Mutex<Data = ()>,
657{ 665{
658 fn lock<R>(&self, f: impl FnOnce(&mut ChannelState<T, N>) -> R) -> R { 666 fn lock<R>(
667 channel_cell: &UnsafeCell<ChannelCell<M, T, N>>,
668 f: impl FnOnce(&mut ChannelState<T, N>) -> R,
669 ) -> R {
659 unsafe { 670 unsafe {
660 let channel_cell = &mut *(self.channel_cell.get()); 671 let channel_cell = &mut *(channel_cell.get());
661 let mutex = &mut channel_cell.mutex; 672 let mutex = &mut channel_cell.mutex;
662 let mut state = &mut channel_cell.state; 673 let mut state = &mut channel_cell.state;
663 mutex.lock(|_| f(&mut state)) 674 mutex.lock(|_| f(&mut state))
@@ -747,16 +758,16 @@ mod tests {
747 758
748 #[test] 759 #[test]
749 fn simple_send_and_receive() { 760 fn simple_send_and_receive() {
750 let c = Channel::<WithNoThreads, u32, 3>::with_no_threads(); 761 let mut c = Channel::<WithNoThreads, u32, 3>::with_no_threads();
751 let (s, r) = split(&c); 762 let (s, r) = split(&mut c);
752 assert!(s.clone().try_send(1).is_ok()); 763 assert!(s.clone().try_send(1).is_ok());
753 assert_eq!(r.try_recv().unwrap(), 1); 764 assert_eq!(r.try_recv().unwrap(), 1);
754 } 765 }
755 766
756 #[test] 767 #[test]
757 fn should_close_without_sender() { 768 fn should_close_without_sender() {
758 let c = Channel::<WithNoThreads, u32, 3>::with_no_threads(); 769 let mut c = Channel::<WithNoThreads, u32, 3>::with_no_threads();
759 let (s, r) = split(&c); 770 let (s, r) = split(&mut c);
760 drop(s); 771 drop(s);
761 match r.try_recv() { 772 match r.try_recv() {
762 Err(TryRecvError::Closed) => assert!(true), 773 Err(TryRecvError::Closed) => assert!(true),
@@ -766,8 +777,8 @@ mod tests {
766 777
767 #[test] 778 #[test]
768 fn should_close_once_drained() { 779 fn should_close_once_drained() {
769 let c = Channel::<WithNoThreads, u32, 3>::with_no_threads(); 780 let mut c = Channel::<WithNoThreads, u32, 3>::with_no_threads();
770 let (s, r) = split(&c); 781 let (s, r) = split(&mut c);
771 assert!(s.try_send(1).is_ok()); 782 assert!(s.try_send(1).is_ok());
772 drop(s); 783 drop(s);
773 assert_eq!(r.try_recv().unwrap(), 1); 784 assert_eq!(r.try_recv().unwrap(), 1);
@@ -779,8 +790,8 @@ mod tests {
779 790
780 #[test] 791 #[test]
781 fn should_reject_send_when_receiver_dropped() { 792 fn should_reject_send_when_receiver_dropped() {
782 let c = Channel::<WithNoThreads, u32, 3>::with_no_threads(); 793 let mut c = Channel::<WithNoThreads, u32, 3>::with_no_threads();
783 let (s, r) = split(&c); 794 let (s, r) = split(&mut c);
784 drop(r); 795 drop(r);
785 match s.try_send(1) { 796 match s.try_send(1) {
786 Err(TrySendError::Closed(1)) => assert!(true), 797 Err(TrySendError::Closed(1)) => assert!(true),
@@ -790,8 +801,8 @@ mod tests {
790 801
791 #[test] 802 #[test]
792 fn should_reject_send_when_channel_closed() { 803 fn should_reject_send_when_channel_closed() {
793 let c = Channel::<WithNoThreads, u32, 3>::with_no_threads(); 804 let mut c = Channel::<WithNoThreads, u32, 3>::with_no_threads();
794 let (s, mut r) = split(&c); 805 let (s, mut r) = split(&mut c);
795 assert!(s.try_send(1).is_ok()); 806 assert!(s.try_send(1).is_ok());
796 r.close(); 807 r.close();
797 assert_eq!(r.try_recv().unwrap(), 1); 808 assert_eq!(r.try_recv().unwrap(), 1);
@@ -808,7 +819,7 @@ mod tests {
808 819
809 static mut CHANNEL: Channel<WithCriticalSections, u32, 3> = 820 static mut CHANNEL: Channel<WithCriticalSections, u32, 3> =
810 Channel::with_critical_sections(); 821 Channel::with_critical_sections();
811 let (s, mut r) = split(unsafe { &CHANNEL }); 822 let (s, mut r) = split(unsafe { &mut CHANNEL });
812 assert!(executor 823 assert!(executor
813 .spawn(async move { 824 .spawn(async move {
814 drop(s); 825 drop(s);
@@ -823,7 +834,7 @@ mod tests {
823 834
824 static mut CHANNEL: Channel<WithCriticalSections, u32, 3> = 835 static mut CHANNEL: Channel<WithCriticalSections, u32, 3> =
825 Channel::with_critical_sections(); 836 Channel::with_critical_sections();
826 let (s, mut r) = split(unsafe { &CHANNEL }); 837 let (s, mut r) = split(unsafe { &mut CHANNEL });
827 assert!(executor 838 assert!(executor
828 .spawn(async move { 839 .spawn(async move {
829 assert!(s.try_send(1).is_ok()); 840 assert!(s.try_send(1).is_ok());
@@ -836,7 +847,7 @@ mod tests {
836 async fn sender_send_completes_if_capacity() { 847 async fn sender_send_completes_if_capacity() {
837 static mut CHANNEL: Channel<WithCriticalSections, u32, 1> = 848 static mut CHANNEL: Channel<WithCriticalSections, u32, 1> =
838 Channel::with_critical_sections(); 849 Channel::with_critical_sections();
839 let (s, mut r) = split(unsafe { &CHANNEL }); 850 let (s, mut r) = split(unsafe { &mut CHANNEL });
840 assert!(s.send(1).await.is_ok()); 851 assert!(s.send(1).await.is_ok());
841 assert_eq!(r.recv().await, Some(1)); 852 assert_eq!(r.recv().await, Some(1));
842 } 853 }
@@ -845,7 +856,7 @@ mod tests {
845 async fn sender_send_completes_if_closed() { 856 async fn sender_send_completes_if_closed() {
846 static mut CHANNEL: Channel<WithCriticalSections, u32, 1> = 857 static mut CHANNEL: Channel<WithCriticalSections, u32, 1> =
847 Channel::with_critical_sections(); 858 Channel::with_critical_sections();
848 let (s, r) = split(unsafe { &CHANNEL }); 859 let (s, r) = split(unsafe { &mut CHANNEL });
849 drop(r); 860 drop(r);
850 match s.send(1).await { 861 match s.send(1).await {
851 Err(SendError(1)) => assert!(true), 862 Err(SendError(1)) => assert!(true),
@@ -859,7 +870,7 @@ mod tests {
859 870
860 static mut CHANNEL: Channel<WithCriticalSections, u32, 1> = 871 static mut CHANNEL: Channel<WithCriticalSections, u32, 1> =
861 Channel::with_critical_sections(); 872 Channel::with_critical_sections();
862 let (s0, mut r) = split(unsafe { &CHANNEL }); 873 let (s0, mut r) = split(unsafe { &mut CHANNEL });
863 assert!(s0.try_send(1).is_ok()); 874 assert!(s0.try_send(1).is_ok());
864 let s1 = s0.clone(); 875 let s1 = s0.clone();
865 let send_task_1 = executor.spawn_with_handle(async move { s0.send(2).await }); 876 let send_task_1 = executor.spawn_with_handle(async move { s0.send(2).await });
@@ -879,7 +890,7 @@ mod tests {
879 async fn sender_close_completes_if_closing() { 890 async fn sender_close_completes_if_closing() {
880 static mut CHANNEL: Channel<WithCriticalSections, u32, 1> = 891 static mut CHANNEL: Channel<WithCriticalSections, u32, 1> =
881 Channel::with_critical_sections(); 892 Channel::with_critical_sections();
882 let (s, mut r) = split(unsafe { &CHANNEL }); 893 let (s, mut r) = split(unsafe { &mut CHANNEL });
883 r.close(); 894 r.close();
884 s.closed().await; 895 s.closed().await;
885 } 896 }
@@ -888,7 +899,7 @@ mod tests {
888 async fn sender_close_completes_if_closed() { 899 async fn sender_close_completes_if_closed() {
889 static mut CHANNEL: Channel<WithCriticalSections, u32, 1> = 900 static mut CHANNEL: Channel<WithCriticalSections, u32, 1> =
890 Channel::with_critical_sections(); 901 Channel::with_critical_sections();
891 let (s, r) = split(unsafe { &CHANNEL }); 902 let (s, r) = split(unsafe { &mut CHANNEL });
892 drop(r); 903 drop(r);
893 s.closed().await; 904 s.closed().await;
894 } 905 }