vise_macros/
labels.rs

1//! Derivation of `EncodeLabelValue` and `EncodeLabelSet` traits.
2
3use std::{collections::HashSet, fmt};
4
5use proc_macro::TokenStream;
6use quote::{quote, quote_spanned};
7use syn::{
8    Attribute, Data, DeriveInput, Expr, Field, Fields, Ident, LitStr, Path, PathArguments, Type,
9};
10
11use crate::utils::{ensure_no_generics, metrics_attribute, ParseAttribute};
12
13#[derive(Debug, Clone, Copy)]
14#[allow(clippy::enum_variant_names)]
15enum RenameRule {
16    LowerCase,
17    UpperCase,
18    CamelCase,
19    SnakeCase,
20    ScreamingSnakeCase,
21    KebabCase,
22    ScreamingKebabCase,
23}
24
25impl RenameRule {
26    fn parse(s: &str) -> Result<Self, &'static str> {
27        Ok(match s {
28            "lowercase" => Self::LowerCase,
29            "UPPERCASE" => Self::UpperCase,
30            "camelCase" => Self::CamelCase,
31            "snake_case" => Self::SnakeCase,
32            "SCREAMING_SNAKE_CASE" => Self::ScreamingSnakeCase,
33            "kebab-case" => Self::KebabCase,
34            "SCREAMING-KEBAB-CASE" => Self::ScreamingKebabCase,
35            _ => {
36                return Err(
37                    "Invalid case specified; should be one of: lowercase, UPPERCASE, camelCase, \
38                     snake_case, SCREAMING_SNAKE_CASE, kebab-case, SCREAMING-KEBAB-CASE",
39                )
40            }
41        })
42    }
43
44    fn transform(self, ident: &str) -> String {
45        debug_assert!(ident.is_ascii()); // Should be checked previously
46        let (spacing_char, scream) = match self {
47            Self::LowerCase => return ident.to_ascii_lowercase(),
48            Self::UpperCase => return ident.to_ascii_uppercase(),
49            Self::CamelCase => return ident[..1].to_ascii_lowercase() + &ident[1..],
50            // ^ Since `ident` is an ASCII string, indexing is safe
51            Self::SnakeCase => ('_', false),
52            Self::ScreamingSnakeCase => ('_', true),
53            Self::KebabCase => ('-', false),
54            Self::ScreamingKebabCase => ('-', true),
55        };
56
57        let mut output = String::with_capacity(ident.len());
58        for (i, ch) in ident.char_indices() {
59            if i > 0 && ch.is_ascii_uppercase() {
60                output.push(spacing_char);
61            }
62            output.push(if scream {
63                ch.to_ascii_uppercase()
64            } else {
65                ch.to_ascii_lowercase()
66            });
67        }
68        output
69    }
70}
71
72#[derive(Default)]
73struct EncodeLabelAttrs {
74    cr: Option<Path>,
75    rename_all: Option<RenameRule>,
76    format: Option<LitStr>,
77    label: Option<LitStr>,
78}
79
80impl EncodeLabelAttrs {
81    fn path_to_crate(&self, span: proc_macro2::Span) -> proc_macro2::TokenStream {
82        if let Some(cr) = &self.cr {
83            // Overriding the span for `cr` via `quote_spanned!` doesn't work.
84            quote!(#cr)
85        } else {
86            quote_spanned!(span=> vise)
87        }
88    }
89}
90
91impl fmt::Debug for EncodeLabelAttrs {
92    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
93        formatter
94            .debug_struct("EncodeLabelAttrs")
95            .field("cr", &self.cr.as_ref().map(|_| "_"))
96            .field("rename_all", &self.rename_all)
97            .field("format", &self.format.as_ref().map(LitStr::value))
98            .field("label", &self.label.as_ref().map(LitStr::value))
99            .finish()
100    }
101}
102
103impl ParseAttribute for EncodeLabelAttrs {
104    fn parse(raw: &Attribute) -> syn::Result<Self> {
105        let mut attrs = Self::default();
106        raw.parse_nested_meta(|meta| {
107            if meta.path.is_ident("crate") {
108                attrs.cr = Some(meta.value()?.parse()?);
109                Ok(())
110            } else if meta.path.is_ident("rename_all") {
111                let case_str: LitStr = meta.value()?.parse()?;
112                let case = RenameRule::parse(&case_str.value())
113                    .map_err(|message| syn::Error::new(case_str.span(), message))?;
114                attrs.rename_all = Some(case);
115                Ok(())
116            } else if meta.path.is_ident("format") {
117                attrs.format = Some(meta.value()?.parse()?);
118                Ok(())
119            } else if meta.path.is_ident("label") {
120                let label: LitStr = meta.value()?.parse()?;
121                attrs.label = Some(label);
122                Ok(())
123            } else {
124                Err(meta.error(
125                    "Unsupported attribute; only `crate`, `rename_all`, `format` and `label` \
126                     are supported (see `vise` crate docs for details)",
127                ))
128            }
129        })?;
130        Ok(attrs)
131    }
132}
133
134#[derive(Debug)]
135struct EnumVariant {
136    ident: Ident,
137    label_value: String,
138}
139
140impl EnumVariant {
141    fn encode(&self) -> proc_macro2::TokenStream {
142        let ident = &self.ident;
143        let label_value = &self.label_value;
144        quote!(Self::#ident => #label_value)
145    }
146}
147
148#[derive(Default)]
149struct EnumVariantAttrs {
150    name: Option<LitStr>,
151}
152
153impl fmt::Debug for EnumVariantAttrs {
154    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
155        formatter
156            .debug_struct("EnumVariantAttrs")
157            .field("name", &self.name.as_ref().map(LitStr::value))
158            .finish()
159    }
160}
161
162impl ParseAttribute for EnumVariantAttrs {
163    fn parse(raw: &Attribute) -> syn::Result<Self> {
164        let mut attrs = Self::default();
165        raw.parse_nested_meta(|meta| {
166            if meta.path.is_ident("name") {
167                attrs.name = Some(meta.value()?.parse()?);
168                Ok(())
169            } else {
170                Err(meta.error(
171                    "Unsupported attribute; only `name` is supported (see `vise` crate docs \
172                     for details)",
173                ))
174            }
175        })?;
176        Ok(attrs)
177    }
178}
179
180#[derive(Debug)]
181struct EncodeLabelValueImpl {
182    attrs: EncodeLabelAttrs,
183    name: Ident,
184    enum_variants: Option<Vec<EnumVariant>>,
185}
186
187impl EncodeLabelValueImpl {
188    fn new(raw: &DeriveInput) -> syn::Result<Self> {
189        let attrs = Self::parse_attrs(raw, "EncodeLabelValue")?;
190        let enum_variants = attrs
191            .rename_all
192            .map(|case| Self::extract_enum_variants(raw, case))
193            .transpose()?;
194
195        Ok(Self {
196            attrs,
197            enum_variants,
198            name: raw.ident.clone(),
199        })
200    }
201
202    fn parse_attrs(raw: &DeriveInput, derived_macro: &str) -> syn::Result<EncodeLabelAttrs> {
203        ensure_no_generics(&raw.generics, derived_macro)?;
204
205        let attrs: EncodeLabelAttrs = metrics_attribute(&raw.attrs)?;
206        if let Some(format) = &attrs.format {
207            if attrs.rename_all.is_some() {
208                let message = "`rename_all` and `format` attributes cannot be specified together";
209                return Err(syn::Error::new(format.span(), message));
210            }
211        }
212        Ok(attrs)
213    }
214
215    fn extract_enum_variants(raw: &DeriveInput, case: RenameRule) -> syn::Result<Vec<EnumVariant>> {
216        let Data::Enum(data) = &raw.data else {
217            let message = "`rename_all` attribute can only be placed on enums";
218            return Err(syn::Error::new_spanned(raw, message));
219        };
220
221        let mut unique_label_values = HashSet::with_capacity(data.variants.len());
222        let variants = data.variants.iter().map(|variant| {
223            if !matches!(variant.fields, Fields::Unit) {
224                let message = "To use `rename_all` attribute, all enum variants must be plain \
225                    (have no fields)";
226                return Err(syn::Error::new_spanned(variant, message));
227            }
228            let ident_str = variant.ident.to_string();
229            if !ident_str.is_ascii() {
230                let message = "Variant name must consist of ASCII chars";
231                return Err(syn::Error::new(variant.ident.span(), message));
232            }
233            let attrs: EnumVariantAttrs = metrics_attribute(&variant.attrs)?;
234            let label_value = if let Some(name_override) = attrs.name {
235                name_override.value()
236            } else {
237                case.transform(&ident_str)
238            };
239            if !unique_label_values.insert(label_value.clone()) {
240                let message = format!("Label value `{label_value}` is redefined");
241                return Err(syn::Error::new_spanned(variant, message));
242            }
243
244            Ok(EnumVariant {
245                ident: variant.ident.clone(),
246                label_value,
247            })
248        });
249        variants.collect()
250    }
251
252    fn impl_value(&self) -> proc_macro2::TokenStream {
253        let cr = self.attrs.path_to_crate(proc_macro2::Span::call_site());
254        let name = &self.name;
255        let encoding = quote!(#cr::_reexports::encoding);
256
257        let encode_impl = if let Some(enum_variants) = &self.enum_variants {
258            let variant_hands = enum_variants.iter().map(EnumVariant::encode);
259            quote! {
260                use ::core::fmt::Write as _;
261                ::core::write!(encoder, "{}", match self {
262                    #(#variant_hands,)*
263                })
264            }
265        } else {
266            let format_lit;
267            let format = if let Some(format) = &self.attrs.format {
268                format
269            } else {
270                format_lit = LitStr::new("{}", name.span());
271                &format_lit
272            };
273
274            quote_spanned! {format.span()=>
275                use ::core::fmt::Write as _;
276                ::core::write!(encoder, #format, self)
277            }
278        };
279
280        quote! {
281            impl #encoding::EncodeLabelValue for #name {
282                fn encode(
283                    &self,
284                    encoder: &mut #encoding::LabelValueEncoder<'_>,
285                ) -> core::fmt::Result {
286                    #encode_impl
287                }
288            }
289        }
290    }
291}
292
293#[derive(Default)]
294struct LabelFieldAttrs {
295    skip: Option<Path>,
296    unit: Option<Expr>,
297}
298
299impl fmt::Debug for LabelFieldAttrs {
300    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
301        formatter
302            .debug_struct("LabelFieldAttrs")
303            .field("skip", &self.skip.as_ref().map(|_| ".."))
304            .field("unit", &self.unit.as_ref().map(|_| ".."))
305            .finish()
306    }
307}
308
309impl ParseAttribute for LabelFieldAttrs {
310    fn parse(raw: &Attribute) -> syn::Result<Self> {
311        let mut attrs = Self::default();
312        raw.parse_nested_meta(|meta| {
313            if meta.path.is_ident("skip") {
314                attrs.skip = Some(meta.value()?.parse()?);
315                Ok(())
316            } else if meta.path.is_ident("unit") {
317                attrs.unit = Some(meta.value()?.parse()?);
318                Ok(())
319            } else {
320                Err(meta.error("unsupported attribute"))
321            }
322        })?;
323        Ok(attrs)
324    }
325}
326
327#[derive(Debug)]
328struct LabelField {
329    name: Ident,
330    is_option: bool,
331    attrs: LabelFieldAttrs,
332}
333
334impl LabelField {
335    fn parse(raw: &Field) -> syn::Result<Self> {
336        let name = raw.ident.clone().ok_or_else(|| {
337            let message = "Encoded fields must be named";
338            syn::Error::new_spanned(raw, message)
339        })?;
340
341        Ok(Self {
342            name,
343            is_option: Self::detect_is_option(&raw.ty),
344            attrs: metrics_attribute(&raw.attrs)?,
345        })
346    }
347
348    /// Strips the `r#` prefix from raw identifiers.
349    fn label_string(&self) -> String {
350        let label = self.name.to_string();
351        if let Some(stripped) = label.strip_prefix("r#") {
352            stripped.to_owned()
353        } else {
354            label
355        }
356    }
357
358    fn label_literal(&self) -> LitStr {
359        let name = &self.name;
360        LitStr::new(&self.label_string(), name.span())
361    }
362
363    fn detect_is_option(ty: &Type) -> bool {
364        let Type::Path(ty) = ty else {
365            return false;
366        };
367        if ty.path.segments.len() != 1 {
368            return false;
369        }
370        let first_segment = ty.path.segments.first().unwrap();
371        first_segment.ident == "Option"
372            && matches!(
373                &first_segment.arguments,
374                PathArguments::AngleBracketed(args) if args.args.len() == 1
375            )
376    }
377
378    fn encode(&self, cr: &proc_macro2::TokenStream) -> proc_macro2::TokenStream {
379        let encoding = quote!(#cr::_reexports::encoding);
380
381        let name = &self.name;
382        let span = name.span();
383        let label = self.label_literal();
384        let label = if let Some(unit) = &self.attrs.unit {
385            quote_spanned!(span=> #cr::LabelWithUnit::new(#label, #unit))
386        } else {
387            quote_spanned!(span=> #label)
388        };
389
390        // Skip `Option`al fields by default if they are `None`.
391        let default_skip: Path;
392        let skip = if self.is_option && self.attrs.skip.is_none() {
393            default_skip = syn::parse_quote_spanned!(span=> ::core::option::Option::is_none);
394            Some(&default_skip)
395        } else {
396            self.attrs.skip.as_ref()
397        };
398
399        let encode_inner = quote_spanned! {span=>
400            let mut label_encoder = encoder.encode_label();
401            let mut key_encoder = label_encoder.encode_label_key()?;
402            #encoding::EncodeLabelKey::encode(&#label, &mut key_encoder)?;
403            let mut value_encoder = key_encoder.encode_label_value()?;
404            {
405                let _guard = #cr::_private::EncodingContext::LabelValue.enter();
406                #encoding::EncodeLabelValue::encode(&self.#name, &mut value_encoder)?;
407            }
408            value_encoder.finish()?;
409        };
410        if let Some(skip) = skip {
411            quote_spanned! {span=>
412                #[allow(clippy::needless_borrow)]
413                // ^ Allows for common cases, such as applying `str::is_empty` for a `&str` field
414                if !#skip(&self.#name) {
415                    #encode_inner
416                }
417            }
418        } else {
419            encode_inner
420        }
421    }
422}
423
424#[derive(Debug)]
425struct EncodeLabelSetImpl {
426    attrs: EncodeLabelAttrs,
427    name: Ident,
428    fields: Option<Vec<LabelField>>,
429}
430
431impl EncodeLabelSetImpl {
432    fn new(raw: &DeriveInput) -> syn::Result<Self> {
433        let attrs = EncodeLabelValueImpl::parse_attrs(raw, "EncodeLabelSet")?;
434        let name = raw.ident.clone();
435
436        let fields = if attrs.label.is_some() {
437            None
438        } else {
439            let Data::Struct(data) = &raw.data else {
440                let message = "Non-singleton `EncodeLabelSet` can only be derived on structs";
441                return Err(syn::Error::new_spanned(raw, message));
442            };
443            let fields: syn::Result<_> = data.fields.iter().map(LabelField::parse).collect();
444            Some(fields?)
445        };
446
447        Ok(Self {
448            attrs,
449            name,
450            fields,
451        })
452    }
453
454    fn validate(&self) -> proc_macro2::TokenStream {
455        let label_assertions = if let Some(label) = &self.attrs.label {
456            let span = label.span();
457            let cr = self.attrs.path_to_crate(span);
458            quote_spanned!(span=> #cr::validation::assert_label_name(#label);)
459        } else {
460            let fields = self.fields.as_ref().unwrap();
461            let field_assertions = fields.iter().map(|field| {
462                let label = field.label_literal();
463                let span = label.span();
464                let cr = self.attrs.path_to_crate(span);
465                quote_spanned!(span=> #cr::validation::assert_label_name(#label))
466            });
467            quote!(#(#field_assertions;)*)
468        };
469        quote! {
470            const _: () = { #label_assertions };
471        }
472    }
473
474    fn impl_set(&self) -> proc_macro2::TokenStream {
475        let encode_impl = if let Some(label) = &self.attrs.label {
476            let cr = self.attrs.path_to_crate(label.span());
477            let encoding = quote!(#cr::_reexports::encoding);
478            quote_spanned! {label.span()=>
479                let mut label_encoder = encoder.encode_label();
480                let mut key_encoder = label_encoder.encode_label_key()?;
481                #encoding::EncodeLabelKey::encode(&#label, &mut key_encoder)?;
482                let mut value_encoder = key_encoder.encode_label_value()?;
483                {
484                    let _guard = #cr::_private::EncodingContext::LabelValue.enter();
485                    #encoding::EncodeLabelValue::encode(self, &mut value_encoder)?;
486                }
487                value_encoder.finish()
488            }
489        } else {
490            let fields = self.fields.as_ref().unwrap();
491            let fields = fields.iter().map(|field| {
492                let cr = self.attrs.path_to_crate(field.name.span());
493                field.encode(&cr)
494            });
495            quote! {
496                #(#fields)*
497                ::core::fmt::Result::Ok(())
498            }
499        };
500
501        let name = &self.name;
502        let cr = self.attrs.path_to_crate(proc_macro2::Span::call_site());
503        let encoding = quote!(#cr::_reexports::encoding);
504        quote! {
505            impl #cr::traits::EncodeLabelSet for #name {
506                fn encode(
507                    &self,
508                    encoder: &mut #encoding::LabelSetEncoder<'_>,
509                ) -> ::core::fmt::Result {
510                    #encode_impl
511                }
512            }
513        }
514    }
515}
516
517pub(crate) fn impl_encode_label_value(input: TokenStream) -> TokenStream {
518    let input: DeriveInput = syn::parse(input).unwrap();
519    let trait_impl = match EncodeLabelValueImpl::new(&input) {
520        Ok(trait_impl) => trait_impl,
521        Err(err) => return err.into_compile_error().into(),
522    };
523    trait_impl.impl_value().into()
524}
525
526pub(crate) fn impl_encode_label_set(input: TokenStream) -> TokenStream {
527    let input: DeriveInput = syn::parse(input).unwrap();
528    let trait_impl = match EncodeLabelSetImpl::new(&input) {
529        Ok(trait_impl) => trait_impl,
530        Err(err) => return err.into_compile_error().into(),
531    };
532    let validations = trait_impl.validate();
533    let set_impl = trait_impl.impl_set();
534    quote!(#validations #set_impl).into()
535}
536
537#[cfg(test)]
538mod tests {
539    use super::*;
540
541    #[test]
542    fn renaming_rules() {
543        let ident = "TestIdent";
544        let rules_and_expected_outcomes = [
545            (RenameRule::LowerCase, "testident"),
546            (RenameRule::UpperCase, "TESTIDENT"),
547            (RenameRule::CamelCase, "testIdent"),
548            (RenameRule::SnakeCase, "test_ident"),
549            (RenameRule::ScreamingSnakeCase, "TEST_IDENT"),
550            (RenameRule::KebabCase, "test-ident"),
551            (RenameRule::ScreamingKebabCase, "TEST-IDENT"),
552        ];
553        for (rule, expected) in rules_and_expected_outcomes {
554            assert_eq!(rule.transform(ident), expected);
555        }
556    }
557
558    #[test]
559    fn encoding_label_set() {
560        let input: DeriveInput = syn::parse_quote! {
561            struct TestLabels {
562                r#type: &'static str,
563                #[metrics(skip = str::is_empty)]
564                kind: &'static str,
565            }
566        };
567        let label_set = EncodeLabelSetImpl::new(&input).unwrap();
568        let fields = label_set.fields.as_ref().unwrap();
569        assert_eq!(fields.len(), 2);
570        assert_eq!(fields[0].label_string(), "type");
571        assert_eq!(fields[1].label_string(), "kind");
572        assert!(fields[1].attrs.skip.is_some());
573    }
574
575    #[test]
576    fn label_value_redefinition_error() {
577        let input: DeriveInput = syn::parse_quote! {
578            #[metrics(rename_all = "snake_case")]
579            enum Label {
580                First,
581                #[metrics(name = "first")]
582                Second,
583            }
584        };
585        let err = EncodeLabelValueImpl::new(&input).unwrap_err().to_string();
586        assert!(err.contains("Label value `first` is redefined"), "{err}");
587    }
588}