diff --git a/arrow-avro/Cargo.toml b/arrow-avro/Cargo.toml index c103c2ecc0f3..b3fb6bfa9384 100644 --- a/arrow-avro/Cargo.toml +++ b/arrow-avro/Cargo.toml @@ -34,7 +34,7 @@ path = "src/lib.rs" bench = false [features] -default = ["deflate", "snappy", "zstd"] +default = ["deflate", "snappy", "zstd", "bzip2", "xz"] deflate = ["flate2"] snappy = ["snap", "crc"] @@ -47,9 +47,10 @@ serde = { version = "1.0.188", features = ["derive"] } flate2 = { version = "1.0", default-features = false, features = ["rust_backend"], optional = true } snap = { version = "1.0", default-features = false, optional = true } zstd = { version = "0.13", default-features = false, optional = true } +bzip2 = { version = "0.4.4", default-features = false, optional = true } +xz = { version = "0.1.0", default-features = false, optional = true } crc = { version = "3.0", optional = true } - [dev-dependencies] -rand = { version = "0.8", default-features = false, features = ["std", "std_rng"] } - +arrow-data = { workspace = true } +rand = { version = "0.8", default-features = false, features = ["std", "std_rng"] } \ No newline at end of file diff --git a/arrow-avro/src/codec.rs b/arrow-avro/src/codec.rs index 2ac1ad038bd7..57b2383c3d09 100644 --- a/arrow-avro/src/codec.rs +++ b/arrow-avro/src/codec.rs @@ -15,48 +15,44 @@ // specific language governing permissions and limitations // under the License. -use crate::schema::{Attributes, ComplexType, PrimitiveType, Record, Schema, TypeName}; +use crate::schema::{Attributes, ComplexType, PrimitiveType, Schema, TypeName}; +use arrow_schema::DataType::*; use arrow_schema::{ - ArrowError, DataType, Field, FieldRef, IntervalUnit, SchemaBuilder, SchemaRef, TimeUnit, + ArrowError, DataType, Field, Fields, IntervalUnit, TimeUnit, DECIMAL128_MAX_PRECISION, + DECIMAL128_MAX_SCALE, }; -use std::borrow::Cow; use std::collections::HashMap; use std::sync::Arc; /// Avro types are not nullable, with nullability instead encoded as a union /// where one of the variants is the null type. -/// -/// To accommodate this we special case two-variant unions where one of the -/// variants is the null type, and use this to derive arrow's notion of nullability -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Copy, Clone, PartialEq, Eq)] pub enum Nullability { - /// The nulls are encoded as the first union variant + /// The nulls are encoded as the first union variant => `[ "null", T ]` NullFirst, - /// The nulls are encoded as the second union variant + /// The nulls are encoded as the second union variant => `[ T, "null" ]` + /// + /// **Important**: In Impala’s out-of-spec approach, branch=0 => null, branch=1 => decode T. + /// This is reversed from the typical “standard” Avro interpretation for `[T,"null"]`. + /// + /// NullSecond, } /// An Avro datatype mapped to the arrow data model #[derive(Debug, Clone)] pub struct AvroDataType { - nullability: Option, - metadata: HashMap, - codec: Codec, + pub nullability: Option, + pub metadata: Arc>, + pub codec: Codec, } impl AvroDataType { - /// Returns an arrow [`Field`] with the given name + /// Returns an arrow [`Field`] with the given name, applying `nullability` if present. pub fn field_with_name(&self, name: &str) -> Field { - let d = self.codec.data_type(); - Field::new(name, d, self.nullability.is_some()).with_metadata(self.metadata.clone()) - } - - pub fn codec(&self) -> &Codec { - &self.codec - } - - pub fn nullability(&self) -> Option { - self.nullability + let is_nullable = self.nullability.is_some(); + let metadata = Arc::try_unwrap(self.metadata.clone()).unwrap_or_else(|arc| (*arc).clone()); + Field::new(name, self.codec.data_type(), is_nullable).with_metadata(metadata) } } @@ -65,12 +61,21 @@ impl AvroDataType { pub struct AvroField { name: String, data_type: AvroDataType, + default: Option, } impl AvroField { /// Returns the arrow [`Field`] pub fn field(&self) -> Field { - self.data_type.field_with_name(&self.name) + let mut fld = self.data_type.field_with_name(&self.name); + if let Some(def_val) = &self.default { + if !def_val.is_null() { + let mut md = fld.metadata().clone(); + md.insert("avro.default".to_string(), def_val.to_string()); + fld = fld.with_metadata(md); + } + } + fld } /// Returns the [`AvroDataType`] @@ -78,6 +83,7 @@ impl AvroField { &self.data_type } + /// Returns the name of this field pub fn name(&self) -> &str { &self.name } @@ -91,9 +97,10 @@ impl<'a> TryFrom<&Schema<'a>> for AvroField { Schema::Complex(ComplexType::Record(r)) => { let mut resolver = Resolver::default(); let data_type = make_data_type(schema, None, &mut resolver)?; - Ok(AvroField { + Ok(Self { data_type, name: r.name.to_string(), + default: None, }) } _ => Err(ArrowError::ParseError(format!( @@ -104,10 +111,9 @@ impl<'a> TryFrom<&Schema<'a>> for AvroField { } /// An Avro encoding -/// -/// #[derive(Debug, Clone)] pub enum Codec { + /// Primitive Null, Boolean, Int32, @@ -115,46 +121,95 @@ pub enum Codec { Float32, Float64, Binary, - Utf8, + String, + /// Complex + Record(Arc<[AvroField]>), + Enum(Arc<[String]>, Arc<[i32]>), + Array(Arc), + Map(Arc), + Fixed(i32), + /// Logical + Decimal(usize, Option, Option), + Uuid, Date32, TimeMillis, TimeMicros, - /// TimestampMillis(is_utc) TimestampMillis(bool), - /// TimestampMicros(is_utc) TimestampMicros(bool), - Fixed(i32), - List(Arc), - Struct(Arc<[AvroField]>), - Interval, + Duration, } impl Codec { - fn data_type(&self) -> DataType { + /// Convert this to an Arrow `DataType` + pub(crate) fn data_type(&self) -> DataType { match self { - Self::Null => DataType::Null, - Self::Boolean => DataType::Boolean, - Self::Int32 => DataType::Int32, - Self::Int64 => DataType::Int64, - Self::Float32 => DataType::Float32, - Self::Float64 => DataType::Float64, - Self::Binary => DataType::Binary, - Self::Utf8 => DataType::Utf8, - Self::Date32 => DataType::Date32, - Self::TimeMillis => DataType::Time32(TimeUnit::Millisecond), - Self::TimeMicros => DataType::Time64(TimeUnit::Microsecond), + // Primitives + Self::Null => Null, + Self::Boolean => Boolean, + Self::Int32 => Int32, + Self::Int64 => Int64, + Self::Float32 => Float32, + Self::Float64 => Float64, + Self::Binary => Binary, + Self::String => Utf8, + Self::Record(fields) => { + let arrow_fields: Vec = fields.iter().map(|f| f.field()).collect(); + Struct(arrow_fields.into()) + } + Self::Enum(_, _) => Dictionary(Box::new(Int32), Box::new(Utf8)), + Self::Array(child_type) => { + let child_dt = child_type.codec.data_type(); + let child_md = Arc::try_unwrap(child_type.metadata.clone()) + .unwrap_or_else(|arc| (*arc).clone()); + let child_field = Field::new(Field::LIST_FIELD_DEFAULT_NAME, child_dt, true) + .with_metadata(child_md); + List(Arc::new(child_field)) + } + Self::Map(value_type) => { + let val_dt = value_type.codec.data_type(); + let val_md = Arc::try_unwrap(value_type.metadata.clone()) + .unwrap_or_else(|arc| (*arc).clone()); + let val_field = Field::new("value", val_dt, true).with_metadata(val_md); + Map( + Arc::new(Field::new( + "entries", + Struct(Fields::from(vec![ + Field::new("key", Utf8, false), + val_field, + ])), + false, + )), + false, + ) + } + Self::Fixed(sz) => FixedSizeBinary(*sz), + Self::Decimal(precision, scale, size) => { + let p = *precision as u8; + let s = scale.unwrap_or(0) as i8; + let too_large_for_128 = match *size { + Some(sz) => sz > 16, + None => { + (p as usize) > DECIMAL128_MAX_PRECISION as usize + || (s as usize) > DECIMAL128_MAX_SCALE as usize + } + }; + if too_large_for_128 { + Decimal256(p, s) + } else { + Decimal128(p, s) + } + } + Self::Uuid => FixedSizeBinary(16), + Self::Date32 => Date32, + Self::TimeMillis => Time32(TimeUnit::Millisecond), + Self::TimeMicros => Time64(TimeUnit::Microsecond), Self::TimestampMillis(is_utc) => { - DataType::Timestamp(TimeUnit::Millisecond, is_utc.then(|| "+00:00".into())) + Timestamp(TimeUnit::Millisecond, is_utc.then(|| "+00:00".into())) } Self::TimestampMicros(is_utc) => { - DataType::Timestamp(TimeUnit::Microsecond, is_utc.then(|| "+00:00".into())) - } - Self::Interval => DataType::Interval(IntervalUnit::MonthDayNano), - Self::Fixed(size) => DataType::FixedSizeBinary(*size), - Self::List(f) => { - DataType::List(Arc::new(f.field_with_name(Field::LIST_FIELD_DEFAULT_NAME))) + Timestamp(TimeUnit::Microsecond, is_utc.then(|| "+00:00".into())) } - Self::Struct(f) => DataType::Struct(f.iter().map(|x| x.field()).collect()), + Self::Duration => Interval(IntervalUnit::MonthDayNano), } } } @@ -169,42 +224,66 @@ impl From for Codec { PrimitiveType::Float => Self::Float32, PrimitiveType::Double => Self::Float64, PrimitiveType::Bytes => Self::Binary, - PrimitiveType::String => Self::Utf8, + PrimitiveType::String => Self::String, } } } /// Resolves Avro type names to [`AvroDataType`] -/// -/// See -#[derive(Debug, Default)] +#[derive(Default, Debug)] struct Resolver<'a> { map: HashMap<(&'a str, &'a str), AvroDataType>, } impl<'a> Resolver<'a> { - fn register(&mut self, name: &'a str, namespace: Option<&'a str>, schema: AvroDataType) { - self.map.insert((name, namespace.unwrap_or("")), schema); + fn register(&mut self, name: &'a str, namespace: Option<&'a str>, dt: AvroDataType) { + let ns = namespace.unwrap_or(""); + self.map.insert((name, ns), dt); } - fn resolve(&self, name: &str, namespace: Option<&'a str>) -> Result { - let (namespace, name) = name - .rsplit_once('.') - .unwrap_or_else(|| (namespace.unwrap_or(""), name)); - + fn resolve( + &self, + full_name: &str, + namespace: Option<&'a str>, + ) -> Result { + let (ns, nm) = match full_name.rsplit_once('.') { + Some((a, b)) => (a, b), + None => (namespace.unwrap_or(""), full_name), + }; self.map - .get(&(namespace, name)) - .ok_or_else(|| ArrowError::ParseError(format!("Failed to resolve {namespace}.{name}"))) + .get(&(nm, ns)) .cloned() + .ok_or_else(|| ArrowError::ParseError(format!("Failed to resolve {ns}.{nm}"))) } } -/// Parses a [`AvroDataType`] from the provided [`Schema`] and the given `name` and `namespace` -/// -/// `name`: is name used to refer to `schema` in its parent -/// `namespace`: an optional qualifier used as part of a type hierarchy -/// -/// See [`Resolver`] for more information +fn parse_decimal_attributes( + attributes: &Attributes, + fallback_size: Option, + precision_required: bool, +) -> Result<(usize, usize, Option), ArrowError> { + let precision = attributes + .additional + .get("precision") + .and_then(|v| v.as_u64()) + .or(if precision_required { None } else { Some(10) }) + .ok_or_else(|| ArrowError::ParseError("Decimal requires precision".to_string()))? + as usize; + let scale = attributes + .additional + .get("scale") + .and_then(|v| v.as_u64()) + .unwrap_or(0) as usize; + let size = attributes + .additional + .get("size") + .and_then(|v| v.as_u64()) + .map(|s| s as usize) + .or(fallback_size); + Ok((precision, scale, size)) +} + +/// Parses a [`AvroDataType`] from the provided [`Schema`], plus optional `namespace`. fn make_data_type<'a>( schema: &Schema<'a>, namespace: Option<&'a str>, @@ -213,117 +292,167 @@ fn make_data_type<'a>( match schema { Schema::TypeName(TypeName::Primitive(p)) => Ok(AvroDataType { nullability: None, - metadata: Default::default(), + metadata: Arc::new(Default::default()), codec: (*p).into(), }), Schema::TypeName(TypeName::Ref(name)) => resolver.resolve(name, namespace), - Schema::Union(f) => { - // Special case the common case of nullable primitives - let null = f + Schema::Union(u) => { + let null_count = u .iter() - .position(|x| x == &Schema::TypeName(TypeName::Primitive(PrimitiveType::Null))); - match (f.len() == 2, null) { - (true, Some(0)) => { - let mut field = make_data_type(&f[1], namespace, resolver)?; - field.nullability = Some(Nullability::NullFirst); - Ok(field) - } - (true, Some(1)) => { - let mut field = make_data_type(&f[0], namespace, resolver)?; - field.nullability = Some(Nullability::NullSecond); - Ok(field) - } - _ => Err(ArrowError::NotYetImplemented(format!( - "Union of {f:?} not currently supported" - ))), + .filter(|x| *x == &Schema::TypeName(TypeName::Primitive(PrimitiveType::Null))) + .count(); + if null_count == 1 && u.len() == 2 { + let null_idx = u + .iter() + .position(|x| x == &Schema::TypeName(TypeName::Primitive(PrimitiveType::Null))) + .unwrap(); + let other_idx = if null_idx == 0 { 1 } else { 0 }; + let mut dt = make_data_type(&u[other_idx], namespace, resolver)?; + dt.nullability = if null_idx == 0 { + Some(Nullability::NullFirst) + } else { + Some(Nullability::NullSecond) + }; + Ok(dt) + } else { + Err(ArrowError::NotYetImplemented(format!( + "Union of {u:?} not currently supported" + ))) } } + Schema::Complex(c) => match c { ComplexType::Record(r) => { - let namespace = r.namespace.or(namespace); + let ns = r.namespace.or(namespace); let fields = r .fields .iter() - .map(|field| { - Ok(AvroField { - name: field.name.to_string(), - data_type: make_data_type(&field.r#type, namespace, resolver)?, + .map(|f| { + let data_type = make_data_type(&f.r#type, ns, resolver)?; + Ok::(AvroField { + name: f.name.to_string(), + data_type, + default: f.default.clone(), }) }) - .collect::>()?; - - let field = AvroDataType { + .collect::, ArrowError>>()?; + let rec = AvroDataType { nullability: None, - codec: Codec::Struct(fields), - metadata: r.attributes.field_metadata(), + metadata: Arc::new(r.attributes.field_metadata()), + codec: Codec::Record(Arc::from(fields)), }; - resolver.register(r.name, namespace, field.clone()); - Ok(field) + resolver.register(r.name, ns, rec.clone()); + Ok(rec) + } + ComplexType::Enum(e) => { + let en = AvroDataType { + nullability: None, + metadata: Arc::new(e.attributes.field_metadata()), + codec: Codec::Enum( + Arc::from(e.symbols.iter().map(|s| s.to_string()).collect::>()), + Arc::from(vec![]), + ), + }; + resolver.register(e.name, namespace, en.clone()); + Ok(en) } ComplexType::Array(a) => { - let mut field = make_data_type(a.items.as_ref(), namespace, resolver)?; + let child = make_data_type(&a.items, namespace, resolver)?; Ok(AvroDataType { nullability: None, - metadata: a.attributes.field_metadata(), - codec: Codec::List(Arc::new(field)), + metadata: Arc::new(a.attributes.field_metadata()), + codec: Codec::Array(Arc::new(child)), }) } - ComplexType::Fixed(f) => { - let size = f.size.try_into().map_err(|e| { - ArrowError::ParseError(format!("Overflow converting size to i32: {e}")) - })?; - - let field = AvroDataType { + ComplexType::Map(m) => { + let val = make_data_type(&m.values, namespace, resolver)?; + Ok(AvroDataType { nullability: None, - metadata: f.attributes.field_metadata(), - codec: Codec::Fixed(size), - }; - resolver.register(f.name, namespace, field.clone()); - Ok(field) + metadata: Arc::new(m.attributes.field_metadata()), + codec: Codec::Map(Arc::new(val)), + }) + } + ComplexType::Fixed(fx) => { + let size = fx.size as i32; + if let Some("decimal") = fx.attributes.logical_type { + let (precision, scale, _) = + parse_decimal_attributes(&fx.attributes, Some(size as usize), true)?; + let dec = AvroDataType { + nullability: None, + metadata: Arc::new(fx.attributes.field_metadata()), + codec: Codec::Decimal(precision, Some(scale), Some(size as usize)), + }; + resolver.register(fx.name, namespace, dec.clone()); + Ok(dec) + } else { + let fixed_dt = AvroDataType { + nullability: None, + metadata: Arc::new(fx.attributes.field_metadata()), + codec: Codec::Fixed(size), + }; + resolver.register(fx.name, namespace, fixed_dt.clone()); + Ok(fixed_dt) + } } - ComplexType::Enum(e) => Err(ArrowError::NotYetImplemented(format!( - "Enum of {e:?} not currently supported" - ))), - ComplexType::Map(m) => Err(ArrowError::NotYetImplemented(format!( - "Map of {m:?} not currently supported" - ))), }, - Schema::Type(t) => { - let mut field = - make_data_type(&Schema::TypeName(t.r#type.clone()), namespace, resolver)?; - // https://avro.apache.org/docs/1.11.1/specification/#logical-types - match (t.attributes.logical_type, &mut field.codec) { - (Some("decimal"), c @ Codec::Fixed(_)) => { - return Err(ArrowError::NotYetImplemented( - "Decimals are not currently supported".to_string(), - )) + Schema::Type(t) => { + let mut dt = make_data_type(&Schema::TypeName(t.r#type.clone()), namespace, resolver)?; + match (t.attributes.logical_type, &mut dt.codec) { + (Some("decimal"), Codec::Fixed(sz)) => { + let (prec, sc, size_opt) = + parse_decimal_attributes(&t.attributes, Some(*sz as usize), false)?; + if let Some(sz_actual) = size_opt { + *sz = sz_actual as i32; + } + dt.codec = Codec::Decimal(prec, Some(sc), Some(*sz as usize)); + } + (Some("decimal"), Codec::Binary) => { + let (prec, sc, _) = parse_decimal_attributes(&t.attributes, None, false)?; + dt.codec = Codec::Decimal(prec, Some(sc), None); + } + (Some("uuid"), Codec::String) => { + dt.codec = Codec::Uuid; + } + (Some("date"), Codec::Int32) => { + dt.codec = Codec::Date32; + } + (Some("time-millis"), Codec::Int32) => { + dt.codec = Codec::TimeMillis; + } + (Some("time-micros"), Codec::Int64) => { + dt.codec = Codec::TimeMicros; } - (Some("date"), c @ Codec::Int32) => *c = Codec::Date32, - (Some("time-millis"), c @ Codec::Int32) => *c = Codec::TimeMillis, - (Some("time-micros"), c @ Codec::Int64) => *c = Codec::TimeMicros, - (Some("timestamp-millis"), c @ Codec::Int64) => *c = Codec::TimestampMillis(true), - (Some("timestamp-micros"), c @ Codec::Int64) => *c = Codec::TimestampMicros(true), - (Some("local-timestamp-millis"), c @ Codec::Int64) => { - *c = Codec::TimestampMillis(false) + (Some("timestamp-millis"), Codec::Int64) => { + dt.codec = Codec::TimestampMillis(true); } - (Some("local-timestamp-micros"), c @ Codec::Int64) => { - *c = Codec::TimestampMicros(false) + (Some("timestamp-micros"), Codec::Int64) => { + dt.codec = Codec::TimestampMicros(true); } - (Some("duration"), c @ Codec::Fixed(12)) => *c = Codec::Interval, - (Some(logical), _) => { - // Insert unrecognized logical type into metadata map - field.metadata.insert("logicalType".into(), logical.into()); + (Some("local-timestamp-millis"), Codec::Int64) => { + dt.codec = Codec::TimestampMillis(false); + } + (Some("local-timestamp-micros"), Codec::Int64) => { + dt.codec = Codec::TimestampMicros(false); + } + (Some("duration"), Codec::Fixed(12)) => { + dt.codec = Codec::Duration; + } + (Some(other), _) => { + if !dt.metadata.contains_key("logicalType") { + let mut arc_map = (*dt.metadata).clone(); + arc_map.insert("logicalType".into(), other.into()); + dt.metadata = Arc::new(arc_map); + } } (None, _) => {} } - - if !t.attributes.additional.is_empty() { - for (k, v) in &t.attributes.additional { - field.metadata.insert(k.to_string(), v.to_string()); - } + for (k, v) in &t.attributes.additional { + let mut arc_map = (*dt.metadata).clone(); + arc_map.insert(k.to_string(), v.to_string()); + dt.metadata = Arc::new(arc_map); } - Ok(field) + Ok(dt) } } } diff --git a/arrow-avro/src/compression.rs b/arrow-avro/src/compression.rs index f29b8dd07606..5c4c988c899e 100644 --- a/arrow-avro/src/compression.rs +++ b/arrow-avro/src/compression.rs @@ -16,7 +16,6 @@ // under the License. use arrow_schema::ArrowError; -use std::io; use std::io::Read; /// The metadata key used for storing the JSON encoded [`CompressionCodec`] @@ -27,6 +26,8 @@ pub enum CompressionCodec { Deflate, Snappy, ZStandard, + Bzip2, + Xz, } impl CompressionCodec { @@ -65,7 +66,6 @@ impl CompressionCodec { CompressionCodec::Snappy => Err(ArrowError::ParseError( "Snappy codec requires snappy feature".to_string(), )), - #[cfg(feature = "zstd")] CompressionCodec::ZStandard => { let mut decoder = zstd::Decoder::new(block)?; @@ -77,6 +77,28 @@ impl CompressionCodec { CompressionCodec::ZStandard => Err(ArrowError::ParseError( "ZStandard codec requires zstd feature".to_string(), )), + #[cfg(feature = "bzip2")] + CompressionCodec::Bzip2 => { + let mut decoder = bzip2::read::BzDecoder::new(block); + let mut out = Vec::new(); + decoder.read_to_end(&mut out)?; + Ok(out) + } + #[cfg(not(feature = "bzip2"))] + CompressionCodec::Bzip2 => Err(ArrowError::ParseError( + "Bzip2 codec requires bzip2 feature".to_string(), + )), + #[cfg(feature = "xz")] + CompressionCodec::Xz => { + let mut decoder = xz::read::XzDecoder::new(block); + let mut out = Vec::new(); + decoder.read_to_end(&mut out)?; + Ok(out) + } + #[cfg(not(feature = "xz"))] + CompressionCodec::Xz => Err(ArrowError::ParseError( + "XZ codec requires xz feature".to_string(), + )), } } } diff --git a/arrow-avro/src/reader/block.rs b/arrow-avro/src/reader/block.rs index 479f0ef90909..e022031164e7 100644 --- a/arrow-avro/src/reader/block.rs +++ b/arrow-avro/src/reader/block.rs @@ -86,7 +86,6 @@ impl BlockDecoder { "Block count cannot be negative, got {c}" )) })?; - self.state = BlockDecoderState::Size; } } @@ -114,15 +113,18 @@ impl BlockDecoder { } BlockDecoderState::Sync => { let to_decode = buf.len().min(self.bytes_remaining); - let write = &mut self.in_progress.sync[16 - to_decode..]; - write[..to_decode].copy_from_slice(&buf[..to_decode]); + let start = 16 - self.bytes_remaining; + let end = start + to_decode; + self.in_progress.sync[start..end].copy_from_slice(&buf[..to_decode]); self.bytes_remaining -= to_decode; buf = &buf[to_decode..]; if self.bytes_remaining == 0 { self.state = BlockDecoderState::Finished; } } - BlockDecoderState::Finished => return Ok(max_read - buf.len()), + BlockDecoderState::Finished => { + return Ok(max_read - buf.len()); + } } } Ok(max_read) diff --git a/arrow-avro/src/reader/cursor.rs b/arrow-avro/src/reader/cursor.rs index 4b6a5a4d65db..9e38a78c63ec 100644 --- a/arrow-avro/src/reader/cursor.rs +++ b/arrow-avro/src/reader/cursor.rs @@ -14,7 +14,6 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. - use crate::reader::vlq::read_varint; use arrow_schema::ArrowError; @@ -65,6 +64,7 @@ impl<'a> AvroCursor<'a> { Ok(val) } + /// Decode a zig-zag encoded Avro int (32-bit). #[inline] pub(crate) fn get_int(&mut self) -> Result { let varint = self.read_vlq()?; @@ -74,18 +74,20 @@ impl<'a> AvroCursor<'a> { Ok((val >> 1) as i32 ^ -((val & 1) as i32)) } + /// Decode a zig-zag encoded Avro long (64-bit). #[inline] pub(crate) fn get_long(&mut self) -> Result { let val = self.read_vlq()?; Ok((val >> 1) as i64 ^ -((val & 1) as i64)) } + /// Read a variable-length byte array from Avro (where the length is stored as an Avro long). pub(crate) fn get_bytes(&mut self) -> Result<&'a [u8], ArrowError> { let len: usize = self.get_long()?.try_into().map_err(|_| { ArrowError::ParseError("offset overflow reading avro bytes".to_string()) })?; - if (self.buf.len() < len) { + if self.buf.len() < len { return Err(ArrowError::ParseError( "Unexpected EOF reading bytes".to_string(), )); @@ -95,9 +97,10 @@ impl<'a> AvroCursor<'a> { Ok(ret) } + /// Read a little-endian 32-bit float #[inline] pub(crate) fn get_float(&mut self) -> Result { - if (self.buf.len() < 4) { + if self.buf.len() < 4 { return Err(ArrowError::ParseError( "Unexpected EOF reading float".to_string(), )); @@ -107,15 +110,28 @@ impl<'a> AvroCursor<'a> { Ok(ret) } + /// Read a little-endian 64-bit float #[inline] pub(crate) fn get_double(&mut self) -> Result { - if (self.buf.len() < 8) { + if self.buf.len() < 8 { return Err(ArrowError::ParseError( - "Unexpected EOF reading float".to_string(), + "Unexpected EOF reading double".to_string(), )); } let ret = f64::from_le_bytes(self.buf[..8].try_into().unwrap()); self.buf = &self.buf[8..]; Ok(ret) } + + /// Read exactly `n` bytes from the buffer (e.g. for Avro `fixed`). + pub(crate) fn get_fixed(&mut self, n: usize) -> Result<&'a [u8], ArrowError> { + if self.buf.len() < n { + return Err(ArrowError::ParseError( + "Unexpected EOF reading fixed".to_string(), + )); + } + let ret = &self.buf[..n]; + self.buf = &self.buf[n..]; + Ok(ret) + } } diff --git a/arrow-avro/src/reader/header.rs b/arrow-avro/src/reader/header.rs index 98c285171bf3..99f2163fa5bb 100644 --- a/arrow-avro/src/reader/header.rs +++ b/arrow-avro/src/reader/header.rs @@ -14,7 +14,6 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. - //! Decoder for [`Header`] use crate::compression::{CompressionCodec, CODEC_METADATA_KEY}; @@ -74,17 +73,18 @@ impl Header { self.sync } - /// Returns the [`CompressionCodec`] if any + /// Returns the [`CompressionCodec`] if any. pub fn compression(&self) -> Result, ArrowError> { let v = self.get(CODEC_METADATA_KEY); - match v { None | Some(b"null") => Ok(None), Some(b"deflate") => Ok(Some(CompressionCodec::Deflate)), Some(b"snappy") => Ok(Some(CompressionCodec::Snappy)), Some(b"zstandard") => Ok(Some(CompressionCodec::ZStandard)), + Some(b"bzip2") => Ok(Some(CompressionCodec::Bzip2)), + Some(b"xz") => Ok(Some(CompressionCodec::Xz)), Some(v) => Err(ArrowError::ParseError(format!( - "Unrecognized compression codec \'{}\'", + "Unrecognized compression codec '{}'", String::from_utf8_lossy(v) ))), } @@ -147,8 +147,6 @@ impl HeaderDecoder { /// This method can be called multiple times with consecutive chunks of data, allowing /// integration with chunked IO systems like [`BufRead::fill_buf`] /// - /// All errors should be considered fatal, and decoding aborted - /// /// Once the entire [`Header`] has been decoded this method will not read any further /// input bytes, and the header can be obtained with [`Self::flush`] /// @@ -264,13 +262,13 @@ impl HeaderDecoder { #[cfg(test)] mod test { use super::*; - use crate::codec::{AvroDataType, AvroField}; + use crate::codec::AvroField; use crate::reader::read_header; use crate::schema::SCHEMA_METADATA_KEY; use crate::test_util::arrow_test_data; use arrow_schema::{DataType, Field, Fields, TimeUnit}; use std::fs::File; - use std::io::{BufRead, BufReader}; + use std::io::BufReader; #[test] fn test_header_decode() { diff --git a/arrow-avro/src/reader/mod.rs b/arrow-avro/src/reader/mod.rs index 12fa67d9c8e3..4d0cbb035088 100644 --- a/arrow-avro/src/reader/mod.rs +++ b/arrow-avro/src/reader/mod.rs @@ -14,7 +14,6 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. - //! Read Avro data to Arrow use crate::reader::block::{Block, BlockDecoder}; @@ -45,7 +44,6 @@ fn read_header(mut reader: R) -> Result { break; } } - decoder .flush() .ok_or_else(|| ArrowError::ParseError("Unexpected EOF".to_string())) @@ -54,7 +52,6 @@ fn read_header(mut reader: R) -> Result { /// Return an iterator of [`Block`] from the provided [`BufRead`] fn read_blocks(mut reader: R) -> impl Iterator> { let mut decoder = BlockDecoder::default(); - let mut try_next = move || { loop { let buf = reader.fill_buf()?; @@ -76,42 +73,54 @@ fn read_blocks(mut reader: R) -> impl Iterator RecordBatch { + /// Helper to read an Avro file into a `RecordBatch`. + /// + /// - `strict_mode`: if `true`, we reject unions of the form `[T,"null"]`. + fn read_file(file: &str, batch_size: usize, strict_mode: bool) -> RecordBatch { let file = File::open(file).unwrap(); let mut reader = BufReader::new(file); let header = read_header(&mut reader).unwrap(); let compression = header.compression().unwrap(); let schema = header.schema().unwrap().unwrap(); let root = AvroField::try_from(&schema).unwrap(); - let mut decoder = RecordDecoder::try_new(root.data_type()).unwrap(); - + let mut decoder = RecordDecoder::try_new(root.data_type(), strict_mode).unwrap(); for result in read_blocks(reader) { let block = result.unwrap(); assert_eq!(block.sync, header.sync()); - if let Some(c) = compression { - let decompressed = c.decompress(&block.data).unwrap(); - - let mut offset = 0; - let mut remaining = block.count; - while remaining > 0 { - let to_read = remaining.max(batch_size); - offset += decoder - .decode(&decompressed[offset..], block.count) - .unwrap(); - - remaining -= to_read; - } - assert_eq!(offset, decompressed.len()); + let block_data = if let Some(c) = compression { + c.decompress(&block.data).unwrap() + } else { + block.data + }; + let mut offset = 0; + let mut remaining = block.count; + while remaining > 0 { + let to_read = remaining.min(batch_size); + offset += decoder.decode(&block_data[offset..], to_read).unwrap(); + remaining -= to_read; } + assert_eq!(offset, block_data.len()); } decoder.flush().unwrap() } @@ -122,6 +131,8 @@ mod test { "avro/alltypes_plain.avro", "avro/alltypes_plain.snappy.avro", "avro/alltypes_plain.zstandard.avro", + "avro/alltypes_plain.bzip2.avro", + "avro/alltypes_plain.xz.avro", ]; let expected = RecordBatch::try_from_iter_with_nullable([ @@ -207,12 +218,1078 @@ mod test { ), ]) .unwrap(); - for file in files { let file = arrow_test_data(file); + assert_eq!(read_file(&file, 8, false), expected); + assert_eq!(read_file(&file, 3, false), expected); + } + } + + #[test] + fn test_alltypes_dictionary() { + let file = "avro/alltypes_dictionary.avro"; + let expected = RecordBatch::try_from_iter_with_nullable([ + ("id", Arc::new(Int32Array::from(vec![0, 1])) as _, true), + ( + "bool_col", + Arc::new(BooleanArray::from(vec![Some(true), Some(false)])) as _, + true, + ), + ( + "tinyint_col", + Arc::new(Int32Array::from(vec![0, 1])) as _, + true, + ), + ( + "smallint_col", + Arc::new(Int32Array::from(vec![0, 1])) as _, + true, + ), + ("int_col", Arc::new(Int32Array::from(vec![0, 1])) as _, true), + ( + "bigint_col", + Arc::new(Int64Array::from(vec![0, 10])) as _, + true, + ), + ( + "float_col", + Arc::new(Float32Array::from(vec![0.0, 1.1])) as _, + true, + ), + ( + "double_col", + Arc::new(Float64Array::from(vec![0.0, 10.1])) as _, + true, + ), + ( + "date_string_col", + Arc::new(BinaryArray::from_iter_values([b"01/01/09", b"01/01/09"])) as _, + true, + ), + ( + "string_col", + Arc::new(BinaryArray::from_iter_values([b"0", b"1"])) as _, + true, + ), + ( + "timestamp_col", + Arc::new( + TimestampMicrosecondArray::from_iter_values([ + 1230768000000000, // 2009-01-01T00:00:00.000 + 1230768060000000, // 2009-01-01T00:01:00.000 + ]) + .with_timezone("+00:00"), + ) as _, + true, + ), + ]) + .unwrap(); + let file_path = arrow_test_data(file); + let batch_large = read_file(&file_path, 8, false); + assert_eq!( + batch_large, expected, + "Decoded RecordBatch does not match for file {}", + file + ); + let batch_small = read_file(&file_path, 3, false); + assert_eq!( + batch_small, expected, + "Decoded RecordBatch (batch size 3) does not match for file {}", + file + ); + } + + #[test] + fn test_alltypes_nulls_plain() { + let file = "avro/alltypes_nulls_plain.avro"; + let expected = RecordBatch::try_from_iter_with_nullable([ + ( + "string_col", + Arc::new(StringArray::from(vec![None::<&str>])) as _, + true, + ), + ("int_col", Arc::new(Int32Array::from(vec![None])) as _, true), + ( + "bool_col", + Arc::new(BooleanArray::from(vec![None])) as _, + true, + ), + ( + "bigint_col", + Arc::new(Int64Array::from(vec![None])) as _, + true, + ), + ( + "float_col", + Arc::new(Float32Array::from(vec![None])) as _, + true, + ), + ( + "double_col", + Arc::new(Float64Array::from(vec![None])) as _, + true, + ), + ( + "bytes_col", + Arc::new(BinaryArray::from(vec![None::<&[u8]>])) as _, + true, + ), + ]) + .unwrap(); + let file_path = arrow_test_data(file); + let batch_large = read_file(&file_path, 8, false); + assert_eq!( + batch_large, expected, + "Decoded RecordBatch does not match for file {}", + file + ); + let batch_small = read_file(&file_path, 3, false); + assert_eq!( + batch_small, expected, + "Decoded RecordBatch (batch size 3) does not match for file {}", + file + ); + } + + #[test] + fn test_binary() { + let file = arrow_test_data("avro/binary.avro"); + let batch = read_file(&file, 8, false); + let expected = RecordBatch::try_from_iter_with_nullable([( + "foo", + Arc::new(BinaryArray::from_iter_values(vec![ + b"\x00".as_ref(), + b"\x01".as_ref(), + b"\x02".as_ref(), + b"\x03".as_ref(), + b"\x04".as_ref(), + b"\x05".as_ref(), + b"\x06".as_ref(), + b"\x07".as_ref(), + b"\x08".as_ref(), + b"\t".as_ref(), + b"\n".as_ref(), + b"\x0b".as_ref(), + ])) as Arc, + true, + )]) + .unwrap(); + assert_eq!(batch, expected); + } + + #[test] + fn test_decimal() { + let files = [ + ("avro/fixed_length_decimal.avro", 25, 2), + ("avro/fixed_length_decimal_legacy.avro", 13, 2), + ("avro/int32_decimal.avro", 4, 2), + ("avro/int64_decimal.avro", 10, 2), + ]; + let decimal_values: Vec = (1..=24).map(|n| n as i128 * 100).collect(); + for (file, precision, scale) in files { + let file_path = arrow_test_data(file); + let actual_batch = read_file(&file_path, 8, false); + let expected_array = Decimal128Array::from_iter_values(decimal_values.clone()) + .with_precision_and_scale(precision, scale) + .unwrap(); + let mut meta = HashMap::new(); + meta.insert("precision".to_string(), precision.to_string()); + meta.insert("scale".to_string(), scale.to_string()); + let field_with_meta = Field::new("value", DataType::Decimal128(precision, scale), true) + .with_metadata(meta); + let expected_schema = Arc::new(Schema::new(vec![field_with_meta])); + let expected_batch = + RecordBatch::try_new(expected_schema.clone(), vec![Arc::new(expected_array)]) + .expect("Failed to build expected RecordBatch"); + assert_eq!( + actual_batch, expected_batch, + "Decoded RecordBatch does not match the expected Decimal128 data for file {}", + file + ); + let actual_batch_small = read_file(&file_path, 3, false); + assert_eq!( + actual_batch_small, expected_batch, + "Decoded RecordBatch does not match the expected Decimal128 data for file {} with batch size 3", + file + ); + } + } + + #[test] + fn test_datapage_v2() { + let file = arrow_test_data("avro/datapage_v2.snappy.avro"); + let batch = read_file(&file, 8, false); + let a = StringArray::from(vec![ + Some("abc"), + Some("abc"), + Some("abc"), + None, + Some("abc"), + ]); + let b = Int32Array::from(vec![Some(1), Some(2), Some(3), Some(4), Some(5)]); + let c = Float64Array::from(vec![Some(2.0), Some(3.0), Some(4.0), Some(5.0), Some(2.0)]); + let d = BooleanArray::from(vec![ + Some(true), + Some(true), + Some(true), + Some(false), + Some(true), + ]); + let e_values = Int32Array::from(vec![ + Some(1), + Some(2), + Some(3), + Some(1), + Some(2), + Some(3), + Some(1), + Some(2), + ]); + let e_offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0i32, 3, 3, 3, 6, 8])); + let e_validity = Some(NullBuffer::from(vec![true, false, false, true, true])); + let field_e = Arc::new(Field::new("item", DataType::Int32, true)); + let e = ListArray::new(field_e, e_offsets, Arc::new(e_values), e_validity); + let expected = RecordBatch::try_from_iter_with_nullable([ + ("a", Arc::new(a) as Arc, true), + ("b", Arc::new(b) as Arc, true), + ("c", Arc::new(c) as Arc, true), + ("d", Arc::new(d) as Arc, true), + ("e", Arc::new(e) as Arc, true), + ]) + .unwrap(); + assert_eq!(batch, expected); + } + + #[test] + fn test_dict_pages_offset_zero() { + let file = arrow_test_data("avro/dict-page-offset-zero.avro"); + let batch = read_file(&file, 32, false); + let num_rows = batch.num_rows(); + let expected_field = Int32Array::from(vec![Some(1552); num_rows]); + let expected = RecordBatch::try_from_iter_with_nullable([( + "l_partkey", + Arc::new(expected_field) as Arc, + true, + )]) + .unwrap(); + assert_eq!(batch, expected); + } + + #[test] + fn test_list_columns() { + let file = arrow_test_data("avro/list_columns.avro"); + let mut int64_list_builder = ListBuilder::new(Int64Builder::new()); + { + { + let values = int64_list_builder.values(); + values.append_value(1); + values.append_value(2); + values.append_value(3); + } + int64_list_builder.append(true); + } + { + { + let values = int64_list_builder.values(); + values.append_null(); + values.append_value(1); + } + int64_list_builder.append(true); + } + { + { + let values = int64_list_builder.values(); + values.append_value(4); + } + int64_list_builder.append(true); + } + let int64_list = int64_list_builder.finish(); + let mut utf8_list_builder = ListBuilder::new(StringBuilder::new()); + { + { + let values = utf8_list_builder.values(); + values.append_value("abc"); + values.append_value("efg"); + values.append_value("hij"); + } + utf8_list_builder.append(true); + } + { + utf8_list_builder.append(false); + } + { + { + let values = utf8_list_builder.values(); + values.append_value("efg"); + values.append_null(); + values.append_value("hij"); + values.append_value("xyz"); + } + utf8_list_builder.append(true); + } + let utf8_list = utf8_list_builder.finish(); + let expected = RecordBatch::try_from_iter_with_nullable([ + ("int64_list", Arc::new(int64_list) as Arc, true), + ("utf8_list", Arc::new(utf8_list) as Arc, true), + ]) + .unwrap(); + let batch = read_file(&file, 8, false); + assert_eq!(batch, expected); + } + + #[test] + fn test_nested_lists() { + let file = arrow_test_data("avro/nested_lists.snappy.avro"); + let inner_values = StringArray::from(vec![ + Some("a"), + Some("b"), + Some("c"), + Some("d"), + Some("a"), + Some("b"), + Some("c"), + Some("d"), + Some("e"), + Some("a"), + Some("b"), + Some("c"), + Some("d"), + Some("e"), + Some("f"), + ]); + let inner_offsets = Buffer::from_slice_ref([0, 2, 3, 3, 4, 6, 8, 8, 9, 11, 13, 14, 14, 15]); + let inner_validity = [ + true, true, false, true, true, true, false, true, true, true, true, false, true, + ]; + let inner_null_buffer = Buffer::from_iter(inner_validity.iter().copied()); + let inner_field = Field::new("item", DataType::Utf8, true); + let inner_list_data = ArrayDataBuilder::new(DataType::List(Arc::new(inner_field))) + .len(13) + .add_buffer(inner_offsets) + .add_child_data(inner_values.to_data()) + .null_bit_buffer(Some(inner_null_buffer)) + .build() + .unwrap(); + let inner_list_array = ListArray::from(inner_list_data); + let middle_offsets = Buffer::from_slice_ref([0, 2, 4, 6, 8, 11, 13]); + let middle_validity = [true; 6]; + let middle_null_buffer = Buffer::from_iter(middle_validity.iter().copied()); + let middle_field = Field::new("item", inner_list_array.data_type().clone(), true); + let middle_list_data = ArrayDataBuilder::new(DataType::List(Arc::new(middle_field))) + .len(6) + .add_buffer(middle_offsets) + .add_child_data(inner_list_array.to_data()) + .null_bit_buffer(Some(middle_null_buffer)) + .build() + .unwrap(); + let middle_list_array = ListArray::from(middle_list_data); + let outer_offsets = Buffer::from_slice_ref([0, 2, 4, 6]); + let outer_null_buffer = Buffer::from_slice_ref([0b111]); // all 3 rows valid + let outer_field = Field::new("item", middle_list_array.data_type().clone(), true); + let outer_list_data = ArrayDataBuilder::new(DataType::List(Arc::new(outer_field))) + .len(3) + .add_buffer(outer_offsets) + .add_child_data(middle_list_array.to_data()) + .null_bit_buffer(Some(outer_null_buffer)) + .build() + .unwrap(); + let a_expected = ListArray::from(outer_list_data); + let b_expected = Int32Array::from(vec![1, 1, 1]); + let expected = RecordBatch::try_from_iter_with_nullable([ + ("a", Arc::new(a_expected) as Arc, true), + ("b", Arc::new(b_expected) as Arc, true), + ]) + .unwrap(); + let left = read_file(&file, 8, false); + assert_eq!(left, expected, "Mismatch for batch size=8"); + let left_small = read_file(&file, 3, false); + assert_eq!(left_small, expected, "Mismatch for batch size=3"); + } - assert_eq!(read_file(&file, 8), expected); - assert_eq!(read_file(&file, 3), expected); + #[test] + fn test_nested_records() { + let f1_f1_1 = StringArray::from(vec!["aaa", "bbb"]); + let f1_f1_2 = Int32Array::from(vec![10, 20]); + let rounded_pi = (std::f64::consts::PI * 100.0).round() / 100.0; + let f1_f1_3_1 = Float64Array::from(vec![rounded_pi, rounded_pi]); + let f1_f1_3 = StructArray::from(vec![( + Arc::new(Field::new("f1_3_1", DataType::Float64, false)), + Arc::new(f1_f1_3_1) as Arc, + )]); + let f1_expected = StructArray::from(vec![ + ( + Arc::new(Field::new("f1_1", DataType::Utf8, false)), + Arc::new(f1_f1_1) as Arc, + ), + ( + Arc::new(Field::new("f1_2", DataType::Int32, false)), + Arc::new(f1_f1_2) as Arc, + ), + ( + Arc::new(Field::new( + "f1_3", + DataType::Struct(Fields::from(vec![Field::new( + "f1_3_1", + DataType::Float64, + false, + )])), + false, + )), + Arc::new(f1_f1_3) as Arc, + ), + ]); + let f2_fields = vec![ + Field::new("f2_1", DataType::Boolean, false), + Field::new("f2_2", DataType::Float32, false), + ]; + let f2_struct_builder = StructBuilder::new( + f2_fields + .iter() + .map(|f| Arc::new(f.clone())) + .collect::>>(), + vec![ + Box::new(BooleanBuilder::new()) as Box, + Box::new(Float32Builder::new()) as Box, + ], + ); + let mut f2_list_builder = ListBuilder::new(f2_struct_builder); + { + let struct_builder = f2_list_builder.values(); + struct_builder.append(true); + { + let b = struct_builder.field_builder::(0).unwrap(); + b.append_value(true); + } + { + let b = struct_builder.field_builder::(1).unwrap(); + b.append_value(1.2_f32); + } + struct_builder.append(true); + { + let b = struct_builder.field_builder::(0).unwrap(); + b.append_value(true); + } + { + let b = struct_builder.field_builder::(1).unwrap(); + b.append_value(2.2_f32); + } + f2_list_builder.append(true); + } + { + let struct_builder = f2_list_builder.values(); + struct_builder.append(true); + { + let b = struct_builder.field_builder::(0).unwrap(); + b.append_value(false); + } + { + let b = struct_builder.field_builder::(1).unwrap(); + b.append_value(10.2_f32); + } + f2_list_builder.append(true); + } + let f2_expected = f2_list_builder.finish(); + let mut f3_struct_builder = StructBuilder::new( + vec![Arc::new(Field::new("f3_1", DataType::Utf8, false))], + vec![Box::new(StringBuilder::new()) as Box], + ); + f3_struct_builder.append(true); + { + let b = f3_struct_builder.field_builder::(0).unwrap(); + b.append_value("xyz"); + } + f3_struct_builder.append(false); + { + let b = f3_struct_builder.field_builder::(0).unwrap(); + b.append_null(); + } + let f3_expected = f3_struct_builder.finish(); + let f4_fields = [Field::new("f4_1", DataType::Int64, false)]; + let f4_struct_builder = StructBuilder::new( + f4_fields + .iter() + .map(|f| Arc::new(f.clone())) + .collect::>>(), + vec![Box::new(Int64Builder::new()) as Box], + ); + let mut f4_list_builder = ListBuilder::new(f4_struct_builder); + { + let struct_builder = f4_list_builder.values(); + struct_builder.append(true); + { + let b = struct_builder.field_builder::(0).unwrap(); + b.append_value(200); + } + struct_builder.append(false); + { + let b = struct_builder.field_builder::(0).unwrap(); + b.append_null(); + } + f4_list_builder.append(true); + } + { + let struct_builder = f4_list_builder.values(); + struct_builder.append(false); + { + let b = struct_builder.field_builder::(0).unwrap(); + b.append_null(); + } + struct_builder.append(true); + { + let b = struct_builder.field_builder::(0).unwrap(); + b.append_value(300); + } + f4_list_builder.append(true); } + let f4_expected = f4_list_builder.finish(); + let expected = RecordBatch::try_from_iter_with_nullable([ + ("f1", Arc::new(f1_expected) as Arc, false), + ("f2", Arc::new(f2_expected) as Arc, false), + ("f3", Arc::new(f3_expected) as Arc, true), + ("f4", Arc::new(f4_expected) as Arc, false), + ]) + .unwrap(); + let file = arrow_test_data("avro/nested_records.avro"); + let batch_large = read_file(&file, 8, false); + assert_eq!( + batch_large, expected, + "Decoded RecordBatch does not match expected data for nested records (batch size 8)" + ); + let batch_small = read_file(&file, 3, false); + assert_eq!( + batch_small, expected, + "Decoded RecordBatch does not match expected data for nested records (batch size 3)" + ); + } + + #[test] + fn test_nonnullable_impala() { + let file = arrow_test_data("avro/nonnullable.impala.avro"); + let id = Int64Array::from(vec![Some(8)]); + let mut int_array_builder = ListBuilder::new(Int32Builder::new()); + { + let vb = int_array_builder.values(); + vb.append_value(-1); + } + int_array_builder.append(true); // finalize one sub-list + let int_array = int_array_builder.finish(); + let mut iaa_builder = ListBuilder::new(ListBuilder::new(Int32Builder::new())); + { + let inner_list_builder = iaa_builder.values(); + { + let vb = inner_list_builder.values(); + vb.append_value(-1); + vb.append_value(-2); + } + inner_list_builder.append(true); + inner_list_builder.append(true); + } + iaa_builder.append(true); + let int_array_array = iaa_builder.finish(); + use arrow_array::builder::MapFieldNames; + let field_names = MapFieldNames { + entry: "entries".to_string(), + key: "key".to_string(), + value: "value".to_string(), + }; + let mut int_map_builder = + MapBuilder::new(Some(field_names), StringBuilder::new(), Int32Builder::new()); + { + let (keys, vals) = int_map_builder.entries(); + keys.append_value("k1"); + vals.append_value(-1); + } + int_map_builder.append(true).unwrap(); // finalize map for row 0 + let int_map = int_map_builder.finish(); + let field_names2 = MapFieldNames { + entry: "entries".to_string(), + key: "key".to_string(), + value: "value".to_string(), + }; + let mut ima_builder = ListBuilder::new(MapBuilder::new( + Some(field_names2), + StringBuilder::new(), + Int32Builder::new(), + )); + { + let map_builder = ima_builder.values(); + map_builder.append(true).unwrap(); + { + let (keys, vals) = map_builder.entries(); + keys.append_value("k1"); + vals.append_value(1); + } + map_builder.append(true).unwrap(); + map_builder.append(true).unwrap(); + map_builder.append(true).unwrap(); + } + ima_builder.append(true); + let int_map_array_ = ima_builder.finish(); + let mut nested_sb = StructBuilder::new( + vec![ + Arc::new(Field::new("a", DataType::Int32, true)), + Arc::new(Field::new( + "B", + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + true, + )), + Arc::new(Field::new( + "c", + DataType::Struct( + vec![Field::new( + "D", + DataType::List(Arc::new(Field::new( + "item", + DataType::List(Arc::new(Field::new( + "item", + DataType::Struct( + vec![ + Field::new("e", DataType::Int32, true), + Field::new("f", DataType::Utf8, true), + ] + .into(), + ), + true, + ))), + true, + ))), + true, + )] + .into(), + ), + true, + )), + Arc::new(Field::new( + "G", + DataType::Map( + Arc::new(Field::new( + "entries", + DataType::Struct( + vec![ + Field::new("key", DataType::Utf8, false), + Field::new( + "value", + DataType::Struct( + vec![Field::new( + "h", + DataType::Struct( + vec![Field::new( + "i", + DataType::List(Arc::new(Field::new( + "item", + DataType::Float64, + true, + ))), + true, + )] + .into(), + ), + true, + )] + .into(), + ), + true, + ), + ] + .into(), + ), + false, + )), + false, + ), + true, + )), + ], + vec![ + Box::new(Int32Builder::new()), + Box::new(ListBuilder::new(Int32Builder::new())), + { + let d_field = Field::new( + "D", + DataType::List(Arc::new(Field::new( + "item", + DataType::List(Arc::new(Field::new( + "item", + DataType::Struct( + vec![ + Field::new("e", DataType::Int32, true), + Field::new("f", DataType::Utf8, true), + ] + .into(), + ), + true, + ))), + true, + ))), + true, + ); + Box::new(StructBuilder::new( + vec![Arc::new(d_field)], + vec![Box::new({ + let ef_struct_builder = StructBuilder::new( + vec![ + Arc::new(Field::new("e", DataType::Int32, true)), + Arc::new(Field::new("f", DataType::Utf8, true)), + ], + vec![ + Box::new(Int32Builder::new()), + Box::new(StringBuilder::new()), + ], + ); + let list_of_ef = ListBuilder::new(ef_struct_builder); + ListBuilder::new(list_of_ef) + })], + )) + }, + { + let map_field_names = MapFieldNames { + entry: "entries".to_string(), + key: "key".to_string(), + value: "value".to_string(), + }; + let i_list_builder = ListBuilder::new(Float64Builder::new()); + let h_struct = StructBuilder::new( + vec![Arc::new(Field::new( + "i", + DataType::List(Arc::new(Field::new("item", DataType::Float64, true))), + true, + ))], + vec![Box::new(i_list_builder)], + ); + let g_value_builder = StructBuilder::new( + vec![Arc::new(Field::new( + "h", + DataType::Struct( + vec![Field::new( + "i", + DataType::List(Arc::new(Field::new( + "item", + DataType::Float64, + true, + ))), + true, + )] + .into(), + ), + true, + ))], + vec![Box::new(h_struct)], + ); + Box::new(MapBuilder::new( + Some(map_field_names), + StringBuilder::new(), + g_value_builder, + )) + }, + ], + ); + nested_sb.append(true); + { + let a_builder = nested_sb.field_builder::(0).unwrap(); + a_builder.append_value(-1); + } + { + let b_builder = nested_sb + .field_builder::>(1) + .unwrap(); + { + let vb = b_builder.values(); + vb.append_value(-1); + } + b_builder.append(true); + } + { + let c_struct_builder = nested_sb.field_builder::(2).unwrap(); + c_struct_builder.append(true); + let d_list_builder = c_struct_builder + .field_builder::>>(0) + .unwrap(); + { + let sub_list_builder = d_list_builder.values(); + { + let ef_struct = sub_list_builder.values(); + ef_struct.append(true); + { + let e_b = ef_struct.field_builder::(0).unwrap(); + e_b.append_value(-1); + let f_b = ef_struct.field_builder::(1).unwrap(); + f_b.append_value("nonnullable"); + } + sub_list_builder.append(true); + } + d_list_builder.append(true); + } + } + { + let g_map_builder = nested_sb + .field_builder::>(3) + .unwrap(); + g_map_builder.append(true).unwrap(); + } + let nested_struct = nested_sb.finish(); + let expected = RecordBatch::try_from_iter_with_nullable([ + ("ID", Arc::new(id) as Arc, true), + ("Int_Array", Arc::new(int_array), true), + ("int_array_array", Arc::new(int_array_array), true), + ("Int_Map", Arc::new(int_map), true), + ("int_map_array", Arc::new(int_map_array_), true), + ("nested_Struct", Arc::new(nested_struct), true), + ]) + .unwrap(); + let batch_large = read_file(&file, 8, false); + assert_eq!(batch_large, expected, "Mismatch for batch_size=8"); + let batch_small = read_file(&file, 3, false); + assert_eq!(batch_small, expected, "Mismatch for batch_size=3"); + } + + #[test] + fn test_nullable_impala() { + let file = arrow_test_data("avro/nullable.impala.avro"); + let batch1 = read_file(&file, 3, false); + let batch2 = read_file(&file, 8, false); + assert_eq!(batch1, batch2); + let batch = batch1; + assert_eq!(batch.num_rows(), 7); + let id_array = batch + .column(0) + .as_any() + .downcast_ref::() + .expect("id column should be an Int64Array"); + let expected_ids = [1, 2, 3, 4, 5, 6, 7]; + for (i, &expected_id) in expected_ids.iter().enumerate() { + assert_eq!( + id_array.value(i), + expected_id, + "Mismatch in id at row {}", + i + ); + } + let int_array = batch + .column(1) + .as_any() + .downcast_ref::() + .expect("int_array column should be a ListArray"); + { + let offsets = int_array.value_offsets(); + let start = offsets[0] as usize; + let end = offsets[1] as usize; + let values = int_array + .values() + .as_any() + .downcast_ref::() + .expect("Values of int_array should be an Int32Array"); + let row0: Vec> = (start..end).map(|i| Some(values.value(i))).collect(); + assert_eq!( + row0, + vec![Some(1), Some(2), Some(3)], + "Mismatch in int_array row 0" + ); + } + let nested_struct = batch + .column(5) + .as_any() + .downcast_ref::() + .expect("nested_struct column should be a StructArray"); + let a_array = nested_struct + .column_by_name("A") + .expect("Field A should exist in nested_struct") + .as_any() + .downcast_ref::() + .expect("Field A should be an Int32Array"); + assert_eq!(a_array.value(0), 1, "Mismatch in nested_struct.A at row 0"); + assert!( + !a_array.is_valid(1), + "Expected null in nested_struct.A at row 1" + ); + assert!( + !a_array.is_valid(3), + "Expected null in nested_struct.A at row 3" + ); + assert_eq!(a_array.value(6), 7, "Mismatch in nested_struct.A at row 6"); + } + + #[test] + fn test_nulls_snappy() { + let file = arrow_test_data("avro/nulls.snappy.avro"); + let batch_large = read_file(&file, 8, false); + use arrow_array::{Int32Array, StructArray}; + use arrow_buffer::Buffer; + use arrow_data::ArrayDataBuilder; + use arrow_schema::{DataType, Field, Fields}; + let b_c_int = Int32Array::from(vec![None; 8]); + let b_c_int_data = b_c_int.into_data(); + let b_struct_field = Field::new("b_c_int", DataType::Int32, true); + let b_struct_type = DataType::Struct(Fields::from(vec![b_struct_field])); + let struct_validity = Buffer::from_iter((0..8).map(|_| true)); + let b_struct_data = ArrayDataBuilder::new(b_struct_type) + .len(8) + .null_bit_buffer(Some(struct_validity)) + .child_data(vec![b_c_int_data]) + .build() + .unwrap(); + let b_struct_array = StructArray::from(b_struct_data); + let expected = arrow_array::RecordBatch::try_from_iter_with_nullable([( + "b_struct", + Arc::new(b_struct_array) as _, + true, + )]) + .unwrap(); + assert_eq!(batch_large, expected, "Mismatch for batch_size=8"); + let batch_small = read_file(&file, 3, false); + assert_eq!(batch_small, expected, "Mismatch for batch_size=3"); + } + + #[test] + fn test_repeated_no_annotation() { + let file = arrow_test_data("avro/repeated_no_annotation.avro"); + let batch_large = read_file(&file, 8, false); + use arrow_array::{Int32Array, Int64Array, ListArray, StringArray, StructArray}; + use arrow_buffer::Buffer; + use arrow_data::ArrayDataBuilder; + use arrow_schema::{DataType, Field, Fields}; + let id_array = Int32Array::from(vec![1, 2, 3, 4, 5, 6]); + let number_array = Int64Array::from(vec![ + Some(5555555555), + Some(1111111111), + Some(1111111111), + Some(2222222222), + Some(3333333333), + ]); + let kind_array = + StringArray::from(vec![None, Some("home"), Some("home"), None, Some("mobile")]); + let phone_fields = Fields::from(vec![ + Field::new("number", DataType::Int64, true), + Field::new("kind", DataType::Utf8, true), + ]); + let phone_struct_data = ArrayDataBuilder::new(DataType::Struct(phone_fields)) + .len(5) + .child_data(vec![number_array.into_data(), kind_array.into_data()]) + .build() + .unwrap(); + let phone_struct_array = StructArray::from(phone_struct_data); + let phone_list_offsets = Buffer::from_slice_ref([0, 0, 0, 0, 1, 2, 5]); + let phone_list_validity = Buffer::from_iter([false, false, true, true, true, true]); + let phone_item_field = Field::new("item", phone_struct_array.data_type().clone(), true); + let phone_list_data = ArrayDataBuilder::new(DataType::List(Arc::new(phone_item_field))) + .len(6) + .add_buffer(phone_list_offsets) + .null_bit_buffer(Some(phone_list_validity)) + .child_data(vec![phone_struct_array.into_data()]) + .build() + .unwrap(); + let phone_list_array = ListArray::from(phone_list_data); + let phone_numbers_validity = Buffer::from_iter([false, false, true, true, true, true]); + let phone_numbers_field = Field::new("phone", phone_list_array.data_type().clone(), true); + let phone_numbers_struct_data = + ArrayDataBuilder::new(DataType::Struct(Fields::from(vec![phone_numbers_field]))) + .len(6) + .null_bit_buffer(Some(phone_numbers_validity)) + .child_data(vec![phone_list_array.into_data()]) + .build() + .unwrap(); + let phone_numbers_struct_array = StructArray::from(phone_numbers_struct_data); + let expected = arrow_array::RecordBatch::try_from_iter_with_nullable([ + ("id", Arc::new(id_array) as _, true), + ( + "phoneNumbers", + Arc::new(phone_numbers_struct_array) as _, + true, + ), + ]) + .unwrap(); + assert_eq!(batch_large, expected, "Mismatch for batch_size=8"); + let batch_small = read_file(&file, 3, false); + assert_eq!(batch_small, expected, "Mismatch for batch_size=3"); + } + + #[test] + fn test_simple() { + // Each entry: (filename, batch_size1, expected_batch, batch_size2) + let tests = [ + ("avro/simple_enum.avro", 4, build_expected_enum(), 2), + ("avro/simple_fixed.avro", 2, build_expected_fixed(), 1), + ]; + + fn build_expected_enum() -> RecordBatch { + let keys_f1 = Int32Array::from(vec![0, 1, 2, 3]); + let vals_f1 = StringArray::from(vec!["a", "b", "c", "d"]); + let f1_dict = + DictionaryArray::::try_new(keys_f1, Arc::new(vals_f1)).unwrap(); + let keys_f2 = Int32Array::from(vec![2, 3, 0, 1]); + let vals_f2 = StringArray::from(vec!["e", "f", "g", "h"]); + let f2_dict = + DictionaryArray::::try_new(keys_f2, Arc::new(vals_f2)).unwrap(); + let keys_f3 = Int32Array::from(vec![Some(1), Some(2), None, Some(0)]); + let vals_f3 = StringArray::from(vec!["i", "j", "k"]); + let f3_dict = + DictionaryArray::::try_new(keys_f3, Arc::new(vals_f3)).unwrap(); + let dict_type = + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)); + let expected_schema = Arc::new(Schema::new(vec![ + Field::new("f1", dict_type.clone(), false), + Field::new("f2", dict_type.clone(), false), + Field::new("f3", dict_type.clone(), true), + ])); + RecordBatch::try_new( + expected_schema, + vec![ + Arc::new(f1_dict) as Arc, + Arc::new(f2_dict) as Arc, + Arc::new(f3_dict) as Arc, + ], + ) + .unwrap() + } + + fn build_expected_fixed() -> RecordBatch { + let f1 = + FixedSizeBinaryArray::try_from_iter(vec![b"abcde", b"12345"].into_iter()).unwrap(); + let f2 = + FixedSizeBinaryArray::try_from_iter(vec![b"fghijklmno", b"1234567890"].into_iter()) + .unwrap(); + let f3 = FixedSizeBinaryArray::try_from_sparse_iter_with_size( + vec![Some(b"ABCDEF" as &[u8]), None].into_iter(), + 6, + ) + .unwrap(); + let expected_schema = Arc::new(Schema::new(vec![ + Field::new("f1", DataType::FixedSizeBinary(5), false), + Field::new("f2", DataType::FixedSizeBinary(10), false), + Field::new("f3", DataType::FixedSizeBinary(6), true), + ])); + RecordBatch::try_new( + expected_schema, + vec![ + Arc::new(f1) as Arc, + Arc::new(f2) as Arc, + Arc::new(f3) as Arc, + ], + ) + .unwrap() + } + for (file_name, batch_size, expected, alt_batch_size) in tests { + let file = arrow_test_data(file_name); + let actual = read_file(&file, batch_size, false); + assert_eq!(actual, expected); + let actual2 = read_file(&file, alt_batch_size, false); + assert_eq!(actual2, expected); + } + } + + #[test] + fn test_single_nan() { + let file = crate::test_util::arrow_test_data("avro/single_nan.avro"); + let actual = read_file(&file, 1, false); + use arrow_array::Float64Array; + let schema = Arc::new(Schema::new(vec![Field::new( + "mycol", + DataType::Float64, + true, + )])); + let col = Float64Array::from(vec![None]); + let expected = RecordBatch::try_new(schema, vec![Arc::new(col)]).unwrap(); + assert_eq!(actual, expected); + let actual2 = read_file(&file, 2, false); + assert_eq!(actual2, expected); } } diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs index 52a58cf63303..3f56997f5733 100644 --- a/arrow-avro/src/reader/record.rs +++ b/arrow-avro/src/reader/record.rs @@ -16,44 +16,53 @@ // under the License. use crate::codec::{AvroDataType, Codec, Nullability}; -use crate::reader::block::{Block, BlockDecoder}; use crate::reader::cursor::AvroCursor; -use crate::reader::header::Header; -use crate::schema::*; +use arrow_array::builder::{Decimal128Builder, Decimal256Builder, PrimitiveBuilder}; use arrow_array::types::*; use arrow_array::*; use arrow_buffer::*; use arrow_schema::{ - ArrowError, DataType, Field as ArrowField, FieldRef, Fields, Schema as ArrowSchema, SchemaRef, + ArrowError, DataType, Field as ArrowField, FieldRef, Fields, IntervalUnit, + Schema as ArrowSchema, SchemaRef, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, }; -use std::collections::HashMap; -use std::io::Read; +use std::cmp::Ordering; use std::sync::Arc; -/// Decodes avro encoded data into [`RecordBatch`] +const DEFAULT_CAPACITY: usize = 1024; + +/// A decoder that converts Avro-encoded data into an Arrow [`RecordBatch`]. +#[derive(Debug)] pub struct RecordDecoder { schema: SchemaRef, fields: Vec, } impl RecordDecoder { - pub fn try_new(data_type: &AvroDataType) -> Result { - match Decoder::try_new(data_type)? { - Decoder::Record(fields, encodings) => Ok(Self { + /// Create a new [`RecordDecoder`] from an [`AvroDataType`] that must be a `Record`. + /// + /// - `strict_mode`: if `true`, we will throw an error if we encounter + /// a union of the form `[T, "null"]` (i.e. `Nullability::NullSecond`). + pub fn try_new(data_type: &AvroDataType, strict_mode: bool) -> Result { + match Decoder::try_new(data_type, strict_mode)? { + Decoder::Record(fields, decoders) => Ok(Self { schema: Arc::new(ArrowSchema::new(fields)), - fields: encodings, + fields: decoders, }), - encoding => Err(ArrowError::ParseError(format!( - "Expected record got {encoding:?}" + other => Err(ArrowError::ParseError(format!( + "Expected record got {other:?}" ))), } } + /// Return the [`SchemaRef`] describing the Arrow schema of rows produced by this decoder. pub fn schema(&self) -> &SchemaRef { &self.schema } - /// Decode `count` records from `buf` + /// Decode `count` Avro records from `buf`. + /// + /// This accumulates data in internal buffers. Once done reading, call + /// [`Self::flush`] to yield an Arrow [`RecordBatch`]. pub fn decode(&mut self, buf: &[u8], count: usize) -> Result { let mut cursor = AvroCursor::new(buf); for _ in 0..count { @@ -64,43 +73,66 @@ impl RecordDecoder { Ok(cursor.position()) } - /// Flush the decoded records into a [`RecordBatch`] + /// Flush into a [`RecordBatch`], + /// + /// We collect arrays from each `Decoder` and build a new [`RecordBatch`]. pub fn flush(&mut self) -> Result { let arrays = self .fields .iter_mut() - .map(|x| x.flush(None)) + .map(|d| d.flush(None)) .collect::, _>>()?; - RecordBatch::try_new(self.schema.clone(), arrays) } } +/// For 2-branch unions we store either `[null, T]` or `[T, null]`. +/// +/// - `NullFirst`: `[null, T]` => branch=0 => null, branch=1 => decode T +/// - `NullSecond`: `[T, null]` => branch=0 => decode T, branch=1 => null +#[derive(Debug, Copy, Clone)] +enum UnionOrder { + NullFirst, + NullSecond, +} + #[derive(Debug)] enum Decoder { + /// Primitive Types Null(usize), Boolean(BooleanBufferBuilder), Int32(Vec), Int64(Vec), Float32(Vec), Float64(Vec), + Binary(OffsetBufferBuilder, Vec), + String(OffsetBufferBuilder, Vec), + /// Complex Types + Record(Fields, Vec), + Enum(Arc<[String]>, Vec), + List(FieldRef, OffsetBufferBuilder, Box), + Map( + FieldRef, + OffsetBufferBuilder, + OffsetBufferBuilder, + Vec, + Box, + ), + Nullable(UnionOrder, NullBufferBuilder, Box), + Fixed(i32, Vec), + /// Logical Types + Decimal(usize, Option, Option, DecimalBuilder), Date32(Vec), TimeMillis(Vec), TimeMicros(Vec), TimestampMillis(bool, Vec), TimestampMicros(bool, Vec), - Binary(OffsetBufferBuilder, Vec), - String(OffsetBufferBuilder, Vec), - List(FieldRef, OffsetBufferBuilder, Box), - Record(Fields, Vec), - Nullable(Nullability, NullBufferBuilder, Box), + Interval(Vec), } impl Decoder { - fn try_new(data_type: &AvroDataType) -> Result { - let nyi = |s: &str| Err(ArrowError::NotYetImplemented(s.to_string())); - - let decoder = match data_type.codec() { + fn try_new(data_type: &AvroDataType, strict_mode: bool) -> Result { + let base = match &data_type.codec { Codec::Null => Self::Null(0), Codec::Boolean => Self::Boolean(BooleanBufferBuilder::new(DEFAULT_CAPACITY)), Codec::Int32 => Self::Int32(Vec::with_capacity(DEFAULT_CAPACITY)), @@ -111,182 +143,631 @@ impl Decoder { OffsetBufferBuilder::new(DEFAULT_CAPACITY), Vec::with_capacity(DEFAULT_CAPACITY), ), - Codec::Utf8 => Self::String( + Codec::String => Self::String( OffsetBufferBuilder::new(DEFAULT_CAPACITY), Vec::with_capacity(DEFAULT_CAPACITY), ), - Codec::Date32 => Self::Date32(Vec::with_capacity(DEFAULT_CAPACITY)), - Codec::TimeMillis => Self::TimeMillis(Vec::with_capacity(DEFAULT_CAPACITY)), - Codec::TimeMicros => Self::TimeMicros(Vec::with_capacity(DEFAULT_CAPACITY)), - Codec::TimestampMillis(is_utc) => { - Self::TimestampMillis(*is_utc, Vec::with_capacity(DEFAULT_CAPACITY)) + Codec::Record(avro_fields) => { + let mut fields = Vec::with_capacity(avro_fields.len()); + let mut children = Vec::with_capacity(avro_fields.len()); + for f in avro_fields.iter() { + // Recursively build a Decoder for each child + let child = Self::try_new(f.data_type(), strict_mode)?; + fields.push(f.field()); + children.push(child); + } + Self::Record(fields.into(), children) } - Codec::TimestampMicros(is_utc) => { - Self::TimestampMicros(*is_utc, Vec::with_capacity(DEFAULT_CAPACITY)) + Codec::Enum(syms, _) => { + Self::Enum(Arc::clone(syms), Vec::with_capacity(DEFAULT_CAPACITY)) } - Codec::Fixed(_) => return nyi("decoding fixed"), - Codec::Interval => return nyi("decoding interval"), - Codec::List(item) => { - let decoder = Self::try_new(item)?; + Codec::Array(child) => { + let child_dec = Self::try_new(child, strict_mode)?; + let item_field = child.field_with_name("item").with_nullable(true); Self::List( - Arc::new(item.field_with_name("item")), + Arc::new(item_field), + OffsetBufferBuilder::new(DEFAULT_CAPACITY), + Box::new(child_dec), + ) + } + Codec::Map(child) => { + let val_field = child.field_with_name("value").with_nullable(true); + let map_field = Arc::new(ArrowField::new( + "entries", + DataType::Struct(Fields::from(vec![ + ArrowField::new("key", DataType::Utf8, false), + val_field, + ])), + false, + )); + let valdec = Self::try_new(child, strict_mode)?; + Self::Map( + map_field, + OffsetBufferBuilder::new(DEFAULT_CAPACITY), OffsetBufferBuilder::new(DEFAULT_CAPACITY), - Box::new(decoder), + Vec::with_capacity(DEFAULT_CAPACITY), + Box::new(valdec), ) } - Codec::Struct(fields) => { - let mut arrow_fields = Vec::with_capacity(fields.len()); - let mut encodings = Vec::with_capacity(fields.len()); - for avro_field in fields.iter() { - let encoding = Self::try_new(avro_field.data_type())?; - arrow_fields.push(avro_field.field()); - encodings.push(encoding); + Codec::Fixed(sz) => Self::Fixed(*sz, Vec::with_capacity(DEFAULT_CAPACITY)), + Codec::Decimal(p, s, size) => { + let b = DecimalBuilder::new(*p, *s, *size)?; + Self::Decimal(*p, *s, *size, b) + } + Codec::Uuid => Self::Fixed(16, Vec::with_capacity(DEFAULT_CAPACITY)), + Codec::Date32 => Self::Date32(Vec::with_capacity(DEFAULT_CAPACITY)), + Codec::TimeMillis => Self::TimeMillis(Vec::with_capacity(DEFAULT_CAPACITY)), + Codec::TimeMicros => Self::TimeMicros(Vec::with_capacity(DEFAULT_CAPACITY)), + Codec::TimestampMillis(utc) => { + Self::TimestampMillis(*utc, Vec::with_capacity(DEFAULT_CAPACITY)) + } + Codec::TimestampMicros(utc) => { + Self::TimestampMicros(*utc, Vec::with_capacity(DEFAULT_CAPACITY)) + } + Codec::Duration => Self::Interval(Vec::with_capacity(DEFAULT_CAPACITY)), + }; + let union_order = match data_type.nullability { + None => None, + Some(Nullability::NullFirst) => Some(UnionOrder::NullFirst), + Some(Nullability::NullSecond) => { + if strict_mode { + return Err(ArrowError::ParseError( + "Found Avro union of the form ['T','null'], which is disallowed in strict_mode" + .to_string(), + )); } - Self::Record(arrow_fields.into(), encodings) + Some(UnionOrder::NullSecond) } }; - - Ok(match data_type.nullability() { - Some(nullability) => Self::Nullable( - nullability, + let decoder = match union_order { + Some(order) => Decoder::Nullable( + order, NullBufferBuilder::new(DEFAULT_CAPACITY), - Box::new(decoder), + Box::new(base), ), - None => decoder, - }) + None => base, + }; + Ok(decoder) } - /// Append a null record fn append_null(&mut self) { match self { - Self::Null(count) => *count += 1, + Self::Null(n) => *n += 1, Self::Boolean(b) => b.append(false), Self::Int32(v) | Self::Date32(v) | Self::TimeMillis(v) => v.push(0), Self::Int64(v) | Self::TimeMicros(v) | Self::TimestampMillis(_, v) | Self::TimestampMicros(_, v) => v.push(0), - Self::Float32(v) => v.push(0.), - Self::Float64(v) => v.push(0.), - Self::Binary(offsets, _) | Self::String(offsets, _) => offsets.push_length(0), - Self::List(_, offsets, e) => { - offsets.push_length(0); - e.append_null(); - } - Self::Record(_, e) => e.iter_mut().for_each(|e| e.append_null()), - Self::Nullable(_, _, _) => unreachable!("Nulls cannot be nested"), + Self::Float32(v) => v.push(0.0), + Self::Float64(v) => v.push(0.0), + Self::Binary(off, _) | Self::String(off, _) => off.push_length(0), + Self::Record(_, children) => { + for c in children { + c.append_null(); + } + } + Self::Enum(_, idxs) => idxs.push(0), + Self::List(_, off, _) => { + off.push_length(0); + } + Self::Map(_, _koff, moff, _kdata, _valdec) => { + moff.push_length(0); + } + Self::Nullable(_, nb, child) => { + nb.append(false); + child.append_null(); + } + Self::Fixed(sz, accum) => { + accum.extend(std::iter::repeat(0u8).take(*sz as usize)); + } + Self::Decimal(_, _, _, db) => { + let _ = db.append_null(); + } + Self::Interval(ivals) => { + ivals.push(IntervalMonthDayNano { + months: 0, + days: 0, + nanoseconds: 0, + }); + } } } - /// Decode a single record from `buf` - fn decode(&mut self, buf: &mut AvroCursor<'_>) -> Result<(), ArrowError> { + fn decode(&mut self, buf: &mut AvroCursor) -> Result<(), ArrowError> { match self { - Self::Null(x) => *x += 1, - Self::Boolean(values) => values.append(buf.get_bool()?), - Self::Int32(values) | Self::Date32(values) | Self::TimeMillis(values) => { - values.push(buf.get_int()?) - } - Self::Int64(values) - | Self::TimeMicros(values) - | Self::TimestampMillis(_, values) - | Self::TimestampMicros(_, values) => values.push(buf.get_long()?), - Self::Float32(values) => values.push(buf.get_float()?), - Self::Float64(values) => values.push(buf.get_double()?), - Self::Binary(offsets, values) | Self::String(offsets, values) => { - let data = buf.get_bytes()?; - offsets.push_length(data.len()); - values.extend_from_slice(data); - } - Self::List(_, _, _) => { - return Err(ArrowError::NotYetImplemented( - "Decoding ListArray".to_string(), - )) - } - Self::Record(_, encodings) => { - for encoding in encodings { - encoding.decode(buf)?; + Self::Null(n) => { + *n += 1; + } + Self::Boolean(b) => { + b.append(buf.get_bool()?); + } + Self::Int32(v) => { + v.push(buf.get_int()?); + } + Self::Int64(v) => { + v.push(buf.get_long()?); + } + Self::Float32(vals) => { + vals.push(buf.get_float()?); + } + Self::Float64(vals) => { + vals.push(buf.get_double()?); + } + Self::Binary(off, data) | Self::String(off, data) => { + let bytes = buf.get_bytes()?; + off.push_length(bytes.len()); + data.extend_from_slice(bytes); + } + Self::Record(_, children) => { + for c in children { + c.decode(buf)?; } } - Self::Nullable(nullability, nulls, e) => { - let is_valid = buf.get_bool()? == matches!(nullability, Nullability::NullFirst); - nulls.append(is_valid); - match is_valid { - true => e.decode(buf)?, - false => e.append_null(), + Self::Enum(_, idxs) => { + idxs.push(buf.get_int()?); + } + Self::List(_, off, child) => { + let total_items = read_array_blocks(buf, |cursor| child.decode(cursor))?; + off.push_length(total_items); + } + Self::Map(_, koff, moff, kdata, valdec) => { + let newly_added = read_map_blocks(buf, |cur| { + let kb = cur.get_bytes()?; + koff.push_length(kb.len()); + kdata.extend_from_slice(kb); + valdec.decode(cur) + })?; + moff.push_length(newly_added); + } + Self::Nullable(order, nb, child) => { + let branch = buf.get_int()?; + match order { + UnionOrder::NullFirst => { + if branch == 0 { + nb.append(false); + child.append_null(); + } else { + nb.append(true); + child.decode(buf)?; + } + } + UnionOrder::NullSecond => { + if branch == 0 { + nb.append(true); + child.decode(buf)?; + } else { + nb.append(false); + child.append_null(); + } + } } } + Self::Fixed(sz, accum) => { + let fx = buf.get_fixed(*sz as usize)?; + accum.extend_from_slice(fx); + } + Self::Decimal(_, _, fsz, db) => { + let raw = match *fsz { + Some(n) => buf.get_fixed(n)?, + None => buf.get_bytes()?, + }; + db.append_bytes(raw)?; + } + Self::Date32(vals) => vals.push(buf.get_int()?), + Self::TimeMillis(vals) => vals.push(buf.get_int()?), + Self::TimeMicros(vals) => vals.push(buf.get_long()?), + Self::TimestampMillis(_, vals) => vals.push(buf.get_long()?), + Self::TimestampMicros(_, vals) => vals.push(buf.get_long()?), + Self::Interval(ivals) => { + let x = buf.get_fixed(12)?; + let months = i32::from_le_bytes(x[0..4].try_into().unwrap()); + let days = i32::from_le_bytes(x[4..8].try_into().unwrap()); + let ms = i32::from_le_bytes(x[8..12].try_into().unwrap()); + let nanos = ms as i64 * 1_000_000; + ivals.push(IntervalMonthDayNano { + months, + days, + nanoseconds: nanos, + }); + } } Ok(()) } - /// Flush decoded records to an [`ArrayRef`] - fn flush(&mut self, nulls: Option) -> Result { - Ok(match self { - Self::Nullable(_, n, e) => e.flush(n.finish())?, - Self::Null(size) => Arc::new(NullArray::new(std::mem::replace(size, 0))), - Self::Boolean(b) => Arc::new(BooleanArray::new(b.finish(), nulls)), - Self::Int32(values) => Arc::new(flush_primitive::(values, nulls)), - Self::Date32(values) => Arc::new(flush_primitive::(values, nulls)), - Self::Int64(values) => Arc::new(flush_primitive::(values, nulls)), - Self::TimeMillis(values) => { - Arc::new(flush_primitive::(values, nulls)) - } - Self::TimeMicros(values) => { - Arc::new(flush_primitive::(values, nulls)) - } - Self::TimestampMillis(is_utc, values) => Arc::new( - flush_primitive::(values, nulls) - .with_timezone_opt(is_utc.then(|| "+00:00")), - ), - Self::TimestampMicros(is_utc, values) => Arc::new( - flush_primitive::(values, nulls) - .with_timezone_opt(is_utc.then(|| "+00:00")), - ), - Self::Float32(values) => Arc::new(flush_primitive::(values, nulls)), - Self::Float64(values) => Arc::new(flush_primitive::(values, nulls)), - - Self::Binary(offsets, values) => { - let offsets = flush_offsets(offsets); - let values = flush_values(values).into(); - Arc::new(BinaryArray::new(offsets, values, nulls)) - } - Self::String(offsets, values) => { - let offsets = flush_offsets(offsets); - let values = flush_values(values).into(); - Arc::new(StringArray::new(offsets, values, nulls)) - } - Self::List(field, offsets, values) => { - let values = values.flush(None)?; - let offsets = flush_offsets(offsets); - Arc::new(ListArray::new(field.clone(), offsets, values, nulls)) - } - Self::Record(fields, encodings) => { - let arrays = encodings - .iter_mut() - .map(|x| x.flush(None)) - .collect::, _>>()?; - Arc::new(StructArray::new(fields.clone(), arrays, nulls)) - } - }) + fn flush(&mut self, nulls: Option) -> Result, ArrowError> { + match self { + Self::Null(count) => { + let c = std::mem::replace(count, 0); + Ok(Arc::new(NullArray::new(c)) as Arc) + } + Self::Boolean(b) => { + let bits = b.finish(); + Ok(Arc::new(BooleanArray::new(bits, nulls)) as Arc) + } + Self::Int32(v) => { + let arr = flush_primitive::(v, nulls); + Ok(Arc::new(arr) as Arc) + } + Self::Date32(v) => { + let arr = flush_primitive::(v, nulls); + Ok(Arc::new(arr) as Arc) + } + Self::Int64(v) => { + let arr = flush_primitive::(v, nulls); + Ok(Arc::new(arr) as Arc) + } + Self::Float32(v) => { + let arr = flush_primitive::(v, nulls); + Ok(Arc::new(arr) as Arc) + } + Self::Float64(v) => { + let arr = flush_primitive::(v, nulls); + Ok(Arc::new(arr) as Arc) + } + Self::Binary(off, data) => { + let offsets = flush_offsets(off); + let vals = flush_values(data).into(); + let arr = BinaryArray::new(offsets, vals, nulls); + Ok(Arc::new(arr) as Arc) + } + Self::String(off, data) => { + let offsets = flush_offsets(off); + let vals = flush_values(data).into(); + let arr = StringArray::new(offsets, vals, nulls); + Ok(Arc::new(arr) as Arc) + } + Self::Record(fields, children) => { + let mut child_arrays = Vec::with_capacity(children.len()); + for c in children { + child_arrays.push(c.flush(None)?); + } + let first_len = match child_arrays.first() { + Some(a) => a.len(), + None => 0, + }; + for (i, arr) in child_arrays.iter().enumerate() { + if arr.len() != first_len { + return Err(ArrowError::InvalidArgumentError(format!( + "Inconsistent struct child length for field #{i}. Expected {first_len}, got {}", + arr.len() + ))); + } + } + if let Some(n) = &nulls { + if n.len() != first_len { + return Err(ArrowError::InvalidArgumentError(format!( + "Struct null buffer length {} != struct fields length {first_len}", + n.len() + ))); + } + } + let sarr = StructArray::new(fields.clone(), child_arrays, nulls); + Ok(Arc::new(sarr) as Arc) + } + Self::Enum(symbols, idxs) => { + let dict_vals = StringArray::from_iter_values(symbols.iter()); + let i32arr = match nulls { + Some(nb) => { + let buff = Buffer::from_slice_ref(&*idxs); + PrimitiveArray::::try_new( + arrow_buffer::ScalarBuffer::from(buff), + Some(nb), + )? + } + None => Int32Array::from_iter_values(idxs.iter().cloned()), + }; + idxs.clear(); + let d = DictionaryArray::::try_new(i32arr, Arc::new(dict_vals))?; + Ok(Arc::new(d) as Arc) + } + Self::List(item_field, off, child) => { + let c = child.flush(None)?; + let offsets = flush_offsets(off); + let final_len = offsets.len() - 1; + if let Some(n) = &nulls { + if n.len() != final_len { + return Err(ArrowError::InvalidArgumentError(format!( + "List array null buffer length {} != final list length {final_len}", + n.len() + ))); + } + } + let larr = ListArray::new(item_field.clone(), offsets, c, nulls); + Ok(Arc::new(larr) as Arc) + } + Self::Map(map_field, k_off, m_off, kdata, valdec) => { + let moff = flush_offsets(m_off); + let koff = flush_offsets(k_off); + let kd = flush_values(kdata).into(); + let val_arr = valdec.flush(None)?; + let key_arr = StringArray::new(koff, kd, None); + if key_arr.len() != val_arr.len() { + return Err(ArrowError::InvalidArgumentError(format!( + "Map keys length ({}) != map values length ({})", + key_arr.len(), + val_arr.len() + ))); + } + let final_len = moff.len() - 1; + if let Some(n) = &nulls { + if n.len() != final_len { + return Err(ArrowError::InvalidArgumentError(format!( + "Map array null buffer length {} != final map length {final_len}", + n.len() + ))); + } + } + let entries_struct = StructArray::new( + Fields::from(vec![ + Arc::new(ArrowField::new("key", DataType::Utf8, false)), + Arc::new(ArrowField::new("value", val_arr.data_type().clone(), true)), + ]), + vec![Arc::new(key_arr), val_arr], + None, + ); + let map_arr = MapArray::new(map_field.clone(), moff, entries_struct, nulls, false); + Ok(Arc::new(map_arr) as Arc) + } + Self::Nullable(_, nb_builder, child) => { + let mask = nb_builder.finish(); + child.flush(mask) + } + Self::Fixed(sz, accum) => { + let b: Buffer = flush_values(accum).into(); + let arr = FixedSizeBinaryArray::try_new(*sz, b, nulls) + .map_err(|e| ArrowError::ParseError(e.to_string()))?; + Ok(Arc::new(arr) as Arc) + } + Self::Decimal(precision, scale, sz, builder) => { + let p = *precision; + let s = scale.unwrap_or(0); + let new_b = DecimalBuilder::new(p, *scale, *sz)?; + let old = std::mem::replace(builder, new_b); + let arr = old.finish(nulls, p, s)?; + Ok(arr) + } + Self::TimeMillis(vals) => { + let arr = flush_primitive::(vals, nulls); + Ok(Arc::new(arr) as Arc) + } + Self::TimeMicros(vals) => { + let arr = flush_primitive::(vals, nulls); + Ok(Arc::new(arr) as Arc) + } + Self::TimestampMillis(is_utc, vals) => { + let arr = flush_primitive::(vals, nulls) + .with_timezone_opt::>(is_utc.then(|| "+00:00".into())); + Ok(Arc::new(arr) as Arc) + } + Self::TimestampMicros(is_utc, vals) => { + let arr = flush_primitive::(vals, nulls) + .with_timezone_opt::>(is_utc.then(|| "+00:00".into())); + Ok(Arc::new(arr) as Arc) + } + Self::Interval(ivals) => { + let len = ivals.len(); + let mut b = PrimitiveBuilder::::with_capacity(len); + for v in ivals.drain(..) { + b.append_value(v); + } + let arr = b + .finish() + .with_data_type(DataType::Interval(IntervalUnit::MonthDayNano)); + if let Some(nb) = nulls { + let arr_data = arr.into_data().into_builder().nulls(Some(nb)); + let arr_data = arr_data.build()?; + Ok( + Arc::new(PrimitiveArray::::from(arr_data)) + as Arc, + ) + } else { + Ok(Arc::new(arr) as Arc) + } + } + } } } -#[inline] -fn flush_values(values: &mut Vec) -> Vec { - std::mem::replace(values, Vec::with_capacity(DEFAULT_CAPACITY)) +fn read_array_blocks( + buf: &mut AvroCursor, + decode_item: impl FnMut(&mut AvroCursor) -> Result<(), ArrowError>, +) -> Result { + read_blockwise_items(buf, true, decode_item) } -#[inline] -fn flush_offsets(offsets: &mut OffsetBufferBuilder) -> OffsetBuffer { - std::mem::replace(offsets, OffsetBufferBuilder::new(DEFAULT_CAPACITY)).finish() +fn read_map_blocks( + buf: &mut AvroCursor, + decode_entry: impl FnMut(&mut AvroCursor) -> Result<(), ArrowError>, +) -> Result { + read_blockwise_items(buf, true, decode_entry) +} + +fn read_blockwise_items( + buf: &mut AvroCursor, + read_size_after_negative: bool, + mut decode_fn: impl FnMut(&mut AvroCursor) -> Result<(), ArrowError>, +) -> Result { + let mut total = 0usize; + loop { + let blk = buf.get_long()?; + match blk.cmp(&0) { + Ordering::Equal => break, + Ordering::Less => { + let cnt = (-blk) as usize; + if read_size_after_negative { + let _size_in_bytes = buf.get_long()?; + } + for _ in 0..cnt { + decode_fn(buf)?; + } + total += cnt; + } + Ordering::Greater => { + let cnt = blk as usize; + for _i in 0..cnt { + decode_fn(buf)?; + } + total += cnt; + } + } + } + Ok(total) } -#[inline] fn flush_primitive( - values: &mut Vec, - nulls: Option, + vals: &mut Vec, + nb: Option, ) -> PrimitiveArray { - PrimitiveArray::new(flush_values(values).into(), nulls) + PrimitiveArray::new(std::mem::take(vals).into(), nb) } -const DEFAULT_CAPACITY: usize = 1024; +fn flush_offsets(ob: &mut OffsetBufferBuilder) -> OffsetBuffer { + std::mem::replace(ob, OffsetBufferBuilder::new(DEFAULT_CAPACITY)).finish() +} + +fn flush_values(vec: &mut Vec) -> Vec { + std::mem::replace(vec, Vec::with_capacity(DEFAULT_CAPACITY)) +} + +/// A builder for Avro decimal, either 128-bit or 256-bit. +#[derive(Debug)] +enum DecimalBuilder { + Decimal128(Decimal128Builder), + Decimal256(Decimal256Builder), +} + +impl DecimalBuilder { + fn new( + precision: usize, + scale: Option, + size: Option, + ) -> Result { + let prec = precision as u8; + let scl = scale.unwrap_or(0) as i8; + if let Some(s) = size { + if s <= 16 { + return Ok(Self::Decimal128( + Decimal128Builder::new().with_precision_and_scale(prec, scl)?, + )); + } + if s <= 32 { + return Ok(Self::Decimal256( + Decimal256Builder::new().with_precision_and_scale(prec, scl)?, + )); + } + return Err(ArrowError::ParseError(format!( + "Unsupported decimal size: {s:?}" + ))); + } + if precision <= DECIMAL128_MAX_PRECISION as usize { + Ok(Self::Decimal128( + Decimal128Builder::new().with_precision_and_scale(prec, scl)?, + )) + } else if precision <= DECIMAL256_MAX_PRECISION as usize { + Ok(Self::Decimal256( + Decimal256Builder::new().with_precision_and_scale(prec, scl)?, + )) + } else { + Err(ArrowError::ParseError(format!( + "Decimal precision {} exceeds maximum supported", + precision + ))) + } + } + + fn append_bytes(&mut self, raw: &[u8]) -> Result<(), ArrowError> { + match self { + Self::Decimal128(b) => { + let ext = sign_extend_to_16(raw)?; + let val = i128::from_be_bytes(ext); + b.append_value(val); + } + Self::Decimal256(b) => { + let ext = sign_extend_to_32(raw)?; + let val = i256::from_be_bytes(ext); + b.append_value(val); + } + } + Ok(()) + } + + fn append_null(&mut self) -> Result<(), ArrowError> { + match self { + Self::Decimal128(b) => { + let zero = [0u8; 16]; + b.append_value(i128::from_be_bytes(zero)); + } + Self::Decimal256(b) => { + let zero = [0u8; 32]; + b.append_value(i256::from_be_bytes(zero)); + } + } + Ok(()) + } + + fn finish( + self, + nb: Option, + precision: usize, + scale: usize, + ) -> Result, ArrowError> { + match self { + Self::Decimal128(mut b) => { + let arr = b.finish(); + let vals = arr.values().clone(); + let dec = Decimal128Array::new(vals, nb) + .with_precision_and_scale(precision as u8, scale as i8)?; + Ok(Arc::new(dec)) + } + Self::Decimal256(mut b) => { + let arr = b.finish(); + let vals = arr.values().clone(); + let dec = Decimal256Array::new(vals, nb) + .with_precision_and_scale(precision as u8, scale as i8)?; + Ok(Arc::new(dec)) + } + } + } +} + +fn sign_extend_to_16(raw: &[u8]) -> Result<[u8; 16], ArrowError> { + let ext = sign_extend(raw, 16); + if ext.len() != 16 { + return Err(ArrowError::ParseError(format!( + "Failed to extend to 16 bytes, got {} bytes", + ext.len() + ))); + } + let mut arr = [0u8; 16]; + arr.copy_from_slice(&ext); + Ok(arr) +} + +fn sign_extend_to_32(raw: &[u8]) -> Result<[u8; 32], ArrowError> { + let ext = sign_extend(raw, 32); + if ext.len() != 32 { + return Err(ArrowError::ParseError(format!( + "Failed to extend to 32 bytes, got {} bytes", + ext.len() + ))); + } + let mut arr = [0u8; 32]; + arr.copy_from_slice(&ext); + Ok(arr) +} + +fn sign_extend(raw: &[u8], target_len: usize) -> Vec { + if raw.is_empty() { + return vec![0; target_len]; + } + let sign_bit = raw[0] & 0x80; + let mut out = Vec::with_capacity(target_len); + if sign_bit != 0 { + out.resize(target_len - raw.len(), 0xFF); + } else { + out.resize(target_len - raw.len(), 0x00); + } + out.extend_from_slice(raw); + out +} diff --git a/arrow-avro/src/reader/vlq.rs b/arrow-avro/src/reader/vlq.rs index b198a0d66f24..818c1f53cc0a 100644 --- a/arrow-avro/src/reader/vlq.rs +++ b/arrow-avro/src/reader/vlq.rs @@ -84,7 +84,7 @@ fn read_varint_array(buf: [u8; 10]) -> Option<(u64, usize)> { #[cold] fn read_varint_slow(buf: &[u8]) -> Option<(u64, usize)> { let mut value = 0; - for (count, byte) in buf.iter().take(10).enumerate() { + for (count, _) in buf.iter().take(10).enumerate() { let byte = buf[count]; value |= u64::from(byte & 0x7F) << (count * 7); if byte <= 0x7F { diff --git a/arrow-avro/src/schema.rs b/arrow-avro/src/schema.rs index a9d91e47948b..6380eef5b839 100644 --- a/arrow-avro/src/schema.rs +++ b/arrow-avro/src/schema.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use serde::{Deserialize, Serialize}; +use serde::{Deserialize, Deserializer, Serialize}; use std::collections::HashMap; /// The metadata key used for storing the JSON encoded [`Schema`] @@ -123,29 +123,49 @@ pub enum ComplexType<'a> { pub struct Record<'a> { #[serde(borrow)] pub name: &'a str, - #[serde(borrow, default)] + #[serde(borrow, default, skip_serializing_if = "Option::is_none")] pub namespace: Option<&'a str>, - #[serde(borrow, default)] + #[serde(borrow, default, skip_serializing_if = "Option::is_none")] pub doc: Option<&'a str>, #[serde(borrow, default)] pub aliases: Vec<&'a str>, #[serde(borrow)] - pub fields: Vec>, + pub fields: Vec>, #[serde(flatten)] pub attributes: Attributes<'a>, } /// A field within a [`Record`] +/// +/// **Modified** to preserve any `"default": null` even in out-of-spec union ordering. #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] -pub struct Field<'a> { +pub struct RecordField<'a> { #[serde(borrow)] pub name: &'a str, - #[serde(borrow, default)] + #[serde(borrow, default, skip_serializing_if = "Option::is_none")] pub doc: Option<&'a str>, + #[serde(borrow, default)] + pub aliases: Vec<&'a str>, #[serde(borrow)] pub r#type: Schema<'a>, - #[serde(borrow, default)] - pub default: Option<&'a str>, + #[serde( + default, + skip_serializing_if = "Option::is_none", + deserialize_with = "allow_out_of_spec_default" + )] + pub default: Option, +} + +/// Custom parse logic that stores *any* default as raw JSON +/// (including "null" for non-null-first unions). +fn allow_out_of_spec_default<'de, D>(deserializer: D) -> Result, D::Error> +where + D: Deserializer<'de>, +{ + match serde_json::Value::deserialize(deserializer) { + Ok(v) => Ok(Some(v)), + Err(_) => Ok(None), + } } /// An enumeration @@ -155,16 +175,16 @@ pub struct Field<'a> { pub struct Enum<'a> { #[serde(borrow)] pub name: &'a str, - #[serde(borrow, default)] + #[serde(borrow, default, skip_serializing_if = "Option::is_none")] pub namespace: Option<&'a str>, - #[serde(borrow, default)] + #[serde(borrow, default, skip_serializing_if = "Option::is_none")] pub doc: Option<&'a str>, #[serde(borrow, default)] pub aliases: Vec<&'a str>, #[serde(borrow)] pub symbols: Vec<&'a str>, - #[serde(borrow, default)] - pub default: Option<&'a str>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub default: Option, #[serde(flatten)] pub attributes: Attributes<'a>, } @@ -198,7 +218,7 @@ pub struct Map<'a> { pub struct Fixed<'a> { #[serde(borrow)] pub name: &'a str, - #[serde(borrow, default)] + #[serde(borrow, default, skip_serializing_if = "Option::is_none")] pub namespace: Option<&'a str>, #[serde(borrow, default)] pub aliases: Vec<&'a str>, @@ -210,7 +230,7 @@ pub struct Fixed<'a> { #[cfg(test)] mod tests { use super::*; - use crate::codec::{AvroDataType, AvroField}; + use crate::codec::AvroField; use arrow_schema::{DataType, Fields, TimeUnit}; use serde_json::json; @@ -254,6 +274,7 @@ mod tests { "type":"fixed", "name":"fixed", "namespace":"topLevelRecord.value", + "aliases":[], "size":11, "logicalType":"decimal", "precision":25, @@ -309,9 +330,10 @@ mod tests { namespace: None, doc: None, aliases: vec![], - fields: vec![Field { + fields: vec![RecordField { name: "value", doc: None, + aliases: vec![], r#type: Schema::Union(vec![ Schema::Complex(decimal), Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), @@ -343,15 +365,17 @@ mod tests { doc: None, aliases: vec!["LinkedLongs"], fields: vec![ - Field { + RecordField { name: "value", doc: None, + aliases: vec![], r#type: Schema::TypeName(TypeName::Primitive(PrimitiveType::Long)), default: None, }, - Field { + RecordField { name: "next", doc: None, + aliases: vec![], r#type: Schema::Union(vec![ Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), Schema::TypeName(TypeName::Ref("LongList")), @@ -359,7 +383,7 @@ mod tests { default: None, } ], - attributes: Attributes::default(), + attributes: Default::default(), })) ); @@ -402,18 +426,20 @@ mod tests { doc: None, aliases: vec![], fields: vec![ - Field { + RecordField { name: "id", doc: None, + aliases: vec![], r#type: Schema::Union(vec![ Schema::TypeName(TypeName::Primitive(PrimitiveType::Int)), Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), ]), default: None, }, - Field { + RecordField { name: "timestamp_col", doc: None, + aliases: vec![], r#type: Schema::Union(vec![ Schema::Type(timestamp), Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), @@ -463,9 +489,10 @@ mod tests { doc: None, aliases: vec![], fields: vec![ - Field { + RecordField { name: "clientHash", doc: None, + aliases: vec![], r#type: Schema::Complex(ComplexType::Fixed(Fixed { name: "MD5", namespace: None, @@ -475,27 +502,30 @@ mod tests { })), default: None, }, - Field { + RecordField { name: "clientProtocol", doc: None, + aliases: vec![], r#type: Schema::Union(vec![ Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), Schema::TypeName(TypeName::Primitive(PrimitiveType::String)), ]), default: None, }, - Field { + RecordField { name: "serverHash", doc: None, + aliases: vec![], r#type: Schema::TypeName(TypeName::Ref("MD5")), default: None, }, - Field { + RecordField { name: "meta", doc: None, + aliases: vec![], r#type: Schema::Union(vec![ Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), - Schema::Complex(ComplexType::Map(Map { + Schema::Complex(ComplexType::Map(crate::schema::Map { values: Box::new(Schema::TypeName(TypeName::Primitive( PrimitiveType::Bytes ))), @@ -508,5 +538,53 @@ mod tests { attributes: Default::default(), })) ); + + let t: Type = serde_json::from_str( + r#"{ + "type":"string", + "logicalType":"uuid" + }"#, + ) + .unwrap(); + + let uuid = Type { + r#type: TypeName::Primitive(PrimitiveType::String), + attributes: Attributes { + logical_type: Some("uuid"), + additional: Default::default(), + }, + }; + + assert_eq!(t, uuid); + + // Ensure aliases are parsed + let schema: Schema = serde_json::from_str( + r#"{ + "type": "record", + "name": "Foo", + "aliases": ["Bar"], + "fields" : [ + {"name":"id","aliases":["uid"],"type":"int"} + ] + }"#, + ) + .unwrap(); + + let with_aliases = Schema::Complex(ComplexType::Record(Record { + name: "Foo", + namespace: None, + doc: None, + aliases: vec!["Bar"], + fields: vec![RecordField { + name: "id", + aliases: vec!["uid"], + doc: None, + r#type: Schema::TypeName(TypeName::Primitive(PrimitiveType::Int)), + default: None, + }], + attributes: Default::default(), + })); + + assert_eq!(schema, with_aliases); } }