1#![doc = include_str!("../README.md")]
2
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::{
6 parse::{Parse, ParseStream},
7 parse_macro_input,
8 spanned::Spanned,
9 Error, ItemFn, ReturnType, Token, Type,
10};
11
12struct MainArgs {
13 allocator_init: Option<syn::Path>,
14}
15
16impl Parse for MainArgs {
17 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
18 if input.is_empty() {
19 return Ok(Self {
20 allocator_init: None,
21 });
22 }
23
24 let key: syn::Ident = input.parse()?;
25 if key != "allocator_init" {
26 return Err(Error::new(
27 key.span(),
28 "unsupported argument; expected `allocator_init = <path>`",
29 ));
30 }
31 input.parse::<Token![=]>()?;
32 let allocator_init = input.parse::<syn::Path>()?;
33
34 if !input.is_empty() {
35 return Err(Error::new(
36 input.span(),
37 "unexpected trailing tokens in `airbender::main` arguments",
38 ));
39 }
40
41 Ok(Self {
42 allocator_init: Some(allocator_init),
43 })
44 }
45}
46
47#[proc_macro_attribute]
48pub fn main(attr: TokenStream, item: TokenStream) -> TokenStream {
49 let args = parse_macro_input!(attr as MainArgs);
50 let input = parse_macro_input!(item as ItemFn);
51 if !input.sig.inputs.is_empty() {
52 return syn::Error::new(
53 input.sig.inputs.span(),
54 "airbender::main does not accept arguments",
55 )
56 .to_compile_error()
57 .into();
58 }
59 if input.sig.asyncness.is_some() {
60 return syn::Error::new(
61 input.sig.asyncness.span(),
62 "airbender::main cannot be async",
63 )
64 .to_compile_error()
65 .into();
66 }
67 if let ReturnType::Type(_, ty) = &input.sig.output {
68 if matches!(**ty, Type::Never(_)) {
69 return syn::Error::new(
70 ty.span(),
71 "airbender::main must return a value implementing Commit (use () if needed)",
72 )
73 .to_compile_error()
74 .into();
75 }
76 }
77
78 let fn_name = &input.sig.ident;
79 let wrapper_name = syn::Ident::new(&format!("__airbender_start_{fn_name}"), fn_name.span());
80
81 let guest_entry = quote! {
82 let output = #fn_name();
83 ::airbender::guest::commit(output)
84 };
85
86 let start_call = if let Some(allocator_init) = args.allocator_init {
87 quote! {
88 ::airbender::rt::start_with_allocator_init(#allocator_init, || {
89 #guest_entry
90 })
91 }
92 } else {
93 quote! {
94 ::airbender::rt::start(|| {
95 #guest_entry
96 })
97 }
98 };
99
100 let expanded = quote! {
101 #input
102
103 #[no_mangle]
104 #[link_section = ".init.rust"]
105 #[export_name = "_start_rust"]
106 pub extern "C" fn #wrapper_name() -> ! {
107 #start_call
108 }
109 };
110
111 expanded.into()
112}