From a25a523e0e4e6def29c1675fee944f07c70d869b Mon Sep 17 00:00:00 2001 From: Phillip LeBlanc Date: Thu, 9 May 2024 14:56:42 +0900 Subject: [PATCH 01/48] Update to DataFusion 37.1 --- Cargo.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 744de0b..f084a1f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,5 +21,5 @@ readme = "README.md" [workspace.dependencies] async-trait = "0.1.77" -datafusion = "37.0.0" -datafusion-substrait = "37.0.0" +datafusion = "37.1.0" +datafusion-substrait = "37.1.0" From a852cd3f7747ef980ece217b27c28d98215b5a57 Mon Sep 17 00:00:00 2001 From: Phillip LeBlanc Date: Thu, 9 May 2024 16:04:21 +0900 Subject: [PATCH 02/48] Allow TableScans of non-federated sources --- datafusion-federation/src/analyzer.rs | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/datafusion-federation/src/analyzer.rs b/datafusion-federation/src/analyzer.rs index 7da905a..f5ba3a1 100644 --- a/datafusion-federation/src/analyzer.rs +++ b/datafusion-federation/src/analyzer.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use datafusion::{ config::ConfigOptions, datasource::source_as_provider, - error::{DataFusionError, Result}, + error::Result, logical_expr::{Expr, LogicalPlan, Projection, TableScan, TableSource}, optimizer::analyzer::AnalyzerRule, }; @@ -121,7 +121,9 @@ impl FederationAnalyzerRule { fn get_federation_provider(&self, plan: &LogicalPlan) -> Result> { match plan { LogicalPlan::TableScan(TableScan { ref source, .. }) => { - let federated_source = get_table_source(source.clone())?; + let Some(federated_source) = get_table_source(source)? else { + return Ok(None); + }; let provider = federated_source.federation_provider(); Ok(Some(provider)) } @@ -149,18 +151,20 @@ fn wrap_projection(plan: LogicalPlan) -> Result { } } -pub fn get_table_source(source: Arc) -> Result> { +pub fn get_table_source( + source: &Arc, +) -> Result>> { // Unwrap TableSource - let source = source_as_provider(&source)?; + let source = source_as_provider(source)?; // Get FederatedTableProviderAdaptor - let wrapper = source + let Some(wrapper) = source .as_any() .downcast_ref::() - .ok_or(DataFusionError::Plan( - "expected a FederatedTableSourceWrapper".to_string(), - ))?; + else { + return Ok(None); + }; // Return original FederatedTableSource - Ok(wrapper.source.clone()) + Ok(Some(Arc::clone(&wrapper.source))) } From c81a766f4c3ba105bbc836324eff58640bd53cc0 Mon Sep 17 00:00:00 2001 From: Phillip LeBlanc Date: Thu, 9 May 2024 16:13:51 +0900 Subject: [PATCH 03/48] .clone() -> Arc::clone(& --- datafusion-federation/src/plan_node.rs | 4 ++-- sources/flight-sql/src/executor/mod.rs | 4 ++-- sources/sql/src/lib.rs | 10 +++++----- sources/sql/src/schema.rs | 15 ++++++++------- 4 files changed, 17 insertions(+), 16 deletions(-) diff --git a/datafusion-federation/src/plan_node.rs b/datafusion-federation/src/plan_node.rs index 35a9306..2940b07 100644 --- a/datafusion-federation/src/plan_node.rs +++ b/datafusion-federation/src/plan_node.rs @@ -62,7 +62,7 @@ impl UserDefinedLogicalNodeCore for FederatedPlanNode { assert_eq!(exprs.len(), 0, "expression size inconsistent"); Self { plan: self.plan.clone(), - planner: self.planner.clone(), + planner: Arc::clone(&self.planner), } } } @@ -144,7 +144,7 @@ impl ExtensionPlanner for FederatedPlanner { assert_eq!(logical_inputs.len(), 0, "Inconsistent number of inputs"); assert_eq!(physical_inputs.len(), 0, "Inconsistent number of inputs"); - let fed_planner = fed_node.planner.clone(); + let fed_planner = Arc::clone(&fed_node.planner); let exec_plan = fed_planner.plan_federation(fed_node, session_state).await?; return Ok(Some(exec_plan)); } diff --git a/sources/flight-sql/src/executor/mod.rs b/sources/flight-sql/src/executor/mod.rs index a5c5a38..537030a 100644 --- a/sources/flight-sql/src/executor/mod.rs +++ b/sources/flight-sql/src/executor/mod.rs @@ -68,11 +68,11 @@ impl SQLExecutor for FlightSQLExecutor { } fn execute(&self, sql: &str, schema: SchemaRef) -> Result { let future_stream = - make_flight_sql_stream(sql.to_string(), self.client.clone(), schema.clone()); + make_flight_sql_stream(sql.to_string(), self.client.clone(), Arc::clone(&schema)); let stream = futures::stream::once(future_stream).try_flatten(); Ok(Box::pin(RecordBatchStreamAdapter::new( - schema.clone(), + Arc::clone(&schema), stream, ))) } diff --git a/sources/sql/src/lib.rs b/sources/sql/src/lib.rs index b814904..5baf59e 100644 --- a/sources/sql/src/lib.rs +++ b/sources/sql/src/lib.rs @@ -38,7 +38,7 @@ impl SQLFederationProvider { pub fn new(executor: Arc) -> Self { Self { analyzer: Arc::new(Analyzer::with_rules(vec![Arc::new( - SQLFederationAnalyzerRule::new(executor.clone()), + SQLFederationAnalyzerRule::new(Arc::clone(&executor)), )])), executor, } @@ -55,7 +55,7 @@ impl FederationProvider for SQLFederationProvider { } fn analyzer(&self) -> Option> { - Some(self.analyzer.clone()) + Some(Arc::clone(&self.analyzer)) } } @@ -66,7 +66,7 @@ struct SQLFederationAnalyzerRule { impl SQLFederationAnalyzerRule { pub fn new(executor: Arc) -> Self { Self { - planner: Arc::new(SQLFederationPlanner::new(executor.clone())), + planner: Arc::new(SQLFederationPlanner::new(Arc::clone(&executor))), } } } @@ -75,7 +75,7 @@ impl AnalyzerRule for SQLFederationAnalyzerRule { fn analyze(&self, plan: LogicalPlan, _config: &ConfigOptions) -> Result { // Simply accept the entire plan for now - let fed_plan = FederatedPlanNode::new(plan.clone(), self.planner.clone()); + let fed_plan = FederatedPlanNode::new(plan.clone(), Arc::clone(&self.planner)); let ext_node = Extension { node: Arc::new(fed_plan), }; @@ -106,7 +106,7 @@ impl FederationPlanner for SQLFederationPlanner { ) -> Result> { Ok(Arc::new(VirtualExecutionPlan::new( node.plan().clone(), - self.executor.clone(), + Arc::clone(&self.executor), ))) } } diff --git a/sources/sql/src/schema.rs b/sources/sql/src/schema.rs index c780f23..3666711 100644 --- a/sources/sql/src/schema.rs +++ b/sources/sql/src/schema.rs @@ -21,7 +21,7 @@ pub struct SQLSchemaProvider { impl SQLSchemaProvider { pub async fn new(provider: Arc) -> Result { - let tables = provider.clone().executor.table_names().await?; + let tables = Arc::clone(&provider).executor.table_names().await?; Self::new_with_tables(provider, tables).await } @@ -32,7 +32,7 @@ impl SQLSchemaProvider { ) -> Result { let futures: Vec<_> = tables .into_iter() - .map(|t| SQLTableSource::new(provider.clone(), t)) + .map(|t| SQLTableSource::new(Arc::clone(&provider), t)) .collect(); let results: Result> = join_all(futures).await.into_iter().collect(); let sources = results?.into_iter().map(Arc::new).collect(); @@ -60,7 +60,9 @@ impl SchemaProvider for SQLSchemaProvider { .iter() .find(|s| s.table_name.eq_ignore_ascii_case(name)) { - let adaptor = FederatedTableProviderAdaptor::new(source.clone()); + let adaptor = FederatedTableProviderAdaptor::new( + Arc::clone(source) as Arc + ); return Ok(Some(Arc::new(adaptor))); } Ok(None) @@ -116,8 +118,7 @@ pub struct SQLTableSource { impl SQLTableSource { // creates a SQLTableSource and infers the table schema pub async fn new(provider: Arc, table_name: String) -> Result { - let schema = provider - .clone() + let schema = Arc::clone(&provider) .executor .get_table_schema(table_name.as_str()) .await?; @@ -139,7 +140,7 @@ impl SQLTableSource { impl FederatedTableSource for SQLTableSource { fn federation_provider(&self) -> Arc { - self.provider.clone() + Arc::clone(&self.provider) as Arc } } @@ -148,7 +149,7 @@ impl TableSource for SQLTableSource { self } fn schema(&self) -> SchemaRef { - self.schema.clone() + Arc::clone(&self.schema) } fn table_type(&self) -> TableType { TableType::Temporary From c6c95420772e7360a41fccec1e4caf43c3f40f70 Mon Sep 17 00:00:00 2001 From: Phillip LeBlanc Date: Thu, 9 May 2024 16:14:56 +0900 Subject: [PATCH 04/48] Update Cargo.toml license to Apache-2.0 --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index f084a1f..197862f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,7 +15,7 @@ members = [ [workspace.package] version = "0.1.3" edition = "2021" -license = "MIT" +license = "Apache-2.0" readme = "README.md" From 5d9b98f08618930e06072002bd7bc4a1365d02ac Mon Sep 17 00:00:00 2001 From: Phillip LeBlanc Date: Mon, 13 May 2024 21:35:21 +0900 Subject: [PATCH 05/48] Upgrade to DataFusion 38 --- Cargo.toml | 4 ++-- datafusion-federation/src/analyzer.rs | 9 +++++---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 197862f..283aa4f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,5 +21,5 @@ readme = "README.md" [workspace.dependencies] async-trait = "0.1.77" -datafusion = "37.1.0" -datafusion-substrait = "37.1.0" +datafusion = "38.0.0" +datafusion-substrait = "38.0.0" diff --git a/datafusion-federation/src/analyzer.rs b/datafusion-federation/src/analyzer.rs index f5ba3a1..1690555 100644 --- a/datafusion-federation/src/analyzer.rs +++ b/datafusion-federation/src/analyzer.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use datafusion::{ + common::Column, config::ConfigOptions, datasource::source_as_provider, error::Result, @@ -74,7 +75,8 @@ impl FederationAnalyzerRule { // federate the entire plan if let Some(provider) = first_provider { if let Some(optimizer) = provider.analyzer() { - let optimized = optimizer.execute_and_check(plan, _config, |_, _| {})?; + let optimized = + optimizer.execute_and_check(plan.clone(), _config, |_, _| {})?; return Ok((Some(optimized), None)); } return Ok((None, None)); @@ -101,8 +103,7 @@ impl FederationAnalyzerRule { if let Some(optimizer) = provider.analyzer() { let wrapped = wrap_projection((*sub_plan).clone())?; - let optimized = - optimizer.execute_and_check(&wrapped, _config, |_, _| {})?; + let optimized = optimizer.execute_and_check(wrapped, _config, |_, _| {})?; return Ok(optimized); } // No federation for this sub-plan (no analyzer) @@ -141,7 +142,7 @@ fn wrap_projection(plan: LogicalPlan) -> Result { .schema() .fields() .iter() - .map(|f| Expr::Column(f.qualified_column())) + .map(|f| Expr::Column(Column::new_unqualified(f.name()))) .collect::>(); Ok(LogicalPlan::Projection(Projection::try_new( expr, From 3de50a9a3e4468087f91c1270e37b7179e11ca0d Mon Sep 17 00:00:00 2001 From: Phillip LeBlanc Date: Tue, 14 May 2024 15:30:03 +0900 Subject: [PATCH 06/48] Add fallback TableProvider to FederatedTableProviderAdaptor --- datafusion-federation/src/table_provider.rs | 72 +++++++++++++++++++-- 1 file changed, 67 insertions(+), 5 deletions(-) diff --git a/datafusion-federation/src/table_provider.rs b/datafusion-federation/src/table_provider.rs index 6a9afaa..c01ec9b 100644 --- a/datafusion-federation/src/table_provider.rs +++ b/datafusion-federation/src/table_provider.rs @@ -17,11 +17,28 @@ use crate::FederationProvider; // from a TableScan. This wrapper may be avoidable. pub struct FederatedTableProviderAdaptor { pub source: Arc, + pub table_provider: Option>, } impl FederatedTableProviderAdaptor { pub fn new(source: Arc) -> Self { - Self { source } + Self { + source, + table_provider: None, + } + } + + /// Creates a new FederatedTableProviderAdaptor that falls back to the + /// provided TableProvider. This is useful if used within a DataFusion + /// context without the federation optimizer. + pub fn new_with_provider( + source: Arc, + table_provider: Arc, + ) -> Self { + Self { + source, + table_provider: Some(table_provider), + } } } @@ -31,18 +48,44 @@ impl TableProvider for FederatedTableProviderAdaptor { self } fn schema(&self) -> SchemaRef { + if let Some(table_provider) = &self.table_provider { + return table_provider.schema(); + } + self.source.schema() } fn constraints(&self) -> Option<&Constraints> { + if let Some(table_provider) = &self.table_provider { + return table_provider + .constraints() + .or_else(|| self.source.constraints()); + } + self.source.constraints() } fn table_type(&self) -> TableType { + if let Some(table_provider) = &self.table_provider { + return table_provider.table_type(); + } + self.source.table_type() } fn get_logical_plan(&self) -> Option<&LogicalPlan> { + if let Some(table_provider) = &self.table_provider { + return table_provider + .get_logical_plan() + .or_else(|| self.source.get_logical_plan()); + } + self.source.get_logical_plan() } fn get_column_default(&self, column: &str) -> Option<&Expr> { + if let Some(table_provider) = &self.table_provider { + return table_provider + .get_column_default(column) + .or_else(|| self.source.get_column_default(column)); + } + self.source.get_column_default(column) } @@ -50,15 +93,34 @@ impl TableProvider for FederatedTableProviderAdaptor { // with a virtual TableProvider that provides federation for a sub-plan. async fn scan( &self, - _state: &SessionState, - _projection: Option<&Vec>, - _filters: &[Expr], - _limit: Option, + state: &SessionState, + projection: Option<&Vec>, + filters: &[Expr], + limit: Option, ) -> Result> { + if let Some(table_provider) = &self.table_provider { + return table_provider.scan(state, projection, filters, limit).await; + } + Err(DataFusionError::NotImplemented( "FederatedTableProviderAdaptor cannot scan".to_string(), )) } + + async fn insert_into( + &self, + _state: &SessionState, + input: Arc, + overwrite: bool, + ) -> Result> { + if let Some(table_provider) = &self.table_provider { + return table_provider.insert_into(_state, input, overwrite).await; + } + + Err(DataFusionError::NotImplemented( + "FederatedTableProviderAdaptor cannot insert_into".to_string(), + )) + } } // FederatedTableProvider extends DataFusion's TableProvider trait From 0fadae1037056bd37fdc90297fa6df2fdd8dbb94 Mon Sep 17 00:00:00 2001 From: Phillip LeBlanc Date: Tue, 14 May 2024 20:00:26 +0900 Subject: [PATCH 07/48] Make connectorx optional --- sources/sql/Cargo.toml | 5 ++++- sources/sql/src/lib.rs | 1 + 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/sources/sql/Cargo.toml b/sources/sql/Cargo.toml index bc89d95..6073584 100644 --- a/sources/sql/Cargo.toml +++ b/sources/sql/Cargo.toml @@ -16,9 +16,12 @@ async-trait.workspace = true connectorx = { git = "https://github.com/devinjdangelo/connector-x.git", features = [ "dst_arrow", "src_sqlite" -] } +], optional = true } datafusion.workspace = true datafusion-federation.path = "../../datafusion-federation" # derive_builder = "0.13.0" futures = "0.3.30" tokio = "1.35.1" + +[features] +connectorx = ["dep:connectorx"] \ No newline at end of file diff --git a/sources/sql/src/lib.rs b/sources/sql/src/lib.rs index 5baf59e..4c1b5ff 100644 --- a/sources/sql/src/lib.rs +++ b/sources/sql/src/lib.rs @@ -21,6 +21,7 @@ use datafusion_federation::{FederatedPlanNode, FederationPlanner, FederationProv mod schema; pub use schema::*; +#[cfg(feature = "connectorx")] pub mod connectorx; mod executor; pub use executor::*; From 8ba03bd75ccf3a7d6fdf5d30bd92160278f5c16a Mon Sep 17 00:00:00 2001 From: Phillip LeBlanc Date: Thu, 16 May 2024 22:00:05 +0900 Subject: [PATCH 08/48] Rewrite all table scans to use federated table name (#4) * testing * testing2 * Rewrite all table scans to use federated table name * Wrap the table scan in a subquery alias --- sources/sql/src/lib.rs | 57 ++++++++++++++++++++++++++++++++++++--- sources/sql/src/schema.rs | 4 +++ 2 files changed, 57 insertions(+), 4 deletions(-) diff --git a/sources/sql/src/lib.rs b/sources/sql/src/lib.rs index 4c1b5ff..4864a0d 100644 --- a/sources/sql/src/lib.rs +++ b/sources/sql/src/lib.rs @@ -7,16 +7,18 @@ use datafusion::{ config::ConfigOptions, error::Result, execution::{context::SessionState, TaskContext}, - logical_expr::{Extension, LogicalPlan}, + logical_expr::{Extension, LogicalPlan, SubqueryAlias}, optimizer::analyzer::{Analyzer, AnalyzerRule}, physical_expr::EquivalenceProperties, physical_plan::{ DisplayAs, DisplayFormatType, ExecutionMode, ExecutionPlan, Partitioning, PlanProperties, SendableRecordBatchStream, }, - sql::unparser::plan_to_sql, + sql::{unparser::plan_to_sql, TableReference}, +}; +use datafusion_federation::{ + get_table_source, FederatedPlanNode, FederationPlanner, FederationProvider, }; -use datafusion_federation::{FederatedPlanNode, FederationPlanner, FederationProvider}; mod schema; pub use schema::*; @@ -74,7 +76,8 @@ impl SQLFederationAnalyzerRule { impl AnalyzerRule for SQLFederationAnalyzerRule { fn analyze(&self, plan: LogicalPlan, _config: &ConfigOptions) -> Result { - // Simply accept the entire plan for now + // Find all table scans, recover the SQLTableSource, find the remote table name and replace the name of the TableScan table. + let plan = rewrite_table_scans(&plan)?; let fed_plan = FederatedPlanNode::new(plan.clone(), Arc::clone(&self.planner)); let ext_node = Extension { @@ -88,6 +91,52 @@ impl AnalyzerRule for SQLFederationAnalyzerRule { "federate_sql" } } + +/// Rewrite table scans to use the original federated table name. +fn rewrite_table_scans(plan: &LogicalPlan) -> Result { + if plan.inputs().is_empty() { + if let LogicalPlan::TableScan(table_scan) = plan { + let original_table_name = table_scan.table_name.clone(); + let mut new_table_scan = table_scan.clone(); + + let Some(federated_source) = get_table_source(&table_scan.source)? else { + // Not a federated source + return Ok(plan.clone()); + }; + + match federated_source.as_any().downcast_ref::() { + Some(sql_table_source) => { + new_table_scan.table_name = TableReference::from(sql_table_source.table_name()); + } + None => { + // Not a SQLTableSource (is this possible?) + return Ok(plan.clone()); + } + } + + // Wrap the table scan in a SubqueryAlias back to the original table name, so references continue to work. + let subquery_alias = LogicalPlan::SubqueryAlias(SubqueryAlias::try_new( + Arc::new(LogicalPlan::TableScan(new_table_scan)), + original_table_name, + )?); + + return Ok(subquery_alias); + } else { + return Ok(plan.clone()); + } + } + + let rewritten_inputs = plan + .inputs() + .into_iter() + .map(rewrite_table_scans) + .collect::>>()?; + + let new_plan = plan.with_new_exprs(plan.expressions(), rewritten_inputs)?; + + Ok(new_plan) +} + struct SQLFederationPlanner { executor: Arc, } diff --git a/sources/sql/src/schema.rs b/sources/sql/src/schema.rs index 3666711..e417961 100644 --- a/sources/sql/src/schema.rs +++ b/sources/sql/src/schema.rs @@ -136,6 +136,10 @@ impl SQLTableSource { schema, }) } + + pub fn table_name(&self) -> &str { + self.table_name.as_str() + } } impl FederatedTableSource for SQLTableSource { From 8fcda49d7e561a70e1efe439ac9c57c8b9ad30a2 Mon Sep 17 00:00:00 2001 From: Phillip LeBlanc Date: Mon, 20 May 2024 23:48:26 +0900 Subject: [PATCH 09/48] Rewrite subquery table scans to point to remote table (#5) * Rewrite subquery table scans to point to remote table * Add tracing * more tracing * more debug logging * Handle subquery in binary expressions --- sources/sql/Cargo.toml | 3 ++- sources/sql/src/lib.rs | 34 ++++++++++++++++++++++++++++++++-- 2 files changed, 34 insertions(+), 3 deletions(-) diff --git a/sources/sql/Cargo.toml b/sources/sql/Cargo.toml index 6073584..c5fd1d2 100644 --- a/sources/sql/Cargo.toml +++ b/sources/sql/Cargo.toml @@ -22,6 +22,7 @@ datafusion-federation.path = "../../datafusion-federation" # derive_builder = "0.13.0" futures = "0.3.30" tokio = "1.35.1" +tracing = "0.1.40" [features] -connectorx = ["dep:connectorx"] \ No newline at end of file +connectorx = ["dep:connectorx"] diff --git a/sources/sql/src/lib.rs b/sources/sql/src/lib.rs index 4864a0d..b997ada 100644 --- a/sources/sql/src/lib.rs +++ b/sources/sql/src/lib.rs @@ -7,7 +7,7 @@ use datafusion::{ config::ConfigOptions, error::Result, execution::{context::SessionState, TaskContext}, - logical_expr::{Extension, LogicalPlan, SubqueryAlias}, + logical_expr::{BinaryExpr, Expr, Extension, LogicalPlan, Subquery, SubqueryAlias}, optimizer::analyzer::{Analyzer, AnalyzerRule}, physical_expr::EquivalenceProperties, physical_plan::{ @@ -126,17 +126,47 @@ fn rewrite_table_scans(plan: &LogicalPlan) -> Result { } } + let mut new_expressions = vec![]; + for expression in plan.expressions() { + new_expressions.push(rewrite_table_scans_in_subqueries(expression)?); + } + let rewritten_inputs = plan .inputs() .into_iter() .map(rewrite_table_scans) .collect::>>()?; - let new_plan = plan.with_new_exprs(plan.expressions(), rewritten_inputs)?; + let new_plan = plan.with_new_exprs(new_expressions, rewritten_inputs)?; Ok(new_plan) } +fn rewrite_table_scans_in_subqueries(expr: Expr) -> Result { + match expr { + Expr::ScalarSubquery(subquery) => { + let new_subquery = rewrite_table_scans(&subquery.subquery)?; + Ok(Expr::ScalarSubquery(Subquery { + subquery: Arc::new(new_subquery), + outer_ref_columns: subquery.outer_ref_columns, + })) + } + Expr::BinaryExpr(binary_expr) => { + let left = rewrite_table_scans_in_subqueries(*binary_expr.left)?; + let right = rewrite_table_scans_in_subqueries(*binary_expr.right)?; + Ok(Expr::BinaryExpr(BinaryExpr::new( + Box::new(left), + binary_expr.op, + Box::new(right), + ))) + } + _ => { + tracing::debug!("rewrite_table_scans_in_subqueries: no match for expr={expr:?}",); + Ok(expr) + } + } +} + struct SQLFederationPlanner { executor: Arc, } From 1c77d8ccfbf30c74530281119d26e0f24e7c3326 Mon Sep 17 00:00:00 2001 From: Phillip LeBlanc Date: Sat, 25 May 2024 00:04:18 +0900 Subject: [PATCH 10/48] Fix the fallback to the table provider for supports_filters_pushdown --- datafusion-federation/src/table_provider.rs | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/datafusion-federation/src/table_provider.rs b/datafusion-federation/src/table_provider.rs index c01ec9b..a1acc30 100644 --- a/datafusion-federation/src/table_provider.rs +++ b/datafusion-federation/src/table_provider.rs @@ -7,7 +7,7 @@ use datafusion::{ datasource::TableProvider, error::{DataFusionError, Result}, execution::context::SessionState, - logical_expr::{Expr, LogicalPlan, TableSource, TableType}, + logical_expr::{Expr, LogicalPlan, TableProviderFilterPushDown, TableSource, TableType}, physical_plan::ExecutionPlan, }; @@ -88,6 +88,19 @@ impl TableProvider for FederatedTableProviderAdaptor { self.source.get_column_default(column) } + fn supports_filters_pushdown( + &self, + filters: &[&Expr], + ) -> Result> { + if let Some(table_provider) = &self.table_provider { + return table_provider.supports_filters_pushdown(filters); + } + + Ok(vec![ + TableProviderFilterPushDown::Unsupported; + filters.len() + ]) + } // Scan is not supported; the adaptor should be replaced // with a virtual TableProvider that provides federation for a sub-plan. From b221f4ec87ad57dc7e026fa9e5f5935bfe6b3961 Mon Sep 17 00:00:00 2001 From: Phillip LeBlanc Date: Sat, 25 May 2024 00:09:45 +0900 Subject: [PATCH 11/48] Use SpiceAI datafusion fork --- Cargo.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 283aa4f..7fe5cdb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,5 +21,5 @@ readme = "README.md" [workspace.dependencies] async-trait = "0.1.77" -datafusion = "38.0.0" -datafusion-substrait = "38.0.0" +datafusion = { git = "https://github.com/spiceai/datafusion.git", rev = "c533d36788eb66b5a90ce158bdad182d6b3a0da9" } +datafusion-substrait = { git = "https://github.com/spiceai/datafusion.git", folder = "datafusion/substrait", rev = "c533d36788eb66b5a90ce158bdad182d6b3a0da9" } \ No newline at end of file From f5d4797d89389610a8830b810429fe10c61e246b Mon Sep 17 00:00:00 2001 From: Phillip LeBlanc Date: Mon, 27 May 2024 19:23:20 +0900 Subject: [PATCH 12/48] Fix the table scan rewrite to properly rewrite column relations (#3) * Fix the table scan rewrite to properly rewrite column relations * Handle more expressions for table scan rewrites --- sources/sql/src/lib.rs | 468 ++++++++++++++++++++++++++++++++++++++--- 1 file changed, 442 insertions(+), 26 deletions(-) diff --git a/sources/sql/src/lib.rs b/sources/sql/src/lib.rs index b997ada..1130bcc 100644 --- a/sources/sql/src/lib.rs +++ b/sources/sql/src/lib.rs @@ -1,13 +1,21 @@ use core::fmt; -use std::{any::Any, sync::Arc, vec}; +use std::{any::Any, collections::HashMap, sync::Arc, vec}; use async_trait::async_trait; use datafusion::{ arrow::datatypes::{Schema, SchemaRef}, + common::Column, config::ConfigOptions, error::Result, execution::{context::SessionState, TaskContext}, - logical_expr::{BinaryExpr, Expr, Extension, LogicalPlan, Subquery, SubqueryAlias}, + logical_expr::{ + expr::{ + AggregateFunction, Alias, Exists, InList, InSubquery, ScalarFunction, Sort, Unnest, + WindowFunction, + }, + Between, BinaryExpr, Case, Cast, Expr, Extension, GetIndexedField, GroupingSet, Like, + LogicalPlan, Subquery, TryCast, + }, optimizer::analyzer::{Analyzer, AnalyzerRule}, physical_expr::EquivalenceProperties, physical_plan::{ @@ -77,7 +85,8 @@ impl SQLFederationAnalyzerRule { impl AnalyzerRule for SQLFederationAnalyzerRule { fn analyze(&self, plan: LogicalPlan, _config: &ConfigOptions) -> Result { // Find all table scans, recover the SQLTableSource, find the remote table name and replace the name of the TableScan table. - let plan = rewrite_table_scans(&plan)?; + let mut known_rewrites = HashMap::new(); + let plan = rewrite_table_scans(&plan, &mut known_rewrites)?; let fed_plan = FederatedPlanNode::new(plan.clone(), Arc::clone(&self.planner)); let ext_node = Extension { @@ -93,7 +102,10 @@ impl AnalyzerRule for SQLFederationAnalyzerRule { } /// Rewrite table scans to use the original federated table name. -fn rewrite_table_scans(plan: &LogicalPlan) -> Result { +fn rewrite_table_scans( + plan: &LogicalPlan, + known_rewrites: &mut HashMap, +) -> Result { if plan.inputs().is_empty() { if let LogicalPlan::TableScan(table_scan) = plan { let original_table_name = table_scan.table_name.clone(); @@ -106,7 +118,15 @@ fn rewrite_table_scans(plan: &LogicalPlan) -> Result { match federated_source.as_any().downcast_ref::() { Some(sql_table_source) => { - new_table_scan.table_name = TableReference::from(sql_table_source.table_name()); + let remote_table_name = TableReference::from(sql_table_source.table_name()); + known_rewrites.insert(original_table_name, remote_table_name.clone()); + + // Rewrite the schema of this node to have the remote table as the qualifier. + let new_schema = (*new_table_scan.projected_schema) + .clone() + .replace_qualifier(remote_table_name.clone()); + new_table_scan.projected_schema = Arc::new(new_schema); + new_table_scan.table_name = remote_table_name; } None => { // Not a SQLTableSource (is this possible?) @@ -114,56 +134,361 @@ fn rewrite_table_scans(plan: &LogicalPlan) -> Result { } } - // Wrap the table scan in a SubqueryAlias back to the original table name, so references continue to work. - let subquery_alias = LogicalPlan::SubqueryAlias(SubqueryAlias::try_new( - Arc::new(LogicalPlan::TableScan(new_table_scan)), - original_table_name, - )?); - - return Ok(subquery_alias); + return Ok(LogicalPlan::TableScan(new_table_scan)); } else { return Ok(plan.clone()); } } - let mut new_expressions = vec![]; - for expression in plan.expressions() { - new_expressions.push(rewrite_table_scans_in_subqueries(expression)?); - } - let rewritten_inputs = plan .inputs() .into_iter() - .map(rewrite_table_scans) + .map(|i| rewrite_table_scans(i, known_rewrites)) .collect::>>()?; + let mut new_expressions = vec![]; + for expression in plan.expressions() { + let new_expr = rewrite_table_scans_in_expr(expression.clone(), known_rewrites)?; + new_expressions.push(new_expr); + } + let new_plan = plan.with_new_exprs(new_expressions, rewritten_inputs)?; Ok(new_plan) } -fn rewrite_table_scans_in_subqueries(expr: Expr) -> Result { +fn rewrite_table_scans_in_expr( + expr: Expr, + known_rewrites: &mut HashMap, +) -> Result { match expr { Expr::ScalarSubquery(subquery) => { - let new_subquery = rewrite_table_scans(&subquery.subquery)?; + let new_subquery = rewrite_table_scans(&subquery.subquery, known_rewrites)?; + let outer_ref_columns = subquery + .outer_ref_columns + .into_iter() + .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) + .collect::>>()?; Ok(Expr::ScalarSubquery(Subquery { subquery: Arc::new(new_subquery), - outer_ref_columns: subquery.outer_ref_columns, + outer_ref_columns, })) } Expr::BinaryExpr(binary_expr) => { - let left = rewrite_table_scans_in_subqueries(*binary_expr.left)?; - let right = rewrite_table_scans_in_subqueries(*binary_expr.right)?; + let left = rewrite_table_scans_in_expr(*binary_expr.left, known_rewrites)?; + let right = rewrite_table_scans_in_expr(*binary_expr.right, known_rewrites)?; Ok(Expr::BinaryExpr(BinaryExpr::new( Box::new(left), binary_expr.op, Box::new(right), ))) } - _ => { - tracing::debug!("rewrite_table_scans_in_subqueries: no match for expr={expr:?}",); - Ok(expr) + Expr::Column(col) => { + if let Some(rewrite) = col.relation.as_ref().and_then(|r| known_rewrites.get(r)) { + Ok(Expr::Column(Column::new(Some(rewrite.clone()), &col.name))) + } else { + Ok(Expr::Column(col)) + } + } + Expr::Alias(alias) => { + let expr = rewrite_table_scans_in_expr(*alias.expr, known_rewrites)?; + if let Some(relation) = &alias.relation { + if let Some(rewrite) = known_rewrites.get(relation) { + return Ok(Expr::Alias(Alias::new( + expr, + Some(rewrite.clone()), + alias.name, + ))); + } + } + Ok(Expr::Alias(Alias::new(expr, alias.relation, alias.name))) + } + Expr::Like(like) => { + let expr = rewrite_table_scans_in_expr(*like.expr, known_rewrites)?; + let pattern = rewrite_table_scans_in_expr(*like.pattern, known_rewrites)?; + Ok(Expr::Like(Like::new( + like.negated, + Box::new(expr), + Box::new(pattern), + like.escape_char, + like.case_insensitive, + ))) + } + Expr::SimilarTo(similar_to) => { + let expr = rewrite_table_scans_in_expr(*similar_to.expr, known_rewrites)?; + let pattern = rewrite_table_scans_in_expr(*similar_to.pattern, known_rewrites)?; + Ok(Expr::SimilarTo(Like::new( + similar_to.negated, + Box::new(expr), + Box::new(pattern), + similar_to.escape_char, + similar_to.case_insensitive, + ))) + } + Expr::Not(e) => { + let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; + Ok(Expr::Not(Box::new(expr))) + } + Expr::IsNotNull(e) => { + let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; + Ok(Expr::IsNotNull(Box::new(expr))) + } + Expr::IsNull(e) => { + let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; + Ok(Expr::IsNull(Box::new(expr))) + } + Expr::IsTrue(e) => { + let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; + Ok(Expr::IsTrue(Box::new(expr))) + } + Expr::IsFalse(e) => { + let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; + Ok(Expr::IsFalse(Box::new(expr))) + } + Expr::IsUnknown(e) => { + let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; + Ok(Expr::IsUnknown(Box::new(expr))) + } + Expr::IsNotTrue(e) => { + let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; + Ok(Expr::IsNotTrue(Box::new(expr))) + } + Expr::IsNotFalse(e) => { + let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; + Ok(Expr::IsNotFalse(Box::new(expr))) + } + Expr::IsNotUnknown(e) => { + let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; + Ok(Expr::IsNotUnknown(Box::new(expr))) + } + Expr::Negative(e) => { + let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; + Ok(Expr::Negative(Box::new(expr))) + } + Expr::GetIndexedField(indexed_field) => { + let expr = rewrite_table_scans_in_expr(*indexed_field.expr, known_rewrites)?; + Ok(Expr::GetIndexedField(GetIndexedField::new( + Box::new(expr), + indexed_field.field, + ))) + } + Expr::Between(between) => { + let expr = rewrite_table_scans_in_expr(*between.expr, known_rewrites)?; + let low = rewrite_table_scans_in_expr(*between.low, known_rewrites)?; + let high = rewrite_table_scans_in_expr(*between.high, known_rewrites)?; + Ok(Expr::Between(Between::new( + Box::new(expr), + between.negated, + Box::new(low), + Box::new(high), + ))) + } + Expr::Case(case) => { + let expr = case + .expr + .map(|e| rewrite_table_scans_in_expr(*e, known_rewrites)) + .transpose()? + .map(Box::new); + let else_expr = case + .else_expr + .map(|e| rewrite_table_scans_in_expr(*e, known_rewrites)) + .transpose()? + .map(Box::new); + let when_expr = case + .when_then_expr + .into_iter() + .map(|(when, then)| { + let when = rewrite_table_scans_in_expr(*when, known_rewrites); + let then = rewrite_table_scans_in_expr(*then, known_rewrites); + + match (when, then) { + (Ok(when), Ok(then)) => Ok((Box::new(when), Box::new(then))), + (Err(e), _) | (_, Err(e)) => Err(e), + } + }) + .collect::, Box)>>>()?; + Ok(Expr::Case(Case::new(expr, when_expr, else_expr))) + } + Expr::Cast(cast) => { + let expr = rewrite_table_scans_in_expr(*cast.expr, known_rewrites)?; + Ok(Expr::Cast(Cast::new(Box::new(expr), cast.data_type))) + } + Expr::TryCast(try_cast) => { + let expr = rewrite_table_scans_in_expr(*try_cast.expr, known_rewrites)?; + Ok(Expr::TryCast(TryCast::new( + Box::new(expr), + try_cast.data_type, + ))) + } + Expr::Sort(sort) => { + let expr = rewrite_table_scans_in_expr(*sort.expr, known_rewrites)?; + Ok(Expr::Sort(Sort::new( + Box::new(expr), + sort.asc, + sort.nulls_first, + ))) + } + Expr::ScalarFunction(sf) => { + let args = sf + .args + .into_iter() + .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) + .collect::>>()?; + Ok(Expr::ScalarFunction(ScalarFunction { + func_def: sf.func_def, + args, + })) + } + Expr::AggregateFunction(af) => { + let args = af + .args + .into_iter() + .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) + .collect::>>()?; + let filter = af + .filter + .map(|e| rewrite_table_scans_in_expr(*e, known_rewrites)) + .transpose()? + .map(Box::new); + let order_by = af + .order_by + .map(|e| { + e.into_iter() + .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) + .collect::>>() + }) + .transpose()?; + Ok(Expr::AggregateFunction(AggregateFunction { + func_def: af.func_def, + args, + distinct: af.distinct, + filter, + order_by, + null_treatment: af.null_treatment, + })) + } + Expr::WindowFunction(wf) => { + let args = wf + .args + .into_iter() + .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) + .collect::>>()?; + let partition_by = wf + .partition_by + .into_iter() + .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) + .collect::>>()?; + let order_by = wf + .order_by + .into_iter() + .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) + .collect::>>()?; + Ok(Expr::WindowFunction(WindowFunction::new( + wf.fun, + args, + partition_by, + order_by, + wf.window_frame, + wf.null_treatment, + ))) + } + Expr::InList(il) => { + let expr = rewrite_table_scans_in_expr(*il.expr, known_rewrites)?; + let list = il + .list + .into_iter() + .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) + .collect::>>()?; + Ok(Expr::InList(InList::new(Box::new(expr), list, il.negated))) + } + Expr::Exists(exists) => { + let subquery_plan = rewrite_table_scans(&exists.subquery.subquery, known_rewrites)?; + let outer_ref_columns = exists + .subquery + .outer_ref_columns + .into_iter() + .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) + .collect::>>()?; + let subquery = Subquery { + subquery: Arc::new(subquery_plan), + outer_ref_columns, + }; + Ok(Expr::Exists(Exists::new(subquery, exists.negated))) + } + Expr::InSubquery(is) => { + let expr = rewrite_table_scans_in_expr(*is.expr, known_rewrites)?; + let subquery_plan = rewrite_table_scans(&is.subquery.subquery, known_rewrites)?; + let outer_ref_columns = is + .subquery + .outer_ref_columns + .into_iter() + .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) + .collect::>>()?; + let subquery = Subquery { + subquery: Arc::new(subquery_plan), + outer_ref_columns, + }; + Ok(Expr::InSubquery(InSubquery::new( + Box::new(expr), + subquery, + is.negated, + ))) + } + Expr::Wildcard { qualifier } => { + if let Some(rewrite) = qualifier + .as_ref() + .and_then(|q| known_rewrites.get(&TableReference::from(q))) + { + Ok(Expr::Wildcard { + qualifier: Some(rewrite.clone().to_string()), + }) + } else { + Ok(Expr::Wildcard { qualifier }) + } + } + Expr::GroupingSet(gs) => match gs { + GroupingSet::Rollup(exprs) => { + let exprs = exprs + .into_iter() + .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) + .collect::>>()?; + Ok(Expr::GroupingSet(GroupingSet::Rollup(exprs))) + } + GroupingSet::Cube(exprs) => { + let exprs = exprs + .into_iter() + .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) + .collect::>>()?; + Ok(Expr::GroupingSet(GroupingSet::Cube(exprs))) + } + GroupingSet::GroupingSets(vec_exprs) => { + let vec_exprs = vec_exprs + .into_iter() + .map(|exprs| { + exprs + .into_iter() + .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) + .collect::>>() + }) + .collect::>>>()?; + Ok(Expr::GroupingSet(GroupingSet::GroupingSets(vec_exprs))) + } + }, + Expr::OuterReferenceColumn(dt, col) => { + if let Some(rewrite) = col.relation.as_ref().and_then(|r| known_rewrites.get(r)) { + Ok(Expr::OuterReferenceColumn( + dt, + Column::new(Some(rewrite.clone()), &col.name), + )) + } else { + Ok(Expr::OuterReferenceColumn(dt, col)) + } + } + Expr::Unnest(unnest) => { + let expr = rewrite_table_scans_in_expr(*unnest.expr, known_rewrites)?; + Ok(Expr::Unnest(Unnest::new(expr))) } + Expr::ScalarVariable(_, _) | Expr::Literal(_) | Expr::Placeholder(_) => Ok(expr), } } @@ -268,3 +593,94 @@ impl ExecutionPlan for VirtualExecutionPlan { &self.props } } + +#[cfg(test)] +mod tests { + use datafusion::{ + arrow::datatypes::{DataType, Field}, + common::Column, + datasource::DefaultTableSource, + error::DataFusionError, + logical_expr::LogicalPlanBuilder, + sql::sqlparser::dialect::{Dialect, GenericDialect}, + }; + use datafusion_federation::FederatedTableProviderAdaptor; + + use super::*; + + struct TestSQLExecutor {} + + #[async_trait] + impl SQLExecutor for TestSQLExecutor { + fn name(&self) -> &str { + "test_sql_table_source" + } + + fn compute_context(&self) -> Option { + None + } + + fn dialect(&self) -> Arc { + Arc::new(GenericDialect {}) + } + + fn execute(&self, _query: &str, _schema: SchemaRef) -> Result { + Err(DataFusionError::NotImplemented( + "execute not implemented".to_string(), + )) + } + + async fn table_names(&self) -> Result> { + Err(DataFusionError::NotImplemented( + "table inference not implemented".to_string(), + )) + } + + async fn get_table_schema(&self, _table_name: &str) -> Result { + Err(DataFusionError::NotImplemented( + "table inference not implemented".to_string(), + )) + } + } + + #[test] + fn test_rewrite_table_scans() -> Result<()> { + let sql_federation_provider = + Arc::new(SQLFederationProvider::new(Arc::new(TestSQLExecutor {}))); + + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, false), + Field::new("b", DataType::Utf8, false), + Field::new("c", DataType::Date32, false), + ])); + let table_source = Arc::new(SQLTableSource::new_with_schema( + sql_federation_provider, + "remote_table".to_string(), + schema, + )?); + let table_provider_adaptor = Arc::new(FederatedTableProviderAdaptor::new(table_source)); + let default_table_source = Arc::new(DefaultTableSource::new(table_provider_adaptor)); + let plan = + LogicalPlanBuilder::scan("foo.df_table", default_table_source, None)?.project(vec![ + Expr::Column(Column::from_qualified_name("foo.df_table.a")), + Expr::Column(Column::from_qualified_name("foo.df_table.b")), + Expr::Column(Column::from_qualified_name("foo.df_table.c")), + ])?; + + let mut known_rewrites = HashMap::new(); + let rewritten_plan = rewrite_table_scans(&plan.build()?, &mut known_rewrites)?; + + println!("rewritten_plan: \n{:#?}", rewritten_plan); + + let unparsed_sql = plan_to_sql(&rewritten_plan)?; + + println!("unparsed_sql: \n{unparsed_sql}"); + + assert_eq!( + format!("{unparsed_sql}"), + r#"SELECT "remote_table"."a", "remote_table"."b", "remote_table"."c" FROM "remote_table""# + ); + + Ok(()) + } +} From 532ecd01bdf4ce77787a56c5b354bb6332eae8c4 Mon Sep 17 00:00:00 2001 From: Phillip LeBlanc Date: Mon, 27 May 2024 23:16:03 +0900 Subject: [PATCH 13/48] Handle table rewrites for aggregation functions --- sources/sql/Cargo.toml | 3 + sources/sql/src/lib.rs | 131 +++++++++++++++++++++++++++++++++++++---- 2 files changed, 122 insertions(+), 12 deletions(-) diff --git a/sources/sql/Cargo.toml b/sources/sql/Cargo.toml index c5fd1d2..5d1d030 100644 --- a/sources/sql/Cargo.toml +++ b/sources/sql/Cargo.toml @@ -26,3 +26,6 @@ tracing = "0.1.40" [features] connectorx = ["dep:connectorx"] + +[dev-dependencies] +tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } diff --git a/sources/sql/src/lib.rs b/sources/sql/src/lib.rs index 1130bcc..3d7d326 100644 --- a/sources/sql/src/lib.rs +++ b/sources/sql/src/lib.rs @@ -183,11 +183,25 @@ fn rewrite_table_scans_in_expr( Box::new(right), ))) } - Expr::Column(col) => { + Expr::Column(mut col) => { if let Some(rewrite) = col.relation.as_ref().and_then(|r| known_rewrites.get(r)) { Ok(Expr::Column(Column::new(Some(rewrite.clone()), &col.name))) } else { - Ok(Expr::Column(col)) + // Check if any of the rewrites match any substring in col.name, and replace that part of the string if so. + // This will handles cases like "MAX(foo.df_table.a)" -> "MAX(remote_table.a)" + let rewritten_name = known_rewrites.iter().find_map(|(table_ref, rewrite)| { + let table_ref_str = table_ref.to_string(); + if col.name.contains(&table_ref_str) { + Some(col.name.replace(&table_ref_str, &rewrite.to_string())) + } else { + None + } + }); + if let Some(new_name) = rewritten_name { + Ok(Expr::Column(Column::new(col.relation.take(), new_name))) + } else { + Ok(Expr::Column(col)) + } } } Expr::Alias(alias) => { @@ -598,9 +612,11 @@ impl ExecutionPlan for VirtualExecutionPlan { mod tests { use datafusion::{ arrow::datatypes::{DataType, Field}, + catalog::schema::{MemorySchemaProvider, SchemaProvider}, common::Column, - datasource::DefaultTableSource, + datasource::{DefaultTableSource, TableProvider}, error::DataFusionError, + execution::context::SessionContext, logical_expr::LogicalPlanBuilder, sql::sqlparser::dialect::{Dialect, GenericDialect}, }; @@ -643,8 +659,7 @@ mod tests { } } - #[test] - fn test_rewrite_table_scans() -> Result<()> { + fn get_test_table_provider() -> Arc { let sql_federation_provider = Arc::new(SQLFederationProvider::new(Arc::new(TestSQLExecutor {}))); @@ -653,13 +668,39 @@ mod tests { Field::new("b", DataType::Utf8, false), Field::new("c", DataType::Date32, false), ])); - let table_source = Arc::new(SQLTableSource::new_with_schema( - sql_federation_provider, - "remote_table".to_string(), - schema, - )?); - let table_provider_adaptor = Arc::new(FederatedTableProviderAdaptor::new(table_source)); - let default_table_source = Arc::new(DefaultTableSource::new(table_provider_adaptor)); + let table_source = Arc::new( + SQLTableSource::new_with_schema( + sql_federation_provider, + "remote_table".to_string(), + schema, + ) + .expect("to have a valid SQLTableSource"), + ); + Arc::new(FederatedTableProviderAdaptor::new(table_source)) + } + + fn get_test_table_source() -> Arc { + Arc::new(DefaultTableSource::new(get_test_table_provider())) + } + + fn get_test_df_context() -> SessionContext { + let ctx = SessionContext::new(); + let catalog = ctx + .catalog("datafusion") + .expect("default catalog is datafusion"); + let foo_schema = Arc::new(MemorySchemaProvider::new()) as Arc; + catalog + .register_schema("foo", Arc::clone(&foo_schema)) + .expect("to register schema"); + foo_schema + .register_table("df_table".to_string(), get_test_table_provider()) + .expect("to register table"); + ctx + } + + #[test] + fn test_rewrite_table_scans_basic() -> Result<()> { + let default_table_source = get_test_table_source(); let plan = LogicalPlanBuilder::scan("foo.df_table", default_table_source, None)?.project(vec![ Expr::Column(Column::from_qualified_name("foo.df_table.a")), @@ -683,4 +724,70 @@ mod tests { Ok(()) } + + fn init_tracing() { + let subscriber = tracing_subscriber::FmtSubscriber::builder() + .with_env_filter("debug") + .with_ansi(true) + .finish(); + let _ = tracing::subscriber::set_global_default(subscriber); + } + + #[tokio::test] + async fn test_rewrite_table_scans_agg() -> Result<()> { + init_tracing(); + let ctx = get_test_df_context(); + + let agg_tests = vec![ + ( + "SELECT MAX(a) FROM foo.df_table", + r#"SELECT MAX("remote_table"."a") FROM "remote_table""#, + ), + ( + "SELECT MIN(a) FROM foo.df_table", + r#"SELECT MIN("remote_table"."a") FROM "remote_table""#, + ), + ( + "SELECT AVG(a) FROM foo.df_table", + r#"SELECT AVG("remote_table"."a") FROM "remote_table""#, + ), + ( + "SELECT SUM(a) FROM foo.df_table", + r#"SELECT SUM("remote_table"."a") FROM "remote_table""#, + ), + ( + "SELECT COUNT(a) FROM foo.df_table", + r#"SELECT COUNT("remote_table"."a") FROM "remote_table""#, + ), + ( + "SELECT COUNT(a) as cnt FROM foo.df_table", + r#"SELECT COUNT("remote_table"."a") AS "cnt" FROM "remote_table""#, + ), + ]; + + for test in agg_tests { + let data_frame = ctx.sql(test.0).await?; + + println!("before optimization: \n{:#?}", data_frame.logical_plan()); + + let mut known_rewrites = HashMap::new(); + let rewritten_plan = + rewrite_table_scans(data_frame.logical_plan(), &mut known_rewrites)?; + + println!("rewritten_plan: \n{:#?}", rewritten_plan); + + let unparsed_sql = plan_to_sql(&rewritten_plan)?; + + println!("unparsed_sql: \n{unparsed_sql}"); + + assert_eq!( + format!("{unparsed_sql}"), + test.1, + "SQL under test: {}", + test.0 + ); + } + + Ok(()) + } } From f5d20bb92b87cb5902c924aebe739d865cf31151 Mon Sep 17 00:00:00 2001 From: Phillip LeBlanc Date: Wed, 29 May 2024 17:13:46 +0900 Subject: [PATCH 14/48] Update datafusion version --- Cargo.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 7fe5cdb..d6c0003 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,5 +21,5 @@ readme = "README.md" [workspace.dependencies] async-trait = "0.1.77" -datafusion = { git = "https://github.com/spiceai/datafusion.git", rev = "c533d36788eb66b5a90ce158bdad182d6b3a0da9" } -datafusion-substrait = { git = "https://github.com/spiceai/datafusion.git", folder = "datafusion/substrait", rev = "c533d36788eb66b5a90ce158bdad182d6b3a0da9" } \ No newline at end of file +datafusion = { git = "https://github.com/spiceai/datafusion.git", rev = "ea1176ae08dd0f94c99ef7f1d7dc989e383a3586" } +datafusion-substrait = { git = "https://github.com/spiceai/datafusion.git", folder = "datafusion/substrait", rev = "ea1176ae08dd0f94c99ef7f1d7dc989e383a3586" } \ No newline at end of file From 65d4be2f3039cb580790ee8dc68b34fd84ff01bd Mon Sep 17 00:00:00 2001 From: Phillip LeBlanc Date: Wed, 22 May 2024 23:31:22 +0900 Subject: [PATCH 15/48] Use Unparser Dialect for plan_to_sql --- sources/flight-sql/src/executor/mod.rs | 4 ++-- sources/sql/src/connectorx/executor.rs | 6 +++--- sources/sql/src/executor.rs | 2 +- sources/sql/src/lib.rs | 9 +++++++-- 4 files changed, 13 insertions(+), 8 deletions(-) diff --git a/sources/flight-sql/src/executor/mod.rs b/sources/flight-sql/src/executor/mod.rs index 537030a..c4f1225 100644 --- a/sources/flight-sql/src/executor/mod.rs +++ b/sources/flight-sql/src/executor/mod.rs @@ -4,7 +4,7 @@ use async_trait::async_trait; use datafusion::{ error::{DataFusionError, Result}, physical_plan::{stream::RecordBatchStreamAdapter, SendableRecordBatchStream}, - sql::sqlparser::dialect::{Dialect, GenericDialect}, + sql::unparser::dialect::{DefaultDialect, Dialect}, }; use datafusion_federation_sql::SQLExecutor; use futures::TryStreamExt; @@ -96,7 +96,7 @@ impl SQLExecutor for FlightSQLExecutor { } fn dialect(&self) -> Arc { - Arc::new(GenericDialect {}) + Arc::new(DefaultDialect {}) } } diff --git a/sources/sql/src/connectorx/executor.rs b/sources/sql/src/connectorx/executor.rs index fc5ea3d..b5964ce 100644 --- a/sources/sql/src/connectorx/executor.rs +++ b/sources/sql/src/connectorx/executor.rs @@ -10,7 +10,7 @@ use datafusion::{ physical_plan::{ stream::RecordBatchStreamAdapter, EmptyRecordBatchStream, SendableRecordBatchStream, }, - sql::sqlparser::dialect::{Dialect, GenericDialect, PostgreSqlDialect, SQLiteDialect}, + sql::unparser::dialect::{DefaultDialect, Dialect, PostgreSqlDialect, SqliteDialect}, }; use futures::executor::block_on; use std::sync::Arc; @@ -92,8 +92,8 @@ impl SQLExecutor for CXExecutor { fn dialect(&self) -> Arc { match &self.conn.ty { SourceType::Postgres => Arc::new(PostgreSqlDialect {}), - SourceType::SQLite => Arc::new(SQLiteDialect {}), - _ => Arc::new(GenericDialect {}), + SourceType::SQLite => Arc::new(SqliteDialect {}), + _ => Arc::new(DefaultDialect {}), } } } diff --git a/sources/sql/src/executor.rs b/sources/sql/src/executor.rs index 7f05910..6042f77 100644 --- a/sources/sql/src/executor.rs +++ b/sources/sql/src/executor.rs @@ -2,7 +2,7 @@ use async_trait::async_trait; use core::fmt; use datafusion::{ arrow::datatypes::SchemaRef, error::Result, physical_plan::SendableRecordBatchStream, - sql::sqlparser::dialect::Dialect, + sql::unparser::dialect::Dialect, }; use std::sync::Arc; diff --git a/sources/sql/src/lib.rs b/sources/sql/src/lib.rs index 3d7d326..05ae6b9 100644 --- a/sources/sql/src/lib.rs +++ b/sources/sql/src/lib.rs @@ -22,7 +22,10 @@ use datafusion::{ DisplayAs, DisplayFormatType, ExecutionMode, ExecutionPlan, Partitioning, PlanProperties, SendableRecordBatchStream, }, - sql::{unparser::plan_to_sql, TableReference}, + sql::{ + unparser::{plan_to_sql, Unparser}, + TableReference, + }, }; use datafusion_federation::{ get_table_source, FederatedPlanNode, FederationPlanner, FederationProvider, @@ -597,7 +600,9 @@ impl ExecutionPlan for VirtualExecutionPlan { _partition: usize, _context: Arc, ) -> Result { - let ast = plan_to_sql(&self.plan)?; + let dialect = self.executor.dialect(); + let unparser = Unparser::new(dialect.as_ref()); + let ast = unparser.plan_to_sql(&self.plan)?; let query = format!("{ast}"); self.executor.execute(query.as_str(), self.schema()) From 6d714cb1b6137a6152e815a74f1f08e2a84fdf10 Mon Sep 17 00:00:00 2001 From: Phillip LeBlanc Date: Thu, 30 May 2024 00:42:45 +0900 Subject: [PATCH 16/48] Use unparser Dialect --- sources/sql/src/lib.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sources/sql/src/lib.rs b/sources/sql/src/lib.rs index 05ae6b9..839e4e9 100644 --- a/sources/sql/src/lib.rs +++ b/sources/sql/src/lib.rs @@ -623,7 +623,7 @@ mod tests { error::DataFusionError, execution::context::SessionContext, logical_expr::LogicalPlanBuilder, - sql::sqlparser::dialect::{Dialect, GenericDialect}, + sql::{unparser::dialect::DefaultDialect, unparser::dialect::Dialect}, }; use datafusion_federation::FederatedTableProviderAdaptor; @@ -642,7 +642,7 @@ mod tests { } fn dialect(&self) -> Arc { - Arc::new(GenericDialect {}) + Arc::new(DefaultDialect {}) } fn execute(&self, _query: &str, _schema: SchemaRef) -> Result { From 9dc7404eb251b3949247f3cc8c2ba89feea5dee7 Mon Sep 17 00:00:00 2001 From: Phillip LeBlanc Date: Thu, 30 May 2024 00:43:57 +0900 Subject: [PATCH 17/48] bump datafusion fork --- Cargo.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index d6c0003..6529024 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,5 +21,5 @@ readme = "README.md" [workspace.dependencies] async-trait = "0.1.77" -datafusion = { git = "https://github.com/spiceai/datafusion.git", rev = "ea1176ae08dd0f94c99ef7f1d7dc989e383a3586" } -datafusion-substrait = { git = "https://github.com/spiceai/datafusion.git", folder = "datafusion/substrait", rev = "ea1176ae08dd0f94c99ef7f1d7dc989e383a3586" } \ No newline at end of file +datafusion = { git = "https://github.com/spiceai/datafusion.git", rev = "4b6489ffd8d138b138c2049966b19d073867885f" } +datafusion-substrait = { git = "https://github.com/spiceai/datafusion.git", folder = "datafusion/substrait", rev = "4b6489ffd8d138b138c2049966b19d073867885f" } \ No newline at end of file From bf5fac66ca196ca52221df1ac90f38879d9b6ea8 Mon Sep 17 00:00:00 2001 From: Phillip LeBlanc Date: Thu, 30 May 2024 15:52:34 +0900 Subject: [PATCH 18/48] Add layer to convert actual schema to expected schema --- Cargo.toml | 2 + datafusion-federation/Cargo.toml | 2 + datafusion-federation/src/lib.rs | 1 + datafusion-federation/src/schema_cast.rs | 101 +++++++++++++++ .../src/schema_cast/record_convert.rs | 119 ++++++++++++++++++ sources/sql/src/lib.rs | 9 +- 6 files changed, 231 insertions(+), 3 deletions(-) create mode 100644 datafusion-federation/src/schema_cast.rs create mode 100644 datafusion-federation/src/schema_cast/record_convert.rs diff --git a/Cargo.toml b/Cargo.toml index 6529024..12c7770 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,5 +21,7 @@ readme = "README.md" [workspace.dependencies] async-trait = "0.1.77" +async-stream = "0.3.5" +futures = "0.3.30" datafusion = { git = "https://github.com/spiceai/datafusion.git", rev = "4b6489ffd8d138b138c2049966b19d073867885f" } datafusion-substrait = { git = "https://github.com/spiceai/datafusion.git", folder = "datafusion/substrait", rev = "4b6489ffd8d138b138c2049966b19d073867885f" } \ No newline at end of file diff --git a/datafusion-federation/Cargo.toml b/datafusion-federation/Cargo.toml index 83e2723..2975a0d 100644 --- a/datafusion-federation/Cargo.toml +++ b/datafusion-federation/Cargo.toml @@ -13,6 +13,8 @@ path = "src/lib.rs" [dependencies] async-trait.workspace = true datafusion.workspace = true +async-stream.workspace = true +futures.workspace = true [package.metadata.docs.rs] diff --git a/datafusion-federation/src/lib.rs b/datafusion-federation/src/lib.rs index 999d296..41ce145 100644 --- a/datafusion-federation/src/lib.rs +++ b/datafusion-federation/src/lib.rs @@ -13,6 +13,7 @@ pub use table_provider::*; mod plan_node; pub use plan_node::*; +pub mod schema_cast; pub type FederationProviderRef = Arc; pub trait FederationProvider: Send + Sync { diff --git a/datafusion-federation/src/schema_cast.rs b/datafusion-federation/src/schema_cast.rs new file mode 100644 index 0000000..6f5db4a --- /dev/null +++ b/datafusion-federation/src/schema_cast.rs @@ -0,0 +1,101 @@ +use async_stream::stream; +use datafusion::arrow::datatypes::SchemaRef; +use datafusion::error::{DataFusionError, Result}; +use datafusion::execution::{SendableRecordBatchStream, TaskContext}; +use datafusion::physical_plan::stream::RecordBatchStreamAdapter; +use datafusion::physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties, +}; +use futures::StreamExt; +use std::any::Any; +use std::clone::Clone; +use std::fmt; +use std::sync::Arc; + +mod record_convert; + +#[derive(Debug)] +#[allow(clippy::module_name_repetitions)] +pub struct SchemaCastScanExec { + input: Arc, + schema: SchemaRef, + properties: PlanProperties, +} + +impl SchemaCastScanExec { + pub fn new(input: Arc, schema: SchemaRef) -> Self { + let eq_properties = input.equivalence_properties().clone(); + let execution_mode = input.execution_mode(); + let properties = PlanProperties::new( + eq_properties, + input.output_partitioning().clone(), + execution_mode, + ); + Self { + input, + schema, + properties, + } + } +} + +impl DisplayAs for SchemaCastScanExec { + fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "SchemaCastScanExec") + } +} + +impl ExecutionPlan for SchemaCastScanExec { + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.properties + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + fn children(&self) -> Vec> { + vec![Arc::clone(&self.input)] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + if children.len() == 1 { + Ok(Arc::new(Self::new( + Arc::clone(&children[0]), + Arc::clone(&self.schema), + ))) + } else { + Err(DataFusionError::Execution( + "SchemaCastScanExec expects exactly one input".to_string(), + )) + } + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + let mut stream = self.input.execute(partition, context)?; + let schema = Arc::clone(&self.schema); + + Ok(Box::pin(RecordBatchStreamAdapter::new( + Arc::clone(&schema), + { + stream! { + while let Some(batch) = stream.next().await { + let batch = record_convert::try_cast_to(batch?, Arc::clone(&schema)); + yield batch.map_err(|e| { DataFusionError::External(Box::new(e)) }); + } + } + }, + ))) + } +} diff --git a/datafusion-federation/src/schema_cast/record_convert.rs b/datafusion-federation/src/schema_cast/record_convert.rs new file mode 100644 index 0000000..95563e3 --- /dev/null +++ b/datafusion-federation/src/schema_cast/record_convert.rs @@ -0,0 +1,119 @@ +use datafusion::arrow::{ + array::{Array, RecordBatch}, + compute::cast, + datatypes::SchemaRef, +}; +use std::sync::Arc; + +pub type Result = std::result::Result; + +#[derive(Debug)] +pub enum Error { + UnableToConvertRecordBatch { + source: datafusion::arrow::error::ArrowError, + }, + + UnexpectedNumberOfColumns { + expected: usize, + found: usize, + }, +} + +impl std::error::Error for Error {} + +impl std::fmt::Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + Error::UnableToConvertRecordBatch { source } => { + write!(f, "Unable to convert record batch: {}", source) + } + Error::UnexpectedNumberOfColumns { expected, found } => { + write!( + f, + "Unexpected number of columns. Expected: {}, Found: {}", + expected, found + ) + } + } + } +} + +/// Cast a given record batch into a new record batch with the given schema. +/// It assumes the record batch columns are correctly ordered. +#[allow(clippy::needless_pass_by_value)] +pub(crate) fn try_cast_to( + record_batch: RecordBatch, + expected_schema: SchemaRef, +) -> Result { + let actual_schema = record_batch.schema(); + + if actual_schema.fields().len() != expected_schema.fields().len() { + return Err(Error::UnexpectedNumberOfColumns { + expected: expected_schema.fields().len(), + found: actual_schema.fields().len(), + }); + } + + let cols = expected_schema + .fields() + .iter() + .enumerate() + .map(|(i, expected_field)| { + let record_batch_col = record_batch.column(i); + + return cast(&Arc::clone(record_batch_col), expected_field.data_type()) + .map_err(|e| Error::UnableToConvertRecordBatch { source: e }); + }) + .collect::>>>()?; + + RecordBatch::try_new(expected_schema, cols) + .map_err(|e| Error::UnableToConvertRecordBatch { source: e }) +} + +#[cfg(test)] +mod test { + use datafusion::arrow::{ + array::{Int32Array, StringArray}, + datatypes::{DataType, Field, Schema, TimeUnit}, + }; + + use super::*; + + fn schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, false), + Field::new("c", DataType::Utf8, false), + ])) + } + + fn to_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, false), + Field::new("b", DataType::LargeUtf8, false), + Field::new("c", DataType::Timestamp(TimeUnit::Microsecond, None), false), + ])) + } + + fn batch_input() -> RecordBatch { + RecordBatch::try_new( + schema(), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(StringArray::from(vec!["foo", "bar", "baz"])), + Arc::new(StringArray::from(vec![ + "2024-01-13 03:18:09.000000", + "2024-01-13 03:18:09", + "2024-01-13 03:18:09.000", + ])), + ], + ) + .expect("record batch should not panic") + } + + #[test] + fn test_string_to_timestamp_conversion() { + let result = try_cast_to(batch_input(), to_schema()).expect("converted"); + assert_eq!(3, result.num_rows()); + } +} diff --git a/sources/sql/src/lib.rs b/sources/sql/src/lib.rs index 839e4e9..1eac4a4 100644 --- a/sources/sql/src/lib.rs +++ b/sources/sql/src/lib.rs @@ -28,7 +28,7 @@ use datafusion::{ }, }; use datafusion_federation::{ - get_table_source, FederatedPlanNode, FederationPlanner, FederationProvider, + get_table_source, schema_cast, FederatedPlanNode, FederationPlanner, FederationProvider, }; mod schema; @@ -526,10 +526,13 @@ impl FederationPlanner for SQLFederationPlanner { node: &FederatedPlanNode, _session_state: &SessionState, ) -> Result> { - Ok(Arc::new(VirtualExecutionPlan::new( + let schema = Arc::new(node.plan().schema().as_arrow().clone()); + let input = Arc::new(VirtualExecutionPlan::new( node.plan().clone(), Arc::clone(&self.executor), - ))) + )); + let schema_cast_exec = schema_cast::SchemaCastScanExec::new(input, schema); + Ok(Arc::new(schema_cast_exec)) } } From 4ea973e9feebfdb776bcfb6daeef91e897b0375f Mon Sep 17 00:00:00 2001 From: Phillip LeBlanc Date: Tue, 11 Jun 2024 12:39:36 +0900 Subject: [PATCH 19/48] Update datafusion-federation to datafusion 39 (#6) --- .vscode/launch.json | 213 +++++++++++++++++++++++ Cargo.toml | 4 +- datafusion-federation/src/plan_node.rs | 6 +- datafusion-federation/src/schema_cast.rs | 4 +- examples/Cargo.toml | 2 +- sources/flight-sql/Cargo.toml | 4 +- sources/flight-sql/src/server/service.rs | 6 +- sources/sql/src/lib.rs | 15 +- 8 files changed, 230 insertions(+), 24 deletions(-) create mode 100644 .vscode/launch.json diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..6982f30 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,213 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "type": "lldb", + "request": "launch", + "name": "Debug unit tests in library 'datafusion_federation'", + "cargo": { + "args": [ + "test", + "--no-run", + "--lib", + "--package=datafusion-federation" + ], + "filter": { + "name": "datafusion_federation", + "kind": "lib" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug example 'sqlite'", + "cargo": { + "args": [ + "build", + "--example=sqlite", + "--package=datafusion-federation-examples" + ], + "filter": { + "name": "sqlite", + "kind": "example" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug unit tests in example 'sqlite'", + "cargo": { + "args": [ + "test", + "--no-run", + "--example=sqlite", + "--package=datafusion-federation-examples" + ], + "filter": { + "name": "sqlite", + "kind": "example" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug example 'flight-sql'", + "cargo": { + "args": [ + "build", + "--example=flight-sql", + "--package=datafusion-federation-examples" + ], + "filter": { + "name": "flight-sql", + "kind": "example" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug unit tests in example 'flight-sql'", + "cargo": { + "args": [ + "test", + "--no-run", + "--example=flight-sql", + "--package=datafusion-federation-examples" + ], + "filter": { + "name": "flight-sql", + "kind": "example" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug example 'postgres-partial'", + "cargo": { + "args": [ + "build", + "--example=postgres-partial", + "--package=datafusion-federation-examples" + ], + "filter": { + "name": "postgres-partial", + "kind": "example" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug unit tests in example 'postgres-partial'", + "cargo": { + "args": [ + "test", + "--no-run", + "--example=postgres-partial", + "--package=datafusion-federation-examples" + ], + "filter": { + "name": "postgres-partial", + "kind": "example" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug example 'sqlite-partial'", + "cargo": { + "args": [ + "build", + "--example=sqlite-partial", + "--package=datafusion-federation-examples" + ], + "filter": { + "name": "sqlite-partial", + "kind": "example" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug unit tests in example 'sqlite-partial'", + "cargo": { + "args": [ + "test", + "--no-run", + "--example=sqlite-partial", + "--package=datafusion-federation-examples" + ], + "filter": { + "name": "sqlite-partial", + "kind": "example" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug unit tests in library 'datafusion_federation_flight_sql'", + "cargo": { + "args": [ + "test", + "--no-run", + "--lib", + "--package=datafusion-federation-flight-sql" + ], + "filter": { + "name": "datafusion_federation_flight_sql", + "kind": "lib" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug unit tests in library 'datafusion_federation_sql'", + "cargo": { + "args": [ + "test", + "--no-run", + "--lib", + "--package=datafusion-federation-sql" + ], + "filter": { + "name": "datafusion_federation_sql", + "kind": "lib" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + } + ] +} \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index 12c7770..2a3608a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,5 +23,5 @@ readme = "README.md" async-trait = "0.1.77" async-stream = "0.3.5" futures = "0.3.30" -datafusion = { git = "https://github.com/spiceai/datafusion.git", rev = "4b6489ffd8d138b138c2049966b19d073867885f" } -datafusion-substrait = { git = "https://github.com/spiceai/datafusion.git", folder = "datafusion/substrait", rev = "4b6489ffd8d138b138c2049966b19d073867885f" } \ No newline at end of file +datafusion = { git = "https://github.com/spiceai/datafusion.git", rev = "be2c2c1f74823956e609a23ca38657cd76c2fcfe" } +datafusion-substrait = { git = "https://github.com/spiceai/datafusion.git", folder = "datafusion/substrait", rev = "be2c2c1f74823956e609a23ca38657cd76c2fcfe" } \ No newline at end of file diff --git a/datafusion-federation/src/plan_node.rs b/datafusion-federation/src/plan_node.rs index 2940b07..0d29211 100644 --- a/datafusion-federation/src/plan_node.rs +++ b/datafusion-federation/src/plan_node.rs @@ -57,13 +57,13 @@ impl UserDefinedLogicalNodeCore for FederatedPlanNode { write!(f, "Federated\n {:?}", self.plan) } - fn from_template(&self, exprs: &[Expr], inputs: &[LogicalPlan]) -> Self { + fn with_exprs_and_inputs(&self, exprs: Vec, inputs: Vec) -> Result { assert_eq!(inputs.len(), 0, "input size inconsistent"); assert_eq!(exprs.len(), 0, "expression size inconsistent"); - Self { + Ok(Self { plan: self.plan.clone(), planner: Arc::clone(&self.planner), - } + }) } } diff --git a/datafusion-federation/src/schema_cast.rs b/datafusion-federation/src/schema_cast.rs index 6f5db4a..53038f2 100644 --- a/datafusion-federation/src/schema_cast.rs +++ b/datafusion-federation/src/schema_cast.rs @@ -58,8 +58,8 @@ impl ExecutionPlan for SchemaCastScanExec { Arc::clone(&self.schema) } - fn children(&self) -> Vec> { - vec![Arc::clone(&self.input)] + fn children(&self) -> Vec<&Arc> { + vec![&self.input] } fn with_new_children( diff --git a/examples/Cargo.toml b/examples/Cargo.toml index 13eb599..52acc4e 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -6,7 +6,7 @@ license.workspace = true readme.workspace = true [dev-dependencies] -arrow-flight = { version = "51.0.0", features = ["flight-sql-experimental"] } +arrow-flight = { version = "52.0.0", features = ["flight-sql-experimental"] } tokio = "1.35.1" async-trait.workspace = true datafusion.workspace = true diff --git a/sources/flight-sql/Cargo.toml b/sources/flight-sql/Cargo.toml index f778dfc..d9c1a1d 100644 --- a/sources/flight-sql/Cargo.toml +++ b/sources/flight-sql/Cargo.toml @@ -18,6 +18,6 @@ datafusion-federation-sql.path = "../sql" futures = "0.3.30" tonic = {version="0.11.0", features=["tls"] } prost = "0.12.3" -arrow = "51.0.0" -arrow-flight = { version = "51.0.0", features = ["flight-sql-experimental"] } +arrow = "52.0.0" +arrow-flight = { version = "52.0.0", features = ["flight-sql-experimental"] } log = "0.4.20" diff --git a/sources/flight-sql/src/server/service.rs b/sources/flight-sql/src/server/service.rs index c49cf32..6bc77ff 100644 --- a/sources/flight-sql/src/server/service.rs +++ b/sources/flight-sql/src/server/service.rs @@ -16,8 +16,8 @@ use arrow_flight::sql::{ CommandGetCrossReference, CommandGetDbSchemas, CommandGetExportedKeys, CommandGetImportedKeys, CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables, CommandGetXdbcTypeInfo, CommandPreparedStatementQuery, CommandPreparedStatementUpdate, - CommandStatementQuery, CommandStatementSubstraitPlan, CommandStatementUpdate, SqlInfo, - TicketStatementQuery, + CommandStatementQuery, CommandStatementSubstraitPlan, CommandStatementUpdate, + DoPutPreparedStatementResult, SqlInfo, TicketStatementQuery, }; use arrow_flight::{ Action, FlightDescriptor, FlightEndpoint, FlightInfo, HandshakeRequest, HandshakeResponse, @@ -601,7 +601,7 @@ impl ArrowFlightSqlService for FlightSqlService { &self, _query: CommandPreparedStatementQuery, request: Request, - ) -> Result::DoPutStream>> { + ) -> Result { info!("do_put_prepared_statement_query"); let (_, _) = self.new_context(request)?; diff --git a/sources/sql/src/lib.rs b/sources/sql/src/lib.rs index 1eac4a4..0d2236b 100644 --- a/sources/sql/src/lib.rs +++ b/sources/sql/src/lib.rs @@ -13,8 +13,8 @@ use datafusion::{ AggregateFunction, Alias, Exists, InList, InSubquery, ScalarFunction, Sort, Unnest, WindowFunction, }, - Between, BinaryExpr, Case, Cast, Expr, Extension, GetIndexedField, GroupingSet, Like, - LogicalPlan, Subquery, TryCast, + Between, BinaryExpr, Case, Cast, Expr, Extension, GroupingSet, Like, LogicalPlan, Subquery, + TryCast, }, optimizer::analyzer::{Analyzer, AnalyzerRule}, physical_expr::EquivalenceProperties, @@ -282,13 +282,6 @@ fn rewrite_table_scans_in_expr( let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; Ok(Expr::Negative(Box::new(expr))) } - Expr::GetIndexedField(indexed_field) => { - let expr = rewrite_table_scans_in_expr(*indexed_field.expr, known_rewrites)?; - Ok(Expr::GetIndexedField(GetIndexedField::new( - Box::new(expr), - indexed_field.field, - ))) - } Expr::Between(between) => { let expr = rewrite_table_scans_in_expr(*between.expr, known_rewrites)?; let low = rewrite_table_scans_in_expr(*between.low, known_rewrites)?; @@ -352,7 +345,7 @@ fn rewrite_table_scans_in_expr( .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) .collect::>>()?; Ok(Expr::ScalarFunction(ScalarFunction { - func_def: sf.func_def, + func: sf.func, args, })) } @@ -587,7 +580,7 @@ impl ExecutionPlan for VirtualExecutionPlan { self.schema() } - fn children(&self) -> Vec> { + fn children(&self) -> Vec<&Arc> { vec![] } From e85aa9652326c9d1649f6535620990e12efa37a2 Mon Sep 17 00:00:00 2001 From: sgrebnov Date: Wed, 12 Jun 2024 11:21:17 -0700 Subject: [PATCH 20/48] Improve rewrite logic for expressions Update Support for multiple table occurrences in single expression Support for multiple different tables in single aggregation expression Document rewrite_column_name_in_expr logic Update Fix spelling --- sources/sql/src/lib.rs | 192 ++++++++++++++++++++++++++++++++++------- 1 file changed, 161 insertions(+), 31 deletions(-) diff --git a/sources/sql/src/lib.rs b/sources/sql/src/lib.rs index 0d2236b..d89dcfd 100644 --- a/sources/sql/src/lib.rs +++ b/sources/sql/src/lib.rs @@ -160,6 +160,76 @@ fn rewrite_table_scans( Ok(new_plan) } +// The function replaces occurrences of table_ref_str in col_name with the new name defined by rewrite. +// The name to rewrite should NOT be a substring of another name. +// Supports multiple occurrences of table_ref_str in col_name. +fn rewrite_column_name_in_expr( + col_name: &str, + table_ref_str: &str, + rewrite: &str, + start_pos: usize, +) -> Option { + if start_pos >= col_name.len() { + return None; + } + + // Find the first occurrence of table_ref_str starting from start_pos + let Some(idx) = col_name[start_pos..].find(table_ref_str) else { + return None; + }; + + // Calculate the absolute index of the occurrence in string as the index above is relative to start_pos + let idx = start_pos + idx; + + if idx > 0 { + // Check if the previous character is alphabetic, numeric, underscore or period, in which case we + // should not rewrite as it is a part of another name. + if let Some(prev_char) = col_name.chars().nth(idx - 1) { + if prev_char.is_alphabetic() + || prev_char.is_numeric() + || prev_char == '_' + || prev_char == '.' + { + return rewrite_column_name_in_expr( + col_name, + table_ref_str, + rewrite, + idx + table_ref_str.len(), + ); + } + } + } + + // Check if the next character is alphabetic, numeric or underscore, in which case we + // should not rewrite as it is a part of another name. + if let Some(next_char) = col_name.chars().nth(idx + table_ref_str.len()) { + if next_char.is_alphabetic() || next_char.is_numeric() || next_char == '_' { + return rewrite_column_name_in_expr( + col_name, + table_ref_str, + rewrite, + idx + table_ref_str.len(), + ); + } + } + + // Found full match, replace table_ref_str occurrence with rewrite + let rewritten_name = format!( + "{}{}{}", + &col_name[..idx], + rewrite, + &col_name[idx + table_ref_str.len()..] + ); + + // Check if the rewritten name contains more occurrence of table_ref_str, and rewrite them as well + // This is done by providing the updated start_pos for search + match rewrite_column_name_in_expr(&rewritten_name, table_ref_str, rewrite, idx + rewrite.len()) + { + Some(new_name) => Some(new_name), // more occurrences found + None => Some(rewritten_name), // no more occurrences/changes + } +} + fn rewrite_table_scans_in_expr( expr: Expr, known_rewrites: &mut HashMap, @@ -192,15 +262,21 @@ fn rewrite_table_scans_in_expr( } else { // Check if any of the rewrites match any substring in col.name, and replace that part of the string if so. // This will handles cases like "MAX(foo.df_table.a)" -> "MAX(remote_table.a)" - let rewritten_name = known_rewrites.iter().find_map(|(table_ref, rewrite)| { - let table_ref_str = table_ref.to_string(); - if col.name.contains(&table_ref_str) { - Some(col.name.replace(&table_ref_str, &rewrite.to_string())) - } else { - None - } - }); - if let Some(new_name) = rewritten_name { + let (new_name, was_rewritten) = known_rewrites.iter().fold( + (col.name.to_string(), false), + |(col_name, was_rewritten), (table_ref, rewrite)| { + match rewrite_column_name_in_expr( + &col_name, + &table_ref.to_string(), + &rewrite.to_string(), + 0, + ) { + Some(new_name) => (new_name, true), + None => (col_name, was_rewritten), + } + }, + ); + if was_rewritten { Ok(Expr::Column(Column::new(col.relation.take(), new_name))) } else { Ok(Expr::Column(col)) @@ -696,6 +772,14 @@ mod tests { foo_schema .register_table("df_table".to_string(), get_test_table_provider()) .expect("to register table"); + + let public_schema = catalog + .schema("public") + .expect("public schema should exist"); + public_schema + .register_table("app_table".to_string(), get_test_table_provider()) + .expect("to register table"); + ctx } @@ -720,7 +804,7 @@ mod tests { assert_eq!( format!("{unparsed_sql}"), - r#"SELECT "remote_table"."a", "remote_table"."b", "remote_table"."c" FROM "remote_table""# + r#"SELECT remote_table.a, remote_table.b, remote_table.c FROM remote_table"# ); Ok(()) @@ -742,52 +826,98 @@ mod tests { let agg_tests = vec![ ( "SELECT MAX(a) FROM foo.df_table", - r#"SELECT MAX("remote_table"."a") FROM "remote_table""#, + r#"SELECT MAX(remote_table.a) FROM remote_table"#, ), ( "SELECT MIN(a) FROM foo.df_table", - r#"SELECT MIN("remote_table"."a") FROM "remote_table""#, + r#"SELECT MIN(remote_table.a) FROM remote_table"#, ), ( "SELECT AVG(a) FROM foo.df_table", - r#"SELECT AVG("remote_table"."a") FROM "remote_table""#, + r#"SELECT AVG(remote_table.a) FROM remote_table"#, ), ( "SELECT SUM(a) FROM foo.df_table", - r#"SELECT SUM("remote_table"."a") FROM "remote_table""#, + r#"SELECT SUM(remote_table.a) FROM remote_table"#, ), ( "SELECT COUNT(a) FROM foo.df_table", - r#"SELECT COUNT("remote_table"."a") FROM "remote_table""#, + r#"SELECT COUNT(remote_table.a) FROM remote_table"#, ), ( "SELECT COUNT(a) as cnt FROM foo.df_table", - r#"SELECT COUNT("remote_table"."a") AS "cnt" FROM "remote_table""#, + r#"SELECT COUNT(remote_table.a) AS cnt FROM remote_table"#, + ), + // multiple occurrences of the same table in single aggregation expression + ( + "SELECT COUNT(CASE WHEN a > 0 THEN a ELSE 0 END) FROM app_table", + r#"SELECT COUNT(CASE WHEN (remote_table.a > 0) THEN remote_table.a ELSE 0 END) FROM remote_table"#, + ), + // different tables in single aggregation expression + ( + "SELECT COUNT(CASE WHEN app_table.a > 0 THEN app_table.a ELSE foo.df_table.a END) FROM app_table, foo.df_table", + r#"SELECT COUNT(CASE WHEN (remote_table.a > 0) THEN remote_table.a ELSE remote_table.a END) FROM remote_table JOIN remote_table ON true"#, ), ]; for test in agg_tests { - let data_frame = ctx.sql(test.0).await?; + test_sql(&ctx, test.0, test.1).await?; + } - println!("before optimization: \n{:#?}", data_frame.logical_plan()); + Ok(()) + } - let mut known_rewrites = HashMap::new(); - let rewritten_plan = - rewrite_table_scans(data_frame.logical_plan(), &mut known_rewrites)?; + #[tokio::test] + async fn test_rewrite_table_scans_alias() -> Result<()> { + init_tracing(); + let ctx = get_test_df_context(); - println!("rewritten_plan: \n{:#?}", rewritten_plan); + let tests = vec![ + ( + "SELECT COUNT(app_table_a) FROM (SELECT a as app_table_a FROM app_table)", + r#"SELECT COUNT(app_table_a) FROM (SELECT remote_table.a AS app_table_a FROM remote_table)"#, + ), + ( + "SELECT app_table_a FROM (SELECT a as app_table_a FROM app_table)", + r#"SELECT app_table_a FROM (SELECT remote_table.a AS app_table_a FROM remote_table)"#, + ), + ( + "SELECT aapp_table FROM (SELECT a as aapp_table FROM app_table)", + r#"SELECT aapp_table FROM (SELECT remote_table.a AS aapp_table FROM remote_table)"#, + ), + ]; - let unparsed_sql = plan_to_sql(&rewritten_plan)?; + for test in tests { + test_sql(&ctx, test.0, test.1).await?; + } - println!("unparsed_sql: \n{unparsed_sql}"); + Ok(()) + } - assert_eq!( - format!("{unparsed_sql}"), - test.1, - "SQL under test: {}", - test.0 - ); - } + async fn test_sql( + ctx: &SessionContext, + sql_query: &str, + expected_sql: &str, + ) -> Result<(), datafusion::error::DataFusionError> { + let data_frame = ctx.sql(sql_query).await?; + + println!("before optimization: \n{:#?}", data_frame.logical_plan()); + + let mut known_rewrites = HashMap::new(); + let rewritten_plan = rewrite_table_scans(data_frame.logical_plan(), &mut known_rewrites)?; + + println!("rewritten_plan: \n{:#?}", rewritten_plan); + + let unparsed_sql = plan_to_sql(&rewritten_plan)?; + + println!("unparsed_sql: \n{unparsed_sql}"); + + assert_eq!( + format!("{unparsed_sql}"), + expected_sql, + "SQL under test: {}", + sql_query + ); Ok(()) } From 15d4bd95d0fe09aa963492044944b33ef66b540d Mon Sep 17 00:00:00 2001 From: yfu Date: Mon, 24 Jun 2024 14:56:06 +1000 Subject: [PATCH 21/48] update spiceai datafusion to include new unparsers logic (#8) --- Cargo.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 2a3608a..ae8a922 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,5 +23,5 @@ readme = "README.md" async-trait = "0.1.77" async-stream = "0.3.5" futures = "0.3.30" -datafusion = { git = "https://github.com/spiceai/datafusion.git", rev = "be2c2c1f74823956e609a23ca38657cd76c2fcfe" } -datafusion-substrait = { git = "https://github.com/spiceai/datafusion.git", folder = "datafusion/substrait", rev = "be2c2c1f74823956e609a23ca38657cd76c2fcfe" } \ No newline at end of file +datafusion = { git = "https://github.com/spiceai/datafusion.git", rev = "a155dee3293a5536ca5c5514f3e87884aa32e5ae" } +datafusion-substrait = { git = "https://github.com/spiceai/datafusion.git", folder = "datafusion/substrait", rev = "a155dee3293a5536ca5c5514f3e87884aa32e5ae" } \ No newline at end of file From 776930f0af7cc185fbb8e8ee47f01a83f412d527 Mon Sep 17 00:00:00 2001 From: Sergei Grebnov Date: Tue, 25 Jun 2024 19:24:48 -0700 Subject: [PATCH 22/48] Update datafusion to include time predicate fix (#9) --- Cargo.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index ae8a922..6119d04 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,5 +23,5 @@ readme = "README.md" async-trait = "0.1.77" async-stream = "0.3.5" futures = "0.3.30" -datafusion = { git = "https://github.com/spiceai/datafusion.git", rev = "a155dee3293a5536ca5c5514f3e87884aa32e5ae" } -datafusion-substrait = { git = "https://github.com/spiceai/datafusion.git", folder = "datafusion/substrait", rev = "a155dee3293a5536ca5c5514f3e87884aa32e5ae" } \ No newline at end of file +datafusion = { git = "https://github.com/spiceai/datafusion.git", rev = "f23af338e80e495b77fb38ced23f0ef88a94662c" } +datafusion-substrait = { git = "https://github.com/spiceai/datafusion.git", folder = "datafusion/substrait", rev = "f23af338e80e495b77fb38ced23f0ef88a94662c" } From c3f696adceaad7f579af7f163cb933e08d7f8ba1 Mon Sep 17 00:00:00 2001 From: yfu Date: Thu, 27 Jun 2024 11:38:07 +1000 Subject: [PATCH 23/48] Remove table reference mutating during logical planning, converting query when executing (#10) * keep original qualified_field when wrap_projections * only mutate the table reference when sending plan to execute * formatting * rename method --- datafusion-federation/src/analyzer.rs | 14 +++++++++++++- sources/sql/src/lib.rs | 26 +++++++++++++++----------- 2 files changed, 28 insertions(+), 12 deletions(-) diff --git a/datafusion-federation/src/analyzer.rs b/datafusion-federation/src/analyzer.rs index 1690555..7760734 100644 --- a/datafusion-federation/src/analyzer.rs +++ b/datafusion-federation/src/analyzer.rs @@ -7,6 +7,7 @@ use datafusion::{ error::Result, logical_expr::{Expr, LogicalPlan, Projection, TableScan, TableSource}, optimizer::analyzer::AnalyzerRule, + sql::TableReference, }; use crate::{FederatedTableProviderAdaptor, FederatedTableSource, FederationProviderRef}; @@ -142,7 +143,18 @@ fn wrap_projection(plan: LogicalPlan) -> Result { .schema() .fields() .iter() - .map(|f| Expr::Column(Column::new_unqualified(f.name()))) + .enumerate() + .map(|(i, f)| { + Expr::Column(Column::from_qualified_name(format!( + "{}.{}", + plan.schema() + .qualified_field(i) + .0 + .map(TableReference::table) + .unwrap_or_default(), + f.name() + ))) + }) .collect::>(); Ok(LogicalPlan::Projection(Projection::try_new( expr, diff --git a/sources/sql/src/lib.rs b/sources/sql/src/lib.rs index d89dcfd..e8d1c7f 100644 --- a/sources/sql/src/lib.rs +++ b/sources/sql/src/lib.rs @@ -87,10 +87,6 @@ impl SQLFederationAnalyzerRule { impl AnalyzerRule for SQLFederationAnalyzerRule { fn analyze(&self, plan: LogicalPlan, _config: &ConfigOptions) -> Result { - // Find all table scans, recover the SQLTableSource, find the remote table name and replace the name of the TableScan table. - let mut known_rewrites = HashMap::new(); - let plan = rewrite_table_scans(&plan, &mut known_rewrites)?; - let fed_plan = FederatedPlanNode::new(plan.clone(), Arc::clone(&self.planner)); let ext_node = Extension { node: Arc::new(fed_plan), @@ -631,6 +627,14 @@ impl VirtualExecutionPlan { let df_schema = self.plan.schema().as_ref(); Arc::new(Schema::from(df_schema)) } + + fn sql(&self) -> Result { + // Find all table scans, recover the SQLTableSource, find the remote table name and replace the name of the TableScan table. + let mut known_rewrites = HashMap::new(); + let ast = Unparser::new(self.executor.dialect().as_ref()) + .plan_to_sql(&rewrite_table_scans(&self.plan, &mut known_rewrites)?)?; + Ok(format!("{ast}")) + } } impl DisplayAs for VirtualExecutionPlan { @@ -643,7 +647,12 @@ impl DisplayAs for VirtualExecutionPlan { if let Some(ctx) = self.executor.compute_context() { write!(f, " compute_context={ctx}")?; } - write!(f, " sql={ast}") + write!(f, " sql={ast}")?; + if let Ok(query) = self.sql() { + write!(f, " rewritten_sql={query}")?; + }; + + Ok(()) } } @@ -672,12 +681,7 @@ impl ExecutionPlan for VirtualExecutionPlan { _partition: usize, _context: Arc, ) -> Result { - let dialect = self.executor.dialect(); - let unparser = Unparser::new(dialect.as_ref()); - let ast = unparser.plan_to_sql(&self.plan)?; - let query = format!("{ast}"); - - self.executor.execute(query.as_str(), self.schema()) + self.executor.execute(self.sql()?.as_str(), self.schema()) } fn properties(&self) -> &PlanProperties { From 1becf0c961da12993009f53216ded2099c38d425 Mon Sep 17 00:00:00 2001 From: Phillip LeBlanc Date: Mon, 8 Jul 2024 11:48:22 +0900 Subject: [PATCH 24/48] Update SpiceAI Datafusion fork --- Cargo.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 6119d04..601e292 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,5 +23,5 @@ readme = "README.md" async-trait = "0.1.77" async-stream = "0.3.5" futures = "0.3.30" -datafusion = { git = "https://github.com/spiceai/datafusion.git", rev = "f23af338e80e495b77fb38ced23f0ef88a94662c" } -datafusion-substrait = { git = "https://github.com/spiceai/datafusion.git", folder = "datafusion/substrait", rev = "f23af338e80e495b77fb38ced23f0ef88a94662c" } +datafusion = { git = "https://github.com/spiceai/datafusion.git", rev = "f865563b1ae8670df21e9a054c5307112f85fbf6" } +datafusion-substrait = { git = "https://github.com/spiceai/datafusion.git", folder = "datafusion/substrait", rev = "f865563b1ae8670df21e9a054c5307112f85fbf6" } From 77548d3e2932516f9bfd0719f2e543760a87d699 Mon Sep 17 00:00:00 2001 From: Phillip LeBlanc Date: Mon, 15 Jul 2024 10:03:08 +0900 Subject: [PATCH 25/48] Upgrade to DF 40 --- .gitignore | 1 + Cargo.toml | 4 ++-- datafusion-federation/src/schema_cast.rs | 4 ++++ sources/sql/src/lib.rs | 11 ++++++----- 4 files changed, 13 insertions(+), 7 deletions(-) diff --git a/.gitignore b/.gitignore index 66b190e..c92a89f 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ /node_modules package-lock.json package.json +.DS_Store \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index 601e292..ce6084d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,5 +23,5 @@ readme = "README.md" async-trait = "0.1.77" async-stream = "0.3.5" futures = "0.3.30" -datafusion = { git = "https://github.com/spiceai/datafusion.git", rev = "f865563b1ae8670df21e9a054c5307112f85fbf6" } -datafusion-substrait = { git = "https://github.com/spiceai/datafusion.git", folder = "datafusion/substrait", rev = "f865563b1ae8670df21e9a054c5307112f85fbf6" } +datafusion = "40.0.0" +datafusion-substrait = "40.0.0" diff --git a/datafusion-federation/src/schema_cast.rs b/datafusion-federation/src/schema_cast.rs index 53038f2..31f52ca 100644 --- a/datafusion-federation/src/schema_cast.rs +++ b/datafusion-federation/src/schema_cast.rs @@ -46,6 +46,10 @@ impl DisplayAs for SchemaCastScanExec { } impl ExecutionPlan for SchemaCastScanExec { + fn name(&self) -> &str { + "SchemaCastScanExec" + } + fn as_any(&self) -> &dyn Any { self } diff --git a/sources/sql/src/lib.rs b/sources/sql/src/lib.rs index e8d1c7f..be678c3 100644 --- a/sources/sql/src/lib.rs +++ b/sources/sql/src/lib.rs @@ -517,12 +517,9 @@ fn rewrite_table_scans_in_expr( ))) } Expr::Wildcard { qualifier } => { - if let Some(rewrite) = qualifier - .as_ref() - .and_then(|q| known_rewrites.get(&TableReference::from(q))) - { + if let Some(rewrite) = qualifier.as_ref().and_then(|q| known_rewrites.get(q)) { Ok(Expr::Wildcard { - qualifier: Some(rewrite.clone().to_string()), + qualifier: Some(rewrite.clone()), }) } else { Ok(Expr::Wildcard { qualifier }) @@ -657,6 +654,10 @@ impl DisplayAs for VirtualExecutionPlan { } impl ExecutionPlan for VirtualExecutionPlan { + fn name(&self) -> &str { + "VirtualExecutionPlan" + } + fn as_any(&self) -> &dyn Any { self } From f430401ef978c3687a1c8d10342c5032ce842952 Mon Sep 17 00:00:00 2001 From: Phillip LeBlanc Date: Mon, 15 Jul 2024 11:29:11 +0900 Subject: [PATCH 26/48] Bump datafusion-federation version to 0.1.5 --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index ce6084d..0962ed3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,7 @@ members = [ # datafusion = { path = "../arrow-datafusion/datafusion/core" } [workspace.package] -version = "0.1.3" +version = "0.1.5" edition = "2021" license = "Apache-2.0" readme = "README.md" From eeb9b9c0ed41650db282ba27bc663feb64e62147 Mon Sep 17 00:00:00 2001 From: yfu Date: Tue, 16 Jul 2024 05:02:19 +1000 Subject: [PATCH 27/48] do not over eager rewrite column when col relation is there (#12) * do not over eager rewrite column when col relation is there * add more tests * Update sources/sql/src/lib.rs Co-authored-by: Phillip LeBlanc --------- Co-authored-by: Phillip LeBlanc --- sources/sql/src/lib.rs | 36 +++++++++++++++++++++++++++++------- 1 file changed, 29 insertions(+), 7 deletions(-) diff --git a/sources/sql/src/lib.rs b/sources/sql/src/lib.rs index be678c3..0d14b92 100644 --- a/sources/sql/src/lib.rs +++ b/sources/sql/src/lib.rs @@ -256,6 +256,12 @@ fn rewrite_table_scans_in_expr( if let Some(rewrite) = col.relation.as_ref().and_then(|r| known_rewrites.get(r)) { Ok(Expr::Column(Column::new(Some(rewrite.clone()), &col.name))) } else { + // This prevent over-eager rewrite and only pass the column into below rewritten + // rule like MAX(...) + if col.relation.is_some() { + return Ok(Expr::Column(col)); + } + // Check if any of the rewrites match any substring in col.name, and replace that part of the string if so. // This will handles cases like "MAX(foo.df_table.a)" -> "MAX(remote_table.a)" let (new_name, was_rewritten) = known_rewrites.iter().fold( @@ -833,35 +839,51 @@ mod tests { "SELECT MAX(a) FROM foo.df_table", r#"SELECT MAX(remote_table.a) FROM remote_table"#, ), + ( + "SELECT foo.df_table.a FROM foo.df_table", + r#"SELECT remote_table.a FROM remote_table"#, + ), ( "SELECT MIN(a) FROM foo.df_table", r#"SELECT MIN(remote_table.a) FROM remote_table"#, ), ( "SELECT AVG(a) FROM foo.df_table", - r#"SELECT AVG(remote_table.a) FROM remote_table"#, + r#"SELECT avg(remote_table.a) FROM remote_table"#, ), ( "SELECT SUM(a) FROM foo.df_table", - r#"SELECT SUM(remote_table.a) FROM remote_table"#, + r#"SELECT sum(remote_table.a) FROM remote_table"#, ), ( "SELECT COUNT(a) FROM foo.df_table", - r#"SELECT COUNT(remote_table.a) FROM remote_table"#, + r#"SELECT count(remote_table.a) FROM remote_table"#, ), ( "SELECT COUNT(a) as cnt FROM foo.df_table", - r#"SELECT COUNT(remote_table.a) AS cnt FROM remote_table"#, + r#"SELECT count(remote_table.a) AS cnt FROM remote_table"#, + ), + ( + "SELECT COUNT(a) as cnt FROM foo.df_table", + r#"SELECT count(remote_table.a) AS cnt FROM remote_table"#, + ), + ( + "SELECT app_table from (SELECT a as app_table FROM app_table) b", + r#"SELECT b.app_table FROM (SELECT remote_table.a AS app_table FROM remote_table) AS b"#, + ), + ( + "SELECT MAX(app_table) from (SELECT a as app_table FROM app_table) b", + r#"SELECT MAX(b.app_table) FROM (SELECT remote_table.a AS app_table FROM remote_table) AS b"#, ), // multiple occurrences of the same table in single aggregation expression ( "SELECT COUNT(CASE WHEN a > 0 THEN a ELSE 0 END) FROM app_table", - r#"SELECT COUNT(CASE WHEN (remote_table.a > 0) THEN remote_table.a ELSE 0 END) FROM remote_table"#, + r#"SELECT count(CASE WHEN (remote_table.a > 0) THEN remote_table.a ELSE 0 END) FROM remote_table"#, ), // different tables in single aggregation expression ( "SELECT COUNT(CASE WHEN app_table.a > 0 THEN app_table.a ELSE foo.df_table.a END) FROM app_table, foo.df_table", - r#"SELECT COUNT(CASE WHEN (remote_table.a > 0) THEN remote_table.a ELSE remote_table.a END) FROM remote_table JOIN remote_table ON true"#, + r#"SELECT count(CASE WHEN (remote_table.a > 0) THEN remote_table.a ELSE remote_table.a END) FROM remote_table JOIN remote_table ON true"#, ), ]; @@ -880,7 +902,7 @@ mod tests { let tests = vec![ ( "SELECT COUNT(app_table_a) FROM (SELECT a as app_table_a FROM app_table)", - r#"SELECT COUNT(app_table_a) FROM (SELECT remote_table.a AS app_table_a FROM remote_table)"#, + r#"SELECT count(app_table_a) FROM (SELECT remote_table.a AS app_table_a FROM remote_table)"#, ), ( "SELECT app_table_a FROM (SELECT a as app_table_a FROM app_table)", From 9b4392f45546dd3a1232c18ad36189bcc3cd408d Mon Sep 17 00:00:00 2001 From: Phillip LeBlanc Date: Wed, 14 Aug 2024 16:27:05 +0900 Subject: [PATCH 28/48] Upgrade to DataFusion v41 --- Cargo.toml | 6 +++--- datafusion-federation/src/table_provider.rs | 6 +++--- sources/sql/src/lib.rs | 15 ++++++++------- sources/sql/src/schema.rs | 3 +-- 4 files changed, 15 insertions(+), 15 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 0962ed3..ec0960f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,7 @@ members = [ # datafusion = { path = "../arrow-datafusion/datafusion/core" } [workspace.package] -version = "0.1.5" +version = "0.1.6" edition = "2021" license = "Apache-2.0" readme = "README.md" @@ -23,5 +23,5 @@ readme = "README.md" async-trait = "0.1.77" async-stream = "0.3.5" futures = "0.3.30" -datafusion = "40.0.0" -datafusion-substrait = "40.0.0" +datafusion = "41.0.0" +datafusion-substrait = "41.0.0" diff --git a/datafusion-federation/src/table_provider.rs b/datafusion-federation/src/table_provider.rs index a1acc30..b820b6e 100644 --- a/datafusion-federation/src/table_provider.rs +++ b/datafusion-federation/src/table_provider.rs @@ -3,10 +3,10 @@ use std::{any::Any, sync::Arc}; use async_trait::async_trait; use datafusion::{ arrow::datatypes::SchemaRef, + catalog::Session, common::Constraints, datasource::TableProvider, error::{DataFusionError, Result}, - execution::context::SessionState, logical_expr::{Expr, LogicalPlan, TableProviderFilterPushDown, TableSource, TableType}, physical_plan::ExecutionPlan, }; @@ -106,7 +106,7 @@ impl TableProvider for FederatedTableProviderAdaptor { // with a virtual TableProvider that provides federation for a sub-plan. async fn scan( &self, - state: &SessionState, + state: &dyn Session, projection: Option<&Vec>, filters: &[Expr], limit: Option, @@ -122,7 +122,7 @@ impl TableProvider for FederatedTableProviderAdaptor { async fn insert_into( &self, - _state: &SessionState, + _state: &dyn Session, input: Arc, overwrite: bool, ) -> Result> { diff --git a/sources/sql/src/lib.rs b/sources/sql/src/lib.rs index 0d14b92..e563576 100644 --- a/sources/sql/src/lib.rs +++ b/sources/sql/src/lib.rs @@ -447,7 +447,7 @@ fn rewrite_table_scans_in_expr( }) .transpose()?; Ok(Expr::AggregateFunction(AggregateFunction { - func_def: af.func_def, + func: af.func, args, distinct: af.distinct, filter, @@ -471,14 +471,14 @@ fn rewrite_table_scans_in_expr( .into_iter() .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) .collect::>>()?; - Ok(Expr::WindowFunction(WindowFunction::new( - wf.fun, + Ok(Expr::WindowFunction(WindowFunction { + fun: wf.fun, args, partition_by, order_by, - wf.window_frame, - wf.null_treatment, - ))) + window_frame: wf.window_frame, + null_treatment: wf.null_treatment, + })) } Expr::InList(il) => { let expr = rewrite_table_scans_in_expr(*il.expr, known_rewrites)?; @@ -700,7 +700,8 @@ impl ExecutionPlan for VirtualExecutionPlan { mod tests { use datafusion::{ arrow::datatypes::{DataType, Field}, - catalog::schema::{MemorySchemaProvider, SchemaProvider}, + catalog::SchemaProvider, + catalog_common::MemorySchemaProvider, common::Column, datasource::{DefaultTableSource, TableProvider}, error::DataFusionError, diff --git a/sources/sql/src/schema.rs b/sources/sql/src/schema.rs index e417961..aa23fd0 100644 --- a/sources/sql/src/schema.rs +++ b/sources/sql/src/schema.rs @@ -2,8 +2,7 @@ use async_trait::async_trait; use datafusion::logical_expr::{TableSource, TableType}; use datafusion::{ - arrow::datatypes::SchemaRef, catalog::schema::SchemaProvider, datasource::TableProvider, - error::Result, + arrow::datatypes::SchemaRef, catalog::SchemaProvider, datasource::TableProvider, error::Result, }; use futures::future::join_all; use std::{any::Any, sync::Arc}; From ae9da45f8b90451119f3ca032b75ad82a28d7547 Mon Sep 17 00:00:00 2001 From: Phillip LeBlanc Date: Thu, 15 Aug 2024 11:39:10 +0900 Subject: [PATCH 29/48] Use `Display` for displaying federated plans --- datafusion-federation/src/plan_node.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion-federation/src/plan_node.rs b/datafusion-federation/src/plan_node.rs index 0d29211..5359107 100644 --- a/datafusion-federation/src/plan_node.rs +++ b/datafusion-federation/src/plan_node.rs @@ -54,7 +54,7 @@ impl UserDefinedLogicalNodeCore for FederatedPlanNode { } fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "Federated\n {:?}", self.plan) + write!(f, "Federated\n {}", self.plan) } fn with_exprs_and_inputs(&self, exprs: Vec, inputs: Vec) -> Result { From 21cfcc8cbc668b33235df52f16ea002f72efc74d Mon Sep 17 00:00:00 2001 From: Phillip LeBlanc Date: Thu, 22 Aug 2024 15:14:18 +0900 Subject: [PATCH 30/48] Make the `FederatedPlanner` public for use in custom query planners (#14) --- datafusion-federation/src/plan_node.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion-federation/src/plan_node.rs b/datafusion-federation/src/plan_node.rs index 5359107..c99204f 100644 --- a/datafusion-federation/src/plan_node.rs +++ b/datafusion-federation/src/plan_node.rs @@ -121,7 +121,7 @@ impl Hash for FederatedPlanNode { } #[derive(Default)] -struct FederatedPlanner {} +pub struct FederatedPlanner {} impl FederatedPlanner { pub fn new() -> Self { From 2cb333011ae704b0e808e168d489ad303ab8e428 Mon Sep 17 00:00:00 2001 From: Sergei Grebnov Date: Sun, 25 Aug 2024 05:07:30 +0300 Subject: [PATCH 31/48] Add List types parsing support to schema cast (#13) * Add List types parsing support to schema cast * Update to use arrow-json for lists parsing --- Cargo.toml | 1 + datafusion-federation/Cargo.toml | 1 + datafusion-federation/src/schema_cast.rs | 3 +- .../src/schema_cast/lists_cast.rs | 619 ++++++++++++++++++ .../src/schema_cast/record_convert.rs | 35 +- 5 files changed, 651 insertions(+), 8 deletions(-) create mode 100644 datafusion-federation/src/schema_cast/lists_cast.rs diff --git a/Cargo.toml b/Cargo.toml index ec0960f..bced60f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,3 +25,4 @@ async-stream = "0.3.5" futures = "0.3.30" datafusion = "41.0.0" datafusion-substrait = "41.0.0" +arrow-json = "52.2.0" diff --git a/datafusion-federation/Cargo.toml b/datafusion-federation/Cargo.toml index 2975a0d..f344a71 100644 --- a/datafusion-federation/Cargo.toml +++ b/datafusion-federation/Cargo.toml @@ -15,6 +15,7 @@ async-trait.workspace = true datafusion.workspace = true async-stream.workspace = true futures.workspace = true +arrow-json.workspace = true [package.metadata.docs.rs] diff --git a/datafusion-federation/src/schema_cast.rs b/datafusion-federation/src/schema_cast.rs index 31f52ca..41d14b5 100644 --- a/datafusion-federation/src/schema_cast.rs +++ b/datafusion-federation/src/schema_cast.rs @@ -12,7 +12,8 @@ use std::clone::Clone; use std::fmt; use std::sync::Arc; -mod record_convert; +pub mod record_convert; +mod lists_cast; #[derive(Debug)] #[allow(clippy::module_name_repetitions)] diff --git a/datafusion-federation/src/schema_cast/lists_cast.rs b/datafusion-federation/src/schema_cast/lists_cast.rs new file mode 100644 index 0000000..9a63b28 --- /dev/null +++ b/datafusion-federation/src/schema_cast/lists_cast.rs @@ -0,0 +1,619 @@ +use arrow_json::ReaderBuilder; +use datafusion::arrow::{ + array::{ + Array, ArrayRef, BooleanArray, BooleanBuilder, FixedSizeListBuilder, Float32Array, + Float32Builder, Float64Array, Float64Builder, Int16Array, Int16Builder, Int32Array, + Int32Builder, Int64Array, Int64Builder, Int8Array, Int8Builder, LargeListBuilder, + LargeStringArray, LargeStringBuilder, ListArray, ListBuilder, StringArray, StringBuilder, + }, + datatypes::{DataType, Field, FieldRef}, + error::ArrowError, +}; +use std::sync::Arc; + +pub type Result = std::result::Result; + +macro_rules! cast_string_to_list_array { + ($string_array:expr, $field_name:expr, $data_type:expr, $builder_type:expr, $primitive_type:ty) => {{ + let item_field = Arc::new(Field::new($field_name, $data_type, true)); + let mut list_builder = ListBuilder::with_capacity($builder_type, $string_array.len()) + .with_field(Arc::clone(&item_field)); + + let list_field = Arc::new(Field::new_list("i", item_field, true)); + let mut decoder = ReaderBuilder::new_with_field(Arc::clone(&list_field)) + .build_decoder() + .map_err(|e| ArrowError::CastError(format!("Failed to create decoder: {e}")))?; + + for value in $string_array { + match value { + None => list_builder.append_null(), + Some(string_value) => { + decoder.decode(string_value.as_bytes()).map_err(|e| { + ArrowError::CastError(format!("Failed to decode value: {e}")) + })?; + + if let Some(b) = decoder.flush().map_err(|e| { + ArrowError::CastError(format!("Failed to decode decoder: {e}")) + })? { + let list_array = b + .column(0) + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::CastError( + "Failed to decode value: unable to downcast to ListArray" + .to_string(), + ) + })?; + let primitive_array = list_array + .values() + .as_any() + .downcast_ref::<$primitive_type>() + .ok_or_else(|| { + ArrowError::CastError( + "Failed to decode value: unable to downcast to PrimitiveType" + .to_string(), + ) + })?; + primitive_array + .iter() + .for_each(|maybe_value| match maybe_value { + Some(value) => list_builder.values().append_value(value), + None => list_builder.values().append_null(), + }); + list_builder.append(true); + } + } + } + } + + Ok(Arc::new(list_builder.finish())) + }}; +} + +macro_rules! cast_string_to_large_list_array { + ($string_array:expr, $field_name:expr, $data_type:expr, $builder_type:expr, $primitive_type:ty) => {{ + let item_field = Arc::new(Field::new($field_name, $data_type, true)); + let mut list_builder = LargeListBuilder::with_capacity($builder_type, $string_array.len()) + .with_field(Arc::clone(&item_field)); + + let list_field = Arc::new(Field::new_list("i", item_field, true)); + let mut decoder = ReaderBuilder::new_with_field(Arc::clone(&list_field)) + .build_decoder() + .map_err(|e| ArrowError::CastError(format!("Failed to create decoder: {e}")))?; + + for value in $string_array { + match value { + None => list_builder.append_null(), + Some(string_value) => { + decoder.decode(string_value.as_bytes()).map_err(|e| { + ArrowError::CastError(format!("Failed to decode value: {e}")) + })?; + + if let Some(b) = decoder.flush().map_err(|e| { + ArrowError::CastError(format!("Failed to decode decoder: {e}")) + })? { + let list_array = b + .column(0) + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::CastError( + "Failed to decode value: unable to downcast to ListArray" + .to_string(), + ) + })?; + let primitive_array = list_array + .values() + .as_any() + .downcast_ref::<$primitive_type>() + .ok_or_else(|| { + ArrowError::CastError( + "Failed to decode value: unable to downcast to PrimitiveType" + .to_string(), + ) + })?; + primitive_array + .iter() + .for_each(|maybe_value| match maybe_value { + Some(value) => list_builder.values().append_value(value), + None => list_builder.values().append_null(), + }); + list_builder.append(true); + } + } + } + } + + Ok(Arc::new(list_builder.finish())) + }}; +} + +macro_rules! cast_string_to_fixed_size_list_array { + ($string_array:expr, $field_name:expr, $data_type:expr, $builder_type:expr, $primitive_type:ty, $value_length:expr) => {{ + let item_field = Arc::new(Field::new($field_name, $data_type, true)); + let mut list_builder = + FixedSizeListBuilder::with_capacity($builder_type, $value_length, $string_array.len()) + .with_field(Arc::clone(&item_field)); + + let list_field = Arc::new(Field::new_list("i", item_field, true)); + let mut decoder = ReaderBuilder::new_with_field(Arc::clone(&list_field)) + .build_decoder() + .map_err(|e| ArrowError::CastError(format!("Failed to create decoder: {e}")))?; + + for value in $string_array { + match value { + None => { + for _ in 0..$value_length { + list_builder.values().append_null() + } + list_builder.append(true) + } + Some(string_value) => { + decoder.decode(string_value.as_bytes()).map_err(|e| { + ArrowError::CastError(format!("Failed to decode value: {e}")) + })?; + + if let Some(b) = decoder.flush().map_err(|e| { + ArrowError::CastError(format!("Failed to decode decoder: {e}")) + })? { + let list_array = b + .column(0) + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::CastError( + "Failed to decode value: unable to downcast to ListArray" + .to_string(), + ) + })?; + let primitive_array = list_array + .values() + .as_any() + .downcast_ref::<$primitive_type>() + .ok_or_else(|| { + ArrowError::CastError( + "Failed to decode value: unable to downcast to PrimitiveType" + .to_string(), + ) + })?; + primitive_array + .iter() + .for_each(|maybe_value| match maybe_value { + Some(value) => list_builder.values().append_value(value), + None => list_builder.values().append_null(), + }); + list_builder.append(true); + } + } + } + } + + Ok(Arc::new(list_builder.finish())) + }}; +} + +pub(crate) fn cast_string_to_list( + array: &dyn Array, + list_item_field: &FieldRef, +) -> Result { + let string_array = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::CastError( + "Failed to decode value: unable to downcast to StringArray".to_string(), + ) + })?; + + let field_name = list_item_field.name(); + + match list_item_field.data_type() { + DataType::Utf8 => { + cast_string_to_list_array!( + string_array, + field_name, + DataType::Utf8, + StringBuilder::new(), + StringArray + ) + } + DataType::LargeUtf8 => { + cast_string_to_list_array!( + string_array, + field_name, + DataType::LargeUtf8, + LargeStringBuilder::new(), + LargeStringArray + ) + } + DataType::Boolean => { + cast_string_to_list_array!( + string_array, + field_name, + DataType::Boolean, + BooleanBuilder::new(), + BooleanArray + ) + } + DataType::Int8 => { + cast_string_to_list_array!( + string_array, + field_name, + DataType::Int8, + Int8Builder::new(), + Int8Array + ) + } + DataType::Int16 => { + cast_string_to_list_array!( + string_array, + field_name, + DataType::Int16, + Int16Builder::new(), + Int16Array + ) + } + DataType::Int32 => { + cast_string_to_list_array!( + string_array, + field_name, + DataType::Int32, + Int32Builder::new(), + Int32Array + ) + } + DataType::Int64 => { + cast_string_to_list_array!( + string_array, + field_name, + DataType::Int64, + Int64Builder::new(), + Int64Array + ) + } + DataType::Float32 => { + cast_string_to_list_array!( + string_array, + field_name, + DataType::Float32, + Float32Builder::new(), + Float32Array + ) + } + DataType::Float64 => { + cast_string_to_list_array!( + string_array, + field_name, + DataType::Float64, + Float64Builder::new(), + Float64Array + ) + } + _ => Err(ArrowError::CastError(format!( + "Unsupported list item type: {}", + list_item_field.data_type() + ))), + } +} + +pub(crate) fn cast_string_to_large_list( + array: &dyn Array, + list_item_field: &FieldRef, +) -> Result { + let string_array = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::CastError( + "Failed to decode value: unable to downcast to StringArray".to_string(), + ) + })?; + + let field_name = list_item_field.name(); + + match list_item_field.data_type() { + DataType::Utf8 => { + cast_string_to_large_list_array!( + string_array, + field_name, + DataType::Utf8, + StringBuilder::new(), + StringArray + ) + } + DataType::LargeUtf8 => { + cast_string_to_large_list_array!( + string_array, + field_name, + DataType::LargeUtf8, + LargeStringBuilder::new(), + LargeStringArray + ) + } + DataType::Boolean => { + cast_string_to_large_list_array!( + string_array, + field_name, + DataType::Boolean, + BooleanBuilder::new(), + BooleanArray + ) + } + DataType::Int8 => { + cast_string_to_large_list_array!( + string_array, + field_name, + DataType::Int8, + Int8Builder::new(), + Int8Array + ) + } + DataType::Int16 => { + cast_string_to_large_list_array!( + string_array, + field_name, + DataType::Int16, + Int16Builder::new(), + Int16Array + ) + } + DataType::Int32 => { + cast_string_to_large_list_array!( + string_array, + field_name, + DataType::Int32, + Int32Builder::new(), + Int32Array + ) + } + DataType::Int64 => { + cast_string_to_large_list_array!( + string_array, + field_name, + DataType::Int64, + Int64Builder::new(), + Int64Array + ) + } + DataType::Float32 => { + cast_string_to_large_list_array!( + string_array, + field_name, + DataType::Float32, + Float32Builder::new(), + Float32Array + ) + } + DataType::Float64 => { + cast_string_to_large_list_array!( + string_array, + field_name, + DataType::Float64, + Float64Builder::new(), + Float64Array + ) + } + _ => Err(ArrowError::CastError(format!( + "Unsupported list item type: {}", + list_item_field.data_type() + ))), + } +} + +pub(crate) fn cast_string_to_fixed_size_list( + array: &dyn Array, + list_item_field: &FieldRef, + value_length: i32, +) -> Result { + let string_array = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::CastError( + "Failed to decode value: unable to downcast to StringArray".to_string(), + ) + })?; + + let field_name = list_item_field.name(); + + match list_item_field.data_type() { + DataType::Utf8 => { + cast_string_to_fixed_size_list_array!( + string_array, + field_name, + DataType::Utf8, + StringBuilder::new(), + StringArray, + value_length + ) + } + DataType::LargeUtf8 => { + cast_string_to_fixed_size_list_array!( + string_array, + field_name, + DataType::LargeUtf8, + LargeStringBuilder::new(), + LargeStringArray, + value_length + ) + } + DataType::Boolean => { + cast_string_to_fixed_size_list_array!( + string_array, + field_name, + DataType::Boolean, + BooleanBuilder::new(), + BooleanArray, + value_length + ) + } + DataType::Int8 => { + cast_string_to_fixed_size_list_array!( + string_array, + field_name, + DataType::Int8, + Int8Builder::new(), + Int8Array, + value_length + ) + } + DataType::Int16 => { + cast_string_to_fixed_size_list_array!( + string_array, + field_name, + DataType::Int16, + Int16Builder::new(), + Int16Array, + value_length + ) + } + DataType::Int32 => { + cast_string_to_fixed_size_list_array!( + string_array, + field_name, + DataType::Int32, + Int32Builder::new(), + Int32Array, + value_length + ) + } + DataType::Int64 => { + cast_string_to_fixed_size_list_array!( + string_array, + field_name, + DataType::Int64, + Int64Builder::new(), + Int64Array, + value_length + ) + } + DataType::Float32 => { + cast_string_to_fixed_size_list_array!( + string_array, + field_name, + DataType::Float32, + Float32Builder::new(), + Float32Array, + value_length + ) + } + DataType::Float64 => { + cast_string_to_fixed_size_list_array!( + string_array, + field_name, + DataType::Float64, + Float64Builder::new(), + Float64Array, + value_length + ) + } + _ => Err(ArrowError::CastError(format!( + "Unsupported list item type: {}", + list_item_field.data_type() + ))), + } +} + +#[cfg(test)] +mod test { + use datafusion::arrow::{ + array::{RecordBatch, StringArray}, + datatypes::{DataType, Field, Schema, SchemaRef}, + }; + + use crate::schema_cast::record_convert::try_cast_to; + + use super::*; + + fn input_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + Field::new("c", DataType::Utf8, false), + ])) + } + + fn output_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new( + "a", + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + false, + ), + Field::new( + "b", + DataType::LargeList(Arc::new(Field::new("item", DataType::Utf8, true))), + false, + ), + Field::new( + "c", + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Boolean, true)), 3), + false, + ), + ])) + } + + fn batch_input() -> RecordBatch { + RecordBatch::try_new( + input_schema(), + vec![ + Arc::new(StringArray::from(vec![ + Some("[1, 2, 3]"), + Some("[4, 5, 6]"), + ])), + Arc::new(StringArray::from(vec![ + Some("[\"foo\", \"bar\"]"), + Some("[\"baz\", \"qux\"]"), + ])), + Arc::new(StringArray::from(vec![ + Some("[true, false, true]"), + Some("[false, true, false]"), + ])), + ], + ) + .expect("record batch should not panic") + } + + fn batch_expected() -> RecordBatch { + let mut list_builder = ListBuilder::new(Int32Builder::new()); + list_builder.append_value([Some(1), Some(2), Some(3)]); + list_builder.append_value([Some(4), Some(5), Some(6)]); + let list_array = list_builder.finish(); + + let mut large_list_builder = LargeListBuilder::new(StringBuilder::new()); + large_list_builder.append_value([Some("foo"), Some("bar")]); + large_list_builder.append_value([Some("baz"), Some("qux")]); + let large_list_array = large_list_builder.finish(); + + let mut fixed_size_list_builder = FixedSizeListBuilder::new(BooleanBuilder::new(), 3); + fixed_size_list_builder.values().append_value(true); + fixed_size_list_builder.values().append_value(false); + fixed_size_list_builder.values().append_value(true); + fixed_size_list_builder.append(true); + fixed_size_list_builder.values().append_value(false); + fixed_size_list_builder.values().append_value(true); + fixed_size_list_builder.values().append_value(false); + fixed_size_list_builder.append(true); + let fixed_size_list_array = fixed_size_list_builder.finish(); + + RecordBatch::try_new( + output_schema(), + vec![ + Arc::new(list_array), + Arc::new(large_list_array), + Arc::new(fixed_size_list_array), + ], + ) + .expect("Failed to create expected RecordBatch") + } + + #[test] + fn test_cast_to_list_largelist_fixedsizelist() { + let input_batch = batch_input(); + let expected = batch_expected(); + let actual = try_cast_to(input_batch, output_schema()).expect("cast should succeed"); + + assert_eq!(actual, expected); + } +} diff --git a/datafusion-federation/src/schema_cast/record_convert.rs b/datafusion-federation/src/schema_cast/record_convert.rs index 95563e3..17d6d23 100644 --- a/datafusion-federation/src/schema_cast/record_convert.rs +++ b/datafusion-federation/src/schema_cast/record_convert.rs @@ -1,10 +1,14 @@ use datafusion::arrow::{ array::{Array, RecordBatch}, compute::cast, - datatypes::SchemaRef, + datatypes::{DataType, SchemaRef}, }; use std::sync::Arc; +use super::lists_cast::{ + cast_string_to_fixed_size_list, cast_string_to_large_list, cast_string_to_list, +}; + pub type Result = std::result::Result; #[derive(Debug)] @@ -41,10 +45,7 @@ impl std::fmt::Display for Error { /// Cast a given record batch into a new record batch with the given schema. /// It assumes the record batch columns are correctly ordered. #[allow(clippy::needless_pass_by_value)] -pub(crate) fn try_cast_to( - record_batch: RecordBatch, - expected_schema: SchemaRef, -) -> Result { +pub fn try_cast_to(record_batch: RecordBatch, expected_schema: SchemaRef) -> Result { let actual_schema = record_batch.schema(); if actual_schema.fields().len() != expected_schema.fields().len() { @@ -61,8 +62,28 @@ pub(crate) fn try_cast_to( .map(|(i, expected_field)| { let record_batch_col = record_batch.column(i); - return cast(&Arc::clone(record_batch_col), expected_field.data_type()) - .map_err(|e| Error::UnableToConvertRecordBatch { source: e }); + match (record_batch_col.data_type(), expected_field.data_type()) { + (DataType::Utf8, DataType::List(item_type)) => { + return cast_string_to_list(&Arc::clone(record_batch_col), item_type) + .map_err(|e| Error::UnableToConvertRecordBatch { source: e }); + } + (DataType::Utf8, DataType::LargeList(item_type)) => { + return cast_string_to_large_list(&Arc::clone(record_batch_col), item_type) + .map_err(|e| Error::UnableToConvertRecordBatch { source: e }); + } + (DataType::Utf8, DataType::FixedSizeList(item_type, value_length)) => { + return cast_string_to_fixed_size_list( + &Arc::clone(record_batch_col), + item_type, + value_length.clone(), + ) + .map_err(|e| Error::UnableToConvertRecordBatch { source: e }); + } + _ => { + return cast(&Arc::clone(record_batch_col), expected_field.data_type()) + .map_err(|e| Error::UnableToConvertRecordBatch { source: e }); + } + } }) .collect::>>>()?; From e0a4eaef873949c4f4bb3edf64e411821e62dcb2 Mon Sep 17 00:00:00 2001 From: Qianqian <130200611+Sevenannn@users.noreply.github.com> Date: Sat, 24 Aug 2024 19:18:00 -0700 Subject: [PATCH 32/48] Add interval cast given original schema (#15) * Add List types parsing support to schema cast * Update to use arrow-json for lists parsing * Add interval cast given original schema * Fix test --------- Co-authored-by: sgrebnov Co-authored-by: Phillip LeBlanc --- datafusion-federation/src/schema_cast.rs | 3 +- .../src/schema_cast/intervals_cast.rs | 177 ++++++++++++++++++ .../src/schema_cast/record_convert.rs | 37 ++-- sources/sql/src/lib.rs | 6 +- 4 files changed, 206 insertions(+), 17 deletions(-) create mode 100644 datafusion-federation/src/schema_cast/intervals_cast.rs diff --git a/datafusion-federation/src/schema_cast.rs b/datafusion-federation/src/schema_cast.rs index 41d14b5..ade8091 100644 --- a/datafusion-federation/src/schema_cast.rs +++ b/datafusion-federation/src/schema_cast.rs @@ -12,8 +12,9 @@ use std::clone::Clone; use std::fmt; use std::sync::Arc; -pub mod record_convert; +mod intervals_cast; mod lists_cast; +pub mod record_convert; #[derive(Debug)] #[allow(clippy::module_name_repetitions)] diff --git a/datafusion-federation/src/schema_cast/intervals_cast.rs b/datafusion-federation/src/schema_cast/intervals_cast.rs new file mode 100644 index 0000000..3e445c5 --- /dev/null +++ b/datafusion-federation/src/schema_cast/intervals_cast.rs @@ -0,0 +1,177 @@ +use datafusion::arrow::{ + array::{ + Array, ArrayRef, IntervalDayTimeBuilder, IntervalMonthDayNanoArray, + IntervalYearMonthBuilder, + }, + datatypes::{IntervalDayTimeType, IntervalYearMonthType}, + error::ArrowError, +}; +use std::sync::Arc; + +pub(crate) fn cast_interval_monthdaynano_to_yearmonth( + interval_monthdaynano_array: &dyn Array, +) -> Result { + let interval_monthdaynano_array = interval_monthdaynano_array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::CastError("Failed to cast IntervalMonthDayNanoArray: Unable to downcast to IntervalMonthDayNanoArray".to_string()) + })?; + + let mut interval_yearmonth_builder = + IntervalYearMonthBuilder::with_capacity(interval_monthdaynano_array.len()); + + for value in interval_monthdaynano_array { + match value { + None => interval_yearmonth_builder.append_null(), + Some(interval_monthdaynano_value) => { + if interval_monthdaynano_value.days != 0 + || interval_monthdaynano_value.nanoseconds != 0 + { + return Err(ArrowError::CastError( + "Failed to cast IntervalMonthDayNanoArray to IntervalYearMonthArray: Non-zero days or nanoseconds".to_string(), + )); + } + interval_yearmonth_builder.append_value(IntervalYearMonthType::make_value( + 0, + interval_monthdaynano_value.months, + )); + } + } + } + + Ok(Arc::new(interval_yearmonth_builder.finish())) +} + +#[allow(clippy::cast_possible_truncation)] +pub(crate) fn cast_interval_monthdaynano_to_daytime( + interval_monthdaynano_array: &dyn Array, +) -> Result { + let interval_monthdaynano_array = interval_monthdaynano_array + .as_any() + .downcast_ref::() + .ok_or_else(|| + ArrowError::CastError("Failed to cast IntervalMonthDayNanoArray: Unable to downcast to IntervalMonthDayNanoArray".to_string()))?; + + let mut interval_daytime_builder = + IntervalDayTimeBuilder::with_capacity(interval_monthdaynano_array.len()); + + for value in interval_monthdaynano_array { + match value { + None => interval_daytime_builder.append_null(), + Some(interval_monthdaynano_value) => { + if interval_monthdaynano_value.months != 0 { + return Err( + ArrowError::CastError("Failed to cast IntervalMonthDayNanoArray to IntervalDayTimeArray: Non-zero months".to_string()), + ); + } + interval_daytime_builder.append_value(IntervalDayTimeType::make_value( + interval_monthdaynano_value.days, + (interval_monthdaynano_value.nanoseconds / 1_000_000) as i32, + )); + } + } + } + Ok(Arc::new(interval_daytime_builder.finish())) +} + +#[cfg(test)] +mod test { + use datafusion::arrow::{ + array::{RecordBatch, IntervalDayTimeArray, IntervalYearMonthArray}, + datatypes::{DataType, Field, Schema, SchemaRef, IntervalUnit, IntervalMonthDayNano, IntervalDayTime}, + }; + + use crate::schema_cast::record_convert::try_cast_to; + + use super::*; + + fn input_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("interval_daytime", DataType::Interval(IntervalUnit::MonthDayNano), false), + Field::new("interval_monthday_nano", DataType::Interval(IntervalUnit::MonthDayNano), false), + Field::new("interval_yearmonth", DataType::Interval(IntervalUnit::MonthDayNano), false), + ])) + } + + fn output_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new( + "interval_daytime", + DataType::Interval(IntervalUnit::DayTime), + false, + ), + Field::new( + "interval_monthday_nano", + DataType::Interval(IntervalUnit::MonthDayNano), + false, + ), + Field::new( + "interval_yearmonth", + DataType::Interval(IntervalUnit::YearMonth), + false, + ), + ])) + } + + fn batch_input() -> RecordBatch { + let interval_daytime_array = IntervalMonthDayNanoArray::from(vec![ + IntervalMonthDayNano::new(0, 1, 1_000_000_000), + IntervalMonthDayNano::new(0, 33, 0), + IntervalMonthDayNano::new(0, 0, 43_200_000_000_000), + ]); + let interval_monthday_nano_array = IntervalMonthDayNanoArray::from(vec![ + IntervalMonthDayNano::new(1, 2, 1000), + IntervalMonthDayNano::new(12, 1, 0), + IntervalMonthDayNano::new(0, 0, 12 * 1000 * 1000), + ]); + let interval_yearmonth_array = IntervalMonthDayNanoArray::from(vec![ + IntervalMonthDayNano::new(2, 0, 0), + IntervalMonthDayNano::new(25, 0, 0), + IntervalMonthDayNano::new(-1, 0, 0), + ]); + + RecordBatch::try_new( + input_schema(), + vec![ + Arc::new(interval_daytime_array), + Arc::new(interval_monthday_nano_array), + Arc::new(interval_yearmonth_array), + ], + ) + .expect("Failed to created arrow interval record batch") + } + + fn batch_expected() -> RecordBatch { + let interval_daytime_array = IntervalDayTimeArray::from(vec![ + IntervalDayTime::new(1, 1000), + IntervalDayTime::new(33, 0), + IntervalDayTime::new(0, 12 * 60 * 60 * 1000), + ]); + let interval_monthday_nano_array = IntervalMonthDayNanoArray::from(vec![ + IntervalMonthDayNano::new(1, 2, 1000), + IntervalMonthDayNano::new(12, 1, 0), + IntervalMonthDayNano::new(0, 0, 12 * 1000 * 1000), + ]); + let interval_yearmonth_array = IntervalYearMonthArray::from(vec![2, 25, -1]); + + RecordBatch::try_new( + output_schema(), + vec![ + Arc::new(interval_daytime_array), + Arc::new(interval_monthday_nano_array), + Arc::new(interval_yearmonth_array), + ], + ) + .expect("Failed to created arrow interval record batch") + } + + #[test] + fn test_cast_interval_with_schema() { + let input_batch = batch_input(); + let expected = batch_expected(); + let actual = try_cast_to(input_batch, output_schema()).expect("cast should succeed"); + + assert_eq!(actual, expected); + } +} diff --git a/datafusion-federation/src/schema_cast/record_convert.rs b/datafusion-federation/src/schema_cast/record_convert.rs index 17d6d23..1a324ee 100644 --- a/datafusion-federation/src/schema_cast/record_convert.rs +++ b/datafusion-federation/src/schema_cast/record_convert.rs @@ -1,12 +1,15 @@ use datafusion::arrow::{ array::{Array, RecordBatch}, compute::cast, - datatypes::{DataType, SchemaRef}, + datatypes::{DataType, IntervalUnit, SchemaRef}, }; use std::sync::Arc; -use super::lists_cast::{ - cast_string_to_fixed_size_list, cast_string_to_large_list, cast_string_to_list, +use super::{ + intervals_cast::{ + cast_interval_monthdaynano_to_daytime, cast_interval_monthdaynano_to_yearmonth, + }, + lists_cast::{cast_string_to_fixed_size_list, cast_string_to_large_list, cast_string_to_list}, }; pub type Result = std::result::Result; @@ -64,25 +67,33 @@ pub fn try_cast_to(record_batch: RecordBatch, expected_schema: SchemaRef) -> Res match (record_batch_col.data_type(), expected_field.data_type()) { (DataType::Utf8, DataType::List(item_type)) => { - return cast_string_to_list(&Arc::clone(record_batch_col), item_type) - .map_err(|e| Error::UnableToConvertRecordBatch { source: e }); + cast_string_to_list(&Arc::clone(record_batch_col), item_type) + .map_err(|e| Error::UnableToConvertRecordBatch { source: e }) } (DataType::Utf8, DataType::LargeList(item_type)) => { - return cast_string_to_large_list(&Arc::clone(record_batch_col), item_type) - .map_err(|e| Error::UnableToConvertRecordBatch { source: e }); + cast_string_to_large_list(&Arc::clone(record_batch_col), item_type) + .map_err(|e| Error::UnableToConvertRecordBatch { source: e }) } (DataType::Utf8, DataType::FixedSizeList(item_type, value_length)) => { - return cast_string_to_fixed_size_list( + cast_string_to_fixed_size_list( &Arc::clone(record_batch_col), item_type, value_length.clone(), ) - .map_err(|e| Error::UnableToConvertRecordBatch { source: e }); - } - _ => { - return cast(&Arc::clone(record_batch_col), expected_field.data_type()) - .map_err(|e| Error::UnableToConvertRecordBatch { source: e }); + .map_err(|e| Error::UnableToConvertRecordBatch { source: e }) } + ( + DataType::Interval(IntervalUnit::MonthDayNano), + DataType::Interval(IntervalUnit::YearMonth), + ) => cast_interval_monthdaynano_to_yearmonth(&Arc::clone(record_batch_col)) + .map_err(|e| Error::UnableToConvertRecordBatch { source: e }), + ( + DataType::Interval(IntervalUnit::MonthDayNano), + DataType::Interval(IntervalUnit::DayTime), + ) => cast_interval_monthdaynano_to_daytime(&Arc::clone(record_batch_col)) + .map_err(|e| Error::UnableToConvertRecordBatch { source: e }), + _ => cast(&Arc::clone(record_batch_col), expected_field.data_type()) + .map_err(|e| Error::UnableToConvertRecordBatch { source: e }), } }) .collect::>>>()?; diff --git a/sources/sql/src/lib.rs b/sources/sql/src/lib.rs index e563576..496bfbf 100644 --- a/sources/sql/src/lib.rs +++ b/sources/sql/src/lib.rs @@ -838,7 +838,7 @@ mod tests { let agg_tests = vec![ ( "SELECT MAX(a) FROM foo.df_table", - r#"SELECT MAX(remote_table.a) FROM remote_table"#, + r#"SELECT max(remote_table.a) FROM remote_table"#, ), ( "SELECT foo.df_table.a FROM foo.df_table", @@ -846,7 +846,7 @@ mod tests { ), ( "SELECT MIN(a) FROM foo.df_table", - r#"SELECT MIN(remote_table.a) FROM remote_table"#, + r#"SELECT min(remote_table.a) FROM remote_table"#, ), ( "SELECT AVG(a) FROM foo.df_table", @@ -874,7 +874,7 @@ mod tests { ), ( "SELECT MAX(app_table) from (SELECT a as app_table FROM app_table) b", - r#"SELECT MAX(b.app_table) FROM (SELECT remote_table.a AS app_table FROM remote_table) AS b"#, + r#"SELECT max(b.app_table) FROM (SELECT remote_table.a AS app_table FROM remote_table) AS b"#, ), // multiple occurrences of the same table in single aggregation expression ( From 21f07bec7284bcbff2bf4e570008290b66e3dc6f Mon Sep 17 00:00:00 2001 From: Sergei Grebnov Date: Mon, 26 Aug 2024 06:49:50 +0300 Subject: [PATCH 33/48] Add Structs parsing support to schema cast (#16) * Add Structs parsing support to schema cast * Optimize memory space --- datafusion-federation/src/schema_cast.rs | 1 + .../src/schema_cast/record_convert.rs | 10 +- .../src/schema_cast/struct_cast.rs | 169 ++++++++++++++++++ 3 files changed, 177 insertions(+), 3 deletions(-) create mode 100644 datafusion-federation/src/schema_cast/struct_cast.rs diff --git a/datafusion-federation/src/schema_cast.rs b/datafusion-federation/src/schema_cast.rs index ade8091..d23a470 100644 --- a/datafusion-federation/src/schema_cast.rs +++ b/datafusion-federation/src/schema_cast.rs @@ -14,6 +14,7 @@ use std::sync::Arc; mod intervals_cast; mod lists_cast; +mod struct_cast; pub mod record_convert; #[derive(Debug)] diff --git a/datafusion-federation/src/schema_cast/record_convert.rs b/datafusion-federation/src/schema_cast/record_convert.rs index 1a324ee..140ca38 100644 --- a/datafusion-federation/src/schema_cast/record_convert.rs +++ b/datafusion-federation/src/schema_cast/record_convert.rs @@ -1,7 +1,7 @@ use datafusion::arrow::{ array::{Array, RecordBatch}, compute::cast, - datatypes::{DataType, IntervalUnit, SchemaRef}, + datatypes::{DataType, IntervalUnit, SchemaRef} }; use std::sync::Arc; @@ -9,7 +9,7 @@ use super::{ intervals_cast::{ cast_interval_monthdaynano_to_daytime, cast_interval_monthdaynano_to_yearmonth, }, - lists_cast::{cast_string_to_fixed_size_list, cast_string_to_large_list, cast_string_to_list}, + lists_cast::{cast_string_to_fixed_size_list, cast_string_to_large_list, cast_string_to_list}, struct_cast::cast_string_to_struct, }; pub type Result = std::result::Result; @@ -78,10 +78,14 @@ pub fn try_cast_to(record_batch: RecordBatch, expected_schema: SchemaRef) -> Res cast_string_to_fixed_size_list( &Arc::clone(record_batch_col), item_type, - value_length.clone(), + *value_length, ) .map_err(|e| Error::UnableToConvertRecordBatch { source: e }) } + (DataType::Utf8, DataType::Struct(_)) => { + cast_string_to_struct(&Arc::clone(record_batch_col), expected_field.clone()) + .map_err(|e| Error::UnableToConvertRecordBatch { source: e }) + } ( DataType::Interval(IntervalUnit::MonthDayNano), DataType::Interval(IntervalUnit::YearMonth), diff --git a/datafusion-federation/src/schema_cast/struct_cast.rs b/datafusion-federation/src/schema_cast/struct_cast.rs new file mode 100644 index 0000000..85ad369 --- /dev/null +++ b/datafusion-federation/src/schema_cast/struct_cast.rs @@ -0,0 +1,169 @@ +use arrow_json::ReaderBuilder; +use datafusion::arrow::{ + array::{Array, ArrayRef, StringArray}, + datatypes::Field, + error::ArrowError, +}; +use std::sync::Arc; + +pub type Result = std::result::Result; + +pub(crate) fn cast_string_to_struct( + array: &dyn Array, + struct_field: Arc, +) -> Result { + let string_array = array + .as_any() + .downcast_ref::() + .ok_or_else(|| ArrowError::CastError("Failed to downcast to StringArray".to_string()))?; + + let mut decoder = ReaderBuilder::new_with_field(struct_field) + .build_decoder() + .map_err(|e| ArrowError::CastError(format!("Failed to create JSON decoder: {e}")))?; + + for value in string_array { + match value { + Some(v) => { + decoder.decode(v.as_bytes()).map_err(|e| { + ArrowError::CastError(format!("Failed to decode struct array: {e}")) + })?; + } + None => { + decoder.decode("null".as_bytes()).map_err(|e| { + ArrowError::CastError(format!("Failed to decode struct array: {e}")) + })?; + } + } + } + + let record = match decoder.flush() { + Ok(Some(record)) => record, + Ok(None) => { + return Err(ArrowError::CastError( + "Failed to flush decoder: No record".to_string(), + )); + } + Err(e) => { + return Err(ArrowError::CastError(format!( + "Failed to decode struct array: {e}" + ))); + } + }; + // struct_field is single struct column + return Ok(Arc::clone(record.column(0))); +} + +#[cfg(test)] +mod test { + use datafusion::arrow::{ + array::{Int32Builder, RecordBatch, StringArray, StringBuilder, StructBuilder}, + datatypes::{DataType, Field, Schema, SchemaRef}, + }; + + use crate::schema_cast::record_convert::try_cast_to; + + use super::*; + + fn input_schema() -> SchemaRef { + Arc::new(Schema::new(vec![Field::new( + "struct_string", + DataType::Utf8, + true, + )])) + } + + fn output_schema() -> SchemaRef { + Arc::new(Schema::new(vec![Field::new( + "struct", + DataType::Struct( + vec![ + Field::new("name", DataType::Utf8, false), + Field::new("age", DataType::Int32, false), + ] + .into(), + ), + true, + )])) + } + + fn batch_input() -> RecordBatch { + RecordBatch::try_new( + input_schema(), + vec![Arc::new(StringArray::from(vec![ + Some(r#"{"name":"John","age":30}"#), + None, + None, + Some(r#"{"name":"Jane","age":25}"#), + ]))], + ) + .expect("record batch should not panic") + } + + fn batch_expected() -> RecordBatch { + let name_field = Field::new("name", DataType::Utf8, false); + let age_field = Field::new("age", DataType::Int32, false); + + let mut struct_builder = StructBuilder::new( + vec![name_field, age_field], + vec![ + Box::new(StringBuilder::new()), + Box::new(Int32Builder::new()), + ], + ); + + struct_builder + .field_builder::(0) + .expect("should return field builder") + .append_value("John"); + struct_builder + .field_builder::(1) + .expect("should return field builder") + .append_value(30); + struct_builder.append(true); + + struct_builder + .field_builder::(0) + .expect("should return field builder") + .append_null(); + struct_builder + .field_builder::(1) + .expect("should return field builder") + .append_null(); + struct_builder.append(false); + + struct_builder + .field_builder::(0) + .expect("should return field builder") + .append_null(); + struct_builder + .field_builder::(1) + .expect("should return field builder") + .append_null(); + struct_builder.append(false); + + struct_builder + .field_builder::(0) + .expect("should return field builder") + .append_value("Jane"); + struct_builder + .field_builder::(1) + .expect("should return field builder") + .append_value(25); + struct_builder.append(true); + + let struct_array = struct_builder.finish(); + + RecordBatch::try_new(output_schema(), vec![Arc::new(struct_array)]) + .expect("Failed to create expected RecordBatch") + } + + #[test] + fn test_cast_to_struct() { + let input_batch = batch_input(); + let expected = batch_expected(); + + let actual = try_cast_to(input_batch, output_schema()).expect("cast should succeed"); + + assert_eq!(actual, expected); + } +} From b6682948d07cc3155edb3dfbf03f8b55570fc1d2 Mon Sep 17 00:00:00 2001 From: peasee <98815791+peasee@users.noreply.github.com> Date: Fri, 30 Aug 2024 17:45:40 +1000 Subject: [PATCH 34/48] feat: Add AST analyzer middleware (#19) * feat: Add AST analyzer middleware * fix: Remove redundant ast_analyzer call * chore: Variable does not need to be mutable --- Cargo.toml | 2 +- sources/sql/src/executor.rs | 8 +++++++- sources/sql/src/lib.rs | 10 ++++++++-- 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index bced60f..1cc0252 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,4 +25,4 @@ async-stream = "0.3.5" futures = "0.3.30" datafusion = "41.0.0" datafusion-substrait = "41.0.0" -arrow-json = "52.2.0" +arrow-json = "52.2.0" \ No newline at end of file diff --git a/sources/sql/src/executor.rs b/sources/sql/src/executor.rs index 6042f77..275bc68 100644 --- a/sources/sql/src/executor.rs +++ b/sources/sql/src/executor.rs @@ -2,11 +2,12 @@ use async_trait::async_trait; use core::fmt; use datafusion::{ arrow::datatypes::SchemaRef, error::Result, physical_plan::SendableRecordBatchStream, - sql::unparser::dialect::Dialect, + sql::sqlparser::ast, sql::unparser::dialect::Dialect, }; use std::sync::Arc; pub type SQLExecutorRef = Arc; +pub type AstAnalyzer = Box Result>; #[async_trait] pub trait SQLExecutor: Sync + Send { @@ -20,6 +21,11 @@ pub trait SQLExecutor: Sync + Send { // The specific SQL dialect (currently supports 'sqlite', 'postgres', 'flight') fn dialect(&self) -> Arc; + /// Returns an AST analyzer specific for this engine to modify the AST before execution + fn ast_analyzer(&self) -> Option { + None + } + // Execution /// Execute a SQL query fn execute(&self, query: &str, schema: SchemaRef) -> Result; diff --git a/sources/sql/src/lib.rs b/sources/sql/src/lib.rs index 496bfbf..5c2388d 100644 --- a/sources/sql/src/lib.rs +++ b/sources/sql/src/lib.rs @@ -634,8 +634,13 @@ impl VirtualExecutionPlan { fn sql(&self) -> Result { // Find all table scans, recover the SQLTableSource, find the remote table name and replace the name of the TableScan table. let mut known_rewrites = HashMap::new(); - let ast = Unparser::new(self.executor.dialect().as_ref()) + let mut ast = Unparser::new(self.executor.dialect().as_ref()) .plan_to_sql(&rewrite_table_scans(&self.plan, &mut known_rewrites)?)?; + + if let Some(analyzer) = self.executor.ast_analyzer() { + ast = analyzer(ast)?; + } + Ok(format!("{ast}")) } } @@ -649,7 +654,8 @@ impl DisplayAs for VirtualExecutionPlan { write!(f, " name={}", self.executor.name())?; if let Some(ctx) = self.executor.compute_context() { write!(f, " compute_context={ctx}")?; - } + }; + write!(f, " sql={ast}")?; if let Ok(query) = self.sql() { write!(f, " rewritten_sql={query}")?; From 7c69f7bc5d4f67469aaa215650e89000941d6915 Mon Sep 17 00:00:00 2001 From: Phillip LeBlanc Date: Wed, 18 Sep 2024 23:26:20 +0900 Subject: [PATCH 35/48] Upgrade to DF 42 & Arrow 53 --- Cargo.toml | 6 ++--- datafusion-federation/src/table_provider.rs | 4 +-- sources/flight-sql/Cargo.toml | 6 ++--- sources/sql/src/lib.rs | 27 ++++++++++----------- 4 files changed, 21 insertions(+), 22 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 1cc0252..6264f79 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,6 @@ readme = "README.md" async-trait = "0.1.77" async-stream = "0.3.5" futures = "0.3.30" -datafusion = "41.0.0" -datafusion-substrait = "41.0.0" -arrow-json = "52.2.0" \ No newline at end of file +datafusion = "42" +datafusion-substrait = "42" +arrow-json = "53" \ No newline at end of file diff --git a/datafusion-federation/src/table_provider.rs b/datafusion-federation/src/table_provider.rs index b820b6e..92df798 100644 --- a/datafusion-federation/src/table_provider.rs +++ b/datafusion-federation/src/table_provider.rs @@ -1,4 +1,4 @@ -use std::{any::Any, sync::Arc}; +use std::{any::Any, borrow::Cow, sync::Arc}; use async_trait::async_trait; use datafusion::{ @@ -70,7 +70,7 @@ impl TableProvider for FederatedTableProviderAdaptor { self.source.table_type() } - fn get_logical_plan(&self) -> Option<&LogicalPlan> { + fn get_logical_plan(&self) -> Option> { if let Some(table_provider) = &self.table_provider { return table_provider .get_logical_plan() diff --git a/sources/flight-sql/Cargo.toml b/sources/flight-sql/Cargo.toml index d9c1a1d..63ce13f 100644 --- a/sources/flight-sql/Cargo.toml +++ b/sources/flight-sql/Cargo.toml @@ -16,8 +16,8 @@ datafusion-substrait.workspace = true datafusion-federation.path = "../../datafusion-federation" datafusion-federation-sql.path = "../sql" futures = "0.3.30" -tonic = {version="0.11.0", features=["tls"] } +tonic = {version="0.12.2", features=["tls"] } prost = "0.12.3" -arrow = "52.0.0" -arrow-flight = { version = "52.0.0", features = ["flight-sql-experimental"] } +arrow = "53.0.0" +arrow-flight = { version = "53.0.0", features = ["flight-sql-experimental"] } log = "0.4.20" diff --git a/sources/sql/src/lib.rs b/sources/sql/src/lib.rs index 5c2388d..852786a 100644 --- a/sources/sql/src/lib.rs +++ b/sources/sql/src/lib.rs @@ -408,14 +408,6 @@ fn rewrite_table_scans_in_expr( try_cast.data_type, ))) } - Expr::Sort(sort) => { - let expr = rewrite_table_scans_in_expr(*sort.expr, known_rewrites)?; - Ok(Expr::Sort(Sort::new( - Box::new(expr), - sort.asc, - sort.nulls_first, - ))) - } Expr::ScalarFunction(sf) => { let args = sf .args @@ -442,8 +434,11 @@ fn rewrite_table_scans_in_expr( .order_by .map(|e| { e.into_iter() - .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) - .collect::>>() + .map(|s| { + rewrite_table_scans_in_expr(s.expr, known_rewrites) + .map(|e| Sort::new(e, s.asc, s.nulls_first)) + }) + .collect::>>() }) .transpose()?; Ok(Expr::AggregateFunction(AggregateFunction { @@ -469,8 +464,11 @@ fn rewrite_table_scans_in_expr( let order_by = wf .order_by .into_iter() - .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) - .collect::>>()?; + .map(|s| { + rewrite_table_scans_in_expr(s.expr, known_rewrites) + .map(|e| Sort::new(e, s.asc, s.nulls_first)) + }) + .collect::>>()?; Ok(Expr::WindowFunction(WindowFunction { fun: wf.fun, args, @@ -522,13 +520,14 @@ fn rewrite_table_scans_in_expr( is.negated, ))) } - Expr::Wildcard { qualifier } => { + Expr::Wildcard { qualifier, options } => { if let Some(rewrite) = qualifier.as_ref().and_then(|q| known_rewrites.get(q)) { Ok(Expr::Wildcard { qualifier: Some(rewrite.clone()), + options, }) } else { - Ok(Expr::Wildcard { qualifier }) + Ok(Expr::Wildcard { qualifier, options }) } } Expr::GroupingSet(gs) => match gs { From f1e7b17755d96b30bafb4d185467ca4c55f85aec Mon Sep 17 00:00:00 2001 From: Sergei Grebnov Date: Wed, 9 Oct 2024 10:34:33 -0700 Subject: [PATCH 36/48] Preserve records batch order when SchemaCastScanExec is involved (#20) --- datafusion-federation/src/schema_cast.rs | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/datafusion-federation/src/schema_cast.rs b/datafusion-federation/src/schema_cast.rs index d23a470..502d4b8 100644 --- a/datafusion-federation/src/schema_cast.rs +++ b/datafusion-federation/src/schema_cast.rs @@ -4,7 +4,8 @@ use datafusion::error::{DataFusionError, Result}; use datafusion::execution::{SendableRecordBatchStream, TaskContext}; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::physical_plan::{ - DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties, + DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties, + PlanProperties, }; use futures::StreamExt; use std::any::Any; @@ -69,6 +70,12 @@ impl ExecutionPlan for SchemaCastScanExec { vec![&self.input] } + /// Prevents the introduction of additional `RepartitionExec` and processing input in parallel. + /// This guarantees that the input is processed as a single stream, preserving the order of the data. + fn benefits_from_input_partitioning(&self) -> Vec { + vec![false] + } + fn with_new_children( self: Arc, children: Vec>, From 06b90003d2e8df046d32d92665c75c4abaa0b970 Mon Sep 17 00:00:00 2001 From: Sergei Grebnov Date: Thu, 17 Oct 2024 11:55:41 -0700 Subject: [PATCH 37/48] Support for rewriting plans with UNNEST (#21) --- Cargo.toml | 4 +- sources/sql/src/lib.rs | 122 ++++++++++++++++++++++++++++++++++++++--- 2 files changed, 117 insertions(+), 9 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 6264f79..063ded9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,8 @@ members = [ [patch.crates-io] # connectorx = { path = "../connector-x/connectorx" } # datafusion = { path = "../arrow-datafusion/datafusion/core" } +# Pending next Datafusion release with `unnest` unparsing support +datafusion = { git = "https://github.com/spiceai/datafusion.git", rev = "2b71f31beec5d3f78d0ae8534976409509423c1b" } [workspace.package] version = "0.1.6" @@ -25,4 +27,4 @@ async-stream = "0.3.5" futures = "0.3.30" datafusion = "42" datafusion-substrait = "42" -arrow-json = "53" \ No newline at end of file +arrow-json = "53" diff --git a/sources/sql/src/lib.rs b/sources/sql/src/lib.rs index 852786a..0cf0879 100644 --- a/sources/sql/src/lib.rs +++ b/sources/sql/src/lib.rs @@ -6,15 +6,16 @@ use datafusion::{ arrow::datatypes::{Schema, SchemaRef}, common::Column, config::ConfigOptions, - error::Result, + error::{DataFusionError, Result}, execution::{context::SessionState, TaskContext}, logical_expr::{ + self, expr::{ AggregateFunction, Alias, Exists, InList, InSubquery, ScalarFunction, Sort, Unnest, WindowFunction, }, - Between, BinaryExpr, Case, Cast, Expr, Extension, GroupingSet, Like, LogicalPlan, Subquery, - TryCast, + Between, BinaryExpr, Case, Cast, Expr, Extension, GroupingSet, Like, LogicalPlan, + LogicalPlanBuilder, Projection, Subquery, TryCast, }, optimizer::analyzer::{Analyzer, AnalyzerRule}, physical_expr::EquivalenceProperties, @@ -145,13 +146,86 @@ fn rewrite_table_scans( .map(|i| rewrite_table_scans(i, known_rewrites)) .collect::>>()?; - let mut new_expressions = vec![]; - for expression in plan.expressions() { - let new_expr = rewrite_table_scans_in_expr(expression.clone(), known_rewrites)?; - new_expressions.push(new_expr); + match plan { + LogicalPlan::Unnest(unnest) => { + // The Union plan cannot be constructed from rewritten expressions. It requires specialized logic to handle + // the renaming in UNNEST columns and the corresponding column aliases in the underlying projection plan. + rewrite_unnest_plan(unnest, rewritten_inputs, known_rewrites) + } + _ => { + let mut new_expressions = vec![]; + for expression in plan.expressions() { + let new_expr = rewrite_table_scans_in_expr(expression.clone(), known_rewrites)?; + new_expressions.push(new_expr); + } + let new_plan = plan.with_new_exprs(new_expressions, rewritten_inputs)?; + Ok(new_plan) + } } +} - let new_plan = plan.with_new_exprs(new_expressions, rewritten_inputs)?; +/// Rewrite an unnest plan to use the original federated table name. +/// In a standard unnest plan, column names are typically referenced in projection columns by wrapping them +/// in aliases such as "UNNEST(table_name.column_name)". `rewrite_table_scans_in_expr` does not handle alias +/// rewriting so we manually collect the rewritten unnest column names/aliases and update the projection +/// plan to ensure that the aliases reflect the new names. +fn rewrite_unnest_plan( + unnest: &logical_expr::Unnest, + mut rewritten_inputs: Vec, + known_rewrites: &mut HashMap, +) -> Result { + // Unnest plan has a single input + let input = rewritten_inputs.remove(0); + + let mut known_unnest_rewrites: HashMap = HashMap::new(); + + // `exec_columns` represent columns to run UNNEST on: rewrite them and collect new names + let unnest_columns = unnest + .exec_columns + .iter() + .map(|c: &Column| { + match rewrite_table_scans_in_expr(Expr::Column(c.clone()), known_rewrites)? { + Expr::Column(column) => { + known_unnest_rewrites.insert(c.name.clone(), column.name.clone()); + Ok(column) + } + _ => Err(DataFusionError::Plan( + "Rewritten column expression must be a column".to_string(), + )), + } + }) + .collect::>>()?; + + let LogicalPlan::Projection(projection) = input else { + return Err(DataFusionError::Plan( + "The input to the unnest plan should be a projection plan".to_string(), + )); + }; + + // rewrite aliases in inner projection; columns were rewritten via `rewrite_table_scans_in_expr`` + let new_expressions = projection + .expr + .into_iter() + .map(|expr| match expr { + Expr::Alias(alias) => { + let name = match known_unnest_rewrites.get(&alias.name) { + Some(name) => name, + None => &alias.name, + }; + Ok(Expr::Alias(Alias::new(*alias.expr, alias.relation, name))) + } + _ => Ok(expr), + }) + .collect::>>()?; + + let updated_unnest_inner_projection = + Projection::try_new(new_expressions, Arc::clone(&projection.input))?; + + // reconstruct the unnest plan with updated projection and rewritten column names + let new_plan = + LogicalPlanBuilder::new(LogicalPlan::Projection(updated_unnest_inner_projection)) + .unnest_columns_with_options(unnest_columns, unnest.options.clone())? + .build()?; Ok(new_plan) } @@ -761,6 +835,11 @@ mod tests { Field::new("a", DataType::Int64, false), Field::new("b", DataType::Utf8, false), Field::new("c", DataType::Date32, false), + Field::new( + "d", + DataType::List(Arc::new(Field::new("item", DataType::Int64, true))), + false, + ), ])); let table_source = Arc::new( SQLTableSource::new_with_schema( @@ -927,6 +1006,33 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_rewrite_table_scans_unnest() -> Result<()> { + init_tracing(); + let ctx = get_test_df_context(); + + let tests = vec![ + ( + "SELECT UNNEST([1, 2, 2, 5, NULL]), b, c from app_table where a > 10 order by b limit 10;", + r#"SELECT UNNEST(make_array(1, 2, 2, 5, NULL)), remote_table.b, remote_table.c FROM remote_table WHERE (remote_table.a > 10) ORDER BY remote_table.b ASC NULLS LAST LIMIT 10"#, + ), + ( + "SELECT UNNEST(app_table.d), b, c from app_table where a > 10 order by b limit 10;", + r#"SELECT UNNEST(remote_table.d), remote_table.b, remote_table.c FROM remote_table WHERE (remote_table.a > 10) ORDER BY remote_table.b ASC NULLS LAST LIMIT 10"#, + ), + ( + "SELECT sum(b.x) AS total FROM (SELECT UNNEST(d) AS x from app_table where a > 0) AS b;", + r#"SELECT sum(b.x) AS total FROM (SELECT UNNEST(remote_table.d) AS x FROM remote_table WHERE (remote_table.a > 0)) AS b"#, + ), + ]; + + for test in tests { + test_sql(&ctx, test.0, test.1).await?; + } + + Ok(()) + } + async fn test_sql( ctx: &SessionContext, sql_query: &str, From ce1ce6158ce03f88c67cde25fca00f436c3e1690 Mon Sep 17 00:00:00 2001 From: Sergei Grebnov Date: Thu, 17 Oct 2024 23:34:54 -0700 Subject: [PATCH 38/48] Correctly handle unnest by FederationAnalyzerRule (#22) --- datafusion-federation/src/analyzer.rs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/datafusion-federation/src/analyzer.rs b/datafusion-federation/src/analyzer.rs index 7760734..764dec3 100644 --- a/datafusion-federation/src/analyzer.rs +++ b/datafusion-federation/src/analyzer.rs @@ -115,7 +115,12 @@ impl FederationAnalyzerRule { }) .collect::>>()?; - let new_plan = plan.with_new_exprs(plan.expressions(), new_inputs)?; + let new_plan = match plan { + // Unnest returns columns to unnest as `expressions` but does not support passing them back to `with_new_exprs`. + // Instead, it uses data from its internal representation to create a new plan. + LogicalPlan::Unnest(_) => plan.with_new_exprs(vec![], new_inputs)?, + _ => plan.with_new_exprs(plan.expressions(), new_inputs)?, + }; Ok((Some(new_plan), None)) } From 0aabd423ef8aea0ae05479331d1486e9ef8f8061 Mon Sep 17 00:00:00 2001 From: Sergei Grebnov Date: Mon, 21 Oct 2024 12:54:06 -0700 Subject: [PATCH 39/48] Filters push down optimization for federation queries (#23) --- datafusion-federation/src/analyzer.rs | 46 ++++++++---- datafusion-federation/src/lib.rs | 1 + datafusion-federation/src/optimize.rs | 104 ++++++++++++++++++++++++++ 3 files changed, 138 insertions(+), 13 deletions(-) create mode 100644 datafusion-federation/src/optimize.rs diff --git a/datafusion-federation/src/analyzer.rs b/datafusion-federation/src/analyzer.rs index 764dec3..0c91a4a 100644 --- a/datafusion-federation/src/analyzer.rs +++ b/datafusion-federation/src/analyzer.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use datafusion::{ - common::Column, + common::{tree_node::TreeNode, Column}, config::ConfigOptions, datasource::source_as_provider, error::Result, @@ -10,7 +10,8 @@ use datafusion::{ sql::TableReference, }; -use crate::{FederatedTableProviderAdaptor, FederatedTableSource, FederationProviderRef}; +use crate::{optimize::optimize_plan, FederatedTableProviderAdaptor, FederatedTableSource, FederationProviderRef}; + #[derive(Default)] pub struct FederationAnalyzerRule {} @@ -20,6 +21,13 @@ impl AnalyzerRule for FederationAnalyzerRule { // TableScans from the same FederationProvider. // There 'largest sub-trees' are passed to their respective FederationProvider.optimizer. fn analyze(&self, plan: LogicalPlan, config: &ConfigOptions) -> Result { + + if !contains_federated_table(&plan)? { + return Ok(plan); + } + + let plan = optimize_plan(plan)?; + let (optimized, _) = self.optimize_recursively(&plan, None, config)?; if let Some(result) = optimized { return Ok(result); @@ -33,6 +41,18 @@ impl AnalyzerRule for FederationAnalyzerRule { } } +fn contains_federated_table(plan: &LogicalPlan) -> Result { + let federated_table_exists = plan.exists(|x| { + if let Some(provider) = get_federation_provider(x)? { + // federated table provider should have an analyzer + return Ok(provider.analyzer().is_some()); + } + Ok(false) + })?; + + Ok(federated_table_exists) +} + impl FederationAnalyzerRule { pub fn new() -> Self { Self::default() @@ -49,7 +69,7 @@ impl FederationAnalyzerRule { _config: &ConfigOptions, ) -> Result<(Option, Option)> { // Check if this node determines the FederationProvider - let sole_provider = self.get_federation_provider(plan)?; + let sole_provider = get_federation_provider(plan)?; if sole_provider.is_some() { return Ok((None, sole_provider)); } @@ -124,18 +144,18 @@ impl FederationAnalyzerRule { Ok((Some(new_plan), None)) } +} - fn get_federation_provider(&self, plan: &LogicalPlan) -> Result> { - match plan { - LogicalPlan::TableScan(TableScan { ref source, .. }) => { - let Some(federated_source) = get_table_source(source)? else { - return Ok(None); - }; - let provider = federated_source.federation_provider(); - Ok(Some(provider)) - } - _ => Ok(None), +fn get_federation_provider(plan: &LogicalPlan) -> Result> { + match plan { + LogicalPlan::TableScan(TableScan { ref source, .. }) => { + let Some(federated_source) = get_table_source(source)? else { + return Ok(None); + }; + let provider = federated_source.federation_provider(); + Ok(Some(provider)) } + _ => Ok(None), } } diff --git a/datafusion-federation/src/lib.rs b/datafusion-federation/src/lib.rs index 41ce145..eb4a635 100644 --- a/datafusion-federation/src/lib.rs +++ b/datafusion-federation/src/lib.rs @@ -8,6 +8,7 @@ use datafusion::optimizer::analyzer::Analyzer; mod analyzer; pub use analyzer::*; +mod optimize; mod table_provider; pub use table_provider::*; diff --git a/datafusion-federation/src/optimize.rs b/datafusion-federation/src/optimize.rs new file mode 100644 index 0000000..c49b0ad --- /dev/null +++ b/datafusion-federation/src/optimize.rs @@ -0,0 +1,104 @@ +use datafusion::{ + common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}, + error::Result, + logical_expr:: LogicalPlan, + optimizer::{ + optimizer::ApplyOrder, push_down_filter, OptimizerConfig, OptimizerContext, OptimizerRule + } +}; + +pub(crate) fn optimize_plan(plan: LogicalPlan) -> Result { + let push_down_filter_rule = push_down_filter::PushDownFilter::new(); + // `push_down_filter` does not use config so it can be default + let optimizer_config = OptimizerContext::default(); + + let res = match push_down_filter_rule.apply_order() { + Some(apply_order) => plan.rewrite(&mut Rewriter::new( + apply_order, + &push_down_filter_rule, + &optimizer_config, + )), + None => optimize_plan_node(plan, &push_down_filter_rule, &optimizer_config), + }; + + Ok(res?.data) +} + +struct Rewriter<'a> { + apply_order: ApplyOrder, + rule: &'a dyn OptimizerRule, + config: &'a dyn OptimizerConfig, +} + +impl<'a> Rewriter<'a> { + fn new( + apply_order: ApplyOrder, + rule: &'a dyn OptimizerRule, + config: &'a dyn OptimizerConfig, + ) -> Self { + Self { + apply_order, + rule, + config, + } + } +} + +impl<'a> TreeNodeRewriter for Rewriter<'a> { + type Node = LogicalPlan; + + fn f_down(&mut self, node: LogicalPlan) -> Result> { + if self.apply_order == ApplyOrder::TopDown { + optimize_plan_node(node, self.rule, self.config) + } else { + Ok(Transformed::no(node)) + } + } + + fn f_up(&mut self, node: LogicalPlan) -> Result> { + if self.apply_order == ApplyOrder::BottomUp { + optimize_plan_node(node, self.rule, self.config) + } else { + Ok(Transformed::no(node)) + } + } +} + +fn should_run_rule_for_node(node: &LogicalPlan, _rule: &dyn OptimizerRule) -> bool { + // this logic is applicable only for `push_down_filter_rule`; we don't have any other rules + if let LogicalPlan::Filter(x) = node { + // Applying the `push_down_filter_rule` to certain nodes like `SubqueryAlias`, `Aggregate`, and `CrossJoin` + // can cause issues during unparsing, thus the optimization is only applied to nodes that are currently supported. + matches!( + x.input.as_ref(), + LogicalPlan::Join(_) | LogicalPlan::TableScan(_) | LogicalPlan::Projection(_) | LogicalPlan::Filter(_) | LogicalPlan::Distinct(_) | LogicalPlan::Sort(_) + ) + } else { + true + } +} + +fn optimize_plan_node( + plan: LogicalPlan, + rule: &dyn OptimizerRule, + config: &dyn OptimizerConfig, +) -> Result> { + if !should_run_rule_for_node(&plan, rule) { + return Ok(Transformed::no(plan)); + } + + if rule.supports_rewrite() { + return rule.rewrite(plan, config); + } + + #[allow(deprecated)] + rule.try_optimize(&plan, config).map(|maybe_plan| { + match maybe_plan { + Some(new_plan) => { + // if the node was rewritten by the optimizer, replace the node + Transformed::yes(new_plan) + } + None => Transformed::no(plan), + } + }) +} From 914bd0836baa6990c5d03f977e5e87fe5eeaf4d6 Mon Sep 17 00:00:00 2001 From: Sergei Grebnov Date: Mon, 4 Nov 2024 10:12:27 -0800 Subject: [PATCH 40/48] Run `optimize_projections` as part of federated plan optimization (#24) --- Cargo.toml | 2 +- datafusion-federation/src/analyzer.rs | 13 +++-- datafusion-federation/src/optimize.rs | 77 ++++++++++++++++++++------- 3 files changed, 67 insertions(+), 25 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 063ded9..3836fd7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,7 +12,7 @@ members = [ # connectorx = { path = "../connector-x/connectorx" } # datafusion = { path = "../arrow-datafusion/datafusion/core" } # Pending next Datafusion release with `unnest` unparsing support -datafusion = { git = "https://github.com/spiceai/datafusion.git", rev = "2b71f31beec5d3f78d0ae8534976409509423c1b" } +datafusion = { git = "https://github.com/spiceai/datafusion.git", rev = "06969ee5af853f0c071a98683dc2b9fde71b81a9" } [workspace.package] version = "0.1.6" diff --git a/datafusion-federation/src/analyzer.rs b/datafusion-federation/src/analyzer.rs index 0c91a4a..be57a04 100644 --- a/datafusion-federation/src/analyzer.rs +++ b/datafusion-federation/src/analyzer.rs @@ -10,11 +10,14 @@ use datafusion::{ sql::TableReference, }; -use crate::{optimize::optimize_plan, FederatedTableProviderAdaptor, FederatedTableSource, FederationProviderRef}; - +use crate::{ + optimize::Optimizer, FederatedTableProviderAdaptor, FederatedTableSource, FederationProviderRef, +}; #[derive(Default)] -pub struct FederationAnalyzerRule {} +pub struct FederationAnalyzerRule { + optimizer: Optimizer, +} impl AnalyzerRule for FederationAnalyzerRule { // Walk over the plan, look for the largest subtrees that only have @@ -26,7 +29,7 @@ impl AnalyzerRule for FederationAnalyzerRule { return Ok(plan); } - let plan = optimize_plan(plan)?; + let plan = self.optimizer.optimize_plan(plan)?; let (optimized, _) = self.optimize_recursively(&plan, None, config)?; if let Some(result) = optimized { @@ -43,7 +46,7 @@ impl AnalyzerRule for FederationAnalyzerRule { fn contains_federated_table(plan: &LogicalPlan) -> Result { let federated_table_exists = plan.exists(|x| { - if let Some(provider) = get_federation_provider(x)? { + if let Some(provider) = get_federation_provider(x)? { // federated table provider should have an analyzer return Ok(provider.analyzer().is_some()); } diff --git a/datafusion-federation/src/optimize.rs b/datafusion-federation/src/optimize.rs index c49b0ad..070581c 100644 --- a/datafusion-federation/src/optimize.rs +++ b/datafusion-federation/src/optimize.rs @@ -1,27 +1,61 @@ use datafusion::{ - common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}, + common::tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, error::Result, - logical_expr:: LogicalPlan, + execution::{SessionState, SessionStateBuilder}, + logical_expr::LogicalPlan, optimizer::{ - optimizer::ApplyOrder, push_down_filter, OptimizerConfig, OptimizerContext, OptimizerRule - } + optimize_projections::OptimizeProjections, optimizer::ApplyOrder, + push_down_filter::PushDownFilter, OptimizerConfig, OptimizerRule, + }, + prelude::SessionConfig, }; -pub(crate) fn optimize_plan(plan: LogicalPlan) -> Result { - let push_down_filter_rule = push_down_filter::PushDownFilter::new(); - // `push_down_filter` does not use config so it can be default - let optimizer_config = OptimizerContext::default(); +pub(crate) struct Optimizer { + config: SessionState, + push_down_filter: PushDownFilter, + optimize_projections: OptimizeProjections, +} + +impl Default for Optimizer { + fn default() -> Self { + // `push_down_filter` and `optimize_projections` does not use config (except `optimize_projections_preserve_existing_projections`) so it can be default + // `SessionState` implements `OptimizerConfig` allowing specification of the required configuration for optimization rules. + let config = SessionStateBuilder::new() + .with_config( + SessionConfig::new().with_optimize_projections_preserve_existing_projections(true), + ) + .build(); - let res = match push_down_filter_rule.apply_order() { - Some(apply_order) => plan.rewrite(&mut Rewriter::new( - apply_order, - &push_down_filter_rule, - &optimizer_config, - )), - None => optimize_plan_node(plan, &push_down_filter_rule, &optimizer_config), - }; + Self { + config, + push_down_filter: PushDownFilter::new(), + optimize_projections: OptimizeProjections::new(), + } + } +} + +impl Optimizer { + pub fn new() -> Self { + Self::default() + } - Ok(res?.data) + pub(crate) fn optimize_plan(&self, plan: LogicalPlan) -> Result { + let mut optimized_plan = plan + .rewrite(&mut Rewriter::new( + ApplyOrder::TopDown, + &self.push_down_filter, + &self.config, + ))? + .data; + + // `optimize_projections` is applied recursively top down so it can be applied only once to the root node + optimized_plan = self + .optimize_projections + .rewrite(optimized_plan, &self.config) + .data()?; + + Ok(optimized_plan) + } } struct Rewriter<'a> { @@ -65,13 +99,18 @@ impl<'a> TreeNodeRewriter for Rewriter<'a> { } fn should_run_rule_for_node(node: &LogicalPlan, _rule: &dyn OptimizerRule) -> bool { - // this logic is applicable only for `push_down_filter_rule`; we don't have any other rules + // this logic is applicable only for `push_down_filter_rule`; we don't have any other rules using `should_run_rule_for_node` if let LogicalPlan::Filter(x) = node { // Applying the `push_down_filter_rule` to certain nodes like `SubqueryAlias`, `Aggregate`, and `CrossJoin` // can cause issues during unparsing, thus the optimization is only applied to nodes that are currently supported. matches!( x.input.as_ref(), - LogicalPlan::Join(_) | LogicalPlan::TableScan(_) | LogicalPlan::Projection(_) | LogicalPlan::Filter(_) | LogicalPlan::Distinct(_) | LogicalPlan::Sort(_) + LogicalPlan::Join(_) + | LogicalPlan::TableScan(_) + | LogicalPlan::Projection(_) + | LogicalPlan::Filter(_) + | LogicalPlan::Distinct(_) + | LogicalPlan::Sort(_) ) } else { true From 5af0df83c2cd1d3f82f293b066b401a4dfd4064b Mon Sep 17 00:00:00 2001 From: Phillip LeBlanc Date: Tue, 12 Nov 2024 17:07:02 +0900 Subject: [PATCH 41/48] Upgrade to DF 43 --- Cargo.toml | 6 +++--- datafusion-federation/src/analyzer.rs | 3 +-- datafusion-federation/src/optimize.rs | 1 + datafusion-federation/src/plan_node.rs | 9 +++++++-- datafusion-federation/src/table_provider.rs | 14 ++++++++++++-- sources/sql/src/lib.rs | 6 ++++++ sources/sql/src/schema.rs | 11 +++++++++++ 7 files changed, 41 insertions(+), 9 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 3836fd7..15d64d7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,7 +12,7 @@ members = [ # connectorx = { path = "../connector-x/connectorx" } # datafusion = { path = "../arrow-datafusion/datafusion/core" } # Pending next Datafusion release with `unnest` unparsing support -datafusion = { git = "https://github.com/spiceai/datafusion.git", rev = "06969ee5af853f0c071a98683dc2b9fde71b81a9" } +datafusion = { git = "https://github.com/spiceai/datafusion.git", rev = "0bad328656a07fd5ab899186462b09e119e21f90" } [workspace.package] version = "0.1.6" @@ -25,6 +25,6 @@ readme = "README.md" async-trait = "0.1.77" async-stream = "0.3.5" futures = "0.3.30" -datafusion = "42" -datafusion-substrait = "42" +datafusion = "43" +datafusion-substrait = "43" arrow-json = "53" diff --git a/datafusion-federation/src/analyzer.rs b/datafusion-federation/src/analyzer.rs index be57a04..644b034 100644 --- a/datafusion-federation/src/analyzer.rs +++ b/datafusion-federation/src/analyzer.rs @@ -14,7 +14,7 @@ use crate::{ optimize::Optimizer, FederatedTableProviderAdaptor, FederatedTableSource, FederationProviderRef, }; -#[derive(Default)] +#[derive(Default, Debug)] pub struct FederationAnalyzerRule { optimizer: Optimizer, } @@ -24,7 +24,6 @@ impl AnalyzerRule for FederationAnalyzerRule { // TableScans from the same FederationProvider. // There 'largest sub-trees' are passed to their respective FederationProvider.optimizer. fn analyze(&self, plan: LogicalPlan, config: &ConfigOptions) -> Result { - if !contains_federated_table(&plan)? { return Ok(plan); } diff --git a/datafusion-federation/src/optimize.rs b/datafusion-federation/src/optimize.rs index 070581c..082ddf6 100644 --- a/datafusion-federation/src/optimize.rs +++ b/datafusion-federation/src/optimize.rs @@ -10,6 +10,7 @@ use datafusion::{ prelude::SessionConfig, }; +#[derive(Debug)] pub(crate) struct Optimizer { config: SessionState, push_down_filter: PushDownFilter, diff --git a/datafusion-federation/src/plan_node.rs b/datafusion-federation/src/plan_node.rs index c99204f..f938290 100644 --- a/datafusion-federation/src/plan_node.rs +++ b/datafusion-federation/src/plan_node.rs @@ -30,6 +30,12 @@ impl FederatedPlanNode { } } +impl PartialOrd for FederatedPlanNode { + fn partial_cmp(&self, other: &Self) -> Option { + self.plan.partial_cmp(&other.plan) + } +} + impl Debug for FederatedPlanNode { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { UserDefinedLogicalNodeCore::fmt_for_explain(self, f) @@ -67,8 +73,7 @@ impl UserDefinedLogicalNodeCore for FederatedPlanNode { } } -#[derive(Default)] - +#[derive(Default, Debug)] pub struct FederatedQueryPlanner {} impl FederatedQueryPlanner { diff --git a/datafusion-federation/src/table_provider.rs b/datafusion-federation/src/table_provider.rs index 92df798..93eb0aa 100644 --- a/datafusion-federation/src/table_provider.rs +++ b/datafusion-federation/src/table_provider.rs @@ -7,7 +7,9 @@ use datafusion::{ common::Constraints, datasource::TableProvider, error::{DataFusionError, Result}, - logical_expr::{Expr, LogicalPlan, TableProviderFilterPushDown, TableSource, TableType}, + logical_expr::{ + dml::InsertOp, Expr, LogicalPlan, TableProviderFilterPushDown, TableSource, TableType, + }, physical_plan::ExecutionPlan, }; @@ -20,6 +22,14 @@ pub struct FederatedTableProviderAdaptor { pub table_provider: Option>, } +impl std::fmt::Debug for FederatedTableProviderAdaptor { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("FederatedTableProviderAdaptor") + .field("table_provider", &self.table_provider) + .finish() + } +} + impl FederatedTableProviderAdaptor { pub fn new(source: Arc) -> Self { Self { @@ -124,7 +134,7 @@ impl TableProvider for FederatedTableProviderAdaptor { &self, _state: &dyn Session, input: Arc, - overwrite: bool, + overwrite: InsertOp, ) -> Result> { if let Some(table_provider) = &self.table_provider { return table_provider.insert_into(_state, input, overwrite).await; diff --git a/sources/sql/src/lib.rs b/sources/sql/src/lib.rs index 0cf0879..0186f1a 100644 --- a/sources/sql/src/lib.rs +++ b/sources/sql/src/lib.rs @@ -78,6 +78,12 @@ struct SQLFederationAnalyzerRule { planner: Arc, } +impl std::fmt::Debug for SQLFederationAnalyzerRule { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SQLFederationAnalyzerRule").finish() + } +} + impl SQLFederationAnalyzerRule { pub fn new(executor: Arc) -> Self { Self { diff --git a/sources/sql/src/schema.rs b/sources/sql/src/schema.rs index aa23fd0..86c58ff 100644 --- a/sources/sql/src/schema.rs +++ b/sources/sql/src/schema.rs @@ -13,6 +13,7 @@ use datafusion_federation::{ use crate::SQLFederationProvider; +#[derive(Debug)] pub struct SQLSchemaProvider { // provider: Arc, tables: Vec>, @@ -74,6 +75,7 @@ impl SchemaProvider for SQLSchemaProvider { } } +#[derive(Debug)] pub struct MultiSchemaProvider { children: Vec>, } @@ -114,6 +116,15 @@ pub struct SQLTableSource { schema: SchemaRef, } +impl std::fmt::Debug for SQLTableSource { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SQLTableSource") + .field("table_name", &self.table_name) + .field("schema", &self.schema) + .finish() + } +} + impl SQLTableSource { // creates a SQLTableSource and infers the table schema pub async fn new(provider: Arc, table_name: String) -> Result { From bbc0f89637dff6e13a49db3e27423f83143c97aa Mon Sep 17 00:00:00 2001 From: Sergei Grebnov Date: Fri, 29 Nov 2024 10:18:32 -0800 Subject: [PATCH 42/48] Add own modified version of Datafusion `optimize_projections` rule (#25) --- datafusion-federation/src/optimize.rs | 23 +- .../src/optimize/optimize_projections/mod.rs | 1047 +++++++++++++++++ .../optimize_projections/required_indices.rs | 229 ++++ 3 files changed, 1283 insertions(+), 16 deletions(-) create mode 100644 datafusion-federation/src/optimize/optimize_projections/mod.rs create mode 100644 datafusion-federation/src/optimize/optimize_projections/required_indices.rs diff --git a/datafusion-federation/src/optimize.rs b/datafusion-federation/src/optimize.rs index 082ddf6..b669210 100644 --- a/datafusion-federation/src/optimize.rs +++ b/datafusion-federation/src/optimize.rs @@ -1,31 +1,26 @@ use datafusion::{ common::tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, error::Result, - execution::{SessionState, SessionStateBuilder}, logical_expr::LogicalPlan, optimizer::{ - optimize_projections::OptimizeProjections, optimizer::ApplyOrder, - push_down_filter::PushDownFilter, OptimizerConfig, OptimizerRule, + optimizer::ApplyOrder, push_down_filter::PushDownFilter, OptimizerConfig, OptimizerContext, OptimizerRule }, - prelude::SessionConfig, }; +use optimize_projections::OptimizeProjections; + +mod optimize_projections; #[derive(Debug)] pub(crate) struct Optimizer { - config: SessionState, + config: OptimizerContext, push_down_filter: PushDownFilter, optimize_projections: OptimizeProjections, } impl Default for Optimizer { fn default() -> Self { - // `push_down_filter` and `optimize_projections` does not use config (except `optimize_projections_preserve_existing_projections`) so it can be default - // `SessionState` implements `OptimizerConfig` allowing specification of the required configuration for optimization rules. - let config = SessionStateBuilder::new() - .with_config( - SessionConfig::new().with_optimize_projections_preserve_existing_projections(true), - ) - .build(); + // `push_down_filter` and `optimize_projections` does not use config so it can be default + let config = OptimizerContext::default(); Self { config, @@ -36,10 +31,6 @@ impl Default for Optimizer { } impl Optimizer { - pub fn new() -> Self { - Self::default() - } - pub(crate) fn optimize_plan(&self, plan: LogicalPlan) -> Result { let mut optimized_plan = plan .rewrite(&mut Rewriter::new( diff --git a/datafusion-federation/src/optimize/optimize_projections/mod.rs b/datafusion-federation/src/optimize/optimize_projections/mod.rs new file mode 100644 index 0000000..699287c --- /dev/null +++ b/datafusion-federation/src/optimize/optimize_projections/mod.rs @@ -0,0 +1,1047 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{ + collections::{HashMap, HashSet}, + result, + sync::Arc, +}; + +use datafusion::{ + common::{ + get_required_group_by_exprs_indices, internal_datafusion_err, internal_err, + tree_node::{Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion}, + Column, JoinType, + }, + error::DataFusionError, + logical_expr::{ + expr::Alias, Aggregate, Distinct, LogicalPlan, Projection, TableScan, Unnest, Window, + }, + optimizer::{optimizer::ApplyOrder, utils::NamePreserver, OptimizerConfig, OptimizerRule}, + prelude::Expr, +}; +use required_indices::RequiredIndicies; + +type Result = result::Result; + +/// A modified version of DataFusion's [`OptimizeProjections`](https://github.com/apache/datafusion/blob/main/datafusion/optimizer/src/optimize_projections/mod.rs) rule. +/// +/// Unlike the original DataFusion implementation, this version does not create +/// or remove `Projection` nodes. This is required for `unparser` to correctly convert +/// `LogicalPlan` to SQL and also keeps the query layout simple and readable while relying on the +/// underlying SQL engine to apply its own optimizations during execution. +/// +/// For more information, see the [DataFusion discussion](https://github.com/apache/datafusion/pull/13267). +mod required_indices; + +#[derive(Default, Debug)] +pub struct OptimizeProjections {} + +impl OptimizeProjections { + #[must_use] + pub fn new() -> Self { + Self {} + } +} + +impl OptimizerRule for OptimizeProjections { + fn name(&self) -> &str { + "federation_optimize_projections" + } + + fn apply_order(&self) -> Option { + None + } + + fn supports_rewrite(&self) -> bool { + true + } + + fn rewrite( + &self, + plan: LogicalPlan, + config: &dyn OptimizerConfig, + ) -> Result> { + // All output fields are necessary: + let indices = RequiredIndicies::new_for_all_exprs(&plan); + optimize_projections(plan, config, indices) + } +} + +/// Removes unnecessary columns (e.g. columns that do not appear in the output +/// schema and/or are not used during any computation step such as expression +/// evaluation) from the logical plan and its inputs. +/// +/// # Parameters +/// +/// - `plan`: A reference to the input `LogicalPlan` to optimize. +/// - `config`: A reference to the optimizer configuration. +/// - `indices`: A slice of column indices that represent the necessary column +/// indices for downstream (parent) plan nodes. +/// +/// # Returns +/// +/// A `Result` object with the following semantics: +/// +/// - `Ok(Some(LogicalPlan))`: An optimized `LogicalPlan` without unnecessary +/// columns. +/// - `Ok(None)`: Signal that the given logical plan did not require any change. +/// - `Err(error)`: An error occurred during the optimization process. +fn optimize_projections( + plan: LogicalPlan, + config: &dyn OptimizerConfig, + indices: RequiredIndicies, +) -> Result> { + // Recursively rewrite any nodes that may be able to avoid computation given + // their parents' required indices. + match plan { + LogicalPlan::Projection(proj) => { + return merge_consecutive_projections(proj)?.transform_data(|proj| { + rewrite_projection_given_requirements(proj, config, &indices) + }) + } + LogicalPlan::Aggregate(aggregate) => { + // Split parent requirements to GROUP BY and aggregate sections: + let n_group_exprs = aggregate.group_expr_len()?; + // Offset aggregate indices so that they point to valid indices at + // `aggregate.aggr_expr`: + let (group_by_reqs, aggregate_reqs) = indices.split_off(n_group_exprs); + + // Get absolutely necessary GROUP BY fields: + let group_by_expr_existing = aggregate + .group_expr + .iter() + .map(|group_by_expr| group_by_expr.schema_name().to_string()) + .collect::>(); + + let new_group_bys = if let Some(simplest_groupby_indices) = + get_required_group_by_exprs_indices( + aggregate.input.schema(), + &group_by_expr_existing, + ) { + // Some of the fields in the GROUP BY may be required by the + // parent even if these fields are unnecessary in terms of + // functional dependency. + group_by_reqs + .append(&simplest_groupby_indices) + .get_at_indices(&aggregate.group_expr) + } else { + aggregate.group_expr + }; + + // Only use the absolutely necessary aggregate expressions required + // by the parent: + let mut new_aggr_expr = aggregate_reqs.get_at_indices(&aggregate.aggr_expr); + + // Aggregations always need at least one aggregate expression. + // With a nested count, we don't require any column as input, but + // still need to create a correct aggregate, which may be optimized + // out later. As an example, consider the following query: + // + // SELECT count(*) FROM (SELECT count(*) FROM [...]) + // + // which always returns 1. + if new_aggr_expr.is_empty() + && new_group_bys.is_empty() + && !aggregate.aggr_expr.is_empty() + { + // take the old, first aggregate expression + new_aggr_expr = aggregate.aggr_expr; + new_aggr_expr.resize_with(1, || unreachable!()); + } + + let all_exprs_iter = new_group_bys.iter().chain(new_aggr_expr.iter()); + let schema = aggregate.input.schema(); + let necessary_indices = RequiredIndicies::new().with_exprs(schema, all_exprs_iter); + let necessary_exprs = necessary_indices.get_required_exprs(schema); + + return optimize_projections( + Arc::unwrap_or_clone(aggregate.input), + config, + necessary_indices, + )? + .transform_data(|aggregate_input| { + // Simplify the input of the aggregation by adding a projection so + // that its input only contains absolutely necessary columns for + // the aggregate expressions. Note that necessary_indices refer to + // fields in `aggregate.input.schema()`. + add_projection_on_top_if_helpful(aggregate_input, necessary_exprs, config) + })? + .map_data(|aggregate_input| { + // Create a new aggregate plan with the updated input and only the + // absolutely necessary fields: + Aggregate::try_new(Arc::new(aggregate_input), new_group_bys, new_aggr_expr) + .map(LogicalPlan::Aggregate) + }); + } + LogicalPlan::Window(window) => { + let input_schema = Arc::clone(window.input.schema()); + // Split parent requirements to child and window expression sections: + let n_input_fields = input_schema.fields().len(); + // Offset window expression indices so that they point to valid + // indices at `window.window_expr`: + let (child_reqs, window_reqs) = indices.split_off(n_input_fields); + + // Only use window expressions that are absolutely necessary according + // to parent requirements: + let new_window_expr = window_reqs.get_at_indices(&window.window_expr); + + // Get all the required column indices at the input, either by the + // parent or window expression requirements. + let required_indices = child_reqs.with_exprs(&input_schema, &new_window_expr); + + return optimize_projections( + Arc::unwrap_or_clone(window.input), + config, + required_indices.clone(), + )? + .transform_data(|window_child| { + if new_window_expr.is_empty() { + // When no window expression is necessary, use the input directly: + Ok(Transformed::no(window_child)) + } else { + // Calculate required expressions at the input of the window. + // Please note that we use `input_schema`, because `required_indices` + // refers to that schema + let required_exprs = required_indices.get_required_exprs(&input_schema); + + let window_child = + add_projection_on_top_if_helpful(window_child, required_exprs, config)? + .data; + + Window::try_new(new_window_expr, Arc::new(window_child)) + .map(LogicalPlan::Window) + .map(Transformed::yes) + } + }); + } + LogicalPlan::TableScan(table_scan) => { + let TableScan { + table_name, + source, + projection, + filters, + fetch, + projected_schema: _, + } = table_scan; + + // Get indices referred to in the original (schema with all fields) + // given projected indices. + let projection = match &projection { + Some(projection) => indices.into_mapped_indices(|idx| projection[idx]), + None => indices.into_inner(), + }; + return TableScan::try_new(table_name, source, Some(projection), filters, fetch) + .map(LogicalPlan::TableScan) + .map(Transformed::yes); + } + // Other node types are handled below + _ => {} + }; + + // For other plan node types, calculate indices for columns they use and + // try to rewrite their children + let mut child_required_indices: Vec = match &plan { + LogicalPlan::Sort(_) + | LogicalPlan::Filter(_) + | LogicalPlan::Repartition(_) + | LogicalPlan::Union(_) + | LogicalPlan::SubqueryAlias(_) + | LogicalPlan::Distinct(Distinct::On(_)) => { + // Pass index requirements from the parent as well as column indices + // that appear in this plan's expressions to its child. All these + // operators benefit from "small" inputs, so the projection_beneficial + // flag is `true`. + plan.inputs() + .into_iter() + .map(|input| { + indices + .clone() + .with_projection_beneficial() + .with_plan_exprs(&plan, input.schema()) + }) + .collect::>()? + } + LogicalPlan::Limit(_) | LogicalPlan::Prepare(_) => { + // Pass index requirements from the parent as well as column indices + // that appear in this plan's expressions to its child. These operators + // do not benefit from "small" inputs, so the projection_beneficial + // flag is `false`. + plan.inputs() + .into_iter() + .map(|input| indices.clone().with_plan_exprs(&plan, input.schema())) + .collect::>()? + } + LogicalPlan::Copy(_) + | LogicalPlan::Ddl(_) + | LogicalPlan::Dml(_) + | LogicalPlan::Explain(_) + | LogicalPlan::Analyze(_) + | LogicalPlan::Subquery(_) + | LogicalPlan::Distinct(Distinct::All(_)) => { + // These plans require all their fields, and their children should + // be treated as final plans -- otherwise, we may have schema a + // mismatch. + // TODO: For some subquery variants (e.g. a subquery arising from an + // EXISTS expression), we may not need to require all indices. + plan.inputs() + .into_iter() + .map(RequiredIndicies::new_for_all_exprs) + .collect() + } + LogicalPlan::Extension(extension) => { + let Some(necessary_children_indices) = + extension.node.necessary_children_exprs(indices.indices()) + else { + // Requirements from parent cannot be routed down to user defined logical plan safely + return Ok(Transformed::no(plan)); + }; + let children = extension.node.inputs(); + if children.len() != necessary_children_indices.len() { + return internal_err!("Inconsistent length between children and necessary children indices. \ + Make sure `.necessary_children_exprs` implementation of the `UserDefinedLogicalNode` is \ + consistent with actual children length for the node."); + } + children + .into_iter() + .zip(necessary_children_indices) + .map(|(child, necessary_indices)| { + RequiredIndicies::new_from_indices(necessary_indices) + .with_plan_exprs(&plan, child.schema()) + }) + .collect::>>()? + } + LogicalPlan::EmptyRelation(_) + | LogicalPlan::RecursiveQuery(_) + | LogicalPlan::Statement(_) + | LogicalPlan::Values(_) + | LogicalPlan::DescribeTable(_) + | LogicalPlan::Execute(_) => { + // These operators have no inputs, so stop the optimization process. + return Ok(Transformed::no(plan)); + } + LogicalPlan::Join(join) => { + let left_len = join.left.schema().fields().len(); + let (left_req_indices, right_req_indices) = + split_join_requirements(left_len, indices, &join.join_type); + let left_indices = left_req_indices.with_plan_exprs(&plan, join.left.schema())?; + let right_indices = right_req_indices.with_plan_exprs(&plan, join.right.schema())?; + // Joins benefit from "small" input tables (lower memory usage). + // Therefore, each child benefits from projection: + vec![ + left_indices.with_projection_beneficial(), + right_indices.with_projection_beneficial(), + ] + } + // these nodes are explicitly rewritten in the match statement above + LogicalPlan::Projection(_) + | LogicalPlan::Aggregate(_) + | LogicalPlan::Window(_) + | LogicalPlan::TableScan(_) => { + return internal_err!( + "OptimizeProjection: should have handled in the match statement above" + ); + } + LogicalPlan::Unnest(Unnest { + dependency_indices, .. + }) => { + vec![RequiredIndicies::new_from_indices( + dependency_indices.clone(), + )] + } + }; + + // Required indices are currently ordered (child0, child1, ...) + // but the loop pops off the last element, so we need to reverse the order + child_required_indices.reverse(); + if child_required_indices.len() != plan.inputs().len() { + return internal_err!( + "OptimizeProjection: child_required_indices length mismatch with plan inputs" + ); + } + + // Rewrite children of the plan + let transformed_plan = plan.map_children(|child| { + let required_indices = child_required_indices.pop().ok_or_else(|| { + internal_datafusion_err!( + "Unexpected number of required_indices in OptimizeProjections rule" + ) + })?; + + let projection_beneficial = required_indices.projection_beneficial(); + let project_exprs = required_indices.get_required_exprs(child.schema()); + + optimize_projections(child, config, required_indices)?.transform_data(|new_input| { + if projection_beneficial { + add_projection_on_top_if_helpful(new_input, project_exprs, config) + } else { + Ok(Transformed::no(new_input)) + } + }) + })?; + + // If any of the children are transformed, we need to potentially update the plan's schema + if transformed_plan.transformed { + transformed_plan.map_data(|plan| plan.recompute_schema()) + } else { + Ok(transformed_plan) + } +} + +/// Merges consecutive projections. +/// +/// Given a projection `proj`, this function attempts to merge it with a previous +/// projection if it exists and if merging is beneficial. Merging is considered +/// beneficial when expressions in the current projection are non-trivial and +/// appear more than once in its input fields. This can act as a caching mechanism +/// for non-trivial computations. +/// +/// # Parameters +/// +/// * `proj` - A reference to the `Projection` to be merged. +/// +/// # Returns +/// +/// A `Result` object with the following semantics: +/// +/// - `Ok(Some(Projection))`: Merge was beneficial and successful. Contains the +/// merged projection. +/// - `Ok(None)`: Signals that merge is not beneficial (and has not taken place). +/// - `Err(error)`: An error occured during the function call. +fn merge_consecutive_projections(proj: Projection) -> Result> { + let Projection { + expr, + input, + schema, + .. + } = proj; + let LogicalPlan::Projection(prev_projection) = input.as_ref() else { + return Projection::try_new_with_schema(expr, input, schema).map(Transformed::no); + }; + + // Count usages (referrals) of each projection expression in its input fields: + let mut column_referral_map = HashMap::<&Column, usize>::new(); + expr.iter() + .for_each(|expr| expr.add_column_ref_counts(&mut column_referral_map)); + + // If an expression is non-trivial and appears more than once, do not merge + // them as consecutive projections will benefit from a compute-once approach. + // For details, see: https://github.com/apache/datafusion/issues/8296 + if column_referral_map.into_iter().any(|(col, usage)| { + usage > 1 + && !is_expr_trivial( + &prev_projection.expr[prev_projection.schema.index_of_column(col).unwrap()], + ) + }) { + // no change + return Projection::try_new_with_schema(expr, input, schema).map(Transformed::no); + } + + let LogicalPlan::Projection(prev_projection) = Arc::unwrap_or_clone(input) else { + // We know it is a `LogicalPlan::Projection` from check above + unreachable!(); + }; + + // Try to rewrite the expressions in the current projection using the + // previous projection as input: + let name_preserver = NamePreserver::new_for_projection(); + let mut original_names = vec![]; + let new_exprs = expr.into_iter().map_until_stop_and_collect(|expr| { + original_names.push(name_preserver.save(&expr)); + + // do not rewrite top level Aliases (rewriter will remove all aliases within exprs) + match expr { + Expr::Alias(Alias { + expr, + relation, + name, + }) => rewrite_expr(*expr, &prev_projection).map(|result| { + result.update_data(|expr| Expr::Alias(Alias::new(expr, relation, name))) + }), + e => rewrite_expr(e, &prev_projection), + } + })?; + + // if the expressions could be rewritten, create a new projection with the + // new expressions + if new_exprs.transformed { + // Add any needed aliases back to the expressions + let new_exprs = new_exprs + .data + .into_iter() + .zip(original_names) + .map(|(expr, original_name)| original_name.restore(expr)) + .collect::>(); + Projection::try_new(new_exprs, prev_projection.input).map(Transformed::yes) + } else { + // not rewritten, so put the projection back together + let input = Arc::new(LogicalPlan::Projection(prev_projection)); + Projection::try_new_with_schema(new_exprs.data, input, schema).map(Transformed::no) + } +} + +// Check whether `expr` is trivial; i.e. it doesn't imply any computation. +fn is_expr_trivial(expr: &Expr) -> bool { + matches!(expr, Expr::Column(_) | Expr::Literal(_)) +} + +/// Rewrites a projection expression using the projection before it (i.e. its input) +/// This is a subroutine to the `merge_consecutive_projections` function. +/// +/// # Parameters +/// +/// * `expr` - A reference to the expression to rewrite. +/// * `input` - A reference to the input of the projection expression (itself +/// a projection). +/// +/// # Returns +/// +/// A `Result` object with the following semantics: +/// +/// - `Ok(Some(Expr))`: Rewrite was successful. Contains the rewritten result. +/// - `Ok(None)`: Signals that `expr` can not be rewritten. +/// - `Err(error)`: An error occurred during the function call. +/// +/// # Notes +/// This rewrite also removes any unnecessary layers of aliasing. +/// +/// Without trimming, we can end up with unnecessary indirections inside expressions +/// during projection merges. +/// +/// Consider: +/// +/// ```text +/// Projection(a1 + b1 as sum1) +/// --Projection(a as a1, b as b1) +/// ----Source(a, b) +/// ``` +/// +/// After merge, we want to produce: +/// +/// ```text +/// Projection(a + b as sum1) +/// --Source(a, b) +/// ``` +/// +/// Without trimming, we would end up with: +/// +/// ```text +/// Projection((a as a1 + b as b1) as sum1) +/// --Source(a, b) +/// ``` +fn rewrite_expr(expr: Expr, input: &Projection) -> Result> { + expr.transform_up(|expr| { + match expr { + // remove any intermediate aliases + Expr::Alias(alias) => Ok(Transformed::yes(*alias.expr)), + Expr::Column(col) => { + // Find index of column: + let idx = input.schema.index_of_column(&col)?; + // get the corresponding unaliased input expression + // + // For example: + // * the input projection is [`a + b` as c, `d + e` as f] + // * the current column is an expression "f" + // + // return the expression `d + e` (not `d + e` as f) + let input_expr = input.expr[idx].clone().unalias_nested().data; + Ok(Transformed::yes(input_expr)) + } + // Unsupported type for consecutive projection merge analysis. + _ => Ok(Transformed::no(expr)), + } + }) +} + +/// Accumulates outer-referenced columns by the +/// given expression, `expr`. +/// +/// # Parameters +/// +/// * `expr` - The expression to analyze for outer-referenced columns. +/// * `columns` - A mutable reference to a `HashSet` where detected +/// columns are collected. +fn outer_columns<'a>(expr: &'a Expr, columns: &mut HashSet<&'a Column>) { + // inspect_expr_pre doesn't handle subquery references, so find them explicitly + expr.apply(|expr| { + match expr { + Expr::OuterReferenceColumn(_, col) => { + columns.insert(col); + } + Expr::ScalarSubquery(subquery) => { + outer_columns_helper_multi(&subquery.outer_ref_columns, columns); + } + Expr::Exists(exists) => { + outer_columns_helper_multi(&exists.subquery.outer_ref_columns, columns); + } + Expr::InSubquery(insubquery) => { + outer_columns_helper_multi(&insubquery.subquery.outer_ref_columns, columns); + } + _ => {} + }; + Ok(TreeNodeRecursion::Continue) + }) + // unwrap: closure above never returns Err, so can not be Err here + .unwrap(); +} + +/// A recursive subroutine that accumulates outer-referenced columns by the +/// given expressions (`exprs`). +/// +/// # Parameters +/// +/// * `exprs` - The expressions to analyze for outer-referenced columns. +/// * `columns` - A mutable reference to a `HashSet` where detected +/// columns are collected. +fn outer_columns_helper_multi<'a, 'b>( + exprs: impl IntoIterator, + columns: &'b mut HashSet<&'a Column>, +) { + exprs.into_iter().for_each(|e| outer_columns(e, columns)); +} + +/// Splits requirement indices for a join into left and right children based on +/// the join type. +/// +/// This function takes the length of the left child, a slice of requirement +/// indices, and the type of join (e.g. `INNER`, `LEFT`, `RIGHT`) as arguments. +/// Depending on the join type, it divides the requirement indices into those +/// that apply to the left child and those that apply to the right child. +/// +/// - For `INNER`, `LEFT`, `RIGHT` and `FULL` joins, the requirements are split +/// between left and right children. The right child indices are adjusted to +/// point to valid positions within the right child by subtracting the length +/// of the left child. +/// +/// - For `LEFT ANTI`, `LEFT SEMI`, `RIGHT SEMI` and `RIGHT ANTI` joins, all +/// requirements are re-routed to either the left child or the right child +/// directly, depending on the join type. +/// +/// # Parameters +/// +/// * `left_len` - The length of the left child. +/// * `indices` - A slice of requirement indices. +/// * `join_type` - The type of join (e.g. `INNER`, `LEFT`, `RIGHT`). +/// +/// # Returns +/// +/// A tuple containing two vectors of `usize` indices: The first vector represents +/// the requirements for the left child, and the second vector represents the +/// requirements for the right child. The indices are appropriately split and +/// adjusted based on the join type. +fn split_join_requirements( + left_len: usize, + indices: RequiredIndicies, + join_type: &JoinType, +) -> (RequiredIndicies, RequiredIndicies) { + match join_type { + // In these cases requirements are split between left/right children: + JoinType::Inner + | JoinType::Left + | JoinType::Right + | JoinType::Full + | JoinType::LeftMark => { + // Decrease right side indices by `left_len` so that they point to valid + // positions within the right child: + indices.split_off(left_len) + } + // All requirements can be re-routed to left child directly. + JoinType::LeftAnti | JoinType::LeftSemi => (indices, RequiredIndicies::new()), + // All requirements can be re-routed to right side directly. + // No need to change index, join schema is right child schema. + JoinType::RightSemi | JoinType::RightAnti => (RequiredIndicies::new(), indices), + } +} + +fn add_projection_on_top_if_helpful( + plan: LogicalPlan, + _project_exprs: Vec, + _config: &dyn OptimizerConfig, +) -> Result> { + // Always prefer the original query layout. + Ok(Transformed::no(plan)) +} + +/// Rewrite the given projection according to the fields required by its +/// ancestors. +/// +/// # Parameters +/// +/// * `proj` - A reference to the original projection to rewrite. +/// * `config` - A reference to the optimizer configuration. +/// * `indices` - A slice of indices representing the columns required by the +/// ancestors of the given projection. +/// +/// # Returns +/// +/// A `Result` object with the following semantics: +/// +/// - `Ok(Some(LogicalPlan))`: Contains the rewritten projection +/// - `Ok(None)`: No rewrite necessary. +/// - `Err(error)`: An error occured during the function call. +fn rewrite_projection_given_requirements( + proj: Projection, + config: &dyn OptimizerConfig, + indices: &RequiredIndicies, +) -> Result> { + let Projection { expr, input, .. } = proj; + + let exprs_used = indices.get_at_indices(&expr); + + let required_indices = RequiredIndicies::new().with_exprs(input.schema(), exprs_used.iter()); + + // rewrite the children projection, and if they are changed rewrite the + // projection down + optimize_projections(Arc::unwrap_or_clone(input), config, required_indices)?.transform_data( + |input| { + if is_projection_unnecessary(&input, &exprs_used, config)? { + Ok(Transformed::yes(input)) + } else { + Projection::try_new(exprs_used, Arc::new(input)) + .map(LogicalPlan::Projection) + .map(Transformed::yes) + } + }, + ) +} + +fn is_projection_unnecessary( + _input: &LogicalPlan, + _proj_exprs: &[Expr], + _config: &dyn OptimizerConfig, +) -> Result { + // Always prefer the original query layout. Return false to keep the projection. + Ok(false) +} + +#[cfg(test)] +mod tests { + use std::{collections::HashMap, sync::Arc}; + + use datafusion::{ + arrow::datatypes::{DataType, Field, Schema}, + common::{DFSchema, JoinType}, + error::DataFusionError, + functions_aggregate::{ + count::{count, count_udaf}, + min_max::max, + }, + logical_expr::{ + LogicalPlan, LogicalPlanBuilder, LogicalTableSource, TableSource, UNNAMED_TABLE, + }, + optimizer::{Optimizer, OptimizerContext, OptimizerRule}, + prelude::{col, lit, Expr, ExprFunctionExt}, + sql::TableReference, + }; + + type Result = std::result::Result; + + #[test] + fn aggregate_filter_pushdown_preserve_projection() -> Result<()> { + let table_scan = test_table_scan()?; + let aggr_with_filter = count_udaf() + .call(vec![col("b")]) + .filter(col("c").gt(lit(42))) + .build()?; + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate( + vec![col("a")], + vec![count(col("b")), aggr_with_filter.alias("count2")], + )? + .project(vec![col("a"), col("count(test.b)"), col("count2")])? + .build()?; + + let expected = "Projection: test.a, count(test.b), count2\ + \n Aggregate: groupBy=[[test.a]], aggr=[[count(test.b), count(test.b) FILTER (WHERE test.c > Int32(42)) AS count2]]\ + \n TableScan: test projection=[a, b, c]"; + + assert_optimized_plan_equal(plan, expected)?; + + Ok(()) + } + + // Selected tests and helpers from the original DataFusion implementation of `OptimizeProjections` rule to ensure + // that the modified version behaves as expected. + // See: + + #[test] + fn aggregate_no_group_by() -> Result<()> { + let table_scan = test_table_scan()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(Vec::::new(), vec![max(col("b"))])? + .build()?; + + let expected = "Aggregate: groupBy=[[]], aggr=[[max(test.b)]]\ + \n TableScan: test projection=[b]"; + + assert_optimized_plan_equal(plan, expected) + } + + #[test] + fn aggregate_group_by() -> Result<()> { + let table_scan = test_table_scan()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(vec![col("c")], vec![max(col("b"))])? + .build()?; + + let expected = "Aggregate: groupBy=[[test.c]], aggr=[[max(test.b)]]\ + \n TableScan: test projection=[b, c]"; + + assert_optimized_plan_equal(plan, expected) + } + + #[test] + fn aggregate_group_by_with_table_alias() -> Result<()> { + let table_scan = test_table_scan()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .alias("a")? + .aggregate(vec![col("c")], vec![max(col("b"))])? + .build()?; + + let expected = "Aggregate: groupBy=[[a.c]], aggr=[[max(a.b)]]\ + \n SubqueryAlias: a\ + \n TableScan: test projection=[b, c]"; + + assert_optimized_plan_equal(plan, expected) + } + + #[test] + fn aggregate_no_group_by_with_filter() -> Result<()> { + let table_scan = test_table_scan()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .filter(col("c").gt(lit(1)))? + .aggregate(Vec::::new(), vec![max(col("b"))])? + .build()?; + + let expected = "Aggregate: groupBy=[[]], aggr=[[max(test.b)]]\ + \n Filter: test.c > Int32(1)\ + \n TableScan: test projection=[b, c]"; + + assert_optimized_plan_equal(plan, expected) + } + + #[test] + fn join_schema_trim_using_join() -> Result<()> { + // shared join columns from using join should be pushed to both sides + + let table_scan = test_table_scan()?; + + let schema = Schema::new(vec![Field::new("a", DataType::UInt32, false)]); + let table2_scan = scan_empty(Some("test2"), &schema, None)?.build()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .join_using(table2_scan, JoinType::Left, vec!["a"])? + .project(vec![col("a"), col("b")])? + .build()?; + + // make sure projections are pushed down to table scan + let expected = "Projection: test.a, test.b\ + \n Left Join: Using test.a = test2.a\ + \n TableScan: test projection=[a, b]\ + \n TableScan: test2 projection=[a]"; + + let optimized_plan = optimize(plan)?; + let formatted_plan = format!("{optimized_plan}"); + assert_eq!(formatted_plan, expected); + + // make sure schema for join node include both join columns + let optimized_join = optimized_plan.inputs()[0]; + assert_eq!( + **optimized_join.schema(), + DFSchema::new_with_metadata( + vec![ + ( + Some("test".into()), + Arc::new(Field::new("a", DataType::UInt32, false)) + ), + ( + Some("test".into()), + Arc::new(Field::new("b", DataType::UInt32, false)) + ), + ( + Some("test2".into()), + Arc::new(Field::new("a", DataType::UInt32, true)) + ), + ], + HashMap::new() + )?, + ); + + Ok(()) + } + + #[test] + fn redundant_project() -> Result<()> { + let table_scan = test_table_scan()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a"), col("b"), col("c")])? + .project(vec![col("a"), col("c"), col("b")])? + .build()?; + let expected = "Projection: test.a, test.c, test.b\ + \n TableScan: test projection=[a, b, c]"; + + assert_optimized_plan_equal(plan, expected) + } + + #[test] + fn reorder_projection() -> Result<()> { + let table_scan = test_table_scan()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("c"), col("b"), col("a")])? + .build()?; + let expected = "Projection: test.c, test.b, test.a\ + \n TableScan: test projection=[a, b, c]"; + + assert_optimized_plan_equal(plan, expected) + } + + #[test] + fn table_scan_without_projection() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan).build()?; + // should expand projection to all columns without projection + let expected = "TableScan: test projection=[a, b, c]"; + assert_optimized_plan_equal(plan, expected) + } + + #[test] + fn table_unused_column() -> Result<()> { + let table_scan = test_table_scan()?; + assert_eq!(3, table_scan.schema().fields().len()); + assert_fields_eq(&table_scan, vec!["a", "b", "c"]); + + // we never use "b" in the first projection => remove it + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("c"), col("a"), col("b")])? + .filter(col("c").gt(lit(1)))? + .aggregate(vec![col("c")], vec![max(col("a"))])? + .build()?; + + assert_fields_eq(&plan, vec!["c", "max(test.a)"]); + + let plan = optimize(plan).expect("failed to optimize plan"); + let expected = "\ + Aggregate: groupBy=[[test.c]], aggr=[[max(test.a)]]\ + \n Filter: test.c > Int32(1)\ + \n Projection: test.c, test.a\ + \n TableScan: test projection=[a, c]"; + + assert_optimized_plan_equal(plan, expected) + } + + fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} + + fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { + assert_optimized_plan_eq(Arc::new(super::OptimizeProjections::new()), plan, expected) + } + + pub fn assert_optimized_plan_eq( + rule: Arc, + plan: LogicalPlan, + expected: &str, + ) -> Result<()> { + // Apply the rule once + let opt_context = OptimizerContext::new().with_max_passes(1); + + let optimizer = Optimizer::with_rules(vec![Arc::clone(&rule)]); + let optimized_plan = optimizer.optimize(plan, &opt_context, observe)?; + let formatted_plan = format!("{optimized_plan}"); + assert_eq!(formatted_plan, expected); + + Ok(()) + } + + fn optimize(plan: LogicalPlan) -> Result { + let optimizer = Optimizer::with_rules(vec![Arc::new(super::OptimizeProjections::new())]); + let optimized_plan = optimizer.optimize(plan, &OptimizerContext::new(), observe)?; + Ok(optimized_plan) + } + + pub fn assert_fields_eq(plan: &LogicalPlan, expected: Vec<&str>) { + let actual: Vec = plan + .schema() + .fields() + .iter() + .map(|f| f.name().clone()) + .collect(); + assert_eq!(actual, expected); + } + + pub fn test_table_scan_fields() -> Vec { + vec![ + Field::new("a", DataType::UInt32, false), + Field::new("b", DataType::UInt32, false), + Field::new("c", DataType::UInt32, false), + ] + } + + /// some tests share a common table with different names + pub fn test_table_scan_with_name(name: &str) -> Result { + let schema = Schema::new(test_table_scan_fields()); + table_scan(Some(name), &schema, None)?.build() + } + + /// some tests share a common table + pub fn test_table_scan() -> Result { + test_table_scan_with_name("test") + } + + /// Scan an empty data source, mainly used in tests + pub fn scan_empty( + name: Option<&str>, + table_schema: &Schema, + projection: Option>, + ) -> Result { + table_scan(name, table_schema, projection) + } + /// Create a LogicalPlanBuilder representing a scan of a table with the provided name and schema. + /// This is mostly used for testing and documentation. + pub fn table_scan( + name: Option>, + table_schema: &Schema, + projection: Option>, + ) -> Result { + table_scan_with_filters(name, table_schema, projection, vec![]) + } + + /// Create a LogicalPlanBuilder representing a scan of a table with the provided name and schema, + /// and inlined filters. + /// This is mostly used for testing and documentation. + pub fn table_scan_with_filters( + name: Option>, + table_schema: &Schema, + projection: Option>, + filters: Vec, + ) -> Result { + let table_source = table_source(table_schema); + let name = name + .map(|n| n.into()) + .unwrap_or_else(|| TableReference::bare(UNNAMED_TABLE)); + LogicalPlanBuilder::scan_with_filters(name, table_source, projection, filters) + } + + fn table_source(table_schema: &Schema) -> Arc { + let table_schema = Arc::new(table_schema.clone()); + Arc::new(LogicalTableSource::new(table_schema)) + } +} diff --git a/datafusion-federation/src/optimize/optimize_projections/required_indices.rs b/datafusion-federation/src/optimize/optimize_projections/required_indices.rs new file mode 100644 index 0000000..aa4889f --- /dev/null +++ b/datafusion-federation/src/optimize/optimize_projections/required_indices.rs @@ -0,0 +1,229 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// Copy of DataFusion's [`RequiredIndicies`](https://github.com/apache/datafusion/blob/main/datafusion/optimizer/src/optimize_projections/required_indices.rs) implementation. + +use std::result; + +use datafusion::{common::{tree_node::TreeNodeRecursion, Column, DFSchemaRef}, error::DataFusionError, logical_expr::LogicalPlan, prelude::Expr}; + +use super::outer_columns; + +type Result = result::Result; + +/// Represents columns in a schema which are required (used) by a plan node +/// +/// Also carries a flag indicating if putting a projection above children is +/// beneficial for the parent. For example `LogicalPlan::Filter` benefits from +/// small tables. Hence for filter child this flag would be `true`. Defaults to +/// `false` +/// +/// # Invariant +/// +/// Indices are always in order and without duplicates. For example, if these +/// indices were added `[3, 2, 4, 3, 6, 1]`, the instance would be represented +/// by `[1, 2, 3, 6]`. +#[derive(Debug, Clone, Default)] +pub(super) struct RequiredIndicies { + /// The indices of the required columns in the + indices: Vec, + /// If putting a projection above children is beneficial for the parent. + /// Defaults to false. + projection_beneficial: bool, +} + +impl RequiredIndicies { + /// Create a new, empty instance + pub fn new() -> Self { + Self::default() + } + + /// Create a new instance that requires all columns from the specified plan + pub fn new_for_all_exprs(plan: &LogicalPlan) -> Self { + Self { + indices: (0..plan.schema().fields().len()).collect(), + projection_beneficial: false, + } + } + + /// Create a new instance with the specified indices as required + pub fn new_from_indices(indices: Vec) -> Self { + Self { + indices, + projection_beneficial: false, + } + .compact() + } + + /// Convert the instance to its inner indices + pub fn into_inner(self) -> Vec { + self.indices + } + + /// Set the projection beneficial flag + pub fn with_projection_beneficial(mut self) -> Self { + self.projection_beneficial = true; + self + } + + /// Return the value of projection beneficial flag + pub fn projection_beneficial(&self) -> bool { + self.projection_beneficial + } + + /// Return a reference to the underlying indices + pub fn indices(&self) -> &[usize] { + &self.indices + } + + /// Add required indices for all `exprs` used in plan + pub fn with_plan_exprs( + mut self, + plan: &LogicalPlan, + schema: &DFSchemaRef, + ) -> Result { + // Add indices of the child fields referred to by the expressions in the + // parent + plan.apply_expressions(|e| { + self.add_expr(schema, e); + Ok(TreeNodeRecursion::Continue) + })?; + Ok(self.compact()) + } + + /// Adds the indices of the fields referred to by the given expression + /// `expr` within the given schema (`input_schema`). + /// + /// Self is NOT compacted (and thus this method is not pub) + /// + /// # Parameters + /// + /// * `input_schema`: The input schema to analyze for index requirements. + /// * `expr`: An expression for which we want to find necessary field indices. + fn add_expr(&mut self, input_schema: &DFSchemaRef, expr: &Expr) { + // TODO could remove these clones (and visit the expression directly) + let mut cols = expr.column_refs(); + // Get outer-referenced (subquery) columns: + outer_columns(expr, &mut cols); + self.indices.reserve(cols.len()); + for col in cols { + if let Some(idx) = input_schema.maybe_index_of_column(col) { + self.indices.push(idx); + } + } + } + + /// Adds the indices of the fields referred to by the given expressions + /// `within the given schema. + /// + /// # Parameters + /// + /// * `input_schema`: The input schema to analyze for index requirements. + /// * `exprs`: the expressions for which we want to find field indices. + pub fn with_exprs<'a>( + self, + schema: &DFSchemaRef, + exprs: impl IntoIterator, + ) -> Self { + exprs + .into_iter() + .fold(self, |mut acc, expr| { + acc.add_expr(schema, expr); + acc + }) + .compact() + } + + /// Adds all `indices` into this instance. + pub fn append(mut self, indices: &[usize]) -> Self { + self.indices.extend_from_slice(indices); + self.compact() + } + + /// Splits this instance into a tuple with two instances: + /// * The first `n` indices + /// * The remaining indices, adjusted down by n + pub fn split_off(self, n: usize) -> (Self, Self) { + let (l, r) = self.partition(|idx| idx < n); + (l, r.map_indices(|idx| idx - n)) + } + + /// Partitions the indices in this instance into two groups based on the + /// given predicate function `f`. + fn partition(&self, f: F) -> (Self, Self) + where + F: Fn(usize) -> bool, + { + let (l, r): (Vec, Vec) = + self.indices.iter().partition(|&&idx| f(idx)); + let projection_beneficial = self.projection_beneficial; + + ( + Self { + indices: l, + projection_beneficial, + }, + Self { + indices: r, + projection_beneficial, + }, + ) + } + + /// Map the indices in this instance to a new set of indices based on the + /// given function `f`, returning the mapped indices + /// + /// Not `pub` as it might not preserve the invariant of compacted indices + fn map_indices(mut self, f: F) -> Self + where + F: Fn(usize) -> usize, + { + self.indices.iter_mut().for_each(|idx| *idx = f(*idx)); + self + } + + /// Apply the given function `f` to each index in this instance, returning + /// the mapped indices + pub fn into_mapped_indices(self, f: F) -> Vec + where + F: Fn(usize) -> usize, + { + self.map_indices(f).into_inner() + } + + /// Returns the `Expr`s from `exprs` that are at the indices in this instance + pub fn get_at_indices(&self, exprs: &[Expr]) -> Vec { + self.indices.iter().map(|&idx| exprs[idx].clone()).collect() + } + + /// Generates the required expressions (columns) that reside at `indices` of + /// the given `input_schema`. + pub fn get_required_exprs(&self, input_schema: &DFSchemaRef) -> Vec { + self.indices + .iter() + .map(|&idx| Expr::from(Column::from(input_schema.qualified_field(idx)))) + .collect() + } + + /// Compacts the indices of this instance so they are sorted + /// (ascending) and deduplicated. + fn compact(mut self) -> Self { + self.indices.sort_unstable(); + self.indices.dedup(); + self + } +} \ No newline at end of file From 04487baa6e29eaff6e89883ac3c2b4d146081387 Mon Sep 17 00:00:00 2001 From: Qianqian <130200611+Sevenannn@users.noreply.github.com> Date: Mon, 23 Dec 2024 11:12:32 -0800 Subject: [PATCH 43/48] Federation analyzer should federate the subquery when possible (#26) * optimize_recursively should check the provider for subqueries * Preserve the vec,use apply_subqueries * Federate subtree in subquery * clean up * Rewrite federation as optimizer rule * Revert "Rewrite federation as optimizer rule" This reverts commit e0ff999a9b7ccf4c64e52143c69c6287e1960b41. * Keep federaion rule as analyzer rule, handle InSubquery federation * Wrap a Projection plan when the entire Insubquery can be federated * not run scalar_subquery_to_join and update scalar subquery handling * Resolve outref federation provider * clean up * Include table reference created by subquery aliases * Remove redundant clone * Do not panic on ScanResult unwrap --- datafusion-federation/src/analyzer.rs | 210 ------- datafusion-federation/src/analyzer/mod.rs | 514 ++++++++++++++++++ .../src/analyzer/scan_result.rs | 113 ++++ datafusion-federation/src/optimize.rs | 3 +- datafusion-federation/src/schema_cast.rs | 7 +- 5 files changed, 632 insertions(+), 215 deletions(-) delete mode 100644 datafusion-federation/src/analyzer.rs create mode 100644 datafusion-federation/src/analyzer/mod.rs create mode 100644 datafusion-federation/src/analyzer/scan_result.rs diff --git a/datafusion-federation/src/analyzer.rs b/datafusion-federation/src/analyzer.rs deleted file mode 100644 index 644b034..0000000 --- a/datafusion-federation/src/analyzer.rs +++ /dev/null @@ -1,210 +0,0 @@ -use std::sync::Arc; - -use datafusion::{ - common::{tree_node::TreeNode, Column}, - config::ConfigOptions, - datasource::source_as_provider, - error::Result, - logical_expr::{Expr, LogicalPlan, Projection, TableScan, TableSource}, - optimizer::analyzer::AnalyzerRule, - sql::TableReference, -}; - -use crate::{ - optimize::Optimizer, FederatedTableProviderAdaptor, FederatedTableSource, FederationProviderRef, -}; - -#[derive(Default, Debug)] -pub struct FederationAnalyzerRule { - optimizer: Optimizer, -} - -impl AnalyzerRule for FederationAnalyzerRule { - // Walk over the plan, look for the largest subtrees that only have - // TableScans from the same FederationProvider. - // There 'largest sub-trees' are passed to their respective FederationProvider.optimizer. - fn analyze(&self, plan: LogicalPlan, config: &ConfigOptions) -> Result { - if !contains_federated_table(&plan)? { - return Ok(plan); - } - - let plan = self.optimizer.optimize_plan(plan)?; - - let (optimized, _) = self.optimize_recursively(&plan, None, config)?; - if let Some(result) = optimized { - return Ok(result); - } - Ok(plan.clone()) - } - - /// A human readable name for this optimizer rule - fn name(&self) -> &str { - "federation_optimizer_rule" - } -} - -fn contains_federated_table(plan: &LogicalPlan) -> Result { - let federated_table_exists = plan.exists(|x| { - if let Some(provider) = get_federation_provider(x)? { - // federated table provider should have an analyzer - return Ok(provider.analyzer().is_some()); - } - Ok(false) - })?; - - Ok(federated_table_exists) -} - -impl FederationAnalyzerRule { - pub fn new() -> Self { - Self::default() - } - - // optimize_recursively recursively finds the largest sub-plans that can be federated - // to a single FederationProvider. - // Returns a plan if a sub-tree was federated, otherwise None. - // Returns a FederationProvider if it covers the entire sub-tree, otherwise None. - fn optimize_recursively( - &self, - plan: &LogicalPlan, - parent: Option<&LogicalPlan>, - _config: &ConfigOptions, - ) -> Result<(Option, Option)> { - // Check if this node determines the FederationProvider - let sole_provider = get_federation_provider(plan)?; - if sole_provider.is_some() { - return Ok((None, sole_provider)); - } - - // optimize_inputs - let inputs = plan.inputs(); - if inputs.is_empty() { - return Ok((None, None)); - } - - let (new_inputs, providers): (Vec<_>, Vec<_>) = inputs - .iter() - .map(|i| self.optimize_recursively(i, Some(plan), _config)) - .collect::>>()? - .into_iter() - .unzip(); - - // Note: assumes provider is None if ambiguous - let first_provider = providers.first().unwrap(); - let is_singular = providers.iter().all(|p| p.is_some() && p == first_provider); - - if is_singular { - if parent.is_none() { - // federate the entire plan - if let Some(provider) = first_provider { - if let Some(optimizer) = provider.analyzer() { - let optimized = - optimizer.execute_and_check(plan.clone(), _config, |_, _| {})?; - return Ok((Some(optimized), None)); - } - return Ok((None, None)); - } - return Ok((None, None)); - } - // The largest sub-plan is higher up. - return Ok((None, first_provider.clone())); - } - - // The plan is ambiguous, any inputs that are not federated and - // have a sole provider, should be federated. - let new_inputs = new_inputs - .into_iter() - .enumerate() - .map(|(i, new_sub_plan)| { - if let Some(sub_plan) = new_sub_plan { - // Already federated - return Ok(sub_plan); - } - let sub_plan = inputs.get(i).unwrap(); - // Check if the input has a sole provider and can be federated. - if let Some(provider) = providers.get(i).unwrap() { - if let Some(optimizer) = provider.analyzer() { - let wrapped = wrap_projection((*sub_plan).clone())?; - - let optimized = optimizer.execute_and_check(wrapped, _config, |_, _| {})?; - return Ok(optimized); - } - // No federation for this sub-plan (no analyzer) - return Ok((*sub_plan).clone()); - } - // No federation for this sub-plan (no provider) - Ok((*sub_plan).clone()) - }) - .collect::>>()?; - - let new_plan = match plan { - // Unnest returns columns to unnest as `expressions` but does not support passing them back to `with_new_exprs`. - // Instead, it uses data from its internal representation to create a new plan. - LogicalPlan::Unnest(_) => plan.with_new_exprs(vec![], new_inputs)?, - _ => plan.with_new_exprs(plan.expressions(), new_inputs)?, - }; - - Ok((Some(new_plan), None)) - } -} - -fn get_federation_provider(plan: &LogicalPlan) -> Result> { - match plan { - LogicalPlan::TableScan(TableScan { ref source, .. }) => { - let Some(federated_source) = get_table_source(source)? else { - return Ok(None); - }; - let provider = federated_source.federation_provider(); - Ok(Some(provider)) - } - _ => Ok(None), - } -} - -fn wrap_projection(plan: LogicalPlan) -> Result { - // TODO: minimize requested columns - match plan { - LogicalPlan::Projection(_) => Ok(plan), - _ => { - let expr = plan - .schema() - .fields() - .iter() - .enumerate() - .map(|(i, f)| { - Expr::Column(Column::from_qualified_name(format!( - "{}.{}", - plan.schema() - .qualified_field(i) - .0 - .map(TableReference::table) - .unwrap_or_default(), - f.name() - ))) - }) - .collect::>(); - Ok(LogicalPlan::Projection(Projection::try_new( - expr, - Arc::new(plan), - )?)) - } - } -} - -pub fn get_table_source( - source: &Arc, -) -> Result>> { - // Unwrap TableSource - let source = source_as_provider(source)?; - - // Get FederatedTableProviderAdaptor - let Some(wrapper) = source - .as_any() - .downcast_ref::() - else { - return Ok(None); - }; - - // Return original FederatedTableSource - Ok(Some(Arc::clone(&wrapper.source))) -} diff --git a/datafusion-federation/src/analyzer/mod.rs b/datafusion-federation/src/analyzer/mod.rs new file mode 100644 index 0000000..74d954a --- /dev/null +++ b/datafusion-federation/src/analyzer/mod.rs @@ -0,0 +1,514 @@ +mod scan_result; + +use crate::FederationProvider; +use crate::{ + optimize::Optimizer, FederatedTableProviderAdaptor, FederatedTableSource, FederationProviderRef, +}; +use datafusion::error::DataFusionError; +use datafusion::logical_expr::{col, expr::InSubquery, LogicalPlanBuilder}; +use datafusion::{ + common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}, + config::ConfigOptions, + datasource::source_as_provider, + error::Result, + logical_expr::{Expr, Extension, LogicalPlan, Projection, TableScan, TableSource}, + optimizer::analyzer::AnalyzerRule, + sql::TableReference, +}; +use scan_result::ScanResult; +use std::collections::HashMap; +use std::sync::Arc; +use std::sync::RwLock; + +#[derive(Debug)] +pub struct FederationAnalyzerRule { + optimizer: Optimizer, + provider_map: Arc>>, +} + +impl Default for FederationAnalyzerRule { + fn default() -> Self { + Self { + optimizer: Optimizer::default(), + provider_map: Arc::new(RwLock::new(HashMap::new())), + } + } +} + +impl AnalyzerRule for FederationAnalyzerRule { + // Walk over the plan, look for the largest subtrees that only have + // TableScans from the same FederationProvider. + // There 'largest sub-trees' are passed to their respective FederationProvider.optimizer. + fn analyze(&self, plan: LogicalPlan, config: &ConfigOptions) -> Result { + if !contains_federated_table(&plan)? { + return Ok(plan); + } + // Run selected optimizer rules before federation + let plan = self.optimizer.optimize_plan(plan)?; + + // Find all federation providers for TableReference appeared in the plan + let providers = get_plan_provider_recursively(&plan)?; + let mut write_map = self.provider_map.write().map_err(|_| { + DataFusionError::External( + "Failed to create federated plan: failed to find all federated providers.".into(), + ) + })?; + write_map.extend(providers); + drop(write_map); + + match self.optimize_plan_recursively(&plan, true, config)? { + (Some(optimized_plan), _) => Ok(optimized_plan), + (None, _) => Ok(plan), + } + } + + /// A human readable name for this optimizer rule + fn name(&self) -> &str { + "federation_optimizer_rule" + } +} + +impl FederationAnalyzerRule { + pub fn new() -> Self { + Self::default() + } + + /// Scans a plan to see if it belongs to a single [`FederationProvider`]. + fn scan_plan_recursively(&self, plan: &LogicalPlan) -> Result { + let mut sole_provider: ScanResult = ScanResult::None; + + plan.apply(&mut |p: &LogicalPlan| -> Result { + let exprs_provider = self.scan_plan_exprs(p)?; + sole_provider.merge(exprs_provider); + + if sole_provider.is_ambiguous() { + return Ok(TreeNodeRecursion::Stop); + } + + let (sub_provider, _) = get_leaf_provider(p)?; + sole_provider.add(sub_provider); + + Ok(sole_provider.check_recursion()) + })?; + + Ok(sole_provider) + } + + /// Scans a plan's expressions to see if it belongs to a single [`FederationProvider`]. + fn scan_plan_exprs(&self, plan: &LogicalPlan) -> Result { + let mut sole_provider: ScanResult = ScanResult::None; + + let exprs = plan.expressions(); + for expr in &exprs { + let expr_result = self.scan_expr_recursively(expr)?; + sole_provider.merge(expr_result); + + if sole_provider.is_ambiguous() { + return Ok(sole_provider); + } + } + + Ok(sole_provider) + } + + /// scans an expression to see if it belongs to a single [`FederationProvider`] + fn scan_expr_recursively(&self, expr: &Expr) -> Result { + let mut sole_provider: ScanResult = ScanResult::None; + + expr.apply(&mut |e: &Expr| -> Result { + // TODO: Support other types of sub-queries + match e { + Expr::ScalarSubquery(ref subquery) => { + let plan_result = self.scan_plan_recursively(&subquery.subquery)?; + + sole_provider.merge(plan_result); + Ok(sole_provider.check_recursion()) + } + Expr::InSubquery(ref insubquery) => { + let plan_result = self.scan_plan_recursively(&insubquery.subquery.subquery)?; + + sole_provider.merge(plan_result); + Ok(sole_provider.check_recursion()) + } + Expr::OuterReferenceColumn(_, ref col) => { + if let Some(table) = &col.relation { + let map = self.provider_map.read().map_err(|_| { + DataFusionError::External( + "Failed to create federated plan: failed to obtain a read lock on federated providers.".into(), + ) + })?; + if let Some(plan_result) = map.get(table) { + sole_provider.merge(plan_result.clone()); + return Ok(sole_provider.check_recursion()); + } + } + // Subqueries that reference outer columns are not supported + // for now. We handle this here as ambiguity to force + // federation lower in the plan tree. + sole_provider = ScanResult::Ambiguous; + Ok(TreeNodeRecursion::Stop) + } + _ => Ok(TreeNodeRecursion::Continue), + } + })?; + + Ok(sole_provider) + } + + /// Recursively finds the largest sub-plans that can be federated + /// to a single FederationProvider. + /// + /// Returns a plan if a sub-tree was federated, otherwise None. + /// + /// Returns a ScanResult of all FederationProviders in the subtree. + fn optimize_plan_recursively( + &self, + plan: &LogicalPlan, + is_root: bool, + _config: &ConfigOptions, + ) -> Result<(Option, ScanResult)> { + let mut sole_provider: ScanResult = ScanResult::None; + + if let LogicalPlan::Extension(Extension { ref node }) = plan { + if node.name() == "Federated" { + // Avoid attempting double federation + return Ok((None, ScanResult::Ambiguous)); + } + } + + // Check if this plan node is a leaf that determines the FederationProvider + let (leaf_provider, _) = get_leaf_provider(plan)?; + + // Check if the expressions contain, a potentially different, FederationProvider + let exprs_result = self.scan_plan_exprs(plan)?; + + // Return early if this is a leaf and there is no ambiguity with the expressions. + if leaf_provider.is_some() && (exprs_result.is_none() || exprs_result == leaf_provider) { + return Ok((None, leaf_provider.into())); + } + // Aggregate leaf & expression providers + sole_provider.add(leaf_provider); + sole_provider.merge(exprs_result.clone()); + + let inputs = plan.inputs(); + // Return early if there are no sources. + if inputs.is_empty() && sole_provider.is_none() { + return Ok((None, ScanResult::None)); + } + + // Recursively optimize inputs + let input_results = inputs + .iter() + .map(|i| self.optimize_plan_recursively(i, false, _config)) + .collect::>>()?; + + // Aggregate the input providers + input_results.iter().for_each(|(_, scan_result)| { + sole_provider.merge(scan_result.clone()); + }); + + if sole_provider.is_none() { + // No providers found + // TODO: Is/should this be reachable? + return Ok((None, ScanResult::None)); + } + + // Federate Exprs when Exprs provider is ambiguous or Exprs provider differs from the sole_provider of current plan + // When Exprs provider is the same as sole_provider and non-ambiguous, the larger sub-plan is higher up + let optimize_expressions = exprs_result.is_some() + && (!(sole_provider == exprs_result) || exprs_result.is_ambiguous()); + + // If all sources are federated to the same provider + if let ScanResult::Distinct(provider) = sole_provider { + if !is_root { + // The largest sub-plan is higher up. + return Ok((None, ScanResult::Distinct(provider))); + } + + let Some(optimizer) = provider.analyzer() else { + // No optimizer provided + return Ok((None, ScanResult::None)); + }; + + // If this is the root plan node; federate the entire plan + let optimized = optimizer.execute_and_check(plan.clone(), _config, |_, _| {})?; + return Ok((Some(optimized), ScanResult::None)); + } + + // The plan is ambiguous; any input that is not yet optimized and has a + // sole provider represents a largest sub-plan and should be federated. + // + // We loop over the input optimization results, federate where needed and + // return a complete list of new inputs for the optimized plan. + let new_inputs = input_results + .into_iter() + .enumerate() + .map(|(i, (input_plan, input_result))| { + if let Some(federated_plan) = input_plan { + // Already federated deeper in the plan tree + return Ok(federated_plan); + } + + let original_input = (*inputs.get(i).unwrap()).clone(); + if input_result.is_ambiguous() { + // Can happen if the input is already federated, so use + // the original input. + return Ok(original_input); + } + + let provider = input_result.unwrap()?; + let Some(provider) = provider else { + // No provider for this input; use the original input. + return Ok(original_input); + }; + + let Some(optimizer) = provider.analyzer() else { + // No optimizer for this input; use the original input. + return Ok(original_input); + }; + + // Replace the input with the federated counterpart + let wrapped = wrap_projection(original_input)?; + let optimized = optimizer.execute_and_check(wrapped, _config, |_, _| {})?; + + Ok(optimized) + }) + .collect::>>()?; + + // Optimize expressions if needed + let new_expressions = if optimize_expressions { + self.optimize_plan_exprs(plan, _config)? + } else { + plan.expressions() + }; + + // Construct the optimized plan + let new_plan = plan.with_new_exprs(new_expressions, new_inputs)?; + + // Return the federated plan + Ok((Some(new_plan), ScanResult::Ambiguous)) + } + + /// Optimizes all exprs of a plan + fn optimize_plan_exprs( + &self, + plan: &LogicalPlan, + _config: &ConfigOptions, + ) -> Result> { + plan.expressions() + .iter() + .map(|expr| { + let transformed = expr + .clone() + .transform(&|e| self.optimize_expr_recursively(e, _config))?; + Ok(transformed.data) + }) + .collect::>>() + } + + /// recursively optimize expressions + /// Current logic: individually federate every sub-query. + fn optimize_expr_recursively( + &self, + expr: Expr, + _config: &ConfigOptions, + ) -> Result> { + match expr { + Expr::ScalarSubquery(ref subquery) => { + // Optimize as root to force federating the sub-query + let (new_subquery, _) = + self.optimize_plan_recursively(&subquery.subquery, true, _config)?; + let Some(new_subquery) = new_subquery else { + return Ok(Transformed::no(expr)); + }; + + // ScalarSubqueryToJoin optimizer rule doesn't support federated node (LogicalPlan::Extension(_)) as subquery + // Wrap a `non-op` Projection LogicalPlan outside the federated node to facilitate ScalarSubqueryToJoin optimization + if matches!(new_subquery, LogicalPlan::Extension(_)) { + let all_columns = new_subquery + .schema() + .fields() + .iter() + .map(|field| col(field.name())) + .collect::>(); + + let projection_plan = LogicalPlanBuilder::from(new_subquery) + .project(all_columns)? + .build()?; + + return Ok(Transformed::yes(Expr::ScalarSubquery( + subquery.with_plan(projection_plan.into()), + ))); + } + + Ok(Transformed::yes(Expr::ScalarSubquery( + subquery.with_plan(new_subquery.into()), + ))) + } + Expr::InSubquery(ref in_subquery) => { + let (new_subquery, _) = + self.optimize_plan_recursively(&in_subquery.subquery.subquery, true, _config)?; + let Some(new_subquery) = new_subquery else { + return Ok(Transformed::no(expr)); + }; + + // DecorrelatePredicateSubquery optimizer rule doesn't support federated node (LogicalPlan::Extension(_)) as subquery + // Wrap a `non-op` Projection LogicalPlan outside the federated node to facilitate DecorrelatePredicateSubquery optimization + if matches!(new_subquery, LogicalPlan::Extension(_)) { + let all_columns = new_subquery + .schema() + .fields() + .iter() + .map(|field| col(field.name())) + .collect::>(); + + let projection_plan = LogicalPlanBuilder::from(new_subquery) + .project(all_columns)? + .build()?; + + return Ok(Transformed::yes(Expr::InSubquery(InSubquery::new( + in_subquery.expr.clone(), + in_subquery.subquery.with_plan(projection_plan.into()), + in_subquery.negated, + )))); + } + + Ok(Transformed::yes(Expr::InSubquery(InSubquery::new( + in_subquery.expr.clone(), + in_subquery.subquery.with_plan(new_subquery.into()), + in_subquery.negated, + )))) + } + _ => Ok(Transformed::no(expr)), + } + } +} + +/// NopFederationProvider is used to represent tables that are not federated, but +/// are resolved by DataFusion. This simplifies the logic of the optimizer rule. +struct NopFederationProvider {} + +impl FederationProvider for NopFederationProvider { + fn name(&self) -> &str { + "nop" + } + + fn compute_context(&self) -> Option { + None + } + + fn analyzer(&self) -> Option> { + None + } +} + +/// Recursively find the [`FederationProvider`] for all [`TableReference`] instances in the plan. +/// This information is used to resolve the federation provider for [`Expr::OuterReferenceColumn`]. +fn get_plan_provider_recursively( + plan: &LogicalPlan, +) -> Result> { + let mut providers: HashMap = HashMap::new(); + + plan.apply(&mut |p: &LogicalPlan| -> Result { + // LogicalPlan::SubqueryAlias can also be referred by OuterReferenceColumn + // Get the federation provider for TableReference representing LogicalPlan::SubqueryAlias + if let LogicalPlan::SubqueryAlias(a) = p { + let subquery_alias_providers = get_plan_provider_recursively(&Arc::clone(&a.input))?; + let mut provider: ScanResult = ScanResult::None; + for (_, i) in subquery_alias_providers { + provider.merge(i); + } + providers.insert(a.alias.clone(), provider); + } + + let (federation_provider, table_reference) = get_leaf_provider(p)?; + if let Some(table_reference) = table_reference { + providers.insert(table_reference, federation_provider.into()); + } + + let _ = p.apply_subqueries(|sub_query| { + let subquery_providers = get_plan_provider_recursively(sub_query)?; + providers.extend(subquery_providers); + Ok(TreeNodeRecursion::Continue) + }); + + Ok(TreeNodeRecursion::Continue) + })?; + + Ok(providers) +} + +fn wrap_projection(plan: LogicalPlan) -> Result { + // TODO: minimize requested columns + match plan { + LogicalPlan::Projection(_) => Ok(plan), + _ => { + let expr = plan + .schema() + .columns() + .iter() + .map(|c| Expr::Column(c.clone())) + .collect::>(); + Ok(LogicalPlan::Projection(Projection::try_new( + expr, + Arc::new(plan), + )?)) + } + } +} + +fn contains_federated_table(plan: &LogicalPlan) -> Result { + let federated_table_exists = plan.exists(|x| { + if let (Some(provider), _) = get_leaf_provider(x)? { + // federated table provider should have an analyzer + return Ok(provider.analyzer().is_some()); + } + Ok(false) + })?; + + Ok(federated_table_exists) +} + +fn get_leaf_provider( + plan: &LogicalPlan, +) -> Result<(Option, Option)> { + match plan { + LogicalPlan::TableScan(TableScan { + ref table_name, + ref source, + .. + }) => { + let table_reference = table_name.clone(); + let Some(federated_source) = get_table_source(source)? else { + // Table is not federated but provided by a standard table provider. + // We use a placeholder federation provider to simplify the logic. + return Ok(( + Some(Arc::new(NopFederationProvider {})), + Some(table_reference), + )); + }; + let provider = federated_source.federation_provider(); + Ok((Some(provider), Some(table_reference))) + } + _ => Ok((None, None)), + } +} + +#[allow(clippy::missing_errors_doc)] +pub fn get_table_source( + source: &Arc, +) -> Result>> { + // Unwrap TableSource + let source = source_as_provider(source)?; + + // Get FederatedTableProviderAdaptor + let Some(wrapper) = source + .as_any() + .downcast_ref::() + else { + return Ok(None); + }; + + // Return original FederatedTableSource + Ok(Some(Arc::clone(&wrapper.source))) +} diff --git a/datafusion-federation/src/analyzer/scan_result.rs b/datafusion-federation/src/analyzer/scan_result.rs new file mode 100644 index 0000000..dc01906 --- /dev/null +++ b/datafusion-federation/src/analyzer/scan_result.rs @@ -0,0 +1,113 @@ +use crate::FederationProviderRef; +use datafusion::common::tree_node::TreeNodeRecursion; +use datafusion::error::{DataFusionError, Result}; + +/// Used to track if all sources, including tableScan, plan inputs and +/// expressions, represents an un-ambiguous, none or a sole' [`crate::FederationProvider`]. +pub enum ScanResult { + None, + Distinct(FederationProviderRef), + Ambiguous, +} + +impl ScanResult { + pub fn merge(&mut self, other: Self) { + match (&self, &other) { + (_, ScanResult::None) => {} + (ScanResult::None, _) => *self = other, + (ScanResult::Ambiguous, _) | (_, ScanResult::Ambiguous) => { + *self = ScanResult::Ambiguous; + } + (ScanResult::Distinct(provider), ScanResult::Distinct(other_provider)) => { + if provider != other_provider { + *self = ScanResult::Ambiguous; + } + } + } + } + + pub fn add(&mut self, provider: Option) { + self.merge(ScanResult::from(provider)) + } + + pub fn is_ambiguous(&self) -> bool { + matches!(self, ScanResult::Ambiguous) + } + + pub fn is_none(&self) -> bool { + matches!(self, ScanResult::None) + } + pub fn is_some(&self) -> bool { + !self.is_none() + } + + pub fn unwrap(self) -> Result> { + match self { + ScanResult::None => Ok(None), + ScanResult::Distinct(provider) => Ok(Some(provider)), + ScanResult::Ambiguous => Err(DataFusionError::External( + "called `ScanResult::unwrap()` on a `Ambiguous` value".into(), + )), + } + } + + pub fn check_recursion(&self) -> TreeNodeRecursion { + if self.is_ambiguous() { + TreeNodeRecursion::Stop + } else { + TreeNodeRecursion::Continue + } + } +} + +impl From> for ScanResult { + fn from(provider: Option) -> Self { + match provider { + Some(provider) => ScanResult::Distinct(provider), + None => ScanResult::None, + } + } +} + +impl PartialEq> for ScanResult { + fn eq(&self, other: &Option) -> bool { + match (self, other) { + (ScanResult::None, None) => true, + (ScanResult::Distinct(provider), Some(other_provider)) => provider == other_provider, + _ => false, + } + } +} + +impl PartialEq for ScanResult { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (ScanResult::None, ScanResult::None) => true, + (ScanResult::Distinct(provider1), ScanResult::Distinct(provider2)) => { + provider1 == provider2 + } + (ScanResult::Ambiguous, ScanResult::Ambiguous) => true, + _ => false, + } + } +} + +impl std::fmt::Debug for ScanResult { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::None => write!(f, "ScanResult::None"), + Self::Distinct(provider) => write!(f, "ScanResult::Distinct({})", provider.name()), + Self::Ambiguous => write!(f, "ScanResult::Ambiguous"), + } + } +} + +impl Clone for ScanResult { + fn clone(&self) -> Self { + match self { + ScanResult::None => ScanResult::None, + ScanResult::Distinct(provider) => ScanResult::Distinct(provider.clone()), + ScanResult::Ambiguous => ScanResult::Ambiguous, + } + } +} diff --git a/datafusion-federation/src/optimize.rs b/datafusion-federation/src/optimize.rs index b669210..5f543ae 100644 --- a/datafusion-federation/src/optimize.rs +++ b/datafusion-federation/src/optimize.rs @@ -3,7 +3,8 @@ use datafusion::{ error::Result, logical_expr::LogicalPlan, optimizer::{ - optimizer::ApplyOrder, push_down_filter::PushDownFilter, OptimizerConfig, OptimizerContext, OptimizerRule + optimizer::ApplyOrder, push_down_filter::PushDownFilter, OptimizerConfig, OptimizerContext, + OptimizerRule, }, }; use optimize_projections::OptimizeProjections; diff --git a/datafusion-federation/src/schema_cast.rs b/datafusion-federation/src/schema_cast.rs index 502d4b8..5b1ba5a 100644 --- a/datafusion-federation/src/schema_cast.rs +++ b/datafusion-federation/src/schema_cast.rs @@ -4,8 +4,7 @@ use datafusion::error::{DataFusionError, Result}; use datafusion::execution::{SendableRecordBatchStream, TaskContext}; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::physical_plan::{ - DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties, - PlanProperties, + DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties, }; use futures::StreamExt; use std::any::Any; @@ -15,8 +14,8 @@ use std::sync::Arc; mod intervals_cast; mod lists_cast; -mod struct_cast; pub mod record_convert; +mod struct_cast; #[derive(Debug)] #[allow(clippy::module_name_repetitions)] @@ -70,7 +69,7 @@ impl ExecutionPlan for SchemaCastScanExec { vec![&self.input] } - /// Prevents the introduction of additional `RepartitionExec` and processing input in parallel. + /// Prevents the introduction of additional `RepartitionExec` and processing input in parallel. /// This guarantees that the input is processed as a single stream, preserving the order of the data. fn benefits_from_input_partitioning(&self) -> Vec { vec![false] From 72374f605ec9617ff1852ca43297233202ec2d43 Mon Sep 17 00:00:00 2001 From: Sergei Grebnov Date: Tue, 24 Dec 2024 12:09:28 -0800 Subject: [PATCH 44/48] Fix unnest rewriting logic (#28) --- sources/sql/src/lib.rs | 77 +++++++++++++++++++++++++++++++----------- 1 file changed, 57 insertions(+), 20 deletions(-) diff --git a/sources/sql/src/lib.rs b/sources/sql/src/lib.rs index 0186f1a..2ed0113 100644 --- a/sources/sql/src/lib.rs +++ b/sources/sql/src/lib.rs @@ -4,7 +4,7 @@ use std::{any::Any, collections::HashMap, sync::Arc, vec}; use async_trait::async_trait; use datafusion::{ arrow::datatypes::{Schema, SchemaRef}, - common::Column, + common::{Column, RecursionUnnestOption, UnnestOptions}, config::ConfigOptions, error::{DataFusionError, Result}, execution::{context::SessionState, TaskContext}, @@ -208,7 +208,7 @@ fn rewrite_unnest_plan( )); }; - // rewrite aliases in inner projection; columns were rewritten via `rewrite_table_scans_in_expr`` + // rewrite aliases in inner projection; columns were rewritten via `rewrite_table_scans_in_expr` let new_expressions = projection .expr .into_iter() @@ -227,15 +227,66 @@ fn rewrite_unnest_plan( let updated_unnest_inner_projection = Projection::try_new(new_expressions, Arc::clone(&projection.input))?; + let unnest_options = rewrite_unnest_options(&unnest.options, known_rewrites); + // reconstruct the unnest plan with updated projection and rewritten column names let new_plan = LogicalPlanBuilder::new(LogicalPlan::Projection(updated_unnest_inner_projection)) - .unnest_columns_with_options(unnest_columns, unnest.options.clone())? + .unnest_columns_with_options(unnest_columns, unnest_options)? .build()?; Ok(new_plan) } +/// Rewrites columns names in the unnest options to use the original federated table name: +/// "unnest_placeholder(foo.df_table.a,depth=1)"" -> "unnest_placeholder(remote_table.a,depth=1)"" +fn rewrite_unnest_options( + options: &UnnestOptions, + known_rewrites: &HashMap, +) -> UnnestOptions { + let mut new_options = options.clone(); + new_options + .recursions + .iter_mut() + .for_each(|x: &mut RecursionUnnestOption| { + if let Some(new_name) = rewrite_column_name(&x.input_column.name, known_rewrites) { + x.input_column.name = new_name; + } + + if let Some(new_name) = rewrite_column_name(&x.output_column.name, known_rewrites) { + x.output_column.name = new_name; + } + }); + new_options +} + +/// Checks if any of the rewrites match any substring in col_name, and replace that part of the string if so. +/// This handles cases like "MAX(foo.df_table.a)" -> "MAX(remote_table.a)" +/// Returns the rewritten name if any rewrite was applied, otherwise None. +fn rewrite_column_name( + col_name: &str, + known_rewrites: &HashMap, +) -> Option { + let (new_col_name, was_rewritten) = known_rewrites.iter().fold( + (col_name.to_string(), false), + |(col_name, was_rewritten), (table_ref, rewrite)| match rewrite_column_name_in_expr( + &col_name, + &table_ref.to_string(), + &rewrite.to_string(), + 0, + ) { + Some(new_name) => (new_name, true), + None => (col_name, was_rewritten), + }, + ); + + if was_rewritten { + Some(new_col_name) + } else { + None + } +} + // The function replaces occurrences of table_ref_str in col_name with the new name defined by rewrite. // The name to rewrite should NOT be a substring of another name. // Supports multiple occurrences of table_ref_str in col_name. @@ -344,21 +395,7 @@ fn rewrite_table_scans_in_expr( // Check if any of the rewrites match any substring in col.name, and replace that part of the string if so. // This will handles cases like "MAX(foo.df_table.a)" -> "MAX(remote_table.a)" - let (new_name, was_rewritten) = known_rewrites.iter().fold( - (col.name.to_string(), false), - |(col_name, was_rewritten), (table_ref, rewrite)| { - match rewrite_column_name_in_expr( - &col_name, - &table_ref.to_string(), - &rewrite.to_string(), - 0, - ) { - Some(new_name) => (new_name, true), - None => (col_name, was_rewritten), - } - }, - ); - if was_rewritten { + if let Some(new_name) = rewrite_column_name(&col.name, known_rewrites) { Ok(Expr::Column(Column::new(col.relation.take(), new_name))) } else { Ok(Expr::Column(col)) @@ -1020,11 +1057,11 @@ mod tests { let tests = vec![ ( "SELECT UNNEST([1, 2, 2, 5, NULL]), b, c from app_table where a > 10 order by b limit 10;", - r#"SELECT UNNEST(make_array(1, 2, 2, 5, NULL)), remote_table.b, remote_table.c FROM remote_table WHERE (remote_table.a > 10) ORDER BY remote_table.b ASC NULLS LAST LIMIT 10"#, + r#"SELECT UNNEST(make_array(1, 2, 2, 5, NULL)) AS "UNNEST(make_array(Int64(1),Int64(2),Int64(2),Int64(5),NULL))", remote_table.b, remote_table.c FROM remote_table WHERE (remote_table.a > 10) ORDER BY remote_table.b ASC NULLS LAST LIMIT 10"#, ), ( "SELECT UNNEST(app_table.d), b, c from app_table where a > 10 order by b limit 10;", - r#"SELECT UNNEST(remote_table.d), remote_table.b, remote_table.c FROM remote_table WHERE (remote_table.a > 10) ORDER BY remote_table.b ASC NULLS LAST LIMIT 10"#, + r#"SELECT UNNEST(remote_table.d) AS "UNNEST(app_table.d)", remote_table.b, remote_table.c FROM remote_table WHERE (remote_table.a > 10) ORDER BY remote_table.b ASC NULLS LAST LIMIT 10"#, ), ( "SELECT sum(b.x) AS total FROM (SELECT UNNEST(d) AS x from app_table where a > 0) AS b;", From 884b2a5413e50232111be51df436568745da33f0 Mon Sep 17 00:00:00 2001 From: Qianqian <130200611+Sevenannn@users.noreply.github.com> Date: Thu, 26 Dec 2024 13:34:36 -0800 Subject: [PATCH 45/48] fix: clean up outdated dependencies & modules and fix tests (#32) * Fix tests * fix * fix formatting errors * fix clippy * clean dependencies, remove outdated crates and test * update test to follow upstream fix: align to upstream changes and pass all tests# * remove commented out code * fix --- .github/workflows/check.yml | 2 +- datafusion-federation/src/optimize.rs | 2 +- .../optimize_projections/required_indices.rs | 20 ++- .../src/schema_cast/intervals_cast.rs | 26 +++- .../src/schema_cast/record_convert.rs | 5 +- .../src/schema_cast/struct_cast.rs | 2 +- examples/Cargo.toml | 6 +- examples/examples/flight-sql.rs | 11 +- examples/examples/postgres-partial.rs | 70 ---------- examples/examples/sqlite-partial.rs | 107 --------------- examples/examples/sqlite.rs | 70 ---------- sources/flight-sql/Cargo.toml | 2 +- sources/sql/src/connectorx/executor.rs | 125 ------------------ sources/sql/src/connectorx/mod.rs | 2 - sources/sql/src/lib.rs | 14 +- 15 files changed, 50 insertions(+), 414 deletions(-) delete mode 100644 examples/examples/postgres-partial.rs delete mode 100644 examples/examples/sqlite-partial.rs delete mode 100644 examples/examples/sqlite.rs delete mode 100644 sources/sql/src/connectorx/executor.rs delete mode 100644 sources/sql/src/connectorx/mod.rs diff --git a/.github/workflows/check.yml b/.github/workflows/check.yml index fe66d27..92e1dad 100644 --- a/.github/workflows/check.yml +++ b/.github/workflows/check.yml @@ -10,4 +10,4 @@ jobs: - uses: actions/checkout@v4 - uses: actions/setup-node@v4 - run: npm install prettier prettier-plugin-toml - - run: npx prettier --check --no-config . + - run: npx prettier --write --no-config . diff --git a/datafusion-federation/src/optimize.rs b/datafusion-federation/src/optimize.rs index 5f543ae..993a2ab 100644 --- a/datafusion-federation/src/optimize.rs +++ b/datafusion-federation/src/optimize.rs @@ -71,7 +71,7 @@ impl<'a> Rewriter<'a> { } } -impl<'a> TreeNodeRewriter for Rewriter<'a> { +impl TreeNodeRewriter for Rewriter<'_> { type Node = LogicalPlan; fn f_down(&mut self, node: LogicalPlan) -> Result> { diff --git a/datafusion-federation/src/optimize/optimize_projections/required_indices.rs b/datafusion-federation/src/optimize/optimize_projections/required_indices.rs index aa4889f..e6df841 100644 --- a/datafusion-federation/src/optimize/optimize_projections/required_indices.rs +++ b/datafusion-federation/src/optimize/optimize_projections/required_indices.rs @@ -16,12 +16,15 @@ // under the License. /// Copy of DataFusion's [`RequiredIndicies`](https://github.com/apache/datafusion/blob/main/datafusion/optimizer/src/optimize_projections/required_indices.rs) implementation. - use std::result; -use datafusion::{common::{tree_node::TreeNodeRecursion, Column, DFSchemaRef}, error::DataFusionError, logical_expr::LogicalPlan, prelude::Expr}; - use super::outer_columns; +use datafusion::{ + common::{tree_node::TreeNodeRecursion, Column, DFSchemaRef}, + error::DataFusionError, + logical_expr::LogicalPlan, + prelude::Expr, +}; type Result = result::Result; @@ -91,11 +94,7 @@ impl RequiredIndicies { } /// Add required indices for all `exprs` used in plan - pub fn with_plan_exprs( - mut self, - plan: &LogicalPlan, - schema: &DFSchemaRef, - ) -> Result { + pub fn with_plan_exprs(mut self, plan: &LogicalPlan, schema: &DFSchemaRef) -> Result { // Add indices of the child fields referred to by the expressions in the // parent plan.apply_expressions(|e| { @@ -168,8 +167,7 @@ impl RequiredIndicies { where F: Fn(usize) -> bool, { - let (l, r): (Vec, Vec) = - self.indices.iter().partition(|&&idx| f(idx)); + let (l, r): (Vec, Vec) = self.indices.iter().partition(|&&idx| f(idx)); let projection_beneficial = self.projection_beneficial; ( @@ -226,4 +224,4 @@ impl RequiredIndicies { self.indices.dedup(); self } -} \ No newline at end of file +} diff --git a/datafusion-federation/src/schema_cast/intervals_cast.rs b/datafusion-federation/src/schema_cast/intervals_cast.rs index 3e445c5..629203e 100644 --- a/datafusion-federation/src/schema_cast/intervals_cast.rs +++ b/datafusion-federation/src/schema_cast/intervals_cast.rs @@ -50,7 +50,7 @@ pub(crate) fn cast_interval_monthdaynano_to_daytime( let interval_monthdaynano_array = interval_monthdaynano_array .as_any() .downcast_ref::() - .ok_or_else(|| + .ok_or_else(|| ArrowError::CastError("Failed to cast IntervalMonthDayNanoArray: Unable to downcast to IntervalMonthDayNanoArray".to_string()))?; let mut interval_daytime_builder = @@ -78,8 +78,10 @@ pub(crate) fn cast_interval_monthdaynano_to_daytime( #[cfg(test)] mod test { use datafusion::arrow::{ - array::{RecordBatch, IntervalDayTimeArray, IntervalYearMonthArray}, - datatypes::{DataType, Field, Schema, SchemaRef, IntervalUnit, IntervalMonthDayNano, IntervalDayTime}, + array::{IntervalDayTimeArray, IntervalYearMonthArray, RecordBatch}, + datatypes::{ + DataType, Field, IntervalDayTime, IntervalMonthDayNano, IntervalUnit, Schema, SchemaRef, + }, }; use crate::schema_cast::record_convert::try_cast_to; @@ -88,9 +90,21 @@ mod test { fn input_schema() -> SchemaRef { Arc::new(Schema::new(vec![ - Field::new("interval_daytime", DataType::Interval(IntervalUnit::MonthDayNano), false), - Field::new("interval_monthday_nano", DataType::Interval(IntervalUnit::MonthDayNano), false), - Field::new("interval_yearmonth", DataType::Interval(IntervalUnit::MonthDayNano), false), + Field::new( + "interval_daytime", + DataType::Interval(IntervalUnit::MonthDayNano), + false, + ), + Field::new( + "interval_monthday_nano", + DataType::Interval(IntervalUnit::MonthDayNano), + false, + ), + Field::new( + "interval_yearmonth", + DataType::Interval(IntervalUnit::MonthDayNano), + false, + ), ])) } diff --git a/datafusion-federation/src/schema_cast/record_convert.rs b/datafusion-federation/src/schema_cast/record_convert.rs index 140ca38..a20401a 100644 --- a/datafusion-federation/src/schema_cast/record_convert.rs +++ b/datafusion-federation/src/schema_cast/record_convert.rs @@ -1,7 +1,7 @@ use datafusion::arrow::{ array::{Array, RecordBatch}, compute::cast, - datatypes::{DataType, IntervalUnit, SchemaRef} + datatypes::{DataType, IntervalUnit, SchemaRef}, }; use std::sync::Arc; @@ -9,7 +9,8 @@ use super::{ intervals_cast::{ cast_interval_monthdaynano_to_daytime, cast_interval_monthdaynano_to_yearmonth, }, - lists_cast::{cast_string_to_fixed_size_list, cast_string_to_large_list, cast_string_to_list}, struct_cast::cast_string_to_struct, + lists_cast::{cast_string_to_fixed_size_list, cast_string_to_large_list, cast_string_to_list}, + struct_cast::cast_string_to_struct, }; pub type Result = std::result::Result; diff --git a/datafusion-federation/src/schema_cast/struct_cast.rs b/datafusion-federation/src/schema_cast/struct_cast.rs index 85ad369..5b6c1f0 100644 --- a/datafusion-federation/src/schema_cast/struct_cast.rs +++ b/datafusion-federation/src/schema_cast/struct_cast.rs @@ -50,7 +50,7 @@ pub(crate) fn cast_string_to_struct( } }; // struct_field is single struct column - return Ok(Arc::clone(record.column(0))); + Ok(Arc::clone(record.column(0))) } #[cfg(test)] diff --git a/examples/Cargo.toml b/examples/Cargo.toml index 52acc4e..48c910c 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -6,18 +6,18 @@ license.workspace = true readme.workspace = true [dev-dependencies] -arrow-flight = { version = "52.0.0", features = ["flight-sql-experimental"] } +arrow-flight = { version = "53.0.0", features = ["flight-sql-experimental"] } tokio = "1.35.1" async-trait.workspace = true datafusion.workspace = true datafusion-federation.path = "../datafusion-federation" -datafusion-federation-sql.path = "../sources/sql" +datafusion-federation-sql = { path = "../sources/sql", features = ["connectorx"] } datafusion-federation-flight-sql.path = "../sources/flight-sql" connectorx = { git = "https://github.com/devinjdangelo/connector-x.git", features = [ "dst_arrow", "src_sqlite" ] } -tonic = "0.11.0" +tonic = "0.12.2" [dependencies] async-std = "1.12.0" diff --git a/examples/examples/flight-sql.rs b/examples/examples/flight-sql.rs index 7a32e29..1f8578f 100644 --- a/examples/examples/flight-sql.rs +++ b/examples/examples/flight-sql.rs @@ -1,8 +1,9 @@ use std::{sync::Arc, time::Duration}; use arrow_flight::sql::client::FlightSqlServiceClient; +use datafusion::execution::SessionStateBuilder; use datafusion::{ - catalog::schema::SchemaProvider, + catalog::SchemaProvider, error::{DataFusionError, Result}, execution::{ context::{SessionContext, SessionState}, @@ -39,14 +40,14 @@ async fn main() -> Result<()> { sleep(Duration::from_secs(3)).await; // Local context - let state = SessionContext::new().state(); let known_tables: Vec = ["test"].iter().map(|&x| x.into()).collect(); // Register FederationAnalyzer // TODO: Interaction with other analyzers & optimizers. - let state = state - .add_analyzer_rule(Arc::new(FederationAnalyzerRule::new())) - .with_query_planner(Arc::new(FederatedQueryPlanner::new())); + let mut state = SessionStateBuilder::new() + .with_query_planner(Arc::new(FederatedQueryPlanner::new())) + .build(); + state.add_analyzer_rule(Arc::new(FederationAnalyzerRule::new())); // Register schema // TODO: table inference diff --git a/examples/examples/postgres-partial.rs b/examples/examples/postgres-partial.rs deleted file mode 100644 index 873dd40..0000000 --- a/examples/examples/postgres-partial.rs +++ /dev/null @@ -1,70 +0,0 @@ -use std::sync::Arc; -use tokio::task; - -use datafusion::{ - catalog::schema::SchemaProvider, - error::Result, - execution::context::{SessionContext, SessionState}, -}; -use datafusion_federation::{FederatedQueryPlanner, FederationAnalyzerRule}; -use datafusion_federation_sql::connectorx::CXExecutor; -use datafusion_federation_sql::{MultiSchemaProvider, SQLFederationProvider, SQLSchemaProvider}; - -#[tokio::main] -async fn main() -> Result<()> { - let state = SessionContext::new().state(); - // Register FederationAnalyzer - // TODO: Interaction with other analyzers & optimizers. - let state = state - .add_analyzer_rule(Arc::new(FederationAnalyzerRule::new())) - .with_query_planner(Arc::new(FederatedQueryPlanner::new())); - - let df = task::spawn_blocking(move || { - // Register schema - let pg_provider_1 = async_std::task::block_on(create_postgres_provider(vec!["class"], "conn1")).unwrap(); - let pg_provider_2 = async_std::task::block_on(create_postgres_provider(vec!["teacher"], "conn2")).unwrap(); - let provider = MultiSchemaProvider::new(vec![ - pg_provider_1, - pg_provider_2, - ]); - - overwrite_default_schema(&state, Arc::new(provider)).unwrap(); - - // Run query - let ctx = SessionContext::new_with_state(state); - let query = r#"SELECT class.name AS classname, teacher.name AS teachername FROM class JOIN teacher ON class.id = teacher.class_id"#; - let df = async_std::task::block_on(ctx.sql(query)).unwrap(); - - df - }).await.unwrap(); - - task::spawn_blocking(move || async_std::task::block_on(df.show())) - .await - .unwrap() -} - -async fn create_postgres_provider( - known_tables: Vec<&str>, - context: &str, -) -> Result> { - let dsn = "postgresql://:@localhost:/".to_string(); - let known_tables: Vec = known_tables.iter().map(|&x| x.into()).collect(); - let mut executor = CXExecutor::new(dsn)?; - executor.context(context.to_string()); - let provider = Arc::new(SQLFederationProvider::new(Arc::new(executor))); - Ok(Arc::new( - SQLSchemaProvider::new_with_tables(provider, known_tables).await?, - )) -} - -fn overwrite_default_schema(state: &SessionState, schema: Arc) -> Result<()> { - let options = &state.config().options().catalog; - let catalog = state - .catalog_list() - .catalog(options.default_catalog.as_str()) - .unwrap(); - - catalog.register_schema(options.default_schema.as_str(), schema)?; - - Ok(()) -} diff --git a/examples/examples/sqlite-partial.rs b/examples/examples/sqlite-partial.rs deleted file mode 100644 index 780462b..0000000 --- a/examples/examples/sqlite-partial.rs +++ /dev/null @@ -1,107 +0,0 @@ -use std::{any::Any, sync::Arc}; - -use async_trait::async_trait; -use datafusion::{ - catalog::schema::SchemaProvider, - datasource::TableProvider, - error::Result, - execution::context::{SessionContext, SessionState}, -}; -use datafusion_federation::{FederatedQueryPlanner, FederationAnalyzerRule}; -use datafusion_federation_sql::{connectorx::CXExecutor, SQLFederationProvider, SQLSchemaProvider}; - -#[tokio::main] -async fn main() -> Result<()> { - let state = SessionContext::new().state(); - // Register FederationAnalyzer - // TODO: Interaction with other analyzers & optimizers. - let state = state - .add_analyzer_rule(Arc::new(FederationAnalyzerRule::new())) - .with_query_planner(Arc::new(FederatedQueryPlanner::new())); - - // Register schema - let provider = MultiSchemaProvider::new(vec![ - create_sqlite_provider(vec!["Artist"], "conn1").await?, - create_sqlite_provider(vec!["Track", "Album"], "conn2").await?, - ]); - - overwrite_default_schema(&state, Arc::new(provider))?; - - // Run query - let ctx = SessionContext::new_with_state(state); - let query = r#"SELECT - t.TrackId, - t.Name AS TrackName, - a.Title AS AlbumTitle, - ar.Name AS ArtistName - FROM Track t - JOIN Album a ON t.AlbumId = a.AlbumId - JOIN Artist ar ON a.ArtistId = ar.ArtistId - limit 10"#; - let df = ctx.sql(query).await?; - - // let explain = df.clone().explain(true, false)?; - // explain.show().await?; - - df.show().await -} - -async fn create_sqlite_provider( - known_tables: Vec<&str>, - context: &str, -) -> Result> { - let dsn = "sqlite://./examples/examples/chinook.sqlite".to_string(); - let known_tables: Vec = known_tables.iter().map(|&x| x.into()).collect(); - let mut executor = CXExecutor::new(dsn)?; - executor.context(context.to_string()); - let provider = Arc::new(SQLFederationProvider::new(Arc::new(executor))); - Ok(Arc::new( - SQLSchemaProvider::new_with_tables(provider, known_tables).await?, - )) -} - -struct MultiSchemaProvider { - children: Vec>, -} - -impl MultiSchemaProvider { - pub fn new(children: Vec>) -> Self { - Self { children } - } -} - -fn overwrite_default_schema(state: &SessionState, schema: Arc) -> Result<()> { - let options = &state.config().options().catalog; - let catalog = state - .catalog_list() - .catalog(options.default_catalog.as_str()) - .unwrap(); - - catalog.register_schema(options.default_schema.as_str(), schema)?; - - Ok(()) -} - -#[async_trait] -impl SchemaProvider for MultiSchemaProvider { - fn as_any(&self) -> &dyn Any { - self - } - - fn table_names(&self) -> Vec { - self.children.iter().flat_map(|p| p.table_names()).collect() - } - - async fn table(&self, name: &str) -> Result>> { - for child in &self.children { - if let Ok(Some(table)) = child.table(name).await { - return Ok(Some(table)); - } - } - Ok(None) - } - - fn table_exist(&self, name: &str) -> bool { - self.children.iter().any(|p| p.table_exist(name)) - } -} diff --git a/examples/examples/sqlite.rs b/examples/examples/sqlite.rs deleted file mode 100644 index 74e5371..0000000 --- a/examples/examples/sqlite.rs +++ /dev/null @@ -1,70 +0,0 @@ -use std::sync::Arc; - -use datafusion::{ - catalog::schema::SchemaProvider, - error::Result, - execution::context::{SessionContext, SessionState}, -}; -use datafusion_federation::{FederatedQueryPlanner, FederationAnalyzerRule}; -use datafusion_federation_sql::{connectorx::CXExecutor, SQLFederationProvider, SQLSchemaProvider}; - -#[tokio::main] -async fn main() -> datafusion::error::Result<()> { - let dsn = "sqlite://./examples/examples/chinook.sqlite".to_string(); - let known_tables: Vec = ["Track", "Album", "Artist"] - .iter() - .map(|&x| x.into()) - .collect(); - - let state = SessionContext::new().state(); - - // Register FederationAnalyzer - // TODO: Interaction with other analyzers & optimizers. - let state = state - .add_analyzer_rule(Arc::new(FederationAnalyzerRule::new())) - .with_query_planner(Arc::new(FederatedQueryPlanner::new())); - - // Register schema - // TODO: table inference - let executor = Arc::new(CXExecutor::new(dsn)?); - let provider = Arc::new(SQLFederationProvider::new(executor)); - let schema_provider = - Arc::new(SQLSchemaProvider::new_with_tables(provider, known_tables).await?); - overwrite_default_schema(&state, schema_provider)?; - - // Run query - let ctx = SessionContext::new_with_state(state); - let query = r#"SELECT - t.TrackId, - t.Name AS TrackName, - a.Title AS AlbumTitle, - ar.Name AS ArtistName - FROM Track t - JOIN Album a ON t.AlbumId = a.AlbumId - JOIN Artist ar ON a.ArtistId = ar.ArtistId - limit 10"#; - let df = ctx.sql(query).await?; - df.show().await?; - - // If the environment variable EXPLAIN is set, print the query plan - if std::env::var("EXPLAIN").is_ok() { - let explain_query = format!("EXPLAIN {query}"); - let df = ctx.sql(explain_query.as_str()).await?; - - df.show().await?; - } - - Ok(()) -} - -fn overwrite_default_schema(state: &SessionState, schema: Arc) -> Result<()> { - let options = &state.config().options().catalog; - let catalog = state - .catalog_list() - .catalog(options.default_catalog.as_str()) - .unwrap(); - - catalog.register_schema(options.default_schema.as_str(), schema)?; - - Ok(()) -} diff --git a/sources/flight-sql/Cargo.toml b/sources/flight-sql/Cargo.toml index 63ce13f..91e8af0 100644 --- a/sources/flight-sql/Cargo.toml +++ b/sources/flight-sql/Cargo.toml @@ -17,7 +17,7 @@ datafusion-federation.path = "../../datafusion-federation" datafusion-federation-sql.path = "../sql" futures = "0.3.30" tonic = {version="0.12.2", features=["tls"] } -prost = "0.12.3" +prost = "0.13.4" arrow = "53.0.0" arrow-flight = { version = "53.0.0", features = ["flight-sql-experimental"] } log = "0.4.20" diff --git a/sources/sql/src/connectorx/executor.rs b/sources/sql/src/connectorx/executor.rs deleted file mode 100644 index b5964ce..0000000 --- a/sources/sql/src/connectorx/executor.rs +++ /dev/null @@ -1,125 +0,0 @@ -use async_trait::async_trait; -use connectorx::{ - destinations::arrow::ArrowDestinationError, - errors::{ConnectorXError, ConnectorXOutError}, - prelude::{get_arrow, CXQuery, SourceConn, SourceType}, -}; -use datafusion::{ - arrow::datatypes::{Field, Schema, SchemaRef}, - error::{DataFusionError, Result}, - physical_plan::{ - stream::RecordBatchStreamAdapter, EmptyRecordBatchStream, SendableRecordBatchStream, - }, - sql::unparser::dialect::{DefaultDialect, Dialect, PostgreSqlDialect, SqliteDialect}, -}; -use futures::executor::block_on; -use std::sync::Arc; -use tokio::task; - -use crate::executor::SQLExecutor; - -pub struct CXExecutor { - context: String, - conn: SourceConn, -} - -impl CXExecutor { - pub fn new(dsn: String) -> Result { - let conn = SourceConn::try_from(dsn.as_str()).map_err(cx_error_to_df)?; - Ok(Self { context: dsn, conn }) - } - - pub fn new_with_conn(conn: SourceConn) -> Self { - Self { - context: conn.conn.to_string(), - conn, - } - } - - pub fn context(&mut self, context: String) { - self.context = context; - } -} - -fn cx_error_to_df(err: ConnectorXError) -> DataFusionError { - DataFusionError::External(format!("ConnectorX: {err:?}").into()) -} - -#[async_trait] -impl SQLExecutor for CXExecutor { - fn name(&self) -> &str { - "connector_x_executor" - } - fn compute_context(&self) -> Option { - Some(self.context.clone()) - } - fn execute(&self, sql: &str, schema: SchemaRef) -> Result { - let conn = self.conn.clone(); - let query: CXQuery = sql.into(); - - let mut dst = block_on(task::spawn_blocking(move || -> Result<_, _> { - get_arrow(&conn, None, &[query.clone()]).map_err(cx_out_error_to_df) - })) - .map_err(|err| DataFusionError::External(err.to_string().into()))??; - let stream = if let Some(batch) = dst.record_batch().map_err(cx_dst_error_to_df)? { - futures::stream::once(async move { Ok(batch) }) - } else { - return Ok(Box::pin(EmptyRecordBatchStream::new(Arc::new( - Schema::empty(), - )))); - }; - - Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream))) - } - - async fn table_names(&self) -> Result> { - Err(DataFusionError::NotImplemented( - "connector_x source: table inference not implemented".to_string(), - )) - } - - async fn get_table_schema(&self, table_name: &str) -> Result { - let conn = self.conn.clone(); - let query: CXQuery = format!("select * from {table_name} limit 1") - .as_str() - .into(); - - let dst = get_arrow(&conn, None, &[query.clone()]).map_err(cx_out_error_to_df)?; - let schema = schema_to_lowercase(dst.arrow_schema()); - Ok(schema) - } - - fn dialect(&self) -> Arc { - match &self.conn.ty { - SourceType::Postgres => Arc::new(PostgreSqlDialect {}), - SourceType::SQLite => Arc::new(SqliteDialect {}), - _ => Arc::new(DefaultDialect {}), - } - } -} - -fn cx_dst_error_to_df(err: ArrowDestinationError) -> DataFusionError { - DataFusionError::External(format!("ConnectorX failed to run query: {err:?}").into()) -} - -/// Get the schema with lowercase field names -fn schema_to_lowercase(schema: SchemaRef) -> SchemaRef { - // DF needs lower case schema - let lower_fields: Vec<_> = schema - .fields - .iter() - .map(|f| { - Field::new( - f.name().to_ascii_lowercase(), - f.data_type().clone(), - f.is_nullable(), - ) - }) - .collect(); - - Arc::new(Schema::new(lower_fields)) -} - -fn cx_out_error_to_df(err: ConnectorXOutError) -> DataFusionError { - DataFusionError::External(format!("ConnectorX failed to run query: {err:?}").into()) -} diff --git a/sources/sql/src/connectorx/mod.rs b/sources/sql/src/connectorx/mod.rs deleted file mode 100644 index 600069a..0000000 --- a/sources/sql/src/connectorx/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -mod executor; -pub use executor::*; diff --git a/sources/sql/src/lib.rs b/sources/sql/src/lib.rs index 2ed0113..92c301c 100644 --- a/sources/sql/src/lib.rs +++ b/sources/sql/src/lib.rs @@ -35,8 +35,6 @@ use datafusion_federation::{ mod schema; pub use schema::*; -#[cfg(feature = "connectorx")] -pub mod connectorx; mod executor; pub use executor::*; @@ -301,9 +299,7 @@ fn rewrite_column_name_in_expr( } // Find the first occurrence of table_ref_str starting from start_pos - let Some(idx) = col_name[start_pos..].find(table_ref_str) else { - return None; - }; + let idx = col_name[start_pos..].find(table_ref_str)?; // Calculate the absolute index of the occurrence in string as the index above is relative to start_pos let idx = start_pos + idx; @@ -1010,8 +1006,8 @@ mod tests { ), // different tables in single aggregation expression ( - "SELECT COUNT(CASE WHEN app_table.a > 0 THEN app_table.a ELSE foo.df_table.a END) FROM app_table, foo.df_table", - r#"SELECT count(CASE WHEN (remote_table.a > 0) THEN remote_table.a ELSE remote_table.a END) FROM remote_table JOIN remote_table ON true"#, + "SELECT COUNT(CASE WHEN appt.a > 0 THEN appt.a ELSE dft.a END) FROM app_table as appt, foo.df_table as dft", + "SELECT count(CASE WHEN (appt.a > 0) THEN appt.a ELSE dft.a END) FROM remote_table AS appt JOIN remote_table AS dft" ), ]; @@ -1083,12 +1079,12 @@ mod tests { ) -> Result<(), datafusion::error::DataFusionError> { let data_frame = ctx.sql(sql_query).await?; - println!("before optimization: \n{:#?}", data_frame.logical_plan()); + // println!("before optimization: \n{:#?}", data_frame.logical_plan()); let mut known_rewrites = HashMap::new(); let rewritten_plan = rewrite_table_scans(data_frame.logical_plan(), &mut known_rewrites)?; - println!("rewritten_plan: \n{:#?}", rewritten_plan); + // println!("rewritten_plan: \n{:#?}", rewritten_plan); let unparsed_sql = plan_to_sql(&rewritten_plan)?; From ce91d5a6479465ba349f95e3775653e7c27a9185 Mon Sep 17 00:00:00 2001 From: Qianqian <130200611+Sevenannn@users.noreply.github.com> Date: Fri, 27 Dec 2024 08:10:21 -0800 Subject: [PATCH 46/48] fix: rewrite shouldn't be performed on a column name same as the table name (#33) * Rewrite shouldn't be performed on a column name same as the table name * fix comment * Add unit test --- sources/sql/src/lib.rs | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/sources/sql/src/lib.rs b/sources/sql/src/lib.rs index 92c301c..c8b9132 100644 --- a/sources/sql/src/lib.rs +++ b/sources/sql/src/lib.rs @@ -304,6 +304,12 @@ fn rewrite_column_name_in_expr( // Calculate the absolute index of the occurrence in string as the index above is relative to start_pos let idx = start_pos + idx; + // Table name same as column name + // Shouldn't rewrite in this case + if idx == 0 && start_pos == 0 { + return None; + } + if idx > 0 { // Check if the previous character is alphabetic, numeric, underscore or period, in which case we // should not rewrite as it is a part of another name. @@ -1072,6 +1078,23 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_rewrite_same_column_table_name() -> Result<()> { + init_tracing(); + let ctx = get_test_df_context(); + + let tests = vec![( + "SELECT app_table FROM (SELECT a app_table from app_table limit 100);", + r#"SELECT app_table FROM (SELECT remote_table.a AS app_table FROM remote_table LIMIT 100)"#, + )]; + + for test in tests { + test_sql(&ctx, test.0, test.1).await?; + } + + Ok(()) + } + async fn test_sql( ctx: &SessionContext, sql_query: &str, From 4c0ff795deeddcd728b4f973c27f002468e5dd9f Mon Sep 17 00:00:00 2001 From: Qianqian <130200611+Sevenannn@users.noreply.github.com> Date: Mon, 30 Dec 2024 12:11:04 -0600 Subject: [PATCH 47/48] fix: don't rewrite table names inside subquery when engine doesn't support it (#29) * Don't rewrite table names inside subquery when engine doesn't support it * resolve merge conflicts * cleanup and update condition for detecting subquery * Add test & cleanup * fix: make clippy happy * Define subquery_use_partial_path as sqlexecutor trait --- sources/sql/src/executor.rs | 5 + sources/sql/src/lib.rs | 588 +++++++++++++++++++++++++++++++----- 2 files changed, 517 insertions(+), 76 deletions(-) diff --git a/sources/sql/src/executor.rs b/sources/sql/src/executor.rs index 275bc68..0922ce1 100644 --- a/sources/sql/src/executor.rs +++ b/sources/sql/src/executor.rs @@ -35,6 +35,11 @@ pub trait SQLExecutor: Sync + Send { async fn table_names(&self) -> Result>; /// Returns the schema of table_name within this SQLExecutor async fn get_table_schema(&self, table_name: &str) -> Result; + + /// Returns whether the executor requires partial table path in subquery + fn subquery_use_partial_path(&self) -> bool { + false + } } impl fmt::Debug for dyn SQLExecutor { diff --git a/sources/sql/src/lib.rs b/sources/sql/src/lib.rs index c8b9132..756acea 100644 --- a/sources/sql/src/lib.rs +++ b/sources/sql/src/lib.rs @@ -1,5 +1,10 @@ use core::fmt; -use std::{any::Any, collections::HashMap, sync::Arc, vec}; +use std::{ + any::Any, + collections::{HashMap, HashSet}, + sync::Arc, + vec, +}; use async_trait::async_trait; use datafusion::{ @@ -109,6 +114,8 @@ impl AnalyzerRule for SQLFederationAnalyzerRule { fn rewrite_table_scans( plan: &LogicalPlan, known_rewrites: &mut HashMap, + subquery_uses_partial_path: bool, + subquery_table_scans: &mut Option>, ) -> Result { if plan.inputs().is_empty() { if let LogicalPlan::TableScan(table_scan) = plan { @@ -123,7 +130,11 @@ fn rewrite_table_scans( match federated_source.as_any().downcast_ref::() { Some(sql_table_source) => { let remote_table_name = TableReference::from(sql_table_source.table_name()); - known_rewrites.insert(original_table_name, remote_table_name.clone()); + known_rewrites.insert(original_table_name.clone(), remote_table_name.clone()); + + if let Some(s) = subquery_table_scans { + s.insert(original_table_name); + } // Rewrite the schema of this node to have the remote table as the qualifier. let new_schema = (*new_table_scan.projected_schema) @@ -147,19 +158,37 @@ fn rewrite_table_scans( let rewritten_inputs = plan .inputs() .into_iter() - .map(|i| rewrite_table_scans(i, known_rewrites)) + .map(|i| { + rewrite_table_scans( + i, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + ) + }) .collect::>>()?; match plan { LogicalPlan::Unnest(unnest) => { // The Union plan cannot be constructed from rewritten expressions. It requires specialized logic to handle // the renaming in UNNEST columns and the corresponding column aliases in the underlying projection plan. - rewrite_unnest_plan(unnest, rewritten_inputs, known_rewrites) + rewrite_unnest_plan( + unnest, + rewritten_inputs, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + ) } _ => { let mut new_expressions = vec![]; for expression in plan.expressions() { - let new_expr = rewrite_table_scans_in_expr(expression.clone(), known_rewrites)?; + let new_expr = rewrite_table_scans_in_expr( + expression.clone(), + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + )?; new_expressions.push(new_expr); } let new_plan = plan.with_new_exprs(new_expressions, rewritten_inputs)?; @@ -177,6 +206,8 @@ fn rewrite_unnest_plan( unnest: &logical_expr::Unnest, mut rewritten_inputs: Vec, known_rewrites: &mut HashMap, + subquery_uses_partial_path: bool, + subquery_table_scans: &mut Option>, ) -> Result { // Unnest plan has a single input let input = rewritten_inputs.remove(0); @@ -188,7 +219,12 @@ fn rewrite_unnest_plan( .exec_columns .iter() .map(|c: &Column| { - match rewrite_table_scans_in_expr(Expr::Column(c.clone()), known_rewrites)? { + match rewrite_table_scans_in_expr( + Expr::Column(c.clone()), + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + )? { Expr::Column(column) => { known_unnest_rewrites.insert(c.name.clone(), column.name.clone()); Ok(column) @@ -225,7 +261,8 @@ fn rewrite_unnest_plan( let updated_unnest_inner_projection = Projection::try_new(new_expressions, Arc::clone(&projection.input))?; - let unnest_options = rewrite_unnest_options(&unnest.options, known_rewrites); + let unnest_options = + rewrite_unnest_options(&unnest.options, known_rewrites, subquery_table_scans); // reconstruct the unnest plan with updated projection and rewritten column names let new_plan = @@ -241,17 +278,22 @@ fn rewrite_unnest_plan( fn rewrite_unnest_options( options: &UnnestOptions, known_rewrites: &HashMap, + subquery_table_scans: &mut Option>, ) -> UnnestOptions { let mut new_options = options.clone(); new_options .recursions .iter_mut() .for_each(|x: &mut RecursionUnnestOption| { - if let Some(new_name) = rewrite_column_name(&x.input_column.name, known_rewrites) { + if let Some(new_name) = + rewrite_column_name(&x.input_column.name, known_rewrites, subquery_table_scans) + { x.input_column.name = new_name; } - if let Some(new_name) = rewrite_column_name(&x.output_column.name, known_rewrites) { + if let Some(new_name) = + rewrite_column_name(&x.output_column.name, known_rewrites, subquery_table_scans) + { x.output_column.name = new_name; } }); @@ -264,17 +306,22 @@ fn rewrite_unnest_options( fn rewrite_column_name( col_name: &str, known_rewrites: &HashMap, + subquery_table_scans: &mut Option>, ) -> Option { let (new_col_name, was_rewritten) = known_rewrites.iter().fold( (col_name.to_string(), false), - |(col_name, was_rewritten), (table_ref, rewrite)| match rewrite_column_name_in_expr( - &col_name, - &table_ref.to_string(), - &rewrite.to_string(), - 0, - ) { - Some(new_name) => (new_name, true), - None => (col_name, was_rewritten), + |(col_name, was_rewritten), (table_ref, rewrite)| { + let mut rewrite_string = rewrite.to_string(); + if let Some(subquery_reference) = subquery_table_scans { + if subquery_reference.get(table_ref).is_some() { + rewrite_string = get_partial_table_name(rewrite); + } + } + match rewrite_column_name_in_expr(&col_name, &table_ref.to_string(), &rewrite_string, 0) + { + Some(new_name) => (new_name, true), + None => (col_name, was_rewritten), + } }, ); @@ -285,6 +332,12 @@ fn rewrite_column_name( } } +fn get_partial_table_name(full_table_reference: &TableReference) -> String { + let full_table_path = full_table_reference.table().to_owned(); + let path_parts: Vec<&str> = full_table_path.split('.').collect(); + path_parts[path_parts.len() - 1].to_owned() +} + // The function replaces occurrences of table_ref_str in col_name with the new name defined by rewrite. // The name to rewrite should NOT be a substring of another name. // Supports multiple occurrences of table_ref_str in col_name. @@ -362,14 +415,38 @@ fn rewrite_column_name_in_expr( fn rewrite_table_scans_in_expr( expr: Expr, known_rewrites: &mut HashMap, + subquery_uses_partial_path: bool, + subquery_table_scans: &mut Option>, ) -> Result { match expr { Expr::ScalarSubquery(subquery) => { - let new_subquery = rewrite_table_scans(&subquery.subquery, known_rewrites)?; + let new_subquery = if subquery_table_scans.is_some() || !subquery_uses_partial_path { + rewrite_table_scans( + &subquery.subquery, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + )? + } else { + let mut scans = Some(HashSet::new()); + rewrite_table_scans( + &subquery.subquery, + known_rewrites, + subquery_uses_partial_path, + &mut scans, + )? + }; let outer_ref_columns = subquery .outer_ref_columns .into_iter() - .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) + .map(|e| { + rewrite_table_scans_in_expr( + e, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + ) + }) .collect::>>()?; Ok(Expr::ScalarSubquery(Subquery { subquery: Arc::new(new_subquery), @@ -377,8 +454,18 @@ fn rewrite_table_scans_in_expr( })) } Expr::BinaryExpr(binary_expr) => { - let left = rewrite_table_scans_in_expr(*binary_expr.left, known_rewrites)?; - let right = rewrite_table_scans_in_expr(*binary_expr.right, known_rewrites)?; + let left = rewrite_table_scans_in_expr( + *binary_expr.left, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + )?; + let right = rewrite_table_scans_in_expr( + *binary_expr.right, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + )?; Ok(Expr::BinaryExpr(BinaryExpr::new( Box::new(left), binary_expr.op, @@ -387,6 +474,25 @@ fn rewrite_table_scans_in_expr( } Expr::Column(mut col) => { if let Some(rewrite) = col.relation.as_ref().and_then(|r| known_rewrites.get(r)) { + if let Some(subquery_reference) = subquery_table_scans { + if col + .relation + .as_ref() + .and_then(|r| subquery_reference.get(r)) + .is_some() + { + // Use the partial table path from source for rewrite + // e.g. If the fully qualified name is foo_db.foo_schema.foo + // Use foo as partial path + let partial_path = get_partial_table_name(rewrite); + let partial_table_reference = TableReference::from(partial_path); + + return Ok(Expr::Column(Column::new( + Some(partial_table_reference), + &col.name, + ))); + } + } Ok(Expr::Column(Column::new(Some(rewrite.clone()), &col.name))) } else { // This prevent over-eager rewrite and only pass the column into below rewritten @@ -397,7 +503,9 @@ fn rewrite_table_scans_in_expr( // Check if any of the rewrites match any substring in col.name, and replace that part of the string if so. // This will handles cases like "MAX(foo.df_table.a)" -> "MAX(remote_table.a)" - if let Some(new_name) = rewrite_column_name(&col.name, known_rewrites) { + if let Some(new_name) = + rewrite_column_name(&col.name, known_rewrites, subquery_table_scans) + { Ok(Expr::Column(Column::new(col.relation.take(), new_name))) } else { Ok(Expr::Column(col)) @@ -405,7 +513,12 @@ fn rewrite_table_scans_in_expr( } } Expr::Alias(alias) => { - let expr = rewrite_table_scans_in_expr(*alias.expr, known_rewrites)?; + let expr = rewrite_table_scans_in_expr( + *alias.expr, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + )?; if let Some(relation) = &alias.relation { if let Some(rewrite) = known_rewrites.get(relation) { return Ok(Expr::Alias(Alias::new( @@ -418,8 +531,18 @@ fn rewrite_table_scans_in_expr( Ok(Expr::Alias(Alias::new(expr, alias.relation, alias.name))) } Expr::Like(like) => { - let expr = rewrite_table_scans_in_expr(*like.expr, known_rewrites)?; - let pattern = rewrite_table_scans_in_expr(*like.pattern, known_rewrites)?; + let expr = rewrite_table_scans_in_expr( + *like.expr, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + )?; + let pattern = rewrite_table_scans_in_expr( + *like.pattern, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + )?; Ok(Expr::Like(Like::new( like.negated, Box::new(expr), @@ -429,8 +552,18 @@ fn rewrite_table_scans_in_expr( ))) } Expr::SimilarTo(similar_to) => { - let expr = rewrite_table_scans_in_expr(*similar_to.expr, known_rewrites)?; - let pattern = rewrite_table_scans_in_expr(*similar_to.pattern, known_rewrites)?; + let expr = rewrite_table_scans_in_expr( + *similar_to.expr, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + )?; + let pattern = rewrite_table_scans_in_expr( + *similar_to.pattern, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + )?; Ok(Expr::SimilarTo(Like::new( similar_to.negated, Box::new(expr), @@ -440,49 +573,114 @@ fn rewrite_table_scans_in_expr( ))) } Expr::Not(e) => { - let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; + let expr = rewrite_table_scans_in_expr( + *e, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + )?; Ok(Expr::Not(Box::new(expr))) } Expr::IsNotNull(e) => { - let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; + let expr = rewrite_table_scans_in_expr( + *e, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + )?; Ok(Expr::IsNotNull(Box::new(expr))) } Expr::IsNull(e) => { - let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; + let expr = rewrite_table_scans_in_expr( + *e, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + )?; Ok(Expr::IsNull(Box::new(expr))) } Expr::IsTrue(e) => { - let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; + let expr = rewrite_table_scans_in_expr( + *e, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + )?; Ok(Expr::IsTrue(Box::new(expr))) } Expr::IsFalse(e) => { - let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; + let expr = rewrite_table_scans_in_expr( + *e, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + )?; Ok(Expr::IsFalse(Box::new(expr))) } Expr::IsUnknown(e) => { - let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; + let expr = rewrite_table_scans_in_expr( + *e, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + )?; Ok(Expr::IsUnknown(Box::new(expr))) } Expr::IsNotTrue(e) => { - let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; + let expr = rewrite_table_scans_in_expr( + *e, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + )?; Ok(Expr::IsNotTrue(Box::new(expr))) } Expr::IsNotFalse(e) => { - let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; + let expr = rewrite_table_scans_in_expr( + *e, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + )?; Ok(Expr::IsNotFalse(Box::new(expr))) } Expr::IsNotUnknown(e) => { - let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; + let expr = rewrite_table_scans_in_expr( + *e, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + )?; Ok(Expr::IsNotUnknown(Box::new(expr))) } Expr::Negative(e) => { - let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; + let expr = rewrite_table_scans_in_expr( + *e, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + )?; Ok(Expr::Negative(Box::new(expr))) } Expr::Between(between) => { - let expr = rewrite_table_scans_in_expr(*between.expr, known_rewrites)?; - let low = rewrite_table_scans_in_expr(*between.low, known_rewrites)?; - let high = rewrite_table_scans_in_expr(*between.high, known_rewrites)?; + let expr = rewrite_table_scans_in_expr( + *between.expr, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + )?; + let low = rewrite_table_scans_in_expr( + *between.low, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + )?; + let high = rewrite_table_scans_in_expr( + *between.high, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + )?; Ok(Expr::Between(Between::new( Box::new(expr), between.negated, @@ -493,20 +691,44 @@ fn rewrite_table_scans_in_expr( Expr::Case(case) => { let expr = case .expr - .map(|e| rewrite_table_scans_in_expr(*e, known_rewrites)) + .map(|e| { + rewrite_table_scans_in_expr( + *e, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + ) + }) .transpose()? .map(Box::new); let else_expr = case .else_expr - .map(|e| rewrite_table_scans_in_expr(*e, known_rewrites)) + .map(|e| { + rewrite_table_scans_in_expr( + *e, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + ) + }) .transpose()? .map(Box::new); let when_expr = case .when_then_expr .into_iter() .map(|(when, then)| { - let when = rewrite_table_scans_in_expr(*when, known_rewrites); - let then = rewrite_table_scans_in_expr(*then, known_rewrites); + let when = rewrite_table_scans_in_expr( + *when, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + ); + let then = rewrite_table_scans_in_expr( + *then, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + ); match (when, then) { (Ok(when), Ok(then)) => Ok((Box::new(when), Box::new(then))), @@ -517,11 +739,21 @@ fn rewrite_table_scans_in_expr( Ok(Expr::Case(Case::new(expr, when_expr, else_expr))) } Expr::Cast(cast) => { - let expr = rewrite_table_scans_in_expr(*cast.expr, known_rewrites)?; + let expr = rewrite_table_scans_in_expr( + *cast.expr, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + )?; Ok(Expr::Cast(Cast::new(Box::new(expr), cast.data_type))) } Expr::TryCast(try_cast) => { - let expr = rewrite_table_scans_in_expr(*try_cast.expr, known_rewrites)?; + let expr = rewrite_table_scans_in_expr( + *try_cast.expr, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + )?; Ok(Expr::TryCast(TryCast::new( Box::new(expr), try_cast.data_type, @@ -531,7 +763,14 @@ fn rewrite_table_scans_in_expr( let args = sf .args .into_iter() - .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) + .map(|e| { + rewrite_table_scans_in_expr( + e, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + ) + }) .collect::>>()?; Ok(Expr::ScalarFunction(ScalarFunction { func: sf.func, @@ -542,11 +781,25 @@ fn rewrite_table_scans_in_expr( let args = af .args .into_iter() - .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) + .map(|e| { + rewrite_table_scans_in_expr( + e, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + ) + }) .collect::>>()?; let filter = af .filter - .map(|e| rewrite_table_scans_in_expr(*e, known_rewrites)) + .map(|e| { + rewrite_table_scans_in_expr( + *e, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + ) + }) .transpose()? .map(Box::new); let order_by = af @@ -554,8 +807,13 @@ fn rewrite_table_scans_in_expr( .map(|e| { e.into_iter() .map(|s| { - rewrite_table_scans_in_expr(s.expr, known_rewrites) - .map(|e| Sort::new(e, s.asc, s.nulls_first)) + rewrite_table_scans_in_expr( + s.expr, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + ) + .map(|e| Sort::new(e, s.asc, s.nulls_first)) }) .collect::>>() }) @@ -573,19 +831,38 @@ fn rewrite_table_scans_in_expr( let args = wf .args .into_iter() - .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) + .map(|e| { + rewrite_table_scans_in_expr( + e, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + ) + }) .collect::>>()?; let partition_by = wf .partition_by .into_iter() - .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) + .map(|e| { + rewrite_table_scans_in_expr( + e, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + ) + }) .collect::>>()?; let order_by = wf .order_by .into_iter() .map(|s| { - rewrite_table_scans_in_expr(s.expr, known_rewrites) - .map(|e| Sort::new(e, s.asc, s.nulls_first)) + rewrite_table_scans_in_expr( + s.expr, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + ) + .map(|e| Sort::new(e, s.asc, s.nulls_first)) }) .collect::>>()?; Ok(Expr::WindowFunction(WindowFunction { @@ -598,21 +875,56 @@ fn rewrite_table_scans_in_expr( })) } Expr::InList(il) => { - let expr = rewrite_table_scans_in_expr(*il.expr, known_rewrites)?; + let expr = rewrite_table_scans_in_expr( + *il.expr, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + )?; let list = il .list .into_iter() - .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) + .map(|e| { + rewrite_table_scans_in_expr( + e, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + ) + }) .collect::>>()?; Ok(Expr::InList(InList::new(Box::new(expr), list, il.negated))) } Expr::Exists(exists) => { - let subquery_plan = rewrite_table_scans(&exists.subquery.subquery, known_rewrites)?; + let subquery_plan = if subquery_table_scans.is_some() || !subquery_uses_partial_path { + rewrite_table_scans( + &exists.subquery.subquery, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + )? + } else { + let mut scans = Some(HashSet::new()); + rewrite_table_scans( + &exists.subquery.subquery, + known_rewrites, + subquery_uses_partial_path, + &mut scans, + )? + }; + let outer_ref_columns = exists .subquery .outer_ref_columns .into_iter() - .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) + .map(|e| { + rewrite_table_scans_in_expr( + e, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + ) + }) .collect::>>()?; let subquery = Subquery { subquery: Arc::new(subquery_plan), @@ -621,13 +933,40 @@ fn rewrite_table_scans_in_expr( Ok(Expr::Exists(Exists::new(subquery, exists.negated))) } Expr::InSubquery(is) => { - let expr = rewrite_table_scans_in_expr(*is.expr, known_rewrites)?; - let subquery_plan = rewrite_table_scans(&is.subquery.subquery, known_rewrites)?; + let expr = rewrite_table_scans_in_expr( + *is.expr, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + )?; + let subquery_plan = if subquery_table_scans.is_some() || !subquery_uses_partial_path { + rewrite_table_scans( + &is.subquery.subquery, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + )? + } else { + let mut scans = Some(HashSet::new()); + rewrite_table_scans( + &is.subquery.subquery, + known_rewrites, + subquery_uses_partial_path, + &mut scans, + )? + }; let outer_ref_columns = is .subquery .outer_ref_columns .into_iter() - .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) + .map(|e| { + rewrite_table_scans_in_expr( + e, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + ) + }) .collect::>>()?; let subquery = Subquery { subquery: Arc::new(subquery_plan), @@ -653,14 +992,28 @@ fn rewrite_table_scans_in_expr( GroupingSet::Rollup(exprs) => { let exprs = exprs .into_iter() - .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) + .map(|e| { + rewrite_table_scans_in_expr( + e, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + ) + }) .collect::>>()?; Ok(Expr::GroupingSet(GroupingSet::Rollup(exprs))) } GroupingSet::Cube(exprs) => { let exprs = exprs .into_iter() - .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) + .map(|e| { + rewrite_table_scans_in_expr( + e, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + ) + }) .collect::>>()?; Ok(Expr::GroupingSet(GroupingSet::Cube(exprs))) } @@ -670,7 +1023,14 @@ fn rewrite_table_scans_in_expr( .map(|exprs| { exprs .into_iter() - .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) + .map(|e| { + rewrite_table_scans_in_expr( + e, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + ) + }) .collect::>>() }) .collect::>>>()?; @@ -688,7 +1048,12 @@ fn rewrite_table_scans_in_expr( } } Expr::Unnest(unnest) => { - let expr = rewrite_table_scans_in_expr(*unnest.expr, known_rewrites)?; + let expr = rewrite_table_scans_in_expr( + *unnest.expr, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + )?; Ok(Expr::Unnest(Unnest::new(expr))) } Expr::ScalarVariable(_, _) | Expr::Literal(_) | Expr::Placeholder(_) => Ok(expr), @@ -752,8 +1117,15 @@ impl VirtualExecutionPlan { fn sql(&self) -> Result { // Find all table scans, recover the SQLTableSource, find the remote table name and replace the name of the TableScan table. let mut known_rewrites = HashMap::new(); - let mut ast = Unparser::new(self.executor.dialect().as_ref()) - .plan_to_sql(&rewrite_table_scans(&self.plan, &mut known_rewrites)?)?; + let subquery_uses_partial_path = self.executor.subquery_use_partial_path(); + let rewritten_plan = rewrite_table_scans( + &self.plan, + &mut known_rewrites, + subquery_uses_partial_path, + &mut None, + )?; + let mut ast = + Unparser::new(self.executor.dialect().as_ref()).plan_to_sql(&rewritten_plan)?; if let Some(analyzer) = self.executor.ast_analyzer() { ast = analyzer(ast)?; @@ -897,6 +1269,31 @@ mod tests { Arc::new(FederatedTableProviderAdaptor::new(table_source)) } + fn get_test_table_provider_with_full_path() -> Arc { + let sql_federation_provider = + Arc::new(SQLFederationProvider::new(Arc::new(TestSQLExecutor {}))); + + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, false), + Field::new("b", DataType::Utf8, false), + Field::new("c", DataType::Date32, false), + Field::new( + "d", + DataType::List(Arc::new(Field::new("item", DataType::Int64, true))), + false, + ), + ])); + let table_source = Arc::new( + SQLTableSource::new_with_schema( + sql_federation_provider, + "remote_db.remote_schema.remote_table".to_string(), + schema, + ) + .expect("to have a valid SQLTableSource"), + ); + Arc::new(FederatedTableProviderAdaptor::new(table_source)) + } + fn get_test_table_source() -> Arc { Arc::new(DefaultTableSource::new(get_test_table_provider())) } @@ -921,6 +1318,13 @@ mod tests { .register_table("app_table".to_string(), get_test_table_provider()) .expect("to register table"); + let public_schema = catalog + .schema("public") + .expect("public schema should exist"); + public_schema + .register_table("bar".to_string(), get_test_table_provider_with_full_path()) + .expect("to register table"); + ctx } @@ -935,7 +1339,8 @@ mod tests { ])?; let mut known_rewrites = HashMap::new(); - let rewritten_plan = rewrite_table_scans(&plan.build()?, &mut known_rewrites)?; + let rewritten_plan = + rewrite_table_scans(&plan.build()?, &mut known_rewrites, false, &mut None)?; println!("rewritten_plan: \n{:#?}", rewritten_plan); @@ -1018,7 +1423,7 @@ mod tests { ]; for test in agg_tests { - test_sql(&ctx, test.0, test.1).await?; + test_sql(&ctx, test.0, test.1, false).await?; } Ok(()) @@ -1045,7 +1450,7 @@ mod tests { ]; for test in tests { - test_sql(&ctx, test.0, test.1).await?; + test_sql(&ctx, test.0, test.1, false).await?; } Ok(()) @@ -1072,7 +1477,32 @@ mod tests { ]; for test in tests { - test_sql(&ctx, test.0, test.1).await?; + test_sql(&ctx, test.0, test.1, false).await?; + } + + Ok(()) + } + + #[tokio::test] + async fn test_subquery_requires_partial_path() -> Result<()> { + init_tracing(); + let ctx = get_test_df_context(); + + let tests = vec![ + ( + "SELECT a FROM bar where a IN (SELECT a FROM bar)", + r#"SELECT remote_db.remote_schema.remote_table.a FROM remote_db.remote_schema.remote_table WHERE remote_db.remote_schema.remote_table.a IN (SELECT a FROM remote_db.remote_schema.remote_table)"#, + true, + ), + ( + "SELECT a FROM bar where a IN (SELECT a FROM bar)", + r#"SELECT remote_db.remote_schema.remote_table.a FROM remote_db.remote_schema.remote_table WHERE remote_db.remote_schema.remote_table.a IN (SELECT remote_db.remote_schema.remote_table.a FROM remote_db.remote_schema.remote_table)"#, + false, + ), + ]; + + for test in tests { + test_sql(&ctx, test.0, test.1, test.2).await?; } Ok(()) @@ -1089,7 +1519,7 @@ mod tests { )]; for test in tests { - test_sql(&ctx, test.0, test.1).await?; + test_sql(&ctx, test.0, test.1, false).await?; } Ok(()) @@ -1099,15 +1529,21 @@ mod tests { ctx: &SessionContext, sql_query: &str, expected_sql: &str, + subquery_uses_partial_path: bool, ) -> Result<(), datafusion::error::DataFusionError> { let data_frame = ctx.sql(sql_query).await?; - // println!("before optimization: \n{:#?}", data_frame.logical_plan()); + println!("before optimization: \n{:#?}", data_frame.logical_plan()); let mut known_rewrites = HashMap::new(); - let rewritten_plan = rewrite_table_scans(data_frame.logical_plan(), &mut known_rewrites)?; + let rewritten_plan = rewrite_table_scans( + data_frame.logical_plan(), + &mut known_rewrites, + subquery_uses_partial_path, + &mut None, + )?; - // println!("rewritten_plan: \n{:#?}", rewritten_plan); + println!("rewritten_plan: \n{:#?}", rewritten_plan); let unparsed_sql = plan_to_sql(&rewritten_plan)?; From 58324710998466a2730da10a4818ae9157818d43 Mon Sep 17 00:00:00 2001 From: Evgenii Khramkov Date: Thu, 2 Jan 2025 19:27:26 +0900 Subject: [PATCH 48/48] Handle LogicalPlan::Limit separately to preserve skip and offset in rewrite_table_scans --- sources/sql/src/lib.rs | 90 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 89 insertions(+), 1 deletion(-) diff --git a/sources/sql/src/lib.rs b/sources/sql/src/lib.rs index 756acea..09a1fb3 100644 --- a/sources/sql/src/lib.rs +++ b/sources/sql/src/lib.rs @@ -19,7 +19,7 @@ use datafusion::{ AggregateFunction, Alias, Exists, InList, InSubquery, ScalarFunction, Sort, Unnest, WindowFunction, }, - Between, BinaryExpr, Case, Cast, Expr, Extension, GroupingSet, Like, LogicalPlan, + Between, BinaryExpr, Case, Cast, Expr, Extension, GroupingSet, Like, Limit, LogicalPlan, LogicalPlanBuilder, Projection, Subquery, TryCast, }, optimizer::analyzer::{Analyzer, AnalyzerRule}, @@ -180,6 +180,44 @@ fn rewrite_table_scans( subquery_table_scans, ) } + LogicalPlan::Limit(limit) => { + let rewritten_skip = limit + .skip + .as_ref() + .map(|skip| { + rewrite_table_scans_in_expr( + *skip.clone(), + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + ) + .map(Box::new) + }) + .transpose()?; + + let rewritten_fetch = limit + .fetch + .as_ref() + .map(|fetch| { + rewrite_table_scans_in_expr( + *fetch.clone(), + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + ) + .map(Box::new) + }) + .transpose()?; + + // explisitly set fetch and skip + let new_plan = LogicalPlan::Limit(Limit { + skip: rewritten_skip, + fetch: rewritten_fetch, + input: Arc::new(rewritten_inputs[0].clone()), + }); + Ok(new_plan) + } + _ => { let mut new_expressions = vec![]; for expression in plan.expressions() { @@ -1558,4 +1596,54 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_rewrite_table_scans_limit_offset() -> Result<()> { + init_tracing(); + let ctx = get_test_df_context(); + + let tests = vec![ + // Basic LIMIT + ( + "SELECT a FROM foo.df_table LIMIT 5", + r#"SELECT remote_table.a FROM remote_table LIMIT 5"#, + ), + // Basic OFFSET + ( + "SELECT a FROM foo.df_table OFFSET 5", + r#"SELECT remote_table.a FROM remote_table OFFSET 5"#, + ), + // OFFSET after LIMIT + ( + "SELECT a FROM foo.df_table LIMIT 10 OFFSET 5", + r#"SELECT remote_table.a FROM remote_table LIMIT 10 OFFSET 5"#, + ), + // LIMIT after OFFSET + ( + "SELECT a FROM foo.df_table OFFSET 5 LIMIT 10", + r#"SELECT remote_table.a FROM remote_table LIMIT 10 OFFSET 5"#, + ), + // Zero OFFSET + ( + "SELECT a FROM foo.df_table OFFSET 0", + r#"SELECT remote_table.a FROM remote_table OFFSET 0"#, + ), + // Zero LIMIT + ( + "SELECT a FROM foo.df_table LIMIT 0", + r#"SELECT remote_table.a FROM remote_table LIMIT 0"#, + ), + // Zero LIMIT and OFFSET + ( + "SELECT a FROM foo.df_table LIMIT 0 OFFSET 0", + r#"SELECT remote_table.a FROM remote_table LIMIT 0 OFFSET 0"#, + ), + ]; + + for test in tests { + test_sql(&ctx, test.0, test.1, false).await?; + } + + Ok(()) + } }