diff options
| -rw-r--r-- | embassy/src/util/mpsc.rs | 84 |
1 files changed, 53 insertions, 31 deletions
diff --git a/embassy/src/util/mpsc.rs b/embassy/src/util/mpsc.rs index 68fcdf7f9..8d534dc49 100644 --- a/embassy/src/util/mpsc.rs +++ b/embassy/src/util/mpsc.rs | |||
| @@ -145,14 +145,11 @@ where | |||
| 145 | futures::future::poll_fn(|cx| self.recv_poll(cx)).await | 145 | futures::future::poll_fn(|cx| self.recv_poll(cx)).await |
| 146 | } | 146 | } |
| 147 | 147 | ||
| 148 | fn recv_poll(self: &mut Self, cx: &mut Context<'_>) -> Poll<Option<T>> { | 148 | fn recv_poll(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> { |
| 149 | match self.try_recv() { | 149 | match self.channel.get().try_recv_with_context(Some(cx)) { |
| 150 | Ok(v) => Poll::Ready(Some(v)), | 150 | Ok(v) => Poll::Ready(Some(v)), |
| 151 | Err(TryRecvError::Closed) => Poll::Ready(None), | 151 | Err(TryRecvError::Closed) => Poll::Ready(None), |
| 152 | Err(TryRecvError::Empty) => { | 152 | Err(TryRecvError::Empty) => Poll::Pending, |
| 153 | self.channel.get().set_receiver_waker(&cx.waker()); | ||
| 154 | Poll::Pending | ||
| 155 | } | ||
| 156 | } | 153 | } |
| 157 | } | 154 | } |
| 158 | 155 | ||
| @@ -279,11 +276,15 @@ where | |||
| 279 | type Output = Result<(), SendError<T>>; | 276 | type Output = Result<(), SendError<T>>; |
| 280 | 277 | ||
| 281 | fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { | 278 | fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { |
| 282 | match self.sender.try_send(unsafe { self.message.get().read() }) { | 279 | match self |
| 280 | .sender | ||
| 281 | .channel | ||
| 282 | .get() | ||
| 283 | .try_send_with_context(unsafe { self.message.get().read() }, Some(cx)) | ||
| 284 | { | ||
| 283 | Ok(..) => Poll::Ready(Ok(())), | 285 | Ok(..) => Poll::Ready(Ok(())), |
| 284 | Err(TrySendError::Closed(m)) => Poll::Ready(Err(SendError(m))), | 286 | Err(TrySendError::Closed(m)) => Poll::Ready(Err(SendError(m))), |
| 285 | Err(TrySendError::Full(..)) => { | 287 | Err(TrySendError::Full(..)) => { |
| 286 | self.sender.channel.get().set_senders_waker(&cx.waker()); | ||
| 287 | Poll::Pending | 288 | Poll::Pending |
| 288 | // Note we leave the existing UnsafeCell contents - they still | 289 | // Note we leave the existing UnsafeCell contents - they still |
| 289 | // contain the original message. We could create another UnsafeCell | 290 | // contain the original message. We could create another UnsafeCell |
| @@ -307,10 +308,9 @@ where | |||
| 307 | type Output = (); | 308 | type Output = (); |
| 308 | 309 | ||
| 309 | fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { | 310 | fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { |
| 310 | if self.sender.is_closed() { | 311 | if self.sender.channel.get().is_closed_with_context(Some(cx)) { |
| 311 | Poll::Ready(()) | 312 | Poll::Ready(()) |
| 312 | } else { | 313 | } else { |
| 313 | self.sender.channel.get().set_senders_waker(&cx.waker()); | ||
| 314 | Poll::Pending | 314 | Poll::Pending |
| 315 | } | 315 | } |
| 316 | } | 316 | } |
| @@ -513,7 +513,11 @@ where | |||
| 513 | } | 513 | } |
| 514 | 514 | ||
| 515 | fn try_recv(&mut self) -> Result<T, TryRecvError> { | 515 | fn try_recv(&mut self) -> Result<T, TryRecvError> { |
| 516 | let state = &mut self.state; | 516 | self.try_recv_with_context(None) |
| 517 | } | ||
| 518 | |||
| 519 | fn try_recv_with_context(&mut self, cx: Option<&mut Context<'_>>) -> Result<T, TryRecvError> { | ||
| 520 | let mut state = &mut self.state; | ||
| 517 | self.mutex.lock(|_| { | 521 | self.mutex.lock(|_| { |
| 518 | if !state.closed { | 522 | if !state.closed { |
| 519 | if state.read_pos != state.write_pos || state.full { | 523 | if state.read_pos != state.write_pos || state.full { |
| @@ -526,6 +530,8 @@ where | |||
| 526 | state.read_pos = (state.read_pos + 1) % state.buf.len(); | 530 | state.read_pos = (state.read_pos + 1) % state.buf.len(); |
| 527 | Ok(message) | 531 | Ok(message) |
| 528 | } else if !state.closing { | 532 | } else if !state.closing { |
| 533 | cx.into_iter() | ||
| 534 | .for_each(|cx| Self::set_receiver_waker(&mut state, &cx.waker())); | ||
| 529 | Err(TryRecvError::Empty) | 535 | Err(TryRecvError::Empty) |
| 530 | } else { | 536 | } else { |
| 531 | state.closed = true; | 537 | state.closed = true; |
| @@ -539,7 +545,15 @@ where | |||
| 539 | } | 545 | } |
| 540 | 546 | ||
| 541 | fn try_send(&mut self, message: T) -> Result<(), TrySendError<T>> { | 547 | fn try_send(&mut self, message: T) -> Result<(), TrySendError<T>> { |
| 542 | let state = &mut self.state; | 548 | self.try_send_with_context(message, None) |
| 549 | } | ||
| 550 | |||
| 551 | fn try_send_with_context( | ||
| 552 | &mut self, | ||
| 553 | message: T, | ||
| 554 | cx: Option<&mut Context<'_>>, | ||
| 555 | ) -> Result<(), TrySendError<T>> { | ||
| 556 | let mut state = &mut self.state; | ||
| 543 | self.mutex.lock(|_| { | 557 | self.mutex.lock(|_| { |
| 544 | if !state.closed { | 558 | if !state.closed { |
| 545 | if !state.full { | 559 | if !state.full { |
| @@ -551,6 +565,8 @@ where | |||
| 551 | state.receiver_waker.wake(); | 565 | state.receiver_waker.wake(); |
| 552 | Ok(()) | 566 | Ok(()) |
| 553 | } else { | 567 | } else { |
| 568 | cx.into_iter() | ||
| 569 | .for_each(|cx| Self::set_senders_waker(&mut state, &cx.waker())); | ||
| 554 | Err(TrySendError::Full(message)) | 570 | Err(TrySendError::Full(message)) |
| 555 | } | 571 | } |
| 556 | } else { | 572 | } else { |
| @@ -568,8 +584,20 @@ where | |||
| 568 | } | 584 | } |
| 569 | 585 | ||
| 570 | fn is_closed(&mut self) -> bool { | 586 | fn is_closed(&mut self) -> bool { |
| 571 | let state = &self.state; | 587 | self.is_closed_with_context(None) |
| 572 | self.mutex.lock(|_| state.closing || state.closed) | 588 | } |
| 589 | |||
| 590 | fn is_closed_with_context(&mut self, cx: Option<&mut Context<'_>>) -> bool { | ||
| 591 | let mut state = &mut self.state; | ||
| 592 | self.mutex.lock(|_| { | ||
| 593 | if state.closing || state.closed { | ||
| 594 | cx.into_iter() | ||
| 595 | .for_each(|cx| Self::set_senders_waker(&mut state, &cx.waker())); | ||
| 596 | true | ||
| 597 | } else { | ||
| 598 | false | ||
| 599 | } | ||
| 600 | }) | ||
| 573 | } | 601 | } |
| 574 | 602 | ||
| 575 | fn register_receiver(&mut self) { | 603 | fn register_receiver(&mut self) { |
| @@ -610,25 +638,19 @@ where | |||
| 610 | }) | 638 | }) |
| 611 | } | 639 | } |
| 612 | 640 | ||
| 613 | fn set_receiver_waker(&mut self, receiver_waker: &Waker) { | 641 | fn set_receiver_waker(state: &mut ChannelState<T, N>, receiver_waker: &Waker) { |
| 614 | let state = &mut self.state; | 642 | state.receiver_waker.register(receiver_waker); |
| 615 | self.mutex.lock(|_| { | ||
| 616 | state.receiver_waker.register(receiver_waker); | ||
| 617 | }) | ||
| 618 | } | 643 | } |
| 619 | 644 | ||
| 620 | fn set_senders_waker(&mut self, senders_waker: &Waker) { | 645 | fn set_senders_waker(state: &mut ChannelState<T, N>, senders_waker: &Waker) { |
| 621 | let state = &mut self.state; | 646 | // Dispose of any existing sender causing them to be polled again. |
| 622 | self.mutex.lock(|_| { | 647 | // This could cause a spin given multiple concurrent senders, however given that |
| 623 | // Dispose of any existing sender causing them to be polled again. | 648 | // most sends only block waiting for the receiver to become active, this should |
| 624 | // This could cause a spin given multiple concurrent senders, however given that | 649 | // be a short-lived activity. The upside is a greatly simplified implementation |
| 625 | // most sends only block waiting for the receiver to become active, this should | 650 | // that avoids the need for intrusive linked-lists and unsafe operations on pinned |
| 626 | // be a short-lived activity. The upside is a greatly simplified implementation | 651 | // pointers. |
| 627 | // that avoids the need for intrusive linked-lists and unsafe operations on pinned | 652 | state.senders_waker.wake(); |
| 628 | // pointers. | 653 | state.senders_waker.register(senders_waker); |
| 629 | state.senders_waker.wake(); | ||
| 630 | state.senders_waker.register(senders_waker); | ||
| 631 | }) | ||
| 632 | } | 654 | } |
| 633 | } | 655 | } |
| 634 | 656 | ||
