aboutsummaryrefslogtreecommitdiff
path: root/embassy-executor-macros
diff options
context:
space:
mode:
Diffstat (limited to 'embassy-executor-macros')
-rw-r--r--embassy-executor-macros/src/macros/task.rs78
1 files changed, 52 insertions, 26 deletions
diff --git a/embassy-executor-macros/src/macros/task.rs b/embassy-executor-macros/src/macros/task.rs
index f01cc3b6c..5b360b128 100644
--- a/embassy-executor-macros/src/macros/task.rs
+++ b/embassy-executor-macros/src/macros/task.rs
@@ -112,25 +112,11 @@ pub fn run(args: TokenStream, item: TokenStream) -> TokenStream {
112 } 112 }
113 } 113 }
114 114
115 let task_ident = f.sig.ident.clone(); 115 // Copy the generics + where clause to avoid more spurious errors.
116 let task_inner_ident = format_ident!("__{}_task", task_ident); 116 let generics = &f.sig.generics;
117 117 let where_clause = &f.sig.generics.where_clause;
118 let mut task_inner = f.clone(); 118 let unsafety = &f.sig.unsafety;
119 let visibility = task_inner.vis.clone(); 119 let visibility = &f.vis;
120 task_inner.vis = syn::Visibility::Inherited;
121 task_inner.sig.ident = task_inner_ident.clone();
122
123 // Forcefully mark the inner task as safe.
124 // SAFETY: We only ever call task_inner in functions
125 // with the same safety preconditions as task_inner
126 task_inner.sig.unsafety = None;
127 let task_body = task_inner.body;
128 task_inner.body = quote! {
129 #[allow(unused_unsafe, reason = "Not all function bodies may require being in an unsafe block")]
130 unsafe {
131 #task_body
132 }
133 };
134 120
135 // assemble the original input arguments, 121 // assemble the original input arguments,
136 // including any attributes that may have 122 // including any attributes that may have
@@ -143,6 +129,51 @@ pub fn run(args: TokenStream, item: TokenStream) -> TokenStream {
143 )); 129 ));
144 } 130 }
145 131
132 let task_ident = f.sig.ident.clone();
133 let task_inner_ident = format_ident!("__{}_task", task_ident);
134
135 let task_inner_future_output = match &f.sig.output {
136 ReturnType::Default => quote! {-> impl ::core::future::Future<Output = ()>},
137 // Special case the never type since we can't stuff it into a `impl Future<Output = !>`
138 ReturnType::Type(arrow, maybe_never) if matches!(**maybe_never, Type::Never(_)) => quote! {
139 #arrow #maybe_never
140 },
141 // Grab the arrow span, why not
142 ReturnType::Type(arrow, typ) if f.sig.asyncness.is_some() => quote! {
143 #arrow impl ::core::future::Future<Output = #typ>
144 },
145 // We assume that if `f` isn't async, it must return `-> impl Future<...>`
146 // This is checked using traits later
147 ReturnType::Type(arrow, typ) => quote! {
148 #arrow #typ
149 },
150 };
151
152 let task_inner_body = if errors.is_empty() {
153 quote! {
154 #f
155
156 // SAFETY: All the preconditions to `#task_ident` apply to
157 // all contexts `#task_inner_ident` is called in
158 #unsafety {
159 #task_ident(#(#full_args,)*)
160 }
161 }
162 } else {
163 quote! {
164 async {::core::todo!()}
165 }
166 };
167
168 let task_inner = quote! {
169 #visibility fn #task_inner_ident #generics (#fargs)
170 #task_inner_future_output
171 #where_clause
172 {
173 #task_inner_body
174 }
175 };
176
146 let spawn = if returns_impl_trait { 177 let spawn = if returns_impl_trait {
147 quote!(spawn) 178 quote!(spawn)
148 } else { 179 } else {
@@ -185,7 +216,7 @@ pub fn run(args: TokenStream, item: TokenStream) -> TokenStream {
185 unsafe { __task_pool_get(#task_inner_ident).#spawn(move || #task_inner_ident(#(#full_args,)*)) } 216 unsafe { __task_pool_get(#task_inner_ident).#spawn(move || #task_inner_ident(#(#full_args,)*)) }
186 }; 217 };
187 218
188 let task_outer_attrs = task_inner.attrs.clone(); 219 let task_outer_attrs = &f.attrs;
189 220
190 if !errors.is_empty() { 221 if !errors.is_empty() {
191 task_outer_body = quote! { 222 task_outer_body = quote! {
@@ -195,11 +226,6 @@ pub fn run(args: TokenStream, item: TokenStream) -> TokenStream {
195 }; 226 };
196 } 227 }
197 228
198 // Copy the generics + where clause to avoid more spurious errors.
199 let generics = &f.sig.generics;
200 let where_clause = &f.sig.generics.where_clause;
201 let unsafety = &f.sig.unsafety;
202
203 let result = quote! { 229 let result = quote! {
204 // This is the user's task function, renamed. 230 // This is the user's task function, renamed.
205 // We put it outside the #task_ident fn below, because otherwise 231 // We put it outside the #task_ident fn below, because otherwise
@@ -226,7 +252,7 @@ fn check_arg_ty(errors: &mut TokenStream, ty: &Type) {
226 252
227 impl<'a, 'ast> Visit<'ast> for Visitor<'a> { 253 impl<'a, 'ast> Visit<'ast> for Visitor<'a> {
228 fn visit_type_reference(&mut self, i: &'ast syn::TypeReference) { 254 fn visit_type_reference(&mut self, i: &'ast syn::TypeReference) {
229 // only check for elided lifetime here. If not elided, it's checked by `visit_lifetime`. 255 // Only check for elided lifetime here. If not elided, it's checked by `visit_lifetime`.
230 if i.lifetime.is_none() { 256 if i.lifetime.is_none() {
231 error( 257 error(
232 self.errors, 258 self.errors,