From 994ad6954216817443054919dd2901e88fdd5775 Mon Sep 17 00:00:00 2001 From: Mikkel Wienberg Madsen Date: Tue, 27 Aug 2024 13:17:06 +0200 Subject: [PATCH] Implement vm for python --- pycare/Cargo.lock | 32 ++++++----- pycare/Cargo.toml | 2 +- pycare/caring.pyi | 56 +++++++++++++++++- pycare/examples/vm1.py | 13 +++++ pycare/examples/vm2.py | 13 +++++ pycare/src/expr.rs | 106 +++++++++++++++++++++++++++++++++++ pycare/src/lib.rs | 39 ++++++++----- pycare/src/vm.rs | 104 ++++++++++++++++++++++++++++++++++ src/net/network.rs | 22 +++++--- src/vm/mod.rs | 11 ++++ src/vm/parsing.rs | 52 +++++++++++++++-- wecare/benches/spdz-25519.rs | 4 +- wecare/src/vm.rs | 60 +++++++++++--------- 13 files changed, 440 insertions(+), 74 deletions(-) create mode 100644 pycare/examples/vm1.py create mode 100644 pycare/examples/vm2.py create mode 100644 pycare/src/expr.rs create mode 100644 pycare/src/vm.rs diff --git a/pycare/Cargo.lock b/pycare/Cargo.lock index daf4bc3..869d4d8 100644 --- a/pycare/Cargo.lock +++ b/pycare/Cargo.lock @@ -408,9 +408,9 @@ checksum = "c007b1ae3abe1cb6f85a16305acd418b7ca6343b953633fee2b76d8f108b830f" [[package]] name = "fixed" -version = "2.0.0-alpha.27.0" +version = "2.0.0-alpha.28.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1bf398c70463a217e213bc751669e4c8509c5676df2444d7a5177722f8ddeaa" +checksum = "8276713fe97d959ae66a91bdac60a9a1b9e39d25513ccca4555fe1cf2571567f" dependencies = [ "az", "bytemuck", @@ -614,9 +614,9 @@ dependencies = [ [[package]] name = "heck" -version = "0.4.1" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" [[package]] name = "hermit-abi" @@ -895,15 +895,15 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.21.2" +version = "0.22.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a5e00b96a521718e08e03b1a622f01c8a8deb50719335de3f60b3b3950f069d8" +checksum = "831e8e819a138c36e212f3af3fd9eeffed6bf1510a805af35b0edee5ffa59433" dependencies = [ "cfg-if", "indoc", "libc", "memoffset", - "parking_lot", + "once_cell", "portable-atomic", "pyo3-build-config", "pyo3-ffi", @@ -913,9 +913,9 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.21.2" +version = "0.22.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7883df5835fafdad87c0d888b266c8ec0f4c9ca48a5bed6bbb592e8dedee1b50" +checksum = "1e8730e591b14492a8945cdff32f089250b05f5accecf74aeddf9e8272ce1fa8" dependencies = [ "once_cell", "python3-dll-a", @@ -924,9 +924,9 @@ dependencies = [ [[package]] name = "pyo3-ffi" -version = "0.21.2" +version = "0.22.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01be5843dc60b916ab4dad1dca6d20b9b4e6ddc8e15f50c47fe6d85f1fb97403" +checksum = "5e97e919d2df92eb88ca80a037969f44e5e70356559654962cbb3316d00300c6" dependencies = [ "libc", "pyo3-build-config", @@ -934,9 +934,9 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.21.2" +version = "0.22.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77b34069fc0682e11b31dbd10321cbf94808394c56fd996796ce45217dfac53c" +checksum = "eb57983022ad41f9e683a599f2fd13c3664d7063a3ac5714cae4b7bee7d3f206" dependencies = [ "proc-macro2", "pyo3-macros-backend", @@ -946,9 +946,9 @@ dependencies = [ [[package]] name = "pyo3-macros-backend" -version = "0.21.2" +version = "0.22.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08260721f32db5e1a5beae69a55553f56b99bd0e1c3e6e0a5e8851a9d0f5a85c" +checksum = "ec480c0c51ddec81019531705acac51bcdbeae563557c982aa8263bb96880372" dependencies = [ "heck", "proc-macro2", @@ -1424,8 +1424,10 @@ name = "wecare" version = "0.1.0" dependencies = [ "caring", + "castaway", "curve25519-dalek", "enum_dispatch", + "ff", "fixed", "rand", "tokio", diff --git a/pycare/Cargo.toml b/pycare/Cargo.toml index 2da0a53..c93c2a4 100644 --- a/pycare/Cargo.toml +++ b/pycare/Cargo.toml @@ -9,5 +9,5 @@ name = "caring" crate-type = ["cdylib"] [dependencies] -pyo3 = { version = "0.21", features = ["abi3-py37", "generate-import-lib", "extension-module"]} +pyo3 = { version = "0.22", features = ["abi3-py37", "generate-import-lib", "extension-module"]} wecare = { path = "../wecare" } diff --git a/pycare/caring.pyi b/pycare/caring.pyi index 124df1c..6edbfeb 100644 --- a/pycare/caring.pyi +++ b/pycare/caring.pyi @@ -1,4 +1,58 @@ +# TODO: Add documentation + +class Id: + ... + +class Opened: + ... + +class Computed: + def as_float(self) -> list[float]: ... + def as_integer(self) -> list[int]: ... + + +class Expr: + @staticmethod + def share(num: float | int | list[float] | list[int]) -> Expr: ... + + @staticmethod + def recv(id: Id) -> Expr: ... + + @staticmethod + def symmetric_share(num: float | int | list[float] | list[int], id: Id, size: int) -> list[Expr]: ... + + def open(self) -> Opened: ... + + def __add__(self, other: Expr) -> Expr: ... + def __sub__(self, other: Expr) -> Expr: ... + def __mul__(self, other: Expr) -> Expr: ... + def __iadd__(self, other: Expr) -> None: ... + def __isub__(self, other: Expr) -> None: ... + def __imul__(self, other: Expr) -> None: ... + class Engine: + def __init__( + self, + scheme: str, + address: str, + peers: list[str], + multithreaded: bool = False, + threshold: int | None = None, + preprocessed: str | None = None, + ) -> None: ... + + def execute(self, script: Opened) -> Computed: ... + + def id(self) -> Id: ... + + def peers(self) -> list[Id]: ... + + +# +# Old stuff +# + +class OldEngine: """ Performs a summation with the connected parties. Returns the sum of all the numbers. @@ -18,7 +72,7 @@ class Engine: """ Takedown the MPC Engine, releasing the resources and dropping connections. """ - def takedown(self): ... + def takedown(self) -> None: ... """ diff --git a/pycare/examples/vm1.py b/pycare/examples/vm1.py new file mode 100644 index 0000000..e9d9518 --- /dev/null +++ b/pycare/examples/vm1.py @@ -0,0 +1,13 @@ +from caring import Expr, Engine + +engine = Engine(scheme="shamir-25519", address="localhost:1234", peers=["localhost:1235"], threshold=1) + +[a, b] = Expr.symmetric_share(23, id=engine.id(), size=2) + +c = a + b; + +script = c.open() + +res = engine.execute(script).as_float() + +print(res) diff --git a/pycare/examples/vm2.py b/pycare/examples/vm2.py new file mode 100644 index 0000000..0d48653 --- /dev/null +++ b/pycare/examples/vm2.py @@ -0,0 +1,13 @@ +from caring import Expr, Engine + +engine = Engine(scheme="shamir-25519", address="localhost:1235", peers=["localhost:1234"], threshold=1) + +[a, b] = Expr.symmetric_share(7, id=engine.id(), size=2) + +c = a + b; + +script = c.open() + +res = engine.execute(script).as_float() + +print(res) diff --git a/pycare/src/expr.rs b/pycare/src/expr.rs new file mode 100644 index 0000000..ef1dd6a --- /dev/null +++ b/pycare/src/expr.rs @@ -0,0 +1,106 @@ +use pyo3::{exceptions::PyTypeError, prelude::*}; +use wecare::vm; + +#[pyclass] +pub struct Expr(vm::Expr); + +#[pyclass(frozen)] +pub struct Opened(pub(crate) vm::Opened); + +#[pyclass(frozen)] +#[derive(Debug, Clone, Copy)] +pub struct Id(pub(crate) vm::Id); + +#[pymethods] +impl Expr { + /// Construct a new share expression + #[staticmethod] + fn share(num: &Bound<'_, PyAny>) -> PyResult { + let res = if let Ok(num) = num.extract::() { + let num = vm::Number::Float(num); + vm::Expr::share(num) + } else if let Ok(num) = num.extract::() { + // TODO: Consider signedness + let num = vm::Number::Integer(num); + vm::Expr::share(num) + } else if let Ok(num) = num.extract::>() { + let num: Vec<_> = num.into_iter().map(vm::Number::Float).collect(); + vm::Expr::share_vec(num) + } else if let Ok(num) = num.extract::>() { + // TODO: Consider signedness + let num: Vec<_> = num.into_iter().map(vm::Number::Integer).collect(); + vm::Expr::share_vec(num) + } else { + return Err(PyTypeError::new_err("num is not a number")); + }; + Ok(Self(res)) + } + + #[staticmethod] + fn symmetric_share(num: &Bound<'_, PyAny>, id: Id, size: usize) -> PyResult> { + let res = if let Ok(num) = num.extract::() { + let num = vm::Number::Float(num); + vm::Expr::symmetric_share(num) + } else if let Ok(num) = num.extract::() { + // TODO: Consider signedness + let num = vm::Number::Integer(num); + vm::Expr::symmetric_share(num) + } else if let Ok(num) = num.extract::>() { + let num: Vec<_> = num.into_iter().map(vm::Number::Float).collect(); + vm::Expr::symmetric_share_vec(num) + } else if let Ok(num) = num.extract::>() { + // TODO: Consider signedness + let num: Vec<_> = num.into_iter().map(vm::Number::Integer).collect(); + vm::Expr::symmetric_share_vec(num) + } else { + return Err(PyTypeError::new_err("num is not a number")); + }; + + let res = res.concrete(id.0 .0, size); + let res = res.into_iter().map(Expr).collect(); + Ok(res) + } + + /// recv from a given party + #[staticmethod] + fn recv(id: &Id) -> Self { + Self(vm::Expr::receive_input(id.0)) + } + + fn open(&self) -> Opened { + Opened(self.0.clone().open()) + } + + fn __iadd__(&mut self, other: &Self) { + let rhs: vm::Expr = other.0.clone(); + self.0 += rhs; + } + + fn __add__(&self, other: &Self) -> Self { + let lhs: vm::Expr = self.0.clone(); + let rhs: vm::Expr = other.0.clone(); + Self(lhs + rhs) + } + + fn __sub__(&self, other: &Self) -> Self { + let lhs: vm::Expr = self.0.clone(); + let rhs: vm::Expr = other.0.clone(); + Self(lhs - rhs) + } + + fn __isub__(&mut self, other: &Self) { + let rhs: vm::Expr = other.0.clone(); + self.0 -= rhs; + } + + fn __mul__(&self, other: &Self) -> Self { + let lhs: vm::Expr = self.0.clone(); + let rhs: vm::Expr = other.0.clone(); + Self(lhs * rhs) + } + + fn __imul__(&mut self, other: &Self) { + let rhs: vm::Expr = other.0.clone(); + self.0 *= rhs; + } +} diff --git a/pycare/src/lib.rs b/pycare/src/lib.rs index 0356ce2..33623da 100644 --- a/pycare/src/lib.rs +++ b/pycare/src/lib.rs @@ -1,15 +1,18 @@ use pyo3::{exceptions::PyIOError, prelude::*, types::PyTuple}; +pub mod expr; +pub mod vm; + use std::fs::File; use wecare::*; #[pyclass] -struct Engine(Option); +struct OldEngine(Option); /// Setup a MPC addition engine connected to the given sockets using SPDZ. #[pyfunction] #[pyo3(signature = (path_to_pre, my_addr, *others))] -fn spdz(path_to_pre: &str, my_addr: &str, others: &Bound<'_, PyTuple>) -> PyResult { +fn spdz(path_to_pre: &str, my_addr: &str, others: &Bound<'_, PyTuple>) -> PyResult { let others: Vec<_> = others .iter() .map(|x| x.extract::().unwrap().clone()) @@ -18,17 +21,17 @@ fn spdz(path_to_pre: &str, my_addr: &str, others: &Bound<'_, PyTuple>) -> PyResu match wecare::Engine::setup(my_addr) .add_participants(&others) .file_to_preprocessed(&mut file) - .build_spdz() { - Ok(e) => Ok(Engine(Some(e))), + .build_spdz() + { + Ok(e) => Ok(OldEngine(Some(e))), Err(e) => Err(PyIOError::new_err(e.0)), } } - /// Setup a MPC addition engine connected to the given sockets using shamir secret sharing. #[pyfunction] #[pyo3(signature = (threshold, my_addr, *others))] -fn shamir(threshold: u32, my_addr: &str, others: &Bound<'_, PyTuple>) -> PyResult { +fn shamir(threshold: u32, my_addr: &str, others: &Bound<'_, PyTuple>) -> PyResult { let others: Vec<_> = others .iter() .map(|x| x.extract::().unwrap().clone()) @@ -36,17 +39,17 @@ fn shamir(threshold: u32, my_addr: &str, others: &Bound<'_, PyTuple>) -> PyResul match wecare::Engine::setup(my_addr) .add_participants(&others) .threshold(threshold as u64) - .build_shamir() { - Ok(e) => Ok(Engine(Some(e))), + .build_shamir() + { + Ok(e) => Ok(OldEngine(Some(e))), Err(e) => Err(PyIOError::new_err(e.0)), } } - /// Setup a MPC addition engine connected to the given sockets using shamir secret sharing. #[pyfunction] #[pyo3(signature = (threshold, my_addr, *others))] -fn feldman(threshold: u32, my_addr: &str, others: &Bound<'_, PyTuple>) -> PyResult { +fn feldman(threshold: u32, my_addr: &str, others: &Bound<'_, PyTuple>) -> PyResult { let others: Vec<_> = others .iter() .map(|x| x.extract::().unwrap().clone()) @@ -54,8 +57,9 @@ fn feldman(threshold: u32, my_addr: &str, others: &Bound<'_, PyTuple>) -> PyResu match wecare::Engine::setup(my_addr) .add_participants(&others) .threshold(threshold as u64) - .build_feldman() { - Ok(e) => Ok(Engine(Some(e))), + .build_feldman() + { + Ok(e) => Ok(OldEngine(Some(e))), Err(e) => Err(PyIOError::new_err(e.0)), } } @@ -69,11 +73,11 @@ fn preproc(number_of_shares: usize, paths_to_pre: &Bound<'_, PyTuple>) { .map(|x| x.extract::().unwrap()) .map(|p| File::create(p).unwrap()) .collect(); - do_preproc(&mut files, vec![number_of_shares, number_of_shares], false); + do_preproc(&mut files, &[number_of_shares, number_of_shares], false); } #[pymethods] -impl Engine { +impl OldEngine { /// Run a sum procedure in which each party supplies a double floating point fn sum(&mut self, a: f64) -> f64 { self.0.as_mut().unwrap().mpc_sum(&[a]).unwrap()[0] @@ -98,6 +102,11 @@ fn caring(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(shamir, m)?)?; m.add_function(wrap_pyfunction!(feldman, m)?)?; m.add_function(wrap_pyfunction!(preproc, m)?)?; - m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; Ok(()) } diff --git a/pycare/src/vm.rs b/pycare/src/vm.rs new file mode 100644 index 0000000..4b776a8 --- /dev/null +++ b/pycare/src/vm.rs @@ -0,0 +1,104 @@ +use std::sync::Mutex; + +use crate::expr::{Id, Opened}; +use pyo3::{exceptions::PyValueError, prelude::*, types::PyList}; +use wecare::vm; + +#[pyclass(frozen)] +pub struct Engine(Mutex); + +#[pyclass(frozen)] +pub struct Computed(vm::Value); + +#[pymethods] +impl Computed { + /// Parse the computed result as a float + fn as_float(&self) -> Vec { + self.0.clone().map(|s| s.to_f64()).to_vec() + } + + /// Parse the computed result as an integer + fn as_integer(&self) -> Vec { + self.0.clone().map(|s| s.to_u64()).to_vec() + } +} + +#[pymethods] +impl Engine { + /// Construct a new engine connected to the other parties + /// + /// * `scheme`: one of {'spdz-25519', 'shamir-25519', 'feldman-25519', 'spdz-32', 'shamir-32'} + /// * `address`: the address to bind to + /// * `peers`: the adresses of the other peers to connect to + /// * `multithreaded`: use a multithreaded runtime + /// * `threshold`: (optional) threshold if using a threshold scheme + /// * `preprocessed`: (optional) path to preprocessed material + #[new] + #[pyo3(signature = (scheme, address, peers, multithreaded=false, threshold=None, preprocessed=None))] + fn new( + scheme: &str, + address: &str, + peers: &Bound<'_, PyList>, + multithreaded: bool, + threshold: Option, + preprocessed: Option<&str>, + ) -> PyResult { + let peers = peers.iter().map(|x| x.extract::().unwrap().clone()); + + let (scheme, field) = match scheme { + "spdz-25519" => (vm::SchemeKind::Spdz, vm::FieldKind::Curve25519), + "shamir-25519" => (vm::SchemeKind::Shamir, vm::FieldKind::Curve25519), + "feldman-25519" => (vm::SchemeKind::Shamir, vm::FieldKind::Curve25519), + "spdz-32" => (vm::SchemeKind::Spdz, vm::FieldKind::Element32), + "shamir-32" => (vm::SchemeKind::Shamir, vm::FieldKind::Element32), + _ => return Err(PyValueError::new_err("Unknown scheme")), + }; + + let builder = vm::Engine::builder() + .address(address) + .participants_from(peers) + .scheme(scheme) + .field(field); + + let builder = builder.threshold(threshold.unwrap_or_default()); + let builder = match preprocessed { + Some(path) => { + let file = std::fs::File::open(path)?; + builder.preprocessed(file) + } + None => builder, + }; + + let builder = if multithreaded { + builder.multi_threaded_runtime() + } else { + builder.single_threaded_runtime() + }; + + let builder = builder.connect_blocking().unwrap(); + let engine = builder.build(); + Ok(Self(Mutex::new(engine))) + } + + /// Execute a script + /// + /// * `script`: list of expressions to evaluate + fn execute(&self, script: &Opened) -> Computed { + let res = { + let mut engine = self.0.lock().expect("Lock poisoned"); + let script: vm::Opened = script.0.clone(); + engine.execute(script) + }; + Computed(res) + } + + /// Your own Id + fn id(&self) -> Id { + Id(self.0.lock().unwrap().id()) + } + + /// Your own Id + fn peers(&self) -> Vec { + self.0.lock().unwrap().peers().into_iter().map(Id).collect() + } +} diff --git a/src/net/network.rs b/src/net/network.rs index 60ad467..d974ebb 100644 --- a/src/net/network.rs +++ b/src/net/network.rs @@ -83,6 +83,21 @@ impl Network { Id((self.index + n + 1) % n) } + pub fn peers(&self) -> Vec { + let n = self.connections.len(); + (0..=n) + .map(|i| Id(i)) + .filter(|id| *id != self.id()) + .collect_vec() + } + + /// Returns a range for representing the participants. + pub fn participants(&self) -> Range { + let n = self.connections.len() as u32; + let n = n + 1; // We need to count ourselves. + 0..n + } + /// Broadcast a message to all other parties. /// /// Asymmetric, non-waiting @@ -312,13 +327,6 @@ impl Network { Ok(()) } - /// Returns a range for representing the participants. - pub fn participants(&self) -> Range { - let n = self.connections.len() as u32; - let n = n + 1; // We need to count ourselves. - 0..n - } - async fn drop_party(_id: usize) -> Result<(), ()> { todo!("Initiate a drop vote"); } diff --git a/src/vm/mod.rs b/src/vm/mod.rs index 6e61bfd..d7181d4 100644 --- a/src/vm/mod.rs +++ b/src/vm/mod.rs @@ -38,6 +38,13 @@ impl Value { Value::Vector(a) => Value::Vector(a.into_iter().map(func).collect()), } } + + pub fn to_vec(self) -> Vec { + match self { + Value::Single(s) => vec![s], + Value::Vector(v) => v.into(), + } + } } impl From for Value { @@ -197,6 +204,10 @@ where self.network.id() } + pub fn peers(&self) -> Vec { + self.network.peers() + } + pub fn add_fuel(&mut self, fuel: &mut Vec>) { self.fueltank.append(fuel); } diff --git a/src/vm/parsing.rs b/src/vm/parsing.rs index 65489f0..709a2a0 100644 --- a/src/vm/parsing.rs +++ b/src/vm/parsing.rs @@ -3,7 +3,7 @@ use itertools::Itertools; use std::{ array, iter::Sum, - ops::{Add, Mul, Sub}, + ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign}, }; use crate::{ @@ -13,20 +13,20 @@ use crate::{ }; /// An expression stack -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct Exp { constants: Vec>, instructions: Vec, } // A dynamicly sized list of expressions. -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct ExpList { constant: Value, } // An opened expression (last step) -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct Opened(Exp); impl Exp { @@ -44,7 +44,7 @@ impl Exp { } } - fn add_constant(&mut self, value: impl Into>) -> Const { + fn append_constant(&mut self, value: impl Into>) -> Const { self.constants.push(value.into()); Const(self.constants.len() as u16 - 1) } @@ -112,6 +112,25 @@ impl Exp { }) } + /// Share and receive based on your given Id + /// + /// * `input`: Your input to secret-share + /// * `me`: Your Id + pub fn share_and_receive_n(input: impl Into, me: Id, n: usize) -> Vec { + let mut input: Option = Some(input.into()); + (0..n) + .map(|i| { + let id = Id(i); + if id == me { + let f = input.take().expect("We only do this once."); + Self::share(f) + } else { + Self::receive_input(id) + } + }) + .collect() + } + /// Open the secret value pub fn open(mut self) -> Opened { self.instructions.push(Instruction::Recombine); @@ -204,6 +223,27 @@ impl ExpList { } } +impl AddAssign for Exp { + fn add_assign(&mut self, rhs: Self) { + self.append(rhs); + self.instructions.push(Instruction::Add); + } +} + +impl SubAssign for Exp { + fn sub_assign(&mut self, rhs: Self) { + self.append(rhs); + self.instructions.push(Instruction::Sub); + } +} + +impl MulAssign for Exp { + fn mul_assign(&mut self, rhs: Self) { + self.append(rhs); + self.instructions.push(Instruction::Mul); + } +} + impl Add for Exp { type Output = Self; @@ -228,7 +268,7 @@ impl Mul for Exp { type Output = Self; fn mul(mut self, rhs: F) -> Self::Output { - let addr = self.add_constant(rhs); + let addr = self.append_constant(rhs); self.instructions.push(Instruction::MulCon(addr)); self } diff --git a/wecare/benches/spdz-25519.rs b/wecare/benches/spdz-25519.rs index dc534e4..37353c1 100644 --- a/wecare/benches/spdz-25519.rs +++ b/wecare/benches/spdz-25519.rs @@ -32,7 +32,7 @@ fn build_spdz_engines() -> (blocking::Engine, blocking::Engine) { Engine::builder() .address("127.0.0.1:1234") .participant("127.0.0.1:1235") - .preprocessed(&mut ctx1) + .preprocessed(ctx1) .scheme(SchemeKind::Spdz) .field(FieldKind::Curve25519) .single_threaded_runtime() @@ -45,7 +45,7 @@ fn build_spdz_engines() -> (blocking::Engine, blocking::Engine) { Engine::builder() .address("127.0.0.1:1235") .participant("127.0.0.1:1234") - .preprocessed(&mut ctx2) + .preprocessed(ctx2) .scheme(SchemeKind::Spdz) .field(FieldKind::Curve25519) .single_threaded_runtime() diff --git a/wecare/src/vm.rs b/wecare/src/vm.rs index bd89150..f1154ce 100644 --- a/wecare/src/vm.rs +++ b/wecare/src/vm.rs @@ -9,15 +9,15 @@ use rand::{rngs::StdRng, SeedableRng}; use caring::{ algebra::{element::Element32, math::Vector}, - net::{agency::Broadcast, connection::TcpConnection, network::TcpNetwork, Id}, + net::{agency::Broadcast, connection::TcpConnection, network::TcpNetwork}, schemes::{feldman, shamir, spdz}, - vm::{ - self, - parsing::{Exp, Opened}, - Value, - }, + vm::{self, parsing::Exp}, }; +pub use caring::net::Id; +pub use caring::vm::Value; + +#[derive(Clone, Copy)] pub enum Number { Float(f64), Integer(u64), @@ -155,6 +155,7 @@ type SpdzCurve25519Engine = SpdzEngine; type SpdzElement32Engine = SpdzEngine; pub type Expr = Exp; +pub type Opened = vm::parsing::Opened; pub enum Engine { Spdz25519(SpdzCurve25519Engine), @@ -189,11 +190,11 @@ macro_rules! delegate_await { } impl Engine { - pub fn builder<'a>() -> EngineBuilder<'a> { + pub fn builder() -> EngineBuilder { EngineBuilder::default() } - pub async fn execute(&mut self, expr: Opened) -> Value { + pub async fn execute(&mut self, expr: Opened) -> Value { let res: Value = match self { Engine::Spdz25519(engine) => { let res = engine.execute(&expr.try_finalize().unwrap()).await; @@ -223,6 +224,10 @@ impl Engine { delegate!(self, id) } + pub fn peers(&self) -> Vec { + delegate!(self, peers) + } + pub async fn sum(&mut self, nums: &[f64]) -> Vec { let nums: Vector<_> = nums.iter().map(|v| Number::Float(*v)).collect(); let program = { @@ -254,17 +259,17 @@ pub enum SchemeKind { } #[derive(Default)] -pub struct EngineBuilder<'a> { +pub struct EngineBuilder { own: Option, peers: Vec, network: Option, threshold: Option, - preprocesing: Option<&'a mut File>, + preprocesing: Option, field: Option, scheme: Option, } -impl<'a> EngineBuilder<'a> { +impl EngineBuilder { pub fn address(mut self, addr: impl ToSocketAddrs) -> Self { // TODO: Handle this better self.own @@ -301,7 +306,7 @@ impl<'a> EngineBuilder<'a> { self } - pub fn preprocessed(mut self, file: &'a mut File) -> Self { + pub fn preprocessed(mut self, file: File) -> Self { self.preprocesing = Some(file); self } @@ -351,13 +356,13 @@ impl<'a> EngineBuilder<'a> { Engine::Shamir32(vm::Engine::new(context, network, rng)) } (SchemeKind::Spdz, FieldKind::Curve25519) => { - let file = self.preprocesing.expect("Missing preproc!"); - let context = spdz::preprocessing::load_context(file); + let mut file = self.preprocesing.expect("Missing preproc!"); + let context = spdz::preprocessing::load_context(&mut file); Engine::Spdz25519(vm::Engine::new(context, network, rng)) } (SchemeKind::Spdz, FieldKind::Element32) => { - let file = self.preprocesing.expect("Missing preproc!"); - let context = spdz::preprocessing::load_context(file); + let mut file = self.preprocesing.expect("Missing preproc!"); + let context = spdz::preprocessing::load_context(&mut file); Engine::Spdz32(vm::Engine::new(context, network, rng)) } (SchemeKind::Feldman, FieldKind::Curve25519) => { @@ -376,10 +381,7 @@ impl<'a> EngineBuilder<'a> { } pub mod blocking { - use caring::{ - net::Id, - vm::{parsing::Opened, Value}, - }; + use caring::{net::Id, vm::Value}; use crate::vm::UnknownNumber; @@ -388,13 +390,13 @@ pub mod blocking { runtime: tokio::runtime::Runtime, } - pub struct EngineBuilder<'a> { - parent: super::EngineBuilder<'a>, + pub struct EngineBuilder { + parent: super::EngineBuilder, runtime: tokio::runtime::Runtime, } - impl<'a> super::EngineBuilder<'a> { - pub fn single_threaded_runtime(self) -> EngineBuilder<'a> { + impl super::EngineBuilder { + pub fn single_threaded_runtime(self) -> EngineBuilder { let runtime = tokio::runtime::Builder::new_current_thread() .enable_all() .build() @@ -404,7 +406,7 @@ pub mod blocking { runtime, } } - pub fn multi_threaded_runtime(self) -> EngineBuilder<'a> { + pub fn multi_threaded_runtime(self) -> EngineBuilder { let runtime = tokio::runtime::Builder::new_multi_thread() .enable_all() .build() @@ -416,7 +418,7 @@ pub mod blocking { } } - impl<'a> EngineBuilder<'a> { + impl EngineBuilder { pub fn connect_blocking(mut self) -> Result { let runtime = &mut self.runtime; let mut parent = self.parent; @@ -435,7 +437,7 @@ pub mod blocking { } impl Engine { - pub fn execute(&mut self, expr: Opened) -> Value { + pub fn execute(&mut self, expr: super::Opened) -> Value { self.runtime.block_on(self.parent.execute(expr)) } @@ -443,6 +445,10 @@ pub mod blocking { self.parent.id() } + pub fn peers(&self) -> Vec { + self.parent.peers() + } + pub fn sum(&mut self, nums: &[f64]) -> Vec { self.runtime.block_on(self.parent.sum(nums)) }