Skip to main content

cargo_airbender/commands/new/
template.rs

1use super::profiles::ProverBackendProfile;
2use crate::cli::NewAllocatorArg;
3use crate::error::{CliError, Result};
4use airbender_build::{DEFAULT_GUEST_TARGET, DEFAULT_GUEST_TOOLCHAIN};
5use serde::Serialize;
6use std::fs;
7use std::path::Path;
8use tera::{Context, Tera};
9
10const GITIGNORE_TEMPLATE: &str = include_str!("../../../templates/.gitignore.template");
11const ROOT_README_TEMPLATE: &str = include_str!("../../../templates/README.md.template");
12const GUEST_CARGO_TEMPLATE: &str = include_str!("../../../templates/guest/Cargo.toml.template");
13const GUEST_MAIN_TEMPLATE: &str = include_str!("../../../templates/guest/src/main.rs.template");
14const GUEST_TOOLCHAIN_TEMPLATE: &str =
15    include_str!("../../../templates/guest/rust-toolchain.toml.template");
16const GUEST_CARGO_CONFIG_TEMPLATE: &str =
17    include_str!("../../../templates/guest/.cargo/config.toml.template");
18const HOST_CARGO_TEMPLATE: &str = include_str!("../../../templates/host/Cargo.toml.template");
19const HOST_TOOLCHAIN_TEMPLATE: &str =
20    include_str!("../../../templates/host/rust-toolchain.toml.template");
21const CUSTOM_ALLOCATOR_MODULE_TEMPLATE: &str =
22    include_str!("../../../templates/snippets/custom_allocator_module.rs.template");
23
24#[derive(Clone, Copy)]
25struct TemplateFile<'a> {
26    relative_path: &'static str,
27    source: &'a str,
28}
29
30pub(super) struct TemplateContext<'a> {
31    project_name: &'a str,
32    sdk_dependency: &'a str,
33    host_dependency: &'a str,
34    enable_std: bool,
35    allocator: NewAllocatorArg,
36    host_dependency_features: &'a str,
37    readme_prover_backend_doc: &'a str,
38}
39
40#[derive(Serialize)]
41struct TemplateData {
42    project_name: String,
43    sdk_dep: String,
44    sdk_default_features: String,
45    sdk_features: String,
46    host_dep: String,
47    host_dep_features: String,
48    prover_backend_doc: String,
49    guest_attributes: String,
50    main_attr_args: String,
51    custom_allocator_block: String,
52    rust_toolchain_channel: String,
53    guest_target: String,
54}
55
56impl<'a> TemplateContext<'a> {
57    pub(super) fn new(
58        project_name: &'a str,
59        sdk_dependency: &'a str,
60        host_dependency: &'a str,
61        enable_std: bool,
62        allocator: NewAllocatorArg,
63        host_dependency_features: &'a str,
64        readme_prover_backend_doc: &'a str,
65    ) -> Self {
66        Self {
67            project_name,
68            sdk_dependency,
69            host_dependency,
70            enable_std,
71            allocator,
72            host_dependency_features,
73            readme_prover_backend_doc,
74        }
75    }
76
77    fn into_template_data(self) -> TemplateData {
78        TemplateData {
79            project_name: self.project_name.to_string(),
80            sdk_dep: self.sdk_dependency.to_string(),
81            sdk_default_features: sdk_default_features(self.allocator).to_string(),
82            sdk_features: sdk_features(self.enable_std, self.allocator),
83            host_dep: self.host_dependency.to_string(),
84            host_dep_features: self.host_dependency_features.to_string(),
85            prover_backend_doc: self.readme_prover_backend_doc.to_string(),
86            guest_attributes: guest_attributes(self.enable_std).to_string(),
87            main_attr_args: main_attr_args(self.allocator).to_string(),
88            custom_allocator_block: custom_allocator_block(self.allocator),
89            rust_toolchain_channel: DEFAULT_GUEST_TOOLCHAIN.to_string(),
90            guest_target: DEFAULT_GUEST_TARGET.to_string(),
91        }
92    }
93}
94
95pub(super) fn write_templates(
96    destination_root: &Path,
97    context: TemplateContext<'_>,
98    profile: ProverBackendProfile,
99) -> Result<()> {
100    let template_data = context.into_template_data();
101    let template_context = Context::from_serialize(&template_data)
102        .map_err(|err| CliError::with_source("failed to build template context", err))?;
103    let template_renderer = template_renderer(profile)?;
104
105    for template in template_files(profile) {
106        let destination_path = destination_root.join(template.relative_path);
107        if let Some(parent) = destination_path.parent() {
108            fs::create_dir_all(parent).map_err(|err| {
109                CliError::with_source(
110                    format!("failed to create directory `{}`", parent.display()),
111                    err,
112                )
113            })?;
114        }
115
116        let rendered = render_template(
117            &template_renderer,
118            template.relative_path,
119            &template_context,
120        )?;
121
122        fs::write(&destination_path, rendered).map_err(|err| {
123            CliError::with_source(
124                format!("failed to write `{}`", destination_path.display()),
125                err,
126            )
127        })?;
128    }
129
130    Ok(())
131}
132
133fn template_files(profile: ProverBackendProfile) -> [TemplateFile<'static>; 9] {
134    [
135        TemplateFile {
136            relative_path: ".gitignore",
137            source: GITIGNORE_TEMPLATE,
138        },
139        TemplateFile {
140            relative_path: "README.md",
141            source: ROOT_README_TEMPLATE,
142        },
143        TemplateFile {
144            relative_path: "guest/Cargo.toml",
145            source: GUEST_CARGO_TEMPLATE,
146        },
147        TemplateFile {
148            relative_path: "guest/src/main.rs",
149            source: GUEST_MAIN_TEMPLATE,
150        },
151        TemplateFile {
152            relative_path: "guest/rust-toolchain.toml",
153            source: GUEST_TOOLCHAIN_TEMPLATE,
154        },
155        TemplateFile {
156            relative_path: "guest/.cargo/config.toml",
157            source: GUEST_CARGO_CONFIG_TEMPLATE,
158        },
159        TemplateFile {
160            relative_path: "host/Cargo.toml",
161            source: HOST_CARGO_TEMPLATE,
162        },
163        TemplateFile {
164            relative_path: "host/src/main.rs",
165            source: profile.host_main_template,
166        },
167        TemplateFile {
168            relative_path: "host/rust-toolchain.toml",
169            source: HOST_TOOLCHAIN_TEMPLATE,
170        },
171    ]
172}
173
174fn template_renderer(profile: ProverBackendProfile) -> Result<Tera> {
175    let mut tera = Tera::default();
176    for template in template_files(profile) {
177        tera.add_raw_template(template.relative_path, template.source)
178            .map_err(|err| {
179                CliError::with_source(
180                    format!("failed to parse template `{}`", template.relative_path),
181                    err,
182                )
183            })?;
184    }
185    Ok(tera)
186}
187
188fn render_template(
189    template_renderer: &Tera,
190    relative_path: &str,
191    context: &Context,
192) -> Result<String> {
193    template_renderer
194        .render(relative_path, context)
195        .map_err(|err| {
196            CliError::with_source(format!("failed to render template `{relative_path}`"), err)
197        })
198}
199
200fn guest_attributes(enable_std: bool) -> &'static str {
201    if enable_std {
202        "#![no_main]"
203    } else {
204        "#![no_std]\n#![no_main]"
205    }
206}
207
208fn sdk_default_features(allocator: NewAllocatorArg) -> &'static str {
209    match allocator {
210        NewAllocatorArg::Talc => "",
211        NewAllocatorArg::Bump | NewAllocatorArg::Custom => ", default-features = false",
212    }
213}
214
215fn sdk_features(enable_std: bool, allocator: NewAllocatorArg) -> String {
216    let mut sdk_feature_flags = Vec::new();
217    if enable_std {
218        sdk_feature_flags.push("std");
219    }
220    match allocator {
221        NewAllocatorArg::Talc => {}
222        NewAllocatorArg::Bump => sdk_feature_flags.push("allocator-bump"),
223        NewAllocatorArg::Custom => sdk_feature_flags.push("allocator-custom"),
224    }
225
226    if sdk_feature_flags.is_empty() {
227        return String::new();
228    }
229
230    let rendered = sdk_feature_flags
231        .iter()
232        .map(|flag| format!("\"{flag}\""))
233        .collect::<Vec<_>>()
234        .join(", ");
235    format!(", features = [{rendered}]")
236}
237
238fn main_attr_args(allocator: NewAllocatorArg) -> &'static str {
239    match allocator {
240        NewAllocatorArg::Custom => "(allocator_init = crate::custom_allocator::init)",
241        NewAllocatorArg::Talc | NewAllocatorArg::Bump => "",
242    }
243}
244
245fn custom_allocator_block(allocator: NewAllocatorArg) -> String {
246    match allocator {
247        NewAllocatorArg::Custom => {
248            format!("\n\n{}", CUSTOM_ALLOCATOR_MODULE_TEMPLATE.trim_end())
249        }
250        NewAllocatorArg::Talc | NewAllocatorArg::Bump => String::new(),
251    }
252}