Skip to main content

airbender_guest/
input.rs

1//! Guest input helpers backed by the Airbender codec.
2
3use crate::transport::Transport;
4use airbender_codec::{AirbenderCodec, AirbenderCodecV0, CodecError};
5use airbender_core::wire::read_framed_bytes_with;
6use core::fmt;
7
8/// Errors that can occur when decoding inputs on the guest.
9#[derive(Debug)]
10pub enum GuestError {
11    Codec(CodecError),
12    UnsupportedTarget,
13}
14
15impl From<CodecError> for GuestError {
16    fn from(err: CodecError) -> Self {
17        GuestError::Codec(err)
18    }
19}
20
21impl fmt::Display for GuestError {
22    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
23        match self {
24            GuestError::Codec(err) => write!(f, "{err}"),
25            GuestError::UnsupportedTarget => {
26                f.write_str("csr transport is only available on riscv32")
27            }
28        }
29    }
30}
31
32/// Read a single value from the CSR-based transport.
33pub fn read<T: serde::de::DeserializeOwned>() -> Result<T, GuestError> {
34    #[cfg(target_arch = "riscv32")]
35    {
36        let mut transport = crate::transport::CsrTransport;
37        read_with(&mut transport)
38    }
39    #[cfg(not(target_arch = "riscv32"))]
40    {
41        Err(GuestError::UnsupportedTarget)
42    }
43}
44
45/// Read a single value using an explicit transport.
46pub fn read_with<T: serde::de::DeserializeOwned>(
47    transport: &mut impl Transport,
48) -> Result<T, GuestError> {
49    let bytes = read_framed_bytes_with(|| transport.read_word());
50    AirbenderCodecV0::decode(&bytes).map_err(GuestError::Codec)
51}
52
53#[cfg(test)]
54mod tests {
55    use super::*;
56    use crate::transport::MockTransport;
57    use airbender_core::wire::frame_words_from_bytes;
58    use alloc::vec;
59
60    #[derive(Debug, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
61    struct Payload {
62        counter: u32,
63        bytes: alloc::vec::Vec<u8>,
64    }
65
66    #[test]
67    fn reads_value_from_transport() {
68        let payload = Payload {
69            counter: 7,
70            bytes: vec![10u8, 20, 30],
71        };
72        let encoded = AirbenderCodecV0::encode(&payload).expect("encode");
73        let words = frame_words_from_bytes(&encoded).expect("frame words");
74        let mut transport = MockTransport::new(words);
75        let decoded: Payload = read_with(&mut transport).expect("read");
76        assert_eq!(decoded, payload);
77    }
78}