Skip to content

Commit

Permalink
support arrow functions with ExprPlanner (#26)
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin authored Jul 9, 2024
1 parent 40a8090 commit a645a62
Show file tree
Hide file tree
Showing 5 changed files with 371 additions and 51 deletions.
19 changes: 13 additions & 6 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,29 @@ repository = "https://github.com/datafusion-contrib/datafusion-functions-json/"
rust-version = "1.76.0"

[dependencies]
arrow = "52"
arrow-schema = "52"
datafusion-common = "39"
datafusion-expr = "39"
arrow = "52.1.0"
arrow-schema = "52.1.0"
datafusion-common = "40"
datafusion-expr = "40"
datafusion-execution = "40"
jiter = "0.5"
paste = "1"
log = "0.4"
datafusion-execution = "39"

[dev-dependencies]
codspeed-criterion-compat = "2.3"
criterion = "0.5.1"
datafusion = "39"
datafusion = "40"
clap = "4"
tokio = { version = "1.37", features = ["full"] }

[patch.crates-io]
# TODO: remove this once datafusion 40.0 is released
datafusion = { git = "https://github.com/apache/datafusion.git", rev = "4123ad6" }
datafusion-common = { git = "https://github.com/apache/datafusion.git", rev = "4123ad6" }
datafusion-expr = { git = "https://github.com/apache/datafusion.git", rev = "4123ad6" }
datafusion-execution = { git = "https://github.com/apache/datafusion.git", rev = "4123ad6" }

[lints.clippy]
dbg_macro = "deny"
print_stdout = "deny"
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> {
Ok(()) as Result<()>
})?;
registry.register_function_rewrite(Arc::new(rewrite::JsonFunctionRewriter))?;
registry.register_expr_planner(Arc::new(rewrite::JsonExprPlanner))?;

Ok(())
}
103 changes: 68 additions & 35 deletions src/rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@ use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::Transformed;
use datafusion_common::DFSchema;
use datafusion_common::Result;
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::expr::{Alias, Cast, Expr, ScalarFunction};
use datafusion_expr::expr_rewriter::FunctionRewrite;
use datafusion_expr::Expr;
use datafusion_expr::planner::{ExprPlanner, PlannerResult, RawBinaryExpr};
use datafusion_expr::sqlparser::ast::BinaryOperator;

pub(crate) struct JsonFunctionRewriter;

Expand All @@ -15,25 +16,37 @@ impl FunctionRewrite for JsonFunctionRewriter {
}

fn rewrite(&self, expr: Expr, _schema: &DFSchema, _config: &ConfigOptions) -> Result<Transformed<Expr>> {
if let Expr::Cast(cast) = &expr {
if let Expr::ScalarFunction(func) = &*cast.expr {
if func.func.name() == "json_get" {
if let Some(t) = switch_json_get(&cast.data_type, &func.args) {
return Ok(t);
}
}
}
} else if let Expr::ScalarFunction(func) = &expr {
if let Some(new_func) = unnest_json_calls(func) {
return Ok(Transformed::yes(Expr::ScalarFunction(new_func)));
}
}
Ok(Transformed::no(expr))
let transform = match &expr {
Expr::Cast(cast) => optimise_json_get_cast(cast),
Expr::ScalarFunction(func) => unnest_json_calls(func),
_ => None,
};
Ok(transform.unwrap_or_else(|| Transformed::no(expr)))
}
}

/// This replaces `get_json(foo, bar)::int` with `json_get_int(foo, bar)` so the JSON function can take care of
/// extracting the right value type from JSON without the need to materialize the JSON union.
fn optimise_json_get_cast(cast: &Cast) -> Option<Transformed<Expr>> {
let scalar_func = extract_scalar_function(&cast.expr)?;
if scalar_func.func.name() != "json_get" {
return None;
}
let func = match &cast.data_type {
DataType::Boolean => crate::json_get_bool::json_get_bool_udf(),
DataType::Float64 | DataType::Float32 => crate::json_get_float::json_get_float_udf(),
DataType::Int64 | DataType::Int32 => crate::json_get_int::json_get_int_udf(),
DataType::Utf8 => crate::json_get_str::json_get_str_udf(),
_ => return None,
};
Some(Transformed::yes(Expr::ScalarFunction(ScalarFunction {
func,
args: scalar_func.args.clone(),
})))
}

