aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--embassy/src/util/mpsc.rs84
1 files changed, 53 insertions, 31 deletions
diff --git a/embassy/src/util/mpsc.rs b/embassy/src/util/mpsc.rs
index 68fcdf7f9..8d534dc49 100644
--- a/embassy/src/util/mpsc.rs
+++ b/embassy/src/util/mpsc.rs
@@ -145,14 +145,11 @@ where
145 futures::future::poll_fn(|cx| self.recv_poll(cx)).await 145 futures::future::poll_fn(|cx| self.recv_poll(cx)).await
146 } 146 }
147 147
148 fn recv_poll(self: &mut Self, cx: &mut Context<'_>) -> Poll<Option<T>> { 148 fn recv_poll(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
149 match self.try_recv() { 149 match self.channel.get().try_recv_with_context(Some(cx)) {
150 Ok(v) => Poll::Ready(Some(v)), 150 Ok(v) => Poll::Ready(Some(v)),
151 Err(TryRecvError::Closed) => Poll::Ready(None), 151 Err(TryRecvError::Closed) => Poll::Ready(None),
152 Err(TryRecvError::Empty) => { 152 Err(TryRecvError::Empty) => Poll::Pending,
153 self.channel.get().set_receiver_waker(&cx.waker());
154 Poll::Pending
155 }
156 } 153 }
157 } 154 }
158 155
@@ -279,11 +276,15 @@ where
279 type Output = Result<(), SendError<T>>; 276 type Output = Result<(), SendError<T>>;
280 277
281 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { 278 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
282 match self.sender.try_send(unsafe { self.message.get().read() }) { 279 match self
280 .sender
281 .channel
282 .get()
283 .try_send_with_context(unsafe { self.message.get().read() }, Some(cx))
284 {
283 Ok(..) => Poll::Ready(Ok(())), 285 Ok(..) => Poll::Ready(Ok(())),
284 Err(TrySendError::Closed(m)) => Poll::Ready(Err(SendError(m))), 286 Err(TrySendError::Closed(m)) => Poll::Ready(Err(SendError(m))),
285 Err(TrySendError::Full(..)) => { 287 Err(TrySendError::Full(..)) => {
286 self.sender.channel.get().set_senders_waker(&cx.waker());
287 Poll::Pending 288 Poll::Pending
288 // Note we leave the existing UnsafeCell contents - they still 289 // Note we leave the existing UnsafeCell contents - they still
289 // contain the original message. We could create another UnsafeCell 290 // contain the original message. We could create another UnsafeCell
@@ -307,10 +308,9 @@ where
307 type Output = (); 308 type Output = ();
308 309
309 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { 310 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
310 if self.sender.is_closed() { 311 if self.sender.channel.get().is_closed_with_context(Some(cx)) {
311 Poll::Ready(()) 312 Poll::Ready(())
312 } else { 313 } else {
313 self.sender.channel.get().set_senders_waker(&cx.waker());
314 Poll::Pending 314 Poll::Pending
315 } 315 }
316 } 316 }
@@ -513,7 +513,11 @@ where
513 } 513 }
514 514
515 fn try_recv(&mut self) -> Result<T, TryRecvError> { 515 fn try_recv(&mut self) -> Result<T, TryRecvError> {
516 let state = &mut self.state; 516 self.try_recv_with_context(None)
517 }
518
519 fn try_recv_with_context(&mut self, cx: Option<&mut Context<'_>>) -> Result<T, TryRecvError> {
520 let mut state = &mut self.state;
517 self.mutex.lock(|_| { 521 self.mutex.lock(|_| {
518 if !state.closed { 522 if !state.closed {
519 if state.read_pos != state.write_pos || state.full { 523 if state.read_pos != state.write_pos || state.full {
@@ -526,6 +530,8 @@ where
526 state.read_pos = (state.read_pos + 1) % state.buf.len(); 530 state.read_pos = (state.read_pos + 1) % state.buf.len();
527 Ok(message) 531 Ok(message)
528 } else if !state.closing { 532 } else if !state.closing {
533 cx.into_iter()
534 .for_each(|cx| Self::set_receiver_waker(&mut state, &cx.waker()));
529 Err(TryRecvError::Empty) 535 Err(TryRecvError::Empty)
530 } else { 536 } else {
531 state.closed = true; 537 state.closed = true;
@@ -539,7 +545,15 @@ where
539 } 545 }
540 546
541 fn try_send(&mut self, message: T) -> Result<(), TrySendError<T>> { 547 fn try_send(&mut self, message: T) -> Result<(), TrySendError<T>> {
542 let state = &mut self.state; 548 self.try_send_with_context(message, None)
549 }
550
551 fn try_send_with_context(
552 &mut self,
553 message: T,
554 cx: Option<&mut Context<'_>>,
555 ) -> Result<(), TrySendError<T>> {
556 let mut state = &mut self.state;
543 self.mutex.lock(|_| { 557 self.mutex.lock(|_| {
544 if !state.closed { 558 if !state.closed {
545 if !state.full { 559 if !state.full {
@@ -551,6 +565,8 @@ where
551 state.receiver_waker.wake(); 565 state.receiver_waker.wake();
552 Ok(()) 566 Ok(())
553 } else { 567 } else {
568 cx.into_iter()
569 .for_each(|cx| Self::set_senders_waker(&mut state, &cx.waker()));
554 Err(TrySendError::Full(message)) 570 Err(TrySendError::Full(message))
555 } 571 }
556 } else { 572 } else {
@@ -568,8 +584,20 @@ where
568 } 584 }
569 585
570 fn is_closed(&mut self) -> bool { 586 fn is_closed(&mut self) -> bool {
571 let state = &self.state; 587 self.is_closed_with_context(None)
572 self.mutex.lock(|_| state.closing || state.closed) 588 }
589
590 fn is_closed_with_context(&mut self, cx: Option<&mut Context<'_>>) -> bool {
591 let mut state = &mut self.state;
592 self.mutex.lock(|_| {
593 if state.closing || state.closed {
594 cx.into_iter()
595 .for_each(|cx| Self::set_senders_waker(&mut state, &cx.waker()));
596 true
597 } else {
598 false
599 }
600 })
573 } 601 }
574 602
575 fn register_receiver(&mut self) { 603 fn register_receiver(&mut self) {
@@ -610,25 +638,19 @@ where
610 }) 638 })
611 } 639 }
612 640
613 fn set_receiver_waker(&mut self, receiver_waker: &Waker) { 641 fn set_receiver_waker(state: &mut ChannelState<T, N>, receiver_waker: &Waker) {
614 let state = &mut self.state; 642 state.receiver_waker.register(receiver_waker);
615 self.mutex.lock(|_| {
616 state.receiver_waker.register(receiver_waker);
617 })
618 } 643 }
619 644
620 fn set_senders_waker(&mut self, senders_waker: &Waker) { 645 fn set_senders_waker(state: &mut ChannelState<T, N>, senders_waker: &Waker) {
621 let state = &mut self.state; 646 // Dispose of any existing sender causing them to be polled again.
622 self.mutex.lock(|_| { 647 // This could cause a spin given multiple concurrent senders, however given that
623 // Dispose of any existing sender causing them to be polled again. 648 // most sends only block waiting for the receiver to become active, this should
624 // This could cause a spin given multiple concurrent senders, however given that 649 // be a short-lived activity. The upside is a greatly simplified implementation
625 // most sends only block waiting for the receiver to become active, this should 650 // that avoids the need for intrusive linked-lists and unsafe operations on pinned
626 // be a short-lived activity. The upside is a greatly simplified implementation 651 // pointers.
627 // that avoids the need for intrusive linked-lists and unsafe operations on pinned 652 state.senders_waker.wake();
628 // pointers. 653 state.senders_waker.register(senders_waker);
629 state.senders_waker.wake();
630 state.senders_waker.register(senders_waker);
631 })
632 } 654 }
633} 655}
634 656