diff --git a/src/base/charstr.rs b/src/base/charstr.rs index b47743c33..ad47bcdfd 100644 --- a/src/base/charstr.rs +++ b/src/base/charstr.rs @@ -24,7 +24,7 @@ //! [RFC 1035]: https://tools.ietf.org/html/rfc1035 use super::cmp::CanonicalOrd; -use super::scan::{BadSymbol, Scanner, Symbol, SymbolCharsError}; +use super::scan::{Scan, ScanError, Token, Tokenizer}; use super::wire::{Compose, ParseError}; #[cfg(feature = "bytes")] use bytes::BytesMut; @@ -317,18 +317,20 @@ impl CharStr<[u8]> { target: &mut impl OctetsBuilder, ) -> Result { let mut len = 0u8; - let mut chars = s.chars(); - while let Some(symbol) = Symbol::from_chars(&mut chars)? { - // We have the max length but there’s another character. Error! - if len == u8::MAX { - return Err(PresentationErrorEnum::LongString.into()); + let mut err = None; + Token::from_raw(s).process(|b| { + if err.is_some() { + return; + } else if len == u8::MAX { + err = Some(PresentationErrorEnum::LongString.into()); + } else { + target + .append_slice(&[b]) + .unwrap_or_else(|e| err = Some(e.into())); + len += 1; } - target - .append_slice(&[symbol.into_octet()?]) - .map_err(Into::into)?; - len += 1; - } - Ok(len) + })?; + err.map(Err).unwrap_or(Ok(len)) } } @@ -382,15 +384,6 @@ impl + ?Sized> CharStr { } } -impl CharStr { - /// Scans the presentation format from a scanner. - pub fn scan>( - scanner: &mut S, - ) -> Result { - scanner.scan_charstr() - } -} - impl + ?Sized> CharStr { /// Returns an object that formats in quoted presentation format. /// diff --git a/src/base/name/absolute.rs b/src/base/name/absolute.rs index 4633d2048..7ac6af92d 100644 --- a/src/base/name/absolute.rs +++ b/src/base/name/absolute.rs @@ -4,7 +4,7 @@ use super::super::cmp::CanonicalOrd; use super::super::net::IpAddr; -use super::super::scan::{Scanner, Symbol, SymbolCharsError, Symbols}; +use super::super::scan::{Scan, Token, Tokenizer, ScanError}; use super::super::wire::{FormError, ParseError}; use super::builder::{FromStrError, NameBuilder, PushError}; use super::label::{Label, LabelTypeError, SplitLabelError}; @@ -112,7 +112,7 @@ impl Name { ::Builder::with_capacity(1); builder .append_slice(b"\0") - .map_err(|_| FromStrError::ShortBuf)?; + .map_err(|_| ScanError::custom("could not build domain name"))?; return Ok(unsafe { Self::from_octets_unchecked(builder.freeze()) }); @@ -157,13 +157,6 @@ impl Name { }) } - /// Reads a name in presentation format from the beginning of a scanner. - pub fn scan>( - scanner: &mut S, - ) -> Result { - scanner.scan_name() - } - /// Returns a domain name consisting of the root label only. /// /// This function will work for any kind octets sequence that can be @@ -738,6 +731,101 @@ impl Name { } } +//--- Scan + +impl Scan for Name +where + Octs: FromBuilder, + ::Builder: EmptyBuilder + + FreezeBuilder + + AsRef<[u8]> + + AsMut<[u8]>, +{ + fn scan(tokens: &mut Tokenizer<'_>) -> Result { + let token = tokens.next()?; + + if token.is_quoted() { + return Err(ScanError::custom("domain names cannot be quoted")); + } + + if token.raw() == "." { + // Make a root name. + let mut builder = ::Builder::with_capacity(token.len()); + builder + .append_slice(b"\0") + .map_err(|_| ScanError::custom("could not create a domain name"))?; + return Ok(unsafe { + Self::from_octets_unchecked(builder.freeze()) + }); + } + + let mut builder = NameBuilder::::new(); + + // We don't expect to find escaped periods in the domain name. + // + // In both cases, we use 'split_inclusive()' to handle a final '.' correctly. + if !token.raw().contains("\\.") { + // We know that dots are never escaped, so we blindly split by them. + for token in token.raw().split_inclusive('.') { + // Strip the delimiting '.' if there is one. + let token = token + .strip_suffix('.') + .unwrap_or(token); + + if token.is_empty() { + return Err(ScanError::custom("empty label")); + } + + let mut err = None; + Token::from_raw(token).process(|slice| { + if err.is_none() { + builder.append_slice(slice) + .unwrap_or_else(|e| err = Some(e)); + } + }); + if let Some(err) = err { return Err(err); } + + // Mark the end of this label. + builder.end_label(); + } + } else { + // We split by dots and check for preceding backslashes. + for mut token in token.raw().split_inclusive('.') { + // Strip the delimiting '.' if it is not escaped. + if let Some(prefix) = token.strip_suffix('.') { + let num = prefix.bytes() + .rposition(|b| b != b'\\') + .map_or(0, |p| prefix.len() - p); + + if num % 2 == 0 { + token = &token[.. token.len() - 1]; + } + } + + if token.is_empty() && !builder.in_label() { + return Err(ScanError::custom("empty label")); + } + + let mut err = None; + Token::from_raw(token).process(|slice| { + if err.is_none() { + builder.append_slice(slice) + .unwrap_or_else(|e| err = Some(e)); + } + }); + if let Some(err) = err { return Err(err); } + + // We don't end the label if we didn't strip out the '.'. + if !token.ends_with('.') { + builder.end_label(); + } + } + } + + builder.into_name().map_err(Into::into) + } +} + //--- AsRef impl AsRef for Name { diff --git a/src/base/scan.rs b/src/base/scan.rs index 7a80cfbe0..65d52970d 100644 --- a/src/base/scan.rs +++ b/src/base/scan.rs @@ -111,6 +111,23 @@ impl<'a> Tokenizer<'a> { Self { input, pos: 0 } } + /// Skip a token if it exactly matches the given raw text. + pub fn try_skip_exactly(&mut self, raw: &str) -> bool { + if let Some(suffix) = self.input[self.pos ..].strip_prefix(raw) { + if suffix.as_bytes().first().map_or(|b| " \t\r".contains(b), true) { + self.pos += raw.len(); + + // Move past any whitespace, to the next token, comment, or newline. + self.pos += self.input[self.pos ..].bytes() + .position(|b| !b" \t\r".contains(&b)) + .unwrap_or(self.input.len() - self.pos); + + return true; + } + } + false + } + /// Extract the next token from the text. pub fn next(&mut self) -> Result, ScanError> { // TODO: We could use 'memchr' to track the position of the next @@ -224,6 +241,11 @@ impl<'a> Tokenizer<'a> { // Move past this token. self.pos += len; + // Move past any whitespace, to the next token, comment, or newline. + self.pos += input[self.pos ..].iter() + .position(|b| !b" \t\r".contains(b)) + .unwrap_or(input.len() - self.pos); + // Ensure the token doesn't contain any invalid characters. if input[pos..][..len].any(|&b| b < 0x20 && b != b'\t') { return Err(ScanError::custom( diff --git a/src/rdata/aaaa.rs b/src/rdata/aaaa.rs index 3cd8bc0da..9530254bf 100644 --- a/src/rdata/aaaa.rs +++ b/src/rdata/aaaa.rs @@ -8,7 +8,7 @@ use crate::base::cmp::CanonicalOrd; use crate::base::iana::Rtype; use crate::base::net::Ipv6Addr; use crate::base::rdata::{ComposeRecordData, ParseRecordData, RecordData}; -use crate::base::scan::{Scanner, ScannerError}; +use crate::base::scan::{Scan, Tokenizer, ScanError}; use crate::base::wire::{Composer, Parse, ParseError}; use core::cmp::Ordering; use core::convert::Infallible; @@ -57,14 +57,6 @@ impl Aaaa { ) -> Result { Ipv6Addr::parse(parser).map(Self::new) } - - pub fn scan(scanner: &mut S) -> Result { - let token = scanner.scan_octets()?; - let token = str::from_utf8(token.as_ref()) - .map_err(|_| S::Error::custom("expected IPv6 address"))?; - Aaaa::from_str(token) - .map_err(|_| S::Error::custom("expected IPv6 address")) - } } //--- From and FromStr @@ -148,6 +140,18 @@ impl ComposeRecordData for Aaaa { } } +//--- Scan + +#[cfg(feature = "std")] +impl Scan for Aaaa { + fn scan(tokens: &mut Tokenizer<'_>) -> Result { + let token = tokens.next()?.processed_bytes()?; + str::from_utf8(token.as_ref()).ok() + .and_then(|s| Self::from_str(s).ok()) + .ok_or(ScanError::custom("expected IPv6 address")) + } +} + //--- Display impl fmt::Display for Aaaa { diff --git a/src/rdata/macros.rs b/src/rdata/macros.rs index 5f855701e..cd98be877 100644 --- a/src/rdata/macros.rs +++ b/src/rdata/macros.rs @@ -39,7 +39,7 @@ macro_rules! rdata_types { ComposeRecordData, ParseAnyRecordData, ParseRecordData, RecordData, UnknownRecordData, }; - use crate::base::scan::ScannerError; + use crate::base::scan::{ScanError, Tokenizer}; use crate::base::wire::{Composer, ParseError}; use octseq::octets::{Octets, OctetsFrom}; use octseq::parse::Parser; @@ -89,29 +89,26 @@ macro_rules! rdata_types { /// If the record data is given via the notation for unknown /// record types, the returned value will be of the /// `ZoneRecordData::Unknown(_)` variant. - pub fn scan( + pub fn scan( + tokens: &mut Tokenizer<'_>, rtype: Rtype, - scanner: &mut S - ) -> Result - where - S: $crate::base::scan::Scanner - { - if scanner.scan_opt_unknown_marker()? { + ) -> Result { + // This is an unknown RDATA marker. + if tokens.try_skip_exactly("\\#") { UnknownRecordData::scan_without_marker( - rtype, scanner + tokens, rtype ).map(ZoneRecordData::Unknown) - } - else { + } else { match rtype { $( $( $( $mtype::RTYPE => { $mtype::scan( - scanner + tokens ).map(ZoneRecordData::$mtype) } )* )* )* _ => { - Err(S::Error::custom( + Err(ScanError::custom( "unknown record type with concrete data" )) } @@ -1144,12 +1141,6 @@ macro_rules! name_type_base { self.$field } - pub fn scan>( - scanner: &mut S - ) -> Result { - scanner.scan_name().map(Self::new) - } - pub(in crate::rdata) fn convert_octets>( self ) -> Result<$target, Target::Error> { @@ -1174,6 +1165,14 @@ macro_rules! name_type_base { } } + //--- Scan + + impl Scan for $target { + fn scan(tokens: &mut Tokenizer<'_>) -> Result { + N::scan(tokens).map(Self::new) + } + } + //--- From and FromStr impl From for $target {