Skip to content

Commit

Permalink
refactor: simplify sync to async sqlx bridge
Browse files Browse the repository at this point in the history
Also reduce amount of threads and use multithreaded tokio runtime.
  • Loading branch information
boxdot committed Jan 20, 2025
1 parent d5304b8 commit ca84c00
Show file tree
Hide file tree
Showing 9 changed files with 264 additions and 376 deletions.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 1 addition & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ phonenumber = "0.3.6"
postcard = { version = "1.1.1", features = ["alloc"] }
qr2term = "0.3.3"
ratatui = "0.29.0"
rayon = "1.10.0"
regex = "1.11.1"
scopeguard = "1.2.0"
serde = { version = "1.0.216", features = ["derive"] }
Expand All @@ -74,7 +73,6 @@ sqlx = { version = "0.8.2", features = [
] }
textwrap = "0.16.1"
thiserror = "2.0.9"
thread_local = "1.1.8"
tokio = { version = "1.42.0", default-features = false, features = [
"rt-multi-thread",
"macros",
Expand All @@ -94,6 +92,7 @@ tempfile = "3.14.0"
crokey = "1.1.0"
strum_macros = "0.26.4"
strum = { version = "0.26.3", features = ["derive"] }
tokio-util = { version = "0.7.13", features = ["rt"] }

[package.metadata.cargo-machete]
# not used directly; brings sqlcipher capabilities to sqlite
Expand Down
20 changes: 12 additions & 8 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use presage::libsignal_service::content::Content;
use ratatui::{backend::CrosstermBackend, Terminal};
use tokio::select;
use tokio_stream::StreamExt;
use tokio_util::task::LocalPoolHandle;
use tracing::debug;
use tracing::{error, info};

Expand All @@ -43,7 +44,7 @@ struct Args {
relink: bool,
}

#[tokio::main(flavor = "current_thread")]
#[tokio::main(worker_threads = 2)]
async fn main() -> anyhow::Result<()> {
let args = Args::parse();

Expand All @@ -62,9 +63,7 @@ async fn main() -> anyhow::Result<()> {

log_panics::init();

tokio::task::LocalSet::new()
.run_until(run_single_threaded(args.relink))
.await
run(args.relink).await
}

async fn is_online() -> bool {
Expand All @@ -88,8 +87,10 @@ pub enum Event {
AppEvent(gurk::event::Event),
}

async fn run_single_threaded(relink: bool) -> anyhow::Result<()> {
let (mut signal_manager, config) = signal::ensure_linked_device(relink).await?;
async fn run(relink: bool) -> anyhow::Result<()> {
let local_pool = LocalPoolHandle::new(2);
let (mut signal_manager, config) =
signal::ensure_linked_device(relink, local_pool.clone()).await?;

let mut storage: Box<dyn Storage> = if config.sqlite.enabled {
debug!(
Expand All @@ -102,6 +103,7 @@ async fn run_single_threaded(relink: bool) -> anyhow::Result<()> {
config.passphrase.clone(),
config.sqlite.preserve_unencrypted,
)
.await
.with_context(|| {
format!(
"failed to open sqlite data storage at: {}",
Expand Down Expand Up @@ -156,7 +158,8 @@ async fn run_single_threaded(relink: bool) -> anyhow::Result<()> {
});

let inner_tx = tx.clone();
tokio::task::spawn_local(async move {

local_pool.spawn_pinned(|| async move {
let mut backoff = Backoff::new();
loop {
let mut messages = if !is_online().await {
Expand All @@ -171,7 +174,8 @@ async fn run_single_threaded(relink: bool) -> anyhow::Result<()> {
Err(e) => {
let e = e.context(
"failed to initialize the stream of Signal messages.\n\
Maybe the device was unlinked? Please try to restart with '--relink` flag.",
Maybe the device was unlinked? Please try to restart with \
'--relink` flag.",
);
inner_tx
.send(Event::Quit(Some(e)))
Expand Down
24 changes: 16 additions & 8 deletions src/signal/impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use presage::{
use presage_store_sled::SledStore;
use tokio::sync::oneshot;
use tokio_stream::{Stream, StreamExt};
use tokio_util::task::LocalPoolHandle;
use tracing::{error, warn};
use uuid::Uuid;

Expand All @@ -33,18 +34,25 @@ use super::{

pub(super) struct PresageManager {
manager: presage::Manager<SledStore, Registered>,
local_pool: LocalPoolHandle,
}

impl PresageManager {
pub(super) fn new(manager: presage::Manager<SledStore, Registered>) -> Self {
Self { manager }
pub(super) fn new(
manager: presage::Manager<SledStore, Registered>,
local_pool: LocalPoolHandle,
) -> Self {
Self {
manager,
local_pool,
}
}
}

#[async_trait(?Send)]
impl SignalManager for PresageManager {
fn clone_boxed(&self) -> Box<dyn SignalManager> {
Box::new(Self::new(self.manager.clone()))
Box::new(Self::new(self.manager.clone(), self.local_pool.clone()))
}

fn user_id(&self) -> Uuid {
Expand Down Expand Up @@ -102,7 +110,7 @@ impl SignalManager for PresageManager {
};

let mut manager = self.manager.clone();
tokio::task::spawn_local(async move {
self.local_pool.spawn_pinned(move || async move {
let body = ContentBody::ReceiptMessage(data_message);
if let Err(error) = manager
.send_message(ServiceId::Aci(sender_uuid.into()), body, now_timestamp)
Expand Down Expand Up @@ -162,7 +170,7 @@ impl SignalManager for PresageManager {
match channel.id {
ChannelId::User(uuid) => {
let mut manager = self.manager.clone();
tokio::task::spawn_local(async move {
self.local_pool.spawn_pinned(move || async move {
if let Err(error) =
upload_attachments(&manager, attachments, &mut data_message).await
{
Expand Down Expand Up @@ -202,7 +210,7 @@ impl SignalManager for PresageManager {
..Default::default()
});

tokio::task::spawn_local(async move {
self.local_pool.spawn_pinned(move || async move {
if let Err(error) =
upload_attachments(&manager, attachments, &mut data_message).await
{
Expand Down Expand Up @@ -271,7 +279,7 @@ impl SignalManager for PresageManager {
(ChannelId::User(uuid), _) => {
let mut manager = self.manager.clone();
let body = ContentBody::DataMessage(data_message);
tokio::task::spawn_local(async move {
self.local_pool.spawn_pinned(move || async move {
if let Err(e) = manager
.send_message(ServiceId::Aci(uuid.into()), body, timestamp)
.await
Expand All @@ -291,7 +299,7 @@ impl SignalManager for PresageManager {
..Default::default()
});

tokio::task::spawn_local(async move {
self.local_pool.spawn_pinned(move || async move {
if let Err(e) = manager
.send_message_to_group(&master_key_bytes, data_message, timestamp)
.await
Expand Down
8 changes: 5 additions & 3 deletions src/signal/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ pub mod test;
use anyhow::{bail, Context as _};
use presage::{libsignal_service::configuration::SignalServers, model::identity::OnNewIdentity};
use presage_store_sled::{MigrationConflictStrategy, SledStore};
use tokio_util::task::LocalPoolHandle;

use crate::config::{self, Config};

Expand All @@ -31,7 +32,8 @@ pub type GroupIdentifierBytes = [u8; GROUP_IDENTIFIER_LEN];
/// path.
pub async fn ensure_linked_device(
relink: bool,
) -> anyhow::Result<(Box<dyn SignalManager>, Config)> {
local_pool: LocalPoolHandle,
) -> anyhow::Result<(Box<dyn SignalManager + Send>, Config)> {
let config = Config::load_installed()?;

let db_path = config
Expand All @@ -54,7 +56,7 @@ pub async fn ensure_linked_device(
match presage::Manager::load_registered(store.clone()).await {
Ok(manager) => {
// done loading manager from store
return Ok((Box::new(PresageManager::new(manager)), config));
return Ok((Box::new(PresageManager::new(manager, local_pool)), config));
}
Err(e) => {
bail!("error loading manager. Try again later or run with --relink to force relink: {}", e)
Expand Down Expand Up @@ -127,5 +129,5 @@ pub async fn ensure_linked_device(
config
};

Ok((Box::new(PresageManager::new(manager)), config))
Ok((Box::new(PresageManager::new(manager, local_pool)), config))
}
83 changes: 35 additions & 48 deletions src/storage/sql/encrypt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use tempfile::tempdir;
use tracing::info;
use url::Url;

pub(super) fn encrypt_db(
pub(super) async fn encrypt_db(
url: &Url,
passphrase: &str,
preserve_unencrypted: bool,
Expand All @@ -22,46 +22,33 @@ pub(super) fn encrypt_db(

info!(%url, "encrypting db");

std::thread::scope(|s| {
s.spawn(|| -> anyhow::Result<()> {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
rt.block_on(async {
let tempdir = tempdir().context("failed to create temp dir")?;
let dest = tempdir.path().join("encrypted.db");

let mut conn = SqliteConnection::connect_with(&opts).await?;
sqlx::raw_sql(&format!(
"
ATTACH DATABASE '{}' AS encrypted KEY '{passphrase}';
SELECT sqlcipher_export('encrypted');
DETACH DATABASE encrypted;
",
dest.display(),
))
.execute(&mut conn)
.await
.context("failed to encrypt db")?;

let origin = url.path();
if preserve_unencrypted {
let backup = format!("{origin}.backup");
std::fs::copy(origin, &backup).with_context(|| {
format!("failed to backup the unencrypted database at: {backup}")
})?;
}

std::fs::copy(dest, origin)
.with_context(|| format!("failed to replace unencrypted db at: {origin}"))?;

Ok(())
})
})
.join()
.expect("encryption failed")
})
let tempdir = tempdir().context("failed to create temp dir")?;
let dest = tempdir.path().join("encrypted.db");

let mut conn = SqliteConnection::connect_with(&opts).await?;
sqlx::raw_sql(&format!(
"
ATTACH DATABASE '{}' AS encrypted KEY '{passphrase}';
SELECT sqlcipher_export('encrypted');
DETACH DATABASE encrypted;
",
dest.display(),
))
.execute(&mut conn)
.await
.context("failed to encrypt db")?;

let origin = url.path();
if preserve_unencrypted {
let backup = format!("{origin}.backup");
std::fs::copy(origin, &backup)
.with_context(|| format!("failed to backup the unencrypted database at: {backup}"))?;
}

std::fs::copy(dest, origin)
.with_context(|| format!("failed to replace unencrypted db at: {origin}"))?;

Ok(())
}

pub(super) fn is_sqlite_encrypted_heuristics(url: &Url) -> Option<bool> {
Expand All @@ -79,28 +66,28 @@ mod tests {

use super::*;

#[test]
fn test_encrypt_unencrypted() -> anyhow::Result<()> {
#[tokio::test]
async fn test_encrypt_unencrypted() {
let tempdir = tempdir().unwrap();
let path = tempdir.path().join("data.sqlite");
let url: Url = format!("sqlite://{}", path.display()).parse().unwrap();

let _ = SqliteStorage::open(&url, None).unwrap();
let _ = SqliteStorage::open(&url, None).await.unwrap();
assert!(path.exists());
assert_eq!(is_sqlite_encrypted_heuristics(&url), Some(false));

let preserve_unencrypted = true;
let passphrase = "secret".to_owned();
encrypt_db(&url, &passphrase, preserve_unencrypted).unwrap();
encrypt_db(&url, &passphrase, preserve_unencrypted)
.await
.unwrap();

assert!(path.exists());
assert_eq!(is_sqlite_encrypted_heuristics(&url), Some(true));

let backup_url: Url = format!("{url}.backup").parse().unwrap();
assert_eq!(is_sqlite_encrypted_heuristics(&backup_url), Some(false));

let _ = SqliteStorage::open(&url, Some(passphrase)).unwrap();

Ok(())
let _ = SqliteStorage::open(&url, Some(passphrase)).await.unwrap();
}
}
Loading

0 comments on commit ca84c00

Please sign in to comment.