diff --git a/core/Cargo.lock b/core/Cargo.lock index eeeb5b18689b..84a2510d9462 100644 --- a/core/Cargo.lock +++ b/core/Cargo.lock @@ -10505,6 +10505,25 @@ dependencies = [ "zkevm_opcode_defs 0.150.20", ] +[[package]] +name = "zk_os_merkle_tree" +version = "26.4.0-non-semver-compat" +dependencies = [ + "anyhow", + "clap 4.5.23", + "leb128", + "once_cell", + "rand 0.8.5", + "rayon", + "tempfile", + "thiserror 1.0.69", + "tracing", + "tracing-subscriber", + "zksync_basic_types", + "zksync_crypto_primitives", + "zksync_storage", +] + [[package]] name = "zkevm_circuits" version = "0.150.20" diff --git a/core/Cargo.toml b/core/Cargo.toml index 4ec35fde54aa..c7e601a27807 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -77,6 +77,7 @@ members = [ "lib/crypto_primitives", "lib/external_price_api", "lib/task_management", + "lib/zk_os_merkle_tree", "lib/test_contracts", # Test infrastructure "tests/loadnext", diff --git a/core/lib/merkle_tree/examples/loadtest/main.rs b/core/lib/merkle_tree/examples/loadtest/main.rs index 6ac8425c0fc6..f144dcd1904d 100644 --- a/core/lib/merkle_tree/examples/loadtest/main.rs +++ b/core/lib/merkle_tree/examples/loadtest/main.rs @@ -43,7 +43,7 @@ struct Cli { /// Number of commits to perform. #[arg(name = "commits")] commit_count: u64, - /// Number of inserts / updates per commit. + /// Number of inserts per commit. #[arg(name = "ops")] writes_per_commit: usize, /// Generate Merkle proofs for each operation. diff --git a/core/lib/merkle_tree/src/consistency.rs b/core/lib/merkle_tree/src/consistency.rs index daf508692b4f..88990fefbf63 100644 --- a/core/lib/merkle_tree/src/consistency.rs +++ b/core/lib/merkle_tree/src/consistency.rs @@ -245,7 +245,7 @@ struct AtomicBitSet { } impl AtomicBitSet { - const BITS_PER_ATOMIC: usize = 8; + const BITS_PER_ATOMIC: usize = 64; fn new(len: usize) -> Self { let atomic_count = (len + Self::BITS_PER_ATOMIC - 1) / Self::BITS_PER_ATOMIC; diff --git a/core/lib/storage/src/db.rs b/core/lib/storage/src/db.rs index e16a3580ac1c..55d770a5e42d 100644 --- a/core/lib/storage/src/db.rs +++ b/core/lib/storage/src/db.rs @@ -77,6 +77,10 @@ impl WriteBatch<'_, CF> { let cf = self.db.column_family(cf); self.inner.delete_range_cf(cf, keys.start, keys.end); } + + pub fn size_in_bytes(&self) -> usize { + self.inner.size_in_bytes() + } } struct RocksDBCaches { @@ -487,10 +491,10 @@ impl RocksDB { self.inner.db.multi_get(keys) } - pub fn multi_get_cf( + pub fn multi_get_cf>( &self, cf: CF, - keys: impl Iterator>, + keys: impl Iterator, ) -> Vec>, rocksdb::Error>> { let cf = self.column_family(cf); self.inner.db.batched_multi_get_cf(cf, keys, false) @@ -597,12 +601,28 @@ impl RocksDB { pub fn from_iterator_cf( &self, cf: CF, - key_from: &[u8], + keys: ops::RangeFrom<&[u8]>, + ) -> impl Iterator, Box<[u8]>)> + '_ { + let cf = self.column_family(cf); + self.inner + .db + .iterator_cf(cf, IteratorMode::From(keys.start, Direction::Forward)) + .map(Result::unwrap) + .fuse() + // ^ unwrap() is safe for the same reasons as in `prefix_iterator_cf()`. + } + + /// Iterates over key-value pairs in the specified column family `cf` in the reverse lexical + /// key order starting from the given `key_from`. + pub fn to_iterator_cf( + &self, + cf: CF, + keys: ops::RangeToInclusive<&[u8]>, ) -> impl Iterator, Box<[u8]>)> + '_ { let cf = self.column_family(cf); self.inner .db - .iterator_cf(cf, IteratorMode::From(key_from, Direction::Forward)) + .iterator_cf(cf, IteratorMode::From(keys.end, Direction::Reverse)) .map(Result::unwrap) .fuse() // ^ unwrap() is safe for the same reasons as in `prefix_iterator_cf()`. diff --git a/core/lib/zk_os_merkle_tree/Cargo.toml b/core/lib/zk_os_merkle_tree/Cargo.toml new file mode 100644 index 000000000000..f8bf2d8dc383 --- /dev/null +++ b/core/lib/zk_os_merkle_tree/Cargo.toml @@ -0,0 +1,29 @@ +[package] +name = "zk_os_merkle_tree" +description = "Persistent ZK OS Merkle tree" +version.workspace = true +edition.workspace = true +authors.workspace = true +homepage.workspace = true +repository.workspace = true +license.workspace = true +keywords.workspace = true +categories.workspace = true + +[dependencies] +zksync_basic_types.workspace = true +zksync_crypto_primitives.workspace = true +zksync_storage.workspace = true + +anyhow.workspace = true +leb128.workspace = true +once_cell.workspace = true +rayon.workspace = true +thiserror.workspace = true +tracing.workspace = true + +[dev-dependencies] +clap = { workspace = true, features = ["derive"] } +rand.workspace = true +tempfile.workspace = true +tracing-subscriber = { workspace = true, features = ["env-filter"] } diff --git a/core/lib/zk_os_merkle_tree/README.md b/core/lib/zk_os_merkle_tree/README.md new file mode 100644 index 000000000000..9ae0fa0bf2e2 --- /dev/null +++ b/core/lib/zk_os_merkle_tree/README.md @@ -0,0 +1,118 @@ +# Persistent ZK OS Merkle tree + +Dense, doubly linked Merkle tree implementation with parameterized depth and amortization factor. + +## Construction + +- The tree is a dense binary Merkle tree with parametric depth (the default depth is currently set to 64; i.e., up to + `2 ** 64` leaves). +- Hash function is parametric as well; the default one is Blake2s with 256-bit output. The tree is always considered to + have fixed depth (i.e., no reduced hashing for lightly populated trees). +- The order of leaves is the insertion order; leaves are never removed from the tree. +- Leaves emulate a linked list. I.e., each leaf holds beside a 32-byte key and 32-byte value, 0-based indices in the + tree to leaves with lexicographically previous and next keys. +- There are 2 pre-inserted guard leaves with min / max keys (i.e., `[0_u8; 32]` and `[u8::MAX; 32]`). As such, all + “real” leaves always have previous / next pointers well-defined. + +Hashing specification: + +```text +hash(leaf) = blake2s( + leaf.key ++ leaf.value ++ leaf.prev.to_le_bytes() ++ leaf.next.to_le_bytes() +); +hash(node) = blake2s(hash(node.left) ++ hash(node.right)); +``` + +where `++` is byte concatenation. + +## Storage layout + +RocksDB is used for tree persistence. The implementation uses versioning and parametric amortization strategy similar to +[Jellyfish Merkle tree] to reduce the amount of I/O at the cost of increased hashing. Here, parametric means that the +radix of internal nodes is configurable (obviously, it's fixed for a specific tree instance). More details on what +amortization means follow. + +A tree is _versioned_; a new version is created for each batch update, and all past versions are available. (This is +necessary to be able to provide Merkle proofs for past versions.) Internally, the forest of all tree versions is built +like an immutable data structure, with tree nodes reused where possible (i.e., if not changed in an update). + +As expected, the Merkle tree consists of leaves and internal nodes; the tree root is a special case of internal node +with additional data (for now, it's just the number of leaves). + +- A **leaf** consists of a key, value and prev / next indices as expected. +- **Internal nodes** consist of refs to children; each ref is a version + hash. To reduce the amount of I/O ops (at the + cost of read / write volume overhead), an internal node contains >2 child refs; that's what the radix mentioned above + means (e.g., in a radix-16 tree each internal node _usually_ contains 16 child refs, with the only possible exception + being the rightmost node on each tree level). + +E.g., here's a radix-16 amortized tree with 2 versions and toy depth 8 (i.e., 1 internal node level excluding the root, +and 1 leaf level). The first version inserts 17 leaves, and the second version updates the last leaf. + +```mermaid +--- +title: Tree structure +--- +flowchart TD + Root0[Root v0] + Root1[Root v1] + Internal0[Internal 0] + Internal1[Internal 1] + Internal1_1[Internal 1'] + Leaf0[Leaf 0] + Leaf1[Leaf 1] + Leaf15[Leaf 15] + Leaf16[Leaf 16] + Leaf16_1[Leaf 16'] + Root0-->Internal0 & Internal1 + Internal0-->Leaf0 & Leaf1 & Leaf15 + Internal1-->Leaf16 + + Root1-->Internal0 & Internal1_1 + Internal1_1-->Leaf16_1 +``` + +Tree nodes are mapped to the RocksDB column family (CF) using _node keys_ consisting of a version, the number of +_nibbles_ (root has 0, its children 1 etc.), and the 0-based index on the level. Without pruning, it's easy to see that +storage is append-only. + +Besides the tree, RocksDB also persists the key to leaf index lookup in a separate CF. This lookup is used during +updates and to get historic Merkle proofs. To accommodate for historic proofs, CF values contain the version at which +the leaf was inserted besides its index; leaves with future versions are skipped during lookup. The lookup CF is +insert-only even with pruning; the only exception is tree truncation. Unlike the tree CF, inserted entries are not +ordered though. + +## Benchmarking + +The `loadtest` example is a CLI app allowing to measure tree performance. It allows using the in-memory or RocksDB +storage backend, and Blake2s or no-op hashing functions. For example, the following command launches a benchmark with +1,000 batches each containing 4,000 insertions and 16,000 updates (= 20,000 writes / batch; 4M inserts in total), +generating an insertion proof for each batch. + +```shell +RUST_LOG=debug cargo run --release \ + -p zk_os_merkle_tree --example loadtest -- \ + --updates=16000 --chunk-size=500 --proofs 1000 4000 +``` + +The order of timings should be as follows (measured on MacBook Pro with 12-core Apple M2 Max CPU and 32 GB DDR5 RAM +using the command line above): + +```text +2025-02-19T11:06:24.736870Z INFO loadtest: Processing block #999 +2025-02-19T11:06:24.813829Z DEBUG zk_os_merkle_tree::storage::patch: loaded lookup info, elapsed: 76.89375ms +2025-02-19T11:06:24.908340Z DEBUG zk_os_merkle_tree::storage::patch: loaded nodes, elapsed: 93.501125ms, distinct_indices.len: 23967 +2025-02-19T11:06:24.908994Z DEBUG zk_os_merkle_tree: loaded tree data, elapsed: 172.085ms, inserts: 4000, updates: 16000, loaded_internal_nodes: 36294 +2025-02-19T11:06:24.936667Z DEBUG zk_os_merkle_tree::storage::patch: collected hashes for batch proof, hash_latency: 15.131706ms, traverse_latency: 10.213624ms +2025-02-19T11:06:24.936756Z DEBUG zk_os_merkle_tree: created batch proof, elapsed: 27.751333ms, proof.leaves.len: 23967, proof.hashes.len: 156210 +2025-02-19T11:06:24.944054Z DEBUG zk_os_merkle_tree: updated tree structure, elapsed: 7.285209ms +2025-02-19T11:06:24.954820Z DEBUG zk_os_merkle_tree: hashed tree, elapsed: 10.747417ms +2025-02-19T11:06:25.017817Z DEBUG zk_os_merkle_tree: persisted tree, elapsed: 62.967083ms +2025-02-19T11:06:25.018655Z INFO loadtest: Processed block #999 in 281.765541ms, root hash = 0x12fa11d7742d67509c9a980e0fb62a1b64a478c9ff4d7596555e1f0d5cb2043f +2025-02-19T11:06:25.018669Z INFO loadtest: Verifying tree consistency... +2025-02-19T11:07:06.144174Z INFO loadtest: Verified tree consistency in 41.126574667s +``` + +I.e., latency is dominated by I/O (~30% for key–index lookup, ~30% for loading tree nodes, and ~20% for tree +persistence). + +[jellyfish merkle tree]: https://developers.diem.com/papers/jellyfish-merkle-tree/2021-01-14.pdf diff --git a/core/lib/zk_os_merkle_tree/examples/loadtest.rs b/core/lib/zk_os_merkle_tree/examples/loadtest.rs new file mode 100644 index 000000000000..6570b80a5b87 --- /dev/null +++ b/core/lib/zk_os_merkle_tree/examples/loadtest.rs @@ -0,0 +1,172 @@ +//! Load test for the Merkle tree. + +use std::{hint::black_box, time::Instant}; + +use anyhow::Context; +use clap::Parser; +use rand::{ + prelude::{IteratorRandom, StdRng}, + SeedableRng, +}; +use tempfile::TempDir; +use tracing_subscriber::EnvFilter; +use zk_os_merkle_tree::{ + Database, DefaultTreeParams, HashTree, MerkleTree, PatchSet, RocksDBWrapper, TreeEntry, + TreeParams, +}; +use zksync_basic_types::H256; +use zksync_crypto_primitives::hasher::{blake2::Blake2Hasher, Hasher}; +use zksync_storage::{RocksDB, RocksDBOptions}; + +#[derive(Debug)] +struct WithDynHasher; + +impl TreeParams for WithDynHasher { + type Hasher = &'static dyn HashTree; + const TREE_DEPTH: u8 = ::TREE_DEPTH; + const INTERNAL_NODE_DEPTH: u8 = ::INTERNAL_NODE_DEPTH; +} + +/// CLI for load-testing for the Merkle tree implementation. +#[derive(Debug, Parser)] +#[command(author, version, about, long_about = None)] +struct Cli { + /// Number of batches to insert into the tree. + #[arg(name = "batches")] + batch_count: u64, + /// Number of inserts per commit. + #[arg(name = "ops")] + writes_per_batch: usize, + /// Additional number of updates of previously written keys per commit. + #[arg(name = "updates", long, default_value = "0")] + updates_per_batch: usize, + /// Generate Merkle proofs for each operation. + #[arg(name = "proofs", long)] + proofs: bool, + /// Use a no-op hashing function. + #[arg(name = "no-hash", long)] + no_hashing: bool, + /// Perform testing on in-memory DB rather than RocksDB (i.e., with focus on hashing logic). + #[arg(long = "in-memory", short = 'M')] + in_memory: bool, + /// Block cache capacity for RocksDB in bytes. + #[arg(long = "block-cache", conflicts_with = "in_memory")] + block_cache: Option, + /// If specified, RocksDB indices and Bloom filters will be managed by the block cache rather than + /// being loaded entirely into RAM. + #[arg(long = "cache-indices", conflicts_with = "in_memory")] + cache_indices: bool, + /// Chunk size for RocksDB multi-get operations. + #[arg(long = "chunk-size", conflicts_with = "in_memory")] + chunk_size: Option, + /// Seed to use in the RNG for reproducibility. + #[arg(long = "rng-seed", default_value = "0")] + rng_seed: u64, + // FIXME: restore missing options (proof, in-memory buffering) +} + +impl Cli { + fn init_logging() { + tracing_subscriber::fmt() + .pretty() + .with_env_filter(EnvFilter::from_default_env()) + .init(); + } + + fn run(self) -> anyhow::Result<()> { + Self::init_logging(); + tracing::info!("Launched with options: {self:?}"); + + let (mut mock_db, mut rocksdb); + let mut _temp_dir = None; + let db: &mut dyn Database = if self.in_memory { + mock_db = PatchSet::default(); + &mut mock_db + } else { + let dir = TempDir::new().context("failed creating temp dir for RocksDB")?; + tracing::info!( + "Created temp dir for RocksDB: {}", + dir.path().to_string_lossy() + ); + let db_options = RocksDBOptions { + block_cache_capacity: self.block_cache, + include_indices_and_filters_in_block_cache: self.cache_indices, + ..RocksDBOptions::default() + }; + let db = + RocksDB::with_options(dir.path(), db_options).context("failed creating RocksDB")?; + rocksdb = RocksDBWrapper::from(db); + + if let Some(chunk_size) = self.chunk_size { + rocksdb.set_multi_get_chunk_size(chunk_size); + } + + _temp_dir = Some(dir); + &mut rocksdb + }; + + let hasher: &dyn HashTree = if self.no_hashing { &() } else { &Blake2Hasher }; + let mut rng = StdRng::seed_from_u64(self.rng_seed); + + let mut tree = MerkleTree::<_, WithDynHasher>::with_hasher(db, hasher) + .context("cannot create tree")?; + let mut next_key_idx = 0_u64; + let mut next_value_idx = 0_u64; + for version in 0..self.batch_count { + let new_keys: Vec<_> = Self::generate_keys(next_key_idx..) + .take(self.writes_per_batch) + .collect(); + let updated_indices = + (0..next_key_idx).choose_multiple(&mut rng, self.updates_per_batch); + next_key_idx += new_keys.len() as u64; + + next_value_idx += (new_keys.len() + updated_indices.len()) as u64; + let updated_keys = Self::generate_keys(updated_indices.into_iter()); + let kvs = new_keys + .into_iter() + .chain(updated_keys) + .zip(next_value_idx..); + let kvs = kvs.map(|(key, idx)| TreeEntry { + key, + value: H256::from_low_u64_be(idx), + }); + let kvs = kvs.collect::>(); + + tracing::info!("Processing block #{version}"); + let start = Instant::now(); + let output = if self.proofs { + let (output, proof) = tree + .extend_with_proof(&kvs) + .context("failed extending tree")?; + black_box(proof); // Ensure that proof creation isn't optimized away + output + } else { + tree.extend(&kvs).context("failed extending tree")? + }; + let root_hash = output.root_hash; + + let elapsed = start.elapsed(); + tracing::info!("Processed block #{version} in {elapsed:?}, root hash = {root_hash:?}"); + } + + tracing::info!("Verifying tree consistency..."); + let start = Instant::now(); + tree.verify_consistency(self.batch_count - 1) + .context("tree consistency check failed")?; + let elapsed = start.elapsed(); + tracing::info!("Verified tree consistency in {elapsed:?}"); + + Ok(()) + } + + fn generate_keys(key_indexes: impl Iterator) -> impl Iterator { + key_indexes.map(move |idx| { + let key = H256::from_low_u64_be(idx); + Blake2Hasher.hash_bytes(key.as_bytes()) + }) + } +} + +fn main() -> anyhow::Result<()> { + Cli::parse().run() +} diff --git a/core/lib/zk_os_merkle_tree/src/consistency.rs b/core/lib/zk_os_merkle_tree/src/consistency.rs new file mode 100644 index 000000000000..13b724d78acf --- /dev/null +++ b/core/lib/zk_os_merkle_tree/src/consistency.rs @@ -0,0 +1,483 @@ +use std::{ + fmt, + sync::atomic::{AtomicU64, Ordering}, +}; + +use zksync_basic_types::H256; + +use crate::{ + leaf_nibbles, max_nibbles_for_internal_node, max_node_children, + types::{InternalNode, KeyLookup, Leaf, Node, NodeKey}, + Database, DeserializeError, HashTree, MerkleTree, TreeParams, +}; + +#[derive(Debug, Clone, Copy)] +pub enum IndexKind { + This, + Prev, + Next, +} + +impl fmt::Display for IndexKind { + fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + formatter.write_str(match self { + Self::This => "self", + Self::Prev => "previous", + Self::Next => "next", + }) + } +} + +#[derive(Debug, thiserror::Error)] +pub enum ConsistencyError { + #[error("failed deserializing node from DB: {0}")] + Deserialize(#[from] DeserializeError), + #[error("tree version {0} does not exist")] + MissingVersion(u64), + #[error("missing root for tree version {0}")] + MissingRoot(u64), + #[error("missing min / max guards")] + NoGuards, + + #[error("internal node with key {key} has unexpected number of children: expected {expected}, actual {actual}")] + ChildCountMismatch { + key: NodeKey, + expected: usize, + actual: usize, + }, + #[error( + "internal node with key {key} should have version {expected_version} (max among child ref versions)" + )] + KeyVersionMismatch { key: NodeKey, expected_version: u64 }, + #[error("root node should have version >={max_child_version} (max among child ref versions)")] + RootVersionMismatch { max_child_version: u64 }, + + #[error("unexpected min guard (leaf with index 0)")] + UnexpectedMinGuard, + #[error("leaf {0} has minimum key")] + MinKey(u64), + #[error("unexpected max guard (leaf with index 1)")] + UnexpectedMaxGuard, + #[error("leaf {0} has maximum key")] + MaxKey(u64), + #[error("missing key {0:?} in lookup")] + MissingKeyLookup(H256), + #[error( + "{kind} index mismatch for key {key:?} in tree leaf ({in_tree}) and lookup ({in_lookup})" + )] + IndexMismatch { + kind: IndexKind, + key: H256, + in_lookup: u64, + in_tree: u64, + }, + #[error("{index_kind} index for leaf #{leaf_index} ({index}) >= leaf count ({leaf_count})")] + LeafIndexOverflow { + leaf_index: u64, + index_kind: IndexKind, + index: u64, + leaf_count: u64, + }, + #[error("{index_kind} index for leaf #{leaf_index} ({index}) is a duplicate")] + DuplicateLeafIndex { + leaf_index: u64, + index_kind: IndexKind, + index: u64, + }, + #[error("{index_kind} index for leaf #{leaf_index} points to the disallowed guard")] + IncorrectGuardRef { + leaf_index: u64, + index_kind: IndexKind, + }, + + #[error( + "internal node at {key} specifies that child hash at `{nibble}` \ + is {expected}, but it actually is {actual}" + )] + HashMismatch { + key: NodeKey, + nibble: u8, + expected: H256, + actual: H256, + }, +} + +impl MerkleTree { + /// Verifies the internal tree consistency as stored in the database. + /// + /// If `validate_indices` flag is set, it will be checked that indices for all tree leaves are unique + /// and are sequentially assigned starting from 1. + /// + /// # Errors + /// + /// Returns an error (the first encountered one if there are multiple). + pub fn verify_consistency(&self, version: u64) -> Result<(), ConsistencyError> { + let manifest = self.db.try_manifest()?; + let manifest = manifest.ok_or(ConsistencyError::MissingVersion(version))?; + if version >= manifest.version_count { + return Err(ConsistencyError::MissingVersion(version)); + } + + let root = self + .db + .try_root(version)? + .ok_or(ConsistencyError::MissingRoot(version))?; + + if root.leaf_count < 2 { + return Err(ConsistencyError::NoGuards); + } + let leaf_data = LeafConsistencyData::new(root.leaf_count); + + // We want to perform a depth-first walk of the tree in order to not keep + // much in memory. + let root_key = NodeKey::root(version); + self.validate_internal_node(&root.root_node, root_key, &leaf_data)?; + + Ok(()) + } + + fn validate_internal_node( + &self, + node: &InternalNode, + key: NodeKey, + leaf_data: &LeafConsistencyData, + ) -> Result { + use rayon::prelude::*; + + let leaf_count = leaf_data.expected_leaf_count; + assert!(leaf_count > 0); // checked during initialization + let child_depth = + (max_nibbles_for_internal_node::

() - key.nibble_count) * P::INTERNAL_NODE_DEPTH; + let last_child_index = (leaf_count - 1) >> child_depth; + let last_index_on_level = last_child_index / u64::from(max_node_children::

()); + + assert!(key.index_on_level <= last_index_on_level); + let expected_child_count = if key.index_on_level < last_index_on_level { + max_node_children::

().into() + } else { + (last_child_index % u64::from(max_node_children::

())) as usize + 1 + }; + + if node.children.len() != expected_child_count { + return Err(ConsistencyError::ChildCountMismatch { + key, + expected: expected_child_count, + actual: node.children.len(), + }); + } + + let expected_version = node + .children + .iter() + .map(|child_ref| child_ref.version) + .max() + .unwrap(); + if key.nibble_count != 0 && expected_version != key.version { + return Err(ConsistencyError::KeyVersionMismatch { + key, + expected_version, + }); + } else if key.nibble_count == 0 && expected_version > key.version { + return Err(ConsistencyError::RootVersionMismatch { + max_child_version: expected_version, + }); + } + + // `.into_par_iter()` below is the only place where `rayon`-based parallelism + // is used in tree verification. + node.children + .par_iter() + .enumerate() + .try_for_each(|(i, child_ref)| { + let child_key = NodeKey { + version: child_ref.version, + nibble_count: key.nibble_count + 1, + index_on_level: i as u64 + (key.index_on_level << P::INTERNAL_NODE_DEPTH), + }; + let children = self.db.try_nodes(&[child_key])?; + + // Assertions are used below because they are a part of the `Database` contract. + assert_eq!(children.len(), 1); + let child = children.into_iter().next().unwrap(); + + // Recursion here is OK; the tree isn't that deep (16 nibbles max, with upper levels most likely having a single child). + let child_hash = match &child { + Node::Internal(node) => { + assert!(child_key.nibble_count <= max_nibbles_for_internal_node::

