From a645a62f8a3cec4c03770852a0eeb533e690e676 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Tue, 9 Jul 2024 09:24:42 +0100 Subject: [PATCH] support arrow functions with `ExprPlanner` (#26) --- Cargo.toml | 19 ++- src/lib.rs | 1 + src/rewrite.rs | 103 ++++++++++------ tests/main.rs | 295 +++++++++++++++++++++++++++++++++++++++++++-- tests/utils/mod.rs | 4 +- 5 files changed, 371 insertions(+), 51 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index cf1a79e..565c91c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,22 +11,29 @@ repository = "https://github.com/datafusion-contrib/datafusion-functions-json/" rust-version = "1.76.0" [dependencies] -arrow = "52" -arrow-schema = "52" -datafusion-common = "39" -datafusion-expr = "39" +arrow = "52.1.0" +arrow-schema = "52.1.0" +datafusion-common = "40" +datafusion-expr = "40" +datafusion-execution = "40" jiter = "0.5" paste = "1" log = "0.4" -datafusion-execution = "39" [dev-dependencies] codspeed-criterion-compat = "2.3" criterion = "0.5.1" -datafusion = "39" +datafusion = "40" clap = "4" tokio = { version = "1.37", features = ["full"] } +[patch.crates-io] +# TODO: remove this once datafusion 40.0 is released +datafusion = { git = "https://github.com/apache/datafusion.git", rev = "4123ad6" } +datafusion-common = { git = "https://github.com/apache/datafusion.git", rev = "4123ad6" } +datafusion-expr = { git = "https://github.com/apache/datafusion.git", rev = "4123ad6" } +datafusion-execution = { git = "https://github.com/apache/datafusion.git", rev = "4123ad6" } + [lints.clippy] dbg_macro = "deny" print_stdout = "deny" diff --git a/src/lib.rs b/src/lib.rs index 76f9e7e..94484aa 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -67,6 +67,7 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { Ok(()) as Result<()> })?; registry.register_function_rewrite(Arc::new(rewrite::JsonFunctionRewriter))?; + registry.register_expr_planner(Arc::new(rewrite::JsonExprPlanner))?; Ok(()) } diff --git a/src/rewrite.rs b/src/rewrite.rs index b1e2e15..6fed579 100644 --- a/src/rewrite.rs +++ b/src/rewrite.rs @@ -3,9 +3,10 @@ use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::Transformed; use datafusion_common::DFSchema; use datafusion_common::Result; -use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::expr::{Alias, Cast, Expr, ScalarFunction}; use datafusion_expr::expr_rewriter::FunctionRewrite; -use datafusion_expr::Expr; +use datafusion_expr::planner::{ExprPlanner, PlannerResult, RawBinaryExpr}; +use datafusion_expr::sqlparser::ast::BinaryOperator; pub(crate) struct JsonFunctionRewriter; @@ -15,25 +16,37 @@ impl FunctionRewrite for JsonFunctionRewriter { } fn rewrite(&self, expr: Expr, _schema: &DFSchema, _config: &ConfigOptions) -> Result> { - if let Expr::Cast(cast) = &expr { - if let Expr::ScalarFunction(func) = &*cast.expr { - if func.func.name() == "json_get" { - if let Some(t) = switch_json_get(&cast.data_type, &func.args) { - return Ok(t); - } - } - } - } else if let Expr::ScalarFunction(func) = &expr { - if let Some(new_func) = unnest_json_calls(func) { - return Ok(Transformed::yes(Expr::ScalarFunction(new_func))); - } - } - Ok(Transformed::no(expr)) + let transform = match &expr { + Expr::Cast(cast) => optimise_json_get_cast(cast), + Expr::ScalarFunction(func) => unnest_json_calls(func), + _ => None, + }; + Ok(transform.unwrap_or_else(|| Transformed::no(expr))) } } +/// This replaces `get_json(foo, bar)::int` with `json_get_int(foo, bar)` so the JSON function can take care of +/// extracting the right value type from JSON without the need to materialize the JSON union. +fn optimise_json_get_cast(cast: &Cast) -> Option> { + let scalar_func = extract_scalar_function(&cast.expr)?; + if scalar_func.func.name() != "json_get" { + return None; + } + let func = match &cast.data_type { + DataType::Boolean => crate::json_get_bool::json_get_bool_udf(), + DataType::Float64 | DataType::Float32 => crate::json_get_float::json_get_float_udf(), + DataType::Int64 | DataType::Int32 => crate::json_get_int::json_get_int_udf(), + DataType::Utf8 => crate::json_get_str::json_get_str_udf(), + _ => return None, + }; + Some(Transformed::yes(Expr::ScalarFunction(ScalarFunction { + func, + args: scalar_func.args.clone(), + }))) +} + // Replace nested JSON functions e.g. `json_get(json_get(col, 'foo'), 'bar')` with `json_get(col, 'foo', 'bar')` -fn unnest_json_calls(func: &ScalarFunction) -> Option { +fn unnest_json_calls(func: &ScalarFunction) -> Option> { if !matches!( func.func.name(), "json_get" | "json_get_bool" | "json_get_float" | "json_get_int" | "json_get_json" | "json_get_str" @@ -42,9 +55,7 @@ fn unnest_json_calls(func: &ScalarFunction) -> Option { } let mut outer_args_iter = func.args.iter(); let first_arg = outer_args_iter.next()?; - let Expr::ScalarFunction(inner_func) = first_arg else { - return None; - }; + let inner_func = extract_scalar_function(first_arg)?; if inner_func.func.name() != "json_get" { return None; } @@ -53,26 +64,48 @@ fn unnest_json_calls(func: &ScalarFunction) -> Option { args.extend(outer_args_iter.cloned()); // See #23, unnest only when all lookup arguments are literals if args.iter().skip(1).all(|arg| matches!(arg, Expr::Literal(_))) { - Some(ScalarFunction { + Some(Transformed::yes(Expr::ScalarFunction(ScalarFunction { func: func.func.clone(), args, - }) + }))) } else { None } } -fn switch_json_get(cast_data_type: &DataType, args: &[Expr]) -> Option> { - let func = match cast_data_type { - DataType::Boolean => crate::json_get_bool::json_get_bool_udf(), - DataType::Float64 | DataType::Float32 => crate::json_get_float::json_get_float_udf(), - DataType::Int64 | DataType::Int32 => crate::json_get_int::json_get_int_udf(), - DataType::Utf8 => crate::json_get_str::json_get_str_udf(), - _ => return None, - }; - let f = ScalarFunction { - func, - args: args.to_vec(), - }; - Some(Transformed::yes(Expr::ScalarFunction(f))) +fn extract_scalar_function(expr: &Expr) -> Option<&ScalarFunction> { + match expr { + Expr::ScalarFunction(func) => Some(func), + Expr::Alias(alias) => extract_scalar_function(&*alias.expr), + _ => None, + } +} + +/// Implement a custom SQL planner to replace postgres JSON operators with custom UDFs +#[derive(Debug, Default)] +pub struct JsonExprPlanner; + +impl ExprPlanner for JsonExprPlanner { + fn plan_binary_op(&self, expr: RawBinaryExpr, _schema: &DFSchema) -> Result> { + let (func, op_display) = match &expr.op { + BinaryOperator::Arrow => (crate::json_get::json_get_udf(), "->"), + BinaryOperator::LongArrow => (crate::json_get_str::json_get_str_udf(), "->>"), + BinaryOperator::Question => (crate::json_contains::json_contains_udf(), "?"), + _ => return Ok(PlannerResult::Original(expr)), + }; + let alias_name = match &expr.left { + Expr::Alias(alias) => format!("{} {} {}", alias.name, op_display, expr.right), + left_expr => format!("{} {} {}", left_expr, op_display, expr.right), + }; + + // we put the alias in so that default column titles are `foo -> bar` instead of `json_get(foo, bar)` + Ok(PlannerResult::Planned(Expr::Alias(Alias::new( + Expr::ScalarFunction(ScalarFunction { + func, + args: vec![expr.left, expr.right], + }), + None::<&str>, + alias_name, + )))) + } } diff --git a/tests/main.rs b/tests/main.rs index 87e5a14..26a79fe 100644 --- a/tests/main.rs +++ b/tests/main.rs @@ -362,7 +362,7 @@ async fn test_json_length_object_nested() { async fn test_json_contains_large() { let expected = [ "+----------+", - "| COUNT(*) |", + "| count(*) |", "+----------+", "| 4 |", "+----------+", @@ -378,7 +378,7 @@ async fn test_json_contains_large() { async fn test_json_contains_large_vec() { let expected = [ "+----------+", - "| COUNT(*) |", + "| count(*) |", "+----------+", "| 0 |", "+----------+", @@ -394,7 +394,7 @@ async fn test_json_contains_large_vec() { async fn test_json_contains_large_both() { let expected = [ "+----------+", - "| COUNT(*) |", + "| count(*) |", "+----------+", "| 0 |", "+----------+", @@ -410,7 +410,7 @@ async fn test_json_contains_large_both() { async fn test_json_contains_large_params() { let expected = [ "+----------+", - "| COUNT(*) |", + "| count(*) |", "+----------+", "| 4 |", "+----------+", @@ -426,7 +426,7 @@ async fn test_json_contains_large_params() { async fn test_json_contains_large_both_params() { let expected = [ "+----------+", - "| COUNT(*) |", + "| count(*) |", "+----------+", "| 4 |", "+----------+", @@ -570,7 +570,7 @@ async fn test_json_get_cte() { } #[tokio::test] -async fn test_json_get_cte_plan() { +async fn test_plan_json_get_cte() { // avoid auto-unnesting with a CTE let sql = r#" explain @@ -611,7 +611,7 @@ async fn test_json_get_unnest() { } #[tokio::test] -async fn test_json_get_unnest_plan() { +async fn test_plan_json_get_unnest() { let sql = "explain select json_get(json_get(json_data, 'foo'), 0) v from test"; let expected = [ "Projection: json_get(test.json_data, Utf8(\"foo\"), Int64(0)) AS v", @@ -645,7 +645,7 @@ async fn test_json_get_int_unnest() { } #[tokio::test] -async fn test_json_get_int_unnest_plan() { +async fn test_plan_json_get_int_unnest() { let sql = "explain select json_get(json_get(json_data, 'foo'), 0)::int v from test"; let expected = [ "Projection: json_get_int(test.json_data, Utf8(\"foo\"), Int64(0)) AS v", @@ -684,7 +684,7 @@ async fn test_json_get_union_array_nested() { } #[tokio::test] -async fn test_json_get_union_array_nested_plan() { +async fn test_plan_json_get_union_array_nested() { let sql = "explain select json_get(json_get(json_data, str_key1), str_key2) v from more_nested"; // json_get is not un-nested because lookup types are not literals let expected = [ @@ -713,3 +713,280 @@ async fn test_json_get_union_array_skip_double_nested() { let batches = run_query(sql).await.unwrap(); assert_batches_eq!(expected, &batches); } + +#[tokio::test] +async fn test_arrow() { + let batches = run_query("select name, json_data->'foo' from test").await.unwrap(); + + let expected = [ + "+------------------+--------------------------+", + "| name | json_data -> Utf8(\"foo\") |", + "+------------------+--------------------------+", + "| object_foo | {str=abc} |", + "| object_foo_array | {array=[1]} |", + "| object_foo_obj | {object={}} |", + "| object_foo_null | {null=true} |", + "| object_bar | {null=} |", + "| list_foo | {null=} |", + "| invalid_json | {null=} |", + "+------------------+--------------------------+", + ]; + assert_batches_eq!(expected, &batches); +} + +#[tokio::test] +async fn test_plan_arrow() { + let lines = logical_plan(r#"explain select json_data->'foo' from test"#).await; + + let expected = [ + "Projection: json_get(test.json_data, Utf8(\"foo\")) AS json_data -> Utf8(\"foo\")", + " TableScan: test projection=[json_data]", + ]; + + assert_eq!(lines, expected); +} + +#[tokio::test] +async fn test_long_arrow() { + let batches = run_query("select name, json_data->>'foo' from test").await.unwrap(); + + let expected = [ + "+------------------+---------------------------+", + "| name | json_data ->> Utf8(\"foo\") |", + "+------------------+---------------------------+", + "| object_foo | abc |", + "| object_foo_array | |", + "| object_foo_obj | |", + "| object_foo_null | |", + "| object_bar | |", + "| list_foo | |", + "| invalid_json | |", + "+------------------+---------------------------+", + ]; + assert_batches_eq!(expected, &batches); +} + +#[tokio::test] +async fn test_plan_long_arrow() { + let lines = logical_plan(r#"explain select json_data->>'foo' from test"#).await; + + let expected = [ + "Projection: json_get_str(test.json_data, Utf8(\"foo\")) AS json_data ->> Utf8(\"foo\")", + " TableScan: test projection=[json_data]", + ]; + + assert_eq!(lines, expected); +} + +#[tokio::test] +async fn test_long_arrow_eq_str() { + let batches = run_query(r"select name, (json_data->>'foo')='abc' from test") + .await + .unwrap(); + + let expected = [ + "+------------------+-----------------------------------------+", + "| name | json_data ->> Utf8(\"foo\") = Utf8(\"abc\") |", + "+------------------+-----------------------------------------+", + "| object_foo | true |", + "| object_foo_array | |", + "| object_foo_obj | |", + "| object_foo_null | |", + "| object_bar | |", + "| list_foo | |", + "| invalid_json | |", + "+------------------+-----------------------------------------+", + ]; + assert_batches_eq!(expected, &batches); +} + +#[tokio::test] +async fn test_arrow_cast_int() { + let sql = r#"select ('{"foo": 42}'->'foo')::int"#; + let batches = run_query(sql).await.unwrap(); + + let expected = [ + "+------------------------------------+", + "| Utf8(\"{\"foo\": 42}\") -> Utf8(\"foo\") |", + "+------------------------------------+", + "| 42 |", + "+------------------------------------+", + ]; + assert_batches_eq!(expected, &batches); + + assert_eq!(display_val(batches).await, (DataType::Int64, "42".to_string())); +} + +#[tokio::test] +async fn test_plan_arrow_cast_int() { + let lines = logical_plan(r#"explain select (json_data->'foo')::int from test"#).await; + + let expected = [ + "Projection: json_get_int(test.json_data, Utf8(\"foo\")) AS json_data -> Utf8(\"foo\")", + " TableScan: test projection=[json_data]", + ]; + + assert_eq!(lines, expected); +} + +#[tokio::test] +async fn test_arrow_double_nested() { + let batches = run_query("select name, json_data->'foo'->0 from test").await.unwrap(); + + let expected = [ + "+------------------+--------------------------------------+", + "| name | json_data -> Utf8(\"foo\") -> Int64(0) |", + "+------------------+--------------------------------------+", + "| object_foo | {null=} |", + "| object_foo_array | {int=1} |", + "| object_foo_obj | {null=} |", + "| object_foo_null | {null=} |", + "| object_bar | {null=} |", + "| list_foo | {null=} |", + "| invalid_json | {null=} |", + "+------------------+--------------------------------------+", + ]; + assert_batches_eq!(expected, &batches); +} + +#[tokio::test] +async fn test_plan_arrow_double_nested() { + let lines = logical_plan(r#"explain select json_data->'foo'->0 from test"#).await; + + let expected = [ + "Projection: json_get(test.json_data, Utf8(\"foo\"), Int64(0)) AS json_data -> Utf8(\"foo\") -> Int64(0)", + " TableScan: test projection=[json_data]", + ]; + + assert_eq!(lines, expected); +} + +#[tokio::test] +async fn test_arrow_double_nested_cast() { + let batches = run_query("select name, (json_data->'foo'->0)::int from test") + .await + .unwrap(); + + let expected = [ + "+------------------+--------------------------------------+", + "| name | json_data -> Utf8(\"foo\") -> Int64(0) |", + "+------------------+--------------------------------------+", + "| object_foo | |", + "| object_foo_array | 1 |", + "| object_foo_obj | |", + "| object_foo_null | |", + "| object_bar | |", + "| list_foo | |", + "| invalid_json | |", + "+------------------+--------------------------------------+", + ]; + assert_batches_eq!(expected, &batches); +} + +#[tokio::test] +async fn test_plan_arrow_double_nested_cast() { + let lines = logical_plan(r#"explain select (json_data->'foo'->0)::int from test"#).await; + + let expected = [ + "Projection: json_get_int(test.json_data, Utf8(\"foo\"), Int64(0)) AS json_data -> Utf8(\"foo\") -> Int64(0)", + " TableScan: test projection=[json_data]", + ]; + + assert_eq!(lines, expected); +} + +#[tokio::test] +async fn test_arrow_nested_columns() { + let expected = [ + "+-------------+", + "| v |", + "+-------------+", + "| {array=[0]} |", + "| {null=} |", + "| {null=true} |", + "+-------------+", + ]; + + let sql = "select json_data->str_key1->str_key2 v from more_nested"; + let batches = run_query(sql).await.unwrap(); + assert_batches_eq!(expected, &batches); +} + +#[tokio::test] +async fn test_arrow_nested_double_columns() { + let expected = [ + "+---------+", + "| v |", + "+---------+", + "| {int=0} |", + "| {null=} |", + "| {null=} |", + "+---------+", + ]; + + let sql = "select json_data->str_key1->str_key2->int_key v from more_nested"; + let batches = run_query(sql).await.unwrap(); + assert_batches_eq!(expected, &batches); +} + +#[tokio::test] +async fn test_lexical_precedence_wrong() { + let sql = r#"select '{"a": "b"}'->>'a'='b' as v"#; + let err = run_query(sql).await.unwrap_err(); + assert_eq!(err.to_string(), "Error during planning: Unexpected argument type to 'json_get_str' at position 2, expected string or int, got Boolean.") +} + +#[tokio::test] +async fn test_question_mark_contains() { + let expected = [ + "+------------------+-------------------------+", + "| name | json_data ? Utf8(\"foo\") |", + "+------------------+-------------------------+", + "| object_foo | true |", + "| object_foo_array | true |", + "| object_foo_obj | true |", + "| object_foo_null | true |", + "| object_bar | false |", + "| list_foo | false |", + "| invalid_json | false |", + "+------------------+-------------------------+", + ]; + + let batches = run_query("select name, json_data ? 'foo' from test").await.unwrap(); + assert_batches_eq!(expected, &batches); +} + +#[tokio::test] +async fn test_arrow_filter() { + let batches = run_query("select name from test where (json_data->>'foo') = 'abc'") + .await + .unwrap(); + + let expected = [ + "+------------+", + "| name |", + "+------------+", + "| object_foo |", + "+------------+", + ]; + assert_batches_eq!(expected, &batches); +} + +#[tokio::test] +async fn test_question_filter() { + let batches = run_query("select name from test where json_data ? 'foo'") + .await + .unwrap(); + + let expected = [ + "+------------------+", + "| name |", + "+------------------+", + "| object_foo |", + "| object_foo_array |", + "| object_foo_obj |", + "| object_foo_null |", + "+------------------+", + ]; + assert_batches_eq!(expected, &batches); +} diff --git a/tests/utils/mod.rs b/tests/utils/mod.rs index ec5ce9d..a5279f9 100644 --- a/tests/utils/mod.rs +++ b/tests/utils/mod.rs @@ -9,10 +9,12 @@ use arrow::{array::LargeStringArray, array::StringArray, record_batch::RecordBat use datafusion::error::Result; use datafusion::execution::context::SessionContext; use datafusion_common::ParamValues; +use datafusion_execution::config::SessionConfig; use datafusion_functions_json::register_all; async fn create_test_table(large_utf8: bool) -> Result { - let mut ctx = SessionContext::new(); + let config = SessionConfig::new().set_str("datafusion.sql_parser.dialect", "postgres"); + let mut ctx = SessionContext::new_with_config(config); register_all(&mut ctx)?; let test_data = [