Skip to content

Commit

Permalink
revert credential changes for base
Browse files Browse the repository at this point in the history
  • Loading branch information
Kvadratni committed Jan 31, 2025
1 parent db11336 commit 996e023
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 150 deletions.
3 changes: 0 additions & 3 deletions crates/goose/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,6 @@ once_cell = "1.20.2"
dirs = "6.0.0"
rand = "0.8.5"

[target.'cfg(target_os = "windows")'.dependencies]
winapi = { version = "0.3", features = ["wincred"] }

[dev-dependencies]
criterion = "0.5"
tempfile = "3.15.0"
Expand Down
158 changes: 11 additions & 147 deletions crates/goose/src/config/base.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#[cfg(not(target_os = "windows"))]
use keyring::Entry;
use once_cell::sync::OnceCell;
use serde::Deserialize;
Expand Down Expand Up @@ -40,98 +39,12 @@ impl From<serde_yaml::Error> for ConfigError {
}
}

#[cfg(not(target_os = "windows"))]
impl From<keyring::Error> for ConfigError {
fn from(err: keyring::Error) -> Self {
ConfigError::KeyringError(err.to_string())
}
}

#[cfg(target_os = "windows")]
mod platform {
use super::*;
use std::ptr;
use winapi::shared::minwindef::DWORD;
use winapi::um::wincred::{CredReadW, CredWriteW, CREDENTIALW, CRED_TYPE_GENERIC};

pub struct WindowsCredential {
target_name: String,
}

impl WindowsCredential {
pub fn new(service: &str, username: &str) -> Self {
let target_name = format!("{}:{}", service, username);
WindowsCredential { target_name }
}

pub fn get_password(&self) -> Result<String, ConfigError> {
unsafe {
let mut pcred: *mut CREDENTIALW = ptr::null_mut();
let target = to_wide_chars(&self.target_name);

let result = CredReadW(target.as_ptr(), CRED_TYPE_GENERIC, 0, &mut pcred);

if result == 0 {
return Err(ConfigError::KeyringError(
"Failed to read credential".into(),
));
}

let credential = &*pcred;
let blob = std::slice::from_raw_parts(
credential.CredentialBlob as *const u8,
credential.CredentialBlobSize as usize,
);

String::from_utf8(blob.to_vec())
.map_err(|e| ConfigError::KeyringError(e.to_string()))
}
}

pub fn set_password(&self, password: &str) -> Result<(), ConfigError> {
unsafe {
let target = to_wide_chars(&self.target_name);
let blob = password.as_bytes();

let mut credential = CREDENTIALW {
Flags: 0,
Type: CRED_TYPE_GENERIC,
TargetName: target.as_ptr() as *mut _,
Comment: ptr::null_mut(),
LastWritten: winapi::shared::minwindef::FILETIME {
dwLowDateTime: 0,
dwHighDateTime: 0,
},
CredentialBlobSize: blob.len() as DWORD,
CredentialBlob: blob.as_ptr() as *mut _,
Persist: 2, // CRED_PERSIST_LOCAL_MACHINE
AttributeCount: 0,
Attributes: ptr::null_mut(),
TargetAlias: ptr::null_mut(),
UserName: ptr::null_mut(),
};

let result = CredWriteW(&mut credential, 0);

if result == 0 {
return Err(ConfigError::KeyringError(
"Failed to write credential".into(),
));
}

Ok(())
}
}
}

fn to_wide_chars(s: &str) -> Vec<u16> {
s.encode_utf16().chain(std::iter::once(0)).collect()
}
}

#[cfg(target_os = "windows")]
use platform::WindowsCredential;

/// Configuration management for Goose.
///
/// This module provides a flexible configuration system that supports:
Expand Down Expand Up @@ -269,30 +182,15 @@ impl Config {

// Load current secrets from the keyring
fn load_secrets(&self) -> Result<HashMap<String, Value>, ConfigError> {
#[cfg(target_os = "windows")]
{
let credential = WindowsCredential::new(&self.keyring_service, KEYRING_USERNAME);
match credential.get_password() {
Ok(content) => {
let values: HashMap<String, Value> = serde_json::from_str(&content)?;
Ok(values)
}
Err(ConfigError::KeyringError(_)) => Ok(HashMap::new()),
Err(e) => Err(e),
}
}
let entry = Entry::new(&self.keyring_service, KEYRING_USERNAME)?;

#[cfg(not(target_os = "windows"))]
{
let entry = Entry::new(&self.keyring_service, KEYRING_USERNAME)?;
match entry.get_password() {
Ok(content) => {
let values: HashMap<String, Value> = serde_json::from_str(&content)?;
Ok(values)
}
Err(keyring::Error::NoEntry) => Ok(HashMap::new()),
Err(e) => Err(ConfigError::KeyringError(e.to_string())),
match entry.get_password() {
Ok(content) => {
let values: HashMap<String, Value> = serde_json::from_str(&content)?;
Ok(values)
}
Err(keyring::Error::NoEntry) => Ok(HashMap::new()),
Err(e) => Err(ConfigError::KeyringError(e.to_string())),
}
}

Expand Down Expand Up @@ -422,19 +320,8 @@ impl Config {
values.insert(key.to_string(), value);

let json_value = serde_json::to_string(&values)?;

#[cfg(target_os = "windows")]
{
let credential = WindowsCredential::new(&self.keyring_service, KEYRING_USERNAME);
credential.set_password(&json_value)?;
}

#[cfg(not(target_os = "windows"))]
{
let entry = Entry::new(&self.keyring_service, KEYRING_USERNAME)?;
entry.set_password(&json_value)?;
}

let entry = Entry::new(&self.keyring_service, KEYRING_USERNAME)?;
entry.set_password(&json_value)?;
Ok(())
}

Expand All @@ -453,19 +340,8 @@ impl Config {
values.remove(key);

let json_value = serde_json::to_string(&values)?;

#[cfg(target_os = "windows")]
{
let credential = WindowsCredential::new(&self.keyring_service, KEYRING_USERNAME);
credential.set_password(&json_value)?;
}

#[cfg(not(target_os = "windows"))]
{
let entry = Entry::new(&self.keyring_service, KEYRING_USERNAME)?;
entry.set_password(&json_value)?;
}

let entry = Entry::new(&self.keyring_service, KEYRING_USERNAME)?;
entry.set_password(&json_value)?;
Ok(())
}
}
Expand All @@ -476,7 +352,6 @@ mod tests {
use serial_test::serial;
use tempfile::NamedTempFile;

#[cfg(not(target_os = "windows"))]
fn cleanup_keyring() -> Result<(), ConfigError> {
let entry = Entry::new(TEST_KEYRING_SERVICE, KEYRING_USERNAME)?;
match entry.delete_credential() {
Expand All @@ -486,16 +361,6 @@ mod tests {
}
}

#[cfg(target_os = "windows")]
fn cleanup_keyring() -> Result<(), ConfigError> {
let credential = WindowsCredential::new(TEST_KEYRING_SERVICE, KEYRING_USERNAME);
match credential.delete_credential() {
Ok(_) => Ok(()),
Err(ConfigError::KeyringError(_)) => Ok(()),
Err(e) => Err(e),
}
}

#[test]
fn test_basic_config() -> Result<(), ConfigError> {
let temp_file = NamedTempFile::new().unwrap();
Expand All @@ -512,7 +377,6 @@ mod tests {
std::env::set_var("TEST_KEY", "env_value");
let value: String = config.get("test_key")?;
assert_eq!(value, "env_value");
std::env::remove_var("TEST_KEY");

Ok(())
}
Expand Down

0 comments on commit 996e023

Please sign in to comment.