()); + self.validate_internal_node(node, child_key, leaf_data)? + } + Node::Leaf(leaf) => { + assert_eq!(child_key.nibble_count, leaf_nibbles::

()); + self.validate_leaf(leaf, child_key, leaf_data)? + } + }; + + if child_hash == child_ref.hash { + Ok(()) + } else { + Err(ConsistencyError::HashMismatch { + key, + nibble: i as u8, + expected: child_ref.hash, + actual: child_hash, + }) + } + })?; + + Ok(node.hash::

(&self.hasher, child_depth)) + } + + fn validate_leaf( + &self, + leaf: &Leaf, + key: NodeKey, + leaf_data: &LeafConsistencyData, + ) -> Result { + let index = key.index_on_level; + + if index == 0 && (leaf.key != H256::zero() || leaf.prev_index != 0) { + return Err(ConsistencyError::UnexpectedMinGuard); + } + if index == 1 && (leaf.key != H256::repeat_byte(0xff) || leaf.next_index != 1) { + return Err(ConsistencyError::UnexpectedMaxGuard); + } + + leaf_data.insert_leaf(leaf, key.index_on_level)?; + + let lookup = self.db.indices(key.version, &[leaf.key])?; + assert_eq!(lookup.len(), 1); + let lookup_index = match lookup.into_iter().next().unwrap() { + KeyLookup::Existing(idx) => idx, + KeyLookup::Missing { .. } => { + return Err(ConsistencyError::MissingKeyLookup(leaf.key)); + } + }; + if lookup_index != index { + return Err(ConsistencyError::IndexMismatch { + kind: IndexKind::This, + key: leaf.key, + in_lookup: lookup_index, + in_tree: index, + }); + } + + if index != 0 { + let prev_key = prev_key(leaf.key).ok_or(ConsistencyError::MinKey(index))?; + let lookup = self.db.indices(key.version, &[prev_key])?; + assert_eq!(lookup.len(), 1); + + let lookup_prev_index = match lookup.into_iter().next().unwrap() { + KeyLookup::Existing(idx) => idx, + KeyLookup::Missing { + prev_key_and_index: (prev_key, idx), + .. + } => { + assert!(prev_key < leaf.key); + idx + } + }; + + if lookup_prev_index != leaf.prev_index { + return Err(ConsistencyError::IndexMismatch { + kind: IndexKind::Prev, + key: leaf.key, + in_lookup: lookup_prev_index, + in_tree: leaf.prev_index, + }); + } + } + + if index != 1 { + let next_key = next_key(leaf.key).ok_or(ConsistencyError::MaxKey(index))?; + let lookup = self.db.indices(key.version, &[next_key])?; + assert_eq!(lookup.len(), 1); + + let lookup_next_index = match lookup.into_iter().next().unwrap() { + KeyLookup::Existing(idx) => idx, + KeyLookup::Missing { + next_key_and_index: (next_key, idx), + .. + } => { + assert!(next_key > leaf.key); + idx + } + }; + + if lookup_next_index != leaf.next_index { + return Err(ConsistencyError::IndexMismatch { + kind: IndexKind::Next, + key: leaf.key, + in_lookup: lookup_next_index, + in_tree: leaf.next_index, + }); + } + } + + Ok(self.hasher.hash_leaf(leaf)) + } +} + +fn prev_key(key: H256) -> Option { + let mut bytes = key.0; + for pos in (0..32).rev() { + if bytes[pos] != 0 { + bytes[pos] -= 1; + for byte in &mut bytes[pos + 1..] { + *byte = 0xff; + } + return Some(H256(bytes)); + } + } + None +} + +fn next_key(key: H256) -> Option { + let mut bytes = key.0; + for pos in (0..32).rev() { + if bytes[pos] != u8::MAX { + bytes[pos] += 1; + for byte in &mut bytes[pos + 1..] { + *byte = 0; + } + return Some(H256(bytes)); + } + } + None +} + +#[must_use = "Final checks should be performed in `finalize()`"] +#[derive(Debug)] +struct LeafConsistencyData { + expected_leaf_count: u64, + prev_indices_set: AtomicBitSet, + next_indices_set: AtomicBitSet, +} + +impl LeafConsistencyData { + fn new(expected_leaf_count: u64) -> Self { + Self { + expected_leaf_count, + prev_indices_set: AtomicBitSet::new(expected_leaf_count as usize), + next_indices_set: AtomicBitSet::new(expected_leaf_count as usize), + } + } + + fn insert_leaf(&self, leaf: &Leaf, leaf_index: u64) -> Result<(), ConsistencyError> { + if leaf_index != 0 { + self.insert_into_set( + &self.prev_indices_set, + IndexKind::Prev, + leaf_index, + leaf.prev_index, + )?; + } + if leaf_index != 1 { + self.insert_into_set( + &self.next_indices_set, + IndexKind::Next, + leaf_index, + leaf.next_index, + )?; + } + Ok(()) + } + + fn insert_into_set( + &self, + bit_set: &AtomicBitSet, + index_kind: IndexKind, + leaf_index: u64, + index: u64, + ) -> Result<(), ConsistencyError> { + if index >= self.expected_leaf_count { + return Err(ConsistencyError::LeafIndexOverflow { + leaf_index, + index_kind, + index, + leaf_count: self.expected_leaf_count, + }); + } + + match index_kind { + IndexKind::Prev if index == 1 => { + return Err(ConsistencyError::IncorrectGuardRef { + leaf_index, + index_kind, + }); + } + IndexKind::Next if index == 0 => { + return Err(ConsistencyError::IncorrectGuardRef { + leaf_index, + index_kind, + }); + } + _ => { /* do nothing */ } + } + + if bit_set.set(index as usize) { + return Err(ConsistencyError::DuplicateLeafIndex { + leaf_index, + index_kind, + index, + }); + } + Ok(()) + } +} + +/// Primitive atomic bit set implementation that only supports setting bits. +#[derive(Debug)] +struct AtomicBitSet { + bits: Vec, +} + +impl AtomicBitSet { + const BITS_PER_ATOMIC: usize = 64; + + fn new(len: usize) -> Self { + let atomic_count = (len + Self::BITS_PER_ATOMIC - 1) / Self::BITS_PER_ATOMIC; + let mut bits = Vec::with_capacity(atomic_count); + bits.resize_with(atomic_count, AtomicU64::default); + Self { bits } + } + + /// Returns the previous bit value. + fn set(&self, bit_index: usize) -> bool { + let atomic_index = bit_index / Self::BITS_PER_ATOMIC; + let shift_in_atomic = bit_index % Self::BITS_PER_ATOMIC; + let atomic = &self.bits[atomic_index]; + let mask = 1 << (shift_in_atomic as u64); + let prev_value = atomic.fetch_or(mask, Ordering::SeqCst); + prev_value & mask != 0 + } +} + +#[cfg(test)] +mod tests { + use zksync_basic_types::{u256_to_h256, U256}; + + use super::*; + + #[test] + fn prev_and_next_key_work_as_expected() { + let key = H256::zero(); + assert_eq!(prev_key(key), None); + assert_eq!(next_key(key), Some(H256::from_low_u64_be(1))); + + let key = H256::from_low_u64_be(10); + assert_eq!(prev_key(key), Some(H256::from_low_u64_be(9))); + assert_eq!(next_key(key), Some(H256::from_low_u64_be(11))); + + let key = H256::from_low_u64_be((1 << 32) - 1); + assert_eq!(prev_key(key), Some(H256::from_low_u64_be((1 << 32) - 2))); + assert_eq!(next_key(key), Some(H256::from_low_u64_be(1 << 32))); + + let key = H256::from_low_u64_be(1 << 32); + assert_eq!(prev_key(key), Some(H256::from_low_u64_be((1 << 32) - 1))); + assert_eq!(next_key(key), Some(H256::from_low_u64_be((1 << 32) + 1))); + + let key = H256::repeat_byte(0xff); + assert_eq!(prev_key(key), Some(u256_to_h256(U256::MAX - 1))); + assert_eq!(next_key(key), None); + } +} diff --git a/core/lib/zk_os_merkle_tree/src/errors.rs b/core/lib/zk_os_merkle_tree/src/errors.rs new file mode 100644 index 000000000000..032ccf507ae3 --- /dev/null +++ b/core/lib/zk_os_merkle_tree/src/errors.rs @@ -0,0 +1,103 @@ +//! Errors interacting with the Merkle tree. + +use std::{error, fmt, str::Utf8Error}; + +use crate::types::NodeKey; + +#[derive(Debug, thiserror::Error)] +pub(crate) enum DeserializeErrorKind { + #[error("Tree manifest was expected, but is missing")] + MissingManifest, + #[error("Node was expected, but is missing")] + MissingNode, + #[error("Unexpected end of input")] + UnexpectedEof, + #[error("data left after deserialization")] + Leftovers, + /// Error reading a LEB128-encoded value. + #[error("failed reading LEB128-encoded value: {0}")] + Leb128(#[source] leb128::read::Error), + /// Error reading a UTF-8 string. + #[error("failed reading UTF-8 string: {0}")] + Utf8(#[source] Utf8Error), + + /// Missing required tag in the tree manifest. + #[error("missing required tag `{0}` in tree manifest")] + MissingTag(&'static str), + /// Unknown tag in the tree manifest. + #[error("unknown tag `{0}` in tree manifest")] + UnknownTag(String), + /// Malformed tag in the tree manifest. + #[error("malformed tag `{name}` in tree manifest: {err}")] + MalformedTag { + /// Tag name. + name: &'static str, + /// Error that has occurred parsing the tag. + #[source] + err: Box, + }, +} + +#[derive(Debug)] +pub(crate) enum DeserializeContext { + Manifest, + Node(NodeKey), + ChildRef(u8), + LeafCount, + KeyIndex(Box<[u8]>), +} + +impl fmt::Display for DeserializeContext { + fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Node(key) => write!(formatter, "node at {key}"), + Self::ChildRef(idx) => write!(formatter, "child ref {idx}"), + Self::LeafCount => write!(formatter, "leaf count"), + Self::Manifest => write!(formatter, "manifest"), + Self::KeyIndex(key) => write!(formatter, "key index {key:?}"), + } + } +} + +/// Error that can occur during deserialization. +#[derive(Debug)] +pub struct DeserializeError { + kind: DeserializeErrorKind, + contexts: Vec, +} + +impl From for DeserializeError { + fn from(kind: DeserializeErrorKind) -> Self { + Self { + kind, + contexts: vec![], + } + } +} + +impl DeserializeError { + #[must_use] + pub(crate) fn with_context(mut self, context: DeserializeContext) -> Self { + self.contexts.push(context); + self + } +} + +impl fmt::Display for DeserializeError { + fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + // `self.contexts` are ordered from the most specific one to the most general one + if !self.contexts.is_empty() { + write!(formatter, "[in ")?; + for (i, context) in self.contexts.iter().enumerate() { + write!(formatter, "{context}")?; + if i + 1 < self.contexts.len() { + write!(formatter, ", ")?; + } + } + write!(formatter, "] ")?; + } + write!(formatter, "{}", self.kind) + } +} + +impl std::error::Error for DeserializeError {} diff --git a/core/lib/zk_os_merkle_tree/src/hasher/mod.rs b/core/lib/zk_os_merkle_tree/src/hasher/mod.rs new file mode 100644 index 000000000000..9a11951f738a --- /dev/null +++ b/core/lib/zk_os_merkle_tree/src/hasher/mod.rs @@ -0,0 +1,150 @@ +use std::iter; + +use once_cell::sync::Lazy; +use zksync_basic_types::H256; +use zksync_crypto_primitives::hasher::{blake2::Blake2Hasher, Hasher}; + +pub(crate) use self::nodes::InternalHashes; +pub use self::proofs::{BatchTreeProof, IntermediateHash, TreeOperation}; +use crate::types::{Leaf, MAX_TREE_DEPTH}; + +mod nodes; +mod proofs; + +/// Tree hashing functionality. +pub trait HashTree: Send + Sync { + /// Returns the unique name of the hasher. This is used in Merkle tree tags to ensure + /// that the tree remains consistent. + fn name(&self) -> &'static str; + + /// Hashes a leaf node. + fn hash_leaf(&self, leaf: &Leaf) -> H256; + /// Compresses hashes in an intermediate node of a binary Merkle tree. + fn hash_branch(&self, lhs: &H256, rhs: &H256) -> H256; + + /// Returns the hash of an empty subtree with the given depth. `depth == 0` corresponds to leaves. Implementations + /// are encouraged to cache the returned values. + /// + /// Guaranteed to never be called with `depth > 64` (i.e., exceeding the depth of the entire tree). + fn empty_subtree_hash(&self, depth: u8) -> H256; +} + +impl HashTree for &H { + fn name(&self) -> &'static str { + (**self).name() + } + + fn hash_leaf(&self, leaf: &Leaf) -> H256 { + (**self).hash_leaf(leaf) + } + + fn hash_branch(&self, lhs: &H256, rhs: &H256) -> H256 { + (**self).hash_branch(lhs, rhs) + } + + fn empty_subtree_hash(&self, depth: u8) -> H256 { + (**self).empty_subtree_hash(depth) + } +} + +/// No-op implementation. +impl HashTree for () { + fn name(&self) -> &'static str { + "no-op" + } + + fn hash_leaf(&self, _leaf: &Leaf) -> H256 { + H256::zero() + } + + fn hash_branch(&self, _lhs: &H256, _rhs: &H256) -> H256 { + H256::zero() + } + + fn empty_subtree_hash(&self, _depth: u8) -> H256 { + H256::zero() + } +} + +impl HashTree for Blake2Hasher { + fn name(&self) -> &'static str { + "Blake2s256" + } + + fn hash_leaf(&self, leaf: &Leaf) -> H256 { + let mut hashed_bytes = [0; 2 * 32 + 2 * 8]; + hashed_bytes[..32].copy_from_slice(leaf.key.as_bytes()); + hashed_bytes[32..64].copy_from_slice(leaf.value.as_bytes()); + hashed_bytes[64..72].copy_from_slice(&leaf.prev_index.to_le_bytes()); + hashed_bytes[72..].copy_from_slice(&leaf.next_index.to_le_bytes()); + self.hash_bytes(&hashed_bytes) + } + + fn hash_branch(&self, lhs: &H256, rhs: &H256) -> H256 { + self.compress(lhs, rhs) + } + + fn empty_subtree_hash(&self, depth: u8) -> H256 { + static EMPTY_TREE_HASHES: Lazy> = Lazy::new(compute_empty_tree_hashes); + EMPTY_TREE_HASHES[usize::from(depth)] + } +} + +fn compute_empty_tree_hashes() -> Vec { + let empty_leaf_hash = Blake2Hasher.hash_bytes(&[0_u8; 2 * 32 + 2 * 8]); + iter::successors(Some(empty_leaf_hash), |hash| { + Some(Blake2Hasher.hash_branch(hash, hash)) + }) + .take(usize::from(MAX_TREE_DEPTH) + 1) + .collect() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn hashing_leaves_is_correct() { + let expected_empty_leaf_hash: H256 = + "0xc4fde76a8d68422c5fbafde250f492109fb29ac66753292e1153aa11adae1a3a" + .parse() + .unwrap(); + assert_eq!(Blake2Hasher.empty_subtree_hash(0), expected_empty_leaf_hash); + + let expected_level1_empty_hash: H256 = + "0xd53cc61c1aba0c548d73b0131e635e3110434a9c13c65cae08ed7da60ad2858f" + .parse() + .unwrap(); + assert_eq!( + Blake2Hasher.empty_subtree_hash(1), + expected_level1_empty_hash + ); + + let expected_level63_empty_hash: H256 = + "0x59841e10b053bb976a3a159af345e27cc4dbbb1f5424051b6d24f5c56b69e74d" + .parse() + .unwrap(); + assert_eq!( + Blake2Hasher.empty_subtree_hash(63), + expected_level63_empty_hash + ); + + let expected_min_guard_hash: H256 = + "0x4034715b557ca4bc5aef36ae5f28223ab27da4ac291cc63d0835ef2e0eba0c42" + .parse() + .unwrap(); + assert_eq!( + Blake2Hasher.hash_leaf(&Leaf::MIN_GUARD), + expected_min_guard_hash + ); + + let expected_max_guard_hash: H256 = + "0xb30053e4154d49d35b0005e3ee0d4e0fc9fd330aed004c86810b57cf40a28afa" + .parse() + .unwrap(); + assert_eq!( + Blake2Hasher.hash_leaf(&Leaf::MAX_GUARD), + expected_max_guard_hash + ); + } +} diff --git a/core/lib/zk_os_merkle_tree/src/hasher/nodes.rs b/core/lib/zk_os_merkle_tree/src/hasher/nodes.rs new file mode 100644 index 000000000000..7c02c718a44b --- /dev/null +++ b/core/lib/zk_os_merkle_tree/src/hasher/nodes.rs @@ -0,0 +1,190 @@ +//! Hashing for tree nodes. + +use std::{collections::HashMap, iter}; + +use zksync_basic_types::H256; + +use crate::{ + max_nibbles_for_internal_node, + types::{InternalNode, Root}, + HashTree, TreeParams, +}; + +/// Internal hashes for a single `InternalNode`. Ordered by ascending depth `1..internal_node_depth` +/// where `depth == 1` is just above child refs. I.e., the last entry contains 2 hashes (unless the internal node is incomplete), +/// the penultimate one 4 hashes, etc. +/// +/// To access hashes more efficiently, we keep a flat `Vec` and uniform offsets for `(depth, index_on_level)` pairs. +/// The latter requires potential padding for rightmost internal nodes; see [`InternalNode::internal_hashes()`]. +/// As a result of these efforts, generating proofs is ~2x more efficient than with layered `Vec>`. +#[derive(Debug)] +struct InternalNodeHashes(Vec); + +impl InternalNode { + pub(crate) fn hash(&self, hasher: &P::Hasher, depth: u8) -> H256 { + self.hash_inner::

