Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(cubesql): SQL push down support for window functions #7403

Merged
merged 1 commit into from
Nov 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/cubejs-schema-compiler/src/adapter/BaseQuery.js
Original file line number Diff line number Diff line change
Expand Up @@ -2479,6 +2479,7 @@ class BaseQuery {
binary: '({{ left }} {{ op }} {{ right }})',
sort: '{{ expr }} {% if asc %}ASC{% else %}DESC{% endif %}{% if nulls_first %} NULLS FIRST{% endif %}',
cast: 'CAST({{ expr }} AS {{ data_type }})',
window_function: '{{ fun_call }} OVER ({% if partition_by %}PARTITION BY {{ partition_by }}{% if order_by %} {% endif %}{% endif %}{% if order_by %}ORDER BY {{ order_by }}{% endif %})'
},
quotes: {
identifiers: '"',
Expand Down
13 changes: 12 additions & 1 deletion rust/cubesql/cubesql/src/compile/engine/df/scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ pub struct WrappedSelectNode {
pub projection_expr: Vec<Expr>,
pub group_expr: Vec<Expr>,
pub aggr_expr: Vec<Expr>,
pub window_expr: Vec<Expr>,
pub from: Arc<LogicalPlan>,
pub joins: Vec<(Arc<LogicalPlan>, Expr, JoinType)>,
pub filter_expr: Vec<Expr>,
Expand All @@ -158,6 +159,7 @@ impl WrappedSelectNode {
projection_expr: Vec<Expr>,
group_expr: Vec<Expr>,
aggr_expr: Vec<Expr>,
window_expr: Vec<Expr>,
from: Arc<LogicalPlan>,
joins: Vec<(Arc<LogicalPlan>, Expr, JoinType)>,
filter_expr: Vec<Expr>,
Expand All @@ -174,6 +176,7 @@ impl WrappedSelectNode {
projection_expr,
group_expr,
aggr_expr,
window_expr,
from,
joins,
filter_expr,
Expand Down Expand Up @@ -207,6 +210,7 @@ impl UserDefinedLogicalNode for WrappedSelectNode {
exprs.extend(self.projection_expr.clone());
exprs.extend(self.group_expr.clone());
exprs.extend(self.aggr_expr.clone());
exprs.extend(self.window_expr.clone());
exprs.extend(self.joins.iter().map(|(_, expr, _)| expr.clone()));
exprs.extend(self.filter_expr.clone());
exprs.extend(self.having_expr.clone());
Expand All @@ -217,11 +221,12 @@ impl UserDefinedLogicalNode for WrappedSelectNode {
fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"WrappedSelect: select_type={:?}, projection_expr={:?}, group_expr={:?}, aggregate_expr={:?}, from={:?}, joins={:?}, filter_expr={:?}, having_expr={:?}, limit={:?}, offset={:?}, order_expr={:?}, alias={:?}",
"WrappedSelect: select_type={:?}, projection_expr={:?}, group_expr={:?}, aggregate_expr={:?}, window_expr={:?}, from={:?}, joins={:?}, filter_expr={:?}, having_expr={:?}, limit={:?}, offset={:?}, order_expr={:?}, alias={:?}",
self.select_type,
self.projection_expr,
self.group_expr,
self.aggr_expr,
self.window_expr,
self.from,
self.joins,
self.filter_expr,
Expand Down Expand Up @@ -261,6 +266,7 @@ impl UserDefinedLogicalNode for WrappedSelectNode {
let mut projection_expr = vec![];
let mut group_expr = vec![];
let mut aggregate_expr = vec![];
let mut window_expr = vec![];
let limit = None;
let offset = None;
let alias = None;
Expand All @@ -278,6 +284,10 @@ impl UserDefinedLogicalNode for WrappedSelectNode {
aggregate_expr.push(exprs_iter.next().unwrap().clone());
}

for _ in self.window_expr.iter() {
window_expr.push(exprs_iter.next().unwrap().clone());
}

for _ in self.joins.iter() {
joins_expr.push(exprs_iter.next().unwrap().clone());
}
Expand All @@ -300,6 +310,7 @@ impl UserDefinedLogicalNode for WrappedSelectNode {
projection_expr,
group_expr,
aggregate_expr,
window_expr,
from,
joins
.into_iter()
Expand Down
96 changes: 95 additions & 1 deletion rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ impl CubeScanWrapperNode {
projection_expr,
group_expr,
aggr_expr,
window_expr,
from,
joins: _joins,
filter_expr: _filter_expr,
Expand Down Expand Up @@ -431,6 +432,20 @@ impl CubeScanWrapperNode {
ungrouped_scan_node.clone(),
)
.await?;

let (window, sql) = Self::generate_column_expr(
plan.clone(),
schema.clone(),
window_expr.clone(),
sql,
generator.clone(),
&column_remapping,
&mut next_remapping,
alias.clone(),
can_rename_columns,
ungrouped_scan_node.clone(),
)
.await?;
// Sort node always comes on top and pushed down to select so we need to replace columns here by appropriate column definitions
let order_replace_map = projection_expr
.iter()
Expand Down Expand Up @@ -504,6 +519,12 @@ impl CubeScanWrapperNode {
)
}),
)
.chain(window.iter().map(|m| {
Self::ungrouped_member_def(
m,
&ungrouped_scan_node.used_cubes,
)
}))
.collect::<Result<_>>()?,
);
load_request.dimensions = Some(
Expand Down Expand Up @@ -1333,7 +1354,80 @@ impl CubeScanWrapperNode {
sql_query,
))
}
// Expr::WindowFunction { .. } => {}
Expr::WindowFunction {
fun,
args,
partition_by,
order_by,
window_frame,
} => {
let mut sql_args = Vec::new();
for arg in args {
let (sql, query) = Self::generate_sql_for_expr(
plan.clone(),
sql_query,
sql_generator.clone(),
arg,
ungrouped_scan_node.clone(),
)
.await?;
sql_query = query;
sql_args.push(sql);
}
let mut sql_partition_by = Vec::new();
for arg in partition_by {
let (sql, query) = Self::generate_sql_for_expr(
plan.clone(),
sql_query,
sql_generator.clone(),
arg,
ungrouped_scan_node.clone(),
)
.await?;
sql_query = query;
sql_partition_by.push(sql);
}
let mut sql_order_by = Vec::new();
for arg in order_by {
let (sql, query) = Self::generate_sql_for_expr(
plan.clone(),
sql_query,
sql_generator.clone(),
arg,
ungrouped_scan_node.clone(),
)
.await?;
sql_query = query;
sql_order_by.push(
sql_generator
.get_sql_templates()
// TODO asc/desc
.sort_expr(sql, true, false)
.map_err(|e| {
DataFusionError::Internal(format!(
"Can't generate SQL for sort expr: {}",
e
))
})?,
);
}
let resulting_sql = sql_generator
.get_sql_templates()
.window_function_expr(
fun,
sql_args,
sql_partition_by,
sql_order_by,
window_frame,
)
.map_err(|e| {
DataFusionError::Internal(format!(
"Can't generate SQL for window function: {}",
e
))
})?;
Ok((resulting_sql, sql_query))
}
// Expr::AggregateUDF { .. } => {}
// Expr::InList { .. } => {}
// Expr::Wildcard => {}
Expand Down
77 changes: 65 additions & 12 deletions rust/cubesql/cubesql/src/compile/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18847,12 +18847,20 @@ ORDER BY \"COUNT(count)\" DESC"
.sql
.contains("CASE WHEN"));

assert!(logical_plan
.find_cube_scan_wrapper()
.wrapped_sql
.unwrap()
.sql
.contains("1123"));
assert!(
logical_plan
.find_cube_scan_wrapper()
.wrapped_sql
.unwrap()
.sql
.contains("1123"),
"SQL contains 1123: {}",
logical_plan
.find_cube_scan_wrapper()
.wrapped_sql
.unwrap()
.sql
);

let physical_plan = query_plan.as_physical_plan().await.unwrap();
println!(
Expand Down Expand Up @@ -18883,12 +18891,20 @@ ORDER BY \"COUNT(count)\" DESC"
.sql
.contains("CASE WHEN"));

assert!(logical_plan
.find_cube_scan_wrapper()
.wrapped_sql
.unwrap()
.sql
.contains("LIMIT 1123"));
assert!(
logical_plan
.find_cube_scan_wrapper()
.wrapped_sql
.unwrap()
.sql
.contains("1123"),
"SQL contains 1123: {}",
logical_plan
.find_cube_scan_wrapper()
.wrapped_sql
.unwrap()
.sql
);

let physical_plan = query_plan.as_physical_plan().await.unwrap();
println!(
Expand Down Expand Up @@ -19063,6 +19079,43 @@ ORDER BY \"COUNT(count)\" DESC"
.contains("EXTRACT"));
}

#[tokio::test]
async fn test_wrapper_window_function() {
if !Rewriter::sql_push_down_enabled() {
return;
}
init_logger();

let query_plan = convert_select_to_query_plan(
"SELECT customer_gender, AVG(avgPrice) mp, SUM(COUNT(count)) OVER() FROM KibanaSampleDataEcommerce a GROUP BY 1 LIMIT 100"
.to_string(),
DatabaseProtocol::PostgreSQL,
)
.await;

let logical_plan = query_plan.as_logical_plan();
assert!(
logical_plan
.find_cube_scan_wrapper()
.wrapped_sql
.unwrap()
.sql
.contains("OVER"),
"SQL should contain 'OVER': {}",
logical_plan
.find_cube_scan_wrapper()
.wrapped_sql
.unwrap()
.sql
);

let physical_plan = query_plan.as_physical_plan().await.unwrap();
println!(
"Physical plan: {}",
displayable(physical_plan.as_ref()).indent()
);
}

#[tokio::test]
async fn test_thoughtspot_pg_date_trunc_year() {
init_logger();
Expand Down
50 changes: 37 additions & 13 deletions rust/cubesql/cubesql/src/compile/rewrite/converter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ use datafusion::{
logical_plan::{
build_join_schema, build_table_udf_schema, exprlist_to_fields, normalize_cols,
plan::{Aggregate, Extension, Filter, Join, Projection, Sort, TableUDFs, Window},
CrossJoin, DFField, DFSchema, DFSchemaRef, Distinct, EmptyRelation, Expr, Like, Limit,
LogicalPlan, LogicalPlanBuilder, TableScan, Union,
replace_col_to_expr, CrossJoin, DFField, DFSchema, DFSchemaRef, Distinct, EmptyRelation,
Expr, Like, Limit, LogicalPlan, LogicalPlanBuilder, TableScan, Union,
},
physical_plan::planner::DefaultPhysicalPlanner,
scalar::ScalarValue,
Expand Down Expand Up @@ -1671,8 +1671,10 @@ impl LanguageToLogicalPlanConverter {
match_expr_list_node!(node_by_id, to_expr, params[2], WrappedSelectGroupExpr);
let aggr_expr =
match_expr_list_node!(node_by_id, to_expr, params[3], WrappedSelectAggrExpr);
let from = Arc::new(self.to_logical_plan(params[4])?);
let joins = match_list_node!(node_by_id, params[5], WrappedSelectJoins)
let window_expr =
match_expr_list_node!(node_by_id, to_expr, params[4], WrappedSelectWindowExpr);
let from = Arc::new(self.to_logical_plan(params[5])?);
let joins = match_list_node!(node_by_id, params[6], WrappedSelectJoins)
.into_iter()
.map(|j| {
if let LogicalPlanLanguage::WrappedSelectJoin(params) = j {
Expand All @@ -1688,28 +1690,49 @@ impl LanguageToLogicalPlanConverter {
.collect::<Result<Vec<_>, _>>()?;

let filter_expr =
match_expr_list_node!(node_by_id, to_expr, params[6], WrappedSelectFilterExpr);
match_expr_list_node!(node_by_id, to_expr, params[7], WrappedSelectFilterExpr);
let having_expr =
match_expr_list_node!(node_by_id, to_expr, params[7], WrappedSelectHavingExpr);
let limit = match_data_node!(node_by_id, params[8], WrappedSelectLimit);
let offset = match_data_node!(node_by_id, params[9], WrappedSelectOffset);
match_expr_list_node!(node_by_id, to_expr, params[8], WrappedSelectHavingExpr);
let limit = match_data_node!(node_by_id, params[9], WrappedSelectLimit);
let offset = match_data_node!(node_by_id, params[10], WrappedSelectOffset);
let order_expr =
match_expr_list_node!(node_by_id, to_expr, params[10], WrappedSelectOrderExpr);
let alias = match_data_node!(node_by_id, params[11], WrappedSelectAlias);
let ungrouped = match_data_node!(node_by_id, params[12], WrappedSelectUngrouped);
match_expr_list_node!(node_by_id, to_expr, params[11], WrappedSelectOrderExpr);
let alias = match_data_node!(node_by_id, params[12], WrappedSelectAlias);
let ungrouped = match_data_node!(node_by_id, params[13], WrappedSelectUngrouped);

let group_expr = normalize_cols(group_expr, &from)?;
let aggr_expr = normalize_cols(aggr_expr, &from)?;
let projection_expr = normalize_cols(projection_expr, &from)?;
let all_expr = match select_type {
let all_expr_without_window = match select_type {
WrappedSelectType::Projection => projection_expr.clone(),
WrappedSelectType::Aggregate => {
group_expr.iter().chain(aggr_expr.iter()).cloned().collect()
}
};
let without_window_fields =
exprlist_to_fields(all_expr_without_window.iter(), from.schema())?;
let replace_map = all_expr_without_window
.iter()
.zip(without_window_fields.iter())
.map(|(e, f)| (f.qualified_column(), e.clone()))
.collect::<Vec<_>>();
let replace_map = replace_map
.iter()
.map(|(c, e)| (c, e))
.collect::<HashMap<_, _>>();
let window_expr_rebased = window_expr
.iter()
.map(|e| replace_col_to_expr(e.clone(), &replace_map))
.collect::<Result<Vec<_>, _>>()?;
let schema = DFSchema::new_with_metadata(
// TODO support joins schema
exprlist_to_fields(all_expr.iter(), from.schema())?,
without_window_fields
.into_iter()
.chain(
exprlist_to_fields(window_expr_rebased.iter(), from.schema())?
.into_iter(),
)
.collect(),
HashMap::new(),
)?;

Expand All @@ -1725,6 +1748,7 @@ impl LanguageToLogicalPlanConverter {
projection_expr,
group_expr,
aggr_expr,
window_expr_rebased,
from,
joins,
filter_expr,
Expand Down
Loading