Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zone updater should work with any record type. #486

Open
wants to merge 6 commits into
base: server-trait-usability-fixes
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 0 additions & 16 deletions src/net/server/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,22 +98,6 @@ pub type ServiceResult<Target> = Result<CallResult<Target>, ServiceError>;
/// }
/// }
///
/// //------------ An anonymous async block service example -------------------
/// struct MyAsyncBlockService;
///
/// impl Service<Vec<u8>, ()> for MyAsyncBlockService {
/// type Target = Vec<u8>;
/// type Stream = Once<Ready<ServiceResult<Self::Target>>>;
/// type Future = Pin<Box<dyn std::future::Future<Output = Self::Stream>>>;
///
/// fn call(
/// &self,
/// msg: Request<Vec<u8>, ()>,
/// ) -> Self::Future {
/// Box::pin(async move { mk_response_stream(&msg) })
/// }
/// }
///
/// //------------ A named Future service example -----------------------------
/// struct MyFut(Request<Vec<u8>, ()>);
///
Expand Down
97 changes: 62 additions & 35 deletions src/zonetree/update.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,17 @@
//! content of zones without requiring knowledge of the low-level details of
//! how the [`WritableZone`] trait implemented by [`Zone`] works.
use core::future::Future;
use core::marker::PhantomData;
use core::pin::Pin;

use std::borrow::ToOwned;
use std::boxed::Box;

use bytes::Bytes;
use tracing::trace;

use crate::base::name::{FlattenInto, Label};
use crate::base::{ParsedName, Record, Rtype};
use crate::net::xfr::protocol::ParsedRecord;
use crate::base::scan::ScannerError;
use crate::base::{Name, Record, Rtype, ToName};
use crate::rdata::ZoneRecordData;
use crate::zonetree::{Rrset, SharedRrset};

Expand Down Expand Up @@ -51,7 +51,7 @@ use super::{InMemoryZoneDiff, WritableZone, WritableZoneNode, Zone};
///
/// ```
/// # use std::str::FromStr;
/// #
/// # use bytes::Bytes;
/// # use domain::base::iana::Class;
/// # use domain::base::MessageBuilder;
/// # use domain::base::Name;
Expand All @@ -76,8 +76,8 @@ use super::{InMemoryZoneDiff, WritableZone, WritableZoneNode, Zone};
/// #
/// # // Prepare some records to pass to ZoneUpdater
/// # let serial = Serial::now();
/// # let mname = ParsedName::from(Name::from_str("mname").unwrap());
/// # let rname = ParsedName::from(Name::from_str("rname").unwrap());
/// # let mname = ParsedName::from(Name::<Bytes>::from_str("mname").unwrap());
/// # let rname = ParsedName::from(Name::<Bytes>::from_str("rname").unwrap());
/// # let ttl = Ttl::from_secs(0);
/// # let new_soa_rec = Record::new(
/// # ParsedName::from(Name::from_str("example.com").unwrap()),
Expand Down Expand Up @@ -106,7 +106,7 @@ use super::{InMemoryZoneDiff, WritableZone, WritableZoneNode, Zone};
///
/// ```rust
/// # use std::str::FromStr;
/// #
/// # use bytes::Bytes;
/// # use domain::base::iana::Class;
/// # use domain::base::MessageBuilder;
/// # use domain::base::Name;
Expand All @@ -133,8 +133,8 @@ use super::{InMemoryZoneDiff, WritableZone, WritableZoneNode, Zone};
/// #
/// # // Prepare some records to pass to ZoneUpdater
/// # let serial = Serial::now();
/// # let mname = ParsedName::from(Name::from_str("mname").unwrap());
/// # let rname = ParsedName::from(Name::from_str("rname").unwrap());
/// # let mname = ParsedName::from(Name::<Bytes>::from_str("mname").unwrap());
/// # let rname = ParsedName::from(Name::<Bytes>::from_str("rname").unwrap());
/// # let ttl = Ttl::from_secs(0);
/// # let new_soa_rec = Record::new(
/// # ParsedName::from(Name::from_str("example.com").unwrap()),
Expand Down Expand Up @@ -214,7 +214,7 @@ use super::{InMemoryZoneDiff, WritableZone, WritableZoneNode, Zone};
/// ```
///
/// [`apply()`]: ZoneUpdater::apply()
pub struct ZoneUpdater {
pub struct ZoneUpdater<N> {
/// The zone to be updated.
zone: Zone,

Expand All @@ -226,9 +226,15 @@ pub struct ZoneUpdater {

/// The current state of the updater.
state: ZoneUpdaterState,

_phantom: PhantomData<N>,
}

impl ZoneUpdater {
impl<N> ZoneUpdater<N>
where
N: ToName + Clone,
ZoneRecordData<Bytes, N>: FlattenInto<ZoneRecordData<Bytes, Name<Bytes>>>,
{
/// Creates a new [`ZoneUpdater`] that will update the given [`Zone`]
/// content.
///
Expand All @@ -246,12 +252,17 @@ impl ZoneUpdater {
zone,
write,
state: Default::default(),
_phantom: PhantomData,
})
})
}
}

