Skip to content

Commit

Permalink
improve union behaviour with null values (#27)
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin authored Jul 9, 2024
1 parent a645a62 commit efe0491
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 29 deletions.
9 changes: 3 additions & 6 deletions src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,8 @@ pub fn invoke<C: FromIterator<Option<I>> + 'static, I>(
Some(ColumnarValue::Array(a)) => {
if args.len() > 2 {
// TODO perhaps we could support this by zipping the arrays, but it's not trivial, #23
return exec_err!(
"More than 1 path element is not supported when querying JSON using an array."
);
}
if let Some(str_path_array) = a.as_any().downcast_ref::<StringArray>() {
exec_err!("More than 1 path element is not supported when querying JSON using an array.")
} else if let Some(str_path_array) = a.as_any().downcast_ref::<StringArray>() {
let paths = str_path_array.iter().map(|opt_key| opt_key.map(JsonPath::Key));
zip_apply(json_array, paths, jiter_find, true)
} else if let Some(str_path_array) = a.as_any().downcast_ref::<LargeStringArray>() {
Expand All @@ -94,7 +91,7 @@ pub fn invoke<C: FromIterator<Option<I>> + 'static, I>(
let paths = int_path_array.iter().map(|opt_index| opt_index.map(Into::into));
zip_apply(json_array, paths, jiter_find, false)
} else {
return exec_err!("unexpected second argument type, expected string or int array");
exec_err!("unexpected second argument type, expected string or int array")
}
}
Some(ColumnarValue::Scalar(_)) => scalar_apply(json_array, &JsonPath::extract_path(args), jiter_find),
Expand Down
35 changes: 16 additions & 19 deletions src/common_union.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::sync::{Arc, OnceLock};

use arrow::array::{Array, ArrayRef, BooleanArray, Float64Array, Int64Array, StringArray, UnionArray};
use arrow::array::{Array, ArrayRef, BooleanArray, Float64Array, Int64Array, NullArray, StringArray, UnionArray};
use arrow::buffer::Buffer;
use arrow_schema::{DataType, Field, UnionFields, UnionMode};
use datafusion_common::ScalarValue;
Expand Down Expand Up @@ -42,7 +42,6 @@ pub(crate) fn json_from_union_scalar<'a>(

#[derive(Debug)]
pub(crate) struct JsonUnion {
nulls: Vec<Option<bool>>,
bools: Vec<Option<bool>>,
ints: Vec<Option<i64>>,
floats: Vec<Option<f64>>,
Expand All @@ -51,22 +50,21 @@ pub(crate) struct JsonUnion {
objects: Vec<Option<String>>,
type_ids: Vec<i8>,
index: usize,
capacity: usize,
length: usize,
}

impl JsonUnion {
fn new(capacity: usize) -> Self {
fn new(length: usize) -> Self {
Self {
nulls: vec![None; capacity],
bools: vec![None; capacity],
ints: vec![None; capacity],
floats: vec![None; capacity],
strings: vec![None; capacity],
arrays: vec![None; capacity],
objects: vec![None; capacity],
type_ids: vec![0; capacity],
bools: vec![None; length],
ints: vec![None; length],
floats: vec![None; length],
strings: vec![None; length],
arrays: vec![None; length],
objects: vec![None; length],
type_ids: vec![0; length],
index: 0,
capacity,
length,
}
}

Expand All @@ -77,7 +75,7 @@ impl JsonUnion {
fn push(&mut self, field: JsonUnionField) {
self.type_ids[self.index] = field.type_id();
match field {
JsonUnionField::JsonNull => self.nulls[self.index] = Some(true),
JsonUnionField::JsonNull => (),
JsonUnionField::Bool(value) => self.bools[self.index] = Some(value),
JsonUnionField::Int(value) => self.ints[self.index] = Some(value),
JsonUnionField::Float(value) => self.floats[self.index] = Some(value),
Expand All @@ -86,13 +84,12 @@ impl JsonUnion {
JsonUnionField::Object(value) => self.objects[self.index] = Some(value),
}
self.index += 1;
debug_assert!(self.index <= self.capacity);
debug_assert!(self.index <= self.length);
}

fn push_none(&mut self) {
self.type_ids[self.index] = TYPE_ID_NULL;
self.index += 1;
debug_assert!(self.index <= self.capacity);
debug_assert!(self.index <= self.length);
}
}

Expand All @@ -119,7 +116,7 @@ impl TryFrom<JsonUnion> for UnionArray {

fn try_from(value: JsonUnion) -> Result<Self, Self::Error> {
let children: Vec<Arc<dyn Array>> = vec![
Arc::new(BooleanArray::from(value.nulls)),
Arc::new(NullArray::new(value.length)),
Arc::new(BooleanArray::from(value.bools)),
Arc::new(Int64Array::from(value.ints)),
Arc::new(Float64Array::from(value.floats)),
Expand Down Expand Up @@ -155,7 +152,7 @@ fn union_fields() -> UnionFields {
FIELDS
.get_or_init(|| {
UnionFields::from_iter([
(TYPE_ID_NULL, Arc::new(Field::new("null", DataType::Boolean, true))),
(TYPE_ID_NULL, Arc::new(Field::new("null", DataType::Null, true))),
(TYPE_ID_BOOL, Arc::new(Field::new("bool", DataType::Boolean, false))),
(TYPE_ID_INT, Arc::new(Field::new("int", DataType::Int64, false))),
(TYPE_ID_FLOAT, Arc::new(Field::new("float", DataType::Float64, false))),
Expand Down
117 changes: 113 additions & 4 deletions tests/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ async fn test_json_get_union() {
"| object_foo | {str=abc} |",
"| object_foo_array | {array=[1]} |",
"| object_foo_obj | {object={}} |",
"| object_foo_null | {null=true} |",
"| object_foo_null | {null=} |",
"| object_bar | {null=} |",
"| list_foo | {null=} |",
"| invalid_json | {null=} |",
Expand Down Expand Up @@ -675,7 +675,7 @@ async fn test_json_get_union_array_nested() {
"+-------------+",
"| {array=[0]} |",
"| {null=} |",
"| {null=true} |",
"| {null=} |",
"+-------------+",
];

Expand Down Expand Up @@ -725,7 +725,7 @@ async fn test_arrow() {
"| object_foo | {str=abc} |",
"| object_foo_array | {array=[1]} |",
"| object_foo_obj | {object={}} |",
"| object_foo_null | {null=true} |",
"| object_foo_null | {null=} |",
"| object_bar | {null=} |",
"| list_foo | {null=} |",
"| invalid_json | {null=} |",
Expand Down Expand Up @@ -903,7 +903,7 @@ async fn test_arrow_nested_columns() {
"+-------------+",
"| {array=[0]} |",
"| {null=} |",
"| {null=true} |",
"| {null=} |",
"+-------------+",
];

Expand Down Expand Up @@ -990,3 +990,112 @@ async fn test_question_filter() {
];
assert_batches_eq!(expected, &batches);
}

#[tokio::test]
async fn test_json_get_union_is_null() {
let batches = run_query("select name, json_get(json_data, 'foo') is null from test")
.await
.unwrap();

let expected = [
"+------------------+----------------------------------------------+",
"| name | json_get(test.json_data,Utf8(\"foo\")) IS NULL |",
"+------------------+----------------------------------------------+",
"| object_foo | false |",
"| object_foo_array | false |",
"| object_foo_obj | false |",
"| object_foo_null | true |",
"| object_bar | true |",
"| list_foo | true |",
"| invalid_json | true |",
"+------------------+----------------------------------------------+",
];
assert_batches_eq!(expected, &batches);
}

#[tokio::test]
async fn test_json_get_union_is_not_null() {
let batches = run_query("select name, json_get(json_data, 'foo') is not null from test")
.await
.unwrap();

let expected = [
"+------------------+--------------------------------------------------+",
"| name | json_get(test.json_data,Utf8(\"foo\")) IS NOT NULL |",
"+------------------+--------------------------------------------------+",
"| object_foo | true |",
"| object_foo_array | true |",
"| object_foo_obj | true |",
"| object_foo_null | false |",
"| object_bar | false |",
"| list_foo | false |",
"| invalid_json | false |",
"+------------------+--------------------------------------------------+",
];
assert_batches_eq!(expected, &batches);
}

#[tokio::test]
async fn test_arrow_union_is_null() {
let batches = run_query("select name, (json_data->'foo') is null from test")
.await
.unwrap();

let expected = [
"+------------------+----------------------------------+",
"| name | json_data -> Utf8(\"foo\") IS NULL |",
"+------------------+----------------------------------+",
"| object_foo | false |",
"| object_foo_array | false |",
"| object_foo_obj | false |",
"| object_foo_null | true |",
"| object_bar | true |",
"| list_foo | true |",
"| invalid_json | true |",
"+------------------+----------------------------------+",
];
assert_batches_eq!(expected, &batches);
}

#[tokio::test]
async fn test_arrow_union_is_not_null() {
let batches = run_query("select name, (json_data->'foo') is not null from test")
.await
.unwrap();

let expected = [
"+------------------+--------------------------------------+",
"| name | json_data -> Utf8(\"foo\") IS NOT NULL |",
"+------------------+--------------------------------------+",
"| object_foo | true |",
"| object_foo_array | true |",
"| object_foo_obj | true |",
"| object_foo_null | false |",
"| object_bar | false |",
"| list_foo | false |",
"| invalid_json | false |",
"+------------------+--------------------------------------+",
];
assert_batches_eq!(expected, &batches);
}

#[tokio::test]
async fn test_arrow_scalar_union_is_null() {
let batches = run_query(
r#"
select ('{"x": 1}'->'foo') is null as not_contains,
('{"foo": 1}'->'foo') is null as contains_num,
('{"foo": null}'->'foo') is null as contains_null"#,
)
.await
.unwrap();

let expected = [
"+--------------+--------------+---------------+",
"| not_contains | contains_num | contains_null |",
"+--------------+--------------+---------------+",
"| true | false | true |",
"+--------------+--------------+---------------+",
];
assert_batches_eq!(expected, &batches);
}

0 comments on commit efe0491

Please sign in to comment.