From 5f58f8f9e1530c2ab7e8ae3e12dfef1d93725439 Mon Sep 17 00:00:00 2001 From: Evance Soumaoro Date: Thu, 8 Aug 2024 17:29:48 +0000 Subject: [PATCH 1/2] added strum to automatically provide as_ref & iter methods to enums --- Cargo.toml | 1 + src/duckdb/secret.rs | 140 +++++++++++++------------------------------ 2 files changed, 44 insertions(+), 97 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index f029ca36..d25d07a6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,6 +32,7 @@ serde_json = "1.0.120" signal-hook = "0.3.17" signal-hook-async-std = "0.2.2" shared = { git = "https://github.com/paradedb/paradedb.git", rev = "4854652" } +strum = { version = "0.26.3", features = ["derive"] } supabase-wrappers = { git = "https://github.com/paradedb/wrappers.git", default-features = false, rev = "27af09b" } thiserror = "1.0.59" uuid = "1.9.1" diff --git a/src/duckdb/secret.rs b/src/duckdb/secret.rs index f3f484e9..16bae696 100644 --- a/src/duckdb/secret.rs +++ b/src/duckdb/secret.rs @@ -17,7 +17,9 @@ use anyhow::{anyhow, bail, Result}; use std::collections::HashMap; +use strum::{AsRefStr, EnumIter}; +#[derive(EnumIter, AsRefStr, PartialEq, Debug)] pub enum UserMappingOptions { // Universal Type, @@ -47,33 +49,6 @@ pub enum UserMappingOptions { } impl UserMappingOptions { - pub fn as_str(&self) -> &str { - match self { - Self::Type => "type", - Self::Provider => "provider", - Self::Scope => "scope", - Self::Chain => "chain", - Self::KeyId => "key_id", - Self::Secret => "secret", - Self::Region => "region", - Self::SessionToken => "session_token", - Self::Endpoint => "endpoint", - Self::UrlStyle => "url_style", - Self::UseSsl => "use_ssl", - Self::UrlCompatibilityMode => "url_compatibility_mode", - Self::AccountId => "account_id", - Self::ConnectionString => "connection_string", - Self::AccountName => "account_name", - Self::TenantId => "tenant_id", - Self::ClientId => "client_id", - Self::ClientSecret => "client_secret", - Self::ClientCertificatePath => "client_certificate_path", - Self::HttpProxy => "http_proxy", - Self::ProxyUserName => "proxy_user_name", - Self::ProxyPassword => "proxy_password", - } - } - #[allow(unused)] pub fn is_required(&self) -> bool { match self { @@ -101,35 +76,6 @@ impl UserMappingOptions { Self::ProxyPassword => false, } } - - #[allow(unused)] - pub fn iter() -> impl Iterator { - [ - Self::Type, - Self::Provider, - Self::Scope, - Self::Chain, - Self::KeyId, - Self::Secret, - Self::Region, - Self::SessionToken, - Self::Endpoint, - Self::UrlStyle, - Self::UseSsl, - Self::UrlCompatibilityMode, - Self::AccountId, - Self::ConnectionString, - Self::AccountName, - Self::TenantId, - Self::ClientId, - Self::ClientSecret, - Self::ClientCertificatePath, - Self::HttpProxy, - Self::ProxyUserName, - Self::ProxyPassword, - ] - .into_iter() - } } pub fn create_secret( @@ -143,95 +89,95 @@ pub fn create_secret( let secret_type = Some(format!( "TYPE {}", user_mapping_options - .get(UserMappingOptions::Type.as_str()) + .get(UserMappingOptions::Type.as_ref()) .ok_or_else(|| anyhow!("type option required for USER MAPPING"))? .as_str() )); let provider = user_mapping_options - .get(UserMappingOptions::Provider.as_str()) + .get(UserMappingOptions::Provider.as_ref()) .map(|provider| format!("PROVIDER {}", provider)); let scope = user_mapping_options - .get(UserMappingOptions::Scope.as_str()) + .get(UserMappingOptions::Scope.as_ref()) .map(|scope| format!("SCOPE {}", scope)); let chain = user_mapping_options - .get(UserMappingOptions::Chain.as_str()) + .get(UserMappingOptions::Chain.as_ref()) .map(|chain| format!("CHAIN '{}'", chain)); let key_id = user_mapping_options - .get(UserMappingOptions::KeyId.as_str()) + .get(UserMappingOptions::KeyId.as_ref()) .map(|key_id| format!("KEY_ID '{}'", key_id)); let secret = user_mapping_options - .get(UserMappingOptions::Secret.as_str()) + .get(UserMappingOptions::Secret.as_ref()) .map(|secret| format!("SECRET '{}'", secret)); let region = user_mapping_options - .get(UserMappingOptions::Region.as_str()) + .get(UserMappingOptions::Region.as_ref()) .map(|region| format!("REGION '{}'", region)); let session_token = user_mapping_options - .get(UserMappingOptions::SessionToken.as_str()) + .get(UserMappingOptions::SessionToken.as_ref()) .map(|session_token| format!("SESSION_TOKEN '{}'", session_token)); let endpoint = user_mapping_options - .get(UserMappingOptions::Endpoint.as_str()) + .get(UserMappingOptions::Endpoint.as_ref()) .map(|endpoint| format!("ENDPOINT '{}'", endpoint)); let url_style = user_mapping_options - .get(UserMappingOptions::UrlStyle.as_str()) + .get(UserMappingOptions::UrlStyle.as_ref()) .map(|url_style| format!("URL_STYLE '{}'", url_style)); let use_ssl = user_mapping_options - .get(UserMappingOptions::UseSsl.as_str()) + .get(UserMappingOptions::UseSsl.as_ref()) .map(|use_ssl| format!("USE_SSL {}", use_ssl)); let url_compatibility_mode = user_mapping_options - .get(UserMappingOptions::UrlCompatibilityMode.as_str()) + .get(UserMappingOptions::UrlCompatibilityMode.as_ref()) .map(|url_compatibility_mode| format!("URL_COMPATIBILITY_MODE {}", url_compatibility_mode)); let account_id = user_mapping_options - .get(UserMappingOptions::AccountId.as_str()) + .get(UserMappingOptions::AccountId.as_ref()) .map(|account_id| format!("ACCOUNT_ID '{}'", account_id)); let connection_string = user_mapping_options - .get(UserMappingOptions::ConnectionString.as_str()) + .get(UserMappingOptions::ConnectionString.as_ref()) .map(|connection_string| format!("CONNECTION_STRING '{}'", connection_string)); let account_name = user_mapping_options - .get(UserMappingOptions::AccountName.as_str()) + .get(UserMappingOptions::AccountName.as_ref()) .map(|account_name| format!("ACCOUNT_NAME '{}'", account_name)); let tenant_id = user_mapping_options - .get(UserMappingOptions::TenantId.as_str()) + .get(UserMappingOptions::TenantId.as_ref()) .map(|tenant_id| format!("TENANT_ID '{}'", tenant_id)); let client_id = user_mapping_options - .get(UserMappingOptions::ClientId.as_str()) + .get(UserMappingOptions::ClientId.as_ref()) .map(|client_id| format!("CLIENT_ID '{}'", client_id)); let client_secret = user_mapping_options - .get(UserMappingOptions::ClientSecret.as_str()) + .get(UserMappingOptions::ClientSecret.as_ref()) .map(|client_secret| format!("CLIENT_SECRET '{}'", client_secret)); let client_certificate_path = user_mapping_options - .get(UserMappingOptions::ClientCertificatePath.as_str()) + .get(UserMappingOptions::ClientCertificatePath.as_ref()) .map(|client_certificate_path| { format!("CLIENT_CERTIFICATE_PATH '{}'", client_certificate_path) }); let http_proxy = user_mapping_options - .get(UserMappingOptions::HttpProxy.as_str()) + .get(UserMappingOptions::HttpProxy.as_ref()) .map(|http_proxy| format!("HTTP_PROXY '{}'", http_proxy)); let proxy_user_name = user_mapping_options - .get(UserMappingOptions::ProxyUserName.as_str()) + .get(UserMappingOptions::ProxyUserName.as_ref()) .map(|proxy_user_name| format!("PROXY_USER_NAME '{}'", proxy_user_name)); let proxy_password = user_mapping_options - .get(UserMappingOptions::ProxyPassword.as_str()) + .get(UserMappingOptions::ProxyPassword.as_ref()) .map(|proxy_password| format!("PROXY_PASSWORD '{}'", proxy_password)); let secret_string = vec![ @@ -278,44 +224,44 @@ mod tests { let secret_name = "s3_secret"; let user_mapping_options = HashMap::from([ ( - UserMappingOptions::Type.as_str().to_string(), + UserMappingOptions::Type.as_ref().to_string(), "S3".to_string(), ), ( - UserMappingOptions::Provider.as_str().to_string(), + UserMappingOptions::Provider.as_ref().to_string(), "CONFIG".to_string(), ), ( - UserMappingOptions::KeyId.as_str().to_string(), + UserMappingOptions::KeyId.as_ref().to_string(), "key_id".to_string(), ), ( - UserMappingOptions::Secret.as_str().to_string(), + UserMappingOptions::Secret.as_ref().to_string(), "secret".to_string(), ), ( - UserMappingOptions::Region.as_str().to_string(), + UserMappingOptions::Region.as_ref().to_string(), "us-west-2".to_string(), ), ( - UserMappingOptions::SessionToken.as_str().to_string(), + UserMappingOptions::SessionToken.as_ref().to_string(), "session_token".to_string(), ), ( - UserMappingOptions::Endpoint.as_str().to_string(), + UserMappingOptions::Endpoint.as_ref().to_string(), "s3.amazonaws.com".to_string(), ), ( - UserMappingOptions::UrlStyle.as_str().to_string(), + UserMappingOptions::UrlStyle.as_ref().to_string(), "vhost".to_string(), ), ( - UserMappingOptions::UseSsl.as_str().to_string(), + UserMappingOptions::UseSsl.as_ref().to_string(), "true".to_string(), ), ( UserMappingOptions::UrlCompatibilityMode - .as_str() + .as_ref() .to_string(), "true".to_string(), ), @@ -336,11 +282,11 @@ mod tests { let secret_name = "s3_secret"; let user_mapping_options = HashMap::from([ ( - UserMappingOptions::Type.as_str().to_string(), + UserMappingOptions::Type.as_ref().to_string(), "S3".to_string(), ), ( - UserMappingOptions::Provider.as_str().to_string(), + UserMappingOptions::Provider.as_ref().to_string(), "TENANT_ID".to_string(), ), ]); @@ -358,27 +304,27 @@ mod tests { let secret_name = "azure_secret"; let user_mapping_options = HashMap::from([ ( - UserMappingOptions::Type.as_str().to_string(), + UserMappingOptions::Type.as_ref().to_string(), "AZURE".to_string(), ), ( - UserMappingOptions::Provider.as_str().to_string(), + UserMappingOptions::Provider.as_ref().to_string(), "CONFIG".to_string(), ), ( - UserMappingOptions::ConnectionString.as_str().to_string(), + UserMappingOptions::ConnectionString.as_ref().to_string(), "connection_string".to_string(), ), ( - UserMappingOptions::HttpProxy.as_str().to_string(), + UserMappingOptions::HttpProxy.as_ref().to_string(), "http_proxy".to_string(), ), ( - UserMappingOptions::ProxyUserName.as_str().to_string(), + UserMappingOptions::ProxyUserName.as_ref().to_string(), "proxy_user_name".to_string(), ), ( - UserMappingOptions::ProxyPassword.as_str().to_string(), + UserMappingOptions::ProxyPassword.as_ref().to_string(), "proxy_password".to_string(), ), ]); @@ -397,7 +343,7 @@ mod tests { fn test_create_type_invalid() { let secret_name = "invalid_secret"; let user_mapping_options = HashMap::from([( - UserMappingOptions::Type.as_str().to_string(), + UserMappingOptions::Type.as_ref().to_string(), "INVALID".to_string(), )]); From 269172f32b35d12f866e5c2c930fcd1348ce7f55 Mon Sep 17 00:00:00 2001 From: Ming Ying Date: Thu, 22 Aug 2024 10:54:23 -0400 Subject: [PATCH 2/2] convert to strum --- src/duckdb/csv.rs | 244 +++++++++++++++++------------------------- src/duckdb/delta.rs | 19 ++-- src/duckdb/iceberg.rs | 23 ++-- src/duckdb/parquet.rs | 76 +++++-------- src/duckdb/secret.rs | 22 ++++ src/fdw/csv.rs | 5 +- src/fdw/delta.rs | 5 +- src/fdw/iceberg.rs | 5 +- src/fdw/parquet.rs | 5 +- 9 files changed, 178 insertions(+), 226 deletions(-) diff --git a/src/duckdb/csv.rs b/src/duckdb/csv.rs index b13008bf..9b588bd9 100644 --- a/src/duckdb/csv.rs +++ b/src/duckdb/csv.rs @@ -17,84 +17,81 @@ use anyhow::{anyhow, Result}; use std::collections::HashMap; +use strum::{AsRefStr, EnumIter}; use super::utils; +#[derive(EnumIter, AsRefStr, PartialEq, Debug)] pub enum CsvOption { + #[strum(serialize = "all_varchar")] AllVarchar, + #[strum(serialize = "allow_quoted_nulls")] AllowQuotedNulls, + #[strum(serialize = "auto_detect")] AutoDetect, + #[strum(serialize = "auto_type_candidates")] AutoTypeCandidates, + #[strum(serialize = "columns")] Columns, + #[strum(serialize = "compression")] Compression, + #[strum(serialize = "dateformat")] Dateformat, + #[strum(serialize = "decimal_separator")] DecimalSeparator, + #[strum(serialize = "delim")] Delim, + #[strum(serialize = "escape")] Escape, + #[strum(serialize = "filename")] Filename, + #[strum(serialize = "files")] Files, + #[strum(serialize = "force_not_null")] ForceNotNull, + #[strum(serialize = "header")] Header, + #[strum(serialize = "hive_partitioning")] HivePartitioning, + #[strum(serialize = "hive_types")] HiveTypes, + #[strum(serialize = "hive_types_autocast")] HiveTypesAutocast, + #[strum(serialize = "ignore_errors")] IgnoreErrors, + #[strum(serialize = "max_line_size")] MaxLineSize, + #[strum(serialize = "names")] Names, + #[strum(serialize = "new_line")] NewLine, + #[strum(serialize = "normalize_names")] NormalizeNames, + #[strum(serialize = "null_padding")] NullPadding, + #[strum(serialize = "nullstr")] Nullstr, + #[strum(serialize = "parallel")] Parallel, + #[strum(serialize = "preserve_casing")] PreserveCasing, + #[strum(serialize = "quote")] Quote, + #[strum(serialize = "sample_size")] SampleSize, + #[strum(serialize = "sep")] Sep, + #[strum(serialize = "skip")] Skip, + #[strum(serialize = "timestampformat")] Timestampformat, + #[strum(serialize = "types")] Types, + #[strum(serialize = "union_by_name")] UnionByName, } impl CsvOption { - pub fn as_str(&self) -> &str { - match self { - Self::AllVarchar => "all_varchar", - Self::AllowQuotedNulls => "allow_quoted_nulls", - Self::AutoDetect => "auto_detect", - Self::AutoTypeCandidates => "auto_type_candidates", - Self::Columns => "columns", - Self::Compression => "compression", - Self::Dateformat => "dateformat", - Self::DecimalSeparator => "decimal_separator", - Self::Delim => "delim", - Self::Escape => "escape", - Self::Filename => "filename", - Self::Files => "files", - Self::ForceNotNull => "force_not_null", - Self::Header => "header", - Self::HivePartitioning => "hive_partitioning", - Self::HiveTypes => "hive_types", - Self::HiveTypesAutocast => "hive_types_autocast", - Self::IgnoreErrors => "ignore_errors", - Self::MaxLineSize => "max_line_size", - Self::Names => "names", - Self::NewLine => "new_line", - Self::NormalizeNames => "normalize_names", - Self::NullPadding => "null_padding", - Self::Nullstr => "nullstr", - Self::Parallel => "parallel", - Self::PreserveCasing => "preserve_casing", - Self::Quote => "quote", - Self::SampleSize => "sample_size", - Self::Sep => "sep", - Self::Skip => "skip", - Self::Timestampformat => "timestampformat", - Self::Types => "types", - Self::UnionByName => "union_by_name", - } - } - pub fn is_required(&self) -> bool { match self { Self::AllVarchar => false, @@ -132,45 +129,6 @@ impl CsvOption { Self::UnionByName => false, } } - - pub fn iter() -> impl Iterator { - [ - Self::AllVarchar, - Self::AllowQuotedNulls, - Self::AutoDetect, - Self::AutoTypeCandidates, - Self::Columns, - Self::Compression, - Self::Dateformat, - Self::DecimalSeparator, - Self::Delim, - Self::Escape, - Self::Filename, - Self::Files, - Self::ForceNotNull, - Self::Header, - Self::HivePartitioning, - Self::HiveTypes, - Self::HiveTypesAutocast, - Self::IgnoreErrors, - Self::MaxLineSize, - Self::Names, - Self::NewLine, - Self::NormalizeNames, - Self::NullPadding, - Self::Nullstr, - Self::Parallel, - Self::PreserveCasing, - Self::Quote, - Self::SampleSize, - Self::Sep, - Self::Skip, - Self::Timestampformat, - Self::Types, - Self::UnionByName, - ] - .into_iter() - } } pub fn create_view( @@ -180,132 +138,132 @@ pub fn create_view( ) -> Result { let files = Some(utils::format_csv( table_options - .get(CsvOption::Files.as_str()) + .get(CsvOption::Files.as_ref()) .ok_or_else(|| anyhow!("files option is required"))?, )); let all_varchar = table_options - .get(CsvOption::AllVarchar.as_str()) + .get(CsvOption::AllVarchar.as_ref()) .map(|option| format!("all_varchar = {option}")); let allow_quoted_nulls = table_options - .get(CsvOption::AllowQuotedNulls.as_str()) + .get(CsvOption::AllowQuotedNulls.as_ref()) .map(|option| format!("allow_quoted_nulls = {option}")); let auto_detect = table_options - .get(CsvOption::AutoDetect.as_str()) + .get(CsvOption::AutoDetect.as_ref()) .map(|option| format!("auto_detect = {option}")); let auto_type_candidates = table_options - .get(CsvOption::AutoTypeCandidates.as_str()) + .get(CsvOption::AutoTypeCandidates.as_ref()) .map(|option| format!("auto_type_candidates = {}", utils::format_csv(option))); let columns = table_options - .get(CsvOption::Columns.as_str()) + .get(CsvOption::Columns.as_ref()) .map(|option| format!("columns = {option}")); let compression = table_options - .get(CsvOption::Compression.as_str()) + .get(CsvOption::Compression.as_ref()) .map(|option| format!("compression = '{option}'")); let dateformat = table_options - .get(CsvOption::Dateformat.as_str()) + .get(CsvOption::Dateformat.as_ref()) .map(|option| format!("dateformat = '{option}'")); let decimal_separator = table_options - .get(CsvOption::DecimalSeparator.as_str()) + .get(CsvOption::DecimalSeparator.as_ref()) .map(|option| format!("decimal_separator = '{option}'")); let delim = table_options - .get(CsvOption::Delim.as_str()) + .get(CsvOption::Delim.as_ref()) .map(|option| format!("delim = '{option}'")); let escape = table_options - .get(CsvOption::Escape.as_str()) + .get(CsvOption::Escape.as_ref()) .map(|option| format!("escape = '{option}'")); let filename = table_options - .get(CsvOption::Filename.as_str()) + .get(CsvOption::Filename.as_ref()) .map(|option| format!("filename = {option}")); let force_not_null = table_options - .get(CsvOption::ForceNotNull.as_str()) + .get(CsvOption::ForceNotNull.as_ref()) .map(|option| format!("force_not_null = {}", utils::format_csv(option))); let header = table_options - .get(CsvOption::Header.as_str()) + .get(CsvOption::Header.as_ref()) .map(|option| format!("header = {option}")); let hive_partitioning = table_options - .get(CsvOption::HivePartitioning.as_str()) + .get(CsvOption::HivePartitioning.as_ref()) .map(|option| format!("hive_partitioning = {option}")); let hive_types = table_options - .get(CsvOption::HiveTypes.as_str()) + .get(CsvOption::HiveTypes.as_ref()) .map(|option| format!("hive_types = {option}")); let hive_types_autocast = table_options - .get(CsvOption::HiveTypesAutocast.as_str()) + .get(CsvOption::HiveTypesAutocast.as_ref()) .map(|option| format!("hive_types_autocast = {option}")); let ignore_errors = table_options - .get(CsvOption::IgnoreErrors.as_str()) + .get(CsvOption::IgnoreErrors.as_ref()) .map(|option| format!("ignore_errors = {option}")); let max_line_size = table_options - .get(CsvOption::MaxLineSize.as_str()) + .get(CsvOption::MaxLineSize.as_ref()) .map(|option| format!("max_line_size = {option}")); let names = table_options - .get(CsvOption::Names.as_str()) + .get(CsvOption::Names.as_ref()) .map(|option| format!("names = {}", utils::format_csv(option))); let new_line = table_options - .get(CsvOption::NewLine.as_str()) + .get(CsvOption::NewLine.as_ref()) .map(|option| format!("new_line = '{option}'")); let normalize_names = table_options - .get(CsvOption::NormalizeNames.as_str()) + .get(CsvOption::NormalizeNames.as_ref()) .map(|option| format!("normalize_names = {option}")); let null_padding = table_options - .get(CsvOption::NullPadding.as_str()) + .get(CsvOption::NullPadding.as_ref()) .map(|option| format!("null_padding = {option}")); let nullstr = table_options - .get(CsvOption::Nullstr.as_str()) + .get(CsvOption::Nullstr.as_ref()) .map(|option| format!("nullstr = {}", utils::format_csv(option))); let parallel = table_options - .get(CsvOption::Parallel.as_str()) + .get(CsvOption::Parallel.as_ref()) .map(|option| format!("parallel = {option}")); let quote = table_options - .get(CsvOption::Quote.as_str()) + .get(CsvOption::Quote.as_ref()) .map(|option| format!("quote = '{option}'")); let sample_size = table_options - .get(CsvOption::SampleSize.as_str()) + .get(CsvOption::SampleSize.as_ref()) .map(|option| format!("sample_size = {option}")); let sep = table_options - .get(CsvOption::Sep.as_str()) + .get(CsvOption::Sep.as_ref()) .map(|option| format!("sep = '{option}'")); let skip = table_options - .get(CsvOption::Skip.as_str()) + .get(CsvOption::Skip.as_ref()) .map(|option| format!("skip = {option}")); let timestampformat = table_options - .get(CsvOption::Timestampformat.as_str()) + .get(CsvOption::Timestampformat.as_ref()) .map(|option| format!("timestampformat = '{option}'")); let types = table_options - .get(CsvOption::Types.as_str()) + .get(CsvOption::Types.as_ref()) .map(|option| format!("types = {}", utils::format_csv(option))); let union_by_name = table_options - .get(CsvOption::UnionByName.as_str()) + .get(CsvOption::UnionByName.as_ref()) .map(|option| format!("union_by_name = {option}")); let create_csv_str = vec![ @@ -360,7 +318,7 @@ mod tests { let table_name = "test"; let schema_name = "main"; let table_options = HashMap::from([( - CsvOption::Files.as_str().to_string(), + CsvOption::Files.as_ref().to_string(), "/data/file.csv".to_string(), )]); let expected = @@ -381,7 +339,7 @@ mod tests { let table_name = "test"; let schema_name = "main"; let table_options = HashMap::from([( - CsvOption::Files.as_str().to_string(), + CsvOption::Files.as_ref().to_string(), "/data/file1.csv, /data/file2.csv".to_string(), )]); @@ -403,104 +361,104 @@ mod tests { let schema_name = "main"; let table_options = HashMap::from([ ( - CsvOption::Files.as_str().to_string(), + CsvOption::Files.as_ref().to_string(), "/data/file.csv".to_string(), ), ( - CsvOption::AllVarchar.as_str().to_string(), + CsvOption::AllVarchar.as_ref().to_string(), "true".to_string(), ), ( - CsvOption::AllowQuotedNulls.as_str().to_string(), + CsvOption::AllowQuotedNulls.as_ref().to_string(), "true".to_string(), ), ( - CsvOption::AutoDetect.as_str().to_string(), + CsvOption::AutoDetect.as_ref().to_string(), "true".to_string(), ), ( - CsvOption::AutoTypeCandidates.as_str().to_string(), + CsvOption::AutoTypeCandidates.as_ref().to_string(), "BIGINT, DATE".to_string(), ), ( - CsvOption::Columns.as_str().to_string(), + CsvOption::Columns.as_ref().to_string(), "{'col1': 'INTEGER', 'col2': 'VARCHAR'}".to_string(), ), ( - CsvOption::Compression.as_str().to_string(), + CsvOption::Compression.as_ref().to_string(), "gzip".to_string(), ), ( - CsvOption::Dateformat.as_str().to_string(), + CsvOption::Dateformat.as_ref().to_string(), "%d/%m/%Y".to_string(), ), ( - CsvOption::DecimalSeparator.as_str().to_string(), + CsvOption::DecimalSeparator.as_ref().to_string(), ".".to_string(), ), - (CsvOption::Delim.as_str().to_string(), ",".to_string()), - (CsvOption::Escape.as_str().to_string(), "\"".to_string()), - (CsvOption::Filename.as_str().to_string(), "true".to_string()), + (CsvOption::Delim.as_ref().to_string(), ",".to_string()), + (CsvOption::Escape.as_ref().to_string(), "\"".to_string()), + (CsvOption::Filename.as_ref().to_string(), "true".to_string()), ( - CsvOption::ForceNotNull.as_str().to_string(), + CsvOption::ForceNotNull.as_ref().to_string(), "col1, col2".to_string(), ), - (CsvOption::Header.as_str().to_string(), "true".to_string()), + (CsvOption::Header.as_ref().to_string(), "true".to_string()), ( - CsvOption::HivePartitioning.as_str().to_string(), + CsvOption::HivePartitioning.as_ref().to_string(), "true".to_string(), ), ( - CsvOption::HiveTypes.as_str().to_string(), + CsvOption::HiveTypes.as_ref().to_string(), "true".to_string(), ), ( - CsvOption::HiveTypesAutocast.as_str().to_string(), + CsvOption::HiveTypesAutocast.as_ref().to_string(), "true".to_string(), ), ( - CsvOption::IgnoreErrors.as_str().to_string(), + CsvOption::IgnoreErrors.as_ref().to_string(), "true".to_string(), ), ( - CsvOption::MaxLineSize.as_str().to_string(), + CsvOption::MaxLineSize.as_ref().to_string(), "1000".to_string(), ), ( - CsvOption::Names.as_str().to_string(), + CsvOption::Names.as_ref().to_string(), "col1, col2".to_string(), ), - (CsvOption::NewLine.as_str().to_string(), "\n".to_string()), + (CsvOption::NewLine.as_ref().to_string(), "\n".to_string()), ( - CsvOption::NormalizeNames.as_str().to_string(), + CsvOption::NormalizeNames.as_ref().to_string(), "true".to_string(), ), ( - CsvOption::NullPadding.as_str().to_string(), + CsvOption::NullPadding.as_ref().to_string(), "true".to_string(), ), ( - CsvOption::Nullstr.as_str().to_string(), + CsvOption::Nullstr.as_ref().to_string(), "none, null".to_string(), ), - (CsvOption::Parallel.as_str().to_string(), "true".to_string()), - (CsvOption::Quote.as_str().to_string(), "\"".to_string()), + (CsvOption::Parallel.as_ref().to_string(), "true".to_string()), + (CsvOption::Quote.as_ref().to_string(), "\"".to_string()), ( - CsvOption::SampleSize.as_str().to_string(), + CsvOption::SampleSize.as_ref().to_string(), "100".to_string(), ), - (CsvOption::Sep.as_str().to_string(), ",".to_string()), - (CsvOption::Skip.as_str().to_string(), "0".to_string()), + (CsvOption::Sep.as_ref().to_string(), ",".to_string()), + (CsvOption::Skip.as_ref().to_string(), "0".to_string()), ( - CsvOption::Timestampformat.as_str().to_string(), + CsvOption::Timestampformat.as_ref().to_string(), "yyyy-MM-dd HH:mm:ss".to_string(), ), ( - CsvOption::Types.as_str().to_string(), + CsvOption::Types.as_ref().to_string(), "BIGINT, VARCHAR".to_string(), ), ( - CsvOption::UnionByName.as_str().to_string(), + CsvOption::UnionByName.as_ref().to_string(), "true".to_string(), ), ]); diff --git a/src/duckdb/delta.rs b/src/duckdb/delta.rs index e4827045..5fef5c05 100644 --- a/src/duckdb/delta.rs +++ b/src/duckdb/delta.rs @@ -17,30 +17,23 @@ use anyhow::{anyhow, Result}; use std::collections::HashMap; +use strum::{AsRefStr, EnumIter}; +#[derive(EnumIter, AsRefStr, PartialEq, Debug)] pub enum DeltaOption { + #[strum(serialize = "files")] Files, + #[strum(serialize = "preserve_casing")] PreserveCasing, } impl DeltaOption { - pub fn as_str(&self) -> &str { - match self { - Self::Files => "files", - Self::PreserveCasing => "preserve_casing", - } - } - pub fn is_required(&self) -> bool { match self { Self::Files => true, Self::PreserveCasing => false, } } - - pub fn iter() -> impl Iterator { - [Self::Files, Self::PreserveCasing].into_iter() - } } pub fn create_view( @@ -51,7 +44,7 @@ pub fn create_view( let files = format!( "'{}'", table_options - .get(DeltaOption::Files.as_str()) + .get(DeltaOption::Files.as_ref()) .ok_or_else(|| anyhow!("files option is required"))? ); @@ -70,7 +63,7 @@ mod tests { let table_name = "test"; let schema_name = "main"; let table_options = HashMap::from([( - DeltaOption::Files.as_str().to_string(), + DeltaOption::Files.as_ref().to_string(), "/data/delta".to_string(), )]); diff --git a/src/duckdb/iceberg.rs b/src/duckdb/iceberg.rs index 14dcca87..98b0c093 100644 --- a/src/duckdb/iceberg.rs +++ b/src/duckdb/iceberg.rs @@ -17,22 +17,19 @@ use anyhow::{anyhow, Result}; use std::collections::HashMap; +use strum::{AsRefStr, EnumIter}; +#[derive(EnumIter, AsRefStr, PartialEq, Debug)] pub enum IcebergOption { + #[strum(serialize = "allow_moved_paths")] AllowMovedPaths, + #[strum(serialize = "files")] Files, + #[strum(serialize = "preserve_casing")] PreserveCasing, } impl IcebergOption { - pub fn as_str(&self) -> &str { - match self { - Self::AllowMovedPaths => "allow_moved_paths", - Self::Files => "files", - Self::PreserveCasing => "preserve_casing", - } - } - pub fn is_required(&self) -> bool { match self { Self::AllowMovedPaths => false, @@ -40,10 +37,6 @@ impl IcebergOption { Self::PreserveCasing => false, } } - - pub fn iter() -> impl Iterator { - [Self::AllowMovedPaths, Self::Files, Self::PreserveCasing].into_iter() - } } pub fn create_view( @@ -54,12 +47,12 @@ pub fn create_view( let files = Some(format!( "'{}'", table_options - .get(IcebergOption::Files.as_str()) + .get(IcebergOption::Files.as_ref()) .ok_or_else(|| anyhow!("files option is required"))? )); let allow_moved_paths = table_options - .get(IcebergOption::AllowMovedPaths.as_str()) + .get(IcebergOption::AllowMovedPaths.as_ref()) .map(|option| format!("allow_moved_paths = {option}")); let create_iceberg_str = [files, allow_moved_paths] @@ -81,7 +74,7 @@ mod tests { let table_name = "test"; let schema_name = "main"; let table_options = HashMap::from([( - IcebergOption::Files.as_str().to_string(), + IcebergOption::Files.as_ref().to_string(), "/data/iceberg".to_string(), )]); diff --git a/src/duckdb/parquet.rs b/src/duckdb/parquet.rs index 03f175a6..1aae0243 100644 --- a/src/duckdb/parquet.rs +++ b/src/duckdb/parquet.rs @@ -17,37 +17,34 @@ use anyhow::{anyhow, Result}; use std::collections::HashMap; +use strum::{AsRefStr, EnumIter}; use super::utils; +#[derive(EnumIter, AsRefStr, PartialEq, Debug)] pub enum ParquetOption { + #[strum(serialize = "binary_as_string")] BinaryAsString, + #[strum(serialize = "filename")] FileName, + #[strum(serialize = "file_row_number")] FileRowNumber, + #[strum(serialize = "files")] Files, + #[strum(serialize = "hive_partitioning")] HivePartitioning, + #[strum(serialize = "hive_types")] HiveTypes, + #[strum(serialize = "hive_types_autocast")] HiveTypesAutocast, + #[strum(serialize = "preserve_casing")] PreserveCasing, + #[strum(serialize = "union_by_name")] UnionByName, // TODO: EncryptionConfig } impl ParquetOption { - pub fn as_str(&self) -> &str { - match self { - Self::BinaryAsString => "binary_as_string", - Self::FileName => "file_name", - Self::FileRowNumber => "file_row_number", - Self::Files => "files", - Self::HivePartitioning => "hive_partitioning", - Self::HiveTypes => "hive_types", - Self::HiveTypesAutocast => "hive_types_autocast", - Self::PreserveCasing => "preserve_casing", - Self::UnionByName => "union_by_name", - } - } - pub fn is_required(&self) -> bool { match self { Self::BinaryAsString => false, @@ -61,21 +58,6 @@ impl ParquetOption { Self::UnionByName => false, } } - - pub fn iter() -> impl Iterator { - [ - Self::BinaryAsString, - Self::FileName, - Self::FileRowNumber, - Self::Files, - Self::HivePartitioning, - Self::HiveTypes, - Self::HiveTypesAutocast, - Self::PreserveCasing, - Self::UnionByName, - ] - .into_iter() - } } pub fn create_view( @@ -85,36 +67,36 @@ pub fn create_view( ) -> Result { let files = Some(utils::format_csv( table_options - .get(ParquetOption::Files.as_str()) + .get(ParquetOption::Files.as_ref()) .ok_or_else(|| anyhow!("files option is required"))?, )); let binary_as_string = table_options - .get(ParquetOption::BinaryAsString.as_str()) + .get(ParquetOption::BinaryAsString.as_ref()) .map(|option| format!("binary_as_string = {option}")); let file_name = table_options - .get(ParquetOption::FileName.as_str()) + .get(ParquetOption::FileName.as_ref()) .map(|option| format!("filename = {option}")); let file_row_number = table_options - .get(ParquetOption::FileRowNumber.as_str()) + .get(ParquetOption::FileRowNumber.as_ref()) .map(|option| format!("file_row_number = {option}")); let hive_partitioning = table_options - .get(ParquetOption::HivePartitioning.as_str()) + .get(ParquetOption::HivePartitioning.as_ref()) .map(|option| format!("hive_partitioning = {option}")); let hive_types = table_options - .get(ParquetOption::HiveTypes.as_str()) + .get(ParquetOption::HiveTypes.as_ref()) .map(|option| format!("hive_types = {option}")); let hive_types_autocast = table_options - .get(ParquetOption::HiveTypesAutocast.as_str()) + .get(ParquetOption::HiveTypesAutocast.as_ref()) .map(|option| format!("hive_types_autocast = {option}")); let union_by_name = table_options - .get(ParquetOption::UnionByName.as_str()) + .get(ParquetOption::UnionByName.as_ref()) .map(|option| format!("union_by_name = {option}")); let create_parquet_str = [ @@ -146,7 +128,7 @@ mod tests { let schema_name = "main"; let files = "/data/file.parquet"; let table_options = - HashMap::from([(ParquetOption::Files.as_str().to_string(), files.to_string())]); + HashMap::from([(ParquetOption::Files.as_ref().to_string(), files.to_string())]); let expected = "CREATE VIEW IF NOT EXISTS main.test AS SELECT * FROM read_parquet('/data/file.parquet')"; let actual = create_view(table_name, schema_name, table_options).unwrap(); @@ -165,7 +147,7 @@ mod tests { let schema_name = "main"; let files = "/data/file1.parquet, /data/file2.parquet"; let table_options = - HashMap::from([(ParquetOption::Files.as_str().to_string(), files.to_string())]); + HashMap::from([(ParquetOption::Files.as_ref().to_string(), files.to_string())]); let expected = "CREATE VIEW IF NOT EXISTS main.test AS SELECT * FROM read_parquet(['/data/file1.parquet', '/data/file2.parquet'])"; let actual = create_view(table_name, schema_name, table_options).unwrap(); @@ -185,35 +167,35 @@ mod tests { let schema_name = "main"; let table_options = HashMap::from([ ( - ParquetOption::Files.as_str().to_string(), + ParquetOption::Files.as_ref().to_string(), "/data/file.parquet".to_string(), ), ( - ParquetOption::BinaryAsString.as_str().to_string(), + ParquetOption::BinaryAsString.as_ref().to_string(), "true".to_string(), ), ( - ParquetOption::FileName.as_str().to_string(), + ParquetOption::FileName.as_ref().to_string(), "false".to_string(), ), ( - ParquetOption::FileRowNumber.as_str().to_string(), + ParquetOption::FileRowNumber.as_ref().to_string(), "true".to_string(), ), ( - ParquetOption::HivePartitioning.as_str().to_string(), + ParquetOption::HivePartitioning.as_ref().to_string(), "true".to_string(), ), ( - ParquetOption::HiveTypes.as_str().to_string(), + ParquetOption::HiveTypes.as_ref().to_string(), "{'release': DATE, 'orders': BIGINT}".to_string(), ), ( - ParquetOption::HiveTypesAutocast.as_str().to_string(), + ParquetOption::HiveTypesAutocast.as_ref().to_string(), "true".to_string(), ), ( - ParquetOption::UnionByName.as_str().to_string(), + ParquetOption::UnionByName.as_ref().to_string(), "true".to_string(), ), ]); diff --git a/src/duckdb/secret.rs b/src/duckdb/secret.rs index 16bae696..cce3c561 100644 --- a/src/duckdb/secret.rs +++ b/src/duckdb/secret.rs @@ -22,29 +22,51 @@ use strum::{AsRefStr, EnumIter}; #[derive(EnumIter, AsRefStr, PartialEq, Debug)] pub enum UserMappingOptions { // Universal + #[strum(serialize = "type")] Type, + #[strum(serialize = "provider")] Provider, + #[strum(serialize = "scope")] Scope, + #[strum(serialize = "chain")] Chain, // S3/GCS/R2 + #[strum(serialize = "key_id")] KeyId, + #[strum(serialize = "secret")] Secret, + #[strum(serialize = "region")] Region, + #[strum(serialize = "session_token")] SessionToken, + #[strum(serialize = "endpoint")] Endpoint, + #[strum(serialize = "url_style")] UrlStyle, + #[strum(serialize = "use_ssl")] UseSsl, + #[strum(serialize = "url_compatibility_mode")] UrlCompatibilityMode, + #[strum(serialize = "account_id")] AccountId, // Azure + #[strum(serialize = "connection_string")] ConnectionString, + #[strum(serialize = "account_name")] AccountName, + #[strum(serialize = "tenant_id")] TenantId, + #[strum(serialize = "client_id")] ClientId, + #[strum(serialize = "client_secret")] ClientSecret, + #[strum(serialize = "client_certificate_path")] ClientCertificatePath, + #[strum(serialize = "http_proxy")] HttpProxy, + #[strum(serialize = "proxy_user_name")] ProxyUserName, + #[strum(serialize = "proxy_password")] ProxyPassword, } diff --git a/src/fdw/csv.rs b/src/fdw/csv.rs index 6bdcd7c9..069804fd 100644 --- a/src/fdw/csv.rs +++ b/src/fdw/csv.rs @@ -20,6 +20,7 @@ use async_std::task; use duckdb::arrow::array::RecordBatch; use pgrx::*; use std::collections::HashMap; +use strum::IntoEnumIterator; use supabase_wrappers::prelude::*; use super::base::*; @@ -111,14 +112,14 @@ impl ForeignDataWrapper for CsvFdw { FOREIGN_SERVER_RELATION_ID => {} FOREIGN_TABLE_RELATION_ID => { let valid_options: Vec = CsvOption::iter() - .map(|opt| opt.as_str().to_string()) + .map(|opt| opt.as_ref().to_string()) .collect(); validate_options(opt_list.clone(), valid_options)?; for opt in CsvOption::iter() { if opt.is_required() { - check_options_contain(&opt_list, opt.as_str())?; + check_options_contain(&opt_list, opt.as_ref())?; } } } diff --git a/src/fdw/delta.rs b/src/fdw/delta.rs index 029b6e1b..3a06e584 100644 --- a/src/fdw/delta.rs +++ b/src/fdw/delta.rs @@ -20,6 +20,7 @@ use async_std::task; use duckdb::arrow::array::RecordBatch; use pgrx::*; use std::collections::HashMap; +use strum::IntoEnumIterator; use supabase_wrappers::prelude::*; use super::base::*; @@ -111,14 +112,14 @@ impl ForeignDataWrapper for DeltaFdw { FOREIGN_SERVER_RELATION_ID => {} FOREIGN_TABLE_RELATION_ID => { let valid_options: Vec = DeltaOption::iter() - .map(|opt| opt.as_str().to_string()) + .map(|opt| opt.as_ref().to_string()) .collect(); validate_options(opt_list.clone(), valid_options)?; for opt in DeltaOption::iter() { if opt.is_required() { - check_options_contain(&opt_list, opt.as_str())?; + check_options_contain(&opt_list, opt.as_ref())?; } } } diff --git a/src/fdw/iceberg.rs b/src/fdw/iceberg.rs index 2a821e08..33f3be86 100644 --- a/src/fdw/iceberg.rs +++ b/src/fdw/iceberg.rs @@ -20,6 +20,7 @@ use async_std::task; use duckdb::arrow::array::RecordBatch; use pgrx::*; use std::collections::HashMap; +use strum::IntoEnumIterator; use supabase_wrappers::prelude::*; use super::base::*; @@ -111,14 +112,14 @@ impl ForeignDataWrapper for IcebergFdw { FOREIGN_SERVER_RELATION_ID => {} FOREIGN_TABLE_RELATION_ID => { let valid_options: Vec = IcebergOption::iter() - .map(|opt| opt.as_str().to_string()) + .map(|opt| opt.as_ref().to_string()) .collect(); validate_options(opt_list.clone(), valid_options)?; for opt in IcebergOption::iter() { if opt.is_required() { - check_options_contain(&opt_list, opt.as_str())?; + check_options_contain(&opt_list, opt.as_ref())?; } } } diff --git a/src/fdw/parquet.rs b/src/fdw/parquet.rs index 1e0576b3..60fc3d83 100644 --- a/src/fdw/parquet.rs +++ b/src/fdw/parquet.rs @@ -20,6 +20,7 @@ use async_std::task; use duckdb::arrow::array::RecordBatch; use pgrx::*; use std::collections::HashMap; +use strum::IntoEnumIterator; use supabase_wrappers::prelude::*; use super::base::*; @@ -111,14 +112,14 @@ impl ForeignDataWrapper for ParquetFdw { FOREIGN_SERVER_RELATION_ID => {} FOREIGN_TABLE_RELATION_ID => { let valid_options: Vec = ParquetOption::iter() - .map(|opt| opt.as_str().to_string()) + .map(|opt| opt.as_ref().to_string()) .collect(); validate_options(opt_list.clone(), valid_options)?; for opt in ParquetOption::iter() { if opt.is_required() { - check_options_contain(&opt_list, opt.as_str())?; + check_options_contain(&opt_list, opt.as_ref())?; } } }