diff --git a/src/serde/de_tree.rs b/src/serde/de_tree.rs new file mode 100644 index 00000000..fc80ce30 --- /dev/null +++ b/src/serde/de_tree.rs @@ -0,0 +1,327 @@ +use std::convert::TryInto; +use std::io::{Error, Read, Result, Write}; + +use sha2::Digest; + +use crate::sha2::Sha256; + +use super::parse_atom::decode_size_with_offset; +use super::utils::{copy_exactly, skip_bytes}; + +const MAX_SINGLE_BYTE: u8 = 0x7f; +const CONS_BOX_MARKER: u8 = 0xff; + +struct ShaWrapper(Sha256); + +impl Write for ShaWrapper { + fn write(&mut self, blob: &[u8]) -> std::result::Result { + self.0.update(blob); + Ok(blob.len()) + } + fn flush(&mut self) -> std::result::Result<(), Error> { + Ok(()) + } +} + +/// This data structure is used with `parse_triples`, which returns a triple of +/// integer values for each clvm object in a tree. + +#[derive(Debug, PartialEq, Eq)] +pub enum ParsedTriple { + Atom { + start: u64, + end: u64, + atom_offset: u32, + }, + Pair { + start: u64, + end: u64, + right_index: u32, + }, +} + +enum ParseOpRef { + ParseObj, + SaveCursor(usize), + SaveIndex(usize), +} + +fn sha_blobs(blobs: &[&[u8]]) -> [u8; 32] { + let mut h = Sha256::new(); + for blob in blobs { + h.update(blob); + } + h.finalize() + .as_slice() + .try_into() + .expect("wrong slice length") +} + +fn tree_hash_for_byte(b: u8, calculate_tree_hashes: bool) -> Option<[u8; 32]> { + if calculate_tree_hashes { + Some(sha_blobs(&[&[1, b]])) + } else { + None + } +} + +fn skip_or_sha_bytes( + f: &mut R, + skip_size: u64, + calculate_tree_hashes: bool, +) -> Result> { + if calculate_tree_hashes { + let mut h = Sha256::new(); + h.update([1]); + let mut w = ShaWrapper(h); + copy_exactly(f, &mut w, skip_size)?; + let r: [u8; 32] = + w.0.finalize() + .as_slice() + .try_into() + .expect("wrong slice length"); + Ok(Some(r)) + } else { + skip_bytes(f, skip_size)?; + Ok(None) + } +} + +/// parse a serialized clvm object tree to an array of `ParsedTriple` objects + +/// This alternative mechanism of deserialization generates an array of +/// references to each clvm object. A reference contains three values: +/// a start offset within the blob, an end offset, and a third value that +/// is either: an atom offset (relative to the start offset) where the atom +/// data starts (and continues to the end offset); or an index in the array +/// corresponding to the "right" element of the pair (in which case, the +/// "left" element corresponds to the current index + 1). +/// +/// Since these values are offsets into the original buffer, that buffer needs +/// to be kept around to get the original atoms. + +type ParsedTriplesOutput = (Vec, Option>); + +pub fn parse_triples( + f: &mut R, + calculate_tree_hashes: bool, +) -> Result { + let mut r = Vec::new(); + let mut tree_hashes = Vec::new(); + let mut op_stack = vec![ParseOpRef::ParseObj]; + let mut cursor: u64 = 0; + loop { + match op_stack.pop() { + None => { + break; + } + Some(op) => match op { + ParseOpRef::ParseObj => { + let mut b: [u8; 1] = [0]; + f.read_exact(&mut b)?; + let start = cursor; + cursor += 1; + let b = b[0]; + if b == CONS_BOX_MARKER { + let index = r.len(); + let new_obj = ParsedTriple::Pair { + start, + end: 0, + right_index: 0, + }; + r.push(new_obj); + if calculate_tree_hashes { + tree_hashes.push([0; 32]) + } + op_stack.push(ParseOpRef::SaveCursor(index)); + op_stack.push(ParseOpRef::ParseObj); + op_stack.push(ParseOpRef::SaveIndex(index)); + op_stack.push(ParseOpRef::ParseObj); + } else { + let (start, end, atom_offset, tree_hash) = { + if b <= MAX_SINGLE_BYTE { + ( + start, + start + 1, + 0, + tree_hash_for_byte(b, calculate_tree_hashes), + ) + } else { + let (atom_offset, atom_size) = decode_size_with_offset(f, b)?; + let end = start + (atom_offset as u64) + atom_size; + let h = skip_or_sha_bytes(f, atom_size, calculate_tree_hashes)?; + (start, end, atom_offset as u32, h) + } + }; + if calculate_tree_hashes { + tree_hashes.push(tree_hash.expect("failed unwrap")) + } + let new_obj = ParsedTriple::Atom { + start, + end, + atom_offset, + }; + cursor = end; + r.push(new_obj); + } + } + ParseOpRef::SaveCursor(index) => { + if let ParsedTriple::Pair { + start, + end: _, + right_index, + } = r[index] + { + if calculate_tree_hashes { + let h = sha_blobs(&[ + &[2], + &tree_hashes[index + 1], + &tree_hashes[right_index as usize], + ]); + tree_hashes[index] = h; + } + r[index] = ParsedTriple::Pair { + start, + end: cursor, + right_index, + }; + } + } + ParseOpRef::SaveIndex(index) => { + if let ParsedTriple::Pair { + start, + end, + right_index: _, + } = r[index] + { + r[index] = ParsedTriple::Pair { + start, + end, + right_index: r.len() as u32, + }; + } + } + }, + } + } + Ok(( + r, + if calculate_tree_hashes { + Some(tree_hashes) + } else { + None + }, + )) +} + +#[cfg(test)] +use std::io::Cursor; + +#[cfg(test)] +use hex::FromHex; + +#[cfg(test)] +fn check_parse_tree(h: &str, expected: Vec, expected_sha_tree_hex: &str) -> () { + let b = Vec::from_hex(h).unwrap(); + println!("{:?}", b); + let mut f = Cursor::new(b); + let (p, tree_hash) = parse_triples(&mut f, false).unwrap(); + assert_eq!(p, expected); + assert_eq!(tree_hash, None); + + let b = Vec::from_hex(h).unwrap(); + let mut f = Cursor::new(b); + let (p, tree_hash) = parse_triples(&mut f, true).unwrap(); + assert_eq!(p, expected); + + let est = Vec::from_hex(expected_sha_tree_hex).unwrap(); + assert_eq!(tree_hash.unwrap()[0].to_vec(), est); +} + +#[test] +fn test_parse_tree() { + check_parse_tree( + "80", + vec![ParsedTriple::Atom { + start: 0, + end: 1, + atom_offset: 1, + }], + "4bf5122f344554c53bde2ebb8cd2b7e3d1600ad631c385a5d7cce23c7785459a", + ); + + check_parse_tree( + "ff648200c8", + vec![ + ParsedTriple::Pair { + start: 0, + end: 5, + right_index: 2, + }, + ParsedTriple::Atom { + start: 1, + end: 2, + atom_offset: 0, + }, + ParsedTriple::Atom { + start: 2, + end: 5, + atom_offset: 1, + }, + ], + "247f7d3f63b346ea93ca47f571cd0f4455392348b888a4286072bef0ac6069b5", + ); + + check_parse_tree( + "ff83666f6fff83626172ff8362617a80", // `(foo bar baz)` + vec![ + ParsedTriple::Pair { + start: 0, + end: 16, + right_index: 2, + }, + ParsedTriple::Atom { + start: 1, + end: 5, + atom_offset: 1, + }, + ParsedTriple::Pair { + start: 5, + end: 16, + right_index: 4, + }, + ParsedTriple::Atom { + start: 6, + end: 10, + atom_offset: 1, + }, + ParsedTriple::Pair { + start: 10, + end: 16, + right_index: 6, + }, + ParsedTriple::Atom { + start: 11, + end: 15, + atom_offset: 1, + }, + ParsedTriple::Atom { + start: 15, + end: 16, + atom_offset: 1, + }, + ], + "47f30bf9935e25e4262023124fb5e986d755b9ed65a28ac78925c933bfd57dbd", + ); + + let s = "c0a0".to_owned() + &hex::encode([0x31u8; 160]); + check_parse_tree( + &s, + vec![ParsedTriple::Atom { + start: 0, + end: 162, + atom_offset: 2, + }], + "d1c109981a9c5a3bbe2d98795a186a0f057dc9a3a7f5e1eb4dfb63a1636efa2d", + ); +} diff --git a/src/serde/mod.rs b/src/serde/mod.rs index 41dcf0f6..1b9b8c6b 100644 --- a/src/serde/mod.rs +++ b/src/serde/mod.rs @@ -1,6 +1,7 @@ mod bytes32; mod de; mod de_br; +mod de_tree; mod errors; mod object_cache; mod parse_atom; @@ -8,6 +9,7 @@ mod read_cache_lookup; mod ser; mod ser_br; mod tools; +mod utils; mod write_atom; #[cfg(test)] @@ -15,6 +17,7 @@ mod test; pub use de::node_from_bytes; pub use de_br::node_from_bytes_backrefs; +pub use de_tree::{parse_triples, ParsedTriple}; pub use ser::node_to_bytes; pub use ser_br::node_to_bytes_backrefs; -pub use tools::{serialized_length_from_bytes, tree_hash_from_stream}; +pub use tools::{parse_through_clvm_object, serialized_length_from_bytes, tree_hash_from_stream}; diff --git a/src/serde/tools.rs b/src/serde/tools.rs index 5a00f2a3..336b6888 100644 --- a/src/serde/tools.rs +++ b/src/serde/tools.rs @@ -1,8 +1,9 @@ use std::io; -use std::io::{Cursor, Read, Seek, SeekFrom}; +use std::io::{Cursor, Read}; use super::errors::bad_encoding; use super::parse_atom::decode_size; +use super::utils::skip_bytes; const MAX_SINGLE_BYTE: u8 = 0x7f; const CONS_BOX_MARKER: u8 = 0xff; @@ -15,43 +16,27 @@ enum ParseOp { pub fn serialized_length_from_bytes(b: &[u8]) -> io::Result { let mut f = Cursor::new(b); - let mut ops = vec![ParseOp::SExp]; + parse_through_clvm_object(&mut f).map_err(|_e| bad_encoding())?; + Ok(f.position()) +} + +pub fn parse_through_clvm_object(f: &mut R) -> io::Result<()> { + let mut to_parse_count = 1; let mut b = [0; 1]; loop { - let op = ops.pop(); - if op.is_none() { + if to_parse_count < 1 { break; + }; + f.read_exact(&mut b)?; + if b[0] == CONS_BOX_MARKER { + to_parse_count += 2; + } else if b[0] != 0x80 && b[0] > MAX_SINGLE_BYTE { + let blob_size = decode_size(f, b[0])?; + skip_bytes(f, blob_size)?; } - match op.unwrap() { - ParseOp::SExp => { - f.read_exact(&mut b)?; - if b[0] == CONS_BOX_MARKER { - // since all we're doing is to determing the length of the - // serialized buffer, we don't need to do anything about - // "cons". So we skip pushing it to lower the pressure on - // the op stack - //ops.push(ParseOp::Cons); - ops.push(ParseOp::SExp); - ops.push(ParseOp::SExp); - } else if b[0] == 0x80 || b[0] <= MAX_SINGLE_BYTE { - // This one byte we just read was the whole atom. - // or the - // special case of NIL - } else { - let blob_size = decode_size(&mut f, b[0])?; - f.seek(SeekFrom::Current(blob_size as i64))?; - if (f.get_ref().len() as u64) < f.position() { - return Err(bad_encoding()); - } - } - } - ParseOp::Cons => { - // cons. No need to construct any structure here. Just keep - // going - } - } + to_parse_count -= 1; } - Ok(f.position()) + Ok(()) } use crate::sha2::{Digest, Sha256}; diff --git a/src/serde/utils.rs b/src/serde/utils.rs new file mode 100644 index 00000000..954d6e09 --- /dev/null +++ b/src/serde/utils.rs @@ -0,0 +1,24 @@ +use std::io; +use std::io::{copy, sink, Error, Read, Write}; + +pub fn copy_exactly( + reader: &mut R, + writer: &mut W, + expected_size: u64, +) -> io::Result<()> { + let mut reader = reader.by_ref().take(expected_size); + + let count = copy(&mut reader, writer)?; + if count < expected_size { + Err(Error::new( + std::io::ErrorKind::UnexpectedEof, + "copy terminated early", + )) + } else { + Ok(()) + } +} + +pub fn skip_bytes(f: &mut R, skip_size: u64) -> io::Result<()> { + copy_exactly(f, &mut sink(), skip_size) +} diff --git a/wheel/python/benchmarks/deserialization.py b/wheel/python/benchmarks/deserialization.py new file mode 100644 index 00000000..ea7df11f --- /dev/null +++ b/wheel/python/benchmarks/deserialization.py @@ -0,0 +1,132 @@ +import io +import pathlib +import time + +from clvm_rs.program import Program +from clvm_rs.clvm_rs import serialized_length + + +def bench(f, name: str, allow_slow=False): + r, t = bench_w_speed(f, name) + if not allow_slow and t > 0.01: + print("*** TOO SLOW") + print() + return r + + +def bench_w_speed(f, name: str): + start = time.time() + r = f() + end = time.time() + d = end - start + print(f"{name}: {d:1.4f} s") + return r, d + + +def benchmark(): + block_path = pathlib.Path(__file__).parent / "block-2500014.compressed.bin" + obj = bench( + lambda: Program.parse(open(block_path, "rb")), + "obj = Program.parse(open([block_blob]))", + ) + bench(lambda: bytes(obj), "bytes(obj)") + + block_blob = open(block_path, "rb").read() + obj1 = bench( + lambda: Program.from_bytes(block_blob), + "obj = Program.from_bytes([block_blob])", + ) + bench(lambda: bytes(obj1), "bytes(obj)") + + cost, output = bench(lambda: obj.run_with_cost(0), "run", allow_slow=True) + print(f"cost = {cost}") + result_blob = bench( + lambda: bytes(output), + "serialize LazyNode", + allow_slow=True, + ) + print(f"output = {len(result_blob)}"), + + result_blob_2 = bench(lambda: bytes(output), "serialize LazyNode again") + assert result_blob == result_blob_2 + + bench( + lambda: print(output.tree_hash().hex()), + "tree hash LazyNode", + allow_slow=True, + ) + bench(lambda: print(output.tree_hash().hex()), "tree hash again LazyNode") + + des_output = bench( + lambda: Program.from_bytes(result_blob), + "from_bytes with tree hashing (fbwth)", + allow_slow=True, + ) + bench(lambda: des_output.tree_hash(), "tree hash (fbwth)") + bench(lambda: des_output.tree_hash(), "tree hash again (fbwth)") + + bench(lambda: serialized_length(result_blob), "serialized_length") + + des_output = bench( + lambda: Program.from_bytes(result_blob, calculate_tree_hash=False), + "from_bytes without tree hashing (fbwoth)", + allow_slow=True, + ) + bench( + lambda: des_output.tree_hash(), + "tree hash (fbwoth)", + allow_slow=True, + ) + bench( + lambda: des_output.tree_hash(), + "tree hash (fbwoth) again", + ) + + reparsed_output = bench( + lambda: Program.parse(io.BytesIO(result_blob)), + "parse with tree hashing (pwth)", + allow_slow=True, + ) + bench(lambda: reparsed_output.tree_hash(), "tree hash (pwth)") + bench( + lambda: reparsed_output.tree_hash(), + "tree hash again (pwth)", + ) + + reparsed_output = bench( + lambda: Program.parse(io.BytesIO(result_blob), calculate_tree_hash=False), + "parse without treehashing (pwowt)", + allow_slow=True, + ) + bench(lambda: reparsed_output.tree_hash(), "tree hash (pwowt)", allow_slow=True) + bench( + lambda: reparsed_output.tree_hash(), + "tree hash again (pwowt)", + ) + + foo = Program.to("foo") + o0 = Program.to((foo, obj)) + o1 = Program.to((foo, obj1)) + o2 = Program.to((foo, output)) + + def compare(): + assert o0 == o1 + + bench(compare, "compare constructed") + + bench(lambda: bytes(o0), "to_bytes constructed o0") + bench(lambda: bytes(o1), "to_bytes constructed o1") + + bench(lambda: print(o0.tree_hash().hex()), "o0 tree_hash") + bench(lambda: print(o0.tree_hash().hex()), "o0 tree_hash (again)") + + bench(lambda: print(o1.tree_hash().hex()), "o1 tree_hash") + bench(lambda: print(o1.tree_hash().hex()), "o1 tree_hash (again)") + + bench(lambda: bytes(o2), "to_bytes constructed o2") + bench(lambda: print(o2.tree_hash().hex()), "o2 tree_hash") + bench(lambda: print(o2.tree_hash().hex()), "o2 tree_hash (again)") + + +if __name__ == "__main__": + benchmark() diff --git a/wheel/python/clvm_rs/curry_and_treehash.py b/wheel/python/clvm_rs/curry_and_treehash.py index 04cf917c..3d47f4c9 100644 --- a/wheel/python/clvm_rs/curry_and_treehash.py +++ b/wheel/python/clvm_rs/curry_and_treehash.py @@ -2,7 +2,7 @@ from .at import at from .casts import CastableType -from .chia_dialect import Dialect +from .chia_dialect import Dialect, CHIA_DIALECT from .clvm_storage import CLVMStorage from .tree_hash import shatree_pair, shatree_atom @@ -154,3 +154,6 @@ def uncurry( # since "rrr" is not None, neither is rrf core = at(core, "rrf") return uncurried_function, core_items + + +CHIA_CURRY_TREEHASHER = CurryTreehasher(CHIA_DIALECT) diff --git a/wheel/python/clvm_rs/tree_hash.py b/wheel/python/clvm_rs/tree_hash.py index 4e39c174..8554bc89 100644 --- a/wheel/python/clvm_rs/tree_hash.py +++ b/wheel/python/clvm_rs/tree_hash.py @@ -5,7 +5,6 @@ This implementation goes to great pains to be non-recursive so we don't have to worry about blowing out the python stack. """ - from hashlib import sha256 from typing import Callable, List, Tuple, cast diff --git a/wheel/src/adapt_response.rs b/wheel/src/adapt_response.rs index 6b30886a..18d55209 100644 --- a/wheel/src/adapt_response.rs +++ b/wheel/src/adapt_response.rs @@ -15,11 +15,11 @@ pub fn adapt_response( ) -> PyResult<(u64, LazyNode)> { match response { Ok(reduction) => { - let val = LazyNode::new(Rc::new(allocator), reduction.1); + let val = LazyNode::new(py, Rc::new(allocator), reduction.1); Ok((reduction.0, val)) } Err(eval_err) => { - let sexp = LazyNode::new(Rc::new(allocator), eval_err.0).to_object(py); + let sexp = LazyNode::new(py, Rc::new(allocator), eval_err.0).to_object(py); let msg = eval_err.1.to_object(py); let tuple = PyTuple::new(py, [msg, sexp]); let value_error: PyErr = PyValueError::new_err(tuple.to_object(py)); diff --git a/wheel/src/api.rs b/wheel/src/api.rs index 1611462c..727d30e4 100644 --- a/wheel/src/api.rs +++ b/wheel/src/api.rs @@ -1,3 +1,5 @@ +use std::io; + use super::lazy_node::LazyNode; use crate::adapt_response::adapt_response; use clvmr::allocator::Allocator; @@ -5,16 +7,36 @@ use clvmr::chia_dialect::ChiaDialect; use clvmr::cost::Cost; use clvmr::reduction::Response; use clvmr::run_program::run_program; -use clvmr::serde::{node_from_bytes, serialized_length_from_bytes}; +use clvmr::serde::{ + node_from_bytes, parse_through_clvm_object, parse_triples, serialized_length_from_bytes, + ParsedTriple, +}; use clvmr::{LIMIT_HEAP, LIMIT_STACK, MEMPOOL_MODE, NO_UNKNOWN_OPS}; use pyo3::prelude::*; +use pyo3::types::{PyBytes, PyTuple}; use pyo3::wrap_pyfunction; +struct ReadPyAny<'py>(&'py PyAny); + +impl<'py> std::io::Read for ReadPyAny<'py> { + fn read(&mut self, b: &mut [u8]) -> std::result::Result { + let r: Vec = self.0.call1((b.len(),))?.extract()?; + let (p0, _p1) = b.split_at_mut(r.len()); + p0.copy_from_slice(&r); + Ok(r.len()) + } +} + #[pyfunction] pub fn serialized_length(program: &[u8]) -> PyResult { Ok(serialized_length_from_bytes(program)?) } +#[pyfunction] +pub fn skip_clvm_object(obj: &PyAny) -> PyResult<()> { + Ok(parse_through_clvm_object(&mut ReadPyAny(obj.getattr("read")?))?) +} + #[pyfunction] pub fn run_serialized_chia_program( py: Python, @@ -39,10 +61,41 @@ pub fn run_serialized_chia_program( adapt_response(py, allocator, r) } +fn tuple_for_parsed_triple(py: Python<'_>, p: &ParsedTriple) -> PyObject { + let tuple = match p { + ParsedTriple::Atom { + start, + end, + atom_offset, + } => PyTuple::new(py, [*start, *end, *atom_offset as u64]), + ParsedTriple::Pair { + start, + end, + right_index, + } => PyTuple::new(py, [*start, *end, *right_index as u64]), + }; + tuple.into_py(py) +} + +#[pyfunction] +fn deserialize_as_tree( + py: Python, + blob: &[u8], + calculate_tree_hashes: bool, +) -> PyResult<(Vec, Option>)> { + let mut cursor = io::Cursor::new(blob); + let (r, tree_hashes) = parse_triples(&mut cursor, calculate_tree_hashes)?; + let r = r.iter().map(|pt| tuple_for_parsed_triple(py, pt)).collect(); + let s = tree_hashes.map(|ths| ths.iter().map(|b| PyBytes::new(py, b).into()).collect()); + Ok((r, s)) +} + #[pymodule] fn clvm_rs(_py: Python, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(run_serialized_chia_program, m)?)?; m.add_function(wrap_pyfunction!(serialized_length, m)?)?; + m.add_function(wrap_pyfunction!(skip_clvm_object, m)?)?; + m.add_function(wrap_pyfunction!(deserialize_as_tree, m)?)?; m.add("NO_UNKNOWN_OPS", NO_UNKNOWN_OPS)?; m.add("LIMIT_HEAP", LIMIT_HEAP)?; diff --git a/wheel/src/lazy_node.rs b/wheel/src/lazy_node.rs index d3495eae..425664dc 100644 --- a/wheel/src/lazy_node.rs +++ b/wheel/src/lazy_node.rs @@ -9,6 +9,8 @@ use pyo3::types::{PyBytes, PyTuple}; pub struct LazyNode { allocator: Rc, node: NodePtr, + #[pyo3(get, set)] + _cached_sha256_treehash: PyObject, } impl ToPyObject for LazyNode { @@ -25,8 +27,8 @@ impl LazyNode { pub fn pair(&self, py: Python) -> PyResult> { match &self.allocator.sexp(self.node) { SExp::Pair(p1, p2) => { - let r1 = Self::new(self.allocator.clone(), *p1); - let r2 = Self::new(self.allocator.clone(), *p2); + let r1 = Self::new(py, self.allocator.clone(), *p1); + let r2 = Self::new(py, self.allocator.clone(), *p2); let v: &PyTuple = PyTuple::new(py, &[r1, r2]); Ok(Some(v.into())) } @@ -44,10 +46,11 @@ impl LazyNode { } impl LazyNode { - pub const fn new(a: Rc, n: NodePtr) -> Self { + pub fn new(py: Python, a: Rc, n: NodePtr) -> Self { Self { allocator: a, node: n, + _cached_sha256_treehash: py.None(), } } }