diff --git a/.gitignore b/.gitignore index 66b190e..c92a89f 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ /node_modules package-lock.json package.json +.DS_Store \ No newline at end of file diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..6982f30 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,213 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "type": "lldb", + "request": "launch", + "name": "Debug unit tests in library 'datafusion_federation'", + "cargo": { + "args": [ + "test", + "--no-run", + "--lib", + "--package=datafusion-federation" + ], + "filter": { + "name": "datafusion_federation", + "kind": "lib" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug example 'sqlite'", + "cargo": { + "args": [ + "build", + "--example=sqlite", + "--package=datafusion-federation-examples" + ], + "filter": { + "name": "sqlite", + "kind": "example" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug unit tests in example 'sqlite'", + "cargo": { + "args": [ + "test", + "--no-run", + "--example=sqlite", + "--package=datafusion-federation-examples" + ], + "filter": { + "name": "sqlite", + "kind": "example" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug example 'flight-sql'", + "cargo": { + "args": [ + "build", + "--example=flight-sql", + "--package=datafusion-federation-examples" + ], + "filter": { + "name": "flight-sql", + "kind": "example" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug unit tests in example 'flight-sql'", + "cargo": { + "args": [ + "test", + "--no-run", + "--example=flight-sql", + "--package=datafusion-federation-examples" + ], + "filter": { + "name": "flight-sql", + "kind": "example" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug example 'postgres-partial'", + "cargo": { + "args": [ + "build", + "--example=postgres-partial", + "--package=datafusion-federation-examples" + ], + "filter": { + "name": "postgres-partial", + "kind": "example" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug unit tests in example 'postgres-partial'", + "cargo": { + "args": [ + "test", + "--no-run", + "--example=postgres-partial", + "--package=datafusion-federation-examples" + ], + "filter": { + "name": "postgres-partial", + "kind": "example" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug example 'sqlite-partial'", + "cargo": { + "args": [ + "build", + "--example=sqlite-partial", + "--package=datafusion-federation-examples" + ], + "filter": { + "name": "sqlite-partial", + "kind": "example" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug unit tests in example 'sqlite-partial'", + "cargo": { + "args": [ + "test", + "--no-run", + "--example=sqlite-partial", + "--package=datafusion-federation-examples" + ], + "filter": { + "name": "sqlite-partial", + "kind": "example" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug unit tests in library 'datafusion_federation_flight_sql'", + "cargo": { + "args": [ + "test", + "--no-run", + "--lib", + "--package=datafusion-federation-flight-sql" + ], + "filter": { + "name": "datafusion_federation_flight_sql", + "kind": "lib" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug unit tests in library 'datafusion_federation_sql'", + "cargo": { + "args": [ + "test", + "--no-run", + "--lib", + "--package=datafusion-federation-sql" + ], + "filter": { + "name": "datafusion_federation_sql", + "kind": "lib" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + } + ] +} \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index 83ca62e..2540ee5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,12 +8,15 @@ members = [ ] [workspace.package] -version = "0.1.3" +version = "0.1.6" edition = "2021" -license = "MIT" +license = "Apache-2.0" readme = "README.md" [workspace.dependencies] async-trait = "0.1.81" +async-stream = "0.3.5" +futures = "0.3.30" datafusion = "41.0.0" datafusion-substrait = "41.0.0" +arrow-json = "52.2.0" diff --git a/datafusion-federation/Cargo.toml b/datafusion-federation/Cargo.toml index 6af448c..1d80673 100644 --- a/datafusion-federation/Cargo.toml +++ b/datafusion-federation/Cargo.toml @@ -17,16 +17,20 @@ all-features = true no-default-features = true [features] -sql = ["futures"] +sql = [] [dependencies] +futures.workspace = true async-trait.workspace = true datafusion.workspace = true +async-stream.workspace = true +arrow-json.workspace = true -futures = { version = "0.3.30", optional = true } [dev-dependencies] tokio = { version = "1.39.3", features = ["full"] } +tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } +tracing = "0.1.40" [[example]] name = "df-csv" diff --git a/datafusion-federation/examples/df-csv.rs b/datafusion-federation/examples/df-csv.rs index 24ac495..0d83f63 100644 --- a/datafusion-federation/examples/df-csv.rs +++ b/datafusion-federation/examples/df-csv.rs @@ -10,7 +10,7 @@ use datafusion::{ options::CsvReadOptions, }, physical_plan::{stream::RecordBatchStreamAdapter, SendableRecordBatchStream}, - sql::sqlparser::dialect::{Dialect, GenericDialect}, + sql::unparser::dialect::{DefaultDialect, Dialect}, }; use datafusion_federation::sql::{SQLExecutor, SQLFederationProvider, SQLSchemaProvider}; use futures::TryStreamExt; @@ -106,7 +106,7 @@ impl SQLExecutor for InMemorySQLExecutor { } fn dialect(&self) -> Arc { - Arc::new(GenericDialect {}) + Arc::new(DefaultDialect {}) } } diff --git a/datafusion-federation/src/lib.rs b/datafusion-federation/src/lib.rs index b3eaff3..f7b4ab2 100644 --- a/datafusion-federation/src/lib.rs +++ b/datafusion-federation/src/lib.rs @@ -16,9 +16,15 @@ use datafusion::{ }; pub use optimizer::{get_table_source, FederationOptimizerRule}; -pub use plan_node::{FederatedPlanNode, FederatedQueryPlanner, FederationPlanner}; +pub use plan_node::{ + FederatedPlanNode, FederatedPlanner, FederatedQueryPlanner, FederationPlanner, +}; pub use table_provider::{FederatedTableProviderAdaptor, FederatedTableSource}; +// TODO clean up this +// TODO move schema_cast.rs to schema_cast directory +pub mod schema_cast; + pub fn default_session_state() -> SessionState { let rules = default_optimizer_rules(); SessionStateBuilder::new() diff --git a/datafusion-federation/src/plan_node.rs b/datafusion-federation/src/plan_node.rs index 5a8cd00..6473aca 100644 --- a/datafusion-federation/src/plan_node.rs +++ b/datafusion-federation/src/plan_node.rs @@ -54,7 +54,7 @@ impl UserDefinedLogicalNodeCore for FederatedPlanNode { } fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "Federated\n {:?}", self.plan) + write!(f, "Federated\n {}", self.plan) } fn with_exprs_and_inputs(&self, exprs: Vec, inputs: Vec) -> Result { @@ -126,7 +126,7 @@ impl Hash for FederatedPlanNode { } #[derive(Default)] -struct FederatedPlanner {} +pub struct FederatedPlanner {} impl FederatedPlanner { pub fn new() -> Self { @@ -152,7 +152,7 @@ impl ExtensionPlanner for FederatedPlanner { )); } - let fed_planner = fed_node.planner.clone(); + let fed_planner = Arc::clone(&fed_node.planner); let exec_plan = fed_planner.plan_federation(fed_node, session_state).await?; return Ok(Some(exec_plan)); } diff --git a/datafusion-federation/src/schema_cast.rs b/datafusion-federation/src/schema_cast.rs new file mode 100644 index 0000000..f38f65c --- /dev/null +++ b/datafusion-federation/src/schema_cast.rs @@ -0,0 +1,108 @@ +use async_stream::stream; +use datafusion::arrow::datatypes::SchemaRef; +use datafusion::error::{DataFusionError, Result}; +use datafusion::execution::{SendableRecordBatchStream, TaskContext}; +use datafusion::physical_plan::stream::RecordBatchStreamAdapter; +use datafusion::physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties, +}; +use futures::StreamExt; +use std::any::Any; +use std::clone::Clone; +use std::fmt; +use std::sync::Arc; + +mod intervals_cast; +mod lists_cast; +pub mod record_convert; +mod struct_cast; + +#[derive(Debug)] +#[allow(clippy::module_name_repetitions)] +pub struct SchemaCastScanExec { + input: Arc, + schema: SchemaRef, + properties: PlanProperties, +} + +impl SchemaCastScanExec { + pub fn new(input: Arc, schema: SchemaRef) -> Self { + let eq_properties = input.equivalence_properties().clone(); + let execution_mode = input.execution_mode(); + let properties = PlanProperties::new( + eq_properties, + input.output_partitioning().clone(), + execution_mode, + ); + Self { + input, + schema, + properties, + } + } +} + +impl DisplayAs for SchemaCastScanExec { + fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "SchemaCastScanExec") + } +} + +impl ExecutionPlan for SchemaCastScanExec { + fn name(&self) -> &str { + "SchemaCastScanExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.properties + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.input] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + if children.len() == 1 { + Ok(Arc::new(Self::new( + Arc::clone(&children[0]), + Arc::clone(&self.schema), + ))) + } else { + Err(DataFusionError::Execution( + "SchemaCastScanExec expects exactly one input".to_string(), + )) + } + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + let mut stream = self.input.execute(partition, context)?; + let schema = Arc::clone(&self.schema); + + Ok(Box::pin(RecordBatchStreamAdapter::new( + Arc::clone(&schema), + { + stream! { + while let Some(batch) = stream.next().await { + let batch = record_convert::try_cast_to(batch?, Arc::clone(&schema)); + yield batch.map_err(|e| { DataFusionError::External(Box::new(e)) }); + } + } + }, + ))) + } +} diff --git a/datafusion-federation/src/schema_cast/intervals_cast.rs b/datafusion-federation/src/schema_cast/intervals_cast.rs new file mode 100644 index 0000000..5fbd806 --- /dev/null +++ b/datafusion-federation/src/schema_cast/intervals_cast.rs @@ -0,0 +1,190 @@ +use datafusion::arrow::{ + array::{ + Array, ArrayRef, IntervalDayTimeBuilder, IntervalMonthDayNanoArray, + IntervalYearMonthBuilder, + }, + datatypes::{IntervalDayTimeType, IntervalYearMonthType}, + error::ArrowError, +}; +use std::sync::Arc; + +pub(crate) fn cast_interval_monthdaynano_to_yearmonth( + interval_monthdaynano_array: &dyn Array, +) -> Result { + let interval_monthdaynano_array = interval_monthdaynano_array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::CastError("Failed to cast IntervalMonthDayNanoArray: Unable to downcast to IntervalMonthDayNanoArray".to_string()) + })?; + + let mut interval_yearmonth_builder = + IntervalYearMonthBuilder::with_capacity(interval_monthdaynano_array.len()); + + for value in interval_monthdaynano_array { + match value { + None => interval_yearmonth_builder.append_null(), + Some(interval_monthdaynano_value) => { + if interval_monthdaynano_value.days != 0 + || interval_monthdaynano_value.nanoseconds != 0 + { + return Err(ArrowError::CastError( + "Failed to cast IntervalMonthDayNanoArray to IntervalYearMonthArray: Non-zero days or nanoseconds".to_string(), + )); + } + interval_yearmonth_builder.append_value(IntervalYearMonthType::make_value( + 0, + interval_monthdaynano_value.months, + )); + } + } + } + + Ok(Arc::new(interval_yearmonth_builder.finish())) +} + +#[allow(clippy::cast_possible_truncation)] +pub(crate) fn cast_interval_monthdaynano_to_daytime( + interval_monthdaynano_array: &dyn Array, +) -> Result { + let interval_monthdaynano_array = interval_monthdaynano_array + .as_any() + .downcast_ref::() + .ok_or_else(|| ArrowError::CastError("Failed to cast IntervalMonthDayNanoArray: Unable to downcast to IntervalMonthDayNanoArray".to_string()))?; + + let mut interval_daytime_builder = + IntervalDayTimeBuilder::with_capacity(interval_monthdaynano_array.len()); + + for value in interval_monthdaynano_array { + match value { + None => interval_daytime_builder.append_null(), + Some(interval_monthdaynano_value) => { + if interval_monthdaynano_value.months != 0 { + return Err( + ArrowError::CastError("Failed to cast IntervalMonthDayNanoArray to IntervalDayTimeArray: Non-zero months".to_string()), + ); + } + interval_daytime_builder.append_value(IntervalDayTimeType::make_value( + interval_monthdaynano_value.days, + (interval_monthdaynano_value.nanoseconds / 1_000_000) as i32, + )); + } + } + } + Ok(Arc::new(interval_daytime_builder.finish())) +} + +#[cfg(test)] +mod test { + use datafusion::arrow::{ + array::{IntervalDayTimeArray, IntervalYearMonthArray, RecordBatch}, + datatypes::{ + DataType, Field, IntervalDayTime, IntervalMonthDayNano, IntervalUnit, Schema, SchemaRef, + }, + }; + + use crate::schema_cast::record_convert::try_cast_to; + + use super::*; + + fn input_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new( + "interval_daytime", + DataType::Interval(IntervalUnit::MonthDayNano), + false, + ), + Field::new( + "interval_monthday_nano", + DataType::Interval(IntervalUnit::MonthDayNano), + false, + ), + Field::new( + "interval_yearmonth", + DataType::Interval(IntervalUnit::MonthDayNano), + false, + ), + ])) + } + + fn output_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new( + "interval_daytime", + DataType::Interval(IntervalUnit::DayTime), + false, + ), + Field::new( + "interval_monthday_nano", + DataType::Interval(IntervalUnit::MonthDayNano), + false, + ), + Field::new( + "interval_yearmonth", + DataType::Interval(IntervalUnit::YearMonth), + false, + ), + ])) + } + + fn batch_input() -> RecordBatch { + let interval_daytime_array = IntervalMonthDayNanoArray::from(vec![ + IntervalMonthDayNano::new(0, 1, 1_000_000_000), + IntervalMonthDayNano::new(0, 33, 0), + IntervalMonthDayNano::new(0, 0, 43_200_000_000_000), + ]); + let interval_monthday_nano_array = IntervalMonthDayNanoArray::from(vec![ + IntervalMonthDayNano::new(1, 2, 1000), + IntervalMonthDayNano::new(12, 1, 0), + IntervalMonthDayNano::new(0, 0, 12 * 1000 * 1000), + ]); + let interval_yearmonth_array = IntervalMonthDayNanoArray::from(vec![ + IntervalMonthDayNano::new(2, 0, 0), + IntervalMonthDayNano::new(25, 0, 0), + IntervalMonthDayNano::new(-1, 0, 0), + ]); + + RecordBatch::try_new( + input_schema(), + vec![ + Arc::new(interval_daytime_array), + Arc::new(interval_monthday_nano_array), + Arc::new(interval_yearmonth_array), + ], + ) + .expect("Failed to created arrow interval record batch") + } + + fn batch_expected() -> RecordBatch { + let interval_daytime_array = IntervalDayTimeArray::from(vec![ + IntervalDayTime::new(1, 1000), + IntervalDayTime::new(33, 0), + IntervalDayTime::new(0, 12 * 60 * 60 * 1000), + ]); + let interval_monthday_nano_array = IntervalMonthDayNanoArray::from(vec![ + IntervalMonthDayNano::new(1, 2, 1000), + IntervalMonthDayNano::new(12, 1, 0), + IntervalMonthDayNano::new(0, 0, 12 * 1000 * 1000), + ]); + let interval_yearmonth_array = IntervalYearMonthArray::from(vec![2, 25, -1]); + + RecordBatch::try_new( + output_schema(), + vec![ + Arc::new(interval_daytime_array), + Arc::new(interval_monthday_nano_array), + Arc::new(interval_yearmonth_array), + ], + ) + .expect("Failed to created arrow interval record batch") + } + + #[test] + fn test_cast_interval_with_schema() { + let input_batch = batch_input(); + let expected = batch_expected(); + let actual = try_cast_to(input_batch, output_schema()).expect("cast should succeed"); + + assert_eq!(actual, expected); + } +} diff --git a/datafusion-federation/src/schema_cast/lists_cast.rs b/datafusion-federation/src/schema_cast/lists_cast.rs new file mode 100644 index 0000000..9a63b28 --- /dev/null +++ b/datafusion-federation/src/schema_cast/lists_cast.rs @@ -0,0 +1,619 @@ +use arrow_json::ReaderBuilder; +use datafusion::arrow::{ + array::{ + Array, ArrayRef, BooleanArray, BooleanBuilder, FixedSizeListBuilder, Float32Array, + Float32Builder, Float64Array, Float64Builder, Int16Array, Int16Builder, Int32Array, + Int32Builder, Int64Array, Int64Builder, Int8Array, Int8Builder, LargeListBuilder, + LargeStringArray, LargeStringBuilder, ListArray, ListBuilder, StringArray, StringBuilder, + }, + datatypes::{DataType, Field, FieldRef}, + error::ArrowError, +}; +use std::sync::Arc; + +pub type Result = std::result::Result; + +macro_rules! cast_string_to_list_array { + ($string_array:expr, $field_name:expr, $data_type:expr, $builder_type:expr, $primitive_type:ty) => {{ + let item_field = Arc::new(Field::new($field_name, $data_type, true)); + let mut list_builder = ListBuilder::with_capacity($builder_type, $string_array.len()) + .with_field(Arc::clone(&item_field)); + + let list_field = Arc::new(Field::new_list("i", item_field, true)); + let mut decoder = ReaderBuilder::new_with_field(Arc::clone(&list_field)) + .build_decoder() + .map_err(|e| ArrowError::CastError(format!("Failed to create decoder: {e}")))?; + + for value in $string_array { + match value { + None => list_builder.append_null(), + Some(string_value) => { + decoder.decode(string_value.as_bytes()).map_err(|e| { + ArrowError::CastError(format!("Failed to decode value: {e}")) + })?; + + if let Some(b) = decoder.flush().map_err(|e| { + ArrowError::CastError(format!("Failed to decode decoder: {e}")) + })? { + let list_array = b + .column(0) + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::CastError( + "Failed to decode value: unable to downcast to ListArray" + .to_string(), + ) + })?; + let primitive_array = list_array + .values() + .as_any() + .downcast_ref::<$primitive_type>() + .ok_or_else(|| { + ArrowError::CastError( + "Failed to decode value: unable to downcast to PrimitiveType" + .to_string(), + ) + })?; + primitive_array + .iter() + .for_each(|maybe_value| match maybe_value { + Some(value) => list_builder.values().append_value(value), + None => list_builder.values().append_null(), + }); + list_builder.append(true); + } + } + } + } + + Ok(Arc::new(list_builder.finish())) + }}; +} + +macro_rules! cast_string_to_large_list_array { + ($string_array:expr, $field_name:expr, $data_type:expr, $builder_type:expr, $primitive_type:ty) => {{ + let item_field = Arc::new(Field::new($field_name, $data_type, true)); + let mut list_builder = LargeListBuilder::with_capacity($builder_type, $string_array.len()) + .with_field(Arc::clone(&item_field)); + + let list_field = Arc::new(Field::new_list("i", item_field, true)); + let mut decoder = ReaderBuilder::new_with_field(Arc::clone(&list_field)) + .build_decoder() + .map_err(|e| ArrowError::CastError(format!("Failed to create decoder: {e}")))?; + + for value in $string_array { + match value { + None => list_builder.append_null(), + Some(string_value) => { + decoder.decode(string_value.as_bytes()).map_err(|e| { + ArrowError::CastError(format!("Failed to decode value: {e}")) + })?; + + if let Some(b) = decoder.flush().map_err(|e| { + ArrowError::CastError(format!("Failed to decode decoder: {e}")) + })? { + let list_array = b + .column(0) + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::CastError( + "Failed to decode value: unable to downcast to ListArray" + .to_string(), + ) + })?; + let primitive_array = list_array + .values() + .as_any() + .downcast_ref::<$primitive_type>() + .ok_or_else(|| { + ArrowError::CastError( + "Failed to decode value: unable to downcast to PrimitiveType" + .to_string(), + ) + })?; + primitive_array + .iter() + .for_each(|maybe_value| match maybe_value { + Some(value) => list_builder.values().append_value(value), + None => list_builder.values().append_null(), + }); + list_builder.append(true); + } + } + } + } + + Ok(Arc::new(list_builder.finish())) + }}; +} + +macro_rules! cast_string_to_fixed_size_list_array { + ($string_array:expr, $field_name:expr, $data_type:expr, $builder_type:expr, $primitive_type:ty, $value_length:expr) => {{ + let item_field = Arc::new(Field::new($field_name, $data_type, true)); + let mut list_builder = + FixedSizeListBuilder::with_capacity($builder_type, $value_length, $string_array.len()) + .with_field(Arc::clone(&item_field)); + + let list_field = Arc::new(Field::new_list("i", item_field, true)); + let mut decoder = ReaderBuilder::new_with_field(Arc::clone(&list_field)) + .build_decoder() + .map_err(|e| ArrowError::CastError(format!("Failed to create decoder: {e}")))?; + + for value in $string_array { + match value { + None => { + for _ in 0..$value_length { + list_builder.values().append_null() + } + list_builder.append(true) + } + Some(string_value) => { + decoder.decode(string_value.as_bytes()).map_err(|e| { + ArrowError::CastError(format!("Failed to decode value: {e}")) + })?; + + if let Some(b) = decoder.flush().map_err(|e| { + ArrowError::CastError(format!("Failed to decode decoder: {e}")) + })? { + let list_array = b + .column(0) + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::CastError( + "Failed to decode value: unable to downcast to ListArray" + .to_string(), + ) + })?; + let primitive_array = list_array + .values() + .as_any() + .downcast_ref::<$primitive_type>() + .ok_or_else(|| { + ArrowError::CastError( + "Failed to decode value: unable to downcast to PrimitiveType" + .to_string(), + ) + })?; + primitive_array + .iter() + .for_each(|maybe_value| match maybe_value { + Some(value) => list_builder.values().append_value(value), + None => list_builder.values().append_null(), + }); + list_builder.append(true); + } + } + } + } + + Ok(Arc::new(list_builder.finish())) + }}; +} + +pub(crate) fn cast_string_to_list( + array: &dyn Array, + list_item_field: &FieldRef, +) -> Result { + let string_array = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::CastError( + "Failed to decode value: unable to downcast to StringArray".to_string(), + ) + })?; + + let field_name = list_item_field.name(); + + match list_item_field.data_type() { + DataType::Utf8 => { + cast_string_to_list_array!( + string_array, + field_name, + DataType::Utf8, + StringBuilder::new(), + StringArray + ) + } + DataType::LargeUtf8 => { + cast_string_to_list_array!( + string_array, + field_name, + DataType::LargeUtf8, + LargeStringBuilder::new(), + LargeStringArray + ) + } + DataType::Boolean => { + cast_string_to_list_array!( + string_array, + field_name, + DataType::Boolean, + BooleanBuilder::new(), + BooleanArray + ) + } + DataType::Int8 => { + cast_string_to_list_array!( + string_array, + field_name, + DataType::Int8, + Int8Builder::new(), + Int8Array + ) + } + DataType::Int16 => { + cast_string_to_list_array!( + string_array, + field_name, + DataType::Int16, + Int16Builder::new(), + Int16Array + ) + } + DataType::Int32 => { + cast_string_to_list_array!( + string_array, + field_name, + DataType::Int32, + Int32Builder::new(), + Int32Array + ) + } + DataType::Int64 => { + cast_string_to_list_array!( + string_array, + field_name, + DataType::Int64, + Int64Builder::new(), + Int64Array + ) + } + DataType::Float32 => { + cast_string_to_list_array!( + string_array, + field_name, + DataType::Float32, + Float32Builder::new(), + Float32Array + ) + } + DataType::Float64 => { + cast_string_to_list_array!( + string_array, + field_name, + DataType::Float64, + Float64Builder::new(), + Float64Array + ) + } + _ => Err(ArrowError::CastError(format!( + "Unsupported list item type: {}", + list_item_field.data_type() + ))), + } +} + +pub(crate) fn cast_string_to_large_list( + array: &dyn Array, + list_item_field: &FieldRef, +) -> Result { + let string_array = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::CastError( + "Failed to decode value: unable to downcast to StringArray".to_string(), + ) + })?; + + let field_name = list_item_field.name(); + + match list_item_field.data_type() { + DataType::Utf8 => { + cast_string_to_large_list_array!( + string_array, + field_name, + DataType::Utf8, + StringBuilder::new(), + StringArray + ) + } + DataType::LargeUtf8 => { + cast_string_to_large_list_array!( + string_array, + field_name, + DataType::LargeUtf8, + LargeStringBuilder::new(), + LargeStringArray + ) + } + DataType::Boolean => { + cast_string_to_large_list_array!( + string_array, + field_name, + DataType::Boolean, + BooleanBuilder::new(), + BooleanArray + ) + } + DataType::Int8 => { + cast_string_to_large_list_array!( + string_array, + field_name, + DataType::Int8, + Int8Builder::new(), + Int8Array + ) + } + DataType::Int16 => { + cast_string_to_large_list_array!( + string_array, + field_name, + DataType::Int16, + Int16Builder::new(), + Int16Array + ) + } + DataType::Int32 => { + cast_string_to_large_list_array!( + string_array, + field_name, + DataType::Int32, + Int32Builder::new(), + Int32Array + ) + } + DataType::Int64 => { + cast_string_to_large_list_array!( + string_array, + field_name, + DataType::Int64, + Int64Builder::new(), + Int64Array + ) + } + DataType::Float32 => { + cast_string_to_large_list_array!( + string_array, + field_name, + DataType::Float32, + Float32Builder::new(), + Float32Array + ) + } + DataType::Float64 => { + cast_string_to_large_list_array!( + string_array, + field_name, + DataType::Float64, + Float64Builder::new(), + Float64Array + ) + } + _ => Err(ArrowError::CastError(format!( + "Unsupported list item type: {}", + list_item_field.data_type() + ))), + } +} + +pub(crate) fn cast_string_to_fixed_size_list( + array: &dyn Array, + list_item_field: &FieldRef, + value_length: i32, +) -> Result { + let string_array = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::CastError( + "Failed to decode value: unable to downcast to StringArray".to_string(), + ) + })?; + + let field_name = list_item_field.name(); + + match list_item_field.data_type() { + DataType::Utf8 => { + cast_string_to_fixed_size_list_array!( + string_array, + field_name, + DataType::Utf8, + StringBuilder::new(), + StringArray, + value_length + ) + } + DataType::LargeUtf8 => { + cast_string_to_fixed_size_list_array!( + string_array, + field_name, + DataType::LargeUtf8, + LargeStringBuilder::new(), + LargeStringArray, + value_length + ) + } + DataType::Boolean => { + cast_string_to_fixed_size_list_array!( + string_array, + field_name, + DataType::Boolean, + BooleanBuilder::new(), + BooleanArray, + value_length + ) + } + DataType::Int8 => { + cast_string_to_fixed_size_list_array!( + string_array, + field_name, + DataType::Int8, + Int8Builder::new(), + Int8Array, + value_length + ) + } + DataType::Int16 => { + cast_string_to_fixed_size_list_array!( + string_array, + field_name, + DataType::Int16, + Int16Builder::new(), + Int16Array, + value_length + ) + } + DataType::Int32 => { + cast_string_to_fixed_size_list_array!( + string_array, + field_name, + DataType::Int32, + Int32Builder::new(), + Int32Array, + value_length + ) + } + DataType::Int64 => { + cast_string_to_fixed_size_list_array!( + string_array, + field_name, + DataType::Int64, + Int64Builder::new(), + Int64Array, + value_length + ) + } + DataType::Float32 => { + cast_string_to_fixed_size_list_array!( + string_array, + field_name, + DataType::Float32, + Float32Builder::new(), + Float32Array, + value_length + ) + } + DataType::Float64 => { + cast_string_to_fixed_size_list_array!( + string_array, + field_name, + DataType::Float64, + Float64Builder::new(), + Float64Array, + value_length + ) + } + _ => Err(ArrowError::CastError(format!( + "Unsupported list item type: {}", + list_item_field.data_type() + ))), + } +} + +#[cfg(test)] +mod test { + use datafusion::arrow::{ + array::{RecordBatch, StringArray}, + datatypes::{DataType, Field, Schema, SchemaRef}, + }; + + use crate::schema_cast::record_convert::try_cast_to; + + use super::*; + + fn input_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + Field::new("c", DataType::Utf8, false), + ])) + } + + fn output_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new( + "a", + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + false, + ), + Field::new( + "b", + DataType::LargeList(Arc::new(Field::new("item", DataType::Utf8, true))), + false, + ), + Field::new( + "c", + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Boolean, true)), 3), + false, + ), + ])) + } + + fn batch_input() -> RecordBatch { + RecordBatch::try_new( + input_schema(), + vec![ + Arc::new(StringArray::from(vec![ + Some("[1, 2, 3]"), + Some("[4, 5, 6]"), + ])), + Arc::new(StringArray::from(vec![ + Some("[\"foo\", \"bar\"]"), + Some("[\"baz\", \"qux\"]"), + ])), + Arc::new(StringArray::from(vec![ + Some("[true, false, true]"), + Some("[false, true, false]"), + ])), + ], + ) + .expect("record batch should not panic") + } + + fn batch_expected() -> RecordBatch { + let mut list_builder = ListBuilder::new(Int32Builder::new()); + list_builder.append_value([Some(1), Some(2), Some(3)]); + list_builder.append_value([Some(4), Some(5), Some(6)]); + let list_array = list_builder.finish(); + + let mut large_list_builder = LargeListBuilder::new(StringBuilder::new()); + large_list_builder.append_value([Some("foo"), Some("bar")]); + large_list_builder.append_value([Some("baz"), Some("qux")]); + let large_list_array = large_list_builder.finish(); + + let mut fixed_size_list_builder = FixedSizeListBuilder::new(BooleanBuilder::new(), 3); + fixed_size_list_builder.values().append_value(true); + fixed_size_list_builder.values().append_value(false); + fixed_size_list_builder.values().append_value(true); + fixed_size_list_builder.append(true); + fixed_size_list_builder.values().append_value(false); + fixed_size_list_builder.values().append_value(true); + fixed_size_list_builder.values().append_value(false); + fixed_size_list_builder.append(true); + let fixed_size_list_array = fixed_size_list_builder.finish(); + + RecordBatch::try_new( + output_schema(), + vec![ + Arc::new(list_array), + Arc::new(large_list_array), + Arc::new(fixed_size_list_array), + ], + ) + .expect("Failed to create expected RecordBatch") + } + + #[test] + fn test_cast_to_list_largelist_fixedsizelist() { + let input_batch = batch_input(); + let expected = batch_expected(); + let actual = try_cast_to(input_batch, output_schema()).expect("cast should succeed"); + + assert_eq!(actual, expected); + } +} diff --git a/datafusion-federation/src/schema_cast/record_convert.rs b/datafusion-federation/src/schema_cast/record_convert.rs new file mode 100644 index 0000000..a20401a --- /dev/null +++ b/datafusion-federation/src/schema_cast/record_convert.rs @@ -0,0 +1,156 @@ +use datafusion::arrow::{ + array::{Array, RecordBatch}, + compute::cast, + datatypes::{DataType, IntervalUnit, SchemaRef}, +}; +use std::sync::Arc; + +use super::{ + intervals_cast::{ + cast_interval_monthdaynano_to_daytime, cast_interval_monthdaynano_to_yearmonth, + }, + lists_cast::{cast_string_to_fixed_size_list, cast_string_to_large_list, cast_string_to_list}, + struct_cast::cast_string_to_struct, +}; + +pub type Result = std::result::Result; + +#[derive(Debug)] +pub enum Error { + UnableToConvertRecordBatch { + source: datafusion::arrow::error::ArrowError, + }, + + UnexpectedNumberOfColumns { + expected: usize, + found: usize, + }, +} + +impl std::error::Error for Error {} + +impl std::fmt::Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + Error::UnableToConvertRecordBatch { source } => { + write!(f, "Unable to convert record batch: {}", source) + } + Error::UnexpectedNumberOfColumns { expected, found } => { + write!( + f, + "Unexpected number of columns. Expected: {}, Found: {}", + expected, found + ) + } + } + } +} + +/// Cast a given record batch into a new record batch with the given schema. +/// It assumes the record batch columns are correctly ordered. +#[allow(clippy::needless_pass_by_value)] +pub fn try_cast_to(record_batch: RecordBatch, expected_schema: SchemaRef) -> Result { + let actual_schema = record_batch.schema(); + + if actual_schema.fields().len() != expected_schema.fields().len() { + return Err(Error::UnexpectedNumberOfColumns { + expected: expected_schema.fields().len(), + found: actual_schema.fields().len(), + }); + } + + let cols = expected_schema + .fields() + .iter() + .enumerate() + .map(|(i, expected_field)| { + let record_batch_col = record_batch.column(i); + + match (record_batch_col.data_type(), expected_field.data_type()) { + (DataType::Utf8, DataType::List(item_type)) => { + cast_string_to_list(&Arc::clone(record_batch_col), item_type) + .map_err(|e| Error::UnableToConvertRecordBatch { source: e }) + } + (DataType::Utf8, DataType::LargeList(item_type)) => { + cast_string_to_large_list(&Arc::clone(record_batch_col), item_type) + .map_err(|e| Error::UnableToConvertRecordBatch { source: e }) + } + (DataType::Utf8, DataType::FixedSizeList(item_type, value_length)) => { + cast_string_to_fixed_size_list( + &Arc::clone(record_batch_col), + item_type, + *value_length, + ) + .map_err(|e| Error::UnableToConvertRecordBatch { source: e }) + } + (DataType::Utf8, DataType::Struct(_)) => { + cast_string_to_struct(&Arc::clone(record_batch_col), expected_field.clone()) + .map_err(|e| Error::UnableToConvertRecordBatch { source: e }) + } + ( + DataType::Interval(IntervalUnit::MonthDayNano), + DataType::Interval(IntervalUnit::YearMonth), + ) => cast_interval_monthdaynano_to_yearmonth(&Arc::clone(record_batch_col)) + .map_err(|e| Error::UnableToConvertRecordBatch { source: e }), + ( + DataType::Interval(IntervalUnit::MonthDayNano), + DataType::Interval(IntervalUnit::DayTime), + ) => cast_interval_monthdaynano_to_daytime(&Arc::clone(record_batch_col)) + .map_err(|e| Error::UnableToConvertRecordBatch { source: e }), + _ => cast(&Arc::clone(record_batch_col), expected_field.data_type()) + .map_err(|e| Error::UnableToConvertRecordBatch { source: e }), + } + }) + .collect::>>>()?; + + RecordBatch::try_new(expected_schema, cols) + .map_err(|e| Error::UnableToConvertRecordBatch { source: e }) +} + +#[cfg(test)] +mod test { + use datafusion::arrow::{ + array::{Int32Array, StringArray}, + datatypes::{DataType, Field, Schema, TimeUnit}, + }; + + use super::*; + + fn schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, false), + Field::new("c", DataType::Utf8, false), + ])) + } + + fn to_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, false), + Field::new("b", DataType::LargeUtf8, false), + Field::new("c", DataType::Timestamp(TimeUnit::Microsecond, None), false), + ])) + } + + fn batch_input() -> RecordBatch { + RecordBatch::try_new( + schema(), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(StringArray::from(vec!["foo", "bar", "baz"])), + Arc::new(StringArray::from(vec![ + "2024-01-13 03:18:09.000000", + "2024-01-13 03:18:09", + "2024-01-13 03:18:09.000", + ])), + ], + ) + .expect("record batch should not panic") + } + + #[test] + fn test_string_to_timestamp_conversion() { + let result = try_cast_to(batch_input(), to_schema()).expect("converted"); + assert_eq!(3, result.num_rows()); + } +} diff --git a/datafusion-federation/src/schema_cast/struct_cast.rs b/datafusion-federation/src/schema_cast/struct_cast.rs new file mode 100644 index 0000000..85ad369 --- /dev/null +++ b/datafusion-federation/src/schema_cast/struct_cast.rs @@ -0,0 +1,169 @@ +use arrow_json::ReaderBuilder; +use datafusion::arrow::{ + array::{Array, ArrayRef, StringArray}, + datatypes::Field, + error::ArrowError, +}; +use std::sync::Arc; + +pub type Result = std::result::Result; + +pub(crate) fn cast_string_to_struct( + array: &dyn Array, + struct_field: Arc, +) -> Result { + let string_array = array + .as_any() + .downcast_ref::() + .ok_or_else(|| ArrowError::CastError("Failed to downcast to StringArray".to_string()))?; + + let mut decoder = ReaderBuilder::new_with_field(struct_field) + .build_decoder() + .map_err(|e| ArrowError::CastError(format!("Failed to create JSON decoder: {e}")))?; + + for value in string_array { + match value { + Some(v) => { + decoder.decode(v.as_bytes()).map_err(|e| { + ArrowError::CastError(format!("Failed to decode struct array: {e}")) + })?; + } + None => { + decoder.decode("null".as_bytes()).map_err(|e| { + ArrowError::CastError(format!("Failed to decode struct array: {e}")) + })?; + } + } + } + + let record = match decoder.flush() { + Ok(Some(record)) => record, + Ok(None) => { + return Err(ArrowError::CastError( + "Failed to flush decoder: No record".to_string(), + )); + } + Err(e) => { + return Err(ArrowError::CastError(format!( + "Failed to decode struct array: {e}" + ))); + } + }; + // struct_field is single struct column + return Ok(Arc::clone(record.column(0))); +} + +#[cfg(test)] +mod test { + use datafusion::arrow::{ + array::{Int32Builder, RecordBatch, StringArray, StringBuilder, StructBuilder}, + datatypes::{DataType, Field, Schema, SchemaRef}, + }; + + use crate::schema_cast::record_convert::try_cast_to; + + use super::*; + + fn input_schema() -> SchemaRef { + Arc::new(Schema::new(vec![Field::new( + "struct_string", + DataType::Utf8, + true, + )])) + } + + fn output_schema() -> SchemaRef { + Arc::new(Schema::new(vec![Field::new( + "struct", + DataType::Struct( + vec![ + Field::new("name", DataType::Utf8, false), + Field::new("age", DataType::Int32, false), + ] + .into(), + ), + true, + )])) + } + + fn batch_input() -> RecordBatch { + RecordBatch::try_new( + input_schema(), + vec![Arc::new(StringArray::from(vec![ + Some(r#"{"name":"John","age":30}"#), + None, + None, + Some(r#"{"name":"Jane","age":25}"#), + ]))], + ) + .expect("record batch should not panic") + } + + fn batch_expected() -> RecordBatch { + let name_field = Field::new("name", DataType::Utf8, false); + let age_field = Field::new("age", DataType::Int32, false); + + let mut struct_builder = StructBuilder::new( + vec![name_field, age_field], + vec![ + Box::new(StringBuilder::new()), + Box::new(Int32Builder::new()), + ], + ); + + struct_builder + .field_builder::(0) + .expect("should return field builder") + .append_value("John"); + struct_builder + .field_builder::(1) + .expect("should return field builder") + .append_value(30); + struct_builder.append(true); + + struct_builder + .field_builder::(0) + .expect("should return field builder") + .append_null(); + struct_builder + .field_builder::(1) + .expect("should return field builder") + .append_null(); + struct_builder.append(false); + + struct_builder + .field_builder::(0) + .expect("should return field builder") + .append_null(); + struct_builder + .field_builder::(1) + .expect("should return field builder") + .append_null(); + struct_builder.append(false); + + struct_builder + .field_builder::(0) + .expect("should return field builder") + .append_value("Jane"); + struct_builder + .field_builder::(1) + .expect("should return field builder") + .append_value(25); + struct_builder.append(true); + + let struct_array = struct_builder.finish(); + + RecordBatch::try_new(output_schema(), vec![Arc::new(struct_array)]) + .expect("Failed to create expected RecordBatch") + } + + #[test] + fn test_cast_to_struct() { + let input_batch = batch_input(); + let expected = batch_expected(); + + let actual = try_cast_to(input_batch, output_schema()).expect("cast should succeed"); + + assert_eq!(actual, expected); + } +} diff --git a/datafusion-federation/src/sql/executor.rs b/datafusion-federation/src/sql/executor.rs index 7f05910..6042f77 100644 --- a/datafusion-federation/src/sql/executor.rs +++ b/datafusion-federation/src/sql/executor.rs @@ -2,7 +2,7 @@ use async_trait::async_trait; use core::fmt; use datafusion::{ arrow::datatypes::SchemaRef, error::Result, physical_plan::SendableRecordBatchStream, - sql::sqlparser::dialect::Dialect, + sql::unparser::dialect::Dialect, }; use std::sync::Arc; diff --git a/datafusion-federation/src/sql/mod.rs b/datafusion-federation/src/sql/mod.rs index e68f8fa..830a99d 100644 --- a/datafusion-federation/src/sql/mod.rs +++ b/datafusion-federation/src/sql/mod.rs @@ -1,27 +1,40 @@ mod executor; mod schema; -use std::{any::Any, fmt, sync::Arc, vec}; +use std::{any::Any, collections::HashMap, fmt, sync::Arc, vec}; use async_trait::async_trait; use datafusion::{ arrow::datatypes::{Schema, SchemaRef}, + common::Column, error::Result, execution::{context::SessionState, TaskContext}, - logical_expr::{Extension, LogicalPlan}, + logical_expr::{ + expr::{ + AggregateFunction, Alias, Exists, InList, InSubquery, ScalarFunction, Sort, Unnest, + WindowFunction, + }, + Between, BinaryExpr, Case, Cast, Expr, Extension, GroupingSet, Like, LogicalPlan, Subquery, + TryCast, + }, optimizer::{optimizer::Optimizer, OptimizerConfig, OptimizerRule}, physical_expr::EquivalenceProperties, physical_plan::{ DisplayAs, DisplayFormatType, ExecutionMode, ExecutionPlan, Partitioning, PlanProperties, SendableRecordBatchStream, }, - sql::unparser::plan_to_sql, + sql::{ + unparser::{plan_to_sql, Unparser}, + TableReference, + }, }; pub use executor::{SQLExecutor, SQLExecutorRef}; pub use schema::{MultiSchemaProvider, SQLSchemaProvider, SQLTableSource}; -use crate::{FederatedPlanNode, FederationPlanner, FederationProvider}; +use crate::{ + get_table_source, schema_cast, FederatedPlanNode, FederationPlanner, FederationProvider, +}; // #[macro_use] // extern crate derive_builder; @@ -64,7 +77,7 @@ struct SQLFederationOptimizerRule { impl SQLFederationOptimizerRule { pub fn new(executor: Arc) -> Self { Self { - planner: Arc::new(SQLFederationPlanner::new(executor.clone())), + planner: Arc::new(SQLFederationPlanner::new(Arc::clone(&executor))), } } } @@ -100,6 +113,481 @@ impl OptimizerRule for SQLFederationOptimizerRule { false } } + +/// Rewrite table scans to use the original federated table name. +fn rewrite_table_scans( + plan: &LogicalPlan, + known_rewrites: &mut HashMap, +) -> Result { + if plan.inputs().is_empty() { + if let LogicalPlan::TableScan(table_scan) = plan { + let original_table_name = table_scan.table_name.clone(); + let mut new_table_scan = table_scan.clone(); + + let Some(federated_source) = get_table_source(&table_scan.source)? else { + // Not a federated source + return Ok(plan.clone()); + }; + + match federated_source.as_any().downcast_ref::() { + Some(sql_table_source) => { + let remote_table_name = TableReference::from(sql_table_source.table_name()); + known_rewrites.insert(original_table_name, remote_table_name.clone()); + + // Rewrite the schema of this node to have the remote table as the qualifier. + let new_schema = (*new_table_scan.projected_schema) + .clone() + .replace_qualifier(remote_table_name.clone()); + new_table_scan.projected_schema = Arc::new(new_schema); + new_table_scan.table_name = remote_table_name; + } + None => { + // Not a SQLTableSource (is this possible?) + return Ok(plan.clone()); + } + } + + return Ok(LogicalPlan::TableScan(new_table_scan)); + } else { + return Ok(plan.clone()); + } + } + + let rewritten_inputs = plan + .inputs() + .into_iter() + .map(|i| rewrite_table_scans(i, known_rewrites)) + .collect::>>()?; + + let mut new_expressions = vec![]; + for expression in plan.expressions() { + let new_expr = rewrite_table_scans_in_expr(expression.clone(), known_rewrites)?; + new_expressions.push(new_expr); + } + + let new_plan = plan.with_new_exprs(new_expressions, rewritten_inputs)?; + + Ok(new_plan) +} + +// The function replaces occurrences of table_ref_str in col_name with the new name defined by rewrite. +// The name to rewrite should NOT be a substring of another name. +// Supports multiple occurrences of table_ref_str in col_name. +fn rewrite_column_name_in_expr( + col_name: &str, + table_ref_str: &str, + rewrite: &str, + start_pos: usize, +) -> Option { + if start_pos >= col_name.len() { + return None; + } + + // Find the first occurrence of table_ref_str starting from start_pos + let idx = col_name[start_pos..].find(table_ref_str)?; + + // Calculate the absolute index of the occurrence in string as the index above is relative to start_pos + let idx = start_pos + idx; + + if idx > 0 { + // Check if the previous character is alphabetic, numeric, underscore or period, in which case we + // should not rewrite as it is a part of another name. + if let Some(prev_char) = col_name.chars().nth(idx - 1) { + if prev_char.is_alphabetic() + || prev_char.is_numeric() + || prev_char == '_' + || prev_char == '.' + { + return rewrite_column_name_in_expr( + col_name, + table_ref_str, + rewrite, + idx + table_ref_str.len(), + ); + } + } + } + + // Check if the next character is alphabetic, numeric or underscore, in which case we + // should not rewrite as it is a part of another name. + if let Some(next_char) = col_name.chars().nth(idx + table_ref_str.len()) { + if next_char.is_alphabetic() || next_char.is_numeric() || next_char == '_' { + return rewrite_column_name_in_expr( + col_name, + table_ref_str, + rewrite, + idx + table_ref_str.len(), + ); + } + } + + // Found full match, replace table_ref_str occurrence with rewrite + let rewritten_name = format!( + "{}{}{}", + &col_name[..idx], + rewrite, + &col_name[idx + table_ref_str.len()..] + ); + // Check if the rewritten name contains more occurrence of table_ref_str, and rewrite them as well + // This is done by providing the updated start_pos for search + match rewrite_column_name_in_expr(&rewritten_name, table_ref_str, rewrite, idx + rewrite.len()) + { + Some(new_name) => Some(new_name), // more occurrences found + None => Some(rewritten_name), // no more occurrences/changes + } +} + +fn rewrite_table_scans_in_expr( + expr: Expr, + known_rewrites: &mut HashMap, +) -> Result { + match expr { + Expr::ScalarSubquery(subquery) => { + let new_subquery = rewrite_table_scans(&subquery.subquery, known_rewrites)?; + let outer_ref_columns = subquery + .outer_ref_columns + .into_iter() + .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) + .collect::>>()?; + Ok(Expr::ScalarSubquery(Subquery { + subquery: Arc::new(new_subquery), + outer_ref_columns, + })) + } + Expr::BinaryExpr(binary_expr) => { + let left = rewrite_table_scans_in_expr(*binary_expr.left, known_rewrites)?; + let right = rewrite_table_scans_in_expr(*binary_expr.right, known_rewrites)?; + Ok(Expr::BinaryExpr(BinaryExpr::new( + Box::new(left), + binary_expr.op, + Box::new(right), + ))) + } + Expr::Column(mut col) => { + if let Some(rewrite) = col.relation.as_ref().and_then(|r| known_rewrites.get(r)) { + Ok(Expr::Column(Column::new(Some(rewrite.clone()), &col.name))) + } else { + // This prevent over-eager rewrite and only pass the column into below rewritten + // rule like MAX(...) + if col.relation.is_some() { + return Ok(Expr::Column(col)); + } + + // Check if any of the rewrites match any substring in col.name, and replace that part of the string if so. + // This will handles cases like "MAX(foo.df_table.a)" -> "MAX(remote_table.a)" + let (new_name, was_rewritten) = known_rewrites.iter().fold( + (col.name.to_string(), false), + |(col_name, was_rewritten), (table_ref, rewrite)| { + match rewrite_column_name_in_expr( + &col_name, + &table_ref.to_string(), + &rewrite.to_string(), + 0, + ) { + Some(new_name) => (new_name, true), + None => (col_name, was_rewritten), + } + }, + ); + if was_rewritten { + Ok(Expr::Column(Column::new(col.relation.take(), new_name))) + } else { + Ok(Expr::Column(col)) + } + } + } + Expr::Alias(alias) => { + let expr = rewrite_table_scans_in_expr(*alias.expr, known_rewrites)?; + if let Some(relation) = &alias.relation { + if let Some(rewrite) = known_rewrites.get(relation) { + return Ok(Expr::Alias(Alias::new( + expr, + Some(rewrite.clone()), + alias.name, + ))); + } + } + Ok(Expr::Alias(Alias::new(expr, alias.relation, alias.name))) + } + Expr::Like(like) => { + let expr = rewrite_table_scans_in_expr(*like.expr, known_rewrites)?; + let pattern = rewrite_table_scans_in_expr(*like.pattern, known_rewrites)?; + Ok(Expr::Like(Like::new( + like.negated, + Box::new(expr), + Box::new(pattern), + like.escape_char, + like.case_insensitive, + ))) + } + Expr::SimilarTo(similar_to) => { + let expr = rewrite_table_scans_in_expr(*similar_to.expr, known_rewrites)?; + let pattern = rewrite_table_scans_in_expr(*similar_to.pattern, known_rewrites)?; + Ok(Expr::SimilarTo(Like::new( + similar_to.negated, + Box::new(expr), + Box::new(pattern), + similar_to.escape_char, + similar_to.case_insensitive, + ))) + } + Expr::Not(e) => { + let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; + Ok(Expr::Not(Box::new(expr))) + } + Expr::IsNotNull(e) => { + let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; + Ok(Expr::IsNotNull(Box::new(expr))) + } + Expr::IsNull(e) => { + let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; + Ok(Expr::IsNull(Box::new(expr))) + } + Expr::IsTrue(e) => { + let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; + Ok(Expr::IsTrue(Box::new(expr))) + } + Expr::IsFalse(e) => { + let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; + Ok(Expr::IsFalse(Box::new(expr))) + } + Expr::IsUnknown(e) => { + let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; + Ok(Expr::IsUnknown(Box::new(expr))) + } + Expr::IsNotTrue(e) => { + let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; + Ok(Expr::IsNotTrue(Box::new(expr))) + } + Expr::IsNotFalse(e) => { + let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; + Ok(Expr::IsNotFalse(Box::new(expr))) + } + Expr::IsNotUnknown(e) => { + let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; + Ok(Expr::IsNotUnknown(Box::new(expr))) + } + Expr::Negative(e) => { + let expr = rewrite_table_scans_in_expr(*e, known_rewrites)?; + Ok(Expr::Negative(Box::new(expr))) + } + Expr::Between(between) => { + let expr = rewrite_table_scans_in_expr(*between.expr, known_rewrites)?; + let low = rewrite_table_scans_in_expr(*between.low, known_rewrites)?; + let high = rewrite_table_scans_in_expr(*between.high, known_rewrites)?; + Ok(Expr::Between(Between::new( + Box::new(expr), + between.negated, + Box::new(low), + Box::new(high), + ))) + } + Expr::Case(case) => { + let expr = case + .expr + .map(|e| rewrite_table_scans_in_expr(*e, known_rewrites)) + .transpose()? + .map(Box::new); + let else_expr = case + .else_expr + .map(|e| rewrite_table_scans_in_expr(*e, known_rewrites)) + .transpose()? + .map(Box::new); + let when_expr = case + .when_then_expr + .into_iter() + .map(|(when, then)| { + let when = rewrite_table_scans_in_expr(*when, known_rewrites); + let then = rewrite_table_scans_in_expr(*then, known_rewrites); + + match (when, then) { + (Ok(when), Ok(then)) => Ok((Box::new(when), Box::new(then))), + (Err(e), _) | (_, Err(e)) => Err(e), + } + }) + .collect::, Box)>>>()?; + Ok(Expr::Case(Case::new(expr, when_expr, else_expr))) + } + Expr::Cast(cast) => { + let expr = rewrite_table_scans_in_expr(*cast.expr, known_rewrites)?; + Ok(Expr::Cast(Cast::new(Box::new(expr), cast.data_type))) + } + Expr::TryCast(try_cast) => { + let expr = rewrite_table_scans_in_expr(*try_cast.expr, known_rewrites)?; + Ok(Expr::TryCast(TryCast::new( + Box::new(expr), + try_cast.data_type, + ))) + } + Expr::Sort(sort) => { + let expr = rewrite_table_scans_in_expr(*sort.expr, known_rewrites)?; + Ok(Expr::Sort(Sort::new( + Box::new(expr), + sort.asc, + sort.nulls_first, + ))) + } + Expr::ScalarFunction(sf) => { + let args = sf + .args + .into_iter() + .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) + .collect::>>()?; + Ok(Expr::ScalarFunction(ScalarFunction { + func: sf.func, + args, + })) + } + Expr::AggregateFunction(af) => { + let args = af + .args + .into_iter() + .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) + .collect::>>()?; + let filter = af + .filter + .map(|e| rewrite_table_scans_in_expr(*e, known_rewrites)) + .transpose()? + .map(Box::new); + let order_by = af + .order_by + .map(|e| { + e.into_iter() + .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) + .collect::>>() + }) + .transpose()?; + Ok(Expr::AggregateFunction(AggregateFunction { + func: af.func, + args, + distinct: af.distinct, + filter, + order_by, + null_treatment: af.null_treatment, + })) + } + Expr::WindowFunction(wf) => { + let args = wf + .args + .into_iter() + .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) + .collect::>>()?; + let partition_by = wf + .partition_by + .into_iter() + .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) + .collect::>>()?; + let order_by = wf + .order_by + .into_iter() + .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) + .collect::>>()?; + Ok(Expr::WindowFunction(WindowFunction { + fun: wf.fun, + args, + partition_by, + order_by, + window_frame: wf.window_frame, + null_treatment: wf.null_treatment, + })) + } + Expr::InList(il) => { + let expr = rewrite_table_scans_in_expr(*il.expr, known_rewrites)?; + let list = il + .list + .into_iter() + .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) + .collect::>>()?; + Ok(Expr::InList(InList::new(Box::new(expr), list, il.negated))) + } + Expr::Exists(exists) => { + let subquery_plan = rewrite_table_scans(&exists.subquery.subquery, known_rewrites)?; + let outer_ref_columns = exists + .subquery + .outer_ref_columns + .into_iter() + .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) + .collect::>>()?; + let subquery = Subquery { + subquery: Arc::new(subquery_plan), + outer_ref_columns, + }; + Ok(Expr::Exists(Exists::new(subquery, exists.negated))) + } + Expr::InSubquery(is) => { + let expr = rewrite_table_scans_in_expr(*is.expr, known_rewrites)?; + let subquery_plan = rewrite_table_scans(&is.subquery.subquery, known_rewrites)?; + let outer_ref_columns = is + .subquery + .outer_ref_columns + .into_iter() + .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) + .collect::>>()?; + let subquery = Subquery { + subquery: Arc::new(subquery_plan), + outer_ref_columns, + }; + Ok(Expr::InSubquery(InSubquery::new( + Box::new(expr), + subquery, + is.negated, + ))) + } + Expr::Wildcard { qualifier } => { + if let Some(rewrite) = qualifier.as_ref().and_then(|q| known_rewrites.get(q)) { + Ok(Expr::Wildcard { + qualifier: Some(rewrite.clone()), + }) + } else { + Ok(Expr::Wildcard { qualifier }) + } + } + Expr::GroupingSet(gs) => match gs { + GroupingSet::Rollup(exprs) => { + let exprs = exprs + .into_iter() + .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) + .collect::>>()?; + Ok(Expr::GroupingSet(GroupingSet::Rollup(exprs))) + } + GroupingSet::Cube(exprs) => { + let exprs = exprs + .into_iter() + .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) + .collect::>>()?; + Ok(Expr::GroupingSet(GroupingSet::Cube(exprs))) + } + GroupingSet::GroupingSets(vec_exprs) => { + let vec_exprs = vec_exprs + .into_iter() + .map(|exprs| { + exprs + .into_iter() + .map(|e| rewrite_table_scans_in_expr(e, known_rewrites)) + .collect::>>() + }) + .collect::>>>()?; + Ok(Expr::GroupingSet(GroupingSet::GroupingSets(vec_exprs))) + } + }, + Expr::OuterReferenceColumn(dt, col) => { + if let Some(rewrite) = col.relation.as_ref().and_then(|r| known_rewrites.get(r)) { + Ok(Expr::OuterReferenceColumn( + dt, + Column::new(Some(rewrite.clone()), &col.name), + )) + } else { + Ok(Expr::OuterReferenceColumn(dt, col)) + } + } + Expr::Unnest(unnest) => { + let expr = rewrite_table_scans_in_expr(*unnest.expr, known_rewrites)?; + Ok(Expr::Unnest(Unnest::new(expr))) + } + Expr::ScalarVariable(_, _) | Expr::Literal(_) | Expr::Placeholder(_) => Ok(expr), + } +} + struct SQLFederationPlanner { executor: Arc, } @@ -117,10 +605,13 @@ impl FederationPlanner for SQLFederationPlanner { node: &FederatedPlanNode, _session_state: &SessionState, ) -> Result> { - Ok(Arc::new(VirtualExecutionPlan::new( + let schema = Arc::new(node.plan().schema().as_arrow().clone()); + let input = Arc::new(VirtualExecutionPlan::new( node.plan().clone(), - self.executor.clone(), - ))) + Arc::clone(&self.executor), + )); + let schema_cast_exec = schema_cast::SchemaCastScanExec::new(input, schema); + Ok(Arc::new(schema_cast_exec)) } } @@ -150,6 +641,14 @@ impl VirtualExecutionPlan { let df_schema = self.plan.schema().as_ref(); Arc::new(Schema::from(df_schema)) } + + fn sql(&self) -> Result { + // Find all table scans, recover the SQLTableSource, find the remote table name and replace the name of the TableScan table. + let mut known_rewrites = HashMap::new(); + let ast = Unparser::new(self.executor.dialect().as_ref()) + .plan_to_sql(&rewrite_table_scans(&self.plan, &mut known_rewrites)?)?; + Ok(format!("{ast}")) + } } impl DisplayAs for VirtualExecutionPlan { @@ -162,6 +661,11 @@ impl DisplayAs for VirtualExecutionPlan { if let Some(ctx) = self.executor.compute_context() { write!(f, " compute_context={ctx}")?; } + write!(f, " sql={ast}")?; + if let Ok(query) = self.sql() { + write!(f, " rewritten_sql={query}")?; + }; + write!(f, " sql={ast}") } } @@ -205,3 +709,258 @@ impl ExecutionPlan for VirtualExecutionPlan { &self.props } } + +#[cfg(test)] +mod tests { + use crate::FederatedTableProviderAdaptor; + use datafusion::{ + arrow::datatypes::{DataType, Field}, + catalog::SchemaProvider, + catalog_common::MemorySchemaProvider, + common::Column, + datasource::{DefaultTableSource, TableProvider}, + error::DataFusionError, + execution::context::SessionContext, + logical_expr::LogicalPlanBuilder, + sql::{unparser::dialect::DefaultDialect, unparser::dialect::Dialect}, + }; + + use super::*; + + struct TestSQLExecutor {} + + #[async_trait] + impl SQLExecutor for TestSQLExecutor { + fn name(&self) -> &str { + "test_sql_table_source" + } + + fn compute_context(&self) -> Option { + None + } + + fn dialect(&self) -> Arc { + Arc::new(DefaultDialect {}) + } + + fn execute(&self, _query: &str, _schema: SchemaRef) -> Result { + Err(DataFusionError::NotImplemented( + "execute not implemented".to_string(), + )) + } + + async fn table_names(&self) -> Result> { + Err(DataFusionError::NotImplemented( + "table inference not implemented".to_string(), + )) + } + + async fn get_table_schema(&self, _table_name: &str) -> Result { + Err(DataFusionError::NotImplemented( + "table inference not implemented".to_string(), + )) + } + } + + fn get_test_table_provider() -> Arc { + let sql_federation_provider = + Arc::new(SQLFederationProvider::new(Arc::new(TestSQLExecutor {}))); + + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, false), + Field::new("b", DataType::Utf8, false), + Field::new("c", DataType::Date32, false), + ])); + let table_source = Arc::new( + SQLTableSource::new_with_schema( + sql_federation_provider, + "remote_table".to_string(), + schema, + ) + .expect("to have a valid SQLTableSource"), + ); + Arc::new(FederatedTableProviderAdaptor::new(table_source)) + } + + fn get_test_table_source() -> Arc { + Arc::new(DefaultTableSource::new(get_test_table_provider())) + } + + fn get_test_df_context() -> SessionContext { + let ctx = SessionContext::new(); + let catalog = ctx + .catalog("datafusion") + .expect("default catalog is datafusion"); + let foo_schema = Arc::new(MemorySchemaProvider::new()) as Arc; + catalog + .register_schema("foo", Arc::clone(&foo_schema)) + .expect("to register schema"); + foo_schema + .register_table("df_table".to_string(), get_test_table_provider()) + .expect("to register table"); + + let public_schema = catalog + .schema("public") + .expect("public schema should exist"); + public_schema + .register_table("app_table".to_string(), get_test_table_provider()) + .expect("to register table"); + + ctx + } + + #[test] + fn test_rewrite_table_scans_basic() -> Result<()> { + let default_table_source = get_test_table_source(); + let plan = + LogicalPlanBuilder::scan("foo.df_table", default_table_source, None)?.project(vec![ + Expr::Column(Column::from_qualified_name("foo.df_table.a")), + Expr::Column(Column::from_qualified_name("foo.df_table.b")), + Expr::Column(Column::from_qualified_name("foo.df_table.c")), + ])?; + + let mut known_rewrites = HashMap::new(); + let rewritten_plan = rewrite_table_scans(&plan.build()?, &mut known_rewrites)?; + + println!("rewritten_plan: \n{:#?}", rewritten_plan); + + let unparsed_sql = plan_to_sql(&rewritten_plan)?; + + println!("unparsed_sql: \n{unparsed_sql}"); + + assert_eq!( + format!("{unparsed_sql}"), + r#"SELECT remote_table.a, remote_table.b, remote_table.c FROM remote_table"# + ); + + Ok(()) + } + + fn init_tracing() { + let subscriber = tracing_subscriber::FmtSubscriber::builder() + .with_env_filter("debug") + .with_ansi(true) + .finish(); + let _ = tracing::subscriber::set_global_default(subscriber); + } + + #[tokio::test] + async fn test_rewrite_table_scans_agg() -> Result<()> { + init_tracing(); + let ctx = get_test_df_context(); + + let agg_tests = vec![ + ( + "SELECT MAX(a) FROM foo.df_table", + r#"SELECT max(remote_table.a) FROM remote_table"#, + ), + ( + "SELECT foo.df_table.a FROM foo.df_table", + r#"SELECT remote_table.a FROM remote_table"#, + ), + ( + "SELECT MIN(a) FROM foo.df_table", + r#"SELECT min(remote_table.a) FROM remote_table"#, + ), + ( + "SELECT AVG(a) FROM foo.df_table", + r#"SELECT avg(remote_table.a) FROM remote_table"#, + ), + ( + "SELECT SUM(a) FROM foo.df_table", + r#"SELECT sum(remote_table.a) FROM remote_table"#, + ), + ( + "SELECT COUNT(a) FROM foo.df_table", + r#"SELECT count(remote_table.a) FROM remote_table"#, + ), + ( + "SELECT COUNT(a) as cnt FROM foo.df_table", + r#"SELECT count(remote_table.a) AS cnt FROM remote_table"#, + ), + ( + "SELECT COUNT(a) as cnt FROM foo.df_table", + r#"SELECT count(remote_table.a) AS cnt FROM remote_table"#, + ), + ( + "SELECT app_table from (SELECT a as app_table FROM app_table) b", + r#"SELECT b.app_table FROM (SELECT remote_table.a AS app_table FROM remote_table) AS b"#, + ), + ( + "SELECT MAX(app_table) from (SELECT a as app_table FROM app_table) b", + r#"SELECT max(b.app_table) FROM (SELECT remote_table.a AS app_table FROM remote_table) AS b"#, + ), + // multiple occurrences of the same table in single aggregation expression + ( + "SELECT COUNT(CASE WHEN a > 0 THEN a ELSE 0 END) FROM app_table", + r#"SELECT count(CASE WHEN (remote_table.a > 0) THEN remote_table.a ELSE 0 END) FROM remote_table"#, + ), + // different tables in single aggregation expression + ( + "SELECT COUNT(CASE WHEN app_table.a > 0 THEN app_table.a ELSE foo.df_table.a END) FROM app_table, foo.df_table", + r#"SELECT count(CASE WHEN (remote_table.a > 0) THEN remote_table.a ELSE remote_table.a END) FROM remote_table JOIN remote_table ON true"#, + ), + ]; + + for test in agg_tests { + test_sql(&ctx, test.0, test.1).await?; + } + + Ok(()) + } + + #[tokio::test] + async fn test_rewrite_table_scans_alias() -> Result<()> { + init_tracing(); + let ctx = get_test_df_context(); + + let tests = vec![ + ( + "SELECT COUNT(app_table_a) FROM (SELECT a as app_table_a FROM app_table)", + r#"SELECT count(app_table_a) FROM (SELECT remote_table.a AS app_table_a FROM remote_table)"#, + ), + ( + "SELECT app_table_a FROM (SELECT a as app_table_a FROM app_table)", + r#"SELECT app_table_a FROM (SELECT remote_table.a AS app_table_a FROM remote_table)"#, + ), + ( + "SELECT aapp_table FROM (SELECT a as aapp_table FROM app_table)", + r#"SELECT aapp_table FROM (SELECT remote_table.a AS aapp_table FROM remote_table)"#, + ), + ]; + + for test in tests { + test_sql(&ctx, test.0, test.1).await?; + } + + Ok(()) + } + + async fn test_sql( + ctx: &SessionContext, + sql_query: &str, + expected_sql: &str, + ) -> Result<(), datafusion::error::DataFusionError> { + let data_frame = ctx.sql(sql_query).await?; + + println!("before optimization: \n{:#?}", data_frame.logical_plan()); + + let mut known_rewrites = HashMap::new(); + let rewritten_plan = rewrite_table_scans(data_frame.logical_plan(), &mut known_rewrites)?; + + println!("rewritten_plan: \n{:#?}", rewritten_plan); + + let unparsed_sql = plan_to_sql(&rewritten_plan)?; + + println!("unparsed_sql: \n{unparsed_sql}"); + + assert_eq!( + format!("{unparsed_sql}"), + expected_sql, + "SQL under test: {}", + sql_query + ); + + Ok(()) + } +} diff --git a/datafusion-federation/src/sql/schema.rs b/datafusion-federation/src/sql/schema.rs index cb35ee6..8e9cf25 100644 --- a/datafusion-federation/src/sql/schema.rs +++ b/datafusion-federation/src/sql/schema.rs @@ -19,7 +19,7 @@ pub struct SQLSchemaProvider { impl SQLSchemaProvider { pub async fn new(provider: Arc) -> Result { - let tables = provider.clone().executor.table_names().await?; + let tables = Arc::clone(&provider).executor.table_names().await?; Self::new_with_tables(provider, tables).await } @@ -30,7 +30,7 @@ impl SQLSchemaProvider { ) -> Result { let futures: Vec<_> = tables .into_iter() - .map(|t| SQLTableSource::new(provider.clone(), t)) + .map(|t| SQLTableSource::new(Arc::clone(&provider), t)) .collect(); let results: Result> = join_all(futures).await.into_iter().collect(); let sources = results?.into_iter().map(Arc::new).collect(); @@ -58,7 +58,9 @@ impl SchemaProvider for SQLSchemaProvider { .iter() .find(|s| s.table_name.eq_ignore_ascii_case(name)) { - let adaptor = FederatedTableProviderAdaptor::new(source.clone()); + let adaptor = FederatedTableProviderAdaptor::new( + Arc::clone(source) as Arc + ); return Ok(Some(Arc::new(adaptor))); } Ok(None) @@ -114,8 +116,7 @@ pub struct SQLTableSource { impl SQLTableSource { // creates a SQLTableSource and infers the table schema pub async fn new(provider: Arc, table_name: String) -> Result { - let schema = provider - .clone() + let schema = Arc::clone(&provider) .executor .get_table_schema(table_name.as_str()) .await?; @@ -133,11 +134,15 @@ impl SQLTableSource { schema, }) } + + pub fn table_name(&self) -> &str { + self.table_name.as_str() + } } impl FederatedTableSource for SQLTableSource { fn federation_provider(&self) -> Arc { - self.provider.clone() + Arc::clone(&self.provider) as Arc } } @@ -146,7 +151,7 @@ impl TableSource for SQLTableSource { self } fn schema(&self) -> SchemaRef { - self.schema.clone() + Arc::clone(&self.schema) } fn table_type(&self) -> TableType { TableType::Temporary diff --git a/datafusion-flight-sql-table-provider/src/lib.rs b/datafusion-flight-sql-table-provider/src/lib.rs index 930939d..0bcf432 100644 --- a/datafusion-flight-sql-table-provider/src/lib.rs +++ b/datafusion-flight-sql-table-provider/src/lib.rs @@ -6,7 +6,7 @@ use async_trait::async_trait; use datafusion::{ error::{DataFusionError, Result}, physical_plan::{stream::RecordBatchStreamAdapter, SendableRecordBatchStream}, - sql::sqlparser::dialect::{Dialect, GenericDialect}, + sql::unparser::dialect::{DefaultDialect, Dialect}, }; use datafusion_federation::sql::SQLExecutor; use futures::TryStreamExt; @@ -68,11 +68,11 @@ impl SQLExecutor for FlightSQLExecutor { } fn execute(&self, sql: &str, schema: SchemaRef) -> Result { let future_stream = - make_flight_sql_stream(sql.to_string(), self.client.clone(), schema.clone()); + make_flight_sql_stream(sql.to_string(), self.client.clone(), Arc::clone(&schema)); let stream = futures::stream::once(future_stream).try_flatten(); Ok(Box::pin(RecordBatchStreamAdapter::new( - schema.clone(), + Arc::clone(&schema), stream, ))) } @@ -96,7 +96,7 @@ impl SQLExecutor for FlightSQLExecutor { } fn dialect(&self) -> Arc { - Arc::new(GenericDialect {}) + Arc::new(DefaultDialect {}) } } diff --git a/examples/Cargo.toml b/examples/Cargo.toml index 13eb599..52acc4e 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -6,7 +6,7 @@ license.workspace = true readme.workspace = true [dev-dependencies] -arrow-flight = { version = "51.0.0", features = ["flight-sql-experimental"] } +arrow-flight = { version = "52.0.0", features = ["flight-sql-experimental"] } tokio = "1.35.1" async-trait.workspace = true datafusion.workspace = true