cargo_airbender/commands/
vk.rs1use crate::cli::{GenerateVkArgs, ProverLevelArg, VerifyProofArgs};
2use crate::error::{CliError, Result};
3use crate::ui;
4use serde::{de::DeserializeOwned, Serialize};
5use std::path::Path;
6
7pub fn generate(args: GenerateVkArgs) -> Result<()> {
8 ensure_gpu_vk_support()?;
9 let security = args.security;
10 let host_security = security.into();
11
12 let vk = match args.level {
13 ProverLevelArg::RecursionUnified => {
14 let vk = airbender_host::compute_unified_vk(&args.app_bin, host_security).map_err(
15 |err| {
16 CliError::with_source(
17 format!(
18 "failed to compute unified verification keys for `{}`",
19 args.app_bin.display()
20 ),
21 err,
22 )
23 },
24 )?;
25 airbender_host::VerificationKey::RealUnified(
26 airbender_host::RealUnifiedVerificationKey { vk },
27 )
28 }
29 ProverLevelArg::Base | ProverLevelArg::RecursionUnrolled => {
30 let level = as_host_level(args.level);
31 let vk = airbender_host::compute_unrolled_vk(&args.app_bin, level, host_security)
32 .map_err(|err| {
33 CliError::with_source(
34 format!(
35 "failed to compute unrolled verification keys for `{}`",
36 args.app_bin.display()
37 ),
38 err,
39 )
40 })?;
41 airbender_host::VerificationKey::RealUnrolled(
42 airbender_host::RealUnrolledVerificationKey { level, vk },
43 )
44 }
45 };
46
47 write_bincode(&args.output, &vk)?;
48
49 ui::success("verification keys generated");
50 ui::field("level", level_name(args.level));
51 ui::field("security", security);
52 ui::field("output", args.output.display());
53
54 Ok(())
55}
56
57fn ensure_gpu_vk_support() -> Result<()> {
58 #[cfg(feature = "gpu-prover")]
59 {
60 Ok(())
61 }
62
63 #[cfg(not(feature = "gpu-prover"))]
64 {
65 Err(CliError::new(
66 "verification key generation requires GPU support in `cargo-airbender`",
67 )
68 .with_hint(
69 "rebuild `cargo-airbender` with default features or pass `--features gpu-prover` to use `generate-vk`",
70 ))
71 }
72}
73
74pub fn verify(args: VerifyProofArgs) -> Result<()> {
75 let expected_output_words = parse_expected_output_words(args.expected_output.as_deref())?;
76
77 let proof: airbender_host::Proof = read_bincode(&args.proof).map_err(|err| {
78 CliError::with_source(
79 format!("failed to decode proof from `{}`", args.proof.display()),
80 err,
81 )
82 })?;
83
84 let vk: airbender_host::VerificationKey = read_bincode(&args.vk).map_err(|err| {
85 CliError::with_source(
86 format!(
87 "failed to decode verification key file `{}`",
88 args.vk.display()
89 ),
90 err,
91 )
92 })?;
93
94 let (level, security) = match &proof {
95 airbender_host::Proof::Dev(_) => {
96 return Err(CliError::new(
97 "detected a dev proof; `cargo airbender verify-proof` supports only real proofs",
98 )
99 .with_hint(
100 "verify dev proofs through `airbender-host` with `Program::dev_verifier()`",
101 ));
102 }
103 airbender_host::Proof::Real(proof) => {
104 let expected_output_commit = expected_output_words
105 .as_ref()
106 .map(|words| words as &dyn airbender_host::Commit);
107
108 airbender_host::verify_real_proof_with_vk(proof, &vk, expected_output_commit)
109 .map_err(|err| CliError::with_source("proof verification failed", err))?;
110 (proof.level(), proof.security())
111 }
112 };
113
114 if expected_output_words.is_none() {
115 tracing::warn!("public outputs were not provided; only proof/VK validity was checked");
116 }
117
118 ui::success("proof verified");
119 ui::field("level", host_level_name(level));
120 ui::field("security", security);
121 if let Some(words) = expected_output_words {
122 ui::field("expected_output", format_output_words(&words));
123 }
124
125 Ok(())
126}
127
128fn parse_expected_output_words(raw: Option<&str>) -> Result<Option<[u32; 8]>> {
129 let Some(raw) = raw else {
130 return Ok(None);
131 };
132
133 let trimmed = raw.trim();
134 if trimmed.is_empty() {
135 return Err(CliError::new("`--expected-output` cannot be empty")
136 .with_hint("provide comma-separated u32 words, for example `--expected-output 42`"));
137 }
138
139 let parts: Vec<&str> = trimmed.split(',').collect();
140 if parts.len() > 8 {
141 return Err(CliError::new(format!(
142 "`--expected-output` accepts at most 8 words (got {})",
143 parts.len()
144 ))
145 .with_hint(
146 "provide up to 8 comma-separated values for x10..x17; missing words are zero-padded",
147 ));
148 }
149
150 let mut words = [0u32; 8];
151 for (index, token) in parts.into_iter().enumerate() {
152 let token = token.trim();
153 if token.is_empty() {
154 return Err(CliError::new(format!(
155 "found an empty word at position {} in `--expected-output`",
156 index + 1
157 ))
158 .with_hint("use comma-separated values like `42,0,0`"));
159 }
160 words[index] = parse_output_word(token, index + 1)?;
161 }
162
163 Ok(Some(words))
164}
165
166fn parse_output_word(token: &str, position: usize) -> Result<u32> {
167 if let Some(hex) = token
168 .strip_prefix("0x")
169 .or_else(|| token.strip_prefix("0X"))
170 {
171 if hex.is_empty() {
172 return Err(CliError::new(format!(
173 "failed to parse output word at position {position}: `{token}`"
174 ))
175 .with_hint("hex output words must use `0x` followed by one or more hex digits"));
176 }
177
178 return u32::from_str_radix(hex, 16)
179 .map_err(|err| {
180 CliError::with_source(
181 format!("failed to parse output word at position {position}: `{token}`"),
182 err,
183 )
184 })
185 .map_err(|err| err.with_hint("use decimal or 0x-prefixed hexadecimal u32 words"));
186 }
187
188 token
189 .parse::<u32>()
190 .map_err(|err| {
191 CliError::with_source(
192 format!("failed to parse output word at position {position}: `{token}`"),
193 err,
194 )
195 })
196 .map_err(|err| err.with_hint("use decimal or 0x-prefixed hexadecimal u32 words"))
197}
198
199fn format_output_words(words: &[u32; 8]) -> String {
200 words
201 .iter()
202 .map(u32::to_string)
203 .collect::<Vec<_>>()
204 .join(",")
205}
206
207fn as_host_level(level: ProverLevelArg) -> airbender_host::ProverLevel {
208 match level {
209 ProverLevelArg::Base => airbender_host::ProverLevel::Base,
210 ProverLevelArg::RecursionUnrolled => airbender_host::ProverLevel::RecursionUnrolled,
211 ProverLevelArg::RecursionUnified => airbender_host::ProverLevel::RecursionUnified,
212 }
213}
214
215fn level_name(level: ProverLevelArg) -> &'static str {
216 match level {
217 ProverLevelArg::Base => "base",
218 ProverLevelArg::RecursionUnrolled => "recursion-unrolled",
219 ProverLevelArg::RecursionUnified => "recursion-unified",
220 }
221}
222
223fn host_level_name(level: airbender_host::ProverLevel) -> &'static str {
224 match level {
225 airbender_host::ProverLevel::Base => "base",
226 airbender_host::ProverLevel::RecursionUnrolled => "recursion-unrolled",
227 airbender_host::ProverLevel::RecursionUnified => "recursion-unified",
228 }
229}
230
231fn read_bincode<T: DeserializeOwned>(path: &Path) -> Result<T> {
232 let bytes = std::fs::read(path).map_err(|err| {
233 CliError::with_source(format!("failed to read `{}`", path.display()), err)
234 })?;
235 let (decoded, read_len): (T, usize) =
236 bincode::serde::decode_from_slice(&bytes, bincode::config::standard())
237 .map_err(|err| CliError::with_source("failed to decode bincode payload", err))?;
238
239 if read_len != bytes.len() {
240 tracing::warn!(
241 "bincode decoded {} bytes but file is {} bytes",
242 read_len,
243 bytes.len()
244 );
245 }
246 Ok(decoded)
247}
248
249fn write_bincode<T: Serialize>(path: &Path, value: &T) -> Result<()> {
250 let encoded = bincode::serde::encode_to_vec(value, bincode::config::standard())
251 .map_err(|err| CliError::with_source("failed to encode bincode payload", err))?;
252 std::fs::write(path, encoded).map_err(|err| {
253 CliError::with_source(format!("failed to write `{}`", path.display()), err)
254 })?;
255
256 Ok(())
257}
258
259#[cfg(test)]
260mod tests {
261 use super::*;
262 #[cfg(not(feature = "gpu-prover"))]
263 use crate::cli::{GenerateVkArgs, SecurityLevelArg};
264 #[cfg(not(feature = "gpu-prover"))]
265 use std::path::PathBuf;
266
267 #[cfg(not(feature = "gpu-prover"))]
268 #[test]
269 fn generate_vk_requires_gpu_support() {
270 let err = generate(GenerateVkArgs {
271 app_bin: PathBuf::from("app.bin"),
272 output: PathBuf::from("vk.bin"),
273 level: ProverLevelArg::Base,
274 security: SecurityLevelArg::default(),
275 })
276 .expect_err("generate-vk must require gpu-prover support");
277
278 assert!(
279 err.to_string().contains("requires GPU support"),
280 "unexpected error: {err}"
281 );
282 }
283
284 #[test]
285 fn parse_expected_output_none() {
286 let parsed = parse_expected_output_words(None).expect("parse should succeed");
287 assert!(parsed.is_none());
288 }
289
290 #[test]
291 fn parse_expected_output_pads_trailing_words() {
292 let parsed = parse_expected_output_words(Some("42")).expect("parse should succeed");
293 assert_eq!(parsed, Some([42, 0, 0, 0, 0, 0, 0, 0]));
294 }
295
296 #[test]
297 fn parse_expected_output_supports_hex_and_spaces() {
298 let parsed =
299 parse_expected_output_words(Some("0x2a, 0X01, 7")).expect("parse should succeed");
300 assert_eq!(parsed, Some([42, 1, 7, 0, 0, 0, 0, 0]));
301 }
302
303 #[test]
304 fn parse_expected_output_rejects_too_many_words() {
305 let err =
306 parse_expected_output_words(Some("1,2,3,4,5,6,7,8,9")).expect_err("parse should fail");
307 assert!(err.to_string().contains("at most 8 words"));
308 }
309
310 #[test]
311 fn parse_expected_output_rejects_empty_word() {
312 let err = parse_expected_output_words(Some("1,,3")).expect_err("parse should fail");
313 assert!(err.to_string().contains("empty word"));
314 }
315
316 #[test]
317 fn parse_expected_output_rejects_invalid_word() {
318 let err = parse_expected_output_words(Some("1,nope")).expect_err("parse should fail");
319 assert!(err.to_string().contains("failed to parse output word"));
320 }
321}