Skip to main content

cargo_airbender/commands/
prove.rs

1use crate::cli::{ProveArgs, ProverBackendArg, ProverLevelArg};
2use crate::error::{CliError, Result};
3use crate::input;
4use crate::ui;
5use airbender_host::Prover;
6
7pub fn run(args: ProveArgs) -> Result<()> {
8    let input_words = input::parse_input_words(&args.input)?;
9    let security = args.security;
10
11    let prove_result = match args.backend {
12        ProverBackendArg::Dev => {
13            if args.threads.is_some() {
14                tracing::warn!("ignoring `--threads` for dev backend");
15            }
16            if args.ram_bound.is_some() {
17                tracing::warn!("ignoring `--ram-bound` for dev backend");
18            }
19            if args.level != ProverLevelArg::RecursionUnified {
20                tracing::warn!("ignoring `--level` for dev backend");
21            }
22
23            let security = security.into();
24            let prover = airbender_host::DevProverBuilder::new(&args.app_bin)
25                .with_security(security)
26                .maybe_cycles(args.cycles)
27                .build()
28                .map_err(|err| {
29                CliError::with_source(
30                    format!(
31                        "failed to initialize dev prover for `{}`",
32                        args.app_bin.display()
33                    ),
34                    err,
35                )
36            })?;
37
38            prover.prove(&input_words)
39        }
40        ProverBackendArg::Gpu => {
41            if args.cycles.is_some() {
42                tracing::warn!("ignoring `--cycles` for gpu backend");
43            }
44            if args.ram_bound.is_some() {
45                tracing::warn!("ignoring `--ram-bound` for gpu backend");
46            }
47
48            #[cfg(feature = "gpu-prover")]
49            {
50                let level = as_host_level(args.level);
51                let security = security.into();
52                let prover = airbender_host::GpuProverBuilder::new(&args.app_bin)
53                    .with_level(level)
54                    .with_security(security)
55                    .maybe_worker_threads(args.threads)
56                    .build()
57                    .map_err(|err| {
58                    CliError::with_source(
59                        format!(
60                            "failed to initialize GPU prover for `{}`",
61                            args.app_bin.display()
62                        ),
63                        err,
64                    )
65                })?;
66
67                prover.prove(&input_words)
68            }
69
70            #[cfg(not(feature = "gpu-prover"))]
71            {
72                return Err(CliError::new(
73                    "GPU backend requires GPU support in `cargo-airbender`",
74                )
75                .with_hint(
76                    "rebuild `cargo-airbender` with default features or pass `--features gpu-prover` to use `--backend gpu`",
77                ));
78            }
79        }
80        ProverBackendArg::Cpu => {
81            let level = as_host_level(args.level);
82            let security = security.into();
83            if level != airbender_host::ProverLevel::Base {
84                return Err(
85                    CliError::new("CPU backend currently supports only `--level base`")
86                        .with_hint("use `--backend gpu` for recursion levels"),
87                );
88            }
89
90            let prover = airbender_host::CpuProverBuilder::new(&args.app_bin)
91                .with_security(security)
92                .maybe_worker_threads(args.threads)
93                .maybe_cycles(args.cycles)
94                .maybe_ram_bound(args.ram_bound)
95                .build()
96                .map_err(|err| {
97                CliError::with_source(
98                    format!(
99                        "failed to initialize CPU prover for `{}`",
100                        args.app_bin.display()
101                    ),
102                    err,
103                )
104            })?;
105
106            prover.prove(&input_words)
107        }
108    }
109    .map_err(|err| {
110        CliError::with_source(
111            format!("failed to generate proof for `{}`", args.app_bin.display()),
112            err,
113        )
114        .with_hint("set `RUST_LOG=info` to inspect prover backend logs")
115    })?;
116
117    tracing::info!("{}", prove_result.proof.debug_info());
118
119    let encoded = bincode::serde::encode_to_vec(&prove_result.proof, bincode::config::standard())
120        .map_err(|err| CliError::with_source("failed to encode proof", err))?;
121    std::fs::write(&args.output, encoded).map_err(|err| {
122        CliError::with_source(
123            format!("failed to write proof to `{}`", args.output.display()),
124            err,
125        )
126    })?;
127
128    ui::success("proof generated");
129    ui::field("backend", backend_name(args.backend));
130    ui::field("level", proof_level(args.backend, args.level));
131    ui::field("security", security);
132    ui::field("cycles", prove_result.cycles);
133    ui::field("output", args.output.display());
134
135    Ok(())
136}
137
138fn backend_name(backend: ProverBackendArg) -> &'static str {
139    match backend {
140        ProverBackendArg::Dev => "dev",
141        ProverBackendArg::Cpu => "cpu",
142        ProverBackendArg::Gpu => "gpu",
143    }
144}
145
146fn proof_level(backend: ProverBackendArg, level: ProverLevelArg) -> &'static str {
147    match backend {
148        ProverBackendArg::Dev => "dev",
149        ProverBackendArg::Cpu | ProverBackendArg::Gpu => level_name(level),
150    }
151}
152
153fn level_name(level: ProverLevelArg) -> &'static str {
154    match level {
155        ProverLevelArg::Base => "base",
156        ProverLevelArg::RecursionUnrolled => "recursion-unrolled",
157        ProverLevelArg::RecursionUnified => "recursion-unified",
158    }
159}
160
161fn as_host_level(level: ProverLevelArg) -> airbender_host::ProverLevel {
162    match level {
163        ProverLevelArg::Base => airbender_host::ProverLevel::Base,
164        ProverLevelArg::RecursionUnrolled => airbender_host::ProverLevel::RecursionUnrolled,
165        ProverLevelArg::RecursionUnified => airbender_host::ProverLevel::RecursionUnified,
166    }
167}