diff --git a/Cargo.toml b/Cargo.toml index f029ca36..d25d07a6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,6 +32,7 @@ serde_json = "1.0.120" signal-hook = "0.3.17" signal-hook-async-std = "0.2.2" shared = { git = "https://github.com/paradedb/paradedb.git", rev = "4854652" } +strum = { version = "0.26.3", features = ["derive"] } supabase-wrappers = { git = "https://github.com/paradedb/wrappers.git", default-features = false, rev = "27af09b" } thiserror = "1.0.59" uuid = "1.9.1" diff --git a/src/duckdb/secret.rs b/src/duckdb/secret.rs index f3f484e9..16bae696 100644 --- a/src/duckdb/secret.rs +++ b/src/duckdb/secret.rs @@ -17,7 +17,9 @@ use anyhow::{anyhow, bail, Result}; use std::collections::HashMap; +use strum::{AsRefStr, EnumIter}; +#[derive(EnumIter, AsRefStr, PartialEq, Debug)] pub enum UserMappingOptions { // Universal Type, @@ -47,33 +49,6 @@ pub enum UserMappingOptions { } impl UserMappingOptions { - pub fn as_str(&self) -> &str { - match self { - Self::Type => "type", - Self::Provider => "provider", - Self::Scope => "scope", - Self::Chain => "chain", - Self::KeyId => "key_id", - Self::Secret => "secret", - Self::Region => "region", - Self::SessionToken => "session_token", - Self::Endpoint => "endpoint", - Self::UrlStyle => "url_style", - Self::UseSsl => "use_ssl", - Self::UrlCompatibilityMode => "url_compatibility_mode", - Self::AccountId => "account_id", - Self::ConnectionString => "connection_string", - Self::AccountName => "account_name", - Self::TenantId => "tenant_id", - Self::ClientId => "client_id", - Self::ClientSecret => "client_secret", - Self::ClientCertificatePath => "client_certificate_path", - Self::HttpProxy => "http_proxy", - Self::ProxyUserName => "proxy_user_name", - Self::ProxyPassword => "proxy_password", - } - } - #[allow(unused)] pub fn is_required(&self) -> bool { match self { @@ -101,35 +76,6 @@ impl UserMappingOptions { Self::ProxyPassword => false, } } - - #[allow(unused)] - pub fn iter() -> impl Iterator { - [ - Self::Type, - Self::Provider, - Self::Scope, - Self::Chain, - Self::KeyId, - Self::Secret, - Self::Region, - Self::SessionToken, - Self::Endpoint, - Self::UrlStyle, - Self::UseSsl, - Self::UrlCompatibilityMode, - Self::AccountId, - Self::ConnectionString, - Self::AccountName, - Self::TenantId, - Self::ClientId, - Self::ClientSecret, - Self::ClientCertificatePath, - Self::HttpProxy, - Self::ProxyUserName, - Self::ProxyPassword, - ] - .into_iter() - } } pub fn create_secret( @@ -143,95 +89,95 @@ pub fn create_secret( let secret_type = Some(format!( "TYPE {}", user_mapping_options - .get(UserMappingOptions::Type.as_str()) + .get(UserMappingOptions::Type.as_ref()) .ok_or_else(|| anyhow!("type option required for USER MAPPING"))? .as_str() )); let provider = user_mapping_options - .get(UserMappingOptions::Provider.as_str()) + .get(UserMappingOptions::Provider.as_ref()) .map(|provider| format!("PROVIDER {}", provider)); let scope = user_mapping_options - .get(UserMappingOptions::Scope.as_str()) + .get(UserMappingOptions::Scope.as_ref()) .map(|scope| format!("SCOPE {}", scope)); let chain = user_mapping_options - .get(UserMappingOptions::Chain.as_str()) + .get(UserMappingOptions::Chain.as_ref()) .map(|chain| format!("CHAIN '{}'", chain)); let key_id = user_mapping_options - .get(UserMappingOptions::KeyId.as_str()) + .get(UserMappingOptions::KeyId.as_ref()) .map(|key_id| format!("KEY_ID '{}'", key_id)); let secret = user_mapping_options - .get(UserMappingOptions::Secret.as_str()) + .get(UserMappingOptions::Secret.as_ref()) .map(|secret| format!("SECRET '{}'", secret)); let region = user_mapping_options - .get(UserMappingOptions::Region.as_str()) + .get(UserMappingOptions::Region.as_ref()) .map(|region| format!("REGION '{}'", region)); let session_token = user_mapping_options - .get(UserMappingOptions::SessionToken.as_str()) + .get(UserMappingOptions::SessionToken.as_ref()) .map(|session_token| format!("SESSION_TOKEN '{}'", session_token)); let endpoint = user_mapping_options - .get(UserMappingOptions::Endpoint.as_str()) + .get(UserMappingOptions::Endpoint.as_ref()) .map(|endpoint| format!("ENDPOINT '{}'", endpoint)); let url_style = user_mapping_options - .get(UserMappingOptions::UrlStyle.as_str()) + .get(UserMappingOptions::UrlStyle.as_ref()) .map(|url_style| format!("URL_STYLE '{}'", url_style)); let use_ssl = user_mapping_options - .get(UserMappingOptions::UseSsl.as_str()) + .get(UserMappingOptions::UseSsl.as_ref()) .map(|use_ssl| format!("USE_SSL {}", use_ssl)); let url_compatibility_mode = user_mapping_options - .get(UserMappingOptions::UrlCompatibilityMode.as_str()) + .get(UserMappingOptions::UrlCompatibilityMode.as_ref()) .map(|url_compatibility_mode| format!("URL_COMPATIBILITY_MODE {}", url_compatibility_mode)); let account_id = user_mapping_options - .get(UserMappingOptions::AccountId.as_str()) + .get(UserMappingOptions::AccountId.as_ref()) .map(|account_id| format!("ACCOUNT_ID '{}'", account_id)); let connection_string = user_mapping_options - .get(UserMappingOptions::ConnectionString.as_str()) + .get(UserMappingOptions::ConnectionString.as_ref()) .map(|connection_string| format!("CONNECTION_STRING '{}'", connection_string)); let account_name = user_mapping_options - .get(UserMappingOptions::AccountName.as_str()) + .get(UserMappingOptions::AccountName.as_ref()) .map(|account_name| format!("ACCOUNT_NAME '{}'", account_name)); let tenant_id = user_mapping_options - .get(UserMappingOptions::TenantId.as_str()) + .get(UserMappingOptions::TenantId.as_ref()) .map(|tenant_id| format!("TENANT_ID '{}'", tenant_id)); let client_id = user_mapping_options - .get(UserMappingOptions::ClientId.as_str()) + .get(UserMappingOptions::ClientId.as_ref()) .map(|client_id| format!("CLIENT_ID '{}'", client_id)); let client_secret = user_mapping_options - .get(UserMappingOptions::ClientSecret.as_str()) + .get(UserMappingOptions::ClientSecret.as_ref()) .map(|client_secret| format!("CLIENT_SECRET '{}'", client_secret)); let client_certificate_path = user_mapping_options - .get(UserMappingOptions::ClientCertificatePath.as_str()) + .get(UserMappingOptions::ClientCertificatePath.as_ref()) .map(|client_certificate_path| { format!("CLIENT_CERTIFICATE_PATH '{}'", client_certificate_path) }); let http_proxy = user_mapping_options - .get(UserMappingOptions::HttpProxy.as_str()) + .get(UserMappingOptions::HttpProxy.as_ref()) .map(|http_proxy| format!("HTTP_PROXY '{}'", http_proxy)); let proxy_user_name = user_mapping_options - .get(UserMappingOptions::ProxyUserName.as_str()) + .get(UserMappingOptions::ProxyUserName.as_ref()) .map(|proxy_user_name| format!("PROXY_USER_NAME '{}'", proxy_user_name)); let proxy_password = user_mapping_options - .get(UserMappingOptions::ProxyPassword.as_str()) + .get(UserMappingOptions::ProxyPassword.as_ref()) .map(|proxy_password| format!("PROXY_PASSWORD '{}'", proxy_password)); let secret_string = vec![ @@ -278,44 +224,44 @@ mod tests { let secret_name = "s3_secret"; let user_mapping_options = HashMap::from([ ( - UserMappingOptions::Type.as_str().to_string(), + UserMappingOptions::Type.as_ref().to_string(), "S3".to_string(), ), ( - UserMappingOptions::Provider.as_str().to_string(), + UserMappingOptions::Provider.as_ref().to_string(), "CONFIG".to_string(), ), ( - UserMappingOptions::KeyId.as_str().to_string(), + UserMappingOptions::KeyId.as_ref().to_string(), "key_id".to_string(), ), ( - UserMappingOptions::Secret.as_str().to_string(), + UserMappingOptions::Secret.as_ref().to_string(), "secret".to_string(), ), ( - UserMappingOptions::Region.as_str().to_string(), + UserMappingOptions::Region.as_ref().to_string(), "us-west-2".to_string(), ), ( - UserMappingOptions::SessionToken.as_str().to_string(), + UserMappingOptions::SessionToken.as_ref().to_string(), "session_token".to_string(), ), ( - UserMappingOptions::Endpoint.as_str().to_string(), + UserMappingOptions::Endpoint.as_ref().to_string(), "s3.amazonaws.com".to_string(), ), ( - UserMappingOptions::UrlStyle.as_str().to_string(), + UserMappingOptions::UrlStyle.as_ref().to_string(), "vhost".to_string(), ), ( - UserMappingOptions::UseSsl.as_str().to_string(), + UserMappingOptions::UseSsl.as_ref().to_string(), "true".to_string(), ), ( UserMappingOptions::UrlCompatibilityMode - .as_str() + .as_ref() .to_string(), "true".to_string(), ), @@ -336,11 +282,11 @@ mod tests { let secret_name = "s3_secret"; let user_mapping_options = HashMap::from([ ( - UserMappingOptions::Type.as_str().to_string(), + UserMappingOptions::Type.as_ref().to_string(), "S3".to_string(), ), ( - UserMappingOptions::Provider.as_str().to_string(), + UserMappingOptions::Provider.as_ref().to_string(), "TENANT_ID".to_string(), ), ]); @@ -358,27 +304,27 @@ mod tests { let secret_name = "azure_secret"; let user_mapping_options = HashMap::from([ ( - UserMappingOptions::Type.as_str().to_string(), + UserMappingOptions::Type.as_ref().to_string(), "AZURE".to_string(), ), ( - UserMappingOptions::Provider.as_str().to_string(), + UserMappingOptions::Provider.as_ref().to_string(), "CONFIG".to_string(), ), ( - UserMappingOptions::ConnectionString.as_str().to_string(), + UserMappingOptions::ConnectionString.as_ref().to_string(), "connection_string".to_string(), ), ( - UserMappingOptions::HttpProxy.as_str().to_string(), + UserMappingOptions::HttpProxy.as_ref().to_string(), "http_proxy".to_string(), ), ( - UserMappingOptions::ProxyUserName.as_str().to_string(), + UserMappingOptions::ProxyUserName.as_ref().to_string(), "proxy_user_name".to_string(), ), ( - UserMappingOptions::ProxyPassword.as_str().to_string(), + UserMappingOptions::ProxyPassword.as_ref().to_string(), "proxy_password".to_string(), ), ]); @@ -397,7 +343,7 @@ mod tests { fn test_create_type_invalid() { let secret_name = "invalid_secret"; let user_mapping_options = HashMap::from([( - UserMappingOptions::Type.as_str().to_string(), + UserMappingOptions::Type.as_ref().to_string(), "INVALID".to_string(), )]);