diff options
| author | Dion Dokter <[email protected]> | 2022-06-16 22:11:29 +0200 |
|---|---|---|
| committer | Dion Dokter <[email protected]> | 2022-06-16 22:11:29 +0200 |
| commit | a614a55c7ddecc171d48c61bf9fa8c6c11ed16f4 (patch) | |
| tree | 3f6f246e8b3cd36c2e47ce072af709c65d8c4f46 | |
| parent | dfde157337b379ff0805cfe4aa5463c078ca1d41 (diff) | |
Put most behaviour one level lower (under the mutex instead of above).
Changed the PubSubBehavior to only have high level functions.
| -rw-r--r-- | embassy/src/channel/pubsub.rs | 347 |
1 files changed, 181 insertions, 166 deletions
diff --git a/embassy/src/channel/pubsub.rs b/embassy/src/channel/pubsub.rs index 20878187d..c5a8c01f8 100644 --- a/embassy/src/channel/pubsub.rs +++ b/embassy/src/channel/pubsub.rs | |||
| @@ -94,122 +94,74 @@ impl<M: RawMutex, T: Clone, const CAP: usize, const SUBS: usize, const PUBS: usi | |||
| 94 | impl<M: RawMutex, T: Clone, const CAP: usize, const SUBS: usize, const PUBS: usize> PubSubBehavior<T> | 94 | impl<M: RawMutex, T: Clone, const CAP: usize, const SUBS: usize, const PUBS: usize> PubSubBehavior<T> |
| 95 | for PubSubChannel<M, T, CAP, SUBS, PUBS> | 95 | for PubSubChannel<M, T, CAP, SUBS, PUBS> |
| 96 | { | 96 | { |
| 97 | fn try_publish(&self, message: T) -> Result<(), T> { | 97 | fn get_message_with_context( |
| 98 | self.inner.lock(|inner| { | 98 | &self, |
| 99 | let mut s = inner.borrow_mut(); | 99 | next_message_id: &mut u64, |
| 100 | 100 | subscriber_index: usize, | |
| 101 | let active_subscriber_count = s.subscriber_wakers.iter().flatten().count(); | 101 | cx: Option<&mut Context<'_>>, |
| 102 | 102 | ) -> Poll<WaitResult<T>> { | |
| 103 | if active_subscriber_count == 0 { | 103 | self.inner.lock(|s| { |
| 104 | // We don't need to publish anything because there is no one to receive it | 104 | let mut s = s.borrow_mut(); |
| 105 | return Ok(()); | 105 | |
| 106 | } | 106 | // Check if we can read a message |
| 107 | 107 | match s.get_message(*next_message_id) { | |
| 108 | if s.queue.is_full() { | 108 | // Yes, so we are done polling |
| 109 | return Err(message); | 109 | Some(WaitResult::Message(message)) => { |
| 110 | } | 110 | *next_message_id += 1; |
| 111 | // We just did a check for this | 111 | Poll::Ready(WaitResult::Message(message)) |
| 112 | s.queue.push_back((message, active_subscriber_count)).ok().unwrap(); | 112 | } |
| 113 | 113 | // No, so we need to reregister our waker and sleep again | |
| 114 | s.next_message_id += 1; | 114 | None => { |
| 115 | 115 | if let Some(cx) = cx { | |
| 116 | // Wake all of the subscribers | 116 | s.register_subscriber_waker(subscriber_index, cx.waker()); |
| 117 | for active_subscriber in s.subscriber_wakers.iter_mut().flatten() { | 117 | } |
| 118 | active_subscriber.wake() | 118 | Poll::Pending |
| 119 | } | ||
| 120 | // We missed a couple of messages. We must do our internal bookkeeping and return that we lagged | ||
| 121 | Some(WaitResult::Lagged(amount)) => { | ||
| 122 | *next_message_id += amount; | ||
| 123 | Poll::Ready(WaitResult::Lagged(amount)) | ||
| 124 | } | ||
| 119 | } | 125 | } |
| 120 | |||
| 121 | Ok(()) | ||
| 122 | }) | 126 | }) |
| 123 | } | 127 | } |
| 124 | 128 | ||
| 125 | fn publish_immediate(&self, message: T) { | 129 | fn publish_with_context(&self, message: T, publisher_index: usize, cx: Option<&mut Context<'_>>) -> Result<(), T> { |
| 126 | self.inner.lock(|inner| { | 130 | self.inner.lock(|s| { |
| 127 | let mut s = inner.borrow_mut(); | 131 | let mut s = s.borrow_mut(); |
| 128 | 132 | // Try to publish the message | |
| 129 | // Make space in the queue if required | 133 | match s.try_publish(message) { |
| 130 | if s.queue.is_full() { | 134 | // We did it, we are ready |
| 131 | s.queue.pop_front(); | 135 | Ok(()) => Ok(()), |
| 132 | } | 136 | // The queue is full, so we need to reregister our waker and go to sleep |
| 133 | 137 | Err(message) => { | |
| 134 | // We are going to call something is Self again. | 138 | if let Some(cx) = cx { |
| 135 | // The lock is fine, but we need to get rid of the refcell borrow | 139 | s.register_publisher_waker(publisher_index, cx.waker()); |
| 136 | drop(s); | 140 | } |
| 137 | 141 | Err(message) | |
| 138 | // This will succeed because we made sure there is space | 142 | } |
| 139 | self.try_publish(message).ok().unwrap(); | ||
| 140 | }); | ||
| 141 | } | ||
| 142 | |||
| 143 | fn get_message(&self, message_id: u64) -> Option<WaitResult<T>> { | ||
| 144 | self.inner.lock(|inner| { | ||
| 145 | let mut s = inner.borrow_mut(); | ||
| 146 | |||
| 147 | let start_id = s.next_message_id - s.queue.len() as u64; | ||
| 148 | |||
| 149 | if message_id < start_id { | ||
| 150 | return Some(WaitResult::Lagged(start_id - message_id)); | ||
| 151 | } | ||
| 152 | |||
| 153 | let current_message_index = (message_id - start_id) as usize; | ||
| 154 | |||
| 155 | if current_message_index >= s.queue.len() { | ||
| 156 | return None; | ||
| 157 | } | ||
| 158 | |||
| 159 | // We've checked that the index is valid | ||
| 160 | let queue_item = s.queue.iter_mut().nth(current_message_index).unwrap(); | ||
| 161 | |||
| 162 | // We're reading this item, so decrement the counter | ||
| 163 | queue_item.1 -= 1; | ||
| 164 | let message = queue_item.0.clone(); | ||
| 165 | |||
| 166 | if current_message_index == 0 && queue_item.1 == 0 { | ||
| 167 | s.queue.pop_front(); | ||
| 168 | s.publisher_wakers.iter_mut().flatten().for_each(|w| w.wake()); | ||
| 169 | } | 143 | } |
| 170 | |||
| 171 | Some(WaitResult::Message(message)) | ||
| 172 | }) | ||
| 173 | } | ||
| 174 | |||
| 175 | fn register_subscriber_waker(&self, subscriber_index: usize, waker: &Waker) { | ||
| 176 | self.inner.lock(|inner| { | ||
| 177 | let mut s = inner.borrow_mut(); | ||
| 178 | s.subscriber_wakers[subscriber_index].as_mut().unwrap().register(waker); | ||
| 179 | }) | 144 | }) |
| 180 | } | 145 | } |
| 181 | 146 | ||
| 182 | fn register_publisher_waker(&self, publisher_index: usize, waker: &Waker) { | 147 | fn publish_immediate(&self, message: T) { |
| 183 | self.inner.lock(|inner| { | 148 | self.inner.lock(|s| { |
| 184 | let mut s = inner.borrow_mut(); | 149 | let mut s = s.borrow_mut(); |
| 185 | s.publisher_wakers[publisher_index].as_mut().unwrap().register(waker); | 150 | s.publish_immediate(message) |
| 186 | }) | 151 | }) |
| 187 | } | 152 | } |
| 188 | 153 | ||
| 189 | fn unregister_subscriber(&self, subscriber_index: usize, subscriber_next_message_id: u64) { | 154 | fn unregister_subscriber(&self, subscriber_index: usize, subscriber_next_message_id: u64) { |
| 190 | self.inner.lock(|inner| { | 155 | self.inner.lock(|s| { |
| 191 | let mut s = inner.borrow_mut(); | 156 | let mut s = s.borrow_mut(); |
| 192 | 157 | s.unregister_subscriber(subscriber_index, subscriber_next_message_id) | |
| 193 | // Remove the subscriber from the wakers | ||
| 194 | s.subscriber_wakers[subscriber_index] = None; | ||
| 195 | |||
| 196 | // All messages that haven't been read yet by this subscriber must have their counter decremented | ||
| 197 | let start_id = s.next_message_id - s.queue.len() as u64; | ||
| 198 | if subscriber_next_message_id >= start_id { | ||
| 199 | let current_message_index = (subscriber_next_message_id - start_id) as usize; | ||
| 200 | s.queue | ||
| 201 | .iter_mut() | ||
| 202 | .skip(current_message_index) | ||
| 203 | .for_each(|(_, counter)| *counter -= 1); | ||
| 204 | } | ||
| 205 | }) | 158 | }) |
| 206 | } | 159 | } |
| 207 | 160 | ||
| 208 | fn unregister_publisher(&self, publisher_index: usize) { | 161 | fn unregister_publisher(&self, publisher_index: usize) { |
| 209 | self.inner.lock(|inner| { | 162 | self.inner.lock(|s| { |
| 210 | let mut s = inner.borrow_mut(); | 163 | let mut s = s.borrow_mut(); |
| 211 | // Remove the publisher from the wakers | 164 | s.unregister_publisher(publisher_index) |
| 212 | s.publisher_wakers[publisher_index] = None; | ||
| 213 | }) | 165 | }) |
| 214 | } | 166 | } |
| 215 | } | 167 | } |
| @@ -241,6 +193,99 @@ impl<T: Clone, const CAP: usize, const SUBS: usize, const PUBS: usize> PubSubSta | |||
| 241 | publisher_wakers: [WAKER_INIT; PUBS], | 193 | publisher_wakers: [WAKER_INIT; PUBS], |
| 242 | } | 194 | } |
| 243 | } | 195 | } |
| 196 | |||
| 197 | fn try_publish(&mut self, message: T) -> Result<(), T> { | ||
| 198 | let active_subscriber_count = self.subscriber_wakers.iter().flatten().count(); | ||
| 199 | |||
| 200 | if active_subscriber_count == 0 { | ||
| 201 | // We don't need to publish anything because there is no one to receive it | ||
| 202 | return Ok(()); | ||
| 203 | } | ||
| 204 | |||
| 205 | if self.queue.is_full() { | ||
| 206 | return Err(message); | ||
| 207 | } | ||
| 208 | // We just did a check for this | ||
| 209 | self.queue.push_back((message, active_subscriber_count)).ok().unwrap(); | ||
| 210 | |||
| 211 | self.next_message_id += 1; | ||
| 212 | |||
| 213 | // Wake all of the subscribers | ||
| 214 | for active_subscriber in self.subscriber_wakers.iter_mut().flatten() { | ||
| 215 | active_subscriber.wake() | ||
| 216 | } | ||
| 217 | |||
| 218 | Ok(()) | ||
| 219 | } | ||
| 220 | |||
| 221 | fn publish_immediate(&mut self, message: T) { | ||
| 222 | // Make space in the queue if required | ||
| 223 | if self.queue.is_full() { | ||
| 224 | self.queue.pop_front(); | ||
| 225 | } | ||
| 226 | |||
| 227 | // This will succeed because we made sure there is space | ||
| 228 | self.try_publish(message).ok().unwrap(); | ||
| 229 | } | ||
| 230 | |||
| 231 | fn get_message(&mut self, message_id: u64) -> Option<WaitResult<T>> { | ||
| 232 | let start_id = self.next_message_id - self.queue.len() as u64; | ||
| 233 | |||
| 234 | if message_id < start_id { | ||
| 235 | return Some(WaitResult::Lagged(start_id - message_id)); | ||
| 236 | } | ||
| 237 | |||
| 238 | let current_message_index = (message_id - start_id) as usize; | ||
| 239 | |||
| 240 | if current_message_index >= self.queue.len() { | ||
| 241 | return None; | ||
| 242 | } | ||
| 243 | |||
| 244 | // We've checked that the index is valid | ||
| 245 | let queue_item = self.queue.iter_mut().nth(current_message_index).unwrap(); | ||
| 246 | |||
| 247 | // We're reading this item, so decrement the counter | ||
| 248 | queue_item.1 -= 1; | ||
| 249 | let message = queue_item.0.clone(); | ||
| 250 | |||
| 251 | if current_message_index == 0 && queue_item.1 == 0 { | ||
| 252 | self.queue.pop_front(); | ||
| 253 | self.publisher_wakers.iter_mut().flatten().for_each(|w| w.wake()); | ||
| 254 | } | ||
| 255 | |||
| 256 | Some(WaitResult::Message(message)) | ||
| 257 | } | ||
| 258 | |||
| 259 | fn register_subscriber_waker(&mut self, subscriber_index: usize, waker: &Waker) { | ||
| 260 | self.subscriber_wakers[subscriber_index] | ||
| 261 | .as_mut() | ||
| 262 | .unwrap() | ||
| 263 | .register(waker); | ||
| 264 | } | ||
| 265 | |||
| 266 | fn register_publisher_waker(&mut self, publisher_index: usize, waker: &Waker) { | ||
| 267 | self.publisher_wakers[publisher_index].as_mut().unwrap().register(waker); | ||
| 268 | } | ||
| 269 | |||
| 270 | fn unregister_subscriber(&mut self, subscriber_index: usize, subscriber_next_message_id: u64) { | ||
| 271 | // Remove the subscriber from the wakers | ||
| 272 | self.subscriber_wakers[subscriber_index] = None; | ||
| 273 | |||
| 274 | // All messages that haven't been read yet by this subscriber must have their counter decremented | ||
| 275 | let start_id = self.next_message_id - self.queue.len() as u64; | ||
| 276 | if subscriber_next_message_id >= start_id { | ||
| 277 | let current_message_index = (subscriber_next_message_id - start_id) as usize; | ||
| 278 | self.queue | ||
| 279 | .iter_mut() | ||
| 280 | .skip(current_message_index) | ||
| 281 | .for_each(|(_, counter)| *counter -= 1); | ||
| 282 | } | ||
| 283 | } | ||
| 284 | |||
| 285 | fn unregister_publisher(&mut self, publisher_index: usize) { | ||
| 286 | // Remove the publisher from the wakers | ||
| 287 | self.publisher_wakers[publisher_index] = None; | ||
| 288 | } | ||
| 244 | } | 289 | } |
| 245 | 290 | ||
| 246 | /// A subscriber to a channel | 291 | /// A subscriber to a channel |
| @@ -276,15 +321,12 @@ impl<'a, T: Clone> Subscriber<'a, T> { | |||
| 276 | /// | 321 | /// |
| 277 | /// This function does not peek. The message is received if there is one. | 322 | /// This function does not peek. The message is received if there is one. |
| 278 | pub fn try_next_message(&mut self) -> Option<WaitResult<T>> { | 323 | pub fn try_next_message(&mut self) -> Option<WaitResult<T>> { |
| 279 | match self.channel.get_message(self.next_message_id) { | 324 | match self |
| 280 | Some(WaitResult::Lagged(amount)) => { | 325 | .channel |
| 281 | self.next_message_id += amount; | 326 | .get_message_with_context(&mut self.next_message_id, self.subscriber_index, None) |
| 282 | Some(WaitResult::Lagged(amount)) | 327 | { |
| 283 | } | 328 | Poll::Ready(result) => Some(result), |
| 284 | result => { | 329 | Poll::Pending => None, |
| 285 | self.next_message_id += 1; | ||
| 286 | result | ||
| 287 | } | ||
| 288 | } | 330 | } |
| 289 | } | 331 | } |
| 290 | 332 | ||
| @@ -317,26 +359,16 @@ impl<'a, T: Clone> futures::Stream for Subscriber<'a, T> { | |||
| 317 | fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { | 359 | fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { |
| 318 | let this = unsafe { self.get_unchecked_mut() }; | 360 | let this = unsafe { self.get_unchecked_mut() }; |
| 319 | 361 | ||
| 320 | // Check if we can read a message | 362 | match this |
| 321 | match this.channel.get_message(this.next_message_id) { | 363 | .channel |
| 322 | // Yes, so we are done polling | 364 | .get_message_with_context(&mut this.next_message_id, this.subscriber_index, Some(cx)) |
| 323 | Some(WaitResult::Message(message)) => { | 365 | { |
| 324 | this.next_message_id += 1; | 366 | Poll::Ready(WaitResult::Message(message)) => Poll::Ready(Some(message)), |
| 325 | Poll::Ready(Some(message)) | 367 | Poll::Ready(WaitResult::Lagged(_)) => { |
| 326 | } | ||
| 327 | // No, so we need to reregister our waker and sleep again | ||
| 328 | None => { | ||
| 329 | this.channel | ||
| 330 | .register_subscriber_waker(this.subscriber_index, cx.waker()); | ||
| 331 | Poll::Pending | ||
| 332 | } | ||
| 333 | // We missed a couple of messages. We must do our internal bookkeeping. | ||
| 334 | // This stream impl doesn't return lag results, so we just ignore and start over | ||
| 335 | Some(WaitResult::Lagged(amount)) => { | ||
| 336 | this.next_message_id += amount; | ||
| 337 | cx.waker().wake_by_ref(); | 368 | cx.waker().wake_by_ref(); |
| 338 | Poll::Pending | 369 | Poll::Pending |
| 339 | } | 370 | } |
| 371 | Poll::Pending => Poll::Pending, | ||
| 340 | } | 372 | } |
| 341 | } | 373 | } |
| 342 | } | 374 | } |
| @@ -369,7 +401,7 @@ impl<'a, T: Clone> Publisher<'a, T> { | |||
| 369 | 401 | ||
| 370 | /// Publish a message if there is space in the message queue | 402 | /// Publish a message if there is space in the message queue |
| 371 | pub fn try_publish(&self, message: T) -> Result<(), T> { | 403 | pub fn try_publish(&self, message: T) -> Result<(), T> { |
| 372 | self.channel.try_publish(message) | 404 | self.channel.publish_with_context(message, self.publisher_index, None) |
| 373 | } | 405 | } |
| 374 | } | 406 | } |
| 375 | 407 | ||
| @@ -395,7 +427,7 @@ impl<'a, T: Clone> ImmediatePublisher<'a, T> { | |||
| 395 | 427 | ||
| 396 | /// Publish a message if there is space in the message queue | 428 | /// Publish a message if there is space in the message queue |
| 397 | pub fn try_publish(&self, message: T) -> Result<(), T> { | 429 | pub fn try_publish(&self, message: T) -> Result<(), T> { |
| 398 | self.channel.try_publish(message) | 430 | self.channel.publish_with_context(message, usize::MAX, None) |
| 399 | } | 431 | } |
| 400 | } | 432 | } |
| 401 | 433 | ||
| @@ -411,19 +443,19 @@ pub enum Error { | |||
| 411 | } | 443 | } |
| 412 | 444 | ||
| 413 | trait PubSubBehavior<T> { | 445 | trait PubSubBehavior<T> { |
| 414 | /// Try to publish a message. If the queue is full it won't succeed | 446 | fn get_message_with_context( |
| 415 | fn try_publish(&self, message: T) -> Result<(), T>; | 447 | &self, |
| 416 | /// Publish a message immediately. If the queue is full, just throw out the oldest one. | 448 | next_message_id: &mut u64, |
| 449 | subscriber_index: usize, | ||
| 450 | cx: Option<&mut Context<'_>>, | ||
| 451 | ) -> Poll<WaitResult<T>>; | ||
| 452 | |||
| 453 | fn publish_with_context(&self, message: T, publisher_index: usize, cx: Option<&mut Context<'_>>) -> Result<(), T>; | ||
| 454 | |||
| 417 | fn publish_immediate(&self, message: T); | 455 | fn publish_immediate(&self, message: T); |
| 418 | /// Tries to read the message if available | 456 | |
| 419 | fn get_message(&self, message_id: u64) -> Option<WaitResult<T>>; | ||
| 420 | /// Register the given waker for the given subscriber. | ||
| 421 | fn register_subscriber_waker(&self, subscriber_index: usize, waker: &Waker); | ||
| 422 | /// Register the given waker for the given publisher. | ||
| 423 | fn register_publisher_waker(&self, publisher_index: usize, waker: &Waker); | ||
| 424 | /// Make the channel forget the subscriber. | ||
| 425 | fn unregister_subscriber(&self, subscriber_index: usize, subscriber_next_message_id: u64); | 457 | fn unregister_subscriber(&self, subscriber_index: usize, subscriber_next_message_id: u64); |
| 426 | /// Make the channel forget the publisher. | 458 | |
| 427 | fn unregister_publisher(&self, publisher_index: usize); | 459 | fn unregister_publisher(&self, publisher_index: usize); |
| 428 | } | 460 | } |
| 429 | 461 | ||
| @@ -436,26 +468,10 @@ impl<'s, 'a, T: Clone> Future for SubscriberWaitFuture<'s, 'a, T> { | |||
| 436 | type Output = WaitResult<T>; | 468 | type Output = WaitResult<T>; |
| 437 | 469 | ||
| 438 | fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { | 470 | fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { |
| 439 | // Check if we can read a message | 471 | let sub_index = self.subscriber.subscriber_index; |
| 440 | match self.subscriber.channel.get_message(self.subscriber.next_message_id) { | 472 | self.subscriber |
| 441 | // Yes, so we are done polling | 473 | .channel |
| 442 | Some(WaitResult::Message(message)) => { | 474 | .get_message_with_context(&mut self.subscriber.next_message_id, sub_index, Some(cx)) |
| 443 | self.subscriber.next_message_id += 1; | ||
| 444 | Poll::Ready(WaitResult::Message(message)) | ||
| 445 | } | ||
| 446 | // No, so we need to reregister our waker and sleep again | ||
| 447 | None => { | ||
| 448 | self.subscriber | ||
| 449 | .channel | ||
| 450 | .register_subscriber_waker(self.subscriber.subscriber_index, cx.waker()); | ||
| 451 | Poll::Pending | ||
| 452 | } | ||
| 453 | // We missed a couple of messages. We must do our internal bookkeeping and return that we lagged | ||
| 454 | Some(WaitResult::Lagged(amount)) => { | ||
| 455 | self.subscriber.next_message_id += amount; | ||
| 456 | Poll::Ready(WaitResult::Lagged(amount)) | ||
| 457 | } | ||
| 458 | } | ||
| 459 | } | 475 | } |
| 460 | } | 476 | } |
| 461 | 477 | ||
| @@ -474,16 +490,15 @@ impl<'s, 'a, T: Clone> Future for PublisherWaitFuture<'s, 'a, T> { | |||
| 474 | fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { | 490 | fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { |
| 475 | let this = unsafe { self.get_unchecked_mut() }; | 491 | let this = unsafe { self.get_unchecked_mut() }; |
| 476 | 492 | ||
| 477 | // Try to publish the message | 493 | let message = this.message.take().unwrap(); |
| 478 | match this.publisher.channel.try_publish(this.message.take().unwrap()) { | 494 | match this |
| 479 | // We did it, we are ready | 495 | .publisher |
| 496 | .channel | ||
| 497 | .publish_with_context(message, this.publisher.publisher_index, Some(cx)) | ||
| 498 | { | ||
| 480 | Ok(()) => Poll::Ready(()), | 499 | Ok(()) => Poll::Ready(()), |
| 481 | // The queue is full, so we need to reregister our waker and go to sleep | ||
| 482 | Err(message) => { | 500 | Err(message) => { |
| 483 | this.message = Some(message); | 501 | this.message = Some(message); |
| 484 | this.publisher | ||
| 485 | .channel | ||
| 486 | .register_publisher_waker(this.publisher.publisher_index, cx.waker()); | ||
| 487 | Poll::Pending | 502 | Poll::Pending |
| 488 | } | 503 | } |
| 489 | } | 504 | } |
