smart_config/source/
yaml.rs

1use std::sync::Arc;
2
3use anyhow::Context;
4
5use super::{ConfigSource, Hierarchical};
6use crate::value::{FileFormat, Map, Pointer, Value, ValueOrigin, WithOrigin};
7
8/// YAML-based configuration source.
9#[derive(Debug, Clone)]
10pub struct Yaml {
11    origin: Arc<ValueOrigin>,
12    inner: Map,
13}
14
15impl Yaml {
16    /// Creates a source with the specified name and contents.
17    ///
18    /// # Errors
19    ///
20    /// Returns an error if the input doesn't conform to the JSON object model; e.g., if it has objects / maps
21    /// with array or object keys.
22    pub fn new(filename: &str, object: serde_yaml::Mapping) -> anyhow::Result<Self> {
23        let origin = Arc::new(ValueOrigin::File {
24            name: filename.to_owned(),
25            format: FileFormat::Yaml,
26        });
27        let inner =
28            Self::map_value(serde_yaml::Value::Mapping(object), &origin, String::new())?.inner;
29        let Value::Object(inner) = inner else {
30            unreachable!();
31        };
32        Ok(Self { origin, inner })
33    }
34
35    fn map_key(key: serde_yaml::Value, parent_path: &str) -> anyhow::Result<String> {
36        Ok(match key {
37            serde_yaml::Value::String(value) => value,
38            serde_yaml::Value::Number(value) => value.to_string(),
39            serde_yaml::Value::Bool(value) => value.to_string(),
40            serde_yaml::Value::Null => "null".into(),
41            _ => anyhow::bail!(
42                "unsupported key type at {parent_path:?}: {key:?}; only primitive value types are supported as keys"
43            ),
44        })
45    }
46
47    fn map_number(number: &serde_yaml::Number, path: &str) -> anyhow::Result<serde_json::Number> {
48        Ok(if let Some(number) = number.as_u64() {
49            number.into()
50        } else if let Some(number) = number.as_i64() {
51            number.into()
52        } else if let Some(number) = number.as_f64() {
53            serde_json::Number::from_f64(number)
54                .with_context(|| format!("unsupported number at {path:?}: {number:?}"))?
55        } else {
56            anyhow::bail!("unsupported number at {path:?}: {number:?}")
57        })
58    }
59
60    fn map_value(
61        value: serde_yaml::Value,
62        file_origin: &Arc<ValueOrigin>,
63        path: String,
64    ) -> anyhow::Result<WithOrigin> {
65        let inner = match value {
66            serde_yaml::Value::Null => Value::Null,
67            serde_yaml::Value::Bool(value) => value.into(),
68            serde_yaml::Value::Number(value) => Value::Number(Self::map_number(&value, &path)?),
69            serde_yaml::Value::String(value) => value.into(),
70            serde_yaml::Value::Sequence(items) => Value::Array(
71                items
72                    .into_iter()
73                    .enumerate()
74                    .map(|(i, value)| {
75                        let child_path = Pointer(&path).join(&i.to_string());
76                        Self::map_value(value, file_origin, child_path)
77                    })
78                    .collect::<anyhow::Result<_>>()?,
79            ),
80            serde_yaml::Value::Mapping(items) => Value::Object(
81                items
82                    .into_iter()
83                    .map(|(key, value)| {
84                        let key = Self::map_key(key, &path)?;
85                        let child_path = Pointer(&path).join(&key);
86                        anyhow::Ok((key, Self::map_value(value, file_origin, child_path)?))
87                    })
88                    .collect::<anyhow::Result<_>>()?,
89            ),
90            serde_yaml::Value::Tagged(tagged) => {
91                return Self::map_value(tagged.value, file_origin, path);
92            }
93        };
94
95        Ok(WithOrigin {
96            inner,
97            origin: if path.is_empty() {
98                file_origin.clone()
99            } else {
100                Arc::new(ValueOrigin::Path {
101                    source: file_origin.clone(),
102                    path,
103                })
104            },
105        })
106    }
107}
108
109impl ConfigSource for Yaml {
110    type Kind = Hierarchical;
111
112    fn into_contents(self) -> WithOrigin<Map> {
113        WithOrigin::new(self.inner, self.origin)
114    }
115}
116
117#[cfg(test)]
118mod tests {
119    use assert_matches::assert_matches;
120
121    use super::*;
122    use crate::value::StrValue;
123
124    const YAML_CONFIG: &str = r#"
125bool: true
126nested:
127    int: 123
128    string: "what?"
129array:
130    - test: 23
131    "#;
132
133    fn filename(source: &ValueOrigin) -> &str {
134        if let ValueOrigin::File {
135            name,
136            format: FileFormat::Yaml,
137        } = source
138        {
139            name
140        } else {
141            panic!("unexpected source: {source:?}");
142        }
143    }
144
145    #[test]
146    fn creating_yaml_config() {
147        let yaml: serde_yaml::Value = serde_yaml::from_str(YAML_CONFIG).unwrap();
148        let serde_yaml::Value::Mapping(yaml) = yaml else {
149            unreachable!();
150        };
151        let yaml = Yaml::new("test.yml", yaml).unwrap();
152
153        assert_matches!(yaml.inner["bool"].inner, Value::Bool(true));
154        assert_matches!(
155            yaml.inner["bool"].origin.as_ref(),
156            ValueOrigin::Path { path, source } if filename(source) == "test.yml" && path == "bool"
157        );
158
159        let str = yaml.inner["nested"].get(Pointer("string")).unwrap();
160        assert_matches!(&str.inner, Value::String(StrValue::Plain(s)) if s == "what?");
161        assert_matches!(
162            str.origin.as_ref(),
163            ValueOrigin::Path { path, source } if filename(source) == "test.yml" && path == "nested.string"
164        );
165
166        let inner_int = yaml.inner["array"].get(Pointer("0.test")).unwrap();
167        assert_matches!(&inner_int.inner, Value::Number(num) if *num == 23_u64.into());
168    }
169
170    #[test]
171    fn unsupported_key() {
172        let yaml = r"
173array:
174    - [12, 34]: bogus
175        ";
176        let yaml: serde_yaml::Value = serde_yaml::from_str(yaml).unwrap();
177        let serde_yaml::Value::Mapping(yaml) = yaml else {
178            unreachable!();
179        };
180
181        let err = Yaml::new("test.yml", yaml).unwrap_err().to_string();
182        assert!(err.contains("unsupported key type"), "{err}");
183        assert!(err.contains("array.0"), "{err}");
184    }
185}