Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug: Fix NULL handling in array_slice, introduce NullHandling enum to Signature #14289

Merged
merged 13 commits into from
Feb 3, 2025
26 changes: 26 additions & 0 deletions datafusion/expr-common/src/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,8 @@ pub struct Signature {
pub type_signature: TypeSignature,
/// The volatility of the function. See [Volatility] for more information.
pub volatility: Volatility,
/// When true, the function always returns null whenever any of its arguments are null.
pub strict: bool,
}

impl Signature {
Expand All @@ -473,20 +475,23 @@ impl Signature {
Signature {
type_signature,
volatility,
strict: false,
}
}
/// An arbitrary number of arguments with the same type, from those listed in `common_types`.
pub fn variadic(common_types: Vec<DataType>, volatility: Volatility) -> Self {
Self {
type_signature: TypeSignature::Variadic(common_types),
volatility,
strict: false,
}
}
/// User-defined coercion rules for the function.
pub fn user_defined(volatility: Volatility) -> Self {
Self {
type_signature: TypeSignature::UserDefined,
volatility,
strict: false,
}
}

Expand All @@ -495,6 +500,7 @@ impl Signature {
Self {
type_signature: TypeSignature::Numeric(arg_count),
volatility,
strict: false,
}
}

Expand All @@ -503,6 +509,7 @@ impl Signature {
Self {
type_signature: TypeSignature::String(arg_count),
volatility,
strict: false,
}
}

Expand All @@ -511,6 +518,7 @@ impl Signature {
Self {
type_signature: TypeSignature::VariadicAny,
volatility,
strict: false,
}
}
/// A fixed number of arguments of the same type, from those listed in `valid_types`.
Expand All @@ -522,13 +530,15 @@ impl Signature {
Self {
type_signature: TypeSignature::Uniform(arg_count, valid_types),
volatility,
strict: false,
}
}
/// Exactly matches the types in `exact_types`, in order.
pub fn exact(exact_types: Vec<DataType>, volatility: Volatility) -> Self {
Signature {
type_signature: TypeSignature::Exact(exact_types),
volatility,
strict: false,
}
}
/// Target coerce types in order
Expand All @@ -539,6 +549,7 @@ impl Signature {
Self {
type_signature: TypeSignature::Coercible(target_types),
volatility,
strict: false,
}
}

Expand All @@ -547,13 +558,15 @@ impl Signature {
Self {
type_signature: TypeSignature::Comparable(arg_count),
volatility,
strict: false,
}
}

pub fn nullary(volatility: Volatility) -> Self {
Signature {
type_signature: TypeSignature::Nullary,
volatility,
strict: false,
}
}

Expand All @@ -562,13 +575,15 @@ impl Signature {
Signature {
type_signature: TypeSignature::Any(arg_count),
volatility,
strict: false,
}
}
/// Any one of a list of [TypeSignature]s.
pub fn one_of(type_signatures: Vec<TypeSignature>, volatility: Volatility) -> Self {
Signature {
type_signature: TypeSignature::OneOf(type_signatures),
volatility,
strict: false,
}
}
/// Specialized Signature for ArrayAppend and similar functions
Expand All @@ -578,6 +593,7 @@ impl Signature {
ArrayFunctionSignature::ArrayAndElement,
),
volatility,
strict: false,
}
}
/// Specialized Signature for Array functions with an optional index
Expand All @@ -587,6 +603,7 @@ impl Signature {
ArrayFunctionSignature::ArrayAndElementAndOptionalIndex,
),
volatility,
strict: false,
}
}
/// Specialized Signature for ArrayPrepend and similar functions
Expand All @@ -596,6 +613,7 @@ impl Signature {
ArrayFunctionSignature::ElementAndArray,
),
volatility,
strict: false,
}
}
/// Specialized Signature for ArrayElement and similar functions
Expand All @@ -605,15 +623,23 @@ impl Signature {
ArrayFunctionSignature::ArrayAndIndex,
),
volatility,
strict: false,
}
}
/// Specialized Signature for ArrayEmpty and similar functions
pub fn array(volatility: Volatility) -> Self {
Signature {
type_signature: TypeSignature::ArraySignature(ArrayFunctionSignature::Array),
volatility,
strict: false,
}
}

