cargo_airbender/commands/new/
template.rs1use super::profiles::ProverBackendProfile;
2use crate::cli::NewAllocatorArg;
3use crate::error::{CliError, Result};
4use airbender_build::{DEFAULT_GUEST_TARGET, DEFAULT_GUEST_TOOLCHAIN};
5use serde::Serialize;
6use std::fs;
7use std::path::Path;
8use tera::{Context, Tera};
9
10const GITIGNORE_TEMPLATE: &str = include_str!("../../../templates/.gitignore.template");
11const ROOT_README_TEMPLATE: &str = include_str!("../../../templates/README.md.template");
12const GUEST_CARGO_TEMPLATE: &str = include_str!("../../../templates/guest/Cargo.toml.template");
13const GUEST_MAIN_TEMPLATE: &str = include_str!("../../../templates/guest/src/main.rs.template");
14const GUEST_TOOLCHAIN_TEMPLATE: &str =
15 include_str!("../../../templates/guest/rust-toolchain.toml.template");
16const GUEST_CARGO_CONFIG_TEMPLATE: &str =
17 include_str!("../../../templates/guest/.cargo/config.toml.template");
18const HOST_CARGO_TEMPLATE: &str = include_str!("../../../templates/host/Cargo.toml.template");
19const HOST_TOOLCHAIN_TEMPLATE: &str =
20 include_str!("../../../templates/host/rust-toolchain.toml.template");
21const CUSTOM_ALLOCATOR_MODULE_TEMPLATE: &str =
22 include_str!("../../../templates/snippets/custom_allocator_module.rs.template");
23
24#[derive(Clone, Copy)]
25struct TemplateFile<'a> {
26 relative_path: &'static str,
27 source: &'a str,
28}
29
30pub(super) struct TemplateContext<'a> {
31 project_name: &'a str,
32 sdk_dependency: &'a str,
33 host_dependency: &'a str,
34 enable_std: bool,
35 allocator: NewAllocatorArg,
36 host_dependency_features: &'a str,
37 readme_prover_backend_doc: &'a str,
38}
39
40#[derive(Serialize)]
41struct TemplateData {
42 project_name: String,
43 sdk_dep: String,
44 sdk_default_features: String,
45 sdk_features: String,
46 host_dep: String,
47 host_dep_features: String,
48 prover_backend_doc: String,
49 guest_attributes: String,
50 main_attr_args: String,
51 custom_allocator_block: String,
52 rust_toolchain_channel: String,
53 guest_target: String,
54}
55
56impl<'a> TemplateContext<'a> {
57 pub(super) fn new(
58 project_name: &'a str,
59 sdk_dependency: &'a str,
60 host_dependency: &'a str,
61 enable_std: bool,
62 allocator: NewAllocatorArg,
63 host_dependency_features: &'a str,
64 readme_prover_backend_doc: &'a str,
65 ) -> Self {
66 Self {
67 project_name,
68 sdk_dependency,
69 host_dependency,
70 enable_std,
71 allocator,
72 host_dependency_features,
73 readme_prover_backend_doc,
74 }
75 }
76
77 fn into_template_data(self) -> TemplateData {
78 TemplateData {
79 project_name: self.project_name.to_string(),
80 sdk_dep: self.sdk_dependency.to_string(),
81 sdk_default_features: sdk_default_features(self.allocator).to_string(),
82 sdk_features: sdk_features(self.enable_std, self.allocator),
83 host_dep: self.host_dependency.to_string(),
84 host_dep_features: self.host_dependency_features.to_string(),
85 prover_backend_doc: self.readme_prover_backend_doc.to_string(),
86 guest_attributes: guest_attributes(self.enable_std).to_string(),
87 main_attr_args: main_attr_args(self.allocator).to_string(),
88 custom_allocator_block: custom_allocator_block(self.allocator),
89 rust_toolchain_channel: DEFAULT_GUEST_TOOLCHAIN.to_string(),
90 guest_target: DEFAULT_GUEST_TARGET.to_string(),
91 }
92 }
93}
94
95pub(super) fn write_templates(
96 destination_root: &Path,
97 context: TemplateContext<'_>,
98 profile: ProverBackendProfile,
99) -> Result<()> {
100 let template_data = context.into_template_data();
101 let template_context = Context::from_serialize(&template_data)
102 .map_err(|err| CliError::with_source("failed to build template context", err))?;
103 let template_renderer = template_renderer(profile)?;
104
105 for template in template_files(profile) {
106 let destination_path = destination_root.join(template.relative_path);
107 if let Some(parent) = destination_path.parent() {
108 fs::create_dir_all(parent).map_err(|err| {
109 CliError::with_source(
110 format!("failed to create directory `{}`", parent.display()),
111 err,
112 )
113 })?;
114 }
115
116 let rendered = render_template(
117 &template_renderer,
118 template.relative_path,
119 &template_context,
120 )?;
121
122 fs::write(&destination_path, rendered).map_err(|err| {
123 CliError::with_source(
124 format!("failed to write `{}`", destination_path.display()),
125 err,
126 )
127 })?;
128 }
129
130 Ok(())
131}
132
133fn template_files(profile: ProverBackendProfile) -> [TemplateFile<'static>; 9] {
134 [
135 TemplateFile {
136 relative_path: ".gitignore",
137 source: GITIGNORE_TEMPLATE,
138 },
139 TemplateFile {
140 relative_path: "README.md",
141 source: ROOT_README_TEMPLATE,
142 },
143 TemplateFile {
144 relative_path: "guest/Cargo.toml",
145 source: GUEST_CARGO_TEMPLATE,
146 },
147 TemplateFile {
148 relative_path: "guest/src/main.rs",
149 source: GUEST_MAIN_TEMPLATE,
150 },
151 TemplateFile {
152 relative_path: "guest/rust-toolchain.toml",
153 source: GUEST_TOOLCHAIN_TEMPLATE,
154 },
155 TemplateFile {
156 relative_path: "guest/.cargo/config.toml",
157 source: GUEST_CARGO_CONFIG_TEMPLATE,
158 },
159 TemplateFile {
160 relative_path: "host/Cargo.toml",
161 source: HOST_CARGO_TEMPLATE,
162 },
163 TemplateFile {
164 relative_path: "host/src/main.rs",
165 source: profile.host_main_template,
166 },
167 TemplateFile {
168 relative_path: "host/rust-toolchain.toml",
169 source: HOST_TOOLCHAIN_TEMPLATE,
170 },
171 ]
172}
173
174fn template_renderer(profile: ProverBackendProfile) -> Result<Tera> {
175 let mut tera = Tera::default();
176 for template in template_files(profile) {
177 tera.add_raw_template(template.relative_path, template.source)
178 .map_err(|err| {
179 CliError::with_source(
180 format!("failed to parse template `{}`", template.relative_path),
181 err,
182 )
183 })?;
184 }
185 Ok(tera)
186}
187
188fn render_template(
189 template_renderer: &Tera,
190 relative_path: &str,
191 context: &Context,
192) -> Result<String> {
193 template_renderer
194 .render(relative_path, context)
195 .map_err(|err| {
196 CliError::with_source(format!("failed to render template `{relative_path}`"), err)
197 })
198}
199
200fn guest_attributes(enable_std: bool) -> &'static str {
201 if enable_std {
202 "#![no_main]"
203 } else {
204 "#![no_std]\n#![no_main]"
205 }
206}
207
208fn sdk_default_features(allocator: NewAllocatorArg) -> &'static str {
209 match allocator {
210 NewAllocatorArg::Talc => "",
211 NewAllocatorArg::Bump | NewAllocatorArg::Custom => ", default-features = false",
212 }
213}
214
215fn sdk_features(enable_std: bool, allocator: NewAllocatorArg) -> String {
216 let mut sdk_feature_flags = Vec::new();
217 if enable_std {
218 sdk_feature_flags.push("std");
219 }
220 match allocator {
221 NewAllocatorArg::Talc => {}
222 NewAllocatorArg::Bump => sdk_feature_flags.push("allocator-bump"),
223 NewAllocatorArg::Custom => sdk_feature_flags.push("allocator-custom"),
224 }
225
226 if sdk_feature_flags.is_empty() {
227 return String::new();
228 }
229
230 let rendered = sdk_feature_flags
231 .iter()
232 .map(|flag| format!("\"{flag}\""))
233 .collect::<Vec<_>>()
234 .join(", ");
235 format!(", features = [{rendered}]")
236}
237
238fn main_attr_args(allocator: NewAllocatorArg) -> &'static str {
239 match allocator {
240 NewAllocatorArg::Custom => "(allocator_init = crate::custom_allocator::init)",
241 NewAllocatorArg::Talc | NewAllocatorArg::Bump => "",
242 }
243}
244
245fn custom_allocator_block(allocator: NewAllocatorArg) -> String {
246 match allocator {
247 NewAllocatorArg::Custom => {
248 format!("\n\n{}", CUSTOM_ALLOCATOR_MODULE_TEMPLATE.trim_end())
249 }
250 NewAllocatorArg::Talc | NewAllocatorArg::Bump => String::new(),
251 }
252}