From 1b42e624246f9355a91ef98ddf96d5af1b9b3687 Mon Sep 17 00:00:00 2001 From: Brezak Date: Wed, 23 Jul 2025 19:20:09 +0200 Subject: embassy-executor: explicitly return impl Future in task inner task --- embassy-executor-macros/src/macros/task.rs | 78 ++++++++++++++++++++---------- 1 file changed, 52 insertions(+), 26 deletions(-) (limited to 'embassy-executor-macros') diff --git a/embassy-executor-macros/src/macros/task.rs b/embassy-executor-macros/src/macros/task.rs index f01cc3b6c..5b360b128 100644 --- a/embassy-executor-macros/src/macros/task.rs +++ b/embassy-executor-macros/src/macros/task.rs @@ -112,25 +112,11 @@ pub fn run(args: TokenStream, item: TokenStream) -> TokenStream { } } - let task_ident = f.sig.ident.clone(); - let task_inner_ident = format_ident!("__{}_task", task_ident); - - let mut task_inner = f.clone(); - let visibility = task_inner.vis.clone(); - task_inner.vis = syn::Visibility::Inherited; - task_inner.sig.ident = task_inner_ident.clone(); - - // Forcefully mark the inner task as safe. - // SAFETY: We only ever call task_inner in functions - // with the same safety preconditions as task_inner - task_inner.sig.unsafety = None; - let task_body = task_inner.body; - task_inner.body = quote! { - #[allow(unused_unsafe, reason = "Not all function bodies may require being in an unsafe block")] - unsafe { - #task_body - } - }; + // Copy the generics + where clause to avoid more spurious errors. + let generics = &f.sig.generics; + let where_clause = &f.sig.generics.where_clause; + let unsafety = &f.sig.unsafety; + let visibility = &f.vis; // assemble the original input arguments, // including any attributes that may have @@ -143,6 +129,51 @@ pub fn run(args: TokenStream, item: TokenStream) -> TokenStream { )); } + let task_ident = f.sig.ident.clone(); + let task_inner_ident = format_ident!("__{}_task", task_ident); + + let task_inner_future_output = match &f.sig.output { + ReturnType::Default => quote! {-> impl ::core::future::Future}, + // Special case the never type since we can't stuff it into a `impl Future` + ReturnType::Type(arrow, maybe_never) if matches!(**maybe_never, Type::Never(_)) => quote! { + #arrow #maybe_never + }, + // Grab the arrow span, why not + ReturnType::Type(arrow, typ) if f.sig.asyncness.is_some() => quote! { + #arrow impl ::core::future::Future + }, + // We assume that if `f` isn't async, it must return `-> impl Future<...>` + // This is checked using traits later + ReturnType::Type(arrow, typ) => quote! { + #arrow #typ + }, + }; + + let task_inner_body = if errors.is_empty() { + quote! { + #f + + // SAFETY: All the preconditions to `#task_ident` apply to + // all contexts `#task_inner_ident` is called in + #unsafety { + #task_ident(#(#full_args,)*) + } + } + } else { + quote! { + async {::core::todo!()} + } + }; + + let task_inner = quote! { + #visibility fn #task_inner_ident #generics (#fargs) + #task_inner_future_output + #where_clause + { + #task_inner_body + } + }; + let spawn = if returns_impl_trait { quote!(spawn) } else { @@ -185,7 +216,7 @@ pub fn run(args: TokenStream, item: TokenStream) -> TokenStream { unsafe { __task_pool_get(#task_inner_ident).#spawn(move || #task_inner_ident(#(#full_args,)*)) } }; - let task_outer_attrs = task_inner.attrs.clone(); + let task_outer_attrs = &f.attrs; if !errors.is_empty() { task_outer_body = quote! { @@ -195,11 +226,6 @@ pub fn run(args: TokenStream, item: TokenStream) -> TokenStream { }; } - // Copy the generics + where clause to avoid more spurious errors. - let generics = &f.sig.generics; - let where_clause = &f.sig.generics.where_clause; - let unsafety = &f.sig.unsafety; - let result = quote! { // This is the user's task function, renamed. // We put it outside the #task_ident fn below, because otherwise @@ -226,7 +252,7 @@ fn check_arg_ty(errors: &mut TokenStream, ty: &Type) { impl<'a, 'ast> Visit<'ast> for Visitor<'a> { fn visit_type_reference(&mut self, i: &'ast syn::TypeReference) { - // only check for elided lifetime here. If not elided, it's checked by `visit_lifetime`. + // Only check for elided lifetime here. If not elided, it's checked by `visit_lifetime`. if i.lifetime.is_none() { error( self.errors, -- cgit