aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDion Dokter <[email protected]>2022-06-16 12:28:12 +0200
committerDion Dokter <[email protected]>2022-06-16 12:28:12 +0200
commit12a6ddfbcd79f2ab62ba264acd997dca0ac64a99 (patch)
treec0a49db0471a0e8b3f85e351ab74e88e1efd874f
parent23177ba7eb800d32a48f70fe568ae9bf6b2cde98 (diff)
Added a pubsub channel implementation
-rw-r--r--embassy/src/channel/mod.rs2
-rw-r--r--embassy/src/channel/pubsub.rs590
2 files changed, 591 insertions, 1 deletions
diff --git a/embassy/src/channel/mod.rs b/embassy/src/channel/mod.rs
index 05edc55d1..5df1f5c5c 100644
--- a/embassy/src/channel/mod.rs
+++ b/embassy/src/channel/mod.rs
@@ -1,5 +1,5 @@
1//! Async channels 1//! Async channels
2 2
3pub mod mpmc; 3pub mod mpmc;
4 4pub mod pubsub;
5pub mod signal; 5pub mod signal;
diff --git a/embassy/src/channel/pubsub.rs b/embassy/src/channel/pubsub.rs
new file mode 100644
index 000000000..5e5cce9cf
--- /dev/null
+++ b/embassy/src/channel/pubsub.rs
@@ -0,0 +1,590 @@
1//! Implementation of [PubSubChannel], a queue where published messages get received by all subscribers.
2
3use core::cell::RefCell;
4use core::fmt::Debug;
5use core::future::Future;
6use core::pin::Pin;
7use core::task::{Context, Poll, Waker};
8
9use heapless::Deque;
10
11use crate::blocking_mutex::raw::RawMutex;
12use crate::blocking_mutex::Mutex;
13use crate::waitqueue::WakerRegistration;
14
15/// A broadcast channel implementation where multiple publishers can send messages to multiple subscribers
16///
17/// Any published message can be read by all subscribers.
18/// A publisher can choose how it sends its message.
19///
20/// - With [Publisher::publish] the publisher has to wait until there is space in the internal message queue.
21/// - With [Publisher::publish_immediate] the publisher doesn't await and instead lets the oldest message
22/// in the queue drop if necessary. This will cause any [Subscriber] that missed the message to receive
23/// an error to indicate that it has lagged.
24pub struct PubSubChannel<M: RawMutex, T: Clone, const CAP: usize, const SUBS: usize, const PUBS: usize> {
25 inner: Mutex<M, RefCell<PubSubState<T, CAP, SUBS, PUBS>>>,
26}
27
28impl<M: RawMutex, T: Clone, const CAP: usize, const SUBS: usize, const PUBS: usize> PubSubChannel<M, T, CAP, SUBS, PUBS> {
29 /// Create a new channel
30 pub const fn new() -> Self {
31 Self {
32 inner: Mutex::const_new(M::INIT, RefCell::new(PubSubState::new())),
33 }
34 }
35
36 /// Create a new subscriber. It will only receive messages that are published after its creation.
37 ///
38 /// If there are no subscriber slots left, an error will be returned.
39 pub fn subscriber(&self) -> Result<Subscriber<'_, T>, Error> {
40 self.inner.lock(|inner| {
41 let mut s = inner.borrow_mut();
42
43 // Search for an empty subscriber spot
44 for (i, sub_spot) in s.subscriber_wakers.iter_mut().enumerate() {
45 if sub_spot.is_none() {
46 // We've found a spot, so now fill it and create the subscriber
47 *sub_spot = Some(WakerRegistration::new());
48 return Ok(Subscriber {
49 subscriber_index: i,
50 next_message_id: s.next_message_id,
51 channel: self,
52 });
53 }
54 }
55
56 // No spot was found, we're full
57 Err(Error::MaximumSubscribersReached)
58 })
59 }
60
61 /// Create a new publisher
62 ///
63 /// If there are no publisher slots left, an error will be returned.
64 pub fn publisher(&self) -> Result<Publisher<'_, T>, Error> {
65 self.inner.lock(|inner| {
66 let mut s = inner.borrow_mut();
67
68 // Search for an empty publisher spot
69 for (i, pub_spot) in s.publisher_wakers.iter_mut().enumerate() {
70 if pub_spot.is_none() {
71 // We've found a spot, so now fill it and create the subscriber
72 *pub_spot = Some(WakerRegistration::new());
73 return Ok(Publisher {
74 publisher_index: i,
75 channel: self,
76 });
77 }
78 }
79
80 // No spot was found, we're full
81 Err(Error::MaximumPublishersReached)
82 })
83 }
84
85 /// Create a new publisher that can only send immediate messages.
86 /// This kind of publisher does not take up a publisher slot.
87 pub fn immediate_publisher(&self) -> ImmediatePublisher<'_, T> {
88 ImmediatePublisher { channel: self }
89 }
90}
91
92impl<M: RawMutex, T: Clone, const CAP: usize, const SUBS: usize, const PUBS: usize> PubSubBehavior<T>
93 for PubSubChannel<M, T, CAP, SUBS, PUBS>
94{
95 fn try_publish(&self, message: T) -> Result<(), T> {
96 self.inner.lock(|inner| {
97 let mut s = inner.borrow_mut();
98
99 let active_subscriber_count = s.subscriber_wakers.iter().flatten().count();
100
101 if active_subscriber_count == 0 {
102 // We don't need to publish anything because there is no one to receive it
103 return Ok(());
104 }
105
106 if s.queue.is_full() {
107 return Err(message);
108 }
109 // We just did a check for this
110 unsafe {
111 s.queue.push_back_unchecked((message, active_subscriber_count));
112 }
113
114 s.next_message_id += 1;
115
116 // Wake all of the subscribers
117 for active_subscriber in s.subscriber_wakers.iter_mut().flatten() {
118 active_subscriber.wake()
119 }
120
121 Ok(())
122 })
123 }
124
125 fn publish_immediate(&self, message: T) {
126 self.inner.lock(|inner| {
127 let mut s = inner.borrow_mut();
128
129 // Make space in the queue if required
130 if s.queue.is_full() {
131 s.queue.pop_front();
132 }
133
134 // We are going to call something is Self again.
135 // The lock is fine, but we need to get rid of the refcell borrow
136 drop(s);
137
138 // This will succeed because we made sure there is space
139 unsafe { self.try_publish(message).unwrap_unchecked() };
140 });
141 }
142
143 fn get_message(&self, message_id: u64) -> Option<WaitResult<T>> {
144 self.inner.lock(|inner| {
145 let mut s = inner.borrow_mut();
146
147 let start_id = s.next_message_id - s.queue.len() as u64;
148
149 if message_id < start_id {
150 return Some(WaitResult::Lagged(start_id - message_id));
151 }
152
153 let current_message_index = (message_id - start_id) as usize;
154
155 if current_message_index >= s.queue.len() {
156 return None;
157 }
158
159 // We've checked that the index is valid
160 unsafe {
161 let queue_item = s.queue.iter_mut().nth(current_message_index).unwrap_unchecked();
162
163 // We're reading this item, so decrement the counter
164 queue_item.1 -= 1;
165 let message = queue_item.0.clone();
166
167 if current_message_index == 0 && queue_item.1 == 0 {
168 s.queue.pop_front();
169 s.publisher_wakers.iter_mut().flatten().for_each(|w| w.wake());
170 }
171
172 Some(WaitResult::Message(message))
173 }
174 })
175 }
176
177 unsafe fn register_subscriber_waker(&self, subscriber_index: usize, waker: &Waker) {
178 self.inner.lock(|inner| {
179 let mut s = inner.borrow_mut();
180 s.subscriber_wakers
181 .get_unchecked_mut(subscriber_index)
182 .as_mut()
183 .unwrap_unchecked()
184 .register(waker);
185 })
186 }
187
188 unsafe fn register_publisher_waker(&self, publisher_index: usize, waker: &Waker) {
189 self.inner.lock(|inner| {
190 let mut s = inner.borrow_mut();
191 s.publisher_wakers
192 .get_unchecked_mut(publisher_index)
193 .as_mut()
194 .unwrap_unchecked()
195 .register(waker);
196 })
197 }
198
199 unsafe fn unregister_subscriber(&self, subscriber_index: usize, subscriber_next_message_id: u64) {
200 self.inner.lock(|inner| {
201 let mut s = inner.borrow_mut();
202
203 // Remove the subscriber from the wakers
204 *s.subscriber_wakers.get_unchecked_mut(subscriber_index) = None;
205
206 // All messages that haven't been read yet by this subscriber must have their counter decremented
207 let start_id = s.next_message_id - s.queue.len() as u64;
208 if subscriber_next_message_id >= start_id {
209 let current_message_index = (subscriber_next_message_id - start_id) as usize;
210 s.queue
211 .iter_mut()
212 .skip(current_message_index)
213 .for_each(|(_, counter)| *counter -= 1);
214 }
215 })
216 }
217
218 unsafe fn unregister_publisher(&self, publisher_index: usize) {
219 self.inner.lock(|inner| {
220 let mut s = inner.borrow_mut();
221 // Remove the publisher from the wakers
222 *s.publisher_wakers.get_unchecked_mut(publisher_index) = None;
223 })
224 }
225}
226
227/// Internal state for the PubSub channel
228struct PubSubState<T: Clone, const CAP: usize, const SUBS: usize, const PUBS: usize> {
229 /// The queue contains the last messages that have been published and a countdown of how many subscribers are yet to read it
230 queue: Deque<(T, usize), CAP>,
231 /// Every message has an id.
232 /// Don't worry, we won't run out.
233 /// If a million messages were published every second, then the ID's would run out in about 584942 years.
234 next_message_id: u64,
235 /// Collection of wakers for Subscribers that are waiting.
236 /// The [Subscriber::subscriber_index] field indexes into this array.
237 subscriber_wakers: [Option<WakerRegistration>; SUBS],
238 /// Collection of wakers for Publishers that are waiting.
239 /// The [Publisher::publisher_index] field indexes into this array.
240 publisher_wakers: [Option<WakerRegistration>; PUBS],
241}
242
243impl<T: Clone, const CAP: usize, const SUBS: usize, const PUBS: usize> PubSubState<T, CAP, SUBS, PUBS> {
244 /// Create a new internal channel state
245 const fn new() -> Self {
246 const WAKER_INIT: Option<WakerRegistration> = None;
247 Self {
248 queue: Deque::new(),
249 next_message_id: 0,
250 subscriber_wakers: [WAKER_INIT; SUBS],
251 publisher_wakers: [WAKER_INIT; PUBS],
252 }
253 }
254}
255
256/// A subscriber to a channel
257///
258/// This instance carries a reference to the channel, but uses a trait object for it so that the channel's
259/// generics are erased on this subscriber
260pub struct Subscriber<'a, T: Clone> {
261 /// Our index into the channel
262 subscriber_index: usize,
263 /// The message id of the next message we are yet to receive
264 next_message_id: u64,
265 /// The channel we are a subscriber to
266 channel: &'a dyn PubSubBehavior<T>,
267}
268
269impl<'a, T: Clone> Subscriber<'a, T> {
270 /// Wait for a published message
271 pub fn wait<'s>(&'s mut self) -> SubscriberWaitFuture<'s, 'a, T> {
272 SubscriberWaitFuture { subscriber: self }
273 }
274
275 /// Try to see if there's a published message we haven't received yet.
276 ///
277 /// This function does not peek. The message is received if there is one.
278 pub fn check(&mut self) -> Option<WaitResult<T>> {
279 match self.channel.get_message(self.next_message_id) {
280 Some(WaitResult::Lagged(amount)) => {
281 self.next_message_id += amount;
282 Some(WaitResult::Lagged(amount))
283 }
284 result => {
285 self.next_message_id += 1;
286 result
287 }
288 }
289 }
290}
291
292impl<'a, T: Clone> Drop for Subscriber<'a, T> {
293 fn drop(&mut self) {
294 unsafe {
295 self.channel
296 .unregister_subscriber(self.subscriber_index, self.next_message_id)
297 }
298 }
299}
300
301/// A publisher to a channel
302///
303/// This instance carries a reference to the channel, but uses a trait object for it so that the channel's
304/// generics are erased on this subscriber
305pub struct Publisher<'a, T: Clone> {
306 /// Our index into the channel
307 publisher_index: usize,
308 /// The channel we are a publisher for
309 channel: &'a dyn PubSubBehavior<T>,
310}
311
312impl<'a, T: Clone> Publisher<'a, T> {
313 /// Publish a message right now even when the queue is full.
314 /// This may cause a subscriber to miss an older message.
315 pub fn publish_immediate(&self, message: T) {
316 self.channel.publish_immediate(message)
317 }
318
319 /// Publish a message. But if the message queue is full, wait for all subscribers to have read the last message
320 pub fn publish<'s>(&'s self, message: T) -> PublisherWaitFuture<'s, 'a, T> {
321 PublisherWaitFuture {
322 message: Some(message),
323 publisher: self,
324 }
325 }
326
327 /// Publish a message if there is space in the message queue
328 pub fn try_publish(&self, message: T) -> Result<(), T> {
329 self.channel.try_publish(message)
330 }
331}
332
333impl<'a, T: Clone> Drop for Publisher<'a, T> {
334 fn drop(&mut self) {
335 unsafe { self.channel.unregister_publisher(self.publisher_index) }
336 }
337}
338
339/// A publisher that can only use the `publish_immediate` function, but it doesn't have to be registered with the channel.
340/// (So an infinite amount is possible)
341pub struct ImmediatePublisher<'a, T: Clone> {
342 /// The channel we are a publisher for
343 channel: &'a dyn PubSubBehavior<T>,
344}
345
346impl<'a, T: Clone> ImmediatePublisher<'a, T> {
347 /// Publish the message right now even when the queue is full.
348 /// This may cause a subscriber to miss an older message.
349 pub fn publish_immediate(&mut self, message: T) {
350 self.channel.publish_immediate(message)
351 }
352
353 /// Publish a message if there is space in the message queue
354 pub fn try_publish(&self, message: T) -> Result<(), T> {
355 self.channel.try_publish(message)
356 }
357
358}
359
360/// Error type for the [PubSubChannel]
361#[derive(Debug, PartialEq, Clone)]
362pub enum Error {
363 /// All subscriber slots are used. To add another subscriber, first another subscriber must be dropped or
364 /// the capacity of the channels must be increased.
365 MaximumSubscribersReached,
366 /// All publisher slots are used. To add another publisher, first another publisher must be dropped or
367 /// the capacity of the channels must be increased.
368 MaximumPublishersReached,
369}
370
371trait PubSubBehavior<T> {
372 /// Try to publish a message. If the queue is full it won't succeed
373 fn try_publish(&self, message: T) -> Result<(), T>;
374 /// Publish a message immediately. If the queue is full, just throw out the oldest one.
375 fn publish_immediate(&self, message: T);
376 /// Tries to read the message if available
377 fn get_message(&self, message_id: u64) -> Option<WaitResult<T>>;
378 /// Register the given waker for the given subscriber.
379 ///
380 /// ## Safety
381 ///
382 /// The subscriber index must be of a valid and active subscriber
383 unsafe fn register_subscriber_waker(&self, subscriber_index: usize, waker: &Waker);
384 /// Register the given waker for the given publisher.
385 ///
386 /// ## Safety
387 ///
388 /// The subscriber index must be of a valid and active publisher
389 unsafe fn register_publisher_waker(&self, publisher_index: usize, waker: &Waker);
390 /// Make the channel forget the subscriber.
391 ///
392 /// ## Safety
393 ///
394 /// The subscriber index must be of a valid and active subscriber which must not be used again
395 /// unless a new subscriber takes on that index.
396 unsafe fn unregister_subscriber(&self, subscriber_index: usize, subscriber_next_message_id: u64);
397 /// Make the channel forget the publisher.
398 ///
399 /// ## Safety
400 ///
401 /// The publisher index must be of a valid and active publisher which must not be used again
402 /// unless a new publisher takes on that index.
403 unsafe fn unregister_publisher(&self, publisher_index: usize);
404}
405
406/// Future for the subscriber wait action
407pub struct SubscriberWaitFuture<'s, 'a, T: Clone> {
408 subscriber: &'s mut Subscriber<'a, T>,
409}
410
411impl<'s, 'a, T: Clone> Future for SubscriberWaitFuture<'s, 'a, T> {
412 type Output = WaitResult<T>;
413
414 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
415 // Check if we can read a message
416 match self.subscriber.channel.get_message(self.subscriber.next_message_id) {
417 // Yes, so we are done polling
418 Some(WaitResult::Message(message)) => {
419 self.subscriber.next_message_id += 1;
420 Poll::Ready(WaitResult::Message(message))
421 }
422 // No, so we need to reregister our waker and sleep again
423 None => {
424 unsafe {
425 self.subscriber
426 .channel
427 .register_subscriber_waker(self.subscriber.subscriber_index, cx.waker());
428 }
429 Poll::Pending
430 }
431 // We missed a couple of messages. We must do our internal bookkeeping and return that we lagged
432 Some(WaitResult::Lagged(amount)) => {
433 self.subscriber.next_message_id += amount;
434 Poll::Ready(WaitResult::Lagged(amount))
435 }
436 }
437 }
438}
439
440/// Future for the publisher wait action
441pub struct PublisherWaitFuture<'s, 'a, T: Clone> {
442 /// The message we need to publish
443 message: Option<T>,
444 publisher: &'s Publisher<'a, T>,
445}
446
447impl<'s, 'a, T: Clone> Future for PublisherWaitFuture<'s, 'a, T> {
448 type Output = ();
449
450 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
451 let this = unsafe { self.get_unchecked_mut() };
452
453 // Try to publish the message
454 match this.publisher.channel.try_publish(this.message.take().unwrap()) {
455 // We did it, we are ready
456 Ok(()) => Poll::Ready(()),
457 // The queue is full, so we need to reregister our waker and go to sleep
458 Err(message) => {
459 this.message = Some(message);
460 unsafe {
461 this.publisher
462 .channel
463 .register_publisher_waker(this.publisher.publisher_index, cx.waker());
464 }
465 Poll::Pending
466 }
467 }
468 }
469}
470
471/// The result of the subscriber wait procedure
472#[derive(Debug, Clone, PartialEq)]
473pub enum WaitResult<T> {
474 /// The subscriber did not receive all messages and lagged by the given amount of messages.
475 /// (This is the amount of messages that were missed)
476 Lagged(u64),
477 /// A message was received
478 Message(T),
479}
480
481#[cfg(test)]
482mod tests {
483 use crate::blocking_mutex::raw::NoopRawMutex;
484 use super::*;
485
486 #[futures_test::test]
487 async fn all_subscribers_receive() {
488 let channel = PubSubChannel::<NoopRawMutex, u32, 4, 4, 4>::new();
489
490 let mut sub0 = channel.subscriber().unwrap();
491 let mut sub1 = channel.subscriber().unwrap();
492 let pub0 = channel.publisher().unwrap();
493
494 pub0.publish(42).await;
495
496 assert_eq!(sub0.wait().await, WaitResult::Message(42));
497 assert_eq!(sub1.wait().await, WaitResult::Message(42));
498
499 assert_eq!(sub0.check(), None);
500 assert_eq!(sub1.check(), None);
501 }
502
503 #[futures_test::test]
504 async fn lag_when_queue_full_on_immediate_publish() {
505 let channel = PubSubChannel::<NoopRawMutex, u32, 4, 4, 4>::new();
506
507 let mut sub0 = channel.subscriber().unwrap();
508 let pub0 = channel.publisher().unwrap();
509
510 pub0.publish_immediate(42);
511 pub0.publish_immediate(43);
512 pub0.publish_immediate(44);
513 pub0.publish_immediate(45);
514 pub0.publish_immediate(46);
515 pub0.publish_immediate(47);
516
517 assert_eq!(sub0.check(), Some(WaitResult::Lagged(2)));
518 assert_eq!(sub0.wait().await, WaitResult::Message(44));
519 assert_eq!(sub0.wait().await, WaitResult::Message(45));
520 assert_eq!(sub0.wait().await, WaitResult::Message(46));
521 assert_eq!(sub0.wait().await, WaitResult::Message(47));
522 assert_eq!(sub0.check(), None);
523 }
524
525 #[test]
526 fn limited_subs_and_pubs() {
527 let channel = PubSubChannel::<NoopRawMutex, u32, 4, 4, 4>::new();
528
529 let sub0 = channel.subscriber();
530 let sub1 = channel.subscriber();
531 let sub2 = channel.subscriber();
532 let sub3 = channel.subscriber();
533 let sub4 = channel.subscriber();
534
535 assert!(sub0.is_ok());
536 assert!(sub1.is_ok());
537 assert!(sub2.is_ok());
538 assert!(sub3.is_ok());
539 assert_eq!(sub4.err().unwrap(), Error::MaximumSubscribersReached);
540
541 drop(sub0);
542
543 let sub5 = channel.subscriber();
544 assert!(sub5.is_ok());
545
546 // publishers
547
548 let pub0 = channel.publisher();
549 let pub1 = channel.publisher();
550 let pub2 = channel.publisher();
551 let pub3 = channel.publisher();
552 let pub4 = channel.publisher();
553
554 assert!(pub0.is_ok());
555 assert!(pub1.is_ok());
556 assert!(pub2.is_ok());
557 assert!(pub3.is_ok());
558 assert_eq!(pub4.err().unwrap(), Error::MaximumPublishersReached);
559
560 drop(pub0);
561
562 let pub5 = channel.publisher();
563 assert!(pub5.is_ok());
564 }
565
566 #[test]
567 fn publisher_wait_on_full_queue() {
568 let channel = PubSubChannel::<NoopRawMutex, u32, 4, 4, 4>::new();
569
570 let pub0 = channel.publisher().unwrap();
571
572 // There are no subscribers, so the queue will never be full
573 assert_eq!(pub0.try_publish(0), Ok(()));
574 assert_eq!(pub0.try_publish(0), Ok(()));
575 assert_eq!(pub0.try_publish(0), Ok(()));
576 assert_eq!(pub0.try_publish(0), Ok(()));
577 assert_eq!(pub0.try_publish(0), Ok(()));
578
579 let sub0 = channel.subscriber().unwrap();
580
581 assert_eq!(pub0.try_publish(0), Ok(()));
582 assert_eq!(pub0.try_publish(0), Ok(()));
583 assert_eq!(pub0.try_publish(0), Ok(()));
584 assert_eq!(pub0.try_publish(0), Ok(()));
585 assert_eq!(pub0.try_publish(0), Err(0));
586
587 drop(sub0);
588 }
589
590}