diff --git a/.github/workflows/check.yml b/.github/workflows/check.yml index fe66d27..92e1dad 100644 --- a/.github/workflows/check.yml +++ b/.github/workflows/check.yml @@ -10,4 +10,4 @@ jobs: - uses: actions/checkout@v4 - uses: actions/setup-node@v4 - run: npm install prettier prettier-plugin-toml - - run: npx prettier --check --no-config . + - run: npx prettier --write --no-config . diff --git a/.gitignore b/.gitignore index 66b190e..c92a89f 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ /node_modules package-lock.json package.json +.DS_Store \ No newline at end of file diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..6982f30 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,213 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "type": "lldb", + "request": "launch", + "name": "Debug unit tests in library 'datafusion_federation'", + "cargo": { + "args": [ + "test", + "--no-run", + "--lib", + "--package=datafusion-federation" + ], + "filter": { + "name": "datafusion_federation", + "kind": "lib" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug example 'sqlite'", + "cargo": { + "args": [ + "build", + "--example=sqlite", + "--package=datafusion-federation-examples" + ], + "filter": { + "name": "sqlite", + "kind": "example" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug unit tests in example 'sqlite'", + "cargo": { + "args": [ + "test", + "--no-run", + "--example=sqlite", + "--package=datafusion-federation-examples" + ], + "filter": { + "name": "sqlite", + "kind": "example" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug example 'flight-sql'", + "cargo": { + "args": [ + "build", + "--example=flight-sql", + "--package=datafusion-federation-examples" + ], + "filter": { + "name": "flight-sql", + "kind": "example" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug unit tests in example 'flight-sql'", + "cargo": { + "args": [ + "test", + "--no-run", + "--example=flight-sql", + "--package=datafusion-federation-examples" + ], + "filter": { + "name": "flight-sql", + "kind": "example" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug example 'postgres-partial'", + "cargo": { + "args": [ + "build", + "--example=postgres-partial", + "--package=datafusion-federation-examples" + ], + "filter": { + "name": "postgres-partial", + "kind": "example" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug unit tests in example 'postgres-partial'", + "cargo": { + "args": [ + "test", + "--no-run", + "--example=postgres-partial", + "--package=datafusion-federation-examples" + ], + "filter": { + "name": "postgres-partial", + "kind": "example" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug example 'sqlite-partial'", + "cargo": { + "args": [ + "build", + "--example=sqlite-partial", + "--package=datafusion-federation-examples" + ], + "filter": { + "name": "sqlite-partial", + "kind": "example" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug unit tests in example 'sqlite-partial'", + "cargo": { + "args": [ + "test", + "--no-run", + "--example=sqlite-partial", + "--package=datafusion-federation-examples" + ], + "filter": { + "name": "sqlite-partial", + "kind": "example" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug unit tests in library 'datafusion_federation_flight_sql'", + "cargo": { + "args": [ + "test", + "--no-run", + "--lib", + "--package=datafusion-federation-flight-sql" + ], + "filter": { + "name": "datafusion_federation_flight_sql", + "kind": "lib" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug unit tests in library 'datafusion_federation_sql'", + "cargo": { + "args": [ + "test", + "--no-run", + "--lib", + "--package=datafusion-federation-sql" + ], + "filter": { + "name": "datafusion_federation_sql", + "kind": "lib" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + } + ] +} \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index 744de0b..15d64d7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,15 +11,20 @@ members = [ [patch.crates-io] # connectorx = { path = "../connector-x/connectorx" } # datafusion = { path = "../arrow-datafusion/datafusion/core" } +# Pending next Datafusion release with `unnest` unparsing support +datafusion = { git = "https://github.com/spiceai/datafusion.git", rev = "0bad328656a07fd5ab899186462b09e119e21f90" } [workspace.package] -version = "0.1.3" +version = "0.1.6" edition = "2021" -license = "MIT" +license = "Apache-2.0" readme = "README.md" [workspace.dependencies] async-trait = "0.1.77" -datafusion = "37.0.0" -datafusion-substrait = "37.0.0" +async-stream = "0.3.5" +futures = "0.3.30" +datafusion = "43" +datafusion-substrait = "43" +arrow-json = "53" diff --git a/datafusion-federation/Cargo.toml b/datafusion-federation/Cargo.toml index 83e2723..f344a71 100644 --- a/datafusion-federation/Cargo.toml +++ b/datafusion-federation/Cargo.toml @@ -13,6 +13,9 @@ path = "src/lib.rs" [dependencies] async-trait.workspace = true datafusion.workspace = true +async-stream.workspace = true +futures.workspace = true +arrow-json.workspace = true [package.metadata.docs.rs] diff --git a/datafusion-federation/src/analyzer.rs b/datafusion-federation/src/analyzer.rs deleted file mode 100644 index f5ba3a1..0000000 --- a/datafusion-federation/src/analyzer.rs +++ /dev/null @@ -1,170 +0,0 @@ -use std::sync::Arc; - -use datafusion::{ - config::ConfigOptions, - datasource::source_as_provider, - error::Result, - logical_expr::{Expr, LogicalPlan, Projection, TableScan, TableSource}, - optimizer::analyzer::AnalyzerRule, -}; - -use crate::{FederatedTableProviderAdaptor, FederatedTableSource, FederationProviderRef}; - -#[derive(Default)] -pub struct FederationAnalyzerRule {} - -impl AnalyzerRule for FederationAnalyzerRule { - // Walk over the plan, look for the largest subtrees that only have - // TableScans from the same FederationProvider. - // There 'largest sub-trees' are passed to their respective FederationProvider.optimizer. - fn analyze(&self, plan: LogicalPlan, config: &ConfigOptions) -> Result { - let (optimized, _) = self.optimize_recursively(&plan, None, config)?; - if let Some(result) = optimized { - return Ok(result); - } - Ok(plan.clone()) - } - - /// A human readable name for this optimizer rule - fn name(&self) -> &str { - "federation_optimizer_rule" - } -} - -impl FederationAnalyzerRule { - pub fn new() -> Self { - Self::default() - } - - // optimize_recursively recursively finds the largest sub-plans that can be federated - // to a single FederationProvider. - // Returns a plan if a sub-tree was federated, otherwise None. - // Returns a FederationProvider if it covers the entire sub-tree, otherwise None. - fn optimize_recursively( - &self, - plan: &LogicalPlan, - parent: Option<&LogicalPlan>, - _config: &ConfigOptions, - ) -> Result<(Option, Option)> { - // Check if this node determines the FederationProvider - let sole_provider = self.get_federation_provider(plan)?; - if sole_provider.is_some() { - return Ok((None, sole_provider)); - } - - // optimize_inputs - let inputs = plan.inputs(); - if inputs.is_empty() { - return Ok((None, None)); - } - - let (new_inputs, providers): (Vec<_>, Vec<_>) = inputs - .iter() - .map(|i| self.optimize_recursively(i, Some(plan), _config)) - .collect::>>()? - .into_iter() - .unzip(); - - // Note: assumes provider is None if ambiguous - let first_provider = providers.first().unwrap(); - let is_singular = providers.iter().all(|p| p.is_some() && p == first_provider); - - if is_singular { - if parent.is_none() { - // federate the entire plan - if let Some(provider) = first_provider { - if let Some(optimizer) = provider.analyzer() { - let optimized = optimizer.execute_and_check(plan, _config, |_, _| {})?; - return Ok((Some(optimized), None)); - } - return Ok((None, None)); - } - return Ok((None, None)); - } - // The largest sub-plan is higher up. - return Ok((None, first_provider.clone())); - } - - // The plan is ambiguous, any inputs that are not federated and - // have a sole provider, should be federated. - let new_inputs = new_inputs - .into_iter() - .enumerate() - .map(|(i, new_sub_plan)| { - if let Some(sub_plan) = new_sub_plan { - // Already federated - return Ok(sub_plan); - } - let sub_plan = inputs.get(i).unwrap(); - // Check if the input has a sole provider and can be federated. - if let Some(provider) = providers.get(i).unwrap() { - if let Some(optimizer) = provider.analyzer() { - let wrapped = wrap_projection((*sub_plan).clone())?; - - let optimized = - optimizer.execute_and_check(&wrapped, _config, |_, _| {})?; - return Ok(optimized); - } - // No federation for this sub-plan (no analyzer) - return Ok((*sub_plan).clone()); - } - // No federation for this sub-plan (no provider) - Ok((*sub_plan).clone()) - }) - .collect::>>()?; - - let new_plan = plan.with_new_exprs(plan.expressions(), new_inputs)?; - - Ok((Some(new_plan), None)) - } - - fn get_federation_provider(&self, plan: &LogicalPlan) -> Result> { - match plan { - LogicalPlan::TableScan(TableScan { ref source, .. }) => { - let Some(federated_source) = get_table_source(source)? else { - return Ok(None); - }; - let provider = federated_source.federation_provider(); - Ok(Some(provider)) - } - _ => Ok(None), - } - } -} - -fn wrap_projection(plan: LogicalPlan) -> Result { - // TODO: minimize requested columns - match plan { - LogicalPlan::Projection(_) => Ok(plan), - _ => { - let expr = plan - .schema() - .fields() - .iter() - .map(|f| Expr::Column(f.qualified_column())) - .collect::>(); - Ok(LogicalPlan::Projection(Projection::try_new( - expr, - Arc::new(plan), - )?)) - } - } -} - -pub fn get_table_source( - source: &Arc, -) -> Result>> { - // Unwrap TableSource - let source = source_as_provider(source)?; - - // Get FederatedTableProviderAdaptor - let Some(wrapper) = source - .as_any() - .downcast_ref::() - else { - return Ok(None); - }; - - // Return original FederatedTableSource - Ok(Some(Arc::clone(&wrapper.source))) -} diff --git a/datafusion-federation/src/analyzer/mod.rs b/datafusion-federation/src/analyzer/mod.rs new file mode 100644 index 0000000..74d954a --- /dev/null +++ b/datafusion-federation/src/analyzer/mod.rs @@ -0,0 +1,514 @@ +mod scan_result; + +use crate::FederationProvider; +use crate::{ + optimize::Optimizer, FederatedTableProviderAdaptor, FederatedTableSource, FederationProviderRef, +}; +use datafusion::error::DataFusionError; +use datafusion::logical_expr::{col, expr::InSubquery, LogicalPlanBuilder}; +use datafusion::{ + common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}, + config::ConfigOptions, + datasource::source_as_provider, + error::Result, + logical_expr::{Expr, Extension, LogicalPlan, Projection, TableScan, TableSource}, + optimizer::analyzer::AnalyzerRule, + sql::TableReference, +}; +use scan_result::ScanResult; +use std::collections::HashMap; +use std::sync::Arc; +use std::sync::RwLock; + +#[derive(Debug)] +pub struct FederationAnalyzerRule { + optimizer: Optimizer, + provider_map: Arc>>, +} + +impl Default for FederationAnalyzerRule { + fn default() -> Self { + Self { + optimizer: Optimizer::default(), + provider_map: Arc::new(RwLock::new(HashMap::new())), + } + } +} + +impl AnalyzerRule for FederationAnalyzerRule { + // Walk over the plan, look for the largest subtrees that only have + // TableScans from the same FederationProvider. + // There 'largest sub-trees' are passed to their respective FederationProvider.optimizer. + fn analyze(&self, plan: LogicalPlan, config: &ConfigOptions) -> Result { + if !contains_federated_table(&plan)? { + return Ok(plan); + } + // Run selected optimizer rules before federation + let plan = self.optimizer.optimize_plan(plan)?; + + // Find all federation providers for TableReference appeared in the plan + let providers = get_plan_provider_recursively(&plan)?; + let mut write_map = self.provider_map.write().map_err(|_| { + DataFusionError::External( + "Failed to create federated plan: failed to find all federated providers.".into(), + ) + })?; + write_map.extend(providers); + drop(write_map); + + match self.optimize_plan_recursively(&plan, true, config)? { + (Some(optimized_plan), _) => Ok(optimized_plan), + (None, _) => Ok(plan), + } + } + + /// A human readable name for this optimizer rule + fn name(&self) -> &str { + "federation_optimizer_rule" + } +} + +impl FederationAnalyzerRule { + pub fn new() -> Self { + Self::default() + } + + /// Scans a plan to see if it belongs to a single [`FederationProvider`]. + fn scan_plan_recursively(&self, plan: &LogicalPlan) -> Result { + let mut sole_provider: ScanResult = ScanResult::None; + + plan.apply(&mut |p: &LogicalPlan| -> Result { + let exprs_provider = self.scan_plan_exprs(p)?; + sole_provider.merge(exprs_provider); + + if sole_provider.is_ambiguous() { + return Ok(TreeNodeRecursion::Stop); + } + + let (sub_provider, _) = get_leaf_provider(p)?; + sole_provider.add(sub_provider); + + Ok(sole_provider.check_recursion()) + })?; + + Ok(sole_provider) + } + + /// Scans a plan's expressions to see if it belongs to a single [`FederationProvider`]. + fn scan_plan_exprs(&self, plan: &LogicalPlan) -> Result { + let mut sole_provider: ScanResult = ScanResult::None; + + let exprs = plan.expressions(); + for expr in &exprs { + let expr_result = self.scan_expr_recursively(expr)?; + sole_provider.merge(expr_result); + + if sole_provider.is_ambiguous() { + return Ok(sole_provider); + } + } + + Ok(sole_provider) + } + + /// scans an expression to see if it belongs to a single [`FederationProvider`] + fn scan_expr_recursively(&self, expr: &Expr) -> Result { + let mut sole_provider: ScanResult = ScanResult::None; + + expr.apply(&mut |e: &Expr| -> Result { + // TODO: Support other types of sub-queries + match e { + Expr::ScalarSubquery(ref subquery) => { + let plan_result = self.scan_plan_recursively(&subquery.subquery)?; + + sole_provider.merge(plan_result); + Ok(sole_provider.check_recursion()) + } + Expr::InSubquery(ref insubquery) => { + let plan_result = self.scan_plan_recursively(&insubquery.subquery.subquery)?; + + sole_provider.merge(plan_result); + Ok(sole_provider.check_recursion()) + } + Expr::OuterReferenceColumn(_, ref col) => { + if let Some(table) = &col.relation { + let map = self.provider_map.read().map_err(|_| { + DataFusionError::External( + "Failed to create federated plan: failed to obtain a read lock on federated providers.".into(), + ) + })?; + if let Some(plan_result) = map.get(table) { + sole_provider.merge(plan_result.clone()); + return Ok(sole_provider.check_recursion()); + } + } + // Subqueries that reference outer columns are not supported + // for now. We handle this here as ambiguity to force + // federation lower in the plan tree. + sole_provider = ScanResult::Ambiguous; + Ok(TreeNodeRecursion::Stop) + } + _ => Ok(TreeNodeRecursion::Continue), + } + })?; + + Ok(sole_provider) + } + + /// Recursively finds the largest sub-plans that can be federated + /// to a single FederationProvider. + /// + /// Returns a plan if a sub-tree was federated, otherwise None. + /// + /// Returns a ScanResult of all FederationProviders in the subtree. + fn optimize_plan_recursively( + &self, + plan: &LogicalPlan, + is_root: bool, + _config: &ConfigOptions, + ) -> Result<(Option, ScanResult)> { + let mut sole_provider: ScanResult = ScanResult::None; + + if let LogicalPlan::Extension(Extension { ref node }) = plan { + if node.name() == "Federated" { + // Avoid attempting double federation + return Ok((None, ScanResult::Ambiguous)); + } + } + + // Check if this plan node is a leaf that determines the FederationProvider + let (leaf_provider, _) = get_leaf_provider(plan)?; + + // Check if the expressions contain, a potentially different, FederationProvider + let exprs_result = self.scan_plan_exprs(plan)?; + + // Return early if this is a leaf and there is no ambiguity with the expressions. + if leaf_provider.is_some() && (exprs_result.is_none() || exprs_result == leaf_provider) { + return Ok((None, leaf_provider.into())); + } + // Aggregate leaf & expression providers + sole_provider.add(leaf_provider); + sole_provider.merge(exprs_result.clone()); + + let inputs = plan.inputs(); + // Return early if there are no sources. + if inputs.is_empty() && sole_provider.is_none() { + return Ok((None, ScanResult::None)); + } + + // Recursively optimize inputs + let input_results = inputs + .iter() + .map(|i| self.optimize_plan_recursively(i, false, _config)) + .collect::>>()?; + + // Aggregate the input providers + input_results.iter().for_each(|(_, scan_result)| { + sole_provider.merge(scan_result.clone()); + }); + + if sole_provider.is_none() { + // No providers found + // TODO: Is/should this be reachable? + return Ok((None, ScanResult::None)); + } + + // Federate Exprs when Exprs provider is ambiguous or Exprs provider differs from the sole_provider of current plan + // When Exprs provider is the same as sole_provider and non-ambiguous, the larger sub-plan is higher up + let optimize_expressions = exprs_result.is_some() + && (!(sole_provider == exprs_result) || exprs_result.is_ambiguous()); + + // If all sources are federated to the same provider + if let ScanResult::Distinct(provider) = sole_provider { + if !is_root { + // The largest sub-plan is higher up. + return Ok((None, ScanResult::Distinct(provider))); + } + + let Some(optimizer) = provider.analyzer() else { + // No optimizer provided + return Ok((None, ScanResult::None)); + }; + + // If this is the root plan node; federate the entire plan + let optimized = optimizer.execute_and_check(plan.clone(), _config, |_, _| {})?; + return Ok((Some(optimized), ScanResult::None)); + } + + // The plan is ambiguous; any input that is not yet optimized and has a + // sole provider represents a largest sub-plan and should be federated. + // + // We loop over the input optimization results, federate where needed and + // return a complete list of new inputs for the optimized plan. + let new_inputs = input_results + .into_iter() + .enumerate() + .map(|(i, (input_plan, input_result))| { + if let Some(federated_plan) = input_plan { + // Already federated deeper in the plan tree + return Ok(federated_plan); + } + + let original_input = (*inputs.get(i).unwrap()).clone(); + if input_result.is_ambiguous() { + // Can happen if the input is already federated, so use + // the original input. + return Ok(original_input); + } + + let provider = input_result.unwrap()?; + let Some(provider) = provider else { + // No provider for this input; use the original input. + return Ok(original_input); + }; + + let Some(optimizer) = provider.analyzer() else { + // No optimizer for this input; use the original input. + return Ok(original_input); + }; + + // Replace the input with the federated counterpart + let wrapped = wrap_projection(original_input)?; + let optimized = optimizer.execute_and_check(wrapped, _config, |_, _| {})?; + + Ok(optimized) + }) + .collect::>>()?; + + // Optimize expressions if needed + let new_expressions = if optimize_expressions { + self.optimize_plan_exprs(plan, _config)? + } else { + plan.expressions() + }; + + // Construct the optimized plan + let new_plan = plan.with_new_exprs(new_expressions, new_inputs)?; + + // Return the federated plan + Ok((Some(new_plan), ScanResult::Ambiguous)) + } + + /// Optimizes all exprs of a plan + fn optimize_plan_exprs( + &self, + plan: &LogicalPlan, + _config: &ConfigOptions, + ) -> Result> { + plan.expressions() + .iter() + .map(|expr| { + let transformed = expr + .clone() + .transform(&|e| self.optimize_expr_recursively(e, _config))?; + Ok(transformed.data) + }) + .collect::>>() + } + + /// recursively optimize expressions + /// Current logic: individually federate every sub-query. + fn optimize_expr_recursively( + &self, + expr: Expr, + _config: &ConfigOptions, + ) -> Result> { + match expr { + Expr::ScalarSubquery(ref subquery) => { + // Optimize as root to force federating the sub-query + let (new_subquery, _) = + self.optimize_plan_recursively(&subquery.subquery, true, _config)?; + let Some(new_subquery) = new_subquery else { + return Ok(Transformed::no(expr)); + }; + + // ScalarSubqueryToJoin optimizer rule doesn't support federated node (LogicalPlan::Extension(_)) as subquery + // Wrap a `non-op` Projection LogicalPlan outside the federated node to facilitate ScalarSubqueryToJoin optimization + if matches!(new_subquery, LogicalPlan::Extension(_)) { + let all_columns = new_subquery + .schema() + .fields() + .iter() + .map(|field| col(field.name())) + .collect::>(); + + let projection_plan = LogicalPlanBuilder::from(new_subquery) + .project(all_columns)? + .build()?; + + return Ok(Transformed::yes(Expr::ScalarSubquery( + subquery.with_plan(projection_plan.into()), + ))); + } + + Ok(Transformed::yes(Expr::ScalarSubquery( + subquery.with_plan(new_subquery.into()), + ))) + } + Expr::InSubquery(ref in_subquery) => { + let (new_subquery, _) = + self.optimize_plan_recursively(&in_subquery.subquery.subquery, true, _config)?; + let Some(new_subquery) = new_subquery else { + return Ok(Transformed::no(expr)); + }; + + // DecorrelatePredicateSubquery optimizer rule doesn't support federated node (LogicalPlan::Extension(_)) as subquery + // Wrap a `non-op` Projection LogicalPlan outside the federated node to facilitate DecorrelatePredicateSubquery optimization + if matches!(new_subquery, LogicalPlan::Extension(_)) { + let all_columns = new_subquery + .schema() + .fields() + .iter() + .map(|field| col(field.name())) + .collect::>(); + + let projection_plan = LogicalPlanBuilder::from(new_subquery) + .project(all_columns)? + .build()?; + + return Ok(Transformed::yes(Expr::InSubquery(InSubquery::new( + in_subquery.expr.clone(), + in_subquery.subquery.with_plan(projection_plan.into()), + in_subquery.negated, + )))); + } + + Ok(Transformed::yes(Expr::InSubquery(InSubquery::new( + in_subquery.expr.clone(), + in_subquery.subquery.with_plan(new_subquery.into()), + in_subquery.negated, + )))) + } + _ => Ok(Transformed::no(expr)), + } + } +} + +/// NopFederationProvider is used to represent tables that are not federated, but +/// are resolved by DataFusion. This simplifies the logic of the optimizer rule. +struct NopFederationProvider {} + +impl FederationProvider for NopFederationProvider { + fn name(&self) -> &str { + "nop" + } + + fn compute_context(&self) -> Option { + None + } + + fn analyzer(&self) -> Option> { + None + } +} + +/// Recursively find the [`FederationProvider`] for all [`TableReference`] instances in the plan. +/// This information is used to resolve the federation provider for [`Expr::OuterReferenceColumn`]. +fn get_plan_provider_recursively( + plan: &LogicalPlan, +) -> Result> { + let mut providers: HashMap = HashMap::new(); + + plan.apply(&mut |p: &LogicalPlan| -> Result { + // LogicalPlan::SubqueryAlias can also be referred by OuterReferenceColumn + // Get the federation provider for TableReference representing LogicalPlan::SubqueryAlias + if let LogicalPlan::SubqueryAlias(a) = p { + let subquery_alias_providers = get_plan_provider_recursively(&Arc::clone(&a.input))?; + let mut provider: ScanResult = ScanResult::None; + for (_, i) in subquery_alias_providers { + provider.merge(i); + } + providers.insert(a.alias.clone(), provider); + } + + let (federation_provider, table_reference) = get_leaf_provider(p)?; + if let Some(table_reference) = table_reference { + providers.insert(table_reference, federation_provider.into()); + } + + let _ = p.apply_subqueries(|sub_query| { + let subquery_providers = get_plan_provider_recursively(sub_query)?; + providers.extend(subquery_providers); + Ok(TreeNodeRecursion::Continue) + }); + + Ok(TreeNodeRecursion::Continue) + })?; + + Ok(providers) +} + +fn wrap_projection(plan: LogicalPlan) -> Result { + // TODO: minimize requested columns + match plan { + LogicalPlan::Projection(_) => Ok(plan), + _ => { + let expr = plan + .schema() + .columns() + .iter() + .map(|c| Expr::Column(c.clone())) + .collect::>(); + Ok(LogicalPlan::Projection(Projection::try_new( + expr, + Arc::new(plan), + )?)) + } + } +} + +fn contains_federated_table(plan: &LogicalPlan) -> Result { + let federated_table_exists = plan.exists(|x| { + if let (Some(provider), _) = get_leaf_provider(x)? { + // federated table provider should have an analyzer + return Ok(provider.analyzer().is_some()); + } + Ok(false) + })?; + + Ok(federated_table_exists) +} + +fn get_leaf_provider( + plan: &LogicalPlan, +) -> Result<(Option, Option)> { + match plan { + LogicalPlan::TableScan(TableScan { + ref table_name, + ref source, + .. + }) => { + let table_reference = table_name.clone(); + let Some(federated_source) = get_table_source(source)? else { + // Table is not federated but provided by a standard table provider. + // We use a placeholder federation provider to simplify the logic. + return Ok(( + Some(Arc::new(NopFederationProvider {})), + Some(table_reference), + )); + }; + let provider = federated_source.federation_provider(); + Ok((Some(provider), Some(table_reference))) + } + _ => Ok((None, None)), + } +} + +#[allow(clippy::missing_errors_doc)] +pub fn get_table_source( + source: &Arc, +) -> Result>> { + // Unwrap TableSource + let source = source_as_provider(source)?; + + // Get FederatedTableProviderAdaptor + let Some(wrapper) = source + .as_any() + .downcast_ref::() + else { + return Ok(None); + }; + + // Return original FederatedTableSource + Ok(Some(Arc::clone(&wrapper.source))) +} diff --git a/datafusion-federation/src/analyzer/scan_result.rs b/datafusion-federation/src/analyzer/scan_result.rs new file mode 100644 index 0000000..dc01906 --- /dev/null +++ b/datafusion-federation/src/analyzer/scan_result.rs @@ -0,0 +1,113 @@ +use crate::FederationProviderRef; +use datafusion::common::tree_node::TreeNodeRecursion; +use datafusion::error::{DataFusionError, Result}; + +/// Used to track if all sources, including tableScan, plan inputs and +/// expressions, represents an un-ambiguous, none or a sole' [`crate::FederationProvider`]. +pub enum ScanResult { + None, + Distinct(FederationProviderRef), + Ambiguous, +} + +impl ScanResult { + pub fn merge(&mut self, other: Self) { + match (&self, &other) { + (_, ScanResult::None) => {} + (ScanResult::None, _) => *self = other, + (ScanResult::Ambiguous, _) | (_, ScanResult::Ambiguous) => { + *self = ScanResult::Ambiguous; + } + (ScanResult::Distinct(provider), ScanResult::Distinct(other_provider)) => { + if provider != other_provider { + *self = ScanResult::Ambiguous; + } + } + } + } + + pub fn add(&mut self, provider: Option) { + self.merge(ScanResult::from(provider)) + } + + pub fn is_ambiguous(&self) -> bool { + matches!(self, ScanResult::Ambiguous) + } + + pub fn is_none(&self) -> bool { + matches!(self, ScanResult::None) + } + pub fn is_some(&self) -> bool { + !self.is_none() + } + + pub fn unwrap(self) -> Result> { + match self { + ScanResult::None => Ok(None), + ScanResult::Distinct(provider) => Ok(Some(provider)), + ScanResult::Ambiguous => Err(DataFusionError::External( + "called `ScanResult::unwrap()` on a `Ambiguous` value".into(), + )), + } + } + + pub fn check_recursion(&self) -> TreeNodeRecursion { + if self.is_ambiguous() { + TreeNodeRecursion::Stop + } else { + TreeNodeRecursion::Continue + } + } +} + +impl From> for ScanResult { + fn from(provider: Option) -> Self { + match provider { + Some(provider) => ScanResult::Distinct(provider), + None => ScanResult::None, + } + } +} + +impl PartialEq> for ScanResult { + fn eq(&self, other: &Option) -> bool { + match (self, other) { + (ScanResult::None, None) => true, + (ScanResult::Distinct(provider), Some(other_provider)) => provider == other_provider, + _ => false, + } + } +} + +impl PartialEq for ScanResult { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (ScanResult::None, ScanResult::None) => true, + (ScanResult::Distinct(provider1), ScanResult::Distinct(provider2)) => { + provider1 == provider2 + } + (ScanResult::Ambiguous, ScanResult::Ambiguous) => true, + _ => false, + } + } +} + +impl std::fmt::Debug for ScanResult { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::None => write!(f, "ScanResult::None"), + Self::Distinct(provider) => write!(f, "ScanResult::Distinct({})", provider.name()), + Self::Ambiguous => write!(f, "ScanResult::Ambiguous"), + } + } +} + +impl Clone for ScanResult { + fn clone(&self) -> Self { + match self { + ScanResult::None => ScanResult::None, + ScanResult::Distinct(provider) => ScanResult::Distinct(provider.clone()), + ScanResult::Ambiguous => ScanResult::Ambiguous, + } + } +} diff --git a/datafusion-federation/src/lib.rs b/datafusion-federation/src/lib.rs index 999d296..eb4a635 100644 --- a/datafusion-federation/src/lib.rs +++ b/datafusion-federation/src/lib.rs @@ -8,11 +8,13 @@ use datafusion::optimizer::analyzer::Analyzer; mod analyzer; pub use analyzer::*; +mod optimize; mod table_provider; pub use table_provider::*; mod plan_node; pub use plan_node::*; +pub mod schema_cast; pub type FederationProviderRef = Arc; pub trait FederationProvider: Send + Sync { diff --git a/datafusion-federation/src/optimize.rs b/datafusion-federation/src/optimize.rs new file mode 100644 index 0000000..993a2ab --- /dev/null +++ b/datafusion-federation/src/optimize.rs @@ -0,0 +1,136 @@ +use datafusion::{ + common::tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, + error::Result, + logical_expr::LogicalPlan, + optimizer::{ + optimizer::ApplyOrder, push_down_filter::PushDownFilter, OptimizerConfig, OptimizerContext, + OptimizerRule, + }, +}; +use optimize_projections::OptimizeProjections; + +mod optimize_projections; + +#[derive(Debug)] +pub(crate) struct Optimizer { + config: OptimizerContext, + push_down_filter: PushDownFilter, + optimize_projections: OptimizeProjections, +} + +impl Default for Optimizer { + fn default() -> Self { + // `push_down_filter` and `optimize_projections` does not use config so it can be default + let config = OptimizerContext::default(); + + Self { + config, + push_down_filter: PushDownFilter::new(), + optimize_projections: OptimizeProjections::new(), + } + } +} + +impl Optimizer { + pub(crate) fn optimize_plan(&self, plan: LogicalPlan) -> Result { + let mut optimized_plan = plan + .rewrite(&mut Rewriter::new( + ApplyOrder::TopDown, + &self.push_down_filter, + &self.config, + ))? + .data; + + // `optimize_projections` is applied recursively top down so it can be applied only once to the root node + optimized_plan = self + .optimize_projections + .rewrite(optimized_plan, &self.config) + .data()?; + + Ok(optimized_plan) + } +} + +struct Rewriter<'a> { + apply_order: ApplyOrder, + rule: &'a dyn OptimizerRule, + config: &'a dyn OptimizerConfig, +} + +impl<'a> Rewriter<'a> { + fn new( + apply_order: ApplyOrder, + rule: &'a dyn OptimizerRule, + config: &'a dyn OptimizerConfig, + ) -> Self { + Self { + apply_order, + rule, + config, + } + } +} + +impl TreeNodeRewriter for Rewriter<'_> { + type Node = LogicalPlan; + + fn f_down(&mut self, node: LogicalPlan) -> Result> { + if self.apply_order == ApplyOrder::TopDown { + optimize_plan_node(node, self.rule, self.config) + } else { + Ok(Transformed::no(node)) + } + } + + fn f_up(&mut self, node: LogicalPlan) -> Result> { + if self.apply_order == ApplyOrder::BottomUp { + optimize_plan_node(node, self.rule, self.config) + } else { + Ok(Transformed::no(node)) + } + } +} + +fn should_run_rule_for_node(node: &LogicalPlan, _rule: &dyn OptimizerRule) -> bool { + // this logic is applicable only for `push_down_filter_rule`; we don't have any other rules using `should_run_rule_for_node` + if let LogicalPlan::Filter(x) = node { + // Applying the `push_down_filter_rule` to certain nodes like `SubqueryAlias`, `Aggregate`, and `CrossJoin` + // can cause issues during unparsing, thus the optimization is only applied to nodes that are currently supported. + matches!( + x.input.as_ref(), + LogicalPlan::Join(_) + | LogicalPlan::TableScan(_) + | LogicalPlan::Projection(_) + | LogicalPlan::Filter(_) + | LogicalPlan::Distinct(_) + | LogicalPlan::Sort(_) + ) + } else { + true + } +} + +fn optimize_plan_node( + plan: LogicalPlan, + rule: &dyn OptimizerRule, + config: &dyn OptimizerConfig, +) -> Result> { + if !should_run_rule_for_node(&plan, rule) { + return Ok(Transformed::no(plan)); + } + + if rule.supports_rewrite() { + return rule.rewrite(plan, config); + } + + #[allow(deprecated)] + rule.try_optimize(&plan, config).map(|maybe_plan| { + match maybe_plan { + Some(new_plan) => { + // if the node was rewritten by the optimizer, replace the node + Transformed::yes(new_plan) + } + None => Transformed::no(plan), + } + }) +} diff --git a/datafusion-federation/src/optimize/optimize_projections/mod.rs b/datafusion-federation/src/optimize/optimize_projections/mod.rs new file mode 100644 index 0000000..699287c --- /dev/null +++ b/datafusion-federation/src/optimize/optimize_projections/mod.rs @@ -0,0 +1,1047 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{ + collections::{HashMap, HashSet}, + result, + sync::Arc, +}; + +use datafusion::{ + common::{ + get_required_group_by_exprs_indices, internal_datafusion_err, internal_err, + tree_node::{Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion}, + Column, JoinType, + }, + error::DataFusionError, + logical_expr::{ + expr::Alias, Aggregate, Distinct, LogicalPlan, Projection, TableScan, Unnest, Window, + }, + optimizer::{optimizer::ApplyOrder, utils::NamePreserver, OptimizerConfig, OptimizerRule}, + prelude::Expr, +}; +use required_indices::RequiredIndicies; + +type Result = result::Result; + +/// A modified version of DataFusion's [`OptimizeProjections`](https://github.com/apache/datafusion/blob/main/datafusion/optimizer/src/optimize_projections/mod.rs) rule. +/// +/// Unlike the original DataFusion implementation, this version does not create +/// or remove `Projection` nodes. This is required for `unparser` to correctly convert +/// `LogicalPlan` to SQL and also keeps the query layout simple and readable while relying on the +/// underlying SQL engine to apply its own optimizations during execution. +/// +/// For more information, see the [DataFusion discussion](https://github.com/apache/datafusion/pull/13267). +mod required_indices; + +#[derive(Default, Debug)] +pub struct OptimizeProjections {} + +impl OptimizeProjections { + #[must_use] + pub fn new() -> Self { + Self {} + } +} + +impl OptimizerRule for OptimizeProjections { + fn name(&self) -> &str { + "federation_optimize_projections" + } + + fn apply_order(&self) -> Option { + None + } + + fn supports_rewrite(&self) -> bool { + true + } + + fn rewrite( + &self, + plan: LogicalPlan, + config: &dyn OptimizerConfig, + ) -> Result> { + // All output fields are necessary: + let indices = RequiredIndicies::new_for_all_exprs(&plan); + optimize_projections(plan, config, indices) + } +} + +/// Removes unnecessary columns (e.g. columns that do not appear in the output +/// schema and/or are not used during any computation step such as expression +/// evaluation) from the logical plan and its inputs. +/// +/// # Parameters +/// +/// - `plan`: A reference to the input `LogicalPlan` to optimize. +/// - `config`: A reference to the optimizer configuration. +/// - `indices`: A slice of column indices that represent the necessary column +/// indices for downstream (parent) plan nodes. +/// +/// # Returns +/// +/// A `Result` object with the following semantics: +/// +/// - `Ok(Some(LogicalPlan))`: An optimized `LogicalPlan` without unnecessary +/// columns. +/// - `Ok(None)`: Signal that the given logical plan did not require any change. +/// - `Err(error)`: An error occurred during the optimization process. +fn optimize_projections( + plan: LogicalPlan, + config: &dyn OptimizerConfig, + indices: RequiredIndicies, +) -> Result> { + // Recursively rewrite any nodes that may be able to avoid computation given + // their parents' required indices. + match plan { + LogicalPlan::Projection(proj) => { + return merge_consecutive_projections(proj)?.transform_data(|proj| { + rewrite_projection_given_requirements(proj, config, &indices) + }) + } + LogicalPlan::Aggregate(aggregate) => { + // Split parent requirements to GROUP BY and aggregate sections: + let n_group_exprs = aggregate.group_expr_len()?; + // Offset aggregate indices so that they point to valid indices at + // `aggregate.aggr_expr`: + let (group_by_reqs, aggregate_reqs) = indices.split_off(n_group_exprs); + + // Get absolutely necessary GROUP BY fields: + let group_by_expr_existing = aggregate + .group_expr + .iter() + .map(|group_by_expr| group_by_expr.schema_name().to_string()) + .collect::>(); + + let new_group_bys = if let Some(simplest_groupby_indices) = + get_required_group_by_exprs_indices( + aggregate.input.schema(), + &group_by_expr_existing, + ) { + // Some of the fields in the GROUP BY may be required by the + // parent even if these fields are unnecessary in terms of + // functional dependency. + group_by_reqs + .append(&simplest_groupby_indices) + .get_at_indices(&aggregate.group_expr) + } else { + aggregate.group_expr + }; + + // Only use the absolutely necessary aggregate expressions required + // by the parent: + let mut new_aggr_expr = aggregate_reqs.get_at_indices(&aggregate.aggr_expr); + + // Aggregations always need at least one aggregate expression. + // With a nested count, we don't require any column as input, but + // still need to create a correct aggregate, which may be optimized + // out later. As an example, consider the following query: + // + // SELECT count(*) FROM (SELECT count(*) FROM [...]) + // + // which always returns 1. + if new_aggr_expr.is_empty() + && new_group_bys.is_empty() + && !aggregate.aggr_expr.is_empty() + { + // take the old, first aggregate expression + new_aggr_expr = aggregate.aggr_expr; + new_aggr_expr.resize_with(1, || unreachable!()); + } + + let all_exprs_iter = new_group_bys.iter().chain(new_aggr_expr.iter()); + let schema = aggregate.input.schema(); + let necessary_indices = RequiredIndicies::new().with_exprs(schema, all_exprs_iter); + let necessary_exprs = necessary_indices.get_required_exprs(schema); + + return optimize_projections( + Arc::unwrap_or_clone(aggregate.input), + config, + necessary_indices, + )? + .transform_data(|aggregate_input| { + // Simplify the input of the aggregation by adding a projection so + // that its input only contains absolutely necessary columns for + // the aggregate expressions. Note that necessary_indices refer to + // fields in `aggregate.input.schema()`. + add_projection_on_top_if_helpful(aggregate_input, necessary_exprs, config) + })? + .map_data(|aggregate_input| { + // Create a new aggregate plan with the updated input and only the + // absolutely necessary fields: + Aggregate::try_new(Arc::new(aggregate_input), new_group_bys, new_aggr_expr) + .map(LogicalPlan::Aggregate) + }); + } + LogicalPlan::Window(window) => { + let input_schema = Arc::clone(window.input.schema()); + // Split parent requirements to child and window expression sections: + let n_input_fields = input_schema.fields().len(); + // Offset window expression indices so that they point to valid + // indices at `window.window_expr`: + let (child_reqs, window_reqs) = indices.split_off(n_input_fields); + + // Only use window expressions that are absolutely necessary according + // to parent requirements: + let new_window_expr = window_reqs.get_at_indices(&window.window_expr); + + // Get all the required column indices at the input, either by the + // parent or window expression requirements. + let required_indices = child_reqs.with_exprs(&input_schema, &new_window_expr); + + return optimize_projections( + Arc::unwrap_or_clone(window.input), + config, + required_indices.clone(), + )? + .transform_data(|window_child| { + if new_window_expr.is_empty() { + // When no window expression is necessary, use the input directly: + Ok(Transformed::no(window_child)) + } else { + // Calculate required expressions at the input of the window. + // Please note that we use `input_schema`, because `required_indices` + // refers to that schema + let required_exprs = required_indices.get_required_exprs(&input_schema); + + let window_child = + add_projection_on_top_if_helpful(window_child, required_exprs, config)? + .data; + + Window::try_new(new_window_expr, Arc::new(window_child)) + .map(LogicalPlan::Window) + .map(Transformed::yes) + } + }); + } + LogicalPlan::TableScan(table_scan) => { + let TableScan { + table_name, + source, + projection, + filters, + fetch, + projected_schema: _, + } = table_scan; + + // Get indices referred to in the original (schema with all fields) + // given projected indices. + let projection = match &projection { + Some(projection) => indices.into_mapped_indices(|idx| projection[idx]), + None => indices.into_inner(), + }; + return TableScan::try_new(table_name, source, Some(projection), filters, fetch) + .map(LogicalPlan::TableScan) + .map(Transformed::yes); + } + // Other node types are handled below + _ => {} + }; + + // For other plan node types, calculate indices for columns they use and + // try to rewrite their children + let mut child_required_indices: Vec = match &plan { + LogicalPlan::Sort(_) + | LogicalPlan::Filter(_) + | LogicalPlan::Repartition(_) + | LogicalPlan::Union(_) + | LogicalPlan::SubqueryAlias(_) + | LogicalPlan::Distinct(Distinct::On(_)) => { + // Pass index requirements from the parent as well as column indices + // that appear in this plan's expressions to its child. All these + // operators benefit from "small" inputs, so the projection_beneficial + // flag is `true`. + plan.inputs() + .into_iter() + .map(|input| { + indices + .clone() + .with_projection_beneficial() + .with_plan_exprs(&plan, input.schema()) + }) + .collect::>()? + } + LogicalPlan::Limit(_) | LogicalPlan::Prepare(_) => { + // Pass index requirements from the parent as well as column indices + // that appear in this plan's expressions to its child. These operators + // do not benefit from "small" inputs, so the projection_beneficial + // flag is `false`. + plan.inputs() + .into_iter() + .map(|input| indices.clone().with_plan_exprs(&plan, input.schema())) + .collect::>()? + } + LogicalPlan::Copy(_) + | LogicalPlan::Ddl(_) + | LogicalPlan::Dml(_) + | LogicalPlan::Explain(_) + | LogicalPlan::Analyze(_) + | LogicalPlan::Subquery(_) + | LogicalPlan::Distinct(Distinct::All(_)) => { + // These plans require all their fields, and their children should + // be treated as final plans -- otherwise, we may have schema a + // mismatch. + // TODO: For some subquery variants (e.g. a subquery arising from an + // EXISTS expression), we may not need to require all indices. + plan.inputs() + .into_iter() + .map(RequiredIndicies::new_for_all_exprs) + .collect() + } + LogicalPlan::Extension(extension) => { + let Some(necessary_children_indices) = + extension.node.necessary_children_exprs(indices.indices()) + else { + // Requirements from parent cannot be routed down to user defined logical plan safely + return Ok(Transformed::no(plan)); + }; + let children = extension.node.inputs(); + if children.len() != necessary_children_indices.len() { + return internal_err!("Inconsistent length between children and necessary children indices. \ + Make sure `.necessary_children_exprs` implementation of the `UserDefinedLogicalNode` is \ + consistent with actual children length for the node."); + } + children + .into_iter() + .zip(necessary_children_indices) + .map(|(child, necessary_indices)| { + RequiredIndicies::new_from_indices(necessary_indices) + .with_plan_exprs(&plan, child.schema()) + }) + .collect::>>()? + } + LogicalPlan::EmptyRelation(_) + | LogicalPlan::RecursiveQuery(_) + | LogicalPlan::Statement(_) + | LogicalPlan::Values(_) + | LogicalPlan::DescribeTable(_) + | LogicalPlan::Execute(_) => { + // These operators have no inputs, so stop the optimization process. + return Ok(Transformed::no(plan)); + } + LogicalPlan::Join(join) => { + let left_len = join.left.schema().fields().len(); + let (left_req_indices, right_req_indices) = + split_join_requirements(left_len, indices, &join.join_type); + let left_indices = left_req_indices.with_plan_exprs(&plan, join.left.schema())?; + let right_indices = right_req_indices.with_plan_exprs(&plan, join.right.schema())?; + // Joins benefit from "small" input tables (lower memory usage). + // Therefore, each child benefits from projection: + vec![ + left_indices.with_projection_beneficial(), + right_indices.with_projection_beneficial(), + ] + } + // these nodes are explicitly rewritten in the match statement above + LogicalPlan::Projection(_) + | LogicalPlan::Aggregate(_) + | LogicalPlan::Window(_) + | LogicalPlan::TableScan(_) => { + return internal_err!( + "OptimizeProjection: should have handled in the match statement above" + ); + } + LogicalPlan::Unnest(Unnest { + dependency_indices, .. + }) => { + vec![RequiredIndicies::new_from_indices( + dependency_indices.clone(), + )] + } + }; + + // Required indices are currently ordered (child0, child1, ...) + // but the loop pops off the last element, so we need to reverse the order + child_required_indices.reverse(); + if child_required_indices.len() != plan.inputs().len() { + return internal_err!( + "OptimizeProjection: child_required_indices length mismatch with plan inputs" + ); + } + + // Rewrite children of the plan + let transformed_plan = plan.map_children(|child| { + let required_indices = child_required_indices.pop().ok_or_else(|| { + internal_datafusion_err!( + "Unexpected number of required_indices in OptimizeProjections rule" + ) + })?; + + let projection_beneficial = required_indices.projection_beneficial(); + let project_exprs = required_indices.get_required_exprs(child.schema()); + + optimize_projections(child, config, required_indices)?.transform_data(|new_input| { + if projection_beneficial { + add_projection_on_top_if_helpful(new_input, project_exprs, config) + } else { + Ok(Transformed::no(new_input)) + } + }) + })?; + + // If any of the children are transformed, we need to potentially update the plan's schema + if transformed_plan.transformed { + transformed_plan.map_data(|plan| plan.recompute_schema()) + } else { + Ok(transformed_plan) + } +} + +/// Merges consecutive projections. +/// +/// Given a projection `proj`, this function attempts to merge it with a previous +/// projection if it exists and if merging is beneficial. Merging is considered +/// beneficial when expressions in the current projection are non-trivial and +/// appear more than once in its input fields. This can act as a caching mechanism +/// for non-trivial computations. +/// +/// # Parameters +/// +/// * `proj` - A reference to the `Projection` to be merged. +/// +/// # Returns +/// +/// A `Result` object with the following semantics: +/// +/// - `Ok(Some(Projection))`: Merge was beneficial and successful. Contains the +/// merged projection. +/// - `Ok(None)`: Signals that merge is not beneficial (and has not taken place). +/// - `Err(error)`: An error occured during the function call. +fn merge_consecutive_projections(proj: Projection) -> Result> { + let Projection { + expr, + input, + schema, + .. + } = proj; + let LogicalPlan::Projection(prev_projection) = input.as_ref() else { + return Projection::try_new_with_schema(expr, input, schema).map(Transformed::no); + }; + + // Count usages (referrals) of each projection expression in its input fields: + let mut column_referral_map = HashMap::<&Column, usize>::new(); + expr.iter() + .for_each(|expr| expr.add_column_ref_counts(&mut column_referral_map)); + + // If an expression is non-trivial and appears more than once, do not merge + // them as consecutive projections will benefit from a compute-once approach. + // For details, see: https://github.com/apache/datafusion/issues/8296 + if column_referral_map.into_iter().any(|(col, usage)| { + usage > 1 + && !is_expr_trivial( + &prev_projection.expr[prev_projection.schema.index_of_column(col).unwrap()], + ) + }) { + // no change + return Projection::try_new_with_schema(expr, input, schema).map(Transformed::no); + } + + let LogicalPlan::Projection(prev_projection) = Arc::unwrap_or_clone(input) else { + // We know it is a `LogicalPlan::Projection` from check above + unreachable!(); + }; + + // Try to rewrite the expressions in the current projection using the + // previous projection as input: + let name_preserver = NamePreserver::new_for_projection(); + let mut original_names = vec![]; + let new_exprs = expr.into_iter().map_until_stop_and_collect(|expr| { + original_names.push(name_preserver.save(&expr)); + + // do not rewrite top level Aliases (rewriter will remove all aliases within exprs) + match expr { + Expr::Alias(Alias { + expr, + relation, + name, + }) => rewrite_expr(*expr, &prev_projection).map(|result| { + result.update_data(|expr| Expr::Alias(Alias::new(expr, relation, name))) + }), + e => rewrite_expr(e, &prev_projection), + } + })?; + + // if the expressions could be rewritten, create a new projection with the + // new expressions + if new_exprs.transformed { + // Add any needed aliases back to the expressions + let new_exprs = new_exprs + .data + .into_iter() + .zip(original_names) + .map(|(expr, original_name)| original_name.restore(expr)) + .collect::>(); + Projection::try_new(new_exprs, prev_projection.input).map(Transformed::yes) + } else { + // not rewritten, so put the projection back together + let input = Arc::new(LogicalPlan::Projection(prev_projection)); + Projection::try_new_with_schema(new_exprs.data, input, schema).map(Transformed::no) + } +} + +// Check whether `expr` is trivial; i.e. it doesn't imply any computation. +fn is_expr_trivial(expr: &Expr) -> bool { + matches!(expr, Expr::Column(_) | Expr::Literal(_)) +} + +/// Rewrites a projection expression using the projection before it (i.e. its input) +/// This is a subroutine to the `merge_consecutive_projections` function. +/// +/// # Parameters +/// +/// * `expr` - A reference to the expression to rewrite. +/// * `input` - A reference to the input of the projection expression (itself +/// a projection). +/// +/// # Returns +/// +/// A `Result` object with the following semantics: +/// +/// - `Ok(Some(Expr))`: Rewrite was successful. Contains the rewritten result. +/// - `Ok(None)`: Signals that `expr` can not be rewritten. +/// - `Err(error)`: An error occurred during the function call. +/// +/// # Notes +/// This rewrite also removes any unnecessary layers of aliasing. +/// +/// Without trimming, we can end up with unnecessary indirections inside expressions +/// during projection merges. +/// +/// Consider: +/// +/// ```text +/// Projection(a1 + b1 as sum1) +/// --Projection(a as a1, b as b1) +/// ----Source(a, b) +/// ``` +/// +/// After merge, we want to produce: +/// +/// ```text +/// Projection(a + b as sum1) +/// --Source(a, b) +/// ``` +/// +/// Without trimming, we would end up with: +/// +/// ```text +/// Projection((a as a1 + b as b1) as sum1) +/// --Source(a, b) +/// ``` +fn rewrite_expr(expr: Expr, input: &Projection) -> Result> { + expr.transform_up(|expr| { + match expr { + // remove any intermediate aliases + Expr::Alias(alias) => Ok(Transformed::yes(*alias.expr)), + Expr::Column(col) => { + // Find index of column: + let idx = input.schema.index_of_column(&col)?; + // get the corresponding unaliased input expression + // + // For example: + // * the input projection is [`a + b` as c, `d + e` as f] + // * the current column is an expression "f" + // + // return the expression `d + e` (not `d + e` as f) + let input_expr = input.expr[idx].clone().unalias_nested().data; + Ok(Transformed::yes(input_expr)) + } + // Unsupported type for consecutive projection merge analysis. + _ => Ok(Transformed::no(expr)), + } + }) +} + +/// Accumulates outer-referenced columns by the +/// given expression, `expr`. +/// +/// # Parameters +/// +/// * `expr` - The expression to analyze for outer-referenced columns. +/// * `columns` - A mutable reference to a `HashSet` where detected +/// columns are collected. +fn outer_columns<'a>(expr: &'a Expr, columns: &mut HashSet<&'a Column>) { + // inspect_expr_pre doesn't handle subquery references, so find them explicitly + expr.apply(|expr| { + match expr { + Expr::OuterReferenceColumn(_, col) => { + columns.insert(col); + } + Expr::ScalarSubquery(subquery) => { + outer_columns_helper_multi(&subquery.outer_ref_columns, columns); + } + Expr::Exists(exists) => { + outer_columns_helper_multi(&exists.subquery.outer_ref_columns, columns); + } + Expr::InSubquery(insubquery) => { + outer_columns_helper_multi(&insubquery.subquery.outer_ref_columns, columns); + } + _ => {} + }; + Ok(TreeNodeRecursion::Continue) + }) + // unwrap: closure above never returns Err, so can not be Err here + .unwrap(); +} + +/// A recursive subroutine that accumulates outer-referenced columns by the +/// given expressions (`exprs`). +/// +/// # Parameters +/// +/// * `exprs` - The expressions to analyze for outer-referenced columns. +/// * `columns` - A mutable reference to a `HashSet` where detected +/// columns are collected. +fn outer_columns_helper_multi<'a, 'b>( + exprs: impl IntoIterator, + columns: &'b mut HashSet<&'a Column>, +) { + exprs.into_iter().for_each(|e| outer_columns(e, columns)); +} + +/// Splits requirement indices for a join into left and right children based on +/// the join type. +/// +/// This function takes the length of the left child, a slice of requirement +/// indices, and the type of join (e.g. `INNER`, `LEFT`, `RIGHT`) as arguments. +/// Depending on the join type, it divides the requirement indices into those +/// that apply to the left child and those that apply to the right child. +/// +/// - For `INNER`, `LEFT`, `RIGHT` and `FULL` joins, the requirements are split +/// between left and right children. The right child indices are adjusted to +/// point to valid positions within the right child by subtracting the length +/// of the left child. +/// +/// - For `LEFT ANTI`, `LEFT SEMI`, `RIGHT SEMI` and `RIGHT ANTI` joins, all +/// requirements are re-routed to either the left child or the right child +/// directly, depending on the join type. +/// +/// # Parameters +/// +/// * `left_len` - The length of the left child. +/// * `indices` - A slice of requirement indices. +/// * `join_type` - The type of join (e.g. `INNER`, `LEFT`, `RIGHT`). +/// +/// # Returns +/// +/// A tuple containing two vectors of `usize` indices: The first vector represents +/// the requirements for the left child, and the second vector represents the +/// requirements for the right child. The indices are appropriately split and +/// adjusted based on the join type. +fn split_join_requirements( + left_len: usize, + indices: RequiredIndicies, + join_type: &JoinType, +) -> (RequiredIndicies, RequiredIndicies) { + match join_type { + // In these cases requirements are split between left/right children: + JoinType::Inner + | JoinType::Left + | JoinType::Right + | JoinType::Full + | JoinType::LeftMark => { + // Decrease right side indices by `left_len` so that they point to valid + // positions within the right child: + indices.split_off(left_len) + } + // All requirements can be re-routed to left child directly. + JoinType::LeftAnti | JoinType::LeftSemi => (indices, RequiredIndicies::new()), + // All requirements can be re-routed to right side directly. + // No need to change index, join schema is right child schema. + JoinType::RightSemi | JoinType::RightAnti => (RequiredIndicies::new(), indices), + } +} + +fn add_projection_on_top_if_helpful( + plan: LogicalPlan, + _project_exprs: Vec, + _config: &dyn OptimizerConfig, +) -> Result> { + // Always prefer the original query layout. + Ok(Transformed::no(plan)) +} + +/// Rewrite the given projection according to the fields required by its +/// ancestors. +/// +/// # Parameters +/// +/// * `proj` - A reference to the original projection to rewrite. +/// * `config` - A reference to the optimizer configuration. +/// * `indices` - A slice of indices representing the columns required by the +/// ancestors of the given projection. +/// +/// # Returns +/// +/// A `Result` object with the following semantics: +/// +/// - `Ok(Some(LogicalPlan))`: Contains the rewritten projection +/// - `Ok(None)`: No rewrite necessary. +/// - `Err(error)`: An error occured during the function call. +fn rewrite_projection_given_requirements( + proj: Projection, + config: &dyn OptimizerConfig, + indices: &RequiredIndicies, +) -> Result> { + let Projection { expr, input, .. } = proj; + + let exprs_used = indices.get_at_indices(&expr); + + let required_indices = RequiredIndicies::new().with_exprs(input.schema(), exprs_used.iter()); + + // rewrite the children projection, and if they are changed rewrite the + // projection down + optimize_projections(Arc::unwrap_or_clone(input), config, required_indices)?.transform_data( + |input| { + if is_projection_unnecessary(&input, &exprs_used, config)? { + Ok(Transformed::yes(input)) + } else { + Projection::try_new(exprs_used, Arc::new(input)) + .map(LogicalPlan::Projection) + .map(Transformed::yes) + } + }, + ) +} + +fn is_projection_unnecessary( + _input: &LogicalPlan, + _proj_exprs: &[Expr], + _config: &dyn OptimizerConfig, +) -> Result { + // Always prefer the original query layout. Return false to keep the projection. + Ok(false) +} + +#[cfg(test)] +mod tests { + use std::{collections::HashMap, sync::Arc}; + + use datafusion::{ + arrow::datatypes::{DataType, Field, Schema}, + common::{DFSchema, JoinType}, + error::DataFusionError, + functions_aggregate::{ + count::{count, count_udaf}, + min_max::max, + }, + logical_expr::{ + LogicalPlan, LogicalPlanBuilder, LogicalTableSource, TableSource, UNNAMED_TABLE, + }, + optimizer::{Optimizer, OptimizerContext, OptimizerRule}, + prelude::{col, lit, Expr, ExprFunctionExt}, + sql::TableReference, + }; + + type Result = std::result::Result; + + #[test] + fn aggregate_filter_pushdown_preserve_projection() -> Result<()> { + let table_scan = test_table_scan()?; + let aggr_with_filter = count_udaf() + .call(vec![col("b")]) + .filter(col("c").gt(lit(42))) + .build()?; + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate( + vec![col("a")], + vec![count(col("b")), aggr_with_filter.alias("count2")], + )? + .project(vec![col("a"), col("count(test.b)"), col("count2")])? + .build()?; + + let expected = "Projection: test.a, count(test.b), count2\ + \n Aggregate: groupBy=[[test.a]], aggr=[[count(test.b), count(test.b) FILTER (WHERE test.c > Int32(42)) AS count2]]\ + \n TableScan: test projection=[a, b, c]"; + + assert_optimized_plan_equal(plan, expected)?; + + Ok(()) + } + + // Selected tests and helpers from the original DataFusion implementation of `OptimizeProjections` rule to ensure + // that the modified version behaves as expected. + // See: + + #[test] + fn aggregate_no_group_by() -> Result<()> { + let table_scan = test_table_scan()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(Vec::::new(), vec![max(col("b"))])? + .build()?; + + let expected = "Aggregate: groupBy=[[]], aggr=[[max(test.b)]]\ + \n TableScan: test projection=[b]"; + + assert_optimized_plan_equal(plan, expected) + } + + #[test] + fn aggregate_group_by() -> Result<()> { + let table_scan = test_table_scan()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(vec![col("c")], vec![max(col("b"))])? + .build()?; + + let expected = "Aggregate: groupBy=[[test.c]], aggr=[[max(test.b)]]\ + \n TableScan: test projection=[b, c]"; + + assert_optimized_plan_equal(plan, expected) + } + + #[test] + fn aggregate_group_by_with_table_alias() -> Result<()> { + let table_scan = test_table_scan()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .alias("a")? + .aggregate(vec![col("c")], vec![max(col("b"))])? + .build()?; + + let expected = "Aggregate: groupBy=[[a.c]], aggr=[[max(a.b)]]\ + \n SubqueryAlias: a\ + \n TableScan: test projection=[b, c]"; + + assert_optimized_plan_equal(plan, expected) + } + + #[test] + fn aggregate_no_group_by_with_filter() -> Result<()> { + let table_scan = test_table_scan()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .filter(col("c").gt(lit(1)))? + .aggregate(Vec::::new(), vec![max(col("b"))])? + .build()?; + + let expected = "Aggregate: groupBy=[[]], aggr=[[max(test.b)]]\ + \n Filter: test.c > Int32(1)\ + \n TableScan: test projection=[b, c]"; + + assert_optimized_plan_equal(plan, expected) + } + + #[test] + fn join_schema_trim_using_join() -> Result<()> { + // shared join columns from using join should be pushed to both sides + + let table_scan = test_table_scan()?; + + let schema = Schema::new(vec![Field::new("a", DataType::UInt32, false)]); + let table2_scan = scan_empty(Some("test2"), &schema, None)?.build()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .join_using(table2_scan, JoinType::Left, vec!["a"])? + .project(vec![col("a"), col("b")])? + .build()?; + + // make sure projections are pushed down to table scan + let expected = "Projection: test.a, test.b\ + \n Left Join: Using test.a = test2.a\ + \n TableScan: test projection=[a, b]\ + \n TableScan: test2 projection=[a]"; + + let optimized_plan = optimize(plan)?; + let formatted_plan = format!("{optimized_plan}"); + assert_eq!(formatted_plan, expected); + + // make sure schema for join node include both join columns + let optimized_join = optimized_plan.inputs()[0]; + assert_eq!( + **optimized_join.schema(), + DFSchema::new_with_metadata( + vec![ + ( + Some("test".into()), + Arc::new(Field::new("a", DataType::UInt32, false)) + ), + ( + Some("test".into()), + Arc::new(Field::new("b", DataType::UInt32, false)) + ), + ( + Some("test2".into()), + Arc::new(Field::new("a", DataType::UInt32, true)) + ), + ], + HashMap::new() + )?, + ); + + Ok(()) + } + + #[test] + fn redundant_project() -> Result<()> { + let table_scan = test_table_scan()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a"), col("b"), col("c")])? + .project(vec![col("a"), col("c"), col("b")])? + .build()?; + let expected = "Projection: test.a, test.c, test.b\ + \n TableScan: test projection=[a, b, c]"; + + assert_optimized_plan_equal(plan, expected) + } + + #[test] + fn reorder_projection() -> Result<()> { + let table_scan = test_table_scan()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("c"), col("b"), col("a")])? + .build()?; + let expected = "Projection: test.c, test.b, test.a\ + \n TableScan: test projection=[a, b, c]"; + + assert_optimized_plan_equal(plan, expected) + } + + #[test] + fn table_scan_without_projection() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan).build()?; + // should expand projection to all columns without projection + let expected = "TableScan: test projection=[a, b, c]"; + assert_optimized_plan_equal(plan, expected) + } + + #[test] + fn table_unused_column() -> Result<()> { + let table_scan = test_table_scan()?; + assert_eq!(3, table_scan.schema().fields().len()); + assert_fields_eq(&table_scan, vec!["a", "b", "c"]); + + // we never use "b" in the first projection => remove it + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("c"), col("a"), col("b")])? + .filter(col("c").gt(lit(1)))? + .aggregate(vec![col("c")], vec![max(col("a"))])? + .build()?; + + assert_fields_eq(&plan, vec!["c", "max(test.a)"]); + + let plan = optimize(plan).expect("failed to optimize plan"); + let expected = "\ + Aggregate: groupBy=[[test.c]], aggr=[[max(test.a)]]\ + \n Filter: test.c > Int32(1)\ + \n Projection: test.c, test.a\ + \n TableScan: test projection=[a, c]"; + + assert_optimized_plan_equal(plan, expected) + } + + fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} + + fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { + assert_optimized_plan_eq(Arc::new(super::OptimizeProjections::new()), plan, expected) + } + + pub fn assert_optimized_plan_eq( + rule: Arc, + plan: LogicalPlan, + expected: &str, + ) -> Result<()> { + // Apply the rule once + let opt_context = OptimizerContext::new().with_max_passes(1); + + let optimizer = Optimizer::with_rules(vec![Arc::clone(&rule)]); + let optimized_plan = optimizer.optimize(plan, &opt_context, observe)?; + let formatted_plan = format!("{optimized_plan}"); + assert_eq!(formatted_plan, expected); + + Ok(()) + } + + fn optimize(plan: LogicalPlan) -> Result { + let optimizer = Optimizer::with_rules(vec![Arc::new(super::OptimizeProjections::new())]); + let optimized_plan = optimizer.optimize(plan, &OptimizerContext::new(), observe)?; + Ok(optimized_plan) + } + + pub fn assert_fields_eq(plan: &LogicalPlan, expected: Vec<&str>) { + let actual: Vec = plan + .schema() + .fields() + .iter() + .map(|f| f.name().clone()) + .collect(); + assert_eq!(actual, expected); + } + + pub fn test_table_scan_fields() -> Vec { + vec![ + Field::new("a", DataType::UInt32, false), + Field::new("b", DataType::UInt32, false), + Field::new("c", DataType::UInt32, false), + ] + } + + /// some tests share a common table with different names + pub fn test_table_scan_with_name(name: &str) -> Result { + let schema = Schema::new(test_table_scan_fields()); + table_scan(Some(name), &schema, None)?.build() + } + + /// some tests share a common table + pub fn test_table_scan() -> Result { + test_table_scan_with_name("test") + } + + /// Scan an empty data source, mainly used in tests + pub fn scan_empty( + name: Option<&str>, + table_schema: &Schema, + projection: Option>, + ) -> Result { + table_scan(name, table_schema, projection) + } + /// Create a LogicalPlanBuilder representing a scan of a table with the provided name and schema. + /// This is mostly used for testing and documentation. + pub fn table_scan( + name: Option>, + table_schema: &Schema, + projection: Option>, + ) -> Result { + table_scan_with_filters(name, table_schema, projection, vec![]) + } + + /// Create a LogicalPlanBuilder representing a scan of a table with the provided name and schema, + /// and inlined filters. + /// This is mostly used for testing and documentation. + pub fn table_scan_with_filters( + name: Option>, + table_schema: &Schema, + projection: Option>, + filters: Vec, + ) -> Result { + let table_source = table_source(table_schema); + let name = name + .map(|n| n.into()) + .unwrap_or_else(|| TableReference::bare(UNNAMED_TABLE)); + LogicalPlanBuilder::scan_with_filters(name, table_source, projection, filters) + } + + fn table_source(table_schema: &Schema) -> Arc { + let table_schema = Arc::new(table_schema.clone()); + Arc::new(LogicalTableSource::new(table_schema)) + } +} diff --git a/datafusion-federation/src/optimize/optimize_projections/required_indices.rs b/datafusion-federation/src/optimize/optimize_projections/required_indices.rs new file mode 100644 index 0000000..e6df841 --- /dev/null +++ b/datafusion-federation/src/optimize/optimize_projections/required_indices.rs @@ -0,0 +1,227 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// Copy of DataFusion's [`RequiredIndicies`](https://github.com/apache/datafusion/blob/main/datafusion/optimizer/src/optimize_projections/required_indices.rs) implementation. +use std::result; + +use super::outer_columns; +use datafusion::{ + common::{tree_node::TreeNodeRecursion, Column, DFSchemaRef}, + error::DataFusionError, + logical_expr::LogicalPlan, + prelude::Expr, +}; + +type Result = result::Result; + +/// Represents columns in a schema which are required (used) by a plan node +/// +/// Also carries a flag indicating if putting a projection above children is +/// beneficial for the parent. For example `LogicalPlan::Filter` benefits from +/// small tables. Hence for filter child this flag would be `true`. Defaults to +/// `false` +/// +/// # Invariant +/// +/// Indices are always in order and without duplicates. For example, if these +/// indices were added `[3, 2, 4, 3, 6, 1]`, the instance would be represented +/// by `[1, 2, 3, 6]`. +#[derive(Debug, Clone, Default)] +pub(super) struct RequiredIndicies { + /// The indices of the required columns in the + indices: Vec, + /// If putting a projection above children is beneficial for the parent. + /// Defaults to false. + projection_beneficial: bool, +} + +impl RequiredIndicies { + /// Create a new, empty instance + pub fn new() -> Self { + Self::default() + } + + /// Create a new instance that requires all columns from the specified plan + pub fn new_for_all_exprs(plan: &LogicalPlan) -> Self { + Self { + indices: (0..plan.schema().fields().len()).collect(), + projection_beneficial: false, + } + } + + /// Create a new instance with the specified indices as required + pub fn new_from_indices(indices: Vec) -> Self { + Self { + indices, + projection_beneficial: false, + } + .compact() + } + + /// Convert the instance to its inner indices + pub fn into_inner(self) -> Vec { + self.indices + } + + /// Set the projection beneficial flag + pub fn with_projection_beneficial(mut self) -> Self { + self.projection_beneficial = true; + self + } + + /// Return the value of projection beneficial flag + pub fn projection_beneficial(&self) -> bool { + self.projection_beneficial + } + + /// Return a reference to the underlying indices + pub fn indices(&self) -> &[usize] { + &self.indices + } + + /// Add required indices for all `exprs` used in plan + pub fn with_plan_exprs(mut self, plan: &LogicalPlan, schema: &DFSchemaRef) -> Result { + // Add indices of the child fields referred to by the expressions in the + // parent + plan.apply_expressions(|e| { + self.add_expr(schema, e); + Ok(TreeNodeRecursion::Continue) + })?; + Ok(self.compact()) + } + + /// Adds the indices of the fields referred to by the given expression + /// `expr` within the given schema (`input_schema`). + /// + /// Self is NOT compacted (and thus this method is not pub) + /// + /// # Parameters + /// + /// * `input_schema`: The input schema to analyze for index requirements. + /// * `expr`: An expression for which we want to find necessary field indices. + fn add_expr(&mut self, input_schema: &DFSchemaRef, expr: &Expr) { + // TODO could remove these clones (and visit the expression directly) + let mut cols = expr.column_refs(); + // Get outer-referenced (subquery) columns: + outer_columns(expr, &mut cols); + self.indices.reserve(cols.len()); + for col in cols { + if let Some(idx) = input_schema.maybe_index_of_column(col) { + self.indices.push(idx); + } + } + } + + /// Adds the indices of the fields referred to by the given expressions + /// `within the given schema. + /// + /// # Parameters + /// + /// * `input_schema`: The input schema to analyze for index requirements. + /// * `exprs`: the expressions for which we want to find field indices. + pub fn with_exprs<'a>( + self, + schema: &DFSchemaRef, + exprs: impl IntoIterator, + ) -> Self { + exprs + .into_iter() + .fold(self, |mut acc, expr| { + acc.add_expr(schema, expr); + acc + }) + .compact() + } + + /// Adds all `indices` into this instance. + pub fn append(mut self, indices: &[usize]) -> Self { + self.indices.extend_from_slice(indices); + self.compact() + } + + /// Splits this instance into a tuple with two instances: + /// * The first `n` indices + /// * The remaining indices, adjusted down by n + pub fn split_off(self, n: usize) -> (Self, Self) { + let (l, r) = self.partition(|idx| idx < n); + (l, r.map_indices(|idx| idx - n)) + } + + /// Partitions the indices in this instance into two groups based on the + /// given predicate function `f`. + fn partition(&self, f: F) -> (Self, Self) + where + F: Fn(usize) -> bool, + { + let (l, r): (Vec, Vec) = self.indices.iter().partition(|&&idx| f(idx)); + let projection_beneficial = self.projection_beneficial; + + ( + Self { + indices: l, + projection_beneficial, + }, + Self { + indices: r, + projection_beneficial, + }, + ) + } + + /// Map the indices in this instance to a new set of indices based on the + /// given function `f`, returning the mapped indices + /// + /// Not `pub` as it might not preserve the invariant of compacted indices + fn map_indices(mut self, f: F) -> Self + where + F: Fn(usize) -> usize, + { + self.indices.iter_mut().for_each(|idx| *idx = f(*idx)); + self + } + + /// Apply the given function `f` to each index in this instance, returning + /// the mapped indices + pub fn into_mapped_indices(self, f: F) -> Vec + where + F: Fn(usize) -> usize, + { + self.map_indices(f).into_inner() + } + + /// Returns the `Expr`s from `exprs` that are at the indices in this instance + pub fn get_at_indices(&self, exprs: &[Expr]) -> Vec { + self.indices.iter().map(|&idx| exprs[idx].clone()).collect() + } + + /// Generates the required expressions (columns) that reside at `indices` of + /// the given `input_schema`. + pub fn get_required_exprs(&self, input_schema: &DFSchemaRef) -> Vec { + self.indices + .iter() + .map(|&idx| Expr::from(Column::from(input_schema.qualified_field(idx)))) + .collect() + } + + /// Compacts the indices of this instance so they are sorted + /// (ascending) and deduplicated. + fn compact(mut self) -> Self { + self.indices.sort_unstable(); + self.indices.dedup(); + self + } +} diff --git a/datafusion-federation/src/plan_node.rs b/datafusion-federation/src/plan_node.rs index 35a9306..f938290 100644 --- a/datafusion-federation/src/plan_node.rs +++ b/datafusion-federation/src/plan_node.rs @@ -30,6 +30,12 @@ impl FederatedPlanNode { } } +impl PartialOrd for FederatedPlanNode { + fn partial_cmp(&self, other: &Self) -> Option { + self.plan.partial_cmp(&other.plan) + } +} + impl Debug for FederatedPlanNode { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { UserDefinedLogicalNodeCore::fmt_for_explain(self, f) @@ -54,21 +60,20 @@ impl UserDefinedLogicalNodeCore for FederatedPlanNode { } fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "Federated\n {:?}", self.plan) + write!(f, "Federated\n {}", self.plan) } - fn from_template(&self, exprs: &[Expr], inputs: &[LogicalPlan]) -> Self { + fn with_exprs_and_inputs(&self, exprs: Vec, inputs: Vec) -> Result { assert_eq!(inputs.len(), 0, "input size inconsistent"); assert_eq!(exprs.len(), 0, "expression size inconsistent"); - Self { + Ok(Self { plan: self.plan.clone(), - planner: self.planner.clone(), - } + planner: Arc::clone(&self.planner), + }) } } -#[derive(Default)] - +#[derive(Default, Debug)] pub struct FederatedQueryPlanner {} impl FederatedQueryPlanner { @@ -121,7 +126,7 @@ impl Hash for FederatedPlanNode { } #[derive(Default)] -struct FederatedPlanner {} +pub struct FederatedPlanner {} impl FederatedPlanner { pub fn new() -> Self { @@ -144,7 +149,7 @@ impl ExtensionPlanner for FederatedPlanner { assert_eq!(logical_inputs.len(), 0, "Inconsistent number of inputs"); assert_eq!(physical_inputs.len(), 0, "Inconsistent number of inputs"); - let fed_planner = fed_node.planner.clone(); + let fed_planner = Arc::clone(&fed_node.planner); let exec_plan = fed_planner.plan_federation(fed_node, session_state).await?; return Ok(Some(exec_plan)); } diff --git a/datafusion-federation/src/schema_cast.rs b/datafusion-federation/src/schema_cast.rs new file mode 100644 index 0000000..5b1ba5a --- /dev/null +++ b/datafusion-federation/src/schema_cast.rs @@ -0,0 +1,114 @@ +use async_stream::stream; +use datafusion::arrow::datatypes::SchemaRef; +use datafusion::error::{DataFusionError, Result}; +use datafusion::execution::{SendableRecordBatchStream, TaskContext}; +use datafusion::physical_plan::stream::RecordBatchStreamAdapter; +use datafusion::physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties, +}; +use futures::StreamExt; +use std::any::Any; +use std::clone::Clone; +use std::fmt; +use std::sync::Arc; + +mod intervals_cast; +mod lists_cast; +pub mod record_convert; +mod struct_cast; + +#[derive(Debug)] +#[allow(clippy::module_name_repetitions)] +pub struct SchemaCastScanExec { + input: Arc, + schema: SchemaRef, + properties: PlanProperties, +} + +impl SchemaCastScanExec { + pub fn new(input: Arc, schema: SchemaRef) -> Self { + let eq_properties = input.equivalence_properties().clone(); + let execution_mode = input.execution_mode(); + let properties = PlanProperties::new( + eq_properties, + input.output_partitioning().clone(), + execution_mode, + ); + Self { + input, + schema, + properties, + } + } +} + +impl DisplayAs for SchemaCastScanExec { + fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "SchemaCastScanExec") + } +} + +impl ExecutionPlan for SchemaCastScanExec { + fn name(&self) -> &str { + "SchemaCastScanExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.properties + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.input] + } + + /// Prevents the introduction of additional `RepartitionExec` and processing input in parallel. + /// This guarantees that the input is processed as a single stream, preserving the order of the data. + fn benefits_from_input_partitioning(&self) -> Vec { + vec![false] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + if children.len() == 1 { + Ok(Arc::new(Self::new( + Arc::clone(&children[0]), + Arc::clone(&self.schema), + ))) + } else { + Err(DataFusionError::Execution( + "SchemaCastScanExec expects exactly one input".to_string(), + )) + } + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + let mut stream = self.input.execute(partition, context)?; + let schema = Arc::clone(&self.schema); + + Ok(Box::pin(RecordBatchStreamAdapter::new( + Arc::clone(&schema), + { + stream! { + while let Some(batch) = stream.next().await { + let batch = record_convert::try_cast_to(batch?, Arc::clone(&schema)); + yield batch.map_err(|e| { DataFusionError::External(Box::new(e)) }); + } + } + }, + ))) + } +} diff --git a/datafusion-federation/src/schema_cast/intervals_cast.rs b/datafusion-federation/src/schema_cast/intervals_cast.rs new file mode 100644 index 0000000..629203e --- /dev/null +++ b/datafusion-federation/src/schema_cast/intervals_cast.rs @@ -0,0 +1,191 @@ +use datafusion::arrow::{ + array::{ + Array, ArrayRef, IntervalDayTimeBuilder, IntervalMonthDayNanoArray, + IntervalYearMonthBuilder, + }, + datatypes::{IntervalDayTimeType, IntervalYearMonthType}, + error::ArrowError, +}; +use std::sync::Arc; + +pub(crate) fn cast_interval_monthdaynano_to_yearmonth( + interval_monthdaynano_array: &dyn Array, +) -> Result { + let interval_monthdaynano_array = interval_monthdaynano_array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::CastError("Failed to cast IntervalMonthDayNanoArray: Unable to downcast to IntervalMonthDayNanoArray".to_string()) + })?; + + let mut interval_yearmonth_builder = + IntervalYearMonthBuilder::with_capacity(interval_monthdaynano_array.len()); + + for value in interval_monthdaynano_array { + match value { + None => interval_yearmonth_builder.append_null(), + Some(interval_monthdaynano_value) => { + if interval_monthdaynano_value.days != 0 + || interval_monthdaynano_value.nanoseconds != 0 + { + return Err(ArrowError::CastError( + "Failed to cast IntervalMonthDayNanoArray to IntervalYearMonthArray: Non-zero days or nanoseconds".to_string(), + )); + } + interval_yearmonth_builder.append_value(IntervalYearMonthType::make_value( + 0, + interval_monthdaynano_value.months, + )); + } + } + } + + Ok(Arc::new(interval_yearmonth_builder.finish())) +} + +#[allow(clippy::cast_possible_truncation)] +pub(crate) fn cast_interval_monthdaynano_to_daytime( + interval_monthdaynano_array: &dyn Array, +) -> Result { + let interval_monthdaynano_array = interval_monthdaynano_array + .as_any() + .downcast_ref::() + .ok_or_else(|| + ArrowError::CastError("Failed to cast IntervalMonthDayNanoArray: Unable to downcast to IntervalMonthDayNanoArray".to_string()))?; + + let mut interval_daytime_builder = + IntervalDayTimeBuilder::with_capacity(interval_monthdaynano_array.len()); + + for value in interval_monthdaynano_array { + match value { + None => interval_daytime_builder.append_null(), + Some(interval_monthdaynano_value) => { + if interval_monthdaynano_value.months != 0 { + return Err( + ArrowError::CastError("Failed to cast IntervalMonthDayNanoArray to IntervalDayTimeArray: Non-zero months".to_string()), + ); + } + interval_daytime_builder.append_value(IntervalDayTimeType::make_value( + interval_monthdaynano_value.days, + (interval_monthdaynano_value.nanoseconds / 1_000_000) as i32, + )); + } + } + } + Ok(Arc::new(interval_daytime_builder.finish())) +} + +#[cfg(test)] +mod test { + use datafusion::arrow::{ + array::{IntervalDayTimeArray, IntervalYearMonthArray, RecordBatch}, + datatypes::{ + DataType, Field, IntervalDayTime, IntervalMonthDayNano, IntervalUnit, Schema, SchemaRef, + }, + }; + + use crate::schema_cast::record_convert::try_cast_to; + + use super::*; + + fn input_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new( + "interval_daytime", + DataType::Interval(IntervalUnit::MonthDayNano), + false, + ), + Field::new( + "interval_monthday_nano", + DataType::Interval(IntervalUnit::MonthDayNano), + false, + ), + Field::new( + "interval_yearmonth", + DataType::Interval(IntervalUnit::MonthDayNano), + false, + ), + ])) + } + + fn output_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new( + "interval_daytime", + DataType::Interval(IntervalUnit::DayTime), + false, + ), + Field::new( + "interval_monthday_nano", + DataType::Interval(IntervalUnit::MonthDayNano), + false, + ), + Field::new( + "interval_yearmonth", + DataType::Interval(IntervalUnit::YearMonth), + false, + ), + ])) + } + + fn batch_input() -> RecordBatch { + let interval_daytime_array = IntervalMonthDayNanoArray::from(vec![ + IntervalMonthDayNano::new(0, 1, 1_000_000_000), + IntervalMonthDayNano::new(0, 33, 0), + IntervalMonthDayNano::new(0, 0, 43_200_000_000_000), + ]); + let interval_monthday_nano_array = IntervalMonthDayNanoArray::from(vec![ + IntervalMonthDayNano::new(1, 2, 1000), + IntervalMonthDayNano::new(12, 1, 0), + IntervalMonthDayNano::new(0, 0, 12 * 1000 * 1000), + ]); + let interval_yearmonth_array = IntervalMonthDayNanoArray::from(vec![ + IntervalMonthDayNano::new(2, 0, 0), + IntervalMonthDayNano::new(25, 0, 0), + IntervalMonthDayNano::new(-1, 0, 0), + ]); + + RecordBatch::try_new( + input_schema(), + vec![ + Arc::new(interval_daytime_array), + Arc::new(interval_monthday_nano_array), + Arc::new(interval_yearmonth_array), + ], + ) + .expect("Failed to created arrow interval record batch") + } + + fn batch_expected() -> RecordBatch { + let interval_daytime_array = IntervalDayTimeArray::from(vec![ + IntervalDayTime::new(1, 1000), + IntervalDayTime::new(33, 0), + IntervalDayTime::new(0, 12 * 60 * 60 * 1000), + ]); + let interval_monthday_nano_array = IntervalMonthDayNanoArray::from(vec![ + IntervalMonthDayNano::new(1, 2, 1000), + IntervalMonthDayNano::new(12, 1, 0), + IntervalMonthDayNano::new(0, 0, 12 * 1000 * 1000), + ]); + let interval_yearmonth_array = IntervalYearMonthArray::from(vec![2, 25, -1]); + + RecordBatch::try_new( + output_schema(), + vec![ + Arc::new(interval_daytime_array), + Arc::new(interval_monthday_nano_array), + Arc::new(interval_yearmonth_array), + ], + ) + .expect("Failed to created arrow interval record batch") + } + + #[test] + fn test_cast_interval_with_schema() { + let input_batch = batch_input(); + let expected = batch_expected(); + let actual = try_cast_to(input_batch, output_schema()).expect("cast should succeed"); + + assert_eq!(actual, expected); + } +} diff --git a/datafusion-federation/src/schema_cast/lists_cast.rs b/datafusion-federation/src/schema_cast/lists_cast.rs new file mode 100644 index 0000000..9a63b28 --- /dev/null +++ b/datafusion-federation/src/schema_cast/lists_cast.rs @@ -0,0 +1,619 @@ +use arrow_json::ReaderBuilder; +use datafusion::arrow::{ + array::{ + Array, ArrayRef, BooleanArray, BooleanBuilder, FixedSizeListBuilder, Float32Array, + Float32Builder, Float64Array, Float64Builder, Int16Array, Int16Builder, Int32Array, + Int32Builder, Int64Array, Int64Builder, Int8Array, Int8Builder, LargeListBuilder, + LargeStringArray, LargeStringBuilder, ListArray, ListBuilder, StringArray, StringBuilder, + }, + datatypes::{DataType, Field, FieldRef}, + error::ArrowError, +}; +use std::sync::Arc; + +pub type Result = std::result::Result; + +macro_rules! cast_string_to_list_array { + ($string_array:expr, $field_name:expr, $data_type:expr, $builder_type:expr, $primitive_type:ty) => {{ + let item_field = Arc::new(Field::new($field_name, $data_type, true)); + let mut list_builder = ListBuilder::with_capacity($builder_type, $string_array.len()) + .with_field(Arc::clone(&item_field)); + + let list_field = Arc::new(Field::new_list("i", item_field, true)); + let mut decoder = ReaderBuilder::new_with_field(Arc::clone(&list_field)) + .build_decoder() + .map_err(|e| ArrowError::CastError(format!("Failed to create decoder: {e}")))?; + + for value in $string_array { + match value { + None => list_builder.append_null(), + Some(string_value) => { + decoder.decode(string_value.as_bytes()).map_err(|e| { + ArrowError::CastError(format!("Failed to decode value: {e}")) + })?; + + if let Some(b) = decoder.flush().map_err(|e| { + ArrowError::CastError(format!("Failed to decode decoder: {e}")) + })? { + let list_array = b + .column(0) + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::CastError( + "Failed to decode value: unable to downcast to ListArray" + .to_string(), + ) + })?; + let primitive_array = list_array + .values() + .as_any() + .downcast_ref::<$primitive_type>() + .ok_or_else(|| { + ArrowError::CastError( + "Failed to decode value: unable to downcast to PrimitiveType" + .to_string(), + ) + })?; + primitive_array + .iter() + .for_each(|maybe_value| match maybe_value { + Some(value) => list_builder.values().append_value(value), + None => list_builder.values().append_null(), + }); + list_builder.append(true); + } + } + } + } + + Ok(Arc::new(list_builder.finish())) + }}; +} + +macro_rules! cast_string_to_large_list_array { + ($string_array:expr, $field_name:expr, $data_type:expr, $builder_type:expr, $primitive_type:ty) => {{ + let item_field = Arc::new(Field::new($field_name, $data_type, true)); + let mut list_builder = LargeListBuilder::with_capacity($builder_type, $string_array.len()) + .with_field(Arc::clone(&item_field)); + + let list_field = Arc::new(Field::new_list("i", item_field, true)); + let mut decoder = ReaderBuilder::new_with_field(Arc::clone(&list_field)) + .build_decoder() + .map_err(|e| ArrowError::CastError(format!("Failed to create decoder: {e}")))?; + + for value in $string_array { + match value { + None => list_builder.append_null(), + Some(string_value) => { + decoder.decode(string_value.as_bytes()).map_err(|e| { + ArrowError::CastError(format!("Failed to decode value: {e}")) + })?; + + if let Some(b) = decoder.flush().map_err(|e| { + ArrowError::CastError(format!("Failed to decode decoder: {e}")) + })? { + let list_array = b + .column(0) + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::CastError( + "Failed to decode value: unable to downcast to ListArray" + .to_string(), + ) + })?; + let primitive_array = list_array + .values() + .as_any() + .downcast_ref::<$primitive_type>() + .ok_or_else(|| { + ArrowError::CastError( + "Failed to decode value: unable to downcast to PrimitiveType" + .to_string(), + ) + })?; + primitive_array + .iter() + .for_each(|maybe_value| match maybe_value { + Some(value) => list_builder.values().append_value(value), + None => list_builder.values().append_null(), + }); + list_builder.append(true); + } + } + } + } + + Ok(Arc::new(list_builder.finish())) + }}; +} + +macro_rules! cast_string_to_fixed_size_list_array { + ($string_array:expr, $field_name:expr, $data_type:expr, $builder_type:expr, $primitive_type:ty, $value_length:expr) => {{ + let item_field = Arc::new(Field::new($field_name, $data_type, true)); + let mut list_builder = + FixedSizeListBuilder::with_capacity($builder_type, $value_length, $string_array.len()) + .with_field(Arc::clone(&item_field)); + + let list_field = Arc::new(Field::new_list("i", item_field, true)); + let mut decoder = ReaderBuilder::new_with_field(Arc::clone(&list_field)) + .build_decoder() + .map_err(|e| ArrowError::CastError(format!("Failed to create decoder: {e}")))?; + + for value in $string_array { + match value { + None => { + for _ in 0..$value_length { + list_builder.values().append_null() + } + list_builder.append(true) + } + Some(string_value) => { + decoder.decode(string_value.as_bytes()).map_err(|e| { + ArrowError::CastError(format!("Failed to decode value: {e}")) + })?; + + if let Some(b) = decoder.flush().map_err(|e| { + ArrowError::CastError(format!("Failed to decode decoder: {e}")) + })? { + let list_array = b + .column(0) + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::CastError( + "Failed to decode value: unable to downcast to ListArray" + .to_string(), + ) + })?; + let primitive_array = list_array + .values() + .as_any() + .downcast_ref::<$primitive_type>() + .ok_or_else(|| { + ArrowError::CastError( + "Failed to decode value: unable to downcast to PrimitiveType" + .to_string(), + ) + })?; + primitive_array + .iter() + .for_each(|maybe_value| match maybe_value { + Some(value) => list_builder.values().append_value(value), + None => list_builder.values().append_null(), + }); + list_builder.append(true); + } + } + } + } + + Ok(Arc::new(list_builder.finish())) + }}; +} + +pub(crate) fn cast_string_to_list( + array: &dyn Array, + list_item_field: &FieldRef, +) -> Result { + let string_array = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::CastError( + "Failed to decode value: unable to downcast to StringArray".to_string(), + ) + })?; + + let field_name = list_item_field.name(); + + match list_item_field.data_type() { + DataType::Utf8 => { + cast_string_to_list_array!( + string_array, + field_name, + DataType::Utf8, + StringBuilder::new(), + StringArray + ) + } + DataType::LargeUtf8 => { + cast_string_to_list_array!( + string_array, + field_name, + DataType::LargeUtf8, + LargeStringBuilder::new(), + LargeStringArray + ) + } + DataType::Boolean => { + cast_string_to_list_array!( + string_array, + field_name, + DataType::Boolean, + BooleanBuilder::new(), + BooleanArray + ) + } + DataType::Int8 => { + cast_string_to_list_array!( + string_array, + field_name, + DataType::Int8, + Int8Builder::new(), + Int8Array + ) + } + DataType::Int16 => { + cast_string_to_list_array!( + string_array, + field_name, + DataType::Int16, + Int16Builder::new(), + Int16Array + ) + } + DataType::Int32 => { + cast_string_to_list_array!( + string_array, + field_name, + DataType::Int32, + Int32Builder::new(), + Int32Array + ) + } + DataType::Int64 => { + cast_string_to_list_array!( + string_array, + field_name, + DataType::Int64, + Int64Builder::new(), + Int64Array + ) + } + DataType::Float32 => { + cast_string_to_list_array!( + string_array, + field_name, + DataType::Float32, + Float32Builder::new(), + Float32Array + ) + } + DataType::Float64 => { + cast_string_to_list_array!( + string_array, + field_name, + DataType::Float64, + Float64Builder::new(), + Float64Array + ) + } + _ => Err(ArrowError::CastError(format!( + "Unsupported list item type: {}", + list_item_field.data_type() + ))), + } +} + +pub(crate) fn cast_string_to_large_list( + array: &dyn Array, + list_item_field: &FieldRef, +) -> Result { + let string_array = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::CastError( + "Failed to decode value: unable to downcast to StringArray".to_string(), + ) + })?; + + let field_name = list_item_field.name(); + + match list_item_field.data_type() { + DataType::Utf8 => { + cast_string_to_large_list_array!( + string_array, + field_name, + DataType::Utf8, + StringBuilder::new(), + StringArray + ) + } + DataType::LargeUtf8 => { + cast_string_to_large_list_array!( + string_array, + field_name, + DataType::LargeUtf8, + LargeStringBuilder::new(), + LargeStringArray + ) + } + DataType::Boolean => { + cast_string_to_large_list_array!( + string_array, + field_name, + DataType::Boolean, + BooleanBuilder::new(), + BooleanArray + ) + } + DataType::Int8 => { + cast_string_to_large_list_array!( + string_array, + field_name, + DataType::Int8, + Int8Builder::new(), + Int8Array + ) + } + DataType::Int16 => { + cast_string_to_large_list_array!( + string_array, + field_name, + DataType::Int16, + Int16Builder::new(), + Int16Array + ) + } + DataType::Int32 => { + cast_string_to_large_list_array!( + string_array, + field_name, + DataType::Int32, + Int32Builder::new(), + Int32Array + ) + } + DataType::Int64 => { + cast_string_to_large_list_array!( + string_array, + field_name, + DataType::Int64, + Int64Builder::new(), + Int64Array + ) + } + DataType::Float32 => { + cast_string_to_large_list_array!( + string_array, + field_name, + DataType::Float32, + Float32Builder::new(), + Float32Array + ) + } + DataType::Float64 => { + cast_string_to_large_list_array!( + string_array, + field_name, + DataType::Float64, + Float64Builder::new(), + Float64Array + ) + } + _ => Err(ArrowError::CastError(format!( + "Unsupported list item type: {}", + list_item_field.data_type() + ))), + } +} + +pub(crate) fn cast_string_to_fixed_size_list( + array: &dyn Array, + list_item_field: &FieldRef, + value_length: i32, +) -> Result { + let string_array = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::CastError( + "Failed to decode value: unable to downcast to StringArray".to_string(), + ) + })?; + + let field_name = list_item_field.name(); + + match list_item_field.data_type() { + DataType::Utf8 => { + cast_string_to_fixed_size_list_array!( + string_array, + field_name, + DataType::Utf8, + StringBuilder::new(), + StringArray, + value_length + ) + } + DataType::LargeUtf8 => { + cast_string_to_fixed_size_list_array!( + string_array, + field_name, + DataType::LargeUtf8, + LargeStringBuilder::new(), + LargeStringArray, + value_length + ) + } + DataType::Boolean => { + cast_string_to_fixed_size_list_array!( + string_array, + field_name, + DataType::Boolean, + BooleanBuilder::new(), + BooleanArray, + value_length + ) + } + DataType::Int8 => { + cast_string_to_fixed_size_list_array!( + string_array, + field_name, + DataType::Int8, + Int8Builder::new(), + Int8Array, + value_length + ) + } + DataType::Int16 => { + cast_string_to_fixed_size_list_array!( + string_array, + field_name, + DataType::Int16, + Int16Builder::new(), + Int16Array, + value_length + ) + } + DataType::Int32 => { + cast_string_to_fixed_size_list_array!( + string_array, + field_name, + DataType::Int32, + Int32Builder::new(), + Int32Array, + value_length + ) + } + DataType::Int64 => { + cast_string_to_fixed_size_list_array!( + string_array, + field_name, + DataType::Int64, + Int64Builder::new(), + Int64Array, + value_length + ) + } + DataType::Float32 => { + cast_string_to_fixed_size_list_array!( + string_array, + field_name, + DataType::Float32, + Float32Builder::new(), + Float32Array, + value_length + ) + } + DataType::Float64 => { + cast_string_to_fixed_size_list_array!( + string_array, + field_name, + DataType::Float64, + Float64Builder::new(), + Float64Array, + value_length + ) + } + _ => Err(ArrowError::CastError(format!( + "Unsupported list item type: {}", + list_item_field.data_type() + ))), + } +} + +#[cfg(test)] +mod test { + use datafusion::arrow::{ + array::{RecordBatch, StringArray}, + datatypes::{DataType, Field, Schema, SchemaRef}, + }; + + use crate::schema_cast::record_convert::try_cast_to; + + use super::*; + + fn input_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + Field::new("c", DataType::Utf8, false), + ])) + } + + fn output_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new( + "a", + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + false, + ), + Field::new( + "b", + DataType::LargeList(Arc::new(Field::new("item", DataType::Utf8, true))), + false, + ), + Field::new( + "c", + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Boolean, true)), 3), + false, + ), + ])) + } + + fn batch_input() -> RecordBatch { + RecordBatch::try_new( + input_schema(), + vec![ + Arc::new(StringArray::from(vec![ + Some("[1, 2, 3]"), + Some("[4, 5, 6]"), + ])), + Arc::new(StringArray::from(vec![ + Some("[\"foo\", \"bar\"]"), + Some("[\"baz\", \"qux\"]"), + ])), + Arc::new(StringArray::from(vec![ + Some("[true, false, true]"), + Some("[false, true, false]"), + ])), + ], + ) + .expect("record batch should not panic") + } + + fn batch_expected() -> RecordBatch { + let mut list_builder = ListBuilder::new(Int32Builder::new()); + list_builder.append_value([Some(1), Some(2), Some(3)]); + list_builder.append_value([Some(4), Some(5), Some(6)]); + let list_array = list_builder.finish(); + + let mut large_list_builder = LargeListBuilder::new(StringBuilder::new()); + large_list_builder.append_value([Some("foo"), Some("bar")]); + large_list_builder.append_value([Some("baz"), Some("qux")]); + let large_list_array = large_list_builder.finish(); + + let mut fixed_size_list_builder = FixedSizeListBuilder::new(BooleanBuilder::new(), 3); + fixed_size_list_builder.values().append_value(true); + fixed_size_list_builder.values().append_value(false); + fixed_size_list_builder.values().append_value(true); + fixed_size_list_builder.append(true); + fixed_size_list_builder.values().append_value(false); + fixed_size_list_builder.values().append_value(true); + fixed_size_list_builder.values().append_value(false); + fixed_size_list_builder.append(true); + let fixed_size_list_array = fixed_size_list_builder.finish(); + + RecordBatch::try_new( + output_schema(), + vec![ + Arc::new(list_array), + Arc::new(large_list_array), + Arc::new(fixed_size_list_array), + ], + ) + .expect("Failed to create expected RecordBatch") + } + + #[test] + fn test_cast_to_list_largelist_fixedsizelist() { + let input_batch = batch_input(); + let expected = batch_expected(); + let actual = try_cast_to(input_batch, output_schema()).expect("cast should succeed"); + + assert_eq!(actual, expected); + } +} diff --git a/datafusion-federation/src/schema_cast/record_convert.rs b/datafusion-federation/src/schema_cast/record_convert.rs new file mode 100644 index 0000000..a20401a --- /dev/null +++ b/datafusion-federation/src/schema_cast/record_convert.rs @@ -0,0 +1,156 @@ +use datafusion::arrow::{ + array::{Array, RecordBatch}, + compute::cast, + datatypes::{DataType, IntervalUnit, SchemaRef}, +}; +use std::sync::Arc; + +use super::{ + intervals_cast::{ + cast_interval_monthdaynano_to_daytime, cast_interval_monthdaynano_to_yearmonth, + }, + lists_cast::{cast_string_to_fixed_size_list, cast_string_to_large_list, cast_string_to_list}, + struct_cast::cast_string_to_struct, +}; + +pub type Result = std::result::Result; + +#[derive(Debug)] +pub enum Error { + UnableToConvertRecordBatch { + source: datafusion::arrow::error::ArrowError, + }, + + UnexpectedNumberOfColumns { + expected: usize, + found: usize, + }, +} + +impl std::error::Error for Error {} + +impl std::fmt::Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + Error::UnableToConvertRecordBatch { source } => { + write!(f, "Unable to convert record batch: {}", source) + } + Error::UnexpectedNumberOfColumns { expected, found } => { + write!( + f, + "Unexpected number of columns. Expected: {}, Found: {}", + expected, found + ) + } + } + } +} + +/// Cast a given record batch into a new record batch with the given schema. +/// It assumes the record batch columns are correctly ordered. +#[allow(clippy::needless_pass_by_value)] +pub fn try_cast_to(record_batch: RecordBatch, expected_schema: SchemaRef) -> Result { + let actual_schema = record_batch.schema(); + + if actual_schema.fields().len() != expected_schema.fields().len() { + return Err(Error::UnexpectedNumberOfColumns { + expected: expected_schema.fields().len(), + found: actual_schema.fields().len(), + }); + } + + let cols = expected_schema + .fields() + .iter() + .enumerate() + .map(|(i, expected_field)| { + let record_batch_col = record_batch.column(i); + + match (record_batch_col.data_type(), expected_field.data_type()) { + (DataType::Utf8, DataType::List(item_type)) => { + cast_string_to_list(&Arc::clone(record_batch_col), item_type) + .map_err(|e| Error::UnableToConvertRecordBatch { source: e }) + } + (DataType::Utf8, DataType::LargeList(item_type)) => { + cast_string_to_large_list(&Arc::clone(record_batch_col), item_type) + .map_err(|e| Error::UnableToConvertRecordBatch { source: e }) + } + (DataType::Utf8, DataType::FixedSizeList(item_type, value_length)) => { + cast_string_to_fixed_size_list( + &Arc::clone(record_batch_col), + item_type, + *value_length, + ) + .map_err(|e| Error::UnableToConvertRecordBatch { source: e }) + } + (DataType::Utf8, DataType::Struct(_)) => { + cast_string_to_struct(&Arc::clone(record_batch_col), expected_field.clone()) + .map_err(|e| Error::UnableToConvertRecordBatch { source: e }) + } + ( + DataType::Interval(IntervalUnit::MonthDayNano), + DataType::Interval(IntervalUnit::YearMonth), + ) => cast_interval_monthdaynano_to_yearmonth(&Arc::clone(record_batch_col)) + .map_err(|e| Error::UnableToConvertRecordBatch { source: e }), + ( + DataType::Interval(IntervalUnit::MonthDayNano), + DataType::Interval(IntervalUnit::DayTime), + ) => cast_interval_monthdaynano_to_daytime(&Arc::clone(record_batch_col)) + .map_err(|e| Error::UnableToConvertRecordBatch { source: e }), + _ => cast(&Arc::clone(record_batch_col), expected_field.data_type()) + .map_err(|e| Error::UnableToConvertRecordBatch { source: e }), + } + }) + .collect::>>>()?; + + RecordBatch::try_new(expected_schema, cols) + .map_err(|e| Error::UnableToConvertRecordBatch { source: e }) +} + +#[cfg(test)] +mod test { + use datafusion::arrow::{ + array::{Int32Array, StringArray}, + datatypes::{DataType, Field, Schema, TimeUnit}, + }; + + use super::*; + + fn schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, false), + Field::new("c", DataType::Utf8, false), + ])) + } + + fn to_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, false), + Field::new("b", DataType::LargeUtf8, false), + Field::new("c", DataType::Timestamp(TimeUnit::Microsecond, None), false), + ])) + } + + fn batch_input() -> RecordBatch { + RecordBatch::try_new( + schema(), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(StringArray::from(vec!["foo", "bar", "baz"])), + Arc::new(StringArray::from(vec![ + "2024-01-13 03:18:09.000000", + "2024-01-13 03:18:09", + "2024-01-13 03:18:09.000", + ])), + ], + ) + .expect("record batch should not panic") + } + + #[test] + fn test_string_to_timestamp_conversion() { + let result = try_cast_to(batch_input(), to_schema()).expect("converted"); + assert_eq!(3, result.num_rows()); + } +} diff --git a/datafusion-federation/src/schema_cast/struct_cast.rs b/datafusion-federation/src/schema_cast/struct_cast.rs new file mode 100644 index 0000000..5b6c1f0 --- /dev/null +++ b/datafusion-federation/src/schema_cast/struct_cast.rs @@ -0,0 +1,169 @@ +use arrow_json::ReaderBuilder; +use datafusion::arrow::{ + array::{Array, ArrayRef, StringArray}, + datatypes::Field, + error::ArrowError, +}; +use std::sync::Arc; + +pub type Result = std::result::Result; + +pub(crate) fn cast_string_to_struct( + array: &dyn Array, + struct_field: Arc, +) -> Result { + let string_array = array + .as_any() + .downcast_ref::() + .ok_or_else(|| ArrowError::CastError("Failed to downcast to StringArray".to_string()))?; + + let mut decoder = ReaderBuilder::new_with_field(struct_field) + .build_decoder() + .map_err(|e| ArrowError::CastError(format!("Failed to create JSON decoder: {e}")))?; + + for value in string_array { + match value { + Some(v) => { + decoder.decode(v.as_bytes()).map_err(|e| { + ArrowError::CastError(format!("Failed to decode struct array: {e}")) + })?; + } + None => { + decoder.decode("null".as_bytes()).map_err(|e| { + ArrowError::CastError(format!("Failed to decode struct array: {e}")) + })?; + } + } + } + + let record = match decoder.flush() { + Ok(Some(record)) => record, + Ok(None) => { + return Err(ArrowError::CastError( + "Failed to flush decoder: No record".to_string(), + )); + } + Err(e) => { + return Err(ArrowError::CastError(format!( + "Failed to decode struct array: {e}" + ))); + } + }; + // struct_field is single struct column + Ok(Arc::clone(record.column(0))) +} + +#[cfg(test)] +mod test { + use datafusion::arrow::{ + array::{Int32Builder, RecordBatch, StringArray, StringBuilder, StructBuilder}, + datatypes::{DataType, Field, Schema, SchemaRef}, + }; + + use crate::schema_cast::record_convert::try_cast_to; + + use super::*; + + fn input_schema() -> SchemaRef { + Arc::new(Schema::new(vec![Field::new( + "struct_string", + DataType::Utf8, + true, + )])) + } + + fn output_schema() -> SchemaRef { + Arc::new(Schema::new(vec![Field::new( + "struct", + DataType::Struct( + vec![ + Field::new("name", DataType::Utf8, false), + Field::new("age", DataType::Int32, false), + ] + .into(), + ), + true, + )])) + } + + fn batch_input() -> RecordBatch { + RecordBatch::try_new( + input_schema(), + vec![Arc::new(StringArray::from(vec![ + Some(r#"{"name":"John","age":30}"#), + None, + None, + Some(r#"{"name":"Jane","age":25}"#), + ]))], + ) + .expect("record batch should not panic") + } + + fn batch_expected() -> RecordBatch { + let name_field = Field::new("name", DataType::Utf8, false); + let age_field = Field::new("age", DataType::Int32, false); + + let mut struct_builder = StructBuilder::new( + vec![name_field, age_field], + vec![ + Box::new(StringBuilder::new()), + Box::new(Int32Builder::new()), + ], + ); + + struct_builder + .field_builder::(0) + .expect("should return field builder") + .append_value("John"); + struct_builder + .field_builder::(1) + .expect("should return field builder") + .append_value(30); + struct_builder.append(true); + + struct_builder + .field_builder::(0) + .expect("should return field builder") + .append_null(); + struct_builder + .field_builder::(1) + .expect("should return field builder") + .append_null(); + struct_builder.append(false); + + struct_builder + .field_builder::(0) + .expect("should return field builder") + .append_null(); + struct_builder + .field_builder::(1) + .expect("should return field builder") + .append_null(); + struct_builder.append(false); + + struct_builder + .field_builder::(0) + .expect("should return field builder") + .append_value("Jane"); + struct_builder + .field_builder::(1) + .expect("should return field builder") + .append_value(25); + struct_builder.append(true); + + let struct_array = struct_builder.finish(); + + RecordBatch::try_new(output_schema(), vec![Arc::new(struct_array)]) + .expect("Failed to create expected RecordBatch") + } + + #[test] + fn test_cast_to_struct() { + let input_batch = batch_input(); + let expected = batch_expected(); + + let actual = try_cast_to(input_batch, output_schema()).expect("cast should succeed"); + + assert_eq!(actual, expected); + } +} diff --git a/datafusion-federation/src/table_provider.rs b/datafusion-federation/src/table_provider.rs index 6a9afaa..93eb0aa 100644 --- a/datafusion-federation/src/table_provider.rs +++ b/datafusion-federation/src/table_provider.rs @@ -1,13 +1,15 @@ -use std::{any::Any, sync::Arc}; +use std::{any::Any, borrow::Cow, sync::Arc}; use async_trait::async_trait; use datafusion::{ arrow::datatypes::SchemaRef, + catalog::Session, common::Constraints, datasource::TableProvider, error::{DataFusionError, Result}, - execution::context::SessionState, - logical_expr::{Expr, LogicalPlan, TableSource, TableType}, + logical_expr::{ + dml::InsertOp, Expr, LogicalPlan, TableProviderFilterPushDown, TableSource, TableType, + }, physical_plan::ExecutionPlan, }; @@ -17,11 +19,36 @@ use crate::FederationProvider; // from a TableScan. This wrapper may be avoidable. pub struct FederatedTableProviderAdaptor { pub source: Arc, + pub table_provider: Option>, +} + +impl std::fmt::Debug for FederatedTableProviderAdaptor { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("FederatedTableProviderAdaptor") + .field("table_provider", &self.table_provider) + .finish() + } } impl FederatedTableProviderAdaptor { pub fn new(source: Arc) -> Self { - Self { source } + Self { + source, + table_provider: None, + } + } + + /// Creates a new FederatedTableProviderAdaptor that falls back to the + /// provided TableProvider. This is useful if used within a DataFusion + /// context without the federation optimizer. + pub fn new_with_provider( + source: Arc, + table_provider: Arc, + ) -> Self { + Self { + source, + table_provider: Some(table_provider), + } } } @@ -31,34 +58,92 @@ impl TableProvider for FederatedTableProviderAdaptor { self } fn schema(&self) -> SchemaRef { + if let Some(table_provider) = &self.table_provider { + return table_provider.schema(); + } + self.source.schema() } fn constraints(&self) -> Option<&Constraints> { + if let Some(table_provider) = &self.table_provider { + return table_provider + .constraints() + .or_else(|| self.source.constraints()); + } + self.source.constraints() } fn table_type(&self) -> TableType { + if let Some(table_provider) = &self.table_provider { + return table_provider.table_type(); + } + self.source.table_type() } - fn get_logical_plan(&self) -> Option<&LogicalPlan> { + fn get_logical_plan(&self) -> Option> { + if let Some(table_provider) = &self.table_provider { + return table_provider + .get_logical_plan() + .or_else(|| self.source.get_logical_plan()); + } + self.source.get_logical_plan() } fn get_column_default(&self, column: &str) -> Option<&Expr> { + if let Some(table_provider) = &self.table_provider { + return table_provider + .get_column_default(column) + .or_else(|| self.source.get_column_default(column)); + } + self.source.get_column_default(column) } + fn supports_filters_pushdown( + &self, + filters: &[&Expr], + ) -> Result> { + if let Some(table_provider) = &self.table_provider { + return table_provider.supports_filters_pushdown(filters); + } + + Ok(vec![ + TableProviderFilterPushDown::Unsupported; + filters.len() + ]) + } // Scan is not supported; the adaptor should be replaced // with a virtual TableProvider that provides federation for a sub-plan. async fn scan( &self, - _state: &SessionState, - _projection: Option<&Vec>, - _filters: &[Expr], - _limit: Option, + state: &dyn Session, + projection: Option<&Vec>, + filters: &[Expr], + limit: Option, ) -> Result> { + if let Some(table_provider) = &self.table_provider { + return table_provider.scan(state, projection, filters, limit).await; + } + Err(DataFusionError::NotImplemented( "FederatedTableProviderAdaptor cannot scan".to_string(), )) } + + async fn insert_into( + &self, + _state: &dyn Session, + input: Arc, + overwrite: InsertOp, + ) -> Result> { + if let Some(table_provider) = &self.table_provider { + return table_provider.insert_into(_state, input, overwrite).await; + } + + Err(DataFusionError::NotImplemented( + "FederatedTableProviderAdaptor cannot insert_into".to_string(), + )) + } } // FederatedTableProvider extends DataFusion's TableProvider trait diff --git a/examples/Cargo.toml b/examples/Cargo.toml index 13eb599..48c910c 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -6,18 +6,18 @@ license.workspace = true readme.workspace = true [dev-dependencies] -arrow-flight = { version = "51.0.0", features = ["flight-sql-experimental"] } +arrow-flight = { version = "53.0.0", features = ["flight-sql-experimental"] } tokio = "1.35.1" async-trait.workspace = true datafusion.workspace = true datafusion-federation.path = "../datafusion-federation" -datafusion-federation-sql.path = "../sources/sql" +datafusion-federation-sql = { path = "../sources/sql", features = ["connectorx"] } datafusion-federation-flight-sql.path = "../sources/flight-sql" connectorx = { git = "https://github.com/devinjdangelo/connector-x.git", features = [ "dst_arrow", "src_sqlite" ] } -tonic = "0.11.0" +tonic = "0.12.2" [dependencies] async-std = "1.12.0" diff --git a/examples/examples/flight-sql.rs b/examples/examples/flight-sql.rs index 7a32e29..1f8578f 100644 --- a/examples/examples/flight-sql.rs +++ b/examples/examples/flight-sql.rs @@ -1,8 +1,9 @@ use std::{sync::Arc, time::Duration}; use arrow_flight::sql::client::FlightSqlServiceClient; +use datafusion::execution::SessionStateBuilder; use datafusion::{ - catalog::schema::SchemaProvider, + catalog::SchemaProvider, error::{DataFusionError, Result}, execution::{ context::{SessionContext, SessionState}, @@ -39,14 +40,14 @@ async fn main() -> Result<()> { sleep(Duration::from_secs(3)).await; // Local context - let state = SessionContext::new().state(); let known_tables: Vec = ["test"].iter().map(|&x| x.into()).collect(); // Register FederationAnalyzer // TODO: Interaction with other analyzers & optimizers. - let state = state - .add_analyzer_rule(Arc::new(FederationAnalyzerRule::new())) - .with_query_planner(Arc::new(FederatedQueryPlanner::new())); + let mut state = SessionStateBuilder::new() + .with_query_planner(Arc::new(FederatedQueryPlanner::new())) + .build(); + state.add_analyzer_rule(Arc::new(FederationAnalyzerRule::new())); // Register schema // TODO: table inference diff --git a/examples/examples/postgres-partial.rs b/examples/examples/postgres-partial.rs deleted file mode 100644 index 873dd40..0000000 --- a/examples/examples/postgres-partial.rs +++ /dev/null @@ -1,70 +0,0 @@ -use std::sync::Arc; -use tokio::task; - -use datafusion::{ - catalog::schema::SchemaProvider, - error::Result, - execution::context::{SessionContext, SessionState}, -}; -use datafusion_federation::{FederatedQueryPlanner, FederationAnalyzerRule}; -use datafusion_federation_sql::connectorx::CXExecutor; -use datafusion_federation_sql::{MultiSchemaProvider, SQLFederationProvider, SQLSchemaProvider}; - -#[tokio::main] -async fn main() -> Result<()> { - let state = SessionContext::new().state(); - // Register FederationAnalyzer - // TODO: Interaction with other analyzers & optimizers. - let state = state - .add_analyzer_rule(Arc::new(FederationAnalyzerRule::new())) - .with_query_planner(Arc::new(FederatedQueryPlanner::new())); - - let df = task::spawn_blocking(move || { - // Register schema - let pg_provider_1 = async_std::task::block_on(create_postgres_provider(vec!["class"], "conn1")).unwrap(); - let pg_provider_2 = async_std::task::block_on(create_postgres_provider(vec!["teacher"], "conn2")).unwrap(); - let provider = MultiSchemaProvider::new(vec![ - pg_provider_1, - pg_provider_2, - ]); - - overwrite_default_schema(&state, Arc::new(provider)).unwrap(); - - // Run query - let ctx = SessionContext::new_with_state(state); - let query = r#"SELECT class.name AS classname, teacher.name AS teachername FROM class JOIN teacher ON class.id = teacher.class_id"#; - let df = async_std::task::block_on(ctx.sql(query)).unwrap(); - - df - }).await.unwrap(); - - task::spawn_blocking(move || async_std::task::block_on(df.show())) - .await - .unwrap() -} - -async fn create_postgres_provider( - known_tables: Vec<&str>, - context: &str, -) -> Result> { - let dsn = "postgresql://:@localhost:/".to_string(); - let known_tables: Vec = known_tables.iter().map(|&x| x.into()).collect(); - let mut executor = CXExecutor::new(dsn)?; - executor.context(context.to_string()); - let provider = Arc::new(SQLFederationProvider::new(Arc::new(executor))); - Ok(Arc::new( - SQLSchemaProvider::new_with_tables(provider, known_tables).await?, - )) -} - -fn overwrite_default_schema(state: &SessionState, schema: Arc) -> Result<()> { - let options = &state.config().options().catalog; - let catalog = state - .catalog_list() - .catalog(options.default_catalog.as_str()) - .unwrap(); - - catalog.register_schema(options.default_schema.as_str(), schema)?; - - Ok(()) -} diff --git a/examples/examples/sqlite-partial.rs b/examples/examples/sqlite-partial.rs deleted file mode 100644 index 780462b..0000000 --- a/examples/examples/sqlite-partial.rs +++ /dev/null @@ -1,107 +0,0 @@ -use std::{any::Any, sync::Arc}; - -use async_trait::async_trait; -use datafusion::{ - catalog::schema::SchemaProvider, - datasource::TableProvider, - error::Result, - execution::context::{SessionContext, SessionState}, -}; -use datafusion_federation::{FederatedQueryPlanner, FederationAnalyzerRule}; -use datafusion_federation_sql::{connectorx::CXExecutor, SQLFederationProvider, SQLSchemaProvider}; - -#[tokio::main] -async fn main() -> Result<()> { - let state = SessionContext::new().state(); - // Register FederationAnalyzer - // TODO: Interaction with other analyzers & optimizers. - let state = state - .add_analyzer_rule(Arc::new(FederationAnalyzerRule::new())) - .with_query_planner(Arc::new(FederatedQueryPlanner::new())); - - // Register schema - let provider = MultiSchemaProvider::new(vec![ - create_sqlite_provider(vec!["Artist"], "conn1").await?, - create_sqlite_provider(vec!["Track", "Album"], "conn2").await?, - ]); - - overwrite_default_schema(&state, Arc::new(provider))?; - - // Run query - let ctx = SessionContext::new_with_state(state); - let query = r#"SELECT - t.TrackId, - t.Name AS TrackName, - a.Title AS AlbumTitle, - ar.Name AS ArtistName - FROM Track t - JOIN Album a ON t.AlbumId = a.AlbumId - JOIN Artist ar ON a.ArtistId = ar.ArtistId - limit 10"#; - let df = ctx.sql(query).await?; - - // let explain = df.clone().explain(true, false)?; - // explain.show().await?; - - df.show().await -} - -async fn create_sqlite_provider( - known_tables: Vec<&str>, - context: &str, -) -> Result> { - let dsn = "sqlite://./examples/examples/chinook.sqlite".to_string(); - let known_tables: Vec = known_tables.iter().map(|&x| x.into()).collect(); - let mut executor = CXExecutor::new(dsn)?; - executor.context(context.to_string()); - let provider = Arc::new(SQLFederationProvider::new(Arc::new(executor))); - Ok(Arc::new( - SQLSchemaProvider::new_with_tables(provider, known_tables).await?, - )) -} - -struct MultiSchemaProvider { - children: Vec>, -} - -impl MultiSchemaProvider { - pub fn new(children: Vec>) -> Self { - Self { children } - } -} - -fn overwrite_default_schema(state: &SessionState, schema: Arc) -> Result<()> { - let options = &state.config().options().catalog; - let catalog = state - .catalog_list() - .catalog(options.default_catalog.as_str()) - .unwrap(); - - catalog.register_schema(options.default_schema.as_str(), schema)?; - - Ok(()) -} - -#[async_trait] -impl SchemaProvider for MultiSchemaProvider { - fn as_any(&self) -> &dyn Any { - self - } - - fn table_names(&self) -> Vec { - self.children.iter().flat_map(|p| p.table_names()).collect() - } - - async fn table(&self, name: &str) -> Result>> { - for child in &self.children { - if let Ok(Some(table)) = child.table(name).await { - return Ok(Some(table)); - } - } - Ok(None) - } - - fn table_exist(&self, name: &str) -> bool { - self.children.iter().any(|p| p.table_exist(name)) - } -} diff --git a/examples/examples/sqlite.rs b/examples/examples/sqlite.rs deleted file mode 100644 index 74e5371..0000000 --- a/examples/examples/sqlite.rs +++ /dev/null @@ -1,70 +0,0 @@ -use std::sync::Arc; - -use datafusion::{ - catalog::schema::SchemaProvider, - error::Result, - execution::context::{SessionContext, SessionState}, -}; -use datafusion_federation::{FederatedQueryPlanner, FederationAnalyzerRule}; -use datafusion_federation_sql::{connectorx::CXExecutor, SQLFederationProvider, SQLSchemaProvider}; - -#[tokio::main] -async fn main() -> datafusion::error::Result<()> { - let dsn = "sqlite://./examples/examples/chinook.sqlite".to_string(); - let known_tables: Vec = ["Track", "Album", "Artist"] - .iter() - .map(|&x| x.into()) - .collect(); - - let state = SessionContext::new().state(); - - // Register FederationAnalyzer - // TODO: Interaction with other analyzers & optimizers. - let state = state - .add_analyzer_rule(Arc::new(FederationAnalyzerRule::new())) - .with_query_planner(Arc::new(FederatedQueryPlanner::new())); - - // Register schema - // TODO: table inference - let executor = Arc::new(CXExecutor::new(dsn)?); - let provider = Arc::new(SQLFederationProvider::new(executor)); - let schema_provider = - Arc::new(SQLSchemaProvider::new_with_tables(provider, known_tables).await?); - overwrite_default_schema(&state, schema_provider)?; - - // Run query - let ctx = SessionContext::new_with_state(state); - let query = r#"SELECT - t.TrackId, - t.Name AS TrackName, - a.Title AS AlbumTitle, - ar.Name AS ArtistName - FROM Track t - JOIN Album a ON t.AlbumId = a.AlbumId - JOIN Artist ar ON a.ArtistId = ar.ArtistId - limit 10"#; - let df = ctx.sql(query).await?; - df.show().await?; - - // If the environment variable EXPLAIN is set, print the query plan - if std::env::var("EXPLAIN").is_ok() { - let explain_query = format!("EXPLAIN {query}"); - let df = ctx.sql(explain_query.as_str()).await?; - - df.show().await?; - } - - Ok(()) -} - -fn overwrite_default_schema(state: &SessionState, schema: Arc) -> Result<()> { - let options = &state.config().options().catalog; - let catalog = state - .catalog_list() - .catalog(options.default_catalog.as_str()) - .unwrap(); - - catalog.register_schema(options.default_schema.as_str(), schema)?; - - Ok(()) -} diff --git a/sources/flight-sql/Cargo.toml b/sources/flight-sql/Cargo.toml index f778dfc..91e8af0 100644 --- a/sources/flight-sql/Cargo.toml +++ b/sources/flight-sql/Cargo.toml @@ -16,8 +16,8 @@ datafusion-substrait.workspace = true datafusion-federation.path = "../../datafusion-federation" datafusion-federation-sql.path = "../sql" futures = "0.3.30" -tonic = {version="0.11.0", features=["tls"] } -prost = "0.12.3" -arrow = "51.0.0" -arrow-flight = { version = "51.0.0", features = ["flight-sql-experimental"] } +tonic = {version="0.12.2", features=["tls"] } +prost = "0.13.4" +arrow = "53.0.0" +arrow-flight = { version = "53.0.0", features = ["flight-sql-experimental"] } log = "0.4.20" diff --git a/sources/flight-sql/src/executor/mod.rs b/sources/flight-sql/src/executor/mod.rs index a5c5a38..c4f1225 100644 --- a/sources/flight-sql/src/executor/mod.rs +++ b/sources/flight-sql/src/executor/mod.rs @@ -4,7 +4,7 @@ use async_trait::async_trait; use datafusion::{ error::{DataFusionError, Result}, physical_plan::{stream::RecordBatchStreamAdapter, SendableRecordBatchStream}, - sql::sqlparser::dialect::{Dialect, GenericDialect}, + sql::unparser::dialect::{DefaultDialect, Dialect}, }; use datafusion_federation_sql::SQLExecutor; use futures::TryStreamExt; @@ -68,11 +68,11 @@ impl SQLExecutor for FlightSQLExecutor { } fn execute(&self, sql: &str, schema: SchemaRef) -> Result { let future_stream = - make_flight_sql_stream(sql.to_string(), self.client.clone(), schema.clone()); + make_flight_sql_stream(sql.to_string(), self.client.clone(), Arc::clone(&schema)); let stream = futures::stream::once(future_stream).try_flatten(); Ok(Box::pin(RecordBatchStreamAdapter::new( - schema.clone(), + Arc::clone(&schema), stream, ))) } @@ -96,7 +96,7 @@ impl SQLExecutor for FlightSQLExecutor { } fn dialect(&self) -> Arc { - Arc::new(GenericDialect {}) + Arc::new(DefaultDialect {}) } } diff --git a/sources/flight-sql/src/server/service.rs b/sources/flight-sql/src/server/service.rs index c49cf32..6bc77ff 100644 --- a/sources/flight-sql/src/server/service.rs +++ b/sources/flight-sql/src/server/service.rs @@ -16,8 +16,8 @@ use arrow_flight::sql::{ CommandGetCrossReference, CommandGetDbSchemas, CommandGetExportedKeys, CommandGetImportedKeys, CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables, CommandGetXdbcTypeInfo, CommandPreparedStatementQuery, CommandPreparedStatementUpdate, - CommandStatementQuery, CommandStatementSubstraitPlan, CommandStatementUpdate, SqlInfo, - TicketStatementQuery, + CommandStatementQuery, CommandStatementSubstraitPlan, CommandStatementUpdate, + DoPutPreparedStatementResult, SqlInfo, TicketStatementQuery, }; use arrow_flight::{ Action, FlightDescriptor, FlightEndpoint, FlightInfo, HandshakeRequest, HandshakeResponse, @@ -601,7 +601,7 @@ impl ArrowFlightSqlService for FlightSqlService { &self, _query: CommandPreparedStatementQuery, request: Request, - ) -> Result::DoPutStream>> { + ) -> Result { info!("do_put_prepared_statement_query"); let (_, _) = self.new_context(request)?; diff --git a/sources/sql/Cargo.toml b/sources/sql/Cargo.toml index bc89d95..5d1d030 100644 --- a/sources/sql/Cargo.toml +++ b/sources/sql/Cargo.toml @@ -16,9 +16,16 @@ async-trait.workspace = true connectorx = { git = "https://github.com/devinjdangelo/connector-x.git", features = [ "dst_arrow", "src_sqlite" -] } +], optional = true } datafusion.workspace = true datafusion-federation.path = "../../datafusion-federation" # derive_builder = "0.13.0" futures = "0.3.30" tokio = "1.35.1" +tracing = "0.1.40" + +[features] +connectorx = ["dep:connectorx"] + +[dev-dependencies] +tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } diff --git a/sources/sql/src/connectorx/executor.rs b/sources/sql/src/connectorx/executor.rs deleted file mode 100644 index fc5ea3d..0000000 --- a/sources/sql/src/connectorx/executor.rs +++ /dev/null @@ -1,125 +0,0 @@ -use async_trait::async_trait; -use connectorx::{ - destinations::arrow::ArrowDestinationError, - errors::{ConnectorXError, ConnectorXOutError}, - prelude::{get_arrow, CXQuery, SourceConn, SourceType}, -}; -use datafusion::{ - arrow::datatypes::{Field, Schema, SchemaRef}, - error::{DataFusionError, Result}, - physical_plan::{ - stream::RecordBatchStreamAdapter, EmptyRecordBatchStream, SendableRecordBatchStream, - }, - sql::sqlparser::dialect::{Dialect, GenericDialect, PostgreSqlDialect, SQLiteDialect}, -}; -use futures::executor::block_on; -use std::sync::Arc; -use tokio::task; - -use crate::executor::SQLExecutor; - -pub struct CXExecutor { - context: String, - conn: SourceConn, -} - -impl CXExecutor { - pub fn new(dsn: String) -> Result { - let conn = SourceConn::try_from(dsn.as_str()).map_err(cx_error_to_df)?; - Ok(Self { context: dsn, conn }) - } - - pub fn new_with_conn(conn: SourceConn) -> Self { - Self { - context: conn.conn.to_string(), - conn, - } - } - - pub fn context(&mut self, context: String) { - self.context = context; - } -} - -fn cx_error_to_df(err: ConnectorXError) -> DataFusionError { - DataFusionError::External(format!("ConnectorX: {err:?}").into()) -} - -#[async_trait] -impl SQLExecutor for CXExecutor { - fn name(&self) -> &str { - "connector_x_executor" - } - fn compute_context(&self) -> Option { - Some(self.context.clone()) - } - fn execute(&self, sql: &str, schema: SchemaRef) -> Result { - let conn = self.conn.clone(); - let query: CXQuery = sql.into(); - - let mut dst = block_on(task::spawn_blocking(move || -> Result<_, _> { - get_arrow(&conn, None, &[query.clone()]).map_err(cx_out_error_to_df) - })) - .map_err(|err| DataFusionError::External(err.to_string().into()))??; - let stream = if let Some(batch) = dst.record_batch().map_err(cx_dst_error_to_df)? { - futures::stream::once(async move { Ok(batch) }) - } else { - return Ok(Box::pin(EmptyRecordBatchStream::new(Arc::new( - Schema::empty(), - )))); - }; - - Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream))) - } - - async fn table_names(&self) -> Result> { - Err(DataFusionError::NotImplemented( - "connector_x source: table inference not implemented".to_string(), - )) - } - - async fn get_table_schema(&self, table_name: &str) -> Result { - let conn = self.conn.clone(); - let query: CXQuery = format!("select * from {table_name} limit 1") - .as_str() - .into(); - - let dst = get_arrow(&conn, None, &[query.clone()]).map_err(cx_out_error_to_df)?; - let schema = schema_to_lowercase(dst.arrow_schema()); - Ok(schema) - } - - fn dialect(&self) -> Arc { - match &self.conn.ty { - SourceType::Postgres => Arc::new(PostgreSqlDialect {}), - SourceType::SQLite => Arc::new(SQLiteDialect {}), - _ => Arc::new(GenericDialect {}), - } - } -} - -fn cx_dst_error_to_df(err: ArrowDestinationError) -> DataFusionError { - DataFusionError::External(format!("ConnectorX failed to run query: {err:?}").into()) -} - -/// Get the schema with lowercase field names -fn schema_to_lowercase(schema: SchemaRef) -> SchemaRef { - // DF needs lower case schema - let lower_fields: Vec<_> = schema - .fields - .iter() - .map(|f| { - Field::new( - f.name().to_ascii_lowercase(), - f.data_type().clone(), - f.is_nullable(), - ) - }) - .collect(); - - Arc::new(Schema::new(lower_fields)) -} - -fn cx_out_error_to_df(err: ConnectorXOutError) -> DataFusionError { - DataFusionError::External(format!("ConnectorX failed to run query: {err:?}").into()) -} diff --git a/sources/sql/src/connectorx/mod.rs b/sources/sql/src/connectorx/mod.rs deleted file mode 100644 index 600069a..0000000 --- a/sources/sql/src/connectorx/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -mod executor; -pub use executor::*; diff --git a/sources/sql/src/executor.rs b/sources/sql/src/executor.rs index 7f05910..0922ce1 100644 --- a/sources/sql/src/executor.rs +++ b/sources/sql/src/executor.rs @@ -2,11 +2,12 @@ use async_trait::async_trait; use core::fmt; use datafusion::{ arrow::datatypes::SchemaRef, error::Result, physical_plan::SendableRecordBatchStream, - sql::sqlparser::dialect::Dialect, + sql::sqlparser::ast, sql::unparser::dialect::Dialect, }; use std::sync::Arc; pub type SQLExecutorRef = Arc; +pub type AstAnalyzer = Box Result>; #[async_trait] pub trait SQLExecutor: Sync + Send { @@ -20,6 +21,11 @@ pub trait SQLExecutor: Sync + Send { // The specific SQL dialect (currently supports 'sqlite', 'postgres', 'flight') fn dialect(&self) -> Arc; + /// Returns an AST analyzer specific for this engine to modify the AST before execution + fn ast_analyzer(&self) -> Option { + None + } + // Execution /// Execute a SQL query fn execute(&self, query: &str, schema: SchemaRef) -> Result; @@ -29,6 +35,11 @@ pub trait SQLExecutor: Sync + Send { async fn table_names(&self) -> Result>; /// Returns the schema of table_name within this SQLExecutor async fn get_table_schema(&self, table_name: &str) -> Result; + + /// Returns whether the executor requires partial table path in subquery + fn subquery_use_partial_path(&self) -> bool { + false + } } impl fmt::Debug for dyn SQLExecutor { diff --git a/sources/sql/src/lib.rs b/sources/sql/src/lib.rs index b814904..09a1fb3 100644 --- a/sources/sql/src/lib.rs +++ b/sources/sql/src/lib.rs @@ -1,27 +1,45 @@ use core::fmt; -use std::{any::Any, sync::Arc, vec}; +use std::{ + any::Any, + collections::{HashMap, HashSet}, + sync::Arc, + vec, +}; use async_trait::async_trait; use datafusion::{ arrow::datatypes::{Schema, SchemaRef}, + common::{Column, RecursionUnnestOption, UnnestOptions}, config::ConfigOptions, - error::Result, + error::{DataFusionError, Result}, execution::{context::SessionState, TaskContext}, - logical_expr::{Extension, LogicalPlan}, + logical_expr::{ + self, + expr::{ + AggregateFunction, Alias, Exists, InList, InSubquery, ScalarFunction, Sort, Unnest, + WindowFunction, + }, + Between, BinaryExpr, Case, Cast, Expr, Extension, GroupingSet, Like, Limit, LogicalPlan, + LogicalPlanBuilder, Projection, Subquery, TryCast, + }, optimizer::analyzer::{Analyzer, AnalyzerRule}, physical_expr::EquivalenceProperties, physical_plan::{ DisplayAs, DisplayFormatType, ExecutionMode, ExecutionPlan, Partitioning, PlanProperties, SendableRecordBatchStream, }, - sql::unparser::plan_to_sql, + sql::{ + unparser::{plan_to_sql, Unparser}, + TableReference, + }, +}; +use datafusion_federation::{ + get_table_source, schema_cast, FederatedPlanNode, FederationPlanner, FederationProvider, }; -use datafusion_federation::{FederatedPlanNode, FederationPlanner, FederationProvider}; mod schema; pub use schema::*; -pub mod connectorx; mod executor; pub use executor::*; @@ -38,7 +56,7 @@ impl SQLFederationProvider { pub fn new(executor: Arc) -> Self { Self { analyzer: Arc::new(Analyzer::with_rules(vec![Arc::new( - SQLFederationAnalyzerRule::new(executor.clone()), + SQLFederationAnalyzerRule::new(Arc::clone(&executor)), )])), executor, } @@ -55,7 +73,7 @@ impl FederationProvider for SQLFederationProvider { } fn analyzer(&self) -> Option> { - Some(self.analyzer.clone()) + Some(Arc::clone(&self.analyzer)) } } @@ -63,19 +81,23 @@ struct SQLFederationAnalyzerRule { planner: Arc, } +impl std::fmt::Debug for SQLFederationAnalyzerRule { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SQLFederationAnalyzerRule").finish() + } +} + impl SQLFederationAnalyzerRule { pub fn new(executor: Arc) -> Self { Self { - planner: Arc::new(SQLFederationPlanner::new(executor.clone())), + planner: Arc::new(SQLFederationPlanner::new(Arc::clone(&executor))), } } } impl AnalyzerRule for SQLFederationAnalyzerRule { fn analyze(&self, plan: LogicalPlan, _config: &ConfigOptions) -> Result { - // Simply accept the entire plan for now - - let fed_plan = FederatedPlanNode::new(plan.clone(), self.planner.clone()); + let fed_plan = FederatedPlanNode::new(plan.clone(), Arc::clone(&self.planner)); let ext_node = Extension { node: Arc::new(fed_plan), }; @@ -87,6 +109,995 @@ impl AnalyzerRule for SQLFederationAnalyzerRule { "federate_sql" } } + +/// Rewrite table scans to use the original federated table name. +fn rewrite_table_scans( + plan: &LogicalPlan, + known_rewrites: &mut HashMap, + subquery_uses_partial_path: bool, + subquery_table_scans: &mut Option>, +) -> Result { + if plan.inputs().is_empty() { + if let LogicalPlan::TableScan(table_scan) = plan { + let original_table_name = table_scan.table_name.clone(); + let mut new_table_scan = table_scan.clone(); + + let Some(federated_source) = get_table_source(&table_scan.source)? else { + // Not a federated source + return Ok(plan.clone()); + }; + + match federated_source.as_any().downcast_ref::() { + Some(sql_table_source) => { + let remote_table_name = TableReference::from(sql_table_source.table_name()); + known_rewrites.insert(original_table_name.clone(), remote_table_name.clone()); + + if let Some(s) = subquery_table_scans { + s.insert(original_table_name); + } + + // Rewrite the schema of this node to have the remote table as the qualifier. + let new_schema = (*new_table_scan.projected_schema) + .clone() + .replace_qualifier(remote_table_name.clone()); + new_table_scan.projected_schema = Arc::new(new_schema); + new_table_scan.table_name = remote_table_name; + } + None => { + // Not a SQLTableSource (is this possible?) + return Ok(plan.clone()); + } + } + + return Ok(LogicalPlan::TableScan(new_table_scan)); + } else { + return Ok(plan.clone()); + } + } + + let rewritten_inputs = plan + .inputs() + .into_iter() + .map(|i| { + rewrite_table_scans( + i, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + ) + }) + .collect::>>()?; + + match plan { + LogicalPlan::Unnest(unnest) => { + // The Union plan cannot be constructed from rewritten expressions. It requires specialized logic to handle + // the renaming in UNNEST columns and the corresponding column aliases in the underlying projection plan. + rewrite_unnest_plan( + unnest, + rewritten_inputs, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + ) + } + LogicalPlan::Limit(limit) => { + let rewritten_skip = limit + .skip + .as_ref() + .map(|skip| { + rewrite_table_scans_in_expr( + *skip.clone(), + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + ) + .map(Box::new) + }) + .transpose()?; + + let rewritten_fetch = limit + .fetch + .as_ref() + .map(|fetch| { + rewrite_table_scans_in_expr( + *fetch.clone(), + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + ) + .map(Box::new) + }) + .transpose()?; + + // explisitly set fetch and skip + let new_plan = LogicalPlan::Limit(Limit { + skip: rewritten_skip, + fetch: rewritten_fetch, + input: Arc::new(rewritten_inputs[0].clone()), + }); + Ok(new_plan) + } + + _ => { + let mut new_expressions = vec![]; + for expression in plan.expressions() { + let new_expr = rewrite_table_scans_in_expr( + expression.clone(), + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + )?; + new_expressions.push(new_expr); + } + let new_plan = plan.with_new_exprs(new_expressions, rewritten_inputs)?; + Ok(new_plan) + } + } +} + +/// Rewrite an unnest plan to use the original federated table name. +/// In a standard unnest plan, column names are typically referenced in projection columns by wrapping them +/// in aliases such as "UNNEST(table_name.column_name)". `rewrite_table_scans_in_expr` does not handle alias +/// rewriting so we manually collect the rewritten unnest column names/aliases and update the projection +/// plan to ensure that the aliases reflect the new names. +fn rewrite_unnest_plan( + unnest: &logical_expr::Unnest, + mut rewritten_inputs: Vec, + known_rewrites: &mut HashMap, + subquery_uses_partial_path: bool, + subquery_table_scans: &mut Option>, +) -> Result { + // Unnest plan has a single input + let input = rewritten_inputs.remove(0); + + let mut known_unnest_rewrites: HashMap = HashMap::new(); + + // `exec_columns` represent columns to run UNNEST on: rewrite them and collect new names + let unnest_columns = unnest + .exec_columns + .iter() + .map(|c: &Column| { + match rewrite_table_scans_in_expr( + Expr::Column(c.clone()), + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + )? { + Expr::Column(column) => { + known_unnest_rewrites.insert(c.name.clone(), column.name.clone()); + Ok(column) + } + _ => Err(DataFusionError::Plan( + "Rewritten column expression must be a column".to_string(), + )), + } + }) + .collect::>>()?; + + let LogicalPlan::Projection(projection) = input else { + return Err(DataFusionError::Plan( + "The input to the unnest plan should be a projection plan".to_string(), + )); + }; + + // rewrite aliases in inner projection; columns were rewritten via `rewrite_table_scans_in_expr` + let new_expressions = projection + .expr + .into_iter() + .map(|expr| match expr { + Expr::Alias(alias) => { + let name = match known_unnest_rewrites.get(&alias.name) { + Some(name) => name, + None => &alias.name, + }; + Ok(Expr::Alias(Alias::new(*alias.expr, alias.relation, name))) + } + _ => Ok(expr), + }) + .collect::>>()?; + + let updated_unnest_inner_projection = + Projection::try_new(new_expressions, Arc::clone(&projection.input))?; + + let unnest_options = + rewrite_unnest_options(&unnest.options, known_rewrites, subquery_table_scans); + + // reconstruct the unnest plan with updated projection and rewritten column names + let new_plan = + LogicalPlanBuilder::new(LogicalPlan::Projection(updated_unnest_inner_projection)) + .unnest_columns_with_options(unnest_columns, unnest_options)? + .build()?; + + Ok(new_plan) +} + +/// Rewrites columns names in the unnest options to use the original federated table name: +/// "unnest_placeholder(foo.df_table.a,depth=1)"" -> "unnest_placeholder(remote_table.a,depth=1)"" +fn rewrite_unnest_options( + options: &UnnestOptions, + known_rewrites: &HashMap, + subquery_table_scans: &mut Option>, +) -> UnnestOptions { + let mut new_options = options.clone(); + new_options + .recursions + .iter_mut() + .for_each(|x: &mut RecursionUnnestOption| { + if let Some(new_name) = + rewrite_column_name(&x.input_column.name, known_rewrites, subquery_table_scans) + { + x.input_column.name = new_name; + } + + if let Some(new_name) = + rewrite_column_name(&x.output_column.name, known_rewrites, subquery_table_scans) + { + x.output_column.name = new_name; + } + }); + new_options +} + +/// Checks if any of the rewrites match any substring in col_name, and replace that part of the string if so. +/// This handles cases like "MAX(foo.df_table.a)" -> "MAX(remote_table.a)" +/// Returns the rewritten name if any rewrite was applied, otherwise None. +fn rewrite_column_name( + col_name: &str, + known_rewrites: &HashMap, + subquery_table_scans: &mut Option>, +) -> Option { + let (new_col_name, was_rewritten) = known_rewrites.iter().fold( + (col_name.to_string(), false), + |(col_name, was_rewritten), (table_ref, rewrite)| { + let mut rewrite_string = rewrite.to_string(); + if let Some(subquery_reference) = subquery_table_scans { + if subquery_reference.get(table_ref).is_some() { + rewrite_string = get_partial_table_name(rewrite); + } + } + match rewrite_column_name_in_expr(&col_name, &table_ref.to_string(), &rewrite_string, 0) + { + Some(new_name) => (new_name, true), + None => (col_name, was_rewritten), + } + }, + ); + + if was_rewritten { + Some(new_col_name) + } else { + None + } +} + +fn get_partial_table_name(full_table_reference: &TableReference) -> String { + let full_table_path = full_table_reference.table().to_owned(); + let path_parts: Vec<&str> = full_table_path.split('.').collect(); + path_parts[path_parts.len() - 1].to_owned() +} + +// The function replaces occurrences of table_ref_str in col_name with the new name defined by rewrite. +// The name to rewrite should NOT be a substring of another name. +// Supports multiple occurrences of table_ref_str in col_name. +fn rewrite_column_name_in_expr( + col_name: &str, + table_ref_str: &str, + rewrite: &str, + start_pos: usize, +) -> Option { + if start_pos >= col_name.len() { + return None; + } + + // Find the first occurrence of table_ref_str starting from start_pos + let idx = col_name[start_pos..].find(table_ref_str)?; + + // Calculate the absolute index of the occurrence in string as the index above is relative to start_pos + let idx = start_pos + idx; + + // Table name same as column name + // Shouldn't rewrite in this case + if idx == 0 && start_pos == 0 { + return None; + } + + if idx > 0 { + // Check if the previous character is alphabetic, numeric, underscore or period, in which case we + // should not rewrite as it is a part of another name. + if let Some(prev_char) = col_name.chars().nth(idx - 1) { + if prev_char.is_alphabetic() + || prev_char.is_numeric() + || prev_char == '_' + || prev_char == '.' + { + return rewrite_column_name_in_expr( + col_name, + table_ref_str, + rewrite, + idx + table_ref_str.len(), + ); + } + } + } + + // Check if the next character is alphabetic, numeric or underscore, in which case we + // should not rewrite as it is a part of another name. + if let Some(next_char) = col_name.chars().nth(idx + table_ref_str.len()) { + if next_char.is_alphabetic() || next_char.is_numeric() || next_char == '_' { + return rewrite_column_name_in_expr( + col_name, + table_ref_str, + rewrite, + idx + table_ref_str.len(), + ); + } + } + + // Found full match, replace table_ref_str occurrence with rewrite + let rewritten_name = format!( + "{}{}{}", + &col_name[..idx], + rewrite, + &col_name[idx + table_ref_str.len()..] + ); + + // Check if the rewritten name contains more occurrence of table_ref_str, and rewrite them as well + // This is done by providing the updated start_pos for search + match rewrite_column_name_in_expr(&rewritten_name, table_ref_str, rewrite, idx + rewrite.len()) + { + Some(new_name) => Some(new_name), // more occurrences found + None => Some(rewritten_name), // no more occurrences/changes + } +} + +fn rewrite_table_scans_in_expr( + expr: Expr, + known_rewrites: &mut HashMap, + subquery_uses_partial_path: bool, + subquery_table_scans: &mut Option>, +) -> Result { + match expr { + Expr::ScalarSubquery(subquery) => { + let new_subquery = if subquery_table_scans.is_some() || !subquery_uses_partial_path { + rewrite_table_scans( + &subquery.subquery, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + )? + } else { + let mut scans = Some(HashSet::new()); + rewrite_table_scans( + &subquery.subquery, + known_rewrites, + subquery_uses_partial_path, + &mut scans, + )? + }; + let outer_ref_columns = subquery + .outer_ref_columns + .into_iter() + .map(|e| { + rewrite_table_scans_in_expr( + e, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + ) + }) + .collect::>>()?; + Ok(Expr::ScalarSubquery(Subquery { + subquery: Arc::new(new_subquery), + outer_ref_columns, + })) + } + Expr::BinaryExpr(binary_expr) => { + let left = rewrite_table_scans_in_expr( + *binary_expr.left, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + )?; + let right = rewrite_table_scans_in_expr( + *binary_expr.right, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + )?; + Ok(Expr::BinaryExpr(BinaryExpr::new( + Box::new(left), + binary_expr.op, + Box::new(right), + ))) + } + Expr::Column(mut col) => { + if let Some(rewrite) = col.relation.as_ref().and_then(|r| known_rewrites.get(r)) { + if let Some(subquery_reference) = subquery_table_scans { + if col + .relation + .as_ref() + .and_then(|r| subquery_reference.get(r)) + .is_some() + { + // Use the partial table path from source for rewrite + // e.g. If the fully qualified name is foo_db.foo_schema.foo + // Use foo as partial path + let partial_path = get_partial_table_name(rewrite); + let partial_table_reference = TableReference::from(partial_path); + + return Ok(Expr::Column(Column::new( + Some(partial_table_reference), + &col.name, + ))); + } + } + Ok(Expr::Column(Column::new(Some(rewrite.clone()), &col.name))) + } else { + // This prevent over-eager rewrite and only pass the column into below rewritten + // rule like MAX(...) + if col.relation.is_some() { + return Ok(Expr::Column(col)); + } + + // Check if any of the rewrites match any substring in col.name, and replace that part of the string if so. + // This will handles cases like "MAX(foo.df_table.a)" -> "MAX(remote_table.a)" + if let Some(new_name) = + rewrite_column_name(&col.name, known_rewrites, subquery_table_scans) + { + Ok(Expr::Column(Column::new(col.relation.take(), new_name))) + } else { + Ok(Expr::Column(col)) + } + } + } + Expr::Alias(alias) => { + let expr = rewrite_table_scans_in_expr( + *alias.expr, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + )?; + if let Some(relation) = &alias.relation { + if let Some(rewrite) = known_rewrites.get(relation) { + return Ok(Expr::Alias(Alias::new( + expr, + Some(rewrite.clone()), + alias.name, + ))); + } + } + Ok(Expr::Alias(Alias::new(expr, alias.relation, alias.name))) + } + Expr::Like(like) => { + let expr = rewrite_table_scans_in_expr( + *like.expr, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + )?; + let pattern = rewrite_table_scans_in_expr( + *like.pattern, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + )?; + Ok(Expr::Like(Like::new( + like.negated, + Box::new(expr), + Box::new(pattern), + like.escape_char, + like.case_insensitive, + ))) + } + Expr::SimilarTo(similar_to) => { + let expr = rewrite_table_scans_in_expr( + *similar_to.expr, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + )?; + let pattern = rewrite_table_scans_in_expr( + *similar_to.pattern, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + )?; + Ok(Expr::SimilarTo(Like::new( + similar_to.negated, + Box::new(expr), + Box::new(pattern), + similar_to.escape_char, + similar_to.case_insensitive, + ))) + } + Expr::Not(e) => { + let expr = rewrite_table_scans_in_expr( + *e, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + )?; + Ok(Expr::Not(Box::new(expr))) + } + Expr::IsNotNull(e) => { + let expr = rewrite_table_scans_in_expr( + *e, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + )?; + Ok(Expr::IsNotNull(Box::new(expr))) + } + Expr::IsNull(e) => { + let expr = rewrite_table_scans_in_expr( + *e, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + )?; + Ok(Expr::IsNull(Box::new(expr))) + } + Expr::IsTrue(e) => { + let expr = rewrite_table_scans_in_expr( + *e, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + )?; + Ok(Expr::IsTrue(Box::new(expr))) + } + Expr::IsFalse(e) => { + let expr = rewrite_table_scans_in_expr( + *e, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + )?; + Ok(Expr::IsFalse(Box::new(expr))) + } + Expr::IsUnknown(e) => { + let expr = rewrite_table_scans_in_expr( + *e, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + )?; + Ok(Expr::IsUnknown(Box::new(expr))) + } + Expr::IsNotTrue(e) => { + let expr = rewrite_table_scans_in_expr( + *e, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + )?; + Ok(Expr::IsNotTrue(Box::new(expr))) + } + Expr::IsNotFalse(e) => { + let expr = rewrite_table_scans_in_expr( + *e, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + )?; + Ok(Expr::IsNotFalse(Box::new(expr))) + } + Expr::IsNotUnknown(e) => { + let expr = rewrite_table_scans_in_expr( + *e, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + )?; + Ok(Expr::IsNotUnknown(Box::new(expr))) + } + Expr::Negative(e) => { + let expr = rewrite_table_scans_in_expr( + *e, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + )?; + Ok(Expr::Negative(Box::new(expr))) + } + Expr::Between(between) => { + let expr = rewrite_table_scans_in_expr( + *between.expr, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + )?; + let low = rewrite_table_scans_in_expr( + *between.low, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + )?; + let high = rewrite_table_scans_in_expr( + *between.high, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + )?; + Ok(Expr::Between(Between::new( + Box::new(expr), + between.negated, + Box::new(low), + Box::new(high), + ))) + } + Expr::Case(case) => { + let expr = case + .expr + .map(|e| { + rewrite_table_scans_in_expr( + *e, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + ) + }) + .transpose()? + .map(Box::new); + let else_expr = case + .else_expr + .map(|e| { + rewrite_table_scans_in_expr( + *e, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + ) + }) + .transpose()? + .map(Box::new); + let when_expr = case + .when_then_expr + .into_iter() + .map(|(when, then)| { + let when = rewrite_table_scans_in_expr( + *when, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + ); + let then = rewrite_table_scans_in_expr( + *then, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + ); + + match (when, then) { + (Ok(when), Ok(then)) => Ok((Box::new(when), Box::new(then))), + (Err(e), _) | (_, Err(e)) => Err(e), + } + }) + .collect::, Box)>>>()?; + Ok(Expr::Case(Case::new(expr, when_expr, else_expr))) + } + Expr::Cast(cast) => { + let expr = rewrite_table_scans_in_expr( + *cast.expr, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + )?; + Ok(Expr::Cast(Cast::new(Box::new(expr), cast.data_type))) + } + Expr::TryCast(try_cast) => { + let expr = rewrite_table_scans_in_expr( + *try_cast.expr, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + )?; + Ok(Expr::TryCast(TryCast::new( + Box::new(expr), + try_cast.data_type, + ))) + } + Expr::ScalarFunction(sf) => { + let args = sf + .args + .into_iter() + .map(|e| { + rewrite_table_scans_in_expr( + e, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + ) + }) + .collect::>>()?; + Ok(Expr::ScalarFunction(ScalarFunction { + func: sf.func, + args, + })) + } + Expr::AggregateFunction(af) => { + let args = af + .args + .into_iter() + .map(|e| { + rewrite_table_scans_in_expr( + e, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + ) + }) + .collect::>>()?; + let filter = af + .filter + .map(|e| { + rewrite_table_scans_in_expr( + *e, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + ) + }) + .transpose()? + .map(Box::new); + let order_by = af + .order_by + .map(|e| { + e.into_iter() + .map(|s| { + rewrite_table_scans_in_expr( + s.expr, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + ) + .map(|e| Sort::new(e, s.asc, s.nulls_first)) + }) + .collect::>>() + }) + .transpose()?; + Ok(Expr::AggregateFunction(AggregateFunction { + func: af.func, + args, + distinct: af.distinct, + filter, + order_by, + null_treatment: af.null_treatment, + })) + } + Expr::WindowFunction(wf) => { + let args = wf + .args + .into_iter() + .map(|e| { + rewrite_table_scans_in_expr( + e, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + ) + }) + .collect::>>()?; + let partition_by = wf + .partition_by + .into_iter() + .map(|e| { + rewrite_table_scans_in_expr( + e, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + ) + }) + .collect::>>()?; + let order_by = wf + .order_by + .into_iter() + .map(|s| { + rewrite_table_scans_in_expr( + s.expr, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + ) + .map(|e| Sort::new(e, s.asc, s.nulls_first)) + }) + .collect::>>()?; + Ok(Expr::WindowFunction(WindowFunction { + fun: wf.fun, + args, + partition_by, + order_by, + window_frame: wf.window_frame, + null_treatment: wf.null_treatment, + })) + } + Expr::InList(il) => { + let expr = rewrite_table_scans_in_expr( + *il.expr, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + )?; + let list = il + .list + .into_iter() + .map(|e| { + rewrite_table_scans_in_expr( + e, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + ) + }) + .collect::>>()?; + Ok(Expr::InList(InList::new(Box::new(expr), list, il.negated))) + } + Expr::Exists(exists) => { + let subquery_plan = if subquery_table_scans.is_some() || !subquery_uses_partial_path { + rewrite_table_scans( + &exists.subquery.subquery, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + )? + } else { + let mut scans = Some(HashSet::new()); + rewrite_table_scans( + &exists.subquery.subquery, + known_rewrites, + subquery_uses_partial_path, + &mut scans, + )? + }; + + let outer_ref_columns = exists + .subquery + .outer_ref_columns + .into_iter() + .map(|e| { + rewrite_table_scans_in_expr( + e, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + ) + }) + .collect::>>()?; + let subquery = Subquery { + subquery: Arc::new(subquery_plan), + outer_ref_columns, + }; + Ok(Expr::Exists(Exists::new(subquery, exists.negated))) + } + Expr::InSubquery(is) => { + let expr = rewrite_table_scans_in_expr( + *is.expr, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + )?; + let subquery_plan = if subquery_table_scans.is_some() || !subquery_uses_partial_path { + rewrite_table_scans( + &is.subquery.subquery, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + )? + } else { + let mut scans = Some(HashSet::new()); + rewrite_table_scans( + &is.subquery.subquery, + known_rewrites, + subquery_uses_partial_path, + &mut scans, + )? + }; + let outer_ref_columns = is + .subquery + .outer_ref_columns + .into_iter() + .map(|e| { + rewrite_table_scans_in_expr( + e, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + ) + }) + .collect::>>()?; + let subquery = Subquery { + subquery: Arc::new(subquery_plan), + outer_ref_columns, + }; + Ok(Expr::InSubquery(InSubquery::new( + Box::new(expr), + subquery, + is.negated, + ))) + } + Expr::Wildcard { qualifier, options } => { + if let Some(rewrite) = qualifier.as_ref().and_then(|q| known_rewrites.get(q)) { + Ok(Expr::Wildcard { + qualifier: Some(rewrite.clone()), + options, + }) + } else { + Ok(Expr::Wildcard { qualifier, options }) + } + } + Expr::GroupingSet(gs) => match gs { + GroupingSet::Rollup(exprs) => { + let exprs = exprs + .into_iter() + .map(|e| { + rewrite_table_scans_in_expr( + e, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + ) + }) + .collect::>>()?; + Ok(Expr::GroupingSet(GroupingSet::Rollup(exprs))) + } + GroupingSet::Cube(exprs) => { + let exprs = exprs + .into_iter() + .map(|e| { + rewrite_table_scans_in_expr( + e, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + ) + }) + .collect::>>()?; + Ok(Expr::GroupingSet(GroupingSet::Cube(exprs))) + } + GroupingSet::GroupingSets(vec_exprs) => { + let vec_exprs = vec_exprs + .into_iter() + .map(|exprs| { + exprs + .into_iter() + .map(|e| { + rewrite_table_scans_in_expr( + e, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + ) + }) + .collect::>>() + }) + .collect::>>>()?; + Ok(Expr::GroupingSet(GroupingSet::GroupingSets(vec_exprs))) + } + }, + Expr::OuterReferenceColumn(dt, col) => { + if let Some(rewrite) = col.relation.as_ref().and_then(|r| known_rewrites.get(r)) { + Ok(Expr::OuterReferenceColumn( + dt, + Column::new(Some(rewrite.clone()), &col.name), + )) + } else { + Ok(Expr::OuterReferenceColumn(dt, col)) + } + } + Expr::Unnest(unnest) => { + let expr = rewrite_table_scans_in_expr( + *unnest.expr, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + )?; + Ok(Expr::Unnest(Unnest::new(expr))) + } + Expr::ScalarVariable(_, _) | Expr::Literal(_) | Expr::Placeholder(_) => Ok(expr), + } +} + struct SQLFederationPlanner { executor: Arc, } @@ -104,10 +1115,13 @@ impl FederationPlanner for SQLFederationPlanner { node: &FederatedPlanNode, _session_state: &SessionState, ) -> Result> { - Ok(Arc::new(VirtualExecutionPlan::new( + let schema = Arc::new(node.plan().schema().as_arrow().clone()); + let input = Arc::new(VirtualExecutionPlan::new( node.plan().clone(), - self.executor.clone(), - ))) + Arc::clone(&self.executor), + )); + let schema_cast_exec = schema_cast::SchemaCastScanExec::new(input, schema); + Ok(Arc::new(schema_cast_exec)) } } @@ -137,6 +1151,26 @@ impl VirtualExecutionPlan { let df_schema = self.plan.schema().as_ref(); Arc::new(Schema::from(df_schema)) } + + fn sql(&self) -> Result { + // Find all table scans, recover the SQLTableSource, find the remote table name and replace the name of the TableScan table. + let mut known_rewrites = HashMap::new(); + let subquery_uses_partial_path = self.executor.subquery_use_partial_path(); + let rewritten_plan = rewrite_table_scans( + &self.plan, + &mut known_rewrites, + subquery_uses_partial_path, + &mut None, + )?; + let mut ast = + Unparser::new(self.executor.dialect().as_ref()).plan_to_sql(&rewritten_plan)?; + + if let Some(analyzer) = self.executor.ast_analyzer() { + ast = analyzer(ast)?; + } + + Ok(format!("{ast}")) + } } impl DisplayAs for VirtualExecutionPlan { @@ -148,12 +1182,22 @@ impl DisplayAs for VirtualExecutionPlan { write!(f, " name={}", self.executor.name())?; if let Some(ctx) = self.executor.compute_context() { write!(f, " compute_context={ctx}")?; - } - write!(f, " sql={ast}") + }; + + write!(f, " sql={ast}")?; + if let Ok(query) = self.sql() { + write!(f, " rewritten_sql={query}")?; + }; + + Ok(()) } } impl ExecutionPlan for VirtualExecutionPlan { + fn name(&self) -> &str { + "VirtualExecutionPlan" + } + fn as_any(&self) -> &dyn Any { self } @@ -162,7 +1206,7 @@ impl ExecutionPlan for VirtualExecutionPlan { self.schema() } - fn children(&self) -> Vec> { + fn children(&self) -> Vec<&Arc> { vec![] } @@ -178,13 +1222,428 @@ impl ExecutionPlan for VirtualExecutionPlan { _partition: usize, _context: Arc, ) -> Result { - let ast = plan_to_sql(&self.plan)?; - let query = format!("{ast}"); - - self.executor.execute(query.as_str(), self.schema()) + self.executor.execute(self.sql()?.as_str(), self.schema()) } fn properties(&self) -> &PlanProperties { &self.props } } + +#[cfg(test)] +mod tests { + use datafusion::{ + arrow::datatypes::{DataType, Field}, + catalog::SchemaProvider, + catalog_common::MemorySchemaProvider, + common::Column, + datasource::{DefaultTableSource, TableProvider}, + error::DataFusionError, + execution::context::SessionContext, + logical_expr::LogicalPlanBuilder, + sql::{unparser::dialect::DefaultDialect, unparser::dialect::Dialect}, + }; + use datafusion_federation::FederatedTableProviderAdaptor; + + use super::*; + + struct TestSQLExecutor {} + + #[async_trait] + impl SQLExecutor for TestSQLExecutor { + fn name(&self) -> &str { + "test_sql_table_source" + } + + fn compute_context(&self) -> Option { + None + } + + fn dialect(&self) -> Arc { + Arc::new(DefaultDialect {}) + } + + fn execute(&self, _query: &str, _schema: SchemaRef) -> Result { + Err(DataFusionError::NotImplemented( + "execute not implemented".to_string(), + )) + } + + async fn table_names(&self) -> Result> { + Err(DataFusionError::NotImplemented( + "table inference not implemented".to_string(), + )) + } + + async fn get_table_schema(&self, _table_name: &str) -> Result { + Err(DataFusionError::NotImplemented( + "table inference not implemented".to_string(), + )) + } + } + + fn get_test_table_provider() -> Arc { + let sql_federation_provider = + Arc::new(SQLFederationProvider::new(Arc::new(TestSQLExecutor {}))); + + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, false), + Field::new("b", DataType::Utf8, false), + Field::new("c", DataType::Date32, false), + Field::new( + "d", + DataType::List(Arc::new(Field::new("item", DataType::Int64, true))), + false, + ), + ])); + let table_source = Arc::new( + SQLTableSource::new_with_schema( + sql_federation_provider, + "remote_table".to_string(), + schema, + ) + .expect("to have a valid SQLTableSource"), + ); + Arc::new(FederatedTableProviderAdaptor::new(table_source)) + } + + fn get_test_table_provider_with_full_path() -> Arc { + let sql_federation_provider = + Arc::new(SQLFederationProvider::new(Arc::new(TestSQLExecutor {}))); + + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, false), + Field::new("b", DataType::Utf8, false), + Field::new("c", DataType::Date32, false), + Field::new( + "d", + DataType::List(Arc::new(Field::new("item", DataType::Int64, true))), + false, + ), + ])); + let table_source = Arc::new( + SQLTableSource::new_with_schema( + sql_federation_provider, + "remote_db.remote_schema.remote_table".to_string(), + schema, + ) + .expect("to have a valid SQLTableSource"), + ); + Arc::new(FederatedTableProviderAdaptor::new(table_source)) + } + + fn get_test_table_source() -> Arc { + Arc::new(DefaultTableSource::new(get_test_table_provider())) + } + + fn get_test_df_context() -> SessionContext { + let ctx = SessionContext::new(); + let catalog = ctx + .catalog("datafusion") + .expect("default catalog is datafusion"); + let foo_schema = Arc::new(MemorySchemaProvider::new()) as Arc; + catalog + .register_schema("foo", Arc::clone(&foo_schema)) + .expect("to register schema"); + foo_schema + .register_table("df_table".to_string(), get_test_table_provider()) + .expect("to register table"); + + let public_schema = catalog + .schema("public") + .expect("public schema should exist"); + public_schema + .register_table("app_table".to_string(), get_test_table_provider()) + .expect("to register table"); + + let public_schema = catalog + .schema("public") + .expect("public schema should exist"); + public_schema + .register_table("bar".to_string(), get_test_table_provider_with_full_path()) + .expect("to register table"); + + ctx + } + + #[test] + fn test_rewrite_table_scans_basic() -> Result<()> { + let default_table_source = get_test_table_source(); + let plan = + LogicalPlanBuilder::scan("foo.df_table", default_table_source, None)?.project(vec![ + Expr::Column(Column::from_qualified_name("foo.df_table.a")), + Expr::Column(Column::from_qualified_name("foo.df_table.b")), + Expr::Column(Column::from_qualified_name("foo.df_table.c")), + ])?; + + let mut known_rewrites = HashMap::new(); + let rewritten_plan = + rewrite_table_scans(&plan.build()?, &mut known_rewrites, false, &mut None)?; + + println!("rewritten_plan: \n{:#?}", rewritten_plan); + + let unparsed_sql = plan_to_sql(&rewritten_plan)?; + + println!("unparsed_sql: \n{unparsed_sql}"); + + assert_eq!( + format!("{unparsed_sql}"), + r#"SELECT remote_table.a, remote_table.b, remote_table.c FROM remote_table"# + ); + + Ok(()) + } + + fn init_tracing() { + let subscriber = tracing_subscriber::FmtSubscriber::builder() + .with_env_filter("debug") + .with_ansi(true) + .finish(); + let _ = tracing::subscriber::set_global_default(subscriber); + } + + #[tokio::test] + async fn test_rewrite_table_scans_agg() -> Result<()> { + init_tracing(); + let ctx = get_test_df_context(); + + let agg_tests = vec![ + ( + "SELECT MAX(a) FROM foo.df_table", + r#"SELECT max(remote_table.a) FROM remote_table"#, + ), + ( + "SELECT foo.df_table.a FROM foo.df_table", + r#"SELECT remote_table.a FROM remote_table"#, + ), + ( + "SELECT MIN(a) FROM foo.df_table", + r#"SELECT min(remote_table.a) FROM remote_table"#, + ), + ( + "SELECT AVG(a) FROM foo.df_table", + r#"SELECT avg(remote_table.a) FROM remote_table"#, + ), + ( + "SELECT SUM(a) FROM foo.df_table", + r#"SELECT sum(remote_table.a) FROM remote_table"#, + ), + ( + "SELECT COUNT(a) FROM foo.df_table", + r#"SELECT count(remote_table.a) FROM remote_table"#, + ), + ( + "SELECT COUNT(a) as cnt FROM foo.df_table", + r#"SELECT count(remote_table.a) AS cnt FROM remote_table"#, + ), + ( + "SELECT COUNT(a) as cnt FROM foo.df_table", + r#"SELECT count(remote_table.a) AS cnt FROM remote_table"#, + ), + ( + "SELECT app_table from (SELECT a as app_table FROM app_table) b", + r#"SELECT b.app_table FROM (SELECT remote_table.a AS app_table FROM remote_table) AS b"#, + ), + ( + "SELECT MAX(app_table) from (SELECT a as app_table FROM app_table) b", + r#"SELECT max(b.app_table) FROM (SELECT remote_table.a AS app_table FROM remote_table) AS b"#, + ), + // multiple occurrences of the same table in single aggregation expression + ( + "SELECT COUNT(CASE WHEN a > 0 THEN a ELSE 0 END) FROM app_table", + r#"SELECT count(CASE WHEN (remote_table.a > 0) THEN remote_table.a ELSE 0 END) FROM remote_table"#, + ), + // different tables in single aggregation expression + ( + "SELECT COUNT(CASE WHEN appt.a > 0 THEN appt.a ELSE dft.a END) FROM app_table as appt, foo.df_table as dft", + "SELECT count(CASE WHEN (appt.a > 0) THEN appt.a ELSE dft.a END) FROM remote_table AS appt JOIN remote_table AS dft" + ), + ]; + + for test in agg_tests { + test_sql(&ctx, test.0, test.1, false).await?; + } + + Ok(()) + } + + #[tokio::test] + async fn test_rewrite_table_scans_alias() -> Result<()> { + init_tracing(); + let ctx = get_test_df_context(); + + let tests = vec![ + ( + "SELECT COUNT(app_table_a) FROM (SELECT a as app_table_a FROM app_table)", + r#"SELECT count(app_table_a) FROM (SELECT remote_table.a AS app_table_a FROM remote_table)"#, + ), + ( + "SELECT app_table_a FROM (SELECT a as app_table_a FROM app_table)", + r#"SELECT app_table_a FROM (SELECT remote_table.a AS app_table_a FROM remote_table)"#, + ), + ( + "SELECT aapp_table FROM (SELECT a as aapp_table FROM app_table)", + r#"SELECT aapp_table FROM (SELECT remote_table.a AS aapp_table FROM remote_table)"#, + ), + ]; + + for test in tests { + test_sql(&ctx, test.0, test.1, false).await?; + } + + Ok(()) + } + + #[tokio::test] + async fn test_rewrite_table_scans_unnest() -> Result<()> { + init_tracing(); + let ctx = get_test_df_context(); + + let tests = vec![ + ( + "SELECT UNNEST([1, 2, 2, 5, NULL]), b, c from app_table where a > 10 order by b limit 10;", + r#"SELECT UNNEST(make_array(1, 2, 2, 5, NULL)) AS "UNNEST(make_array(Int64(1),Int64(2),Int64(2),Int64(5),NULL))", remote_table.b, remote_table.c FROM remote_table WHERE (remote_table.a > 10) ORDER BY remote_table.b ASC NULLS LAST LIMIT 10"#, + ), + ( + "SELECT UNNEST(app_table.d), b, c from app_table where a > 10 order by b limit 10;", + r#"SELECT UNNEST(remote_table.d) AS "UNNEST(app_table.d)", remote_table.b, remote_table.c FROM remote_table WHERE (remote_table.a > 10) ORDER BY remote_table.b ASC NULLS LAST LIMIT 10"#, + ), + ( + "SELECT sum(b.x) AS total FROM (SELECT UNNEST(d) AS x from app_table where a > 0) AS b;", + r#"SELECT sum(b.x) AS total FROM (SELECT UNNEST(remote_table.d) AS x FROM remote_table WHERE (remote_table.a > 0)) AS b"#, + ), + ]; + + for test in tests { + test_sql(&ctx, test.0, test.1, false).await?; + } + + Ok(()) + } + + #[tokio::test] + async fn test_subquery_requires_partial_path() -> Result<()> { + init_tracing(); + let ctx = get_test_df_context(); + + let tests = vec![ + ( + "SELECT a FROM bar where a IN (SELECT a FROM bar)", + r#"SELECT remote_db.remote_schema.remote_table.a FROM remote_db.remote_schema.remote_table WHERE remote_db.remote_schema.remote_table.a IN (SELECT a FROM remote_db.remote_schema.remote_table)"#, + true, + ), + ( + "SELECT a FROM bar where a IN (SELECT a FROM bar)", + r#"SELECT remote_db.remote_schema.remote_table.a FROM remote_db.remote_schema.remote_table WHERE remote_db.remote_schema.remote_table.a IN (SELECT remote_db.remote_schema.remote_table.a FROM remote_db.remote_schema.remote_table)"#, + false, + ), + ]; + + for test in tests { + test_sql(&ctx, test.0, test.1, test.2).await?; + } + + Ok(()) + } + + #[tokio::test] + async fn test_rewrite_same_column_table_name() -> Result<()> { + init_tracing(); + let ctx = get_test_df_context(); + + let tests = vec![( + "SELECT app_table FROM (SELECT a app_table from app_table limit 100);", + r#"SELECT app_table FROM (SELECT remote_table.a AS app_table FROM remote_table LIMIT 100)"#, + )]; + + for test in tests { + test_sql(&ctx, test.0, test.1, false).await?; + } + + Ok(()) + } + + async fn test_sql( + ctx: &SessionContext, + sql_query: &str, + expected_sql: &str, + subquery_uses_partial_path: bool, + ) -> Result<(), datafusion::error::DataFusionError> { + let data_frame = ctx.sql(sql_query).await?; + + println!("before optimization: \n{:#?}", data_frame.logical_plan()); + + let mut known_rewrites = HashMap::new(); + let rewritten_plan = rewrite_table_scans( + data_frame.logical_plan(), + &mut known_rewrites, + subquery_uses_partial_path, + &mut None, + )?; + + println!("rewritten_plan: \n{:#?}", rewritten_plan); + + let unparsed_sql = plan_to_sql(&rewritten_plan)?; + + println!("unparsed_sql: \n{unparsed_sql}"); + + assert_eq!( + format!("{unparsed_sql}"), + expected_sql, + "SQL under test: {}", + sql_query + ); + + Ok(()) + } + + #[tokio::test] + async fn test_rewrite_table_scans_limit_offset() -> Result<()> { + init_tracing(); + let ctx = get_test_df_context(); + + let tests = vec![ + // Basic LIMIT + ( + "SELECT a FROM foo.df_table LIMIT 5", + r#"SELECT remote_table.a FROM remote_table LIMIT 5"#, + ), + // Basic OFFSET + ( + "SELECT a FROM foo.df_table OFFSET 5", + r#"SELECT remote_table.a FROM remote_table OFFSET 5"#, + ), + // OFFSET after LIMIT + ( + "SELECT a FROM foo.df_table LIMIT 10 OFFSET 5", + r#"SELECT remote_table.a FROM remote_table LIMIT 10 OFFSET 5"#, + ), + // LIMIT after OFFSET + ( + "SELECT a FROM foo.df_table OFFSET 5 LIMIT 10", + r#"SELECT remote_table.a FROM remote_table LIMIT 10 OFFSET 5"#, + ), + // Zero OFFSET + ( + "SELECT a FROM foo.df_table OFFSET 0", + r#"SELECT remote_table.a FROM remote_table OFFSET 0"#, + ), + // Zero LIMIT + ( + "SELECT a FROM foo.df_table LIMIT 0", + r#"SELECT remote_table.a FROM remote_table LIMIT 0"#, + ), + // Zero LIMIT and OFFSET + ( + "SELECT a FROM foo.df_table LIMIT 0 OFFSET 0", + r#"SELECT remote_table.a FROM remote_table LIMIT 0 OFFSET 0"#, + ), + ]; + + for test in tests { + test_sql(&ctx, test.0, test.1, false).await?; + } + + Ok(()) + } +} diff --git a/sources/sql/src/schema.rs b/sources/sql/src/schema.rs index c780f23..86c58ff 100644 --- a/sources/sql/src/schema.rs +++ b/sources/sql/src/schema.rs @@ -2,8 +2,7 @@ use async_trait::async_trait; use datafusion::logical_expr::{TableSource, TableType}; use datafusion::{ - arrow::datatypes::SchemaRef, catalog::schema::SchemaProvider, datasource::TableProvider, - error::Result, + arrow::datatypes::SchemaRef, catalog::SchemaProvider, datasource::TableProvider, error::Result, }; use futures::future::join_all; use std::{any::Any, sync::Arc}; @@ -14,6 +13,7 @@ use datafusion_federation::{ use crate::SQLFederationProvider; +#[derive(Debug)] pub struct SQLSchemaProvider { // provider: Arc, tables: Vec>, @@ -21,7 +21,7 @@ pub struct SQLSchemaProvider { impl SQLSchemaProvider { pub async fn new(provider: Arc) -> Result { - let tables = provider.clone().executor.table_names().await?; + let tables = Arc::clone(&provider).executor.table_names().await?; Self::new_with_tables(provider, tables).await } @@ -32,7 +32,7 @@ impl SQLSchemaProvider { ) -> Result { let futures: Vec<_> = tables .into_iter() - .map(|t| SQLTableSource::new(provider.clone(), t)) + .map(|t| SQLTableSource::new(Arc::clone(&provider), t)) .collect(); let results: Result> = join_all(futures).await.into_iter().collect(); let sources = results?.into_iter().map(Arc::new).collect(); @@ -60,7 +60,9 @@ impl SchemaProvider for SQLSchemaProvider { .iter() .find(|s| s.table_name.eq_ignore_ascii_case(name)) { - let adaptor = FederatedTableProviderAdaptor::new(source.clone()); + let adaptor = FederatedTableProviderAdaptor::new( + Arc::clone(source) as Arc + ); return Ok(Some(Arc::new(adaptor))); } Ok(None) @@ -73,6 +75,7 @@ impl SchemaProvider for SQLSchemaProvider { } } +#[derive(Debug)] pub struct MultiSchemaProvider { children: Vec>, } @@ -113,11 +116,19 @@ pub struct SQLTableSource { schema: SchemaRef, } +impl std::fmt::Debug for SQLTableSource { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SQLTableSource") + .field("table_name", &self.table_name) + .field("schema", &self.schema) + .finish() + } +} + impl SQLTableSource { // creates a SQLTableSource and infers the table schema pub async fn new(provider: Arc, table_name: String) -> Result { - let schema = provider - .clone() + let schema = Arc::clone(&provider) .executor .get_table_schema(table_name.as_str()) .await?; @@ -135,11 +146,15 @@ impl SQLTableSource { schema, }) } + + pub fn table_name(&self) -> &str { + self.table_name.as_str() + } } impl FederatedTableSource for SQLTableSource { fn federation_provider(&self) -> Arc { - self.provider.clone() + Arc::clone(&self.provider) as Arc } } @@ -148,7 +163,7 @@ impl TableSource for SQLTableSource { self } fn schema(&self) -> SchemaRef { - self.schema.clone() + Arc::clone(&self.schema) } fn table_type(&self) -> TableType { TableType::Temporary