(hasher, depth, true, |_| {}) + } + + fn hash_inner( + &self, + hasher: &P::Hasher, + depth: u8, + hash_last_level: bool, + mut on_level: impl FnMut(&[H256]), + ) -> H256 { + assert!(depth <= max_nibbles_for_internal_node::

() * P::INTERNAL_NODE_DEPTH); + + let mut hashes: Vec<_> = self.children.iter().map(|child| child.hash).collect(); + let mut level_count = P::INTERNAL_NODE_DEPTH.min(P::TREE_DEPTH - depth); + if !hash_last_level { + level_count = level_count.saturating_sub(1); + } + + for level_offset in 0..level_count { + let new_len = hashes.len().div_ceil(2); + for i in 0..new_len { + hashes[i] = if 2 * i + 1 < hashes.len() { + hasher.hash_branch(&hashes[2 * i], &hashes[2 * i + 1]) + } else { + hasher.hash_branch( + &hashes[2 * i], + &hasher.empty_subtree_hash(depth + level_offset), + ) + }; + } + hashes.truncate(new_len); + on_level(&hashes); + } + + hashes[0] + } + + fn internal_hashes(&self, hasher: &P::Hasher, depth: u8) -> InternalNodeHashes { + // capacity = 2 + 4 + ... + 2 ** (P::INTERNAL_NODE_DEPTH - 1) = 2 * (2 ** (P::INTERNAL_NODE_DEPTH - 1) - 1) = 2 ** P::INTERNAL_NODE_DEPTH - 2 + let capacity = (1 << P::INTERNAL_NODE_DEPTH) - 2; + let mut hashes = InternalNodeHashes(Vec::with_capacity(capacity)); + let mut full_level_len = 1 << (P::INTERNAL_NODE_DEPTH - 1); + self.hash_inner::

(hasher, depth, false, |level_hashes| { + hashes.0.extend_from_slice(level_hashes); + // Pad if necessary so that there are uniform offsets for each level. The padding should never be read. + // It doesn't waste that much space given that it may be required only for one internal node per level. + hashes + .0 + .extend(iter::repeat(H256::zero()).take(full_level_len - level_hashes.len())); + full_level_len /= 2; + }); + assert_eq!(hashes.0.len(), capacity); + hashes + } +} + +impl Root { + pub(crate) fn hash(&self, hasher: &P::Hasher) -> H256 { + self.root_node.hash::

( + hasher, + max_nibbles_for_internal_node::

() * P::INTERNAL_NODE_DEPTH, + ) + } +} + +/// Internal hashes for a level of `InternalNode`s. +#[derive(Debug)] +pub(crate) struct InternalHashes<'a> { + nodes: &'a HashMap, + /// Internal hashes for each node. + // TODO: `Vec<(u64, H256)>` for a level may be more efficient + internal_hashes: HashMap, + // `internal_node_depth` / `level_offsets` are constants w.r.t. `TreeParams`; we keep them as fields + // to avoid making `InternalHashes` parametric. (Also, `level_offsets` cannot be computed in compile time + // on stable Rust at the time.) + internal_node_depth: u8, + level_offsets: Vec, +} + +impl<'a> InternalHashes<'a> { + pub(crate) fn new( + nodes: &'a HashMap, + hasher: &P::Hasher, + depth: u8, + ) -> Self { + use rayon::prelude::*; + + let mut offset = 0; + let mut level_offsets = Vec::with_capacity(usize::from(P::INTERNAL_NODE_DEPTH) - 1); + for depth_in_node in 1..P::INTERNAL_NODE_DEPTH { + level_offsets.push(offset); + offset += 1_usize << (P::INTERNAL_NODE_DEPTH - depth_in_node); + } + + let internal_hashes = nodes + .par_iter() + .map(|(idx, node)| (*idx, node.internal_hashes::

(hasher, depth))) + .collect(); + + Self { + nodes, + internal_hashes, + internal_node_depth: P::INTERNAL_NODE_DEPTH, + level_offsets, + } + } + + pub(crate) fn get(&self, depth_in_node: u8, index_on_level: u64) -> H256 { + let bit_shift = self.internal_node_depth - depth_in_node; + let node_index = index_on_level >> bit_shift; + let index_in_node = (index_on_level % (1 << bit_shift)) as usize; + + if depth_in_node == 0 { + // Get the hash from a `ChildRef` + self.nodes[&node_index].child_ref(index_in_node).hash + } else { + let overall_idx = self.level_offsets[usize::from(depth_in_node) - 1] + index_in_node; + self.internal_hashes[&node_index].0[overall_idx] + } + } +} + +#[cfg(test)] +mod tests { + use zksync_crypto_primitives::hasher::blake2::Blake2Hasher; + + use super::*; + use crate::DefaultTreeParams; + + #[test] + fn constructing_internal_hashes() { + let nodes = HashMap::from([(0, InternalNode::new(16, 0)), (1, InternalNode::new(7, 0))]); + let internal_hashes = InternalHashes::new::(&nodes, &Blake2Hasher, 0); + + assert_eq!(internal_hashes.level_offsets, [0, 8, 12]); + + for i in 0..(16 + 7) { + assert_eq!(internal_hashes.get(0, i), H256::zero()); + } + assert_eq!(internal_hashes.internal_hashes.len(), 2); + + let expected_hash = Blake2Hasher.hash_branch(&H256::zero(), &H256::zero()); + for i in 0..(8 + 3) { + assert_eq!(internal_hashes.get(1, i), expected_hash); + } + let expected_boundary_hash = + Blake2Hasher.hash_branch(&H256::zero(), &Blake2Hasher.empty_subtree_hash(0)); + assert_eq!(internal_hashes.get(1, 11), expected_boundary_hash); + + let expected_boundary_hash = + Blake2Hasher.hash_branch(&expected_hash, &expected_boundary_hash); + let expected_hash = Blake2Hasher.hash_branch(&expected_hash, &expected_hash); + for i in 0..(4 + 1) { + assert_eq!(internal_hashes.get(2, i), expected_hash); + } + assert_eq!(internal_hashes.get(2, 5), expected_boundary_hash); + + let expected_boundary_hash = + Blake2Hasher.hash_branch(&expected_hash, &expected_boundary_hash); + let expected_hash = Blake2Hasher.hash_branch(&expected_hash, &expected_hash); + for i in 0..2 { + assert_eq!(internal_hashes.get(3, i), expected_hash); + } + assert_eq!(internal_hashes.get(3, 2), expected_boundary_hash); + } +} diff --git a/core/lib/zk_os_merkle_tree/src/hasher/proofs.rs b/core/lib/zk_os_merkle_tree/src/hasher/proofs.rs new file mode 100644 index 000000000000..bd86f30f6e87 --- /dev/null +++ b/core/lib/zk_os_merkle_tree/src/hasher/proofs.rs @@ -0,0 +1,375 @@ +use std::{collections::BTreeMap, iter}; + +use anyhow::Context; +use zksync_basic_types::H256; + +use crate::{types::Leaf, BatchOutput, HashTree, TreeEntry}; + +/// Operation on a Merkle tree entry used in [`BatchTreeProof`]. +#[derive(Debug, Clone, Copy)] +#[cfg_attr(test, derive(PartialEq))] +pub enum TreeOperation { + /// Update of an existing entry. + Update { index: u64 }, + /// Insertion of a new entry. + Insert { + /// Prev index before *batch* insertion (i.e., always points to an index existing before batch insertion). + prev_index: u64, + }, +} + +#[derive(Debug)] +pub struct IntermediateHash { + pub value: H256, + /// Level + index on level. Redundant and is only checked in tests. + #[cfg(test)] + pub location: (u8, u64), +} + +/// Merkle proof of batch insertion into [`MerkleTree`](crate::MerkleTree). +/// +/// # How it's verified +/// +/// Assumes that the tree before insertion is correctly constructed (in particular, leaves are correctly linked via prev / next index). +/// Given that, proof verification is as follows: +/// +/// 1. Check that all necessary leaves are present in `sorted_leaves`, and their keys match inserted / updated entries. +/// 2. Previous root hash of the tree is recreated using `sorted_leaves` and `hashes`. +/// 3. `sorted_leaves` are updated / extended as per inserted / updated entries. +/// 4. New root hash of the tree is recreated using updated `sorted_leaves` and (the same) `hashes`. +#[derive(Debug)] +pub struct BatchTreeProof { + /// Performed tree operations. Correspond 1-to-1 to [`TreeEntry`]s. + pub operations: Vec, + /// Sorted leaves from the tree before insertion sufficient to prove it. Contains all updated leaves + /// (incl. prev / next neighbors for the inserted leaves), and the last leaf in the tree if there are inserts. + pub sorted_leaves: BTreeMap, + /// Hashes necessary and sufficient to restore previous and updated root hashes. Provided in the ascending `(depth, index_on_level)` order, + /// where `depth == 0` are leaves, `depth == 1` are nodes aggregating leaf pairs etc. + pub hashes: Vec, +} + +impl BatchTreeProof { + #[cfg(test)] + fn empty() -> Self { + Self { + operations: vec![], + sorted_leaves: BTreeMap::new(), + hashes: vec![], + } + } + + /// Returns the new root hash of the tree on success. + pub fn verify( + mut self, + hasher: &dyn HashTree, + tree_depth: u8, + prev_output: Option, + entries: &[TreeEntry], + ) -> anyhow::Result { + let Some(prev_output) = prev_output else { + return self.verify_for_empty_tree(hasher, tree_depth, entries); + }; + + anyhow::ensure!( + self.operations.len() == entries.len(), + "Unexpected operations length" + ); + if let Some((max_idx, _)) = self.sorted_leaves.iter().next_back() { + anyhow::ensure!(*max_idx < prev_output.leaf_count, "Index is too large"); + } + + let mut index_by_key: BTreeMap<_, _> = self + .sorted_leaves + .iter() + .map(|(idx, leaf)| (leaf.key, *idx)) + .collect(); + + let mut next_tree_index = prev_output.leaf_count; + for (&operation, entry) in self.operations.iter().zip(entries) { + match operation { + TreeOperation::Update { index } => { + anyhow::ensure!( + index < prev_output.leaf_count, + "Updated non-existing index {index}" + ); + let existing_leaf = self + .sorted_leaves + .get(&index) + .with_context(|| format!("Update for index {index} is not proven"))?; + anyhow::ensure!( + existing_leaf.key == entry.key, + "Update for index {index} has unexpected key" + ); + } + TreeOperation::Insert { prev_index } => { + let prev_leaf = self.sorted_leaves.get(&prev_index).with_context(|| { + format!("prev leaf {prev_index} for {entry:?} is not proven") + })?; + anyhow::ensure!(prev_leaf.key < entry.key); + + let old_next_index = prev_leaf.next_index; + let old_next_leaf = + self.sorted_leaves.get(&old_next_index).with_context(|| { + format!("old next leaf {old_next_index} for {entry:?} is not proven") + })?; + anyhow::ensure!(old_next_leaf.prev_index == prev_index); + anyhow::ensure!(entry.key < old_next_leaf.key); + + index_by_key.insert(entry.key, next_tree_index); + next_tree_index += 1; + } + }; + } + + let restored_prev_hash = Self::zip_leaves( + hasher, + tree_depth, + prev_output.leaf_count, + self.sorted_leaves.iter().map(|(idx, leaf)| (*idx, leaf)), + self.hashes.iter(), + )?; + anyhow::ensure!( + restored_prev_hash == prev_output.root_hash, + "Mismatch for previous root hash: prev_output={prev_output:?}, restored={restored_prev_hash:?}" + ); + + // Expand `leaves` with the newly inserted leaves and update the existing leaves. + for (&operation, entry) in self.operations.iter().zip(entries) { + match operation { + TreeOperation::Update { index } => { + // We've checked the key correspondence already. + self.sorted_leaves.get_mut(&index).unwrap().value = entry.value; + } + TreeOperation::Insert { .. } => { + let mut it = index_by_key.range(entry.key..); + let (_, &this_index) = it.next().unwrap(); + // `unwrap()`s below are safe: at least the pre-existing next index is greater, and the pre-existing prev index is lesser. + let (_, &next_index) = it.next().unwrap(); + let (_, &prev_index) = index_by_key.range(..entry.key).next_back().unwrap(); + + self.sorted_leaves.insert( + this_index, + Leaf { + key: entry.key, + value: entry.value, + prev_index, + next_index, + }, + ); + + // Prev / next leaves may be missing if they are inserted in the batch as well; + // in this case, prev / next index will be set correctly once the leaf is created. + if let Some(prev_leaf) = self.sorted_leaves.get_mut(&prev_index) { + prev_leaf.next_index = this_index; + } + if let Some(next_leaf) = self.sorted_leaves.get_mut(&next_index) { + next_leaf.prev_index = this_index; + } + } + } + } + + Self::zip_leaves( + hasher, + tree_depth, + next_tree_index, + self.sorted_leaves.iter().map(|(idx, leaf)| (*idx, leaf)), + self.hashes.iter(), + ) + } + + fn verify_for_empty_tree( + self, + hasher: &dyn HashTree, + tree_depth: u8, + entries: &[TreeEntry], + ) -> anyhow::Result { + // The proof must be entirely empty since we can get all data from `entries`. + anyhow::ensure!(self.sorted_leaves.is_empty()); + anyhow::ensure!(self.operations.is_empty()); + anyhow::ensure!(self.hashes.is_empty()); + + let index_by_key: BTreeMap<_, _> = entries + .iter() + .enumerate() + .map(|(i, entry)| (entry.key, i as u64 + 2)) + .collect(); + anyhow::ensure!( + index_by_key.len() == entries.len(), + "There are entries with duplicate keys" + ); + + let mut min_leaf_index = 1; + let mut max_leaf_index = 0; + let sorted_leaves = entries.iter().enumerate().map(|(i, entry)| { + let this_index = i as u64 + 2; + + // The key itself is guaranteed to be the first yielded item, hence `skip(1)`. + let mut it = index_by_key.range(entry.key..).skip(1); + let next_index = it.next().map(|(_, idx)| *idx).unwrap_or_else(|| { + max_leaf_index = this_index; + 1 + }); + let prev_index = index_by_key + .range(..entry.key) + .map(|(_, idx)| *idx) + .next_back() + .unwrap_or_else(|| { + min_leaf_index = this_index; + 0 + }); + + Leaf { + key: entry.key, + value: entry.value, + prev_index, + next_index, + } + }); + let sorted_leaves: Vec<_> = sorted_leaves.collect(); + + let min_guard = Leaf { + next_index: min_leaf_index, + ..Leaf::MIN_GUARD + }; + let max_guard = Leaf { + prev_index: max_leaf_index, + ..Leaf::MAX_GUARD + }; + let leaves_with_guards = [(0, &min_guard), (1, &max_guard)] + .into_iter() + .chain((2..).zip(&sorted_leaves)); + + Self::zip_leaves( + hasher, + tree_depth, + 2 + entries.len() as u64, + leaves_with_guards, + iter::empty(), + ) + } + + fn zip_leaves<'a>( + hasher: &dyn HashTree, + tree_depth: u8, + leaf_count: u64, + sorted_leaves: impl Iterator, + mut hashes: impl Iterator, + ) -> anyhow::Result { + let mut node_hashes: Vec<_> = sorted_leaves + .map(|(idx, leaf)| (idx, hasher.hash_leaf(leaf))) + .collect(); + let mut last_idx_on_level = leaf_count - 1; + + for depth in 0..tree_depth { + let mut i = 0; + let mut next_level_i = 0; + while i < node_hashes.len() { + let (current_idx, current_hash) = node_hashes[i]; + let next_level_hash = if current_idx % 2 == 1 { + // The hash to the left is missing; get it from `hashes` + i += 1; + let lhs = hashes.next().context("ran out of hashes")?; + #[cfg(test)] + assert_eq!(lhs.location, (depth, current_idx - 1)); + + hasher.hash_branch(&lhs.value, ¤t_hash) + } else if let Some((_, next_hash)) = node_hashes + .get(i + 1) + .filter(|(next_idx, _)| *next_idx == current_idx + 1) + { + i += 2; + hasher.hash_branch(¤t_hash, next_hash) + } else { + // The hash to the right is missing; get it from `hashes`, or set to the empty subtree hash if appropriate. + i += 1; + let rhs = if current_idx == last_idx_on_level { + hasher.empty_subtree_hash(depth) + } else { + let rhs = hashes.next().context("ran out of hashes")?; + #[cfg(test)] + assert_eq!(rhs.location, (depth, current_idx + 1)); + rhs.value + }; + hasher.hash_branch(¤t_hash, &rhs) + }; + + node_hashes[next_level_i] = (current_idx / 2, next_level_hash); + next_level_i += 1; + } + node_hashes.truncate(next_level_i); + last_idx_on_level /= 2; + } + + anyhow::ensure!(hashes.next().is_none(), "not all hashes consumed"); + + Ok(node_hashes[0].1) + } +} + +#[cfg(test)] +mod tests { + use zksync_crypto_primitives::hasher::blake2::Blake2Hasher; + + use super::*; + + #[test] + fn insertion_proof_for_empty_tree() { + let proof = BatchTreeProof::empty(); + let hash = proof.verify(&Blake2Hasher, 64, None, &[]).unwrap(); + assert_eq!( + hash, + "0x8a41011d351813c31088367deecc9b70677ecf15ffc24ee450045cdeaf447f63" + .parse() + .unwrap() + ); + + let proof = BatchTreeProof::empty(); + let entry = TreeEntry { + key: H256::repeat_byte(0x01), + value: H256::repeat_byte(0x10), + }; + let hash = proof.verify(&Blake2Hasher, 64, None, &[entry]).unwrap(); + assert_eq!( + hash, + "0x91a1688c802dc607125d0b5e5ab4d95d89a4a4fb8cca71a122db6076cb70f8f3" + .parse() + .unwrap() + ); + } + + #[test] + fn basic_insertion_proof() { + let proof = BatchTreeProof { + operations: vec![TreeOperation::Insert { prev_index: 0 }], + sorted_leaves: BTreeMap::from([(0, Leaf::MIN_GUARD), (1, Leaf::MAX_GUARD)]), + hashes: vec![], + }; + + let empty_tree_output = BatchOutput { + leaf_count: 2, + root_hash: "0x8a41011d351813c31088367deecc9b70677ecf15ffc24ee450045cdeaf447f63" + .parse() + .unwrap(), + }; + let new_tree_hash = proof + .verify( + &Blake2Hasher, + 64, + Some(empty_tree_output), + &[TreeEntry { + key: H256::repeat_byte(0x01), + value: H256::repeat_byte(0x10), + }], + ) + .unwrap(); + + assert_eq!( + new_tree_hash, + "0x91a1688c802dc607125d0b5e5ab4d95d89a4a4fb8cca71a122db6076cb70f8f3" + .parse() + .unwrap() + ); + } +} diff --git a/core/lib/zk_os_merkle_tree/src/lib.rs b/core/lib/zk_os_merkle_tree/src/lib.rs new file mode 100644 index 000000000000..d1bd7c7e8ccd --- /dev/null +++ b/core/lib/zk_os_merkle_tree/src/lib.rs @@ -0,0 +1,224 @@ +//! Persistent ZK OS Merkle tree. + +use std::{fmt, time::Instant}; + +use anyhow::Context as _; +use zksync_basic_types::H256; +pub use zksync_crypto_primitives::hasher::blake2::Blake2Hasher; + +pub use self::{ + errors::DeserializeError, + hasher::HashTree, + storage::{Database, MerkleTreeColumnFamily, PatchSet, RocksDBWrapper}, + types::{BatchOutput, TreeEntry}, +}; +use crate::{ + hasher::BatchTreeProof, + storage::{TreeUpdate, WorkingPatchSet}, + types::MAX_TREE_DEPTH, +}; + +mod consistency; +mod errors; +mod hasher; +mod storage; +#[cfg(test)] +mod tests; +mod types; + +/// Marker trait for tree parameters. +pub trait TreeParams: fmt::Debug + Send + Sync { + type Hasher: HashTree; + + const TREE_DEPTH: u8; + const INTERNAL_NODE_DEPTH: u8; +} + +#[inline(always)] +pub(crate) fn leaf_nibbles() -> u8 { + P::TREE_DEPTH.div_ceil(P::INTERNAL_NODE_DEPTH) +} + +#[inline(always)] +pub(crate) const fn max_nibbles_for_internal_node() -> u8 { + P::TREE_DEPTH.div_ceil(P::INTERNAL_NODE_DEPTH) - 1 +} + +#[inline(always)] +pub(crate) const fn max_node_children() -> u8 { + 1 << P::INTERNAL_NODE_DEPTH +} + +// TODO: internal node depth 3 looks slightly better from the I/O overhead & performance perspective +#[derive(Debug)] +pub struct DefaultTreeParams(()); + +impl TreeParams + for DefaultTreeParams +{ + type Hasher = Blake2Hasher; + const TREE_DEPTH: u8 = { + assert!(TREE_DEPTH > 0 && TREE_DEPTH <= MAX_TREE_DEPTH); + TREE_DEPTH + }; + const INTERNAL_NODE_DEPTH: u8 = { + assert!(INTERNAL_NODE_DEPTH > 0 && INTERNAL_NODE_DEPTH < 8); // to fit child count into `u8` + INTERNAL_NODE_DEPTH + }; +} + +#[derive(Debug)] +pub struct MerkleTree { + db: DB, + hasher: P::Hasher, +} + +impl MerkleTree { + /// Loads a tree with the default Blake2 hasher. + /// + /// # Errors + /// + /// Errors in the same situations as [`Self::with_hasher()`]. + pub fn new(db: DB) -> anyhow::Result { + Self::with_hasher(db, Blake2Hasher) + } +} + +impl MerkleTree { + /// Loads a tree with the specified hasher. + /// + /// # Errors + /// + /// Errors if the hasher or basic tree parameters (e.g., the tree depth) + /// do not match those of the tree loaded from the database. + pub fn with_hasher(db: DB, hasher: P::Hasher) -> anyhow::Result { + let maybe_manifest = db.try_manifest().context("failed reading tree manifest")?; + if let Some(manifest) = maybe_manifest { + manifest.tags.ensure_consistency::

(&hasher)?; + } + Ok(Self { db, hasher }) + } + + /// Returns the root hash of a tree at the specified `version`, or `None` if the version + /// was not written yet. + pub fn root_hash(&self, version: u64) -> anyhow::Result> { + let Some(root) = self.db.try_root(version)? else { + return Ok(None); + }; + Ok(Some(root.hash::

