diff --git a/packages/cubejs-schema-compiler/src/adapter/BaseQuery.js b/packages/cubejs-schema-compiler/src/adapter/BaseQuery.js index 040a8ddc3ea17..0c5b8879f943e 100644 --- a/packages/cubejs-schema-compiler/src/adapter/BaseQuery.js +++ b/packages/cubejs-schema-compiler/src/adapter/BaseQuery.js @@ -2479,6 +2479,7 @@ class BaseQuery { binary: '({{ left }} {{ op }} {{ right }})', sort: '{{ expr }} {% if asc %}ASC{% else %}DESC{% endif %}{% if nulls_first %} NULLS FIRST{% endif %}', cast: 'CAST({{ expr }} AS {{ data_type }})', + window_function: '{{ fun_call }} OVER ({% if partition_by %}PARTITION BY {{ partition_by }}{% if order_by %} {% endif %}{% endif %}{% if order_by %}ORDER BY {{ order_by }}{% endif %})' }, quotes: { identifiers: '"', diff --git a/rust/cubesql/cubesql/src/compile/engine/df/scan.rs b/rust/cubesql/cubesql/src/compile/engine/df/scan.rs index 4c2a8e9132d65..a9044c2b89f80 100644 --- a/rust/cubesql/cubesql/src/compile/engine/df/scan.rs +++ b/rust/cubesql/cubesql/src/compile/engine/df/scan.rs @@ -140,6 +140,7 @@ pub struct WrappedSelectNode { pub projection_expr: Vec, pub group_expr: Vec, pub aggr_expr: Vec, + pub window_expr: Vec, pub from: Arc, pub joins: Vec<(Arc, Expr, JoinType)>, pub filter_expr: Vec, @@ -158,6 +159,7 @@ impl WrappedSelectNode { projection_expr: Vec, group_expr: Vec, aggr_expr: Vec, + window_expr: Vec, from: Arc, joins: Vec<(Arc, Expr, JoinType)>, filter_expr: Vec, @@ -174,6 +176,7 @@ impl WrappedSelectNode { projection_expr, group_expr, aggr_expr, + window_expr, from, joins, filter_expr, @@ -207,6 +210,7 @@ impl UserDefinedLogicalNode for WrappedSelectNode { exprs.extend(self.projection_expr.clone()); exprs.extend(self.group_expr.clone()); exprs.extend(self.aggr_expr.clone()); + exprs.extend(self.window_expr.clone()); exprs.extend(self.joins.iter().map(|(_, expr, _)| expr.clone())); exprs.extend(self.filter_expr.clone()); exprs.extend(self.having_expr.clone()); @@ -217,11 +221,12 @@ impl UserDefinedLogicalNode for WrappedSelectNode { fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( f, - "WrappedSelect: select_type={:?}, projection_expr={:?}, group_expr={:?}, aggregate_expr={:?}, from={:?}, joins={:?}, filter_expr={:?}, having_expr={:?}, limit={:?}, offset={:?}, order_expr={:?}, alias={:?}", + "WrappedSelect: select_type={:?}, projection_expr={:?}, group_expr={:?}, aggregate_expr={:?}, window_expr={:?}, from={:?}, joins={:?}, filter_expr={:?}, having_expr={:?}, limit={:?}, offset={:?}, order_expr={:?}, alias={:?}", self.select_type, self.projection_expr, self.group_expr, self.aggr_expr, + self.window_expr, self.from, self.joins, self.filter_expr, @@ -261,6 +266,7 @@ impl UserDefinedLogicalNode for WrappedSelectNode { let mut projection_expr = vec![]; let mut group_expr = vec![]; let mut aggregate_expr = vec![]; + let mut window_expr = vec![]; let limit = None; let offset = None; let alias = None; @@ -278,6 +284,10 @@ impl UserDefinedLogicalNode for WrappedSelectNode { aggregate_expr.push(exprs_iter.next().unwrap().clone()); } + for _ in self.window_expr.iter() { + window_expr.push(exprs_iter.next().unwrap().clone()); + } + for _ in self.joins.iter() { joins_expr.push(exprs_iter.next().unwrap().clone()); } @@ -300,6 +310,7 @@ impl UserDefinedLogicalNode for WrappedSelectNode { projection_expr, group_expr, aggregate_expr, + window_expr, from, joins .into_iter() diff --git a/rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs b/rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs index f90844c2a0d39..b594c78d93eec 100644 --- a/rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs +++ b/rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs @@ -297,6 +297,7 @@ impl CubeScanWrapperNode { projection_expr, group_expr, aggr_expr, + window_expr, from, joins: _joins, filter_expr: _filter_expr, @@ -431,6 +432,20 @@ impl CubeScanWrapperNode { ungrouped_scan_node.clone(), ) .await?; + + let (window, sql) = Self::generate_column_expr( + plan.clone(), + schema.clone(), + window_expr.clone(), + sql, + generator.clone(), + &column_remapping, + &mut next_remapping, + alias.clone(), + can_rename_columns, + ungrouped_scan_node.clone(), + ) + .await?; // Sort node always comes on top and pushed down to select so we need to replace columns here by appropriate column definitions let order_replace_map = projection_expr .iter() @@ -504,6 +519,12 @@ impl CubeScanWrapperNode { ) }), ) + .chain(window.iter().map(|m| { + Self::ungrouped_member_def( + m, + &ungrouped_scan_node.used_cubes, + ) + })) .collect::>()?, ); load_request.dimensions = Some( @@ -1333,7 +1354,80 @@ impl CubeScanWrapperNode { sql_query, )) } - // Expr::WindowFunction { .. } => {} + Expr::WindowFunction { + fun, + args, + partition_by, + order_by, + window_frame, + } => { + let mut sql_args = Vec::new(); + for arg in args { + let (sql, query) = Self::generate_sql_for_expr( + plan.clone(), + sql_query, + sql_generator.clone(), + arg, + ungrouped_scan_node.clone(), + ) + .await?; + sql_query = query; + sql_args.push(sql); + } + let mut sql_partition_by = Vec::new(); + for arg in partition_by { + let (sql, query) = Self::generate_sql_for_expr( + plan.clone(), + sql_query, + sql_generator.clone(), + arg, + ungrouped_scan_node.clone(), + ) + .await?; + sql_query = query; + sql_partition_by.push(sql); + } + let mut sql_order_by = Vec::new(); + for arg in order_by { + let (sql, query) = Self::generate_sql_for_expr( + plan.clone(), + sql_query, + sql_generator.clone(), + arg, + ungrouped_scan_node.clone(), + ) + .await?; + sql_query = query; + sql_order_by.push( + sql_generator + .get_sql_templates() + // TODO asc/desc + .sort_expr(sql, true, false) + .map_err(|e| { + DataFusionError::Internal(format!( + "Can't generate SQL for sort expr: {}", + e + )) + })?, + ); + } + let resulting_sql = sql_generator + .get_sql_templates() + .window_function_expr( + fun, + sql_args, + sql_partition_by, + sql_order_by, + window_frame, + ) + .map_err(|e| { + DataFusionError::Internal(format!( + "Can't generate SQL for window function: {}", + e + )) + })?; + Ok((resulting_sql, sql_query)) + } // Expr::AggregateUDF { .. } => {} // Expr::InList { .. } => {} // Expr::Wildcard => {} diff --git a/rust/cubesql/cubesql/src/compile/mod.rs b/rust/cubesql/cubesql/src/compile/mod.rs index 1e4bc2af486d2..354b4f6d5de43 100644 --- a/rust/cubesql/cubesql/src/compile/mod.rs +++ b/rust/cubesql/cubesql/src/compile/mod.rs @@ -18847,12 +18847,20 @@ ORDER BY \"COUNT(count)\" DESC" .sql .contains("CASE WHEN")); - assert!(logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql - .contains("1123")); + assert!( + logical_plan + .find_cube_scan_wrapper() + .wrapped_sql + .unwrap() + .sql + .contains("1123"), + "SQL contains 1123: {}", + logical_plan + .find_cube_scan_wrapper() + .wrapped_sql + .unwrap() + .sql + ); let physical_plan = query_plan.as_physical_plan().await.unwrap(); println!( @@ -18883,12 +18891,20 @@ ORDER BY \"COUNT(count)\" DESC" .sql .contains("CASE WHEN")); - assert!(logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql - .contains("LIMIT 1123")); + assert!( + logical_plan + .find_cube_scan_wrapper() + .wrapped_sql + .unwrap() + .sql + .contains("1123"), + "SQL contains 1123: {}", + logical_plan + .find_cube_scan_wrapper() + .wrapped_sql + .unwrap() + .sql + ); let physical_plan = query_plan.as_physical_plan().await.unwrap(); println!( @@ -19063,6 +19079,43 @@ ORDER BY \"COUNT(count)\" DESC" .contains("EXTRACT")); } + #[tokio::test] + async fn test_wrapper_window_function() { + if !Rewriter::sql_push_down_enabled() { + return; + } + init_logger(); + + let query_plan = convert_select_to_query_plan( + "SELECT customer_gender, AVG(avgPrice) mp, SUM(COUNT(count)) OVER() FROM KibanaSampleDataEcommerce a GROUP BY 1 LIMIT 100" + .to_string(), + DatabaseProtocol::PostgreSQL, + ) + .await; + + let logical_plan = query_plan.as_logical_plan(); + assert!( + logical_plan + .find_cube_scan_wrapper() + .wrapped_sql + .unwrap() + .sql + .contains("OVER"), + "SQL should contain 'OVER': {}", + logical_plan + .find_cube_scan_wrapper() + .wrapped_sql + .unwrap() + .sql + ); + + let physical_plan = query_plan.as_physical_plan().await.unwrap(); + println!( + "Physical plan: {}", + displayable(physical_plan.as_ref()).indent() + ); + } + #[tokio::test] async fn test_thoughtspot_pg_date_trunc_year() { init_logger(); diff --git a/rust/cubesql/cubesql/src/compile/rewrite/converter.rs b/rust/cubesql/cubesql/src/compile/rewrite/converter.rs index a08d6862e84ca..364227739aaab 100644 --- a/rust/cubesql/cubesql/src/compile/rewrite/converter.rs +++ b/rust/cubesql/cubesql/src/compile/rewrite/converter.rs @@ -41,8 +41,8 @@ use datafusion::{ logical_plan::{ build_join_schema, build_table_udf_schema, exprlist_to_fields, normalize_cols, plan::{Aggregate, Extension, Filter, Join, Projection, Sort, TableUDFs, Window}, - CrossJoin, DFField, DFSchema, DFSchemaRef, Distinct, EmptyRelation, Expr, Like, Limit, - LogicalPlan, LogicalPlanBuilder, TableScan, Union, + replace_col_to_expr, CrossJoin, DFField, DFSchema, DFSchemaRef, Distinct, EmptyRelation, + Expr, Like, Limit, LogicalPlan, LogicalPlanBuilder, TableScan, Union, }, physical_plan::planner::DefaultPhysicalPlanner, scalar::ScalarValue, @@ -1671,8 +1671,10 @@ impl LanguageToLogicalPlanConverter { match_expr_list_node!(node_by_id, to_expr, params[2], WrappedSelectGroupExpr); let aggr_expr = match_expr_list_node!(node_by_id, to_expr, params[3], WrappedSelectAggrExpr); - let from = Arc::new(self.to_logical_plan(params[4])?); - let joins = match_list_node!(node_by_id, params[5], WrappedSelectJoins) + let window_expr = + match_expr_list_node!(node_by_id, to_expr, params[4], WrappedSelectWindowExpr); + let from = Arc::new(self.to_logical_plan(params[5])?); + let joins = match_list_node!(node_by_id, params[6], WrappedSelectJoins) .into_iter() .map(|j| { if let LogicalPlanLanguage::WrappedSelectJoin(params) = j { @@ -1688,28 +1690,49 @@ impl LanguageToLogicalPlanConverter { .collect::, _>>()?; let filter_expr = - match_expr_list_node!(node_by_id, to_expr, params[6], WrappedSelectFilterExpr); + match_expr_list_node!(node_by_id, to_expr, params[7], WrappedSelectFilterExpr); let having_expr = - match_expr_list_node!(node_by_id, to_expr, params[7], WrappedSelectHavingExpr); - let limit = match_data_node!(node_by_id, params[8], WrappedSelectLimit); - let offset = match_data_node!(node_by_id, params[9], WrappedSelectOffset); + match_expr_list_node!(node_by_id, to_expr, params[8], WrappedSelectHavingExpr); + let limit = match_data_node!(node_by_id, params[9], WrappedSelectLimit); + let offset = match_data_node!(node_by_id, params[10], WrappedSelectOffset); let order_expr = - match_expr_list_node!(node_by_id, to_expr, params[10], WrappedSelectOrderExpr); - let alias = match_data_node!(node_by_id, params[11], WrappedSelectAlias); - let ungrouped = match_data_node!(node_by_id, params[12], WrappedSelectUngrouped); + match_expr_list_node!(node_by_id, to_expr, params[11], WrappedSelectOrderExpr); + let alias = match_data_node!(node_by_id, params[12], WrappedSelectAlias); + let ungrouped = match_data_node!(node_by_id, params[13], WrappedSelectUngrouped); let group_expr = normalize_cols(group_expr, &from)?; let aggr_expr = normalize_cols(aggr_expr, &from)?; let projection_expr = normalize_cols(projection_expr, &from)?; - let all_expr = match select_type { + let all_expr_without_window = match select_type { WrappedSelectType::Projection => projection_expr.clone(), WrappedSelectType::Aggregate => { group_expr.iter().chain(aggr_expr.iter()).cloned().collect() } }; + let without_window_fields = + exprlist_to_fields(all_expr_without_window.iter(), from.schema())?; + let replace_map = all_expr_without_window + .iter() + .zip(without_window_fields.iter()) + .map(|(e, f)| (f.qualified_column(), e.clone())) + .collect::>(); + let replace_map = replace_map + .iter() + .map(|(c, e)| (c, e)) + .collect::>(); + let window_expr_rebased = window_expr + .iter() + .map(|e| replace_col_to_expr(e.clone(), &replace_map)) + .collect::, _>>()?; let schema = DFSchema::new_with_metadata( // TODO support joins schema - exprlist_to_fields(all_expr.iter(), from.schema())?, + without_window_fields + .into_iter() + .chain( + exprlist_to_fields(window_expr_rebased.iter(), from.schema())? + .into_iter(), + ) + .collect(), HashMap::new(), )?; @@ -1725,6 +1748,7 @@ impl LanguageToLogicalPlanConverter { projection_expr, group_expr, aggr_expr, + window_expr_rebased, from, joins, filter_expr, diff --git a/rust/cubesql/cubesql/src/compile/rewrite/cost.rs b/rust/cubesql/cubesql/src/compile/rewrite/cost.rs index 120461d96d243..95e0114197cb1 100644 --- a/rust/cubesql/cubesql/src/compile/rewrite/cost.rs +++ b/rust/cubesql/cubesql/src/compile/rewrite/cost.rs @@ -8,11 +8,13 @@ pub struct BestCubePlan; /// This cost struct maintains following structural relationships: /// - `replacers` > other nodes - having replacers in structure means not finished processing /// - `table_scans` > other nodes - having table scan means not detected cube scan +/// - `empty_wrappers` > `non_detected_cube_scans` - we don't want empty wrapper to hide non detected cube scan errors /// - `non_detected_cube_scans` > other nodes - minimize cube scans without members /// - `filters` > `filter_members` - optimize for push down of filters /// - `filter_members` > `cube_members` - optimize for `inDateRange` filter push down to time dimension /// - `member_errors` > `cube_members` - extra cube members may be required (e.g. CASE) /// - `member_errors` > `wrapper_nodes` - use SQL push down where possible if cube scan can't be detected +/// - `non_pushed_down_window` > `wrapper_nodes` - prefer to always push down window functions /// - match errors by priority - optimize for more specific errors #[derive(Debug, Clone, Ord, PartialOrd, Eq, PartialEq)] pub struct CubePlanCost { @@ -24,6 +26,8 @@ pub struct CubePlanCost { structure_points: i64, filter_members: i64, member_errors: i64, + // TODO if pre-aggregation can be used for window functions, then it'd be suboptimal + non_pushed_down_window: i64, wrapper_nodes: i64, ast_size_outside_wrapper: usize, cube_members: i64, @@ -98,6 +102,7 @@ impl CubePlanCost { 0 }) + other.non_detected_cube_scans, filter_members: self.filter_members + other.filter_members, + non_pushed_down_window: self.non_pushed_down_window + other.non_pushed_down_window, member_errors: self.member_errors + other.member_errors, cube_members: self.cube_members + other.cube_members, errors: self.errors + other.errors, @@ -124,6 +129,7 @@ impl CubePlanCost { }, filter_members: self.filter_members, member_errors: self.member_errors, + non_pushed_down_window: self.non_pushed_down_window, cube_members: self.cube_members, errors: self.errors, structure_points: self.structure_points, @@ -186,6 +192,11 @@ impl CostFunction for BestCubePlan { _ => 0, }; + let non_pushed_down_window = match enode { + LogicalPlanLanguage::Window(_) => 1, + _ => 0, + }; + let ast_size_inside_wrapper = match enode { LogicalPlanLanguage::WrappedSelect(_) => 1, _ => 0, @@ -264,6 +275,7 @@ impl CostFunction for BestCubePlan { filter_members, non_detected_cube_scans, member_errors, + non_pushed_down_window, cube_members, errors: this_errors, structure_points, diff --git a/rust/cubesql/cubesql/src/compile/rewrite/mod.rs b/rust/cubesql/cubesql/src/compile/rewrite/mod.rs index c63bcefc2a6a3..4a86a8bbb3aca 100644 --- a/rust/cubesql/cubesql/src/compile/rewrite/mod.rs +++ b/rust/cubesql/cubesql/src/compile/rewrite/mod.rs @@ -249,6 +249,7 @@ crate::plan_to_language! { projection_expr: Vec, group_expr: Vec, aggr_expr: Vec, + window_expr: Vec, from: Arc, joins: Vec, filter_expr: Vec, @@ -691,6 +692,19 @@ fn agg_fun_expr(fun_name: impl Display, args: Vec, distinct: impl ) } +fn window_fun_expr_var_arg( + fun_name: impl Display, + arg_list: impl Display, + partition_by: impl Display, + order_by: impl Display, + window_frame: impl Display, +) -> String { + format!( + "(WindowFunctionExpr {} {} {} {} {})", + fun_name, arg_list, partition_by, order_by, window_frame + ) +} + fn udaf_expr(fun_name: impl Display, args: Vec) -> String { format!( "(AggregateUDFExpr {} {})", @@ -703,11 +717,16 @@ fn limit(skip: impl Display, fetch: impl Display, input: impl Display) -> String format!("(Limit {} {} {})", skip, fetch, input) } +fn window(input: impl Display, window_expr: impl Display) -> String { + format!("(Window {} {})", input, window_expr) +} + fn wrapped_select( select_type: impl Display, projection_expr: impl Display, group_expr: impl Display, aggr_expr: impl Display, + window_expr: impl Display, from: impl Display, joins: impl Display, filter_expr: impl Display, @@ -719,11 +738,12 @@ fn wrapped_select( ungrouped: impl Display, ) -> String { format!( - "(WrappedSelect {} {} {} {} {} {} {} {} {} {} {} {} {})", + "(WrappedSelect {} {} {} {} {} {} {} {} {} {} {} {} {} {})", select_type, projection_expr, group_expr, aggr_expr, + window_expr, from, joins, filter_expr, @@ -763,6 +783,15 @@ fn wrapped_select_aggr_expr_empty_tail() -> String { "WrappedSelectAggrExpr".to_string() } +#[allow(dead_code)] +fn wrapped_select_window_expr(left: impl Display, right: impl Display) -> String { + format!("(WrappedSelectWindowExpr {} {})", left, right) +} + +fn wrapped_select_window_expr_empty_tail() -> String { + "WrappedSelectWindowExpr".to_string() +} + #[allow(dead_code)] fn wrapped_select_joins(left: impl Display, right: impl Display) -> String { format!("(WrappedSelectJoins {} {})", left, right) diff --git a/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/aggregate.rs b/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/aggregate.rs index d5b03f902ddc2..0470437f02ec1 100644 --- a/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/aggregate.rs +++ b/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/aggregate.rs @@ -7,10 +7,10 @@ use crate::{ transforming_chain_rewrite, transforming_rewrite, wrapped_select, wrapped_select_filter_expr_empty_tail, wrapped_select_having_expr_empty_tail, wrapped_select_joins_empty_tail, wrapped_select_order_expr_empty_tail, - wrapped_select_projection_expr_empty_tail, wrapper_pullup_replacer, - wrapper_pushdown_replacer, AggregateFunctionExprDistinct, AggregateFunctionExprFun, - AliasExprAlias, ColumnExprColumn, LogicalPlanLanguage, WrappedSelectUngrouped, - WrapperPullupReplacerUngrouped, + wrapped_select_projection_expr_empty_tail, wrapped_select_window_expr_empty_tail, + wrapper_pullup_replacer, wrapper_pushdown_replacer, AggregateFunctionExprDistinct, + AggregateFunctionExprFun, AliasExprAlias, ColumnExprColumn, LogicalPlanLanguage, + WrappedSelectUngrouped, WrapperPullupReplacerUngrouped, }, transport::V1CubeMetaMeasureExt, var, var_iter, @@ -60,6 +60,12 @@ impl WrapperRules { "?ungrouped", "?cube_members", ), + wrapper_pullup_replacer( + wrapped_select_window_expr_empty_tail(), + "?alias_to_cube", + "?ungrouped", + "?cube_members", + ), wrapper_pullup_replacer( "?cube_scan_input", "?alias_to_cube", diff --git a/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/aggregate_function.rs b/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/aggregate_function.rs index 2437f96d9c38c..3d6a2c4690fdc 100644 --- a/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/aggregate_function.rs +++ b/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/aggregate_function.rs @@ -11,7 +11,7 @@ use datafusion::physical_plan::aggregates::AggregateFunction; use egg::{EGraph, Rewrite, Subst}; impl WrapperRules { - pub fn aggregate_function_rules( + pub fn window_function_rules( &self, rules: &mut Vec>, ) { diff --git a/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/limit.rs b/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/limit.rs index 9b816313d6284..fe2e588897380 100644 --- a/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/limit.rs +++ b/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/limit.rs @@ -22,6 +22,7 @@ impl WrapperRules { "?projection_expr", "?group_expr", "?aggr_expr", + "?window_expr", "?cube_scan_input", "?joins", "?filter_expr", @@ -46,6 +47,7 @@ impl WrapperRules { "?projection_expr", "?group_expr", "?aggr_expr", + "?window_expr", "?cube_scan_input", "?joins", "?filter_expr", diff --git a/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/mod.rs b/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/mod.rs index 9b685426ee00b..e3f469a4778f3 100644 --- a/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/mod.rs +++ b/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/mod.rs @@ -15,6 +15,8 @@ mod projection; mod scalar_function; mod sort_expr; mod udf_function; +mod window; +mod window_function; mod wrapper_pull_up; use crate::compile::{ @@ -44,7 +46,9 @@ impl RewriteRules for WrapperRules { self.projection_rules(&mut rules); self.limit_rules(&mut rules); self.order_rules(&mut rules); + self.window_rules(&mut rules); self.aggregate_function_rules(&mut rules); + self.window_function_rules(&mut rules); self.scalar_function_rules(&mut rules); self.udf_function_rules(&mut rules); self.extract_rules(&mut rules); diff --git a/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/order.rs b/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/order.rs index 931b76b8ba745..30449176b1827 100644 --- a/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/order.rs +++ b/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/order.rs @@ -18,6 +18,7 @@ impl WrapperRules { "?projection_expr", "?group_expr", "?aggr_expr", + "?window_expr", "?cube_scan_input", "?joins", "?filter_expr", @@ -56,6 +57,12 @@ impl WrapperRules { "?ungrouped", "?cube_members", ), + wrapper_pullup_replacer( + "?window_expr", + "?alias_to_cube", + "?ungrouped", + "?cube_members", + ), wrapper_pullup_replacer( "?cube_scan_input", "?alias_to_cube", diff --git a/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/projection.rs b/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/projection.rs index fc08dc6e88ff3..aba2c26458dbc 100644 --- a/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/projection.rs +++ b/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/projection.rs @@ -4,9 +4,9 @@ use crate::{ transforming_rewrite, wrapped_select, wrapped_select_aggr_expr_empty_tail, wrapped_select_filter_expr_empty_tail, wrapped_select_group_expr_empty_tail, wrapped_select_having_expr_empty_tail, wrapped_select_joins_empty_tail, - wrapped_select_order_expr_empty_tail, wrapper_pullup_replacer, wrapper_pushdown_replacer, - LogicalPlanLanguage, ProjectionAlias, WrappedSelectAlias, WrappedSelectUngrouped, - WrapperPullupReplacerUngrouped, + wrapped_select_order_expr_empty_tail, wrapped_select_window_expr_empty_tail, + wrapper_pullup_replacer, wrapper_pushdown_replacer, LogicalPlanLanguage, ProjectionAlias, + WrappedSelectAlias, WrappedSelectUngrouped, WrapperPullupReplacerUngrouped, }, var, var_iter, }; @@ -54,6 +54,12 @@ impl WrapperRules { "?ungrouped", "?cube_members", ), + wrapper_pullup_replacer( + wrapped_select_window_expr_empty_tail(), + "?alias_to_cube", + "?ungrouped", + "?cube_members", + ), wrapper_pullup_replacer( "?cube_scan_input", "?alias_to_cube", diff --git a/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/window.rs b/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/window.rs new file mode 100644 index 0000000000000..81c366cb102e8 --- /dev/null +++ b/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/window.rs @@ -0,0 +1,97 @@ +use crate::compile::rewrite::{ + analysis::LogicalPlanAnalysis, cube_scan_wrapper, rewrite, rules::wrapper::WrapperRules, + window, wrapped_select, wrapped_select_window_expr_empty_tail, wrapper_pullup_replacer, + wrapper_pushdown_replacer, LogicalPlanLanguage, +}; +use egg::Rewrite; + +impl WrapperRules { + pub fn window_rules(&self, rules: &mut Vec>) { + rules.extend(vec![rewrite( + "wrapper-push-down-window-to-cube-scan", + window( + cube_scan_wrapper( + wrapper_pullup_replacer( + wrapped_select( + "?select_type", + "?projection_expr", + "?group_expr", + "?aggr_expr", + wrapped_select_window_expr_empty_tail(), + "?cube_scan_input", + "?joins", + "?filter_expr", + "?having_expr", + "?limit", + "?offset", + "?order_expr", + "?select_alias", + "?select_ungrouped", + ), + "?alias_to_cube", + "?ungrouped", + "?cube_members", + ), + "CubeScanWrapperFinalized:false", + ), + "?window_expr", + ), + cube_scan_wrapper( + wrapped_select( + "?select_type", + wrapper_pullup_replacer( + "?projection_expr", + "?alias_to_cube", + "?ungrouped", + "?cube_members", + ), + wrapper_pullup_replacer( + "?group_expr", + "?alias_to_cube", + "?ungrouped", + "?cube_members", + ), + wrapper_pullup_replacer( + "?aggr_expr", + "?alias_to_cube", + "?ungrouped", + "?cube_members", + ), + wrapper_pushdown_replacer( + "?window_expr", + "?alias_to_cube", + "?ungrouped", + "?cube_members", + ), + wrapper_pullup_replacer( + "?cube_scan_input", + "?alias_to_cube", + "?ungrouped", + "?cube_members", + ), + "?joins", + "?filter_expr", + "?having_expr", + "?limit", + "?offset", + wrapper_pullup_replacer( + "?order_expr", + "?alias_to_cube", + "?ungrouped", + "?cube_members", + ), + "?select_alias", + "?select_ungrouped", + ), + "CubeScanWrapperFinalized:false", + ), + )]); + + Self::list_pushdown_pullup_rules( + rules, + "wrapper-window-expr", + "WindowWindowExpr", + "WrappedSelectWindowExpr", + ); + } +} diff --git a/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/window_function.rs b/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/window_function.rs new file mode 100644 index 0000000000000..6e7d9cf5fd8ad --- /dev/null +++ b/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/window_function.rs @@ -0,0 +1,158 @@ +use crate::{ + compile::rewrite::{ + analysis::LogicalPlanAnalysis, rewrite, rules::wrapper::WrapperRules, transforming_rewrite, + window_fun_expr_var_arg, wrapper_pullup_replacer, wrapper_pushdown_replacer, + LogicalPlanLanguage, WindowFunctionExprFun, WrapperPullupReplacerAliasToCube, + }, + var, var_iter, +}; +use datafusion::physical_plan::window_functions::WindowFunction; +use egg::{EGraph, Rewrite, Subst}; + +impl WrapperRules { + pub fn aggregate_function_rules( + &self, + rules: &mut Vec>, + ) { + rules.extend(vec![ + rewrite( + "wrapper-push-down-window-function", + wrapper_pushdown_replacer( + window_fun_expr_var_arg( + "?fun", + "?expr", + "?partition_by", + "?order_by", + "?window_frame", + ), + "?alias_to_cube", + "?ungrouped", + "?cube_members", + ), + window_fun_expr_var_arg( + "?fun", + wrapper_pushdown_replacer( + "?expr", + "?alias_to_cube", + "?ungrouped", + "?cube_members", + ), + wrapper_pushdown_replacer( + "?partition_by", + "?alias_to_cube", + "?ungrouped", + "?cube_members", + ), + wrapper_pushdown_replacer( + "?order_by", + "?alias_to_cube", + "?ungrouped", + "?cube_members", + ), + "?window_frame", + ), + ), + transforming_rewrite( + "wrapper-pull-up-window-function", + window_fun_expr_var_arg( + "?fun", + wrapper_pullup_replacer( + "?expr", + "?alias_to_cube", + "?ungrouped", + "?cube_members", + ), + wrapper_pullup_replacer( + "?partition_by", + "?alias_to_cube", + "?ungrouped", + "?cube_members", + ), + wrapper_pullup_replacer( + "?order_by", + "?alias_to_cube", + "?ungrouped", + "?cube_members", + ), + "?window_frame", + ), + wrapper_pullup_replacer( + window_fun_expr_var_arg( + "?fun", + "?expr", + "?partition_by", + "?order_by", + "?window_frame", + ), + "?alias_to_cube", + "?ungrouped", + "?cube_members", + ), + self.transform_window_fun_expr("?fun", "?alias_to_cube"), + ), + ]); + + Self::expr_list_pushdown_pullup_rules( + rules, + "wrapper-window-fun-args", + "WindowFunctionExprArgs", + ); + + Self::expr_list_pushdown_pullup_rules( + rules, + "wrapper-window-fun-partition-by", + "WindowFunctionExprPartitionBy", + ); + + Self::expr_list_pushdown_pullup_rules( + rules, + "wrapper-window-fun-order-by", + "WindowFunctionExprOrderBy", + ); + } + + fn transform_window_fun_expr( + &self, + fun_var: &'static str, + alias_to_cube_var: &'static str, + ) -> impl Fn(&mut EGraph, &mut Subst) -> bool { + let fun_var = var!(fun_var); + let alias_to_cube_var = var!(alias_to_cube_var); + let meta = self.cube_context.meta.clone(); + move |egraph, subst| { + for alias_to_cube in var_iter!( + egraph[subst[alias_to_cube_var]], + WrapperPullupReplacerAliasToCube + ) + .cloned() + { + if let Some(sql_generator) = meta.sql_generator_by_alias_to_cube(&alias_to_cube) { + if sql_generator + .get_sql_templates() + .templates + .contains_key("expressions/window_function") + { + for fun in var_iter!(egraph[subst[fun_var]], WindowFunctionExprFun).cloned() + { + let fun = match fun { + WindowFunction::AggregateFunction(agg_fun) => agg_fun.to_string(), + WindowFunction::BuiltInWindowFunction(window_fun) => { + window_fun.to_string() + } + }; + + if sql_generator + .get_sql_templates() + .templates + .contains_key(&format!("functions/{}", fun.as_str())) + { + return true; + } + } + } + } + } + false + } + } +} diff --git a/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/wrapper_pull_up.rs b/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/wrapper_pull_up.rs index 736e20b3f615c..fa701c9f2948b 100644 --- a/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/wrapper_pull_up.rs +++ b/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/wrapper_pull_up.rs @@ -38,6 +38,12 @@ impl WrapperRules { "?ungrouped", "?cube_members", ), + wrapper_pullup_replacer( + "?window_expr", + "?alias_to_cube", + "?ungrouped", + "?cube_members", + ), wrapper_pullup_replacer( "?cube_scan_input", "?alias_to_cube", @@ -67,6 +73,7 @@ impl WrapperRules { "?projection_expr", "?group_expr", "?aggr_expr", + "?window_expr", "?cube_scan_input", wrapped_select_joins_empty_tail(), wrapped_select_filter_expr_empty_tail(), @@ -109,12 +116,19 @@ impl WrapperRules { "?ungrouped", "?cube_members", ), + wrapper_pullup_replacer( + "?window_expr", + "?alias_to_cube", + "?ungrouped", + "?cube_members", + ), wrapper_pullup_replacer( wrapped_select( "?inner_select_type", "?inner_projection_expr", "?inner_group_expr", "?inner_aggr_expr", + "?inner_window_expr", "?inner_cube_scan_input", "?inner_joins", "?inner_filter_expr", @@ -152,11 +166,13 @@ impl WrapperRules { "?projection_expr", "?group_expr", "?aggr_expr", + "?window_expr", wrapped_select( "?inner_select_type", "?inner_projection_expr", "?inner_group_expr", "?inner_aggr_expr", + "?inner_window_expr", "?inner_cube_scan_input", "?inner_joins", "?inner_filter_expr", diff --git a/rust/cubesql/cubesql/src/compile/test/mod.rs b/rust/cubesql/cubesql/src/compile/test/mod.rs index b657f4473709e..094a4199148e5 100644 --- a/rust/cubesql/cubesql/src/compile/test/mod.rs +++ b/rust/cubesql/cubesql/src/compile/test/mod.rs @@ -233,6 +233,7 @@ pub fn get_test_tenant_ctx() -> Arc { ("expressions/sort".to_string(), "{{ expr }} {% if asc %}ASC{% else %}DESC{% endif %}{% if nulls_first %} NULLS FIRST {% endif %}".to_string()), ("expressions/cast".to_string(), "CAST({{ expr }} AS {{ data_type }})".to_string()), ("expressions/interval".to_string(), "INTERVAL '{{ interval }}'".to_string()), + ("expressions/window_function".to_string(), "{{ fun_call }} OVER ({% if partition_by %}PARTITION BY {{ partition_by }}{% if order_by %} {% endif %}{% endif %}{% if order_by %}ORDER BY {{ order_by }}{% endif %})".to_string()), ("quotes/identifiers".to_string(), "\"".to_string()), ("quotes/escape".to_string(), "\"\"".to_string()), ("params/param".to_string(), "${{ param_index + 1 }}".to_string()) diff --git a/rust/cubesql/cubesql/src/transport/service.rs b/rust/cubesql/cubesql/src/transport/service.rs index d41cf48fc6b41..4a655f8784c40 100644 --- a/rust/cubesql/cubesql/src/transport/service.rs +++ b/rust/cubesql/cubesql/src/transport/service.rs @@ -6,7 +6,8 @@ use cubeclient::{ use datafusion::{ arrow::{datatypes::SchemaRef, record_batch::RecordBatch}, - physical_plan::aggregates::AggregateFunction, + logical_plan::window_frames::WindowFrame, + physical_plan::{aggregates::AggregateFunction, window_functions::WindowFunction}, }; use minijinja::{context, value::Value, Environment}; use serde_derive::*; @@ -418,6 +419,54 @@ impl SqlTemplates { ) } + pub fn window_function_name(&self, window_function: WindowFunction) -> String { + match window_function { + WindowFunction::AggregateFunction(aggregate_function) => { + self.aggregate_function_name(aggregate_function, false) + } + WindowFunction::BuiltInWindowFunction(built_in_window_function) => { + built_in_window_function.to_string() + } + } + } + + pub fn window_function( + &self, + window_function: WindowFunction, + args: Vec, + ) -> Result { + let function = self.window_function_name(window_function); + let args_concat = args.join(", "); + self.render_template( + &format!("functions/{}", function), + context! { args_concat => args_concat, args => args }, + ) + } + + pub fn window_function_expr( + &self, + window_function: WindowFunction, + args: Vec, + partition_by: Vec, + order_by: Vec, + _window_frame: Option, + ) -> Result { + let fun_call = self.window_function(window_function, args)?; + let partition_by_concat = partition_by.join(", "); + let order_by_concat = order_by.join(", "); + // TODO window_frame + self.render_template( + "expressions/window_function", + context! { + fun_call => fun_call, + partition_by => partition_by, + partition_by_concat => partition_by_concat, + order_by => order_by, + order_by_concat => order_by_concat + }, + ) + } + pub fn case( &self, expr: Option,