smart_config/de/
deserializer.rs

1//! `serde`-compatible deserializer based on a value with origin.
2
3use std::sync::Arc;
4
5use serde::{
6    Deserialize, Deserializer,
7    de::{
8        self, DeserializeSeed, Error as DeError, IntoDeserializer,
9        value::{MapDeserializer, SeqDeserializer},
10    },
11};
12
13use crate::{
14    error::ErrorWithOrigin,
15    utils::EnumVariant,
16    value::{Map, StrValue, Value, ValueOrigin, WithOrigin},
17};
18
19/// Available deserialization options.
20#[derive(Debug, Clone, Default)]
21pub struct DeserializerOptions {
22    /// Enables coercion of variant names between cases, e.g. from `SHOUTING_CASE` to `shouting_case`.
23    pub coerce_variant_names: bool,
24}
25
26impl WithOrigin {
27    #[cold]
28    pub(crate) fn invalid_type(&self, expected: &str) -> ErrorWithOrigin {
29        let actual = match &self.inner {
30            Value::Null => de::Unexpected::Unit,
31            Value::Bool(value) => de::Unexpected::Bool(*value),
32            Value::Number(value) => {
33                if let Some(value) = value.as_u64() {
34                    de::Unexpected::Unsigned(value)
35                } else if let Some(value) = value.as_i64() {
36                    de::Unexpected::Signed(value)
37                } else if let Some(value) = value.as_f64() {
38                    de::Unexpected::Float(value)
39                } else {
40                    de::Unexpected::Other("number")
41                }
42            }
43            Value::String(StrValue::Plain(s)) => de::Unexpected::Str(s),
44            Value::String(StrValue::Secret(_)) => de::Unexpected::Other("secret"),
45            Value::Array(_) => de::Unexpected::Seq,
46            Value::Object(_) => de::Unexpected::Map,
47        };
48        ErrorWithOrigin::json(
49            DeError::invalid_type(actual, &expected),
50            self.origin.clone(),
51        )
52    }
53}
54
55macro_rules! parse_int_value {
56    ($($ty:ident => $method:ident,)*) => {
57        $(
58        fn $method<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
59            let result = match self.value() {
60                Value::String(s) => {
61                    match s.expose().parse::<$ty>() {
62                        Ok(val) => val.into_deserializer().$method(visitor),
63                        Err(err) => {
64                            let err = DeError::custom(format_args!("{err} while parsing {} value '{s}'", stringify!($ty)));
65                            return Err(self.enrich_err(err));
66                        }
67                    }
68                }
69                Value::Number(number) => number.deserialize_any(visitor).map_err(|err| self.enrich_err(err.into())),
70                _ => return Err(self.invalid_type(&format!("{} number", stringify!($ty)))),
71            };
72            result.map_err(|err| err.set_origin_if_unset(&self.value.origin))
73        }
74        )*
75    }
76}
77
78#[derive(Debug, Clone, Copy)]
79pub struct ValueDeserializer<'a> {
80    value: &'a WithOrigin,
81    options: &'a DeserializerOptions,
82}
83
84impl<'a> ValueDeserializer<'a> {
85    pub(super) fn new(value: &'a WithOrigin, options: &'a DeserializerOptions) -> Self {
86        Self { value, options }
87    }
88
89    pub(super) fn value(&self) -> &'a Value {
90        &self.value.inner
91    }
92
93    pub(super) fn origin(&self) -> &Arc<ValueOrigin> {
94        &self.value.origin
95    }
96
97    pub(super) fn enrich_err(&self, err: serde_json::Error) -> ErrorWithOrigin {
98        ErrorWithOrigin::json(err, self.value.origin.clone())
99    }
100
101    pub(super) fn invalid_type(&self, expected: &str) -> ErrorWithOrigin {
102        self.value.invalid_type(expected)
103    }
104
105    fn parse_array<'de, V: de::Visitor<'de>>(
106        &self,
107        array: &[WithOrigin],
108        visitor: V,
109    ) -> Result<V::Value, ErrorWithOrigin> {
110        let mut deserializer = SeqDeserializer::new(
111            array
112                .iter()
113                .map(|val| ValueDeserializer::new(val, self.options)),
114        );
115        let seq = visitor.visit_seq(&mut deserializer)?;
116        deserializer.end()?;
117        Ok(seq)
118    }
119
120    fn parse_object<'de, V: de::Visitor<'de>>(
121        &self,
122        object: &Map,
123        visitor: V,
124    ) -> Result<V::Value, ErrorWithOrigin> {
125        let mut deserializer = MapDeserializer::new(
126            object
127                .iter()
128                .map(|(key, value)| (key.as_str(), ValueDeserializer::new(value, self.options))),
129        );
130        let map = visitor.visit_map(&mut deserializer)?;
131        deserializer.end()?;
132        Ok(map)
133    }
134}
135
136impl<'de> Deserializer<'de> for ValueDeserializer<'_> {
137    type Error = ErrorWithOrigin;
138
139    fn deserialize_any<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
140        let result = match self.value() {
141            Value::Null => visitor.visit_none(),
142            Value::Bool(value) => visitor.visit_bool(*value),
143            Value::Number(value) => value
144                .deserialize_any(visitor)
145                .map_err(|err| self.enrich_err(err)),
146            Value::String(value) => visitor.visit_str(value.expose()),
147            Value::Array(array) => self.parse_array(array, visitor),
148            Value::Object(object) => self.parse_object(object, visitor),
149        };
150        result.map_err(|err| err.set_origin_if_unset(&self.value.origin))
151    }
152
153    fn deserialize_option<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
154        let result = match self.value() {
155            Value::Null => visitor.visit_none(),
156            _ => visitor.visit_some(self),
157        };
158        result.map_err(|err| err.set_origin_if_unset(&self.value.origin))
159    }
160
161    fn deserialize_newtype_struct<V: de::Visitor<'de>>(
162        self,
163        _name: &'static str,
164        visitor: V,
165    ) -> Result<V::Value, Self::Error> {
166        visitor
167            .visit_newtype_struct(self)
168            .map_err(|err| err.set_origin_if_unset(&self.value.origin))
169    }
170
171    fn deserialize_seq<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
172        let result = match self.value() {
173            Value::Array(array) => self.parse_array(array, visitor),
174            _ => Err(self.invalid_type("array")),
175        };
176        result.map_err(|err| err.set_origin_if_unset(&self.value.origin))
177    }
178
179    fn deserialize_tuple<V: de::Visitor<'de>>(
180        self,
181        _len: usize,
182        visitor: V,
183    ) -> Result<V::Value, Self::Error> {
184        self.deserialize_seq(visitor)
185            .map_err(|err| err.set_origin_if_unset(&self.value.origin))
186    }
187
188    fn deserialize_tuple_struct<V: de::Visitor<'de>>(
189        self,
190        _name: &'static str,
191        _len: usize,
192        visitor: V,
193    ) -> Result<V::Value, Self::Error> {
194        self.deserialize_seq(visitor)
195            .map_err(|err| err.set_origin_if_unset(&self.value.origin))
196    }
197
198    fn deserialize_map<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
199        let result = match self.value() {
200            Value::Object(object) => self.parse_object(object, visitor),
201            _ => Err(self.invalid_type("object")),
202        };
203        result.map_err(|err| err.set_origin_if_unset(&self.value.origin))
204    }
205
206    fn deserialize_struct<V: de::Visitor<'de>>(
207        self,
208        _name: &'static str,
209        _fields: &'static [&'static str],
210        visitor: V,
211    ) -> Result<V::Value, Self::Error> {
212        let result = match self.value() {
213            Value::Array(array) => self.parse_array(array, visitor),
214            Value::Object(object) => self.parse_object(object, visitor),
215            _ => Err(self.invalid_type("array or object")),
216        };
217        result.map_err(|err| err.set_origin_if_unset(&self.value.origin))
218    }
219
220    fn deserialize_enum<V: de::Visitor<'de>>(
221        self,
222        _name: &'static str,
223        variants: &'static [&'static str],
224        visitor: V,
225    ) -> Result<V::Value, Self::Error> {
226        let (mut variant, value) = match self.value() {
227            Value::Object(object) if object.len() == 1 => {
228                let (variant, value) = object.iter().next().unwrap();
229                (variant.as_str(), Some(value))
230            }
231            Value::String(s) => (s.expose(), None),
232            _ => return Err(self.invalid_type("string or object with single key")),
233        };
234
235        if self.options.coerce_variant_names
236            && let Some(parsed) = EnumVariant::new(variant)
237            && let Some(expected_variant) = parsed.try_match(variants)
238        {
239            variant = expected_variant;
240        }
241
242        visitor
243            .visit_enum(EnumDeserializer {
244                variant,
245                inner: VariantDeserializer {
246                    value,
247                    options: self.options,
248                    parent_origin: self.value.origin.clone(),
249                },
250            })
251            .map_err(|err| err.set_origin_if_unset(&self.value.origin))
252    }
253
254    // Primitive values
255
256    fn deserialize_bool<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
257        let result = match self.value() {
258            Value::Bool(value) => visitor.visit_bool(*value),
259            Value::String(s) => match s.expose().parse::<bool>() {
260                Ok(val) => visitor.visit_bool(val),
261                Err(err) => {
262                    let err =
263                        DeError::custom(format_args!("{err} while parsing value '{s}' as boolean"));
264                    return Err(self.enrich_err(err));
265                }
266            },
267            _ => return Err(self.invalid_type("boolean or boolean-like string")),
268        };
269        result.map_err(|err: serde_json::Error| self.enrich_err(err))
270    }
271
272    parse_int_value! {
273        u8 => deserialize_u8,
274        u16 => deserialize_u16,
275        u32 => deserialize_u32,
276        u64 => deserialize_u64,
277        i8 => deserialize_i8,
278        i16 => deserialize_i16,
279        i32 => deserialize_i32,
280        i64 => deserialize_i64,
281        u128 => deserialize_u128,
282        i128 => deserialize_i128,
283        f32 => deserialize_f32,
284        f64 => deserialize_f64,
285    }
286
287    fn deserialize_string<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
288        let result = match self.value() {
289            Value::String(s) => visitor.visit_str(s.expose()),
290            Value::Null => visitor.visit_string("null".to_string()),
291            Value::Bool(value) => visitor.visit_string(value.to_string()),
292            Value::Number(value) => visitor.visit_string(value.to_string()),
293            _ => Err(self.invalid_type("string or other primitive type")),
294        };
295        result.map_err(|err| err.set_origin_if_unset(&self.value.origin))
296    }
297
298    fn deserialize_char<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
299        self.deserialize_string(visitor)
300    }
301
302    fn deserialize_str<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
303        self.deserialize_string(visitor)
304    }
305
306    fn deserialize_byte_buf<V: de::Visitor<'de>>(
307        self,
308        visitor: V,
309    ) -> Result<V::Value, Self::Error> {
310        let result = match self.value() {
311            Value::String(s) => visitor.visit_str(s.expose()),
312            Value::Array(array) => self.parse_array(array, visitor),
313            _ => return Err(self.invalid_type("string or array")),
314        };
315        result.map_err(|err| err.set_origin_if_unset(&self.value.origin))
316    }
317
318    fn deserialize_bytes<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
319        self.deserialize_byte_buf(visitor)
320    }
321
322    fn deserialize_identifier<V: de::Visitor<'de>>(
323        self,
324        visitor: V,
325    ) -> Result<V::Value, Self::Error> {
326        self.deserialize_string(visitor)
327    }
328
329    fn deserialize_unit<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
330        let result = match self.value() {
331            Value::Null => visitor.visit_unit(),
332            _ => Err(self.invalid_type("null")),
333        };
334        result.map_err(|err| err.set_origin_if_unset(&self.value.origin))
335    }
336
337    fn deserialize_unit_struct<V: de::Visitor<'de>>(
338        self,
339        _name: &'static str,
340        visitor: V,
341    ) -> Result<V::Value, Self::Error> {
342        self.deserialize_unit(visitor)
343    }
344
345    fn deserialize_ignored_any<V: de::Visitor<'de>>(
346        self,
347        visitor: V,
348    ) -> Result<V::Value, Self::Error> {
349        visitor
350            .visit_unit()
351            .map_err(|err: serde_json::Error| self.enrich_err(err))
352    }
353}
354
355impl IntoDeserializer<'_, ErrorWithOrigin> for ValueDeserializer<'_> {
356    type Deserializer = Self;
357
358    fn into_deserializer(self) -> Self::Deserializer {
359        self
360    }
361}
362
363#[derive(Debug)]
364struct EnumDeserializer<'a> {
365    variant: &'a str,
366    inner: VariantDeserializer<'a>,
367}
368
369impl<'a, 'de> de::EnumAccess<'de> for EnumDeserializer<'a> {
370    type Error = ErrorWithOrigin;
371    type Variant = VariantDeserializer<'a>;
372
373    fn variant_seed<V: DeserializeSeed<'de>>(
374        self,
375        seed: V,
376    ) -> Result<(V::Value, Self::Variant), Self::Error> {
377        let variant = self.variant.into_deserializer();
378        match seed.deserialize(variant) {
379            Ok(val) => Ok((val, self.inner)),
380            Err(err) => Err(ErrorWithOrigin::json(err, self.inner.origin().clone())),
381        }
382    }
383}
384
385#[derive(Debug)]
386struct VariantDeserializer<'a> {
387    value: Option<&'a WithOrigin>,
388    options: &'a DeserializerOptions,
389    parent_origin: Arc<ValueOrigin>,
390}
391
392impl VariantDeserializer<'_> {
393    fn origin(&self) -> &Arc<ValueOrigin> {
394        self.value.map_or(&self.parent_origin, |val| &val.origin)
395    }
396}
397
398impl<'de> de::VariantAccess<'de> for VariantDeserializer<'_> {
399    type Error = ErrorWithOrigin;
400
401    fn unit_variant(self) -> Result<(), Self::Error> {
402        if let Some(value) = self.value {
403            Deserialize::deserialize(ValueDeserializer::new(value, self.options))
404        } else {
405            Ok(())
406        }
407    }
408
409    fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value, Self::Error>
410    where
411        T: DeserializeSeed<'de>,
412    {
413        if let Some(value) = self.value {
414            seed.deserialize(ValueDeserializer::new(value, self.options))
415        } else {
416            let err = DeError::invalid_type(de::Unexpected::Unit, &"newtype variant");
417            Err(ErrorWithOrigin::json(err, self.parent_origin))
418        }
419    }
420
421    fn tuple_variant<V: de::Visitor<'de>>(
422        self,
423        _len: usize,
424        visitor: V,
425    ) -> Result<V::Value, Self::Error> {
426        if let Some(value) = self.value {
427            de::Deserializer::deserialize_seq(ValueDeserializer::new(value, self.options), visitor)
428        } else {
429            let err = DeError::invalid_type(de::Unexpected::Unit, &"tuple variant");
430            Err(ErrorWithOrigin::json(err, self.parent_origin))
431        }
432    }
433
434    fn struct_variant<V: de::Visitor<'de>>(
435        self,
436        _fields: &'static [&'static str],
437        visitor: V,
438    ) -> Result<V::Value, Self::Error> {
439        if let Some(value) = self.value {
440            de::Deserializer::deserialize_map(ValueDeserializer::new(value, self.options), visitor)
441        } else {
442            let err = DeError::invalid_type(de::Unexpected::Unit, &"struct variant");
443            Err(ErrorWithOrigin::json(err, self.parent_origin))
444        }
445    }
446}