Skip to content

Commit

Permalink
Merge pull request #112 from hozan23/rebase-spiceai-main-branches
Browse files Browse the repository at this point in the history
Rebase spiceai onto main branch
  • Loading branch information
phillipleblanc authored Oct 11, 2024
2 parents 55cfc67 + 9332074 commit 8f2a63f
Show file tree
Hide file tree
Showing 30 changed files with 2,179 additions and 414 deletions.
1 change: 1 addition & 0 deletions .github/workflows/pr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ on:
pull_request:
branches:
- main
- spiceai

jobs:
lint:
Expand Down
42 changes: 29 additions & 13 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,25 @@ license = "Apache-2.0"
description = "Extend the capabilities of DataFusion to support additional data sources via implementations of the `TableProvider` trait."

[dependencies]
arrow = "53.1.0"
arrow-array = { version = "53.1.0", optional = true }
arrow-cast = { version = "53.1.0", optional = true }
arrow-flight = { version = "53.1.0", optional = true, features = ["flight-sql-experimental", "tls"] }
arrow-schema = { version = "53.1.0", optional = true, features = ["serde"] }
arrow-json = "53.1.0"
arrow = "53"
arrow-array = { version = "53", optional = true }
arrow-cast = { version = "53", optional = true }
arrow-flight = { version = "53", optional = true, features = ["flight-sql-experimental", "tls"] }
arrow-schema = { version = "53", optional = true, features = ["serde"] }
arrow-json = "53"
async-stream = { version = "0.3.5", optional = true }
async-trait = "0.1.80"
num-bigint = "0.4.4"
bigdecimal = "0.4.5"
bigdecimal_0_3_0 = { package = "bigdecimal", version = "0.3.0" }
byteorder = "1.5.0"
chrono = "0.4.38"
datafusion = "42.0.0"
datafusion-expr = { version = "42.0.0", optional = true }
datafusion-physical-expr = { version = "42.0.0", optional = true }
datafusion-physical-plan = { version = "42.0.0", optional = true }
datafusion-proto = { version = "42.0.0", optional = true }
datafusion-federation = { version = "0.3.0", features = ["sql"] }
duckdb = { version = "1", features = [
datafusion-federation = { version = "0.3.0", features = ["sql"] }
duckdb = { version = "1.1.1", features = [
"bundled",
"r2d2",
"vtab",
Expand All @@ -37,16 +36,28 @@ duckdb = { version = "1", features = [
fallible-iterator = "0.3.0"
futures = "0.3.30"
mysql_async = { version = "0.34.1", features = ["native-tls-tls", "chrono"], optional = true }
prost = { version = "0.13.2", optional = true }
r2d2 = { version = "0.8.10", optional = true }
rusqlite = { version = "0.31.0", optional = true }
sea-query = { version = "0.31.0", features = ["backend-sqlite", "backend-postgres", "postgres-array", "with-rust_decimal", "with-bigdecimal", "with-time", "with-chrono"] }
sea-query = { version = "0.32.0-rc.1", features = [
"backend-sqlite",
"backend-postgres",
"postgres-array",
"with-rust_decimal",
"with-bigdecimal",
"with-time",
"with-chrono"] }
secrecy = "0.8.0"
serde = { version = "1.0.209", optional = true }
serde_json = "1.0.124"
snafu = "0.8.3"
time = "0.3.36"
tokio = { version = "1.38.0", features = ["macros", "fs"] }
tokio-postgres = { version = "0.7.10", features = ["with-chrono-0_4", "with-uuid-1", "with-serde_json-1", "with-geo-types-0_7"], optional = true }
tokio-postgres = { version = "0.7.10", features = [
"with-chrono-0_4",
"with-uuid-1",
"with-serde_json-1",
"with-geo-types-0_7"], optional = true }
tracing = "0.1.40"
uuid = { version = "1.9.1", optional = true }
postgres-native-tls = { version = "0.5.0", optional = true }
Expand All @@ -57,9 +68,11 @@ trust-dns-resolver = "0.23.2"
url = "2.5.1"
pem = { version = "3.0.4", optional = true }
tokio-rusqlite = { version = "0.5.1", optional = true }
tonic = { version = "0.12", optional = true }
tonic = { version = "0.12.2", optional = true }
itertools = "0.13.0"
dyn-clone = { version = "1.0.17", optional = true }
geo-types = "0.7.13"
fundu = "2.0.1"

[dev-dependencies]
anyhow = "1.0.86"
Expand All @@ -79,7 +92,7 @@ prost = { version = "0.13"}
mysql = ["dep:mysql_async", "dep:async-stream"]
postgres = ["dep:tokio-postgres", "dep:uuid", "dep:postgres-native-tls", "dep:bb8", "dep:bb8-postgres", "dep:native-tls", "dep:pem", "dep:async-stream"]
sqlite = ["dep:rusqlite", "dep:tokio-rusqlite"]
duckdb = ["dep:duckdb", "dep:r2d2", "dep:uuid"]
duckdb = ["dep:duckdb", "dep:r2d2", "dep:uuid", "dep:dyn-clone", "dep:async-stream"]
flight = [
"dep:arrow-array",
"dep:arrow-cast",
Expand All @@ -94,3 +107,6 @@ flight = [
]
duckdb-federation = ["duckdb"]
sqlite-federation = ["sqlite"]
postgres-federation = ["postgres"]


2 changes: 1 addition & 1 deletion examples/duckdb_external_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use duckdb::AccessMode;
/// DuckDB-backed tables can be created at runtime.
#[tokio::main]
async fn main() {
let duckdb = Arc::new(DuckDBTableProviderFactory::new().access_mode(AccessMode::ReadWrite));
let duckdb = Arc::new(DuckDBTableProviderFactory::new(AccessMode::ReadWrite));

let runtime = Arc::new(RuntimeEnv::default());
let state = SessionStateBuilder::new()
Expand Down
98 changes: 66 additions & 32 deletions src/duckdb.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
use crate::sql::db_connection_pool::{
self,
dbconnection::{
duckdbconn::{flatten_table_function_name, is_table_function, DuckDbConnection},
duckdbconn::{
flatten_table_function_name, is_table_function, DuckDBParameter, DuckDbConnection,
},
get_schema, DbConnection,
},
duckdbpool::DuckDbConnectionPool,
DbConnectionPool, Mode,
DbConnectionPool, DbInstanceKey, Mode,
};
use crate::sql::sql_provider_datafusion;
use crate::util::{
Expand All @@ -25,10 +27,11 @@ use datafusion::{
logical_expr::CreateExternalTable,
sql::TableReference,
};
use duckdb::{AccessMode, DuckdbConnectionManager, ToSql, Transaction};
use duckdb::{AccessMode, DuckdbConnectionManager, Transaction};
use itertools::Itertools;
use snafu::prelude::*;
use std::{cmp, collections::HashMap, sync::Arc};
use tokio::sync::Mutex;

use self::{creator::TableCreator, sql_table::DuckDBTable, write::DuckDBTableWriter};

Expand Down Expand Up @@ -87,11 +90,6 @@ pub enum Error {
#[snafu(display("Unable to commit transaction: {source}"))]
UnableToCommitTransaction { source: duckdb::Error },

#[snafu(display("Unable to checkpoint duckdb: {source}"))]
UnableToCheckpoint {
source: Box<dyn std::error::Error + Send + Sync>,
},

#[snafu(display("Unable to begin duckdb transaction: {source}"))]
UnableToBeginTransaction { source: duckdb::Error },

Expand Down Expand Up @@ -121,6 +119,7 @@ type Result<T, E = Error> = std::result::Result<T, E>;

pub struct DuckDBTableProviderFactory {
access_mode: AccessMode,
instances: Arc<Mutex<HashMap<DbInstanceKey, DuckDbConnectionPool>>>,
}

const DUCKDB_DB_PATH_PARAM: &str = "open";
Expand All @@ -129,9 +128,10 @@ const DUCKDB_ATTACH_DATABASES_PARAM: &str = "attach_databases";

impl DuckDBTableProviderFactory {
#[must_use]
pub fn new() -> Self {
pub fn new(access_mode: AccessMode) -> Self {
Self {
access_mode: AccessMode::ReadOnly,
access_mode,
instances: Arc::new(Mutex::new(HashMap::new())),
}
}

Expand All @@ -148,12 +148,6 @@ impl DuckDBTableProviderFactory {
.unwrap_or_default()
}

#[must_use]
pub fn access_mode(mut self, access_mode: AccessMode) -> Self {
self.access_mode = access_mode;
self
}

#[must_use]
pub fn duckdb_file_path(&self, name: &str, options: &mut HashMap<String, String>) -> String {
let options = util::remove_prefix_from_hashmap_keys(options.clone(), "duckdb_");
Expand All @@ -169,15 +163,44 @@ impl DuckDBTableProviderFactory {
.cloned()
.unwrap_or(default_filepath)
}
}

impl Default for DuckDBTableProviderFactory {
fn default() -> Self {
Self::new()
pub async fn get_or_init_memory_instance(&self) -> Result<DuckDbConnectionPool> {
let key = DbInstanceKey::memory();
let mut instances = self.instances.lock().await;

if let Some(instance) = instances.get(&key) {
return Ok(instance.clone());
}

let pool = DuckDbConnectionPool::new_memory().context(DbConnectionPoolSnafu)?;

instances.insert(key, pool.clone());

Ok(pool)
}

pub async fn get_or_init_file_instance(
&self,
db_path: impl Into<Arc<str>>,
) -> Result<DuckDbConnectionPool> {
let db_path = db_path.into();
let key = DbInstanceKey::file(Arc::clone(&db_path));
let mut instances = self.instances.lock().await;

if let Some(instance) = instances.get(&key) {
return Ok(instance.clone());
}

let pool = DuckDbConnectionPool::new_file(&db_path, &self.access_mode)
.context(DbConnectionPoolSnafu)?;

instances.insert(key, pool.clone());

Ok(pool)
}
}

type DynDuckDbConnectionPool = dyn DbConnectionPool<r2d2::PooledConnection<DuckdbConnectionManager>, &'static dyn ToSql>
type DynDuckDbConnectionPool = dyn DbConnectionPool<r2d2::PooledConnection<DuckdbConnectionManager>, DuckDBParameter>
+ Send
+ Sync;

Expand Down Expand Up @@ -229,12 +252,13 @@ impl TableProviderFactory for DuckDBTableProviderFactory {
// open duckdb at given path or create a new one
let db_path = self.duckdb_file_path(&name, &mut options);

DuckDbConnectionPool::new_file(&db_path, &self.access_mode)
.context(DbConnectionPoolSnafu)
self.get_or_init_file_instance(db_path)
.await
.map_err(to_datafusion_error)?
}
Mode::Memory => DuckDbConnectionPool::new_memory()
.context(DbConnectionPoolSnafu)
Mode::Memory => self
.get_or_init_memory_instance()
.await
.map_err(to_datafusion_error)?,
};

Expand Down Expand Up @@ -265,7 +289,12 @@ impl TableProviderFactory for DuckDBTableProviderFactory {
));

#[cfg(feature = "duckdb-federation")]
let read_provider = Arc::new(read_provider.create_federated_table_provider()?);
let read_provider: Arc<dyn TableProvider> = if mode == Mode::File {
// federation is disabled for in-memory mode until memory connections are updated to use the same database instance instead of separate instances
Arc::new(read_provider.create_federated_table_provider()?)
} else {
read_provider
};

Ok(DuckDBTableWriter::create(
read_provider,
Expand Down Expand Up @@ -317,18 +346,18 @@ impl DuckDB {
pub fn connect_sync(
&self,
) -> Result<
Box<dyn DbConnection<r2d2::PooledConnection<DuckdbConnectionManager>, &'static dyn ToSql>>,
Box<dyn DbConnection<r2d2::PooledConnection<DuckdbConnectionManager>, DuckDBParameter>>,
> {
Arc::clone(&self.pool)
.connect_sync()
.context(DbConnectionSnafu)
}

pub fn duckdb_conn<'a>(
db_connection: &'a mut Box<
dyn DbConnection<r2d2::PooledConnection<DuckdbConnectionManager>, &'static dyn ToSql>,
pub fn duckdb_conn(
db_connection: &mut Box<
dyn DbConnection<r2d2::PooledConnection<DuckdbConnectionManager>, DuckDBParameter>,
>,
) -> Result<&'a mut DuckDbConnection> {
) -> Result<&mut DuckDbConnection> {
db_connection
.as_any_mut()
.downcast_mut::<DuckDbConnection>()
Expand Down Expand Up @@ -441,7 +470,12 @@ impl DuckDBTableFactory {
));

#[cfg(feature = "duckdb-federation")]
let table_provider = Arc::new(table_provider.create_federated_table_provider()?);
let table_provider: Arc<dyn TableProvider> = if self.pool.mode() == Mode::File {
// federation is disabled for in-memory mode until memory connections are updated to use the same database instance instead of separate instances
Arc::new(table_provider.create_federated_table_provider()?)
} else {
table_provider
};

Ok(table_provider)
}
Expand Down
Loading

0 comments on commit 8f2a63f

Please sign in to comment.