/// Returns an equivalent Signature, with strict set to true.
pub fn with_strict(mut self) -> Self {
self.strict = true;
self
}
}

#[cfg(test)]
Expand Down
5 changes: 2 additions & 3 deletions datafusion/functions-nested/src/extract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,8 @@ pub(super) struct ArraySlice {
impl ArraySlice {
pub fn new() -> Self {
Self {
signature: Signature::variadic_any(Volatility::Immutable),
// TODO: This signature should use the actual accepted types, not variadic_any.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't been able to figure out a way using the current TypeSignature to create a type signature that accepts (any list type, i64, i64). My only idea is to extend ArrayFunctionSignature with a variant like ArrayAndElements(NonZeroUsize).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another approach, which probably warrants it's own issue/PR, is to add something similar to PostgreSQL psuedo-types (https://www.postgresql.org/docs/current/datatype-pseudo.html). Then I can make a signature of something like

TypeSignature::OneOf(vec![
    TypeSignature::Exact(vec![AnyArray, Int64, Int64]),
    TypeSignature::Exact(vec![AnyArray, Int64, Int64, Int64]),
])

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My only idea is to extend ArrayFunctionSignature with a variant like ArrayAndElements(NonZeroUsize)

I think this is better

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just pushed an update to do this.

signature: Signature::variadic_any(Volatility::Immutable).with_strict(),
aliases: vec![String::from("list_slice")],
}
}
Expand Down Expand Up @@ -430,8 +431,6 @@ fn array_slice_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
}
LargeList(_) => {
let array = as_large_list_array(&args[0])?;
let from_array = as_int64_array(&args[1])?;
let to_array = as_int64_array(&args[2])?;
general_array_slice::<i64>(array, from_array, to_array, stride)
}
_ => exec_err!("array_slice does not support type: {:?}", array_data_type),
Expand Down
1 change: 1 addition & 0 deletions datafusion/functions-nested/src/flatten.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ impl Flatten {
ArrayFunctionSignature::RecursiveArray,
),
volatility: Volatility::Immutable,
strict: false,
},
aliases: vec![],
}
Expand Down
52 changes: 51 additions & 1 deletion datafusion/physical-expr/src/scalar_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ use crate::PhysicalExpr;

