Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug: Fix NULL handling in array_slice, introduce NullHandling enum to Signature #14289

Merged
merged 13 commits into from
Feb 3, 2025
21 changes: 15 additions & 6 deletions datafusion/expr-common/src/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
//! and return types of functions in DataFusion.

use std::fmt::Display;
use std::num::NonZeroUsize;

use crate::type_coercion::aggregates::NUMERICS;
use arrow::datatypes::{DataType, IntervalUnit, TimeUnit};
Expand Down Expand Up @@ -236,9 +237,9 @@ pub enum ArrayFunctionSignature {
/// The first argument should be non-list or list, and the second argument should be List/LargeList.
/// The first argument's list dimension should be one dimension less than the second argument's list dimension.
ElementAndArray,
/// Specialized Signature for Array functions of the form (List/LargeList, Index)
/// The first argument should be List/LargeList/FixedSizedList, and the second argument should be Int64.
ArrayAndIndex,
/// Specialized Signature for Array functions of the form (List/LargeList, Index+)
/// The first argument should be List/LargeList/FixedSizedList, and the next n arguments should be Int64.
ArrayAndIndexes(NonZeroUsize),
/// Specialized Signature for Array functions of the form (List/LargeList, Element, Optional Index)
ArrayAndElementAndOptionalIndex,
/// Specialized Signature for ArrayEmpty and similar functions
Expand All @@ -265,8 +266,12 @@ impl Display for ArrayFunctionSignature {
ArrayFunctionSignature::ElementAndArray => {
write!(f, "element, array")
}
ArrayFunctionSignature::ArrayAndIndex => {
write!(f, "array, index")
ArrayFunctionSignature::ArrayAndIndexes(count) => {
write!(f, "array")?;
for _ in 0..count.get() {
write!(f, ", index")?;
}
Ok(())
}
ArrayFunctionSignature::Array => {
write!(f, "array")
Expand Down Expand Up @@ -600,9 +605,13 @@ impl Signature {
}
/// Specialized Signature for ArrayElement and similar functions
pub fn array_and_index(volatility: Volatility) -> Self {
Self::array_and_indexes(volatility, NonZeroUsize::new(1).expect("1 is non-zero"))
}
/// Specialized Signature for ArraySlice and similar functions
pub fn array_and_indexes(volatility: Volatility, count: NonZeroUsize) -> Self {
Signature {
type_signature: TypeSignature::ArraySignature(
ArrayFunctionSignature::ArrayAndIndex,
ArrayFunctionSignature::ArrayAndIndexes(count),
),
volatility,
}
Expand Down
4 changes: 2 additions & 2 deletions datafusion/expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ pub use udaf::{
aggregate_doc_sections, AggregateUDF, AggregateUDFImpl, ReversedUDAF, StatisticsArgs,
};
pub use udf::{
scalar_doc_sections, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs, ScalarUDF,
ScalarUDFImpl,
scalar_doc_sections, NullHandling, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs,
ScalarUDF, ScalarUDFImpl,
};
pub use udwf::{window_doc_sections, ReversedUDWF, WindowUDF, WindowUDFImpl};
pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits};
Expand Down
13 changes: 10 additions & 3 deletions datafusion/expr/src/type_coercion/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -671,13 +671,20 @@ fn get_valid_types(
ArrayFunctionSignature::ElementAndArray => {
array_append_or_prepend_valid_types(current_types, false)?
}
ArrayFunctionSignature::ArrayAndIndex => {
if current_types.len() != 2 {
ArrayFunctionSignature::ArrayAndIndexes(count) => {
if current_types.len() != count.get() + 1 {
return Ok(vec![vec![]]);
}
array(&current_types[0]).map_or_else(
|| vec![vec![]],
|array_type| vec![vec![array_type, DataType::Int64]],
|array_type| {
let mut inner = Vec::with_capacity(count.get() + 1);
inner.push(array_type);
for _ in 0..count.get() {
inner.push(DataType::Int64);
}
vec![inner]
},
)
}
ArrayFunctionSignature::ArrayAndElementAndOptionalIndex => {
Expand Down
2 changes: 1 addition & 1 deletion datafusion/expr/src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {

/// Whether the aggregate function is nullable.
///
/// Nullable means that that the function could return `null` for any inputs.
/// Nullable means that the function could return `null` for any inputs.
/// For example, aggregate functions like `COUNT` always return a non null value
/// but others like `MIN` will return `NULL` if there is nullable input.
/// Note that if the function is declared as *not* nullable, make sure the [`AggregateUDFImpl::default_value`] is `non-null`
Expand Down
19 changes: 19 additions & 0 deletions datafusion/expr/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,11 @@ impl ScalarUDF {
self.inner.return_type_from_args(args)
}

/// Returns the behavior that this function has when any of the inputs are Null.
pub fn null_handling(&self) -> NullHandling {
self.inner.null_handling()
}

/// Do the function rewrite
///
/// See [`ScalarUDFImpl::simplify`] for more details.
Expand Down Expand Up @@ -417,6 +422,15 @@ impl ReturnInfo {
}
}

/// A function's behavior when the input is Null.
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)]
pub enum NullHandling {
/// Null inputs are passed into the function implementation.
PassThrough,
/// Any Null input causes the function to return Null.
Propagate,
}

/// Trait for implementing user defined scalar functions.
///
/// This trait exposes the full API for implementing user defined functions and
Expand Down Expand Up @@ -589,6 +603,11 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
true
}

/// Returns the behavior that this function has when any of the inputs are Null.
fn null_handling(&self) -> NullHandling {
NullHandling::PassThrough
}

/// Invoke the function on `args`, returning the appropriate result
///
/// Note: This method is deprecated and will be removed in future releases.
Expand Down
75 changes: 54 additions & 21 deletions datafusion/functions-nested/src/extract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ use arrow::array::MutableArrayData;
use arrow::array::OffsetSizeTrait;
use arrow::buffer::OffsetBuffer;
use arrow::datatypes::DataType;
use arrow_buffer::NullBufferBuilder;
use arrow_schema::DataType::{FixedSizeList, LargeList, List};
use arrow_schema::Field;
use datafusion_common::cast::as_int64_array;
Expand All @@ -35,12 +36,13 @@ use datafusion_common::cast::as_list_array;
use datafusion_common::{
exec_err, internal_datafusion_err, plan_err, DataFusionError, Result,
};
use datafusion_expr::Expr;
use datafusion_expr::{ArrayFunctionSignature, Expr, TypeSignature};
use datafusion_expr::{
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
ColumnarValue, Documentation, NullHandling, ScalarUDFImpl, Signature, Volatility,
};
use datafusion_macros::user_doc;
use std::any::Any;
use std::num::NonZeroUsize;
use std::sync::Arc;

use crate::utils::make_scalar_function;
Expand Down Expand Up @@ -330,7 +332,26 @@ pub(super) struct ArraySlice {
impl ArraySlice {
pub fn new() -> Self {
Self {
signature: Signature::variadic_any(Volatility::Immutable),
signature: Signature::one_of(
vec![
TypeSignature::ArraySignature(
ArrayFunctionSignature::ArrayAndIndexes(
NonZeroUsize::new(1).expect("1 is non-zero"),
),
),
TypeSignature::ArraySignature(
ArrayFunctionSignature::ArrayAndIndexes(
NonZeroUsize::new(2).expect("2 is non-zero"),
),
),
TypeSignature::ArraySignature(
ArrayFunctionSignature::ArrayAndIndexes(
NonZeroUsize::new(3).expect("3 is non-zero"),
),
),
],
Volatility::Immutable,
),
aliases: vec![String::from("list_slice")],
}
}
Expand Down Expand Up @@ -374,6 +395,10 @@ impl ScalarUDFImpl for ArraySlice {
Ok(arg_types[0].clone())
}

fn null_handling(&self) -> NullHandling {
NullHandling::Propagate
}

fn invoke_batch(
&self,
args: &[ColumnarValue],
Expand Down Expand Up @@ -430,8 +455,6 @@ fn array_slice_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
}
LargeList(_) => {
let array = as_large_list_array(&args[0])?;
let from_array = as_int64_array(&args[1])?;
let to_array = as_int64_array(&args[2])?;
general_array_slice::<i64>(array, from_array, to_array, stride)
}
_ => exec_err!("array_slice does not support type: {:?}", array_data_type),
Expand All @@ -451,9 +474,8 @@ where
let original_data = values.to_data();
let capacity = Capacities::Array(original_data.len());

// use_nulls: false, we don't need nulls but empty array for array_slice, so we don't need explicit nulls but adjust offset to indicate nulls.
let mut mutable =
MutableArrayData::with_capacities(vec![&original_data], false, capacity);
MutableArrayData::with_capacities(vec![&original_data], true, capacity);

// We have the slice syntax compatible with DuckDB v0.8.1.
// The rule `adjusted_from_index` and `adjusted_to_index` follows the rule of array_slice in duckdb.
Expand Down Expand Up @@ -516,30 +538,33 @@ where
}

let mut offsets = vec![O::usize_as(0)];
let mut null_builder = NullBufferBuilder::new(array.len());

for (row_index, offset_window) in array.offsets().windows(2).enumerate() {
let start = offset_window[0];
let end = offset_window[1];
let len = end - start;

// len 0 indicate array is null, return empty array in this row.
// If any input is null, return null.
if array.is_null(row_index)
|| from_array.is_null(row_index)
|| to_array.is_null(row_index)
{
mutable.extend_nulls(1);
offsets.push(offsets[row_index] + O::usize_as(1));
null_builder.append_null();
continue;
}
null_builder.append_non_null();

// Empty arrays always return an empty array.
if len == O::usize_as(0) {
offsets.push(offsets[row_index]);
continue;
}

// If index is null, we consider it as the minimum / maximum index of the array.
let from_index = if from_array.is_null(row_index) {
Some(O::usize_as(0))
} else {
adjusted_from_index::<O>(from_array.value(row_index), len)?
};

let to_index = if to_array.is_null(row_index) {
Some(len - O::usize_as(1))
} else {
adjusted_to_index::<O>(to_array.value(row_index), len)?
};
let from_index = adjusted_from_index::<O>(from_array.value(row_index), len)?;
let to_index = adjusted_to_index::<O>(to_array.value(row_index), len)?;

if let (Some(from), Some(to)) = (from_index, to_index) {
let stride = stride.map(|s| s.value(row_index));
Expand Down Expand Up @@ -613,7 +638,7 @@ where
Arc::new(Field::new_list_field(array.value_type(), true)),
OffsetBuffer::<O>::new(offsets.into()),
arrow_array::make_array(data),
None,
null_builder.finish(),
)?))
}

Expand Down Expand Up @@ -665,6 +690,10 @@ impl ScalarUDFImpl for ArrayPopFront {
Ok(arg_types[0].clone())
}

fn null_handling(&self) -> NullHandling {
NullHandling::Propagate
}

fn invoke_batch(
&self,
args: &[ColumnarValue],
Expand Down Expand Up @@ -765,6 +794,10 @@ impl ScalarUDFImpl for ArrayPopBack {
Ok(arg_types[0].clone())
}

fn null_handling(&self) -> NullHandling {
NullHandling::Propagate
}

fn invoke_batch(
&self,
args: &[ColumnarValue],
Expand Down
12 changes: 11 additions & 1 deletion datafusion/physical-expr/src/scalar_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ use datafusion_expr::interval_arithmetic::Interval;
use datafusion_expr::sort_properties::ExprProperties;
use datafusion_expr::type_coercion::functions::data_types_with_scalar_udf;
use datafusion_expr::{
expr_vec_fmt, ColumnarValue, Expr, ReturnTypeArgs, ScalarFunctionArgs, ScalarUDF,
expr_vec_fmt, ColumnarValue, Expr, NullHandling, ReturnTypeArgs, ScalarFunctionArgs,
ScalarUDF,
};

/// Physical expression of a scalar function
Expand Down Expand Up @@ -186,6 +187,15 @@ impl PhysicalExpr for ScalarFunctionExpr {
.map(|e| e.evaluate(batch))
.collect::<Result<Vec<_>>>()?;

if self.fun.null_handling() == NullHandling::Propagate
&& args.iter().any(
|arg| matches!(arg, ColumnarValue::Scalar(scalar) if scalar.is_null()),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not super confident about this check, how should ColumnarValue::Arrays be treated?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I think I understand this now. If the function is called with a single set of arguments then each arg will be a ColumnarValue::Scalar. However, if the function is called on a batch of arguments, then each arg will be a ColumnarValue::Array of all the arguments. So this does not work in the batch case.

What we'd really like is to identify all indexes, i, s.t. one of the args at index i is Null. Then somehow skip all rows at the identified indexes and immediately return Null for those. That seems a little tricky because it looks like we pass the entire ArrayRef to the function implementation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I think I understand this now. If the function is called with a single set of arguments then each arg will be a ColumnarValue::Scalar. However, if the function is called on a batch of arguments, then each arg will be a ColumnarValue::Array of all the arguments. So this does not work in the batch case.

What we'd really like is to identify all indexes, i, s.t. one of the args at index i is Null. Then somehow skip all rows at the identified indexes and immediately return Null for those. That seems a little tricky because it looks like we pass the entire ArrayRef to the function implementation.

I don't think we need to peek the null for column case, they should be specific logic handled for each function. For scalar case, since most of the scalar function returns null if any one of args is null, it is beneficial to introduce another null handling method. It is just convenient method nice to have but not the must have solution for null handling since they can be handled in 'invoke' as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is just convenient method nice to have but not the must have solution for null handling since they can be handled in 'invoke' as well.

If someone forgets to handle nulls in invoke, then don't we run the risk of accidentally returning different results depending on if the function was called with scalar arguments or with a batch of arguments?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, I'm not sure I understand the point of the check here, if the invoke implementation also has to handle nulls, but maybe I'm misunderstanding what you're saying.

Copy link
Contributor

@jayzhan211 jayzhan211 Feb 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scalar_function(null, null, ...) and scalar_function(column_contains_null, ...). We can only handling nulls but not column_contains_null because we don't now the data in the column

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm sorry, I don't think I understand your response. So I should leave this code block here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.

)
{
let null_value = ScalarValue::try_from(&self.return_type)?;
return Ok(ColumnarValue::Scalar(null_value));
}

let input_empty = args.is_empty();
let input_all_scalar = args
.iter()
Expand Down
Loading