(&self.hasher))) + } + + /// Returns the latest version of the tree present in the database, or `None` if + /// no versions are present yet. + pub fn latest_version(&self) -> anyhow::Result> { + let Some(manifest) = self.db.try_manifest()? else { + return Ok(None); + }; + Ok(manifest.version_count.checked_sub(1)) + } + + pub fn latest_root_hash(&self) -> anyhow::Result> { + let Some(version) = self + .latest_version() + .context("failed getting latest version")? + else { + return Ok(None); + }; + self.root_hash(version) + } + + /// Extends this tree by creating its new version. + /// + /// All keys in the provided entries must be distinct. + /// + /// # Return value + /// + /// Returns information about the update such as the final tree hash. + /// + /// # Errors + /// + /// Proxies database I/O errors. + pub fn extend(&mut self, entries: &[TreeEntry]) -> anyhow::Result { + let (output, _) = self.extend_inner(entries, false)?; + Ok(output) + } + + #[tracing::instrument(level = "debug", name = "extend", skip_all, fields(latest_version))] + fn extend_inner( + &mut self, + entries: &[TreeEntry], + with_proof: bool, + ) -> anyhow::Result<(BatchOutput, Option)> { + let latest_version = self + .latest_version() + .context("failed getting latest version")?; + tracing::Span::current().record("latest_version", latest_version); + + let started_at = Instant::now(); + let (mut patch, mut update) = if let Some(version) = latest_version { + self.create_patch(version, entries) + .context("failed loading tree data")? + } else { + ( + WorkingPatchSet::

::empty(), + TreeUpdate::for_empty_tree(entries)?, + ) + }; + tracing::debug!( + elapsed = ?started_at.elapsed(), + inserts = update.inserts.len(), + updates = update.updates.len(), + loaded_internal_nodes = patch.total_internal_nodes(), + "loaded tree data" + ); + + let proof = if with_proof { + let started_at = Instant::now(); + let proof = patch.create_batch_proof(&self.hasher, update.take_operations()); + tracing::debug!( + elapsed = ?started_at.elapsed(), + proof.leaves.len = proof.sorted_leaves.len(), + proof.hashes.len = proof.hashes.len(), + "created batch proof" + ); + Some(proof) + } else { + None + }; + + let started_at = Instant::now(); + let update = patch.update(update); + tracing::debug!(elapsed = ?started_at.elapsed(), "updated tree structure"); + + let started_at = Instant::now(); + let (patch, output) = patch.finalize(&self.hasher, update); + tracing::debug!(elapsed = ?started_at.elapsed(), "hashed tree"); + + let started_at = Instant::now(); + self.db + .apply_patch(patch) + .context("failed persisting tree changes")?; + tracing::debug!(elapsed = ?started_at.elapsed(), "persisted tree"); + Ok((output, proof)) + } + + pub fn extend_with_proof( + &mut self, + entries: &[TreeEntry], + ) -> anyhow::Result<(BatchOutput, BatchTreeProof)> { + let (output, proof) = self.extend_inner(entries, true)?; + Ok((output, proof.unwrap())) + } + + pub fn truncate_recent_versions(&mut self, retained_version_count: u64) -> anyhow::Result<()> { + let mut manifest = self.db.try_manifest()?.unwrap_or_default(); + let current_version_count = manifest.version_count; + if current_version_count > retained_version_count { + // TODO: It is necessary to remove "future" stale keys since otherwise they may be used in future pruning and lead + // to non-obsolete tree nodes getting removed. + manifest.version_count = retained_version_count; + self.db.truncate(manifest, ..current_version_count)?; + } + Ok(()) + } +} diff --git a/core/lib/zk_os_merkle_tree/src/storage/mod.rs b/core/lib/zk_os_merkle_tree/src/storage/mod.rs new file mode 100644 index 000000000000..02ec78d3f2a6 --- /dev/null +++ b/core/lib/zk_os_merkle_tree/src/storage/mod.rs @@ -0,0 +1,242 @@ +use std::{ + cmp, + collections::{BTreeMap, HashMap}, + ops, +}; + +use zksync_basic_types::H256; + +pub(crate) use self::patch::{TreeUpdate, WorkingPatchSet}; +pub use self::rocksdb::{MerkleTreeColumnFamily, RocksDBWrapper}; +use crate::{ + errors::{DeserializeContext, DeserializeError, DeserializeErrorKind}, + types::{InternalNode, KeyLookup, Leaf, Manifest, Node, NodeKey, Root}, +}; + +mod patch; +mod rocksdb; +mod serialization; +#[cfg(test)] +mod tests; + +/// Generic database functionality. Its main implementation is [`RocksDB`]. +pub trait Database: Send + Sync { + fn indices(&self, version: u64, keys: &[H256]) -> Result, DeserializeError>; + + fn try_manifest(&self) -> Result, DeserializeError>; + + /// Tries to obtain a root from this storage. + /// + /// # Errors + /// + /// Returns a deserialization error if any. + fn try_root(&self, version: u64) -> Result, DeserializeError>; + + /// Obtains nodes with the specified key from the storage. Root nodes must be obtained + /// using [`Self::root()`], never this method. + /// + /// # Errors + /// + /// Returns a deserialization error if any. + fn try_nodes(&self, keys: &[NodeKey]) -> Result, DeserializeError>; + + /// Applies changes in the `patch` to this database. This operation should be atomic. + /// + /// # Errors + /// + /// Returns I/O errors. + fn apply_patch(&mut self, patch: PatchSet) -> anyhow::Result<()>; + + /// Truncates the tree. `manifest` specifies the new number of tree versions, and `truncated_versions` + /// contains the last version *before* the truncation. This operation should be atomic. + fn truncate( + &mut self, + manifest: Manifest, + truncated_versions: ops::RangeTo, + ) -> anyhow::Result<()>; +} + +impl Database for &mut DB { + fn indices(&self, version: u64, keys: &[H256]) -> Result, DeserializeError> { + (**self).indices(version, keys) + } + + fn try_manifest(&self) -> Result, DeserializeError> { + (**self).try_manifest() + } + + fn try_root(&self, version: u64) -> Result, DeserializeError> { + (**self).try_root(version) + } + + fn try_nodes(&self, keys: &[NodeKey]) -> Result, DeserializeError> { + (**self).try_nodes(keys) + } + + fn apply_patch(&mut self, patch: PatchSet) -> anyhow::Result<()> { + (**self).apply_patch(patch) + } + + fn truncate( + &mut self, + manifest: Manifest, + truncated_versions: ops::RangeTo, + ) -> anyhow::Result<()> { + (**self).truncate(manifest, truncated_versions) + } +} + +#[derive(Debug, Clone, Copy)] +#[cfg_attr(test, derive(PartialEq))] +struct InsertedKeyEntry { + index: u64, + inserted_at: u64, +} + +#[derive(Debug, Clone)] +struct PartialPatchSet { + leaf_count: u64, + // TODO: maybe, a wrapper around `Vec<(_, _)>` would be more efficient? + /// `internal[0]` corresponds to a root, `internal[1]` to single-nibble nodes etc. + internal: Vec>, + /// Sorted by the index. + leaves: HashMap, +} + +impl PartialPatchSet { + fn root(&self) -> Root { + Root { + leaf_count: self.leaf_count, + root_node: self.internal[0][&0].clone(), + } + } + + fn node(&self, nibble_count: u8, index_on_level: u64) -> Option { + let nibble_count = usize::from(nibble_count); + let leaf_nibbles = self.internal.len(); + Some(match nibble_count.cmp(&leaf_nibbles) { + cmp::Ordering::Less => { + let level = &self.internal[nibble_count]; + level.get(&index_on_level)?.clone().into() + } + cmp::Ordering::Equal => (*self.leaves.get(&index_on_level)?).into(), + cmp::Ordering::Greater => return None, + }) + } + + fn total_internal_nodes(&self) -> usize { + self.internal.iter().map(HashMap::len).sum() + } +} + +/// Immutable in-memory changeset that can atomically applied to a [`Database`]. +#[derive(Debug, Clone, Default)] +pub struct PatchSet { + manifest: Manifest, + patches_by_version: HashMap, + // We maintain a joint index for all versions to make it easier to use `PatchSet` as a `Database` or in a `Patched` wrapper. + sorted_new_leaves: BTreeMap, + // TODO: stale keys +} + +impl PatchSet { + fn index(&self, version: u64, key: &H256) -> KeyLookup { + let (next_key, next_entry) = self + .sorted_new_leaves + .range(key..) + .find(|(_, entry)| entry.inserted_at <= version) + .expect("guards must be inserted into a tree on initialization"); + if next_key == key { + return KeyLookup::Existing(next_entry.index); + } + + let (prev_key, prev_entry) = self + .sorted_new_leaves + .range(..key) + .rev() + .find(|(_, entry)| entry.inserted_at <= version) + .expect("guards must be inserted into a tree on initialization"); + KeyLookup::Missing { + prev_key_and_index: (*prev_key, prev_entry.index), + next_key_and_index: (*next_key, next_entry.index), + } + } + + #[cfg(test)] + pub(crate) fn manifest_mut(&mut self) -> &mut Manifest { + &mut self.manifest + } +} + +impl Database for PatchSet { + fn indices(&self, version: u64, keys: &[H256]) -> Result, DeserializeError> { + use rayon::prelude::*; + + let mut lookup = vec![]; + keys.par_iter() + .map(|key| self.index(version, key)) + .collect_into_vec(&mut lookup); + Ok(lookup) + } + + fn try_manifest(&self) -> Result, DeserializeError> { + // We consider manifest absent if there are no tree versions. This is important for tree tag checks on the empty tree. + Ok(if self.manifest.version_count == 0 { + None + } else { + Some(self.manifest.clone()) + }) + } + + fn try_root(&self, version: u64) -> Result, DeserializeError> { + Ok(self + .patches_by_version + .get(&version) + .map(PartialPatchSet::root)) + } + + fn try_nodes(&self, keys: &[NodeKey]) -> Result, DeserializeError> { + let nodes = keys.iter().map(|key| { + assert!(key.nibble_count > 0); + + let maybe_node = self + .patches_by_version + .get(&key.version) + .and_then(|patch| patch.node(key.nibble_count, key.index_on_level)); + maybe_node.ok_or_else(|| { + DeserializeError::from(DeserializeErrorKind::MissingNode) + .with_context(DeserializeContext::Node(*key)) + }) + }); + nodes.collect() + } + + fn apply_patch(&mut self, patch: PatchSet) -> anyhow::Result<()> { + let new_version_count = patch.manifest.version_count; + anyhow::ensure!( + new_version_count >= self.manifest.version_count, + "Use `truncate()` for truncating tree" + ); + + self.manifest = patch.manifest; + self.patches_by_version.extend(patch.patches_by_version); + self.sorted_new_leaves.extend(patch.sorted_new_leaves); + Ok(()) + } + + fn truncate( + &mut self, + manifest: Manifest, + _truncated_versions: ops::RangeTo, + ) -> anyhow::Result<()> { + let new_version_count = manifest.version_count; + self.patches_by_version + .retain(|&version, _| version < new_version_count); + // This requires a full scan, but we assume there aren't that many data in a patch (it's mostly used as a `Database` for testing). + self.sorted_new_leaves + .retain(|_, entry| entry.inserted_at < new_version_count); + + self.manifest = manifest; + Ok(()) + } +} diff --git a/core/lib/zk_os_merkle_tree/src/storage/patch.rs b/core/lib/zk_os_merkle_tree/src/storage/patch.rs new file mode 100644 index 000000000000..270fdfb3cb49 --- /dev/null +++ b/core/lib/zk_os_merkle_tree/src/storage/patch.rs @@ -0,0 +1,607 @@ +use std::{ + collections::{BTreeMap, BTreeSet, HashMap}, + marker::PhantomData, + mem, ops, + time::{Duration, Instant}, +}; + +use anyhow::Context as _; +use zksync_basic_types::H256; + +use super::{Database, InsertedKeyEntry, PartialPatchSet, PatchSet}; +use crate::{ + errors::{DeserializeContext, DeserializeErrorKind}, + hasher::{BatchTreeProof, IntermediateHash, InternalHashes, TreeOperation}, + leaf_nibbles, max_nibbles_for_internal_node, max_node_children, + types::{InternalNode, KeyLookup, Leaf, Manifest, Node, NodeKey, Root, TreeTags}, + BatchOutput, DeserializeError, HashTree, MerkleTree, TreeEntry, TreeParams, +}; + +/// Information about an atomic tree update. +#[must_use = "Should be applied to a `PartialPatchSet`"] +#[derive(Debug)] +pub(crate) struct TreeUpdate { + pub(super) version: u64, + pub(super) sorted_new_leaves: BTreeMap, + pub(crate) updates: Vec<(u64, H256)>, + pub(crate) inserts: Vec, + operations: Vec, +} + +impl TreeUpdate { + pub(crate) fn for_empty_tree(entries: &[TreeEntry]) -> anyhow::Result { + let mut sorted_new_leaves = BTreeMap::from([ + ( + H256::zero(), + InsertedKeyEntry { + index: 0, + inserted_at: 0, + }, + ), + ( + H256::repeat_byte(0xff), + InsertedKeyEntry { + index: 1, + inserted_at: 0, + }, + ), + ]); + sorted_new_leaves.extend(entries.iter().enumerate().map(|(i, entry)| { + ( + entry.key, + InsertedKeyEntry { + index: i as u64 + 2, + inserted_at: 0, + }, + ) + })); + + anyhow::ensure!( + sorted_new_leaves.len() == entries.len() + 2, + "Attempting to insert duplicate keys into a tree; please deduplicate keys on the caller side" + ); + + let mut inserts = Vec::with_capacity(entries.len() + 2); + for entry in [&TreeEntry::MIN_GUARD, &TreeEntry::MAX_GUARD] + .into_iter() + .chain(entries) + { + let prev_index = match sorted_new_leaves.range(..entry.key).next_back() { + Some((_, prev_entry)) => prev_entry.index, + None => { + assert_eq!(entry.key, H256::zero()); + 0 + } + }; + + let next_range = (ops::Bound::Excluded(entry.key), ops::Bound::Unbounded); + let next_index = match sorted_new_leaves.range(next_range).next() { + Some((_, next_entry)) => next_entry.index, + None => { + assert_eq!(entry.key, H256::repeat_byte(0xff)); + 1 + } + }; + + inserts.push(Leaf { + key: entry.key, + value: entry.value, + prev_index, + next_index, + }); + } + + Ok(Self { + version: 0, + sorted_new_leaves, + updates: vec![], + inserts, + // Since the tree is empty, we expect a completely empty `BatchTreeProof` incl. `operations`; see its validation logic. + // This makes the proof occupy less space and doesn't require to specify bogus indices for `TreeOperation::Insert`. + operations: vec![], + }) + } + + pub(crate) fn take_operations(&mut self) -> Vec { + mem::take(&mut self.operations) + } +} + +#[must_use = "Should be finalized with a `PartialPatchSet`"] +#[derive(Debug)] +pub(crate) struct FinalTreeUpdate { + pub(super) version: u64, + pub(super) sorted_new_leaves: BTreeMap, +} + +impl PartialPatchSet { + /// Updates ancestor's `ChildRef` version for all loaded internal nodes. This should be called before adding new leaves + /// to the tree; it works because the loaded leaves are exactly the leaves for which ancestor versions must be updated. + fn update_ancestor_versions(&mut self, version: u64) { + let mut indices: Vec<_> = self.leaves.keys().copied().collect(); + indices.sort_unstable(); + + for internal_level in self.internal.iter_mut().rev() { + let mut prev_index = None; + indices = indices + .into_iter() + .filter_map(|idx| { + let parent_idx = idx >> P::INTERNAL_NODE_DEPTH; + let parent = internal_level.get_mut(&parent_idx).unwrap(); + parent + .child_mut((idx % u64::from(max_node_children::

())) as usize) + .version = version; + + if prev_index == Some(parent_idx) { + None + } else { + prev_index = Some(parent_idx); + Some(parent_idx) + } + }) + .collect(); + } + } +} + +#[derive(Debug)] +pub(crate) struct WorkingPatchSet

