Skip to content

Commit

Permalink
Improve rust readability.
Browse files Browse the repository at this point in the history
  • Loading branch information
johnspurlock-skymethod committed Dec 4, 2023
1 parent 88fc892 commit 87f2735
Showing 1 changed file with 94 additions and 35 deletions.
129 changes: 94 additions & 35 deletions napi/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ use rusqlite::OpenFlags;
use std::collections::HashMap;
use std::io::Cursor;
use std::path::Path;
use std::sync::atomic::AtomicU32;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use std::sync::Mutex;
use tokio::sync::Notify;
Expand All @@ -29,22 +31,23 @@ pub fn open(path: String, in_memory: Option<bool>, debug: bool) -> u32 {
if debug {
println!("[napi] open: path={:#?} in_memory={:#?}", path, in_memory)
}
let flags = if in_memory.is_some() && in_memory.unwrap() {
let flags = if let Some(true) = in_memory {
OpenFlags::default() | OpenFlags::SQLITE_OPEN_MEMORY
} else {
OpenFlags::default()
};
let conn = Connection::open_with_flags(Path::new(&path), flags).unwrap();
let rng: Box<_> = Box::new(rand::rngs::StdRng::from_entropy());
let rng = Box::new(rand::rngs::StdRng::from_entropy());

let opened_path = conn.path().unwrap().to_owned();
// let notifier = SqliteNotifier::default(); // TODO waiting on watch pr
// let sqlite = Sqlite::new(conn, notifier, rng).unwrap();
let notify = Arc::new(Notify::new());
let sqlite = Sqlite::new(conn, notify, rng).unwrap();

let db_id = DBS.lock().unwrap().keys().max().unwrap_or(&0) + 1;
let db_id = DB_ID.fetch_add(1, Ordering::Relaxed);
DBS.lock().unwrap().insert(db_id, sqlite);

if debug {
println!(
"[napi] open: db_id={:#?} opened_path={:#?}",
Expand All @@ -59,15 +62,19 @@ pub fn close(db_id: u32, debug: bool) {
if debug {
println!("[napi] close: db_id={:#?}", db_id)
}
let mut dbs = DBS.lock().unwrap();
dbs.get(&db_id).unwrap().close();
dbs.remove(&db_id);
let db = DBS.lock().unwrap().remove(&db_id);
db.map(|db| db.close());
}

#[napi]
pub async fn snapshot_read(db_id: u32, snapshot_read_bytes: Buffer, debug: bool) -> Result<Buffer> {
let snapshot_read_pb = pb::SnapshotRead::decode(&mut Cursor::new(snapshot_read_bytes))
.map_err(convert_prost_decode_error_to_anyhow)?;
pub async fn snapshot_read(
db_id: u32,
snapshot_read_bytes: Buffer,
debug: bool,
) -> Result<Buffer> {
let snapshot_read_pb =
pb::SnapshotRead::decode(&mut Cursor::new(snapshot_read_bytes))
.map_err(convert_prost_decode_error_to_anyhow)?;
if debug {
println!(
"[napi] snapshot_read: db_id={:#?} snapshot_read_pb={:#?}",
Expand All @@ -83,7 +90,12 @@ pub async fn snapshot_read(db_id: u32, snapshot_read_bytes: Buffer, debug: bool)
.try_into()
.map_err(convert_error_to_anyhow)?;

let db = DBS.lock().unwrap().get(&db_id).unwrap().to_owned();
let db = DBS
.lock()
.unwrap()
.get(&db_id)
.map(|db| db.clone())
.ok_or_else(|| anyhow::anyhow!("db not found"))?;

let output_pb: pb::SnapshotReadOutput = db
.snapshot_read(requests, options)
Expand All @@ -99,9 +111,14 @@ pub async fn snapshot_read(db_id: u32, snapshot_read_bytes: Buffer, debug: bool)
}

#[napi]
pub async fn atomic_write(db_id: u32, atomic_write_bytes: Buffer, debug: bool) -> Result<Buffer> {
let atomic_write_pb = pb::AtomicWrite::decode(&mut Cursor::new(atomic_write_bytes))
.map_err(convert_prost_decode_error_to_anyhow)?;
pub async fn atomic_write(
db_id: u32,
atomic_write_bytes: Buffer,
debug: bool,
) -> Result<Buffer> {
let atomic_write_pb =
pb::AtomicWrite::decode(&mut Cursor::new(atomic_write_bytes))
.map_err(convert_prost_decode_error_to_anyhow)?;
if debug {
println!(
"[napi] atomic_write: db_id={:#?} atomic_write_pb={:#?}",
Expand All @@ -113,7 +130,12 @@ pub async fn atomic_write(db_id: u32, atomic_write_bytes: Buffer, debug: bool) -
.try_into()
.map_err(convert_error_to_anyhow)?;

let db = DBS.lock().unwrap().get(&db_id).unwrap().to_owned();
let db = DBS
.lock()
.unwrap()
.get(&db_id)
.map(|db| db.clone())
.ok_or_else(|| anyhow::anyhow!("db not found"))?;

let output_pb: pb::AtomicWriteOutput = db
.atomic_write(atomic_write)
Expand Down Expand Up @@ -146,31 +168,36 @@ pub async fn dequeue_next_message(
println!("[napi] dequeue_next_message: db_id={:#?}", db_id)
}

let db: Sqlite = DBS.lock().unwrap().get(&db_id).unwrap().to_owned();
let db = DBS
.lock()
.unwrap()
.get(&db_id)
.map(|db| db.clone())
.ok_or_else(|| anyhow::anyhow!("db not found"))?;

let opt_handle = db
.dequeue_next_message()
.await
.map_err(convert_sqlite_backend_error_to_anyhow)?;

if opt_handle.is_none() {
let Some(mut handle) = opt_handle else {
if debug {
println!(
"[napi] dequeue_next_message: no messages! db_id={:#?}",
db_id
)
}
return Ok(Either::B(()));
}
};

let mut handle = opt_handle.unwrap();
let payload = handle
.take_payload()
.await
.map_err(convert_sqlite_backend_error_to_anyhow)?;

let message_id: u32 = MSGS.lock().unwrap().keys().max().unwrap_or(&0) + 1;
let message_id = MSG_ID.fetch_add(1, Ordering::Relaxed);
MSGS.lock().unwrap().insert(message_id, handle);

if debug {
println!(
"[napi] dequeue_next_message: received message db_id={:#?} message_id={:#?}",
Expand All @@ -183,15 +210,23 @@ pub async fn dequeue_next_message(
}

#[napi]
pub async fn finish_message(db_id: u32, message_id: u32, success: bool, debug: bool) -> Result<()> {
pub async fn finish_message(
db_id: u32,
message_id: u32,
success: bool,
debug: bool,
) -> Result<()> {
if debug {
println!(
"[napi] finish_message db_id={:#?} message_id={:#?} success={:#?}",
db_id, message_id, success
)
}
let opt_handle = MSGS.lock().unwrap().remove(&message_id);
let handle = opt_handle.unwrap();
let handle = MSGS
.lock()
.unwrap()
.remove(&message_id)
.ok_or_else(|| anyhow::anyhow!("message not found"))?;

handle
.finish(success)
Expand All @@ -202,7 +237,11 @@ pub async fn finish_message(db_id: u32, message_id: u32, success: bool, debug: b
}

#[napi]
pub async fn start_watch(db_id: u32, watch_bytes: Buffer, debug: bool) -> Result<u32> {
pub async fn start_watch(
db_id: u32,
watch_bytes: Buffer,
debug: bool,
) -> Result<u32> {
if debug {
println!(
"[napi] start_watch: db_id={:#?} watch_bytes={:#?}",
Expand All @@ -221,7 +260,7 @@ pub async fn start_watch(db_id: u32, watch_bytes: Buffer, debug: bool) -> Result

// TODO store stream in WATCHES?

let watch_id: u32 = WATCHES.lock().unwrap().keys().max().unwrap_or(&0) + 1;
let watch_id = WATCH_ID.fetch_add(1, Ordering::Relaxed);
WATCHES.lock().unwrap().insert(watch_id, ());

Ok(watch_id)
Expand Down Expand Up @@ -263,12 +302,17 @@ pub fn end_watch(db_id: u32, watch_id: u32, debug: bool) {

//

static DBS: Lazy<Mutex<HashMap<u32, Sqlite>>> = Lazy::new(|| Mutex::new(HashMap::new()));
static DB_ID: AtomicU32 = AtomicU32::new(0);
static DBS: Lazy<Mutex<HashMap<u32, Sqlite>>> =
Lazy::new(|| Mutex::new(HashMap::new()));

static MSG_ID: AtomicU32 = AtomicU32::new(0);
static MSGS: Lazy<Mutex<HashMap<u32, SqliteMessageHandle>>> =
Lazy::new(|| Mutex::new(HashMap::new()));

static WATCHES: Lazy<Mutex<HashMap<u32, ()>>> = Lazy::new(|| Mutex::new(HashMap::new()));
static WATCH_ID: AtomicU32 = AtomicU32::new(0);
static WATCHES: Lazy<Mutex<HashMap<u32, ()>>> =
Lazy::new(|| Mutex::new(HashMap::new()));

fn to_buffer<T: prost::Message>(output: T) -> Buffer {
let mut buf = Vec::with_capacity(output.encoded_len());
Expand All @@ -285,27 +329,42 @@ fn convert_error_to_str(err: denokv_proto::ConvertError) -> String {
ConvertError::TooManyReadRanges => String::from("TooManyReadRanges"),
ConvertError::TooManyChecks => String::from("TooManyChecks"),
ConvertError::TooManyMutations => String::from("TooManyMutations"),
ConvertError::TooManyQueueUndeliveredKeys => String::from("TooManyQueueUndeliveredKeys"),
ConvertError::TooManyQueueBackoffIntervals => String::from("TooManyQueueBackoffIntervals"),
ConvertError::QueueBackoffIntervalTooLarge => String::from("QueueBackoffIntervalTooLarge"),
ConvertError::InvalidReadRangeLimit => String::from("InvalidReadRangeLimit"),
ConvertError::TooManyQueueUndeliveredKeys => {
String::from("TooManyQueueUndeliveredKeys")
}
ConvertError::TooManyQueueBackoffIntervals => {
String::from("TooManyQueueBackoffIntervals")
}
ConvertError::QueueBackoffIntervalTooLarge => {
String::from("QueueBackoffIntervalTooLarge")
}
ConvertError::InvalidReadRangeLimit => {
String::from("InvalidReadRangeLimit")
}
ConvertError::DecodeError => String::from("DecodeError"),
ConvertError::InvalidVersionstamp => String::from("InvalidVersionstamp"),
ConvertError::InvalidMutationKind => String::from("InvalidMutationKind"),
ConvertError::InvalidMutationExpireAt => String::from("InvalidMutationExpireAt"),
ConvertError::InvalidMutationEnqueueDeadline => String::from("InvalidMutationEnqueueDeadline"),
// ConvertError::TooManyWatchedKeys => String::from("TooManyWatchedKeys"), // TODO waiting on watch pr
ConvertError::InvalidMutationExpireAt => {
String::from("InvalidMutationExpireAt")
}
ConvertError::InvalidMutationEnqueueDeadline => {
String::from("InvalidMutationEnqueueDeadline")
}
}
}

fn convert_error_to_anyhow(err: denokv_proto::ConvertError) -> anyhow::Error {
anyhow::anyhow!(convert_error_to_str(err))
}

fn convert_sqlite_backend_error_to_anyhow(err: SqliteBackendError) -> anyhow::Error {
fn convert_sqlite_backend_error_to_anyhow(
err: SqliteBackendError,
) -> anyhow::Error {
err.into()
}

fn convert_prost_decode_error_to_anyhow(err: prost::DecodeError) -> anyhow::Error {
fn convert_prost_decode_error_to_anyhow(
err: prost::DecodeError,
) -> anyhow::Error {
err.into()
}

0 comments on commit 87f2735

Please sign in to comment.