diff --git a/Cargo.lock b/Cargo.lock index eb1be40..c3985d4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1790,6 +1790,7 @@ dependencies = [ name = "stateroom-wasm" version = "0.2.9" dependencies = [ + "bincode", "stateroom", "stateroom-wasm-macro", ] @@ -1799,6 +1800,7 @@ name = "stateroom-wasm-host" version = "0.2.9" dependencies = [ "anyhow", + "bincode", "byteorder", "stateroom", "tracing", diff --git a/examples/binary-echo/Cargo.lock b/examples/binary-echo/Cargo.lock index b632a40..c0eeb8b 100644 --- a/examples/binary-echo/Cargo.lock +++ b/examples/binary-echo/Cargo.lock @@ -9,6 +9,15 @@ dependencies = [ "stateroom-wasm", ] +[[package]] +name = "bincode" +version = "1.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad" +dependencies = [ + "serde", +] + [[package]] name = "proc-macro2" version = "1.0.81" @@ -27,14 +36,38 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "serde" +version = "1.0.200" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddc6f9cc94d67c0e21aaf7eda3a010fd3af78ebf6e096aa6e2e13c79749cce4f" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.200" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "856f046b9400cee3c8c94ed572ecdb752444c24528c035cd35882aad6f492bcb" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "stateroom" version = "0.2.8" +dependencies = [ + "serde", +] [[package]] name = "stateroom-wasm" version = "0.2.9" dependencies = [ + "bincode", "stateroom", "stateroom-wasm-macro", ] diff --git a/examples/binary-echo/src/lib.rs b/examples/binary-echo/src/lib.rs index 0e62ce8..17701b7 100644 --- a/examples/binary-echo/src/lib.rs +++ b/examples/binary-echo/src/lib.rs @@ -1,5 +1,5 @@ use stateroom_wasm::{ - stateroom_wasm, ClientId, MessageRecipient, StateroomContext, StateroomService, + stateroom_wasm, ClientId, MessageRecipient, StateroomContext, StateroomService, MessagePayload }; #[stateroom_wasm] @@ -7,14 +7,11 @@ use stateroom_wasm::{ struct BinaryEcho; impl StateroomService for BinaryEcho { - fn message(&mut self, _: ClientId, message: &str, ctx: &impl StateroomContext) { - ctx.send_binary(MessageRecipient::Broadcast, message.as_bytes()); - } - - fn binary(&mut self, _: ClientId, message: &[u8], ctx: &impl StateroomContext) { - ctx.send_message( - MessageRecipient::Broadcast, - &format!("Received binary data: {:?}", &message), - ); + fn message(&mut self, _: ClientId, message: MessagePayload, ctx: &impl StateroomContext) { + let message = match message { + MessagePayload::Text(s) => MessagePayload::Bytes(s.as_bytes().to_vec()), + MessagePayload::Bytes(b) => MessagePayload::Text(format!("{:?}", b)), + }; + ctx.send_message(MessageRecipient::Broadcast, message); } } diff --git a/examples/clock/Cargo.lock b/examples/clock/Cargo.lock index c69cd5f..f5e83b5 100644 --- a/examples/clock/Cargo.lock +++ b/examples/clock/Cargo.lock @@ -2,6 +2,15 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "bincode" +version = "1.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad" +dependencies = [ + "serde", +] + [[package]] name = "clock" version = "0.1.0" @@ -27,14 +36,38 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "serde" +version = "1.0.200" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddc6f9cc94d67c0e21aaf7eda3a010fd3af78ebf6e096aa6e2e13c79749cce4f" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.200" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "856f046b9400cee3c8c94ed572ecdb752444c24528c035cd35882aad6f492bcb" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "stateroom" version = "0.2.8" +dependencies = [ + "serde", +] [[package]] name = "stateroom-wasm" version = "0.2.9" dependencies = [ + "bincode", "stateroom", "stateroom-wasm-macro", ] diff --git a/examples/clock/src/lib.rs b/examples/clock/src/lib.rs index 8af0019..1816282 100644 --- a/examples/clock/src/lib.rs +++ b/examples/clock/src/lib.rs @@ -10,7 +10,7 @@ impl StateroomService for ClockServer { } fn timer(&mut self, ctx: &impl StateroomContext) { - ctx.send_message(MessageRecipient::Broadcast, &format!("Timer @ {}", self.0)); + ctx.send_message(MessageRecipient::Broadcast, format!("Timer @ {}", self.0)); self.0 += 1; ctx.set_timer(4000); } diff --git a/examples/counter-service/Cargo.lock b/examples/counter-service/Cargo.lock index ee787a8..f8629a8 100644 --- a/examples/counter-service/Cargo.lock +++ b/examples/counter-service/Cargo.lock @@ -2,6 +2,15 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "bincode" +version = "1.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad" +dependencies = [ + "serde", +] + [[package]] name = "counter" version = "0.1.0" @@ -27,14 +36,38 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "serde" +version = "1.0.200" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddc6f9cc94d67c0e21aaf7eda3a010fd3af78ebf6e096aa6e2e13c79749cce4f" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.200" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "856f046b9400cee3c8c94ed572ecdb752444c24528c035cd35882aad6f492bcb" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "stateroom" version = "0.2.8" +dependencies = [ + "serde", +] [[package]] name = "stateroom-wasm" version = "0.2.9" dependencies = [ + "bincode", "stateroom", "stateroom-wasm-macro", ] diff --git a/examples/counter-service/src/lib.rs b/examples/counter-service/src/lib.rs index ca5e932..2b1d938 100644 --- a/examples/counter-service/src/lib.rs +++ b/examples/counter-service/src/lib.rs @@ -1,5 +1,5 @@ use stateroom_wasm::{ - stateroom_wasm, ClientId, MessageRecipient, StateroomContext, StateroomService, + stateroom_wasm, ClientId, MessageRecipient, StateroomContext, StateroomService, MessagePayload }; #[stateroom_wasm] @@ -7,8 +7,13 @@ use stateroom_wasm::{ struct SharedCounterServer(i32); impl StateroomService for SharedCounterServer { - fn message(&mut self, _: ClientId, message: &str, ctx: &impl StateroomContext) { - match message { + fn message(&mut self, _: ClientId, message: MessagePayload, ctx: &impl StateroomContext) { + let message = match message { + MessagePayload::Text(s) => s, + MessagePayload::Bytes(_) => return, + }; + + match &message[..] { "increment" => self.0 += 1, "decrement" => self.0 -= 1, _ => (), @@ -16,7 +21,7 @@ impl StateroomService for SharedCounterServer { ctx.send_message( MessageRecipient::Broadcast, - &format!("new value: {}", self.0), + format!("new value: {}", self.0), ); } } diff --git a/examples/cpu-hog/Cargo.lock b/examples/cpu-hog/Cargo.lock index fbd304d..c377e4f 100644 --- a/examples/cpu-hog/Cargo.lock +++ b/examples/cpu-hog/Cargo.lock @@ -2,6 +2,15 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "bincode" +version = "1.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad" +dependencies = [ + "serde", +] + [[package]] name = "cpu-hog" version = "0.1.0" @@ -28,14 +37,38 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "serde" +version = "1.0.200" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddc6f9cc94d67c0e21aaf7eda3a010fd3af78ebf6e096aa6e2e13c79749cce4f" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.200" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "856f046b9400cee3c8c94ed572ecdb752444c24528c035cd35882aad6f492bcb" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "stateroom" version = "0.2.8" +dependencies = [ + "serde", +] [[package]] name = "stateroom-wasm" version = "0.2.9" dependencies = [ + "bincode", "stateroom", "stateroom-wasm-macro", ] diff --git a/examples/cpu-hog/src/lib.rs b/examples/cpu-hog/src/lib.rs index 721491d..9c38902 100644 --- a/examples/cpu-hog/src/lib.rs +++ b/examples/cpu-hog/src/lib.rs @@ -15,7 +15,7 @@ fn get_time() -> u64 { impl StateroomService for CpuHog { fn connect(&mut self, _: ClientId, ctx: &impl StateroomContext) { - ctx.send_message(MessageRecipient::Broadcast, &format!("Connected.")); + ctx.send_message(MessageRecipient::Broadcast, format!("Connected.")); let init_time = get_time(); loop { @@ -25,6 +25,6 @@ impl StateroomService for CpuHog { } } - ctx.send_message(MessageRecipient::Broadcast, &format!("Finished.")); + ctx.send_message(MessageRecipient::Broadcast, format!("Finished.")); } } diff --git a/examples/echo-server/Cargo.lock b/examples/echo-server/Cargo.lock index aa6f38d..12b6057 100644 --- a/examples/echo-server/Cargo.lock +++ b/examples/echo-server/Cargo.lock @@ -2,6 +2,15 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "bincode" +version = "1.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad" +dependencies = [ + "serde", +] + [[package]] name = "echo-server" version = "0.1.0" @@ -27,14 +36,38 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "serde" +version = "1.0.200" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddc6f9cc94d67c0e21aaf7eda3a010fd3af78ebf6e096aa6e2e13c79749cce4f" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.200" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "856f046b9400cee3c8c94ed572ecdb752444c24528c035cd35882aad6f492bcb" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "stateroom" version = "0.2.8" +dependencies = [ + "serde", +] [[package]] name = "stateroom-wasm" version = "0.2.9" dependencies = [ + "bincode", "stateroom", "stateroom-wasm-macro", ] diff --git a/examples/echo-server/dist/server.wasm b/examples/echo-server/dist/server.wasm index 36e5df0..ed55e18 100755 Binary files a/examples/echo-server/dist/server.wasm and b/examples/echo-server/dist/server.wasm differ diff --git a/examples/echo-server/src/lib.rs b/examples/echo-server/src/lib.rs index c82299a..cd431c2 100644 --- a/examples/echo-server/src/lib.rs +++ b/examples/echo-server/src/lib.rs @@ -1,5 +1,5 @@ use stateroom_wasm::{ - stateroom_wasm, ClientId, MessageRecipient, StateroomContext, StateroomService, + stateroom_wasm, ClientId, MessageRecipient, StateroomContext, StateroomService, MessagePayload, }; #[stateroom_wasm] @@ -8,20 +8,25 @@ struct EchoServer; impl StateroomService for EchoServer { fn connect(&mut self, client_id: ClientId, ctx: &impl StateroomContext) { - ctx.send_message(client_id, &format!("User {:?} connected.", client_id)); + ctx.send_message(client_id, format!("User {:?} connected.", client_id)); } - fn message(&mut self, client_id: ClientId, message: &str, ctx: &impl StateroomContext) { + fn message(&mut self, client_id: ClientId, message: MessagePayload, ctx: &impl StateroomContext) { + let message = match message { + MessagePayload::Text(s) => s, + MessagePayload::Bytes(b) => unimplemented!(), + }; + ctx.send_message( MessageRecipient::Broadcast, - &format!("User {:?} sent '{}'", client_id, message), + format!("User {:?} sent '{}'", client_id, message), ); } fn disconnect(&mut self, client_id: ClientId, ctx: &impl StateroomContext) { ctx.send_message( MessageRecipient::Broadcast, - &format!("User {:?} left.", client_id), + format!("User {:?} left.", client_id), ); } } diff --git a/examples/randomness/Cargo.lock b/examples/randomness/Cargo.lock index f7ad06d..45e67a2 100644 --- a/examples/randomness/Cargo.lock +++ b/examples/randomness/Cargo.lock @@ -2,6 +2,15 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "bincode" +version = "1.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad" +dependencies = [ + "serde", +] + [[package]] name = "bytemuck" version = "1.15.0" @@ -35,14 +44,38 @@ dependencies = [ "wasi", ] +[[package]] +name = "serde" +version = "1.0.200" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddc6f9cc94d67c0e21aaf7eda3a010fd3af78ebf6e096aa6e2e13c79749cce4f" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.200" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "856f046b9400cee3c8c94ed572ecdb752444c24528c035cd35882aad6f492bcb" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "stateroom" version = "0.2.8" +dependencies = [ + "serde", +] [[package]] name = "stateroom-wasm" version = "0.2.9" dependencies = [ + "bincode", "stateroom", "stateroom-wasm-macro", ] diff --git a/examples/randomness/src/lib.rs b/examples/randomness/src/lib.rs index e1a0bdc..14675ee 100644 --- a/examples/randomness/src/lib.rs +++ b/examples/randomness/src/lib.rs @@ -1,6 +1,6 @@ use bytemuck::cast; use stateroom_wasm::{ - stateroom_wasm, ClientId, MessageRecipient, StateroomContext, StateroomService, + stateroom_wasm, ClientId, MessageRecipient, StateroomContext, StateroomService, MessagePayload }; #[stateroom_wasm] @@ -18,21 +18,26 @@ impl StateroomService for RandomServer { ctx.send_message( client_id, - &format!("User {:?} connected. Random number: {}", client_id, num[0]), + format!("User {:?} connected. Random number: {}", client_id, num[0]), ); } - fn message(&mut self, client_id: ClientId, message: &str, ctx: &impl StateroomContext) { + fn message(&mut self, client_id: ClientId, message: MessagePayload, ctx: &impl StateroomContext) { + let message = match message { + MessagePayload::Text(s) => s, + MessagePayload::Bytes(_) => return, + }; + ctx.send_message( MessageRecipient::Broadcast, - &format!("User {:?} sent '{}'", client_id, message), + format!("User {:?} sent '{}'", client_id, message), ); } fn disconnect(&mut self, client_id: ClientId, ctx: &impl StateroomContext) { ctx.send_message( MessageRecipient::Broadcast, - &format!("User {:?} left.", client_id), + format!("User {:?} left.", client_id), ); } } diff --git a/stateroom-server/src/server.rs b/stateroom-server/src/server.rs index 503a3c7..29484e3 100644 --- a/stateroom-server/src/server.rs +++ b/stateroom-server/src/server.rs @@ -1,7 +1,8 @@ use axum::extract::ws::Message; use dashmap::DashMap; use stateroom::{ - ClientId, MessageRecipient, StateroomContext, StateroomService, StateroomServiceFactory, + ClientId, MessagePayload, MessageRecipient, StateroomContext, StateroomService, + StateroomServiceFactory, }; use std::{ sync::{atomic::AtomicU32, Arc}, @@ -46,12 +47,17 @@ impl ServerStateroomContext { } impl StateroomContext for ServerStateroomContext { - fn send_message(&self, recipient: impl Into, message: &str) { - self.try_send(recipient.into(), Message::Text(message.to_string())); - } - - fn send_binary(&self, recipient: impl Into, message: &[u8]) { - self.try_send(recipient.into(), Message::Binary(message.to_vec())); + fn send_message( + &self, + recipient: impl Into, + message: impl Into, + ) { + let message: MessagePayload = message.into(); + let message: Message = match message { + MessagePayload::Text(s) => Message::Text(s), + MessagePayload::Bytes(b) => Message::Binary(b), + }; + self.try_send(recipient.into(), message); } fn set_timer(&self, ms_delay: u32) { @@ -101,8 +107,12 @@ impl ServerState { let msg = rx.recv().await; match msg { Some(Event::Message { client, message }) => match message { - Message::Text(msg) => service.message(client, &msg, context.as_ref()), - Message::Binary(msg) => service.binary(client, &msg, context.as_ref()), + Message::Text(msg) => { + service.message(client, MessagePayload::Text(msg), context.as_ref()) + } + Message::Binary(msg) => { + service.message(client, MessagePayload::Bytes(msg), context.as_ref()) + } Message::Close(_) => {} msg => tracing::warn!("Ignoring unhandled message: {:?}", msg), }, diff --git a/stateroom-wasm-host/Cargo.toml b/stateroom-wasm-host/Cargo.toml index bea641c..0194996 100644 --- a/stateroom-wasm-host/Cargo.toml +++ b/stateroom-wasm-host/Cargo.toml @@ -11,7 +11,8 @@ description = "A Stateroom service implementation that takes a WebAssembly modul [dependencies] anyhow = "1.0.45" byteorder = "1.4.3" -stateroom = {path="../stateroom", version="0.2.8"} +stateroom = {path="../stateroom", version="0.2.8", features=["serde"]} wasmtime = "20.0.0" tracing = "0.1.28" wasi-common = "20.0.0" +bincode = "1.3.3" diff --git a/stateroom-wasm-host/src/wasm_host.rs b/stateroom-wasm-host/src/wasm_host.rs index 2ab2cc7..e10e7e0 100644 --- a/stateroom-wasm-host/src/wasm_host.rs +++ b/stateroom-wasm-host/src/wasm_host.rs @@ -1,27 +1,22 @@ use crate::WasmRuntimeError; -use anyhow::Result; +use anyhow::{Context, Result}; use byteorder::{LittleEndian, ReadBytesExt}; -use stateroom::{ClientId, MessageRecipient, StateroomContext, StateroomService}; +use stateroom::{ + ClientId, MessageFromProcess, MessagePayload, MessageToProcess, StateroomContext, + StateroomService, +}; use std::{borrow::BorrowMut, sync::Arc}; use wasi_common::{sync::WasiCtxBuilder, WasiCtx}; use wasmtime::{Caller, Engine, Extern, Instance, Linker, Memory, Module, Store, TypedFunc, Val}; const ENV: &str = "env"; const EXT_MEMORY: &str = "memory"; -const EXT_FN_CONNECT: &str = "connect"; -const EXT_FN_DISCONNECT: &str = "disconnect"; -const EXT_FN_BINARY: &str = "binary"; -const EXT_FN_INIT: &str = "init"; -const EXT_FN_MESSAGE: &str = "message"; -const EXT_FN_SEND_MESSAGE: &str = "send_message"; -const EXT_FN_SEND_BINARY: &str = "send_binary"; -const EXT_FN_SET_TIMER: &str = "set_timer"; -const EXT_FN_TIMER: &str = "timer"; -const EXT_FN_INITIALIZE: &str = "initialize"; -const EXT_FN_MALLOC: &str = "jam_malloc"; -const EXT_FN_FREE: &str = "jam_free"; -const EXT_JAMSOCKET_VERSION: &str = "JAMSOCKET_API_VERSION"; -const EXT_JAMSOCKET_PROTOCOL: &str = "JAMSOCKET_API_PROTOCOL"; +const EXT_FN_SEND: &str = "stateroom_send"; +const EXT_FN_RECV: &str = "stateroom_recv"; +const EXT_FN_MALLOC: &str = "stateroom_malloc"; +const EXT_FN_FREE: &str = "stateroom_free"; +const EXT_STATEROOM_VERSION: &str = "STATEROOM_API_VERSION"; +const EXT_STATEROOM_PROTOCOL: &str = "STATEROOM_API_PROTOCOL"; const EXPECTED_API_VERSION: i32 = 1; const EXPECTED_PROTOCOL_VERSION: i32 = 0; @@ -33,12 +28,7 @@ pub struct WasmHost { fn_malloc: TypedFunc, fn_free: TypedFunc<(u32, u32), ()>, - fn_init: TypedFunc<(), ()>, - fn_message: TypedFunc<(u32, u32, u32), ()>, - fn_binary: TypedFunc<(u32, u32, u32), ()>, - fn_connect: TypedFunc, - fn_disconnect: TypedFunc, - fn_timer: TypedFunc<(), ()>, + fn_recv: TypedFunc<(u32, u32), ()>, } impl WasmHost { @@ -52,23 +42,11 @@ impl WasmHost { Ok((pt, len)) } - fn try_message(&mut self, client: ClientId, message: &str) -> Result<()> { - let (pt, len) = self.put_data(message.as_bytes())?; - - self.fn_message - .call(&mut self.store, (client.into(), pt, len))?; - - self.fn_free.call(&mut self.store, (pt, len))?; - - Ok(()) - } - - fn try_binary(&mut self, client: ClientId, message: &[u8]) -> Result<()> { - let (pt, len) = self.put_data(message)?; - - self.fn_binary - .call(&mut self.store, (client.into(), pt, len))?; + fn try_recv(&mut self, message: MessageToProcess) -> Result<()> { + let payload = bincode::serialize(&message).unwrap(); + let (pt, len) = self.put_data(&payload)?; + self.fn_recv.call(&mut self.store, (pt, len))?; self.fn_free.call(&mut self.store, (pt, len))?; Ok(()) @@ -77,39 +55,32 @@ impl WasmHost { impl StateroomService for WasmHost { fn init(&mut self, _: &impl StateroomContext) { - if let Err(error) = self.fn_init.call(&mut self.store, ()) { - tracing::error!(?error, "Error calling `init` on wasm host"); - } + let message = MessageToProcess::Init; + self.try_recv(message).unwrap(); } - fn message(&mut self, client: ClientId, message: &str, _: &impl StateroomContext) { - if let Err(error) = self.try_message(client, message) { - tracing::error!(?error, "Error calling `message` on wasm host"); - } + fn message(&mut self, sender: ClientId, message: MessagePayload, _: &impl StateroomContext) { + let message = MessageToProcess::Message { sender, message }; + self.try_recv(message).unwrap(); } fn connect(&mut self, client: ClientId, _: &impl StateroomContext) { - if let Err(error) = self.fn_connect.call(&mut self.store, client.into()) { - tracing::error!(?error, "Error calling `connect` on wasm host"); - } + let message = MessageToProcess::Connect { + client: client.into(), + }; + self.try_recv(message).unwrap(); } fn disconnect(&mut self, client: ClientId, _: &impl StateroomContext) { - if let Err(error) = self.fn_disconnect.call(&mut self.store, client.into()) { - tracing::error!(?error, "Error calling `disconnect` on wasm host"); + let message = MessageToProcess::Disconnect { + client: client.into(), }; + self.try_recv(message).unwrap(); } fn timer(&mut self, _: &impl StateroomContext) { - if let Err(error) = self.fn_timer.call(&mut self.store, ()) { - tracing::error!(?error, "Error calling `timer` on wasm host"); - }; - } - - fn binary(&mut self, client: ClientId, message: &[u8], _: &impl StateroomContext) { - if let Err(error) = self.try_binary(client, message) { - tracing::error!(?error, "Error calling `binary` on wasm host"); - }; + let message = MessageToProcess::Timer; + self.try_recv(message).unwrap(); } } @@ -121,17 +92,6 @@ fn get_memory(caller: &mut Caller<'_, T>) -> Memory { } } -#[inline] -fn get_string<'a, T>( - caller: &'a Caller<'_, T>, - memory: &'a Memory, - start: u32, - len: u32, -) -> Result<&'a str> { - let data = get_u8_vec(caller, memory, start, len); - std::str::from_utf8(data).map_err(|e| e.into()) -} - #[inline] fn get_u8_vec<'a, T>( caller: &'a Caller<'_, T>, @@ -193,43 +153,20 @@ impl WasmHost { let context = context.clone(); linker.func_wrap( ENV, - EXT_FN_SEND_MESSAGE, - move |mut caller: Caller<'_, WasiCtx>, client: i32, start: u32, len: u32| { - let memory = get_memory(&mut caller); - let message = get_string(&caller, &memory, start, len)?; - - context.send_message(MessageRecipient::decode_i32(client), message); - - Ok(()) - }, - )?; - } - - { - #[allow(clippy::redundant_clone)] - let context = context.clone(); - linker.func_wrap( - ENV, - EXT_FN_SEND_BINARY, - move |mut caller: Caller<'_, WasiCtx>, client: i32, start: u32, len: u32| { + EXT_FN_SEND, + move |mut caller: Caller<'_, WasiCtx>, start: u32, len: u32| { let memory = get_memory(&mut caller); let message = get_u8_vec(&caller, &memory, start, len); + let message: MessageFromProcess = bincode::deserialize(&message).unwrap(); - context.send_binary(MessageRecipient::decode_i32(client), message); - - Ok(()) - }, - )?; - } - - { - #[allow(clippy::redundant_clone)] - let context = context.clone(); - linker.func_wrap( - ENV, - EXT_FN_SET_TIMER, - move |_: Caller<'_, WasiCtx>, duration_ms: u32| { - context.set_timer(duration_ms); + match message { + MessageFromProcess::Message { recipient, message } => { + context.send_message(recipient, message); + } + MessageFromProcess::SetTimer { ms_delay } => { + context.set_timer(ms_delay); + } + }; Ok(()) }, @@ -238,13 +175,12 @@ impl WasmHost { let instance = linker.instantiate(&mut store, module)?; - let initialize = - instance.get_typed_func::<(u32, u32), ()>(&mut store, EXT_FN_INITIALIZE)?; - let fn_malloc = instance.get_typed_func::(&mut store, EXT_FN_MALLOC)?; let fn_free = instance.get_typed_func::<(u32, u32), ()>(&mut store, EXT_FN_FREE)?; + let fn_recv = instance.get_typed_func::<(u32, u32), ()>(&mut store, EXT_FN_RECV)?; + let mut memory = instance .get_memory(&mut store, EXT_MEMORY) .ok_or(WasmRuntimeError::CouldNotImportMemory)?; @@ -256,48 +192,30 @@ impl WasmHost { let pt = fn_malloc.call(&mut store, len)?; memory.write(&mut store, pt as usize, room_id)?; - initialize.call(&mut store, (pt, len))?; fn_free.call(&mut store, (pt, len))?; } - if get_global(&mut store, &mut memory, &instance, EXT_JAMSOCKET_VERSION)? + if get_global(&mut store, &mut memory, &instance, EXT_STATEROOM_VERSION) + .context("Stateroom version")? != EXPECTED_API_VERSION { return Err(WasmRuntimeError::InvalidApiVersion.into()); } - if get_global(&mut store, &mut memory, &instance, EXT_JAMSOCKET_PROTOCOL)? + if get_global(&mut store, &mut memory, &instance, EXT_STATEROOM_PROTOCOL) + .context("Stateroom protocol")? != EXPECTED_PROTOCOL_VERSION { return Err(WasmRuntimeError::InvalidProtocolVersion.into()); } - let fn_connect = instance.get_typed_func::(&mut store, EXT_FN_CONNECT)?; - - let fn_disconnect = instance.get_typed_func::(&mut store, EXT_FN_DISCONNECT)?; - - let fn_timer = instance.get_typed_func::<(), ()>(&mut store, EXT_FN_TIMER)?; - - let fn_init = instance.get_typed_func::<(), ()>(&mut store, EXT_FN_INIT)?; - - let fn_message = - instance.get_typed_func::<(u32, u32, u32), ()>(&mut store, EXT_FN_MESSAGE)?; - - let fn_binary = - instance.get_typed_func::<(u32, u32, u32), ()>(&mut store, EXT_FN_BINARY)?; - Ok(WasmHost { store, memory, fn_malloc, fn_free, - fn_init, - fn_message, - fn_binary, - fn_connect, - fn_disconnect, - fn_timer, + fn_recv, }) } } diff --git a/stateroom-wasm/Cargo.toml b/stateroom-wasm/Cargo.toml index 6300218..732ac25 100644 --- a/stateroom-wasm/Cargo.toml +++ b/stateroom-wasm/Cargo.toml @@ -10,4 +10,5 @@ description = "A macro for building a Stateroom service as a WebAssembly module. [dependencies] stateroom-wasm-macro = {path="./stateroom-wasm-macro", version="0.2.8"} -stateroom = {path="../stateroom", version="0.2.8"} +stateroom = {path="../stateroom", version="0.2.8", features=["serde"]} +bincode = "1.3.3" diff --git a/stateroom-wasm/src/lib.rs b/stateroom-wasm/src/lib.rs index fb1a55f..9de2b78 100644 --- a/stateroom-wasm/src/lib.rs +++ b/stateroom-wasm/src/lib.rs @@ -1,2 +1,73 @@ +use stateroom::MessageFromProcess; pub use stateroom::{ClientId, MessageRecipient, StateroomContext, StateroomService}; +pub use stateroom::{MessagePayload, MessageToProcess}; pub use stateroom_wasm_macro::stateroom_wasm; + +type Callback = unsafe extern "C" fn(*const u8, u32); + +pub struct WrappedStateroomService { + state: S, + context: WasmStateroomContext, +} + +impl WrappedStateroomService { + pub fn new(state: S, callback: Callback) -> Self { + Self { + state, + context: WasmStateroomContext { callback }, + } + } + + pub fn recv(&mut self, message_ptr: *const u8, message_len: u32) { + let message = unsafe { std::slice::from_raw_parts(message_ptr, message_len as usize) }; + let message: MessageToProcess = bincode::deserialize(message).unwrap(); + + match message { + MessageToProcess::Init => { + self.state.init(&self.context); + } + MessageToProcess::Connect { client } => { + self.state.connect(client.into(), &self.context); + } + MessageToProcess::Disconnect { client } => { + self.state.disconnect(client.into(), &self.context); + } + MessageToProcess::Message { sender, message } => { + self.state.message(sender, message, &self.context); + } + MessageToProcess::Timer => { + self.state.timer(&self.context); + } + } + } +} + +struct WasmStateroomContext { + callback: Callback, +} + +impl WasmStateroomContext { + pub fn send(&self, message: &MessageFromProcess) { + let message = bincode::serialize(message).unwrap(); + unsafe { + (self.callback)(message.as_ptr(), message.len() as u32); + } + } +} + +impl StateroomContext for WasmStateroomContext { + fn send_message( + &self, + recipient: impl Into, + message: impl Into, + ) { + let message: MessagePayload = message.into(); + let recipient: MessageRecipient = recipient.into(); + + self.send(&MessageFromProcess::Message { recipient, message }); + } + + fn set_timer(&self, ms_delay: u32) { + self.send(&MessageFromProcess::SetTimer { ms_delay }); + } +} diff --git a/stateroom-wasm/stateroom-wasm-macro/src/lib.rs b/stateroom-wasm/stateroom-wasm-macro/src/lib.rs index 7ac39fc..49e2f15 100644 --- a/stateroom-wasm/stateroom-wasm-macro/src/lib.rs +++ b/stateroom-wasm/stateroom-wasm-macro/src/lib.rs @@ -31,127 +31,39 @@ fn stateroom_wasm_impl(item: &proc_macro2::TokenStream) -> proc_macro2::TokenStr use super::#name; - // Instance-global stateroom service. - static mut SERVER_STATE: Option<#name> = None; - - #[no_mangle] - pub static JAMSOCKET_API_VERSION: i32 = 1; - - #[no_mangle] - pub static JAMSOCKET_API_PROTOCOL: i32 = 0; - - struct GlobalStateroomContext; - - impl stateroom_wasm::StateroomContext for GlobalStateroomContext { - fn set_timer(&self, ms_delay: u32) { - unsafe { - ffi::set_timer(ms_delay); - } - } - - fn send_message(&self, recipient: impl Into, message: &str) { - unsafe { - ffi::send_message( - recipient.into().encode_i32(), - &message.as_bytes()[0] as *const u8 as u32, - message.len() as u32, - ); - } - } - - fn send_binary(&self, recipient: impl Into, message: &[u8]) { - unsafe { - ffi::send_binary( - recipient.into().encode_i32(), - &message[0] as *const u8 as u32, - message.len() as u32, - ); - } - } - } - // Functions implemented by the host. mod ffi { extern "C" { - pub fn send_message(client: i32, message: u32, message_len: u32); - - pub fn send_binary(client: i32, message: u32, message_len: u32); - - pub fn set_timer(ms_delay: u32); + pub fn stateroom_send(message_ptr: *const u8, message_len: u32); } } - // Functions provided to the host. - #[no_mangle] - extern "C" fn initialize(room_id_ptr: *const u8, room_id_len: usize) { - let room_id = unsafe { - String::from_utf8(std::slice::from_raw_parts(room_id_ptr, room_id_len).to_vec()).map_err(|e| format!("Error parsing UTF-8 from host {:?}", e)).unwrap() - }; - let mut c = #name::default(); - - unsafe { - SERVER_STATE.replace(c); - } - } - - #[no_mangle] - extern "C" fn connect(client_id: stateroom_wasm::ClientId) { - match unsafe { SERVER_STATE.as_mut() } { - Some(st) => stateroom_wasm::StateroomService::connect(st, client_id.into(), &GlobalStateroomContext), - None => () - } - } - - #[no_mangle] - extern "C" fn disconnect(client_id: stateroom_wasm::ClientId) { - match unsafe { SERVER_STATE.as_mut() } { - Some(st) => stateroom_wasm::StateroomService::disconnect(st, client_id.into(), &GlobalStateroomContext), - None => () - } - } - - #[no_mangle] - extern "C" fn timer() { - match unsafe { SERVER_STATE.as_mut() } { - Some(st) => stateroom_wasm::StateroomService::timer(st, &GlobalStateroomContext), - None => () - } - } + // Instance-global stateroom service. + static mut SERVER_STATE: Option> = None; #[no_mangle] - extern "C" fn init() { - match unsafe { SERVER_STATE.as_mut() } { - Some(st) => stateroom_wasm::StateroomService::init(st, &GlobalStateroomContext), - None => () - } - } + pub static STATEROOM_API_VERSION: i32 = 1; #[no_mangle] - extern "C" fn message(client_id: stateroom_wasm::ClientId, ptr: *const u8, len: usize) { - unsafe { - let string = String::from_utf8(std::slice::from_raw_parts(ptr, len).to_vec()).expect("Error parsing UTF-8 from host {:?}"); - - match SERVER_STATE.as_mut() { - Some(st) => stateroom_wasm::StateroomService::message(st, client_id.into(), &string, &GlobalStateroomContext), - None => () - } - } - } + pub static STATEROOM_API_PROTOCOL: i32 = 0; #[no_mangle] - extern "C" fn binary(client_id: stateroom_wasm::ClientId, ptr: *const u8, len: usize) { - unsafe { - let data = std::slice::from_raw_parts(ptr, len); - + extern "C" fn stateroom_recv(message_ptr: *const u8, message_len: u32) { + let state = unsafe { match SERVER_STATE.as_mut() { - Some(st) => stateroom_wasm::StateroomService::binary(st, client_id.into(), data, &GlobalStateroomContext), - None => () + Some(s) => s, + None => { + let s = stateroom_wasm::WrappedStateroomService::new(#name::default(), ffi::stateroom_send); + SERVER_STATE.replace(s); + SERVER_STATE.as_mut().unwrap() + } } - } + }; + state.recv(message_ptr, message_len); } #[no_mangle] - pub unsafe extern "C" fn jam_malloc(size: u32) -> *mut u8 { + pub unsafe extern "C" fn stateroom_malloc(size: u32) -> *mut u8 { if size == 0 { return core::ptr::null_mut(); } @@ -160,7 +72,7 @@ fn stateroom_wasm_impl(item: &proc_macro2::TokenStream) -> proc_macro2::TokenStr } #[no_mangle] - pub unsafe extern "C" fn jam_free(ptr: *mut u8, size: u32) { + pub unsafe extern "C" fn stateroom_free(ptr: *mut u8, size: u32) { if size == 0 { return; } diff --git a/stateroom/src/lib.rs b/stateroom/src/lib.rs index 6c17d79..36daa5f 100644 --- a/stateroom/src/lib.rs +++ b/stateroom/src/lib.rs @@ -75,12 +75,11 @@ pub trait StateroomContext: Send + Sync + 'static { /// Recipient can be a `u32` representing an individual user to send a message to, or /// `MessageRecipient::Broadcast` to broadcast a message to all connected users. /// The message is a string which is sent verbatim to the user(s) indicated. - fn send_message(&self, recipient: impl Into, message: &str); - - /// Sends a binary message to a currently connected user, or broadcast a message to all users. - /// - /// See [StateroomContext::send_message] for details on the semantics of `recipient`. - fn send_binary(&self, recipient: impl Into, message: &[u8]); + fn send_message( + &self, + recipient: impl Into, + message: impl Into, + ); /// Sets a timer to wake up the service in the given number of milliseconds by invoking `timer()`. /// @@ -108,10 +107,13 @@ pub trait StateroomService: Send + Sync + 'static { fn disconnect(&mut self, client: ClientId, context: &impl StateroomContext) {} /// Called each time a client sends a text message to the service. - fn message(&mut self, client: ClientId, message: &str, context: &impl StateroomContext) {} - - /// Called each time a client sends a binary message to the service. - fn binary(&mut self, client: ClientId, message: &[u8], context: &impl StateroomContext) {} + fn message( + &mut self, + client: ClientId, + message: MessagePayload, + context: &impl StateroomContext, + ) { + } /// Called when [StateroomContext::set_timer] has been called on this service's context, /// after the provided duration. diff --git a/stateroom/src/messages.rs b/stateroom/src/messages.rs index 4ca9263..200a373 100644 --- a/stateroom/src/messages.rs +++ b/stateroom/src/messages.rs @@ -4,14 +4,22 @@ use serde::{Deserialize, Serialize}; use crate::{ClientId, MessageRecipient}; #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug)] pub enum MessagePayload { Bytes(Vec), Text(String), } +impl Into for String { + fn into(self) -> MessagePayload { + MessagePayload::Text(self.to_string()) + } +} + #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[cfg_attr(feature = "serde", serde(tag = "type"))] +#[derive(Debug)] pub enum MessageToProcess { + Init, Connect { client: ClientId, }, @@ -19,17 +27,19 @@ pub enum MessageToProcess { client: ClientId, }, Message { - client: ClientId, + sender: ClientId, message: MessagePayload, }, Timer, } #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[cfg_attr(feature = "serde", serde(tag = "type"))] pub enum MessageFromProcess { Message { recipient: MessageRecipient, message: MessagePayload, }, + SetTimer { + ms_delay: u32, + }, }