Skip to content

Commit

Permalink
Fix column expression rewrite (#37)
Browse files Browse the repository at this point in the history
  • Loading branch information
Sevenannn authored Jan 4, 2025
1 parent 0cdae99 commit 4e65b2b
Showing 1 changed file with 13 additions and 12 deletions.
25 changes: 13 additions & 12 deletions sources/sql/src/rewrite/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ fn rewrite_column_name_in_expr(

// Table name same as column name
// Shouldn't rewrite in this case
if idx == 0 && start_pos == 0 {
if idx == 0 && table_ref_str.len() == col_name.len() {
return None;
}

Expand Down Expand Up @@ -1381,14 +1381,21 @@ mod tests {
}

#[tokio::test]
async fn test_rewrite_same_column_table_name() -> Result<()> {
async fn test_rewrite_column_name_in_expr() -> Result<()> {
init_tracing();
let ctx = get_test_df_context();

let tests = vec![(
"SELECT app_table FROM (SELECT a app_table from app_table limit 100);",
r#"SELECT app_table FROM (SELECT remote_table.a AS app_table FROM remote_table LIMIT 100)"#,
)];
let tests = vec![
(
// Column alias name same as table name
"SELECT app_table FROM (SELECT a app_table from app_table limit 100);",
r#"SELECT app_table FROM (SELECT remote_table.a AS app_table FROM remote_table LIMIT 100)"#,
),
(
"SELECT a - 1, COUNT(*) AS c FROM app_table GROUP BY a - 1;",
r#"SELECT (remote_table.a - 1), count(*) AS c FROM remote_table GROUP BY (remote_table.a - 1)"#,
),
];

for test in tests {
test_sql(&ctx, test.0, test.1, false).await?;
Expand All @@ -1405,8 +1412,6 @@ mod tests {
) -> Result<(), datafusion::error::DataFusionError> {
let data_frame = ctx.sql(sql_query).await?;

// println!("before optimization: \n{:#?}", data_frame.logical_plan());

let mut known_rewrites = HashMap::new();
let rewritten_plan = rewrite_table_scans(
data_frame.logical_plan(),
Expand All @@ -1415,12 +1420,8 @@ mod tests {
&mut None,
)?;

// println!("rewritten_plan: \n{:#?}", rewritten_plan);

let unparsed_sql = plan_to_sql(&rewritten_plan)?;

println!("unparsed_sql: \n{unparsed_sql}");

assert_eq!(
format!("{unparsed_sql}"),
expected_sql,
Expand Down

0 comments on commit 4e65b2b

Please sign in to comment.