aboutsummaryrefslogtreecommitdiff
path: root/embassy-executor-macros/src/macros/task.rs
blob: 5b360b128f33e760e1110de51c6f2b34a2f4066e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
use std::str::FromStr;

use darling::export::NestedMeta;
use darling::FromMeta;
use proc_macro2::{Span, TokenStream};
use quote::{format_ident, quote};
use syn::visit::{self, Visit};
use syn::{Expr, ExprLit, Lit, LitInt, ReturnType, Type};

use crate::util::*;

#[derive(Debug, FromMeta, Default)]
struct Args {
    #[darling(default)]
    pool_size: Option<syn::Expr>,
    /// Use this to override the `embassy_executor` crate path. Defaults to `::embassy_executor`.
    #[darling(default)]
    embassy_executor: Option<syn::Expr>,
}

pub fn run(args: TokenStream, item: TokenStream) -> TokenStream {
    let mut errors = TokenStream::new();

    // If any of the steps for this macro fail, we still want to expand to an item that is as close
    // to the expected output as possible. This helps out IDEs such that completions and other
    // related features keep working.
    let f: ItemFn = match syn::parse2(item.clone()) {
        Ok(x) => x,
        Err(e) => return token_stream_with_error(item, e),
    };

    let args = match NestedMeta::parse_meta_list(args) {
        Ok(x) => x,
        Err(e) => return token_stream_with_error(item, e),
    };

    let args = match Args::from_list(&args) {
        Ok(x) => x,
        Err(e) => {
            errors.extend(e.write_errors());
            Args::default()
        }
    };

    let pool_size = args.pool_size.unwrap_or(Expr::Lit(ExprLit {
        attrs: vec![],
        lit: Lit::Int(LitInt::new("1", Span::call_site())),
    }));

    let embassy_executor = args
        .embassy_executor
        .unwrap_or(Expr::Verbatim(TokenStream::from_str("::embassy_executor").unwrap()));

    let returns_impl_trait = match &f.sig.output {
        ReturnType::Type(_, ty) => matches!(**ty, Type::ImplTrait(_)),
        _ => false,
    };
    if f.sig.asyncness.is_none() && !returns_impl_trait {
        error(&mut errors, &f.sig, "task functions must be async");
    }
    if !f.sig.generics.params.is_empty() {
        error(&mut errors, &f.sig, "task functions must not be generic");
    }
    if !f.sig.generics.where_clause.is_none() {
        error(&mut errors, &f.sig, "task functions must not have `where` clauses");
    }
    if !f.sig.abi.is_none() {
        error(&mut errors, &f.sig, "task functions must not have an ABI qualifier");
    }
    if !f.sig.variadic.is_none() {
        error(&mut errors, &f.sig, "task functions must not be variadic");
    }
    if f.sig.asyncness.is_some() {
        match &f.sig.output {
            ReturnType::Default => {}
            ReturnType::Type(_, ty) => match &**ty {
                Type::Tuple(tuple) if tuple.elems.is_empty() => {}
                Type::Never(_) => {}
                _ => error(
                    &mut errors,
                    &f.sig,
                    "task functions must either not return a value, return `()` or return `!`",
                ),
            },
        }
    }

    let mut args = Vec::new();
    let mut fargs = f.sig.inputs.clone();

    for arg in fargs.iter_mut() {
        match arg {
            syn::FnArg::Receiver(_) => {
                error(&mut errors, arg, "task functions must not have `self` arguments");
            }
            syn::FnArg::Typed(t) => {
                check_arg_ty(&mut errors, &t.ty);
                match t.pat.as_mut() {
                    syn::Pat::Ident(id) => {
                        id.mutability = None;
                        args.push((id.clone(), t.attrs.clone()));
                    }
                    _ => {
                        error(
                            &mut errors,
                            arg,
                            "pattern matching in task arguments is not yet supported",
                        );
                    }
                }
            }
        }
    }

    // Copy the generics + where clause to avoid more spurious errors.
    let generics = &f.sig.generics;
    let where_clause = &f.sig.generics.where_clause;
    let unsafety = &f.sig.unsafety;
    let visibility = &f.vis;

    // assemble the original input arguments,
    // including any attributes that may have
    // been applied previously
    let mut full_args = Vec::new();
    for (arg, cfgs) in args {
        full_args.push(quote!(
            #(#cfgs)*
            #arg
        ));
    }

    let task_ident = f.sig.ident.clone();
    let task_inner_ident = format_ident!("__{}_task", task_ident);

    let task_inner_future_output = match &f.sig.output {
        ReturnType::Default => quote! {-> impl ::core::future::Future<Output = ()>},
        // Special case the never type since we can't stuff it into a `impl Future<Output = !>`
        ReturnType::Type(arrow, maybe_never) if matches!(**maybe_never, Type::Never(_)) => quote! {
            #arrow #maybe_never
        },
        // Grab the arrow span, why not
        ReturnType::Type(arrow, typ) if f.sig.asyncness.is_some() => quote! {
            #arrow impl ::core::future::Future<Output = #typ>
        },
        // We assume that if `f` isn't async, it must return `-> impl Future<...>`
        // This is checked using traits later
        ReturnType::Type(arrow, typ) => quote! {
            #arrow #typ
        },
    };

    let task_inner_body = if errors.is_empty() {
        quote! {
            #f

            // SAFETY: All the preconditions to `#task_ident` apply to
            //         all contexts `#task_inner_ident` is called in
            #unsafety {
                #task_ident(#(#full_args,)*)
            }
        }
    } else {
        quote! {
            async {::core::todo!()}
        }
    };

    let task_inner = quote! {
        #visibility fn #task_inner_ident #generics (#fargs)
        #task_inner_future_output
        #where_clause
        {
            #task_inner_body
        }
    };

    let spawn = if returns_impl_trait {
        quote!(spawn)
    } else {
        quote!(_spawn_async_fn)
    };

    #[cfg(feature = "nightly")]
    let mut task_outer_body = quote! {
        trait _EmbassyInternalTaskTrait {
            type Fut: ::core::future::Future<Output: #embassy_executor::_export::TaskReturnValue> + 'static;
            fn construct(#fargs) -> Self::Fut;
        }

        impl _EmbassyInternalTaskTrait for () {
            type Fut = impl core::future::Future<Output: #embassy_executor::_export::TaskReturnValue> + 'static;
            fn construct(#fargs) -> Self::Fut {
                #task_inner_ident(#(#full_args,)*)
            }
        }

        const POOL_SIZE: usize = #pool_size;
        static POOL: #embassy_executor::raw::TaskPool<<() as _EmbassyInternalTaskTrait>::Fut, POOL_SIZE> = #embassy_executor::raw::TaskPool::new();
        unsafe { POOL.#spawn(move || <() as _EmbassyInternalTaskTrait>::construct(#(#full_args,)*)) }
    };
    #[cfg(not(feature = "nightly"))]
    let mut task_outer_body = quote! {
        const fn __task_pool_get<F, Args, Fut>(_: F) -> &'static #embassy_executor::raw::TaskPool<Fut, POOL_SIZE>
        where
            F: #embassy_executor::_export::TaskFn<Args, Fut = Fut>,
            Fut: ::core::future::Future + 'static,
        {
            unsafe { &*POOL.get().cast() }
        }

        const POOL_SIZE: usize = #pool_size;
        static POOL: #embassy_executor::_export::TaskPoolHolder<
            {#embassy_executor::_export::task_pool_size::<_, _, _, POOL_SIZE>(#task_inner_ident)},
            {#embassy_executor::_export::task_pool_align::<_, _, _, POOL_SIZE>(#task_inner_ident)},
        > = unsafe { ::core::mem::transmute(#embassy_executor::_export::task_pool_new::<_, _, _, POOL_SIZE>(#task_inner_ident)) };
        unsafe { __task_pool_get(#task_inner_ident).#spawn(move || #task_inner_ident(#(#full_args,)*)) }
    };

    let task_outer_attrs = &f.attrs;

    if !errors.is_empty() {
        task_outer_body = quote! {
            #![allow(unused_variables, unreachable_code)]
            let _x: #embassy_executor::SpawnToken<()> = ::core::todo!();
            _x
        };
    }

    let result = quote! {
        // This is the user's task function, renamed.
        // We put it outside the #task_ident fn below, because otherwise
        // the items defined there (such as POOL) would be in scope
        // in the user's code.
        #[doc(hidden)]
        #task_inner

        #(#task_outer_attrs)*
        #visibility #unsafety fn #task_ident #generics (#fargs) -> #embassy_executor::SpawnToken<impl Sized> #where_clause{
            #task_outer_body
        }

        #errors
    };

    result
}

fn check_arg_ty(errors: &mut TokenStream, ty: &Type) {
    struct Visitor<'a> {
        errors: &'a mut TokenStream,
    }

    impl<'a, 'ast> Visit<'ast> for Visitor<'a> {
        fn visit_type_reference(&mut self, i: &'ast syn::TypeReference) {
            // Only check for elided lifetime here. If not elided, it's checked by `visit_lifetime`.
            if i.lifetime.is_none() {
                error(
                    self.errors,
                    i.and_token,
                    "Arguments for tasks must live forever. Try using the `'static` lifetime.",
                )
            }
            visit::visit_type_reference(self, i);
        }

        fn visit_lifetime(&mut self, i: &'ast syn::Lifetime) {
            if i.ident.to_string() != "static" {
                error(
                    self.errors,
                    i,
                    "Arguments for tasks must live forever. Try using the `'static` lifetime.",
                )
            }
        }

        fn visit_type_impl_trait(&mut self, i: &'ast syn::TypeImplTrait) {
            error(self.errors, i, "`impl Trait` is not allowed in task arguments. It is syntax sugar for generics, and tasks can't be generic.");
        }
    }

    Visit::visit_type(&mut Visitor { errors }, ty);
}