Skip to content

Commit

Permalink
refactor: minimize deserialize_newtype_struct
Browse files Browse the repository at this point in the history
Also update example because we do not change fdt binary anymore.

Signed-off-by: Woshiluo Luo <[email protected]>
  • Loading branch information
woshiluo committed Nov 18, 2024
1 parent c4574b4 commit 6913268
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 171 deletions.
2 changes: 1 addition & 1 deletion examples/qemu-virt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ fn main() -> Result<(), Error> {

// 解析过程中,设备树的内容被修改了。
// 因此若要以其他方式再次访问设备树,先将这次解析的结果释放。
assert_ne!(slice, RAW_DEVICE_TREE);
// assert_ne!(slice, RAW_DEVICE_TREE);
}
// 释放后,内存会恢复原状。
assert_eq!(slice, RAW_DEVICE_TREE);
Expand Down
10 changes: 0 additions & 10 deletions src/de_mut/cursor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,16 +246,6 @@ impl PropCursor {
))
}
}

pub fn operate_on(&self, dtb: RefDtb<'_>, f: impl FnOnce(&mut [u8])) {
if let [_, len_data, _, data @ ..] = &mut dtb.borrow_mut().structure[self.0..] {
f(unsafe {
core::slice::from_raw_parts_mut(data.as_mut_ptr() as _, len_data.as_usize())
});
} else {
todo!()
}
}
}

#[derive(Debug)]
Expand Down
32 changes: 1 addition & 31 deletions src/de_mut/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -239,37 +239,7 @@ impl<'de> de::Deserializer<'de> for &mut ValueDeserializer<'de> {
if name == super::VALUE_DESERIALIZER_NAME {
return visitor.visit_newtype_struct(self);
}
match self.cursor {
ValueCursor::Prop(_, cursor) => match name {
"StrSeq" => {
let inner = super::str_seq::Inner {
dtb: self.dtb,
cursor,
};
visitor.visit_borrowed_bytes(unsafe {
core::slice::from_raw_parts(
&inner as *const _ as *const u8,
core::mem::size_of_val(&inner),
)
})
}
"Reg" => {
let inner = super::reg::Inner {
dtb: self.dtb,
reg: self.reg,
cursor,
};
visitor.visit_borrowed_bytes(unsafe {
core::slice::from_raw_parts(
&inner as *const _ as *const u8,
core::mem::size_of_val(&inner),
)
})
}
_ => visitor.visit_newtype_struct(self),
},
ValueCursor::Body(_) => visitor.visit_newtype_struct(self),
}
unreachable!("unknown newtype struct");
}

fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
Expand Down
62 changes: 14 additions & 48 deletions src/de_mut/reg.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::{PropCursor, RefDtb, StructureBlock, BLOCK_LEN};
use core::{fmt::Debug, marker::PhantomData, mem::MaybeUninit, ops::Range};
use serde::{de, Deserialize};
use super::{PropCursor, RefDtb, StructureBlock, ValueCursor, BLOCK_LEN};
use core::{fmt::Debug, ops::Range};
use serde::Deserialize;

