From 0a16f8139450a9ea4bc9b4c67aec320598ad929d Mon Sep 17 00:00:00 2001 From: Serge Barral Date: Thu, 13 Jun 2024 17:24:11 +0200 Subject: [PATCH] Validate nanoseconds field on deserialize --- Cargo.toml | 3 + src/lib.rs | 160 ++++++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 162 insertions(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index ee28b8e..9b50876 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,6 +24,9 @@ serde = { version = "1", default-features = false, features = ["derive"], option nix = { version = "0.26", default-features = false, features = ["time"], optional = true } defmt = { version = "0.3", optional = true } +[dev-dependencies] +serde_json = "1" + [features] default = ["std"] std = [] diff --git a/src/lib.rs b/src/lib.rs index c295900..c12b7e3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -336,7 +336,7 @@ pub type Tai1972Time = TaiTime<63_072_000>; /// assert_eq!(timestamp.subsec_nanos(), 789_333_333); /// ``` #[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash, PartialOrd, Ord)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "serde", derive(serde::Serialize))] #[cfg_attr(feature = "defmt", derive(defmt::Format))] pub struct TaiTime { /// The number of whole seconds in the future (if positive) or in the past @@ -1416,6 +1416,116 @@ impl fmt::Display for TaiTime { } } +#[cfg(feature = "serde")] +impl<'de, const EPOCH_REF: i64> serde::Deserialize<'de> for TaiTime { + fn deserialize(deserializer: D) -> Result + where + D: serde::de::Deserializer<'de>, + { + use serde::de::{self, Deserialize, Deserializer, MapAccess, SeqAccess, Visitor}; + + enum Field { + Secs, + Nanos, + } + + impl<'de> Deserialize<'de> for Field { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct FieldVisitor; + + impl<'de> Visitor<'de> for FieldVisitor { + type Value = Field; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("`secs` or `nanos`") + } + + fn visit_str(self, value: &str) -> Result + where + E: de::Error, + { + match value { + "secs" => Ok(Field::Secs), + "nanos" => Ok(Field::Nanos), + _ => Err(de::Error::unknown_field(value, FIELDS)), + } + } + } + + deserializer.deserialize_identifier(FieldVisitor) + } + } + + struct DurationVisitor; + + impl<'de, const EPOCH_REF: i64> Visitor<'de> for DurationVisitor { + type Value = TaiTime; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("struct TaiTime") + } + + fn visit_seq(self, mut seq: V) -> Result, V::Error> + where + V: SeqAccess<'de>, + { + let secs = seq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(0, &self))?; + let nanos = seq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(1, &self))?; + + TaiTime::new(secs, nanos).ok_or_else(|| { + de::Error::invalid_value( + de::Unexpected::Unsigned(nanos as u64), + &"a number of nanoseconds between 0 and 999999999", + ) + }) + } + + fn visit_map(self, mut map: V) -> Result, V::Error> + where + V: MapAccess<'de>, + { + let mut secs = None; + let mut nanos = None; + while let Some(key) = map.next_key()? { + match key { + Field::Secs => { + if secs.is_some() { + return Err(de::Error::duplicate_field("secs")); + } + secs = Some(map.next_value()?); + } + Field::Nanos => { + if nanos.is_some() { + return Err(de::Error::duplicate_field("nanos")); + } + nanos = Some(map.next_value()?); + } + } + } + let secs = secs.ok_or_else(|| de::Error::missing_field("secs"))?; + let nanos = nanos.ok_or_else(|| de::Error::missing_field("nanos"))?; + + TaiTime::new(secs, nanos).ok_or_else(|| { + de::Error::invalid_value( + de::Unexpected::Unsigned(nanos as u64), + &"a number of nanoseconds between 0 and 999999999", + ) + }) + } + } + + const FIELDS: &[&str] = &["secs", "nanos"]; + deserializer.deserialize_struct("TaiTime", FIELDS, DurationVisitor::) + } +} + #[cfg(test)] mod tests { use super::*; @@ -1927,4 +2037,52 @@ mod tests { assert!(date_time_str.parse::().is_err()); } } + + #[cfg(feature = "serde")] + #[test] + fn deserialize_from_seq() { + use serde_json; + + let data = r#"[987654321, 123456789]"#; + + let t: GpsTime = serde_json::from_str(data).unwrap(); + assert_eq!(t, GpsTime::new(987654321, 123456789).unwrap()); + } + + #[cfg(feature = "serde")] + #[test] + fn deserialize_from_map() { + use serde_json; + + let data = r#"{"secs": 987654321, "nanos": 123456789}"#; + + let t: GpsTime = serde_json::from_str(data).unwrap(); + assert_eq!(t, GpsTime::new(987654321, 123456789).unwrap()); + } + + #[cfg(feature = "serde")] + #[test] + fn deserialize_invalid_nanos() { + use serde_json; + + let data = r#"{"secs": 987654321, "nanos": 1000000000}"#; + + let t: Result = serde_json::from_str(data); + + assert!(t.is_err()) + } + + #[cfg(feature = "serde")] + #[test] + fn serialize_roundtrip() { + use serde_json; + + let t0 = GpsTime::new(987654321, 123456789).unwrap(); + + let data = serde_json::to_string(&t0).unwrap(); + + let t1: GpsTime = serde_json::from_str(&data).unwrap(); + + assert_eq!(t0, t1); + } }