impl ZoneUpdater {
impl<N> ZoneUpdater<N>
where
N: ToName + Clone,
ZoneRecordData<Bytes, N>: FlattenInto<ZoneRecordData<Bytes, Name<Bytes>>>,
{
/// Apply the given [`ZoneUpdate`] to the [`Zone`] being updated.
///
/// Returns `Ok` on success, `Err` otherwise. On success, if changes were
Expand All @@ -266,7 +277,7 @@ impl ZoneUpdater {
/// progress and re-open the zone for editing again.
pub async fn apply(
&mut self,
update: ZoneUpdate<ParsedRecord>,
update: ZoneUpdate<Record<N, ZoneRecordData<Bytes, N>>>,
) -> Result<Option<InMemoryZoneDiff>, Error> {
trace!("Update: {update}");

Expand Down Expand Up @@ -344,7 +355,11 @@ impl ZoneUpdater {
}
}

impl ZoneUpdater {
impl<N> ZoneUpdater<N>
where
N: ToName + Clone,
ZoneRecordData<Bytes, N>: FlattenInto<ZoneRecordData<Bytes, Name<Bytes>>>,
{
/// Given a zone record, obtain a [`WritableZoneNode`] for the owner.
///
/// A [`Zone`] is a tree structure which can be modified by descending the
Expand All @@ -364,7 +379,7 @@ impl ZoneUpdater {
/// the record owner name.
async fn get_writable_child_node_for_owner(
&mut self,
rec: &ParsedRecord,
rec: &Record<N, ZoneRecordData<Bytes, N>>,
) -> Result<Option<Box<dyn WritableZoneNode>>, Error> {
let mut it = rel_name_rev_iter(self.zone.apex_name(), rec.owner())?;

Expand All @@ -386,17 +401,19 @@ impl ZoneUpdater {
/// Create or update the SOA RRset using the given SOA record.
async fn update_soa(
&mut self,
new_soa: Record<
ParsedName<Bytes>,
ZoneRecordData<Bytes, ParsedName<Bytes>>,
>,
new_soa: Record<N, ZoneRecordData<Bytes, N>>,
) -> Result<(), Error> {
if new_soa.rtype() != Rtype::SOA {
return Err(Error::NotSoaRecord);
}

let mut rrset = Rrset::new(Rtype::SOA, new_soa.ttl());
rrset.push_data(new_soa.data().to_owned().flatten_into());
let Ok(flattened) = new_soa.data().clone().try_flatten_into() else {
return Err(Error::IoError(std::io::Error::custom(
"Unable to flatten bytes",
)));
};
rrset.push_data(flattened);
self.write
.update_root_rrset(SharedRrset::new(rrset))
.await?;
Expand All @@ -407,10 +424,7 @@ impl ZoneUpdater {
/// Find and delete a resource record in the zone by exact match.
async fn delete_record_from_rrset(
&mut self,
rec: Record<
ParsedName<Bytes>,
ZoneRecordData<Bytes, ParsedName<Bytes>>,
>,
rec: Record<N, ZoneRecordData<Bytes, N>>,
) -> Result<(), Error> {
// Find or create the point to edit in the node tree.
let tree_node = self.get_writable_child_node_for_owner(&rec).await?;
Expand Down Expand Up @@ -443,11 +457,12 @@ impl ZoneUpdater {
/// Add a resource record to a new or existing RRset.
async fn add_record_to_rrset(
&mut self,
rec: Record<
ParsedName<Bytes>,
ZoneRecordData<Bytes, ParsedName<Bytes>>,
>,
) -> Result<(), Error> {
rec: Record<N, ZoneRecordData<Bytes, N>>,
) -> Result<(), Error>
where
ZoneRecordData<Bytes, N>:
FlattenInto<ZoneRecordData<Bytes, Name<Bytes>>>,
{
// Find or create the point to edit in the node tree.
let tree_node = self.get_writable_child_node_for_owner(&rec).await?;
let tree_node = tree_node.as_ref().unwrap_or(self.write.root());
Expand All @@ -456,7 +471,11 @@ impl ZoneUpdater {
// RRset in the tree plus the one to add.
let mut rrset = Rrset::new(rec.rtype(), rec.ttl());
let rtype = rec.rtype();
let data = rec.into_data().flatten_into();
let Ok(data) = rec.into_data().try_flatten_into() else {
return Err(Error::IoError(std::io::Error::custom(
"Unable to flatten bytes",
)));
};

rrset.push_data(data);

Expand Down Expand Up @@ -904,7 +923,7 @@ mod tests {

// IN NS NS.JAIN.AD.JP.
let ns_1 = Record::new(
ParsedName::from(Name::from_str("JAIN.AD.JP.").unwrap()),
ParsedName::from(Name::<Bytes>::from_str("JAIN.AD.JP.").unwrap()),
Class::IN,
Ttl::from_secs(0),
Ns::new(ParsedName::from(
Expand All @@ -919,7 +938,9 @@ mod tests {

// NS.JAIN.AD.JP. IN A 133.69.136.1
let a_1 = Record::new(
ParsedName::from(Name::from_str("NS.JAIN.AD.JP.").unwrap()),
ParsedName::from(
Name::<Bytes>::from_str("NS.JAIN.AD.JP.").unwrap(),
),
Class::IN,
Ttl::from_secs(0),
A::new(Ipv4Addr::new(133, 69, 136, 1)).into(),
Expand All @@ -931,7 +952,9 @@ mod tests {

// NEZU.JAIN.AD.JP. IN A 133.69.136.5
let nezu = Record::new(
ParsedName::from(Name::from_str("NEZU.JAIN.AD.JP.").unwrap()),
ParsedName::from(
Name::<Bytes>::from_str("NEZU.JAIN.AD.JP.").unwrap(),
),
Class::IN,
Ttl::from_secs(0),
A::new(Ipv4Addr::new(133, 69, 136, 5)).into(),
Expand All @@ -956,7 +979,9 @@ mod tests {
.await
.unwrap();
let a_2 = Record::new(
ParsedName::from(Name::from_str("JAIN-BB.JAIN.AD.JP.").unwrap()),
ParsedName::from(
Name::<Bytes>::from_str("JAIN-BB.JAIN.AD.JP.").unwrap(),
),
Class::IN,
Ttl::from_secs(0),
A::new(Ipv4Addr::new(133, 69, 136, 4)).into(),
Expand Down Expand Up @@ -991,7 +1016,9 @@ mod tests {
.await
.unwrap();
let a_4 = Record::new(
ParsedName::from(Name::from_str("JAIN-BB.JAIN.AD.JP.").unwrap()),
ParsedName::from(
Name::<Bytes>::from_str("JAIN-BB.JAIN.AD.JP.").unwrap(),
),
Class::IN,
Ttl::from_secs(0),
A::new(Ipv4Addr::new(133, 69, 136, 3)).into(),
Expand Down