From 60cf8ae49941053b46964d56878224798028d42a Mon Sep 17 00:00:00 2001 From: Sevenannn Date: Thu, 5 Sep 2024 19:30:14 -0700 Subject: [PATCH 01/40] Add feature flag to disable postgres federation --- Cargo.toml | 4 ++++ src/postgres.rs | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/Cargo.toml b/Cargo.toml index ce0e851..2dd2041 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -99,3 +99,7 @@ flight = [ ] duckdb-federation = ["duckdb"] sqlite-federation = ["sqlite"] +postgres-federation = ["postgres"] + +[patch.crates-io] +datafusion-federation = { git = "https://github.com/spiceai/datafusion-federation.git", rev = "b6682948d07cc3155edb3dfbf03f8b55570fc1d2" } diff --git a/src/postgres.rs b/src/postgres.rs index 5eb3be3..83d8409 100644 --- a/src/postgres.rs +++ b/src/postgres.rs @@ -156,6 +156,8 @@ impl PostgresTableFactory { .await .map_err(|e| Box::new(e) as Box)?, ); + + #[cfg(feature = "postgres-federation")] let table_provider = Arc::new( table_provider .create_federated_table_provider() @@ -301,7 +303,9 @@ impl TableProviderFactory for PostgresTableProviderFactory { Some(Engine::Postgres), )); + #[cfg(feature = "postgres-federation")] let read_provider = Arc::new(read_provider.create_federated_table_provider()?); + Ok(PostgresTableWriter::create( read_provider, postgres, From 51096f4ee9785edb7e56700fcad29ad612f370b0 Mon Sep 17 00:00:00 2001 From: Sevenannn Date: Thu, 5 Sep 2024 20:03:28 -0700 Subject: [PATCH 02/40] Only disable federation for tableproviderfactory --- src/postgres.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/postgres.rs b/src/postgres.rs index 83d8409..494fb44 100644 --- a/src/postgres.rs +++ b/src/postgres.rs @@ -157,7 +157,6 @@ impl PostgresTableFactory { .map_err(|e| Box::new(e) as Box)?, ); - #[cfg(feature = "postgres-federation")] let table_provider = Arc::new( table_provider .create_federated_table_provider() From 28d2409e5354bb1bd743bfe51597723f3b86a37f Mon Sep 17 00:00:00 2001 From: yfu Date: Wed, 4 Sep 2024 16:53:25 +1000 Subject: [PATCH 03/40] DuckDB streaming (#41) * wip * duckdb streaming * clippy * arrow to arrow stream * error message * fix: Support `INTERVAL` in SQLite (#85) * poc: Support interval in SQLite using an AST analyzer * Refactoring * u64 -> i64 * fix: Support INTERVAL expressions in SQLite * docs: Add comment about flattening arguments list * refactor: Rename SQLiteVisitor to SQLiteIntervalVisitor * test: Add some tests --------- Co-authored-by: Phillip LeBlanc * Use DuckDB streaming * Fixes * Fix feature flagging * Fix lint * Add spiceai branch to pull_request --------- Co-authored-by: peasee <98815791+peasee@users.noreply.github.com> Co-authored-by: Phillip LeBlanc --- .github/workflows/pr.yaml | 1 + Cargo.toml | 5 +- src/duckdb.rs | 18 +-- .../dbconnection/duckdbconn.rs | 107 ++++++++++++++++-- src/sql/db_connection_pool/duckdbpool.rs | 12 +- src/sql/sql_provider_datafusion/mod.rs | 36 ++++-- 6 files changed, 144 insertions(+), 35 deletions(-) diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index ccb056e..c45867e 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -5,6 +5,7 @@ on: pull_request: branches: - main + - spiceai jobs: lint: diff --git a/Cargo.toml b/Cargo.toml index 2dd2041..0a68bb3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -62,6 +62,7 @@ pem = { version = "3.0.4", optional = true } tokio-rusqlite = { version = "0.5.1", optional = true } tonic = { version = "0.11", optional = true } # pinned for arrow-flight compat itertools = "0.13.0" +dyn-clone = { version = "1.0.17", optional = true } geo-types = "0.7.13" [dev-dependencies] @@ -81,7 +82,7 @@ arrow-schema = "52.2.0" mysql = ["dep:mysql_async", "dep:async-stream"] postgres = ["dep:tokio-postgres", "dep:uuid", "dep:postgres-native-tls", "dep:bb8", "dep:bb8-postgres", "dep:native-tls", "dep:pem", "dep:async-stream"] sqlite = ["dep:rusqlite", "dep:tokio-rusqlite"] -duckdb = ["dep:duckdb", "dep:r2d2", "dep:uuid"] +duckdb = ["dep:duckdb", "dep:r2d2", "dep:uuid", "dep:dyn-clone", "dep:async-stream"] flight = [ "dep:arrow-array", "dep:arrow-cast", @@ -102,4 +103,4 @@ sqlite-federation = ["sqlite"] postgres-federation = ["postgres"] [patch.crates-io] -datafusion-federation = { git = "https://github.com/spiceai/datafusion-federation.git", rev = "b6682948d07cc3155edb3dfbf03f8b55570fc1d2" } +duckdb = { git = "https://github.com/spiceai/duckdb-rs.git", rev = "f2ca47d094a5636df8b9f3792b2f474a7b210dc1" } diff --git a/src/duckdb.rs b/src/duckdb.rs index 70cd090..bb9ea8a 100644 --- a/src/duckdb.rs +++ b/src/duckdb.rs @@ -1,7 +1,9 @@ use crate::sql::db_connection_pool::{ self, dbconnection::{ - duckdbconn::{flatten_table_function_name, is_table_function, DuckDbConnection}, + duckdbconn::{ + flatten_table_function_name, is_table_function, DuckDBParameter, DuckDbConnection, + }, get_schema, DbConnection, }, duckdbpool::DuckDbConnectionPool, @@ -25,7 +27,7 @@ use datafusion::{ logical_expr::CreateExternalTable, sql::TableReference, }; -use duckdb::{AccessMode, DuckdbConnectionManager, ToSql, Transaction}; +use duckdb::{AccessMode, DuckdbConnectionManager, Transaction}; use itertools::Itertools; use snafu::prelude::*; use std::{cmp, collections::HashMap, sync::Arc}; @@ -177,7 +179,7 @@ impl Default for DuckDBTableProviderFactory { } } -type DynDuckDbConnectionPool = dyn DbConnectionPool, &'static dyn ToSql> +type DynDuckDbConnectionPool = dyn DbConnectionPool, DuckDBParameter> + Send + Sync; @@ -317,18 +319,18 @@ impl DuckDB { pub fn connect_sync( &self, ) -> Result< - Box, &'static dyn ToSql>>, + Box, DuckDBParameter>>, > { Arc::clone(&self.pool) .connect_sync() .context(DbConnectionSnafu) } - pub fn duckdb_conn<'a>( - db_connection: &'a mut Box< - dyn DbConnection, &'static dyn ToSql>, + pub fn duckdb_conn( + db_connection: &mut Box< + dyn DbConnection, DuckDBParameter>, >, - ) -> Result<&'a mut DuckDbConnection> { + ) -> Result<&mut DuckDbConnection> { db_connection .as_any_mut() .downcast_mut::() diff --git a/src/sql/db_connection_pool/dbconnection/duckdbconn.rs b/src/sql/db_connection_pool/dbconnection/duckdbconn.rs index 4ad5140..f97ed69 100644 --- a/src/sql/db_connection_pool/dbconnection/duckdbconn.rs +++ b/src/sql/db_connection_pool/dbconnection/duckdbconn.rs @@ -1,16 +1,20 @@ use std::any::Any; use arrow::array::RecordBatch; +use async_stream::stream; use datafusion::arrow::datatypes::SchemaRef; +use datafusion::error::DataFusionError; use datafusion::execution::SendableRecordBatchStream; -use datafusion::physical_plan::memory::MemoryStream; +use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::sql::sqlparser::ast::TableFactor; use datafusion::sql::sqlparser::parser::Parser; use datafusion::sql::sqlparser::{dialect::DuckDbDialect, tokenizer::Tokenizer}; use datafusion::sql::TableReference; use duckdb::DuckdbConnectionManager; use duckdb::ToSql; +use dyn_clone::DynClone; use snafu::{prelude::*, ResultExt}; +use tokio::sync::mpsc::Sender; use super::DbConnection; use super::Result; @@ -20,7 +24,22 @@ use super::SyncDbConnection; pub enum Error { #[snafu(display("DuckDBError: {source}"))] DuckDBError { source: duckdb::Error }, + + #[snafu(display("ChannelError: {message}"))] + ChannelError { message: String }, +} + +pub trait DuckDBSyncParameter: ToSql + Sync + Send + DynClone { + fn as_input_parameter(&self) -> &dyn ToSql; +} + +impl DuckDBSyncParameter for T { + fn as_input_parameter(&self) -> &dyn ToSql { + self + } } +dyn_clone::clone_trait_object!(DuckDBSyncParameter); +pub type DuckDBParameter = Box; pub struct DuckDbConnection { pub conn: r2d2::PooledConnection, @@ -34,7 +53,7 @@ impl DuckDbConnection { } } -impl<'a> DbConnection, &'a dyn ToSql> +impl DbConnection, DuckDBParameter> for DuckDbConnection { fn as_any(&self) -> &dyn Any { @@ -47,13 +66,14 @@ impl<'a> DbConnection, &'a dyn T fn as_sync( &self, - ) -> Option<&dyn SyncDbConnection, &'a dyn ToSql>> - { + ) -> Option< + &dyn SyncDbConnection, DuckDBParameter>, + > { Some(self) } } -impl SyncDbConnection, &dyn ToSql> +impl SyncDbConnection, DuckDBParameter> for DuckDbConnection { fn new(conn: r2d2::PooledConnection) -> Self { @@ -83,23 +103,88 @@ impl SyncDbConnection, &dyn ToSq fn query_arrow( &self, sql: &str, - params: &[&dyn ToSql], + params: &[DuckDBParameter], _projected_schema: Option, ) -> Result { - let mut stmt = self.conn.prepare(sql).context(DuckDBSnafu)?; + let (batch_tx, mut batch_rx) = tokio::sync::mpsc::channel::(4); + + let fetch_schema_sql = + format!("WITH fetch_schema AS ({sql}) SELECT * FROM fetch_schema LIMIT 0"); + let mut stmt = self + .conn + .prepare(&fetch_schema_sql) + .boxed() + .context(super::UnableToGetSchemaSnafu)?; + + let result: duckdb::Arrow<'_> = stmt + .query_arrow([]) + .boxed() + .context(super::UnableToGetSchemaSnafu)?; - let result: duckdb::Arrow<'_> = stmt.query_arrow(params).context(DuckDBSnafu)?; let schema = result.get_schema(); - let recs: Vec = result.collect(); - Ok(Box::pin(MemoryStream::try_new(recs, schema, None)?)) + + let params = params.iter().map(dyn_clone::clone).collect::>(); + + let conn = self.conn.try_clone()?; + let sql = sql.to_string(); + + let cloned_schema = schema.clone(); + + let join_handle = tokio::task::spawn_blocking(move || { + let mut stmt = conn.prepare(&sql).context(DuckDBSnafu)?; + let params: &[&dyn ToSql] = ¶ms + .iter() + .map(|f| f.as_input_parameter()) + .collect::>(); + let result: duckdb::ArrowStream<'_> = stmt + .stream_arrow(params, cloned_schema) + .context(DuckDBSnafu)?; + for i in result { + blocking_channel_send(&batch_tx, i)?; + } + + Ok::<_, Box>(()) + }); + + let output_stream = stream! { + while let Some(batch) = batch_rx.recv().await { + yield Ok(batch); + } + + if let Err(e) = join_handle.await { + yield Err(DataFusionError::Execution(format!( + "Failed to execute DuckDB query: {e}" + ))) + } + }; + + Ok(Box::pin(RecordBatchStreamAdapter::new( + schema, + output_stream, + ))) } - fn execute(&self, sql: &str, params: &[&dyn ToSql]) -> Result { + fn execute(&self, sql: &str, params: &[DuckDBParameter]) -> Result { + let params: &[&dyn ToSql] = ¶ms + .iter() + .map(|f| f.as_input_parameter()) + .collect::>(); + let rows_modified = self.conn.execute(sql, params).context(DuckDBSnafu)?; Ok(rows_modified as u64) } } +fn blocking_channel_send(channel: &Sender, item: T) -> Result<()> { + match channel.blocking_send(item) { + Ok(()) => Ok(()), + Err(e) => Err(Error::ChannelError { + message: format!("{e}"), + } + .into()), + } +} + #[must_use] pub fn flatten_table_function_name(table_reference: &TableReference) -> String { let table_name = table_reference.table(); diff --git a/src/sql/db_connection_pool/duckdbpool.rs b/src/sql/db_connection_pool/duckdbpool.rs index a0fc87a..db8b6a4 100644 --- a/src/sql/db_connection_pool/duckdbpool.rs +++ b/src/sql/db_connection_pool/duckdbpool.rs @@ -1,9 +1,9 @@ use async_trait::async_trait; -use duckdb::{vtab::arrow::ArrowVTab, AccessMode, DuckdbConnectionManager, ToSql}; +use duckdb::{vtab::arrow::ArrowVTab, AccessMode, DuckdbConnectionManager}; use snafu::{prelude::*, ResultExt}; use std::sync::Arc; -use super::{DbConnectionPool, Result}; +use super::{dbconnection::duckdbconn::DuckDBParameter, DbConnectionPool, Result}; use crate::sql::db_connection_pool::{ dbconnection::{duckdbconn::DuckDbConnection, DbConnection, SyncDbConnection}, JoinPushDown, @@ -134,7 +134,7 @@ impl DuckDbConnectionPool { pub fn connect_sync( self: Arc, ) -> Result< - Box, &'static dyn ToSql>>, + Box, DuckDBParameter>>, > { let pool = Arc::clone(&self.pool); let conn: r2d2::PooledConnection = @@ -144,13 +144,13 @@ impl DuckDbConnectionPool { } #[async_trait] -impl DbConnectionPool, &'static dyn ToSql> +impl DbConnectionPool, DuckDBParameter> for DuckDbConnectionPool { async fn connect( &self, ) -> Result< - Box, &'static dyn ToSql>>, + Box, DuckDBParameter>>, > { let pool = Arc::clone(&self.pool); let conn: r2d2::PooledConnection = @@ -225,7 +225,6 @@ fn extract_db_name(file_path: Arc) -> Result { #[cfg(test)] mod test { - use rand::Rng; use super::*; @@ -265,6 +264,7 @@ mod test { } #[tokio::test] + #[cfg(feature = "duckdb-federation")] async fn test_duckdb_connection_pool_with_attached_databases() { let db_base_name = random_db_name(); let db_attached_name = random_db_name(); diff --git a/src/sql/sql_provider_datafusion/mod.rs b/src/sql/sql_provider_datafusion/mod.rs index 73f9b37..56dabb7 100644 --- a/src/sql/sql_provider_datafusion/mod.rs +++ b/src/sql/sql_provider_datafusion/mod.rs @@ -604,19 +604,30 @@ mod tests { #[cfg(feature = "duckdb")] mod duckdb_tests { use super::*; - use crate::sql::db_connection_pool::dbconnection::duckdbconn::DuckDbConnection; + use crate::sql::db_connection_pool::dbconnection::duckdbconn::{ + DuckDBSyncParameter, DuckDbConnection, + }; use crate::sql::db_connection_pool::{duckdbpool::DuckDbConnectionPool, DbConnectionPool}; - use duckdb::{DuckdbConnectionManager, ToSql}; + use duckdb::DuckdbConnectionManager; #[tokio::test] async fn test_duckdb_table() -> Result<(), Box> { let t = setup_tracing(); let ctx = SessionContext::new(); let pool: Arc< - dyn DbConnectionPool, &dyn ToSql> - + Send + dyn DbConnectionPool< + r2d2::PooledConnection, + Box, + > + Send + Sync, - > = Arc::new(DuckDbConnectionPool::new_memory()?); + > = Arc::new(DuckDbConnectionPool::new_memory()?) + as Arc< + dyn DbConnectionPool< + r2d2::PooledConnection, + Box, + > + Send + + Sync, + >; let conn = pool.connect().await?; let db_conn = conn .as_any() @@ -639,10 +650,19 @@ mod tests { let t = setup_tracing(); let ctx = SessionContext::new(); let pool: Arc< - dyn DbConnectionPool, &dyn ToSql> - + Send + dyn DbConnectionPool< + r2d2::PooledConnection, + Box, + > + Send + Sync, - > = Arc::new(DuckDbConnectionPool::new_memory()?); + > = Arc::new(DuckDbConnectionPool::new_memory()?) + as Arc< + dyn DbConnectionPool< + r2d2::PooledConnection, + Box, + > + Send + + Sync, + >; let conn = pool.connect().await?; let db_conn = conn .as_any() From df09d7f67fcd3306394ad339ebfa2c70550739ce Mon Sep 17 00:00:00 2001 From: peasee <98815791+peasee@users.noreply.github.com> Date: Fri, 6 Sep 2024 12:48:51 +1000 Subject: [PATCH 04/40] fix: Disable federation in memory mode databases (#86) --- src/duckdb.rs | 14 ++++++++++++-- src/sql/db_connection_pool/duckdbpool.rs | 10 +++++++++- src/sqlite.rs | 7 ++++++- 3 files changed, 27 insertions(+), 4 deletions(-) diff --git a/src/duckdb.rs b/src/duckdb.rs index bb9ea8a..c60712f 100644 --- a/src/duckdb.rs +++ b/src/duckdb.rs @@ -267,7 +267,12 @@ impl TableProviderFactory for DuckDBTableProviderFactory { )); #[cfg(feature = "duckdb-federation")] - let read_provider = Arc::new(read_provider.create_federated_table_provider()?); + let read_provider: Arc = if mode == Mode::File { + // federation is disabled for in-memory mode until memory connections are updated to use the same database instance instead of separate instances + Arc::new(read_provider.create_federated_table_provider()?) + } else { + read_provider + }; Ok(DuckDBTableWriter::create( read_provider, @@ -443,7 +448,12 @@ impl DuckDBTableFactory { )); #[cfg(feature = "duckdb-federation")] - let table_provider = Arc::new(table_provider.create_federated_table_provider()?); + let table_provider: Arc = if self.pool.mode() == Mode::File { + // federation is disabled for in-memory mode until memory connections are updated to use the same database instance instead of separate instances + Arc::new(table_provider.create_federated_table_provider()?) + } else { + table_provider + }; Ok(table_provider) } diff --git a/src/sql/db_connection_pool/duckdbpool.rs b/src/sql/db_connection_pool/duckdbpool.rs index db8b6a4..45b8fbe 100644 --- a/src/sql/db_connection_pool/duckdbpool.rs +++ b/src/sql/db_connection_pool/duckdbpool.rs @@ -3,7 +3,7 @@ use duckdb::{vtab::arrow::ArrowVTab, AccessMode, DuckdbConnectionManager}; use snafu::{prelude::*, ResultExt}; use std::sync::Arc; -use super::{dbconnection::duckdbconn::DuckDBParameter, DbConnectionPool, Result}; +use super::{dbconnection::duckdbconn::DuckDBParameter, DbConnectionPool, Mode, Result}; use crate::sql::db_connection_pool::{ dbconnection::{duckdbconn::DuckDbConnection, DbConnection, SyncDbConnection}, JoinPushDown, @@ -36,6 +36,7 @@ pub struct DuckDbConnectionPool { pool: Arc>, join_push_down: JoinPushDown, attached_databases: Vec>, + mode: Mode, } impl DuckDbConnectionPool { @@ -71,6 +72,7 @@ impl DuckDbConnectionPool { // There can't be any other tables that share the same context for an in-memory DuckDB. join_push_down: JoinPushDown::Disallow, attached_databases: Vec::new(), + mode: Mode::Memory, }) } @@ -108,6 +110,7 @@ impl DuckDbConnectionPool { // Allow join-push down for any other instances that connect to the same underlying file. join_push_down: JoinPushDown::AllowedFor(path.to_string()), attached_databases: Vec::new(), + mode: Mode::File, }) } @@ -141,6 +144,11 @@ impl DuckDbConnectionPool { pool.get().context(ConnectionPoolSnafu)?; Ok(Box::new(DuckDbConnection::new(conn))) } + + #[must_use] + pub fn mode(&self) -> Mode { + self.mode + } } #[async_trait] diff --git a/src/sqlite.rs b/src/sqlite.rs index 829531a..32e6183 100644 --- a/src/sqlite.rs +++ b/src/sqlite.rs @@ -261,7 +261,12 @@ impl TableProviderFactory for SqliteTableProviderFactory { .map_err(to_datafusion_error)?; #[cfg(feature = "sqlite-federation")] - let read_provider = Arc::new(read_provider.create_federated_table_provider()?); + let read_provider: Arc = if mode == Mode::File { + // federation is disabled for in-memory mode until memory connections are updated to use the same database instance instead of separate instances + Arc::new(read_provider.create_federated_table_provider()?) + } else { + read_provider + }; Ok(SqliteTableWriter::create( read_provider, From b5faa36a28e8cd8e89251f393e4e3e21e2ac9537 Mon Sep 17 00:00:00 2001 From: Sergei Grebnov Date: Sat, 7 Sep 2024 00:47:01 -0700 Subject: [PATCH 05/40] SQLite: Validate expected indexes when attaching local datasets (#88) * SQLite: Validate expected indexes when attaching local datasets * Add test for indexes creation and retrieval (SQLite) * Update warning messages --- src/sqlite.rs | 156 +++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 155 insertions(+), 1 deletion(-) diff --git a/src/sqlite.rs b/src/sqlite.rs index 32e6183..2b017a1 100644 --- a/src/sqlite.rs +++ b/src/sqlite.rs @@ -1,5 +1,5 @@ use crate::sql::arrow_sql_gen::statement::{CreateTableBuilder, IndexBuilder, InsertBuilder}; -use crate::sql::db_connection_pool::dbconnection::{self, get_schema}; +use crate::sql::db_connection_pool::dbconnection::{self, get_schema, AsyncDbConnection}; use crate::sql::db_connection_pool::sqlitepool::SqliteConnectionPoolFactory; use crate::sql::db_connection_pool::{ self, @@ -8,6 +8,7 @@ use crate::sql::db_connection_pool::{ DbConnectionPool, Mode, }; use crate::sql::sql_provider_datafusion; +use arrow::array::StringArray; use arrow::{array::RecordBatch, datatypes::SchemaRef}; use async_trait::async_trait; use datafusion::catalog::Session; @@ -19,9 +20,11 @@ use datafusion::{ logical_expr::CreateExternalTable, sql::TableReference, }; +use futures::TryStreamExt; use rusqlite::{ToSql, Transaction}; use snafu::prelude::*; use sql_table::SQLiteTable; +use std::collections::HashSet; use std::{collections::HashMap, sync::Arc}; use tokio_rusqlite::Connection; @@ -246,6 +249,11 @@ impl TableProviderFactory for SqliteTableProviderFactory { .await .context(UnableToCreateTableSnafu) .map_err(to_datafusion_error)?; + } else if !sqlite.verify_indexes_match(sqlite_conn, &indexes).await? { + tracing::warn!( + "The local table definition at '{db_path}' for '{name}' does not match the expected configuration. To fix this, drop the existing local copy. A new table with the correct schema will be automatically created upon first access.", + name = name + ); } let dyn_pool: Arc = read_pool; @@ -442,4 +450,150 @@ impl Sqlite { Ok(()) } + + async fn get_indexes( + &self, + sqlite_conn: &mut SqliteConnection, + ) -> DataFusionResult> { + let query_result = sqlite_conn + .query_arrow( + format!("PRAGMA index_list({name})", name = self.table_name).as_str(), + &[], + None, + ) + .await?; + + let mut indexes = HashSet::new(); + + query_result + .try_collect::>() + .await + .into_iter() + .flatten() + .for_each(|batch| { + if let Some(name_array) = batch + .column_by_name("name") + .and_then(|col| col.as_any().downcast_ref::()) + { + for index_name in name_array.iter().flatten() { + // Filter out SQLite's auto-generated indexes + if !index_name.starts_with("sqlite_autoindex_") { + indexes.insert(index_name.to_string()); + } + } + } + }); + + Ok(indexes) + } + + async fn verify_indexes_match( + &self, + sqlite_conn: &mut SqliteConnection, + indexes: &[(ColumnReference, IndexType)], + ) -> DataFusionResult { + let expected_indexes_str_map: HashSet = indexes + .iter() + .map(|(col, _)| IndexBuilder::new(&self.table_name, col.iter().collect()).index_name()) + .collect(); + + let actual_indexes_str_map = self.get_indexes(sqlite_conn).await?; + + let missing_in_actual = expected_indexes_str_map + .difference(&actual_indexes_str_map) + .collect::>(); + let extra_in_actual = actual_indexes_str_map + .difference(&expected_indexes_str_map) + .collect::>(); + + if !missing_in_actual.is_empty() { + tracing::warn!( + "Missing indexes detected for the table '{name}': {:?}.", + missing_in_actual, + name = self.table_name + ); + } + if !extra_in_actual.is_empty() { + tracing::warn!( + "The table '{name}' contains unexpected indexes not presented in the configuration: {:?}.", + extra_in_actual, + name = self.table_name + ); + } + + Ok(missing_in_actual.is_empty() && extra_in_actual.is_empty()) + } +} + +#[cfg(test)] +pub(crate) mod tests { + + use arrow::datatypes::{DataType, Schema}; + use datafusion::{common::ToDFSchema, prelude::SessionContext}; + + use super::*; + + #[tokio::test] + async fn test_sqlite_table_creation_with_indexes() { + let schema = Arc::new(Schema::new(vec![ + arrow::datatypes::Field::new("first_name", DataType::Utf8, false), + arrow::datatypes::Field::new("last_name", DataType::Utf8, false), + arrow::datatypes::Field::new("id", DataType::Int64, false), + ])); + + let options: HashMap = [( + "indexes".to_string(), + "id:enabled;(first_name, last_name):unique".to_string(), + )] + .iter() + .cloned() + .collect(); + + let expected_indexes: HashSet = [ + "i_test_table_id".to_string(), + "i_test_table_first_name_last_name".to_string(), + ] + .iter() + .cloned() + .collect(); + + let df_schema = ToDFSchema::to_dfschema_ref(Arc::clone(&schema)).expect("df schema"); + + let external_table = CreateExternalTable { + schema: df_schema, + name: TableReference::bare("test_table"), + location: String::new(), + file_type: String::new(), + table_partition_cols: vec![], + if_not_exists: true, + definition: None, + order_exprs: vec![], + unbounded: false, + options, + constraints: Constraints::empty(), + column_defaults: HashMap::default(), + }; + let ctx = SessionContext::new(); + let table = SqliteTableProviderFactory::default() + .create(&ctx.state(), &external_table) + .await + .expect("table should be created"); + + let sqlite = table + .as_any() + .downcast_ref::() + .expect("downcast to SqliteTableWriter") + .sqlite(); + + let mut db_conn = sqlite.connect().await.expect("should connect to db"); + let sqlite_conn = + Sqlite::sqlite_conn(&mut db_conn).expect("should create sqlite connection"); + + let retrieved_indexes = sqlite + .get_indexes(sqlite_conn) + .await + .expect("should get indexes"); + + assert_eq!(retrieved_indexes, expected_indexes); + } } From ba766257c9868366f552e28b6a4cea56d21fcb6d Mon Sep 17 00:00:00 2001 From: Sergei Grebnov Date: Sun, 8 Sep 2024 17:55:25 -0700 Subject: [PATCH 06/40] SQLite: Validate expected primary keys when attaching local datasets (#89) --- src/sqlite.rs | 118 +++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 113 insertions(+), 5 deletions(-) diff --git a/src/sqlite.rs b/src/sqlite.rs index 2b017a1..55b71d6 100644 --- a/src/sqlite.rs +++ b/src/sqlite.rs @@ -8,7 +8,7 @@ use crate::sql::db_connection_pool::{ DbConnectionPool, Mode, }; use crate::sql::sql_provider_datafusion; -use arrow::array::StringArray; +use arrow::array::{Int64Array, StringArray}; use arrow::{array::RecordBatch, datatypes::SchemaRef}; use async_trait::async_trait; use datafusion::catalog::Session; @@ -249,11 +249,20 @@ impl TableProviderFactory for SqliteTableProviderFactory { .await .context(UnableToCreateTableSnafu) .map_err(to_datafusion_error)?; - } else if !sqlite.verify_indexes_match(sqlite_conn, &indexes).await? { - tracing::warn!( + } else { + let mut table_definition_matches = true; + + table_definition_matches &= sqlite.verify_indexes_match(sqlite_conn, &indexes).await?; + table_definition_matches &= sqlite + .verify_primary_keys_match(sqlite_conn, &primary_keys) + .await?; + + if !table_definition_matches { + tracing::warn!( "The local table definition at '{db_path}' for '{name}' does not match the expected configuration. To fix this, drop the existing local copy. A new table with the correct schema will be automatically created upon first access.", name = name ); + } } let dyn_pool: Arc = read_pool; @@ -487,6 +496,47 @@ impl Sqlite { Ok(indexes) } + async fn get_primary_keys( + &self, + sqlite_conn: &mut SqliteConnection, + ) -> DataFusionResult> { + let query_result = sqlite_conn + .query_arrow( + format!("PRAGMA table_info({name})", name = self.table_name).as_str(), + &[], + None, + ) + .await?; + + let mut primary_keys = HashSet::new(); + + query_result + .try_collect::>() + .await + .into_iter() + .flatten() + .for_each(|batch| { + if let (Some(name_array), Some(pk_array)) = ( + batch + .column_by_name("name") + .and_then(|col| col.as_any().downcast_ref::()), + batch + .column_by_name("pk") + .and_then(|col| col.as_any().downcast_ref::()), + ) { + // name and pk fields can't be None so it is safe to flatten both + for (name, pk) in name_array.iter().flatten().zip(pk_array.iter().flatten()) { + if pk > 0 { + // pk > 0 indicates primary key + primary_keys.insert(name.to_string()); + } + } + } + }); + + Ok(primary_keys) + } + async fn verify_indexes_match( &self, sqlite_conn: &mut SqliteConnection, @@ -523,13 +573,49 @@ impl Sqlite { Ok(missing_in_actual.is_empty() && extra_in_actual.is_empty()) } + + async fn verify_primary_keys_match( + &self, + sqlite_conn: &mut SqliteConnection, + primary_keys: &[String], + ) -> DataFusionResult { + let expected_pk_keys_str_map: HashSet = primary_keys.iter().cloned().collect(); + + let actual_pk_keys_str_map = self.get_primary_keys(sqlite_conn).await?; + + let missing_in_actual = expected_pk_keys_str_map + .difference(&actual_pk_keys_str_map) + .collect::>(); + let extra_in_actual = actual_pk_keys_str_map + .difference(&expected_pk_keys_str_map) + .collect::>(); + + if !missing_in_actual.is_empty() { + tracing::warn!( + "Missing primary keys detected for the table '{name}': {:?}.", + missing_in_actual, + name = self.table_name + ); + } + if !extra_in_actual.is_empty() { + tracing::warn!( + "The table '{name}' contains unexpected primary keys not presented in the configuration: {:?}.", + extra_in_actual, + name = self.table_name + ); + } + + Ok(missing_in_actual.is_empty() && extra_in_actual.is_empty()) + } } #[cfg(test)] pub(crate) mod tests { use arrow::datatypes::{DataType, Schema}; - use datafusion::{common::ToDFSchema, prelude::SessionContext}; + use datafusion::{ + common::ToDFSchema, prelude::SessionContext, sql::sqlparser::ast::TableConstraint, + }; use super::*; @@ -559,6 +645,21 @@ pub(crate) mod tests { let df_schema = ToDFSchema::to_dfschema_ref(Arc::clone(&schema)).expect("df schema"); + let expected_primary_keys: HashSet = ["id".to_string()].iter().cloned().collect(); + + let primary_keys_constraints = Constraints::new_from_table_constraints( + &[TableConstraint::PrimaryKey { + columns: vec!["id"].into_iter().map(Into::into).collect(), + name: None, + index_name: None, + index_options: vec![], + characteristics: None, + index_type: None, + }], + &df_schema, + ) + .expect("should create constraints"); + let external_table = CreateExternalTable { schema: df_schema, name: TableReference::bare("test_table"), @@ -570,7 +671,7 @@ pub(crate) mod tests { order_exprs: vec![], unbounded: false, options, - constraints: Constraints::empty(), + constraints: primary_keys_constraints, column_defaults: HashMap::default(), }; let ctx = SessionContext::new(); @@ -595,5 +696,12 @@ pub(crate) mod tests { .expect("should get indexes"); assert_eq!(retrieved_indexes, expected_indexes); + + let retrieved_primary_keys = sqlite + .get_primary_keys(sqlite_conn) + .await + .expect("should get primary keys"); + + assert_eq!(retrieved_primary_keys, expected_primary_keys); } } From dc9223e8e8a5fc322858c0ee115e9aa7fd2c19c0 Mon Sep 17 00:00:00 2001 From: Phillip LeBlanc Date: Mon, 9 Sep 2024 15:43:48 +0900 Subject: [PATCH 07/40] Change to use Spice AI fork of sea_query for SQLite decimal support (#90) --- Cargo.toml | 3 +-- src/sql/arrow_sql_gen/statement.rs | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 0a68bb3..c1c886a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,7 +20,6 @@ num-bigint = "0.4.4" base64 = { version = "0.22.1", optional = true } bytes = { version = "1.7.1", optional = true } bigdecimal = "0.4.5" -bigdecimal_0_3_0 = { package = "bigdecimal", version = "0.3.0" } byteorder = "1.5.0" chrono = "0.4.38" datafusion = "41.0.0" @@ -42,7 +41,7 @@ mysql_async = { version = "0.34.1", features = ["native-tls-tls", "chrono"], opt prost = { version = "0.12" , optional = true } # pinned for arrow-flight compat r2d2 = { version = "0.8.10", optional = true } rusqlite = { version = "0.31.0", optional = true } -sea-query = { version = "0.31.0", features = ["backend-sqlite", "backend-postgres", "postgres-array", "with-rust_decimal", "with-bigdecimal", "with-time", "with-chrono"] } +sea-query = { git = "https://github.com/spiceai/sea-query.git", rev = "213b6b876068f58159ebdd5852604a021afaebf9", features = ["backend-sqlite", "backend-postgres", "postgres-array", "with-rust_decimal", "with-bigdecimal", "with-time", "with-chrono"] } secrecy = "0.8.0" serde = { version = "1.0.209", optional = true } serde_json = "1.0.124" diff --git a/src/sql/arrow_sql_gen/statement.rs b/src/sql/arrow_sql_gen/statement.rs index 1d336ca..344765a 100644 --- a/src/sql/arrow_sql_gen/statement.rs +++ b/src/sql/arrow_sql_gen/statement.rs @@ -7,7 +7,7 @@ use arrow::{ datatypes::{DataType, Field, Fields, IntervalUnit, Schema, SchemaRef, TimeUnit}, util::display::array_value_to_string, }; -use bigdecimal_0_3_0::BigDecimal; +use bigdecimal::BigDecimal; use chrono::{DateTime, FixedOffset}; use num_bigint::BigInt; use sea_query::{ From f0e846786513b6a10c8dce52712d8a7181a57333 Mon Sep 17 00:00:00 2001 From: peasee <98815791+peasee@users.noreply.github.com> Date: Tue, 10 Sep 2024 12:08:24 +1000 Subject: [PATCH 08/40] fix: Don't silence blocking task errors (#91) * fix: Don't silence blocking task errors * fix: Cover Ok(Err()) match arm for DuckDB writer handle * refactor: Rename overloaded error e --- src/duckdb/write.rs | 24 ++++++++++++++----- .../dbconnection/duckdbconn.rs | 16 +++++++++---- src/sql/db_connection_pool/duckdbpool.rs | 3 ++- 3 files changed, 32 insertions(+), 11 deletions(-) diff --git a/src/duckdb/write.rs b/src/duckdb/write.rs index a13b7b8..a7eb3ce 100644 --- a/src/duckdb/write.rs +++ b/src/duckdb/write.rs @@ -184,16 +184,28 @@ impl DataSink for DuckDBDataSink { .context(super::ConstraintViolationSnafu) .map_err(to_datafusion_error)?; - batch_tx.send(batch).await.map_err(|e| { - DataFusionError::Execution(format!( - "Unable to send RecordBatch to duckdb writer: {e}" - )) - })?; + if let Err(send_error) = batch_tx.send(batch).await { + match duckdb_write_handle.await { + Err(join_error) => { + return Err(DataFusionError::Execution(format!( + "Error writing to DuckDB: {join_error}" + ))); + } + Ok(Err(datafusion_error)) => { + return Err(datafusion_error); + } + _ => { + return Err(DataFusionError::Execution(format!( + "Unable to send RecordBatch to DuckDB writer: {send_error}" + ))) + } + }; + } } if notify_commit_transaction.send(()).is_err() { return Err(DataFusionError::Execution( - "Unable to send message to commit transaction to duckdb writer.".to_string(), + "Unable to send message to commit transaction to DuckDB writer.".to_string(), )); }; diff --git a/src/sql/db_connection_pool/dbconnection/duckdbconn.rs b/src/sql/db_connection_pool/dbconnection/duckdbconn.rs index f97ed69..563cfc7 100644 --- a/src/sql/db_connection_pool/dbconnection/duckdbconn.rs +++ b/src/sql/db_connection_pool/dbconnection/duckdbconn.rs @@ -151,10 +151,18 @@ impl SyncDbConnection, DuckDBPar yield Ok(batch); } - if let Err(e) = join_handle.await { - yield Err(DataFusionError::Execution(format!( - "Failed to execute DuckDB query: {e}" - ))) + match join_handle.await { + Ok(Err(task_error)) => { + yield Err(DataFusionError::Execution(format!( + "Failed to execute DuckDB query: {task_error}" + ))) + }, + Err(join_error) => { + yield Err(DataFusionError::Execution(format!( + "Failed to execute DuckDB query: {join_error}" + ))) + }, + _ => {} } }; diff --git a/src/sql/db_connection_pool/duckdbpool.rs b/src/sql/db_connection_pool/duckdbpool.rs index 45b8fbe..be5ff94 100644 --- a/src/sql/db_connection_pool/duckdbpool.rs +++ b/src/sql/db_connection_pool/duckdbpool.rs @@ -186,6 +186,7 @@ impl DbConnectionPool, DuckDBPar .context(DuckDBSnafu)?; db_ids.push(db_id); } + conn.execute(&format!("SET search_path = \"{}\"", db_ids.join(",")), []) .context(DuckDBSnafu)?; } @@ -247,7 +248,7 @@ mod test { name.push(rng.gen_range(b'a'..=b'z') as char); } - format!("./{name}.sqlite") + format!("./{name}.duckdb") } #[tokio::test] From 996cef0fc8c820f0125b70b572ad376499fa23f4 Mon Sep 17 00:00:00 2001 From: peasee <98815791+peasee@users.noreply.github.com> Date: Tue, 10 Sep 2024 23:31:28 +1000 Subject: [PATCH 09/40] fix: Re-attach databases on each DuckDB query (#92) * fix: Re-attach databases on each query * Update src/sql/db_connection_pool/dbconnection/duckdbconn.rs --------- Co-authored-by: Phillip LeBlanc --- .../dbconnection/duckdbconn.rs | 153 +++++++++++++++++- src/sql/db_connection_pool/duckdbpool.rs | 58 +++---- 2 files changed, 179 insertions(+), 32 deletions(-) diff --git a/src/sql/db_connection_pool/dbconnection/duckdbconn.rs b/src/sql/db_connection_pool/dbconnection/duckdbconn.rs index 563cfc7..31c8aeb 100644 --- a/src/sql/db_connection_pool/dbconnection/duckdbconn.rs +++ b/src/sql/db_connection_pool/dbconnection/duckdbconn.rs @@ -1,4 +1,5 @@ use std::any::Any; +use std::sync::Arc; use arrow::array::RecordBatch; use async_stream::stream; @@ -10,8 +11,8 @@ use datafusion::sql::sqlparser::ast::TableFactor; use datafusion::sql::sqlparser::parser::Parser; use datafusion::sql::sqlparser::{dialect::DuckDbDialect, tokenizer::Tokenizer}; use datafusion::sql::TableReference; -use duckdb::DuckdbConnectionManager; use duckdb::ToSql; +use duckdb::{Connection, DuckdbConnectionManager}; use dyn_clone::DynClone; use snafu::{prelude::*, ResultExt}; use tokio::sync::mpsc::Sender; @@ -27,6 +28,15 @@ pub enum Error { #[snafu(display("ChannelError: {message}"))] ChannelError { message: String }, + + #[snafu(display("Unable to attach DuckDB database {path}: {source}"))] + UnableToAttachDatabase { + path: Arc, + source: std::io::Error, + }, + + #[snafu(display("Unable to extract database name from database file path"))] + UnableToExtractDatabaseNameFromPath { path: Arc }, } pub trait DuckDBSyncParameter: ToSql + Sync + Send + DynClone { @@ -41,8 +51,105 @@ impl DuckDBSyncParameter for T { dyn_clone::clone_trait_object!(DuckDBSyncParameter); pub type DuckDBParameter = Box; +#[derive(Debug)] +pub struct DuckDBAttachments { + attachments: Vec>, + search_path: Arc, +} + +impl DuckDBAttachments { + /// Creates a new instance of a `DuckDBAttachments`, which instructs DuckDB connections to attach other DuckDB databases for queries. + #[must_use] + pub fn new(id: &str, attachments: &[Arc]) -> Self { + let search_path = Self::get_search_path(id, attachments); + Self { + attachments: attachments.to_owned(), + search_path, + } + } + + /// Returns the search path for the given database and attachments. + /// The given database needs to be included separately, as search path by default do not include the main database. + #[must_use] + pub fn get_search_path(id: &str, attachments: &[Arc]) -> Arc { + // search path includes the main database and all attached databases + let mut search_path: Vec> = vec![id.into()]; + + search_path.extend( + attachments + .iter() + .enumerate() + .map(|(i, _)| format!("attachment_{i}").into()), + ); + + search_path.join(",").into() + } + + /// Sets the search path for the given connection. + /// + /// # Errors + /// + /// Returns an error if the search path cannot be set or the connection fails. + pub fn set_search_path(&self, conn: &Connection) -> Result<()> { + conn.execute(&format!("SET search_path ='{}'", self.search_path), []) + .context(DuckDBSnafu)?; + Ok(()) + } + + /// Resets the search path for the given connection to default. + /// + /// # Errors + /// + /// Returns an error if the search path cannot be set or the connection fails. + pub fn reset_search_path(&self, conn: &Connection) -> Result<()> { + conn.execute("RESET search_path", []).context(DuckDBSnafu)?; + Ok(()) + } + + /// Attaches the databases to the given connection and sets the search path for the newly attached databases. + /// + /// # Errors + /// + /// Returns an error if a specific attachment is missing, cannot be attached, search path cannot be set or the connection fails. + pub fn attach(&self, conn: &Connection) -> Result<()> { + for (i, db) in self.attachments.iter().enumerate() { + // check the db file exists + std::fs::metadata(db.as_ref()).context(UnableToAttachDatabaseSnafu { + path: Arc::clone(db), + })?; + + conn.execute( + &format!("ATTACH IF NOT EXISTS '{db}' AS attachment_{i} (READ_ONLY)"), + [], + ) + .context(DuckDBSnafu)?; + } + + self.set_search_path(conn)?; + + Ok(()) + } + + /// Detaches the databases from the given connection and resets the search path to default. + /// + /// # Errors + /// + /// Returns an error if an attachment cannot be detached, search path cannot be set or the connection fails. + pub fn detach(&self, conn: &Connection) -> Result<()> { + for (i, _) in self.attachments.iter().enumerate() { + conn.execute(&format!("DETACH attachment_{i}"), []) + .context(DuckDBSnafu)?; + } + + self.reset_search_path(conn)?; + + Ok(()) + } +} + pub struct DuckDbConnection { pub conn: r2d2::PooledConnection, + attachments: Option>, } impl DuckDbConnection { @@ -51,6 +158,36 @@ impl DuckDbConnection { ) -> &mut r2d2::PooledConnection { &mut self.conn } + + #[must_use] + pub fn with_attachments(mut self, attachments: Option>) -> Self { + self.attachments = attachments; + self + } + + /// Passthrough if Option is Some for `DuckDBAttachments::attach` + /// + /// # Errors + /// + /// See `DuckDBAttachments::attach` for more information. + pub fn attach(conn: &Connection, attachments: &Option>) -> Result<()> { + if let Some(attachments) = attachments { + attachments.attach(conn)?; + } + Ok(()) + } + + /// Passthrough if Option is Some for `DuckDBAttachments::detach` + /// + /// # Errors + /// + /// See `DuckDBAttachments::detach` for more information. + pub fn detach(conn: &Connection, attachments: &Option>) -> Result<()> { + if let Some(attachments) = attachments { + attachments.detach(conn)?; + } + Ok(()) + } } impl DbConnection, DuckDBParameter> @@ -77,7 +214,10 @@ impl SyncDbConnection, DuckDBPar for DuckDbConnection { fn new(conn: r2d2::PooledConnection) -> Self { - DuckDbConnection { conn } + DuckDbConnection { + conn, + attachments: None, + } } fn get_schema(&self, table_reference: &TableReference) -> Result { @@ -108,6 +248,7 @@ impl SyncDbConnection, DuckDBPar ) -> Result { let (batch_tx, mut batch_rx) = tokio::sync::mpsc::channel::(4); + Self::attach(&self.conn, &self.attachments)?; let fetch_schema_sql = format!("WITH fetch_schema AS ({sql}) SELECT * FROM fetch_schema LIMIT 0"); let mut stmt = self @@ -121,16 +262,21 @@ impl SyncDbConnection, DuckDBPar .boxed() .context(super::UnableToGetSchemaSnafu)?; + Self::detach(&self.conn, &self.attachments)?; + let schema = result.get_schema(); let params = params.iter().map(dyn_clone::clone).collect::>(); - let conn = self.conn.try_clone()?; + let conn = self.conn.try_clone()?; // try_clone creates a new connection to the same database + // this creates a new connection session, requiring resetting the ATTACHments and search_path let sql = sql.to_string(); let cloned_schema = schema.clone(); + let attachments = self.attachments.clone(); let join_handle = tokio::task::spawn_blocking(move || { + Self::attach(&conn, &attachments)?; // this attach could happen when we clone the connection, but we can't detach after the thread closes because the connection isn't thread safe let mut stmt = conn.prepare(&sql).context(DuckDBSnafu)?; let params: &[&dyn ToSql] = ¶ms .iter() @@ -143,6 +289,7 @@ impl SyncDbConnection, DuckDBPar blocking_channel_send(&batch_tx, i)?; } + Self::detach(&conn, &attachments)?; Ok::<_, Box>(()) }); diff --git a/src/sql/db_connection_pool/duckdbpool.rs b/src/sql/db_connection_pool/duckdbpool.rs index be5ff94..e7964c4 100644 --- a/src/sql/db_connection_pool/duckdbpool.rs +++ b/src/sql/db_connection_pool/duckdbpool.rs @@ -3,7 +3,10 @@ use duckdb::{vtab::arrow::ArrowVTab, AccessMode, DuckdbConnectionManager}; use snafu::{prelude::*, ResultExt}; use std::sync::Arc; -use super::{dbconnection::duckdbconn::DuckDBParameter, DbConnectionPool, Mode, Result}; +use super::{ + dbconnection::duckdbconn::{DuckDBAttachments, DuckDBParameter}, + DbConnectionPool, Mode, Result, +}; use crate::sql::db_connection_pool::{ dbconnection::{duckdbconn::DuckDbConnection, DbConnection, SyncDbConnection}, JoinPushDown, @@ -142,13 +145,33 @@ impl DuckDbConnectionPool { let pool = Arc::clone(&self.pool); let conn: r2d2::PooledConnection = pool.get().context(ConnectionPoolSnafu)?; - Ok(Box::new(DuckDbConnection::new(conn))) + + let attachments = self.get_attachments()?; + + Ok(Box::new( + DuckDbConnection::new(conn).with_attachments(attachments), + )) } #[must_use] pub fn mode(&self) -> Mode { self.mode } + + pub fn get_attachments(&self) -> Result>> { + if self.attached_databases.is_empty() { + Ok(None) + } else { + #[cfg(not(feature = "duckdb-federation"))] + return Ok(None); + + #[cfg(feature = "duckdb-federation")] + Ok(Some(Arc::new(DuckDBAttachments::new( + &extract_db_name(Arc::clone(&self.path))?, + &self.attached_databases, + )))) + } + } } #[async_trait] @@ -164,34 +187,11 @@ impl DbConnectionPool, DuckDBPar let conn: r2d2::PooledConnection = pool.get().context(ConnectionPoolSnafu)?; - #[cfg(feature = "duckdb-federation")] - if !self.attached_databases.is_empty() { - let mut db_ids = Vec::new(); - db_ids.push(extract_db_name(Arc::clone(&self.path))?); - - for (i, db) in self.attached_databases.iter().enumerate() { - // check the db file exists - std::fs::metadata(db.as_ref()).context(UnableToAttachDatabaseSnafu { - path: Arc::clone(db), - })?; - - let db_id = format!("attachment_{i}"); - conn.execute( - &format!( - "ATTACH IF NOT EXISTS '{db}' AS {} (READ_ONLY)", - db_id.clone() - ), - [], - ) - .context(DuckDBSnafu)?; - db_ids.push(db_id); - } - - conn.execute(&format!("SET search_path = \"{}\"", db_ids.join(",")), []) - .context(DuckDBSnafu)?; - } + let attachments = self.get_attachments()?; - Ok(Box::new(DuckDbConnection::new(conn))) + Ok(Box::new( + DuckDbConnection::new(conn).with_attachments(attachments), + )) } fn join_push_down(&self) -> JoinPushDown { From 9cb103c45d2f71cb2294d272eac55240ecd7efe7 Mon Sep 17 00:00:00 2001 From: Qianqian <130200611+Sevenannn@users.noreply.github.com> Date: Fri, 13 Sep 2024 23:47:57 -0700 Subject: [PATCH 10/40] Correctly handle mysql timestamp() and datetime() types (#96) * Correctly handle mysql timestamp() and datetime() types * Restructure MySQL test, add test for timestamp() types * Include test for datetime types --- src/sql/arrow_sql_gen/mysql.rs | 64 +--- .../dbconnection/mysqlconn.rs | 6 +- tests/mysql/mod.rs | 273 +++++++++++++++--- 3 files changed, 255 insertions(+), 88 deletions(-) diff --git a/src/sql/arrow_sql_gen/mysql.rs b/src/sql/arrow_sql_gen/mysql.rs index 1a7682e..4e54460 100644 --- a/src/sql/arrow_sql_gen/mysql.rs +++ b/src/sql/arrow_sql_gen/mysql.rs @@ -4,7 +4,7 @@ use arrow::{ ArrayBuilder, ArrayRef, BinaryBuilder, Date32Builder, Decimal128Builder, Float32Builder, Float64Builder, Int16Builder, Int32Builder, Int64Builder, Int8Builder, LargeStringBuilder, NullBuilder, RecordBatch, RecordBatchOptions, Time64NanosecondBuilder, - TimestampMillisecondBuilder, UInt64Builder, + TimestampMicrosecondBuilder, UInt64Builder, }, datatypes::{DataType, Date32Type, Field, Schema, SchemaRef, TimeUnit}, }; @@ -14,6 +14,7 @@ use chrono::{NaiveDate, NaiveTime, Timelike}; use mysql_async::{consts::ColumnFlags, consts::ColumnType, FromValueError, Row, Value}; use snafu::{ResultExt, Snafu}; use std::{convert, sync::Arc}; +use time::PrimitiveDateTime; #[derive(Debug, Snafu)] pub enum Error { @@ -358,61 +359,25 @@ pub fn rows_to_arrow(rows: &[Row], projected_schema: &Option) -> Resu }; let Some(builder) = builder .as_any_mut() - .downcast_mut::() + .downcast_mut::() else { return FailedToDowncastBuilderSnafu { mysql_type: format!("{mysql_type:?}"), } .fail(); }; - let v = handle_null_error(row.get_opt::(i).transpose()).context( - FailedToGetRowValueSnafu { + let v = + handle_null_error(row.get_opt::(i).transpose()) + .context(FailedToGetRowValueSnafu { mysql_type: column_type, - }, - )?; + })?; match v { Some(v) => { - let timestamp = match v { - Value::Date(year, month, day, hour, minute, second, micros) => { - let timestamp = chrono::NaiveDate::from_ymd_opt( - i32::from(year), - u32::from(month), - u32::from(day), - ) - .unwrap_or_default() - .and_hms_micro_opt( - u32::from(hour), - u32::from(minute), - u32::from(second), - micros, - ) - .unwrap_or_default() - .and_utc(); - timestamp.timestamp() * 1000 - } - Value::Time(is_neg, days, hours, minutes, seconds, micros) => { - let naive_time = chrono::NaiveTime::from_hms_micro_opt( - u32::from(hours), - u32::from(minutes), - u32::from(seconds), - micros, - ) - .unwrap_or_default(); - - let time: i64 = naive_time.num_seconds_from_midnight().into(); - - let timestamp = i64::from(days) * 24 * 60 * 60 + time; - - if is_neg { - -timestamp - } else { - timestamp - } - } - _ => 0, - }; - builder.append_value(timestamp); + #[allow(clippy::cast_possible_truncation)] + let timestamp_micros = + (v.assume_utc().unix_timestamp_nanos() / 1_000) as i64; + builder.append_value(timestamp_micros); } None => builder.append_null(), } @@ -450,9 +415,9 @@ pub fn map_column_to_data_type( ColumnType::MYSQL_TYPE_DOUBLE => Some(DataType::Float64), // Decimal precision must be a value between 0x00 - 0x51, so it's safe to unwrap_or_default here ColumnType::MYSQL_TYPE_DECIMAL | ColumnType::MYSQL_TYPE_NEWDECIMAL => Some(DataType::Decimal128(column_decimal_precision.unwrap_or_default(), column_decimal_scale.unwrap_or_default())), - ColumnType::MYSQL_TYPE_TIMESTAMP | ColumnType::MYSQL_TYPE_DATETIME | ColumnType::MYSQL_TYPE_DATETIME2 => { - Some(DataType::Timestamp(TimeUnit::Millisecond, None)) - } + ColumnType::MYSQL_TYPE_TIMESTAMP | ColumnType::MYSQL_TYPE_DATETIME => { + Some(DataType::Timestamp(TimeUnit::Microsecond, None)) + }, ColumnType::MYSQL_TYPE_DATE => Some(DataType::Date32), ColumnType::MYSQL_TYPE_TIME => { Some(DataType::Time64(TimeUnit::Nanosecond)) @@ -480,6 +445,7 @@ pub fn map_column_to_data_type( // Unsupported yet | ColumnType::MYSQL_TYPE_UNKNOWN | ColumnType::MYSQL_TYPE_TIMESTAMP2 + | ColumnType::MYSQL_TYPE_DATETIME2 | ColumnType::MYSQL_TYPE_TIME2 | ColumnType::MYSQL_TYPE_GEOMETRY => { unimplemented!("Unsupported column type {:?}", column_type) diff --git a/src/sql/db_connection_pool/dbconnection/mysqlconn.rs b/src/sql/db_connection_pool/dbconnection/mysqlconn.rs index dc8b7f4..74f149d 100644 --- a/src/sql/db_connection_pool/dbconnection/mysqlconn.rs +++ b/src/sql/db_connection_pool/dbconnection/mysqlconn.rs @@ -203,10 +203,8 @@ fn map_str_type_to_column_type(data_type: &str) -> Result { _ if data_type.starts_with("float") => ColumnType::MYSQL_TYPE_FLOAT, _ if data_type.starts_with("double") => ColumnType::MYSQL_TYPE_DOUBLE, _ if data_type.eq("null") => ColumnType::MYSQL_TYPE_NULL, - _ if data_type.eq("timestamp2") => ColumnType::MYSQL_TYPE_TIMESTAMP2, - _ if data_type.eq("timestamp") => ColumnType::MYSQL_TYPE_TIMESTAMP, - _ if data_type.eq("datetime2") => ColumnType::MYSQL_TYPE_DATETIME2, - _ if data_type.eq("datetime") => ColumnType::MYSQL_TYPE_DATETIME, + _ if data_type.starts_with("timestamp") => ColumnType::MYSQL_TYPE_TIMESTAMP, + _ if data_type.starts_with("datetime") => ColumnType::MYSQL_TYPE_DATETIME, _ if data_type.eq("time2") => ColumnType::MYSQL_TYPE_TIME2, _ if data_type.eq("time") => ColumnType::MYSQL_TYPE_TIME, _ if data_type.eq("date") => ColumnType::MYSQL_TYPE_DATE, diff --git a/tests/mysql/mod.rs b/tests/mysql/mod.rs index 32950dd..412f2c9 100644 --- a/tests/mysql/mod.rs +++ b/tests/mysql/mod.rs @@ -5,7 +5,10 @@ use datafusion_table_providers::{ use rstest::rstest; use std::sync::Arc; -use arrow::array::*; +use arrow::{ + array::*, + datatypes::{DataType, Field, Schema, TimeUnit}, +}; use datafusion_table_providers::sql::db_connection_pool::dbconnection::AsyncDbConnection; @@ -13,11 +16,235 @@ use crate::docker::RunningContainer; mod common; -async fn arrow_mysql_one_way(port: usize) { - let table_name = "test_table"; +async fn test_mysql_decimal_types(port: usize) { + let table_name = "decimal_table"; + let create_table_stmt = " + CREATE TABLE IF NOT EXISTS decimal_table (decimal_col DECIMAL(10, 2)); + "; + let insert_table_stmt = " + INSERT INTO decimal_table (decimal_col) VALUES (NULL), (12); + "; + + let schema = Arc::new(Schema::new(vec![Field::new( + "decimal_col", + DataType::Decimal128(10, 2), + true, + )])); + + let expected_record = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new( + Decimal128Array::from(vec![None, Some(i128::from(1200))]) + .with_precision_and_scale(10, 2) + .unwrap(), + )], + ) + .expect("Failed to created arrow record batch"); + + let decimal_record = arrow_mysql_one_way( + port, + "decimal_table", + create_table_stmt, + insert_table_stmt, + expected_record, + ) + .await; +} + +async fn test_mysql_timestamp_types(port: usize) { + let create_table_stmt = " + CREATE TABLE timestamp_table ( + timestamp_no_fraction TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP, + timestamp_one_fraction TIMESTAMP(1), + timestamp_two_fraction TIMESTAMP(2), + timestamp_three_fraction TIMESTAMP(3), + timestamp_four_fraction TIMESTAMP(4), + timestamp_five_fraction TIMESTAMP(5), + timestamp_six_fraction TIMESTAMP(6) +); + "; + let insert_table_stmt = " +INSERT INTO timestamp_table ( + timestamp_no_fraction, + timestamp_one_fraction, + timestamp_two_fraction, + timestamp_three_fraction, + timestamp_four_fraction, + timestamp_five_fraction, + timestamp_six_fraction +) +VALUES +( + '2024-09-12 10:00:00', + '2024-09-12 10:00:00.1', + '2024-09-12 10:00:00.12', + '2024-09-12 10:00:00.123', + '2024-09-12 10:00:00.1234', + '2024-09-12 10:00:00.12345', + '2024-09-12 10:00:00.123456' +); + "; + + let schema = Arc::new(Schema::new(vec![ + Field::new( + "timestamp_no_fraction", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ), + Field::new( + "timestamp_one_fraction", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ), + Field::new( + "timestamp_two_fraction", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ), + Field::new( + "timestamp_three_fraction", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ), + Field::new( + "timestamp_four_fraction", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ), + Field::new( + "timestamp_five_fraction", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ), + Field::new( + "timestamp_six_fraction", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ), + ])); + + let expected_record = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(TimestampMicrosecondArray::from(vec![1_726_135_200_000_000])), + Arc::new(TimestampMicrosecondArray::from(vec![1_726_135_200_100_000])), + Arc::new(TimestampMicrosecondArray::from(vec![1_726_135_200_120_000])), + Arc::new(TimestampMicrosecondArray::from(vec![1_726_135_200_123_000])), + Arc::new(TimestampMicrosecondArray::from(vec![1_726_135_200_123_400])), + Arc::new(TimestampMicrosecondArray::from(vec![1_726_135_200_123_450])), + Arc::new(TimestampMicrosecondArray::from(vec![1_726_135_200_123_456])), + ], + ) + .expect("Failed to created arrow record batch"); + + arrow_mysql_one_way( + port, + "timestamp_table", + create_table_stmt, + insert_table_stmt, + expected_record, + ) + .await; +} + +async fn test_mysql_datetime_types(port: usize) { + let create_table_stmt = " +CREATE TABLE datetime_table ( + dt0 DATETIME(0), + dt1 DATETIME(1), + dt2 DATETIME(2), + dt3 DATETIME(3), + dt4 DATETIME(4), + dt5 DATETIME(5), + dt6 DATETIME(6) +); + + "; + let insert_table_stmt = " +INSERT INTO datetime_table (dt0, dt1, dt2, dt3, dt4, dt5, dt6) +VALUES ( + '2024-09-12 10:00:00', + '2024-09-12 10:00:00.1', + '2024-09-12 10:00:00.12', + '2024-09-12 10:00:00.123', + '2024-09-12 10:00:00.1234', + '2024-09-12 10:00:00.12345', + '2024-09-12 10:00:00.123456' +); + "; + + let schema = Arc::new(Schema::new(vec![ + Field::new( + "dt0", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ), + Field::new( + "dt1", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ), + Field::new( + "dt2", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ), + Field::new( + "dt3", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ), + Field::new( + "dt4", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ), + Field::new( + "dt5", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ), + Field::new( + "dt6", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ), + ])); + + let expected_record = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(TimestampMicrosecondArray::from(vec![1_726_135_200_000_000])), + Arc::new(TimestampMicrosecondArray::from(vec![1_726_135_200_100_000])), + Arc::new(TimestampMicrosecondArray::from(vec![1_726_135_200_120_000])), + Arc::new(TimestampMicrosecondArray::from(vec![1_726_135_200_123_000])), + Arc::new(TimestampMicrosecondArray::from(vec![1_726_135_200_123_400])), + Arc::new(TimestampMicrosecondArray::from(vec![1_726_135_200_123_450])), + Arc::new(TimestampMicrosecondArray::from(vec![1_726_135_200_123_456])), + ], + ) + .expect("Failed to created arrow record batch"); + + arrow_mysql_one_way( + port, + "datetime_table", + create_table_stmt, + insert_table_stmt, + expected_record, + ) + .await; +} + +async fn arrow_mysql_one_way( + port: usize, + table_name: &str, + create_table_stmt: &str, + insert_table_stmt: &str, + expected_record: RecordBatch, +) -> Vec { tracing::debug!("Running tests on {table_name}"); - let ctx = SessionContext::new(); + let ctx = SessionContext::new(); let pool = common::get_mysql_connection_pool(port) .await .expect("MySQL connection pool should be created"); @@ -27,14 +254,6 @@ async fn arrow_mysql_one_way(port: usize) { .await .expect("Connection should be established"); - // Create mysql table with decimal columns that contains null value - let create_table_stmt = " - CREATE TABLE IF NOT EXISTS test_table (id INT AUTO_INCREMENT PRIMARY KEY, salary DECIMAL(10, 2)); - "; - let insert_table_stmt = " - INSERT INTO test_table (salary) VALUES (NULL), (12); - "; - // Create table and insert data into mysql test_table let _ = db_conn .execute(create_table_stmt, &[]) @@ -67,29 +286,10 @@ async fn arrow_mysql_one_way(port: usize) { record_batch[0].columns() ); - let int32_array = Int32Array::from(vec![1, 2]); - let decimal128_array = Decimal128Array::from(vec![None, Some(i128::from(1200))]) - .with_precision_and_scale(10, 2) - .unwrap(); - - // Check results assert_eq!(record_batch.len(), 1); - assert_eq!( - record_batch[0] - .column(0) - .as_any() - .downcast_ref::() - .unwrap(), - &int32_array - ); - assert_eq!( - record_batch[0] - .column(1) - .as_any() - .downcast_ref::() - .unwrap(), - &decimal128_array - ); + assert_eq!(record_batch[0], expected_record); + + return record_batch; } async fn start_mysql_container(port: usize) -> RunningContainer { @@ -108,6 +308,9 @@ async fn test_mysql_arrow_oneway() { let port = crate::get_random_port(); let mysql_container = start_mysql_container(port).await; - arrow_mysql_one_way(port).await; + test_mysql_decimal_types(port).await; + test_mysql_timestamp_types(port).await; + test_mysql_datetime_types(port).await; + mysql_container.remove().await.expect("container to stop"); } From 7017b03d80c6c498cd7b4061c637e2df17475723 Mon Sep 17 00:00:00 2001 From: Qianqian <130200611+Sevenannn@users.noreply.github.com> Date: Sat, 14 Sep 2024 17:54:49 -0700 Subject: [PATCH 11/40] Postgres enum support (#100) * Postgres enum support * Add enum test as part of integration test * update * Remove the duplicate function * fix --------- Co-authored-by: Phillip LeBlanc --- src/sql/arrow_sql_gen/arrow.rs | 15 +++-- src/sql/arrow_sql_gen/postgres.rs | 56 +++++++++++++++++- tests/arrow_record_batch_gen/mod.rs | 22 ++++++- tests/postgres/common.rs | 12 ++++ tests/postgres/mod.rs | 90 +++++++++++++++++++++++++++++ 5 files changed, 187 insertions(+), 8 deletions(-) diff --git a/src/sql/arrow_sql_gen/arrow.rs b/src/sql/arrow_sql_gen/arrow.rs index 3577b3c..eb2fbd0 100644 --- a/src/sql/arrow_sql_gen/arrow.rs +++ b/src/sql/arrow_sql_gen/arrow.rs @@ -1,12 +1,13 @@ use arrow::{ array::{ - ArrayBuilder, BinaryBuilder, BooleanBuilder, Date32Builder, Date64Builder, + types::Int8Type, ArrayBuilder, BinaryBuilder, BooleanBuilder, Date32Builder, Date64Builder, Decimal128Builder, Decimal256Builder, FixedSizeBinaryBuilder, FixedSizeListBuilder, Float32Builder, Float64Builder, Int16Builder, Int32Builder, Int64Builder, Int8Builder, IntervalMonthDayNanoBuilder, LargeBinaryBuilder, LargeStringBuilder, ListBuilder, - NullBuilder, StringBuilder, StructBuilder, Time64NanosecondBuilder, - TimestampMicrosecondBuilder, TimestampMillisecondBuilder, TimestampNanosecondBuilder, - TimestampSecondBuilder, UInt16Builder, UInt32Builder, UInt64Builder, UInt8Builder, + NullBuilder, StringBuilder, StringDictionaryBuilder, StructBuilder, + Time64NanosecondBuilder, TimestampMicrosecondBuilder, TimestampMillisecondBuilder, + TimestampNanosecondBuilder, TimestampSecondBuilder, UInt16Builder, UInt32Builder, + UInt64Builder, UInt8Builder, }, datatypes::{DataType, TimeUnit}, }; @@ -62,6 +63,12 @@ pub fn map_data_type_to_array_builder(data_type: &DataType) -> Box match (&**key_type, &**value_type) { + (DataType::Int8, DataType::Utf8) => { + Box::new(StringDictionaryBuilder::::new()) + } + _ => unimplemented!("Unimplemented dictionary type"), + }, DataType::Date32 => Box::new(Date32Builder::new()), DataType::Date64 => Box::new(Date64Builder::new()), // For time format, always use nanosecond diff --git a/src/sql/arrow_sql_gen/postgres.rs b/src/sql/arrow_sql_gen/postgres.rs index df07478..9781590 100644 --- a/src/sql/arrow_sql_gen/postgres.rs +++ b/src/sql/arrow_sql_gen/postgres.rs @@ -1,4 +1,5 @@ use std::convert; +use std::io::Read; use std::sync::Arc; use crate::sql::arrow_sql_gen::arrow::map_data_type_to_array_builder_optional; @@ -7,11 +8,11 @@ use arrow::array::{ ArrayBuilder, ArrayRef, BinaryBuilder, BooleanBuilder, Date32Builder, Decimal128Builder, FixedSizeListBuilder, Float32Builder, Float64Builder, Int16Builder, Int32Builder, Int64Builder, Int8Builder, IntervalMonthDayNanoBuilder, LargeBinaryBuilder, LargeStringBuilder, ListBuilder, - RecordBatch, RecordBatchOptions, StringBuilder, StructBuilder, Time64NanosecondBuilder, - TimestampNanosecondBuilder, UInt32Builder, + RecordBatch, RecordBatchOptions, StringBuilder, StringDictionaryBuilder, StructBuilder, + Time64NanosecondBuilder, TimestampNanosecondBuilder, UInt32Builder, }; use arrow::datatypes::{ - DataType, Date32Type, Field, IntervalMonthDayNanoType, IntervalUnit, Schema, TimeUnit, + DataType, Date32Type, Field, Int8Type, IntervalMonthDayNanoType, IntervalUnit, Schema, TimeUnit, }; use bigdecimal::num_bigint::BigInt; use bigdecimal::num_bigint::Sign; @@ -745,6 +746,31 @@ pub fn rows_to_arrow(rows: &[Row]) -> Result { ); } } + Kind::Enum(_) => { + let Some(builder) = builder else { + return NoBuilderForIndexSnafu { index: i }.fail(); + }; + let Some(builder) = builder + .as_any_mut() + .downcast_mut::>() + else { + return FailedToDowncastBuilderSnafu { + postgres_type: format!("{postgres_type}"), + } + .fail(); + }; + + let v = row.try_get::>(i).context( + FailedToGetRowValueSnafu { + pg_type: postgres_type.clone(), + }, + )?; + + match v { + Some(v) => builder.append_value(v.enum_value), + None => builder.append_null(), + } + } _ => { unimplemented!("Unsupported type {:?} for column index {i}", postgres_type,) } @@ -850,6 +876,10 @@ fn map_column_type_to_data_type(column_type: &Type) -> Option { } Some(DataType::Struct(arrow_fields.into())) } + Kind::Enum(_) => Some(DataType::Dictionary( + Box::new(DataType::Int8), + Box::new(DataType::Utf8), + )), _ => unimplemented!("Unsupported column type {:?}", column_type), }, } @@ -1009,6 +1039,26 @@ impl<'a> FromSql<'a> for MoneyFromSql { } } +struct EnumValueFromSql { + enum_value: String, +} + +impl<'a> FromSql<'a> for EnumValueFromSql { + fn from_sql( + _ty: &Type, + raw: &'a [u8], + ) -> Result> { + let mut cursor = std::io::Cursor::new(raw); + let mut enum_value = String::new(); + cursor.read_to_string(&mut enum_value)?; + Ok(EnumValueFromSql { enum_value }) + } + + fn accepts(ty: &Type) -> bool { + matches!(*ty.kind(), Kind::Enum(_)) + } +} + pub struct GeometryFromSql<'a> { wkb: &'a [u8], } diff --git a/tests/arrow_record_batch_gen/mod.rs b/tests/arrow_record_batch_gen/mod.rs index 9e39610..8e36793 100644 --- a/tests/arrow_record_batch_gen/mod.rs +++ b/tests/arrow_record_batch_gen/mod.rs @@ -2,7 +2,7 @@ use arrow::array::RecordBatch; use arrow::{ array::*, datatypes::{ - i256, DataType, Date32Type, Date64Type, Field, Fields, IntervalDayTime, + i256, DataType, Date32Type, Date64Type, Field, Fields, Int8Type, IntervalDayTime, IntervalMonthDayNano, IntervalUnit, IntervalYearMonthType, Schema, SchemaRef, TimeUnit, }, }; @@ -535,6 +535,26 @@ pub(crate) fn get_arrow_bytea_array_record_batch() -> (RecordBatch, SchemaRef) { (record_batch, schema) } +// DICTIONARY_ARRAY +pub(crate) fn get_arrow_dictionary_array_record_batch() -> (RecordBatch, SchemaRef) { + let mut builder = StringDictionaryBuilder::::new(); + builder.append_value("happy"); + builder.append_value("sad"); + builder.append_value("neutral"); + let array: DictionaryArray = builder.finish(); + + let schema = Arc::new(Schema::new(vec![Field::new( + "mood_status", + DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)), + true, + )])); + + let record_batch = RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(array)]) + .expect("Failed to created arrow dictionary array record batch"); + + (record_batch, schema) +} + // Custom Test Case for Sqlite <-> Arrow Decimal Roundtrip // SQLite supports up to 16 precision for decimal numbers through REAL type, conforming to IEEE 754 Binary-64 format - https://www.sqlite.org/floatingpoint.html pub(crate) fn get_sqlite_arrow_decimal_record_batch() -> (RecordBatch, SchemaRef) { diff --git a/tests/postgres/common.rs b/tests/postgres/common.rs index 0fc0046..35617bf 100644 --- a/tests/postgres/common.rs +++ b/tests/postgres/common.rs @@ -1,4 +1,7 @@ use bollard::secret::HealthConfig; +#[cfg(feature = "postgres")] +use datafusion_table_providers::sql::db_connection_pool::postgrespool::PostgresConnectionPool; +use datafusion_table_providers::util::secrets::to_secret_map; use std::collections::HashMap; use tracing::instrument; @@ -57,3 +60,12 @@ pub(super) async fn start_postgres_docker_container( tokio::time::sleep(std::time::Duration::from_millis(5000)).await; Ok(running_container) } + +#[instrument] +pub(super) async fn get_postgres_connection_pool( + port: usize, +) -> Result { + let pool = PostgresConnectionPool::new(to_secret_map(get_pg_params(port))).await?; + + Ok(pool) +} diff --git a/tests/postgres/mod.rs b/tests/postgres/mod.rs index a9ddadf..dfe12f6 100644 --- a/tests/postgres/mod.rs +++ b/tests/postgres/mod.rs @@ -137,3 +137,93 @@ async fn test_arrow_postgres_roundtrip( ) .await; } + +#[rstest] +#[test_log::test(tokio::test)] +async fn test_arrow_postgres_one_way(container_manager: &Mutex) { + let mut container_manager = container_manager.lock().await; + if !container_manager.claimed { + container_manager.claimed = true; + start_container(&container_manager).await; + } + + let extra_stmt = Some("CREATE TYPE mood AS ENUM ('happy', 'sad', 'neutral');"); + let create_table_stmt = " + CREATE TABLE person_mood ( + mood_status mood NOT NULL + );"; + + let insert_table_stmt = " + INSERT INTO person_mood (mood_status) VALUES ('happy'), ('sad'), ('neutral'); + "; + + let (expected_record, _) = get_arrow_dictionary_array_record_batch(); + + arrow_postgres_one_way( + container_manager.port, + "person_mood", + create_table_stmt, + insert_table_stmt, + extra_stmt, + expected_record, + ) + .await; +} + +async fn arrow_postgres_one_way( + port: usize, + table_name: &str, + create_table_stmt: &str, + insert_table_stmt: &str, + extra_stmt: Option<&str>, + expected_record: RecordBatch, +) { + tracing::debug!("Running tests on {table_name}"); + let ctx = SessionContext::new(); + + let pool = common::get_postgres_connection_pool(port) + .await + .expect("Postgres connection pool should be created"); + + let db_conn = pool + .connect_direct() + .await + .expect("Connection should be established"); + + if let Some(extra_stmt) = extra_stmt { + let _ = db_conn + .conn + .execute(extra_stmt, &[]) + .await + .expect("Statement should be created"); + } + + let _ = db_conn + .conn + .execute(create_table_stmt, &[]) + .await + .expect("Postgres table should be created"); + + let _ = db_conn + .conn + .execute(insert_table_stmt, &[]) + .await + .expect("Postgres table data should be inserted"); + + // Register datafusion table, test row -> arrow conversion + let sqltable_pool: Arc = Arc::new(pool); + let table = SqlTable::new("postgres", &sqltable_pool, table_name, None) + .await + .expect("Table should be created"); + ctx.register_table(table_name, Arc::new(table)) + .expect("Table should be registered"); + let sql = format!("SELECT * FROM {table_name}"); + let df = ctx + .sql(&sql) + .await + .expect("DataFrame should be created from query"); + + let record_batch = df.collect().await.expect("RecordBatch should be collected"); + + assert_eq!(record_batch[0], expected_record); +} From 1bd4ed35114108b3c708dd8ad7f4315cd0c648ff Mon Sep 17 00:00:00 2001 From: Sergei Grebnov Date: Sat, 14 Sep 2024 22:55:06 -0700 Subject: [PATCH 12/40] Fix SQLite Invalid column type Real bug (#98) --- src/sql/arrow_sql_gen/sqlite.rs | 29 +++++++++++++++++-- .../dbconnection/sqliteconn.rs | 7 +++-- 2 files changed, 31 insertions(+), 5 deletions(-) diff --git a/src/sql/arrow_sql_gen/sqlite.rs b/src/sql/arrow_sql_gen/sqlite.rs index 9ec12dd..454af1d 100644 --- a/src/sql/arrow_sql_gen/sqlite.rs +++ b/src/sql/arrow_sql_gen/sqlite.rs @@ -29,6 +29,7 @@ use arrow::array::StringBuilder; use arrow::datatypes::DataType; use arrow::datatypes::Field; use arrow::datatypes::Schema; +use arrow::datatypes::SchemaRef; use rusqlite::types::Type; use rusqlite::Row; use rusqlite::Rows; @@ -60,7 +61,7 @@ pub type Result = std::result::Result; /// # Errors /// /// Returns an error if there is a failure in converting the rows to a `RecordBatch`. -pub fn rows_to_arrow(mut rows: Rows, num_cols: usize) -> Result { +pub fn rows_to_arrow(mut rows: Rows, num_cols: usize, projected_schema: Option) -> Result { let mut arrow_fields: Vec = Vec::new(); let mut arrow_columns_builders: Vec> = Vec::new(); let mut sqlite_types: Vec = Vec::new(); @@ -68,7 +69,7 @@ pub fn rows_to_arrow(mut rows: Rows, num_cols: usize) -> Result { if let Ok(Some(row)) = rows.next() { for i in 0..num_cols { - let column_type = row + let mut column_type = row .get_ref(i) .context(FailedToExtractRowValueSnafu)? .data_type(); @@ -77,6 +78,30 @@ pub fn rows_to_arrow(mut rows: Rows, num_cols: usize) -> Result { .column_name(i) .context(FailedToExtractColumnNameSnafu)? .to_string(); + + // SQLite can store floating point values without a fractional component as integers. + // Therefore, we need to verify if the column is actually a floating point type + // by examining the projected schema. + // Note: The same column may contain both integer and floating point values. + // Reading values as Float is safe even if the value is stored as an integer. + // Refer to the rusqlite type handling documentation for more details: + // https://github.com/rusqlite/rusqlite/blob/95680270eca6f405fb51f5fbe6a214aac5fdce58/src/types/mod.rs#L21C1-L22C75 + // + // `REAL` to integer: always returns an [`Error::InvalidColumnType`](crate::Error::InvalidColumnType) error. + // `INTEGER` to float: casts using `as` operator. Never fails. + // `REAL` to float: casts using `as` operator. Never fails. + + if column_type == Type::Integer { + if let Some(projected_schema) = projected_schema.as_ref() { + match projected_schema.fields[i].data_type() { + DataType::Decimal128(..) | DataType::Float16 | DataType::Float32 | DataType::Float64 => { + column_type = Type::Real; + } + _ => {} + } + } + } + let data_type = map_column_type_to_data_type(column_type); arrow_fields.push(Field::new(column_name, data_type.clone(), true)); diff --git a/src/sql/db_connection_pool/dbconnection/sqliteconn.rs b/src/sql/db_connection_pool/dbconnection/sqliteconn.rs index 5569a71..ef2d783 100644 --- a/src/sql/db_connection_pool/dbconnection/sqliteconn.rs +++ b/src/sql/db_connection_pool/dbconnection/sqliteconn.rs @@ -63,7 +63,7 @@ impl AsyncDbConnection for SqliteConnec let mut stmt = conn.prepare(&format!("SELECT * FROM {table_reference} LIMIT 1"))?; let column_count = stmt.column_count(); let rows = stmt.query([])?; - let rec = rows_to_arrow(rows, column_count) + let rec = rows_to_arrow(rows, column_count, None) .context(ConversionSnafu) .map_err(to_tokio_rusqlite_error)?; let schema = rec.schema(); @@ -80,7 +80,7 @@ impl AsyncDbConnection for SqliteConnec &self, sql: &str, params: &[&'static (dyn ToSql + Sync)], - _projected_schema: Option, + projected_schema: Option, ) -> Result { let sql = sql.to_string(); let params = params.to_vec(); @@ -94,7 +94,8 @@ impl AsyncDbConnection for SqliteConnec } let column_count = stmt.column_count(); let rows = stmt.raw_query(); - let rec = rows_to_arrow(rows, column_count) + + let rec = rows_to_arrow(rows, column_count, projected_schema) .context(ConversionSnafu) .map_err(to_tokio_rusqlite_error)?; Ok(rec) From 2a92a77ffb7b3d52f496191074d3abdd4b9d18f2 Mon Sep 17 00:00:00 2001 From: Sergei Grebnov Date: Sat, 14 Sep 2024 23:02:22 -0700 Subject: [PATCH 13/40] Prevent SQLite from writing incomplete data on errors (#101) --- src/sqlite/write.rs | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/src/sqlite/write.rs b/src/sqlite/write.rs index 6f14fd1..27c911a 100644 --- a/src/sqlite/write.rs +++ b/src/sqlite/write.rs @@ -121,6 +121,11 @@ impl DataSink for SqliteDataSink { _context: &Arc, ) -> datafusion::common::Result { let (batch_tx, mut batch_rx) = tokio::sync::mpsc::channel::(1); + + // Since the main task/stream can be dropped or fail, we use a oneshot channel to signal that all data is received and we should commit the transaction + let (notify_commit_transaction, mut on_commit_transaction) = + tokio::sync::oneshot::channel(); + let mut db_conn = self.sqlite.connect().await.map_err(to_datafusion_error)?; let sqlite_conn = Sqlite::sqlite_conn(&mut db_conn).map_err(to_datafusion_error)?; @@ -144,6 +149,15 @@ impl DataSink for SqliteDataSink { })?; } + if notify_commit_transaction.send(()).is_err() { + return Err(DataFusionError::Execution( + "Unable to send message to commit transaction to SQLite writer.".to_string(), + )); + }; + + // Drop the sender to signal the receiver that no more data is coming + drop(batch_tx); + Ok::<_, DataFusionError>(num_rows) }); @@ -165,6 +179,12 @@ impl DataSink for SqliteDataSink { } } + if on_commit_transaction.try_recv().is_err() { + return Err(tokio_rusqlite::Error::Other( + "No message to commit transaction has been received.".into(), + )); + } + transaction.commit()?; Ok(()) From a5d60a6dfceff642cab43bf2ef07ebe4470d8615 Mon Sep 17 00:00:00 2001 From: hozan23 Date: Thu, 19 Sep 2024 11:27:24 +0200 Subject: [PATCH 14/40] cargo fmt --- src/sql/arrow_sql_gen/sqlite.rs | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/sql/arrow_sql_gen/sqlite.rs b/src/sql/arrow_sql_gen/sqlite.rs index 454af1d..6466305 100644 --- a/src/sql/arrow_sql_gen/sqlite.rs +++ b/src/sql/arrow_sql_gen/sqlite.rs @@ -61,7 +61,11 @@ pub type Result = std::result::Result; /// # Errors /// /// Returns an error if there is a failure in converting the rows to a `RecordBatch`. -pub fn rows_to_arrow(mut rows: Rows, num_cols: usize, projected_schema: Option) -> Result { +pub fn rows_to_arrow( + mut rows: Rows, + num_cols: usize, + projected_schema: Option, +) -> Result { let mut arrow_fields: Vec = Vec::new(); let mut arrow_columns_builders: Vec> = Vec::new(); let mut sqlite_types: Vec = Vec::new(); @@ -80,9 +84,9 @@ pub fn rows_to_arrow(mut rows: Rows, num_cols: usize, projected_schema: Option { + DataType::Decimal128(..) + | DataType::Float16 + | DataType::Float32 + | DataType::Float64 => { column_type = Type::Real; } _ => {} From 4d0b67bb4c3ce86661bf6044f5053bd9f03d78e9 Mon Sep 17 00:00:00 2001 From: hozan23 Date: Thu, 19 Sep 2024 11:40:06 +0200 Subject: [PATCH 15/40] minor fix to the integration postgres test --- tests/postgres/mod.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/postgres/mod.rs b/tests/postgres/mod.rs index dfe12f6..cc925a3 100644 --- a/tests/postgres/mod.rs +++ b/tests/postgres/mod.rs @@ -7,7 +7,9 @@ use datafusion::logical_expr::CreateExternalTable; use datafusion::physical_plan::collect; use datafusion::physical_plan::memory::MemoryExec; use datafusion_federation::schema_cast::record_convert::try_cast_to; +use datafusion_table_providers::postgres::DynPostgresConnectionPool; use datafusion_table_providers::postgres::PostgresTableProviderFactory; +use datafusion_table_providers::sql::sql_provider_datafusion::SqlTable; use rstest::{fixture, rstest}; use std::collections::HashMap; use std::sync::Arc; From 6e33170f28076457bbd643e4b9fc07ca2cd450ad Mon Sep 17 00:00:00 2001 From: hozan23 Date: Thu, 19 Sep 2024 14:12:53 +0200 Subject: [PATCH 16/40] use the latest version of sea-query crate --- Cargo.toml | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index c1c886a..fe70bad 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -41,14 +41,25 @@ mysql_async = { version = "0.34.1", features = ["native-tls-tls", "chrono"], opt prost = { version = "0.12" , optional = true } # pinned for arrow-flight compat r2d2 = { version = "0.8.10", optional = true } rusqlite = { version = "0.31.0", optional = true } -sea-query = { git = "https://github.com/spiceai/sea-query.git", rev = "213b6b876068f58159ebdd5852604a021afaebf9", features = ["backend-sqlite", "backend-postgres", "postgres-array", "with-rust_decimal", "with-bigdecimal", "with-time", "with-chrono"] } +sea-query = { version = "0.32.0-rc.1", features = [ + "backend-sqlite", + "backend-postgres", + "postgres-array", + "with-rust_decimal", + "with-bigdecimal", + "with-time", + "with-chrono"] } secrecy = "0.8.0" serde = { version = "1.0.209", optional = true } serde_json = "1.0.124" snafu = "0.8.3" time = "0.3.36" tokio = { version = "1.38.0", features = ["macros", "fs"] } -tokio-postgres = { version = "0.7.10", features = ["with-chrono-0_4", "with-uuid-1", "with-serde_json-1", "with-geo-types-0_7"], optional = true } +tokio-postgres = { version = "0.7.10", features = [ + "with-chrono-0_4", + "with-uuid-1", + "with-serde_json-1", + "with-geo-types-0_7"], optional = true } tracing = "0.1.40" uuid = { version = "1.9.1", optional = true } postgres-native-tls = { version = "0.5.0", optional = true } From 21b42b98e0162473e0855932c2730448f4a71dbc Mon Sep 17 00:00:00 2001 From: Phillip LeBlanc Date: Mon, 16 Sep 2024 16:54:22 +0900 Subject: [PATCH 17/40] DuckDBTableProviderFactory keeps track of opened instances (#105) --- examples/duckdb_external_table.rs | 2 +- src/duckdb.rs | 61 +++++++++++++++++++++-------- src/sql/db_connection_pool/mod.rs | 20 ++++++++++ tests/arrow_record_batch_gen/mod.rs | 5 +-- tests/duckdb/mod.rs | 2 +- tests/mysql/common.rs | 6 +-- tests/mysql/mod.rs | 5 +-- tests/postgres/common.rs | 6 +-- tests/postgres/mod.rs | 8 ++-- 9 files changed, 77 insertions(+), 38 deletions(-) diff --git a/examples/duckdb_external_table.rs b/examples/duckdb_external_table.rs index bd619fb..6d8884f 100644 --- a/examples/duckdb_external_table.rs +++ b/examples/duckdb_external_table.rs @@ -12,7 +12,7 @@ use duckdb::AccessMode; /// DuckDB-backed tables can be created at runtime. #[tokio::main] async fn main() { - let duckdb = Arc::new(DuckDBTableProviderFactory::new().access_mode(AccessMode::ReadWrite)); + let duckdb = Arc::new(DuckDBTableProviderFactory::new(AccessMode::ReadWrite)); let runtime = Arc::new(RuntimeEnv::default()); let state = SessionStateBuilder::new() diff --git a/src/duckdb.rs b/src/duckdb.rs index c60712f..2d07091 100644 --- a/src/duckdb.rs +++ b/src/duckdb.rs @@ -7,7 +7,7 @@ use crate::sql::db_connection_pool::{ get_schema, DbConnection, }, duckdbpool::DuckDbConnectionPool, - DbConnectionPool, Mode, + DbConnectionPool, DbInstanceKey, Mode, }; use crate::sql::sql_provider_datafusion; use crate::util::{ @@ -31,6 +31,7 @@ use duckdb::{AccessMode, DuckdbConnectionManager, Transaction}; use itertools::Itertools; use snafu::prelude::*; use std::{cmp, collections::HashMap, sync::Arc}; +use tokio::sync::Mutex; use self::{creator::TableCreator, sql_table::DuckDBTable, write::DuckDBTableWriter}; @@ -123,6 +124,7 @@ type Result = std::result::Result; pub struct DuckDBTableProviderFactory { access_mode: AccessMode, + instances: Arc>>, } const DUCKDB_DB_PATH_PARAM: &str = "open"; @@ -131,9 +133,10 @@ const DUCKDB_ATTACH_DATABASES_PARAM: &str = "attach_databases"; impl DuckDBTableProviderFactory { #[must_use] - pub fn new() -> Self { + pub fn new(access_mode: AccessMode) -> Self { Self { - access_mode: AccessMode::ReadOnly, + access_mode, + instances: Arc::new(Mutex::new(HashMap::new())), } } @@ -150,12 +153,6 @@ impl DuckDBTableProviderFactory { .unwrap_or_default() } - #[must_use] - pub fn access_mode(mut self, access_mode: AccessMode) -> Self { - self.access_mode = access_mode; - self - } - #[must_use] pub fn duckdb_file_path(&self, name: &str, options: &mut HashMap) -> String { let options = util::remove_prefix_from_hashmap_keys(options.clone(), "duckdb_"); @@ -171,11 +168,40 @@ impl DuckDBTableProviderFactory { .cloned() .unwrap_or(default_filepath) } -} -impl Default for DuckDBTableProviderFactory { - fn default() -> Self { - Self::new() + pub async fn get_or_init_memory_instance(&self) -> Result { + let key = DbInstanceKey::memory(); + let mut instances = self.instances.lock().await; + + if let Some(instance) = instances.get(&key) { + return Ok(instance.clone()); + } + + let pool = DuckDbConnectionPool::new_memory().context(DbConnectionPoolSnafu)?; + + instances.insert(key, pool.clone()); + + Ok(pool) + } + + pub async fn get_or_init_file_instance( + &self, + db_path: impl Into>, + ) -> Result { + let db_path = db_path.into(); + let key = DbInstanceKey::file(Arc::clone(&db_path)); + let mut instances = self.instances.lock().await; + + if let Some(instance) = instances.get(&key) { + return Ok(instance.clone()); + } + + let pool = DuckDbConnectionPool::new_file(&db_path, &self.access_mode) + .context(DbConnectionPoolSnafu)?; + + instances.insert(key, pool.clone()); + + Ok(pool) } } @@ -231,12 +257,13 @@ impl TableProviderFactory for DuckDBTableProviderFactory { // open duckdb at given path or create a new one let db_path = self.duckdb_file_path(&name, &mut options); - DuckDbConnectionPool::new_file(&db_path, &self.access_mode) - .context(DbConnectionPoolSnafu) + self.get_or_init_file_instance(db_path) + .await .map_err(to_datafusion_error)? } - Mode::Memory => DuckDbConnectionPool::new_memory() - .context(DbConnectionPoolSnafu) + Mode::Memory => self + .get_or_init_memory_instance() + .await .map_err(to_datafusion_error)?, }; diff --git a/src/sql/db_connection_pool/mod.rs b/src/sql/db_connection_pool/mod.rs index a642bfd..dff8300 100644 --- a/src/sql/db_connection_pool/mod.rs +++ b/src/sql/db_connection_pool/mod.rs @@ -1,5 +1,6 @@ use async_trait::async_trait; use dbconnection::DbConnection; +use std::sync::Arc; pub mod dbconnection; #[cfg(feature = "duckdb")] @@ -48,3 +49,22 @@ impl From<&str> for Mode { } } } + +/// A key that uniquely identifies a database instance. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum DbInstanceKey { + /// The database is a file on disk, with the given path. + File(Arc), + /// The database is in memory. + Memory, +} + +impl DbInstanceKey { + pub fn memory() -> Self { + DbInstanceKey::Memory + } + + pub fn file(path: Arc) -> Self { + DbInstanceKey::File(path) + } +} diff --git a/tests/arrow_record_batch_gen/mod.rs b/tests/arrow_record_batch_gen/mod.rs index 8e36793..61949f0 100644 --- a/tests/arrow_record_batch_gen/mod.rs +++ b/tests/arrow_record_batch_gen/mod.rs @@ -2,13 +2,12 @@ use arrow::array::RecordBatch; use arrow::{ array::*, datatypes::{ - i256, DataType, Date32Type, Date64Type, Field, Fields, Int8Type, IntervalDayTime, - IntervalMonthDayNano, IntervalUnit, IntervalYearMonthType, Schema, SchemaRef, TimeUnit, + i256, DataType, Date32Type, Date64Type, Field, Int8Type, IntervalDayTime, + IntervalMonthDayNano, IntervalUnit, Schema, SchemaRef, TimeUnit, }, }; use chrono::NaiveDate; use std::sync::Arc; -use types::IntervalDayTimeType; // Helper functions to create arrow record batches of different types diff --git a/tests/duckdb/mod.rs b/tests/duckdb/mod.rs index a10c9f6..adae065 100644 --- a/tests/duckdb/mod.rs +++ b/tests/duckdb/mod.rs @@ -18,7 +18,7 @@ async fn arrow_duckdb_round_trip( source_schema: SchemaRef, table_name: &str, ) { - let factory = DuckDBTableProviderFactory::new().access_mode(duckdb::AccessMode::ReadWrite); + let factory = DuckDBTableProviderFactory::new(duckdb::AccessMode::ReadWrite); let ctx = SessionContext::new(); let cmd = CreateExternalTable { schema: Arc::new(arrow_record.schema().to_dfschema().expect("to df schema")), diff --git a/tests/mysql/common.rs b/tests/mysql/common.rs index 986ca3b..bd88200 100644 --- a/tests/mysql/common.rs +++ b/tests/mysql/common.rs @@ -45,11 +45,7 @@ fn get_mysql_params(port: usize) -> HashMap { pub async fn start_mysql_docker_container(port: usize) -> Result { let container_name = format!("{MYSQL_DOCKER_CONTAINER}-{port}"); - let port = if let Ok(port) = port.try_into() { - port - } else { - 15432 - }; + let port = port.try_into().unwrap_or(15432); let mysql_docker_image = std::env::var("MYSQL_DOCKER_IMAGE") .unwrap_or_else(|_| format!("{}mysql:latest", container_registry())); diff --git a/tests/mysql/mod.rs b/tests/mysql/mod.rs index 412f2c9..34a3b3c 100644 --- a/tests/mysql/mod.rs +++ b/tests/mysql/mod.rs @@ -17,7 +17,6 @@ use crate::docker::RunningContainer; mod common; async fn test_mysql_decimal_types(port: usize) { - let table_name = "decimal_table"; let create_table_stmt = " CREATE TABLE IF NOT EXISTS decimal_table (decimal_col DECIMAL(10, 2)); "; @@ -41,7 +40,7 @@ async fn test_mysql_decimal_types(port: usize) { ) .expect("Failed to created arrow record batch"); - let decimal_record = arrow_mysql_one_way( + let _ = arrow_mysql_one_way( port, "decimal_table", create_table_stmt, @@ -289,7 +288,7 @@ async fn arrow_mysql_one_way( assert_eq!(record_batch.len(), 1); assert_eq!(record_batch[0], expected_record); - return record_batch; + record_batch } async fn start_mysql_container(port: usize) -> RunningContainer { diff --git a/tests/postgres/common.rs b/tests/postgres/common.rs index 35617bf..6cc9912 100644 --- a/tests/postgres/common.rs +++ b/tests/postgres/common.rs @@ -29,11 +29,7 @@ pub(super) async fn start_postgres_docker_container( port: usize, ) -> Result { let container_name = format!("{PG_DOCKER_CONTAINER}-{port}"); - let port = if let Ok(port) = port.try_into() { - port - } else { - 15432 - }; + let port = port.try_into().unwrap_or(15432); let pg_docker_image = std::env::var("PG_DOCKER_IMAGE") .unwrap_or_else(|_| format!("{}postgres:latest", container_registry())); diff --git a/tests/postgres/mod.rs b/tests/postgres/mod.rs index cc925a3..c5645f1 100644 --- a/tests/postgres/mod.rs +++ b/tests/postgres/mod.rs @@ -7,9 +7,11 @@ use datafusion::logical_expr::CreateExternalTable; use datafusion::physical_plan::collect; use datafusion::physical_plan::memory::MemoryExec; use datafusion_federation::schema_cast::record_convert::try_cast_to; -use datafusion_table_providers::postgres::DynPostgresConnectionPool; -use datafusion_table_providers::postgres::PostgresTableProviderFactory; -use datafusion_table_providers::sql::sql_provider_datafusion::SqlTable; + +use datafusion_table_providers::{ + postgres::{DynPostgresConnectionPool, PostgresTableProviderFactory}, + sql::sql_provider_datafusion::SqlTable, +}; use rstest::{fixture, rstest}; use std::collections::HashMap; use std::sync::Arc; From 2acb5c7564a3a8a2f3b5be3b85e1a7cdeff43c66 Mon Sep 17 00:00:00 2001 From: Phillip LeBlanc Date: Mon, 16 Sep 2024 21:44:12 +0900 Subject: [PATCH 18/40] Ignore CHECKPOINT errors (#107) --- src/duckdb.rs | 5 ----- src/duckdb/write.rs | 11 +++++++---- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/src/duckdb.rs b/src/duckdb.rs index 2d07091..3ff516b 100644 --- a/src/duckdb.rs +++ b/src/duckdb.rs @@ -90,11 +90,6 @@ pub enum Error { #[snafu(display("Unable to commit transaction: {source}"))] UnableToCommitTransaction { source: duckdb::Error }, - #[snafu(display("Unable to checkpoint duckdb: {source}"))] - UnableToCheckpoint { - source: Box, - }, - #[snafu(display("Unable to begin duckdb transaction: {source}"))] UnableToBeginTransaction { source: duckdb::Error }, diff --git a/src/duckdb/write.rs b/src/duckdb/write.rs index a7eb3ce..00d40cb 100644 --- a/src/duckdb/write.rs +++ b/src/duckdb/write.rs @@ -214,12 +214,15 @@ impl DataSink for DuckDBDataSink { match duckdb_write_handle.await { Ok(result) => { - // before returning the result, CHECKPOINT to flush the WAL to disk + // before returning the result, attempt to CHECKPOINT to flush the WAL to disk let mut conn = self.duckdb.connect_sync().map_err(to_datafusion_error)?; let conn = DuckDB::duckdb_conn(&mut conn).map_err(to_datafusion_error)?; - conn.execute("CHECKPOINT", &[]).map_err(|err| { - to_datafusion_error(super::Error::UnableToCheckpoint { source: err }) - })?; + + // This may fail if multiple transactions are active (i.e. actively writing data) + // we can ignore the error since it will be written once the last transaction finishes. + if let Err(e) = conn.execute("CHECKPOINT", &[]) { + tracing::trace!("DuckDB CHECKPOINT failed - this is expected if there are multiple active transactions: {e}"); + }; result } From afcddc89c1f4d25513d2f53d680fe9fd4baf7fb9 Mon Sep 17 00:00:00 2001 From: Phillip LeBlanc Date: Mon, 16 Sep 2024 22:53:12 +0900 Subject: [PATCH 19/40] Don't attempt to CHECKPOINT after writing to DuckDB (#108) --- src/duckdb/write.rs | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/src/duckdb/write.rs b/src/duckdb/write.rs index 00d40cb..1d076e4 100644 --- a/src/duckdb/write.rs +++ b/src/duckdb/write.rs @@ -1,7 +1,6 @@ use std::{any::Any, fmt, sync::Arc}; use crate::duckdb::DuckDB; -use crate::sql::db_connection_pool::dbconnection::SyncDbConnection; use crate::util::{ constraints, on_conflict::OnConflict, retriable_error::check_and_mark_retriable_error, }; @@ -213,19 +212,7 @@ impl DataSink for DuckDBDataSink { drop(batch_tx); match duckdb_write_handle.await { - Ok(result) => { - // before returning the result, attempt to CHECKPOINT to flush the WAL to disk - let mut conn = self.duckdb.connect_sync().map_err(to_datafusion_error)?; - let conn = DuckDB::duckdb_conn(&mut conn).map_err(to_datafusion_error)?; - - // This may fail if multiple transactions are active (i.e. actively writing data) - // we can ignore the error since it will be written once the last transaction finishes. - if let Err(e) = conn.execute("CHECKPOINT", &[]) { - tracing::trace!("DuckDB CHECKPOINT failed - this is expected if there are multiple active transactions: {e}"); - }; - - result - } + Ok(result) => result, Err(e) => Err(DataFusionError::Execution(format!( "Error writing to DuckDB: {e}" ))), From 311c8b507fbf45e2f6c54ed10539a4c3957fe235 Mon Sep 17 00:00:00 2001 From: Phillip LeBlanc Date: Tue, 17 Sep 2024 22:46:41 +0900 Subject: [PATCH 20/40] SqliteTableProviderFactory keeps track of opened instances (#109) * wip * wip * wip * tweak --- src/sql/db_connection_pool/sqlitepool.rs | 29 +++++++++++++ src/sqlite.rs | 54 +++++++++++++++++------- 2 files changed, 68 insertions(+), 15 deletions(-) diff --git a/src/sql/db_connection_pool/sqlitepool.rs b/src/sql/db_connection_pool/sqlitepool.rs index c81ccc3..5b16cd0 100644 --- a/src/sql/db_connection_pool/sqlitepool.rs +++ b/src/sql/db_connection_pool/sqlitepool.rs @@ -198,6 +198,35 @@ impl SqliteConnectionPool { pub fn connect_sync(&self) -> Box> { Box::new(SqliteConnection::new(self.conn.clone())) } + + /// Will attempt to clone the connection pool. This will always succeed for in-memory mode. + /// For file-mode, it will attempt to create a new connection pool with the same configuration. + /// + /// Due to the way the connection pool is implemented, it doesn't allow multiple concurrent reads/writes + /// using the same connection pool instance. + pub async fn try_clone(&self) -> Result { + match self.mode { + Mode::Memory => Ok(SqliteConnectionPool { + conn: self.conn.clone(), + join_push_down: self.join_push_down.clone(), + mode: self.mode, + path: Arc::clone(&self.path), + attach_databases: self.attach_databases.clone(), + }), + Mode::File => { + let attach_databases = if self.attach_databases.is_empty() { + None + } else { + Some(self.attach_databases.clone()) + }; + + SqliteConnectionPoolFactory::new(&self.path, self.mode) + .with_databases(attach_databases) + .build() + .await + } + } + } } #[async_trait] diff --git a/src/sqlite.rs b/src/sqlite.rs index 55b71d6..9f905c2 100644 --- a/src/sqlite.rs +++ b/src/sqlite.rs @@ -1,6 +1,7 @@ use crate::sql::arrow_sql_gen::statement::{CreateTableBuilder, IndexBuilder, InsertBuilder}; use crate::sql::db_connection_pool::dbconnection::{self, get_schema, AsyncDbConnection}; use crate::sql::db_connection_pool::sqlitepool::SqliteConnectionPoolFactory; +use crate::sql::db_connection_pool::DbInstanceKey; use crate::sql::db_connection_pool::{ self, dbconnection::{sqliteconn::SqliteConnection, DbConnection}, @@ -26,6 +27,7 @@ use snafu::prelude::*; use sql_table::SQLiteTable; use std::collections::HashSet; use std::{collections::HashMap, sync::Arc}; +use tokio::sync::Mutex; use tokio_rusqlite::Connection; use crate::util::{ @@ -95,7 +97,9 @@ pub enum Error { type Result = std::result::Result; -pub struct SqliteTableProviderFactory {} +pub struct SqliteTableProviderFactory { + instances: Arc>>, +} const SQLITE_DB_PATH_PARAM: &str = "file"; const SQLITE_DB_BASE_FOLDER_PARAM: &str = "data_directory"; @@ -104,7 +108,9 @@ const SQLITE_ATTACH_DATABASES_PARAM: &str = "attach_databases"; impl SqliteTableProviderFactory { #[must_use] pub fn new() -> Self { - Self {} + Self { + instances: Arc::new(Mutex::new(HashMap::new())), + } } #[must_use] @@ -132,6 +138,32 @@ impl SqliteTableProviderFactory { .cloned() .unwrap_or(default_filepath) } + + pub async fn get_or_init_instance( + &self, + db_path: impl Into>, + mode: Mode, + ) -> Result { + let db_path = db_path.into(); + let key = match mode { + Mode::Memory => DbInstanceKey::memory(), + Mode::File => DbInstanceKey::file(Arc::clone(&db_path)), + }; + let mut instances = self.instances.lock().await; + + if let Some(instance) = instances.get(&key) { + return instance.try_clone().await.context(DbConnectionPoolSnafu); + } + + let pool = SqliteConnectionPoolFactory::new(&db_path, mode) + .build() + .await + .context(DbConnectionPoolSnafu)?; + + instances.insert(key, pool.try_clone().await.context(DbConnectionPoolSnafu)?); + + Ok(pool) + } } impl Default for SqliteTableProviderFactory { @@ -143,10 +175,6 @@ impl Default for SqliteTableProviderFactory { pub type DynSqliteConnectionPool = dyn DbConnectionPool + Send + Sync; -fn handle_db_error(err: db_connection_pool::Error) -> DataFusionError { - to_datafusion_error(Error::DbConnectionPoolError { source: err }) -} - #[async_trait] impl TableProviderFactory for SqliteTableProviderFactory { #[allow(clippy::too_many_lines)] @@ -191,13 +219,12 @@ impl TableProviderFactory for SqliteTableProviderFactory { ); } - let db_path = self.sqlite_file_path(&name, &cmd.options); + let db_path: Arc = self.sqlite_file_path(&name, &cmd.options).into(); let pool: Arc = Arc::new( - SqliteConnectionPoolFactory::new(&db_path, mode) - .build() + self.get_or_init_instance(Arc::clone(&db_path), mode) .await - .map_err(handle_db_error)?, + .map_err(to_datafusion_error)?, ); let read_pool = if mode == Mode::Memory { @@ -207,11 +234,9 @@ impl TableProviderFactory for SqliteTableProviderFactory { // even though we setup SQLite to use WAL mode, the pool isn't really a pool so shares the same connection // and we can't have concurrent writes when sharing the same connection Arc::new( - SqliteConnectionPoolFactory::new(&db_path, mode) - .with_databases(self.attach_databases(&options)) - .build() + self.get_or_init_instance(Arc::clone(&db_path), mode) .await - .map_err(handle_db_error)?, + .map_err(to_datafusion_error)?, ) }; @@ -611,7 +636,6 @@ impl Sqlite { #[cfg(test)] pub(crate) mod tests { - use arrow::datatypes::{DataType, Schema}; use datafusion::{ common::ToDFSchema, prelude::SessionContext, sql::sqlparser::ast::TableConstraint, From a989de01036281bc09b68e0526f3c7ccdac03e4b Mon Sep 17 00:00:00 2001 From: Qianqian <130200611+Sevenannn@users.noreply.github.com> Date: Wed, 18 Sep 2024 06:41:26 -0700 Subject: [PATCH 21/40] Support all time() types in MySQL (#97) * Support all time() types in MySQL * Include test for time types --- .../dbconnection/mysqlconn.rs | 3 +- tests/mysql/mod.rs | 73 +++++++++++++++++++ 2 files changed, 74 insertions(+), 2 deletions(-) diff --git a/src/sql/db_connection_pool/dbconnection/mysqlconn.rs b/src/sql/db_connection_pool/dbconnection/mysqlconn.rs index 74f149d..755ff8f 100644 --- a/src/sql/db_connection_pool/dbconnection/mysqlconn.rs +++ b/src/sql/db_connection_pool/dbconnection/mysqlconn.rs @@ -203,10 +203,9 @@ fn map_str_type_to_column_type(data_type: &str) -> Result { _ if data_type.starts_with("float") => ColumnType::MYSQL_TYPE_FLOAT, _ if data_type.starts_with("double") => ColumnType::MYSQL_TYPE_DOUBLE, _ if data_type.eq("null") => ColumnType::MYSQL_TYPE_NULL, + _ if data_type.starts_with("time") => ColumnType::MYSQL_TYPE_TIME, _ if data_type.starts_with("timestamp") => ColumnType::MYSQL_TYPE_TIMESTAMP, _ if data_type.starts_with("datetime") => ColumnType::MYSQL_TYPE_DATETIME, - _ if data_type.eq("time2") => ColumnType::MYSQL_TYPE_TIME2, - _ if data_type.eq("time") => ColumnType::MYSQL_TYPE_TIME, _ if data_type.eq("date") => ColumnType::MYSQL_TYPE_DATE, _ if data_type.eq("year") => ColumnType::MYSQL_TYPE_YEAR, _ if data_type.eq("newdate") => ColumnType::MYSQL_TYPE_NEWDATE, diff --git a/tests/mysql/mod.rs b/tests/mysql/mod.rs index 34a3b3c..5b2378b 100644 --- a/tests/mysql/mod.rs +++ b/tests/mysql/mod.rs @@ -234,6 +234,78 @@ VALUES ( .await; } +async fn test_time_types(port: usize) { + let create_table_stmt = " +CREATE TABLE time_table ( + t0 TIME(0), + t1 TIME(1), + t2 TIME(2), + t3 TIME(3), + t4 TIME(4), + t5 TIME(5), + t6 TIME(6) +); + "; + let insert_table_stmt = " +INSERT INTO time_table (t0, t1, t2, t3, t4, t5, t6) +VALUES + ('12:30:00', + '12:30:00.1', + '12:30:00.12', + '12:30:00.123', + '12:30:00.1234', + '12:30:00.12345', + '12:30:00.123456'); + "; + + let schema = Arc::new(Schema::new(vec![ + Field::new("t0", DataType::Time64(TimeUnit::Nanosecond), true), + Field::new("t1", DataType::Time64(TimeUnit::Nanosecond), true), + Field::new("t2", DataType::Time64(TimeUnit::Nanosecond), true), + Field::new("t3", DataType::Time64(TimeUnit::Nanosecond), true), + Field::new("t4", DataType::Time64(TimeUnit::Nanosecond), true), + Field::new("t5", DataType::Time64(TimeUnit::Nanosecond), true), + Field::new("t6", DataType::Time64(TimeUnit::Nanosecond), true), + ])); + + let expected_record = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Time64NanosecondArray::from(vec![ + (12 * 3600 + 30 * 60 + 0) * 1_000_000_000, + ])), + Arc::new(Time64NanosecondArray::from(vec![ + (12 * 3600 + 30 * 60) * 1_000_000_000 + 100_000_000, + ])), + Arc::new(Time64NanosecondArray::from(vec![ + (12 * 3600 + 30 * 60) * 1_000_000_000 + 120_000_000, + ])), + Arc::new(Time64NanosecondArray::from(vec![ + (12 * 3600 + 30 * 60) * 1_000_000_000 + 123_000_000, + ])), + Arc::new(Time64NanosecondArray::from(vec![ + (12 * 3600 + 30 * 60) * 1_000_000_000 + 123_400_000, + ])), + Arc::new(Time64NanosecondArray::from(vec![ + (12 * 3600 + 30 * 60) * 1_000_000_000 + 123_450_000, + ])), + Arc::new(Time64NanosecondArray::from(vec![ + (12 * 3600 + 30 * 60) * 1_000_000_000 + 123_456_000, + ])), + ], + ) + .expect("Failed to created arrow record batch"); + + arrow_mysql_one_way( + port, + "time_table", + create_table_stmt, + insert_table_stmt, + expected_record, + ) + .await; +} + async fn arrow_mysql_one_way( port: usize, table_name: &str, @@ -310,6 +382,7 @@ async fn test_mysql_arrow_oneway() { test_mysql_decimal_types(port).await; test_mysql_timestamp_types(port).await; test_mysql_datetime_types(port).await; + test_time_types(port).await; mysql_container.remove().await.expect("container to stop"); } From c4b6c709f3246451b36c4fc344dd04339331b44d Mon Sep 17 00:00:00 2001 From: Phillip LeBlanc Date: Thu, 19 Sep 2024 17:22:35 +0900 Subject: [PATCH 22/40] Upgrade to Arrow 53, DataFusion 42 and DuckDB 1.1 (#111) --- Cargo.toml | 17 ++++++++--------- src/util/constraints.rs | 4 +--- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index fe70bad..d5919ba 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,12 +8,11 @@ license = "Apache-2.0" description = "Extend the capabilities of DataFusion to support additional data sources via implementations of the `TableProvider` trait." [dependencies] -arrow = "52.2.0" -arrow-array = { version = "52.2.0", optional = true } -arrow-cast = { version = "52.2.0", optional = true } -arrow-flight = { version = "52.2.0", optional = true, features = ["flight-sql-experimental", "tls"] } -arrow-schema = { version = "52.2.0", optional = true, features = ["serde"] } -arrow-json = "52.2.0" +arrow = "53" +arrow-array = { version = "53", optional = true } +arrow-flight = { version = "53", optional = true, features = ["flight-sql-experimental", "tls"] } +arrow-schema = { version = "53", optional = true, features = ["serde"] } +arrow-json = "53" async-stream = { version = "0.3.5", optional = true } async-trait = "0.1.80" num-bigint = "0.4.4" @@ -38,7 +37,7 @@ duckdb = { version = "1", features = [ fallible-iterator = "0.3.0" futures = "0.3.30" mysql_async = { version = "0.34.1", features = ["native-tls-tls", "chrono"], optional = true } -prost = { version = "0.12" , optional = true } # pinned for arrow-flight compat +prost = { version = "0.13.2", optional = true } r2d2 = { version = "0.8.10", optional = true } rusqlite = { version = "0.31.0", optional = true } sea-query = { version = "0.32.0-rc.1", features = [ @@ -70,7 +69,7 @@ trust-dns-resolver = "0.23.2" url = "2.5.1" pem = { version = "3.0.4", optional = true } tokio-rusqlite = { version = "0.5.1", optional = true } -tonic = { version = "0.11", optional = true } # pinned for arrow-flight compat +tonic = { version = "0.12.2", optional = true } itertools = "0.13.0" dyn-clone = { version = "1.0.17", optional = true } geo-types = "0.7.13" @@ -113,4 +112,4 @@ sqlite-federation = ["sqlite"] postgres-federation = ["postgres"] [patch.crates-io] -duckdb = { git = "https://github.com/spiceai/duckdb-rs.git", rev = "f2ca47d094a5636df8b9f3792b2f474a7b210dc1" } +duckdb = { git = "https://github.com/spiceai/duckdb-rs.git", rev = "5b98603705a381ceeb5cc371e4f606b7332b57ce" } diff --git a/src/util/constraints.rs b/src/util/constraints.rs index 472c48e..2e9d97a 100644 --- a/src/util/constraints.rs +++ b/src/util/constraints.rs @@ -61,9 +61,7 @@ async fn validate_batch_with_constraint( let ctx = SessionContext::new(); let df = ctx.read_batches(batches).context(DataFusionSnafu)?; - let Ok(count_name) = count(lit(COUNT_STAR_EXPANSION)).display_name() else { - unreachable!() - }; + let count_name = count(lit(COUNT_STAR_EXPANSION)).schema_name().to_string(); // This is equivalent to: // ```sql From a0492fc067de4cf3e313ab006c2e324cf0582ce2 Mon Sep 17 00:00:00 2001 From: Qianqian <130200611+Sevenannn@users.noreply.github.com> Date: Thu, 19 Sep 2024 13:00:03 -0700 Subject: [PATCH 23/40] Handle inconsistent scale in Postgres Numeric Type data (#110) --- src/sql/arrow_sql_gen/postgres.rs | 25 +++++++++++- tests/postgres/mod.rs | 68 ++++++++++++++++++++++++++++++- 2 files changed, 89 insertions(+), 4 deletions(-) diff --git a/src/sql/arrow_sql_gen/postgres.rs b/src/sql/arrow_sql_gen/postgres.rs index 9781590..f247c16 100644 --- a/src/sql/arrow_sql_gen/postgres.rs +++ b/src/sql/arrow_sql_gen/postgres.rs @@ -78,6 +78,9 @@ pub enum Error { #[snafu(display("No Arrow field found for index {index}"))] NoArrowFieldForIndex { index: usize }, + #[snafu(display("No PostgreSQL scale found for index {index}"))] + NoPostgresScaleForIndex { index: usize }, + #[snafu(display("No column name for index: {index}"))] NoColumnNameForIndex { index: usize }, } @@ -183,6 +186,7 @@ pub fn rows_to_arrow(rows: &[Row]) -> Result { let mut arrow_fields: Vec> = Vec::new(); let mut arrow_columns_builders: Vec>> = Vec::new(); let mut postgres_types: Vec = Vec::new(); + let mut postgres_numeric_scales: Vec> = Vec::new(); let mut column_names: Vec = Vec::new(); if !rows.is_empty() { @@ -197,6 +201,7 @@ pub fn rows_to_arrow(rows: &[Row]) -> Result { } None => arrow_fields.push(None), } + postgres_numeric_scales.push(None); arrow_columns_builders .push(map_data_type_to_array_builder_optional(data_type.as_ref())); postgres_types.push(column_type.clone()); @@ -214,6 +219,10 @@ pub fn rows_to_arrow(rows: &[Row]) -> Result { return NoArrowFieldForIndexSnafu { index: i }.fail(); }; + let Some(postgres_numeric_scale) = postgres_numeric_scales.get_mut(i) else { + return NoPostgresScaleForIndexSnafu { index: i }.fail(); + }; + match *postgres_type { Type::INT2 => { handle_primitive_type!(builder, Type::INT2, Int16Builder, i16, row, i); @@ -440,12 +449,19 @@ pub fn rows_to_arrow(rows: &[Row]) -> Result { *arrow_field = Some(new_arrow_field); } + if postgres_numeric_scale.is_none() { + *postgres_numeric_scale = Some(scale); + }; + let Some(v) = v else { dec_builder.append_null(); continue; }; - let Some(v_i128) = v.to_decimal_128() else { + // Record Batch Scale is determined by first row, while Postgres Numeric Type doesn't have fixed scale + // Resolve scale difference for incoming records + let dest_scale = postgres_numeric_scale.unwrap_or_default(); + let Some(v_i128) = v.to_decimal_128_with_scale(dest_scale) else { return FailedToConvertBigDecimalToI128Snafu { big_decimal: v.inner, } @@ -909,7 +925,12 @@ struct BigDecimalFromSql { } impl BigDecimalFromSql { - fn to_decimal_128(&self) -> Option { + fn to_decimal_128_with_scale(&self, dest_scale: u16) -> Option { + // Resolve scale difference by upscaling / downscaling to the scale of arrow Decimal128 type + if dest_scale != self.scale { + return (&self.inner * 10i128.pow(u32::from(dest_scale))).to_i128(); + } + (&self.inner * 10i128.pow(u32::from(self.scale))).to_i128() } diff --git a/tests/postgres/mod.rs b/tests/postgres/mod.rs index c5645f1..3e5d26e 100644 --- a/tests/postgres/mod.rs +++ b/tests/postgres/mod.rs @@ -1,5 +1,8 @@ use crate::arrow_record_batch_gen::*; -use arrow::{array::RecordBatch, datatypes::SchemaRef}; +use arrow::{ + array::{Decimal128Array, RecordBatch}, + datatypes::{DataType, Field, Schema, SchemaRef}, +}; use datafusion::catalog::TableProviderFactory; use datafusion::common::{Constraints, ToDFSchema}; use datafusion::execution::context::SessionContext; @@ -151,6 +154,11 @@ async fn test_arrow_postgres_one_way(container_manager: &Mutex start_container(&container_manager).await; } + test_postgres_enum_type(container_manager.port).await; + test_postgres_numeric_type(container_manager.port).await; +} + +async fn test_postgres_enum_type(port: usize) { let extra_stmt = Some("CREATE TYPE mood AS ENUM ('happy', 'sad', 'neutral');"); let create_table_stmt = " CREATE TABLE person_mood ( @@ -164,7 +172,7 @@ async fn test_arrow_postgres_one_way(container_manager: &Mutex let (expected_record, _) = get_arrow_dictionary_array_record_batch(); arrow_postgres_one_way( - container_manager.port, + port, "person_mood", create_table_stmt, insert_table_stmt, @@ -174,6 +182,62 @@ async fn test_arrow_postgres_one_way(container_manager: &Mutex .await; } +async fn test_postgres_numeric_type(port: usize) { + let extra_stmt = None; + let create_table_stmt = " + CREATE TABLE numeric_values ( + first_column NUMERIC, -- No precision specified + second_column NUMERIC -- No precision specified +);"; + + let insert_table_stmt = " + INSERT INTO numeric_values (first_column, second_column) VALUES +(1.0917217805754313, 0.00000000000000000000), +(0.97824560830666753739, 1220.9175000000000000), +(1.0917217805754313, 52.9533333333333333); + "; + + let schema = Arc::new(Schema::new(vec![ + Field::new("first_column", DataType::Decimal128(38, 16), true), + Field::new("second_column", DataType::Decimal128(38, 20), true), + ])); + + let expected_record = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new( + Decimal128Array::from(vec![ + 10917217805754313i128, + 9782456083066675i128, + 10917217805754313i128, + ]) + .with_precision_and_scale(38, 16) + .unwrap(), + ), + Arc::new( + Decimal128Array::from(vec![ + 0i128, + 122091750000000000000000i128, + 5295333333333333330000i128, + ]) + .with_precision_and_scale(38, 20) + .unwrap(), + ), + ], + ) + .expect("Failed to created arrow record batch"); + + arrow_postgres_one_way( + port, + "numeric_values", + create_table_stmt, + insert_table_stmt, + extra_stmt, + expected_record, + ) + .await; +} + async fn arrow_postgres_one_way( port: usize, table_name: &str, From d71e9d9e7771500a0f74da579d8fdbf406a72789 Mon Sep 17 00:00:00 2001 From: Qianqian <130200611+Sevenannn@users.noreply.github.com> Date: Thu, 19 Sep 2024 21:36:40 -0700 Subject: [PATCH 24/40] Verify MySQL parameters and connections before creating connection pool (#113) * Verify MySQL parameters and connections before creating connection pool * Update --- src/sql/db_connection_pool/mysqlpool.rs | 95 ++++++++++++++++++------- 1 file changed, 68 insertions(+), 27 deletions(-) diff --git a/src/sql/db_connection_pool/mysqlpool.rs b/src/sql/db_connection_pool/mysqlpool.rs index 73cde8b..78ba178 100644 --- a/src/sql/db_connection_pool/mysqlpool.rs +++ b/src/sql/db_connection_pool/mysqlpool.rs @@ -3,7 +3,7 @@ use std::{collections::HashMap, path::PathBuf, sync::Arc}; use async_trait::async_trait; use mysql_async::{ prelude::{Queryable, ToValue}, - DriverError, Params, Row, SslOpts, + DriverError, Opts, Params, Row, SslOpts, }; use secrecy::{ExposeSecret, Secret, SecretString}; use snafu::{ResultExt, Snafu}; @@ -13,24 +13,41 @@ use crate::{ dbconnection::{mysqlconn::MySQLConnection, AsyncDbConnection, DbConnection}, JoinPushDown, }, - util, + util::{self, ns_lookup::verify_ns_lookup_and_tcp_connect}, }; -use super::{DbConnectionPool, Result}; +use super::DbConnectionPool; + +pub type Result = std::result::Result; #[derive(Debug, Snafu)] pub enum Error { - #[snafu(display("ConnectionPoolError: {source}"))] - ConnectionPoolError { source: mysql_async::UrlError }, + #[snafu(display("MySQL connection error: {source}"))] + MySQLConnectionError { source: mysql_async::Error }, - #[snafu(display("ConnectionPoolRunError: {source}"))] - ConnectionPoolRunError { source: mysql_async::Error }, + #[snafu(display("Invalid MySQL connection string: {source} "))] + InvalidConnectionString { source: mysql_async::UrlError }, #[snafu(display("Invalid parameter: {parameter_name}"))] InvalidParameterError { parameter_name: String }, #[snafu(display("Invalid root cert path: {path}"))] InvalidRootCertPathError { path: String }, + + #[snafu(display("Cannot connect to MySQL on {host}:{port}. Ensure that the host and port are correctly configured, and that the host is reachable."))] + InvalidHostOrPortError { + source: crate::util::ns_lookup::Error, + host: String, + port: u16, + }, + + #[snafu(display( + "Authentication failed. Ensure that the username and password are correctly configured." + ))] + InvalidUsernameOrPassword, + + #[snafu(display("{message}"))] + UnknownMySQLDatabase { message: String }, } pub struct MySQLConnectionPool { @@ -68,9 +85,10 @@ impl MySQLConnectionPool { if let Some(mysql_connection_string) = params.get("connection_string").map(Secret::expose_secret) { - connection_string = mysql_async::OptsBuilder::from_opts(mysql_async::Opts::from_url( - mysql_connection_string.as_str(), - )?); + connection_string = mysql_async::OptsBuilder::from_opts( + mysql_async::Opts::from_url(mysql_connection_string.as_str()) + .context(InvalidConnectionStringSnafu)?, + ); } else { if let Some(mysql_host) = params.get("host").map(Secret::expose_secret) { connection_string = connection_string.ip_or_hostname(mysql_host.as_str()); @@ -120,29 +138,39 @@ impl MySQLConnectionPool { let opts = mysql_async::Opts::from(connection_string); + verify_mysql_opts(&opts).await?; + let join_push_down = get_join_context(&opts); let pool = mysql_async::Pool::new(opts); // Test the connection - let mut conn = pool - .get_conn() - .await - .map_err(|err| match err { - // In case of an incorrect user name, the error `Unknown authentication plugin 'sha256_password'` is returned. - // We override it with a more user-friendly error message. - mysql_async::Error::Driver(DriverError::UnknownAuthPlugin { .. }) => { - mysql_async::Error::Other( - "Unable to authenticate. Is the user name and password correct?".into(), - ) + let mut conn = pool.get_conn().await.map_err(|err| match err { + // In case of an incorrect user name, the error `Unknown authentication plugin 'sha256_password'` is returned. + // We override it with a more user-friendly error message. + mysql_async::Error::Driver(DriverError::UnknownAuthPlugin { .. }) => { + Error::InvalidUsernameOrPassword + } + mysql_async::Error::Server(server_error) => { + match server_error.code { + // Code 1049: Server error: `ERROR 42000 (1049): Unknown database + 1049 => Error::UnknownMySQLDatabase { + message: server_error.message, + }, + // Code 1045: Server error: ERROR 1045 (28000): Access denied for user (using password: YES / NO) + 1045 => Error::InvalidUsernameOrPassword, + _ => Error::MySQLConnectionError { + source: mysql_async::Error::Server(server_error), + }, } - _ => err, - }) - .context(ConnectionPoolRunSnafu)?; + } + _ => Error::MySQLConnectionError { source: err }, + })?; + let _rows: Vec = conn .exec("SELECT 1", Params::Empty) .await - .context(ConnectionPoolRunSnafu)?; + .context(MySQLConnectionSnafu)?; Ok(Self { pool: Arc::new(pool), @@ -157,11 +185,23 @@ impl MySQLConnectionPool { /// Returns an error if there is a problem creating the connection pool. pub async fn connect_direct(&self) -> super::Result { let pool = Arc::clone(&self.pool); - let conn = pool.get_conn().await.context(ConnectionPoolRunSnafu)?; + let conn = pool.get_conn().await.context(MySQLConnectionSnafu)?; Ok(MySQLConnection::new(conn)) } } +async fn verify_mysql_opts(opts: &Opts) -> Result<()> { + // Verify the host and port are correct + let host = opts.ip_or_hostname(); + let port = opts.tcp_port(); + + verify_ns_lookup_and_tcp_connect(host, port) + .await + .context(InvalidHostOrPortSnafu { host, port })?; + + Ok(()) +} + fn get_join_context(opts: &mysql_async::Opts) -> JoinPushDown { let mut join_context = format!("host={},port={}", opts.ip_or_hostname(), opts.tcp_port()); if let Some(db_name) = opts.db_name() { @@ -201,9 +241,10 @@ fn get_ssl_opts(ssl_mode: &str, rootcert_path: Option) -> Option for MySQLConnectionPool { async fn connect( &self, - ) -> Result>> { + ) -> super::Result>> + { let pool = Arc::clone(&self.pool); - let conn = pool.get_conn().await.context(ConnectionPoolRunSnafu)?; + let conn = pool.get_conn().await.context(MySQLConnectionSnafu)?; Ok(Box::new(MySQLConnection::new(conn))) } From fc710faa0ed9107d1d8062b1db71ee28724eb1d1 Mon Sep 17 00:00:00 2001 From: hozan23 Date: Tue, 24 Sep 2024 11:12:58 +0200 Subject: [PATCH 25/40] minor fix: upgrade datafusion to 42 --- Cargo.toml | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index d5919ba..f01735f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,6 +10,7 @@ description = "Extend the capabilities of DataFusion to support additional data [dependencies] arrow = "53" arrow-array = { version = "53", optional = true } +arrow-cast = { version = "53", optional = true } arrow-flight = { version = "53", optional = true, features = ["flight-sql-experimental", "tls"] } arrow-schema = { version = "53", optional = true, features = ["serde"] } arrow-json = "53" @@ -21,12 +22,12 @@ bytes = { version = "1.7.1", optional = true } bigdecimal = "0.4.5" byteorder = "1.5.0" chrono = "0.4.38" -datafusion = "41.0.0" -datafusion-expr = { version = "41.0.0", optional = true } -datafusion-physical-expr = { version = "41.0.0", optional = true } -datafusion-physical-plan = { version = "41.0.0", optional = true } -datafusion-proto = { version = "41.0.0", optional = true } -datafusion-federation = { version = "0.2.2", features = ["sql"] } +datafusion = "42.0.0" +datafusion-expr = { version = "42.0.0", optional = true } +datafusion-physical-expr = { version = "42.0.0", optional = true } +datafusion-physical-plan = { version = "42.0.0", optional = true } +datafusion-proto = { version = "42.0.0", optional = true } +datafusion-federation = { version = "0.3.0", features = ["sql"] } duckdb = { version = "1", features = [ "bundled", "r2d2", From 496487eb576a393d084d578ab86928976f4d5f62 Mon Sep 17 00:00:00 2001 From: Qianqian <130200611+Sevenannn@users.noreply.github.com> Date: Tue, 24 Sep 2024 11:13:56 -0700 Subject: [PATCH 26/40] Propagate MySQL wrong table error (#114) --- src/postgres.rs | 11 +++----- .../dbconnection/mysqlconn.rs | 26 +++++++++++++++---- 2 files changed, 24 insertions(+), 13 deletions(-) diff --git a/src/postgres.rs b/src/postgres.rs index 494fb44..9cfe6c1 100644 --- a/src/postgres.rs +++ b/src/postgres.rs @@ -7,7 +7,9 @@ use crate::sql::db_connection_pool::{ postgrespool::{self, PostgresConnectionPool}, DbConnectionPool, }; -use crate::sql::sql_provider_datafusion::{self, Engine, SqlTable}; + +use crate::sql::sql_provider_datafusion::{Engine, SqlTable}; + use arrow::{ array::RecordBatch, datatypes::{Schema, SchemaRef}, @@ -86,13 +88,6 @@ pub enum Error { source: tokio_postgres::error::Error, }, - #[snafu(display( - "Unable to construct the DataFusion SQL Table Provider for Postgres: {source}" - ))] - UnableToConstructSqlTable { - source: sql_provider_datafusion::Error, - }, - #[snafu(display("Unable to generate SQL: {source}"))] UnableToGenerateSQL { source: DataFusionError }, diff --git a/src/sql/db_connection_pool/dbconnection/mysqlconn.rs b/src/sql/db_connection_pool/dbconnection/mysqlconn.rs index 755ff8f..d0785fa 100644 --- a/src/sql/db_connection_pool/dbconnection/mysqlconn.rs +++ b/src/sql/db_connection_pool/dbconnection/mysqlconn.rs @@ -78,11 +78,27 @@ impl<'a> AsyncDbConnection for MySQLConnection { let columns_meta_query = format!("SHOW COLUMNS FROM {}", table_reference.to_quoted_string()); - let columns_meta: Vec = conn - .exec(&columns_meta_query, Params::Empty) - .await - .boxed() - .context(super::UnableToGetSchemaSnafu)?; + let columns_meta: Vec = match conn.exec(&columns_meta_query, Params::Empty).await { + Ok(columns_meta) => columns_meta, + Err(e) => match e { + mysql_async::Error::Server(server_error) => { + if server_error.code == 1146 { + return Err(super::Error::UndefinedTable { + source: Box::new(server_error.clone()), + table_name: table_reference.to_string(), + }); + } + return Err(super::Error::UnableToGetSchema { + source: Box::new(mysql_async::Error::Server(server_error)), + }); + } + _ => { + return Err(super::Error::UnableToGetSchema { + source: Box::new(e), + }) + } + }, + }; Ok(columns_meta_to_schema(columns_meta).context(super::UnableToGetSchemaSnafu)?) } From fb7376e568e2a38eff95d4fb35ad5e5ed0fcee86 Mon Sep 17 00:00:00 2001 From: Qianqian <130200611+Sevenannn@users.noreply.github.com> Date: Tue, 24 Sep 2024 19:12:09 -0700 Subject: [PATCH 27/40] Fix MySQL timestamp type (#116) --- src/sql/db_connection_pool/dbconnection/mysqlconn.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sql/db_connection_pool/dbconnection/mysqlconn.rs b/src/sql/db_connection_pool/dbconnection/mysqlconn.rs index d0785fa..0e4ee61 100644 --- a/src/sql/db_connection_pool/dbconnection/mysqlconn.rs +++ b/src/sql/db_connection_pool/dbconnection/mysqlconn.rs @@ -219,8 +219,8 @@ fn map_str_type_to_column_type(data_type: &str) -> Result { _ if data_type.starts_with("float") => ColumnType::MYSQL_TYPE_FLOAT, _ if data_type.starts_with("double") => ColumnType::MYSQL_TYPE_DOUBLE, _ if data_type.eq("null") => ColumnType::MYSQL_TYPE_NULL, - _ if data_type.starts_with("time") => ColumnType::MYSQL_TYPE_TIME, _ if data_type.starts_with("timestamp") => ColumnType::MYSQL_TYPE_TIMESTAMP, + _ if data_type.starts_with("time") => ColumnType::MYSQL_TYPE_TIME, _ if data_type.starts_with("datetime") => ColumnType::MYSQL_TYPE_DATETIME, _ if data_type.eq("date") => ColumnType::MYSQL_TYPE_DATE, _ if data_type.eq("year") => ColumnType::MYSQL_TYPE_YEAR, From a397212f5c0ef59492636ae443004440e2f62b14 Mon Sep 17 00:00:00 2001 From: Sergei Grebnov Date: Wed, 25 Sep 2024 22:10:43 -0700 Subject: [PATCH 28/40] Postgres should respect target decimal precision and scale (#120) --- src/sql/arrow_sql_gen/postgres.rs | 38 +++++++++++++++++-- .../dbconnection/postgresconn.rs | 8 ++-- 2 files changed, 38 insertions(+), 8 deletions(-) diff --git a/src/sql/arrow_sql_gen/postgres.rs b/src/sql/arrow_sql_gen/postgres.rs index f247c16..221618f 100644 --- a/src/sql/arrow_sql_gen/postgres.rs +++ b/src/sql/arrow_sql_gen/postgres.rs @@ -12,7 +12,8 @@ use arrow::array::{ Time64NanosecondBuilder, TimestampNanosecondBuilder, UInt32Builder, }; use arrow::datatypes::{ - DataType, Date32Type, Field, Int8Type, IntervalMonthDayNanoType, IntervalUnit, Schema, TimeUnit, + DataType, Date32Type, Field, Int8Type, IntervalMonthDayNanoType, IntervalUnit, Schema, + SchemaRef, TimeUnit, }; use bigdecimal::num_bigint::BigInt; use bigdecimal::num_bigint::Sign; @@ -182,7 +183,7 @@ macro_rules! handle_composite_types { /// /// Returns an error if there is a failure in converting the rows to a `RecordBatch`. #[allow(clippy::too_many_lines)] -pub fn rows_to_arrow(rows: &[Row]) -> Result { +pub fn rows_to_arrow(rows: &[Row], projected_schema: &Option) -> Result { let mut arrow_fields: Vec> = Vec::new(); let mut arrow_columns_builders: Vec>> = Vec::new(); let mut postgres_types: Vec = Vec::new(); @@ -194,14 +195,32 @@ pub fn rows_to_arrow(rows: &[Row]) -> Result { for column in row.columns() { let column_name = column.name(); let column_type = column.type_(); - let data_type = map_column_type_to_data_type(column_type); + + let mut numeric_scale: Option = None; + + let data_type = if *column_type == Type::NUMERIC { + if let Some(schema) = projected_schema.as_ref() { + match get_decimal_column_precision_and_scale(column_name, schema) { + Some((precision, scale)) => { + numeric_scale = Some(u16::try_from(scale).unwrap_or_default()); + Some(DataType::Decimal128(precision, scale)) + } + None => None, + } + } else { + None + } + } else { + map_column_type_to_data_type(column_type) + }; + match &data_type { Some(data_type) => { arrow_fields.push(Some(Field::new(column_name, data_type.clone(), true))); } None => arrow_fields.push(None), } - postgres_numeric_scales.push(None); + postgres_numeric_scales.push(numeric_scale); arrow_columns_builders .push(map_data_type_to_array_builder_optional(data_type.as_ref())); postgres_types.push(column_type.clone()); @@ -1251,3 +1270,14 @@ mod tests { assert_eq!(positive_result.wkb, positive_geometry); } } + +fn get_decimal_column_precision_and_scale( + column_name: &str, + projected_schema: &SchemaRef, +) -> Option<(u8, i8)> { + let field = projected_schema.field_with_name(column_name).ok()?; + match field.data_type() { + DataType::Decimal128(precision, scale) => Some((*precision, *scale)), + _ => None, + } +} diff --git a/src/sql/db_connection_pool/dbconnection/postgresconn.rs b/src/sql/db_connection_pool/dbconnection/postgresconn.rs index 870e92c..0c78d48 100644 --- a/src/sql/db_connection_pool/dbconnection/postgresconn.rs +++ b/src/sql/db_connection_pool/dbconnection/postgresconn.rs @@ -111,7 +111,7 @@ impl<'a> } }; - let rec = match rows_to_arrow(rows.as_slice()) { + let rec = match rows_to_arrow(rows.as_slice(), &None) { Ok(rec) => rec, Err(e) => { return Err(super::Error::UnableToGetSchema { @@ -128,7 +128,7 @@ impl<'a> &self, sql: &str, params: &[&'a (dyn ToSql + Sync)], - _projected_schema: Option, + projected_schema: Option, ) -> Result { // TODO: We should have a way to detect if params have been passed // if they haven't we should use .copy_out instead, because it should be much faster @@ -139,12 +139,12 @@ impl<'a> .context(QuerySnafu)?; // chunk the stream into groups of rows - let mut stream = streamable.chunks(4_000).boxed().map(|rows| { + let mut stream = streamable.chunks(4_000).boxed().map(move |rows| { let rows = rows .into_iter() .collect::, _>>() .context(QuerySnafu)?; - let rec = rows_to_arrow(rows.as_slice()).context(ConversionSnafu)?; + let rec = rows_to_arrow(rows.as_slice(), &projected_schema).context(ConversionSnafu)?; Ok::<_, PostgresError>(rec) }); From 83a52952df73fc629bc76b0e4e82f4cbda1a34ce Mon Sep 17 00:00:00 2001 From: Qianqian <130200611+Sevenannn@users.noreply.github.com> Date: Thu, 26 Sep 2024 11:07:45 -0700 Subject: [PATCH 29/40] Update row -> arrow conversion for all MYSQL_TYPE_VAR_STRING and MYSQL_TYPE_STRING types (#118) --- src/sql/arrow_sql_gen/mysql.rs | 13 ++--- .../dbconnection/mysqlconn.rs | 2 +- tests/mysql/mod.rs | 54 +++++++++++++++++++ 3 files changed, 58 insertions(+), 11 deletions(-) diff --git a/src/sql/arrow_sql_gen/mysql.rs b/src/sql/arrow_sql_gen/mysql.rs index 4e54460..fd433e0 100644 --- a/src/sql/arrow_sql_gen/mysql.rs +++ b/src/sql/arrow_sql_gen/mysql.rs @@ -3,7 +3,7 @@ use arrow::{ array::{ ArrayBuilder, ArrayRef, BinaryBuilder, Date32Builder, Decimal128Builder, Float32Builder, Float64Builder, Int16Builder, Int32Builder, Int64Builder, Int8Builder, LargeStringBuilder, - NullBuilder, RecordBatch, RecordBatchOptions, Time64NanosecondBuilder, + NullBuilder, RecordBatch, RecordBatchOptions, StringBuilder, Time64NanosecondBuilder, TimestampMicrosecondBuilder, UInt64Builder, }, datatypes::{DataType, Date32Type, Field, Schema, SchemaRef, TimeUnit}, @@ -292,14 +292,7 @@ pub fn rows_to_arrow(rows: &[Row], projected_schema: &Option) -> Resu i ); } else { - handle_primitive_type!( - builder, - column_type, - LargeStringBuilder, - String, - row, - i - ); + handle_primitive_type!(builder, column_type, StringBuilder, String, row, i); } } ColumnType::MYSQL_TYPE_DATE => { @@ -435,7 +428,7 @@ pub fn map_column_to_data_type( if column_is_binary { Some(DataType::Binary) } else { - Some(DataType::LargeUtf8) + Some(DataType::Utf8) } }, // replication only diff --git a/src/sql/db_connection_pool/dbconnection/mysqlconn.rs b/src/sql/db_connection_pool/dbconnection/mysqlconn.rs index 0e4ee61..09832bd 100644 --- a/src/sql/db_connection_pool/dbconnection/mysqlconn.rs +++ b/src/sql/db_connection_pool/dbconnection/mysqlconn.rs @@ -225,7 +225,6 @@ fn map_str_type_to_column_type(data_type: &str) -> Result { _ if data_type.eq("date") => ColumnType::MYSQL_TYPE_DATE, _ if data_type.eq("year") => ColumnType::MYSQL_TYPE_YEAR, _ if data_type.eq("newdate") => ColumnType::MYSQL_TYPE_NEWDATE, - _ if data_type.starts_with("varchar") => ColumnType::MYSQL_TYPE_VARCHAR, _ if data_type.starts_with("bit") => ColumnType::MYSQL_TYPE_BIT, _ if data_type.starts_with("array") => ColumnType::MYSQL_TYPE_TYPED_ARRAY, _ if data_type.starts_with("json") => ColumnType::MYSQL_TYPE_JSON, @@ -240,6 +239,7 @@ fn map_str_type_to_column_type(data_type: &str) -> Result { _ if data_type.starts_with("longtext") => ColumnType::MYSQL_TYPE_LONG_BLOB, _ if data_type.starts_with("blob") => ColumnType::MYSQL_TYPE_BLOB, _ if data_type.starts_with("text") => ColumnType::MYSQL_TYPE_BLOB, + _ if data_type.starts_with("varchar") => ColumnType::MYSQL_TYPE_VAR_STRING, _ if data_type.starts_with("varbinary") => ColumnType::MYSQL_TYPE_VAR_STRING, _ if data_type.starts_with("char") => ColumnType::MYSQL_TYPE_STRING, _ if data_type.starts_with("binary") => ColumnType::MYSQL_TYPE_STRING, diff --git a/tests/mysql/mod.rs b/tests/mysql/mod.rs index 5b2378b..bbea979 100644 --- a/tests/mysql/mod.rs +++ b/tests/mysql/mod.rs @@ -306,6 +306,59 @@ VALUES .await; } +async fn test_mysql_string_types(port: usize) { + let create_table_stmt = " +CREATE TABLE string_table ( + name VARCHAR(255), + data VARBINARY(255), + fixed_name CHAR(10), + fixed_data BINARY(10) +); + "; + let insert_table_stmt = " +INSERT INTO string_table (name, data, fixed_name, fixed_data) +VALUES +('Alice', 'Alice', 'ALICE', 'abc'), +('Bob', 'Bob', 'BOB', 'bob1234567'), +('Charlie', 'Charlie', 'CHARLIE', '0123456789'), +('Dave', 'Dave', 'DAVE', 'dave000000'); + "; + + let schema = Arc::new(Schema::new(vec![ + Field::new("name", DataType::Utf8, true), + Field::new("data", DataType::Binary, true), + Field::new("fixed_name", DataType::Utf8, true), + Field::new("fixed_data", DataType::Binary, true), + ])); + + let expected_record = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(StringArray::from(vec!["Alice", "Bob", "Charlie", "Dave"])), + Arc::new(BinaryArray::from_vec(vec![ + b"Alice", b"Bob", b"Charlie", b"Dave", + ])), + Arc::new(StringArray::from(vec!["ALICE", "BOB", "CHARLIE", "DAVE"])), + Arc::new(BinaryArray::from_vec(vec![ + b"abc\0\0\0\0\0\0\0", + b"bob1234567", + b"0123456789", + b"dave000000", + ])), + ], + ) + .expect("Failed to created arrow record batch"); + + arrow_mysql_one_way( + port, + "string_table", + create_table_stmt, + insert_table_stmt, + expected_record, + ) + .await; +} + async fn arrow_mysql_one_way( port: usize, table_name: &str, @@ -383,6 +436,7 @@ async fn test_mysql_arrow_oneway() { test_mysql_timestamp_types(port).await; test_mysql_datetime_types(port).await; test_time_types(port).await; + test_mysql_string_types(port).await; mysql_container.remove().await.expect("container to stop"); } From d4e0bd282c62ba132583d527de09caed64188f25 Mon Sep 17 00:00:00 2001 From: Qianqian <130200611+Sevenannn@users.noreply.github.com> Date: Thu, 26 Sep 2024 11:48:20 -0700 Subject: [PATCH 30/40] Use Decimal256 instead of Decimal128 for MySQL decimal type (#115) --- src/sql/arrow_sql_gen/mysql.rs | 131 ++++++++++++++++++++++---------- tests/mysql/mod.rs | 133 +++++++++++++++++++++++---------- 2 files changed, 186 insertions(+), 78 deletions(-) diff --git a/src/sql/arrow_sql_gen/mysql.rs b/src/sql/arrow_sql_gen/mysql.rs index fd433e0..b2728a0 100644 --- a/src/sql/arrow_sql_gen/mysql.rs +++ b/src/sql/arrow_sql_gen/mysql.rs @@ -1,20 +1,20 @@ use crate::sql::arrow_sql_gen::arrow::map_data_type_to_array_builder_optional; use arrow::{ array::{ - ArrayBuilder, ArrayRef, BinaryBuilder, Date32Builder, Decimal128Builder, Float32Builder, + ArrayBuilder, ArrayRef, BinaryBuilder, Date32Builder, Decimal256Builder, Decimal128Builder, Float32Builder, Float64Builder, Int16Builder, Int32Builder, Int64Builder, Int8Builder, LargeStringBuilder, NullBuilder, RecordBatch, RecordBatchOptions, StringBuilder, Time64NanosecondBuilder, TimestampMicrosecondBuilder, UInt64Builder, }, - datatypes::{DataType, Date32Type, Field, Schema, SchemaRef, TimeUnit}, + datatypes::{i256, DataType, Date32Type, Field, Schema, SchemaRef, TimeUnit}, }; use bigdecimal::BigDecimal; -use bigdecimal::ToPrimitive; use chrono::{NaiveDate, NaiveTime, Timelike}; use mysql_async::{consts::ColumnFlags, consts::ColumnType, FromValueError, Row, Value}; use snafu::{ResultExt, Snafu}; use std::{convert, sync::Arc}; use time::PrimitiveDateTime; +use bigdecimal::ToPrimitive; #[derive(Debug, Snafu)] pub enum Error { @@ -43,9 +43,6 @@ pub enum Error { source: mysql_async::FromValueError, }, - #[snafu(display("Failed to parse raw Postgres Bytes as BigDecimal: {:?}", bytes))] - FailedToParseBigDecimalFromPostgres { bytes: Vec }, - #[snafu(display("Cannot represent BigDecimal as i128: {big_decimal}"))] FailedToConvertBigDecimalToI128 { big_decimal: BigDecimal }, @@ -105,14 +102,14 @@ pub fn rows_to_arrow(rows: &[Row], projected_schema: &Option) -> Resu let (decimal_precision, decimal_scale) = match column_type { ColumnType::MYSQL_TYPE_DECIMAL | ColumnType::MYSQL_TYPE_NEWDECIMAL => { - // use 38 as default precision for decimal types if there is no way to get the precision from the column + // use 76 as default precision for decimal types if there is no way to get the precision from the column match projected_schema { Some(schema) => { let precision = - get_decimal_column_precision(&column_name, schema).unwrap_or(38); + get_decimal_column_precision(&column_name, schema).unwrap_or(76); (Some(precision), Some(column.decimals() as i8)) } - None => (Some(38), Some(column.decimals() as i8)), + None => (Some(76), Some(column.decimals() as i8)), } } _ => (None, None), @@ -143,7 +140,7 @@ pub fn rows_to_arrow(rows: &[Row], projected_schema: &Option) -> Resu let Some(builder) = arrow_columns_builders.get_mut(i) else { return NoBuilderForIndexSnafu { index: i }.fail(); }; - + match *mysql_type { ColumnType::MYSQL_TYPE_NULL => { let Some(builder) = builder else { @@ -235,34 +232,67 @@ pub fn rows_to_arrow(rows: &[Row], projected_schema: &Option) -> Resu return NoBuilderForIndexSnafu { index: i }.fail(); }; - let Some(builder) = builder.as_any_mut().downcast_mut::() - else { - return FailedToDowncastBuilderSnafu { - mysql_type: format!("{mysql_type:?}"), - } - .fail(); + let arrow_field = match arrow_fields.get(i) { + Some(Some(field)) => field, + _ => return NoArrowFieldForIndexSnafu { index: i }.fail(), }; - let val = handle_null_error(row.get_opt::(i).transpose()) - .context(FailedToGetRowValueSnafu { - mysql_type: ColumnType::MYSQL_TYPE_DECIMAL, - })?; - - let scale = match &val { - Some(val) => val.fractional_digit_count(), - None => 0, - }; - - let Some(val) = val else { - builder.append_null(); - continue; - }; - - let Some(val) = to_decimal_128(&val, scale) else { - return FailedToConvertBigDecimalToI128Snafu { big_decimal: val }.fail(); - }; - - builder.append_value(val); + match arrow_field.data_type() { + DataType::Decimal128(_, _) => { + let Some(builder) = builder.as_any_mut().downcast_mut::() + else { + return FailedToDowncastBuilderSnafu { + mysql_type: format!("{mysql_type:?}"), + } + .fail(); + }; + let val = handle_null_error(row.get_opt::(i).transpose()) + .context(FailedToGetRowValueSnafu { + mysql_type: ColumnType::MYSQL_TYPE_DECIMAL, + })?; + + let scale = match &val { + Some(val) => val.fractional_digit_count(), + None => 0, + }; + + let Some(val) = val else { + builder.append_null(); + continue; + }; + + let Some(val) = to_decimal_128(&val, scale) else { + return FailedToConvertBigDecimalToI128Snafu { big_decimal: val }.fail(); + }; + + builder.append_value(val); + } + DataType::Decimal256(_, _) => { + let Some(builder) = builder.as_any_mut().downcast_mut::() + else { + return FailedToDowncastBuilderSnafu { + mysql_type: format!("{mysql_type:?}"), + } + .fail(); + }; + + let val = handle_null_error(row.get_opt::(i).transpose()) + .context(FailedToGetRowValueSnafu { + mysql_type: ColumnType::MYSQL_TYPE_DECIMAL, + })?; + + let Some(val) = val else { + builder.append_null(); + continue; + }; + + let val = to_decimal_256(&val); + + builder.append_value(val); + } + // ColumnType::MYSQL_TYPE_DECIMAL & ColumnType::MYSQL_TYPE_NEWDECIMAL are only mapped to Decimal128/Decimal256 in `map_column_to_data_type` function + _ => unreachable!() + } } column_type @ (ColumnType::MYSQL_TYPE_VARCHAR | ColumnType::MYSQL_TYPE_JSON @@ -407,7 +437,12 @@ pub fn map_column_to_data_type( ColumnType::MYSQL_TYPE_FLOAT => Some(DataType::Float32), ColumnType::MYSQL_TYPE_DOUBLE => Some(DataType::Float64), // Decimal precision must be a value between 0x00 - 0x51, so it's safe to unwrap_or_default here - ColumnType::MYSQL_TYPE_DECIMAL | ColumnType::MYSQL_TYPE_NEWDECIMAL => Some(DataType::Decimal128(column_decimal_precision.unwrap_or_default(), column_decimal_scale.unwrap_or_default())), + ColumnType::MYSQL_TYPE_DECIMAL | ColumnType::MYSQL_TYPE_NEWDECIMAL => { + if column_decimal_precision.unwrap_or_default() > 38 { + return Some(DataType::Decimal256(column_decimal_precision.unwrap_or_default(), column_decimal_scale.unwrap_or_default())); + } + Some(DataType::Decimal128(column_decimal_precision.unwrap_or_default(), column_decimal_scale.unwrap_or_default())) + }, ColumnType::MYSQL_TYPE_TIMESTAMP | ColumnType::MYSQL_TYPE_DATETIME => { Some(DataType::Timestamp(TimeUnit::Microsecond, None)) }, @@ -450,14 +485,32 @@ fn to_decimal_128(decimal: &BigDecimal, scale: i64) -> Option { (decimal * 10i128.pow(scale.try_into().unwrap_or_default())).to_i128() } +fn to_decimal_256(decimal: &BigDecimal) -> i256 { + let (bigint_value, _) = decimal.as_bigint_and_exponent(); + let mut bigint_bytes = bigint_value.to_signed_bytes_le(); + + let is_negative = bigint_value.sign() == num_bigint::Sign::Minus; + let fill_byte = if is_negative { 0xFF } else { 0x00 }; + + if bigint_bytes.len() > 32 { + bigint_bytes.truncate(32); + } else { + bigint_bytes.resize(32, fill_byte); + }; + + let mut array = [0u8; 32]; + array.copy_from_slice(&bigint_bytes); + + i256::from_le_bytes(array) +} + fn get_decimal_column_precision(column_name: &str, projected_schema: &SchemaRef) -> Option { let field = projected_schema.field_with_name(column_name).ok()?; match field.data_type() { - DataType::Decimal128(precision, _) => Some(*precision), + DataType::Decimal256(precision, _) | DataType::Decimal128(precision, _) => Some(*precision), _ => None, } } - fn handle_null_error( result: Result, FromValueError>, ) -> Result, FromValueError> { diff --git a/tests/mysql/mod.rs b/tests/mysql/mod.rs index bbea979..0e73272 100644 --- a/tests/mysql/mod.rs +++ b/tests/mysql/mod.rs @@ -7,7 +7,7 @@ use std::sync::Arc; use arrow::{ array::*, - datatypes::{DataType, Field, Schema, TimeUnit}, + datatypes::{i256, DataType, Field, Schema, TimeUnit}, }; use datafusion_table_providers::sql::db_connection_pool::dbconnection::AsyncDbConnection; @@ -16,40 +16,6 @@ use crate::docker::RunningContainer; mod common; -async fn test_mysql_decimal_types(port: usize) { - let create_table_stmt = " - CREATE TABLE IF NOT EXISTS decimal_table (decimal_col DECIMAL(10, 2)); - "; - let insert_table_stmt = " - INSERT INTO decimal_table (decimal_col) VALUES (NULL), (12); - "; - - let schema = Arc::new(Schema::new(vec![Field::new( - "decimal_col", - DataType::Decimal128(10, 2), - true, - )])); - - let expected_record = RecordBatch::try_new( - Arc::clone(&schema), - vec![Arc::new( - Decimal128Array::from(vec![None, Some(i128::from(1200))]) - .with_precision_and_scale(10, 2) - .unwrap(), - )], - ) - .expect("Failed to created arrow record batch"); - - let _ = arrow_mysql_one_way( - port, - "decimal_table", - create_table_stmt, - insert_table_stmt, - expected_record, - ) - .await; -} - async fn test_mysql_timestamp_types(port: usize) { let create_table_stmt = " CREATE TABLE timestamp_table ( @@ -234,7 +200,7 @@ VALUES ( .await; } -async fn test_time_types(port: usize) { +async fn test_mysql_time_types(port: usize) { let create_table_stmt = " CREATE TABLE time_table ( t0 TIME(0), @@ -348,7 +314,6 @@ VALUES ], ) .expect("Failed to created arrow record batch"); - arrow_mysql_one_way( port, "string_table", @@ -359,6 +324,95 @@ VALUES .await; } +async fn test_mysql_decimal_types_to_decimal256(port: usize) { + let create_table_stmt = " +CREATE TABLE high_precision_decimal ( + decimal_values DECIMAL(50, 10) +); + "; + let insert_table_stmt = " +INSERT INTO high_precision_decimal (decimal_values) VALUES +(NULL), +(1234567890123456789012345678901234567890.1234567890), +(-9876543210987654321098765432109876543210.9876543210), +(0.0000000001), +(-0.000000001), +(0); + "; + + let schema = Arc::new(Schema::new(vec![Field::new( + "decimal_values", + DataType::Decimal256(50, 10), + true, + )])); + + let expected_record = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new( + Decimal256Array::from(vec![ + None, + Some( + i256::from_string("12345678901234567890123456789012345678901234567890") + .unwrap(), + ), + Some( + i256::from_string("-98765432109876543210987654321098765432109876543210") + .unwrap(), + ), + Some(i256::from_string("1").unwrap()), + Some(i256::from_string("-10").unwrap()), + Some(i256::from_string("0").unwrap()), + ]) + .with_precision_and_scale(50, 10) + .expect("Failed to create decimal256 array"), + )], + ) + .expect("Failed to created arrow record batch"); + + arrow_mysql_one_way( + port, + "high_precision_decimal", + create_table_stmt, + insert_table_stmt, + expected_record, + ) + .await; +} + +async fn test_mysql_decimal_types_to_decimal128(port: usize) { + let create_table_stmt = " + CREATE TABLE IF NOT EXISTS decimal_table (decimal_col DECIMAL(10, 2)); + "; + let insert_table_stmt = " + INSERT INTO decimal_table (decimal_col) VALUES (NULL), (12); + "; + + let schema = Arc::new(Schema::new(vec![Field::new( + "decimal_col", + DataType::Decimal128(10, 2), + true, + )])); + + let expected_record = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new( + Decimal128Array::from(vec![None, Some(i128::from(1200))]) + .with_precision_and_scale(10, 2) + .unwrap(), + )], + ) + .expect("Failed to created arrow record batch"); + + let _ = arrow_mysql_one_way( + port, + "decimal_table", + create_table_stmt, + insert_table_stmt, + expected_record, + ) + .await; +} + async fn arrow_mysql_one_way( port: usize, table_name: &str, @@ -432,11 +486,12 @@ async fn test_mysql_arrow_oneway() { let port = crate::get_random_port(); let mysql_container = start_mysql_container(port).await; - test_mysql_decimal_types(port).await; test_mysql_timestamp_types(port).await; test_mysql_datetime_types(port).await; - test_time_types(port).await; + test_mysql_time_types(port).await; test_mysql_string_types(port).await; + test_mysql_decimal_types_to_decimal128(port).await; + test_mysql_decimal_types_to_decimal256(port).await; mysql_container.remove().await.expect("container to stop"); } From 1b5af50250880d5473f57f1c09417e6a5e8cb7de Mon Sep 17 00:00:00 2001 From: Qianqian <130200611+Sevenannn@users.noreply.github.com> Date: Sat, 28 Sep 2024 10:18:38 -0700 Subject: [PATCH 31/40] Fix mysql blob & text types (#117) --- src/sql/arrow_sql_gen/mysql.rs | 74 ++++++++++++++++--- .../dbconnection/mysqlconn.rs | 39 +++++++--- tests/mysql/mod.rs | 63 ++++++++++++++++ 3 files changed, 154 insertions(+), 22 deletions(-) diff --git a/src/sql/arrow_sql_gen/mysql.rs b/src/sql/arrow_sql_gen/mysql.rs index b2728a0..8a9cd44 100644 --- a/src/sql/arrow_sql_gen/mysql.rs +++ b/src/sql/arrow_sql_gen/mysql.rs @@ -2,9 +2,9 @@ use crate::sql::arrow_sql_gen::arrow::map_data_type_to_array_builder_optional; use arrow::{ array::{ ArrayBuilder, ArrayRef, BinaryBuilder, Date32Builder, Decimal256Builder, Decimal128Builder, Float32Builder, - Float64Builder, Int16Builder, Int32Builder, Int64Builder, Int8Builder, LargeStringBuilder, - NullBuilder, RecordBatch, RecordBatchOptions, StringBuilder, Time64NanosecondBuilder, - TimestampMicrosecondBuilder, UInt64Builder, + Float64Builder, Int16Builder, Int32Builder, Int64Builder, Int8Builder, LargeBinaryBuilder, + LargeStringBuilder, NullBuilder, RecordBatch, RecordBatchOptions, + StringBuilder, Time64NanosecondBuilder, TimestampMicrosecondBuilder, UInt64Builder, }, datatypes::{i256, DataType, Date32Type, Field, Schema, SchemaRef, TimeUnit}, }; @@ -92,6 +92,7 @@ pub fn rows_to_arrow(rows: &[Row], projected_schema: &Option) -> Resu let mut mysql_types: Vec = Vec::new(); let mut column_names: Vec = Vec::new(); let mut column_is_binary_stats: Vec = Vec::new(); + let mut column_use_large_str_or_blob_stats: Vec = Vec::new(); if !rows.is_empty() { let row = &rows[0]; @@ -99,6 +100,7 @@ pub fn rows_to_arrow(rows: &[Row], projected_schema: &Option) -> Resu let column_name = column.name_str(); let column_type = column.column_type(); let column_is_binary = column.flags().contains(ColumnFlags::BINARY_FLAG); + let column_use_large_str_or_blob = column.column_length() > 2_u32.pow(31) - 1; let (decimal_precision, decimal_scale) = match column_type { ColumnType::MYSQL_TYPE_DECIMAL | ColumnType::MYSQL_TYPE_NEWDECIMAL => { @@ -118,6 +120,7 @@ pub fn rows_to_arrow(rows: &[Row], projected_schema: &Option) -> Resu let data_type = map_column_to_data_type( column_type, column_is_binary, + column_use_large_str_or_blob, decimal_precision, decimal_scale, ); @@ -132,6 +135,7 @@ pub fn rows_to_arrow(rows: &[Row], projected_schema: &Option) -> Resu mysql_types.push(column_type); column_names.push(column_name.to_string()); column_is_binary_stats.push(column_is_binary); + column_use_large_str_or_blob_stats.push(column_use_large_str_or_blob); } } @@ -296,10 +300,6 @@ pub fn rows_to_arrow(rows: &[Row], projected_schema: &Option) -> Resu } column_type @ (ColumnType::MYSQL_TYPE_VARCHAR | ColumnType::MYSQL_TYPE_JSON - | ColumnType::MYSQL_TYPE_TINY_BLOB - | ColumnType::MYSQL_TYPE_BLOB - | ColumnType::MYSQL_TYPE_MEDIUM_BLOB - | ColumnType::MYSQL_TYPE_LONG_BLOB | ColumnType::MYSQL_TYPE_ENUM) => { handle_primitive_type!( builder, @@ -310,6 +310,45 @@ pub fn rows_to_arrow(rows: &[Row], projected_schema: &Option) -> Resu i ); } + ColumnType::MYSQL_TYPE_BLOB => { + match ( + column_use_large_str_or_blob_stats[i], + column_is_binary_stats[i], + ) { + (true, true) => handle_primitive_type!( + builder, + ColumnType::MYSQL_TYPE_BLOB, + LargeBinaryBuilder, + Vec, + row, + i + ), + (true, false) => handle_primitive_type!( + builder, + ColumnType::MYSQL_TYPE_BLOB, + LargeStringBuilder, + String, + row, + i + ), + (false, true) => handle_primitive_type!( + builder, + ColumnType::MYSQL_TYPE_BLOB, + BinaryBuilder, + Vec, + row, + i + ), + (false, false) => handle_primitive_type!( + builder, + ColumnType::MYSQL_TYPE_BLOB, + StringBuilder, + String, + row, + i + ), + } + } column_type @ (ColumnType::MYSQL_TYPE_STRING | ColumnType::MYSQL_TYPE_VAR_STRING) => { if column_is_binary_stats[i] { @@ -424,6 +463,7 @@ pub fn rows_to_arrow(rows: &[Row], projected_schema: &Option) -> Resu pub fn map_column_to_data_type( column_type: ColumnType, column_is_binary: bool, + column_use_large_str_or_blob: bool, column_decimal_precision: Option, column_decimal_scale: Option, ) -> Option { @@ -453,11 +493,18 @@ pub fn map_column_to_data_type( ColumnType::MYSQL_TYPE_VARCHAR | ColumnType::MYSQL_TYPE_JSON | ColumnType::MYSQL_TYPE_ENUM - | ColumnType::MYSQL_TYPE_SET - | ColumnType::MYSQL_TYPE_TINY_BLOB - | ColumnType::MYSQL_TYPE_BLOB - | ColumnType::MYSQL_TYPE_MEDIUM_BLOB - | ColumnType::MYSQL_TYPE_LONG_BLOB => Some(DataType::LargeUtf8), + | ColumnType::MYSQL_TYPE_SET => Some(DataType::LargeUtf8), + // MYSQL_TYPE_BLOB includes TINYBLOB, BLOB, MEDIUMBLOB, LONGBLOB, TINYTEXT, TEXT, MEDIUMTEXT, LONGTEXT https://dev.mysql.com/doc/c-api/8.0/en/c-api-data-structures.html + // MySQL String Type Storage requirement: https://dev.mysql.com/doc/refman/8.4/en/storage-requirements.html + // Binary / Utf8 stores up to 2^31 - 1 length binary / non-binary string + ColumnType::MYSQL_TYPE_BLOB => { + match (column_use_large_str_or_blob, column_is_binary) { + (true, true) => Some(DataType::LargeBinary), + (true, false) => Some(DataType::LargeUtf8), + (false, true) => Some(DataType::Binary), + (false, false) => Some(DataType::Utf8), + } + } ColumnType::MYSQL_TYPE_STRING | ColumnType::MYSQL_TYPE_VAR_STRING => { if column_is_binary { @@ -475,6 +522,9 @@ pub fn map_column_to_data_type( | ColumnType::MYSQL_TYPE_TIMESTAMP2 | ColumnType::MYSQL_TYPE_DATETIME2 | ColumnType::MYSQL_TYPE_TIME2 + | ColumnType::MYSQL_TYPE_LONG_BLOB + | ColumnType::MYSQL_TYPE_TINY_BLOB + | ColumnType::MYSQL_TYPE_MEDIUM_BLOB | ColumnType::MYSQL_TYPE_GEOMETRY => { unimplemented!("Unsupported column type {:?}", column_type) } diff --git a/src/sql/db_connection_pool/dbconnection/mysqlconn.rs b/src/sql/db_connection_pool/dbconnection/mysqlconn.rs index 09832bd..98553f8 100644 --- a/src/sql/db_connection_pool/dbconnection/mysqlconn.rs +++ b/src/sql/db_connection_pool/dbconnection/mysqlconn.rs @@ -185,6 +185,7 @@ fn columns_meta_to_schema(columns_meta: Vec) -> Result { let column_type = map_str_type_to_column_type(&data_type)?; let column_is_binary = map_str_type_to_is_binary(&data_type); + let column_use_large_str_or_blob = map_str_type_to_use_large_str_or_blob(&data_type); let (precision, scale) = match column_type { ColumnType::MYSQL_TYPE_DECIMAL | ColumnType::MYSQL_TYPE_NEWDECIMAL => { @@ -195,9 +196,14 @@ fn columns_meta_to_schema(columns_meta: Vec) -> Result { _ => (None, None), }; - let arrow_data_type = - map_column_to_data_type(column_type, column_is_binary, precision, scale) - .context(UnsupportedDataTypeSnafu { data_type })?; + let arrow_data_type = map_column_to_data_type( + column_type, + column_is_binary, + column_use_large_str_or_blob, + precision, + scale, + ) + .context(UnsupportedDataTypeSnafu { data_type })?; fields.push(Field::new(&column_name, arrow_data_type, true)); } @@ -231,12 +237,12 @@ fn map_str_type_to_column_type(data_type: &str) -> Result { _ if data_type.starts_with("newdecimal") => ColumnType::MYSQL_TYPE_NEWDECIMAL, _ if data_type.starts_with("enum") => ColumnType::MYSQL_TYPE_ENUM, _ if data_type.starts_with("set") => ColumnType::MYSQL_TYPE_SET, - _ if data_type.starts_with("tinyblob") => ColumnType::MYSQL_TYPE_TINY_BLOB, - _ if data_type.starts_with("tinytext") => ColumnType::MYSQL_TYPE_TINY_BLOB, - _ if data_type.starts_with("mediumblob") => ColumnType::MYSQL_TYPE_MEDIUM_BLOB, - _ if data_type.starts_with("mediumtext") => ColumnType::MYSQL_TYPE_MEDIUM_BLOB, - _ if data_type.starts_with("longblob") => ColumnType::MYSQL_TYPE_LONG_BLOB, - _ if data_type.starts_with("longtext") => ColumnType::MYSQL_TYPE_LONG_BLOB, + _ if data_type.starts_with("tinyblob") => ColumnType::MYSQL_TYPE_BLOB, + _ if data_type.starts_with("tinytext") => ColumnType::MYSQL_TYPE_BLOB, + _ if data_type.starts_with("mediumblob") => ColumnType::MYSQL_TYPE_BLOB, + _ if data_type.starts_with("mediumtext") => ColumnType::MYSQL_TYPE_BLOB, + _ if data_type.starts_with("longblob") => ColumnType::MYSQL_TYPE_BLOB, + _ if data_type.starts_with("longtext") => ColumnType::MYSQL_TYPE_BLOB, _ if data_type.starts_with("blob") => ColumnType::MYSQL_TYPE_BLOB, _ if data_type.starts_with("text") => ColumnType::MYSQL_TYPE_BLOB, _ if data_type.starts_with("varchar") => ColumnType::MYSQL_TYPE_VAR_STRING, @@ -251,7 +257,20 @@ fn map_str_type_to_column_type(data_type: &str) -> Result { } fn map_str_type_to_is_binary(data_type: &str) -> bool { - if data_type.starts_with("binary") | data_type.starts_with("varbinary") { + if data_type.starts_with("binary") + | data_type.starts_with("varbinary") + | data_type.starts_with("tinyblob") + | data_type.starts_with("mediumblob") + | data_type.starts_with("blob") + | data_type.starts_with("longblob") + { + return true; + } + false +} + +fn map_str_type_to_use_large_str_or_blob(data_type: &str) -> bool { + if data_type.starts_with("long") { return true; } false diff --git a/tests/mysql/mod.rs b/tests/mysql/mod.rs index 0e73272..65210f6 100644 --- a/tests/mysql/mod.rs +++ b/tests/mysql/mod.rs @@ -272,6 +272,68 @@ VALUES .await; } +async fn test_mysql_blob_types(port: usize) { + let create_table_stmt = " +CREATE TABLE blobs_table ( + tinyblob_col TINYBLOB, + tinytext_col TINYTEXT, + mediumblob_col MEDIUMBLOB, + mediumtext_col MEDIUMTEXT, + blob_col BLOB, + text_col TEXT, + longblob_col LONGBLOB, + longtext_col LONGTEXT +); + "; + let insert_table_stmt = " +INSERT INTO blobs_table ( + tinyblob_col, tinytext_col, mediumblob_col, mediumtext_col, blob_col, text_col, longblob_col, longtext_col +) +VALUES + ( + 'small_blob', 'small_text', + 'medium_blob', 'medium_text', + 'larger_blob', 'larger_text', + 'very_large_blob', 'very_large_text' + ); + "; + + let schema = Arc::new(Schema::new(vec![ + Field::new("tinyblob_col", DataType::Binary, true), + Field::new("tinytext_col", DataType::Utf8, true), + Field::new("mediumblob_col", DataType::Binary, true), + Field::new("mediumtext_col", DataType::Utf8, true), + Field::new("blob_col", DataType::Binary, true), + Field::new("text_col", DataType::Utf8, true), + Field::new("longblob_col", DataType::LargeBinary, true), + Field::new("longtext_col", DataType::LargeUtf8, true), + ])); + + let expected_record = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(BinaryArray::from_vec(vec![b"small_blob"])), + Arc::new(StringArray::from(vec!["small_text"])), + Arc::new(BinaryArray::from_vec(vec![b"medium_blob"])), + Arc::new(StringArray::from(vec!["medium_text"])), + Arc::new(BinaryArray::from_vec(vec![b"larger_blob"])), + Arc::new(StringArray::from(vec!["larger_text"])), + Arc::new(LargeBinaryArray::from_vec(vec![b"very_large_blob"])), + Arc::new(LargeStringArray::from(vec!["very_large_text"])), + ], + ) + .expect("Failed to created arrow record batch"); + + arrow_mysql_one_way( + port, + "blobs_table", + create_table_stmt, + insert_table_stmt, + expected_record, + ) + .await; +} + async fn test_mysql_string_types(port: usize) { let create_table_stmt = " CREATE TABLE string_table ( @@ -489,6 +551,7 @@ async fn test_mysql_arrow_oneway() { test_mysql_timestamp_types(port).await; test_mysql_datetime_types(port).await; test_mysql_time_types(port).await; + test_mysql_blob_types(port).await; test_mysql_string_types(port).await; test_mysql_decimal_types_to_decimal128(port).await; test_mysql_decimal_types_to_decimal256(port).await; From 5e6f163a3e1e3c252f7e03d7ffc009bfb2c09e45 Mon Sep 17 00:00:00 2001 From: Qianqian <130200611+Sevenannn@users.noreply.github.com> Date: Sun, 29 Sep 2024 22:25:54 -0700 Subject: [PATCH 32/40] Add sqlite_busytimeout parameter as user configurable param (#121) * Add sqlite_busytimeout parameter as user configurable param * Remove debug log * Fix lint, fix integration test --------- Co-authored-by: Phillip LeBlanc --- src/sql/db_connection_pool/sqlitepool.rs | 63 ++++++++++++++++-------- src/sqlite.rs | 31 ++++++++++-- tests/sqlite/mod.rs | 12 +++-- 3 files changed, 79 insertions(+), 27 deletions(-) diff --git a/src/sql/db_connection_pool/sqlitepool.rs b/src/sql/db_connection_pool/sqlitepool.rs index 5b16cd0..6c22baf 100644 --- a/src/sql/db_connection_pool/sqlitepool.rs +++ b/src/sql/db_connection_pool/sqlitepool.rs @@ -1,4 +1,4 @@ -use std::sync::Arc; +use std::{sync::Arc, time::Duration}; use async_trait::async_trait; use snafu::{prelude::*, ResultExt}; @@ -26,14 +26,16 @@ pub struct SqliteConnectionPoolFactory { path: Arc, mode: Mode, attach_databases: Option>>, + busy_timeout: Duration, } impl SqliteConnectionPoolFactory { - pub fn new(path: &str, mode: Mode) -> Self { + pub fn new(path: &str, mode: Mode, busy_timeout: Duration) -> Self { SqliteConnectionPoolFactory { path: path.into(), mode, attach_databases: None, + busy_timeout, } } @@ -80,9 +82,14 @@ impl SqliteConnectionPoolFactory { vec![] }; - let pool = - SqliteConnectionPool::new(&self.path, self.mode, join_push_down, attach_databases) - .await?; + let pool = SqliteConnectionPool::new( + &self.path, + self.mode, + join_push_down, + attach_databases, + self.busy_timeout, + ) + .await?; pool.setup().await?; @@ -96,6 +103,7 @@ pub struct SqliteConnectionPool { mode: Mode, path: Arc, attach_databases: Vec>, + busy_timeout: Duration, } impl SqliteConnectionPool { @@ -113,6 +121,7 @@ impl SqliteConnectionPool { mode: Mode, join_push_down: JoinPushDown, attach_databases: Vec>, + busy_timeout: Duration, ) -> Result { let conn = match mode { Mode::Memory => Connection::open_in_memory() @@ -130,6 +139,7 @@ impl SqliteConnectionPool { mode, attach_databases, path: path.into(), + busy_timeout, }) } @@ -147,19 +157,23 @@ impl SqliteConnectionPool { pub async fn setup(&self) -> Result<()> { let conn = self.conn.clone(); + let busy_timeout = self.busy_timeout; // these configuration options are only applicable for file-mode databases if self.mode == Mode::File { // change transaction mode to Write-Ahead log instead of default atomic rollback journal: https://www.sqlite.org/wal.html // NOTE: This is a no-op if the database is in-memory, as only MEMORY or OFF are supported: https://www.sqlite.org/pragma.html#pragma_journal_mode - conn.call(|conn| { + conn.call(move |conn| { conn.pragma_update(None, "journal_mode", "WAL")?; - conn.pragma_update(None, "busy_timeout", "5000")?; conn.pragma_update(None, "synchronous", "NORMAL")?; conn.pragma_update(None, "cache_size", "-20000")?; conn.pragma_update(None, "foreign_keys", "true")?; conn.pragma_update(None, "temp_store", "memory")?; // conn.set_transaction_behavior(TransactionBehavior::Immediate); introduced in rustqlite 0.32.1, but tokio-rusqlite is still on 0.31.0 + + // Set user configurable connection timeout + conn.busy_timeout(busy_timeout)?; + Ok(()) }) .await @@ -212,6 +226,7 @@ impl SqliteConnectionPool { mode: self.mode, path: Arc::clone(&self.path), attach_databases: self.attach_databases.clone(), + busy_timeout: self.busy_timeout, }), Mode::File => { let attach_databases = if self.attach_databases.is_empty() { @@ -220,7 +235,7 @@ impl SqliteConnectionPool { Some(self.attach_databases.clone()) }; - SqliteConnectionPoolFactory::new(&self.path, self.mode) + SqliteConnectionPoolFactory::new(&self.path, self.mode, self.busy_timeout) .with_databases(attach_databases) .build() .await @@ -250,6 +265,7 @@ mod tests { use crate::sql::db_connection_pool::Mode; use rand::Rng; use rstest::rstest; + use std::time::Duration; fn random_db_name() -> String { let mut rng = rand::thread_rng(); @@ -266,7 +282,7 @@ mod tests { #[tokio::test] async fn test_sqlite_connection_pool_factory() { let db_name = random_db_name(); - let factory = SqliteConnectionPoolFactory::new(&db_name, Mode::File); + let factory = SqliteConnectionPoolFactory::new(&db_name, Mode::File, None); let pool = factory.build().await.unwrap(); assert!(pool.join_push_down == JoinPushDown::AllowedFor(db_name.clone())); @@ -285,10 +301,11 @@ mod tests { db_names.sort(); let factory = - SqliteConnectionPoolFactory::new(&db_names[0], Mode::File).with_databases(Some(vec![ - db_names[1].clone().into(), - db_names[2].clone().into(), - ])); + SqliteConnectionPoolFactory::new(&db_names[0], Mode::File, Duration::from_millis(5000)) + .with_databases(Some(vec![ + db_names[1].clone().into(), + db_names[2].clone().into(), + ])); SqliteConnectionPool::init(&db_names[1], Mode::File) .await @@ -317,7 +334,8 @@ mod tests { async fn test_sqlite_connection_pool_factory_with_empty_attachments() { let db_name = random_db_name(); let factory = - SqliteConnectionPoolFactory::new(&db_name, Mode::File).with_databases(Some(vec![])); + SqliteConnectionPoolFactory::new(&db_name, Mode::File, Duration::from_millis(5000)) + .with_databases(Some(vec![])); let pool = factory.build().await.unwrap(); @@ -333,8 +351,12 @@ mod tests { #[tokio::test] async fn test_sqlite_connection_pool_factory_memory_with_attachments() { - let factory = SqliteConnectionPoolFactory::new("./test.sqlite", Mode::Memory) - .with_databases(Some(vec!["./test1.sqlite".into(), "./test2.sqlite".into()])); + let factory = SqliteConnectionPoolFactory::new( + "./test.sqlite", + Mode::Memory, + Duration::from_millis(5000), + ) + .with_databases(Some(vec!["./test1.sqlite".into(), "./test2.sqlite".into()])); let pool = factory.build().await.unwrap(); assert!(pool.join_push_down == JoinPushDown::Disallow); @@ -355,10 +377,11 @@ mod tests { db_names.sort(); let factory = - SqliteConnectionPoolFactory::new(&db_names[0], Mode::File).with_databases(Some(vec![ - db_names[1].clone().into(), - db_names[2].clone().into(), - ])); + SqliteConnectionPoolFactory::new(&db_names[0], Mode::File, Duration::from_millis(5000)) + .with_databases(Some(vec![ + db_names[1].clone().into(), + db_names[2].clone().into(), + ])); let pool = factory.build().await; assert!(pool.is_err()); diff --git a/src/sqlite.rs b/src/sqlite.rs index 9f905c2..173937c 100644 --- a/src/sqlite.rs +++ b/src/sqlite.rs @@ -26,6 +26,7 @@ use rusqlite::{ToSql, Transaction}; use snafu::prelude::*; use sql_table::SQLiteTable; use std::collections::HashSet; +use std::time::Duration; use std::{collections::HashMap, sync::Arc}; use tokio::sync::Mutex; use tokio_rusqlite::Connection; @@ -93,6 +94,9 @@ pub enum Error { #[snafu(display("Unable to infer schema: {source}"))] UnableToInferSchema { source: dbconnection::Error }, + + #[snafu(display("Invalid SQLite busy_timeout value"))] + InvalidBusyTimeoutValue { value: String }, } type Result = std::result::Result; @@ -104,6 +108,7 @@ pub struct SqliteTableProviderFactory { const SQLITE_DB_PATH_PARAM: &str = "file"; const SQLITE_DB_BASE_FOLDER_PARAM: &str = "data_directory"; const SQLITE_ATTACH_DATABASES_PARAM: &str = "attach_databases"; +const SQLITE_BUSY_TIMEOUT_PARAM: &str = "sqlite_busy_timeout"; impl SqliteTableProviderFactory { #[must_use] @@ -139,10 +144,27 @@ impl SqliteTableProviderFactory { .unwrap_or(default_filepath) } + pub fn sqlite_busy_timeout(&self, options: &HashMap) -> Result { + let busy_timeout = options.get(SQLITE_BUSY_TIMEOUT_PARAM).cloned(); + match busy_timeout { + Some(busy_timeout) => { + let result: u64 = busy_timeout.parse().map_err(|_| { + InvalidBusyTimeoutValueSnafu { + value: busy_timeout, + } + .build() + })?; + Ok(Duration::from_millis(result)) + } + None => Ok(Duration::from_millis(5000)), + } + } + pub async fn get_or_init_instance( &self, db_path: impl Into>, mode: Mode, + busy_timeout: Duration, ) -> Result { let db_path = db_path.into(); let key = match mode { @@ -155,7 +177,7 @@ impl SqliteTableProviderFactory { return instance.try_clone().await.context(DbConnectionPoolSnafu); } - let pool = SqliteConnectionPoolFactory::new(&db_path, mode) + let pool = SqliteConnectionPoolFactory::new(&db_path, mode, busy_timeout) .build() .await .context(DbConnectionPoolSnafu)?; @@ -219,10 +241,13 @@ impl TableProviderFactory for SqliteTableProviderFactory { ); } + let busy_timeout = self + .sqlite_busy_timeout(&cmd.options) + .map_err(to_datafusion_error)?; let db_path: Arc = self.sqlite_file_path(&name, &cmd.options).into(); let pool: Arc = Arc::new( - self.get_or_init_instance(Arc::clone(&db_path), mode) + self.get_or_init_instance(Arc::clone(&db_path), mode, busy_timeout) .await .map_err(to_datafusion_error)?, ); @@ -234,7 +259,7 @@ impl TableProviderFactory for SqliteTableProviderFactory { // even though we setup SQLite to use WAL mode, the pool isn't really a pool so shares the same connection // and we can't have concurrent writes when sharing the same connection Arc::new( - self.get_or_init_instance(Arc::clone(&db_path), mode) + self.get_or_init_instance(Arc::clone(&db_path), mode, busy_timeout) .await .map_err(to_datafusion_error)?, ) diff --git a/tests/sqlite/mod.rs b/tests/sqlite/mod.rs index 74adbaf..0862e57 100644 --- a/tests/sqlite/mod.rs +++ b/tests/sqlite/mod.rs @@ -21,10 +21,14 @@ async fn arrow_sqlite_round_trip( tracing::debug!("Running tests on {table_name}"); let ctx = SessionContext::new(); - let pool = SqliteConnectionPoolFactory::new(":memory:", Mode::Memory) - .build() - .await - .expect("Sqlite connection pool to be created"); + let pool = SqliteConnectionPoolFactory::new( + ":memory:", + Mode::Memory, + std::time::Duration::from_millis(5000), + ) + .build() + .await + .expect("Sqlite connection pool to be created"); let conn = pool .connect() From d3a37bd1b7758117c376dcf402e1f923e7ff5459 Mon Sep 17 00:00:00 2001 From: Qianqian <130200611+Sevenannn@users.noreply.github.com> Date: Mon, 30 Sep 2024 11:52:19 -0700 Subject: [PATCH 33/40] Use arrow dictionary type for mysql enum type (#119) --- src/sql/arrow_sql_gen/arrow.rs | 6 +- src/sql/arrow_sql_gen/mysql.rs | 63 ++++++++++++++++--- .../dbconnection/mysqlconn.rs | 14 ++++- tests/mysql/mod.rs | 49 ++++++++++++++- 4 files changed, 118 insertions(+), 14 deletions(-) diff --git a/src/sql/arrow_sql_gen/arrow.rs b/src/sql/arrow_sql_gen/arrow.rs index eb2fbd0..f60da64 100644 --- a/src/sql/arrow_sql_gen/arrow.rs +++ b/src/sql/arrow_sql_gen/arrow.rs @@ -9,7 +9,7 @@ use arrow::{ TimestampNanosecondBuilder, TimestampSecondBuilder, UInt16Builder, UInt32Builder, UInt64Builder, UInt8Builder, }, - datatypes::{DataType, TimeUnit}, + datatypes::{DataType, TimeUnit, UInt16Type}, }; pub fn map_data_type_to_array_builder_optional( @@ -21,6 +21,7 @@ pub fn map_data_type_to_array_builder_optional( } } +#[allow(clippy::too_many_lines)] pub fn map_data_type_to_array_builder(data_type: &DataType) -> Box { match data_type { DataType::Int8 => Box::new(Int8Builder::new()), @@ -67,6 +68,9 @@ pub fn map_data_type_to_array_builder(data_type: &DataType) -> Box { Box::new(StringDictionaryBuilder::::new()) } + (DataType::UInt16, DataType::Utf8) => { + Box::new(StringDictionaryBuilder::::new()) + } _ => unimplemented!("Unimplemented dictionary type"), }, DataType::Date32 => Box::new(Date32Builder::new()), diff --git a/src/sql/arrow_sql_gen/mysql.rs b/src/sql/arrow_sql_gen/mysql.rs index 8a9cd44..bd4076c 100644 --- a/src/sql/arrow_sql_gen/mysql.rs +++ b/src/sql/arrow_sql_gen/mysql.rs @@ -4,9 +4,9 @@ use arrow::{ ArrayBuilder, ArrayRef, BinaryBuilder, Date32Builder, Decimal256Builder, Decimal128Builder, Float32Builder, Float64Builder, Int16Builder, Int32Builder, Int64Builder, Int8Builder, LargeBinaryBuilder, LargeStringBuilder, NullBuilder, RecordBatch, RecordBatchOptions, - StringBuilder, Time64NanosecondBuilder, TimestampMicrosecondBuilder, UInt64Builder, + StringBuilder, Time64NanosecondBuilder, StringDictionaryBuilder, TimestampMicrosecondBuilder, UInt64Builder, }, - datatypes::{i256, DataType, Date32Type, Field, Schema, SchemaRef, TimeUnit}, + datatypes::{i256, DataType, Date32Type, Field, Schema, SchemaRef, TimeUnit, UInt16Type}, }; use bigdecimal::BigDecimal; use chrono::{NaiveDate, NaiveTime, Timelike}; @@ -92,6 +92,7 @@ pub fn rows_to_arrow(rows: &[Row], projected_schema: &Option) -> Resu let mut mysql_types: Vec = Vec::new(); let mut column_names: Vec = Vec::new(); let mut column_is_binary_stats: Vec = Vec::new(); + let mut column_is_enum_stats: Vec = Vec::new(); let mut column_use_large_str_or_blob_stats: Vec = Vec::new(); if !rows.is_empty() { @@ -100,6 +101,7 @@ pub fn rows_to_arrow(rows: &[Row], projected_schema: &Option) -> Resu let column_name = column.name_str(); let column_type = column.column_type(); let column_is_binary = column.flags().contains(ColumnFlags::BINARY_FLAG); + let column_is_enum = column.flags().contains(ColumnFlags::ENUM_FLAG); let column_use_large_str_or_blob = column.column_length() > 2_u32.pow(31) - 1; let (decimal_precision, decimal_scale) = match column_type { @@ -120,6 +122,7 @@ pub fn rows_to_arrow(rows: &[Row], projected_schema: &Option) -> Resu let data_type = map_column_to_data_type( column_type, column_is_binary, + column_is_enum, column_use_large_str_or_blob, decimal_precision, decimal_scale, @@ -135,6 +138,7 @@ pub fn rows_to_arrow(rows: &[Row], projected_schema: &Option) -> Resu mysql_types.push(column_type); column_names.push(column_name.to_string()); column_is_binary_stats.push(column_is_binary); + column_is_enum_stats.push(column_is_enum); column_use_large_str_or_blob_stats.push(column_use_large_str_or_blob); } } @@ -299,8 +303,7 @@ pub fn rows_to_arrow(rows: &[Row], projected_schema: &Option) -> Resu } } column_type @ (ColumnType::MYSQL_TYPE_VARCHAR - | ColumnType::MYSQL_TYPE_JSON - | ColumnType::MYSQL_TYPE_ENUM) => { + | ColumnType::MYSQL_TYPE_JSON) => { handle_primitive_type!( builder, column_type, @@ -349,9 +352,40 @@ pub fn rows_to_arrow(rows: &[Row], projected_schema: &Option) -> Resu ), } } + ColumnType::MYSQL_TYPE_ENUM => { + // ENUM and SET values are returned as strings. For these, check that the type value is MYSQL_TYPE_STRING and that the ENUM_FLAG or SET_FLAG flag is set in the flags value. + // https://dev.mysql.com/doc/c-api/9.0/en/c-api-data-structures.html + unreachable!() + } column_type @ (ColumnType::MYSQL_TYPE_STRING | ColumnType::MYSQL_TYPE_VAR_STRING) => { - if column_is_binary_stats[i] { + // Handle MYSQL_TYPE_ENUM value + if column_is_enum_stats[i] { + let Some(builder) = builder else { + return NoBuilderForIndexSnafu { index: i }.fail(); + }; + let Some(builder) = builder + .as_any_mut() + .downcast_mut::>() + else { + return FailedToDowncastBuilderSnafu { + mysql_type: format!("{mysql_type:?}"), + } + .fail(); + }; + + let v = handle_null_error(row.get_opt::(i).transpose()) + .context(FailedToGetRowValueSnafu { + mysql_type: ColumnType::MYSQL_TYPE_ENUM, + })?; + + match v { + Some(v) => { + builder.append_value(v); + } + None => builder.append_null(), + } + } else if column_is_binary_stats[i] { handle_primitive_type!( builder, column_type, @@ -361,7 +395,14 @@ pub fn rows_to_arrow(rows: &[Row], projected_schema: &Option) -> Resu i ); } else { - handle_primitive_type!(builder, column_type, StringBuilder, String, row, i); + handle_primitive_type!( + builder, + column_type, + StringBuilder, + String, + row, + i + ); } } ColumnType::MYSQL_TYPE_DATE => { @@ -463,6 +504,7 @@ pub fn rows_to_arrow(rows: &[Row], projected_schema: &Option) -> Resu pub fn map_column_to_data_type( column_type: ColumnType, column_is_binary: bool, + column_is_enum: bool, column_use_large_str_or_blob: bool, column_decimal_precision: Option, column_decimal_scale: Option, @@ -491,9 +533,7 @@ pub fn map_column_to_data_type( Some(DataType::Time64(TimeUnit::Nanosecond)) } ColumnType::MYSQL_TYPE_VARCHAR - | ColumnType::MYSQL_TYPE_JSON - | ColumnType::MYSQL_TYPE_ENUM - | ColumnType::MYSQL_TYPE_SET => Some(DataType::LargeUtf8), + | ColumnType::MYSQL_TYPE_JSON => Some(DataType::LargeUtf8), // MYSQL_TYPE_BLOB includes TINYBLOB, BLOB, MEDIUMBLOB, LONGBLOB, TINYTEXT, TEXT, MEDIUMTEXT, LONGTEXT https://dev.mysql.com/doc/c-api/8.0/en/c-api-data-structures.html // MySQL String Type Storage requirement: https://dev.mysql.com/doc/refman/8.4/en/storage-requirements.html // Binary / Utf8 stores up to 2^31 - 1 length binary / non-binary string @@ -505,9 +545,12 @@ pub fn map_column_to_data_type( (false, false) => Some(DataType::Utf8), } } + ColumnType::MYSQL_TYPE_ENUM | ColumnType::MYSQL_TYPE_SET => unreachable!(), ColumnType::MYSQL_TYPE_STRING | ColumnType::MYSQL_TYPE_VAR_STRING => { - if column_is_binary { + if column_is_enum { + Some(DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Utf8))) + } else if column_is_binary { Some(DataType::Binary) } else { Some(DataType::Utf8) diff --git a/src/sql/db_connection_pool/dbconnection/mysqlconn.rs b/src/sql/db_connection_pool/dbconnection/mysqlconn.rs index 98553f8..81eba80 100644 --- a/src/sql/db_connection_pool/dbconnection/mysqlconn.rs +++ b/src/sql/db_connection_pool/dbconnection/mysqlconn.rs @@ -185,6 +185,7 @@ fn columns_meta_to_schema(columns_meta: Vec) -> Result { let column_type = map_str_type_to_column_type(&data_type)?; let column_is_binary = map_str_type_to_is_binary(&data_type); + let column_is_enum = map_str_type_to_is_enum(&data_type); let column_use_large_str_or_blob = map_str_type_to_use_large_str_or_blob(&data_type); let (precision, scale) = match column_type { @@ -199,6 +200,7 @@ fn columns_meta_to_schema(columns_meta: Vec) -> Result { let arrow_data_type = map_column_to_data_type( column_type, column_is_binary, + column_is_enum, column_use_large_str_or_blob, precision, scale, @@ -235,8 +237,9 @@ fn map_str_type_to_column_type(data_type: &str) -> Result { _ if data_type.starts_with("array") => ColumnType::MYSQL_TYPE_TYPED_ARRAY, _ if data_type.starts_with("json") => ColumnType::MYSQL_TYPE_JSON, _ if data_type.starts_with("newdecimal") => ColumnType::MYSQL_TYPE_NEWDECIMAL, - _ if data_type.starts_with("enum") => ColumnType::MYSQL_TYPE_ENUM, - _ if data_type.starts_with("set") => ColumnType::MYSQL_TYPE_SET, + // MySQL ENUM & SET value is exported as MYSQL_TYPE_STRING under c api: https://dev.mysql.com/doc/c-api/9.0/en/c-api-data-structures.html + _ if data_type.starts_with("enum") => ColumnType::MYSQL_TYPE_STRING, + _ if data_type.starts_with("set") => ColumnType::MYSQL_TYPE_STRING, _ if data_type.starts_with("tinyblob") => ColumnType::MYSQL_TYPE_BLOB, _ if data_type.starts_with("tinytext") => ColumnType::MYSQL_TYPE_BLOB, _ if data_type.starts_with("mediumblob") => ColumnType::MYSQL_TYPE_BLOB, @@ -276,6 +279,13 @@ fn map_str_type_to_use_large_str_or_blob(data_type: &str) -> bool { false } +fn map_str_type_to_is_enum(data_type: &str) -> bool { + if data_type.starts_with("enum") { + return true; + } + false +} + fn extract_decimal_precision_and_scale(data_type: &str) -> Result<(u8, i8)> { let (start, end) = match (data_type.find('('), data_type.find(')')) { (Some(start), Some(end)) => (start, end), diff --git a/tests/mysql/mod.rs b/tests/mysql/mod.rs index 65210f6..962370c 100644 --- a/tests/mysql/mod.rs +++ b/tests/mysql/mod.rs @@ -7,7 +7,7 @@ use std::sync::Arc; use arrow::{ array::*, - datatypes::{i256, DataType, Field, Schema, TimeUnit}, + datatypes::{i256, DataType, Field, Schema, TimeUnit, UInt16Type}, }; use datafusion_table_providers::sql::db_connection_pool::dbconnection::AsyncDbConnection; @@ -272,6 +272,52 @@ VALUES .await; } +async fn test_mysql_enum_types(port: usize) { + let create_table_stmt = " +CREATE TABLE enum_table ( + status ENUM('active', 'inactive', 'pending', 'suspended') +); + "; + let insert_table_stmt = " +INSERT INTO enum_table (status) +VALUES +(NULL), +('active'), +('inactive'), +('pending'), +('suspended'), +('inactive'); + "; + + let mut builder = StringDictionaryBuilder::::new(); + builder.append_null(); + builder.append_value("active"); + builder.append_value("inactive"); + builder.append_value("pending"); + builder.append_value("suspended"); + builder.append_value("inactive"); + + let array: DictionaryArray = builder.finish(); + + let schema = Arc::new(Schema::new(vec![Field::new( + "status", + DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Utf8)), + true, + )])); + + let expected_record = RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(array)]) + .expect("Failed to created arrow dictionary array record batch"); + + arrow_mysql_one_way( + port, + "enum_table", + create_table_stmt, + insert_table_stmt, + expected_record, + ) + .await; +} + async fn test_mysql_blob_types(port: usize) { let create_table_stmt = " CREATE TABLE blobs_table ( @@ -551,6 +597,7 @@ async fn test_mysql_arrow_oneway() { test_mysql_timestamp_types(port).await; test_mysql_datetime_types(port).await; test_mysql_time_types(port).await; + test_mysql_enum_types(port).await; test_mysql_blob_types(port).await; test_mysql_string_types(port).await; test_mysql_decimal_types_to_decimal128(port).await; From 7595f8c4f5a80b569d15b136682d64779cbf0e60 Mon Sep 17 00:00:00 2001 From: Qianqian <130200611+Sevenannn@users.noreply.github.com> Date: Mon, 30 Sep 2024 12:19:52 -0700 Subject: [PATCH 34/40] Remove prefix for sqlite busy timeout param (#123) --- src/sqlite.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sqlite.rs b/src/sqlite.rs index 173937c..520bcca 100644 --- a/src/sqlite.rs +++ b/src/sqlite.rs @@ -108,7 +108,7 @@ pub struct SqliteTableProviderFactory { const SQLITE_DB_PATH_PARAM: &str = "file"; const SQLITE_DB_BASE_FOLDER_PARAM: &str = "data_directory"; const SQLITE_ATTACH_DATABASES_PARAM: &str = "attach_databases"; -const SQLITE_BUSY_TIMEOUT_PARAM: &str = "sqlite_busy_timeout"; +const SQLITE_BUSY_TIMEOUT_PARAM: &str = "busy_timeout"; impl SqliteTableProviderFactory { #[must_use] From e6717ecc3a924f75a452310ae40100faa670cb68 Mon Sep 17 00:00:00 2001 From: Qianqian <130200611+Sevenannn@users.noreply.github.com> Date: Mon, 30 Sep 2024 16:10:05 -0700 Subject: [PATCH 35/40] Support parsing sqlite busy_timeout durations with units (#124) --- Cargo.toml | 1 + src/sqlite.rs | 15 ++++++++------- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index f01735f..95dc111 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -74,6 +74,7 @@ tonic = { version = "0.12.2", optional = true } itertools = "0.13.0" dyn-clone = { version = "1.0.17", optional = true } geo-types = "0.7.13" +fundu = "2.0.1" [dev-dependencies] anyhow = "1.0.86" diff --git a/src/sqlite.rs b/src/sqlite.rs index 520bcca..c3b74ae 100644 --- a/src/sqlite.rs +++ b/src/sqlite.rs @@ -97,6 +97,11 @@ pub enum Error { #[snafu(display("Invalid SQLite busy_timeout value"))] InvalidBusyTimeoutValue { value: String }, + + #[snafu(display( + "Unable to parse SQLite busy_timeout parameter, ensure it is a valid duration" + ))] + UnableToParseBusyTimeoutParameter { source: fundu::ParseError }, } type Result = std::result::Result; @@ -148,13 +153,9 @@ impl SqliteTableProviderFactory { let busy_timeout = options.get(SQLITE_BUSY_TIMEOUT_PARAM).cloned(); match busy_timeout { Some(busy_timeout) => { - let result: u64 = busy_timeout.parse().map_err(|_| { - InvalidBusyTimeoutValueSnafu { - value: busy_timeout, - } - .build() - })?; - Ok(Duration::from_millis(result)) + let duration = fundu::parse_duration(&busy_timeout) + .context(UnableToParseBusyTimeoutParameterSnafu)?; + Ok(duration) } None => Ok(Duration::from_millis(5000)), } From a6b144679953fd665d71c04c3bdf7187e439a2c6 Mon Sep 17 00:00:00 2001 From: Sergei Grebnov Date: Wed, 2 Oct 2024 12:12:15 -0700 Subject: [PATCH 36/40] Support retries when writing data to SQLite (#125) --- src/sqlite/write.rs | 19 ++++++++++++------- src/util/retriable_error.rs | 16 ++++++++++++++++ 2 files changed, 28 insertions(+), 7 deletions(-) diff --git a/src/sqlite/write.rs b/src/sqlite/write.rs index 27c911a..fa267f0 100644 --- a/src/sqlite/write.rs +++ b/src/sqlite/write.rs @@ -19,7 +19,9 @@ use futures::StreamExt; use snafu::prelude::*; use crate::util::{ - constraints, on_conflict::OnConflict, retriable_error::check_and_mark_retriable_error, + constraints, + on_conflict::OnConflict, + retriable_error::{check_and_mark_retriable_error, to_retriable_data_write_error}, }; use super::{to_datafusion_error, Sqlite}; @@ -126,8 +128,13 @@ impl DataSink for SqliteDataSink { let (notify_commit_transaction, mut on_commit_transaction) = tokio::sync::oneshot::channel(); - let mut db_conn = self.sqlite.connect().await.map_err(to_datafusion_error)?; - let sqlite_conn = Sqlite::sqlite_conn(&mut db_conn).map_err(to_datafusion_error)?; + let mut db_conn = self + .sqlite + .connect() + .await + .map_err(to_retriable_data_write_error)?; + let sqlite_conn = + Sqlite::sqlite_conn(&mut db_conn).map_err(to_retriable_data_write_error)?; let constraints = self.sqlite.constraints().clone(); let mut data = data; @@ -191,11 +198,9 @@ impl DataSink for SqliteDataSink { }) .await .context(super::UnableToInsertIntoTableAsyncSnafu) - .map_err(to_datafusion_error)?; + .map_err(to_retriable_data_write_error)?; - let num_rows = task.await.map_err(|err| { - DataFusionError::Execution(format!("Error sending data batch: {err}")) - })??; + let num_rows = task.await.map_err(to_retriable_data_write_error)??; Ok(num_rows) } diff --git a/src/util/retriable_error.rs b/src/util/retriable_error.rs index 7156233..a9a0a4c 100644 --- a/src/util/retriable_error.rs +++ b/src/util/retriable_error.rs @@ -1,3 +1,5 @@ +use std::error::Error; + use datafusion::error::DataFusionError; use snafu::Snafu; @@ -7,6 +9,11 @@ pub enum RetriableError { DataRetrievalError { source: datafusion::error::DataFusionError, }, + + #[snafu(display("{source}"))] + DataWriteError { + source: Box + }, } #[must_use] @@ -36,6 +43,15 @@ pub fn check_and_mark_retriable_error(err: DataFusionError) -> DataFusionError { DataFusionError::External(Box::new(RetriableError::DataRetrievalError { source: err })) } +// Wraps error as `RetriableError::DataWriteError` so we can detect this error and retry later at a higher level +#[must_use] +pub fn to_retriable_data_write_error(error: E) -> DataFusionError +where + E: Error + Send + Sync + 'static, +{ + DataFusionError::External(Box::new(RetriableError::DataWriteError { source: error.into() })) +} + fn is_invalid_query_error(error: &DataFusionError) -> bool { match error { DataFusionError::Context(_, err) => is_invalid_query_error(err.as_ref()), From 409b6af0a96372c820d1b6785a99d2fccbee4a15 Mon Sep 17 00:00:00 2001 From: Qianqian <130200611+Sevenannn@users.noreply.github.com> Date: Wed, 2 Oct 2024 14:49:00 -0700 Subject: [PATCH 37/40] Implement write retry for DuckDB (#128) --- src/duckdb/write.rs | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/src/duckdb/write.rs b/src/duckdb/write.rs index 1d076e4..6d96acb 100644 --- a/src/duckdb/write.rs +++ b/src/duckdb/write.rs @@ -2,7 +2,9 @@ use std::{any::Any, fmt, sync::Arc}; use crate::duckdb::DuckDB; use crate::util::{ - constraints, on_conflict::OnConflict, retriable_error::check_and_mark_retriable_error, + constraints, + on_conflict::OnConflict, + retriable_error::{check_and_mark_retriable_error, to_retriable_data_write_error}, }; use arrow::{array::RecordBatch, datatypes::SchemaRef}; use async_trait::async_trait; @@ -138,9 +140,12 @@ impl DataSink for DuckDBDataSink { let duckdb_write_handle: JoinHandle> = tokio::task::spawn_blocking(move || { - let mut db_conn = duckdb.connect_sync().map_err(to_datafusion_error)?; + let mut db_conn = duckdb + .connect_sync() + .map_err(to_retriable_data_write_error)?; - let duckdb_conn = DuckDB::duckdb_conn(&mut db_conn).map_err(to_datafusion_error)?; + let duckdb_conn = + DuckDB::duckdb_conn(&mut db_conn).map_err(to_retriable_data_write_error)?; let tx = duckdb_conn .conn @@ -159,15 +164,13 @@ impl DataSink for DuckDBDataSink { )?, }; - if on_commit_transaction.try_recv().is_err() { - return Err(DataFusionError::Execution( - "No message to commit transaction has been received.".to_string(), - )); - } + on_commit_transaction + .try_recv() + .map_err(to_retriable_data_write_error)?; tx.commit() .context(super::UnableToCommitTransactionSnafu) - .map_err(to_datafusion_error)?; + .map_err(to_retriable_data_write_error)?; Ok(num_rows) }); From ee073c74afa737c48dd68d464af347e5340a7083 Mon Sep 17 00:00:00 2001 From: Sergei Grebnov Date: Wed, 9 Oct 2024 10:55:44 -0700 Subject: [PATCH 38/40] Preserve records batch order (update datafusion-federation) (#130) --- Cargo.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Cargo.toml b/Cargo.toml index 95dc111..cbf5a71 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -115,3 +115,4 @@ postgres-federation = ["postgres"] [patch.crates-io] duckdb = { git = "https://github.com/spiceai/duckdb-rs.git", rev = "5b98603705a381ceeb5cc371e4f606b7332b57ce" } + From 441510c04ad5b50aa3200db1bb4042063adc35fb Mon Sep 17 00:00:00 2001 From: hozan23 Date: Fri, 11 Oct 2024 12:30:01 +0200 Subject: [PATCH 39/40] cargo fmt --- src/sql/arrow_sql_gen/mysql.rs | 80 +++++++++++++++++----------------- src/util/retriable_error.rs | 6 ++- 2 files changed, 43 insertions(+), 43 deletions(-) diff --git a/src/sql/arrow_sql_gen/mysql.rs b/src/sql/arrow_sql_gen/mysql.rs index bd4076c..b346015 100644 --- a/src/sql/arrow_sql_gen/mysql.rs +++ b/src/sql/arrow_sql_gen/mysql.rs @@ -1,20 +1,21 @@ use crate::sql::arrow_sql_gen::arrow::map_data_type_to_array_builder_optional; use arrow::{ array::{ - ArrayBuilder, ArrayRef, BinaryBuilder, Date32Builder, Decimal256Builder, Decimal128Builder, Float32Builder, - Float64Builder, Int16Builder, Int32Builder, Int64Builder, Int8Builder, LargeBinaryBuilder, - LargeStringBuilder, NullBuilder, RecordBatch, RecordBatchOptions, - StringBuilder, Time64NanosecondBuilder, StringDictionaryBuilder, TimestampMicrosecondBuilder, UInt64Builder, + ArrayBuilder, ArrayRef, BinaryBuilder, Date32Builder, Decimal128Builder, Decimal256Builder, + Float32Builder, Float64Builder, Int16Builder, Int32Builder, Int64Builder, Int8Builder, + LargeBinaryBuilder, LargeStringBuilder, NullBuilder, RecordBatch, RecordBatchOptions, + StringBuilder, StringDictionaryBuilder, Time64NanosecondBuilder, + TimestampMicrosecondBuilder, UInt64Builder, }, datatypes::{i256, DataType, Date32Type, Field, Schema, SchemaRef, TimeUnit, UInt16Type}, }; use bigdecimal::BigDecimal; +use bigdecimal::ToPrimitive; use chrono::{NaiveDate, NaiveTime, Timelike}; use mysql_async::{consts::ColumnFlags, consts::ColumnType, FromValueError, Row, Value}; use snafu::{ResultExt, Snafu}; use std::{convert, sync::Arc}; use time::PrimitiveDateTime; -use bigdecimal::ToPrimitive; #[derive(Debug, Snafu)] pub enum Error { @@ -148,7 +149,7 @@ pub fn rows_to_arrow(rows: &[Row], projected_schema: &Option) -> Resu let Some(builder) = arrow_columns_builders.get_mut(i) else { return NoBuilderForIndexSnafu { index: i }.fail(); }; - + match *mysql_type { ColumnType::MYSQL_TYPE_NULL => { let Some(builder) = builder else { @@ -247,63 +248,67 @@ pub fn rows_to_arrow(rows: &[Row], projected_schema: &Option) -> Resu match arrow_field.data_type() { DataType::Decimal128(_, _) => { - let Some(builder) = builder.as_any_mut().downcast_mut::() + let Some(builder) = + builder.as_any_mut().downcast_mut::() else { return FailedToDowncastBuilderSnafu { mysql_type: format!("{mysql_type:?}"), - } - .fail(); - }; - let val = handle_null_error(row.get_opt::(i).transpose()) - .context(FailedToGetRowValueSnafu { - mysql_type: ColumnType::MYSQL_TYPE_DECIMAL, - })?; - + } + .fail(); + }; + let val = + handle_null_error(row.get_opt::(i).transpose()) + .context(FailedToGetRowValueSnafu { + mysql_type: ColumnType::MYSQL_TYPE_DECIMAL, + })?; + let scale = match &val { Some(val) => val.fractional_digit_count(), None => 0, }; - + let Some(val) = val else { builder.append_null(); continue; }; - + let Some(val) = to_decimal_128(&val, scale) else { - return FailedToConvertBigDecimalToI128Snafu { big_decimal: val }.fail(); + return FailedToConvertBigDecimalToI128Snafu { big_decimal: val } + .fail(); }; - + builder.append_value(val); } DataType::Decimal256(_, _) => { - let Some(builder) = builder.as_any_mut().downcast_mut::() + let Some(builder) = + builder.as_any_mut().downcast_mut::() else { return FailedToDowncastBuilderSnafu { mysql_type: format!("{mysql_type:?}"), } .fail(); }; - - let val = handle_null_error(row.get_opt::(i).transpose()) - .context(FailedToGetRowValueSnafu { - mysql_type: ColumnType::MYSQL_TYPE_DECIMAL, - })?; - + + let val = + handle_null_error(row.get_opt::(i).transpose()) + .context(FailedToGetRowValueSnafu { + mysql_type: ColumnType::MYSQL_TYPE_DECIMAL, + })?; + let Some(val) = val else { builder.append_null(); continue; }; - + let val = to_decimal_256(&val); - + builder.append_value(val); } // ColumnType::MYSQL_TYPE_DECIMAL & ColumnType::MYSQL_TYPE_NEWDECIMAL are only mapped to Decimal128/Decimal256 in `map_column_to_data_type` function - _ => unreachable!() + _ => unreachable!(), } } - column_type @ (ColumnType::MYSQL_TYPE_VARCHAR - | ColumnType::MYSQL_TYPE_JSON) => { + column_type @ (ColumnType::MYSQL_TYPE_VARCHAR | ColumnType::MYSQL_TYPE_JSON) => { handle_primitive_type!( builder, column_type, @@ -395,14 +400,7 @@ pub fn rows_to_arrow(rows: &[Row], projected_schema: &Option) -> Resu i ); } else { - handle_primitive_type!( - builder, - column_type, - StringBuilder, - String, - row, - i - ); + handle_primitive_type!(builder, column_type, StringBuilder, String, row, i); } } ColumnType::MYSQL_TYPE_DATE => { @@ -522,7 +520,7 @@ pub fn map_column_to_data_type( ColumnType::MYSQL_TYPE_DECIMAL | ColumnType::MYSQL_TYPE_NEWDECIMAL => { if column_decimal_precision.unwrap_or_default() > 38 { return Some(DataType::Decimal256(column_decimal_precision.unwrap_or_default(), column_decimal_scale.unwrap_or_default())); - } + } Some(DataType::Decimal128(column_decimal_precision.unwrap_or_default(), column_decimal_scale.unwrap_or_default())) }, ColumnType::MYSQL_TYPE_TIMESTAMP | ColumnType::MYSQL_TYPE_DATETIME => { @@ -565,7 +563,7 @@ pub fn map_column_to_data_type( | ColumnType::MYSQL_TYPE_TIMESTAMP2 | ColumnType::MYSQL_TYPE_DATETIME2 | ColumnType::MYSQL_TYPE_TIME2 - | ColumnType::MYSQL_TYPE_LONG_BLOB + | ColumnType::MYSQL_TYPE_LONG_BLOB | ColumnType::MYSQL_TYPE_TINY_BLOB | ColumnType::MYSQL_TYPE_MEDIUM_BLOB | ColumnType::MYSQL_TYPE_GEOMETRY => { diff --git a/src/util/retriable_error.rs b/src/util/retriable_error.rs index a9a0a4c..f01a9e5 100644 --- a/src/util/retriable_error.rs +++ b/src/util/retriable_error.rs @@ -12,7 +12,7 @@ pub enum RetriableError { #[snafu(display("{source}"))] DataWriteError { - source: Box + source: Box, }, } @@ -49,7 +49,9 @@ pub fn to_retriable_data_write_error(error: E) -> DataFusionError where E: Error + Send + Sync + 'static, { - DataFusionError::External(Box::new(RetriableError::DataWriteError { source: error.into() })) + DataFusionError::External(Box::new(RetriableError::DataWriteError { + source: error.into(), + })) } fn is_invalid_query_error(error: &DataFusionError) -> bool { From 2cf7ade22ef1e5f8d4b8e191882504ca47847193 Mon Sep 17 00:00:00 2001 From: hozan23 Date: Fri, 11 Oct 2024 12:31:15 +0200 Subject: [PATCH 40/40] update duckdb version to 1.1.1 --- Cargo.toml | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index cbf5a71..8420d20 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,7 +28,7 @@ datafusion-physical-expr = { version = "42.0.0", optional = true } datafusion-physical-plan = { version = "42.0.0", optional = true } datafusion-proto = { version = "42.0.0", optional = true } datafusion-federation = { version = "0.3.0", features = ["sql"] } -duckdb = { version = "1", features = [ +duckdb = { version = "1.1.1", features = [ "bundled", "r2d2", "vtab", @@ -113,6 +113,4 @@ duckdb-federation = ["duckdb"] sqlite-federation = ["sqlite"] postgres-federation = ["postgres"] -[patch.crates-io] -duckdb = { git = "https://github.com/spiceai/duckdb-rs.git", rev = "5b98603705a381ceeb5cc371e4f606b7332b57ce" }