Skip to content

Commit

Permalink
Add a 'SafeBuffer' struct which provides safe methods for reflection
Browse files Browse the repository at this point in the history
It verifies buffer against schema during construction and provides all the unsafe getters in lib.rs in a safe way
  • Loading branch information
candysonya committed Jan 15, 2025
1 parent d1d04f4 commit a60d076
Show file tree
Hide file tree
Showing 5 changed files with 1,108 additions and 152 deletions.
7 changes: 6 additions & 1 deletion rust/flatbuffers/src/vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,12 @@ impl<'a, T: Follow<'a> + 'a> Vector<'a, T> {
match f(&value, &key) {
Ordering::Equal => return Some(value),
Ordering::Less => left = mid + 1,
Ordering::Greater => right = mid - 1,
Ordering::Greater => {
if mid == 0 {
return None;
}
right = mid - 1;
},
}
}

Expand Down
110 changes: 58 additions & 52 deletions rust/reflection/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
*/

mod reflection_generated;
pub mod reflection_verifier;
mod reflection_verifier;
mod safe_buffer;
mod r#struct;
pub use crate::r#struct::Struct;
pub use crate::reflection_generated::reflection;
pub use crate::safe_buffer::SafeBuffer;

use flatbuffers::{
emplace_scalar, read_scalar, EndianScalar, Follow, ForwardsUOffset, InvalidFlatbuffer,
Expand All @@ -34,26 +36,30 @@ use num::traits::FromPrimitive;
use stdint::uintmax_t;
use thiserror::Error;

#[derive(Error, Debug)]
#[derive(Error, Debug, PartialEq)]
pub enum FlatbufferError {
#[error(transparent)]
VerificationError(#[from] flatbuffers::InvalidFlatbuffer),
#[error("Failed to convert between data type {0} and field type {1}")]
FieldTypeMismatch(String, String),
#[error("Set field value not supported for non-populated or non-scalar fields")]
SetValueNotSupported(),
SetValueNotSupported,
#[error(transparent)]
ParseFloatError(#[from] std::num::ParseFloatError),
#[error(transparent)]
TryFromIntError(#[from] std::num::TryFromIntError),
#[error("Couldn't set string because cache vector is polluted")]
SetStringPolluted(),
SetStringPolluted,
#[error("Invalid schema: Polluted buffer or the schema doesn't match the buffer.")]
InvalidSchema(),
InvalidSchema,
#[error("Type not supported: {0}")]
TypeNotSupported(String),
#[error("No type or invalid type found in union enum")]
InvalidUnionEnum(),
InvalidUnionEnum,
#[error("Table or Struct doesn't belong to the buffer")]
InvalidTableOrStruct,
#[error("Field not found in the table schema")]
FieldNotFound,
}

pub type FlatbufferResult<T, E = FlatbufferError> = core::result::Result<T, E>;
Expand All @@ -73,8 +79,8 @@ pub unsafe fn get_any_root(data: &[u8]) -> Table {
/// # Safety
///
/// The value of the corresponding slot must have type T
pub unsafe fn get_field_integer<'a, T: Follow<'a, Inner = T> + PrimInt + FromPrimitive + 'a>(
table: &'a Table,
pub unsafe fn get_field_integer<T: for<'a> Follow<'a, Inner = T> + PrimInt + FromPrimitive>(
table: &Table,
field: &Field,
) -> FlatbufferResult<Option<T>> {
if size_of::<T>() != get_type_size(field.type_().base_type()) {
Expand All @@ -98,8 +104,8 @@ pub unsafe fn get_field_integer<'a, T: Follow<'a, Inner = T> + PrimInt + FromPri
/// # Safety
///
/// The value of the corresponding slot must have type T
pub unsafe fn get_field_float<'a, T: Follow<'a, Inner = T> + Float + 'a>(
table: &'a Table,
pub unsafe fn get_field_float<T: for<'a> Follow<'a, Inner = T> + Float>(
table: &Table,
field: &Field,
) -> FlatbufferResult<Option<T>> {
if size_of::<T>() != get_type_size(field.type_().base_type()) {
Expand All @@ -124,7 +130,7 @@ pub unsafe fn get_field_float<'a, T: Follow<'a, Inner = T> + Float + 'a>(
///
/// The value of the corresponding slot must have type String
pub unsafe fn get_field_string<'a>(
table: &'a Table,
table: &Table<'a>,
field: &Field,
) -> FlatbufferResult<Option<&'a str>> {
if field.type_().base_type() != BaseType::String {
Expand All @@ -148,7 +154,7 @@ pub unsafe fn get_field_string<'a>(
///
/// The value of the corresponding slot must have type Struct
pub unsafe fn get_field_struct<'a>(
table: &'a Table,
table: &Table<'a>,
field: &Field,
) -> FlatbufferResult<Option<Struct<'a>>> {
// TODO inherited from C++: This does NOT check if the field is a table or struct, but we'd need
Expand All @@ -174,7 +180,7 @@ pub unsafe fn get_field_struct<'a>(
///
/// The value of the corresponding slot must have type Vector
pub unsafe fn get_field_vector<'a, T: Follow<'a, Inner = T>>(
table: &'a Table,
table: &Table<'a>,
field: &Field,
) -> FlatbufferResult<Option<Vector<'a, T>>> {
if field.type_().base_type() != BaseType::Vector
Expand All @@ -200,7 +206,7 @@ pub unsafe fn get_field_vector<'a, T: Follow<'a, Inner = T>>(
///
/// The value of the corresponding slot must have type Table
pub unsafe fn get_field_table<'a>(
table: &'a Table,
table: &Table<'a>,
field: &Field,
) -> FlatbufferResult<Option<Table<'a>>> {
if field.type_().base_type() != BaseType::Obj {
Expand All @@ -218,32 +224,6 @@ pub unsafe fn get_field_table<'a>(
Ok(table.get::<ForwardsUOffset<Table<'a>>>(field.offset(), None))
}

/// Gets a [Struct] struct field given its exact type. Returns error if the type doesn't match.
///
/// # Safety
///
/// The value of the corresponding slot must have type Struct.
pub unsafe fn get_field_struct_in_struct<'a>(
st: &'a Struct,
field: &Field,
) -> FlatbufferResult<Struct<'a>> {
// TODO inherited from C++: This does NOT check if the field is a table or struct, but we'd need
// access to the schema to check the is_struct flag.
if field.type_().base_type() != BaseType::Obj {
return Err(FlatbufferError::FieldTypeMismatch(
String::from("Obj"),
field
.type_()
.base_type()
.variant_name()
.unwrap_or_default()
.to_string(),
));
}

Ok(st.get::<Struct>(field.offset() as usize))
}

/// Returns the value of any table field as a 64-bit int, regardless of what type it is. Returns default integer if the field is not set or error if the value cannot be parsed as integer.
/// [num_traits](https://docs.rs/num-traits/latest/num_traits/cast/trait.NumCast.html) is used for number casting.
///
Expand Down Expand Up @@ -290,6 +270,32 @@ pub unsafe fn get_any_field_string(table: &Table, field: &Field, schema: &Schema
}
}

/// Gets a [Struct] struct field given its exact type. Returns error if the type doesn't match.
///
/// # Safety
///
/// The value of the corresponding slot must have type Struct.
pub unsafe fn get_field_struct_in_struct<'a>(
st: &Struct<'a>,
field: &Field,
) -> FlatbufferResult<Struct<'a>> {
// TODO inherited from C++: This does NOT check if the field is a table or struct, but we'd need
// access to the schema to check the is_struct flag.
if field.type_().base_type() != BaseType::Obj {
return Err(FlatbufferError::FieldTypeMismatch(
String::from("Obj"),
field
.type_()
.base_type()
.variant_name()
.unwrap_or_default()
.to_string(),
));
}

Ok(st.get::<Struct>(field.offset() as usize))
}

/// Returns the value of any struct field as a 64-bit int, regardless of what type it is. Returns error if the value cannot be parsed as integer.
///
/// # Safety
Expand Down Expand Up @@ -348,11 +354,11 @@ pub unsafe fn set_any_field_integer(
let table = Table::follow(buf, table_loc);

let Some(field_loc) = get_field_loc(&table, field) else {
return Err(FlatbufferError::SetValueNotSupported());
return Err(FlatbufferError::SetValueNotSupported);
};

if !is_scalar(field_type) {
return Err(FlatbufferError::SetValueNotSupported());
return Err(FlatbufferError::SetValueNotSupported);
}

set_any_value_integer(field_type, buf, field_loc, v)
Expand All @@ -373,11 +379,11 @@ pub unsafe fn set_any_field_float(
let table = Table::follow(buf, table_loc);

let Some(field_loc) = get_field_loc(&table, field) else {
return Err(FlatbufferError::SetValueNotSupported());
return Err(FlatbufferError::SetValueNotSupported);
};

if !is_scalar(field_type) {
return Err(FlatbufferError::SetValueNotSupported());
return Err(FlatbufferError::SetValueNotSupported);
}

set_any_value_float(field_type, buf, field_loc, v)
Expand All @@ -398,11 +404,11 @@ pub unsafe fn set_any_field_string(
let table = Table::follow(buf, table_loc);

let Some(field_loc) = get_field_loc(&table, field) else {
return Err(FlatbufferError::SetValueNotSupported());
return Err(FlatbufferError::SetValueNotSupported);
};

if !is_scalar(field_type) {
return Err(FlatbufferError::SetValueNotSupported());
return Err(FlatbufferError::SetValueNotSupported);
}

set_any_value_float(field_type, buf, field_loc, v.parse::<f64>()?)
Expand All @@ -423,7 +429,7 @@ pub unsafe fn set_field<T: EndianScalar>(
let table = Table::follow(buf, table_loc);

if !is_scalar(field_type) {
return Err(FlatbufferError::SetValueNotSupported());
return Err(FlatbufferError::SetValueNotSupported);
}

if core::mem::size_of::<T>() != get_type_size(field_type) {
Expand All @@ -434,7 +440,7 @@ pub unsafe fn set_field<T: EndianScalar>(
}

let Some(field_loc) = get_field_loc(&table, field) else {
return Err(FlatbufferError::SetValueNotSupported());
return Err(FlatbufferError::SetValueNotSupported);
};

if buf.len() < field_loc.saturating_add(get_type_size(field_type)) {
Expand Down Expand Up @@ -480,7 +486,7 @@ pub unsafe fn set_string(
let table = Table::follow(buf, table_loc);

let Some(field_loc) = get_field_loc(&table, field) else {
return Err(FlatbufferError::SetValueNotSupported());
return Err(FlatbufferError::SetValueNotSupported);
};

if buf.len() < field_loc + get_type_size(field_type) {
Expand Down Expand Up @@ -814,7 +820,7 @@ fn set_any_value_integer(
))
}
}
_ => Err(FlatbufferError::SetValueNotSupported()),
_ => Err(FlatbufferError::SetValueNotSupported),
}
}

Expand Down Expand Up @@ -927,7 +933,7 @@ fn set_any_value_float(
return Ok(emplace_scalar::<f64>(buf, v));
}
}
_ => return Err(FlatbufferError::SetValueNotSupported()),
_ => return Err(FlatbufferError::SetValueNotSupported),
}
return Err(FlatbufferError::FieldTypeMismatch(
String::from("f64"),
Expand All @@ -954,7 +960,7 @@ unsafe fn update_offset(
offset: isize,
) -> FlatbufferResult<()> {
if updated.len() != buf.len() {
return Err(FlatbufferError::SetStringPolluted());
return Err(FlatbufferError::SetStringPolluted);
}

if updated[table_loc] {
Expand Down
Loading

0 comments on commit a60d076

Please sign in to comment.