Skip to content

Commit

Permalink
Fix OuterReferenceColumns not being rewritten correctly
Browse files Browse the repository at this point in the history
  • Loading branch information
phillipleblanc committed Jan 10, 2025
1 parent 4e65b2b commit a05bc57
Showing 1 changed file with 69 additions and 10 deletions.
79 changes: 69 additions & 10 deletions sources/sql/src/rewrite/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,56 @@ use datafusion_federation::{get_table_source, table_reference::MultiPartTableRef

use crate::SQLTableSource;

fn collect_known_rewrites(
plan: &LogicalPlan,
known_rewrites: &mut HashMap<TableReference, MultiPartTableReference>,
) -> Result<()> {
if let LogicalPlan::TableScan(table_scan) = plan {
let original_table_name = table_scan.table_name.clone();

if let Some(federated_source) = get_table_source(&table_scan.source)? {
if let Some(sql_table_source) =
federated_source.as_any().downcast_ref::<SQLTableSource>()
{
let remote_table_name = sql_table_source.table_name();
known_rewrites.insert(original_table_name, remote_table_name.clone());
}
}
}

// Recursively collect from all inputs
for input in plan.inputs() {
collect_known_rewrites(input, known_rewrites)?;
}

Ok(())
}

/// Rewrite table scans to use the original federated table name.
pub(crate) fn rewrite_table_scans(
plan: &LogicalPlan,
known_rewrites: &mut HashMap<TableReference, MultiPartTableReference>,
subquery_uses_partial_path: bool,
subquery_table_scans: &mut Option<HashSet<TableReference>>,
) -> Result<LogicalPlan> {
// First pass: collect all known rewrites
collect_known_rewrites(plan, known_rewrites)?;

// Second pass: do the actual rewriting with complete known_rewrites
rewrite_plan_with_known_rewrites(
plan,
known_rewrites,
subquery_uses_partial_path,
subquery_table_scans,
)
}

// Move the main rewriting logic to this function
fn rewrite_plan_with_known_rewrites(
plan: &LogicalPlan,
known_rewrites: &HashMap<TableReference, MultiPartTableReference>,
subquery_uses_partial_path: bool,
subquery_table_scans: &mut Option<HashSet<TableReference>>,
) -> Result<LogicalPlan> {
if plan.inputs().is_empty() {
if let LogicalPlan::TableScan(table_scan) = plan {
Expand All @@ -41,7 +85,6 @@ pub(crate) fn rewrite_table_scans(
match federated_source.as_any().downcast_ref::<SQLTableSource>() {
Some(sql_table_source) => {
let remote_table_name = sql_table_source.table_name();
known_rewrites.insert(original_table_name.clone(), remote_table_name.clone());

// If the remote table name is a MultiPartTableReference, we will not rewrite it here, but rewrite it after the final unparsing on the AST directly.
let MultiPartTableReference::TableReference(remote_table_name) =
Expand Down Expand Up @@ -91,7 +134,7 @@ pub(crate) fn rewrite_table_scans(
.inputs()
.into_iter()
.map(|i| {
rewrite_table_scans(
rewrite_plan_with_known_rewrites(
i,
known_rewrites,
subquery_uses_partial_path,
Expand Down Expand Up @@ -172,7 +215,7 @@ pub(crate) fn rewrite_table_scans(
fn rewrite_unnest_plan(
unnest: &logical_expr::Unnest,
mut rewritten_inputs: Vec<LogicalPlan>,
known_rewrites: &mut HashMap<TableReference, MultiPartTableReference>,
known_rewrites: &HashMap<TableReference, MultiPartTableReference>,
subquery_uses_partial_path: bool,
subquery_table_scans: &mut Option<HashSet<TableReference>>,
) -> Result<LogicalPlan> {
Expand Down Expand Up @@ -391,22 +434,22 @@ fn rewrite_column_name_in_expr(

fn rewrite_table_scans_in_expr(
expr: Expr,
known_rewrites: &mut HashMap<TableReference, MultiPartTableReference>,
known_rewrites: &HashMap<TableReference, MultiPartTableReference>,
subquery_uses_partial_path: bool,
subquery_table_scans: &mut Option<HashSet<TableReference>>,
) -> Result<Expr> {
match expr {
Expr::ScalarSubquery(subquery) => {
let new_subquery = if subquery_table_scans.is_some() || !subquery_uses_partial_path {
rewrite_table_scans(
rewrite_plan_with_known_rewrites(
&subquery.subquery,
known_rewrites,
subquery_uses_partial_path,
subquery_table_scans,
)?
} else {
let mut scans = Some(HashSet::new());
rewrite_table_scans(
rewrite_plan_with_known_rewrites(
&subquery.subquery,
known_rewrites,
subquery_uses_partial_path,
Expand Down Expand Up @@ -888,15 +931,15 @@ fn rewrite_table_scans_in_expr(
}
Expr::Exists(exists) => {
let subquery_plan = if subquery_table_scans.is_some() || !subquery_uses_partial_path {
rewrite_table_scans(
rewrite_plan_with_known_rewrites(
&exists.subquery.subquery,
known_rewrites,
subquery_uses_partial_path,
subquery_table_scans,
)?
} else {
let mut scans = Some(HashSet::new());
rewrite_table_scans(
rewrite_plan_with_known_rewrites(
&exists.subquery.subquery,
known_rewrites,
subquery_uses_partial_path,
Expand Down Expand Up @@ -930,15 +973,15 @@ fn rewrite_table_scans_in_expr(
subquery_table_scans,
)?;
let subquery_plan = if subquery_table_scans.is_some() || !subquery_uses_partial_path {
rewrite_table_scans(
rewrite_plan_with_known_rewrites(
&is.subquery.subquery,
known_rewrites,
subquery_uses_partial_path,
subquery_table_scans,
)?
} else {
let mut scans = Some(HashSet::new());
rewrite_table_scans(
rewrite_plan_with_known_rewrites(
&is.subquery.subquery,
known_rewrites,
subquery_uses_partial_path,
Expand Down Expand Up @@ -1380,6 +1423,22 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn test_rewrite_outer_ref_columns() -> Result<()> {
init_tracing();
let ctx = get_test_df_context();
let tests = vec![(
"SELECT foo.df_table.a FROM bar JOIN foo.df_table ON foo.df_table.a = (SELECT bar.a FROM bar WHERE bar.a > foo.df_table.a)",
r#"SELECT remote_table.a FROM remote_db.remote_schema.remote_table JOIN remote_table ON (remote_table.a = (SELECT a FROM remote_db.remote_schema.remote_table WHERE (remote_table.a > remote_table.a)))"#,
true,
)];
for test in tests {
test_sql(&ctx, test.0, test.1, test.2).await?;
}

Ok(())
}

#[tokio::test]
async fn test_rewrite_column_name_in_expr() -> Result<()> {
init_tracing();
Expand Down

0 comments on commit a05bc57

Please sign in to comment.