Skip to main content

airbender_host/prover/
gpu_prover.rs

1use super::{
2    base_path, receipt_from_real_proof, resolve_app_bin_path, ProveResult, Prover, ProverLevel,
3};
4use crate::error::{HostError, Result};
5use crate::proof::{Proof, RealProof};
6use crate::security::SecurityLevel;
7use execution_utils::unrolled_gpu::UnrolledProver;
8use gpu_prover::execution::prover::ExecutionProverConfiguration;
9use riscv_transpiler::abstractions::non_determinism::QuasiUARTSource;
10use std::any::Any;
11use std::path::{Path, PathBuf};
12use std::sync::atomic::{AtomicBool, Ordering};
13use std::sync::{mpsc, Mutex};
14use std::thread::JoinHandle;
15
16/// Builder for creating a configured cached GPU prover.
17pub struct GpuProverBuilder {
18    app_bin_path: PathBuf,
19    worker_threads: Option<usize>,
20    security: SecurityLevel,
21    level: ProverLevel,
22}
23
24impl GpuProverBuilder {
25    pub fn new(app_bin_path: impl AsRef<Path>) -> Self {
26        Self {
27            app_bin_path: app_bin_path.as_ref().to_path_buf(),
28            worker_threads: None,
29            security: SecurityLevel::default(),
30            level: ProverLevel::RecursionUnified,
31        }
32    }
33
34    pub fn with_worker_threads(mut self, worker_threads: usize) -> Self {
35        self.worker_threads = Some(worker_threads);
36        self
37    }
38
39    pub fn maybe_worker_threads(self, worker_threads: Option<usize>) -> Self {
40        match worker_threads {
41            Some(v) => self.with_worker_threads(v),
42            None => self,
43        }
44    }
45
46    pub fn with_level(mut self, level: ProverLevel) -> Self {
47        self.level = level;
48        self
49    }
50
51    pub fn with_security(mut self, security: SecurityLevel) -> Self {
52        self.security = security;
53        self
54    }
55
56    pub fn build(self) -> Result<GpuProver> {
57        GpuProver::new(
58            &self.app_bin_path,
59            self.worker_threads,
60            self.security,
61            self.level,
62        )
63    }
64}
65
66/// GPU prover wrapper that owns and reuses a single `UnrolledProver` instance.
67///
68/// ## Poisoning
69///
70/// Actual proving happens on a separate thread, and in case the program cannot be
71/// proven, the prover can panic. Prover panics are not unwind safe, so the thread
72/// and the prover will be disposed of, making this prover object poisoned, e.g. not
73/// usable for future proving attempts. Once poisoned, the prover will return an error
74/// on all operations.
75///
76/// After poisioning, you can instantiate a new prover if required.
77pub struct GpuProver {
78    command_tx: mpsc::Sender<WorkerCommand>,
79    worker_handle: Mutex<Option<JoinHandle<()>>>,
80    poisoned: AtomicBool,
81}
82
83enum WorkerCommand {
84    Prove {
85        input_words: Vec<u32>,
86        response_tx: mpsc::Sender<Result<ProveResult>>,
87    },
88    Shutdown,
89}
90
91impl GpuProver {
92    fn new(
93        app_bin_path: &Path,
94        worker_threads: Option<usize>,
95        security: SecurityLevel,
96        level: ProverLevel,
97    ) -> Result<Self> {
98        if matches!(worker_threads, Some(0)) {
99            return Err(HostError::Prover(
100                "worker thread count must be greater than zero".to_string(),
101            ));
102        }
103
104        let app_bin_path = resolve_app_bin_path(app_bin_path)?;
105        let (command_tx, worker_handle) =
106            spawn_worker(app_bin_path, worker_threads, security, level)?;
107
108        Ok(Self {
109            command_tx,
110            worker_handle: Mutex::new(Some(worker_handle)),
111            poisoned: AtomicBool::new(false),
112        })
113    }
114
115    pub fn is_poisoned(&self) -> bool {
116        self.poisoned.load(Ordering::SeqCst)
117    }
118
119    fn poisoned_error() -> HostError {
120        HostError::Prover("GPU prover is poisoned due to a previous proving panic".to_string())
121    }
122
123    fn handle_worker_failure(&self, operation: &str) -> HostError {
124        if self.poisoned.swap(true, Ordering::SeqCst) {
125            return Self::poisoned_error();
126        }
127
128        match self.take_worker_panic_message() {
129            Some(message) => HostError::Prover(format!(
130                "GPU prover panicked while {operation}; prover is now poisoned: {message}"
131            )),
132            None => HostError::Prover(format!(
133                "GPU prover worker failed while {operation}; prover is now poisoned"
134            )),
135        }
136    }
137
138    fn take_worker_panic_message(&self) -> Option<String> {
139        let mut handle_slot = match self.worker_handle.lock() {
140            Ok(slot) => slot,
141            Err(poisoned) => poisoned.into_inner(),
142        };
143        let handle = handle_slot.take()?;
144
145        match handle.join() {
146            Ok(()) => None,
147            Err(payload) => Some(panic_payload_to_string(payload)),
148        }
149    }
150}
151
152impl Prover for GpuProver {
153    fn prove(&self, input_words: &[u32]) -> Result<ProveResult> {
154        if self.is_poisoned() {
155            return Err(Self::poisoned_error());
156        }
157
158        let (response_tx, response_rx) = mpsc::channel();
159        self.command_tx
160            .send(WorkerCommand::Prove {
161                input_words: input_words.to_vec(),
162                response_tx,
163            })
164            .map_err(|_| self.handle_worker_failure("submitting a prove request"))?;
165
166        response_rx
167            .recv()
168            .map_err(|_| self.handle_worker_failure("receiving a prove response"))?
169    }
170}
171
172impl Drop for GpuProver {
173    fn drop(&mut self) {
174        let _ = self.command_tx.send(WorkerCommand::Shutdown);
175
176        let handle_slot = match self.worker_handle.get_mut() {
177            Ok(slot) => slot,
178            Err(poisoned) => poisoned.into_inner(),
179        };
180
181        if let Some(handle) = handle_slot.take() {
182            let _ = handle.join();
183        }
184    }
185}
186
187fn spawn_worker(
188    app_bin_path: PathBuf,
189    worker_threads: Option<usize>,
190    security: SecurityLevel,
191    level: ProverLevel,
192) -> Result<(mpsc::Sender<WorkerCommand>, JoinHandle<()>)> {
193    let (command_tx, command_rx) = mpsc::channel();
194    let (init_tx, init_rx) = mpsc::channel();
195
196    let worker_handle = std::thread::Builder::new()
197        .name("airbender-gpu-prover".to_string())
198        .spawn(move || {
199            gpu_worker_loop(
200                command_rx,
201                init_tx,
202                app_bin_path,
203                worker_threads,
204                security,
205                level,
206            )
207        })
208        .map_err(|err| {
209            HostError::Prover(format!("failed to spawn GPU prover worker thread: {err}"))
210        })?;
211
212    match init_rx.recv() {
213        Ok(Ok(())) => Ok((command_tx, worker_handle)),
214        Ok(Err(err)) => {
215            let _ = worker_handle.join();
216            Err(err)
217        }
218        Err(_) => {
219            let reason = match worker_handle.join() {
220                Ok(()) => "GPU prover worker exited during initialization".to_string(),
221                Err(payload) => format!(
222                    "GPU prover worker panicked during initialization: {}",
223                    panic_payload_to_string(payload)
224                ),
225            };
226            Err(HostError::Prover(reason))
227        }
228    }
229}
230
231fn gpu_worker_loop(
232    command_rx: mpsc::Receiver<WorkerCommand>,
233    init_tx: mpsc::Sender<Result<()>>,
234    app_bin_path: PathBuf,
235    worker_threads: Option<usize>,
236    security: SecurityLevel,
237    level: ProverLevel,
238) {
239    // Keep all prover state inside this dedicated thread so a panic does not unwind
240    // through host-call boundaries or require `AssertUnwindSafe`.
241    let prover = match create_unrolled_prover(
242        &app_bin_path,
243        worker_threads,
244        security,
245        level.as_unrolled_level(),
246    ) {
247        Ok(prover) => prover,
248        Err(err) => {
249            let _ = init_tx.send(Err(err));
250            return;
251        }
252    };
253
254    if init_tx.send(Ok(())).is_err() {
255        return;
256    }
257
258    while let Ok(command) = command_rx.recv() {
259        match command {
260            WorkerCommand::Prove {
261                input_words,
262                response_tx,
263            } => {
264                let oracle = QuasiUARTSource::new_with_reads(input_words);
265                // TODO: we use `batch 0` for all the jobs, which can cause issues when generating multiple proofs in parallel.
266                let (inner_proof, cycles) = prover.prove(0, oracle);
267                let receipt = receipt_from_real_proof(&inner_proof);
268                let proof = Proof::Real(RealProof::new(security, level, inner_proof));
269                let result = Ok(ProveResult {
270                    proof,
271                    cycles,
272                    receipt,
273                });
274                let _ = response_tx.send(result);
275            }
276            WorkerCommand::Shutdown => break,
277        }
278    }
279}
280
281fn panic_payload_to_string(payload: Box<dyn Any + Send + 'static>) -> String {
282    if let Some(message) = payload.downcast_ref::<String>() {
283        return message.clone();
284    }
285    if let Some(message) = payload.downcast_ref::<&str>() {
286        return (*message).to_string();
287    }
288
289    "unknown panic payload".to_string()
290}
291
292fn create_unrolled_prover(
293    app_bin_path: &Path,
294    worker_threads: Option<usize>,
295    security: SecurityLevel,
296    level: execution_utils::unrolled_gpu::UnrolledProverLevel,
297) -> Result<UnrolledProver> {
298    let base_path = base_path(app_bin_path)?;
299    let mut configuration = ExecutionProverConfiguration::default();
300    if let Some(threads) = worker_threads {
301        configuration.max_thread_pool_threads = Some(threads);
302        configuration.replay_worker_threads_count = threads;
303    }
304    Ok(UnrolledProver::new(
305        security.into(),
306        &base_path,
307        configuration,
308        level,
309    ))
310}