{ + inner: PartialPatchSet, + _params: PhantomData

, +} + +impl WorkingPatchSet

{ + pub(crate) fn empty() -> Self { + Self::new(Root { + leaf_count: 0, + root_node: InternalNode::empty(), + }) + } + + fn new(root: Root) -> Self { + let mut internal = vec![HashMap::new(); max_nibbles_for_internal_node::

() as usize + 1]; + internal[0].insert(0, root.root_node); + + Self { + inner: PartialPatchSet { + leaf_count: root.leaf_count, + internal, + leaves: HashMap::new(), + }, + _params: PhantomData, + } + } + + #[cfg(test)] + pub(super) fn inner(&self) -> &PartialPatchSet { + &self.inner + } + + /// `leaf_indices` must be sorted. + fn load_nodes( + &mut self, + db: &impl Database, + leaf_indices: impl Iterator + Clone, + ) -> anyhow::Result<()> { + let this = &mut self.inner; + for nibble_count in 1..=leaf_nibbles::

() { + let bit_shift = (leaf_nibbles::

() - nibble_count) * P::INTERNAL_NODE_DEPTH; + + let mut prev_index_on_level = None; + let parent_level = &this.internal[usize::from(nibble_count) - 1]; + + let requested_keys = leaf_indices.clone().filter_map(|idx| { + let index_on_level = idx >> bit_shift; + if prev_index_on_level == Some(index_on_level) { + None + } else { + prev_index_on_level = Some(index_on_level); + let parent_idx = index_on_level >> P::INTERNAL_NODE_DEPTH; + let parent = &parent_level[&parent_idx]; + let child_idx = index_on_level % u64::from(max_node_children::

()); + let this_ref = parent.child_ref(child_idx as usize); + let requested_key = NodeKey { + version: this_ref.version, + nibble_count, + index_on_level, + }; + Some((index_on_level, requested_key)) + } + }); + let (indices, requested_keys): (Vec<_>, Vec<_>) = requested_keys.unzip(); + let loaded_nodes = db.try_nodes(&requested_keys)?; + + if nibble_count == leaf_nibbles::

() { + this.leaves = loaded_nodes + .into_iter() + .zip(indices) + .map(|(node, idx)| { + ( + idx, + match node { + Node::Leaf(leaf) => leaf, + Node::Internal(_) => unreachable!(), + }, + ) + }) + .collect(); + } else { + this.internal[usize::from(nibble_count)] = loaded_nodes + .into_iter() + .zip(indices) + .map(|(node, idx)| { + ( + idx, + match node { + Node::Internal(node) => node, + Node::Leaf(_) => unreachable!(), + }, + ) + }) + .collect(); + } + } + + Ok(()) + } + + pub(crate) fn total_internal_nodes(&self) -> usize { + self.inner.total_internal_nodes() + } + + pub(crate) fn create_batch_proof( + &self, + hasher: &P::Hasher, + operations: Vec, + ) -> BatchTreeProof { + let sorted_leaves: BTreeMap<_, _> = self + .inner + .leaves + .iter() + .map(|(idx, leaf)| (*idx, *leaf)) + .collect(); + BatchTreeProof { + operations, + hashes: self.collect_hashes(sorted_leaves.keys().copied().collect(), hasher), + sorted_leaves, + } + } + + /// Provides necessary and sufficient hashes for a [`BatchTreeProof`]. Should be called before any modifying operations; + /// by design, leaves loaded for a batch update are exactly leaves included into a `BatchTreeProof`. + /// + /// `leaf_indices` is the sorted list of all loaded leaves. + fn collect_hashes(&self, leaf_indices: Vec, hasher: &P::Hasher) -> Vec { + let mut indices_on_level = leaf_indices; + if indices_on_level.is_empty() { + return vec![]; + } + + let this = &self.inner; + let mut hashes = vec![]; + // Should not underflow because `indices_on_level` is non-empty. + let mut last_idx_on_level = this.leaf_count - 1; + + let mut internal_hashes = None; + let mut internal_node_levels = this.internal.iter().rev(); + let mut hash_latency = Duration::ZERO; + let mut traverse_latency = Duration::ZERO; + + // The logic below essentially repeats `BatchTreeProof::zip_leaves()`, only instead of taking provided hashes, + // they are put in `hashes`. + for depth in 0..P::TREE_DEPTH { + let depth_in_internal_node = depth % P::INTERNAL_NODE_DEPTH; + if depth_in_internal_node == 0 { + // Initialize / update `internal_hashes`. Computing *all* internal hashes may be somewhat inefficient, + // but since it's parallelized, it doesn't look like a major concern. + let level = internal_node_levels + .next() + .expect("run out of internal node levels"); + let started_at = Instant::now(); + internal_hashes = Some(InternalHashes::new::

(level, hasher, depth)); + hash_latency += started_at.elapsed(); + } + + let started_at = Instant::now(); + // `unwrap()` is safe; `internal_hashes` is initialized on the first loop iteration + let internal_hashes = internal_hashes.as_ref().unwrap(); + + let mut i = 0; + let mut next_level_i = 0; + while i < indices_on_level.len() { + let current_idx = indices_on_level[i]; + if current_idx % 2 == 1 { + // The hash to the left is missing; get it from `hashes` + i += 1; + hashes.push(IntermediateHash { + value: internal_hashes.get(depth_in_internal_node, current_idx - 1), + #[cfg(test)] + location: (depth, current_idx - 1), + }); + } else if indices_on_level + .get(i + 1) + .map_or(false, |&next_idx| next_idx == current_idx + 1) + { + i += 2; + // Don't get the hash, it'll be available locally. + } else { + // The hash to the right is missing; get it from `hashes`, or set to the empty subtree hash. + i += 1; + if current_idx < last_idx_on_level { + hashes.push(IntermediateHash { + value: internal_hashes.get(depth_in_internal_node, current_idx + 1), + #[cfg(test)] + location: (depth, current_idx + 1), + }); + } + }; + + indices_on_level[next_level_i] = current_idx / 2; + next_level_i += 1; + } + indices_on_level.truncate(next_level_i); + last_idx_on_level /= 2; + traverse_latency += started_at.elapsed(); + } + + tracing::debug!( + ?hash_latency, + ?traverse_latency, + "collected hashes for batch proof" + ); + hashes + } + + pub(crate) fn update(&mut self, update: TreeUpdate) -> FinalTreeUpdate { + let this = &mut self.inner; + let version = update.version; + this.update_ancestor_versions::

(version); + + for (idx, value) in update.updates { + this.leaves.get_mut(&idx).unwrap().value = value; + } + + if !update.inserts.is_empty() { + let first_new_idx = this.leaf_count; + // Cannot underflow because `update.inserts.len() >= 1` + let new_indexes = first_new_idx..=(first_new_idx + update.inserts.len() as u64 - 1); + + // Update prev / next index pointers for neighbors. + for (idx, new_leaf) in new_indexes.clone().zip(&update.inserts) { + // Prev / next leaf may also be new, in which case, we'll insert it with the correct prev / next pointers, + // so we don't need to do anything here. + let prev_idx = new_leaf.prev_index; + if let Some(prev_leaf) = this.leaves.get_mut(&prev_idx) { + prev_leaf.next_index = idx; + } + + let next_idx = new_leaf.next_index; + if let Some(next_leaf) = this.leaves.get_mut(&next_idx) { + next_leaf.prev_index = idx; + } + } + + this.leaves.extend(new_indexes.clone().zip(update.inserts)); + this.leaf_count = *new_indexes.end() + 1; + + // Add / update internal nodes. + for (i, internal_level) in this.internal.iter_mut().enumerate() { + let nibble_count = i as u8; + let child_depth = + (max_nibbles_for_internal_node::

() - nibble_count) * P::INTERNAL_NODE_DEPTH; + let first_index_on_level = + (new_indexes.start() >> child_depth) / u64::from(max_node_children::

()); + let last_child_index = new_indexes.end() >> child_depth; + let last_index_on_level = last_child_index / u64::from(max_node_children::

()); + + // Only `first_index_on_level` may exist already; all others are necessarily new. + let mut start_idx = first_index_on_level; + if let Some(parent) = internal_level.get_mut(&first_index_on_level) { + let expected_len = if last_index_on_level == first_index_on_level { + (last_child_index % u64::from(max_node_children::

