diff --git a/libsql/Cargo.toml b/libsql/Cargo.toml index a744802b10..de96f759c3 100644 --- a/libsql/Cargo.toml +++ b/libsql/Cargo.toml @@ -99,6 +99,8 @@ sync = [ "parser", "serde", "stream", + "remote", + "replication", "dep:tower", "dep:hyper", "dep:http", diff --git a/libsql/src/database.rs b/libsql/src/database.rs index 05eb67241a..5e0fa328eb 100644 --- a/libsql/src/database.rs +++ b/libsql/src/database.rs @@ -10,7 +10,6 @@ pub use libsql_sys::{Cipher, EncryptionConfig}; use crate::{Connection, Result}; use std::fmt; use std::sync::atomic::AtomicU64; -use std::sync::Arc; cfg_core! { bitflags::bitflags! { @@ -82,7 +81,14 @@ enum DbType { encryption_config: Option, }, #[cfg(feature = "sync")] - Offline { db: crate::local::Database }, + Offline { + db: crate::local::Database, + remote_writes: bool, + read_your_writes: bool, + url: String, + auth_token: String, + connector: crate::util::ConnectorService, + }, #[cfg(feature = "remote")] Remote { url: String, @@ -117,7 +123,7 @@ pub struct Database { db_type: DbType, /// The maximum replication index returned from a write performed using any connection created using this Database object. #[allow(dead_code)] - max_write_replication_index: Arc, + max_write_replication_index: std::sync::Arc, } cfg_core! { @@ -375,7 +381,7 @@ cfg_replication! { #[cfg(feature = "replication")] DbType::Sync { db, encryption_config: _ } => db.sync().await, #[cfg(feature = "sync")] - DbType::Offline { db } => db.sync_offline().await, + DbType::Offline { db, .. } => db.sync_offline().await, _ => Err(Error::SyncNotSupported(format!("{:?}", self.db_type))), } } @@ -642,13 +648,42 @@ impl Database { } #[cfg(feature = "sync")] - DbType::Offline { db } => { - use crate::local::impls::LibsqlConnection; - - let conn = db.connect()?; - - let conn = std::sync::Arc::new(LibsqlConnection { conn }); + DbType::Offline { + db, + remote_writes, + read_your_writes, + url, + auth_token, + connector, + } => { + use crate::{ + hrana::{connection::HttpConnection, hyper::HttpSender}, + local::impls::LibsqlConnection, + replication::connection::State, + sync::connection::SyncedConnection, + }; + use tokio::sync::Mutex; + + let local = db.connect()?; + + if *remote_writes { + let synced = SyncedConnection { + local, + remote: HttpConnection::new( + url.clone(), + auth_token.clone(), + HttpSender::new(connector.clone(), None), + ), + read_your_writes: *read_your_writes, + context: db.sync_ctx.clone().unwrap(), + state: std::sync::Arc::new(Mutex::new(State::Init)), + }; + + let conn = std::sync::Arc::new(synced); + return Ok(Connection { conn }); + } + let conn = std::sync::Arc::new(LibsqlConnection { conn: local }); Ok(Connection { conn }) } diff --git a/libsql/src/database/builder.rs b/libsql/src/database/builder.rs index 9d9e0fdf8f..14df79ad6d 100644 --- a/libsql/src/database/builder.rs +++ b/libsql/src/database/builder.rs @@ -102,7 +102,9 @@ impl Builder<()> { connector: None, version: None, }, - connector:None, + connector: None, + read_your_writes: true, + remote_writes: false, }, } } @@ -463,6 +465,8 @@ cfg_sync! { flags: crate::OpenFlags, remote: Remote, connector: Option, + remote_writes: bool, + read_your_writes: bool, } impl Builder { @@ -472,6 +476,16 @@ cfg_sync! { self } + pub fn read_your_writes(mut self, v: bool) -> Builder { + self.inner.read_your_writes = v; + self + } + + pub fn remote_writes(mut self, v: bool) -> Builder { + self.inner.remote_writes = v; + self + } + /// Provide a custom http connector that will be used to create http connections. pub fn connector(mut self, connector: C) -> Builder where @@ -497,6 +511,8 @@ cfg_sync! { version: _, }, connector, + remote_writes, + read_your_writes, } = self.inner; let path = path.to_str().ok_or(crate::Error::InvalidUTF8Path)?.to_owned(); @@ -515,16 +531,23 @@ cfg_sync! { let connector = crate::util::ConnectorService::new(svc); let db = crate::local::Database::open_local_with_offline_writes( - connector, + connector.clone(), path, flags, - url, - auth_token, + url.clone(), + auth_token.clone(), ) .await?; Ok(Database { - db_type: DbType::Offline { db }, + db_type: DbType::Offline { + db, + remote_writes, + read_your_writes, + url, + auth_token, + connector, + }, max_write_replication_index: Default::default(), }) } diff --git a/libsql/src/hrana/hyper.rs b/libsql/src/hrana/hyper.rs index b02ce0ce23..536a371765 100644 --- a/libsql/src/hrana/hyper.rs +++ b/libsql/src/hrana/hyper.rs @@ -305,14 +305,17 @@ impl Conn for HranaStream { let parse = crate::parser::Statement::parse(sql); for s in parse { let s = s?; - if s.kind == crate::parser::StmtKind::TxnBegin - || s.kind == crate::parser::StmtKind::TxnBeginReadOnly - || s.kind == crate::parser::StmtKind::TxnEnd - { + + use crate::parser::StmtKind; + if matches!( + s.kind, + StmtKind::TxnBegin | StmtKind::TxnBeginReadOnly | StmtKind::TxnEnd + ) { return Err(Error::TransactionalBatchError( "Transactions forbidden inside transactional batch".to_string(), )); } + stmts.push(Stmt::new(s.stmt, false)); } let res = self diff --git a/libsql/src/hrana/mod.rs b/libsql/src/hrana/mod.rs index 4a6fd0c63a..432ad6eea4 100644 --- a/libsql/src/hrana/mod.rs +++ b/libsql/src/hrana/mod.rs @@ -3,7 +3,7 @@ pub mod connection; cfg_remote! { - mod hyper; + pub mod hyper; } mod cursor; diff --git a/libsql/src/local/database.rs b/libsql/src/local/database.rs index 48170c58ab..3b157e715d 100644 --- a/libsql/src/local/database.rs +++ b/libsql/src/local/database.rs @@ -20,10 +20,11 @@ cfg_replication!( cfg_sync! { use crate::sync::SyncContext; + use tokio::sync::Mutex; + use std::sync::Arc; } -use crate::{database::OpenFlags, local::connection::Connection}; -use crate::{Error::ConnectionFailed, Result}; +use crate::{database::OpenFlags, local::connection::Connection, Error::ConnectionFailed, Result}; use libsql_sys::ffi; // A libSQL database. @@ -33,7 +34,7 @@ pub struct Database { #[cfg(feature = "replication")] pub replication_ctx: Option, #[cfg(feature = "sync")] - pub sync_ctx: Option>, + pub sync_ctx: Option>>, } impl Database { @@ -222,7 +223,7 @@ impl Database { let sync_ctx = SyncContext::new(connector, db_path.into(), endpoint, Some(auth_token)).await?; - db.sync_ctx = Some(tokio::sync::Mutex::new(sync_ctx)); + db.sync_ctx = Some(Arc::new(Mutex::new(sync_ctx))); Ok(db) } @@ -463,137 +464,10 @@ impl Database { #[cfg(feature = "sync")] /// Sync WAL frames to remote. pub async fn sync_offline(&self) -> Result { - use crate::sync::SyncError; - use crate::Error; - let mut sync_ctx = self.sync_ctx.as_ref().unwrap().lock().await; let conn = self.connect()?; - let durable_frame_no = sync_ctx.durable_frame_num(); - let max_frame_no = conn.wal_frame_count(); - - if max_frame_no > durable_frame_no { - match self.try_push(&mut sync_ctx, &conn).await { - Ok(rep) => Ok(rep), - Err(Error::Sync(err)) => { - // Retry the sync because we are ahead of the server and we need to push some older - // frames. - if let Some(SyncError::InvalidPushFrameNoLow(_, _)) = err.downcast_ref() { - tracing::debug!("got InvalidPushFrameNo, retrying push"); - self.try_push(&mut sync_ctx, &conn).await - } else { - Err(Error::Sync(err)) - } - } - Err(e) => Err(e), - } - } else { - self.try_pull(&mut sync_ctx, &conn).await - } - .or_else(|err| { - let Error::Sync(err) = err else { - return Err(err); - }; - - // TODO(levy): upcasting should be done *only* at the API boundary, doing this in - // internal code just sucks. - let Some(SyncError::HttpDispatch(_)) = err.downcast_ref() else { - return Err(Error::Sync(err)); - }; - - Ok(crate::database::Replicated { - frame_no: None, - frames_synced: 0, - }) - }) - } - - #[cfg(feature = "sync")] - async fn try_push( - &self, - sync_ctx: &mut SyncContext, - conn: &Connection, - ) -> Result { - let page_size = { - let rows = conn - .query("PRAGMA page_size", crate::params::Params::None)? - .unwrap(); - let row = rows.next()?.unwrap(); - let page_size = row.get::(0)?; - page_size - }; - - let max_frame_no = conn.wal_frame_count(); - if max_frame_no == 0 { - return Ok(crate::database::Replicated { - frame_no: None, - frames_synced: 0, - }); - } - - let generation = sync_ctx.generation(); // TODO: Probe from WAL. - let start_frame_no = sync_ctx.durable_frame_num() + 1; - let end_frame_no = max_frame_no; - - let mut frame_no = start_frame_no; - while frame_no <= end_frame_no { - let frame = conn.wal_get_frame(frame_no, page_size)?; - - // The server returns its maximum frame number. To avoid resending - // frames the server already knows about, we need to update the - // frame number to the one returned by the server. - let max_frame_no = sync_ctx - .push_one_frame(frame.freeze(), generation, frame_no) - .await?; - - if max_frame_no > frame_no { - frame_no = max_frame_no; - } - frame_no += 1; - } - - sync_ctx.write_metadata().await?; - - // TODO(lucio): this can underflow if the server previously returned a higher max_frame_no - // than what we have stored here. - let frame_count = end_frame_no - start_frame_no + 1; - Ok(crate::database::Replicated { - frame_no: None, - frames_synced: frame_count as usize, - }) - } - - #[cfg(feature = "sync")] - async fn try_pull( - &self, - sync_ctx: &mut SyncContext, - conn: &Connection, - ) -> Result { - let generation = sync_ctx.generation(); - let mut frame_no = sync_ctx.durable_frame_num() + 1; - - let insert_handle = conn.wal_insert_handle()?; - - loop { - match sync_ctx.pull_one_frame(generation, frame_no).await { - Ok(Some(frame)) => { - insert_handle.insert(&frame)?; - frame_no += 1; - } - Ok(None) => { - sync_ctx.write_metadata().await?; - return Ok(crate::database::Replicated { - frame_no: None, - frames_synced: 1, - }); - } - Err(err) => { - tracing::debug!("pull_one_frame error: {:?}", err); - sync_ctx.write_metadata().await?; - return Err(err); - } - } - } + crate::sync::sync_offline(&mut sync_ctx, &conn).await } pub(crate) fn path(&self) -> &str { diff --git a/libsql/src/local/impls.rs b/libsql/src/local/impls.rs index f1f8a6d153..223467d27d 100644 --- a/libsql/src/local/impls.rs +++ b/libsql/src/local/impls.rs @@ -91,7 +91,7 @@ impl Drop for LibsqlConnection { } } -pub(crate) struct LibsqlStmt(pub(super) crate::local::Statement); +pub(crate) struct LibsqlStmt(pub crate::local::Statement); #[async_trait::async_trait] impl Stmt for LibsqlStmt { diff --git a/libsql/src/params.rs b/libsql/src/params.rs index 21c0ebce0d..6921e0145e 100644 --- a/libsql/src/params.rs +++ b/libsql/src/params.rs @@ -141,6 +141,13 @@ impl IntoParams for Params { } } +impl Sealed for &Params {} +impl IntoParams for &Params { + fn into_params(self) -> Result { + Ok(self.clone()) + } +} + impl Sealed for Vec {} impl IntoParams for Vec { fn into_params(self) -> Result { diff --git a/libsql/src/replication/connection.rs b/libsql/src/replication/connection.rs index f87e2bc595..7519dba82f 100644 --- a/libsql/src/replication/connection.rs +++ b/libsql/src/replication/connection.rs @@ -40,7 +40,7 @@ struct Inner { } #[derive(Debug, Default, PartialEq, Eq, Clone, Copy)] -enum State { +pub enum State { #[default] Init, Invalid, @@ -106,7 +106,7 @@ fn predict_final_state<'a>( /// parsed. This means that we only take into account the entire passed sql statement set and /// for example will reject writes if we are in a readonly txn to start with even if we commit /// and start a new transaction with the write in it. -fn should_execute_local(state: &mut State, stmts: &[parser::Statement]) -> Result { +pub fn should_execute_local(state: &mut State, stmts: &[parser::Statement]) -> Result { let predicted_end_state = predict_final_state(*state, stmts.iter()); let should_execute_local = match (*state, predicted_end_state) { diff --git a/libsql/src/replication/mod.rs b/libsql/src/replication/mod.rs index 00296fb532..57ef8fc6a5 100644 --- a/libsql/src/replication/mod.rs +++ b/libsql/src/replication/mod.rs @@ -30,7 +30,7 @@ use self::local_client::LocalClient; use self::remote_client::RemoteClient; pub(crate) mod client; -mod connection; +pub(crate) mod connection; pub(crate) mod local_client; pub(crate) mod remote_client; diff --git a/libsql/src/sync.rs b/libsql/src/sync.rs index ed625f7b68..da25c2f84f 100644 --- a/libsql/src/sync.rs +++ b/libsql/src/sync.rs @@ -1,4 +1,4 @@ -use crate::{util::ConnectorService, Result}; +use crate::{local::Connection, util::ConnectorService, Error, Result}; use std::path::Path; @@ -12,6 +12,10 @@ use uuid::Uuid; #[cfg(test)] mod test; +pub mod connection; +pub mod statement; +pub mod transaction; + const METADATA_VERSION: u32 = 0; const DEFAULT_MAX_RETRIES: usize = 5; @@ -107,7 +111,11 @@ impl SyncContext { } #[tracing::instrument(skip(self))] - pub(crate) async fn pull_one_frame(&mut self, generation: u32, frame_no: u32) -> Result> { + pub(crate) async fn pull_one_frame( + &mut self, + generation: u32, + frame_no: u32, + ) -> Result> { let uri = format!( "{}/sync/{}/{}/{}", self.sync_url, @@ -315,7 +323,10 @@ impl SyncContext { async fn read_metadata(&mut self) -> Result<()> { let path = format!("{}-info", self.db_path); - if !Path::new(&path).try_exists().map_err(SyncError::io("metadata file exists"))? { + if !Path::new(&path) + .try_exists() + .map_err(SyncError::io("metadata file exists"))? + { tracing::debug!("no metadata info file found"); return Ok(()); } @@ -419,3 +430,131 @@ async fn atomic_write>(path: P, data: &[u8]) -> Result<()> { Ok(()) } + +/// Sync WAL frames to remote. +pub async fn sync_offline( + sync_ctx: &mut SyncContext, + conn: &Connection, +) -> Result { + let durable_frame_no = sync_ctx.durable_frame_num(); + let max_frame_no = conn.wal_frame_count(); + + if max_frame_no > durable_frame_no { + match try_push(sync_ctx, conn).await { + Ok(rep) => Ok(rep), + Err(Error::Sync(err)) => { + // Retry the sync because we are ahead of the server and we need to push some older + // frames. + if let Some(SyncError::InvalidPushFrameNoLow(_, _)) = err.downcast_ref() { + tracing::debug!("got InvalidPushFrameNo, retrying push"); + try_push(sync_ctx, conn).await + } else { + Err(Error::Sync(err)) + } + } + Err(e) => Err(e), + } + } else { + try_pull(sync_ctx, conn).await + } + .or_else(|err| { + let Error::Sync(err) = err else { + return Err(err); + }; + + // TODO(levy): upcasting should be done *only* at the API boundary, doing this in + // internal code just sucks. + let Some(SyncError::HttpDispatch(_)) = err.downcast_ref() else { + return Err(Error::Sync(err)); + }; + + Ok(crate::database::Replicated { + frame_no: None, + frames_synced: 0, + }) + }) +} + +async fn try_push( + sync_ctx: &mut SyncContext, + conn: &Connection, +) -> Result { + let page_size = { + let rows = conn + .query("PRAGMA page_size", crate::params::Params::None)? + .unwrap(); + let row = rows.next()?.unwrap(); + let page_size = row.get::(0)?; + page_size + }; + + let max_frame_no = conn.wal_frame_count(); + if max_frame_no == 0 { + return Ok(crate::database::Replicated { + frame_no: None, + frames_synced: 0, + }); + } + + let generation = sync_ctx.generation(); // TODO: Probe from WAL. + let start_frame_no = sync_ctx.durable_frame_num() + 1; + let end_frame_no = max_frame_no; + + let mut frame_no = start_frame_no; + while frame_no <= end_frame_no { + let frame = conn.wal_get_frame(frame_no, page_size)?; + + // The server returns its maximum frame number. To avoid resending + // frames the server already knows about, we need to update the + // frame number to the one returned by the server. + let max_frame_no = sync_ctx + .push_one_frame(frame.freeze(), generation, frame_no) + .await?; + + if max_frame_no > frame_no { + frame_no = max_frame_no; + } + frame_no += 1; + } + + sync_ctx.write_metadata().await?; + + // TODO(lucio): this can underflow if the server previously returned a higher max_frame_no + // than what we have stored here. + let frame_count = end_frame_no - start_frame_no + 1; + Ok(crate::database::Replicated { + frame_no: None, + frames_synced: frame_count as usize, + }) +} + +async fn try_pull( + sync_ctx: &mut SyncContext, + conn: &Connection, +) -> Result { + let generation = sync_ctx.generation(); + let mut frame_no = sync_ctx.durable_frame_num() + 1; + + let insert_handle = conn.wal_insert_handle()?; + + loop { + match sync_ctx.pull_one_frame(generation, frame_no).await { + Ok(Some(frame)) => { + insert_handle.insert(&frame)?; + frame_no += 1; + } + Ok(None) => { + sync_ctx.write_metadata().await?; + return Ok(crate::database::Replicated { + frame_no: None, + frames_synced: 1, + }); + } + Err(err) => { + tracing::debug!("pull_one_frame error: {:?}", err); + sync_ctx.write_metadata().await?; + return Err(err); + } + } + } +} diff --git a/libsql/src/sync/connection.rs b/libsql/src/sync/connection.rs new file mode 100644 index 0000000000..ba9b19d27f --- /dev/null +++ b/libsql/src/sync/connection.rs @@ -0,0 +1,120 @@ +use crate::{ + connection::Conn, + hrana::{connection::HttpConnection, hyper::HttpSender}, + local::{self, impls::LibsqlStmt}, + params::Params, + replication::connection::State, + sync::SyncContext, + BatchRows, Error, Result, Statement, Transaction, TransactionBehavior, +}; +use std::sync::Arc; +use tokio::sync::Mutex; + +use super::{statement::SyncedStatement, transaction::SyncedTx}; + +#[derive(Clone)] +pub struct SyncedConnection { + pub remote: HttpConnection, + pub local: local::Connection, + pub read_your_writes: bool, + pub context: Arc>, + pub state: Arc>, +} + +impl SyncedConnection { + async fn should_execute_local(&self, sql: &str) -> Result { + let stmts = crate::parser::Statement::parse(sql) + .collect::>>() + .or_else(|err| match err { + Error::Sqlite3UnsupportedStatement => Ok(vec![]), + err => Err(err), + })?; + + let mut state = self.state.lock().await; + + crate::replication::connection::should_execute_local(&mut state, stmts.as_slice()) + } +} + +#[async_trait::async_trait] +impl Conn for SyncedConnection { + async fn execute(&self, sql: &str, params: Params) -> Result { + let mut stmt = self.prepare(sql).await?; + stmt.execute(params).await.map(|v| v as u64) + } + + async fn execute_batch(&self, sql: &str) -> Result { + if self.should_execute_local(sql).await? { + self.local.execute_batch(sql) + } else { + self.remote.execute_batch(sql).await + } + } + + async fn execute_transactional_batch(&self, sql: &str) -> Result { + if self.should_execute_local(sql).await? { + self.local.execute_transactional_batch(sql)?; + Ok(BatchRows::empty()) + } else { + self.remote.execute_transactional_batch(sql).await + } + } + + async fn prepare(&self, sql: &str) -> Result { + if self.should_execute_local(sql).await? { + Ok(Statement { + inner: Box::new(LibsqlStmt(self.local.prepare(sql)?)), + }) + } else { + let stmt = Statement { + inner: Box::new(self.remote.prepare(sql)?), + }; + + if self.read_your_writes { + Ok(Statement { + inner: Box::new(SyncedStatement { + conn: self.local.clone(), + context: self.context.clone(), + inner: stmt, + }), + }) + } else { + Ok(stmt) + } + } + } + + async fn transaction(&self, tx_behavior: TransactionBehavior) -> Result { + let tx = SyncedTx::begin(self.clone(), tx_behavior).await?; + + Ok(Transaction { + inner: Box::new(tx), + conn: crate::Connection { + conn: Arc::new(self.clone()), + }, + close: None, + }) + } + + fn interrupt(&self) -> Result<()> { + Ok(()) + } + + fn is_autocommit(&self) -> bool { + self.remote.is_autocommit() + } + + fn changes(&self) -> u64 { + self.remote.changes() + } + + fn total_changes(&self) -> u64 { + self.remote.total_changes() + } + + fn last_insert_rowid(&self) -> i64 { + self.remote.last_insert_rowid() + } + + async fn reset(&self) {} +} diff --git a/libsql/src/sync/statement.rs b/libsql/src/sync/statement.rs new file mode 100644 index 0000000000..b8576c73bb --- /dev/null +++ b/libsql/src/sync/statement.rs @@ -0,0 +1,58 @@ +use crate::{ + local::{self}, + params::Params, + statement::Stmt, + sync::SyncContext, Column, Result, Rows, Statement, +}; +use std::sync::Arc; +use tokio::sync::Mutex; + +pub struct SyncedStatement { + pub conn: local::Connection, + pub context: Arc>, + pub inner: Statement, +} + +#[async_trait::async_trait] +impl Stmt for SyncedStatement { + fn finalize(&mut self) { + self.inner.finalize() + } + + async fn execute(&mut self, params: &Params) -> Result { + let result = self.inner.execute(params).await; + let mut context = self.context.lock().await; + let _ = crate::sync::sync_offline(&mut context, &self.conn).await; + result + } + + async fn query(&mut self, params: &Params) -> Result { + let result = self.inner.query(params).await; + let mut context = self.context.lock().await; + let _ = crate::sync::sync_offline(&mut context, &self.conn).await; + result + } + + async fn run(&mut self, params: &Params) -> Result<()> { + let result = self.inner.run(params).await; + let mut context = self.context.lock().await; + let _ = crate::sync::sync_offline(&mut context, &self.conn).await; + result + } + + fn reset(&mut self) { + self.inner.reset() + } + + fn parameter_count(&self) -> usize { + self.inner.parameter_count() + } + + fn parameter_name(&self, idx: i32) -> Option<&str> { + self.inner.parameter_name(idx) + } + + fn columns(&self) -> Vec { + self.inner.columns() + } +} diff --git a/libsql/src/sync/transaction.rs b/libsql/src/sync/transaction.rs new file mode 100644 index 0000000000..ebfa1ee803 --- /dev/null +++ b/libsql/src/sync/transaction.rs @@ -0,0 +1,41 @@ +use crate::{ + connection::Conn, + params::Params, + transaction::Tx, Result, TransactionBehavior, +}; + +use super::connection::SyncedConnection; + +pub struct SyncedTx(SyncedConnection); + +impl SyncedTx { + pub(crate) async fn begin( + conn: SyncedConnection, + tx_behavior: TransactionBehavior, + ) -> Result { + conn.execute( + match tx_behavior { + TransactionBehavior::Deferred => "BEGIN DEFERRED", + TransactionBehavior::Immediate => "BEGIN IMMEDIATE", + TransactionBehavior::Exclusive => "BEGIN EXCLUSIVE", + TransactionBehavior::ReadOnly => "BEGIN READONLY", + }, + Params::None, + ) + .await?; + Ok(Self(conn.clone())) + } +} + +#[async_trait::async_trait] +impl Tx for SyncedTx { + async fn commit(&mut self) -> Result<()> { + self.0.execute("COMMIT", Params::None).await?; + Ok(()) + } + + async fn rollback(&mut self) -> Result<()> { + self.0.execute("ROLLBACK", Params::None).await?; + Ok(()) + } +}