aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorhuntc <[email protected]>2021-07-14 16:34:32 +1000
committerhuntc <[email protected]>2021-07-15 12:31:53 +1000
commitd711e8a82cef7ac26191e330aa4bd7cfebd570be (patch)
tree1f2ebfdde872dd62ff5e4f5b44a939f418c87b75
parentbabee7f32a4919957836a002e2c971aac368bfab (diff)
Eliminates unsoundness by using an UnsafeCell for sharing the channel
-rw-r--r--embassy/src/util/mpsc.rs348
1 files changed, 174 insertions, 174 deletions
diff --git a/embassy/src/util/mpsc.rs b/embassy/src/util/mpsc.rs
index b30e41318..c409161f8 100644
--- a/embassy/src/util/mpsc.rs
+++ b/embassy/src/util/mpsc.rs
@@ -122,11 +122,10 @@ where
122{ 122{
123 let sender = Sender { channel: &channel }; 123 let sender = Sender { channel: &channel };
124 let receiver = Receiver { channel: &channel }; 124 let receiver = Receiver { channel: &channel };
125 { 125 channel.lock(|c| {
126 let c = channel.get();
127 c.register_receiver(); 126 c.register_receiver();
128 c.register_sender(); 127 c.register_sender();
129 } 128 });
130 (sender, receiver) 129 (sender, receiver)
131} 130}
132 131
@@ -155,11 +154,12 @@ where
155 } 154 }
156 155
157 fn recv_poll(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> { 156 fn recv_poll(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
158 match self.channel.get().try_recv_with_context(Some(cx)) { 157 self.channel
159 Ok(v) => Poll::Ready(Some(v)), 158 .lock(|c| match c.try_recv_with_context(Some(cx)) {
160 Err(TryRecvError::Closed) => Poll::Ready(None), 159 Ok(v) => Poll::Ready(Some(v)),
161 Err(TryRecvError::Empty) => Poll::Pending, 160 Err(TryRecvError::Closed) => Poll::Ready(None),
162 } 161 Err(TryRecvError::Empty) => Poll::Pending,
162 })
163 } 163 }
164 164
165 /// Attempts to immediately receive a message on this `Receiver` 165 /// Attempts to immediately receive a message on this `Receiver`
@@ -167,7 +167,7 @@ where
167 /// This method will either receive a message from the channel immediately or return an error 167 /// This method will either receive a message from the channel immediately or return an error
168 /// if the channel is empty. 168 /// if the channel is empty.
169 pub fn try_recv(&self) -> Result<T, TryRecvError> { 169 pub fn try_recv(&self) -> Result<T, TryRecvError> {
170 self.channel.get().try_recv() 170 self.channel.lock(|c| c.try_recv())
171 } 171 }
172 172
173 /// Closes the receiving half of a channel without dropping it. 173 /// Closes the receiving half of a channel without dropping it.
@@ -181,7 +181,7 @@ where
181 /// until those are released. 181 /// until those are released.
182 /// 182 ///
183 pub fn close(&mut self) { 183 pub fn close(&mut self) {
184 self.channel.get().close() 184 self.channel.lock(|c| c.close())
185 } 185 }
186} 186}
187 187
@@ -190,7 +190,7 @@ where
190 M: Mutex<Data = ()>, 190 M: Mutex<Data = ()>,
191{ 191{
192 fn drop(&mut self) { 192 fn drop(&mut self) {
193 self.channel.get().deregister_receiver() 193 self.channel.lock(|c| c.deregister_receiver())
194 } 194 }
195} 195}
196 196
@@ -245,7 +245,7 @@ where
245 /// [`channel`]: channel 245 /// [`channel`]: channel
246 /// [`close`]: Receiver::close 246 /// [`close`]: Receiver::close
247 pub fn try_send(&self, message: T) -> Result<(), TrySendError<T>> { 247 pub fn try_send(&self, message: T) -> Result<(), TrySendError<T>> {
248 self.channel.get().try_send(message) 248 self.channel.lock(|c| c.try_send(message))
249 } 249 }
250 250
251 /// Completes when the receiver has dropped. 251 /// Completes when the receiver has dropped.
@@ -266,7 +266,7 @@ where
266 /// [`Receiver`]: crate::sync::mpsc::Receiver 266 /// [`Receiver`]: crate::sync::mpsc::Receiver
267 /// [`Receiver::close`]: crate::sync::mpsc::Receiver::close 267 /// [`Receiver::close`]: crate::sync::mpsc::Receiver::close
268 pub fn is_closed(&self) -> bool { 268 pub fn is_closed(&self) -> bool {
269 self.channel.get().is_closed() 269 self.channel.lock(|c| c.is_closed())
270 } 270 }
271} 271}
272 272
@@ -286,7 +286,11 @@ where
286 286
287 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { 287 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
288 match self.message.take() { 288 match self.message.take() {
289 Some(m) => match self.sender.channel.get().try_send_with_context(m, Some(cx)) { 289 Some(m) => match self
290 .sender
291 .channel
292 .lock(|c| c.try_send_with_context(m, Some(cx)))
293 {
290 Ok(..) => Poll::Ready(Ok(())), 294 Ok(..) => Poll::Ready(Ok(())),
291 Err(TrySendError::Closed(m)) => Poll::Ready(Err(SendError(m))), 295 Err(TrySendError::Closed(m)) => Poll::Ready(Err(SendError(m))),
292 Err(TrySendError::Full(m)) => { 296 Err(TrySendError::Full(m)) => {
@@ -315,7 +319,11 @@ where
315 type Output = (); 319 type Output = ();
316 320
317 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { 321 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
318 if self.sender.channel.get().is_closed_with_context(Some(cx)) { 322 if self
323 .sender
324 .channel
325 .lock(|c| c.is_closed_with_context(Some(cx)))
326 {
319 Poll::Ready(()) 327 Poll::Ready(())
320 } else { 328 } else {
321 Poll::Pending 329 Poll::Pending
@@ -328,7 +336,7 @@ where
328 M: Mutex<Data = ()>, 336 M: Mutex<Data = ()>,
329{ 337{
330 fn drop(&mut self) { 338 fn drop(&mut self) {
331 self.channel.get().deregister_sender() 339 self.channel.lock(|c| c.deregister_sender())
332 } 340 }
333} 341}
334 342
@@ -338,7 +346,7 @@ where
338{ 346{
339 #[allow(clippy::clone_double_ref)] 347 #[allow(clippy::clone_double_ref)]
340 fn clone(&self) -> Self { 348 fn clone(&self) -> Self {
341 self.channel.get().register_sender(); 349 self.channel.lock(|c| c.register_sender());
342 Sender { 350 Sender {
343 channel: self.channel.clone(), 351 channel: self.channel.clone(),
344 } 352 }
@@ -421,6 +429,116 @@ impl<T, const N: usize> ChannelState<T, N> {
421 senders_waker: WakerRegistration::new(), 429 senders_waker: WakerRegistration::new(),
422 } 430 }
423 } 431 }
432
433 fn try_recv(&mut self) -> Result<T, TryRecvError> {
434 self.try_recv_with_context(None)
435 }
436
437 fn try_recv_with_context(&mut self, cx: Option<&mut Context<'_>>) -> Result<T, TryRecvError> {
438 if self.read_pos != self.write_pos || self.full {
439 if self.full {
440 self.full = false;
441 self.senders_waker.wake();
442 }
443 let message = unsafe { (self.buf[self.read_pos]).assume_init_mut().get().read() };
444 self.read_pos = (self.read_pos + 1) % self.buf.len();
445 Ok(message)
446 } else if !self.closed {
447 cx.into_iter()
448 .for_each(|cx| self.set_receiver_waker(&cx.waker()));
449 Err(TryRecvError::Empty)
450 } else {
451 Err(TryRecvError::Closed)
452 }
453 }
454
455 fn try_send(&mut self, message: T) -> Result<(), TrySendError<T>> {
456 self.try_send_with_context(message, None)
457 }
458
459 fn try_send_with_context(
460 &mut self,
461 message: T,
462 cx: Option<&mut Context<'_>>,
463 ) -> Result<(), TrySendError<T>> {
464 if !self.closed {
465 if !self.full {
466 self.buf[self.write_pos] = MaybeUninit::new(message.into());
467 self.write_pos = (self.write_pos + 1) % self.buf.len();
468 if self.write_pos == self.read_pos {
469 self.full = true;
470 }
471 self.receiver_waker.wake();
472 Ok(())
473 } else {
474 cx.into_iter()
475 .for_each(|cx| self.set_senders_waker(&cx.waker()));
476 Err(TrySendError::Full(message))
477 }
478 } else {
479 Err(TrySendError::Closed(message))
480 }
481 }
482
483 fn close(&mut self) {
484 self.receiver_waker.wake();
485 self.closed = true;
486 }
487
488 fn is_closed(&mut self) -> bool {
489 self.is_closed_with_context(None)
490 }
491
492 fn is_closed_with_context(&mut self, cx: Option<&mut Context<'_>>) -> bool {
493 if self.closed {
494 cx.into_iter()
495 .for_each(|cx| self.set_senders_waker(&cx.waker()));
496 true
497 } else {
498 false
499 }
500 }
501
502 fn register_receiver(&mut self) {
503 assert!(!self.receiver_registered);
504 self.receiver_registered = true;
505 }
506
507 fn deregister_receiver(&mut self) {
508 if self.receiver_registered {
509 self.closed = true;
510 self.senders_waker.wake();
511 }
512 self.receiver_registered = false;
513 }
514
515 fn register_sender(&mut self) {
516 self.senders_registered += 1;
517 }
518
519 fn deregister_sender(&mut self) {
520 assert!(self.senders_registered > 0);
521 self.senders_registered -= 1;
522 if self.senders_registered == 0 {
523 self.receiver_waker.wake();
524 self.closed = true;
525 }
526 }
527
528 fn set_receiver_waker(&mut self, receiver_waker: &Waker) {
529 self.receiver_waker.register(receiver_waker);
530 }
531
532 fn set_senders_waker(&mut self, senders_waker: &Waker) {
533 // Dispose of any existing sender causing them to be polled again.
534 // This could cause a spin given multiple concurrent senders, however given that
535 // most sends only block waiting for the receiver to become active, this should
536 // be a short-lived activity. The upside is a greatly simplified implementation
537 // that avoids the need for intrusive linked-lists and unsafe operations on pinned
538 // pointers.
539 self.senders_waker.wake();
540 self.senders_waker.register(senders_waker);
541 }
424} 542}
425 543
426impl<T, const N: usize> Drop for ChannelState<T, N> { 544impl<T, const N: usize> Drop for ChannelState<T, N> {
@@ -445,6 +563,13 @@ pub struct Channel<M, T, const N: usize>
445where 563where
446 M: Mutex<Data = ()>, 564 M: Mutex<Data = ()>,
447{ 565{
566 sync_channel: UnsafeCell<ChannelCell<M, T, N>>,
567}
568
569struct ChannelCell<M, T, const N: usize>
570where
571 M: Mutex<Data = ()>,
572{
448 mutex: M, 573 mutex: M,
449 state: ChannelState<T, N>, 574 state: ChannelState<T, N>,
450} 575}
@@ -468,7 +593,10 @@ impl<T, const N: usize> Channel<WithCriticalSections, T, N> {
468 pub const fn with_critical_sections() -> Self { 593 pub const fn with_critical_sections() -> Self {
469 let mutex = CriticalSectionMutex::new(()); 594 let mutex = CriticalSectionMutex::new(());
470 let state = ChannelState::new(); 595 let state = ChannelState::new();
471 Channel { mutex, state } 596 let sync_channel = ChannelCell { mutex, state };
597 Channel {
598 sync_channel: UnsafeCell::new(sync_channel),
599 }
472 } 600 }
473} 601}
474 602
@@ -492,7 +620,10 @@ impl<T, const N: usize> Channel<WithThreadModeOnly, T, N> {
492 pub const fn with_thread_mode_only() -> Self { 620 pub const fn with_thread_mode_only() -> Self {
493 let mutex = ThreadModeMutex::new(()); 621 let mutex = ThreadModeMutex::new(());
494 let state = ChannelState::new(); 622 let state = ChannelState::new();
495 Channel { mutex, state } 623 let sync_channel = ChannelCell { mutex, state };
624 Channel {
625 sync_channel: UnsafeCell::new(sync_channel),
626 }
496 } 627 }
497} 628}
498 629
@@ -513,7 +644,10 @@ impl<T, const N: usize> Channel<WithNoThreads, T, N> {
513 pub const fn with_no_threads() -> Self { 644 pub const fn with_no_threads() -> Self {
514 let mutex = NoopMutex::new(()); 645 let mutex = NoopMutex::new(());
515 let state = ChannelState::new(); 646 let state = ChannelState::new();
516 Channel { mutex, state } 647 let sync_channel = ChannelCell { mutex, state };
648 Channel {
649 sync_channel: UnsafeCell::new(sync_channel),
650 }
517 } 651 }
518} 652}
519 653
@@ -521,144 +655,13 @@ impl<M, T, const N: usize> Channel<M, T, N>
521where 655where
522 M: Mutex<Data = ()>, 656 M: Mutex<Data = ()>,
523{ 657{
524 fn get(&self) -> &mut Self { 658 fn lock<R>(&self, f: impl FnOnce(&mut ChannelState<T, N>) -> R) -> R {
525 let const_ptr = self as *const Self; 659 unsafe {
526 let mut_ptr = const_ptr as *mut Self; 660 let sync_channel = &mut *(self.sync_channel.get());
527 unsafe { &mut *mut_ptr } 661 let mutex = &mut sync_channel.mutex;
528 } 662 let mut state = &mut sync_channel.state;
529 663 mutex.lock(|_| f(&mut state))
530 fn try_recv(&mut self) -> Result<T, TryRecvError> { 664 }
531 self.try_recv_with_context(None)
532 }
533
534 fn try_recv_with_context(&mut self, cx: Option<&mut Context<'_>>) -> Result<T, TryRecvError> {
535 let mut state = &mut self.state;
536 self.mutex.lock(|_| {
537 if state.read_pos != state.write_pos || state.full {
538 if state.full {
539 state.full = false;
540 state.senders_waker.wake();
541 }
542 let message = unsafe { (state.buf[state.read_pos]).assume_init_mut().get().read() };
543 state.read_pos = (state.read_pos + 1) % state.buf.len();
544 Ok(message)
545 } else if !state.closed {
546 cx.into_iter()
547 .for_each(|cx| Self::set_receiver_waker(&mut state, &cx.waker()));
548 Err(TryRecvError::Empty)
549 } else {
550 Err(TryRecvError::Closed)
551 }
552 })
553 }
554
555 fn try_send(&mut self, message: T) -> Result<(), TrySendError<T>> {
556 self.try_send_with_context(message, None)
557 }
558
559 fn try_send_with_context(
560 &mut self,
561 message: T,
562 cx: Option<&mut Context<'_>>,
563 ) -> Result<(), TrySendError<T>> {
564 let mut state = &mut self.state;
565 self.mutex.lock(|_| {
566 if !state.closed {
567 if !state.full {
568 state.buf[state.write_pos] = MaybeUninit::new(message.into());
569 state.write_pos = (state.write_pos + 1) % state.buf.len();
570 if state.write_pos == state.read_pos {
571 state.full = true;
572 }
573 state.receiver_waker.wake();
574 Ok(())
575 } else {
576 cx.into_iter()
577 .for_each(|cx| Self::set_senders_waker(&mut state, &cx.waker()));
578 Err(TrySendError::Full(message))
579 }
580 } else {
581 Err(TrySendError::Closed(message))
582 }
583 })
584 }
585
586 fn close(&mut self) {
587 let state = &mut self.state;
588 self.mutex.lock(|_| {
589 state.receiver_waker.wake();
590 state.closed = true;
591 });
592 }
593
594 fn is_closed(&mut self) -> bool {
595 self.is_closed_with_context(None)
596 }
597
598 fn is_closed_with_context(&mut self, cx: Option<&mut Context<'_>>) -> bool {
599 let mut state = &mut self.state;
600 self.mutex.lock(|_| {
601 if state.closed {
602 cx.into_iter()
603 .for_each(|cx| Self::set_senders_waker(&mut state, &cx.waker()));
604 true
605 } else {
606 false
607 }
608 })
609 }
610
611 fn register_receiver(&mut self) {
612 let state = &mut self.state;
613 self.mutex.lock(|_| {
614 assert!(!state.receiver_registered);
615 state.receiver_registered = true;
616 });
617 }
618
619 fn deregister_receiver(&mut self) {
620 let state = &mut self.state;
621 self.mutex.lock(|_| {
622 if state.receiver_registered {
623 state.closed = true;
624 state.senders_waker.wake();
625 }
626 state.receiver_registered = false;
627 })
628 }
629
630 fn register_sender(&mut self) {
631 let state = &mut self.state;
632 self.mutex.lock(|_| {
633 state.senders_registered += 1;
634 })
635 }
636
637 fn deregister_sender(&mut self) {
638 let state = &mut self.state;
639 self.mutex.lock(|_| {
640 assert!(state.senders_registered > 0);
641 state.senders_registered -= 1;
642 if state.senders_registered == 0 {
643 state.receiver_waker.wake();
644 state.closed = true;
645 }
646 })
647 }
648
649 fn set_receiver_waker(state: &mut ChannelState<T, N>, receiver_waker: &Waker) {
650 state.receiver_waker.register(receiver_waker);
651 }
652
653 fn set_senders_waker(state: &mut ChannelState<T, N>, senders_waker: &Waker) {
654 // Dispose of any existing sender causing them to be polled again.
655 // This could cause a spin given multiple concurrent senders, however given that
656 // most sends only block waiting for the receiver to become active, this should
657 // be a short-lived activity. The upside is a greatly simplified implementation
658 // that avoids the need for intrusive linked-lists and unsafe operations on pinned
659 // pointers.
660 state.senders_waker.wake();
661 state.senders_waker.register(senders_waker);
662 } 665 }
663} 666}
664 667
@@ -672,15 +675,12 @@ mod tests {
672 675
673 use super::*; 676 use super::*;
674 677
675 fn capacity<M, T, const N: usize>(c: &Channel<M, T, N>) -> usize 678 fn capacity<T, const N: usize>(c: &ChannelState<T, N>) -> usize {
676 where 679 if !c.full {
677 M: Mutex<Data = ()>, 680 if c.write_pos > c.read_pos {
678 { 681 (c.buf.len() - c.write_pos) + c.read_pos
679 if !c.state.full {
680 if c.state.write_pos > c.state.read_pos {
681 (c.state.buf.len() - c.state.write_pos) + c.state.read_pos
682 } else { 682 } else {
683 (c.state.buf.len() - c.state.read_pos) + c.state.write_pos 683 (c.buf.len() - c.read_pos) + c.write_pos
684 } 684 }
685 } else { 685 } else {
686 0 686 0
@@ -689,14 +689,14 @@ mod tests {
689 689
690 #[test] 690 #[test]
691 fn sending_once() { 691 fn sending_once() {
692 let mut c = Channel::<WithNoThreads, u32, 3>::with_no_threads(); 692 let mut c = ChannelState::<u32, 3>::new();
693 assert!(c.try_send(1).is_ok()); 693 assert!(c.try_send(1).is_ok());
694 assert_eq!(capacity(&c), 2); 694 assert_eq!(capacity(&c), 2);
695 } 695 }
696 696
697 #[test] 697 #[test]
698 fn sending_when_full() { 698 fn sending_when_full() {
699 let mut c = Channel::<WithNoThreads, u32, 3>::with_no_threads(); 699 let mut c = ChannelState::<u32, 3>::new();
700 let _ = c.try_send(1); 700 let _ = c.try_send(1);
701 let _ = c.try_send(1); 701 let _ = c.try_send(1);
702 let _ = c.try_send(1); 702 let _ = c.try_send(1);
@@ -709,8 +709,8 @@ mod tests {
709 709
710 #[test] 710 #[test]
711 fn sending_when_closed() { 711 fn sending_when_closed() {
712 let mut c = Channel::<WithNoThreads, u32, 3>::with_no_threads(); 712 let mut c = ChannelState::<u32, 3>::new();
713 c.state.closed = true; 713 c.closed = true;
714 match c.try_send(2) { 714 match c.try_send(2) {
715 Err(TrySendError::Closed(2)) => assert!(true), 715 Err(TrySendError::Closed(2)) => assert!(true),
716 _ => assert!(false), 716 _ => assert!(false),
@@ -719,7 +719,7 @@ mod tests {
719 719
720 #[test] 720 #[test]
721 fn receiving_once_with_one_send() { 721 fn receiving_once_with_one_send() {
722 let mut c = Channel::<WithNoThreads, u32, 3>::with_no_threads(); 722 let mut c = ChannelState::<u32, 3>::new();
723 assert!(c.try_send(1).is_ok()); 723 assert!(c.try_send(1).is_ok());
724 assert_eq!(c.try_recv().unwrap(), 1); 724 assert_eq!(c.try_recv().unwrap(), 1);
725 assert_eq!(capacity(&c), 3); 725 assert_eq!(capacity(&c), 3);
@@ -727,7 +727,7 @@ mod tests {
727 727
728 #[test] 728 #[test]
729 fn receiving_when_empty() { 729 fn receiving_when_empty() {
730 let mut c = Channel::<WithNoThreads, u32, 3>::with_no_threads(); 730 let mut c = ChannelState::<u32, 3>::new();
731 match c.try_recv() { 731 match c.try_recv() {
732 Err(TryRecvError::Empty) => assert!(true), 732 Err(TryRecvError::Empty) => assert!(true),
733 _ => assert!(false), 733 _ => assert!(false),
@@ -737,8 +737,8 @@ mod tests {
737 737
738 #[test] 738 #[test]
739 fn receiving_when_closed() { 739 fn receiving_when_closed() {
740 let mut c = Channel::<WithNoThreads, u32, 3>::with_no_threads(); 740 let mut c = ChannelState::<u32, 3>::new();
741 c.state.closed = true; 741 c.closed = true;
742 match c.try_recv() { 742 match c.try_recv() {
743 Err(TryRecvError::Closed) => assert!(true), 743 Err(TryRecvError::Closed) => assert!(true),
744 _ => assert!(false), 744 _ => assert!(false),