aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDion Dokter <[email protected]>2022-06-16 22:11:29 +0200
committerDion Dokter <[email protected]>2022-06-16 22:11:29 +0200
commita614a55c7ddecc171d48c61bf9fa8c6c11ed16f4 (patch)
tree3f6f246e8b3cd36c2e47ce072af709c65d8c4f46
parentdfde157337b379ff0805cfe4aa5463c078ca1d41 (diff)
Put most behaviour one level lower (under the mutex instead of above).
Changed the PubSubBehavior to only have high level functions.
-rw-r--r--embassy/src/channel/pubsub.rs347
1 files changed, 181 insertions, 166 deletions
diff --git a/embassy/src/channel/pubsub.rs b/embassy/src/channel/pubsub.rs
index 20878187d..c5a8c01f8 100644
--- a/embassy/src/channel/pubsub.rs
+++ b/embassy/src/channel/pubsub.rs
@@ -94,122 +94,74 @@ impl<M: RawMutex, T: Clone, const CAP: usize, const SUBS: usize, const PUBS: usi
94impl<M: RawMutex, T: Clone, const CAP: usize, const SUBS: usize, const PUBS: usize> PubSubBehavior<T> 94impl<M: RawMutex, T: Clone, const CAP: usize, const SUBS: usize, const PUBS: usize> PubSubBehavior<T>
95 for PubSubChannel<M, T, CAP, SUBS, PUBS> 95 for PubSubChannel<M, T, CAP, SUBS, PUBS>
96{ 96{
97 fn try_publish(&self, message: T) -> Result<(), T> { 97 fn get_message_with_context(
98 self.inner.lock(|inner| { 98 &self,
99 let mut s = inner.borrow_mut(); 99 next_message_id: &mut u64,
100 100 subscriber_index: usize,
101 let active_subscriber_count = s.subscriber_wakers.iter().flatten().count(); 101 cx: Option<&mut Context<'_>>,
102 102 ) -> Poll<WaitResult<T>> {
103 if active_subscriber_count == 0 { 103 self.inner.lock(|s| {
104 // We don't need to publish anything because there is no one to receive it 104 let mut s = s.borrow_mut();
105 return Ok(()); 105
106 } 106 // Check if we can read a message
107 107 match s.get_message(*next_message_id) {
108 if s.queue.is_full() { 108 // Yes, so we are done polling
109 return Err(message); 109 Some(WaitResult::Message(message)) => {
110 } 110 *next_message_id += 1;
111 // We just did a check for this 111 Poll::Ready(WaitResult::Message(message))
112 s.queue.push_back((message, active_subscriber_count)).ok().unwrap(); 112 }
113 113 // No, so we need to reregister our waker and sleep again
114 s.next_message_id += 1; 114 None => {
115 115 if let Some(cx) = cx {
116 // Wake all of the subscribers 116 s.register_subscriber_waker(subscriber_index, cx.waker());
117 for active_subscriber in s.subscriber_wakers.iter_mut().flatten() { 117 }
118 active_subscriber.wake() 118 Poll::Pending
119 }
120 // We missed a couple of messages. We must do our internal bookkeeping and return that we lagged
121 Some(WaitResult::Lagged(amount)) => {
122 *next_message_id += amount;
123 Poll::Ready(WaitResult::Lagged(amount))
124 }
119 } 125 }
120
121 Ok(())
122 }) 126 })
123 } 127 }
124 128
125 fn publish_immediate(&self, message: T) { 129 fn publish_with_context(&self, message: T, publisher_index: usize, cx: Option<&mut Context<'_>>) -> Result<(), T> {
126 self.inner.lock(|inner| { 130 self.inner.lock(|s| {
127 let mut s = inner.borrow_mut(); 131 let mut s = s.borrow_mut();
128 132 // Try to publish the message
129 // Make space in the queue if required 133 match s.try_publish(message) {
130 if s.queue.is_full() { 134 // We did it, we are ready
131 s.queue.pop_front(); 135 Ok(()) => Ok(()),
132 } 136 // The queue is full, so we need to reregister our waker and go to sleep
133 137 Err(message) => {
134 // We are going to call something is Self again. 138 if let Some(cx) = cx {
135 // The lock is fine, but we need to get rid of the refcell borrow 139 s.register_publisher_waker(publisher_index, cx.waker());
136 drop(s); 140 }
137 141 Err(message)
138 // This will succeed because we made sure there is space 142 }
139 self.try_publish(message).ok().unwrap();
140 });
141 }
142
143 fn get_message(&self, message_id: u64) -> Option<WaitResult<T>> {
144 self.inner.lock(|inner| {
145 let mut s = inner.borrow_mut();
146
147 let start_id = s.next_message_id - s.queue.len() as u64;
148
149 if message_id < start_id {
150 return Some(WaitResult::Lagged(start_id - message_id));
151 }
152
153 let current_message_index = (message_id - start_id) as usize;
154
155 if current_message_index >= s.queue.len() {
156 return None;
157 }
158
159 // We've checked that the index is valid
160 let queue_item = s.queue.iter_mut().nth(current_message_index).unwrap();
161
162 // We're reading this item, so decrement the counter
163 queue_item.1 -= 1;
164 let message = queue_item.0.clone();
165
166 if current_message_index == 0 && queue_item.1 == 0 {
167 s.queue.pop_front();
168 s.publisher_wakers.iter_mut().flatten().for_each(|w| w.wake());
169 } 143 }
170
171 Some(WaitResult::Message(message))
172 })
173 }
174
175 fn register_subscriber_waker(&self, subscriber_index: usize, waker: &Waker) {
176 self.inner.lock(|inner| {
177 let mut s = inner.borrow_mut();
178 s.subscriber_wakers[subscriber_index].as_mut().unwrap().register(waker);
179 }) 144 })
180 } 145 }
181 146
182 fn register_publisher_waker(&self, publisher_index: usize, waker: &Waker) { 147 fn publish_immediate(&self, message: T) {
183 self.inner.lock(|inner| { 148 self.inner.lock(|s| {
184 let mut s = inner.borrow_mut(); 149 let mut s = s.borrow_mut();
185 s.publisher_wakers[publisher_index].as_mut().unwrap().register(waker); 150 s.publish_immediate(message)
186 }) 151 })
187 } 152 }
188 153
189 fn unregister_subscriber(&self, subscriber_index: usize, subscriber_next_message_id: u64) { 154 fn unregister_subscriber(&self, subscriber_index: usize, subscriber_next_message_id: u64) {
190 self.inner.lock(|inner| { 155 self.inner.lock(|s| {
191 let mut s = inner.borrow_mut(); 156 let mut s = s.borrow_mut();
192 157 s.unregister_subscriber(subscriber_index, subscriber_next_message_id)
193 // Remove the subscriber from the wakers
194 s.subscriber_wakers[subscriber_index] = None;
195
196 // All messages that haven't been read yet by this subscriber must have their counter decremented
197 let start_id = s.next_message_id - s.queue.len() as u64;
198 if subscriber_next_message_id >= start_id {
199 let current_message_index = (subscriber_next_message_id - start_id) as usize;
200 s.queue
201 .iter_mut()
202 .skip(current_message_index)
203 .for_each(|(_, counter)| *counter -= 1);
204 }
205 }) 158 })
206 } 159 }
207 160
208 fn unregister_publisher(&self, publisher_index: usize) { 161 fn unregister_publisher(&self, publisher_index: usize) {
209 self.inner.lock(|inner| { 162 self.inner.lock(|s| {
210 let mut s = inner.borrow_mut(); 163 let mut s = s.borrow_mut();
211 // Remove the publisher from the wakers 164 s.unregister_publisher(publisher_index)
212 s.publisher_wakers[publisher_index] = None;
213 }) 165 })
214 } 166 }
215} 167}
@@ -241,6 +193,99 @@ impl<T: Clone, const CAP: usize, const SUBS: usize, const PUBS: usize> PubSubSta
241 publisher_wakers: [WAKER_INIT; PUBS], 193 publisher_wakers: [WAKER_INIT; PUBS],
242 } 194 }
243 } 195 }
196
197 fn try_publish(&mut self, message: T) -> Result<(), T> {
198 let active_subscriber_count = self.subscriber_wakers.iter().flatten().count();
199
200 if active_subscriber_count == 0 {
201 // We don't need to publish anything because there is no one to receive it
202 return Ok(());
203 }
204
205 if self.queue.is_full() {
206 return Err(message);
207 }
208 // We just did a check for this
209 self.queue.push_back((message, active_subscriber_count)).ok().unwrap();
210
211 self.next_message_id += 1;
212
213 // Wake all of the subscribers
214 for active_subscriber in self.subscriber_wakers.iter_mut().flatten() {
215 active_subscriber.wake()
216 }
217
218 Ok(())
219 }
220
221 fn publish_immediate(&mut self, message: T) {
222 // Make space in the queue if required
223 if self.queue.is_full() {
224 self.queue.pop_front();
225 }
226
227 // This will succeed because we made sure there is space
228 self.try_publish(message).ok().unwrap();
229 }
230
231 fn get_message(&mut self, message_id: u64) -> Option<WaitResult<T>> {
232 let start_id = self.next_message_id - self.queue.len() as u64;
233
234 if message_id < start_id {
235 return Some(WaitResult::Lagged(start_id - message_id));
236 }
237
238 let current_message_index = (message_id - start_id) as usize;
239
240 if current_message_index >= self.queue.len() {
241 return None;
242 }
243
244 // We've checked that the index is valid
245 let queue_item = self.queue.iter_mut().nth(current_message_index).unwrap();
246
247 // We're reading this item, so decrement the counter
248 queue_item.1 -= 1;
249 let message = queue_item.0.clone();
250
251 if current_message_index == 0 && queue_item.1 == 0 {
252 self.queue.pop_front();
253 self.publisher_wakers.iter_mut().flatten().for_each(|w| w.wake());
254 }
255
256 Some(WaitResult::Message(message))
257 }
258
259 fn register_subscriber_waker(&mut self, subscriber_index: usize, waker: &Waker) {
260 self.subscriber_wakers[subscriber_index]
261 .as_mut()
262 .unwrap()
263 .register(waker);
264 }
265
266 fn register_publisher_waker(&mut self, publisher_index: usize, waker: &Waker) {
267 self.publisher_wakers[publisher_index].as_mut().unwrap().register(waker);
268 }
269
270 fn unregister_subscriber(&mut self, subscriber_index: usize, subscriber_next_message_id: u64) {
271 // Remove the subscriber from the wakers
272 self.subscriber_wakers[subscriber_index] = None;
273
274 // All messages that haven't been read yet by this subscriber must have their counter decremented
275 let start_id = self.next_message_id - self.queue.len() as u64;
276 if subscriber_next_message_id >= start_id {
277 let current_message_index = (subscriber_next_message_id - start_id) as usize;
278 self.queue
279 .iter_mut()
280 .skip(current_message_index)
281 .for_each(|(_, counter)| *counter -= 1);
282 }
283 }
284
285 fn unregister_publisher(&mut self, publisher_index: usize) {
286 // Remove the publisher from the wakers
287 self.publisher_wakers[publisher_index] = None;
288 }
244} 289}
245 290
246/// A subscriber to a channel 291/// A subscriber to a channel
@@ -276,15 +321,12 @@ impl<'a, T: Clone> Subscriber<'a, T> {
276 /// 321 ///
277 /// This function does not peek. The message is received if there is one. 322 /// This function does not peek. The message is received if there is one.
278 pub fn try_next_message(&mut self) -> Option<WaitResult<T>> { 323 pub fn try_next_message(&mut self) -> Option<WaitResult<T>> {
279 match self.channel.get_message(self.next_message_id) { 324 match self
280 Some(WaitResult::Lagged(amount)) => { 325 .channel
281 self.next_message_id += amount; 326 .get_message_with_context(&mut self.next_message_id, self.subscriber_index, None)
282 Some(WaitResult::Lagged(amount)) 327 {
283 } 328 Poll::Ready(result) => Some(result),
284 result => { 329 Poll::Pending => None,
285 self.next_message_id += 1;
286 result
287 }
288 } 330 }
289 } 331 }
290 332
@@ -317,26 +359,16 @@ impl<'a, T: Clone> futures::Stream for Subscriber<'a, T> {
317 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { 359 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
318 let this = unsafe { self.get_unchecked_mut() }; 360 let this = unsafe { self.get_unchecked_mut() };
319 361
320 // Check if we can read a message 362 match this
321 match this.channel.get_message(this.next_message_id) { 363 .channel
322 // Yes, so we are done polling 364 .get_message_with_context(&mut this.next_message_id, this.subscriber_index, Some(cx))
323 Some(WaitResult::Message(message)) => { 365 {
324 this.next_message_id += 1; 366 Poll::Ready(WaitResult::Message(message)) => Poll::Ready(Some(message)),
325 Poll::Ready(Some(message)) 367 Poll::Ready(WaitResult::Lagged(_)) => {
326 }
327 // No, so we need to reregister our waker and sleep again
328 None => {
329 this.channel
330 .register_subscriber_waker(this.subscriber_index, cx.waker());
331 Poll::Pending
332 }
333 // We missed a couple of messages. We must do our internal bookkeeping.
334 // This stream impl doesn't return lag results, so we just ignore and start over
335 Some(WaitResult::Lagged(amount)) => {
336 this.next_message_id += amount;
337 cx.waker().wake_by_ref(); 368 cx.waker().wake_by_ref();
338 Poll::Pending 369 Poll::Pending
339 } 370 }
371 Poll::Pending => Poll::Pending,
340 } 372 }
341 } 373 }
342} 374}
@@ -369,7 +401,7 @@ impl<'a, T: Clone> Publisher<'a, T> {
369 401
370 /// Publish a message if there is space in the message queue 402 /// Publish a message if there is space in the message queue
371 pub fn try_publish(&self, message: T) -> Result<(), T> { 403 pub fn try_publish(&self, message: T) -> Result<(), T> {
372 self.channel.try_publish(message) 404 self.channel.publish_with_context(message, self.publisher_index, None)
373 } 405 }
374} 406}
375 407
@@ -395,7 +427,7 @@ impl<'a, T: Clone> ImmediatePublisher<'a, T> {
395 427
396 /// Publish a message if there is space in the message queue 428 /// Publish a message if there is space in the message queue
397 pub fn try_publish(&self, message: T) -> Result<(), T> { 429 pub fn try_publish(&self, message: T) -> Result<(), T> {
398 self.channel.try_publish(message) 430 self.channel.publish_with_context(message, usize::MAX, None)
399 } 431 }
400} 432}
401 433
@@ -411,19 +443,19 @@ pub enum Error {
411} 443}
412 444
413trait PubSubBehavior<T> { 445trait PubSubBehavior<T> {
414 /// Try to publish a message. If the queue is full it won't succeed 446 fn get_message_with_context(
415 fn try_publish(&self, message: T) -> Result<(), T>; 447 &self,
416 /// Publish a message immediately. If the queue is full, just throw out the oldest one. 448 next_message_id: &mut u64,
449 subscriber_index: usize,
450 cx: Option<&mut Context<'_>>,
451 ) -> Poll<WaitResult<T>>;
452
453 fn publish_with_context(&self, message: T, publisher_index: usize, cx: Option<&mut Context<'_>>) -> Result<(), T>;
454
417 fn publish_immediate(&self, message: T); 455 fn publish_immediate(&self, message: T);
418 /// Tries to read the message if available 456
419 fn get_message(&self, message_id: u64) -> Option<WaitResult<T>>;
420 /// Register the given waker for the given subscriber.
421 fn register_subscriber_waker(&self, subscriber_index: usize, waker: &Waker);
422 /// Register the given waker for the given publisher.
423 fn register_publisher_waker(&self, publisher_index: usize, waker: &Waker);
424 /// Make the channel forget the subscriber.
425 fn unregister_subscriber(&self, subscriber_index: usize, subscriber_next_message_id: u64); 457 fn unregister_subscriber(&self, subscriber_index: usize, subscriber_next_message_id: u64);
426 /// Make the channel forget the publisher. 458
427 fn unregister_publisher(&self, publisher_index: usize); 459 fn unregister_publisher(&self, publisher_index: usize);
428} 460}
429 461
@@ -436,26 +468,10 @@ impl<'s, 'a, T: Clone> Future for SubscriberWaitFuture<'s, 'a, T> {
436 type Output = WaitResult<T>; 468 type Output = WaitResult<T>;
437 469
438 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { 470 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
439 // Check if we can read a message 471 let sub_index = self.subscriber.subscriber_index;
440 match self.subscriber.channel.get_message(self.subscriber.next_message_id) { 472 self.subscriber
441 // Yes, so we are done polling 473 .channel
442 Some(WaitResult::Message(message)) => { 474 .get_message_with_context(&mut self.subscriber.next_message_id, sub_index, Some(cx))
443 self.subscriber.next_message_id += 1;
444 Poll::Ready(WaitResult::Message(message))
445 }
446 // No, so we need to reregister our waker and sleep again
447 None => {
448 self.subscriber
449 .channel
450 .register_subscriber_waker(self.subscriber.subscriber_index, cx.waker());
451 Poll::Pending
452 }
453 // We missed a couple of messages. We must do our internal bookkeeping and return that we lagged
454 Some(WaitResult::Lagged(amount)) => {
455 self.subscriber.next_message_id += amount;
456 Poll::Ready(WaitResult::Lagged(amount))
457 }
458 }
459 } 475 }
460} 476}
461 477
@@ -474,16 +490,15 @@ impl<'s, 'a, T: Clone> Future for PublisherWaitFuture<'s, 'a, T> {
474 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { 490 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
475 let this = unsafe { self.get_unchecked_mut() }; 491 let this = unsafe { self.get_unchecked_mut() };
476 492
477 // Try to publish the message 493 let message = this.message.take().unwrap();
478 match this.publisher.channel.try_publish(this.message.take().unwrap()) { 494 match this
479 // We did it, we are ready 495 .publisher
496 .channel
497 .publish_with_context(message, this.publisher.publisher_index, Some(cx))
498 {
480 Ok(()) => Poll::Ready(()), 499 Ok(()) => Poll::Ready(()),
481 // The queue is full, so we need to reregister our waker and go to sleep
482 Err(message) => { 500 Err(message) => {
483 this.message = Some(message); 501 this.message = Some(message);
484 this.publisher
485 .channel
486 .register_publisher_waker(this.publisher.publisher_index, cx.waker());
487 Poll::Pending 502 Poll::Pending
488 } 503 }
489 } 504 }