/// 节点地址空间。
pub struct Reg<'de>(Inner<'de>);
Expand Down Expand Up @@ -40,58 +40,24 @@ impl<'de> Deserialize<'de> for Reg<'_> {
where
D: serde::Deserializer<'de>,
{
struct Visitor<'de, 'b> {
marker: PhantomData<Reg<'b>>,
lifetime: PhantomData<&'de ()>,
}
impl<'de, 'b> de::Visitor<'de> for Visitor<'de, 'b> {
type Value = Reg<'b>;

fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result {
write!(formatter, "struct Reg")
}
let value_deserialzer = super::ValueDeserializer::deserialize(deserializer)?;

fn visit_borrowed_bytes<E>(self, v: &'de [u8]) -> Result<Self::Value, E>
where
E: de::Error,
{
// 结构体转为内存切片,然后拷贝过来
if v.len() == core::mem::size_of::<Self::Value>() {
Ok(Self::Value::from_raw_parts(v.as_ptr()))
} else {
Err(E::invalid_length(
v.len(),
&"`Reg` is copied with wrong size.",
))
let inner = Inner {
dtb: value_deserialzer.dtb,
reg: value_deserialzer.reg,
cursor: match value_deserialzer.cursor {
ValueCursor::Prop(_, cursor) => cursor,
_ => {
unreachable!("Reg Deserialize should only be called by prop cursor")
}
}
}

serde::Deserializer::deserialize_newtype_struct(
deserializer,
"Reg",
Visitor {
marker: PhantomData,
lifetime: PhantomData,
},
)
};

Ok(Self(inner))
}
}

impl Reg<'_> {
fn from_raw_parts(ptr: *const u8) -> Self {
// 直接从指针拷贝
unsafe {
let mut res = MaybeUninit::<Self>::uninit();
core::ptr::copy_nonoverlapping(
ptr,
res.as_mut_ptr() as *mut _,
core::mem::size_of::<Self>(),
);
res.assume_init()
}
}

pub fn iter(&self) -> RegIter {
RegIter {
data: self.0.cursor.data_on(self.0.dtb),
Expand Down
103 changes: 22 additions & 81 deletions src/de_mut/str_seq.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::{PropCursor, RefDtb};
use core::{fmt::Debug, marker::PhantomData, mem::MaybeUninit};
use serde::{de, Deserialize};
use super::{PropCursor, RefDtb, ValueCursor};
use core::fmt::Debug;
use serde::Deserialize;

/// 一组 '\0' 分隔字符串的映射。
///
Expand Down Expand Up @@ -31,70 +31,23 @@ impl<'de> Deserialize<'de> for StrSeq<'_> {
where
D: serde::Deserializer<'de>,
{
struct Visitor<'de, 'b> {
marker: PhantomData<StrSeq<'b>>,
lifetime: PhantomData<&'de ()>,
}
impl<'de, 'b> de::Visitor<'de> for Visitor<'de, 'b> {
type Value = StrSeq<'b>;

fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result {
write!(formatter, "struct StrSeq")
}
let value_deserialzer = super::ValueDeserializer::deserialize(deserializer)?;

fn visit_borrowed_bytes<E>(self, v: &'de [u8]) -> Result<Self::Value, E>
where
E: de::Error,
{
// 结构体转为内存切片,然后拷贝过来
if v.len() == core::mem::size_of::<Self::Value>() {
Ok(Self::Value::from_raw_parts(v.as_ptr()))
} else {
Err(E::invalid_length(
v.len(),
&"`StrSeq` is copied with wrong size.",
))
let inner = Inner {
dtb: value_deserialzer.dtb,
cursor: match value_deserialzer.cursor {
ValueCursor::Prop(_, cursor) => cursor,
_ => {
unreachable!("Reg Deserialize should only be called by prop cursor")
}
}
}

serde::Deserializer::deserialize_newtype_struct(
deserializer,
"StrSeq",
Visitor {
marker: PhantomData,
lifetime: PhantomData,
},
)
};

Ok(Self(inner))
}
}

impl StrSeq<'_> {
fn from_raw_parts(ptr: *const u8) -> Self {
// 直接从指针拷贝
let res = unsafe {
let mut res = MaybeUninit::<Self>::uninit();
core::ptr::copy_nonoverlapping(
ptr,
res.as_mut_ptr() as *mut _,
core::mem::size_of::<Self>(),
);
res.assume_init()
};
// 初始化
res.0.cursor.operate_on(res.0.dtb, |data| {
let mut i = data.len() - 1;
for j in (0..data.len() - 1).rev() {
if data[j] == b'\0' {
data[i] = (i - j - 1) as _;
i = j;
}
}
data[i] = i as u8;
});
res
}

/// 构造一个可访问每个字符串的迭代器。
pub fn iter(&self) -> StrSeqIter {
StrSeqIter {
Expand Down Expand Up @@ -125,27 +78,15 @@ impl<'de> Iterator for StrSeqIter<'de> {
if self.data.is_empty() {
None
} else {
let len = *self.data.last().unwrap() as usize;
let (a, b) = self.data.split_at(self.data.len() - len - 1);
self.data = a;
Some(unsafe { core::str::from_utf8_unchecked(&b[..len]) })
let pos = self
.data
.iter()
.position(|&x| x == b'\0')
.unwrap_or(self.data.len());
let (a, b) = self.data.split_at(pos + 1);
self.data = b;
// Remove \0 at end
Some(unsafe { core::str::from_utf8_unchecked(&a[..a.len() - 1]) })
}
}
}

impl Drop for StrSeq<'_> {
fn drop(&mut self) {
self.0.cursor.operate_on(self.0.dtb, |data| {
let mut idx = data.len() - 1;
loop {
let len = data[idx] as usize;
data[idx] = 0;
if idx > len {
idx -= len + 1;
} else {
break;
}
}
})
}
}

0 comments on commit 6913268

Please sign in to comment.