())) as usize + 1 + } else { + max_node_children::

().into() + }; + parent.ensure_len(expected_len, version); + start_idx += 1; + } + + let new_nodes = (start_idx..=last_index_on_level).map(|idx| { + let expected_len = if idx == last_index_on_level { + (last_child_index % u64::from(max_node_children::

())) as usize + 1 + } else { + max_node_children::

().into() + }; + (idx, InternalNode::new(expected_len, version)) + }); + internal_level.extend(new_nodes); + } + } + + FinalTreeUpdate { + version, + sorted_new_leaves: update.sorted_new_leaves, + } + } + + pub(crate) fn finalize( + self, + hasher: &P::Hasher, + update: FinalTreeUpdate, + ) -> (PatchSet, BatchOutput) { + use rayon::prelude::*; + + let mut this = self.inner; + let mut hashes: Vec<_> = this + .leaves + .par_iter() + .map(|(idx, leaf)| (*idx, hasher.hash_leaf(leaf))) + .collect(); + + for (nibble_depth, internal_level) in this.internal.iter_mut().rev().enumerate() { + for (idx, hash) in hashes { + // The parent node must exist by construction. + internal_level + .get_mut(&(idx >> P::INTERNAL_NODE_DEPTH)) + .unwrap() + .child_mut((idx % u64::from(max_node_children::

())) as usize) + .hash = hash; + } + + let depth = nibble_depth as u8 * P::INTERNAL_NODE_DEPTH; + hashes = internal_level + .par_iter() + .map(|(idx, node)| (*idx, node.hash::

(hasher, depth))) + .collect(); + } + + assert_eq!(hashes.len(), 1); + let (root_idx, root_hash) = hashes[0]; + assert_eq!(root_idx, 0); + let output = BatchOutput { + leaf_count: this.leaf_count, + root_hash, + }; + + let patch = PatchSet { + manifest: Manifest { + version_count: update.version + 1, + tags: TreeTags::for_params::

(hasher), + }, + patches_by_version: HashMap::from([(update.version, this)]), + sorted_new_leaves: update.sorted_new_leaves, + }; + (patch, output) + } +} + +impl MerkleTree { + /// Loads data for processing the specified entries into a patch set. + #[tracing::instrument( + level = "debug", + skip(self, entries), + fields(entries.len = entries.len()) + )] + pub(crate) fn create_patch( + &self, + latest_version: u64, + entries: &[TreeEntry], + ) -> anyhow::Result<(WorkingPatchSet

, TreeUpdate)> { + let root = self.db.try_root(latest_version)?.ok_or_else(|| { + DeserializeError::from(DeserializeErrorKind::MissingNode) + .with_context(DeserializeContext::Node(NodeKey::root(latest_version))) + })?; + let keys: Vec<_> = entries.iter().map(|entry| entry.key).collect(); + + let started_at = Instant::now(); + let lookup = self + .db + .indices(u64::MAX, &keys) + .context("failed loading indices")?; + tracing::debug!(elapsed = ?started_at.elapsed(), "loaded lookup info"); + + // Collect all distinct indices that need to be loaded. + let mut sorted_new_leaves = BTreeMap::new(); + let mut new_index = root.leaf_count; + let distinct_indices = + lookup + .iter() + .zip(entries) + .flat_map(|(lookup, entry)| match lookup { + KeyLookup::Existing(idx) => [*idx, *idx], + KeyLookup::Missing { + prev_key_and_index, + next_key_and_index, + } => { + sorted_new_leaves.insert( + entry.key, + InsertedKeyEntry { + index: new_index, + inserted_at: latest_version + 1, + }, + ); + new_index += 1; + + [prev_key_and_index.1, next_key_and_index.1] + } + }); + let mut distinct_indices: BTreeSet<_> = distinct_indices.collect(); + if !sorted_new_leaves.is_empty() { + // Need to load the latest existing leaf and its ancestors so that new ancestors can be correctly + // inserted for the new leaves. + distinct_indices.insert(root.leaf_count - 1); + } + + let started_at = Instant::now(); + let mut patch = WorkingPatchSet::new(root); + patch.load_nodes(&self.db, distinct_indices.iter().copied())?; + tracing::debug!( + elapsed = ?started_at.elapsed(), + distinct_indices.len = distinct_indices.len(), + "loaded nodes" + ); + + let mut updates = Vec::with_capacity(entries.len() - sorted_new_leaves.len()); + let mut inserts = Vec::with_capacity(sorted_new_leaves.len()); + let mut operations = Vec::with_capacity(entries.len()); + for (entry, lookup) in entries.iter().zip(lookup) { + match lookup { + KeyLookup::Existing(idx) => { + updates.push((idx, entry.value)); + operations.push(TreeOperation::Update { index: idx }); + } + + KeyLookup::Missing { + prev_key_and_index, + next_key_and_index, + } => { + // Adjust previous / next indices according to the data inserted in the same batch. + let mut prev_index = prev_key_and_index.1; + operations.push(TreeOperation::Insert { prev_index }); + + if let Some((&local_prev_key, inserted)) = + sorted_new_leaves.range(..entry.key).next_back() + { + if local_prev_key > prev_key_and_index.0 { + prev_index = inserted.index; + } + } + + let mut next_index = next_key_and_index.1; + let next_range = (ops::Bound::Excluded(entry.key), ops::Bound::Unbounded); + if let Some((&local_next_key, inserted)) = + sorted_new_leaves.range(next_range).next() + { + if local_next_key < next_key_and_index.0 { + next_index = inserted.index; + } + } + + inserts.push(Leaf { + key: entry.key, + value: entry.value, + prev_index, + next_index, + }); + } + } + } + + anyhow::ensure!( + sorted_new_leaves.len() == inserts.len(), + "Attempting to insert duplicate keys into a tree; please deduplicate keys on the caller side" + ); + // We don't check for duplicate updates since they don't lead to logical errors, they're just inefficient + + Ok(( + patch, + TreeUpdate { + version: latest_version + 1, + sorted_new_leaves, + updates, + inserts, + operations, + }, + )) + } +} diff --git a/core/lib/zk_os_merkle_tree/src/storage/rocksdb.rs b/core/lib/zk_os_merkle_tree/src/storage/rocksdb.rs new file mode 100644 index 000000000000..1cc64cbbaf65 --- /dev/null +++ b/core/lib/zk_os_merkle_tree/src/storage/rocksdb.rs @@ -0,0 +1,617 @@ +//! RocksDB implementation of [`Database`]. + +use std::{ops, path::Path}; + +use anyhow::Context as _; +use once_cell::sync::OnceCell; +use zksync_basic_types::H256; +use zksync_storage::{db::NamedColumnFamily, rocksdb, rocksdb::DBPinnableSlice, RocksDB}; + +use crate::{ + errors::{DeserializeContext, DeserializeErrorKind}, + storage::{InsertedKeyEntry, PartialPatchSet, PatchSet}, + types::{InternalNode, KeyLookup, Leaf, Manifest, Node, NodeKey, Root}, + Database, DeserializeError, +}; + +impl NodeKey { + const DB_KEY_LEN: usize = 8 + 1 + 8; + + fn as_db_key(&self) -> [u8; Self::DB_KEY_LEN] { + let mut buffer = [0_u8; Self::DB_KEY_LEN]; + buffer[..8].copy_from_slice(&self.version.to_be_bytes()); + buffer[8] = self.nibble_count; + buffer[9..].copy_from_slice(&self.index_on_level.to_be_bytes()); + buffer + } +} + +/// RocksDB column families used by the tree. +#[derive(Debug, Clone, Copy)] +pub enum MerkleTreeColumnFamily { + /// Column family containing versioned tree information in the form of + /// `NodeKey` -> `Node` mapping. + Tree, + /// Resolves keys to (index, version) tuples. + KeyIndices, + // TODO: stale keys +} + +impl NamedColumnFamily for MerkleTreeColumnFamily { + const DB_NAME: &'static str = "zkos_merkle_tree"; + const ALL: &'static [Self] = &[Self::Tree, Self::KeyIndices]; + + fn name(&self) -> &'static str { + match self { + Self::Tree => "default", + Self::KeyIndices => "key_indices", + } + } + + fn requires_tuning(&self) -> bool { + matches!(self, Self::Tree) + } +} + +/// Main [`Database`] implementation wrapping a [`RocksDB`] reference. +#[derive(Debug, Clone)] +pub struct RocksDBWrapper { + db: RocksDB, + multi_get_chunk_size: usize, + leaf_nibbles: OnceCell, +} + +impl RocksDBWrapper { + /// Key to store the tree [`Manifest`]. + // This key must not overlap with keys for nodes; easy to see that it's true, + // since the minimum node key is [0, 0, 0, 0, 0, 0, 0, 0]. + const MANIFEST_KEY: &'static [u8] = &[0]; + + /// Creates a new wrapper, initializing RocksDB at the specified directory. + /// + /// # Errors + /// + /// Propagates RocksDB I/O errors. + pub fn new(path: &Path) -> Result { + Ok(Self::from(RocksDB::new(path)?)) + } + + /// Sets the chunk size for multi-get operations. The requested keys will be split + /// into chunks of this size and requested in parallel using `rayon`. Setting chunk size + /// to a large value (e.g., `usize::MAX`) will effectively disable parallelism. + /// + /// [RocksDB docs] claim that multi-get operations may be parallelized internally, + /// but this seems to be dependent on the env; it may be the case that (single-threaded) + /// I/O parallelization is only achieved using `liburing`, which requires enabling + /// the `io-uring` feature of `rocksdb` crate and is only available on Linux. + /// Thus, setting this value to around `100..1_000` can still lead to substantial + /// performance boost (order of 2x) in some environments. + /// + /// [RocksDB docs]: https://github.com/facebook/rocksdb/wiki/MultiGet-Performance + pub fn set_multi_get_chunk_size(&mut self, chunk_size: usize) { + self.multi_get_chunk_size = chunk_size; + } + + fn set_leaf_nibbles(&mut self, manifest: &Manifest) -> anyhow::Result { + let tags = &manifest.tags; + let leaf_nibbles_from_manifest = tags.depth.div_ceil(tags.internal_node_depth); + if let Some(&leaf_nibbles) = self.leaf_nibbles.get() { + anyhow::ensure!( + leaf_nibbles_from_manifest == leaf_nibbles, + "Invalid manifest update" + ); + } else { + self.leaf_nibbles.set(leaf_nibbles_from_manifest).ok(); + } + Ok(leaf_nibbles_from_manifest) + } + + fn raw_node(&self, key: &[u8]) -> Option> { + self.db + .get_cf(MerkleTreeColumnFamily::Tree, key) + .expect("Failed reading from RocksDB") + } + + pub(crate) fn raw_nodes(&self, keys: &[NodeKey]) -> Vec>> { + use rayon::prelude::*; + + keys.par_chunks(self.multi_get_chunk_size) + .map(|chunk| { + let keys = chunk.iter().map(NodeKey::as_db_key); + let results = self.db.multi_get_cf(MerkleTreeColumnFamily::Tree, keys); + results + .into_iter() + .map(|result| result.expect("Failed reading from RocksDB")) + }) + .flatten_iter() + .collect() + } + + fn deserialize_node(&self, raw_node: &[u8], key: &NodeKey) -> Result { + let leaf_nibbles = *self.leaf_nibbles.get_or_try_init(|| { + let tags = self + .try_manifest()? + .ok_or(DeserializeErrorKind::MissingManifest)? + .tags; + Ok::<_, DeserializeError>(tags.depth.div_ceil(tags.internal_node_depth)) + })?; + + // If we didn't succeed with the patch set, or the key version is old, + // access the underlying storage. + let node = if key.nibble_count == leaf_nibbles { + Leaf::deserialize(raw_node).map(Node::Leaf) + } else { + InternalNode::deserialize(raw_node).map(Node::Internal) + }; + node.map_err(|err| err.with_context(DeserializeContext::Node(*key))) + } + + fn lookup_key(&self, key: H256, version: u64) -> Result { + let (next_key, next_entry) = self + .db + .from_iterator_cf(MerkleTreeColumnFamily::KeyIndices, key.as_bytes()..) + .find_map(|(key, raw_entry)| { + let entry = InsertedKeyEntry::deserialize(&raw_entry) + .map_err(|err| err.with_context(DeserializeContext::KeyIndex(key.clone()))) + .unwrap(); + (entry.inserted_at <= version).then(|| (H256::from_slice(&key), entry)) + }) + .expect("guards must be inserted into a tree on initialization"); + + if next_key == key { + return Ok(KeyLookup::Existing(next_entry.index)); + } + + let (prev_key, prev_entry) = self + .db + .to_iterator_cf(MerkleTreeColumnFamily::KeyIndices, ..=key.as_bytes()) + .find_map(|(key, raw_entry)| { + let entry = InsertedKeyEntry::deserialize(&raw_entry) + .map_err(|err| err.with_context(DeserializeContext::KeyIndex(key.clone()))) + .unwrap(); + (entry.inserted_at <= version).then(|| (H256::from_slice(&key), entry)) + }) + .expect("guards must be inserted into a tree on initialization"); + + Ok(KeyLookup::Missing { + prev_key_and_index: (prev_key, prev_entry.index), + next_key_and_index: (next_key, next_entry.index), + }) + } + + /// Returns the wrapped RocksDB instance. + pub fn into_inner(self) -> RocksDB { + self.db + } +} + +impl From> for RocksDBWrapper { + fn from(db: RocksDB) -> Self { + Self { + db, + multi_get_chunk_size: usize::MAX, + leaf_nibbles: OnceCell::new(), + } + } +} + +impl Database for RocksDBWrapper { + // TODO: Try alternatives (e.g., reusing iterators) + fn indices(&self, version: u64, keys: &[H256]) -> Result, DeserializeError> { + use rayon::prelude::*; + + let mut results = vec![]; + keys.par_iter() + .map(|&key| self.lookup_key(key, version)) + .collect_into_vec(&mut results); + results.into_iter().collect() + } + + fn try_manifest(&self) -> Result, DeserializeError> { + let Some(raw_manifest) = self.raw_node(Self::MANIFEST_KEY) else { + return Ok(None); + }; + Manifest::deserialize(&raw_manifest) + .map(Some) + .map_err(|err| err.with_context(DeserializeContext::Manifest)) + } + + fn try_root(&self, version: u64) -> Result, DeserializeError> { + let node_key = NodeKey::root(version); + let Some(raw_root) = self.raw_node(&node_key.as_db_key()) else { + return Ok(None); + }; + Root::deserialize(&raw_root) + .map(Some) + .map_err(|err| err.with_context(DeserializeContext::Node(node_key))) + } + + fn try_nodes(&self, keys: &[NodeKey]) -> Result, DeserializeError> { + let raw_nodes = self.raw_nodes(keys).into_iter().zip(keys); + + let nodes = raw_nodes.map(|(maybe_node, key)| { + let raw_node = maybe_node.ok_or_else(|| { + DeserializeError::from(DeserializeErrorKind::MissingNode) + .with_context(DeserializeContext::Node(*key)) + })?; + self.deserialize_node(&raw_node, key) + }); + nodes.collect() + } + + fn apply_patch(&mut self, patch: PatchSet) -> anyhow::Result<()> { + let leaf_nibbles = self.set_leaf_nibbles(&patch.manifest)?; + let tree_cf = MerkleTreeColumnFamily::Tree; + let mut write_batch = self.db.new_write_batch(); + let mut node_bytes = Vec::with_capacity(128); + // ^ 128 looks somewhat reasonable as node capacity + + let new_leaves = patch.sorted_new_leaves.len(); + let total_leaves: usize = patch + .patches_by_version + .values() + .map(|patch| patch.leaves.len()) + .sum(); + let total_internal_nodes: usize = patch + .patches_by_version + .values() + .map(PartialPatchSet::total_internal_nodes) + .sum(); + + patch.manifest.serialize(&mut node_bytes); + write_batch.put_cf(tree_cf, Self::MANIFEST_KEY, &node_bytes); + + for (key, entry) in patch.sorted_new_leaves { + node_bytes.clear(); + entry.serialize(&mut node_bytes); + write_batch.put_cf( + MerkleTreeColumnFamily::KeyIndices, + key.as_bytes(), + &node_bytes, + ); + } + + for (version, sub_patch) in patch.patches_by_version { + let root_key = NodeKey::root(version); + // Delete the key range corresponding to the entire new version. This removes + // potential garbage left after reverting the tree to a previous version. + let next_root_key = NodeKey::root(version + 1); + let keys_to_delete = &root_key.as_db_key()[..]..&next_root_key.as_db_key()[..]; + write_batch.delete_range_cf(tree_cf, keys_to_delete); + + node_bytes.clear(); + sub_patch.root().serialize(&mut node_bytes); + write_batch.put_cf(tree_cf, &root_key.as_db_key(), &node_bytes); + + // The root is serialized above, hence `skip(1)` + for (i, level) in sub_patch.internal.into_iter().enumerate().skip(1) { + let nibble_count = i as u8; + for (index_on_level, node) in level { + let node_key = NodeKey { + version, + nibble_count, + index_on_level, + }; + node_bytes.clear(); + node.serialize(&mut node_bytes); + write_batch.put_cf(tree_cf, &node_key.as_db_key(), &node_bytes); + } + } + + for (index_on_level, leaf) in sub_patch.leaves { + let node_key = NodeKey { + version, + nibble_count: leaf_nibbles, + index_on_level, + }; + node_bytes.clear(); + leaf.serialize(&mut node_bytes); + write_batch.put_cf(tree_cf, &node_key.as_db_key(), &node_bytes); + } + } + + tracing::debug!( + total_size = write_batch.size_in_bytes(), + new_leaves, + total_leaves, + total_internal_nodes, + "writing to RocksDB" + ); + + self.db + .write(write_batch) + .context("Failed writing a batch to RocksDB")?; + Ok(()) + } + + fn truncate( + &mut self, + manifest: Manifest, + truncated_versions: ops::RangeTo, + ) -> anyhow::Result<()> { + let leaf_nibbles = self.set_leaf_nibbles(&manifest)?; + let mut write_batch = self.db.new_write_batch(); + let mut node_bytes = Vec::with_capacity(128); + // ^ 128 looks somewhat reasonable as node capacity + + manifest.serialize(&mut node_bytes); + write_batch.put_cf( + MerkleTreeColumnFamily::Tree, + Self::MANIFEST_KEY, + &node_bytes, + ); + + // Find out the retained number of leaves. + let last_retained_version = manifest + .version_count + .checked_sub(1) + .context("at least 1 tree version must be retained")?; + let last_retained_root = self.try_root(last_retained_version)?.ok_or_else(|| { + DeserializeError::from(DeserializeErrorKind::MissingNode).with_context( + DeserializeContext::Node(NodeKey::root(last_retained_version)), + ) + })?; + let mut first_new_leaf_index = last_retained_root.leaf_count; + + // For each truncated version, get keys for the new leaves and remove them from the `KeyIndices` CF. + for truncated_version in manifest.version_count..truncated_versions.end { + let truncated_root = self.try_root(truncated_version)?.ok_or_else(|| { + DeserializeError::from(DeserializeErrorKind::MissingNode) + .with_context(DeserializeContext::Node(NodeKey::root(truncated_version))) + })?; + let new_leaf_count = truncated_root.leaf_count; + + let start_leaf_key = NodeKey { + version: truncated_version, + nibble_count: leaf_nibbles, + index_on_level: first_new_leaf_index, + }; + let start_leaf_key = start_leaf_key.as_db_key(); + + let new_leaves = self + .db + .from_iterator_cf(MerkleTreeColumnFamily::Tree, start_leaf_key.as_slice()..) + .take_while(|(raw_key, _)| { + // Otherwise, we're no longer iterating over leaves for `truncated_version` + raw_key[..9] == start_leaf_key[..9] + }) + .map(|(_, raw_leaf)| Leaf::deserialize(&raw_leaf)); + for new_leaf in new_leaves { + let new_key = new_leaf?.key; + write_batch.delete_cf(MerkleTreeColumnFamily::KeyIndices, new_key.as_bytes()); + } + + first_new_leaf_index = new_leaf_count; + } + + self.db + .write(write_batch) + .context("Failed writing a batch to RocksDB")?; + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use std::collections::{BTreeMap, HashMap}; + + use tempfile::TempDir; + use zksync_crypto_primitives::hasher::blake2::Blake2Hasher; + + use super::*; + use crate::{ + leaf_nibbles, max_nibbles_for_internal_node, max_node_children, storage::PartialPatchSet, + types::TreeTags, DefaultTreeParams, MerkleTree, TreeEntry, TreeParams, + }; + + #[test] + fn looking_up_keys() { + let temp_dir = TempDir::new().unwrap(); + let mut db = RocksDBWrapper::new(temp_dir.path()).unwrap(); + let patch = PatchSet { + sorted_new_leaves: BTreeMap::from([ + ( + H256::zero(), + InsertedKeyEntry { + index: 0, + inserted_at: 0, + }, + ), + ( + H256::repeat_byte(0xff), + InsertedKeyEntry { + index: 1, + inserted_at: 0, + }, + ), + ( + H256::repeat_byte(1), + InsertedKeyEntry { + index: 2, + inserted_at: 1, + }, + ), + ]), + ..PatchSet::default() + }; + db.apply_patch(patch).unwrap(); + + assert_eq!( + db.lookup_key(H256::repeat_byte(1), 0).unwrap(), + KeyLookup::Missing { + prev_key_and_index: (H256::zero(), 0), + next_key_and_index: (H256::repeat_byte(0xff), 1), + } + ); + for version in [1, 2] { + assert_eq!( + db.lookup_key(H256::repeat_byte(1), version).unwrap(), + KeyLookup::Existing(2) + ); + } + + assert_eq!( + db.lookup_key(H256::repeat_byte(2), 0).unwrap(), + KeyLookup::Missing { + prev_key_and_index: (H256::zero(), 0), + next_key_and_index: (H256::repeat_byte(0xff), 1), + } + ); + for version in [1, 2] { + assert_eq!( + db.lookup_key(H256::repeat_byte(2), version).unwrap(), + KeyLookup::Missing { + prev_key_and_index: (H256::repeat_byte(1), 2), + next_key_and_index: (H256::repeat_byte(0xff), 1), + } + ); + } + + assert_eq!( + db.lookup_key(H256::from_low_u64_be(u64::MAX), 0).unwrap(), + KeyLookup::Missing { + prev_key_and_index: (H256::zero(), 0), + next_key_and_index: (H256::repeat_byte(0xff), 1), + } + ); + for version in [1, 2] { + assert_eq!( + db.lookup_key(H256::from_low_u64_be(u64::MAX), version) + .unwrap(), + KeyLookup::Missing { + prev_key_and_index: (H256::zero(), 0), + next_key_and_index: (H256::repeat_byte(1), 2), + } + ); + } + } + + fn test_persisting_nodes>() { + let temp_dir = TempDir::new().unwrap(); + let mut db = RocksDBWrapper::new(temp_dir.path()).unwrap(); + let patch = PartialPatchSet { + leaf_count: 2, + internal: (0..leaf_nibbles::

()) + .map(|i| { + HashMap::from([( + 0, + InternalNode::new(usize::from(i % max_node_children::

()) + 1, 0), + )]) + }) + .collect(), + leaves: HashMap::from([(0, Leaf::MIN_GUARD), (1, Leaf::MAX_GUARD)]), + }; + let patch = PatchSet { + manifest: Manifest { + version_count: 1, + tags: TreeTags::for_params::

(&Blake2Hasher), + }, + patches_by_version: HashMap::from([(0, patch)]), + ..PatchSet::default() + }; + db.apply_patch(patch).unwrap(); + + let manifest = db.try_manifest().unwrap().expect("no manifest"); + assert_eq!(manifest.version_count, 1); + + let root = db.try_root(0).unwrap().expect("no root"); + assert_eq!(root.leaf_count, 2); + assert_eq!(root.root_node, InternalNode::new(1, 0)); + + for nibble_count in 1..=max_nibbles_for_internal_node::

