diff --git a/rust/flatbuffers/src/lib.rs b/rust/flatbuffers/src/lib.rs index 1ded368171c..9ed308024fb 100644 --- a/rust/flatbuffers/src/lib.rs +++ b/rust/flatbuffers/src/lib.rs @@ -56,7 +56,7 @@ pub use crate::push::{Push, PushAlignment}; pub use crate::table::{buffer_has_identifier, Table}; pub use crate::vector::{follow_cast_ref, Vector, VectorIter}; pub use crate::verifier::{ - ErrorTraceDetail, InvalidFlatbuffer, SimpleToVerifyInSlice, Verifiable, Verifier, + ErrorTraceDetail, InvalidFlatbuffer, SimpleToVerifyInSlice, TableVerifier, Verifiable, Verifier, VerifierOptions, }; pub use crate::vtable::field_index_to_field_offset; diff --git a/rust/flatbuffers/src/verifier.rs b/rust/flatbuffers/src/verifier.rs index 047d4f61360..c4c55f587d0 100644 --- a/rust/flatbuffers/src/verifier.rs +++ b/rust/flatbuffers/src/verifier.rs @@ -5,6 +5,11 @@ use alloc::vec::Vec; use core::ops::Range; use core::option::Option; +#[cfg(not(feature = "std"))] +use alloc::borrow::Cow; +#[cfg(feature = "std")] +use std::borrow::Cow; + #[cfg(all(nightly, not(feature = "std")))] use core::error::Error; #[cfg(feature = "std")] @@ -20,11 +25,11 @@ pub enum ErrorTraceDetail { position: usize, }, TableField { - field_name: &'static str, + field_name: Cow<'static, str>, position: usize, }, UnionVariant { - variant: &'static str, + variant: Cow<'static, str>, position: usize, }, } @@ -44,12 +49,12 @@ impl core::convert::AsRef<[ErrorTraceDetail]> for ErrorTrace { #[derive(Clone, Debug, PartialEq, Eq)] pub enum InvalidFlatbuffer { MissingRequiredField { - required: &'static str, + required: Cow<'static, str>, error_trace: ErrorTrace, }, InconsistentUnion { - field: &'static str, - field_type: &'static str, + field: Cow<'static, str>, + field_type: Cow<'static, str>, error_trace: ErrorTrace, }, Utf8Error { @@ -63,7 +68,7 @@ pub enum InvalidFlatbuffer { }, Unaligned { position: usize, - unaligned_type: &'static str, + unaligned_type: Cow<'static, str>, error_trace: ErrorTrace, }, RangeOutOfBounds { @@ -217,16 +222,19 @@ impl InvalidFlatbuffer { error_trace: Default::default(), }) } - fn new_inconsistent_union(field: &'static str, field_type: &'static str) -> Result { + pub fn new_inconsistent_union( + field: impl Into>, + field_type: impl Into>, + ) -> Result { Err(Self::InconsistentUnion { - field, - field_type, + field: field.into(), + field_type: field_type.into(), error_trace: Default::default(), }) } - fn new_missing_required(required: &'static str) -> Result { + pub fn new_missing_required(required: impl Into>) -> Result { Err(Self::MissingRequiredField { - required, + required: required.into(), error_trace: Default::default(), }) } @@ -251,7 +259,7 @@ fn append_trace(mut res: Result, d: ErrorTraceDetail) -> Result { } /// Adds a TableField trace detail if `res` is a data error. -fn trace_field(res: Result, field_name: &'static str, position: usize) -> Result { +fn trace_field(res: Result, field_name: Cow<'static, str>, position: usize) -> Result { append_trace( res, ErrorTraceDetail::TableField { @@ -333,19 +341,19 @@ impl<'opts, 'buf> Verifier<'opts, 'buf> { /// /// Note this does not impact soundness as this crate does not assume alignment of structs #[inline] - fn is_aligned(&self, pos: usize) -> Result<()> { + pub fn is_aligned(&self, pos: usize) -> Result<()> { if pos % core::mem::align_of::() == 0 { Ok(()) } else { Err(InvalidFlatbuffer::Unaligned { - unaligned_type: core::any::type_name::(), + unaligned_type: Cow::Borrowed(core::any::type_name::()), position: pos, error_trace: Default::default(), }) } } #[inline] - fn range_in_buffer(&mut self, pos: usize, size: usize) -> Result<()> { + pub fn range_in_buffer(&mut self, pos: usize, size: usize) -> Result<()> { let end = pos.saturating_add(size); if end > self.buffer.len() { return InvalidFlatbuffer::new_range_oob(pos, end); @@ -363,12 +371,17 @@ impl<'opts, 'buf> Verifier<'opts, 'buf> { self.range_in_buffer(pos, core::mem::size_of::()) } #[inline] + pub fn get_u8(&mut self, pos: usize) -> Result { + self.in_buffer::(pos)?; + Ok(u8::from_le_bytes([self.buffer[pos]])) + } + #[inline] fn get_u16(&mut self, pos: usize) -> Result { self.in_buffer::(pos)?; Ok(u16::from_le_bytes([self.buffer[pos], self.buffer[pos + 1]])) } #[inline] - fn get_uoffset(&mut self, pos: usize) -> Result { + pub fn get_uoffset(&mut self, pos: usize) -> Result { self.in_buffer::(pos)?; Ok(u32::from_le_bytes([ self.buffer[pos], @@ -434,11 +447,17 @@ impl<'opts, 'buf> Verifier<'opts, 'buf> { /// tracing the error. pub fn verify_union_variant( &mut self, - variant: &'static str, + variant: impl Into>, position: usize, ) -> Result<()> { let res = T::run_verifier(self, position); - append_trace(res, ErrorTraceDetail::UnionVariant { variant, position }) + append_trace( + res, + ErrorTraceDetail::UnionVariant { + variant: variant.into(), + position, + }, + ) } } @@ -456,7 +475,7 @@ pub struct TableVerifier<'ver, 'opts, 'buf> { } impl<'ver, 'opts, 'buf> TableVerifier<'ver, 'opts, 'buf> { - fn deref(&mut self, field: VOffsetT) -> Result> { + pub fn deref(&mut self, field: VOffsetT) -> Result> { let field = field as usize; if field < self.vtable_len { let field_offset = self.verifier.get_u16(self.vtable.saturating_add(field))?; @@ -469,23 +488,28 @@ impl<'ver, 'opts, 'buf> TableVerifier<'ver, 'opts, 'buf> { Ok(None) } + #[inline] + pub fn verifier(&mut self) -> &mut Verifier<'opts, 'buf> { + self.verifier + } + #[inline] pub fn visit_field( mut self, - field_name: &'static str, + field_name: impl Into>, field: VOffsetT, required: bool, ) -> Result { if let Some(field_pos) = self.deref(field)? { trace_field( T::run_verifier(self.verifier, field_pos), - field_name, + field_name.into(), field_pos, )?; return Ok(self); } if required { - InvalidFlatbuffer::new_missing_required(field_name) + InvalidFlatbuffer::new_missing_required(field_name.into()) } else { Ok(self) } @@ -496,9 +520,9 @@ impl<'ver, 'opts, 'buf> TableVerifier<'ver, 'opts, 'buf> { /// reads the key, then invokes the callback to perform data-dependent verification. pub fn visit_union( mut self, - key_field_name: &'static str, + key_field_name: impl Into>, key_field_voff: VOffsetT, - val_field_name: &'static str, + val_field_name: impl Into>, val_field_voff: VOffsetT, required: bool, verify_union: UnionVerifier, @@ -515,24 +539,28 @@ impl<'ver, 'opts, 'buf> TableVerifier<'ver, 'opts, 'buf> { match (key_pos, val_pos) { (None, None) => { if required { - InvalidFlatbuffer::new_missing_required(val_field_name) + InvalidFlatbuffer::new_missing_required(val_field_name.into()) } else { Ok(self) } } (Some(k), Some(v)) => { - trace_field(Key::run_verifier(self.verifier, k), key_field_name, k)?; + trace_field( + Key::run_verifier(self.verifier, k), + key_field_name.into(), + k, + )?; // Safety: // Run verifier on `k` above let discriminant = unsafe { Key::follow(self.verifier.buffer, k) }; trace_field( verify_union(discriminant, self.verifier, v), - val_field_name, + val_field_name.into(), v, )?; Ok(self) } - _ => InvalidFlatbuffer::new_inconsistent_union(key_field_name, val_field_name), + _ => InvalidFlatbuffer::new_inconsistent_union(key_field_name.into(), val_field_name.into()), } } pub fn finish(self) -> &'ver mut Verifier<'opts, 'buf> { diff --git a/rust/reflection/src/lib.rs b/rust/reflection/src/lib.rs index ed0f011f945..c747ee5932a 100644 --- a/rust/reflection/src/lib.rs +++ b/rust/reflection/src/lib.rs @@ -15,6 +15,7 @@ */ mod reflection_generated; +pub mod reflection_verifier; mod r#struct; pub use crate::r#struct::Struct; pub use crate::reflection_generated::reflection; @@ -47,6 +48,12 @@ pub enum FlatbufferError { TryFromIntError(#[from] std::num::TryFromIntError), #[error("Couldn't set string because cache vector is polluted")] SetStringPolluted(), + #[error("Invalid schema: Polluted buffer or the schema doesn't match the buffer.")] + InvalidSchema(), + #[error("Type not supported: {0}")] + TypeNotSupported(String), + #[error("No type or invalid type found in union enum")] + InvalidUnionEnum(), } pub type FlatbufferResult = core::result::Result; diff --git a/rust/reflection/src/reflection_verifier.rs b/rust/reflection/src/reflection_verifier.rs new file mode 100644 index 00000000000..db504128a72 --- /dev/null +++ b/rust/reflection/src/reflection_verifier.rs @@ -0,0 +1,355 @@ +/* + * Copyright 2018 Google Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use crate::reflection_generated::reflection::{BaseType, Field, Object, Schema}; +use crate::{FlatbufferError, FlatbufferResult}; +use flatbuffers::{ + ForwardsUOffset, InvalidFlatbuffer, TableVerifier, UOffsetT, Vector, Verifiable, Verifier, + VerifierOptions, SIZE_UOFFSET, SIZE_VOFFSET, +}; + +/// Verifies a buffer against its schema with default verification options. +pub fn verify(buffer: &[u8], schema: &Schema) -> FlatbufferResult<()> { + verify_with_options(buffer, schema, &VerifierOptions::default()) +} + +/// Verifies a buffer against its schema with custom verification options. +pub fn verify_with_options( + buffer: &[u8], + schema: &Schema, + opts: &VerifierOptions, +) -> FlatbufferResult<()> { + let mut verifier = Verifier::new(opts, buffer); + if let Some(table_object) = schema.root_table() { + if let core::result::Result::Ok(table_pos) = verifier.get_uoffset(0) { + let mut verified = vec![false; buffer.len()]; + return verify_table( + &mut verifier, + &table_object, + table_pos.try_into()?, + schema, + &mut verified, + ); + } + } + Err(FlatbufferError::InvalidSchema()) +} + +fn verify_table( + verifier: &mut Verifier, + table_object: &Object, + table_pos: usize, + schema: &Schema, + verified: &mut [bool], +) -> FlatbufferResult<()> { + if table_pos < verified.len() && verified[table_pos] { + return Ok(()); + } + + let mut table_verifier = verifier.visit_table(table_pos)?; + + for field in &table_object.fields() { + let field_name = field.name().to_owned(); + table_verifier = match field.type_().base_type() { + BaseType::UType | BaseType::UByte => { + table_verifier.visit_field::(field_name, field.offset(), field.required())? + } + BaseType::Bool => { + table_verifier.visit_field::(field_name, field.offset(), field.required())? + } + BaseType::Byte => { + table_verifier.visit_field::(field_name, field.offset(), field.required())? + } + BaseType::Short => { + table_verifier.visit_field::(field_name, field.offset(), field.required())? + } + BaseType::UShort => { + table_verifier.visit_field::(field_name, field.offset(), field.required())? + } + BaseType::Int => { + table_verifier.visit_field::(field_name, field.offset(), field.required())? + } + BaseType::UInt => { + table_verifier.visit_field::(field_name, field.offset(), field.required())? + } + BaseType::Long => { + table_verifier.visit_field::(field_name, field.offset(), field.required())? + } + BaseType::ULong => { + table_verifier.visit_field::(field_name, field.offset(), field.required())? + } + BaseType::Float => { + table_verifier.visit_field::(field_name, field.offset(), field.required())? + } + BaseType::Double => { + table_verifier.visit_field::(field_name, field.offset(), field.required())? + } + BaseType::String => table_verifier.visit_field::>( + field_name, + field.offset(), + field.required(), + )?, + BaseType::Vector => verify_vector(table_verifier, &field, schema, verified)?, + BaseType::Obj => { + if let Some(field_pos) = table_verifier.deref(field.offset())? { + let child_obj = schema.objects().get(field.type_().index().try_into()?); + if child_obj.is_struct() { + table_verifier + .verifier() + .range_in_buffer(field_pos, child_obj.bytesize().try_into()?)? + } else { + let field_value = table_verifier.verifier().get_uoffset(field_pos)?; + verify_table( + table_verifier.verifier(), + &child_obj, + field_pos.saturating_add(field_value.try_into()?), + schema, + verified, + )?; + } + } else if field.required() { + return InvalidFlatbuffer::new_missing_required(field.name().to_string())?; + } + table_verifier + } + BaseType::Union => { + if let Some(field_pos) = table_verifier.deref(field.offset())? { + let field_value = table_verifier.verifier().get_uoffset(field_pos)?; + verify_union( + table_verifier, + &field, + field_pos.saturating_add(field_value.try_into()?), + schema, + verified, + )? + } else if field.required() { + return InvalidFlatbuffer::new_missing_required(field.name().to_string())?; + } else { + table_verifier + } + } + _ => { + return Err(FlatbufferError::TypeNotSupported( + field + .type_() + .base_type() + .variant_name() + .unwrap_or_default() + .to_string(), + )); + } + }; + } + + table_verifier.finish(); + verified[table_pos] = true; + Ok(()) +} + +fn verify_vector<'a, 'b, 'c>( + mut table_verifier: TableVerifier<'a, 'b, 'c>, + field: &Field, + schema: &Schema, + verified: &mut [bool], +) -> FlatbufferResult> { + let field_name = field.name().to_owned(); + match field.type_().element() { + BaseType::UType | BaseType::UByte => table_verifier + .visit_field::>>( + field_name, + field.offset(), + field.required(), + ) + .map_err(FlatbufferError::VerificationError), + BaseType::Bool => table_verifier + .visit_field::>>( + field_name, + field.offset(), + field.required(), + ) + .map_err(FlatbufferError::VerificationError), + BaseType::Byte => table_verifier + .visit_field::>>( + field_name, + field.offset(), + field.required(), + ) + .map_err(FlatbufferError::VerificationError), + BaseType::Short => table_verifier + .visit_field::>>( + field_name, + field.offset(), + field.required(), + ) + .map_err(FlatbufferError::VerificationError), + BaseType::UShort => table_verifier + .visit_field::>>( + field_name, + field.offset(), + field.required(), + ) + .map_err(FlatbufferError::VerificationError), + BaseType::Int => table_verifier + .visit_field::>>( + field_name, + field.offset(), + field.required(), + ) + .map_err(FlatbufferError::VerificationError), + BaseType::UInt => table_verifier + .visit_field::>>( + field_name, + field.offset(), + field.required(), + ) + .map_err(FlatbufferError::VerificationError), + BaseType::Long => table_verifier + .visit_field::>>( + field_name, + field.offset(), + field.required(), + ) + .map_err(FlatbufferError::VerificationError), + BaseType::ULong => table_verifier + .visit_field::>>( + field_name, + field.offset(), + field.required(), + ) + .map_err(FlatbufferError::VerificationError), + BaseType::Float => table_verifier + .visit_field::>>( + field_name, + field.offset(), + field.required(), + ) + .map_err(FlatbufferError::VerificationError), + BaseType::Double => table_verifier + .visit_field::>>( + field_name, + field.offset(), + field.required(), + ) + .map_err(FlatbufferError::VerificationError), + BaseType::String => table_verifier + .visit_field::>>>( + field_name, + field.offset(), + field.required(), + ) + .map_err(FlatbufferError::VerificationError), + BaseType::Obj => { + if let Some(field_pos) = table_verifier.deref(field.offset())? { + let verifier = table_verifier.verifier(); + let vector_offset = verifier.get_uoffset(field_pos)?; + let vector_pos = field_pos.saturating_add(vector_offset.try_into()?); + let vector_len = verifier.get_uoffset(vector_pos)?; + let vector_start = vector_pos.saturating_add(SIZE_UOFFSET); + let child_obj = schema.objects().get(field.type_().index().try_into()?); + if child_obj.is_struct() { + let vector_size = vector_len.saturating_mul(child_obj.bytesize().try_into()?); + verifier.range_in_buffer(vector_start, vector_size.try_into()?)?; + } else { + verifier.is_aligned::(vector_start)?; + let vector_size = vector_len.saturating_mul(SIZE_UOFFSET.try_into()?); + verifier.range_in_buffer(vector_start, vector_size.try_into()?)?; + let vector_range = core::ops::Range { + start: vector_start, + end: vector_start.saturating_add(vector_size.try_into()?), + }; + for element_pos in vector_range.step_by(SIZE_UOFFSET) { + let table_pos = element_pos + .saturating_add(verifier.get_uoffset(element_pos)?.try_into()?); + verify_table(verifier, &child_obj, table_pos, schema, verified)?; + } + } + } else if field.required() { + return InvalidFlatbuffer::new_missing_required(field.name().to_string())?; + } + Ok(table_verifier) + } + _ => { + return Err(FlatbufferError::TypeNotSupported( + field + .type_() + .base_type() + .variant_name() + .unwrap_or_default() + .to_string(), + )) + } + } +} + +fn verify_union<'a, 'b, 'c>( + mut table_verifier: TableVerifier<'a, 'b, 'c>, + field: &Field, + union_pos: usize, + schema: &Schema, + verified: &mut [bool], +) -> FlatbufferResult> { + let union_enum = schema.enums().get(field.type_().index().try_into()?); + if union_enum.values().is_empty() { + return Err(FlatbufferError::InvalidUnionEnum()); + } + + let enum_offset = field.offset() - u16::try_from(SIZE_VOFFSET)?; + if let Some(enum_pos) = table_verifier.deref(enum_offset)? { + let enum_value = table_verifier.verifier().get_u8(enum_pos)?; + let enum_type = union_enum + .values() + .get(enum_value.into()) + .union_type() + .ok_or(FlatbufferError::InvalidUnionEnum())?; + + match enum_type.base_type() { + BaseType::String => <&str>::run_verifier(table_verifier.verifier(), union_pos)?, + BaseType::Obj => { + let child_obj = schema.objects().get(enum_type.index().try_into()?); + if child_obj.is_struct() { + table_verifier + .verifier() + .range_in_buffer(union_pos, child_obj.bytesize().try_into()?)? + } else { + verify_table( + table_verifier.verifier(), + &child_obj, + union_pos, + schema, + verified, + )?; + } + } + _ => { + return Err(FlatbufferError::TypeNotSupported( + enum_type + .base_type() + .variant_name() + .unwrap_or_default() + .to_string(), + )) + } + } + } else { + return InvalidFlatbuffer::new_inconsistent_union( + format!("{}_type", field.name()), + field.name().to_string(), + )?; + } + + verified[union_pos] = true; + Ok(table_verifier) +} diff --git a/tests/rust_reflection_test/src/lib.rs b/tests/rust_reflection_test/src/lib.rs index f85e835d8ff..a1722bfe888 100644 --- a/tests/rust_reflection_test/src/lib.rs +++ b/tests/rust_reflection_test/src/lib.rs @@ -1,4 +1,5 @@ use flatbuffers_reflection::reflection::{root_as_schema, BaseType, Field}; +use flatbuffers_reflection::reflection_verifier::{verify, verify_with_options}; use flatbuffers_reflection::{ get_any_field_float, get_any_field_float_in_struct, get_any_field_integer, get_any_field_integer_in_struct, get_any_field_string, get_any_field_string_in_struct, @@ -7,8 +8,9 @@ use flatbuffers_reflection::{ set_any_field_integer, set_any_field_string, set_field, set_string, }; -use flatbuffers::FlatBufferBuilder; +use flatbuffers::{FlatBufferBuilder, VerifierOptions}; +use std::error::Error; use std::fs::File; use std::io::Read; @@ -1444,6 +1446,65 @@ fn test_buffer_set_string_diff_type_fails() { ); } +#[test] +fn test_verify_buffer_default_options_succeeds() { + let buffer = create_test_buffer(); + let schema_buffer = load_file_as_buffer("../monster_test.bfbs"); + let schema = root_as_schema(schema_buffer.as_slice()).unwrap(); + + let res = verify(&buffer, &schema); + + assert!(res.is_ok()); +} + +#[test] +fn test_verify_buffer_limit_max_depth_fails() { + let buffer = create_test_buffer(); + let schema_buffer = load_file_as_buffer("../monster_test.bfbs"); + let schema = root_as_schema(schema_buffer.as_slice()).unwrap(); + let verify_options = VerifierOptions { + max_depth: 1, + ..Default::default() + }; + + let res = verify_with_options(&buffer, &schema, &verify_options); + + assert!(res.is_err()); + assert!(format!("{:#?}", res.err().unwrap()).contains("DepthLimitReached")); +} + +#[test] +fn test_verify_buffer_limit_max_table_fails() { + let buffer = create_test_buffer(); + let schema_buffer = load_file_as_buffer("../monster_test.bfbs"); + let schema = root_as_schema(schema_buffer.as_slice()).unwrap(); + let verify_options = VerifierOptions { + max_tables: 1, + ..Default::default() + }; + + let res = verify_with_options(&buffer, &schema, &verify_options); + + assert!(res.is_err()); + assert!(format!("{:#?}", res.err().unwrap()).contains("TooManyTables")); +} + +#[test] +fn test_verify_buffer_limit_max_size_fails() { + let buffer = create_test_buffer(); + let schema_buffer = load_file_as_buffer("../monster_test.bfbs"); + let schema = root_as_schema(schema_buffer.as_slice()).unwrap(); + let verify_options = VerifierOptions { + max_apparent_size: 1 << 6, + ..Default::default() + }; + + let res = verify_with_options(&buffer, &schema, &verify_options); + + assert!(res.is_err()); + assert!(format!("{:#?}", res.err().unwrap()).contains("ApparentSizeTooLarge")); +} + fn load_file_as_buffer(path: &str) -> Vec { std::fs::read(path).unwrap() }