diff options
Diffstat (limited to 'embassy-futures/src/join.rs')
| -rw-r--r-- | embassy-futures/src/join.rs | 68 |
1 files changed, 68 insertions, 0 deletions
diff --git a/embassy-futures/src/join.rs b/embassy-futures/src/join.rs index 7600d4b8a..bc0cb5303 100644 --- a/embassy-futures/src/join.rs +++ b/embassy-futures/src/join.rs | |||
| @@ -1,6 +1,7 @@ | |||
| 1 | //! Wait for multiple futures to complete. | 1 | //! Wait for multiple futures to complete. |
| 2 | 2 | ||
| 3 | use core::future::Future; | 3 | use core::future::Future; |
| 4 | use core::mem::MaybeUninit; | ||
| 4 | use core::pin::Pin; | 5 | use core::pin::Pin; |
| 5 | use core::task::{Context, Poll}; | 6 | use core::task::{Context, Poll}; |
| 6 | use core::{fmt, mem}; | 7 | use core::{fmt, mem}; |
| @@ -252,3 +253,70 @@ where | |||
| 252 | { | 253 | { |
| 253 | Join5::new(future1, future2, future3, future4, future5) | 254 | Join5::new(future1, future2, future3, future4, future5) |
| 254 | } | 255 | } |
| 256 | |||
| 257 | // ===================================================== | ||
| 258 | |||
| 259 | /// Future for the [`join_array`] function. | ||
| 260 | #[must_use = "futures do nothing unless you `.await` or poll them"] | ||
| 261 | pub struct JoinArray<Fut: Future, const N: usize> { | ||
| 262 | futures: [MaybeDone<Fut>; N], | ||
| 263 | } | ||
| 264 | |||
| 265 | impl<Fut: Future, const N: usize> fmt::Debug for JoinArray<Fut, N> | ||
| 266 | where | ||
| 267 | Fut: Future + fmt::Debug, | ||
| 268 | Fut::Output: fmt::Debug, | ||
| 269 | { | ||
| 270 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { | ||
| 271 | f.debug_struct("JoinArray").field("futures", &self.futures).finish() | ||
| 272 | } | ||
| 273 | } | ||
| 274 | |||
| 275 | impl<Fut: Future, const N: usize> Future for JoinArray<Fut, N> { | ||
| 276 | type Output = [Fut::Output; N]; | ||
| 277 | fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { | ||
| 278 | let this = unsafe { self.get_unchecked_mut() }; | ||
| 279 | let mut all_done = true; | ||
| 280 | for f in this.futures.iter_mut() { | ||
| 281 | all_done &= unsafe { Pin::new_unchecked(f) }.poll(cx); | ||
| 282 | } | ||
| 283 | |||
| 284 | if all_done { | ||
| 285 | let mut array: [MaybeUninit<Fut::Output>; N] = unsafe { MaybeUninit::uninit().assume_init() }; | ||
| 286 | for i in 0..N { | ||
| 287 | array[i].write(this.futures[i].take_output()); | ||
| 288 | } | ||
| 289 | Poll::Ready(unsafe { (&array as *const _ as *const [Fut::Output; N]).read() }) | ||
| 290 | } else { | ||
| 291 | Poll::Pending | ||
| 292 | } | ||
| 293 | } | ||
| 294 | } | ||
| 295 | |||
| 296 | /// Joins the result of an array of futures, waiting for them all to complete. | ||
| 297 | /// | ||
| 298 | /// This function will return a new future which awaits all futures to | ||
| 299 | /// complete. The returned future will finish with a tuple of all results. | ||
| 300 | /// | ||
| 301 | /// Note that this function consumes the passed futures and returns a | ||
| 302 | /// wrapped version of it. | ||
| 303 | /// | ||
| 304 | /// # Examples | ||
| 305 | /// | ||
| 306 | /// ``` | ||
| 307 | /// # embassy_futures::block_on(async { | ||
| 308 | /// | ||
| 309 | /// async fn foo(n: u32) -> u32 { n } | ||
| 310 | /// let a = foo(1); | ||
| 311 | /// let b = foo(2); | ||
| 312 | /// let c = foo(3); | ||
| 313 | /// let res = embassy_futures::join::join_array([a, b, c]).await; | ||
| 314 | /// | ||
| 315 | /// assert_eq!(res, [1, 2, 3]); | ||
| 316 | /// # }); | ||
| 317 | /// ``` | ||
| 318 | pub fn join_array<Fut: Future, const N: usize>(futures: [Fut; N]) -> JoinArray<Fut, N> { | ||
| 319 | JoinArray { | ||
| 320 | futures: futures.map(MaybeDone::Future), | ||
| 321 | } | ||
| 322 | } | ||
