Skip to content

Commit

Permalink
refactor(cubesql): Extract cube join condition check for rewrites to …
Browse files Browse the repository at this point in the history
…function
  • Loading branch information
mcheshkov committed Nov 13, 2024
1 parent 5fd13d1 commit 4659f8c
Showing 1 changed file with 87 additions and 71 deletions.
158 changes: 87 additions & 71 deletions rust/cubesql/cubesql/src/compile/rewrite/rules/members.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2713,77 +2713,14 @@ impl MemberRules {
let left_aliases_var = var!(left_aliases_var);
let right_aliases_var = var!(right_aliases_var);
move |egraph, subst| {
if egraph
.index(subst[left_aliases_var])
.data
.member_name_to_expr
.is_some()
{
if egraph
.index(subst[right_aliases_var])
.data
.member_name_to_expr
.is_some()
{
let left_join_ons: Vec<Vec<_>> =
var_iter!(egraph[subst[left_on_var]], JoinLeftOn)
.map(|elem| elem.iter().cloned().collect())
.collect();
for left_join_on in left_join_ons {
for join_on in left_join_on {
let member_names_to_expr_left = &mut egraph
.index_mut(subst[left_aliases_var])
.data
.member_name_to_expr
.as_mut()
.unwrap();

// TODO: Avoid the join_on.*.clone() calls (should be trivial).
let mut column_name = join_on.name.clone();
if let Some(name) = find_column_by_alias(
&column_name,
member_names_to_expr_left,
&join_on.relation.clone().unwrap_or_default(),
) {
column_name = name.split(".").last().unwrap().to_string();
}

if column_name == "__cubeJoinField" {
let right_join_ons: Vec<Vec<_>> =
var_iter!(egraph[subst[right_on_var]], JoinRightOn)
.map(|elem| elem.iter().cloned().collect())
.collect();
for right_join_on in right_join_ons {
for join_on in right_join_on.iter() {
let member_names_to_expr_right = &mut egraph
.index_mut(subst[right_aliases_var])
.data
.member_name_to_expr
.as_mut()
.unwrap();

let mut column_name = join_on.name.clone();
if let Some(name) = find_column_by_alias(
&column_name,
member_names_to_expr_right,
&join_on.relation.clone().unwrap_or_default(),
) {
column_name =
name.split(".").last().unwrap().to_string();
}

if column_name == "__cubeJoinField" {
return true;
}
}
}
}
}
}
}
}

false
is_proper_cube_join_condition(
egraph,
subst,
left_aliases_var,
left_on_var,
right_aliases_var,
right_on_var,
)
}
}

Expand Down Expand Up @@ -2961,6 +2898,85 @@ fn find_column_by_alias(
None
}

fn is_proper_cube_join_condition(
egraph: &mut CubeEGraph,
subst: &Subst,
left_cube_members_var: Var,
left_on_var: Var,
right_cube_members_var: Var,
right_on_var: Var,
) -> bool {
if egraph
.index(subst[left_cube_members_var])
.data
.member_name_to_expr
.is_some()
{
if egraph
.index(subst[right_cube_members_var])
.data
.member_name_to_expr
.is_some()
{
let left_join_ons: Vec<Vec<_>> = var_iter!(egraph[subst[left_on_var]], JoinLeftOn)
.map(|elem| elem.iter().cloned().collect())
.collect();
for left_join_on in left_join_ons {
for join_on in left_join_on {
let member_names_to_expr_left = &mut egraph
.index_mut(subst[left_cube_members_var])
.data
.member_name_to_expr
.as_mut()
.unwrap();

// TODO: Avoid the join_on.*.clone() calls (should be trivial).
let mut column_name = join_on.name.clone();
if let Some(name) = find_column_by_alias(
&column_name,
member_names_to_expr_left,
&join_on.relation.clone().unwrap_or_default(),
) {
column_name = name.split(".").last().unwrap().to_string();
}

if column_name == "__cubeJoinField" {
let right_join_ons: Vec<Vec<_>> =
var_iter!(egraph[subst[right_on_var]], JoinRightOn)
.map(|elem| elem.iter().cloned().collect())
.collect();
for right_join_on in right_join_ons {
for join_on in right_join_on.iter() {
let member_names_to_expr_right = &mut egraph
.index_mut(subst[right_cube_members_var])
.data
.member_name_to_expr
.as_mut()
.unwrap();

let mut column_name = join_on.name.clone();
if let Some(name) = find_column_by_alias(
&column_name,
member_names_to_expr_right,
&join_on.relation.clone().unwrap_or_default(),
) {
column_name = name.split(".").last().unwrap().to_string();
}

if column_name == "__cubeJoinField" {
return true;
}

Check warning on line 2968 in rust/cubesql/cubesql/src/compile/rewrite/rules/members.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/rewrite/rules/members.rs#L2968

Added line #L2968 was not covered by tests
}
}
}
}
}
}
}

false
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down

0 comments on commit 4659f8c

Please sign in to comment.