Skip to content

Commit

Permalink
Merge pull request #936 from Chia-Network/datalayer_helpers_for_python
Browse files Browse the repository at this point in the history
add a few missing datalayer helpers for python
  • Loading branch information
altendky authored Feb 19, 2025
2 parents cddd30a + 4568738 commit 8acbf89
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 25 deletions.
127 changes: 104 additions & 23 deletions crates/chia-datalayer/src/merkle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,12 @@ create_errors!(
BlockIndexOutOfBoundsError,
"block index out of bounds: {0}",
(TreeIndex)
),
(
LeafHashNotFound,
LeafHashNotFoundError,
"leaf hash not found: {0:?}",
(Hash)
)
)
);
Expand Down Expand Up @@ -641,20 +647,31 @@ impl Block {
}
}

// TODO: take the encouragement to clean up and make clippy happy again
#[allow(clippy::type_complexity)]
fn get_free_indexes_and_keys_values_indexes(
blob: &Vec<u8>,
) -> Result<(HashSet<TreeIndex>, HashMap<KeyId, TreeIndex>), Error> {
) -> Result<
(
HashSet<TreeIndex>,
HashMap<KeyId, TreeIndex>,
HashMap<Hash, TreeIndex>,
),
Error,
> {
let index_count = blob.len() / BLOCK_SIZE;

let mut seen_indexes: Vec<bool> = vec![false; index_count];
let mut key_to_index: HashMap<KeyId, TreeIndex> = HashMap::default();
let mut leaf_hash_to_index: HashMap<Hash, TreeIndex> = HashMap::default();

for item in MerkleBlobLeftChildFirstIterator::new(blob) {
for item in MerkleBlobLeftChildFirstIterator::new(blob, None) {
let (index, block) = item?;
seen_indexes[index.0 as usize] = true;

if let Node::Leaf(leaf) = block.node {
key_to_index.insert(leaf.key, index);
leaf_hash_to_index.insert(leaf.hash, index);
}
}

Expand All @@ -665,7 +682,7 @@ fn get_free_indexes_and_keys_values_indexes(
}
}

Ok((free_indexes, key_to_index))
Ok((free_indexes, key_to_index, leaf_hash_to_index))
}

