Skip to content

Commit

Permalink
#Rust Add buffer verification
Browse files Browse the repository at this point in the history
  • Loading branch information
candysonya committed Jan 14, 2025
1 parent 6f730de commit d1d04f4
Show file tree
Hide file tree
Showing 5 changed files with 481 additions and 30 deletions.
2 changes: 1 addition & 1 deletion rust/flatbuffers/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ pub use crate::push::{Push, PushAlignment};
pub use crate::table::{buffer_has_identifier, Table};
pub use crate::vector::{follow_cast_ref, Vector, VectorIter};
pub use crate::verifier::{
ErrorTraceDetail, InvalidFlatbuffer, SimpleToVerifyInSlice, Verifiable, Verifier,
ErrorTraceDetail, InvalidFlatbuffer, SimpleToVerifyInSlice, TableVerifier, Verifiable, Verifier,
VerifierOptions,
};
pub use crate::vtable::field_index_to_field_offset;
Expand Down
84 changes: 56 additions & 28 deletions rust/flatbuffers/src/verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ use alloc::vec::Vec;
use core::ops::Range;
use core::option::Option;

#[cfg(not(feature = "std"))]
use alloc::borrow::Cow;
#[cfg(feature = "std")]
use std::borrow::Cow;

#[cfg(all(nightly, not(feature = "std")))]
use core::error::Error;
#[cfg(feature = "std")]
Expand All @@ -20,11 +25,11 @@ pub enum ErrorTraceDetail {
position: usize,
},
TableField {
field_name: &'static str,
field_name: Cow<'static, str>,
position: usize,
},
UnionVariant {
variant: &'static str,
variant: Cow<'static, str>,
position: usize,
},
}
Expand All @@ -44,12 +49,12 @@ impl core::convert::AsRef<[ErrorTraceDetail]> for ErrorTrace {
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum InvalidFlatbuffer {
MissingRequiredField {
required: &'static str,
required: Cow<'static, str>,
error_trace: ErrorTrace,
},
InconsistentUnion {
field: &'static str,
field_type: &'static str,
field: Cow<'static, str>,
field_type: Cow<'static, str>,
error_trace: ErrorTrace,
},
Utf8Error {
Expand All @@ -63,7 +68,7 @@ pub enum InvalidFlatbuffer {
},
Unaligned {
position: usize,
unaligned_type: &'static str,
unaligned_type: Cow<'static, str>,
error_trace: ErrorTrace,
},
RangeOutOfBounds {
Expand Down Expand Up @@ -217,16 +222,19 @@ impl InvalidFlatbuffer {
error_trace: Default::default(),
})
}
fn new_inconsistent_union<T>(field: &'static str, field_type: &'static str) -> Result<T> {
pub fn new_inconsistent_union<T>(
field: impl Into<Cow<'static, str>>,
field_type: impl Into<Cow<'static, str>>,
) -> Result<T> {
Err(Self::InconsistentUnion {
field,
field_type,
field: field.into(),
field_type: field_type.into(),
error_trace: Default::default(),
})
}
fn new_missing_required<T>(required: &'static str) -> Result<T> {
pub fn new_missing_required<T>(required: impl Into<Cow<'static, str>>) -> Result<T> {
Err(Self::MissingRequiredField {
required,
required: required.into(),
error_trace: Default::default(),
})
}
Expand All @@ -251,7 +259,7 @@ fn append_trace<T>(mut res: Result<T>, d: ErrorTraceDetail) -> Result<T> {
}

/// Adds a TableField trace detail if `res` is a data error.
fn trace_field<T>(res: Result<T>, field_name: &'static str, position: usize) -> Result<T> {
fn trace_field<T>(res: Result<T>, field_name: Cow<'static, str>, position: usize) -> Result<T> {
append_trace(
res,
ErrorTraceDetail::TableField {
Expand Down Expand Up @@ -333,19 +341,19 @@ impl<'opts, 'buf> Verifier<'opts, 'buf> {
///
/// Note this does not impact soundness as this crate does not assume alignment of structs
#[inline]
fn is_aligned<T>(&self, pos: usize) -> Result<()> {
pub fn is_aligned<T>(&self, pos: usize) -> Result<()> {
if pos % core::mem::align_of::<T>() == 0 {
Ok(())
} else {
Err(InvalidFlatbuffer::Unaligned {
unaligned_type: core::any::type_name::<T>(),
unaligned_type: Cow::Borrowed(core::any::type_name::<T>()),
position: pos,
error_trace: Default::default(),
})
}
}
#[inline]
fn range_in_buffer(&mut self, pos: usize, size: usize) -> Result<()> {
pub fn range_in_buffer(&mut self, pos: usize, size: usize) -> Result<()> {
let end = pos.saturating_add(size);
if end > self.buffer.len() {
return InvalidFlatbuffer::new_range_oob(pos, end);
Expand All @@ -363,12 +371,17 @@ impl<'opts, 'buf> Verifier<'opts, 'buf> {
self.range_in_buffer(pos, core::mem::size_of::<T>())
}
#[inline]
pub fn get_u8(&mut self, pos: usize) -> Result<u8> {
self.in_buffer::<u8>(pos)?;
Ok(u8::from_le_bytes([self.buffer[pos]]))
}
#[inline]
fn get_u16(&mut self, pos: usize) -> Result<u16> {
self.in_buffer::<u16>(pos)?;
Ok(u16::from_le_bytes([self.buffer[pos], self.buffer[pos + 1]]))
}
#[inline]
fn get_uoffset(&mut self, pos: usize) -> Result<UOffsetT> {
pub fn get_uoffset(&mut self, pos: usize) -> Result<UOffsetT> {
self.in_buffer::<u32>(pos)?;
Ok(u32::from_le_bytes([
self.buffer[pos],
Expand Down Expand Up @@ -434,11 +447,17 @@ impl<'opts, 'buf> Verifier<'opts, 'buf> {
/// tracing the error.
pub fn verify_union_variant<T: Verifiable>(
&mut self,
variant: &'static str,
variant: impl Into<Cow<'static, str>>,
position: usize,
) -> Result<()> {
let res = T::run_verifier(self, position);
append_trace(res, ErrorTraceDetail::UnionVariant { variant, position })
append_trace(
res,
ErrorTraceDetail::UnionVariant {
variant: variant.into(),
position,
},
)
}
}

Expand All @@ -456,7 +475,7 @@ pub struct TableVerifier<'ver, 'opts, 'buf> {
}

impl<'ver, 'opts, 'buf> TableVerifier<'ver, 'opts, 'buf> {
fn deref(&mut self, field: VOffsetT) -> Result<Option<usize>> {
pub fn deref(&mut self, field: VOffsetT) -> Result<Option<usize>> {
let field = field as usize;
if field < self.vtable_len {
let field_offset = self.verifier.get_u16(self.vtable.saturating_add(field))?;
Expand All @@ -469,23 +488,28 @@ impl<'ver, 'opts, 'buf> TableVerifier<'ver, 'opts, 'buf> {
Ok(None)
}

#[inline]
pub fn verifier(&mut self) -> &mut Verifier<'opts, 'buf> {
self.verifier
}

#[inline]
pub fn visit_field<T: Verifiable>(
mut self,
field_name: &'static str,
field_name: impl Into<Cow<'static, str>>,
field: VOffsetT,
required: bool,
) -> Result<Self> {
if let Some(field_pos) = self.deref(field)? {
trace_field(
T::run_verifier(self.verifier, field_pos),
field_name,
field_name.into(),
field_pos,
)?;
return Ok(self);
}
if required {
InvalidFlatbuffer::new_missing_required(field_name)
InvalidFlatbuffer::new_missing_required(field_name.into())
} else {
Ok(self)
}
Expand All @@ -496,9 +520,9 @@ impl<'ver, 'opts, 'buf> TableVerifier<'ver, 'opts, 'buf> {
/// reads the key, then invokes the callback to perform data-dependent verification.
pub fn visit_union<Key, UnionVerifier>(
mut self,
key_field_name: &'static str,
key_field_name: impl Into<Cow<'static, str>>,
key_field_voff: VOffsetT,
val_field_name: &'static str,
val_field_name: impl Into<Cow<'static, str>>,
val_field_voff: VOffsetT,
required: bool,
verify_union: UnionVerifier,
Expand All @@ -515,24 +539,28 @@ impl<'ver, 'opts, 'buf> TableVerifier<'ver, 'opts, 'buf> {
match (key_pos, val_pos) {
(None, None) => {
if required {
InvalidFlatbuffer::new_missing_required(val_field_name)
InvalidFlatbuffer::new_missing_required(val_field_name.into())
} else {
Ok(self)
}
}
(Some(k), Some(v)) => {
trace_field(Key::run_verifier(self.verifier, k), key_field_name, k)?;
trace_field(
Key::run_verifier(self.verifier, k),
key_field_name.into(),
k,
)?;
// Safety:
// Run verifier on `k` above
let discriminant = unsafe { Key::follow(self.verifier.buffer, k) };
trace_field(
verify_union(discriminant, self.verifier, v),
val_field_name,
val_field_name.into(),
v,
)?;
Ok(self)
}
_ => InvalidFlatbuffer::new_inconsistent_union(key_field_name, val_field_name),
_ => InvalidFlatbuffer::new_inconsistent_union(key_field_name.into(), val_field_name.into()),
}
}
pub fn finish(self) -> &'ver mut Verifier<'opts, 'buf> {
Expand Down
7 changes: 7 additions & 0 deletions rust/reflection/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

mod reflection_generated;
pub mod reflection_verifier;
mod r#struct;
pub use crate::r#struct::Struct;
pub use crate::reflection_generated::reflection;
Expand Down Expand Up @@ -47,6 +48,12 @@ pub enum FlatbufferError {
TryFromIntError(#[from] std::num::TryFromIntError),
#[error("Couldn't set string because cache vector is polluted")]
SetStringPolluted(),
#[error("Invalid schema: Polluted buffer or the schema doesn't match the buffer.")]
InvalidSchema(),
#[error("Type not supported: {0}")]
TypeNotSupported(String),
#[error("No type or invalid type found in union enum")]
InvalidUnionEnum(),
}

pub type FlatbufferResult<T, E = FlatbufferError> = core::result::Result<T, E>;
Expand Down
Loading

0 comments on commit d1d04f4

Please sign in to comment.