() { + let node_key = NodeKey { + version: 0, + nibble_count, + index_on_level: 0, + }; + let nodes = db.try_nodes(&[node_key]).unwrap(); + assert_eq!(nodes.len(), 1); + let Node::Internal(node) = &nodes[0] else { + panic!("unexpected node: {nodes:?}"); + }; + + let expected_node_len = nibble_count % max_node_children::

() + 1; + assert_eq!(*node, InternalNode::new(expected_node_len.into(), 0)); + } + + let leaf_keys = [ + NodeKey { + version: 0, + nibble_count: leaf_nibbles::

(), + index_on_level: 0, + }, + NodeKey { + version: 0, + nibble_count: leaf_nibbles::

(), + index_on_level: 1, + }, + ]; + let leaves = db.try_nodes(&leaf_keys).unwrap(); + assert_eq!(leaves.len(), 2); + let [Node::Leaf(first_leaf), Node::Leaf(second_leaf)] = leaves.as_slice() else { + panic!("unexpected node: {leaves:?}"); + }; + assert_eq!(*first_leaf, Leaf::MIN_GUARD); + assert_eq!(*second_leaf, Leaf::MAX_GUARD); + } + + #[test] + fn persisting_nodes() { + println!("Default tree params"); + test_persisting_nodes::(); + println!("Default tree params"); + test_persisting_nodes::>(); + println!("Default tree params"); + test_persisting_nodes::>(); + } + + fn get_all_keys(db: &RocksDBWrapper) -> Vec { + db.db + .prefix_iterator_cf(MerkleTreeColumnFamily::KeyIndices, &[]) + .map(|(raw_key, _)| H256::from_slice(&raw_key)) + .collect() + } + + #[test] + fn truncating_tree_removes_key_indices() { + let temp_dir = TempDir::new().unwrap(); + let db = RocksDBWrapper::new(temp_dir.path()).unwrap(); + + let mut tree = MerkleTree::new(db).unwrap(); + tree.extend(&[]).unwrap(); + tree.extend(&[TreeEntry { + key: H256::repeat_byte(1), + value: H256::repeat_byte(2), + }]) + .unwrap(); + tree.extend(&[ + TreeEntry { + key: H256::repeat_byte(2), + value: H256::repeat_byte(3), + }, + TreeEntry { + key: H256::repeat_byte(3), + value: H256::repeat_byte(4), + }, + ]) + .unwrap(); + + let all_keys = get_all_keys(&tree.db); + assert_eq!( + all_keys, + [ + H256::zero(), + H256::repeat_byte(1), + H256::repeat_byte(2), + H256::repeat_byte(3), + H256::repeat_byte(0xff) + ] + ); + + tree.truncate_recent_versions(1).unwrap(); + + // Only guards should be retained. + let all_keys = get_all_keys(&tree.db); + assert_eq!(all_keys, [H256::zero(), H256::repeat_byte(0xff)]); + } +} diff --git a/core/lib/zk_os_merkle_tree/src/storage/serialization.rs b/core/lib/zk_os_merkle_tree/src/storage/serialization.rs new file mode 100644 index 000000000000..8b6da688d8a2 --- /dev/null +++ b/core/lib/zk_os_merkle_tree/src/storage/serialization.rs @@ -0,0 +1,220 @@ +//! Binary serialization of tree nodes. + +use std::str; + +use zksync_basic_types::H256; + +use crate::{ + errors::{DeserializeContext, DeserializeErrorKind}, + storage::InsertedKeyEntry, + types::{ChildRef, InternalNode, Leaf, Manifest, Root, TreeTags}, + DeserializeError, +}; + +const HASH_SIZE: usize = 32; + +impl InsertedKeyEntry { + pub(super) fn deserialize(mut buffer: &[u8]) -> Result { + let index = leb128::read::unsigned(&mut buffer).map_err(DeserializeErrorKind::Leb128)?; + let inserted_at = + leb128::read::unsigned(&mut buffer).map_err(DeserializeErrorKind::Leb128)?; + Ok(Self { index, inserted_at }) + } + + pub(super) fn serialize(&self, buffer: &mut Vec) { + leb128::write::unsigned(buffer, self.index).unwrap(); + leb128::write::unsigned(buffer, self.inserted_at).unwrap(); + } +} + +impl Leaf { + pub(super) fn deserialize(mut buffer: &[u8]) -> Result { + if buffer.len() < 2 * HASH_SIZE { + return Err(DeserializeErrorKind::UnexpectedEof.into()); + } + let key = H256::from_slice(&buffer[..HASH_SIZE]); + let value = H256::from_slice(&buffer[HASH_SIZE..2 * HASH_SIZE]); + + buffer = &buffer[2 * HASH_SIZE..]; + let prev_index = + leb128::read::unsigned(&mut buffer).map_err(DeserializeErrorKind::Leb128)?; + let next_index = + leb128::read::unsigned(&mut buffer).map_err(DeserializeErrorKind::Leb128)?; + if !buffer.is_empty() { + return Err(DeserializeErrorKind::Leftovers.into()); + } + Ok(Self { + key, + value, + prev_index, + next_index, + }) + } + + pub(super) fn serialize(&self, buffer: &mut Vec) { + buffer.extend_from_slice(self.key.as_bytes()); + buffer.extend_from_slice(self.value.as_bytes()); + leb128::write::unsigned(buffer, self.prev_index).unwrap(); + leb128::write::unsigned(buffer, self.next_index).unwrap(); + } +} + +impl ChildRef { + fn deserialize(buffer: &mut &[u8]) -> Result { + if buffer.len() < HASH_SIZE { + return Err(DeserializeErrorKind::UnexpectedEof.into()); + } + let hash = H256::from_slice(&buffer[..HASH_SIZE]); + *buffer = &buffer[HASH_SIZE..]; + let version = leb128::read::unsigned(buffer).map_err(DeserializeErrorKind::Leb128)?; + Ok(Self { hash, version }) + } + + fn serialize(&self, buffer: &mut Vec) { + buffer.extend_from_slice(self.hash.as_bytes()); + leb128::write::unsigned(buffer, self.version).unwrap(); + } +} + +impl InternalNode { + pub(super) fn deserialize(mut buffer: &[u8]) -> Result { + if buffer.is_empty() { + return Err(DeserializeErrorKind::UnexpectedEof.into()); + } + let len = buffer[0]; + buffer = &buffer[1..]; + + let children: Vec<_> = (0..len) + .map(|i| { + ChildRef::deserialize(&mut buffer) + .map_err(|err| err.with_context(DeserializeContext::ChildRef(i))) + }) + .collect::>()?; + if !buffer.is_empty() { + return Err(DeserializeErrorKind::Leftovers.into()); + } + Ok(Self { children }) + } + + pub(super) fn serialize(&self, buffer: &mut Vec) { + buffer.push(self.children.len() as u8); + + for child_ref in &self.children { + child_ref.serialize(buffer); + } + } +} + +impl Root { + pub(super) fn deserialize(mut buffer: &[u8]) -> Result { + let leaf_count = leb128::read::unsigned(&mut buffer).map_err(|err| { + DeserializeError::from(DeserializeErrorKind::Leb128(err)) + .with_context(DeserializeContext::LeafCount) + })?; + Ok(Self { + leaf_count, + root_node: InternalNode::deserialize(buffer)?, + }) + } + + pub(super) fn serialize(&self, buffer: &mut Vec) { + leb128::write::unsigned(buffer, self.leaf_count).unwrap(); + self.root_node.serialize(buffer); + } +} + +impl TreeTags { + fn deserialize_str<'a>(bytes: &mut &'a [u8]) -> Result<&'a str, DeserializeErrorKind> { + let str_len = leb128::read::unsigned(bytes).map_err(DeserializeErrorKind::Leb128)?; + let str_len = usize::try_from(str_len).map_err(|_| DeserializeErrorKind::UnexpectedEof)?; + + if bytes.len() < str_len { + return Err(DeserializeErrorKind::UnexpectedEof); + } + let (s, rest) = bytes.split_at(str_len); + *bytes = rest; + str::from_utf8(s).map_err(DeserializeErrorKind::Utf8) + } + + fn serialize_str(bytes: &mut Vec, s: &str) { + leb128::write::unsigned(bytes, s.len() as u64).unwrap(); + bytes.extend_from_slice(s.as_bytes()); + } + + fn deserialize(bytes: &mut &[u8]) -> Result { + let tag_count = leb128::read::unsigned(bytes).map_err(DeserializeErrorKind::Leb128)?; + let mut architecture = None; + let mut hasher = None; + let mut depth = None; + let mut internal_node_depth = None; + + for _ in 0..tag_count { + let key = Self::deserialize_str(bytes)?; + let value = Self::deserialize_str(bytes)?; + match key { + "architecture" => architecture = Some(value.to_owned()), + "hasher" => hasher = Some(value.to_owned()), + "depth" => { + let parsed = + value + .parse::() + .map_err(|err| DeserializeErrorKind::MalformedTag { + name: "depth", + err: err.into(), + })?; + depth = Some(parsed); + } + "internal_node_depth" => { + let parsed = + value + .parse::() + .map_err(|err| DeserializeErrorKind::MalformedTag { + name: "internal_node_depth", + err: err.into(), + })?; + internal_node_depth = Some(parsed); + } + _ => return Err(DeserializeErrorKind::UnknownTag(key.to_owned()).into()), + } + } + Ok(Self { + architecture: architecture.ok_or(DeserializeErrorKind::MissingTag("architecture"))?, + depth: depth.ok_or(DeserializeErrorKind::MissingTag("depth"))?, + internal_node_depth: internal_node_depth + .ok_or(DeserializeErrorKind::MissingTag("internal_node_depth"))?, + hasher: hasher.ok_or(DeserializeErrorKind::MissingTag("hasher"))?, + }) + } + + fn serialize(&self, buffer: &mut Vec) { + let entry_count = 4; // custom tags aren't supported (yet?) + leb128::write::unsigned(buffer, entry_count).unwrap(); + + Self::serialize_str(buffer, "architecture"); + Self::serialize_str(buffer, &self.architecture); + Self::serialize_str(buffer, "depth"); + Self::serialize_str(buffer, &self.depth.to_string()); + Self::serialize_str(buffer, "internal_node_depth"); + Self::serialize_str(buffer, &self.internal_node_depth.to_string()); + Self::serialize_str(buffer, "hasher"); + Self::serialize_str(buffer, &self.hasher); + } +} + +impl Manifest { + pub(super) fn deserialize(mut bytes: &[u8]) -> Result { + let version_count = + leb128::read::unsigned(&mut bytes).map_err(DeserializeErrorKind::Leb128)?; + let tags = TreeTags::deserialize(&mut bytes)?; + + Ok(Self { + version_count, + tags, + }) + } + + pub(super) fn serialize(&self, buffer: &mut Vec) { + leb128::write::unsigned(buffer, self.version_count).unwrap(); + self.tags.serialize(buffer); + } +} diff --git a/core/lib/zk_os_merkle_tree/src/storage/tests.rs b/core/lib/zk_os_merkle_tree/src/storage/tests.rs new file mode 100644 index 000000000000..39ec18ca290d --- /dev/null +++ b/core/lib/zk_os_merkle_tree/src/storage/tests.rs @@ -0,0 +1,434 @@ +use std::collections::HashSet; + +use zksync_crypto_primitives::hasher::blake2::Blake2Hasher; + +use super::*; +use crate::{DefaultTreeParams, MerkleTree, TreeEntry, TreeParams}; + +#[test] +fn creating_min_update_for_empty_tree() { + let update = TreeUpdate::for_empty_tree(&[]).unwrap(); + assert_eq!(update.version, 0); + assert!(update.updates.is_empty()); + + assert_eq!(update.inserts.len(), 2); + assert_eq!(update.inserts[0], Leaf::MIN_GUARD); + assert_eq!(update.inserts[1], Leaf::MAX_GUARD); + + assert_eq!(update.sorted_new_leaves.len(), 2); + assert_eq!( + update.sorted_new_leaves[&H256::zero()], + InsertedKeyEntry { + index: 0, + inserted_at: 0, + } + ); + assert_eq!( + update.sorted_new_leaves[&H256::repeat_byte(0xff)], + InsertedKeyEntry { + index: 1, + inserted_at: 0, + } + ); +} + +#[test] +fn creating_non_empty_update_for_empty_tree() { + let update = TreeUpdate::for_empty_tree(&[ + TreeEntry { + key: H256::repeat_byte(2), + value: H256::from_low_u64_be(1), + }, + TreeEntry { + key: H256::repeat_byte(1), + value: H256::from_low_u64_be(2), + }, + ]) + .unwrap(); + assert_eq!(update.version, 0); + assert!(update.updates.is_empty()); + + assert_eq!(update.inserts.len(), 4); + assert_eq!( + update.inserts[0], + Leaf { + next_index: 3, + ..Leaf::MIN_GUARD + } + ); + assert_eq!( + update.inserts[1], + Leaf { + prev_index: 2, + ..Leaf::MAX_GUARD + } + ); + assert_eq!( + update.inserts[2], + Leaf { + key: H256::repeat_byte(2), + value: H256::from_low_u64_be(1), + prev_index: 3, + next_index: 1, + } + ); + assert_eq!( + update.inserts[3], + Leaf { + key: H256::repeat_byte(1), + value: H256::from_low_u64_be(2), + prev_index: 0, + next_index: 2, + } + ); + + assert_eq!(update.sorted_new_leaves.len(), 4); + assert_eq!( + update.sorted_new_leaves[&H256::zero()], + InsertedKeyEntry { + index: 0, + inserted_at: 0, + } + ); + assert_eq!( + update.sorted_new_leaves[&H256::repeat_byte(0xff)], + InsertedKeyEntry { + index: 1, + inserted_at: 0, + } + ); + assert_eq!( + update.sorted_new_leaves[&H256::repeat_byte(2)], + InsertedKeyEntry { + index: 2, + inserted_at: 0, + } + ); + assert_eq!( + update.sorted_new_leaves[&H256::repeat_byte(1)], + InsertedKeyEntry { + index: 3, + inserted_at: 0, + } + ); +} + +fn test_creating_empty_tree>() { + const { + assert!(P::TREE_DEPTH == 64); + } + + let mut patch = WorkingPatchSet::

::empty(); + let final_update = patch.update(TreeUpdate::for_empty_tree(&[]).unwrap()); + assert_eq!(final_update.version, 0); + + { + let patch = patch.inner(); + assert_eq!(patch.leaves.len(), 2); + assert_eq!(patch.leaves[&0], Leaf::MIN_GUARD); + assert_eq!(patch.leaves[&1], Leaf::MAX_GUARD); + let last_level = patch.internal.last().unwrap(); + assert_eq!(last_level.len(), 1); + assert_eq!(last_level[&0].children.len(), 2); + + for level in patch.internal.iter().rev().skip(1) { + assert_eq!(level.len(), 1); + assert_eq!(level[&0].children.len(), 1); + } + + assert_eq!(patch.leaf_count, 2); + assert_eq!(patch.root().root_node.children.len(), 1); + } + + let (patch, ..) = patch.finalize(&Blake2Hasher, final_update); + assert_eq!(patch.manifest.version_count, 1); + assert_eq!(patch.patches_by_version.len(), 1); + let root = patch.try_root(0).unwrap().expect("no root"); + assert_eq!(root.leaf_count, 2); + + assert_eq!(root.root_node.children.len(), 1); + let expected_root_hash: H256 = + "0x8a41011d351813c31088367deecc9b70677ecf15ffc24ee450045cdeaf447f63" + .parse() + .unwrap(); + assert_eq!(root.hash::

(&Blake2Hasher), expected_root_hash); +} + +#[test] +fn creating_empty_tree() { + println!("Default tree params"); + test_creating_empty_tree::(); + println!("Node depth = 3"); + test_creating_empty_tree::>(); + println!("Node depth = 2"); + test_creating_empty_tree::>(); +} + +fn test_creating_tree_with_leaves_in_single_batch

() +where + P: TreeParams, +{ + const { + assert!(P::TREE_DEPTH == 64); + } + + let mut patch = WorkingPatchSet::

::empty(); + let update = TreeUpdate::for_empty_tree(&[TreeEntry { + key: H256::repeat_byte(0x01), + value: H256::repeat_byte(0x10), + }]) + .unwrap(); + let final_update = patch.update(update); + + assert_eq!(patch.inner().leaves.len(), 3); + + let (patch, ..) = patch.finalize(&Blake2Hasher, final_update); + let root = patch.try_root(0).unwrap().expect("no root"); + assert_eq!(root.leaf_count, 3); + + let expected_root_hash: H256 = + "0x91a1688c802dc607125d0b5e5ab4d95d89a4a4fb8cca71a122db6076cb70f8f3" + .parse() + .unwrap(); + assert_eq!(root.hash::

(&Blake2Hasher), expected_root_hash); +} + +#[test] +fn creating_tree_with_leaves_in_single_batch() { + println!("Default tree params"); + test_creating_tree_with_leaves_in_single_batch::(); + println!("Node depth = 3"); + test_creating_tree_with_leaves_in_single_batch::>(); + println!("Node depth = 2"); + test_creating_tree_with_leaves_in_single_batch::>(); +} + +fn test_creating_tree_with_leaves_incrementally

() +where + P: TreeParams, +{ + const { + assert!(P::TREE_DEPTH == 64); + } + + let mut patch = WorkingPatchSet::

::empty(); + let final_update = patch.update(TreeUpdate::for_empty_tree(&[]).unwrap()); + let (patch, ..) = patch.finalize(&Blake2Hasher, final_update); + + let merkle_tree = MerkleTree::<_, P>::with_hasher(patch, Blake2Hasher).unwrap(); + let new_entry = TreeEntry { + key: H256::repeat_byte(0x01), + value: H256::repeat_byte(0x10), + }; + let (mut patch, update) = merkle_tree.create_patch(0, &[new_entry]).unwrap(); + + assert_eq!(patch.inner().leaf_count, 2); + assert_eq!( + patch.inner().leaves, + HashMap::from([(0, Leaf::MIN_GUARD), (1, Leaf::MAX_GUARD)]) + ); + + assert!(update.updates.is_empty()); + assert_eq!(update.inserts.len(), 1); + assert_eq!(update.inserts[0].prev_index, 0); + assert_eq!(update.inserts[0].next_index, 1); + assert_eq!(update.sorted_new_leaves.len(), 1); + assert_eq!( + update.sorted_new_leaves[&new_entry.key], + InsertedKeyEntry { + index: 2, + inserted_at: 1 + } + ); + + let final_update = patch.update(update); + { + let patch = patch.inner(); + assert_eq!(patch.leaf_count, 3); + assert_eq!( + patch.leaves[&0], + Leaf { + next_index: 2, + ..Leaf::MIN_GUARD + } + ); + assert_eq!( + patch.leaves[&1], + Leaf { + prev_index: 2, + ..Leaf::MAX_GUARD + } + ); + assert_eq!( + patch.leaves[&2], + Leaf { + key: new_entry.key, + value: new_entry.value, + prev_index: 0, + next_index: 1, + } + ); + } + + assert_eq!(final_update.version, 1); + let (new_patch, ..) = patch.finalize(&Blake2Hasher, final_update); + assert_eq!(new_patch.manifest.version_count, 2); + assert_eq!(new_patch.patches_by_version.len(), 1); + let root = new_patch.patches_by_version[&1].root(); + let expected_root_hash: H256 = + "0x91a1688c802dc607125d0b5e5ab4d95d89a4a4fb8cca71a122db6076cb70f8f3" + .parse() + .unwrap(); + assert_eq!(root.hash::

(&Blake2Hasher), expected_root_hash); +} + +#[test] +fn creating_tree_with_leaves_incrementally() { + println!("Default tree params"); + test_creating_tree_with_leaves_incrementally::(); + println!("Node depth = 3"); + test_creating_tree_with_leaves_incrementally::>(); + println!("Node depth = 2"); + test_creating_tree_with_leaves_incrementally::>(); +} + +fn test_creating_tree_with_multiple_leaves_and_update

