From 9b3c27d502adbcda8a98a4de486a9d0baf4307aa Mon Sep 17 00:00:00 2001 From: Pavel Tiunov Date: Wed, 1 Nov 2023 14:57:20 -0700 Subject: [PATCH] feat(cubesql): SQL push down support for `IS NULL` and `IS NOT NULL` expressions --- .../cubesql/src/compile/engine/df/wrapper.rs | 221 +++++++++--------- rust/cubesql/cubesql/src/compile/mod.rs | 29 +++ .../cubesql/src/compile/rewrite/cost.rs | 2 +- .../rewrite/rules/wrapper/is_null_expr.rs | 108 +++++++++ .../src/compile/rewrite/rules/wrapper/mod.rs | 2 + rust/cubesql/cubesql/src/compile/test/mod.rs | 3 +- rust/cubesql/cubesql/src/transport/service.rs | 7 + 7 files changed, 263 insertions(+), 109 deletions(-) create mode 100644 rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/is_null_expr.rs diff --git a/rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs b/rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs index 5bdd374e4a489..3f891b51e42a7 100644 --- a/rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs +++ b/rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs @@ -713,6 +713,8 @@ impl CubeScanWrapperNode { ungrouped_scan_node.clone(), ) .await?; + let expr_sql = + Self::escape_interpolation_quotes(expr_sql, ungrouped_scan_node.is_some()); sql = new_sql_query; let original_alias = expr_name(&original_expr, &schema)?; @@ -889,18 +891,15 @@ impl CubeScanWrapperNode { ungrouped_scan_node.clone(), ) .await?; - let resulting_sql = Self::escape_interpolation_quotes( - sql_generator - .get_sql_templates() - .binary_expr(left, op.to_string(), right) - .map_err(|e| { - DataFusionError::Internal(format!( - "Can't generate SQL for binary expr: {}", - e - )) - })?, - ungrouped_scan_node.is_some(), - ); + let resulting_sql = sql_generator + .get_sql_templates() + .binary_expr(left, op.to_string(), right) + .map_err(|e| { + DataFusionError::Internal(format!( + "Can't generate SQL for binary expr: {}", + e + )) + })?; Ok((resulting_sql, sql_query)) } // Expr::AnyExpr { .. } => {} @@ -908,8 +907,46 @@ impl CubeScanWrapperNode { // Expr::ILike(_) => {} // Expr::SimilarTo(_) => {} // Expr::Not(_) => {} - // Expr::IsNotNull(_) => {} - // Expr::IsNull(_) => {} + Expr::IsNotNull(expr) => { + let (expr, sql_query) = Self::generate_sql_for_expr( + plan.clone(), + sql_query, + sql_generator.clone(), + *expr, + ungrouped_scan_node.clone(), + ) + .await?; + let resulting_sql = sql_generator + .get_sql_templates() + .is_null_expr(expr, true) + .map_err(|e| { + DataFusionError::Internal(format!( + "Can't generate SQL for is not null expr: {}", + e + )) + })?; + Ok((resulting_sql, sql_query)) + } + Expr::IsNull(expr) => { + let (expr, sql_query) = Self::generate_sql_for_expr( + plan.clone(), + sql_query, + sql_generator.clone(), + *expr, + ungrouped_scan_node.clone(), + ) + .await?; + let resulting_sql = sql_generator + .get_sql_templates() + .is_null_expr(expr, false) + .map_err(|e| { + DataFusionError::Internal(format!( + "Can't generate SQL for is null expr: {}", + e + )) + })?; + Ok((resulting_sql, sql_query)) + } // Expr::Negative(_) => {} // Expr::GetIndexedField { .. } => {} // Expr::Between { .. } => {} @@ -967,18 +1004,12 @@ impl CubeScanWrapperNode { } else { None }; - let resulting_sql = Self::escape_interpolation_quotes( - sql_generator - .get_sql_templates() - .case(expr, when_then_expr_sql, else_expr) - .map_err(|e| { - DataFusionError::Internal(format!( - "Can't generate SQL for case: {}", - e - )) - })?, - ungrouped_scan_node.is_some(), - ); + let resulting_sql = sql_generator + .get_sql_templates() + .case(expr, when_then_expr_sql, else_expr) + .map_err(|e| { + DataFusionError::Internal(format!("Can't generate SQL for case: {}", e)) + })?; Ok((resulting_sql, sql_query)) } Expr::Cast { expr, data_type } => { @@ -1022,18 +1053,12 @@ impl CubeScanWrapperNode { ))); } }; - let resulting_sql = Self::escape_interpolation_quotes( - sql_generator - .get_sql_templates() - .cast_expr(expr, data_type.to_string()) - .map_err(|e| { - DataFusionError::Internal(format!( - "Can't generate SQL for cast: {}", - e - )) - })?, - ungrouped_scan_node.is_some(), - ); + let resulting_sql = sql_generator + .get_sql_templates() + .cast_expr(expr, data_type.to_string()) + .map_err(|e| { + DataFusionError::Internal(format!("Can't generate SQL for cast: {}", e)) + })?; Ok((resulting_sql, sql_query)) } // Expr::TryCast { .. } => {} @@ -1050,18 +1075,15 @@ impl CubeScanWrapperNode { ungrouped_scan_node.clone(), ) .await?; - let resulting_sql = Self::escape_interpolation_quotes( - sql_generator - .get_sql_templates() - .sort_expr(expr, asc, nulls_first) - .map_err(|e| { - DataFusionError::Internal(format!( - "Can't generate SQL for sort expr: {}", - e - )) - })?, - ungrouped_scan_node.is_some(), - ); + let resulting_sql = sql_generator + .get_sql_templates() + .sort_expr(expr, asc, nulls_first) + .map_err(|e| { + DataFusionError::Internal(format!( + "Can't generate SQL for sort expr: {}", + e + )) + })?; Ok((resulting_sql, sql_query)) } @@ -1142,18 +1164,15 @@ impl CubeScanWrapperNode { }; let interval = format!("{} {}", num, date_part); ( - Self::escape_interpolation_quotes( - sql_generator - .get_sql_templates() - .interval_expr(interval, num, date_part.to_string()) - .map_err(|e| { - DataFusionError::Internal(format!( - "Can't generate SQL for interval: {}", - e - )) - })?, - ungrouped_scan_node.is_some(), - ), + sql_generator + .get_sql_templates() + .interval_expr(interval, num, date_part.to_string()) + .map_err(|e| { + DataFusionError::Internal(format!( + "Can't generate SQL for interval: {}", + e + )) + })?, sql_query, ) } else { @@ -1185,18 +1204,15 @@ impl CubeScanWrapperNode { sql_args.push(sql); } Ok(( - Self::escape_interpolation_quotes( - sql_generator - .get_sql_templates() - .scalar_function(fun.name.to_string(), sql_args, None) - .map_err(|e| { - DataFusionError::Internal(format!( - "Can't generate SQL for scalar function: {}", - e - )) - })?, - ungrouped_scan_node.is_some(), - ), + sql_generator + .get_sql_templates() + .scalar_function(fun.name.to_string(), sql_args, None) + .map_err(|e| { + DataFusionError::Internal(format!( + "Can't generate SQL for scalar function: {}", + e + )) + })?, sql_query, )) } @@ -1221,18 +1237,15 @@ impl CubeScanWrapperNode { ) .await?; return Ok(( - Self::escape_interpolation_quotes( - sql_generator - .get_sql_templates() - .extract_expr(date_part.to_string(), arg_sql) - .map_err(|e| { - DataFusionError::Internal(format!( + sql_generator + .get_sql_templates() + .extract_expr(date_part.to_string(), arg_sql) + .map_err(|e| { + DataFusionError::Internal(format!( "Can't generate SQL for scalar function: {}", e )) - })?, - ungrouped_scan_node.is_some(), - ), + })?, query, )); } @@ -1269,18 +1282,15 @@ impl CubeScanWrapperNode { sql_args.push(sql); } Ok(( - Self::escape_interpolation_quotes( - sql_generator - .get_sql_templates() - .scalar_function(fun.to_string(), sql_args, date_part) - .map_err(|e| { - DataFusionError::Internal(format!( - "Can't generate SQL for scalar function: {}", - e - )) - })?, - ungrouped_scan_node.is_some(), - ), + sql_generator + .get_sql_templates() + .scalar_function(fun.to_string(), sql_args, date_part) + .map_err(|e| { + DataFusionError::Internal(format!( + "Can't generate SQL for scalar function: {}", + e + )) + })?, sql_query, )) } @@ -1311,18 +1321,15 @@ impl CubeScanWrapperNode { sql_args.push(sql); } Ok(( - Self::escape_interpolation_quotes( - sql_generator - .get_sql_templates() - .aggregate_function(fun, sql_args, distinct) - .map_err(|e| { - DataFusionError::Internal(format!( - "Can't generate SQL for aggregate function: {}", - e - )) - })?, - ungrouped_scan_node.is_some(), - ), + sql_generator + .get_sql_templates() + .aggregate_function(fun, sql_args, distinct) + .map_err(|e| { + DataFusionError::Internal(format!( + "Can't generate SQL for aggregate function: {}", + e + )) + })?, sql_query, )) } diff --git a/rust/cubesql/cubesql/src/compile/mod.rs b/rust/cubesql/cubesql/src/compile/mod.rs index 4e40a470afcb5..1e4bc2af486d2 100644 --- a/rust/cubesql/cubesql/src/compile/mod.rs +++ b/rust/cubesql/cubesql/src/compile/mod.rs @@ -18897,6 +18897,35 @@ ORDER BY \"COUNT(count)\" DESC" ); } + #[tokio::test] + async fn test_case_wrapper_with_null() { + if !Rewriter::sql_push_down_enabled() { + return; + } + init_logger(); + + let query_plan = convert_select_to_query_plan( + "SELECT CASE WHEN taxful_total_price IS NULL THEN NULL WHEN taxful_total_price < taxful_total_price * 2 THEN taxful_total_price END FROM KibanaSampleDataEcommerce GROUP BY 1" + .to_string(), + DatabaseProtocol::PostgreSQL, + ) + .await; + + let logical_plan = query_plan.as_logical_plan(); + assert!(logical_plan + .find_cube_scan_wrapper() + .wrapped_sql + .unwrap() + .sql + .contains("CASE WHEN")); + + let physical_plan = query_plan.as_physical_plan().await.unwrap(); + println!( + "Physical plan: {}", + displayable(physical_plan.as_ref()).indent() + ); + } + #[tokio::test] async fn test_case_wrapper_ungrouped_on_dimension() { if !Rewriter::sql_push_down_enabled() { diff --git a/rust/cubesql/cubesql/src/compile/rewrite/cost.rs b/rust/cubesql/cubesql/src/compile/rewrite/cost.rs index ee351ba2d17ab..120461d96d243 100644 --- a/rust/cubesql/cubesql/src/compile/rewrite/cost.rs +++ b/rust/cubesql/cubesql/src/compile/rewrite/cost.rs @@ -18,11 +18,11 @@ pub struct BestCubePlan; pub struct CubePlanCost { replacers: i64, table_scans: i64, + empty_wrappers: i64, non_detected_cube_scans: i64, filters: i64, structure_points: i64, filter_members: i64, - empty_wrappers: i64, member_errors: i64, wrapper_nodes: i64, ast_size_outside_wrapper: usize, diff --git a/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/is_null_expr.rs b/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/is_null_expr.rs new file mode 100644 index 0000000000000..e60039a9f06d0 --- /dev/null +++ b/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/is_null_expr.rs @@ -0,0 +1,108 @@ +use crate::{ + compile::rewrite::{ + analysis::LogicalPlanAnalysis, is_not_null_expr, is_null_expr, rewrite, + rules::wrapper::WrapperRules, transforming_rewrite, wrapper_pullup_replacer, + wrapper_pushdown_replacer, LogicalPlanLanguage, WrapperPullupReplacerAliasToCube, + }, + var, var_iter, +}; +use egg::{EGraph, Rewrite, Subst}; + +impl WrapperRules { + pub fn is_null_expr_rules( + &self, + rules: &mut Vec>, + ) { + rules.extend(vec![ + rewrite( + "wrapper-push-down-is-null-expr", + wrapper_pushdown_replacer( + is_null_expr("?expr"), + "?alias_to_cube", + "?ungrouped", + "?cube_members", + ), + is_null_expr(wrapper_pushdown_replacer( + "?expr", + "?alias_to_cube", + "?ungrouped", + "?cube_members", + )), + ), + transforming_rewrite( + "wrapper-pull-up-is-null-expr", + is_null_expr(wrapper_pullup_replacer( + "?expr", + "?alias_to_cube", + "?ungrouped", + "?cube_members", + )), + wrapper_pullup_replacer( + is_null_expr("?expr"), + "?alias_to_cube", + "?ungrouped", + "?cube_members", + ), + self.transform_is_null_expr("?alias_to_cube"), + ), + rewrite( + "wrapper-push-down-is-not-null-expr", + wrapper_pushdown_replacer( + is_not_null_expr("?expr"), + "?alias_to_cube", + "?ungrouped", + "?cube_members", + ), + is_not_null_expr(wrapper_pushdown_replacer( + "?expr", + "?alias_to_cube", + "?ungrouped", + "?cube_members", + )), + ), + transforming_rewrite( + "wrapper-pull-up-is-not-null-expr", + is_not_null_expr(wrapper_pullup_replacer( + "?expr", + "?alias_to_cube", + "?ungrouped", + "?cube_members", + )), + wrapper_pullup_replacer( + is_not_null_expr("?expr"), + "?alias_to_cube", + "?ungrouped", + "?cube_members", + ), + self.transform_is_null_expr("?alias_to_cube"), + ), + ]); + } + + fn transform_is_null_expr( + &self, + alias_to_cube_var: &'static str, + ) -> impl Fn(&mut EGraph, &mut Subst) -> bool { + let alias_to_cube_var = var!(alias_to_cube_var); + let meta = self.cube_context.meta.clone(); + move |egraph, subst| { + for alias_to_cube in var_iter!( + egraph[subst[alias_to_cube_var]], + WrapperPullupReplacerAliasToCube + ) + .cloned() + { + if let Some(sql_generator) = meta.sql_generator_by_alias_to_cube(&alias_to_cube) { + if sql_generator + .get_sql_templates() + .templates + .contains_key("expressions/is_null") + { + return true; + } + } + } + false + } + } +} diff --git a/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/mod.rs b/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/mod.rs index c2c6bb7a0db80..9b685426ee00b 100644 --- a/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/mod.rs +++ b/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/mod.rs @@ -7,6 +7,7 @@ mod cast; mod column; mod cube_scan_wrapper; mod extract; +mod is_null_expr; mod limit; mod literal; mod order; @@ -50,6 +51,7 @@ impl RewriteRules for WrapperRules { self.alias_rules(&mut rules); self.case_rules(&mut rules); self.binary_expr_rules(&mut rules); + self.is_null_expr_rules(&mut rules); self.sort_expr_rules(&mut rules); self.cast_rules(&mut rules); self.column_rules(&mut rules); diff --git a/rust/cubesql/cubesql/src/compile/test/mod.rs b/rust/cubesql/cubesql/src/compile/test/mod.rs index cff33f472bc6d..b657f4473709e 100644 --- a/rust/cubesql/cubesql/src/compile/test/mod.rs +++ b/rust/cubesql/cubesql/src/compile/test/mod.rs @@ -228,7 +228,8 @@ pub fn get_test_tenant_ctx() -> Arc { "{{expr}} {{quoted_alias}}".to_string(), ), ("expressions/binary".to_string(), "{{ left }} {{ op }} {{ right }}".to_string()), - ("expressions/case".to_string(), "CASE {% if expr %}{{ expr }} {% endif %}{% for when, then in when_then %}WHEN {{ when }} THEN {{ then }}{% endfor %}{% if else_expr %} ELSE {{ else_expr }}{% endif %} END".to_string()), + ("expressions/is_null".to_string(), "{{ expr }} {% if negate %}NOT {% endif %}IS NULL".to_string()), + ("expressions/case".to_string(), "CASE{% if expr %}{{ expr }} {% endif %}{% for when, then in when_then %} WHEN {{ when }} THEN {{ then }}{% endfor %}{% if else_expr %} ELSE {{ else_expr }}{% endif %} END".to_string()), ("expressions/sort".to_string(), "{{ expr }} {% if asc %}ASC{% else %}DESC{% endif %}{% if nulls_first %} NULLS FIRST {% endif %}".to_string()), ("expressions/cast".to_string(), "CAST({{ expr }} AS {{ data_type }})".to_string()), ("expressions/interval".to_string(), "INTERVAL '{{ interval }}'".to_string()), diff --git a/rust/cubesql/cubesql/src/transport/service.rs b/rust/cubesql/cubesql/src/transport/service.rs index d77258458a02e..d41cf48fc6b41 100644 --- a/rust/cubesql/cubesql/src/transport/service.rs +++ b/rust/cubesql/cubesql/src/transport/service.rs @@ -442,6 +442,13 @@ impl SqlTemplates { ) } + pub fn is_null_expr(&self, expr: String, negate: bool) -> Result { + self.render_template( + "expressions/is_null", + context! { expr => expr, negate => negate }, + ) + } + pub fn sort_expr( &self, expr: String,