smart_config/metadata/
_private.rs

1//! Metadata validations performed in compile time.
2
3use std::{any, marker::PhantomData};
4
5use compile_fmt::{Ascii, CompileArgs, clip, clip_ascii, compile_args, compile_panic};
6
7use super::{ConfigMetadata, NestedConfigMetadata, ParamMetadata};
8use crate::{
9    DeserializeConfig, DeserializeConfigError,
10    de::DeserializeContext,
11    utils::const_eq,
12    visit::{ConfigVisitor, VisitConfig},
13};
14
15pub type BoxedDeserializer =
16    fn(DeserializeContext<'_>) -> Result<Box<dyn any::Any>, DeserializeConfigError>;
17
18pub type BoxedVisitor = fn(&dyn any::Any, &mut dyn ConfigVisitor);
19
20pub const fn box_config_visitor<T: VisitConfig + 'static>() -> BoxedVisitor {
21    |boxed_config, visitor| {
22        let config = boxed_config
23            .downcast_ref::<T>()
24            .expect("Internal error: visit target has incorrect type");
25        config.visit_config(visitor);
26    }
27}
28
29pub trait DeserializeBoxedConfig {
30    fn deserialize_boxed_config(
31        &self,
32        ctx: DeserializeContext<'_>,
33    ) -> Result<Box<dyn any::Any>, DeserializeConfigError>;
34}
35
36impl<T: DeserializeConfig> DeserializeBoxedConfig for PhantomData<T> {
37    fn deserialize_boxed_config(
38        &self,
39        ctx: DeserializeContext<'_>,
40    ) -> Result<Box<dyn any::Any>, DeserializeConfigError> {
41        T::deserialize_config(ctx).map(|config| Box::new(config) as Box<dyn any::Any>)
42    }
43}
44
45impl<T> DeserializeBoxedConfig for &PhantomData<T> {
46    fn deserialize_boxed_config(
47        &self,
48        _ctx: DeserializeContext<'_>,
49    ) -> Result<Box<dyn any::Any>, DeserializeConfigError> {
50        Err(DeserializeConfigError::new())
51    }
52}
53
54const fn is_valid_start_name_char(ch: u8) -> bool {
55    ch == b'_' || ch.is_ascii_lowercase()
56}
57
58const fn is_valid_name_char(ch: u8) -> bool {
59    ch == b'_' || ch.is_ascii_lowercase() || ch.is_ascii_digit()
60}
61
62#[derive(Debug, Clone, Copy)]
63enum AllowedChars {
64    NameStart,
65    Name,
66    Path,
67}
68
69impl AllowedChars {
70    const fn as_str(self) -> Ascii<'static> {
71        Ascii::new(match self {
72            Self::NameStart => "[_a-z]",
73            Self::Name => "[_a-z0-9]",
74            Self::Path => "[_a-z0-9.]",
75        })
76    }
77}
78
79#[derive(Debug)]
80enum ValidationError {
81    Empty,
82    NonAscii {
83        pos: usize,
84    },
85    DisallowedChar {
86        pos: usize,
87        ch: char,
88        allowed: AllowedChars,
89    },
90}
91
92type ErrorArgs = CompileArgs<101>;
93
94impl ValidationError {
95    const fn fmt(self) -> ErrorArgs {
96        match self {
97            Self::Empty => compile_args!(capacity: ErrorArgs::CAPACITY, "name cannot be empty"),
98            Self::NonAscii { pos } => compile_args!(
99                capacity: ErrorArgs::CAPACITY,
100                "name contains non-ASCII chars, first at position ",
101                pos => compile_fmt::fmt::<usize>()
102            ),
103            Self::DisallowedChar { pos, ch, allowed } => compile_args!(
104                "name contains a disallowed char '",
105                ch => compile_fmt::fmt::<char>(),
106                "' at position ", pos => compile_fmt::fmt::<usize>(),
107                "; allowed chars are ",
108                allowed.as_str() => clip_ascii(10, "")
109            ),
110        }
111    }
112}
113
114const fn validate_name(name: &str) -> Result<(), ValidationError> {
115    if name.is_empty() {
116        return Err(ValidationError::Empty);
117    }
118
119    let name_bytes = name.as_bytes();
120    let mut pos = 0;
121    while pos < name.len() {
122        if name_bytes[pos] > 127 {
123            return Err(ValidationError::NonAscii { pos });
124        }
125        let ch = name_bytes[pos];
126        let is_disallowed = (pos == 0 && !is_valid_start_name_char(ch)) || !is_valid_name_char(ch);
127        if is_disallowed {
128            return Err(ValidationError::DisallowedChar {
129                pos,
130                ch: ch as char,
131                allowed: if pos == 0 {
132                    AllowedChars::NameStart
133                } else {
134                    AllowedChars::Name
135                },
136            });
137        }
138        pos += 1;
139    }
140    Ok(())
141}
142
143/// Checks that a param name is valid.
144#[track_caller]
145pub const fn assert_param_name(name: &str) {
146    if let Err(err) = validate_name(name) {
147        compile_panic!(
148            "Param / config name `", name => clip(32, "…"), "` is invalid: ",
149            &err.fmt() => compile_fmt::fmt::<&ErrorArgs>()
150        );
151    }
152}
153
154#[track_caller]
155pub const fn assert_param_alias(name: &str) {
156    let mut path_start = None;
157    if !name.is_empty() {
158        let name_bytes = name.as_bytes();
159        let mut pos = 0;
160        while pos < name.len() && name_bytes[pos] == b'.' {
161            pos += 1;
162        }
163
164        if pos > 0 {
165            path_start = Some(pos);
166        }
167    }
168
169    if let Some(path_start) = path_start {
170        if let Err(err) = validate_path(name, path_start) {
171            compile_panic!(
172                "Param / config alias path `", name => clip(32, "…"), "` is invalid: ",
173                &err.fmt() => compile_fmt::fmt::<&ErrorArgs>()
174            );
175        }
176    } else {
177        assert_param_name(name);
178    }
179}
180
181#[track_caller]
182const fn assert_param_against_config(
183    param_parent: &'static str,
184    param: &ParamMetadata,
185    config_parent: &'static str,
186    config: &NestedConfigMetadata,
187) {
188    let mut param_i = 0;
189    while param_i <= param.aliases.len() {
190        let param_name = if param_i == 0 {
191            param.name
192        } else {
193            param.aliases[param_i - 1].0
194        };
195
196        if param_name.as_bytes()[0] == b'.' {
197            // Path-like alias; skip checks
198        }
199
200        let mut config_i = 0;
201        while config_i <= config.aliases.len() {
202            let config_name = if config_i == 0 {
203                config.name
204            } else {
205                config.aliases[config_i - 1].0
206            };
207
208            if const_eq(param_name.as_bytes(), config_name.as_bytes()) {
209                compile_panic!(
210                    "Name / alias `", param_name => clip(32, "…"), "` of param `",
211                    param_parent => clip(32, "…"), ".",
212                    param.rust_field_name  => clip(32, "…"),
213                    "` coincides with a name / alias of a nested config `",
214                    config_parent => clip(32, "…"), ".",
215                    config.rust_field_name  => clip(32, "…"),
216                    "`. This is an unconditional error; \
217                    config deserialization relies on the fact that configs never coincide with params"
218                );
219            }
220
221            config_i += 1;
222        }
223        param_i += 1;
224    }
225
226    if const_eq(param.name.as_bytes(), config.name.as_bytes()) {
227        compile_panic!(
228            "Name `", param.name => clip(32, "…"), "` of param `",
229            param_parent => clip(32, "…"), ".",
230            param.rust_field_name  => clip(32, "…"),
231            "` coincides with a name of a nested config `",
232            config_parent => clip(32, "…"), ".",
233            config.rust_field_name  => clip(32, "…"),
234            "`. This is an unconditional error; \
235            config deserialization relies on the fact that configs never coincide with params"
236        );
237    }
238
239    let mut alias_i = 0;
240    while alias_i < param.aliases.len() {
241        let alias = param.aliases[alias_i].0;
242        if const_eq(alias.as_bytes(), config.name.as_bytes()) {
243            compile_panic!(
244                "Alias `", alias => clip(32, "…"), "` of param `",
245                param_parent => clip(32, "…"), ".",
246                param.rust_field_name  => clip(32, "…"),
247                "` coincides with a name of a nested config `",
248                config_parent => clip(32, "…"), ".",
249                config.rust_field_name  => clip(32, "…"),
250                "`. This is an unconditional error; \
251                config deserialization relies on the fact that configs never coincide with params"
252            );
253        }
254        alias_i += 1;
255    }
256}
257
258#[track_caller]
259const fn assert_param_name_is_not_a_config(
260    param_parent: &'static str,
261    param: &ParamMetadata,
262    config: &ConfigMetadata,
263) {
264    let mut config_i = 0;
265    while config_i < config.nested_configs.len() {
266        let nested = &config.nested_configs[config_i];
267        if nested.name.is_empty() {
268            // Flattened config; recurse.
269            assert_param_name_is_not_a_config(param_parent, param, nested.meta);
270        } else {
271            assert_param_against_config(param_parent, param, config.ty.name_in_code(), nested);
272        }
273        config_i += 1;
274    }
275}
276
277#[track_caller]
278const fn assert_config_name_is_not_a_param(
279    config_parent: &'static str,
280    config: &NestedConfigMetadata,
281    configs: &[NestedConfigMetadata],
282) {
283    let mut config_i = 0;
284    while config_i < configs.len() {
285        let flattened = &configs[config_i];
286        if flattened.name.is_empty() {
287            let param_parent = flattened.meta.ty.name_in_code();
288            let params = flattened.meta.params;
289            let mut param_i = 0;
290            while param_i < params.len() {
291                assert_param_against_config(param_parent, &params[param_i], config_parent, config);
292                param_i += 1;
293            }
294
295            // Recurse into the next level.
296            assert_config_name_is_not_a_param(config_parent, config, flattened.meta.nested_configs);
297        }
298        config_i += 1;
299    }
300}
301
302impl ConfigMetadata {
303    #[track_caller]
304    pub const fn assert_valid(&self) {
305        // Check that param names don't coincide with nested config names (both params and nested configs include
306        // ones through flattened configs, recursively). Having both a param and a config bound to the same location
307        // doesn't logically make sense, and accounting for it would make merging / deserialization logic unreasonably complex.
308        self.assert_params_are_not_configs();
309        self.assert_configs_are_not_params();
310    }
311
312    #[track_caller]
313    const fn assert_params_are_not_configs(&self) {
314        let mut param_i = 0;
315        while param_i < self.params.len() {
316            assert_param_name_is_not_a_config(self.ty.name_in_code(), &self.params[param_i], self);
317            param_i += 1;
318        }
319    }
320
321    #[track_caller]
322    const fn assert_configs_are_not_params(&self) {
323        let mut config_i = 0;
324        while config_i < self.nested_configs.len() {
325            let config = &self.nested_configs[config_i];
326            if !config.name.is_empty() {
327                assert_config_name_is_not_a_param(
328                    self.ty.name_in_code(),
329                    config,
330                    self.nested_configs,
331                );
332            }
333            config_i += 1;
334        }
335    }
336}
337
338// TODO: validate param types (non-empty intersection)
339
340const fn validate_path(name: &str, start: usize) -> Result<(), ValidationError> {
341    if name.is_empty() {
342        return Err(ValidationError::Empty);
343    }
344
345    let name_bytes = name.as_bytes();
346    let mut pos = start;
347    let mut is_segment_start = true;
348    while pos < name.len() {
349        if name_bytes[pos] > 127 {
350            return Err(ValidationError::NonAscii { pos });
351        }
352        let ch = name_bytes[pos];
353
354        let is_disallowed = (is_segment_start && !is_valid_start_name_char(ch))
355            || (ch != b'.' && !is_valid_name_char(ch));
356        if is_disallowed {
357            return Err(ValidationError::DisallowedChar {
358                pos,
359                ch: ch as char,
360                allowed: if is_segment_start {
361                    AllowedChars::NameStart
362                } else {
363                    AllowedChars::Path
364                },
365            });
366        }
367
368        is_segment_start = ch == b'.';
369        pos += 1;
370    }
371    Ok(())
372}
373
374const fn have_prefix_relation(a: &str, b: &str) -> bool {
375    let a = a.as_bytes();
376    let b = b.as_bytes();
377    let mut i = 0;
378    while i < a.len() && i < b.len() {
379        if a[i] != b[i] {
380            return false;
381        }
382        i += 1;
383    }
384
385    if a.len() == b.len() {
386        true
387    } else {
388        (a.len() < b.len() && b[a.len()] == b'.') || (b.len() < a.len() && a[b.len()] == b'.')
389    }
390}
391
392/// Asserts config paths for the `config!` macro.
393#[track_caller]
394pub const fn assert_paths(paths: &[&str]) {
395    // First, validate each path in isolation.
396    let mut i = 0;
397    while i < paths.len() {
398        let path = paths[i];
399        if let Err(err) = validate_path(path, 0) {
400            compile_panic!(
401                "Path #", i => compile_fmt::fmt::<usize>(), " `", path => clip(32, "…"), "` is invalid: ",
402                &err.fmt() => compile_fmt::fmt::<&ErrorArgs>()
403            );
404        }
405        i += 1;
406    }
407
408    let mut i = 0;
409    while i + 1 < paths.len() {
410        let path = paths[i];
411        let mut j = i + 1;
412        while j < paths.len() {
413            let other_path = paths[j];
414            if have_prefix_relation(path, other_path) {
415                let (short_i, short, long_i, long) = if path.len() < other_path.len() {
416                    (i, path, j, other_path)
417                } else {
418                    (j, other_path, i, path)
419                };
420
421                compile_panic!(
422                    "Path #", short_i => compile_fmt::fmt::<usize>(), " `", short => clip(32, "…"), "` is a prefix of path #",
423                    long_i => compile_fmt::fmt::<usize>(), " `", long => clip(32, "…"), "`"
424                );
425            }
426            j += 1;
427        }
428        i += 1;
429    }
430}
431
432#[cfg(test)]
433mod tests {
434    use assert_matches::assert_matches;
435
436    use super::*;
437
438    #[test]
439    fn validating_paths() {
440        validate_path("test", 0).unwrap();
441        validate_path("long.test_path._with_3_segments", 0).unwrap();
442
443        assert_matches!(
444            validate_path("test.pa!th", 0).unwrap_err(),
445            ValidationError::DisallowedChar { .. }
446        );
447        assert_matches!(
448            validate_path("test.3", 0).unwrap_err(),
449            ValidationError::DisallowedChar { .. }
450        );
451        assert_matches!(
452            validate_path("test..path", 0).unwrap_err(),
453            ValidationError::DisallowedChar { .. }
454        );
455    }
456
457    #[test]
458    fn checking_prefix_relations() {
459        assert!(have_prefix_relation("test", "test.path"));
460        assert!(have_prefix_relation("test.path", "test"));
461        assert!(have_prefix_relation("test.path", "test.path"));
462
463        assert!(!have_prefix_relation("test.path", "test_path"));
464        assert!(!have_prefix_relation("test", "test_path"));
465    }
466}