aboutsummaryrefslogtreecommitdiff
path: root/embassy-executor-macros/src/macros/task.rs
diff options
context:
space:
mode:
Diffstat (limited to 'embassy-executor-macros/src/macros/task.rs')
-rw-r--r--embassy-executor-macros/src/macros/task.rs82
1 files changed, 67 insertions, 15 deletions
diff --git a/embassy-executor-macros/src/macros/task.rs b/embassy-executor-macros/src/macros/task.rs
index 1c5e3571d..fc8673743 100644
--- a/embassy-executor-macros/src/macros/task.rs
+++ b/embassy-executor-macros/src/macros/task.rs
@@ -5,7 +5,7 @@ use darling::FromMeta;
5use proc_macro2::{Span, TokenStream}; 5use proc_macro2::{Span, TokenStream};
6use quote::{format_ident, quote}; 6use quote::{format_ident, quote};
7use syn::visit::{self, Visit}; 7use syn::visit::{self, Visit};
8use syn::{Expr, ExprLit, Lit, LitInt, ReturnType, Type}; 8use syn::{Expr, ExprLit, Lit, LitInt, ReturnType, Type, Visibility};
9 9
10use crate::util::*; 10use crate::util::*;
11 11
@@ -112,13 +112,11 @@ pub fn run(args: TokenStream, item: TokenStream) -> TokenStream {
112 } 112 }
113 } 113 }
114 114
115 let task_ident = f.sig.ident.clone(); 115 // Copy the generics + where clause to avoid more spurious errors.
116 let task_inner_ident = format_ident!("__{}_task", task_ident); 116 let generics = &f.sig.generics;
117 117 let where_clause = &f.sig.generics.where_clause;
118 let mut task_inner = f.clone(); 118 let unsafety = &f.sig.unsafety;
119 let visibility = task_inner.vis.clone(); 119 let visibility = &f.vis;
120 task_inner.vis = syn::Visibility::Inherited;
121 task_inner.sig.ident = task_inner_ident.clone();
122 120
123 // assemble the original input arguments, 121 // assemble the original input arguments,
124 // including any attributes that may have 122 // including any attributes that may have
@@ -131,6 +129,64 @@ pub fn run(args: TokenStream, item: TokenStream) -> TokenStream {
131 )); 129 ));
132 } 130 }
133 131
132 let task_ident = f.sig.ident.clone();
133 let task_inner_ident = format_ident!("__{}_task", task_ident);
134
135 let task_inner_future_output = match &f.sig.output {
136 ReturnType::Default => quote! {-> impl ::core::future::Future<Output = ()>},
137 // Special case the never type since we can't stuff it into a `impl Future<Output = !>`
138 ReturnType::Type(arrow, maybe_never)
139 if f.sig.asyncness.is_some() && matches!(**maybe_never, Type::Never(_)) =>
140 {
141 quote! {
142 #arrow impl ::core::future::Future<Output=#embassy_executor::_export::Never>
143 }
144 }
145 ReturnType::Type(arrow, maybe_never) if matches!(**maybe_never, Type::Never(_)) => quote! {
146 #arrow #maybe_never
147 },
148 // Grab the arrow span, why not
149 ReturnType::Type(arrow, typ) if f.sig.asyncness.is_some() => quote! {
150 #arrow impl ::core::future::Future<Output = #typ>
151 },
152 // We assume that if `f` isn't async, it must return `-> impl Future<...>`
153 // This is checked using traits later
154 ReturnType::Type(arrow, typ) => quote! {
155 #arrow #typ
156 },
157 };
158
159 // We have to rename the function since it might be recursive;
160 let mut task_inner_function = f.clone();
161 let task_inner_function_ident = format_ident!("__{}_task_inner_function", task_ident);
162 task_inner_function.sig.ident = task_inner_function_ident.clone();
163 task_inner_function.vis = Visibility::Inherited;
164
165 let task_inner_body = if errors.is_empty() {
166 quote! {
167 #task_inner_function
168
169 // SAFETY: All the preconditions to `#task_ident` apply to
170 // all contexts `#task_inner_ident` is called in
171 #unsafety {
172 #task_inner_function_ident(#(#full_args,)*)
173 }
174 }
175 } else {
176 quote! {
177 async {::core::todo!()}
178 }
179 };
180
181 let task_inner = quote! {
182 #visibility fn #task_inner_ident #generics (#fargs)
183 #task_inner_future_output
184 #where_clause
185 {
186 #task_inner_body
187 }
188 };
189
134 let spawn = if returns_impl_trait { 190 let spawn = if returns_impl_trait {
135 quote!(spawn) 191 quote!(spawn)
136 } else { 192 } else {
@@ -173,7 +229,7 @@ pub fn run(args: TokenStream, item: TokenStream) -> TokenStream {
173 unsafe { __task_pool_get(#task_inner_ident).#spawn(move || #task_inner_ident(#(#full_args,)*)) } 229 unsafe { __task_pool_get(#task_inner_ident).#spawn(move || #task_inner_ident(#(#full_args,)*)) }
174 }; 230 };
175 231
176 let task_outer_attrs = task_inner.attrs.clone(); 232 let task_outer_attrs = &f.attrs;
177 233
178 if !errors.is_empty() { 234 if !errors.is_empty() {
179 task_outer_body = quote! { 235 task_outer_body = quote! {
@@ -183,10 +239,6 @@ pub fn run(args: TokenStream, item: TokenStream) -> TokenStream {
183 }; 239 };
184 } 240 }
185 241
186 // Copy the generics + where clause to avoid more spurious errors.
187 let generics = &f.sig.generics;
188 let where_clause = &f.sig.generics.where_clause;
189
190 let result = quote! { 242 let result = quote! {
191 // This is the user's task function, renamed. 243 // This is the user's task function, renamed.
192 // We put it outside the #task_ident fn below, because otherwise 244 // We put it outside the #task_ident fn below, because otherwise
@@ -196,7 +248,7 @@ pub fn run(args: TokenStream, item: TokenStream) -> TokenStream {
196 #task_inner 248 #task_inner
197 249
198 #(#task_outer_attrs)* 250 #(#task_outer_attrs)*
199 #visibility fn #task_ident #generics (#fargs) -> #embassy_executor::SpawnToken<impl Sized> #where_clause{ 251 #visibility #unsafety fn #task_ident #generics (#fargs) -> #embassy_executor::SpawnToken<impl Sized> #where_clause{
200 #task_outer_body 252 #task_outer_body
201 } 253 }
202 254
@@ -213,7 +265,7 @@ fn check_arg_ty(errors: &mut TokenStream, ty: &Type) {
213 265
214 impl<'a, 'ast> Visit<'ast> for Visitor<'a> { 266 impl<'a, 'ast> Visit<'ast> for Visitor<'a> {
215 fn visit_type_reference(&mut self, i: &'ast syn::TypeReference) { 267 fn visit_type_reference(&mut self, i: &'ast syn::TypeReference) {
216 // only check for elided lifetime here. If not elided, it's checked by `visit_lifetime`. 268 // Only check for elided lifetime here. If not elided, it's checked by `visit_lifetime`.
217 if i.lifetime.is_none() { 269 if i.lifetime.is_none() {
218 error( 270 error(
219 self.errors, 271 self.errors,