From fe4553f533da0838494ff215ac139627cbf2426c Mon Sep 17 00:00:00 2001 From: Cory Grinstead Date: Tue, 1 Oct 2024 12:23:14 -0500 Subject: [PATCH] [FEAT]: sql `read_deltalake` function (#2974) depends on https://github.com/Eventual-Inc/Daft/pull/2954 --------- Co-authored-by: Kev Wang --- Cargo.lock | 1 + .../file-formats/src/file_format_config.rs | 11 + src/daft-dsl/src/lib.rs | 2 +- src/daft-dsl/src/lit.rs | 75 +++- src/daft-plan/src/builder.rs | 179 +++++++- src/daft-plan/src/lib.rs | 2 +- src/daft-scan/src/lib.rs | 2 +- src/daft-schema/src/time_unit.rs | 16 + src/daft-sql/Cargo.toml | 1 + src/daft-sql/src/error.rs | 9 + src/daft-sql/src/functions.rs | 35 +- src/daft-sql/src/lib.rs | 2 +- src/daft-sql/src/modules/config.rs | 391 ++++++++++++++++++ src/daft-sql/src/modules/image/resize.rs | 4 +- src/daft-sql/src/modules/mod.rs | 1 + src/daft-sql/src/planner.rs | 31 +- src/daft-sql/src/table_provider/mod.rs | 119 ++++++ .../src/table_provider/read_parquet.rs | 77 ++++ tests/sql/test_table_funcs.py | 7 + 19 files changed, 933 insertions(+), 32 deletions(-) create mode 100644 src/daft-sql/src/modules/config.rs create mode 100644 src/daft-sql/src/table_provider/mod.rs create mode 100644 src/daft-sql/src/table_provider/read_parquet.rs create mode 100644 tests/sql/test_table_funcs.py diff --git a/Cargo.lock b/Cargo.lock index 592a0793ba..8d60981bd3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2170,6 +2170,7 @@ version = "0.3.0-dev0" dependencies = [ "common-daft-config", "common-error", + "common-io-config", "daft-core", "daft-dsl", "daft-functions", diff --git a/src/common/file-formats/src/file_format_config.rs b/src/common/file-formats/src/file_format_config.rs index fe659bc444..6054907861 100644 --- a/src/common/file-formats/src/file_format_config.rs +++ b/src/common/file-formats/src/file_format_config.rs @@ -115,6 +115,17 @@ impl ParquetSourceConfig { } } +impl Default for ParquetSourceConfig { + fn default() -> Self { + Self { + coerce_int96_timestamp_unit: TimeUnit::Nanoseconds, + field_id_mapping: None, + row_groups: None, + chunk_size: None, + } + } +} + #[cfg(feature = "python")] #[pymethods] impl ParquetSourceConfig { diff --git a/src/daft-dsl/src/lib.rs b/src/daft-dsl/src/lib.rs index 754578eb6d..2fa99115e3 100644 --- a/src/daft-dsl/src/lib.rs +++ b/src/daft-dsl/src/lib.rs @@ -18,7 +18,7 @@ pub use expr::{ binary_op, col, has_agg, has_stateful_udf, is_partition_compatible, AggExpr, ApproxPercentileParams, Expr, ExprRef, Operator, SketchType, }; -pub use lit::{lit, literals_to_series, null_lit, Literal, LiteralValue}; +pub use lit::{lit, literal_value, literals_to_series, null_lit, Literal, LiteralValue}; #[cfg(feature = "python")] use pyo3::prelude::*; pub use resolve_expr::{ diff --git a/src/daft-dsl/src/lit.rs b/src/daft-dsl/src/lit.rs index 5db0f05a3d..55888d73f8 100644 --- a/src/daft-dsl/src/lit.rs +++ b/src/daft-dsl/src/lit.rs @@ -14,6 +14,7 @@ use daft_core::{ display_timestamp, }, }; +use indexmap::IndexMap; use serde::{Deserialize, Serialize}; #[cfg(feature = "python")] @@ -68,6 +69,8 @@ pub enum LiteralValue { /// Python object. #[cfg(feature = "python")] Python(PyObjectWrapper), + + Struct(IndexMap), } impl Eq for LiteralValue {} @@ -112,6 +115,12 @@ impl Hash for LiteralValue { } #[cfg(feature = "python")] Python(py_obj) => py_obj.hash(state), + Struct(entries) => { + entries.iter().for_each(|(v, f)| { + v.hash(state); + f.hash(state); + }); + } } } } @@ -143,6 +152,16 @@ impl Display for LiteralValue { Python::with_gil(|py| pyobj.0.call_method0(py, pyo3::intern!(py, "__str__"))) .unwrap() }), + Struct(entries) => { + write!(f, "Struct(")?; + for (i, (field, v)) in entries.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{}: {}", field.name, v)?; + } + write!(f, ")") + } } } } @@ -169,6 +188,7 @@ impl LiteralValue { Series(series) => series.data_type().clone(), #[cfg(feature = "python")] Python(_) => DataType::Python, + Struct(entries) => DataType::Struct(entries.keys().cloned().collect()), } } @@ -203,6 +223,13 @@ impl LiteralValue { Series(series) => series.clone().rename("literal"), #[cfg(feature = "python")] Python(val) => PythonArray::from(("literal", vec![val.0.clone()])).into_series(), + Struct(entries) => { + let struct_dtype = DataType::Struct(entries.keys().cloned().collect()); + let struct_field = Field::new("literal", struct_dtype); + + let values = entries.values().map(|v| v.to_series()).collect(); + StructArray::new(struct_field, values, None).into_series() + } }; result } @@ -235,6 +262,7 @@ impl LiteralValue { Decimal(..) | Series(..) | Time(..) | Binary(..) => display_sql_err, #[cfg(feature = "python")] Python(..) => display_sql_err, + Struct(..) => display_sql_err, } } @@ -304,49 +332,64 @@ impl LiteralValue { } } -pub trait Literal { +pub trait Literal: Sized { /// [Literal](Expr::Literal) expression. - fn lit(self) -> ExprRef; + fn lit(self) -> ExprRef { + Expr::Literal(self.literal_value()).into() + } + fn literal_value(self) -> LiteralValue; } impl Literal for String { - fn lit(self) -> ExprRef { - Expr::Literal(LiteralValue::Utf8(self)).into() + fn literal_value(self) -> LiteralValue { + LiteralValue::Utf8(self) } } impl<'a> Literal for &'a str { - fn lit(self) -> ExprRef { - Expr::Literal(LiteralValue::Utf8(self.to_owned())).into() + fn literal_value(self) -> LiteralValue { + LiteralValue::Utf8(self.to_owned()) } } macro_rules! make_literal { ($TYPE:ty, $SCALAR:ident) => { impl Literal for $TYPE { - fn lit(self) -> ExprRef { - Expr::Literal(LiteralValue::$SCALAR(self)).into() + fn literal_value(self) -> LiteralValue { + LiteralValue::$SCALAR(self) } } }; } impl<'a> Literal for &'a [u8] { - fn lit(self) -> ExprRef { - Expr::Literal(LiteralValue::Binary(self.to_vec())).into() + fn literal_value(self) -> LiteralValue { + LiteralValue::Binary(self.to_vec()) } } impl Literal for Series { - fn lit(self) -> ExprRef { - Expr::Literal(LiteralValue::Series(self)).into() + fn literal_value(self) -> LiteralValue { + LiteralValue::Series(self) } } #[cfg(feature = "python")] impl Literal for pyo3::PyObject { - fn lit(self) -> ExprRef { - Expr::Literal(LiteralValue::Python(PyObjectWrapper(self))).into() + fn literal_value(self) -> LiteralValue { + LiteralValue::Python(PyObjectWrapper(self)) + } +} + +impl Literal for Option +where + T: Literal, +{ + fn literal_value(self) -> LiteralValue { + match self { + Some(val) => val.literal_value(), + None => LiteralValue::Null, + } } } @@ -361,6 +404,10 @@ pub fn lit(t: L) -> ExprRef { t.lit() } +pub fn literal_value(t: L) -> LiteralValue { + t.literal_value() +} + pub fn null_lit() -> ExprRef { Arc::new(Expr::Literal(LiteralValue::Null)) } diff --git a/src/daft-plan/src/builder.rs b/src/daft-plan/src/builder.rs index 982a3634a9..a9e05ec6cb 100644 --- a/src/daft-plan/src/builder.rs +++ b/src/daft-plan/src/builder.rs @@ -1,17 +1,27 @@ use std::{ - collections::{HashMap, HashSet}, + collections::{BTreeMap, HashMap, HashSet}, sync::Arc, }; use common_daft_config::DaftPlanningConfig; use common_display::mermaid::MermaidDisplayOptions; use common_error::DaftResult; -use common_file_formats::FileFormat; +use common_file_formats::{FileFormat, FileFormatConfig, ParquetSourceConfig}; use common_io_config::IOConfig; -use daft_core::join::{JoinStrategy, JoinType}; +use daft_core::{ + join::{JoinStrategy, JoinType}, + prelude::TimeUnit, +}; use daft_dsl::{col, ExprRef}; -use daft_scan::{PhysicalScanInfo, Pushdowns, ScanOperatorRef}; -use daft_schema::schema::{Schema, SchemaRef}; +use daft_scan::{ + glob::GlobScanOperator, + storage_config::{NativeStorageConfig, StorageConfig}, + PhysicalScanInfo, Pushdowns, ScanOperatorRef, +}; +use daft_schema::{ + field::Field, + schema::{Schema, SchemaRef}, +}; #[cfg(feature = "python")] use { crate::sink_info::{CatalogInfo, IcebergCatalogInfo}, @@ -73,7 +83,29 @@ impl From<&LogicalPlanBuilder> for LogicalPlanRef { value.plan.clone() } } - +pub trait IntoGlobPath { + fn into_glob_path(self) -> Vec; +} +impl IntoGlobPath for Vec { + fn into_glob_path(self) -> Vec { + self + } +} +impl IntoGlobPath for String { + fn into_glob_path(self) -> Vec { + vec![self] + } +} +impl IntoGlobPath for &str { + fn into_glob_path(self) -> Vec { + vec![self.to_string()] + } +} +impl IntoGlobPath for Vec<&str> { + fn into_glob_path(self) -> Vec { + self.iter().map(|s| s.to_string()).collect() + } +} impl LogicalPlanBuilder { /// Replace the LogicalPlanBuilder's plan with the provided plan pub fn with_new_plan>>(&self, plan: LP) -> Self { @@ -105,9 +137,51 @@ impl LogicalPlanBuilder { )); let logical_plan: LogicalPlan = logical_ops::Source::new(schema.clone(), source_info.into()).into(); + Ok(Self::new(logical_plan.into(), None)) } + #[cfg(feature = "python")] + pub fn delta_scan>( + glob_path: T, + io_config: Option, + multithreaded_io: bool, + ) -> DaftResult { + use daft_scan::storage_config::PyStorageConfig; + + Python::with_gil(|py| { + let io_config = io_config.unwrap_or_default(); + + let native_storage_config = NativeStorageConfig { + io_config: Some(io_config), + multithreaded_io, + }; + + let py_storage_config: PyStorageConfig = + Arc::new(StorageConfig::Native(Arc::new(native_storage_config))).into(); + + // let py_io_config = PyIOConfig { config: io_config }; + let delta_lake_scan = PyModule::import_bound(py, "daft.delta_lake.delta_lake_scan")?; + let delta_lake_scan_operator = + delta_lake_scan.getattr(pyo3::intern!(py, "DeltaLakeScanOperator"))?; + let delta_lake_operator = delta_lake_scan_operator + .call1((glob_path.as_ref(), py_storage_config))? + .to_object(py); + let scan_operator_handle = + ScanOperatorHandle::from_python_scan_operator(delta_lake_operator, py)?; + Self::table_scan(scan_operator_handle.into(), None) + }) + } + + #[cfg(not(feature = "python"))] + pub fn delta_scan( + glob_path: T, + io_config: Option, + multithreaded_io: bool, + ) -> DaftResult { + panic!("Delta Lake scan requires the 'python' feature to be enabled.") + } + pub fn table_scan( scan_operator: ScanOperatorRef, pushdowns: Option, @@ -142,6 +216,10 @@ impl LogicalPlanBuilder { Ok(Self::new(logical_plan.into(), None)) } + pub fn parquet_scan(glob_path: T) -> ParquetScanBuilder { + ParquetScanBuilder::new(glob_path) + } + pub fn select(&self, to_select: Vec) -> DaftResult { let logical_plan: LogicalPlan = logical_ops::Project::try_new(self.plan.clone(), to_select)?.into(); @@ -498,6 +576,95 @@ impl LogicalPlanBuilder { } } +pub struct ParquetScanBuilder { + pub glob_paths: Vec, + pub infer_schema: bool, + pub coerce_int96_timestamp_unit: TimeUnit, + pub field_id_mapping: Option>>, + pub row_groups: Option>>>, + pub chunk_size: Option, + pub io_config: Option, + pub multithreaded: bool, + pub schema: Option, +} + +impl ParquetScanBuilder { + pub fn new(glob_paths: T) -> Self { + let glob_paths = glob_paths.into_glob_path(); + Self::new_impl(glob_paths) + } + + // concrete implementation to reduce LLVM code duplication + fn new_impl(glob_paths: Vec) -> Self { + Self { + glob_paths, + infer_schema: true, + coerce_int96_timestamp_unit: TimeUnit::Nanoseconds, + field_id_mapping: None, + row_groups: None, + chunk_size: None, + multithreaded: true, + schema: None, + io_config: None, + } + } + pub fn infer_schema(mut self, infer_schema: bool) -> Self { + self.infer_schema = infer_schema; + self + } + pub fn coerce_int96_timestamp_unit(mut self, unit: TimeUnit) -> Self { + self.coerce_int96_timestamp_unit = unit; + self + } + pub fn field_id_mapping(mut self, field_id_mapping: Arc>) -> Self { + self.field_id_mapping = Some(field_id_mapping); + self + } + pub fn row_groups(mut self, row_groups: Vec>>) -> Self { + self.row_groups = Some(row_groups); + self + } + pub fn chunk_size(mut self, chunk_size: usize) -> Self { + self.chunk_size = Some(chunk_size); + self + } + + pub fn io_config(mut self, io_config: IOConfig) -> Self { + self.io_config = Some(io_config); + self + } + + pub fn multithreaded(mut self, multithreaded: bool) -> Self { + self.multithreaded = multithreaded; + self + } + pub fn schema(mut self, schema: SchemaRef) -> Self { + self.schema = Some(schema); + self + } + + pub fn finish(self) -> DaftResult { + let cfg = ParquetSourceConfig { + coerce_int96_timestamp_unit: self.coerce_int96_timestamp_unit, + field_id_mapping: self.field_id_mapping, + row_groups: self.row_groups, + chunk_size: self.chunk_size, + }; + + let operator = Arc::new(GlobScanOperator::try_new( + self.glob_paths, + Arc::new(FileFormatConfig::Parquet(cfg)), + Arc::new(StorageConfig::Native(Arc::new( + NativeStorageConfig::new_internal(self.multithreaded, self.io_config), + ))), + self.infer_schema, + self.schema, + )?); + + LogicalPlanBuilder::table_scan(ScanOperatorRef(operator), None) + } +} + /// A Python-facing wrapper of the LogicalPlanBuilder. /// /// This lightweight proxy interface should hold as much of the Python-specific logic diff --git a/src/daft-plan/src/lib.rs b/src/daft-plan/src/lib.rs index 50e309916b..2541a143db 100644 --- a/src/daft-plan/src/lib.rs +++ b/src/daft-plan/src/lib.rs @@ -19,7 +19,7 @@ pub mod source_info; mod test; mod treenode; -pub use builder::{LogicalPlanBuilder, PyLogicalPlanBuilder}; +pub use builder::{LogicalPlanBuilder, ParquetScanBuilder, PyLogicalPlanBuilder}; pub use daft_core::join::{JoinStrategy, JoinType}; pub use logical_plan::{LogicalPlan, LogicalPlanRef}; pub use partitioning::ClusteringSpec; diff --git a/src/daft-scan/src/lib.rs b/src/daft-scan/src/lib.rs index 10cc0c6804..23191b1d11 100644 --- a/src/daft-scan/src/lib.rs +++ b/src/daft-scan/src/lib.rs @@ -22,7 +22,7 @@ use serde::{Deserialize, Serialize}; mod anonymous; pub use anonymous::AnonymousScanOperator; -mod glob; +pub mod glob; use common_daft_config::DaftExecutionConfig; pub mod scan_task_iters; diff --git a/src/daft-schema/src/time_unit.rs b/src/daft-schema/src/time_unit.rs index d4b17b0e7c..9b1afea2e5 100644 --- a/src/daft-schema/src/time_unit.rs +++ b/src/daft-schema/src/time_unit.rs @@ -1,4 +1,7 @@ +use std::str::FromStr; + use arrow2::datatypes::TimeUnit as ArrowTimeUnit; +use common_error::DaftError; use derive_more::Display; use serde::{Deserialize, Serialize}; @@ -33,6 +36,19 @@ impl TimeUnit { } } +impl FromStr for TimeUnit { + type Err = DaftError; + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "ns" | "nanoseconds" => Ok(Self::Nanoseconds), + "us" | "microseconds" => Ok(Self::Microseconds), + "ms" | "milliseconds" => Ok(Self::Milliseconds), + "s" | "seconds" => Ok(Self::Seconds), + _ => Err(DaftError::ValueError("Invalid time unit".to_string())), + } + } +} + impl From<&ArrowTimeUnit> for TimeUnit { fn from(tu: &ArrowTimeUnit) -> Self { match tu { diff --git a/src/daft-sql/Cargo.toml b/src/daft-sql/Cargo.toml index 2b80dda42c..81d7d36ff0 100644 --- a/src/daft-sql/Cargo.toml +++ b/src/daft-sql/Cargo.toml @@ -1,6 +1,7 @@ [dependencies] common-daft-config = {path = "../common/daft-config"} common-error = {path = "../common/error"} +common-io-config = {path = "../common/io-config", default-features = false} daft-core = {path = "../daft-core"} daft-dsl = {path = "../daft-dsl"} daft-functions = {path = "../daft-functions"} diff --git a/src/daft-sql/src/error.rs b/src/daft-sql/src/error.rs index 31f8a400ed..1fd9ae97e7 100644 --- a/src/daft-sql/src/error.rs +++ b/src/daft-sql/src/error.rs @@ -12,6 +12,8 @@ pub enum PlannerError { ParseError { message: String }, #[snafu(display("Invalid operation: {message}"))] InvalidOperation { message: String }, + #[snafu(display("Invalid argument ({message}) for function '{function}'"))] + InvalidFunctionArgument { message: String, function: String }, #[snafu(display("Table not found: {message}"))] TableNotFound { message: String }, #[snafu(display("Column {column_name} not found in {relation}"))] @@ -66,6 +68,13 @@ impl PlannerError { message: message.into(), } } + + pub fn invalid_argument, F: Into>(arg: S, function: F) -> Self { + Self::InvalidFunctionArgument { + message: arg.into(), + function: function.into(), + } + } } #[macro_export] diff --git a/src/daft-sql/src/functions.rs b/src/daft-sql/src/functions.rs index 6b456af17c..2a67d97c63 100644 --- a/src/daft-sql/src/functions.rs +++ b/src/daft-sql/src/functions.rs @@ -1,5 +1,6 @@ use std::{collections::HashMap, sync::Arc}; +use config::SQLModuleConfig; use daft_dsl::ExprRef; use hashing::SQLModuleHashing; use once_cell::sync::Lazy; @@ -31,6 +32,7 @@ pub(crate) static SQL_FUNCTIONS: Lazy = Lazy::new(|| { functions.register::(); functions.register::(); functions.register::(); + functions.register::(); functions }); @@ -110,7 +112,7 @@ pub(crate) struct SQLFunctionArguments { } impl SQLFunctionArguments { - pub fn get_unnamed(&self, idx: usize) -> Option<&ExprRef> { + pub fn get_positional(&self, idx: usize) -> Option<&ExprRef> { self.positional.get(&idx) } pub fn get_named(&self, name: &str) -> Option<&ExprRef> { @@ -123,6 +125,12 @@ impl SQLFunctionArguments { .map(|expr| T::from_expr(expr)) .transpose() } + pub fn try_get_positional(&self, idx: usize) -> Result, PlannerError> { + self.positional + .get(&idx) + .map(|expr| T::from_expr(expr)) + .transpose() + } } pub trait SQLLiteral { @@ -155,6 +163,17 @@ impl SQLLiteral for i64 { } } +impl SQLLiteral for usize { + fn from_expr(expr: &ExprRef) -> Result + where + Self: Sized, + { + expr.as_literal() + .and_then(|lit| lit.as_i64().map(|v| v as Self)) + .ok_or_else(|| PlannerError::invalid_operation("Expected an integer literal")) + } +} + impl SQLLiteral for bool { fn from_expr(expr: &ExprRef) -> Result where @@ -258,6 +277,15 @@ impl SQLPlanner { where T: TryFrom, { + self.parse_function_args(args, expected_named, expected_positional)? + .try_into() + } + pub(crate) fn parse_function_args( + &self, + args: &[FunctionArg], + expected_named: &'static [&'static str], + expected_positional: usize, + ) -> SQLPlannerResult { let mut positional_args = HashMap::new(); let mut named_args = HashMap::new(); for (idx, arg) in args.iter().enumerate() { @@ -282,11 +310,10 @@ impl SQLPlanner { } } - SQLFunctionArguments { + Ok(SQLFunctionArguments { positional: positional_args, named: named_args, - } - .try_into() + }) } pub(crate) fn plan_function_arg( diff --git a/src/daft-sql/src/lib.rs b/src/daft-sql/src/lib.rs index 97fd91c280..6246e8b242 100644 --- a/src/daft-sql/src/lib.rs +++ b/src/daft-sql/src/lib.rs @@ -3,9 +3,9 @@ pub mod error; pub mod functions; mod modules; mod planner; - #[cfg(feature = "python")] pub mod python; +mod table_provider; #[cfg(feature = "python")] use pyo3::prelude::*; diff --git a/src/daft-sql/src/modules/config.rs b/src/daft-sql/src/modules/config.rs new file mode 100644 index 0000000000..9a540d3025 --- /dev/null +++ b/src/daft-sql/src/modules/config.rs @@ -0,0 +1,391 @@ +use common_io_config::{AzureConfig, GCSConfig, HTTPConfig, IOConfig, S3Config}; +use daft_core::prelude::{DataType, Field}; +use daft_dsl::{literal_value, Expr, ExprRef, LiteralValue}; + +use super::SQLModule; +use crate::{ + error::{PlannerError, SQLPlannerResult}, + functions::{SQLFunction, SQLFunctionArguments, SQLFunctions}, + unsupported_sql_err, +}; + +pub struct SQLModuleConfig; + +impl SQLModule for SQLModuleConfig { + fn register(parent: &mut SQLFunctions) { + parent.add_fn("S3Config", S3ConfigFunction); + parent.add_fn("HTTPConfig", HTTPConfigFunction); + parent.add_fn("AzureConfig", AzureConfigFunction); + parent.add_fn("GCSConfig", GCSConfigFunction); + } +} + +pub struct S3ConfigFunction; +macro_rules! item { + ($name:expr, $ty:ident) => { + ( + Field::new(stringify!($name), DataType::$ty), + literal_value($name), + ) + }; +} + +impl SQLFunction for S3ConfigFunction { + fn to_expr( + &self, + inputs: &[sqlparser::ast::FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> crate::error::SQLPlannerResult { + // TODO(cory): Ideally we should use serde to deserialize the input arguments + let args: SQLFunctionArguments = planner.parse_function_args( + inputs, + &[ + "region_name", + "endpoint_url", + "key_id", + "session_token", + "access_key", + "credentials_provider", + "buffer_time", + "max_connections_per_io_thread", + "retry_initial_backoff_ms", + "connect_timeout_ms", + "read_timeout_ms", + "num_tries", + "retry_mode", + "anonymous", + "use_ssl", + "verify_ssl", + "check_hostname_ssl", + "requester_pays", + "force_virtual_addressing", + "profile_name", + ], + 0, + )?; + + let region_name = args.try_get_named::("region_name")?; + let endpoint_url = args.try_get_named::("endpoint_url")?; + let key_id = args.try_get_named::("key_id")?; + let session_token = args.try_get_named::("session_token")?; + + let access_key = args.try_get_named::("access_key")?; + let buffer_time = args.try_get_named("buffer_time")?.map(|t: i64| t as u64); + + let max_connections_per_io_thread = args + .try_get_named("max_connections_per_io_thread")? + .map(|t: i64| t as u32); + + let retry_initial_backoff_ms = args + .try_get_named("retry_initial_backoff_ms")? + .map(|t: i64| t as u64); + + let connect_timeout_ms = args + .try_get_named("connect_timeout_ms")? + .map(|t: i64| t as u64); + + let read_timeout_ms = args + .try_get_named("read_timeout_ms")? + .map(|t: i64| t as u64); + + let num_tries = args.try_get_named("num_tries")?.map(|t: i64| t as u32); + let retry_mode = args.try_get_named::("retry_mode")?; + let anonymous = args.try_get_named::("anonymous")?; + let use_ssl = args.try_get_named::("use_ssl")?; + let verify_ssl = args.try_get_named::("verify_ssl")?; + let check_hostname_ssl = args.try_get_named::("check_hostname_ssl")?; + let requester_pays = args.try_get_named::("requester_pays")?; + let force_virtual_addressing = args.try_get_named::("force_virtual_addressing")?; + let profile_name = args.try_get_named::("profile_name")?; + + let entries = vec![ + (Field::new("variant", DataType::Utf8), literal_value("s3")), + item!(region_name, Utf8), + item!(endpoint_url, Utf8), + item!(key_id, Utf8), + item!(session_token, Utf8), + item!(access_key, Utf8), + item!(buffer_time, UInt64), + item!(max_connections_per_io_thread, UInt32), + item!(retry_initial_backoff_ms, UInt64), + item!(connect_timeout_ms, UInt64), + item!(read_timeout_ms, UInt64), + item!(num_tries, UInt32), + item!(retry_mode, Utf8), + item!(anonymous, Boolean), + item!(use_ssl, Boolean), + item!(verify_ssl, Boolean), + item!(check_hostname_ssl, Boolean), + item!(requester_pays, Boolean), + item!(force_virtual_addressing, Boolean), + item!(profile_name, Utf8), + ] + .into_iter() + .collect::<_>(); + + Ok(Expr::Literal(LiteralValue::Struct(entries)).arced()) + } +} + +pub struct HTTPConfigFunction; + +impl SQLFunction for HTTPConfigFunction { + fn to_expr( + &self, + inputs: &[sqlparser::ast::FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> crate::error::SQLPlannerResult { + let args: SQLFunctionArguments = + planner.parse_function_args(inputs, &["user_agent", "bearer_token"], 0)?; + + let user_agent = args.try_get_named::("user_agent")?; + let bearer_token = args.try_get_named::("bearer_token")?; + + let entries = vec![ + (Field::new("variant", DataType::Utf8), literal_value("http")), + item!(user_agent, Utf8), + item!(bearer_token, Utf8), + ] + .into_iter() + .collect::<_>(); + + Ok(Expr::Literal(LiteralValue::Struct(entries)).arced()) + } +} +pub struct AzureConfigFunction; +impl SQLFunction for AzureConfigFunction { + fn to_expr( + &self, + inputs: &[sqlparser::ast::FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> crate::error::SQLPlannerResult { + let args: SQLFunctionArguments = planner.parse_function_args( + inputs, + &[ + "storage_account", + "access_key", + "sas_token", + "bearer_token", + "tenant_id", + "client_id", + "client_secret", + "use_fabric_endpoint", + "anonymous", + "endpoint_url", + "use_ssl", + ], + 0, + )?; + + let storage_account = args.try_get_named::("storage_account")?; + let access_key = args.try_get_named::("access_key")?; + let sas_token = args.try_get_named::("sas_token")?; + let bearer_token = args.try_get_named::("bearer_token")?; + let tenant_id = args.try_get_named::("tenant_id")?; + let client_id = args.try_get_named::("client_id")?; + let client_secret = args.try_get_named::("client_secret")?; + let use_fabric_endpoint = args.try_get_named::("use_fabric_endpoint")?; + let anonymous = args.try_get_named::("anonymous")?; + let endpoint_url = args.try_get_named::("endpoint_url")?; + let use_ssl = args.try_get_named::("use_ssl")?; + + let entries = vec![ + ( + Field::new("variant", DataType::Utf8), + literal_value("azure"), + ), + item!(storage_account, Utf8), + item!(access_key, Utf8), + item!(sas_token, Utf8), + item!(bearer_token, Utf8), + item!(tenant_id, Utf8), + item!(client_id, Utf8), + item!(client_secret, Utf8), + item!(use_fabric_endpoint, Boolean), + item!(anonymous, Boolean), + item!(endpoint_url, Utf8), + item!(use_ssl, Boolean), + ] + .into_iter() + .collect::<_>(); + + Ok(Expr::Literal(LiteralValue::Struct(entries)).arced()) + } +} + +pub struct GCSConfigFunction; + +impl SQLFunction for GCSConfigFunction { + fn to_expr( + &self, + inputs: &[sqlparser::ast::FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> SQLPlannerResult { + let args: SQLFunctionArguments = planner.parse_function_args( + inputs, + &["project_id", "credentials", "token", "anonymous"], + 0, + )?; + + let project_id = args.try_get_named::("project_id")?; + let credentials = args.try_get_named::("credentials")?; + let token = args.try_get_named::("token")?; + let anonymous = args.try_get_named::("anonymous")?; + + let entries = vec![ + (Field::new("variant", DataType::Utf8), literal_value("gcs")), + item!(project_id, Utf8), + item!(credentials, Utf8), + item!(token, Utf8), + item!(anonymous, Boolean), + ] + .into_iter() + .collect::<_>(); + + Ok(Expr::Literal(LiteralValue::Struct(entries)).arced()) + } +} + +pub(crate) fn expr_to_iocfg(expr: &ExprRef) -> SQLPlannerResult { + // TODO(CORY): use serde to deserialize this + let Expr::Literal(LiteralValue::Struct(entries)) = expr.as_ref() else { + unsupported_sql_err!("Invalid IOConfig"); + }; + + macro_rules! get_value { + ($field:literal, $type:ident) => { + entries + .get(&Field::new($field, DataType::$type)) + .and_then(|s| match s { + LiteralValue::$type(s) => Some(Ok(s.clone())), + LiteralValue::Null => None, + _ => Some(Err(PlannerError::invalid_argument($field, "IOConfig"))), + }) + .transpose() + }; + } + + let variant = get_value!("variant", Utf8)? + .expect("variant is required for IOConfig, this indicates a programming error"); + + match variant.as_ref() { + "s3" => { + let region_name = get_value!("region_name", Utf8)?; + let endpoint_url = get_value!("endpoint_url", Utf8)?; + let key_id = get_value!("key_id", Utf8)?; + let session_token = get_value!("session_token", Utf8)?.map(|s| s.into()); + let access_key = get_value!("access_key", Utf8)?.map(|s| s.into()); + let buffer_time = get_value!("buffer_time", UInt64)?; + let max_connections_per_io_thread = + get_value!("max_connections_per_io_thread", UInt32)?; + let retry_initial_backoff_ms = get_value!("retry_initial_backoff_ms", UInt64)?; + let connect_timeout_ms = get_value!("connect_timeout_ms", UInt64)?; + let read_timeout_ms = get_value!("read_timeout_ms", UInt64)?; + let num_tries = get_value!("num_tries", UInt32)?; + let retry_mode = get_value!("retry_mode", Utf8)?; + let anonymous = get_value!("anonymous", Boolean)?; + let use_ssl = get_value!("use_ssl", Boolean)?; + let verify_ssl = get_value!("verify_ssl", Boolean)?; + let check_hostname_ssl = get_value!("check_hostname_ssl", Boolean)?; + let requester_pays = get_value!("requester_pays", Boolean)?; + let force_virtual_addressing = get_value!("force_virtual_addressing", Boolean)?; + let profile_name = get_value!("profile_name", Utf8)?; + let default = S3Config::default(); + let s3_config = S3Config { + region_name, + endpoint_url, + key_id, + session_token, + access_key, + credentials_provider: None, + buffer_time, + max_connections_per_io_thread: max_connections_per_io_thread + .unwrap_or(default.max_connections_per_io_thread), + retry_initial_backoff_ms: retry_initial_backoff_ms + .unwrap_or(default.retry_initial_backoff_ms), + connect_timeout_ms: connect_timeout_ms.unwrap_or(default.connect_timeout_ms), + read_timeout_ms: read_timeout_ms.unwrap_or(default.read_timeout_ms), + num_tries: num_tries.unwrap_or(default.num_tries), + retry_mode, + anonymous: anonymous.unwrap_or(default.anonymous), + use_ssl: use_ssl.unwrap_or(default.use_ssl), + verify_ssl: verify_ssl.unwrap_or(default.verify_ssl), + check_hostname_ssl: check_hostname_ssl.unwrap_or(default.check_hostname_ssl), + requester_pays: requester_pays.unwrap_or(default.requester_pays), + force_virtual_addressing: force_virtual_addressing + .unwrap_or(default.force_virtual_addressing), + profile_name, + }; + + Ok(IOConfig { + s3: s3_config, + ..Default::default() + }) + } + "http" => { + let default = HTTPConfig::default(); + let user_agent = get_value!("user_agent", Utf8)?.unwrap_or(default.user_agent); + let bearer_token = get_value!("bearer_token", Utf8)?.map(|s| s.into()); + + Ok(IOConfig { + http: HTTPConfig { + user_agent, + bearer_token, + }, + ..Default::default() + }) + } + "azure" => { + let storage_account = get_value!("storage_account", Utf8)?; + let access_key = get_value!("access_key", Utf8)?; + let sas_token = get_value!("sas_token", Utf8)?; + let bearer_token = get_value!("bearer_token", Utf8)?; + let tenant_id = get_value!("tenant_id", Utf8)?; + let client_id = get_value!("client_id", Utf8)?; + let client_secret = get_value!("client_secret", Utf8)?; + let use_fabric_endpoint = get_value!("use_fabric_endpoint", Boolean)?; + let anonymous = get_value!("anonymous", Boolean)?; + let endpoint_url = get_value!("endpoint_url", Utf8)?; + let use_ssl = get_value!("use_ssl", Boolean)?; + + let default = AzureConfig::default(); + + Ok(IOConfig { + azure: AzureConfig { + storage_account, + access_key: access_key.map(|s| s.into()), + sas_token, + bearer_token, + tenant_id, + client_id, + client_secret: client_secret.map(|s| s.into()), + use_fabric_endpoint: use_fabric_endpoint.unwrap_or(default.use_fabric_endpoint), + anonymous: anonymous.unwrap_or(default.anonymous), + endpoint_url, + use_ssl: use_ssl.unwrap_or(default.use_ssl), + }, + ..Default::default() + }) + } + "gcs" => { + let project_id = get_value!("project_id", Utf8)?; + let credentials = get_value!("credentials", Utf8)?; + let token = get_value!("token", Utf8)?; + let anonymous = get_value!("anonymous", Boolean)?; + let default = GCSConfig::default(); + + Ok(IOConfig { + gcs: GCSConfig { + project_id, + credentials: credentials.map(|s| s.into()), + token, + anonymous: anonymous.unwrap_or(default.anonymous), + }, + ..Default::default() + }) + } + _ => { + unreachable!("variant is required for IOConfig, this indicates a programming error") + } + } +} diff --git a/src/daft-sql/src/modules/image/resize.rs b/src/daft-sql/src/modules/image/resize.rs index e4c9804d39..ac6c12fd50 100644 --- a/src/daft-sql/src/modules/image/resize.rs +++ b/src/daft-sql/src/modules/image/resize.rs @@ -16,7 +16,7 @@ impl TryFrom for ImageResize { fn try_from(args: SQLFunctionArguments) -> Result { let width = args .get_named("w") - .or_else(|| args.get_unnamed(0)) + .or_else(|| args.get_positional(0)) .map(|arg| match arg.as_ref() { Expr::Literal(LiteralValue::Int64(i)) => Ok(*i), _ => unsupported_sql_err!("Expected width to be a number"), @@ -28,7 +28,7 @@ impl TryFrom for ImageResize { let height = args .get_named("h") - .or_else(|| args.get_unnamed(1)) + .or_else(|| args.get_positional(1)) .map(|arg| match arg.as_ref() { Expr::Literal(LiteralValue::Int64(i)) => Ok(*i), _ => unsupported_sql_err!("Expected height to be a number"), diff --git a/src/daft-sql/src/modules/mod.rs b/src/daft-sql/src/modules/mod.rs index 989c401393..af4cb731a3 100644 --- a/src/daft-sql/src/modules/mod.rs +++ b/src/daft-sql/src/modules/mod.rs @@ -1,6 +1,7 @@ use crate::functions::SQLFunctions; pub mod aggs; +pub mod config; pub mod float; pub mod hashing; pub mod image; diff --git a/src/daft-sql/src/planner.rs b/src/daft-sql/src/planner.rs index afeb14fa2d..1be5b724a9 100644 --- a/src/daft-sql/src/planner.rs +++ b/src/daft-sql/src/planner.rs @@ -397,7 +397,19 @@ impl SQLPlanner { fn plan_relation(&self, rel: &sqlparser::ast::TableFactor) -> SQLPlannerResult { match rel { - sqlparser::ast::TableFactor::Table { name, .. } => { + sqlparser::ast::TableFactor::Table { + name, + args: Some(args), + alias, + .. + } => { + let tbl_fn = name.0.first().unwrap().value.as_str(); + + self.plan_table_function(tbl_fn, args, alias) + } + sqlparser::ast::TableFactor::Table { + name, args: None, .. + } => { let table_name = name.to_string(); let plan = self .catalog @@ -728,7 +740,22 @@ impl SQLPlanner { } SQLExpr::Struct { .. } => unsupported_sql_err!("STRUCT"), SQLExpr::Named { .. } => unsupported_sql_err!("NAMED"), - SQLExpr::Dictionary(_) => unsupported_sql_err!("DICTIONARY"), + SQLExpr::Dictionary(dict) => { + let entries = dict + .iter() + .map(|entry| { + let key = entry.key.value.clone(); + let value = self.plan_expr(&entry.value)?; + let value = value.as_literal().ok_or_else(|| { + PlannerError::invalid_operation("Dictionary value is not a literal") + })?; + let struct_field = Field::new(key, value.get_type()); + Ok((struct_field, value.clone())) + }) + .collect::>()?; + + Ok(Expr::Literal(LiteralValue::Struct(entries)).arced()) + } SQLExpr::Map(_) => unsupported_sql_err!("MAP"), SQLExpr::Subscript { expr, subscript } => self.plan_subscript(expr, subscript.as_ref()), SQLExpr::Array(_) => unsupported_sql_err!("ARRAY"), diff --git a/src/daft-sql/src/table_provider/mod.rs b/src/daft-sql/src/table_provider/mod.rs new file mode 100644 index 0000000000..453fa4965f --- /dev/null +++ b/src/daft-sql/src/table_provider/mod.rs @@ -0,0 +1,119 @@ +pub mod read_parquet; +use std::{collections::HashMap, sync::Arc}; + +use daft_plan::LogicalPlanBuilder; +use once_cell::sync::Lazy; +use read_parquet::ReadParquetFunction; +use sqlparser::ast::{TableAlias, TableFunctionArgs}; + +use crate::{ + error::SQLPlannerResult, + modules::config::expr_to_iocfg, + planner::{Relation, SQLPlanner}, + unsupported_sql_err, +}; + +pub(crate) static SQL_TABLE_FUNCTIONS: Lazy = Lazy::new(|| { + let mut functions = SQLTableFunctions::new(); + functions.add_fn("read_parquet", ReadParquetFunction); + #[cfg(feature = "python")] + functions.add_fn("read_deltalake", ReadDeltalakeFunction); + + functions +}); + +/// TODOs +/// - Use multimap for function variants. +/// - Add more functions.. +pub struct SQLTableFunctions { + pub(crate) map: HashMap>, +} + +impl SQLTableFunctions { + /// Create a new [SQLFunctions] instance. + pub fn new() -> Self { + Self { + map: HashMap::new(), + } + } + /// Add a [FunctionExpr] to the [SQLFunctions] instance. + pub(crate) fn add_fn(&mut self, name: &str, func: F) { + self.map.insert(name.to_string(), Arc::new(func)); + } + + /// Get a function by name from the [SQLFunctions] instance. + pub(crate) fn get(&self, name: &str) -> Option<&Arc> { + self.map.get(name) + } +} + +impl SQLPlanner { + pub(crate) fn plan_table_function( + &self, + fn_name: &str, + args: &TableFunctionArgs, + alias: &Option, + ) -> SQLPlannerResult { + let fns = &SQL_TABLE_FUNCTIONS; + + let Some(func) = fns.get(fn_name) else { + unsupported_sql_err!("Function `{}` not found", fn_name); + }; + + let builder = func.plan(self, args)?; + let name = alias + .as_ref() + .map(|a| a.name.value.clone()) + .unwrap_or_else(|| fn_name.to_string()); + + Ok(Relation::new(builder, name)) + } +} + +pub(crate) trait SQLTableFunction: Send + Sync { + fn plan( + &self, + planner: &SQLPlanner, + args: &TableFunctionArgs, + ) -> SQLPlannerResult; +} + +pub struct ReadDeltalakeFunction; + +#[cfg(feature = "python")] +impl SQLTableFunction for ReadDeltalakeFunction { + fn plan( + &self, + planner: &SQLPlanner, + args: &TableFunctionArgs, + ) -> SQLPlannerResult { + let (uri, io_config) = match args.args.as_slice() { + [uri] => (uri, None), + [uri, io_config] => { + let args = planner.parse_function_args(&[io_config.clone()], &["io_config"], 0)?; + let io_config = args.get_named("io_config").map(expr_to_iocfg).transpose()?; + + (uri, io_config) + } + _ => unsupported_sql_err!("Expected one or two arguments"), + }; + let uri = planner.plan_function_arg(uri)?; + + let Some(uri) = uri.as_literal().and_then(|lit| lit.as_str()) else { + unsupported_sql_err!("Expected a string literal for the first argument"); + }; + + LogicalPlanBuilder::delta_scan(uri, io_config, true).map_err(From::from) + } +} + +#[cfg(not(feature = "python"))] +impl SQLTableFunction for ReadDeltalakeFunction { + fn plan( + &self, + planner: &SQLPlanner, + args: &TableFunctionArgs, + ) -> SQLPlannerResult { + unsupported_sql_err!("`read_deltalake` function is not supported. Enable the `python` feature to use this function.") + } +} diff --git a/src/daft-sql/src/table_provider/read_parquet.rs b/src/daft-sql/src/table_provider/read_parquet.rs new file mode 100644 index 0000000000..36f84507a0 --- /dev/null +++ b/src/daft-sql/src/table_provider/read_parquet.rs @@ -0,0 +1,77 @@ +use daft_core::prelude::TimeUnit; +use daft_plan::{LogicalPlanBuilder, ParquetScanBuilder}; +use sqlparser::ast::TableFunctionArgs; + +use super::SQLTableFunction; +use crate::{ + error::{PlannerError, SQLPlannerResult}, + functions::SQLFunctionArguments, + modules::config::expr_to_iocfg, + planner::SQLPlanner, +}; + +pub(super) struct ReadParquetFunction; + +impl TryFrom for ParquetScanBuilder { + type Error = PlannerError; + + fn try_from(args: SQLFunctionArguments) -> Result { + let glob_paths: String = args.try_get_positional(0)?.ok_or_else(|| { + PlannerError::invalid_operation("path is required for `read_parquet`") + })?; + let infer_schema = args.try_get_named("infer_schema")?.unwrap_or(true); + let coerce_int96_timestamp_unit = + args.try_get_named::("coerce_int96_timestamp_unit")?; + let coerce_int96_timestamp_unit: TimeUnit = coerce_int96_timestamp_unit + .as_deref() + .unwrap_or("nanoseconds") + .parse::() + .map_err(|_| { + PlannerError::invalid_argument("coerce_int96_timestamp_unit", "read_parquet") + })?; + let chunk_size = args.try_get_named("chunk_size")?; + let multithreaded = args.try_get_named("multithreaded")?.unwrap_or(true); + + let field_id_mapping = None; // TODO + let row_groups = None; // TODO + let schema = None; // TODO + let io_config = args.get_named("io_config").map(expr_to_iocfg).transpose()?; + + Ok(Self { + glob_paths: vec![glob_paths], + infer_schema, + coerce_int96_timestamp_unit, + field_id_mapping, + row_groups, + chunk_size, + io_config, + multithreaded, + schema, + }) + } +} + +impl SQLTableFunction for ReadParquetFunction { + fn plan( + &self, + planner: &SQLPlanner, + args: &TableFunctionArgs, + ) -> SQLPlannerResult { + let builder: ParquetScanBuilder = planner.plan_function_args( + args.args.as_slice(), + &[ + "infer_schema", + "coerce_int96_timestamp_unit", + "chunk_size", + "multithreaded", + // "schema", + // "field_id_mapping", + // "row_groups", + "io_config", + ], + 1, // 1 positional argument (path) + )?; + + builder.finish().map_err(From::from) + } +} diff --git a/tests/sql/test_table_funcs.py b/tests/sql/test_table_funcs.py new file mode 100644 index 0000000000..16fbf040c7 --- /dev/null +++ b/tests/sql/test_table_funcs.py @@ -0,0 +1,7 @@ +import daft + + +def test_sql_read_parquet(): + df = daft.sql("SELECT * FROM read_parquet('tests/assets/parquet-data/mvp.parquet')").collect() + expected = daft.read_parquet("tests/assets/parquet-data/mvp.parquet").collect() + assert df.to_pydict() == expected.to_pydict()