From 79934df21bc505c4353138c5c38ce1973419dbc8 Mon Sep 17 00:00:00 2001 From: Michiel De Backker Date: Fri, 10 May 2024 08:18:27 +0200 Subject: [PATCH] barebones sub-query support --- datafusion-federation/src/analyzer.rs | 352 +++++++++++++++++++------- examples/examples/sqlite-subquery.rs | 63 +++++ 2 files changed, 326 insertions(+), 89 deletions(-) create mode 100644 examples/examples/sqlite-subquery.rs diff --git a/datafusion-federation/src/analyzer.rs b/datafusion-federation/src/analyzer.rs index 1aea4fd..59dff9d 100644 --- a/datafusion-federation/src/analyzer.rs +++ b/datafusion-federation/src/analyzer.rs @@ -1,17 +1,19 @@ use std::sync::Arc; +use datafusion::common::not_impl_err; use datafusion::common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; +use datafusion::optimizer::analyzer::Analyzer; use datafusion::{ config::ConfigOptions, datasource::source_as_provider, error::Result, logical_expr::{Expr, LogicalPlan, Projection, TableScan, TableSource}, - error::{DataFusionError, Result}, - logical_expr::{Expr, LogicalPlan, Projection, TableScan, TableSource}, optimizer::analyzer::AnalyzerRule, }; -use crate::{FederatedTableProviderAdaptor, FederatedTableSource, FederationProviderRef}; +use crate::{ + FederatedTableProviderAdaptor, FederatedTableSource, FederationProvider, FederationProviderRef, +}; #[derive(Default)] pub struct FederationAnalyzerRule {} @@ -21,7 +23,7 @@ impl AnalyzerRule for FederationAnalyzerRule { // TableScans from the same FederationProvider. // There 'largest sub-trees' are passed to their respective FederationProvider.optimizer. fn analyze(&self, plan: LogicalPlan, config: &ConfigOptions) -> Result { - let (optimized, _) = self.optimize_recursively(&plan, None, config)?; + let (optimized, _) = self.optimize_plan_recursively(&plan, true, config)?; if let Some(result) = optimized { return Ok(result); } @@ -34,129 +36,301 @@ impl AnalyzerRule for FederationAnalyzerRule { } } +// tri-state: +// None: no providers +// Some(None): ambiguous +// Some(Some(provider)): sole provider +type ScanResult = Option>; + impl FederationAnalyzerRule { pub fn new() -> Self { Self::default() } + // scans a plan to see if it belongs to a single FederationProvider + fn scan_plan_recursively(&self, plan: &LogicalPlan) -> Result { + let mut sole_provider: ScanResult = None; + + plan.apply(&mut |p: &LogicalPlan| -> Result { + let exprs_provider = self.scan_plan_exprs(plan)?; + let recursion = merge_scan_result(&mut sole_provider, exprs_provider); + if recursion == TreeNodeRecursion::Stop { + return Ok(recursion); + } + + let sub_provider = get_leaf_provider(p)?; + Ok(proc_scan_result(&mut sole_provider, sub_provider)) + })?; + + Ok(sole_provider) + } + + // scans a plan's expressions to see if it belongs to a single FederationProvider + fn scan_plan_exprs(&self, plan: &LogicalPlan) -> Result { + let mut sole_provider: ScanResult = None; + + let exprs = plan.expressions(); + for expr in &exprs { + let expr_result = self.scan_expr_recursively(expr)?; + let recursion = merge_scan_result(&mut sole_provider, expr_result); + if recursion == TreeNodeRecursion::Stop { + return Ok(sole_provider); + } + } + + Ok(sole_provider) + } + + // scans an expression to see if it belongs to a single FederationProvider + fn scan_expr_recursively(&self, expr: &Expr) -> Result { + let mut sole_provider: ScanResult = None; + + expr.apply(&mut |e: &Expr| -> Result { + // TODO: Support other types of sub-queries + if let Expr::ScalarSubquery(ref subquery) = e { + let plan_result = self.scan_plan_recursively(&subquery.subquery)?; + Ok(merge_scan_result(&mut sole_provider, plan_result)) + } else { + Ok(TreeNodeRecursion::Continue) + } + })?; + + Ok(sole_provider) + } + // optimize_recursively recursively finds the largest sub-plans that can be federated // to a single FederationProvider. // Returns a plan if a sub-tree was federated, otherwise None. // Returns a FederationProvider if it covers the entire sub-tree, otherwise None. - fn optimize_recursively( + fn optimize_plan_recursively( &self, plan: &LogicalPlan, - parent: Option<&LogicalPlan>, + is_root: bool, _config: &ConfigOptions, - ) -> Result<(Option, Option)> { - // Check if this node determines the FederationProvider - let sole_provider = self.get_federation_provider(plan)?; - if sole_provider.is_some() { - return Ok((None, sole_provider)); + ) -> Result<(Option, ScanResult)> { + // Used to track if all sources, including tableScan, plan inputs and + // expressions, represents an un-ambiguous or 'sole' FederationProvider + let mut sole_provider: ScanResult = None; + + // Check if this plan node is a leaf that determines the FederationProvider + let leaf_provider = get_leaf_provider(plan)?; + + // Check if the expressions contain, a potentially different, FederationProvider + let exprs_result = self.scan_plan_exprs(plan)?; + let optimize_expressions = exprs_result.is_some(); + + // Return early if this is a leaf and there is no ambiguity with the expressions. + if leaf_provider.is_some() + && (exprs_result.is_none() || scan_result_eq(&exprs_result, &leaf_provider)) + { + return Ok((None, Some(leaf_provider))); } + // Aggregate leaf & expression providers + proc_scan_result(&mut sole_provider, leaf_provider); + merge_scan_result(&mut sole_provider, exprs_result); - // optimize_inputs let inputs = plan.inputs(); - let expressions = plan.expressions(); - if inputs.is_empty() && expressions.is_empty() { + // Return early if there are no sources. + if inputs.is_empty() && sole_provider.is_none() { return Ok((None, None)); } - // Optimize expressions - let mut new_expressions = vec![]; - let optimize_exprs = |expr: Expr| { - if let Expr::ScalarSubquery(ref subquery) = expr { - let (new_subquery, _) = - self.optimize_recursively(&subquery.subquery, parent, _config)?; - if let Some(new_subquery) = new_subquery { - return Ok(Transformed::new( - Expr::ScalarSubquery(subquery.with_plan(new_subquery.into())), - true, - TreeNodeRecursion::Continue, - )); - } - } + // Recursively optimize inputs + let input_results = inputs + .iter() + .map(|i| self.optimize_plan_recursively(i, false, _config)) + .collect::>>()?; + + // Aggregate the input providers + input_results.iter().for_each(|(_, scan_result)| { + merge_scan_result(&mut sole_provider, scan_result.clone()); + }); - Ok(Transformed::new(expr, false, TreeNodeRecursion::Continue)) + let Some(sole_provider) = sole_provider else { + // No providers found + // TODO: Is/should this be reachable? + return Ok((None, None)); }; - for expr in &expressions { - let transformed = expr.clone().transform(&optimize_exprs); - new_expressions.push(transformed.unwrap().data); - } + // If all sources are federated to the same provider + if let Some(provider) = sole_provider { + if !is_root { + // The largest sub-plan is higher up. + return Ok((None, Some(Some(provider)))); + } - // Optimize inputs - let (new_inputs, providers): (Vec<_>, Vec<_>) = inputs - .iter() - .map(|i| self.optimize_recursively(i, Some(plan), _config)) - .collect::>>()? - .into_iter() - .unzip(); - - // Note: assumes provider is None if ambiguous - let first_provider = providers.first().unwrap(); - let is_singular = providers.iter().all(|p| p.is_some() && p == first_provider); - - if is_singular { - if parent.is_none() { - // federate the entire plan - if let Some(provider) = first_provider { - if let Some(optimizer) = provider.analyzer() { - let optimized = optimizer.execute_and_check(plan, _config, |_, _| {})?; - return Ok((Some(optimized), None)); - } - return Ok((None, None)); - } + let Some(optimizer) = provider.analyzer() else { + // No optimizer provided return Ok((None, None)); - } - // The largest sub-plan is higher up. - return Ok((None, first_provider.clone())); + }; + + // If this is the root plan node; federate the entire plan + let optimized = optimizer.execute_and_check(plan, _config, |_, _| {})?; + return Ok((Some(optimized), None)); } - // The plan is ambiguous, any inputs that are not federated and - // have a sole provider, should be federated. - let new_inputs = new_inputs + // The plan is ambiguous; any input that is not yet optimized and has a + // sole provider represents a largest sub-plan and should be federated. + // + // We loop over the input optimization results, federate where needed and + // return a complete list of new inputs for the optimized plan. + let new_inputs = input_results .into_iter() .enumerate() - .map(|(i, new_sub_plan)| { - if let Some(sub_plan) = new_sub_plan { - // Already federated - return Ok(sub_plan); + .map(|(i, (input_plan, provider))| { + if let Some(federated_plan) = input_plan { + // Already federated deeper in the plan tree + return Ok(federated_plan); } - let sub_plan = inputs.get(i).unwrap(); - // Check if the input has a sole provider and can be federated. - if let Some(provider) = providers.get(i).unwrap() { - if let Some(optimizer) = provider.analyzer() { - let wrapped = wrap_projection((*sub_plan).clone())?; - - let optimized = - optimizer.execute_and_check(&wrapped, _config, |_, _| {})?; - return Ok(optimized); - } - // No federation for this sub-plan (no analyzer) - return Ok((*sub_plan).clone()); - } - // No federation for this sub-plan (no provider) - Ok((*sub_plan).clone()) + let provider = provider.unwrap(); + let original_input = (*inputs.get(i).unwrap()).clone(); + let Some(provider) = provider else { + // No provider for this input; use the original input. + return Ok(original_input); + }; + + let Some(optimizer) = provider.analyzer() else { + // No optimizer for this input; use the original input. + return Ok(original_input); + }; + + // Replace the input with the federated counterpart + let wrapped = wrap_projection(original_input)?; + let optimized = optimizer.execute_and_check(&wrapped, _config, |_, _| {})?; + + Ok(optimized) }) .collect::>>()?; + // Optimize expressions if needed + let new_expressions = if optimize_expressions { + self.optimize_plan_exprs(plan, _config)? + } else { + plan.expressions() + }; + + // Construct the optimized plan let new_plan = plan.with_new_exprs(new_expressions, new_inputs)?; - Ok((Some(new_plan), None)) + // Return the federated plan and Some(None) meaning "ambiguous provider" + Ok((Some(new_plan), Some(None))) } - fn get_federation_provider(&self, plan: &LogicalPlan) -> Result> { - match plan { - LogicalPlan::TableScan(TableScan { ref source, .. }) => { - let Some(federated_source) = get_table_source(source)? else { - return Ok(None); + // Optimize all exprs of a plan + fn optimize_plan_exprs( + &self, + plan: &LogicalPlan, + _config: &ConfigOptions, + ) -> Result> { + plan.expressions() + .iter() + .map(|expr| { + let transformed = expr + .clone() + .transform(&|e| self.optimize_expr_recursively(e, _config))?; + Ok(transformed.data) + }) + .collect::>>() + } + + // recursively optimize expressions + // Current logic: individually federate every sub-query. + fn optimize_expr_recursively( + &self, + expr: Expr, + _config: &ConfigOptions, + ) -> Result> { + match expr { + Expr::ScalarSubquery(ref subquery) => { + // Optimize as root to force federating the sub-query + let (new_subquery, _) = + self.optimize_plan_recursively(&subquery.subquery, true, _config)?; + let Some(new_subquery) = new_subquery else { + return Ok(Transformed::no(expr)); }; - let provider = federated_source.federation_provider(); - Ok(Some(provider)) + Ok(Transformed::yes(Expr::ScalarSubquery( + subquery.with_plan(new_subquery.into()), + ))) } - _ => Ok(None), + Expr::InSubquery(_) => not_impl_err!("InSubquery"), + _ => Ok(Transformed::no(expr)), + } + } +} + +fn scan_result_eq(result: &ScanResult, provider: &Option) -> bool { + match (result, provider) { + (None, _) => false, + (Some(left), right) => left == right, + } +} + +fn merge_scan_result(result: &mut ScanResult, other: ScanResult) -> TreeNodeRecursion { + match (&result, other) { + (_, None) => TreeNodeRecursion::Continue, + (_, Some(None)) => { + *result = Some(None); + TreeNodeRecursion::Stop + } + (_, Some(other)) => proc_scan_result(result, other), + } +} + +fn proc_scan_result( + result: &mut ScanResult, + provider: Option, +) -> TreeNodeRecursion { + match (&result, provider) { + (_, None) => TreeNodeRecursion::Continue, // No provider in this plan + (Some(None), _) => { + // Should be unreadable + TreeNodeRecursion::Stop + } + (None, Some(provider)) => { + *result = Some(Some(provider)); + TreeNodeRecursion::Continue + } + (Some(Some(result_unwrapped)), Some(provider_unwrapped)) => { + if *result_unwrapped == provider_unwrapped { + TreeNodeRecursion::Continue + } else { + *result = Some(None); + TreeNodeRecursion::Stop + } + } + } +} + +// NopFederationProvider is used to represent tables that are not federated, but +// are resolved by DataFusion. This simplifies the logic of the optimizer rule. +struct NopFederationProvider {} + +impl FederationProvider for NopFederationProvider { + fn name(&self) -> &str { + "nop" + } + + fn compute_context(&self) -> Option { + None + } + + fn analyzer(&self) -> Option> { + None + } +} + +fn get_leaf_provider(plan: &LogicalPlan) -> Result> { + match plan { + LogicalPlan::TableScan(TableScan { ref source, .. }) => { + let Some(federated_source) = get_table_source(source)? else { + // Table is not federated but provided by a standard table provider. + // We use a placeholder federation provider to simplify the logic. + return Ok(Some(Arc::new(NopFederationProvider {}))); + }; + let provider = federated_source.federation_provider(); + Ok(Some(provider)) } + _ => Ok(None), } } diff --git a/examples/examples/sqlite-subquery.rs b/examples/examples/sqlite-subquery.rs new file mode 100644 index 0000000..4f5043d --- /dev/null +++ b/examples/examples/sqlite-subquery.rs @@ -0,0 +1,63 @@ +use std::sync::Arc; + +use datafusion::{ + catalog::schema::SchemaProvider, + error::Result, + execution::context::{SessionContext, SessionState}, +}; +use datafusion_federation::{FederatedQueryPlanner, FederationAnalyzerRule}; +use datafusion_federation_sql::{connectorx::CXExecutor, SQLFederationProvider, SQLSchemaProvider}; + +#[tokio::main] +async fn main() -> datafusion::error::Result<()> { + let dsn = "sqlite://./examples/examples/chinook.sqlite".to_string(); + let known_tables: Vec = ["Track", "Album", "Artist"] + .iter() + .map(|&x| x.into()) + .collect(); + + let state = SessionContext::new().state(); + + // Register FederationAnalyzer + // TODO: Interaction with other analyzers & optimizers. + let state = state + .add_analyzer_rule(Arc::new(FederationAnalyzerRule::new())) + .with_query_planner(Arc::new(FederatedQueryPlanner::new())); + + // Register schema + // TODO: table inference + let executor = Arc::new(CXExecutor::new(dsn)?); + let provider = Arc::new(SQLFederationProvider::new(executor)); + let schema_provider = + Arc::new(SQLSchemaProvider::new_with_tables(provider, known_tables).await?); + overwrite_default_schema(&state, schema_provider)?; + + // Run query + let ctx = SessionContext::new_with_state(state); + let query = r#"SELECT Name, (SELECT Title FROM Album limit 1) FROM Artist limit 1"#; + // let query = r#"SELECT ArtistId, Name, (SELECT Title FROM Album where ArtistId = a.ArtistId limit 1) FROM Artist a limit 1"#; + let df = ctx.sql(query).await?; + df.show().await?; + + // If the environment variable EXPLAIN is set, print the query plan + if std::env::var("EXPLAIN").is_ok() { + let explain_query = format!("EXPLAIN {query}"); + let df = ctx.sql(explain_query.as_str()).await?; + + df.show().await?; + } + + Ok(()) +} + +fn overwrite_default_schema(state: &SessionState, schema: Arc) -> Result<()> { + let options = &state.config().options().catalog; + let catalog = state + .catalog_list() + .catalog(options.default_catalog.as_str()) + .unwrap(); + + catalog.register_schema(options.default_schema.as_str(), schema)?; + + Ok(()) +}