use std::{fmt, marker::PhantomData, str::FromStr, time::Duration};
use serde::{
de::{self, EnumAccess, Error as DeError, Unexpected, VariantAccess},
Deserialize, Deserializer,
};
use crate::{
de::{CustomKnownOption, DeserializeContext, DeserializeParam, Optional, WellKnown},
error::ErrorWithOrigin,
metadata::{BasicTypes, ParamMetadata, SizeUnit, TimeUnit, TypeDescription, TypeSuffixes},
value::Value,
ByteSize,
};
impl TimeUnit {
fn overflow_err(self, raw_val: u64) -> serde_json::Error {
let plural = self.plural();
DeError::custom(format!(
"{raw_val} {plural} does not fit into `u64` when converted to seconds"
))
}
fn into_duration(self, raw_value: u64) -> Result<Duration, serde_json::Error> {
self.checked_mul(raw_value)
.ok_or_else(|| self.overflow_err(raw_value))
}
}
impl DeserializeParam<Duration> for TimeUnit {
const EXPECTING: BasicTypes = BasicTypes::INTEGER;
fn describe(&self, description: &mut TypeDescription) {
description
.set_details("time duration")
.set_unit((*self).into());
}
fn deserialize_param(
&self,
ctx: DeserializeContext<'_>,
param: &'static ParamMetadata,
) -> Result<Duration, ErrorWithOrigin> {
let deserializer = ctx.current_value_deserializer(param.name)?;
let raw_value = u64::deserialize(deserializer)?;
self.into_duration(raw_value)
.map_err(|err| deserializer.enrich_err(err))
}
fn serialize_param(&self, param: &Duration) -> serde_json::Value {
match self {
Self::Millis => serde_json::to_value(param.as_millis()).unwrap(),
Self::Seconds => param.as_secs().into(),
Self::Minutes => (param.as_secs() / 60).into(),
Self::Hours => (param.as_secs() / 3_600).into(),
Self::Days => (param.as_secs() / 86_400).into(),
Self::Weeks => (param.as_secs() / 86_400 / 7).into(),
}
}
}
impl DeserializeParam<ByteSize> for SizeUnit {
const EXPECTING: BasicTypes = BasicTypes::INTEGER;
fn describe(&self, description: &mut TypeDescription) {
description
.set_details("byte size")
.set_unit((*self).into());
}
fn deserialize_param(
&self,
ctx: DeserializeContext<'_>,
param: &'static ParamMetadata,
) -> Result<ByteSize, ErrorWithOrigin> {
let deserializer = ctx.current_value_deserializer(param.name)?;
let raw_value = u64::deserialize(deserializer)?;
ByteSize::checked(raw_value, *self).ok_or_else(|| {
let err = DeError::custom(format!(
"{raw_value} {unit} does not fit into `u64`",
unit = self.plural()
));
deserializer.enrich_err(err)
})
}
fn serialize_param(&self, param: &ByteSize) -> serde_json::Value {
match self {
Self::Bytes => param.0.into(),
Self::KiB => (param.0 >> 10).into(),
Self::MiB => (param.0 >> 20).into(),
Self::GiB => (param.0 >> 30).into(),
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct WithUnit;
impl WithUnit {
const EXPECTED_TYPES: BasicTypes = BasicTypes::STRING.or(BasicTypes::OBJECT);
fn deserialize<Raw, T>(
ctx: &DeserializeContext<'_>,
param: &'static ParamMetadata,
) -> Result<T, ErrorWithOrigin>
where
Raw: EnumWithUnit + TryInto<T, Error = serde_json::Error>,
{
let deserializer = ctx.current_value_deserializer(param.name)?;
let raw = if let Value::String(s) = deserializer.value() {
s.expose()
.parse::<Raw>()
.map_err(|err| deserializer.enrich_err(err))?
} else {
deserializer.deserialize_enum("Raw", Raw::VARIANTS, EnumVisitor(PhantomData::<Raw>))?
};
raw.try_into().map_err(|err| deserializer.enrich_err(err))
}
fn deserialize_opt<Raw, T>(
ctx: &DeserializeContext<'_>,
param: &'static ParamMetadata,
) -> Result<Option<T>, ErrorWithOrigin>
where
Raw: EnumWithUnit + TryInto<T, Error = serde_json::Error>,
{
let deserializer = ctx.current_value_deserializer(param.name)?;
let raw = if let Value::String(s) = deserializer.value() {
Some(
s.expose()
.parse::<Raw>()
.map_err(|err| deserializer.enrich_err(err))?,
)
} else {
deserializer.deserialize_enum(
"Raw",
Raw::VARIANTS,
EnumVisitor(PhantomData::<Option<Raw>>),
)?
};
let Some(raw) = raw else {
return Ok(None);
};
raw.try_into()
.map(Some)
.map_err(|err| deserializer.enrich_err(err))
}
}
trait EnumWithUnit: FromStr<Err = serde_json::Error> {
const VARIANTS: &'static [&'static str];
fn extract_variant(unit: &str) -> Option<fn(u64) -> Self>;
fn parse<E: de::Error>(unit: &str, value: u64) -> Result<Self, E> {
let variant_mapper = Self::extract_variant(unit)
.ok_or_else(|| DeError::unknown_variant(unit, Self::VARIANTS))?;
Ok(variant_mapper(value))
}
fn parse_opt<E: de::Error>(unit: &str, value: Option<u64>) -> Result<Option<Self>, E> {
let variant_mapper = Self::extract_variant(unit)
.ok_or_else(|| DeError::unknown_variant(unit, Self::VARIANTS))?;
Ok(value.map(variant_mapper))
}
}
#[derive(Debug)]
struct EnumVisitor<T>(PhantomData<T>);
impl<'v, T: EnumWithUnit> de::Visitor<'v> for EnumVisitor<T> {
type Value = T;
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(formatter, "enum with one of {:?} variants", T::VARIANTS)
}
fn visit_enum<A: EnumAccess<'v>>(self, data: A) -> Result<Self::Value, A::Error> {
let (tag, payload) = data.variant::<String>()?;
let value = payload.newtype_variant()?;
let unit = tag.strip_prefix("in_").unwrap_or(&tag);
T::parse(unit, value)
}
}
impl<'v, T: EnumWithUnit> de::Visitor<'v> for EnumVisitor<Option<T>> {
type Value = Option<T>;
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(formatter, "enum with one of {:?} variants", T::VARIANTS)
}
fn visit_enum<A: EnumAccess<'v>>(self, data: A) -> Result<Self::Value, A::Error> {
let (tag, payload) = data.variant::<String>()?;
let value = payload.newtype_variant()?;
let unit = tag.strip_prefix("in_").unwrap_or(&tag);
T::parse_opt(unit, value)
}
}
#[derive(Debug)]
#[cfg_attr(test, derive(PartialEq))]
enum RawDuration {
Millis(u64),
Seconds(u64),
Minutes(u64),
Hours(u64),
Days(u64),
Weeks(u64),
}
macro_rules! impl_enum_with_unit {
($($($name:tt)|+ => $func:expr,)+) => {
const VARIANTS: &'static [&'static str] = &[$($($name,)+)+];
fn extract_variant(unit: &str) -> Option<fn(u64) -> Self> {
Some(match unit {
$($($name )|+ => $func,)+
_ => return None,
})
}
};
}
impl EnumWithUnit for RawDuration {
impl_enum_with_unit!(
"milliseconds" | "millis" | "ms" => Self::Millis,
"seconds" | "second" | "secs" | "sec" | "s" => Self::Seconds,
"minutes" | "minute" | "mins" | "min" | "m" => Self::Minutes,
"hours" | "hour" | "hr" | "h" => Self::Hours,
"days" | "day" | "d" => Self::Days,
"weeks" | "week" | "w" => Self::Weeks,
);
}
impl RawDuration {
const EXPECTING: &'static str = "value with unit, like '10 ms'";
}
impl FromStr for RawDuration {
type Err = serde_json::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let unit_start = s
.find(|ch: char| !ch.is_ascii_digit())
.ok_or_else(|| DeError::invalid_type(Unexpected::Str(s), &Self::EXPECTING))?;
if unit_start == 0 {
return Err(DeError::invalid_type(Unexpected::Str(s), &Self::EXPECTING));
}
let value: u64 = s[..unit_start].parse().map_err(DeError::custom)?;
let unit = s[unit_start..].trim();
Self::parse(unit, value)
}
}
impl TryFrom<RawDuration> for Duration {
type Error = serde_json::Error;
fn try_from(value: RawDuration) -> Result<Self, Self::Error> {
let (unit, raw_value) = match value {
RawDuration::Millis(val) => (TimeUnit::Millis, val),
RawDuration::Seconds(val) => (TimeUnit::Seconds, val),
RawDuration::Minutes(val) => (TimeUnit::Minutes, val),
RawDuration::Hours(val) => (TimeUnit::Hours, val),
RawDuration::Days(val) => (TimeUnit::Days, val),
RawDuration::Weeks(val) => (TimeUnit::Weeks, val),
};
unit.into_duration(raw_value)
}
}
impl DeserializeParam<Duration> for WithUnit {
const EXPECTING: BasicTypes = Self::EXPECTED_TYPES;
fn describe(&self, description: &mut TypeDescription) {
description.set_details("duration with unit, or object with single unit key");
description.set_suffixes(TypeSuffixes::DurationUnits);
}
fn deserialize_param(
&self,
ctx: DeserializeContext<'_>,
param: &'static ParamMetadata,
) -> Result<Duration, ErrorWithOrigin> {
Self::deserialize::<RawDuration, _>(&ctx, param)
}
fn serialize_param(&self, param: &Duration) -> serde_json::Value {
if param.is_zero() {
return "0s".into();
}
let duration_string = if param.subsec_millis() != 0 {
format!("{}ms", param.as_millis())
} else {
let seconds = param.as_secs();
if seconds % 60 != 0 {
format!("{seconds}s")
} else if seconds % 3_600 != 0 {
format!("{}min", seconds / 60)
} else if seconds % 86_400 != 0 {
format!("{}h", seconds / 3_600)
} else if seconds % (86_400 * 7) != 0 {
format!("{}d", seconds / 86_400)
} else {
format!("{}w", seconds / (86_400 * 7))
}
};
duration_string.into()
}
}
impl DeserializeParam<Option<Duration>> for WithUnit {
const EXPECTING: BasicTypes = Self::EXPECTED_TYPES;
fn describe(&self, description: &mut TypeDescription) {
<Self as DeserializeParam<Duration>>::describe(self, description);
}
fn deserialize_param(
&self,
ctx: DeserializeContext<'_>,
param: &'static ParamMetadata,
) -> Result<Option<Duration>, ErrorWithOrigin> {
Self::deserialize_opt::<RawDuration, _>(&ctx, param)
}
fn serialize_param(&self, param: &Option<Duration>) -> serde_json::Value {
match param {
Some(val) => self.serialize_param(val),
None => serde_json::Value::Null,
}
}
}
impl WellKnown for Duration {
type Deserializer = WithUnit;
const DE: Self::Deserializer = WithUnit;
}
impl CustomKnownOption for Duration {
type OptDeserializer = Optional<WithUnit, true>;
const OPT_DE: Self::OptDeserializer = Optional(WithUnit);
}
#[derive(Debug)]
#[cfg_attr(test, derive(PartialEq))]
enum RawByteSize {
Bytes(u64),
Kilobytes(u64),
Megabytes(u64),
Gigabytes(u64),
}
impl EnumWithUnit for RawByteSize {
impl_enum_with_unit!(
"bytes" | "b" => Self::Bytes,
"kilobytes" | "kb" | "kib" => Self::Kilobytes,
"megabytes" | "mb" | "mib" => Self::Megabytes,
"gigabytes" | "gb" | "gib" => Self::Gigabytes,
);
}
impl RawByteSize {
const EXPECTING: &'static str = "value with unit, like '32 MB'";
}
impl<'de> Deserialize<'de> for RawByteSize {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
deserializer.deserialize_enum(
"RawByteSize",
Self::VARIANTS,
EnumVisitor(PhantomData::<Self>),
)
}
}
impl FromStr for RawByteSize {
type Err = serde_json::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let unit_start = s
.find(|ch: char| !ch.is_ascii_digit())
.ok_or_else(|| DeError::invalid_type(Unexpected::Str(s), &Self::EXPECTING))?;
if unit_start == 0 {
return Err(DeError::invalid_type(Unexpected::Str(s), &Self::EXPECTING));
}
let value: u64 = s[..unit_start].parse().map_err(DeError::custom)?;
let unit = s[unit_start..].trim();
Self::parse(&unit.to_lowercase(), value)
}
}
impl TryFrom<RawByteSize> for ByteSize {
type Error = serde_json::Error;
fn try_from(value: RawByteSize) -> Result<Self, Self::Error> {
let (unit, raw_value) = match value {
RawByteSize::Bytes(val) => (SizeUnit::Bytes, val),
RawByteSize::Kilobytes(val) => (SizeUnit::KiB, val),
RawByteSize::Megabytes(val) => (SizeUnit::MiB, val),
RawByteSize::Gigabytes(val) => (SizeUnit::GiB, val),
};
ByteSize::checked(raw_value, unit).ok_or_else(|| {
DeError::custom(format!(
"{raw_value} {unit} does not fit into `u64`",
unit = unit.plural()
))
})
}
}
impl DeserializeParam<ByteSize> for WithUnit {
const EXPECTING: BasicTypes = Self::EXPECTED_TYPES;
fn describe(&self, description: &mut TypeDescription) {
description.set_details("size with unit, or object with single unit key");
description.set_suffixes(TypeSuffixes::SizeUnits);
}
fn deserialize_param(
&self,
ctx: DeserializeContext<'_>,
param: &'static ParamMetadata,
) -> Result<ByteSize, ErrorWithOrigin> {
Self::deserialize::<RawByteSize, _>(&ctx, param)
}
fn serialize_param(&self, param: &ByteSize) -> serde_json::Value {
param.to_string().into()
}
}
impl DeserializeParam<Option<ByteSize>> for WithUnit {
const EXPECTING: BasicTypes = Self::EXPECTED_TYPES;
fn describe(&self, description: &mut TypeDescription) {
<Self as DeserializeParam<ByteSize>>::describe(self, description);
}
fn deserialize_param(
&self,
ctx: DeserializeContext<'_>,
param: &'static ParamMetadata,
) -> Result<Option<ByteSize>, ErrorWithOrigin> {
Self::deserialize_opt::<RawByteSize, _>(&ctx, param)
}
fn serialize_param(&self, param: &Option<ByteSize>) -> serde_json::Value {
match param {
Some(val) => val.to_string().into(),
None => serde_json::Value::Null,
}
}
}
impl WellKnown for ByteSize {
type Deserializer = WithUnit;
const DE: Self::Deserializer = WithUnit;
}
impl CustomKnownOption for ByteSize {
type OptDeserializer = Optional<WithUnit, true>;
const OPT_DE: Self::OptDeserializer = Optional(WithUnit);
}
impl TypeSuffixes {
pub(crate) fn contains(self, suffix: &str) -> bool {
match self {
Self::All => true,
Self::DurationUnits => {
let suffix = suffix.strip_prefix("in_").unwrap_or(suffix);
RawDuration::VARIANTS.contains(&suffix)
}
Self::SizeUnits => {
let suffix = suffix.strip_prefix("in_").unwrap_or(suffix);
RawByteSize::VARIANTS.contains(&suffix)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parsing_time_string() {
let duration: RawDuration = "10ms".parse().unwrap();
assert_eq!(duration, RawDuration::Millis(10));
let duration: RawDuration = "50 seconds".parse().unwrap();
assert_eq!(duration, RawDuration::Seconds(50));
let duration: RawDuration = "40s".parse().unwrap();
assert_eq!(duration, RawDuration::Seconds(40));
let duration: RawDuration = "10 min".parse().unwrap();
assert_eq!(duration, RawDuration::Minutes(10));
let duration: RawDuration = "10m".parse().unwrap();
assert_eq!(duration, RawDuration::Minutes(10));
let duration: RawDuration = "12 hours".parse().unwrap();
assert_eq!(duration, RawDuration::Hours(12));
let duration: RawDuration = "12h".parse().unwrap();
assert_eq!(duration, RawDuration::Hours(12));
let duration: RawDuration = "30d".parse().unwrap();
assert_eq!(duration, RawDuration::Days(30));
let duration: RawDuration = "1 day".parse().unwrap();
assert_eq!(duration, RawDuration::Days(1));
let duration: RawDuration = "2 weeks".parse().unwrap();
assert_eq!(duration, RawDuration::Weeks(2));
let duration: RawDuration = "3w".parse().unwrap();
assert_eq!(duration, RawDuration::Weeks(3));
}
#[test]
fn parsing_time_string_errors() {
let err = "".parse::<RawDuration>().unwrap_err().to_string();
assert!(err.starts_with("invalid type"), "{err}");
let err = "???".parse::<RawDuration>().unwrap_err().to_string();
assert!(err.starts_with("invalid type"), "{err}");
let err = "10".parse::<RawDuration>().unwrap_err().to_string();
assert!(err.starts_with("invalid type"), "{err}");
let err = "hours".parse::<RawDuration>().unwrap_err().to_string();
assert!(err.starts_with("invalid type"), "{err}");
let err = "111111111111111111111111111111111111111111s"
.parse::<RawDuration>()
.unwrap_err()
.to_string();
assert!(err.contains("too large"), "{err}");
let err = "10 months".parse::<RawDuration>().unwrap_err().to_string();
assert!(err.starts_with("unknown variant"), "{err}");
}
#[test]
fn parsing_byte_size_string() {
let size: RawByteSize = "16bytes".parse().unwrap();
assert_eq!(size, RawByteSize::Bytes(16));
let size: RawByteSize = "128 KiB".parse().unwrap();
assert_eq!(size, RawByteSize::Kilobytes(128));
let size: RawByteSize = "16 kb".parse().unwrap();
assert_eq!(size, RawByteSize::Kilobytes(16));
let size: RawByteSize = "4MB".parse().unwrap();
assert_eq!(size, RawByteSize::Megabytes(4));
let size: RawByteSize = "1 GB".parse().unwrap();
assert_eq!(size, RawByteSize::Gigabytes(1));
}
#[test]
fn serializing_with_time_unit() {
let val = TimeUnit::Millis.serialize_param(&Duration::from_millis(10));
assert_eq!(val, 10_u32);
let val = TimeUnit::Millis.serialize_param(&Duration::from_secs(10));
assert_eq!(val, 10_000_u32);
let val = TimeUnit::Seconds.serialize_param(&Duration::from_secs(10));
assert_eq!(val, 10_u32);
let val = TimeUnit::Minutes.serialize_param(&Duration::from_secs(10));
assert_eq!(val, 0_u32);
let val = TimeUnit::Minutes.serialize_param(&Duration::from_secs(120));
assert_eq!(val, 2_u32);
}
#[test]
fn serializing_with_size_unit() {
let val = SizeUnit::Bytes.serialize_param(&ByteSize(128));
assert_eq!(val, 128_u32);
let val = SizeUnit::Bytes.serialize_param(&ByteSize(1 << 16));
assert_eq!(val, 1_u32 << 16);
let val = SizeUnit::KiB.serialize_param(&ByteSize(1 << 16));
assert_eq!(val, 1_u32 << 6);
let val = SizeUnit::MiB.serialize_param(&ByteSize(1 << 16));
assert_eq!(val, 0_u32);
let val = SizeUnit::MiB.serialize_param(&ByteSize::new(3, SizeUnit::MiB));
assert_eq!(val, 3_u32);
}
#[test]
fn serializing_with_duration() {
let val = WithUnit.serialize_param(&Duration::ZERO);
assert_eq!(val, "0s");
let val = WithUnit.serialize_param(&Duration::from_millis(10));
assert_eq!(val, "10ms");
let val = WithUnit.serialize_param(&Duration::from_secs(5));
assert_eq!(val, "5s");
let val = WithUnit.serialize_param(&Duration::from_millis(5_050));
assert_eq!(val, "5050ms");
let val = WithUnit.serialize_param(&Duration::from_secs(300));
assert_eq!(val, "5min");
let val = WithUnit.serialize_param(&Duration::from_secs(7_200));
assert_eq!(val, "2h");
let val = WithUnit.serialize_param(&Duration::from_secs(86_400));
assert_eq!(val, "1d");
}
#[test]
fn serializing_with_byte_size() {
let val = WithUnit.serialize_param(&ByteSize(0));
assert_eq!(val, "0 B");
let val = WithUnit.serialize_param(&ByteSize(128));
assert_eq!(val, "128 B");
let val = WithUnit.serialize_param(&ByteSize(32 << 10));
assert_eq!(val, "32 KiB");
let val = WithUnit.serialize_param(&ByteSize(3 << 20));
assert_eq!(val, "3 MiB");
}
}