use arrow::datatypes::{DataType, Schema};
use arrow::record_batch::RecordBatch;
use arrow_array::Array;
use arrow_array::{Array, LargeListArray, ListArray};
use datafusion_common::{internal_err, DFSchema, Result, ScalarValue};
use datafusion_expr::interval_arithmetic::Interval;
use datafusion_expr::sort_properties::ExprProperties;
Expand Down Expand Up @@ -186,6 +186,56 @@ impl PhysicalExpr for ScalarFunctionExpr {
.map(|e| e.evaluate(batch))
.collect::<Result<Vec<_>>>()?;

if self.fun.signature().strict && args.iter().any(|arg| arg.data_type().is_null())
{
let null_value = match &self.return_type {
DataType::Null => ScalarValue::Null,
DataType::Boolean => ScalarValue::Boolean(None),
DataType::Int8 => ScalarValue::Int8(None),
DataType::Int16 => ScalarValue::Int16(None),
DataType::Int32 => ScalarValue::Int32(None),
DataType::Int64 => ScalarValue::Int64(None),
DataType::UInt8 => ScalarValue::UInt8(None),
DataType::UInt16 => ScalarValue::UInt16(None),
DataType::UInt32 => ScalarValue::UInt32(None),
DataType::UInt64 => ScalarValue::UInt64(None),
DataType::Float16 => ScalarValue::Float16(None),
DataType::Float32 => ScalarValue::Float32(None),
DataType::Float64 => ScalarValue::Float64(None),
DataType::Timestamp(_, _) => todo!(),
DataType::Date32 => todo!(),
DataType::Date64 => todo!(),
DataType::Time32(_) => todo!(),
DataType::Time64(_) => todo!(),
DataType::Duration(_) => todo!(),
DataType::Interval(_) => todo!(),
DataType::Binary => todo!(),
DataType::FixedSizeBinary(_) => todo!(),
DataType::LargeBinary => todo!(),
DataType::BinaryView => todo!(),
DataType::Utf8 => todo!(),
DataType::LargeUtf8 => todo!(),
DataType::Utf8View => todo!(),
DataType::List(field_ref) => ScalarValue::List(Arc::new(
ListArray::new_null(Arc::clone(field_ref), 1),
)),
DataType::ListView(_) => todo!(),
DataType::FixedSizeList(_, _) => todo!(),
DataType::LargeList(field_ref) => ScalarValue::LargeList(Arc::new(
LargeListArray::new_null(Arc::clone(field_ref), 1),
)),
DataType::LargeListView(_) => todo!(),
DataType::Struct(_) => todo!(),
DataType::Union(_, _) => todo!(),
DataType::Dictionary(_, _) => todo!(),
DataType::Decimal128(_, _) => todo!(),
DataType::Decimal256(_, _) => todo!(),
DataType::Map(_, _) => todo!(),
DataType::RunEndEncoded(_, _) => todo!(),
};
jkosh44 marked this conversation as resolved.
Show resolved Hide resolved
return Ok(ColumnarValue::Scalar(null_value));
}

let input_empty = args.is_empty();
let input_all_scalar = args
.iter()
Expand Down
41 changes: 30 additions & 11 deletions datafusion/sqllogictest/test_files/array.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1817,18 +1817,26 @@ select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 0,
[1, 2, 3, 4] [h, e, l]

# array_slice scalar function #8 (with NULL and positive number)
query error
query ??
select array_slice(make_array(1, 2, 3, 4, 5), NULL, 4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), NULL, 3);
----
NULL NULL

query error
query ??
select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), NULL, 4), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), NULL, 3);
----
NULL NULL

# array_slice scalar function #9 (with positive number and NULL)
query error
query ??
select array_slice(make_array(1, 2, 3, 4, 5), 2, NULL), array_slice(make_array('h', 'e', 'l', 'l', 'o'), 3, NULL);
----
NULL NULL

query error
query ??
select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2, NULL), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 3, NULL);
----
NULL NULL

# array_slice scalar function #10 (with zero-zero)
query ??
Expand All @@ -1842,12 +1850,15 @@ select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 0,
[] []

# array_slice scalar function #11 (with NULL-NULL)
query error
query ??
select array_slice(make_array(1, 2, 3, 4, 5), NULL), array_slice(make_array('h', 'e', 'l', 'l', 'o'), NULL);
----
NULL NULL

query error
query ??
select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), NULL), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), NULL);

----
NULL NULL

# array_slice scalar function #12 (with zero and negative number)
query ??
Expand All @@ -1861,18 +1872,26 @@ select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 0,
[1, 2] [h, e, l]

# array_slice scalar function #13 (with negative number and NULL)
query error
query ??
select array_slice(make_array(1, 2, 3, 4, 5), -2, NULL), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -3, NULL);
----
NULL NULL

query error
query ??
select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -2, NULL), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -3, NULL);
----
NULL NULL

# array_slice scalar function #14 (with NULL and negative number)
query error
query ??
select array_slice(make_array(1, 2, 3, 4, 5), NULL, -4), array_slice(make_array('h', 'e', 'l', 'l', 'o'), NULL, -3);
----
NULL NULL

query error
query ??
select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), NULL, -4), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), NULL, -3);
----
NULL NULL

# array_slice scalar function #15 (with negative indexes)
query ??
Expand Down
Loading