diff --git a/src/config/mod.rs b/src/config/mod.rs index d2d7c2f1..25f0cd67 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -338,7 +338,7 @@ impl Config { pub fn rag_file(&self, name: &str) -> Result { let path = match &self.agent { Some(agent) => Self::agent_rag_file(agent.name(), name)?, - None => Self::rags_dir()?.join(format!("{name}.bin")), + None => Self::rags_dir()?.join(format!("{name}.yaml")), }; Ok(path) } @@ -359,7 +359,7 @@ impl Config { } pub fn agent_rag_file(agent_name: &str, rag_name: &str) -> Result { - Ok(Self::agent_config_dir(agent_name)?.join(format!("{rag_name}.bin"))) + Ok(Self::agent_config_dir(agent_name)?.join(format!("{rag_name}.yaml"))) } pub fn agent_variables_file(name: &str) -> Result { @@ -1186,7 +1186,7 @@ impl Config { let mut names = vec![]; for entry in rd.flatten() { let name = entry.file_name(); - if let Some(name) = name.to_string_lossy().strip_suffix(".bin") { + if let Some(name) = name.to_string_lossy().strip_suffix(".yaml") { names.push(name.to_string()); } } diff --git a/src/config/session.rs b/src/config/session.rs index 40519360..476ad512 100644 --- a/src/config/session.rs +++ b/src/config/session.rs @@ -326,10 +326,10 @@ impl Session { self.path = Some(session_path.display().to_string()); let content = serde_yaml::to_string(&self) - .with_context(|| format!("Failed to serde session {}", self.name))?; + .with_context(|| format!("Failed to serde session '{}'", self.name))?; write(session_path, content).with_context(|| { format!( - "Failed to write session {} to {}", + "Failed to write session '{}' to '{}'", self.name, session_path.display() ) diff --git a/src/rag/mod.rs b/src/rag/mod.rs index f6a511d7..e03b3f11 100644 --- a/src/rag/mod.rs +++ b/src/rag/mod.rs @@ -8,18 +8,17 @@ use crate::utils::*; mod bm25; mod loader; +mod serde_vectors; mod splitter; -use anyhow::bail; -use anyhow::{anyhow, Context, Result}; +use anyhow::{anyhow, bail, Context, Result}; use hnsw_rs::prelude::*; use indexmap::{IndexMap, IndexSet}; use inquire::{required, validator::Validation, Confirm, Select, Text}; use path_absolutize::Absolutize; use serde::{Deserialize, Serialize}; use serde_json::json; -use std::collections::HashMap; -use std::{fmt::Debug, io::BufReader, path::Path}; +use std::{collections::HashMap, fmt::Debug, fs, path::Path}; pub struct Rag { config: GlobalConfig, @@ -103,9 +102,8 @@ impl Rag { pub fn load(config: &GlobalConfig, name: &str, path: &Path) -> Result { let err = || format!("Failed to load rag '{name}' at '{}'", path.display()); - let file = std::fs::File::open(path).with_context(err)?; - let reader = BufReader::new(file); - let data: RagData = bincode::deserialize_from(reader).with_context(err)?; + let content = fs::read_to_string(path).with_context(err)?; + let data: RagData = serde_yaml::from_str(&content).with_context(err)?; Self::create(config, name, path, data) } @@ -236,9 +234,13 @@ impl Rag { } let path = Path::new(&self.path); ensure_parent_exists(path)?; - let mut file = std::fs::File::create(path)?; - bincode::serialize_into(&mut file, &self.data) - .with_context(|| format!("Failed to save rag '{}'", self.name))?; + + let content = serde_yaml::to_string(&self.data) + .with_context(|| format!("Failed to serde rag '{}'", self.name))?; + fs::write(path, content).with_context(|| { + format!("Failed to save rag '{}' to '{}'", self.name, path.display()) + })?; + Ok(true) } @@ -576,6 +578,7 @@ pub struct RagData { pub next_file_id: FileId, pub document_paths: Vec, pub files: IndexMap, + #[serde(with = "serde_vectors")] pub vectors: IndexMap>, } diff --git a/src/rag/serde_vectors.rs b/src/rag/serde_vectors.rs new file mode 100644 index 00000000..894c22ca --- /dev/null +++ b/src/rag/serde_vectors.rs @@ -0,0 +1,69 @@ +use super::*; + +use base64::{engine::general_purpose::STANDARD, Engine}; +use serde::{de, Deserializer, Serializer}; + +pub fn serialize( + vectors: &IndexMap>, + serializer: S, +) -> Result +where + S: Serializer, +{ + let encoded_map: IndexMap = vectors + .iter() + .map(|(key, vec)| { + let (h, l) = split_document_id(*key); + let byte_slice = unsafe { + std::slice::from_raw_parts( + vec.as_ptr() as *const u8, + vec.len() * std::mem::size_of::(), + ) + }; + (format!("{h}-{l}"), STANDARD.encode(byte_slice)) + }) + .collect(); + + encoded_map.serialize(serializer) +} + +pub fn deserialize<'de, D>(deserializer: D) -> Result>, D::Error> +where + D: Deserializer<'de>, +{ + let encoded_map: IndexMap = + IndexMap::::deserialize(deserializer)?; + + let mut decoded_map = IndexMap::new(); + for (key, base64_str) in encoded_map { + let decoded_key: DocumentId = key + .split_once('-') + .and_then(|(h, l)| { + let h = h.parse::().ok()?; + let l = l.parse::().ok()?; + Some(combine_document_id(h, l)) + }) + .ok_or_else(|| de::Error::custom(format!("Invalid key '{key}'")))?; + + let decoded_data = STANDARD.decode(&base64_str).map_err(de::Error::custom)?; + + if decoded_data.len() % std::mem::size_of::() != 0 { + return Err(de::Error::custom(format!("Invalid vector at '{key}'"))); + } + + let num_f32s = decoded_data.len() / std::mem::size_of::(); + + let mut vec_f32 = vec![0.0f32; num_f32s]; + unsafe { + std::ptr::copy_nonoverlapping( + decoded_data.as_ptr(), + vec_f32.as_mut_ptr() as *mut u8, + decoded_data.len(), + ); + } + + decoded_map.insert(decoded_key, vec_f32); + } + + Ok(decoded_map) +}