diff --git a/datafusion-federation/src/sql/executor.rs b/datafusion-federation/src/sql/executor.rs index 6042f77..275bc68 100644 --- a/datafusion-federation/src/sql/executor.rs +++ b/datafusion-federation/src/sql/executor.rs @@ -2,11 +2,12 @@ use async_trait::async_trait; use core::fmt; use datafusion::{ arrow::datatypes::SchemaRef, error::Result, physical_plan::SendableRecordBatchStream, - sql::unparser::dialect::Dialect, + sql::sqlparser::ast, sql::unparser::dialect::Dialect, }; use std::sync::Arc; pub type SQLExecutorRef = Arc; +pub type AstAnalyzer = Box Result>; #[async_trait] pub trait SQLExecutor: Sync + Send { @@ -20,6 +21,11 @@ pub trait SQLExecutor: Sync + Send { // The specific SQL dialect (currently supports 'sqlite', 'postgres', 'flight') fn dialect(&self) -> Arc; + /// Returns an AST analyzer specific for this engine to modify the AST before execution + fn ast_analyzer(&self) -> Option { + None + } + // Execution /// Execute a SQL query fn execute(&self, query: &str, schema: SchemaRef) -> Result; diff --git a/datafusion-federation/src/sql/mod.rs b/datafusion-federation/src/sql/mod.rs index 830a99d..9489173 100644 --- a/datafusion-federation/src/sql/mod.rs +++ b/datafusion-federation/src/sql/mod.rs @@ -645,8 +645,13 @@ impl VirtualExecutionPlan { fn sql(&self) -> Result { // Find all table scans, recover the SQLTableSource, find the remote table name and replace the name of the TableScan table. let mut known_rewrites = HashMap::new(); - let ast = Unparser::new(self.executor.dialect().as_ref()) + let mut ast = Unparser::new(self.executor.dialect().as_ref()) .plan_to_sql(&rewrite_table_scans(&self.plan, &mut known_rewrites)?)?; + + if let Some(analyzer) = self.executor.ast_analyzer() { + ast = analyzer(ast)?; + } + Ok(format!("{ast}")) } } @@ -660,7 +665,8 @@ impl DisplayAs for VirtualExecutionPlan { write!(f, " name={}", self.executor.name())?; if let Some(ctx) = self.executor.compute_context() { write!(f, " compute_context={ctx}")?; - } + }; + write!(f, " sql={ast}")?; if let Ok(query) = self.sql() { write!(f, " rewritten_sql={query}")?;