/// Stores a DataLayer merkle tree in bytes and provides serialization on each access so that only
Expand All @@ -682,6 +699,7 @@ pub struct MerkleBlob {
// TODO: would be nice for this to be deterministic ala a fifo set
free_indexes: HashSet<TreeIndex>,
key_to_index: HashMap<KeyId, TreeIndex>,
leaf_hash_to_index: HashMap<Hash, TreeIndex>,
// TODO: used by fuzzing, some cleaner way? making it cfg-dependent is annoying with
// the type stubs
pub check_integrity_on_drop: bool,
Expand All @@ -696,12 +714,14 @@ impl MerkleBlob {
}

// TODO: maybe integrate integrity check here if quick enough
let (free_indexes, key_to_index) = get_free_indexes_and_keys_values_indexes(&blob)?;
let (free_indexes, key_to_index, leaf_hash_to_index) =
get_free_indexes_and_keys_values_indexes(&blob)?;

let self_ = Self {
blob,
free_indexes,
key_to_index,
leaf_hash_to_index,
check_integrity_on_drop: true,
};

Expand All @@ -712,6 +732,7 @@ impl MerkleBlob {
self.blob.clear();
self.key_to_index.clear();
self.free_indexes.clear();
self.leaf_hash_to_index.clear();
}

pub fn insert(
Expand Down Expand Up @@ -1080,7 +1101,7 @@ impl MerkleBlob {
}

fn get_min_height_leaf(&self) -> Result<LeafNode, Error> {
let (_index, block) = MerkleBlobBreadthFirstIterator::new(&self.blob)
let (_index, block) = MerkleBlobBreadthFirstIterator::new(&self.blob, None)
.next()
.ok_or(Error::UnableToFindALeaf())??;

Expand Down Expand Up @@ -1169,7 +1190,7 @@ impl MerkleBlob {
let mut internal_count: usize = 0;
let mut child_to_parent: HashMap<TreeIndex, TreeIndex> = HashMap::new();

for item in MerkleBlobParentFirstIterator::new(&self.blob) {
for item in MerkleBlobParentFirstIterator::new(&self.blob, None) {
let (index, block) = item?;
if let Some(parent) = block.node.parent().0 {
if child_to_parent.remove(&index) != Some(parent) {
Expand Down Expand Up @@ -1433,7 +1454,7 @@ impl MerkleBlob {

pub fn calculate_lazy_hashes(&mut self) -> Result<(), Error> {
// OPT: yeah, storing the whole set of blocks via collect is not great
for item in MerkleBlobLeftChildFirstIterator::new(&self.blob).collect::<Vec<_>>() {
for item in MerkleBlobLeftChildFirstIterator::new(&self.blob, None).collect::<Vec<_>>() {
let (index, mut block) = item?;
// OPT: really want a pruned traversal, not filter
if !block.metadata.dirty {
Expand Down Expand Up @@ -1505,14 +1526,46 @@ impl MerkleBlob {
layers,
})
}

pub fn get_node_by_hash(&self, node_hash: Hash) -> Result<(KeyId, ValueId), Error> {
let Some(index) = self.leaf_hash_to_index.get(&node_hash) else {
return Err(Error::LeafHashNotFound(node_hash));
};

let node = self
.get_node(*index)?
.expect_leaf("should only have leaves in the leaf hash to index cache");

Ok((node.key, node.value))
}

pub fn get_hashes_indexes(&self, leafs_only: bool) -> Result<HashMap<Hash, TreeIndex>, Error> {
let mut hash_to_index = HashMap::new();

if self.blob.is_empty() {
return Ok(hash_to_index);
}

for item in MerkleBlobParentFirstIterator::new(&self.blob, None) {
let (index, block) = item?;

if leafs_only && block.metadata.node_type != NodeType::Leaf {
continue;
}

hash_to_index.insert(block.node.hash(), index);
}

Ok(hash_to_index)
}
}

impl PartialEq for MerkleBlob {
fn eq(&self, other: &Self) -> bool {
// NOTE: this is checking tree structure equality, not serialized bytes equality
for item in zip(
MerkleBlobLeftChildFirstIterator::new(&self.blob),
MerkleBlobLeftChildFirstIterator::new(&other.blob),
MerkleBlobLeftChildFirstIterator::new(&self.blob, None),
MerkleBlobLeftChildFirstIterator::new(&other.blob, None),
) {
let (Ok((_, self_block)), Ok((_, other_block))) = item else {
// TODO: it's an error though, hmm
Expand Down Expand Up @@ -1631,11 +1684,15 @@ impl MerkleBlob {
Ok(list.into())
}

#[pyo3(name = "get_nodes_with_indexes")]
pub fn py_get_nodes_with_indexes(&self, py: Python<'_>) -> PyResult<pyo3::PyObject> {
#[pyo3(name = "get_nodes_with_indexes", signature = (index=None))]
pub fn py_get_nodes_with_indexes(
&self,
index: Option<TreeIndex>,
py: Python<'_>,
) -> PyResult<pyo3::PyObject> {
let list = pyo3::types::PyList::empty(py);

for item in MerkleBlobParentFirstIterator::new(&self.blob) {
for item in MerkleBlobParentFirstIterator::new(&self.blob, index) {
let (index, block) = item?;
list.append((index.into_pyobject(py)?, block.node.into_pyobject(py)?))?;
}
Expand Down Expand Up @@ -1711,6 +1768,27 @@ impl MerkleBlob {
pub fn py_get_proof_of_inclusion(&self, key: KeyId) -> PyResult<ProofOfInclusion> {
Ok(self.get_proof_of_inclusion(key)?)
}

#[pyo3(name = "get_node_by_hash")]
pub fn py_get_node_by_hash(&self, node_hash: Hash) -> PyResult<(KeyId, ValueId)> {
Ok(self.get_node_by_hash(node_hash)?)
}

#[pyo3(name = "get_hashes_indexes", signature = (leafs_only=false))]
pub fn py_get_hashes_indexes(&self, leafs_only: bool) -> PyResult<HashMap<Hash, TreeIndex>> {
Ok(self.get_hashes_indexes(leafs_only)?)
}

#[pyo3(name = "get_random_leaf_node")]
pub fn py_get_random_leaf_node(&self, seed: &[u8]) -> PyResult<LeafNode> {
let insert_location = self.get_random_insert_location_by_seed(seed)?;
let InsertLocation::Leaf { index, side: _ } = insert_location else {
// TODO: real error
return Err(PyValueError::new_err(""));
};

Ok(self.get_node(index)?.expect_leaf("matched leaf above"))
}
}

fn try_get_block(blob: &[u8], index: TreeIndex) -> Result<Block, Error> {
Expand All @@ -1737,12 +1815,13 @@ pub struct MerkleBlobLeftChildFirstIterator<'a> {
}

impl<'a> MerkleBlobLeftChildFirstIterator<'a> {
fn new(blob: &'a Vec<u8>) -> Self {
fn new(blob: &'a Vec<u8>, from_index: Option<TreeIndex>) -> Self {
let mut deque = VecDeque::new();
let from_index = from_index.unwrap_or(TreeIndex(0));
if blob.len() / BLOCK_SIZE > 0 {
deque.push_back(MerkleBlobLeftChildFirstIteratorItem {
visited: false,
index: TreeIndex(0),
index: from_index,
});
}

Expand Down Expand Up @@ -1804,10 +1883,11 @@ pub struct MerkleBlobParentFirstIterator<'a> {
}

impl<'a> MerkleBlobParentFirstIterator<'a> {
fn new(blob: &'a Vec<u8>) -> Self {
fn new(blob: &'a Vec<u8>, from_index: Option<TreeIndex>) -> Self {
let mut deque = VecDeque::new();
let from_index = from_index.unwrap_or(TreeIndex(0));
if blob.len() / BLOCK_SIZE > 0 {
deque.push_back(TreeIndex(0));
deque.push_back(from_index);
}

Self {
Expand Down Expand Up @@ -1852,10 +1932,11 @@ pub struct MerkleBlobBreadthFirstIterator<'a> {

impl<'a> MerkleBlobBreadthFirstIterator<'a> {
#[allow(unused)]
fn new(blob: &'a Vec<u8>) -> Self {
fn new(blob: &'a Vec<u8>, from_index: Option<TreeIndex>) -> Self {
let mut deque = VecDeque::new();
let from_index = from_index.unwrap_or(TreeIndex(0));
if blob.len() / BLOCK_SIZE > 0 {
deque.push_back(TreeIndex(0));
deque.push_back(from_index);
}

Self {
Expand Down Expand Up @@ -2307,7 +2388,7 @@ mod tests {
let mut blob = small_blob.blob.clone();
let expected_free_index = TreeIndex((blob.len() / BLOCK_SIZE) as u32);
blob.extend_from_slice(&[0; BLOCK_SIZE]);
let (free_indexes, _) = get_free_indexes_and_keys_values_indexes(&blob).unwrap();
let (free_indexes, _, _) = get_free_indexes_and_keys_values_indexes(&blob).unwrap();
assert_eq!(free_indexes, HashSet::from([expected_free_index]));
}

Expand Down Expand Up @@ -2338,15 +2419,15 @@ mod tests {
#[rstest]
fn test_upsert_upserts(mut small_blob: MerkleBlob) {
let before_blocks =
MerkleBlobLeftChildFirstIterator::new(&small_blob.blob).collect::<Vec<_>>();
MerkleBlobLeftChildFirstIterator::new(&small_blob.blob, None).collect::<Vec<_>>();
let (key, index) = small_blob.key_to_index.iter().next().unwrap();
let original = small_blob.get_node(*index).unwrap().expect_leaf("<<self>>");
let new_value = ValueId(original.value.0 + 1);

small_blob.upsert(*key, new_value, &original.hash).unwrap();

let after_blocks =
MerkleBlobLeftChildFirstIterator::new(&small_blob.blob).collect::<Vec<_>>();
MerkleBlobLeftChildFirstIterator::new(&small_blob.blob, None).collect::<Vec<_>>();

assert_eq!(before_blocks.len(), after_blocks.len());
for item in zip(before_blocks, after_blocks) {
Expand Down Expand Up @@ -2656,7 +2737,7 @@ mod tests {
#[case] expected: Expect,
#[by_ref] traversal_blob: &'a MerkleBlob,
) where
F: Fn(&'a Vec<u8>) -> T,
F: Fn(&'a Vec<u8>, Option<TreeIndex>) -> T,
T: Iterator<Item = Result<(TreeIndex, Block), Error>>,
{
let mut dot_actual = traversal_blob.to_dot().unwrap();
Expand All @@ -2665,7 +2746,7 @@ mod tests {
let mut actual = vec![];
{
let blob: &Vec<u8> = &traversal_blob.blob;
for item in iterator_new(blob) {
for item in iterator_new(blob, None) {
let (index, block) = item.unwrap();
actual.push(iterator_test_reference(index, &block));
dot_actual.push_traversal(index);
Expand Down
2 changes: 1 addition & 1 deletion crates/chia-datalayer/src/merkle/dot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ impl Node {
impl MerkleBlob {
pub fn to_dot(&self) -> Result<DotLines, Error> {
let mut result = DotLines::new();
for item in MerkleBlobLeftChildFirstIterator::new(&self.blob) {
for item in MerkleBlobLeftChildFirstIterator::new(&self.blob, None) {
let (index, block) = item?;
result.push(block.node.to_dot(index));
}
Expand Down
8 changes: 7 additions & 1 deletion wheel/python/chia_rs/datalayer.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class StreamingError(Exception): ...
class IndexIsNotAChildError(Exception): ...
class CycleFoundError(Exception): ...
class BlockIndexOutOfBoundsError(Exception): ...
class LeafHashNotFoundError(Exception): ...

@final
class KeyId:
Expand Down Expand Up @@ -202,6 +203,8 @@ class MerkleBlob:
@property
def key_to_index(self) -> Mapping[KeyId, TreeIndex]: ...
@property
def leaf_hash_to_index(self) -> Mapping[bytes32, TreeIndex]: ...
@property
def check_integrity_on_drop(self) -> bool: ...

def __init__(
Expand All @@ -215,14 +218,17 @@ class MerkleBlob:
def get_raw_node(self, index: TreeIndex) -> Union[InternalNode, LeafNode]: ...
def calculate_lazy_hashes(self) -> None: ...
def get_lineage_with_indexes(self, index: TreeIndex) -> list[tuple[TreeIndex, Union[InternalNode, LeafNode]]]:...
def get_nodes_with_indexes(self) -> list[tuple[TreeIndex, Union[InternalNode, LeafNode]]]: ...
def get_nodes_with_indexes(self, index: Optional[TreeIndex] = ...) -> list[tuple[TreeIndex, Union[InternalNode, LeafNode]]]: ...
def empty(self) -> bool: ...
def get_root_hash(self) -> bytes32: ...
def batch_insert(self, keys_values: list[tuple[KeyId, ValueId]], hashes: list[bytes32]): ...
def get_hash_at_index(self, index: TreeIndex): ...
def get_keys_values(self) -> dict[KeyId, ValueId]: ...
def get_key_index(self, key: KeyId) -> TreeIndex: ...
def get_proof_of_inclusion(self, key: KeyId) -> ProofOfInclusion: ...
def get_node_by_hash(self, node_hash: bytes32) -> tuple[KeyId, ValueId]: ...
def get_hashes_indexes(self, leafs_only: bool = ...) -> dict[bytes32, TreeIndex]: ...
def get_random_leaf_node(self, seed: bytes) -> LeafNode: ...

def __len__(self) -> int: ...

Expand Down

0 comments on commit 8acbf89

Please sign in to comment.