Skip to main content

cargo_airbender/commands/
vk.rs

1use crate::cli::{GenerateVkArgs, ProverLevelArg, VerifyProofArgs};
2use crate::error::{CliError, Result};
3use crate::ui;
4use serde::{de::DeserializeOwned, Serialize};
5use std::path::Path;
6
7pub fn generate(args: GenerateVkArgs) -> Result<()> {
8    ensure_gpu_vk_support()?;
9    let security = args.security;
10    let host_security = security.into();
11
12    let vk = match args.level {
13        ProverLevelArg::RecursionUnified => {
14            let vk = airbender_host::compute_unified_vk(&args.app_bin, host_security).map_err(
15                |err| {
16                    CliError::with_source(
17                        format!(
18                            "failed to compute unified verification keys for `{}`",
19                            args.app_bin.display()
20                        ),
21                        err,
22                    )
23                },
24            )?;
25            airbender_host::VerificationKey::RealUnified(
26                airbender_host::RealUnifiedVerificationKey { vk },
27            )
28        }
29        ProverLevelArg::Base | ProverLevelArg::RecursionUnrolled => {
30            let level = as_host_level(args.level);
31            let vk = airbender_host::compute_unrolled_vk(&args.app_bin, level, host_security)
32                .map_err(|err| {
33                    CliError::with_source(
34                        format!(
35                            "failed to compute unrolled verification keys for `{}`",
36                            args.app_bin.display()
37                        ),
38                        err,
39                    )
40                })?;
41            airbender_host::VerificationKey::RealUnrolled(
42                airbender_host::RealUnrolledVerificationKey { level, vk },
43            )
44        }
45    };
46
47    write_bincode(&args.output, &vk)?;
48
49    ui::success("verification keys generated");
50    ui::field("level", level_name(args.level));
51    ui::field("security", security);
52    ui::field("output", args.output.display());
53
54    Ok(())
55}
56
57fn ensure_gpu_vk_support() -> Result<()> {
58    #[cfg(feature = "gpu-prover")]
59    {
60        Ok(())
61    }
62
63    #[cfg(not(feature = "gpu-prover"))]
64    {
65        Err(CliError::new(
66            "verification key generation requires GPU support in `cargo-airbender`",
67        )
68        .with_hint(
69            "rebuild `cargo-airbender` with default features or pass `--features gpu-prover` to use `generate-vk`",
70        ))
71    }
72}
73
74pub fn verify(args: VerifyProofArgs) -> Result<()> {
75    let expected_output_words = parse_expected_output_words(args.expected_output.as_deref())?;
76
77    let proof: airbender_host::Proof = read_bincode(&args.proof).map_err(|err| {
78        CliError::with_source(
79            format!("failed to decode proof from `{}`", args.proof.display()),
80            err,
81        )
82    })?;
83
84    let vk: airbender_host::VerificationKey = read_bincode(&args.vk).map_err(|err| {
85        CliError::with_source(
86            format!(
87                "failed to decode verification key file `{}`",
88                args.vk.display()
89            ),
90            err,
91        )
92    })?;
93
94    let (level, security) = match &proof {
95        airbender_host::Proof::Dev(_) => {
96            return Err(CliError::new(
97                "detected a dev proof; `cargo airbender verify-proof` supports only real proofs",
98            )
99            .with_hint(
100                "verify dev proofs through `airbender-host` with `Program::dev_verifier()`",
101            ));
102        }
103        airbender_host::Proof::Real(proof) => {
104            let expected_output_commit = expected_output_words
105                .as_ref()
106                .map(|words| words as &dyn airbender_host::Commit);
107
108            airbender_host::verify_real_proof_with_vk(proof, &vk, expected_output_commit)
109                .map_err(|err| CliError::with_source("proof verification failed", err))?;
110            (proof.level(), proof.security())
111        }
112    };
113
114    if expected_output_words.is_none() {
115        tracing::warn!("public outputs were not provided; only proof/VK validity was checked");
116    }
117
118    ui::success("proof verified");
119    ui::field("level", host_level_name(level));
120    ui::field("security", security);
121    if let Some(words) = expected_output_words {
122        ui::field("expected_output", format_output_words(&words));
123    }
124
125    Ok(())
126}
127
128fn parse_expected_output_words(raw: Option<&str>) -> Result<Option<[u32; 8]>> {
129    let Some(raw) = raw else {
130        return Ok(None);
131    };
132
133    let trimmed = raw.trim();
134    if trimmed.is_empty() {
135        return Err(CliError::new("`--expected-output` cannot be empty")
136            .with_hint("provide comma-separated u32 words, for example `--expected-output 42`"));
137    }
138
139    let parts: Vec<&str> = trimmed.split(',').collect();
140    if parts.len() > 8 {
141        return Err(CliError::new(format!(
142            "`--expected-output` accepts at most 8 words (got {})",
143            parts.len()
144        ))
145        .with_hint(
146            "provide up to 8 comma-separated values for x10..x17; missing words are zero-padded",
147        ));
148    }
149
150    let mut words = [0u32; 8];
151    for (index, token) in parts.into_iter().enumerate() {
152        let token = token.trim();
153        if token.is_empty() {
154            return Err(CliError::new(format!(
155                "found an empty word at position {} in `--expected-output`",
156                index + 1
157            ))
158            .with_hint("use comma-separated values like `42,0,0`"));
159        }
160        words[index] = parse_output_word(token, index + 1)?;
161    }
162
163    Ok(Some(words))
164}
165
166fn parse_output_word(token: &str, position: usize) -> Result<u32> {
167    if let Some(hex) = token
168        .strip_prefix("0x")
169        .or_else(|| token.strip_prefix("0X"))
170    {
171        if hex.is_empty() {
172            return Err(CliError::new(format!(
173                "failed to parse output word at position {position}: `{token}`"
174            ))
175            .with_hint("hex output words must use `0x` followed by one or more hex digits"));
176        }
177
178        return u32::from_str_radix(hex, 16)
179            .map_err(|err| {
180                CliError::with_source(
181                    format!("failed to parse output word at position {position}: `{token}`"),
182                    err,
183                )
184            })
185            .map_err(|err| err.with_hint("use decimal or 0x-prefixed hexadecimal u32 words"));
186    }
187
188    token
189        .parse::<u32>()
190        .map_err(|err| {
191            CliError::with_source(
192                format!("failed to parse output word at position {position}: `{token}`"),
193                err,
194            )
195        })
196        .map_err(|err| err.with_hint("use decimal or 0x-prefixed hexadecimal u32 words"))
197}
198
199fn format_output_words(words: &[u32; 8]) -> String {
200    words
201        .iter()
202        .map(u32::to_string)
203        .collect::<Vec<_>>()
204        .join(",")
205}
206
207fn as_host_level(level: ProverLevelArg) -> airbender_host::ProverLevel {
208    match level {
209        ProverLevelArg::Base => airbender_host::ProverLevel::Base,
210        ProverLevelArg::RecursionUnrolled => airbender_host::ProverLevel::RecursionUnrolled,
211        ProverLevelArg::RecursionUnified => airbender_host::ProverLevel::RecursionUnified,
212    }
213}
214
215fn level_name(level: ProverLevelArg) -> &'static str {
216    match level {
217        ProverLevelArg::Base => "base",
218        ProverLevelArg::RecursionUnrolled => "recursion-unrolled",
219        ProverLevelArg::RecursionUnified => "recursion-unified",
220    }
221}
222
223fn host_level_name(level: airbender_host::ProverLevel) -> &'static str {
224    match level {
225        airbender_host::ProverLevel::Base => "base",
226        airbender_host::ProverLevel::RecursionUnrolled => "recursion-unrolled",
227        airbender_host::ProverLevel::RecursionUnified => "recursion-unified",
228    }
229}
230
231fn read_bincode<T: DeserializeOwned>(path: &Path) -> Result<T> {
232    let bytes = std::fs::read(path).map_err(|err| {
233        CliError::with_source(format!("failed to read `{}`", path.display()), err)
234    })?;
235    let (decoded, read_len): (T, usize) =
236        bincode::serde::decode_from_slice(&bytes, bincode::config::standard())
237            .map_err(|err| CliError::with_source("failed to decode bincode payload", err))?;
238
239    if read_len != bytes.len() {
240        tracing::warn!(
241            "bincode decoded {} bytes but file is {} bytes",
242            read_len,
243            bytes.len()
244        );
245    }
246    Ok(decoded)
247}
248
249fn write_bincode<T: Serialize>(path: &Path, value: &T) -> Result<()> {
250    let encoded = bincode::serde::encode_to_vec(value, bincode::config::standard())
251        .map_err(|err| CliError::with_source("failed to encode bincode payload", err))?;
252    std::fs::write(path, encoded).map_err(|err| {
253        CliError::with_source(format!("failed to write `{}`", path.display()), err)
254    })?;
255
256    Ok(())
257}
258
259#[cfg(test)]
260mod tests {
261    use super::*;
262    #[cfg(not(feature = "gpu-prover"))]
263    use crate::cli::{GenerateVkArgs, SecurityLevelArg};
264    #[cfg(not(feature = "gpu-prover"))]
265    use std::path::PathBuf;
266
267    #[cfg(not(feature = "gpu-prover"))]
268    #[test]
269    fn generate_vk_requires_gpu_support() {
270        let err = generate(GenerateVkArgs {
271            app_bin: PathBuf::from("app.bin"),
272            output: PathBuf::from("vk.bin"),
273            level: ProverLevelArg::Base,
274            security: SecurityLevelArg::default(),
275        })
276        .expect_err("generate-vk must require gpu-prover support");
277
278        assert!(
279            err.to_string().contains("requires GPU support"),
280            "unexpected error: {err}"
281        );
282    }
283
284    #[test]
285    fn parse_expected_output_none() {
286        let parsed = parse_expected_output_words(None).expect("parse should succeed");
287        assert!(parsed.is_none());
288    }
289
290    #[test]
291    fn parse_expected_output_pads_trailing_words() {
292        let parsed = parse_expected_output_words(Some("42")).expect("parse should succeed");
293        assert_eq!(parsed, Some([42, 0, 0, 0, 0, 0, 0, 0]));
294    }
295
296    #[test]
297    fn parse_expected_output_supports_hex_and_spaces() {
298        let parsed =
299            parse_expected_output_words(Some("0x2a, 0X01, 7")).expect("parse should succeed");
300        assert_eq!(parsed, Some([42, 1, 7, 0, 0, 0, 0, 0]));
301    }
302
303    #[test]
304    fn parse_expected_output_rejects_too_many_words() {
305        let err =
306            parse_expected_output_words(Some("1,2,3,4,5,6,7,8,9")).expect_err("parse should fail");
307        assert!(err.to_string().contains("at most 8 words"));
308    }
309
310    #[test]
311    fn parse_expected_output_rejects_empty_word() {
312        let err = parse_expected_output_words(Some("1,,3")).expect_err("parse should fail");
313        assert!(err.to_string().contains("empty word"));
314    }
315
316    #[test]
317    fn parse_expected_output_rejects_invalid_word() {
318        let err = parse_expected_output_words(Some("1,nope")).expect_err("parse should fail");
319        assert!(err.to_string().contains("failed to parse output word"));
320    }
321}