diff --git a/Cargo.lock b/Cargo.lock index 49778ae01..b18b3a819 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -47,6 +47,56 @@ dependencies = [ "libc", ] +[[package]] +name = "anstream" +version = "0.6.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8acc5369981196006228e28809f761875c0327210a891e941f4c683b3a99529b" +dependencies = [ + "anstyle", + "anstyle-parse", + "anstyle-query", + "anstyle-wincon", + "colorchoice", + "is_terminal_polyfill", + "utf8parse", +] + +[[package]] +name = "anstyle" +version = "1.0.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55cc3b69f167a1ef2e161439aa98aed94e6028e5f9a59be9a6ffb47aef1651f9" + +[[package]] +name = "anstyle-parse" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b2d16507662817a6a20a9ea92df6652ee4f94f914589377d69f3b21bc5798a9" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-query" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79947af37f4177cfead1110013d678905c37501914fba0efea834c3fe9a8d60c" +dependencies = [ + "windows-sys 0.59.0", +] + +[[package]] +name = "anstyle-wincon" +version = "3.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ca3534e77181a9cc07539ad51f2141fe32f6c3ffd4df76db8ad92346b003ae4e" +dependencies = [ + "anstyle", + "once_cell", + "windows-sys 0.59.0", +] + [[package]] name = "arbitrary" version = "1.4.1" @@ -167,6 +217,12 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "colorchoice" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990" + [[package]] name = "concurrent-queue" version = "2.5.0" @@ -232,9 +288,11 @@ version = "0.10.3" dependencies = [ "arbitrary", "arc-swap", + "bumpalo", "bytes", "chrono", "domain-macros", + "env_logger", "futures-util", "hashbrown 0.14.5", "heapless", @@ -287,6 +345,29 @@ version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" +[[package]] +name = "env_filter" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "186e05a59d4c50738528153b83b0b0194d3a29507dfec16eccd4b342903397d0" +dependencies = [ + "log", + "regex", +] + +[[package]] +name = "env_logger" +version = "0.11.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dcaee3d8e3cfc3fd92428d477bc97fc29ec8716d180c0d74c643bb26166660e0" +dependencies = [ + "anstream", + "anstyle", + "env_filter", + "humantime", + "log", +] + [[package]] name = "equivalent" version = "1.0.1" @@ -494,6 +575,12 @@ dependencies = [ "stable_deref_trait", ] +[[package]] +name = "humantime" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" + [[package]] name = "iana-time-zone" version = "0.1.61" @@ -527,6 +614,12 @@ dependencies = [ "hashbrown 0.15.2", ] +[[package]] +name = "is_terminal_polyfill" +version = "1.70.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" + [[package]] name = "itertools" version = "0.13.0" @@ -1440,6 +1533,12 @@ version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" +[[package]] +name = "utf8parse" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" + [[package]] name = "uuid" version = "1.12.1" @@ -1647,6 +1746,15 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets 0.52.6", +] + [[package]] name = "windows-targets" version = "0.48.5" diff --git a/Cargo.toml b/Cargo.toml index 481da842c..da0deeebb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,7 @@ license = "BSD-3-Clause" domain-macros = { path = "./macros", version = "0.10.3" } arbitrary = { version = "1.4.1", optional = true, features = ["derive"] } +bumpalo = { version = "3.12", optional = true } octseq = { version = "0.5.2", default-features = false } time = { version = "0.3.1", default-features = false } rand = { version = "0.8", optional = true } @@ -74,7 +75,7 @@ zonefile = ["bytes", "serde", "std"] # Unstable features unstable-client-transport = ["moka", "net", "tracing"] -unstable-server-transport = ["arc-swap", "chrono/clock", "libc", "net", "siphasher", "tracing"] +unstable-server-transport = ["dep:bumpalo", "arc-swap", "chrono/clock", "dep:log", "libc", "net", "rand", "siphasher", "tracing"] unstable-sign = ["std", "dep:secrecy", "unstable-validate", "time/formatting"] unstable-stelline = ["tokio/test-util", "tracing", "tracing-subscriber", "tsig", "unstable-client-transport", "unstable-server-transport", "zonefile"] unstable-validate = ["bytes", "std", "ring"] @@ -99,6 +100,7 @@ tokio-rustls = { version = "0.26", default-features = false, features = [ tokio-test = "0.4" tokio-tfo = { version = "0.2.0" } webpki-roots = { version = "0.26" } +env_logger = { version = "0.11" } # For the "mysql-zone" example #sqlx = { version = "0.6", features = [ "runtime-tokio-native-tls", "mysql" ] } @@ -118,6 +120,10 @@ required-features = ["resolv"] name = "lookup" required-features = ["resolv"] +[[example]] +name = "new-server" +required-features = ["net", "unstable-server-transport", "unstable-client-transport"] + [[example]] name = "resolv-sync" required-features = ["resolv-sync"] diff --git a/examples/new-server.rs b/examples/new-server.rs new file mode 100644 index 000000000..27b4b643a --- /dev/null +++ b/examples/new-server.rs @@ -0,0 +1,92 @@ +use std::ops::ControlFlow; + +use log::trace; + +use domain::new_server::{ + exchange::{OutgoingResponse, ResponseCode}, + layers::{ + cookie::{CookieMetadata, CookiePolicy, CookieSecrets}, + CookieLayer, + }, + transport, Exchange, LocalService, LocalServiceLayer, Service, + ServiceLayer, +}; + +pub struct MyService; + +impl Service for MyService { + async fn respond(&self, exchange: &mut Exchange<'_>) { + let cookie = exchange + .metadata + .iter() + .find_map(|m| m.try_as::()); + + if let Some(CookieMetadata::ServerCookie { .. }) = cookie { + trace!(target: "MyService", "Request had a valid cookie"); + } else { + trace!(target: "MyService", "Request did not have a valid cookie"); + } + + exchange.respond(ResponseCode::Success); + + // Copy all questions from the request to the response. + exchange + .response + .questions + .append(&mut exchange.request.questions); + } +} + +impl LocalService for MyService { + async fn respond_local(&self, exchange: &mut Exchange<'_>) { + self.respond(exchange).await + } +} + +pub struct MyLayer; + +impl ServiceLayer for MyLayer { + async fn process_incoming( + &self, + exchange: &mut Exchange<'_>, + ) -> ControlFlow<()> { + trace!(target: "MyLayer", + "Incoming request (message ID {})", + exchange.request.id); + ControlFlow::Continue(()) + } + + async fn process_outgoing(&self, response: OutgoingResponse<'_, '_>) { + trace!(target: "MyLayer", + "Outgoing response (message ID {})", + response.response.id); + } +} + +impl LocalServiceLayer for MyLayer { + async fn process_local_incoming( + &self, + exchange: &mut Exchange<'_>, + ) -> ControlFlow<()> { + self.process_incoming(exchange).await + } + + async fn process_local_outgoing( + &self, + response: OutgoingResponse<'_, '_>, + ) { + self.process_outgoing(response).await + } +} + +#[tokio::main] +async fn main() { + env_logger::init(); + + let addr = "127.0.0.1:8080".parse().unwrap(); + let cookie_layer = + CookieLayer::new(CookiePolicy::default(), CookieSecrets::generate()); + let service = (MyLayer, cookie_layer, MyService); + let result = transport::serve_udp(addr, service).await; + println!("Ended on result {result:?}"); +} diff --git a/src/lib.rs b/src/lib.rs index e38ebcaa7..b8e02d032 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -213,3 +213,5 @@ pub mod zonetree; pub mod new_base; pub mod new_edns; pub mod new_rdata; + +pub mod new_server; diff --git a/src/new_edns/cookie.rs b/src/new_edns/cookie.rs index 77c810be5..9b7c5a243 100644 --- a/src/new_edns/cookie.rs +++ b/src/new_edns/cookie.rs @@ -5,20 +5,22 @@ //! [RFC 7873]: https://datatracker.ietf.org/doc/html/rfc7873 //! [RFC 9018]: https://datatracker.ietf.org/doc/html/rfc9018 -use core::fmt; +use core::{ + borrow::{Borrow, BorrowMut}, + fmt, + hash::{Hash, Hasher}, + ops::{Deref, DerefMut}, +}; -#[cfg(all(feature = "std", feature = "siphasher"))] -use core::ops::Range; - -#[cfg(all(feature = "std", feature = "siphasher"))] -use std::net::IpAddr; +#[cfg(feature = "siphasher")] +use core::{net::IpAddr, ops::Range}; use domain_macros::*; -use crate::new_base::Serial; - -#[cfg(all(feature = "std", feature = "siphasher"))] -use crate::new_base::wire::{AsBytes, TruncationError}; +use crate::new_base::{ + wire::{AsBytes, ParseBytesByRef}, + Serial, +}; //----------- ClientCookie --------------------------------------------------- @@ -57,35 +59,25 @@ impl ClientCookie { impl ClientCookie { /// Build a [`Cookie`] in response to this request. /// - /// A 24-byte version-1 interoperable cookie will be generated and written - /// to the given buffer. If the buffer is big enough, the remaining part - /// of the buffer is returned. - #[cfg(all(feature = "std", feature = "siphasher"))] - pub fn respond_into<'b>( - &self, - addr: IpAddr, - secret: &[u8; 16], - mut bytes: &'b mut [u8], - ) -> Result<&'b mut [u8], TruncationError> { - use core::hash::Hasher; - + /// A 24-byte version-1 interoperable cookie will be returned. + #[cfg(feature = "siphasher")] + pub fn respond(&self, addr: IpAddr, secret: &[u8; 16]) -> CookieBuf { use siphasher::sip::SipHasher24; - use crate::new_base::wire::BuildBytes; + // Construct a buffer to write into. + let mut bytes = [0u8; 24]; - // Build and hash the cookie simultaneously. - let mut hasher = SipHasher24::new_with_key(secret); - - bytes = self.build_bytes(bytes)?; - hasher.write(self.as_bytes()); + bytes[0..8].copy_from_slice(self.as_bytes()); // The version number and the reserved octets. - bytes = [1, 0, 0, 0].build_bytes(bytes)?; - hasher.write(&[1, 0, 0, 0]); + bytes[8..12].copy_from_slice(&[1, 0, 0, 0]); let timestamp = Serial::unix_time(); - bytes = timestamp.build_bytes(bytes)?; - hasher.write(timestamp.as_bytes()); + bytes[12..16].copy_from_slice(timestamp.as_bytes()); + + // Hash the cookie. + let mut hasher = SipHasher24::new_with_key(secret); + hasher.write(&bytes[0..16]); match addr { IpAddr::V4(addr) => hasher.write(&addr.octets()), @@ -93,9 +85,11 @@ impl ClientCookie { } let hash = hasher.finish().to_le_bytes(); - bytes = hash.build_bytes(bytes)?; + bytes[16..24].copy_from_slice(&hash); - Ok(bytes) + let cookie = Cookie::parse_bytes_by_ref(&bytes) + .expect("Any 24-byte string is a valid 'Cookie'"); + CookieBuf::copy_from(cookie) } } @@ -194,7 +188,7 @@ impl Cookie { /// valid. /// /// [RFC 9018]: https://datatracker.ietf.org/doc/html/rfc9018 - #[cfg(all(feature = "std", feature = "siphasher"))] + #[cfg(feature = "siphasher")] pub fn verify( &self, addr: IpAddr, @@ -229,6 +223,99 @@ impl Cookie { } } +//----------- CookieBuf ------------------------------------------------------ + +/// A 41-byte buffer holding a [`Cookie`]. +#[derive(Clone)] +pub struct CookieBuf { + /// The size of the cookie, in bytes. + /// + /// This value is between 24 and 40, inclusive. + size: u8, + + /// The cookie data, as raw bytes. + data: [u8; 40], +} + +//--- Construction + +impl CookieBuf { + /// Copy a [`Cookie`] into a [`CookieBuf`]. + pub fn copy_from(cookie: &Cookie) -> Self { + let mut data = [0u8; 40]; + let cookie = cookie.as_bytes(); + data[..cookie.len()].copy_from_slice(cookie); + let size = cookie.len() as u8; + Self { size, data } + } +} + +//--- Access to the underlying 'Cookie' + +impl Deref for CookieBuf { + type Target = Cookie; + + fn deref(&self) -> &Self::Target { + let bytes = &self.data[..self.size as usize]; + // SAFETY: A 'CookieBuf' always contains a valid 'Cookie'. + unsafe { Cookie::parse_bytes_by_ref(bytes).unwrap_unchecked() } + } +} + +impl DerefMut for CookieBuf { + fn deref_mut(&mut self) -> &mut Self::Target { + let bytes = &mut self.data[..self.size as usize]; + // SAFETY: A 'CookieBuf' always contains a valid 'Cookie'. + unsafe { Cookie::parse_bytes_by_mut(bytes).unwrap_unchecked() } + } +} + +impl Borrow for CookieBuf { + fn borrow(&self) -> &Cookie { + self + } +} + +impl BorrowMut for CookieBuf { + fn borrow_mut(&mut self) -> &mut Cookie { + self + } +} + +impl AsRef for CookieBuf { + fn as_ref(&self) -> &Cookie { + self + } +} + +impl AsMut for CookieBuf { + fn as_mut(&mut self) -> &mut Cookie { + self + } +} + +//--- Forwarding formatting, equality and hashing + +impl fmt::Debug for CookieBuf { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + (**self).fmt(f) + } +} + +impl PartialEq for CookieBuf { + fn eq(&self, that: &Self) -> bool { + **self == **that + } +} + +impl Eq for CookieBuf {} + +impl Hash for CookieBuf { + fn hash(&self, state: &mut H) { + (**self).hash(state) + } +} + //----------- CookieError ---------------------------------------------------- /// An invalid [`Cookie`] was encountered. diff --git a/src/new_edns/mod.rs b/src/new_edns/mod.rs index f2cf6b710..fec98291d 100644 --- a/src/new_edns/mod.rs +++ b/src/new_edns/mod.rs @@ -8,19 +8,21 @@ use domain_macros::*; use crate::{ new_base::{ + name::RevName, parse::{ParseMessageBytes, SplitMessageBytes}, wire::{ AsBytes, BuildBytes, ParseBytes, ParseBytesByRef, ParseError, SizePrefixed, SplitBytes, TruncationError, U16, }, + RClass, RType, Record, }, - new_rdata::Opt, + new_rdata::{Opt, RecordData}, }; //----------- EDNS option modules -------------------------------------------- mod cookie; -pub use cookie::{ClientCookie, Cookie}; +pub use cookie::{ClientCookie, Cookie, CookieBuf, CookieError}; mod ext_err; pub use ext_err::{ExtError, ExtErrorCode}; @@ -46,6 +48,51 @@ pub struct EdnsRecord<'a> { pub options: SizePrefixed<&'a Opt>, } +//--- Converting to and from 'Record' + +impl<'n, 'a, DN> TryFrom>> + for EdnsRecord<'a> +{ + type Error = ParseError; + + fn try_from( + value: Record<&'n RevName, RecordData<'a, DN>>, + ) -> Result { + if !value.rname.is_root() || value.rtype != RType::OPT { + return Err(ParseError); + } + + let RecordData::Opt(opt) = value.rdata else { + return Err(ParseError); + }; + + let ttl = value.ttl.value.get().to_be_bytes(); + Ok(Self { + max_udp_payload: value.rclass.code, + ext_rcode: ttl[0], + version: ttl[1], + flags: u16::from_be_bytes([ttl[2], ttl[3]]).into(), + options: SizePrefixed::new(opt), + }) + } +} + +impl<'a, DN> From> for Record<&RevName, RecordData<'a, DN>> { + fn from(value: EdnsRecord<'a>) -> Self { + let flags = value.flags.bits().to_be_bytes(); + let ttl = [value.ext_rcode, value.version, flags[0], flags[1]]; + Record { + rname: RevName::ROOT, + rtype: RType::OPT, + rclass: RClass { + code: value.max_udp_payload, + }, + ttl: u32::from_be_bytes(ttl).into(), + rdata: RecordData::Opt(*value.options), + } + } +} + //--- Parsing from DNS messages impl<'a> SplitMessageBytes<'a> for EdnsRecord<'a> { @@ -177,6 +224,22 @@ impl EdnsFlags { } } +//--- Conversion to and from integers + +impl From for EdnsFlags { + fn from(value: u16) -> Self { + Self { + inner: U16::new(value), + } + } +} + +impl From for u16 { + fn from(value: EdnsFlags) -> Self { + value.inner.get() + } +} + //--- Formatting impl fmt::Debug for EdnsFlags { diff --git a/src/new_rdata/edns.rs b/src/new_rdata/edns.rs index 43327b50c..d832d9f5a 100644 --- a/src/new_rdata/edns.rs +++ b/src/new_rdata/edns.rs @@ -26,6 +26,14 @@ pub struct Opt { contents: [u8], } +//--- Associated Constants + +impl Opt { + /// Empty OPT record data. + pub const EMPTY: &'static Self = + unsafe { core::mem::transmute(&[] as &'static [u8]) }; +} + //--- Inspection impl Opt { diff --git a/src/new_rdata/mod.rs b/src/new_rdata/mod.rs index 70f041240..8491a3c6d 100644 --- a/src/new_rdata/mod.rs +++ b/src/new_rdata/mod.rs @@ -63,6 +63,40 @@ pub enum RecordData<'a, N> { Unknown(RType, &'a UnknownRecordData), } +impl<'a, N> RecordData<'a, N> { + /// Transform the compressed domain names in this record data. + pub fn map_names R>( + self, + mut f: F, + ) -> RecordData<'a, R> { + match self { + Self::A(r) => RecordData::A(r), + Self::Ns(r) => RecordData::Ns(Ns { name: (f)(r.name) }), + Self::CName(r) => RecordData::CName(CName { name: (f)(r.name) }), + Self::Soa(r) => RecordData::Soa(Soa { + mname: (f)(r.mname), + rname: (f)(r.rname), + serial: r.serial, + refresh: r.refresh, + retry: r.retry, + expire: r.expire, + minimum: r.minimum, + }), + Self::Wks(r) => RecordData::Wks(r), + Self::Ptr(r) => RecordData::Ptr(Ptr { name: (f)(r.name) }), + Self::HInfo(r) => RecordData::HInfo(r), + Self::Mx(r) => RecordData::Mx(Mx { + preference: r.preference, + exchange: (f)(r.exchange), + }), + Self::Txt(r) => RecordData::Txt(r), + Self::Aaaa(r) => RecordData::Aaaa(r), + Self::Opt(r) => RecordData::Opt(r), + Self::Unknown(t, r) => RecordData::Unknown(t, r), + } + } +} + //--- Parsing record data impl<'a, N> ParseRecordData<'a> for RecordData<'a, N> diff --git a/src/new_server/exchange.rs b/src/new_server/exchange.rs new file mode 100644 index 000000000..9233aa461 --- /dev/null +++ b/src/new_server/exchange.rs @@ -0,0 +1,588 @@ +//! Request-response exchanges for DNS servers. +//! +//! This module provides a number of utility types for the DNS service layer +//! architecture. In particular, an [`Exchange`] represents a DNS request as +//! it is being passed along a server pipeline, and an [`OutgoingResponse`] is +//! the corresponding response as it is passed back through. + +use core::any::{Any, TypeId}; +use std::{boxed::Box, time::SystemTime, vec::Vec}; + +use bumpalo::Bump; + +use crate::{ + new_base::{ + build::{BuilderContext, MessageBuilder}, + name::{RevName, RevNameBuf}, + parse::SplitMessageBytes, + wire::{BuildBytes, ParseError, SizePrefixed, TruncationError, U16}, + HeaderFlags, Message, Question, RType, Record, SectionCounts, + }, + new_edns::{EdnsOption, EdnsRecord}, + new_rdata::{Opt, RecordData}, +}; + +//----------- Exchange ------------------------------------------------------- + +/// A DNS request-response exchange. +/// +/// An [`Exchange`] represents a request sent to a DNS server and the server's +/// response (as it is being built). It tracks basic information about the +/// request, such as when it was sent and the connection it originates from, +/// as well as metadata stored by layers in the DNS server. +pub struct Exchange<'a> { + /// An allocator for storing parts of the message. + pub alloc: Allocator<'a>, + + /// When the exchange began (i.e. when the request was received). + pub reception: SystemTime, + + /// The request message. + pub request: ParsedMessage<'a>, + + /// The response message being built. + pub response: ParsedMessage<'a>, + + /// Dynamic metadata stored by the DNS server. + pub metadata: Vec, +} + +impl Exchange<'_> { + /// Begin a response with the given code. + pub fn respond(&mut self, code: ResponseCode) { + self.response.respond_to(&self.request, code); + } +} + +//----------- OutgoingResponse ----------------------------------------------- + +/// An [`Exchange`] with an initialized response message. +pub struct OutgoingResponse<'e, 'a> { + /// An allocator for storing parts of the message. + pub alloc: &'e mut Allocator<'a>, + + /// The response message being built. + pub response: &'e mut ParsedMessage<'a>, + + /// Dynamic metadata stored by the DNS server. + pub metadata: &'e mut Vec, +} + +impl<'e, 'a> OutgoingResponse<'e, 'a> { + /// Construct an [`OutgoingResponse`] on an [`Exchange`]. + pub fn new(exchange: &'e mut Exchange<'a>) -> Self { + Self { + alloc: &mut exchange.alloc, + response: &mut exchange.response, + metadata: &mut exchange.metadata, + } + } + + /// Reborrow this response for a shorter lifetime. + pub fn reborrow(&mut self) -> OutgoingResponse<'_, 'a> { + OutgoingResponse { + alloc: self.alloc, + response: self.response, + metadata: self.metadata, + } + } +} + +//----------- ParsedMessage -------------------------------------------------- + +/// A pre-parsed DNS message. +/// +/// This is a simple representation of DNS messages outside the wire format, +/// making it easy to inspect and modify them efficiently. +#[derive(Clone, Default, Debug)] +pub struct ParsedMessage<'a> { + /// The message ID. + pub id: U16, + + /// The message flags. + pub flags: HeaderFlags, + + /// Questions in the message. + pub questions: Vec>, + + /// Answer records in the message. + pub answers: Vec>>, + + /// Authority records in the message. + pub authorities: Vec>>, + + /// Additional records in the message. + /// + /// If there is an EDNS record, it will be included here, but its record + /// data (which contains the EDNS options) will be empty. The options are + /// stored in the `options` field for easier access. + pub additional: Vec>>, + + /// EDNS options in the message. + /// + /// These options will be appended to the EDNS record in the additional + /// section (there must be one for any options to exist). The order of + /// the options is meaningless. + pub options: Vec>, +} + +impl<'a> ParsedMessage<'a> { + /// Parse an existing [`Message`]. + /// + /// Decompressed domain names are allocated using the given [`Bump`]. + pub fn parse( + message: &'a Message, + alloc: &mut Allocator<'a>, + ) -> Result { + type ParsedQuestion = Question; + type ParsedRecord<'a> = + Record>; + + /// Map a domain name by placing it in a [`Bump`]. + fn map_name<'a>( + name: RevNameBuf, + alloc: &mut Allocator<'a>, + ) -> &'a RevName { + // Allocate the domain name. + let name = alloc.alloc_slice_copy(name.as_bytes()); + // SAFETY: 'name' has the same bytes as the input 'name'. + unsafe { RevName::from_bytes_unchecked(name) } + } + + let mut this = Self::default(); + let mut offset = 0; + + // Parse the message header. + this.id = message.header.id; + this.flags = message.header.flags; + let counts = message.header.counts; + + // Parse the question section. + this.questions + .reserve(counts.questions.get().max(256) as usize); + for _ in 0..counts.questions.get() { + let (question, rest) = ParsedQuestion::split_message_bytes( + &message.contents, + offset, + )?; + + this.questions + .push(question.map_name(|n| map_name(n, alloc))); + offset = rest; + } + + // Parse the answer section. + this.answers.reserve(counts.answers.get().max(256) as usize); + for _ in 0..counts.answers.get() { + let (answer, rest) = + ParsedRecord::split_message_bytes(&message.contents, offset)?; + + this.answers.push(Record { + rname: map_name(answer.rname, alloc), + rtype: answer.rtype, + rclass: answer.rclass, + ttl: answer.ttl, + rdata: answer.rdata.map_names(|n| map_name(n, alloc)), + }); + offset = rest; + } + + // Parse the authority section. + this.authorities + .reserve(counts.authorities.get().max(256) as usize); + for _ in 0..counts.authorities.get() { + let (authority, rest) = + ParsedRecord::split_message_bytes(&message.contents, offset)?; + + this.authorities.push(Record { + rname: map_name(authority.rname, alloc), + rtype: authority.rtype, + rclass: authority.rclass, + ttl: authority.ttl, + rdata: authority.rdata.map_names(|n| map_name(n, alloc)), + }); + offset = rest; + } + + // The EDNS record data. + let mut edns_data = None; + + // Parse the additional section. + this.additional + .reserve(counts.additional.get().max(256) as usize); + for _ in 0..counts.additional.get() { + let (mut additional, rest) = + ParsedRecord::split_message_bytes(&message.contents, offset)?; + + if let RecordData::Opt(opt) = additional.rdata { + if edns_data.is_some() { + // A message cannot contain two distinct EDNS records. + return Err(ParseError); + } + + edns_data = Some(opt); + + // Deduplicate the EDNS data. + additional.rdata = RecordData::Opt(Opt::EMPTY); + } + + this.additional.push(Record { + rname: map_name(additional.rname, alloc), + rtype: additional.rtype, + rclass: additional.rclass, + ttl: additional.ttl, + rdata: additional.rdata.map_names(|n| map_name(n, alloc)), + }); + offset = rest; + } + + // Ensure there's no other content in the message. + if offset != message.contents.len() { + return Err(ParseError); + } + + // Parse EDNS options. + if let Some(edns_data) = edns_data { + for option in edns_data.options() { + this.options.push(option?); + } + } + + Ok(this) + } + + /// Build this message into the given buffer. + /// + /// If the message could not fit in the given buffer, a + /// [`TruncationError`] is returned. + pub fn build<'b>( + &self, + buffer: &'b mut [u8], + ) -> Result<&'b Message, TruncationError> { + // Construct a 'MessageBuilder'. + if buffer.len() < 12 { + return Err(TruncationError); + } + let mut context = BuilderContext::default(); + let mut builder = MessageBuilder::new(buffer, &mut context); + + // Build the message header. + let header = builder.header_mut(); + header.id = self.id; + header.flags = self.flags; + header.counts = SectionCounts::default(); + + // Build the question section. + for question in &self.questions { + builder + .build_question(question)? + .expect("No answers, authorities, or additionals are built"); + } + + // Build the answer section. + for answer in &self.answers { + builder + .build_answer(answer)? + .expect("No authorities, or additionals are built"); + } + + // Build the authority section. + for authority in &self.authorities { + builder + .build_authority(authority)? + .expect("No additionals are built"); + } + + // Build the additional section. + let mut edns_built = false; + for additional in &self.additional { + if additional.rtype == RType::OPT { + // Technically, multiple OPT records are an error. But this + // isn't the right place to report that. + debug_assert!(!edns_built, "Multiple EDNS records found"); + + let mut builder = builder.build_additional(additional)?; + let mut delegate = builder.delegate(); + let mut uninit = delegate.uninitialized(); + for option in &self.options { + uninit = option.build_bytes(uninit)?; + } + let uninit_len = uninit.len(); + let appended = delegate.uninitialized().len() - uninit_len; + delegate.mark_appended(appended); + delegate.commit(); + builder.commit(); + + edns_built = true; + continue; + } + + builder.build_additional(additional)?; + } + + debug_assert!( + self.options.is_empty() || edns_built, + "EDNS options found, but no OPT record", + ); + + Ok(builder.finish()) + } +} + +impl ParsedMessage<'_> { + /// Whether this message has an EDNS record. + pub fn has_edns(&self) -> bool { + self.additional.iter().any(|r| r.rtype == RType::OPT) + } +} + +impl ParsedMessage<'_> { + /// Reset this object to a blank message. + /// + /// This is helpful in order to reuse the underlying allocations. + pub fn reset(&mut self) { + self.id = U16::new(0); + self.flags = HeaderFlags::default(); + self.questions.clear(); + self.answers.clear(); + self.authorities.clear(); + self.additional.clear(); + self.options.clear(); + } + + /// Begin a new message in response to the given one. + /// + /// The contents of `self` will be overwritten. The message ID and flags + /// will be copied from the request message, and the given response code + /// will be set. The OPT record, if any, will also be copied (without any + /// EDNS options); if an extended response code is used, it will be added. + pub fn respond_to( + &mut self, + request: &ParsedMessage<'_>, + code: ResponseCode, + ) { + self.reset(); + self.id = request.id; + self.flags = request.flags.respond(code.header_bits()); + + if let Some(edns) = request + .additional + .iter() + .find_map(|r| EdnsRecord::try_from(r.clone()).ok()) + { + // Copy the EDNS record, without any options. + let record = EdnsRecord { + max_udp_payload: edns.max_udp_payload, + ext_rcode: edns.ext_rcode, + version: edns.version, + flags: edns.flags, + options: SizePrefixed::new(Opt::EMPTY), + }; + self.additional.push(record.into()); + } + } +} + +//----------- ResponseCode --------------------------------------------------- + +/// A (possibly extended) DNS response code. +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +pub enum ResponseCode { + /// The request was answered successfully. + Success, + + /// The request was misformatted. + FormatError, + + /// The server encountered an internal error. + ServerFailure, + + /// The queried domain name does not exist. + NonExistentDomain, + + /// The server does not support the requested kind of query. + NotImplemented, + + /// Policy prevents the server from answering the query. + Refused, + + /// The TSIG record in the request was invalid. + InvalidTSIG, + + /// The server does not support the request's OPT record version. + UnsupportedOptVersion, + + /// The request did not contain a valid EDNS server cookie. + BadCookie, +} + +impl ResponseCode { + /// This code's representation in the DNS message header. + pub const fn header_bits(&self) -> u8 { + match self { + Self::Success => 0, + Self::FormatError => 1, + Self::ServerFailure => 2, + Self::NonExistentDomain => 3, + Self::NotImplemented => 4, + Self::Refused => 5, + Self::InvalidTSIG => 9, + Self::UnsupportedOptVersion => 0, + Self::BadCookie => 7, + } + } + + /// This code's representation in the EDNS record header. + pub const fn edns_bits(&self) -> u8 { + match self { + Self::Success => 0, + Self::FormatError => 0, + Self::ServerFailure => 0, + Self::NonExistentDomain => 0, + Self::NotImplemented => 0, + Self::Refused => 0, + Self::InvalidTSIG => 0, + Self::UnsupportedOptVersion => 1, + Self::BadCookie => 1, + } + } +} + +//----------- Metadata ------------------------------------------------------- + +/// Arbitrary metadata about a DNS exchange. +/// +/// This should be used by [`ServiceLayer`](super::ServiceLayer)s for storing +/// information they have extracted from an incoming DNS request message. The +/// metadata may be relevant to future layers: for example, some may wish to +/// handle TSIG-signed requests differently from others. The metadata is also +/// relevant to the original layer in [`process_outgoing()`], as it does not +/// have access to the original request. +/// +/// [`process_outgoing()`]: super::ServiceLayer::process_outgoing() +/// +/// # Implementation +/// +/// This is an enhanced version of `Box` that can +/// perform downcasting more efficiently. It stores the [`TypeId`] of the +/// object inline, allowing it to skip a vtable lookup. +pub struct Metadata { + /// The type ID of the object. + type_id: TypeId, + + /// The underlying object. + object: Box, +} + +impl Metadata { + /// Wrap an object in [`Metadata`]. + pub fn new(object: T) -> Self { + let type_id = TypeId::of::(); + let object = Box::new(object) as Box; + Self { type_id, object } + } + + /// Check whether this is metadata of a certain type. + pub fn is(&self) -> bool { + self.type_id == TypeId::of::() + } + + /// Try downcasting to a reference of a particular type. + pub fn try_as(&self) -> Option<&T> { + if !self.is::() { + return None; + } + + let pointer: *const (dyn Any + Send + 'static) = &*self.object; + // SAFETY: 'pointer' was created by 'Box::into_raw()', and thus is + // safe to dereference (the pointer will only be dropped when 'self' + // is, but that cannot happen during the current lifetime). + Some(unsafe { &*pointer.cast::() }) + } + + /// Try downcasting to a mutable reference of a particular type. + pub fn try_as_mut(&mut self) -> Option<&mut T> { + if !self.is::() { + return None; + } + + let pointer: *mut (dyn Any + Send + 'static) = &mut *self.object; + // SAFETY: 'pointer' was created by 'Box::into_raw()', and thus is + // safe to dereference (the pointer will only be dropped when 'self' + // is, but that cannot happen during the current lifetime). + Some(unsafe { &mut *pointer.cast::() }) + } + + /// Try moving this object out of the [`Metadata`]. + pub fn try_into(self) -> Result { + if !self.is::() { + return Err(self); + } + + let pointer: *mut _ = Box::into_raw(self.object); + // SAFETY: 'pointer' was created by 'Box::into_raw()', and thus is + // safe to move into the same 'Box'. + Ok(*unsafe { Box::from_raw(pointer.cast::()) }) + } +} + +//----------- Allocator ------------------------------------------------------ + +/// A bump allocator with a fixed lifetime. +/// +/// This is a wrapper around [`bumpalo::Bump`] that guarantees thread safety. +/// It is equivalent to `&'a mut Bump`, but `&mut &'a mut Bump` does not work +/// (allocated objects only last for the shorter lifetime, not for `'a`). +/// `&mut Allocator<'a>` does work, giving objects of lifetime `'a`. +/// +/// # Thread Safety +/// +/// [`Bump`] is not thread safe; using it from multiple threads simultaneously +/// would cause undefined behaviour. [`Allocator`] implements [`Send`], and +/// so it cannot directly expose shared references to the underlying [`Bump`]; +/// a user could get `&Bump` on one thread, send the [`Allocator`] to another +/// thread, then get `&Bump` over there. This is why [`Allocator`] copies +/// [`Bump`]'s methods instead of implementing [`Deref`] to [`Bump`]. +/// +/// [`Deref`]: core::ops::Deref +#[derive(Debug)] +#[repr(transparent)] +pub struct Allocator<'a> { + /// The underlying allocator. + /// + /// In order to share access to a [`Bump`], even on a single thread, it + /// must be a shared reference (`&'a Bump`). That is how we store it + /// here. However, we guarantee that the [`Allocator`] is constructed + /// from a mutable reference -- thus that this is the only reference to + /// the bump allocator. It is never exposed publicly, so it cannot be + /// copied and used from multiple threads. + inner: &'a Bump, +} + +impl<'a> Allocator<'a> { + /// Construct a new [`Allocator`]. + pub fn new(inner: &'a mut Bump) -> Self { + // NOTE: The 'Bump' is mutably borrowed for lifetime 'a; the reference + // we store is thus guaranteed to be unique. + Self { inner } + } + + /// Allocate an object. + pub fn alloc(&mut self, val: T) -> &'a mut T { + self.inner.alloc(val) + } + + /// Allocate a slice and copy the given contents into it. + pub fn alloc_slice_copy(&mut self, src: &[T]) -> &'a mut [T] { + self.inner.alloc_slice_copy(src) + } +} + +// SAFETY: An 'Allocator' contains '&Bump', which is '!Send' because 'Bump' is +// '!Sync'. However, we guarantee that there are no other references to the +// 'Bump' -- that this is really '&mut Bump' (which is 'Send'). +unsafe impl Send for Allocator<'_> {} + +// NOTE: 'Allocator' acts a bit like the nightly-only 'std::sync::Exclusive', +// since it doesn't provide any shared access to the underlying 'Bump'. It is +// sound for it to implement 'Sync', but we defer this until necessary. diff --git a/src/new_server/impls.rs b/src/new_server/impls.rs new file mode 100644 index 000000000..ffabe4d3a --- /dev/null +++ b/src/new_server/impls.rs @@ -0,0 +1,502 @@ +//! Blanket implementations for service traits. + +use core::ops::ControlFlow; + +#[cfg(feature = "std")] +use std::{boxed::Box, rc::Rc, sync::Arc, vec::Vec}; + +use super::{ + exchange::OutgoingResponse, Exchange, LocalService, LocalServiceLayer, + Service, ServiceLayer, +}; + +//----------- impl Service --------------------------------------------------- + +impl Service for &T { + async fn respond(&self, exchange: &mut Exchange<'_>) { + T::respond(self, exchange).await + } +} + +impl Service for &mut T { + async fn respond(&self, exchange: &mut Exchange<'_>) { + T::respond(self, exchange).await + } +} + +#[cfg(feature = "std")] +impl Service for Box { + async fn respond(&self, exchange: &mut Exchange<'_>) { + T::respond(self, exchange).await + } +} + +#[cfg(feature = "std")] +impl Service for Arc { + async fn respond(&self, exchange: &mut Exchange<'_>) { + T::respond(self, exchange).await + } +} + +impl Service for (A, S) +where + A: ServiceLayer, + S: Service, +{ + async fn respond(&self, exchange: &mut Exchange<'_>) { + if self.0.process_incoming(exchange).await.is_continue() { + self.1.respond(exchange).await; + let response = OutgoingResponse::new(exchange); + self.0.process_outgoing(response).await; + } + } +} + +macro_rules! impl_service_tuple { + ($($layers:ident)* .. $service:ident) => { + impl<$($layers,)* $service: ?Sized> + Service for ($($layers,)* $service,) + where + $($layers: ServiceLayer,)* + $service: Service, + { + async fn respond(&self, exchange: &mut Exchange<'_>) { + #[allow(non_snake_case)] + let ($($layers,)* $service,) = self; + (($($layers),*,), $service).respond(exchange).await + } + } + }; +} + +impl_service_tuple!(A B..S); +impl_service_tuple!(A B C..S); +impl_service_tuple!(A B C D..S); +impl_service_tuple!(A B C D E..S); +impl_service_tuple!(A B C D E F..S); +impl_service_tuple!(A B C D E F G..S); +impl_service_tuple!(A B C D E F G H..S); +impl_service_tuple!(A B C D E F G H I..S); +impl_service_tuple!(A B C D E F G H I J..S); +impl_service_tuple!(A B C D E F G H I J K..S); + +//----------- impl LocalService ---------------------------------------------- + +impl LocalService for &T { + async fn respond_local(&self, exchange: &mut Exchange<'_>) { + T::respond_local(self, exchange).await + } +} + +impl LocalService for &mut T { + async fn respond_local(&self, exchange: &mut Exchange<'_>) { + T::respond_local(self, exchange).await + } +} + +#[cfg(feature = "std")] +impl LocalService for Box { + async fn respond_local(&self, exchange: &mut Exchange<'_>) { + T::respond_local(self, exchange).await + } +} + +#[cfg(feature = "std")] +impl LocalService for Rc { + async fn respond_local(&self, exchange: &mut Exchange<'_>) { + T::respond_local(self, exchange).await + } +} + +#[cfg(feature = "std")] +impl LocalService for Arc { + async fn respond_local(&self, exchange: &mut Exchange<'_>) { + T::respond_local(self, exchange).await + } +} + +impl LocalService for (A, S) +where + A: LocalServiceLayer, + S: LocalService, +{ + async fn respond_local(&self, exchange: &mut Exchange<'_>) { + if self.0.process_local_incoming(exchange).await.is_continue() { + self.1.respond_local(exchange).await; + let response = OutgoingResponse::new(exchange); + self.0.process_local_outgoing(response).await; + } + } +} + +macro_rules! impl_local_service_tuple { + ($($layers:ident)* .. $service:ident) => { + impl<$($layers,)* $service: ?Sized> + LocalService for ($($layers,)* $service,) + where + $($layers: LocalServiceLayer,)* + $service: LocalService, + { + async fn respond_local(&self, exchange: &mut Exchange<'_>) { + #[allow(non_snake_case)] + let ($($layers,)* $service,) = self; + (($($layers),*,), $service).respond_local(exchange).await + } + } + }; +} + +impl_local_service_tuple!(A B..S); +impl_local_service_tuple!(A B C..S); +impl_local_service_tuple!(A B C D..S); +impl_local_service_tuple!(A B C D E..S); +impl_local_service_tuple!(A B C D E F..S); +impl_local_service_tuple!(A B C D E F G..S); +impl_local_service_tuple!(A B C D E F G H..S); +impl_local_service_tuple!(A B C D E F G H I..S); +impl_local_service_tuple!(A B C D E F G H I J..S); +impl_local_service_tuple!(A B C D E F G H I J K..S); + +//----------- impl ServiceLayer ---------------------------------------------- + +impl ServiceLayer for &T { + async fn process_incoming( + &self, + exchange: &mut Exchange<'_>, + ) -> ControlFlow<()> { + T::process_incoming(self, exchange).await + } + + async fn process_outgoing(&self, response: OutgoingResponse<'_, '_>) { + T::process_outgoing(self, response).await + } +} + +impl ServiceLayer for &mut T { + async fn process_incoming( + &self, + exchange: &mut Exchange<'_>, + ) -> ControlFlow<()> { + T::process_incoming(self, exchange).await + } + + async fn process_outgoing(&self, response: OutgoingResponse<'_, '_>) { + T::process_outgoing(self, response).await + } +} + +#[cfg(feature = "std")] +impl ServiceLayer for Box { + async fn process_incoming( + &self, + exchange: &mut Exchange<'_>, + ) -> ControlFlow<()> { + T::process_incoming(self, exchange).await + } + + async fn process_outgoing(&self, response: OutgoingResponse<'_, '_>) { + T::process_outgoing(self, response).await + } +} + +#[cfg(feature = "std")] +impl ServiceLayer for Arc { + async fn process_incoming( + &self, + exchange: &mut Exchange<'_>, + ) -> ControlFlow<()> { + T::process_incoming(self, exchange).await + } + + async fn process_outgoing(&self, response: OutgoingResponse<'_, '_>) { + T::process_outgoing(self, response).await + } +} + +impl ServiceLayer for (A, B) +where + A: ServiceLayer, + B: ServiceLayer, +{ + async fn process_incoming( + &self, + exchange: &mut Exchange<'_>, + ) -> ControlFlow<()> { + self.0.process_incoming(exchange).await?; + self.1.process_incoming(exchange).await + } + + async fn process_outgoing(&self, mut response: OutgoingResponse<'_, '_>) { + self.1.process_outgoing(response.reborrow()).await; + self.0.process_outgoing(response.reborrow()).await + } +} + +macro_rules! impl_service_layer_tuple { + ($first:ident .. $last:ident: $($middle:ident)+) => { + impl<$first, $($middle,)+ $last: ?Sized> + ServiceLayer for ($first, $($middle,)+ $last) + where + $first: ServiceLayer, + $($middle: ServiceLayer,)+ + $last: ServiceLayer, + { + async fn process_incoming( + &self, + exchange: &mut Exchange<'_>, + ) -> ControlFlow<()> + { + #[allow(non_snake_case)] + let ($first, $($middle,)+ ref $last) = self; + $first.process_incoming(exchange).await?; + $($middle.process_incoming(exchange).await?;)+ + $last.process_incoming(exchange).await + } + + async fn process_outgoing( + &self, + response: OutgoingResponse<'_, '_>, + ) { + #[allow(non_snake_case)] + let ($first, $($middle,)+ ref $last) = self; + ($first, ($($middle,)+ $last)) + .process_outgoing(response).await + } + } + } +} + +impl_service_layer_tuple!(A..C: B); +impl_service_layer_tuple!(A..D: B C); +impl_service_layer_tuple!(A..E: B C D); +impl_service_layer_tuple!(A..F: B C D E); +impl_service_layer_tuple!(A..G: B C D E F); +impl_service_layer_tuple!(A..H: B C D E F G); +impl_service_layer_tuple!(A..I: B C D E F G H); +impl_service_layer_tuple!(A..J: B C D E F G H I); +impl_service_layer_tuple!(A..K: B C D E F G H I J); +impl_service_layer_tuple!(A..L: B C D E F G H I J K); + +#[cfg(feature = "std")] +impl ServiceLayer for [T] { + async fn process_incoming( + &self, + exchange: &mut Exchange<'_>, + ) -> ControlFlow<()> { + for layer in self { + layer.process_incoming(exchange).await?; + } + ControlFlow::Continue(()) + } + + async fn process_outgoing(&self, mut response: OutgoingResponse<'_, '_>) { + for layer in self.iter().rev() { + layer.process_outgoing(response.reborrow()).await; + } + } +} + +#[cfg(feature = "std")] +impl ServiceLayer for Vec { + async fn process_incoming( + &self, + exchange: &mut Exchange<'_>, + ) -> ControlFlow<()> { + self.as_slice().process_incoming(exchange).await + } + + async fn process_outgoing(&self, response: OutgoingResponse<'_, '_>) { + self.as_slice().process_outgoing(response).await + } +} + +//----------- impl LocalServiceLayer ----------------------------------------- + +impl LocalServiceLayer for &T { + async fn process_local_incoming( + &self, + exchange: &mut Exchange<'_>, + ) -> ControlFlow<()> { + T::process_local_incoming(self, exchange).await + } + + async fn process_local_outgoing( + &self, + response: OutgoingResponse<'_, '_>, + ) { + T::process_local_outgoing(self, response).await + } +} + +impl LocalServiceLayer for &mut T { + async fn process_local_incoming( + &self, + exchange: &mut Exchange<'_>, + ) -> ControlFlow<()> { + T::process_local_incoming(self, exchange).await + } + + async fn process_local_outgoing( + &self, + response: OutgoingResponse<'_, '_>, + ) { + T::process_local_outgoing(self, response).await + } +} + +#[cfg(feature = "std")] +impl LocalServiceLayer for Box { + async fn process_local_incoming( + &self, + exchange: &mut Exchange<'_>, + ) -> ControlFlow<()> { + T::process_local_incoming(self, exchange).await + } + + async fn process_local_outgoing( + &self, + response: OutgoingResponse<'_, '_>, + ) { + T::process_local_outgoing(self, response).await + } +} + +#[cfg(feature = "std")] +impl LocalServiceLayer for Rc { + async fn process_local_incoming( + &self, + exchange: &mut Exchange<'_>, + ) -> ControlFlow<()> { + T::process_local_incoming(self, exchange).await + } + + async fn process_local_outgoing( + &self, + response: OutgoingResponse<'_, '_>, + ) { + T::process_local_outgoing(self, response).await + } +} + +#[cfg(feature = "std")] +impl LocalServiceLayer for Arc { + async fn process_local_incoming( + &self, + exchange: &mut Exchange<'_>, + ) -> ControlFlow<()> { + T::process_local_incoming(self, exchange).await + } + + async fn process_local_outgoing( + &self, + response: OutgoingResponse<'_, '_>, + ) { + T::process_local_outgoing(self, response).await + } +} + +impl LocalServiceLayer for (A, B) +where + A: LocalServiceLayer, + B: LocalServiceLayer, +{ + async fn process_local_incoming( + &self, + exchange: &mut Exchange<'_>, + ) -> ControlFlow<()> { + self.0.process_local_incoming(exchange).await?; + self.1.process_local_incoming(exchange).await?; + ControlFlow::Continue(()) + } + + async fn process_local_outgoing( + &self, + mut response: OutgoingResponse<'_, '_>, + ) { + self.1.process_local_outgoing(response.reborrow()).await; + self.0.process_local_outgoing(response.reborrow()).await + } +} + +macro_rules! impl_local_service_layer_tuple { + ($first:ident .. $last:ident: $($middle:ident)+) => { + impl<$first, $($middle,)+ $last: ?Sized> + LocalServiceLayer for ($first, $($middle,)+ $last) + where + $first: LocalServiceLayer, + $($middle: LocalServiceLayer,)+ + $last: LocalServiceLayer, + { + async fn process_local_incoming( + &self, + exchange: &mut Exchange<'_> + ) -> ControlFlow<()> { + #[allow(non_snake_case)] + let ($first, $($middle,)+ ref $last) = self; + $first.process_local_incoming(exchange).await?; + $($middle.process_local_incoming(exchange).await?;)+ + $last.process_local_incoming(exchange).await + } + + async fn process_local_outgoing( + &self, + response: OutgoingResponse<'_, '_> + ) { + #[allow(non_snake_case)] + let ($first, $($middle,)+ ref $last) = self; + ($first, ($($middle,)+ $last)) + .process_local_outgoing(response).await + } + } + } +} + +impl_local_service_layer_tuple!(A..C: B); +impl_local_service_layer_tuple!(A..D: B C); +impl_local_service_layer_tuple!(A..E: B C D); +impl_local_service_layer_tuple!(A..F: B C D E); +impl_local_service_layer_tuple!(A..G: B C D E F); +impl_local_service_layer_tuple!(A..H: B C D E F G); +impl_local_service_layer_tuple!(A..I: B C D E F G H); +impl_local_service_layer_tuple!(A..J: B C D E F G H I); +impl_local_service_layer_tuple!(A..K: B C D E F G H I J); +impl_local_service_layer_tuple!(A..L: B C D E F G H I J K); + +#[cfg(feature = "std")] +impl LocalServiceLayer for [T] { + async fn process_local_incoming( + &self, + exchange: &mut Exchange<'_>, + ) -> ControlFlow<()> { + for layer in self { + layer.process_local_incoming(exchange).await?; + } + ControlFlow::Continue(()) + } + + async fn process_local_outgoing( + &self, + mut response: OutgoingResponse<'_, '_>, + ) { + for layer in self.iter().rev() { + layer.process_local_outgoing(response.reborrow()).await; + } + } +} + +#[cfg(feature = "std")] +impl LocalServiceLayer for Vec { + async fn process_local_incoming( + &self, + exchange: &mut Exchange<'_>, + ) -> ControlFlow<()> { + self.as_slice().process_local_incoming(exchange).await + } + + async fn process_local_outgoing( + &self, + response: OutgoingResponse<'_, '_>, + ) { + self.as_slice().process_local_outgoing(response).await + } +} diff --git a/src/new_server/layers/cookie.rs b/src/new_server/layers/cookie.rs new file mode 100644 index 000000000..7bd8e00aa --- /dev/null +++ b/src/new_server/layers/cookie.rs @@ -0,0 +1,484 @@ +//! DNS cookie management. + +use core::{ + net::{IpAddr, Ipv4Addr, Ipv6Addr}, + ops::{ControlFlow, Range}, +}; + +use std::{sync::Arc, vec::Vec}; + +use arc_swap::ArcSwap; +use log::trace; +use rand::{CryptoRng, Rng, RngCore}; + +use crate::{ + new_base::{ + wire::{AsBytes, ParseBytesByRef}, + Serial, + }, + new_edns::{ + ClientCookie, Cookie, CookieBuf, CookieError, EdnsOption, OptionCode, + }, + new_server::{ + exchange::{Metadata, OutgoingResponse, ResponseCode}, + transport::{SourceIpAddr, UdpMetadata}, + Exchange, LocalServiceLayer, ServiceLayer, + }, +}; + +//----------- CookieLayer ---------------------------------------------------- + +/// Server-side DNS cookie management. +#[derive(Debug)] +pub struct CookieLayer { + /// The cookie policy to use. + policy: ArcSwap, + + /// The secrets to use for signing and verifying. + secrets: ArcSwap, +} + +//--- Interaction + +impl CookieLayer { + /// Construct a new [`CookieLayer`]. + pub fn new(policy: CookiePolicy, secrets: CookieSecrets) -> Self { + Self { + policy: ArcSwap::new(Arc::new(policy)), + secrets: ArcSwap::new(Arc::new(secrets)), + } + } + + /// Load the cookie policy. + /// + /// The current state of the policy is loaded. The policy may be changed + /// by a different thread, so future calls to the method may result in + /// different policies. + pub fn get_policy(&self) -> Arc { + self.policy.load_full() + } + + /// Replace the cookie policy. + /// + /// This will atomically update the policy, so that future callers of + /// [`get_policy()`](Self::get_policy()) will (soon but not necessarily + /// immediately) see the updated policy. + pub fn set_policy(&self, policy: CookiePolicy) { + self.policy.store(Arc::new(policy)); + } + + /// Load the cookie secrets. + /// + /// The current state of the secrets is loaded. The secrets may be + /// changed by a different thread, so future calls to the method may + /// result in different secrets. + pub fn get_secrets(&self) -> Arc { + self.secrets.load_full() + } + + /// Replace the cookie secrets. + /// + /// This will atomically update the secrets, so that future callers of + /// [`get_secrets()`](Self::get_secrets()) will (soon but not necessarily + /// immediately) see the updated secrets. + pub fn set_secrets(&self, secrets: CookieSecrets) { + self.secrets.store(Arc::new(secrets)); + } +} + +//--- Processing incoming requests + +impl CookieLayer { + /// Respond to an incoming request with an alleged server cookie. + fn process_incoming_server_cookie<'a>( + &self, + exchange: &mut Exchange<'a>, + addr: IpAddr, + cookie: &'a Cookie, + ) -> ControlFlow<()> { + // Determine the validity period of the cookie. + let now = Serial::unix_time(); + let validity = now + -300..now + 3600; + + // Check if the cookie is actually valid. + if self.secrets.load().verify(&addr, validity, cookie).is_err() { + trace!(target: "CookieLayer", + "Ignoring invalid server cookie in request {}", + exchange.request.id); + + // Simply ignore the server part. + return self.process_incoming_wo_server_cookie( + exchange, + addr, + Some(cookie.request()), + ); + } + + trace!(target: "CookieLayer", + "Validated server cookie in request {}", + exchange.request.id); + + // Determine whether the cookie needs to be renewed. + let expiry = now + 1800; + let regenerate = cookie.timestamp() >= expiry; + + // Remember the cookie status. + let cookie = CookieBuf::copy_from(cookie); + let metadata = CookieMetadata::ServerCookie { cookie, regenerate }; + exchange.metadata.push(Metadata::new(metadata)); + + // Continue into the next layer. + ControlFlow::Continue(()) + } + + /// Respond to an incoming request without a (valid) server cookie. + fn process_incoming_wo_server_cookie<'a>( + &self, + exchange: &mut Exchange<'a>, + addr: IpAddr, + cookie: Option<&'a ClientCookie>, + ) -> ControlFlow<()> { + // RFC 7873, section 5.2.3: + // + // > Servers MUST, at least occasionally, respond to such requests to + // > inform the client of the correct Server Cookie. This is + // > necessary so that such a client can bootstrap to the more secure + // > state where requests and responses have recognized Server Cookies + // > and Client Cookies. A server is not expected to maintain + // > per-client state to achieve this. For example, it could respond + // > to every Nth request across all clients. + + // We rate-limit requests based on the cookie policy. If the request + // originates from a restricted IP address, the request is allowed to + // continue with a small probability. All requests from unrestricted + // IP addresses are allowed to go through. All non-UDP requests are + // allowed to go through anyway. + if !exchange.metadata.iter().any(|m| m.is::()) + || !self.policy.load().is_required_for(addr) + || rand::thread_rng().gen_bool(0.05) + { + // The request is allowed to go through. + trace!(target: "CookieLayer", + "Allowing request {} regardless of missing/invalid server cookie", + exchange.request.id); + let metadata = match cookie { + Some(&cookie) => CookieMetadata::ClientCookie(cookie), + None => CookieMetadata::None, + }; + exchange.metadata.push(Metadata::new(metadata)); + return ControlFlow::Continue(()); + } + + // Block the request. + trace!(target: "CookieLayer", + "Blocking request {} due to missing/invalid server cookie", + exchange.request.id); + if exchange.request.has_edns() { + exchange.respond(ResponseCode::BadCookie); + } else { + exchange.respond(ResponseCode::Refused); + exchange.response.flags = + exchange.response.flags.set_truncated(true); + } + exchange + .response + .questions + .append(&mut exchange.request.questions); + + ControlFlow::Break(()) + } +} + +//--- Processing outgoing responses + +impl CookieLayer { + /// Generate an EDNS COOKIE option for a response. + fn generate_cookie( + &self, + addr: IpAddr, + cookie: ClientCookie, + ) -> CookieBuf { + cookie.respond(addr, &self.secrets.load().primary) + } +} + +//--- ServiceLayer + +impl ServiceLayer for CookieLayer { + async fn process_incoming( + &self, + exchange: &mut Exchange<'_>, + ) -> ControlFlow<()> { + // Check for an EDNS COOKIE option. + let cookie = exchange + .request + .options + .iter() + .find(|option| option.code() == OptionCode::COOKIE) + .cloned(); + + // Determine the IP address the request originated from. + let Some(&SourceIpAddr(addr)) = + exchange.metadata.iter().find_map(|m| m.try_as()) + else { + // We couldn't determine the source address. + // TODO: This is unexpected, log it. + return ControlFlow::Continue(()); + }; + + match cookie { + Some(EdnsOption::Cookie(cookie)) => { + self.process_incoming_server_cookie(exchange, addr, cookie) + } + + Some(EdnsOption::ClientCookie(cookie)) => self + .process_incoming_wo_server_cookie( + exchange, + addr, + Some(cookie), + ), + + None => { + self.process_incoming_wo_server_cookie(exchange, addr, None) + } + + _ => unreachable!(), + } + } + + async fn process_outgoing(&self, response: OutgoingResponse<'_, '_>) { + // Determine the IP address the request originated from. + let Some(&SourceIpAddr(addr)) = + response.metadata.iter().find_map(|m| m.try_as()) + else { + // We couldn't determine the source address. + // TODO: This is unexpected, log it. + return; + }; + + // Check for cookie metadata. + let cookie = match response.metadata.iter().find_map(|m| m.try_as()) { + // The request had a client cookie (and possibly an invalid server + // cookie). Generate a new server cookie and include it. + Some(CookieMetadata::ClientCookie(cookie)) => { + trace!(target: "CookieLayer", + "Generating cookie for response {}", + response.response.id); + self.generate_cookie(addr, *cookie) + } + + // The request had a server cookie that may need to be renewed. + Some(CookieMetadata::ServerCookie { cookie, regenerate }) => { + if *regenerate { + trace!(target: "CookieLayer", + "Refreshing cookie for response {}", + response.response.id); + self.generate_cookie(addr, *cookie.request()) + } else { + trace!(target: "CookieLayer", + "Using existing server cookie for response {}", + response.response.id); + cookie.clone() + } + } + + // The request did not contain a cookie, or the cookie layer was + // disabled when answering this request. + Some(CookieMetadata::None) | None => return, + }; + + // Copy the cookie into the response. + // TODO: Check that the response includes an EDNS record. + let cookie = response.alloc.alloc_slice_copy((*cookie).as_bytes()); + let cookie = Cookie::parse_bytes_by_ref(cookie).unwrap(); + let option = EdnsOption::Cookie(cookie); + response.response.options.push(option); + } +} + +impl LocalServiceLayer for CookieLayer { + async fn process_local_incoming( + &self, + exchange: &mut Exchange<'_>, + ) -> ControlFlow<()> { + self.process_incoming(exchange).await + } + + async fn process_local_outgoing( + &self, + response: OutgoingResponse<'_, '_>, + ) { + self.process_outgoing(response).await + } +} + +//----------- CookiePolicy --------------------------------------------------- + +/// Configuration for DNS cookie enforcement. +#[derive(Clone, Debug, Default)] +pub struct CookiePolicy { + /// IP addresses that must provide DNS cookies with their queries. + pub required: PrefixTree, + + /// IP addresses that need not provide DNS cookies with their queries. + pub allowed: PrefixTree, +} + +impl CookiePolicy { + /// Whether an IP address is required to use DNS cookies. + pub fn is_required_for(&self, addr: IpAddr) -> bool { + match (self.required.test(addr), self.allowed.test(addr)) { + // The address is restricted, but is more specifically allowed. + (Some(r), Some(a)) if a >= r => false, + + // The address is definitely restricted. + (Some(_), _) => true, + + // There are no restrictions on the address. + (None, _) => true, + } + } +} + +//----------- CookieSecrets -------------------------------------------------- + +/// The secrets used for DNS cookies. +#[derive(Clone, Debug)] +pub struct CookieSecrets { + /// The primary secret (used for generation and verification). + pub primary: [u8; 16], + + /// A secondary secret for verification. + pub secondary: [u8; 16], +} + +impl CookieSecrets { + /// Initialize [`CookieSecrets`] with a random primary. + pub fn generate() -> Self { + Self::generate_with(rand::thread_rng()) + } + + /// Initialize [`CookieSecrets`] with the given RNG. + pub fn generate_with(mut rng: impl CryptoRng + RngCore) -> Self { + let primary = rng.gen(); + Self { + primary, + secondary: primary, + } + } + + /// Verify the given cookie against these secrets. + fn verify( + &self, + addr: &IpAddr, + validity: Range, + cookie: &Cookie, + ) -> Result<(), CookieError> { + let Err(err) = cookie.verify(*addr, &self.primary, validity.clone()) + else { + return Ok(()); + }; + + // TODO: Compare secrets more carefully. + if self.primary == self.secondary { + return Err(err); + } + + cookie.verify(*addr, &self.secondary, validity) + } +} + +//----------- CookieMetadata ------------------------------------------------- + +/// Information about a DNS request's use of cookies. +#[derive(Clone, Debug)] +pub enum CookieMetadata { + /// The request did not use DNS cookies. + None, + + /// The request included a DNS client cookie. + ClientCookie(ClientCookie), + + /// The request included a DNS server cookie. + ServerCookie { + /// The cookie used in the request. + cookie: CookieBuf, + + /// Whether a new cookie should be generated. + regenerate: bool, + }, +} + +//----------- PrefixTree ----------------------------------------------------- + +/// A set of IP addresses represented as prefixes. +#[derive(Clone, Debug, Default)] +pub struct PrefixTree { + /// A list of v4 prefixes, from longest to shortest. + v4_prefixes: Vec<(u8, Ipv4Addr)>, + + /// A list of v6 prefixes, from longest to shortest. + v6_prefixes: Vec<(u8, Ipv6Addr)>, +} + +impl PrefixTree { + /// Build a [`PrefixTree`] from an unsorted list of prefixes. + /// + /// The prefixes will be sorted before being used. Outside the valid + /// length of each prefix, only zero bits must be used. + pub fn from_prefixes( + mut v4_prefixes: Vec<(u8, Ipv4Addr)>, + mut v6_prefixes: Vec<(u8, Ipv6Addr)>, + ) -> Self { + v4_prefixes.sort_unstable_by(|a, b| a.0.cmp(&b.0).reverse()); + v6_prefixes.sort_unstable_by(|a, b| a.0.cmp(&b.0).reverse()); + Self::from_sorted_prefixes(v4_prefixes, v6_prefixes) + } + + /// Build a [`PrefixTree`] from a sorted list of prefixes. + /// + /// The prefixes must be sorted from longest to shortest. Within a + /// particular prefix length, the addresses are unordered. Outside the + /// valid length of each prefix, only zero bits must be used. + pub fn from_sorted_prefixes( + v4_prefixes: Vec<(u8, Ipv4Addr)>, + v6_prefixes: Vec<(u8, Ipv6Addr)>, + ) -> Self { + Self { + v4_prefixes, + v6_prefixes, + } + } + + /// Test whether an IP address is in this prefix tree. + /// + /// If a matching prefix is found, its length is returned. + pub fn test(&self, addr: IpAddr) -> Option { + match addr { + IpAddr::V4(addr) => self.test_v4(addr), + IpAddr::V6(addr) => self.test_v6(addr), + } + } + + /// Test whether an IPv4 address is in this prefix tree. + /// + /// If a matching prefix is found, its length is returned. + pub fn test_v4(&self, addr: Ipv4Addr) -> Option { + self.v4_prefixes + .iter() + .copied() + .find(|(_, prefix)| (prefix & addr) == *prefix) + .map(|(length, _)| length) + } + + /// Test whether an IPv6 address is in this prefix tree. + /// + /// If a matching prefix is found, its length is returned. + pub fn test_v6(&self, addr: Ipv6Addr) -> Option { + self.v6_prefixes + .iter() + .copied() + .find(|(_, prefix)| (prefix & addr) == *prefix) + .map(|(length, _)| length) + } +} diff --git a/src/new_server/layers/mod.rs b/src/new_server/layers/mod.rs new file mode 100644 index 000000000..885e64001 --- /dev/null +++ b/src/new_server/layers/mod.rs @@ -0,0 +1,4 @@ +//! Common plug-in functionality for DNS servers. + +pub mod cookie; +pub use cookie::CookieLayer; diff --git a/src/new_server/mod.rs b/src/new_server/mod.rs new file mode 100644 index 000000000..d9aef6c1a --- /dev/null +++ b/src/new_server/mod.rs @@ -0,0 +1,149 @@ +//! Responding to DNS requests. +//! +//! # Architecture +//! +//! A _transport_ implements a network interface allowing it to receive DNS +//! requests and return DNS responses. Transports can be implemented on UDP, +//! TCP, TLS, etc., and users can implement their own transports. +//! +//! A _service_ implements the business logic of handling a DNS request and +//! building a DNS response. A service can be composed of multiple _layers_, +//! each of which can inspect the request and prepare part of the response. +//! Many common layers are already implemented, but users can define more. + +#![cfg(feature = "unstable-server-transport")] +#![cfg_attr(docsrs, doc(cfg(feature = "unstable-server-transport")))] + +use core::{future::Future, ops::ControlFlow}; + +mod impls; + +pub mod exchange; +pub use exchange::Exchange; +use exchange::OutgoingResponse; + +pub mod transport; + +pub mod layers; + +//----------- Service -------------------------------------------------------- + +/// A (multi-threaded) DNS service, that computes responses for requests. +/// +/// Given a DNS request message, a service computes an appropriate response. +/// Services are usually wrapped in a network transport that receives requests +/// and returns the service's responses. +/// +/// Use [`LocalService`] for a single-threaded equivalent. +/// +/// # Layering +/// +/// Additional functionality can be added to a service by prefixing it with +/// service layers, usually in a tuple. A number of blanket implementations +/// are provided to simplify this. +pub trait Service: LocalService + Sync { + /// Respond to a DNS request. + /// + /// The returned [`Future`] is thread-safe; it implements [`Send`]. Use + /// [`LocalService::respond_local()`] if this is not necessary. + fn respond( + &self, + exchange: &mut Exchange<'_>, + ) -> impl Future + Send; +} + +//----------- LocalService --------------------------------------------------- + +/// A (single-threaded) DNS service, that computes responses for requests. +/// +/// Given a DNS request message, a service computes an appropriate response. +/// Services are usually wrapped in a network transport that receives requests +/// and returns the service's responses. +/// +/// Use [`Service`] for a multi-threaded equivalent. +/// +/// # Layering +/// +/// Additional functionality can be added to a service by prefixing it with +/// service layers, usually in a tuple. A number of blanket implementations +/// are provided to simplify this. +pub trait LocalService { + /// Respond to a DNS request. + /// + /// The returned [`Future`] is thread-local; it does not implement + /// [`Send`]. Use [`Service::respond()`] for a thread-safe alternative. + fn respond_local( + &self, + exchange: &mut Exchange<'_>, + ) -> impl Future; +} + +//----------- ServiceLayer --------------------------------------------------- + +/// A (multi-threaded) layer wrapping a DNS [`Service`]. +/// +/// A layer can be wrapped around a service, inspecting the requests sent to +/// it and transforming the responses returned by it. +/// +/// Use [`LocalServiceLayer`] for a single-threaded equivalent. +/// +/// # Combinations +/// +/// Layers can be combined (usually in a tuple) into larger layers. A number +/// of blanket implementations are provided to simplify this. +pub trait ServiceLayer: LocalServiceLayer + Sync { + /// Process an incoming DNS request. + /// + /// The returned [`Future`] is thread-safe; it implements [`Send`]. Use + /// [`LocalServiceLayer::process_local_incoming()`] if this is not + /// necessary. + fn process_incoming( + &self, + exchange: &mut Exchange<'_>, + ) -> impl Future> + Send; + + /// Process an outgoing DNS response. + /// + /// The returned [`Future`] is thread-safe; it implements [`Send`]. Use + /// [`LocalServiceLayer::process_local_outgoing()`] if this is not + /// necessary. + fn process_outgoing( + &self, + response: OutgoingResponse<'_, '_>, + ) -> impl Future + Send; +} + +//----------- LocalServiceLayer ---------------------------------------------- + +/// A (single-threaded) layer wrapping a DNS [`Service`]. +/// +/// A layer can be wrapped around a service, inspecting the requests sent to +/// it and transforming the responses returned by it. +/// +/// Use [`ServiceLayer`] for a multi-threaded equivalent. +/// +/// # Combinations +/// +/// Layers can be combined (usually in a tuple) into larger layers. A number +/// of blanket implementations are provided to simplify this. +pub trait LocalServiceLayer { + /// Process an incoming DNS request. + /// + /// The returned [`Future`] is thread-local; it does not implement + /// [`Send`]. Use [`ServiceLayer::process_incoming()`] for a thread-safe + /// alternative. + fn process_local_incoming( + &self, + exchange: &mut Exchange<'_>, + ) -> impl Future>; + + /// Process an outgoing DNS response. + /// + /// The returned [`Future`] is thread-local; it does not implement + /// [`Send`]. Use [`ServiceLayer::process_outgoing()`] for a thread-safe + /// alternative. + fn process_local_outgoing( + &self, + response: OutgoingResponse<'_, '_>, + ) -> impl Future; +} diff --git a/src/new_server/transport/mod.rs b/src/new_server/transport/mod.rs new file mode 100644 index 000000000..5aae60dc4 --- /dev/null +++ b/src/new_server/transport/mod.rs @@ -0,0 +1,144 @@ +//! Network transports for DNS servers. + +use core::net::{IpAddr, SocketAddr}; +use std::{io, sync::Arc, time::SystemTime, vec::Vec}; + +use bumpalo::Bump; +use log::trace; +use tokio::net::UdpSocket; + +use crate::{ + new_base::{ + wire::{AsBytes, ParseBytesByRef}, + Message, + }, + new_server::exchange::{Allocator, Metadata}, +}; + +use super::{exchange::ParsedMessage, Exchange, Service}; + +//----------- serve_udp() ---------------------------------------------------- + +/// Serve DNS requests over UDP. +/// +/// A UDP socket will be bound to the given address and listened on for DNS +/// requests. Requests will be handed to the given [`Service`] and responses +/// will be returned directly. Each DNS request is handed off to a Tokio task +/// so they can respond asynchronously. +pub async fn serve_udp( + addr: SocketAddr, + service: impl Service + Send + 'static, +) -> io::Result<()> { + /// Internal multi-threaded state. + struct State { + /// The UDP socket serving DNS. + socket: UdpSocket, + + /// The service implementing response logic. + service: S, + } + + impl State { + /// Respond to a particular UDP request. + async fn respond(self: Arc, buffer: Vec, peer: SocketAddr) { + let Ok(message) = Message::parse_bytes_by_ref(&buffer) else { + // This message is fundamentally invalid, just give up. + return; + }; + + let mut allocator = Bump::new(); + let mut allocator = Allocator::new(&mut allocator); + + let Ok(request) = ParsedMessage::parse(message, &mut allocator) + else { + // This message is malformed; inform the peer and stop. + let mut buffer = [0u8; 12]; + let response = Message::parse_bytes_by_mut(&mut buffer) + .expect("Any 12-byte or larger buffer is a 'Message'"); + response.header.id = message.header.id; + response.header.flags = message.header.flags.respond(1); + let response = response.slice_to(0); + let _ = self.socket.send_to(response.as_bytes(), peer).await; + return; + }; + + // Build a complete 'Exchange' around the request. + let mut exchange = Exchange { + alloc: allocator, + reception: SystemTime::now(), + request, + response: ParsedMessage::default(), + metadata: vec![Metadata::new(SourceIpAddr(peer.ip()))], + }; + + trace!(target: "serve_udp", + "Received request {} from {peer}", + exchange.request.id); + + // Generate the appropriate response. + self.service.respond(&mut exchange).await; + + trace!(target: "serve_udp", + "Sending response {} to {peer}", + exchange.response.id); + + // Build up the response message. + let mut buffer = vec![0u8; 65536]; + let message = + exchange.response.build(&mut buffer).unwrap_or_else(|_| { + todo!("how to handle truncation errors?") + }); + + // Send the response back to the peer. + let _ = self.socket.send_to(message.as_bytes(), peer).await; + } + } + + // Generate internal state. + let state = Arc::new(State { + socket: UdpSocket::bind(addr).await?, + service, + }); + + // Main loop: wait on new requests. + loop { + // Allocate a buffer for the request. + let mut buffer = vec![0u8; 65536]; + + // Receive a DNS request. + let (size, peer) = state.socket.recv_from(&mut buffer).await?; + buffer.truncate(size); + + // Spawn a Tokio task to respond to the request. + tokio::task::spawn(state.clone().respond(buffer, peer)); + } +} + +//----------- SourceIpAddr --------------------------------------------------- + +/// The IP address a DNS request originated from. +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +pub struct SourceIpAddr(pub IpAddr); + +impl From for SourceIpAddr { + fn from(value: IpAddr) -> Self { + Self(value) + } +} + +impl From for IpAddr { + fn from(value: SourceIpAddr) -> Self { + value.0 + } +} + +//----------- UdpMetadata ---------------------------------------------------- + +/// Information about a DNS request on a UDP socket. +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +pub struct UdpMetadata { + /// The UDP port the request originated from. + /// + /// Use [`SourceIpAddr`] to determine the associated IP address. + pub port: u16, +}