// Replace nested JSON functions e.g. `json_get(json_get(col, 'foo'), 'bar')` with `json_get(col, 'foo', 'bar')`
fn unnest_json_calls(func: &ScalarFunction) -> Option<ScalarFunction> {
fn unnest_json_calls(func: &ScalarFunction) -> Option<Transformed<Expr>> {
if !matches!(
func.func.name(),
"json_get" | "json_get_bool" | "json_get_float" | "json_get_int" | "json_get_json" | "json_get_str"
Expand All @@ -42,9 +55,7 @@ fn unnest_json_calls(func: &ScalarFunction) -> Option<ScalarFunction> {
}
let mut outer_args_iter = func.args.iter();
let first_arg = outer_args_iter.next()?;
let Expr::ScalarFunction(inner_func) = first_arg else {
return None;
};
let inner_func = extract_scalar_function(first_arg)?;
if inner_func.func.name() != "json_get" {
return None;
}
Expand All @@ -53,26 +64,48 @@ fn unnest_json_calls(func: &ScalarFunction) -> Option<ScalarFunction> {
args.extend(outer_args_iter.cloned());
// See #23, unnest only when all lookup arguments are literals
if args.iter().skip(1).all(|arg| matches!(arg, Expr::Literal(_))) {
Some(ScalarFunction {
Some(Transformed::yes(Expr::ScalarFunction(ScalarFunction {
func: func.func.clone(),
args,
})
})))
} else {
None
}
}

fn switch_json_get(cast_data_type: &DataType, args: &[Expr]) -> Option<Transformed<Expr>> {
let func = match cast_data_type {
DataType::Boolean => crate::json_get_bool::json_get_bool_udf(),
DataType::Float64 | DataType::Float32 => crate::json_get_float::json_get_float_udf(),
DataType::Int64 | DataType::Int32 => crate::json_get_int::json_get_int_udf(),
DataType::Utf8 => crate::json_get_str::json_get_str_udf(),
_ => return None,
};
let f = ScalarFunction {
func,
args: args.to_vec(),
};
Some(Transformed::yes(Expr::ScalarFunction(f)))
fn extract_scalar_function(expr: &Expr) -> Option<&ScalarFunction> {
match expr {
Expr::ScalarFunction(func) => Some(func),
Expr::Alias(alias) => extract_scalar_function(&*alias.expr),
_ => None,
}
}

/// Implement a custom SQL planner to replace postgres JSON operators with custom UDFs
#[derive(Debug, Default)]
pub struct JsonExprPlanner;

impl ExprPlanner for JsonExprPlanner {
fn plan_binary_op(&self, expr: RawBinaryExpr, _schema: &DFSchema) -> Result<PlannerResult<RawBinaryExpr>> {
let (func, op_display) = match &expr.op {
BinaryOperator::Arrow => (crate::json_get::json_get_udf(), "->"),
BinaryOperator::LongArrow => (crate::json_get_str::json_get_str_udf(), "->>"),
BinaryOperator::Question => (crate::json_contains::json_contains_udf(), "?"),
_ => return Ok(PlannerResult::Original(expr)),
};
let alias_name = match &expr.left {
Expr::Alias(alias) => format!("{} {} {}", alias.name, op_display, expr.right),
left_expr => format!("{} {} {}", left_expr, op_display, expr.right),
};

// we put the alias in so that default column titles are `foo -> bar` instead of `json_get(foo, bar)`
Ok(PlannerResult::Planned(Expr::Alias(Alias::new(
Expr::ScalarFunction(ScalarFunction {
func,
args: vec![expr.left, expr.right],
}),
None::<&str>,
alias_name,
))))
}
}
Loading

0 comments on commit a645a62

Please sign in to comment.