From 8bedecc00b2f1f04d7b1a907152ce0d19b7046a5 Mon Sep 17 00:00:00 2001 From: Arttu Date: Fri, 24 May 2024 15:05:04 +0200 Subject: [PATCH] More properly handle nullability of types/literals in Substrait (#10640) * More properly handle nullability of types/literals in Substrait This isn't perfect; some things are still assumed to just always be nullable (e.g. Literal list elements). But it's still giving a closer match than just assuming everything is nullable. * Avoid cloning and creating DataFusionError Co-authored-by: Jonah Gao * simplify Literal/ScalarValue null handling --------- Co-authored-by: Jonah Gao --- .../substrait/src/logical_plan/consumer.rs | 60 +++- .../substrait/src/logical_plan/producer.rs | 282 +++++------------- 2 files changed, 122 insertions(+), 220 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index a08485fd3555..7e8a0cadb57a 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -1136,11 +1136,13 @@ pub(crate) fn from_substrait_type(dt: &substrait::proto::Type) -> Result { - let inner_type = - from_substrait_type(list.r#type.as_ref().ok_or_else(|| { - substrait_datafusion_err!("List type must have inner type") - })?)?; - let field = Arc::new(Field::new_list_field(inner_type, true)); + let inner_type = list.r#type.as_ref().ok_or_else(|| { + substrait_datafusion_err!("List type must have inner type") + })?; + let field = Arc::new(Field::new_list_field( + from_substrait_type(inner_type)?, + is_substrait_type_nullable(inner_type)?, + )); match list.type_variation_reference { DEFAULT_CONTAINER_TYPE_REF => Ok(DataType::List(field)), LARGE_CONTAINER_TYPE_REF => Ok(DataType::LargeList(field)), @@ -1163,8 +1165,11 @@ pub(crate) fn from_substrait_type(dt: &substrait::proto::Type) -> Result { let mut fields = vec![]; for (i, f) in s.types.iter().enumerate() { - let field = - Field::new(&format!("c{i}"), from_substrait_type(f)?, true); + let field = Field::new( + &format!("c{i}"), + from_substrait_type(f)?, + is_substrait_type_nullable(f)?, + ); fields.push(field); } Ok(DataType::Struct(fields.into())) @@ -1175,6 +1180,47 @@ pub(crate) fn from_substrait_type(dt: &substrait::proto::Type) -> Result Result { + fn is_nullable(nullability: i32) -> bool { + nullability != substrait::proto::r#type::Nullability::Required as i32 + } + + let nullable = match dtype + .kind + .as_ref() + .ok_or_else(|| substrait_datafusion_err!("Type must contain Kind"))? + { + r#type::Kind::Bool(val) => is_nullable(val.nullability), + r#type::Kind::I8(val) => is_nullable(val.nullability), + r#type::Kind::I16(val) => is_nullable(val.nullability), + r#type::Kind::I32(val) => is_nullable(val.nullability), + r#type::Kind::I64(val) => is_nullable(val.nullability), + r#type::Kind::Fp32(val) => is_nullable(val.nullability), + r#type::Kind::Fp64(val) => is_nullable(val.nullability), + r#type::Kind::String(val) => is_nullable(val.nullability), + r#type::Kind::Binary(val) => is_nullable(val.nullability), + r#type::Kind::Timestamp(val) => is_nullable(val.nullability), + r#type::Kind::Date(val) => is_nullable(val.nullability), + r#type::Kind::Time(val) => is_nullable(val.nullability), + r#type::Kind::IntervalYear(val) => is_nullable(val.nullability), + r#type::Kind::IntervalDay(val) => is_nullable(val.nullability), + r#type::Kind::TimestampTz(val) => is_nullable(val.nullability), + r#type::Kind::Uuid(val) => is_nullable(val.nullability), + r#type::Kind::FixedChar(val) => is_nullable(val.nullability), + r#type::Kind::Varchar(val) => is_nullable(val.nullability), + r#type::Kind::FixedBinary(val) => is_nullable(val.nullability), + r#type::Kind::Decimal(val) => is_nullable(val.nullability), + r#type::Kind::PrecisionTimestamp(val) => is_nullable(val.nullability), + r#type::Kind::PrecisionTimestampTz(val) => is_nullable(val.nullability), + r#type::Kind::Struct(val) => is_nullable(val.nullability), + r#type::Kind::List(val) => is_nullable(val.nullability), + r#type::Kind::Map(val) => is_nullable(val.nullability), + r#type::Kind::UserDefined(val) => is_nullable(val.nullability), + r#type::Kind::UserDefinedTypeReference(_) => true, // not implemented, assume nullable + }; + Ok(nullable) +} + fn from_substrait_bound( bound: &Option, is_lower: bool, diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index e216008c73da..c0aac0c0a406 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -1089,7 +1089,7 @@ pub fn to_substrait_rex( Ok(Expression { rex_type: Some(RexType::Cast(Box::new( substrait::proto::expression::Cast { - r#type: Some(to_substrait_type(data_type)?), + r#type: Some(to_substrait_type(data_type, true)?), input: Some(Box::new(to_substrait_rex( ctx, expr, @@ -1296,75 +1296,79 @@ pub fn to_substrait_rex( } } -fn to_substrait_type(dt: &DataType) -> Result { - let default_nullability = r#type::Nullability::Required as i32; +fn to_substrait_type(dt: &DataType, nullable: bool) -> Result { + let nullability = if nullable { + r#type::Nullability::Nullable as i32 + } else { + r#type::Nullability::Required as i32 + }; match dt { DataType::Null => internal_err!("Null cast is not valid"), DataType::Boolean => Ok(substrait::proto::Type { kind: Some(r#type::Kind::Bool(r#type::Boolean { type_variation_reference: DEFAULT_TYPE_REF, - nullability: default_nullability, + nullability, })), }), DataType::Int8 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::I8(r#type::I8 { type_variation_reference: DEFAULT_TYPE_REF, - nullability: default_nullability, + nullability, })), }), DataType::UInt8 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::I8(r#type::I8 { type_variation_reference: UNSIGNED_INTEGER_TYPE_REF, - nullability: default_nullability, + nullability, })), }), DataType::Int16 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::I16(r#type::I16 { type_variation_reference: DEFAULT_TYPE_REF, - nullability: default_nullability, + nullability, })), }), DataType::UInt16 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::I16(r#type::I16 { type_variation_reference: UNSIGNED_INTEGER_TYPE_REF, - nullability: default_nullability, + nullability, })), }), DataType::Int32 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::I32(r#type::I32 { type_variation_reference: DEFAULT_TYPE_REF, - nullability: default_nullability, + nullability, })), }), DataType::UInt32 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::I32(r#type::I32 { type_variation_reference: UNSIGNED_INTEGER_TYPE_REF, - nullability: default_nullability, + nullability, })), }), DataType::Int64 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::I64(r#type::I64 { type_variation_reference: DEFAULT_TYPE_REF, - nullability: default_nullability, + nullability, })), }), DataType::UInt64 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::I64(r#type::I64 { type_variation_reference: UNSIGNED_INTEGER_TYPE_REF, - nullability: default_nullability, + nullability, })), }), // Float16 is not supported in Substrait DataType::Float32 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::Fp32(r#type::Fp32 { type_variation_reference: DEFAULT_TYPE_REF, - nullability: default_nullability, + nullability, })), }), DataType::Float64 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::Fp64(r#type::Fp64 { type_variation_reference: DEFAULT_TYPE_REF, - nullability: default_nullability, + nullability, })), }), // Timezone is ignored. @@ -1378,90 +1382,90 @@ fn to_substrait_type(dt: &DataType) -> Result { Ok(substrait::proto::Type { kind: Some(r#type::Kind::Timestamp(r#type::Timestamp { type_variation_reference, - nullability: default_nullability, + nullability, })), }) } DataType::Date32 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::Date(r#type::Date { type_variation_reference: DATE_32_TYPE_REF, - nullability: default_nullability, + nullability, })), }), DataType::Date64 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::Date(r#type::Date { type_variation_reference: DATE_64_TYPE_REF, - nullability: default_nullability, + nullability, })), }), DataType::Binary => Ok(substrait::proto::Type { kind: Some(r#type::Kind::Binary(r#type::Binary { type_variation_reference: DEFAULT_CONTAINER_TYPE_REF, - nullability: default_nullability, + nullability, })), }), DataType::FixedSizeBinary(length) => Ok(substrait::proto::Type { kind: Some(r#type::Kind::FixedBinary(r#type::FixedBinary { length: *length, type_variation_reference: DEFAULT_TYPE_REF, - nullability: default_nullability, + nullability, })), }), DataType::LargeBinary => Ok(substrait::proto::Type { kind: Some(r#type::Kind::Binary(r#type::Binary { type_variation_reference: LARGE_CONTAINER_TYPE_REF, - nullability: default_nullability, + nullability, })), }), DataType::Utf8 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::String(r#type::String { type_variation_reference: DEFAULT_CONTAINER_TYPE_REF, - nullability: default_nullability, + nullability, })), }), DataType::LargeUtf8 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::String(r#type::String { type_variation_reference: LARGE_CONTAINER_TYPE_REF, - nullability: default_nullability, + nullability, })), }), DataType::List(inner) => { - let inner_type = to_substrait_type(inner.data_type())?; + let inner_type = to_substrait_type(inner.data_type(), inner.is_nullable())?; Ok(substrait::proto::Type { kind: Some(r#type::Kind::List(Box::new(r#type::List { r#type: Some(Box::new(inner_type)), type_variation_reference: DEFAULT_CONTAINER_TYPE_REF, - nullability: default_nullability, + nullability, }))), }) } DataType::LargeList(inner) => { - let inner_type = to_substrait_type(inner.data_type())?; + let inner_type = to_substrait_type(inner.data_type(), inner.is_nullable())?; Ok(substrait::proto::Type { kind: Some(r#type::Kind::List(Box::new(r#type::List { r#type: Some(Box::new(inner_type)), type_variation_reference: LARGE_CONTAINER_TYPE_REF, - nullability: default_nullability, + nullability, }))), }) } DataType::Struct(fields) => { let field_types = fields .iter() - .map(|field| to_substrait_type(field.data_type())) + .map(|field| to_substrait_type(field.data_type(), field.is_nullable())) .collect::>>()?; Ok(substrait::proto::Type { kind: Some(r#type::Kind::Struct(r#type::Struct { types: field_types, type_variation_reference: DEFAULT_TYPE_REF, - nullability: default_nullability, + nullability, })), }) } DataType::Decimal128(p, s) => Ok(substrait::proto::Type { kind: Some(r#type::Kind::Decimal(r#type::Decimal { type_variation_reference: DECIMAL_128_TYPE_REF, - nullability: default_nullability, + nullability, scale: *s as i32, precision: *p as i32, })), @@ -1469,7 +1473,7 @@ fn to_substrait_type(dt: &DataType) -> Result { DataType::Decimal256(p, s) => Ok(substrait::proto::Type { kind: Some(r#type::Kind::Decimal(r#type::Decimal { type_variation_reference: DECIMAL_256_TYPE_REF, - nullability: default_nullability, + nullability, scale: *s as i32, precision: *p as i32, })), @@ -1687,6 +1691,16 @@ fn to_substrait_bounds(window_frame: &WindowFrame) -> Result<(Bound, Bound)> { } fn to_substrait_literal(value: &ScalarValue) -> Result { + if value.is_null() { + return Ok(Literal { + nullable: true, + type_variation_reference: DEFAULT_TYPE_REF, + literal_type: Some(LiteralType::Null(to_substrait_type( + &value.data_type(), + true, + )?)), + }); + } let (literal_type, type_variation_reference) = match value { ScalarValue::Boolean(Some(b)) => (LiteralType::Boolean(*b), DEFAULT_TYPE_REF), ScalarValue::Int8(Some(n)) => (LiteralType::I8(*n as i32), DEFAULT_TYPE_REF), @@ -1744,14 +1758,14 @@ fn to_substrait_literal(value: &ScalarValue) -> Result { }), DECIMAL_128_TYPE_REF, ), - ScalarValue::List(l) if !value.is_null() => ( + ScalarValue::List(l) => ( convert_array_to_literal_list(l)?, DEFAULT_CONTAINER_TYPE_REF, ), - ScalarValue::LargeList(l) if !value.is_null() => { + ScalarValue::LargeList(l) => { (convert_array_to_literal_list(l)?, LARGE_CONTAINER_TYPE_REF) } - ScalarValue::Struct(s) if !value.is_null() => ( + ScalarValue::Struct(s) => ( LiteralType::Struct(Struct { fields: s .columns() @@ -1763,11 +1777,14 @@ fn to_substrait_literal(value: &ScalarValue) -> Result { }), DEFAULT_TYPE_REF, ), - _ => (try_to_substrait_null(value)?, DEFAULT_TYPE_REF), + _ => ( + not_impl_err!("Unsupported literal: {value:?}")?, + DEFAULT_TYPE_REF, + ), }; Ok(Literal { - nullable: true, + nullable: false, type_variation_reference, literal_type: Some(literal_type), }) @@ -1784,7 +1801,7 @@ fn convert_array_to_literal_list( .collect::>>()?; if values.is_empty() { - let et = match to_substrait_type(array.data_type())? { + let et = match to_substrait_type(array.data_type(), array.is_nullable())? { substrait::proto::Type { kind: Some(r#type::Kind::List(lt)), } => lt.as_ref().to_owned(), @@ -1832,173 +1849,6 @@ fn to_substrait_unary_scalar_fn( }) } -fn try_to_substrait_null(v: &ScalarValue) -> Result { - let default_nullability = r#type::Nullability::Nullable as i32; - match v { - ScalarValue::Boolean(None) => Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::Bool(r#type::Boolean { - type_variation_reference: DEFAULT_TYPE_REF, - nullability: default_nullability, - })), - })), - ScalarValue::Int8(None) => Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::I8(r#type::I8 { - type_variation_reference: DEFAULT_TYPE_REF, - nullability: default_nullability, - })), - })), - ScalarValue::UInt8(None) => Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::I8(r#type::I8 { - type_variation_reference: UNSIGNED_INTEGER_TYPE_REF, - nullability: default_nullability, - })), - })), - ScalarValue::Int16(None) => Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::I16(r#type::I16 { - type_variation_reference: DEFAULT_TYPE_REF, - nullability: default_nullability, - })), - })), - ScalarValue::UInt16(None) => Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::I16(r#type::I16 { - type_variation_reference: UNSIGNED_INTEGER_TYPE_REF, - nullability: default_nullability, - })), - })), - ScalarValue::Int32(None) => Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::I32(r#type::I32 { - type_variation_reference: DEFAULT_TYPE_REF, - nullability: default_nullability, - })), - })), - ScalarValue::UInt32(None) => Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::I32(r#type::I32 { - type_variation_reference: UNSIGNED_INTEGER_TYPE_REF, - nullability: default_nullability, - })), - })), - ScalarValue::Int64(None) => Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::I64(r#type::I64 { - type_variation_reference: DEFAULT_TYPE_REF, - nullability: default_nullability, - })), - })), - ScalarValue::UInt64(None) => Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::I64(r#type::I64 { - type_variation_reference: UNSIGNED_INTEGER_TYPE_REF, - nullability: default_nullability, - })), - })), - ScalarValue::Float32(None) => Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::Fp32(r#type::Fp32 { - type_variation_reference: DEFAULT_TYPE_REF, - nullability: default_nullability, - })), - })), - ScalarValue::Float64(None) => Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::Fp64(r#type::Fp64 { - type_variation_reference: DEFAULT_TYPE_REF, - nullability: default_nullability, - })), - })), - ScalarValue::TimestampSecond(None, _) => { - Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::Timestamp(r#type::Timestamp { - type_variation_reference: TIMESTAMP_SECOND_TYPE_REF, - nullability: default_nullability, - })), - })) - } - ScalarValue::TimestampMillisecond(None, _) => { - Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::Timestamp(r#type::Timestamp { - type_variation_reference: TIMESTAMP_MILLI_TYPE_REF, - nullability: default_nullability, - })), - })) - } - ScalarValue::TimestampMicrosecond(None, _) => { - Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::Timestamp(r#type::Timestamp { - type_variation_reference: TIMESTAMP_MICRO_TYPE_REF, - nullability: default_nullability, - })), - })) - } - ScalarValue::TimestampNanosecond(None, _) => { - Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::Timestamp(r#type::Timestamp { - type_variation_reference: TIMESTAMP_NANO_TYPE_REF, - nullability: default_nullability, - })), - })) - } - ScalarValue::Date32(None) => Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::Date(r#type::Date { - type_variation_reference: DATE_32_TYPE_REF, - nullability: default_nullability, - })), - })), - ScalarValue::Date64(None) => Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::Date(r#type::Date { - type_variation_reference: DATE_64_TYPE_REF, - nullability: default_nullability, - })), - })), - ScalarValue::Binary(None) => Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::Binary(r#type::Binary { - type_variation_reference: DEFAULT_CONTAINER_TYPE_REF, - nullability: default_nullability, - })), - })), - ScalarValue::LargeBinary(None) => Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::Binary(r#type::Binary { - type_variation_reference: LARGE_CONTAINER_TYPE_REF, - nullability: default_nullability, - })), - })), - ScalarValue::FixedSizeBinary(_, None) => { - Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::Binary(r#type::Binary { - type_variation_reference: DEFAULT_TYPE_REF, - nullability: default_nullability, - })), - })) - } - ScalarValue::Utf8(None) => Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::String(r#type::String { - type_variation_reference: DEFAULT_CONTAINER_TYPE_REF, - nullability: default_nullability, - })), - })), - ScalarValue::LargeUtf8(None) => Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::String(r#type::String { - type_variation_reference: LARGE_CONTAINER_TYPE_REF, - nullability: default_nullability, - })), - })), - ScalarValue::Decimal128(None, p, s) => { - Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::Decimal(r#type::Decimal { - scale: *s as i32, - precision: *p as i32, - type_variation_reference: DEFAULT_TYPE_REF, - nullability: default_nullability, - })), - })) - } - ScalarValue::List(l) => Ok(LiteralType::Null(to_substrait_type(l.data_type())?)), - ScalarValue::LargeList(l) => { - Ok(LiteralType::Null(to_substrait_type(l.data_type())?)) - } - ScalarValue::Struct(s) => { - Ok(LiteralType::Null(to_substrait_type(s.data_type())?)) - } - // TODO: Extend support for remaining data types - _ => not_impl_err!("Unsupported literal: {v:?}"), - } -} - /// Try to convert an [Expr] to a [FieldReference]. /// Returns `Err` if the [Expr] is not a [Expr::Column]. fn try_to_substrait_field_reference( @@ -2141,8 +1991,8 @@ mod test { ), )))?; - let c0 = Field::new("c0", DataType::Boolean, true); - let c1 = Field::new("c1", DataType::Int32, true); + let c0 = Field::new("c0", DataType::Boolean, false); + let c1 = Field::new("c1", DataType::Int32, false); let c2 = Field::new("c2", DataType::Utf8, true); round_trip_literal( ScalarStructBuilder::new() @@ -2190,16 +2040,20 @@ mod test { round_trip_type(DataType::LargeUtf8)?; round_trip_type(DataType::Decimal128(10, 2))?; round_trip_type(DataType::Decimal256(30, 2))?; - round_trip_type(DataType::List( - Field::new_list_field(DataType::Int32, true).into(), - ))?; - round_trip_type(DataType::LargeList( - Field::new_list_field(DataType::Int32, true).into(), - ))?; + + for nullable in [true, false] { + round_trip_type(DataType::List( + Field::new_list_field(DataType::Int32, nullable).into(), + ))?; + round_trip_type(DataType::LargeList( + Field::new_list_field(DataType::Int32, nullable).into(), + ))?; + } + round_trip_type(DataType::Struct( vec![ Field::new("c0", DataType::Int32, true), - Field::new("c1", DataType::Utf8, true), + Field::new("c1", DataType::Utf8, false), ] .into(), ))?; @@ -2210,7 +2064,9 @@ mod test { fn round_trip_type(dt: DataType) -> Result<()> { println!("Checking round trip of {dt:?}"); - let substrait = to_substrait_type(&dt)?; + // As DataFusion doesn't consider nullability as a property of the type, but field, + // it doesn't matter if we set nullability to true or false here. + let substrait = to_substrait_type(&dt, true)?; let roundtrip_dt = from_substrait_type(&substrait)?; assert_eq!(dt, roundtrip_dt); Ok(())