Skip to content

Commit

Permalink
added strum to automatically provide as_ref & iter methods to enums
Browse files Browse the repository at this point in the history
  • Loading branch information
evanxg852000 authored and rebasedming committed Aug 22, 2024
1 parent 3ec508c commit 5f58f8f
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 97 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ serde_json = "1.0.120"
signal-hook = "0.3.17"
signal-hook-async-std = "0.2.2"
shared = { git = "https://github.com/paradedb/paradedb.git", rev = "4854652" }
strum = { version = "0.26.3", features = ["derive"] }
supabase-wrappers = { git = "https://github.com/paradedb/wrappers.git", default-features = false, rev = "27af09b" }
thiserror = "1.0.59"
uuid = "1.9.1"
Expand Down
140 changes: 43 additions & 97 deletions src/duckdb/secret.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

use anyhow::{anyhow, bail, Result};
use std::collections::HashMap;
use strum::{AsRefStr, EnumIter};

#[derive(EnumIter, AsRefStr, PartialEq, Debug)]
pub enum UserMappingOptions {
// Universal
Type,
Expand Down Expand Up @@ -47,33 +49,6 @@ pub enum UserMappingOptions {
}

impl UserMappingOptions {
pub fn as_str(&self) -> &str {
match self {
Self::Type => "type",
Self::Provider => "provider",
Self::Scope => "scope",
Self::Chain => "chain",
Self::KeyId => "key_id",
Self::Secret => "secret",
Self::Region => "region",
Self::SessionToken => "session_token",
Self::Endpoint => "endpoint",
Self::UrlStyle => "url_style",
Self::UseSsl => "use_ssl",
Self::UrlCompatibilityMode => "url_compatibility_mode",
Self::AccountId => "account_id",
Self::ConnectionString => "connection_string",
Self::AccountName => "account_name",
Self::TenantId => "tenant_id",
Self::ClientId => "client_id",
Self::ClientSecret => "client_secret",
Self::ClientCertificatePath => "client_certificate_path",
Self::HttpProxy => "http_proxy",
Self::ProxyUserName => "proxy_user_name",
Self::ProxyPassword => "proxy_password",
}
}

#[allow(unused)]
pub fn is_required(&self) -> bool {
match self {
Expand Down Expand Up @@ -101,35 +76,6 @@ impl UserMappingOptions {
Self::ProxyPassword => false,
}
}

#[allow(unused)]
pub fn iter() -> impl Iterator<Item = Self> {
[
Self::Type,
Self::Provider,
Self::Scope,
Self::Chain,
Self::KeyId,
Self::Secret,
Self::Region,
Self::SessionToken,
Self::Endpoint,
Self::UrlStyle,
Self::UseSsl,
Self::UrlCompatibilityMode,
Self::AccountId,
Self::ConnectionString,
Self::AccountName,
Self::TenantId,
Self::ClientId,
Self::ClientSecret,
Self::ClientCertificatePath,
Self::HttpProxy,
Self::ProxyUserName,
Self::ProxyPassword,
]
.into_iter()
}
}

pub fn create_secret(
Expand All @@ -143,95 +89,95 @@ pub fn create_secret(
let secret_type = Some(format!(
"TYPE {}",
user_mapping_options
.get(UserMappingOptions::Type.as_str())
.get(UserMappingOptions::Type.as_ref())
.ok_or_else(|| anyhow!("type option required for USER MAPPING"))?
.as_str()
));

let provider = user_mapping_options
.get(UserMappingOptions::Provider.as_str())
.get(UserMappingOptions::Provider.as_ref())
.map(|provider| format!("PROVIDER {}", provider));

let scope = user_mapping_options
.get(UserMappingOptions::Scope.as_str())
.get(UserMappingOptions::Scope.as_ref())
.map(|scope| format!("SCOPE {}", scope));

let chain = user_mapping_options
.get(UserMappingOptions::Chain.as_str())
.get(UserMappingOptions::Chain.as_ref())
.map(|chain| format!("CHAIN '{}'", chain));

let key_id = user_mapping_options
.get(UserMappingOptions::KeyId.as_str())
.get(UserMappingOptions::KeyId.as_ref())
.map(|key_id| format!("KEY_ID '{}'", key_id));

let secret = user_mapping_options
.get(UserMappingOptions::Secret.as_str())
.get(UserMappingOptions::Secret.as_ref())
.map(|secret| format!("SECRET '{}'", secret));

let region = user_mapping_options
.get(UserMappingOptions::Region.as_str())
.get(UserMappingOptions::Region.as_ref())
.map(|region| format!("REGION '{}'", region));

let session_token = user_mapping_options
.get(UserMappingOptions::SessionToken.as_str())
.get(UserMappingOptions::SessionToken.as_ref())
.map(|session_token| format!("SESSION_TOKEN '{}'", session_token));

let endpoint = user_mapping_options
.get(UserMappingOptions::Endpoint.as_str())
.get(UserMappingOptions::Endpoint.as_ref())
.map(|endpoint| format!("ENDPOINT '{}'", endpoint));

let url_style = user_mapping_options
.get(UserMappingOptions::UrlStyle.as_str())
.get(UserMappingOptions::UrlStyle.as_ref())
.map(|url_style| format!("URL_STYLE '{}'", url_style));

let use_ssl = user_mapping_options
.get(UserMappingOptions::UseSsl.as_str())
.get(UserMappingOptions::UseSsl.as_ref())
.map(|use_ssl| format!("USE_SSL {}", use_ssl));

let url_compatibility_mode = user_mapping_options
.get(UserMappingOptions::UrlCompatibilityMode.as_str())
.get(UserMappingOptions::UrlCompatibilityMode.as_ref())
.map(|url_compatibility_mode| format!("URL_COMPATIBILITY_MODE {}", url_compatibility_mode));

let account_id = user_mapping_options
.get(UserMappingOptions::AccountId.as_str())
.get(UserMappingOptions::AccountId.as_ref())
.map(|account_id| format!("ACCOUNT_ID '{}'", account_id));

let connection_string = user_mapping_options
.get(UserMappingOptions::ConnectionString.as_str())
.get(UserMappingOptions::ConnectionString.as_ref())
.map(|connection_string| format!("CONNECTION_STRING '{}'", connection_string));

let account_name = user_mapping_options
.get(UserMappingOptions::AccountName.as_str())
.get(UserMappingOptions::AccountName.as_ref())
.map(|account_name| format!("ACCOUNT_NAME '{}'", account_name));

let tenant_id = user_mapping_options
.get(UserMappingOptions::TenantId.as_str())
.get(UserMappingOptions::TenantId.as_ref())
.map(|tenant_id| format!("TENANT_ID '{}'", tenant_id));

let client_id = user_mapping_options
.get(UserMappingOptions::ClientId.as_str())
.get(UserMappingOptions::ClientId.as_ref())
.map(|client_id| format!("CLIENT_ID '{}'", client_id));

let client_secret = user_mapping_options
.get(UserMappingOptions::ClientSecret.as_str())
.get(UserMappingOptions::ClientSecret.as_ref())
.map(|client_secret| format!("CLIENT_SECRET '{}'", client_secret));

let client_certificate_path = user_mapping_options
.get(UserMappingOptions::ClientCertificatePath.as_str())
.get(UserMappingOptions::ClientCertificatePath.as_ref())
.map(|client_certificate_path| {
format!("CLIENT_CERTIFICATE_PATH '{}'", client_certificate_path)
});

let http_proxy = user_mapping_options
.get(UserMappingOptions::HttpProxy.as_str())
.get(UserMappingOptions::HttpProxy.as_ref())
.map(|http_proxy| format!("HTTP_PROXY '{}'", http_proxy));

let proxy_user_name = user_mapping_options
.get(UserMappingOptions::ProxyUserName.as_str())
.get(UserMappingOptions::ProxyUserName.as_ref())
.map(|proxy_user_name| format!("PROXY_USER_NAME '{}'", proxy_user_name));

let proxy_password = user_mapping_options
.get(UserMappingOptions::ProxyPassword.as_str())
.get(UserMappingOptions::ProxyPassword.as_ref())
.map(|proxy_password| format!("PROXY_PASSWORD '{}'", proxy_password));

let secret_string = vec![
Expand Down Expand Up @@ -278,44 +224,44 @@ mod tests {
let secret_name = "s3_secret";
let user_mapping_options = HashMap::from([
(
UserMappingOptions::Type.as_str().to_string(),
UserMappingOptions::Type.as_ref().to_string(),
"S3".to_string(),
),
(
UserMappingOptions::Provider.as_str().to_string(),
UserMappingOptions::Provider.as_ref().to_string(),
"CONFIG".to_string(),
),
(
UserMappingOptions::KeyId.as_str().to_string(),
UserMappingOptions::KeyId.as_ref().to_string(),
"key_id".to_string(),
),
(
UserMappingOptions::Secret.as_str().to_string(),
UserMappingOptions::Secret.as_ref().to_string(),
"secret".to_string(),
),
(
UserMappingOptions::Region.as_str().to_string(),
UserMappingOptions::Region.as_ref().to_string(),
"us-west-2".to_string(),
),
(
UserMappingOptions::SessionToken.as_str().to_string(),
UserMappingOptions::SessionToken.as_ref().to_string(),
"session_token".to_string(),
),
(
UserMappingOptions::Endpoint.as_str().to_string(),
UserMappingOptions::Endpoint.as_ref().to_string(),
"s3.amazonaws.com".to_string(),
),
(
UserMappingOptions::UrlStyle.as_str().to_string(),
UserMappingOptions::UrlStyle.as_ref().to_string(),
"vhost".to_string(),
),
(
UserMappingOptions::UseSsl.as_str().to_string(),
UserMappingOptions::UseSsl.as_ref().to_string(),
"true".to_string(),
),
(
UserMappingOptions::UrlCompatibilityMode
.as_str()
.as_ref()
.to_string(),
"true".to_string(),
),
Expand All @@ -336,11 +282,11 @@ mod tests {
let secret_name = "s3_secret";
let user_mapping_options = HashMap::from([
(
UserMappingOptions::Type.as_str().to_string(),
UserMappingOptions::Type.as_ref().to_string(),
"S3".to_string(),
),
(
UserMappingOptions::Provider.as_str().to_string(),
UserMappingOptions::Provider.as_ref().to_string(),
"TENANT_ID".to_string(),
),
]);
Expand All @@ -358,27 +304,27 @@ mod tests {
let secret_name = "azure_secret";
let user_mapping_options = HashMap::from([
(
UserMappingOptions::Type.as_str().to_string(),
UserMappingOptions::Type.as_ref().to_string(),
"AZURE".to_string(),
),
(
UserMappingOptions::Provider.as_str().to_string(),
UserMappingOptions::Provider.as_ref().to_string(),
"CONFIG".to_string(),
),
(
UserMappingOptions::ConnectionString.as_str().to_string(),
UserMappingOptions::ConnectionString.as_ref().to_string(),
"connection_string".to_string(),
),
(
UserMappingOptions::HttpProxy.as_str().to_string(),
UserMappingOptions::HttpProxy.as_ref().to_string(),
"http_proxy".to_string(),
),
(
UserMappingOptions::ProxyUserName.as_str().to_string(),
UserMappingOptions::ProxyUserName.as_ref().to_string(),
"proxy_user_name".to_string(),
),
(
UserMappingOptions::ProxyPassword.as_str().to_string(),
UserMappingOptions::ProxyPassword.as_ref().to_string(),
"proxy_password".to_string(),
),
]);
Expand All @@ -397,7 +343,7 @@ mod tests {
fn test_create_type_invalid() {
let secret_name = "invalid_secret";
let user_mapping_options = HashMap::from([(
UserMappingOptions::Type.as_str().to_string(),
UserMappingOptions::Type.as_ref().to_string(),
"INVALID".to_string(),
)]);

Expand Down

0 comments on commit 5f58f8f

Please sign in to comment.