1use super::{
2 base_path, receipt_from_real_proof, resolve_app_bin_path, ProveResult, Prover, ProverLevel,
3};
4use crate::error::{HostError, Result};
5use crate::proof::{Proof, RealProof};
6use crate::security::SecurityLevel;
7use execution_utils::unrolled_gpu::UnrolledProver;
8use gpu_prover::execution::prover::ExecutionProverConfiguration;
9use riscv_transpiler::abstractions::non_determinism::QuasiUARTSource;
10use std::any::Any;
11use std::path::{Path, PathBuf};
12use std::sync::atomic::{AtomicBool, Ordering};
13use std::sync::{mpsc, Mutex};
14use std::thread::JoinHandle;
15
16pub struct GpuProverBuilder {
18 app_bin_path: PathBuf,
19 worker_threads: Option<usize>,
20 security: SecurityLevel,
21 level: ProverLevel,
22}
23
24impl GpuProverBuilder {
25 pub fn new(app_bin_path: impl AsRef<Path>) -> Self {
26 Self {
27 app_bin_path: app_bin_path.as_ref().to_path_buf(),
28 worker_threads: None,
29 security: SecurityLevel::default(),
30 level: ProverLevel::RecursionUnified,
31 }
32 }
33
34 pub fn with_worker_threads(mut self, worker_threads: usize) -> Self {
35 self.worker_threads = Some(worker_threads);
36 self
37 }
38
39 pub fn maybe_worker_threads(self, worker_threads: Option<usize>) -> Self {
40 match worker_threads {
41 Some(v) => self.with_worker_threads(v),
42 None => self,
43 }
44 }
45
46 pub fn with_level(mut self, level: ProverLevel) -> Self {
47 self.level = level;
48 self
49 }
50
51 pub fn with_security(mut self, security: SecurityLevel) -> Self {
52 self.security = security;
53 self
54 }
55
56 pub fn build(self) -> Result<GpuProver> {
57 GpuProver::new(
58 &self.app_bin_path,
59 self.worker_threads,
60 self.security,
61 self.level,
62 )
63 }
64}
65
66pub struct GpuProver {
78 command_tx: mpsc::Sender<WorkerCommand>,
79 worker_handle: Mutex<Option<JoinHandle<()>>>,
80 poisoned: AtomicBool,
81}
82
83enum WorkerCommand {
84 Prove {
85 input_words: Vec<u32>,
86 response_tx: mpsc::Sender<Result<ProveResult>>,
87 },
88 Shutdown,
89}
90
91impl GpuProver {
92 fn new(
93 app_bin_path: &Path,
94 worker_threads: Option<usize>,
95 security: SecurityLevel,
96 level: ProverLevel,
97 ) -> Result<Self> {
98 if matches!(worker_threads, Some(0)) {
99 return Err(HostError::Prover(
100 "worker thread count must be greater than zero".to_string(),
101 ));
102 }
103
104 let app_bin_path = resolve_app_bin_path(app_bin_path)?;
105 let (command_tx, worker_handle) =
106 spawn_worker(app_bin_path, worker_threads, security, level)?;
107
108 Ok(Self {
109 command_tx,
110 worker_handle: Mutex::new(Some(worker_handle)),
111 poisoned: AtomicBool::new(false),
112 })
113 }
114
115 pub fn is_poisoned(&self) -> bool {
116 self.poisoned.load(Ordering::SeqCst)
117 }
118
119 fn poisoned_error() -> HostError {
120 HostError::Prover("GPU prover is poisoned due to a previous proving panic".to_string())
121 }
122
123 fn handle_worker_failure(&self, operation: &str) -> HostError {
124 if self.poisoned.swap(true, Ordering::SeqCst) {
125 return Self::poisoned_error();
126 }
127
128 match self.take_worker_panic_message() {
129 Some(message) => HostError::Prover(format!(
130 "GPU prover panicked while {operation}; prover is now poisoned: {message}"
131 )),
132 None => HostError::Prover(format!(
133 "GPU prover worker failed while {operation}; prover is now poisoned"
134 )),
135 }
136 }
137
138 fn take_worker_panic_message(&self) -> Option<String> {
139 let mut handle_slot = match self.worker_handle.lock() {
140 Ok(slot) => slot,
141 Err(poisoned) => poisoned.into_inner(),
142 };
143 let handle = handle_slot.take()?;
144
145 match handle.join() {
146 Ok(()) => None,
147 Err(payload) => Some(panic_payload_to_string(payload)),
148 }
149 }
150}
151
152impl Prover for GpuProver {
153 fn prove(&self, input_words: &[u32]) -> Result<ProveResult> {
154 if self.is_poisoned() {
155 return Err(Self::poisoned_error());
156 }
157
158 let (response_tx, response_rx) = mpsc::channel();
159 self.command_tx
160 .send(WorkerCommand::Prove {
161 input_words: input_words.to_vec(),
162 response_tx,
163 })
164 .map_err(|_| self.handle_worker_failure("submitting a prove request"))?;
165
166 response_rx
167 .recv()
168 .map_err(|_| self.handle_worker_failure("receiving a prove response"))?
169 }
170}
171
172impl Drop for GpuProver {
173 fn drop(&mut self) {
174 let _ = self.command_tx.send(WorkerCommand::Shutdown);
175
176 let handle_slot = match self.worker_handle.get_mut() {
177 Ok(slot) => slot,
178 Err(poisoned) => poisoned.into_inner(),
179 };
180
181 if let Some(handle) = handle_slot.take() {
182 let _ = handle.join();
183 }
184 }
185}
186
187fn spawn_worker(
188 app_bin_path: PathBuf,
189 worker_threads: Option<usize>,
190 security: SecurityLevel,
191 level: ProverLevel,
192) -> Result<(mpsc::Sender<WorkerCommand>, JoinHandle<()>)> {
193 let (command_tx, command_rx) = mpsc::channel();
194 let (init_tx, init_rx) = mpsc::channel();
195
196 let worker_handle = std::thread::Builder::new()
197 .name("airbender-gpu-prover".to_string())
198 .spawn(move || {
199 gpu_worker_loop(
200 command_rx,
201 init_tx,
202 app_bin_path,
203 worker_threads,
204 security,
205 level,
206 )
207 })
208 .map_err(|err| {
209 HostError::Prover(format!("failed to spawn GPU prover worker thread: {err}"))
210 })?;
211
212 match init_rx.recv() {
213 Ok(Ok(())) => Ok((command_tx, worker_handle)),
214 Ok(Err(err)) => {
215 let _ = worker_handle.join();
216 Err(err)
217 }
218 Err(_) => {
219 let reason = match worker_handle.join() {
220 Ok(()) => "GPU prover worker exited during initialization".to_string(),
221 Err(payload) => format!(
222 "GPU prover worker panicked during initialization: {}",
223 panic_payload_to_string(payload)
224 ),
225 };
226 Err(HostError::Prover(reason))
227 }
228 }
229}
230
231fn gpu_worker_loop(
232 command_rx: mpsc::Receiver<WorkerCommand>,
233 init_tx: mpsc::Sender<Result<()>>,
234 app_bin_path: PathBuf,
235 worker_threads: Option<usize>,
236 security: SecurityLevel,
237 level: ProverLevel,
238) {
239 let prover = match create_unrolled_prover(
242 &app_bin_path,
243 worker_threads,
244 security,
245 level.as_unrolled_level(),
246 ) {
247 Ok(prover) => prover,
248 Err(err) => {
249 let _ = init_tx.send(Err(err));
250 return;
251 }
252 };
253
254 if init_tx.send(Ok(())).is_err() {
255 return;
256 }
257
258 while let Ok(command) = command_rx.recv() {
259 match command {
260 WorkerCommand::Prove {
261 input_words,
262 response_tx,
263 } => {
264 let oracle = QuasiUARTSource::new_with_reads(input_words);
265 let (inner_proof, cycles) = prover.prove(0, oracle);
267 let receipt = receipt_from_real_proof(&inner_proof);
268 let proof = Proof::Real(RealProof::new(security, level, inner_proof));
269 let result = Ok(ProveResult {
270 proof,
271 cycles,
272 receipt,
273 });
274 let _ = response_tx.send(result);
275 }
276 WorkerCommand::Shutdown => break,
277 }
278 }
279}
280
281fn panic_payload_to_string(payload: Box<dyn Any + Send + 'static>) -> String {
282 if let Some(message) = payload.downcast_ref::<String>() {
283 return message.clone();
284 }
285 if let Some(message) = payload.downcast_ref::<&str>() {
286 return (*message).to_string();
287 }
288
289 "unknown panic payload".to_string()
290}
291
292fn create_unrolled_prover(
293 app_bin_path: &Path,
294 worker_threads: Option<usize>,
295 security: SecurityLevel,
296 level: execution_utils::unrolled_gpu::UnrolledProverLevel,
297) -> Result<UnrolledProver> {
298 let base_path = base_path(app_bin_path)?;
299 let mut configuration = ExecutionProverConfiguration::default();
300 if let Some(threads) = worker_threads {
301 configuration.max_thread_pool_threads = Some(threads);
302 configuration.replay_worker_threads_count = threads;
303 }
304 Ok(UnrolledProver::new(
305 security.into(),
306 &base_path,
307 configuration,
308 level,
309 ))
310}