() +where + P: TreeParams, +{ + const { + assert!(P::TREE_DEPTH == 64); + } + + let mut patch = WorkingPatchSet::

::empty(); + let final_update = patch.update(TreeUpdate::for_empty_tree(&[]).unwrap()); + let (patch, ..) = patch.finalize(&Blake2Hasher, final_update); + + let mut merkle_tree = MerkleTree::<_, P>::with_hasher(patch, Blake2Hasher).unwrap(); + let first_entry = TreeEntry { + key: H256::repeat_byte(0x01), + value: H256::repeat_byte(0x10), + }; + let second_entry = TreeEntry { + key: H256::repeat_byte(0x02), + value: H256::repeat_byte(0x20), + }; + let (mut patch, update) = merkle_tree + .create_patch(0, &[first_entry, second_entry]) + .unwrap(); + + let final_update = patch.update(update); + let (new_patch, ..) = patch.finalize(&Blake2Hasher, final_update); + + merkle_tree.db.apply_patch(new_patch).unwrap(); + + let expected_root_hash: H256 = + "0x20881c4aa37e3be665cc078db2727f0fc821bc5d9f09f053bb9a93ebd2799fcf" + .parse() + .unwrap(); + assert_eq!(merkle_tree.root_hash(1).unwrap(), Some(expected_root_hash)); + + let updated_entry = TreeEntry { + key: first_entry.key, + value: H256::repeat_byte(0x33), + }; + let (mut patch, update) = merkle_tree.create_patch(1, &[updated_entry]).unwrap(); + + assert!(update.inserts.is_empty()); + assert_eq!(update.updates, [(2, updated_entry.value)]); + + { + let patch = patch.inner(); + // `patch` should only load the updated leaf + assert_eq!(patch.leaves.len(), 1); + assert_eq!(patch.leaves[&2].key, updated_entry.key); + for level in &patch.internal { + assert_eq!(level.len(), 1, "{level:?}"); + } + } + + let final_update = patch.update(update); + let (new_patch, ..) = patch.finalize(&Blake2Hasher, final_update); + merkle_tree.db.apply_patch(new_patch).unwrap(); + + let expected_root_hash: H256 = + "0x4b6bd61930a8dee1bc412d8a38780f098137be9edbf29c078546b7492748d251" + .parse() + .unwrap(); + assert_eq!(merkle_tree.root_hash(2).unwrap(), Some(expected_root_hash)); +} + +#[test] +fn creating_tree_with_multiple_leaves_and_update() { + println!("Default tree params"); + test_creating_tree_with_multiple_leaves_and_update::(); + println!("Node depth = 3"); + test_creating_tree_with_multiple_leaves_and_update::>(); + println!("Node depth = 2"); + test_creating_tree_with_multiple_leaves_and_update::>(); +} + +fn test_mixed_update_and_insert

() +where + P: TreeParams, +{ + const { + assert!(P::TREE_DEPTH == 64); + } + + let mut merkle_tree = + MerkleTree::<_, P>::with_hasher(PatchSet::default(), Blake2Hasher).unwrap(); + let first_entry = TreeEntry { + key: H256::repeat_byte(0x01), + value: H256::repeat_byte(0x10), + }; + merkle_tree.extend(&[first_entry]).unwrap(); + + let updated_entry = TreeEntry { + key: first_entry.key, + value: H256::repeat_byte(0x33), + }; + let second_entry = TreeEntry { + key: H256::repeat_byte(0x02), + value: H256::repeat_byte(0x20), + }; + let (mut patch, update) = merkle_tree + .create_patch(0, &[updated_entry, second_entry]) + .unwrap(); + + assert_eq!( + update.inserts, + [Leaf { + key: second_entry.key, + value: second_entry.value, + prev_index: 2, + next_index: 1, + }] + ); + assert_eq!(update.updates, [(2, updated_entry.value)]); + // Leaf 1 is updated as a neighbor for the inserted leaf. Leaf 0 is not updated. + assert_eq!( + patch.inner().leaves.keys().copied().collect::>(), + HashSet::from([1, 2]) + ); + + let final_update = patch.update(update); + let (new_patch, ..) = patch.finalize(&Blake2Hasher, final_update); + merkle_tree.db.apply_patch(new_patch).unwrap(); + + let expected_root_hash: H256 = + "0x4b6bd61930a8dee1bc412d8a38780f098137be9edbf29c078546b7492748d251" + .parse() + .unwrap(); + assert_eq!(merkle_tree.root_hash(1).unwrap(), Some(expected_root_hash)); +} + +#[test] +fn mixed_update_and_insert() { + println!("Default tree params"); + test_mixed_update_and_insert::(); + println!("Node depth = 3"); + test_mixed_update_and_insert::>(); + println!("Node depth = 2"); + test_mixed_update_and_insert::>(); +} diff --git a/core/lib/zk_os_merkle_tree/src/tests.rs b/core/lib/zk_os_merkle_tree/src/tests.rs new file mode 100644 index 000000000000..f356304dd671 --- /dev/null +++ b/core/lib/zk_os_merkle_tree/src/tests.rs @@ -0,0 +1,331 @@ +//! Tests for the public `MerkleTree` interface. + +use std::collections::{BTreeMap, HashMap, HashSet}; + +use rand::{rngs::StdRng, seq::SliceRandom, Rng, SeedableRng}; + +use super::*; +use crate::{ + hasher::TreeOperation, + storage::PatchSet, + types::{Leaf, TreeTags}, +}; + +#[test] +fn tree_depth_mismatch() { + let mut db = PatchSet::default(); + db.manifest_mut().version_count = 1; + db.manifest_mut().tags = TreeTags { + depth: 48, + ..TreeTags::for_params::(&Blake2Hasher) + }; + + let err = MerkleTree::new(db).unwrap_err().to_string(); + assert!( + err.contains("Unexpected tree depth: expected 64, got 48"), + "{err}" + ); +} + +#[test] +fn tree_internal_node_depth_mismatch() { + let mut db = PatchSet::default(); + db.manifest_mut().version_count = 1; + db.manifest_mut().tags = TreeTags { + internal_node_depth: 3, + ..TreeTags::for_params::(&Blake2Hasher) + }; + + let err = MerkleTree::new(db).unwrap_err().to_string(); + assert!( + err.contains("Unexpected internal node depth: expected 4, got 3"), + "{err}" + ); +} + +fn naive_hash_tree(entries: &[TreeEntry]) -> H256 { + let mut indices = BTreeMap::from([(H256::zero(), 0_u64), (H256::repeat_byte(0xff), 1)]); + indices.extend( + entries + .iter() + .enumerate() + .map(|(i, entry)| (entry.key, i as u64 + 2)), + ); + let prev_indices: Vec<_> = [0].into_iter().chain(indices.values().copied()).collect(); + let next_indices: Vec<_> = indices.values().skip(1).copied().chain([1]).collect(); + let prev_and_next_indices: HashMap<_, _> = indices + .into_keys() + .zip(prev_indices.into_iter().zip(next_indices)) + .collect(); + + let leaves = [&TreeEntry::MIN_GUARD, &TreeEntry::MAX_GUARD] + .into_iter() + .chain(entries) + .map(|entry| { + let (prev_index, next_index) = prev_and_next_indices[&entry.key]; + Leaf { + key: entry.key, + value: entry.value, + prev_index, + next_index, + } + }); + + let mut hashes: Vec<_> = leaves.map(|leaf| Blake2Hasher.hash_leaf(&leaf)).collect(); + for depth in 0..64 { + if hashes.len() % 2 == 1 { + hashes.push(Blake2Hasher.empty_subtree_hash(depth)); + } + hashes = hashes + .chunks(2) + .map(|chunk| match chunk { + [lhs, rhs] => Blake2Hasher.hash_branch(lhs, rhs), + _ => unreachable!(), + }) + .collect(); + } + hashes[0] +} + +fn test_comparing_tree_hash_against_naive_impl(mut create_db: impl FnMut() -> DB) { + const RNG_SEED: u64 = 42; + + let mut rng = StdRng::seed_from_u64(RNG_SEED); + let nodes = (0..100).map(|_| TreeEntry { + key: H256(rng.gen()), + value: H256(rng.gen()), + }); + let inserts: Vec<_> = nodes.collect(); + let expected_root_hash = naive_hash_tree(&inserts); + + for chunk_size in [1, 2, 3, 5, 8, 13, 21, 34, 55, 100] { + println!("Insert in {chunk_size}-sized chunks"); + let mut tree = MerkleTree::new(create_db()).unwrap(); + for chunk in inserts.chunks(chunk_size) { + tree.extend(chunk).unwrap(); + } + let root_hash = tree.latest_root_hash().unwrap().expect("tree is empty"); + assert_eq!(root_hash, expected_root_hash); + + let latest_version = tree.latest_version().unwrap().expect("no version"); + for version in 0..=latest_version { + println!("Verifying version {version}"); + tree.verify_consistency(version).unwrap(); + } + } +} + +#[test] +fn comparing_tree_hash_against_naive_impl() { + test_comparing_tree_hash_against_naive_impl(PatchSet::default); +} + +fn test_comparing_tree_hash_with_updates(db: impl Database) { + const RNG_SEED: u64 = 42; + + let mut rng = StdRng::seed_from_u64(RNG_SEED); + let nodes = (0..100).map(|_| TreeEntry { + key: H256(rng.gen()), + value: H256(rng.gen()), + }); + let inserts: Vec<_> = nodes.collect(); + let initial_root_hash = naive_hash_tree(&inserts); + + let mut tree = MerkleTree::new(db).unwrap(); + tree.extend(&inserts).unwrap(); + assert_eq!( + tree.latest_root_hash().unwrap().expect("tree is empty"), + initial_root_hash + ); + + let mut updates = inserts; + for update in &mut updates { + update.value = H256(rng.gen()); + } + let new_root_hash = naive_hash_tree(&updates); + updates.shuffle(&mut rng); + + for chunk_size in [1, 2, 3, 5, 8, 13, 21, 34, 55, 100] { + println!("Update in {chunk_size}-sized chunks"); + for chunk in updates.chunks(chunk_size) { + tree.extend(chunk).unwrap(); + } + let root_hash = tree.latest_root_hash().unwrap().expect("tree is empty"); + assert_eq!(root_hash, new_root_hash); + + let latest_version = tree.latest_version().unwrap().expect("no version"); + for version in 0..=latest_version { + println!("Verifying version {version}"); + tree.verify_consistency(version).unwrap(); + } + + tree.truncate_recent_versions(1).unwrap(); + assert_eq!(tree.latest_version().unwrap(), Some(0)); + assert_eq!(tree.latest_root_hash().unwrap(), Some(initial_root_hash)); + } +} + +#[test] +fn comparing_tree_hash_with_updates() { + test_comparing_tree_hash_with_updates(PatchSet::default()); +} + +fn test_extending_tree_with_proof(db: impl Database, inserts_count: usize, update_count: usize) { + const RNG_SEED: u64 = 42; + + let mut rng = StdRng::seed_from_u64(RNG_SEED); + let nodes = (0..inserts_count).map(|_| TreeEntry { + key: H256(rng.gen()), + value: H256(rng.gen()), + }); + let inserts: Vec<_> = nodes.collect(); + + let mut tree = MerkleTree::new(db).unwrap(); + let (inserts_output, proof) = tree.extend_with_proof(&inserts).unwrap(); + let root_hash_from_proof = proof.verify(&Blake2Hasher, 64, None, &inserts).unwrap(); + assert_eq!(root_hash_from_proof, inserts_output.root_hash); + + // Test a proof with only updates. + let updates: Vec<_> = inserts + .choose_multiple(&mut rng, update_count) + .map(|entry| TreeEntry { + key: entry.key, + value: H256::zero(), + }) + .collect(); + + let (output, proof) = tree.extend_with_proof(&updates).unwrap(); + let updates_tree_hash = output.root_hash; + + assert_eq!(proof.operations.len(), updates.len()); + let mut updated_indices = vec![]; + for op in &proof.operations { + match *op { + TreeOperation::Update { index } => updated_indices.push(index), + TreeOperation::Insert { .. } => panic!("unexpected operation: {op:?}"), + } + } + updated_indices.sort_unstable(); + + assert_eq!(proof.sorted_leaves.len(), updates.len()); + assert_eq!( + proof.sorted_leaves.keys().copied().collect::>(), + updated_indices + ); + + let root_hash_from_proof = proof + .verify(&Blake2Hasher, 64, Some(inserts_output), &updates) + .unwrap(); + assert_eq!(root_hash_from_proof, updates_tree_hash); +} + +#[test] +fn extending_tree_with_proof() { + for insert_count in [10, 20, 50, 100, 1_000] { + for update_count in HashSet::from([1, 2, insert_count / 4, insert_count / 2]) { + println!("insert_count={insert_count}, update_count={update_count}"); + test_extending_tree_with_proof(PatchSet::default(), insert_count, update_count); + } + } +} + +fn test_incrementally_extending_tree_with_proofs(db: impl Database, update_count: usize) { + const RNG_SEED: u64 = 123; + + let mut tree = MerkleTree::new(db).unwrap(); + let empty_tree_output = tree.extend(&[]).unwrap(); + + let mut rng = StdRng::seed_from_u64(RNG_SEED); + let nodes = (0..1_000).map(|_| TreeEntry { + key: H256(rng.gen()), + value: H256(rng.gen()), + }); + let inserts: Vec<_> = nodes.collect(); + + for chunk_size in [1, 2, 3, 5, 8, 13, 21, 34, 55, 100] { + println!("Update in {chunk_size}-sized chunks"); + + let mut tree_output = empty_tree_output; + for (i, chunk) in inserts.chunks(chunk_size).enumerate() { + let chunk_start_idx = i * chunk_size; + // Only choose updates from the previously inserted entries. + let mut updates: Vec<_> = inserts[..chunk_start_idx] + .choose_multiple(&mut rng, update_count.min(chunk_start_idx)) + .map(|entry| TreeEntry { + key: entry.key, + value: H256::repeat_byte(0xff), + }) + .collect(); + + updates.shuffle(&mut rng); + let mut entries = chunk.to_vec(); + entries.extend(updates); + + let (new_output, proof) = tree.extend_with_proof(&entries).unwrap(); + let proof_hash = proof + .verify(&Blake2Hasher, 64, Some(tree_output), &entries) + .unwrap(); + assert_eq!(proof_hash, new_output.root_hash); + tree_output = new_output; + } + + tree.truncate_recent_versions(1).unwrap(); + } +} + +#[test] +fn incrementally_extending_tree_with_proofs() { + for update_count in [0, 1, 2, 5, 10] { + println!("update_count={update_count}"); + test_incrementally_extending_tree_with_proofs(PatchSet::default(), update_count); + } +} + +mod rocksdb { + use tempfile::TempDir; + + use super::*; + + #[test] + fn comparing_tree_hash_against_naive_impl() { + let temp_dir = TempDir::new().unwrap(); + let mut i = 0; + test_comparing_tree_hash_against_naive_impl(|| { + i += 1; + RocksDBWrapper::new(&temp_dir.path().join(i.to_string())).unwrap() + }); + } + + #[test] + fn comparing_tree_hash_with_updates() { + let temp_dir = TempDir::new().unwrap(); + let db = RocksDBWrapper::new(temp_dir.path()).unwrap(); + test_comparing_tree_hash_with_updates(db); + } + + #[test] + fn extending_tree_with_proof() { + let temp_dir = TempDir::new().unwrap(); + for insert_count in [10, 20, 50, 100, 1_000] { + for update_count in HashSet::from([1, 2, insert_count / 4, insert_count / 2]) { + println!("insert_count={insert_count}, update_count={update_count}"); + let db_path = temp_dir + .path() + .join(format!("{insert_count}-{update_count}")); + let db = RocksDBWrapper::new(&db_path).unwrap(); + test_extending_tree_with_proof(db, insert_count, update_count); + } + } + } + + #[test] + fn incrementally_extending_tree_with_proofs() { + let temp_dir = TempDir::new().unwrap(); + for update_count in [0, 1, 2, 5, 10] { + println!("update_count={update_count}"); + let db_path = temp_dir.path().join(update_count.to_string()); + let db = RocksDBWrapper::new(&db_path).unwrap(); + test_incrementally_extending_tree_with_proofs(db, update_count); + } + } +} diff --git a/core/lib/zk_os_merkle_tree/src/types/mod.rs b/core/lib/zk_os_merkle_tree/src/types/mod.rs new file mode 100644 index 000000000000..c0b020aa3d09 --- /dev/null +++ b/core/lib/zk_os_merkle_tree/src/types/mod.rs @@ -0,0 +1,263 @@ +use std::fmt; + +use zksync_basic_types::H256; +use zksync_crypto_primitives::hasher::blake2::Blake2Hasher; + +use crate::{DefaultTreeParams, HashTree, TreeParams}; + +/// Maximum supported tree depth (to fit indexes into `u64`). +pub(crate) const MAX_TREE_DEPTH: u8 = 64; + +/// Tree leaf. +#[derive(Debug, Clone, Copy)] +#[cfg_attr(test, derive(PartialEq))] +pub struct Leaf { + pub key: H256, + pub value: H256, + /// 0-based index of a leaf with the lexicographically previous key. + pub prev_index: u64, + /// 0-based index of a leaf with the lexicographically next key. + pub next_index: u64, +} + +impl Leaf { + /// Minimum guard leaf inserted at the tree at its initialization. + pub const MIN_GUARD: Self = Self { + key: H256::zero(), + value: H256::zero(), + prev_index: 0, + next_index: 1, + }; + + /// Maximum guard leaf inserted at the tree at its initialization. + pub const MAX_GUARD: Self = Self { + key: H256::repeat_byte(0xff), + value: H256::zero(), + prev_index: 0, + next_index: 1, + }; +} + +#[derive(Debug, Clone, Copy)] +#[cfg_attr(test, derive(PartialEq))] +pub(crate) struct ChildRef { + pub(crate) version: u64, + pub(crate) hash: H256, +} + +/// Internal node of the tree, potentially amortized to have higher number of child references +/// (e.g., 8 or 16 instead of 2), depending on [`TreeParams`]. +#[derive(Debug, Clone)] +#[cfg_attr(test, derive(PartialEq))] +pub struct InternalNode { + pub(crate) children: Vec, +} + +impl InternalNode { + pub(crate) fn empty() -> Self { + Self { children: vec![] } + } + + pub(crate) fn new(len: usize, version: u64) -> Self { + Self { + children: vec![ + ChildRef { + version, + hash: H256::zero() + }; + len + ], + } + } + + /// Panics if the index doesn't exist. + pub(crate) fn child_ref(&self, index: usize) -> &ChildRef { + &self.children[index] + } + + pub(crate) fn child_mut(&mut self, index: usize) -> &mut ChildRef { + &mut self.children[index] + } + + pub(crate) fn ensure_len(&mut self, expected_len: usize, version: u64) { + self.children.resize_with(expected_len, || ChildRef { + version, + hash: H256::zero(), + }); + } +} + +/// Arbitrary tree node. +#[derive(Debug, Clone)] +pub enum Node { + Internal(InternalNode), + Leaf(Leaf), +} + +impl From for Node { + fn from(node: InternalNode) -> Self { + Self::Internal(node) + } +} + +impl From for Node { + fn from(leaf: Leaf) -> Self { + Self::Leaf(leaf) + } +} + +/// Result of a key lookup in the tree. +/// +/// Either a leaf with this key is already present in the tree, or there are neighbor leaves, which need to be updated during insertion +/// or included into the proof for missing reads. +#[derive(Debug)] +#[cfg_attr(test, derive(PartialEq))] +pub enum KeyLookup { + Existing(u64), + Missing { + prev_key_and_index: (H256, u64), + next_key_and_index: (H256, u64), + }, +} + +/// Unique key for a versioned tree node. +#[derive(Clone, Copy)] +pub struct NodeKey { + /// Tree version. + pub(crate) version: u64, + /// 0 is root, 1 is its children etc. + pub(crate) nibble_count: u8, + pub(crate) index_on_level: u64, +} + +impl NodeKey { + pub(crate) const fn root(version: u64) -> Self { + Self { + version, + nibble_count: 0, + index_on_level: 0, + } + } +} + +impl fmt::Display for NodeKey { + fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + formatter, + "{}:{}nibs:{}", + self.version, self.nibble_count, self.index_on_level + ) + } +} + +impl fmt::Debug for NodeKey { + fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(self, formatter) + } +} + +/// Tree root: a node + additional metadata (for now, just the number of leaves in the tree). +#[derive(Debug, Clone)] +pub struct Root { + pub(crate) leaf_count: u64, + pub(crate) root_node: InternalNode, +} + +/// Entry in a Merkle tree associated with a key. Provided as an input for [`MerkleTree`](crate::MerkleTree) operations. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct TreeEntry { + /// Tree key. + pub key: H256, + /// Value associated with the key. + pub value: H256, +} + +impl TreeEntry { + pub(crate) const MIN_GUARD: Self = Self { + key: H256::zero(), + value: H256::zero(), + }; + + pub(crate) const MAX_GUARD: Self = Self { + key: H256::repeat_byte(0xff), + value: H256::zero(), + }; +} + +/// Persisted tags associated with a tree. +#[derive(Debug, Clone)] +#[cfg_attr(test, derive(PartialEq))] +pub(crate) struct TreeTags { + pub architecture: String, + pub depth: u8, + pub internal_node_depth: u8, + pub hasher: String, +} + +impl Default for TreeTags { + fn default() -> Self { + Self::for_params::(&Blake2Hasher) + } +} + +impl TreeTags { + const ARCHITECTURE: &'static str = "AmortizedLinkedListMT"; + + pub(crate) fn for_params(hasher: &P::Hasher) -> Self { + Self { + architecture: Self::ARCHITECTURE.to_owned(), + depth: P::TREE_DEPTH, + internal_node_depth: P::INTERNAL_NODE_DEPTH, + hasher: hasher.name().to_owned(), + } + } + + pub(crate) fn ensure_consistency( + &self, + hasher: &P::Hasher, + ) -> anyhow::Result<()> { + anyhow::ensure!( + self.architecture == Self::ARCHITECTURE, + "Unsupported tree architecture `{}`, expected `{}`", + self.architecture, + Self::ARCHITECTURE + ); + anyhow::ensure!( + self.depth == P::TREE_DEPTH, + "Unexpected tree depth: expected {expected}, got {got}", + expected = P::TREE_DEPTH, + got = self.depth + ); + anyhow::ensure!( + self.internal_node_depth == P::INTERNAL_NODE_DEPTH, + "Unexpected internal node depth: expected {expected}, got {got}", + expected = P::INTERNAL_NODE_DEPTH, + got = self.internal_node_depth + ); + anyhow::ensure!( + hasher.name() == self.hasher, + "Mismatch between the provided tree hasher `{}` and the hasher `{}` used \ + in the database", + hasher.name(), + self.hasher + ); + Ok(()) + } +} + +/// Version-independent information about the tree. +#[derive(Debug, Clone, Default)] +pub struct Manifest { + /// Number of tree versions stored in the database. + pub(crate) version_count: u64, + pub(crate) tags: TreeTags, +} + +/// Output of updating / inserting data in a [`MerkleTree`](crate::MerkleTree). +#[derive(Debug, Clone, Copy)] +pub struct BatchOutput { + /// New root hash of the tree. + pub root_hash: H256, + /// New leaf count (including 2 guard entries). + pub leaf_count: u64, +}