Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add a few missing datalayer helpers for python #936

Merged
merged 5 commits into from
Feb 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 104 additions & 23 deletions crates/chia-datalayer/src/merkle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,12 @@ create_errors!(
BlockIndexOutOfBoundsError,
"block index out of bounds: {0}",
(TreeIndex)
),
(
LeafHashNotFound,
LeafHashNotFoundError,
"leaf hash not found: {0:?}",
(Hash)
)
)
);
Expand Down Expand Up @@ -641,20 +647,31 @@ impl Block {
}
}

// TODO: take the encouragement to clean up and make clippy happy again
#[allow(clippy::type_complexity)]
fn get_free_indexes_and_keys_values_indexes(
blob: &Vec<u8>,
) -> Result<(HashSet<TreeIndex>, HashMap<KeyId, TreeIndex>), Error> {
) -> Result<
(
HashSet<TreeIndex>,
HashMap<KeyId, TreeIndex>,
HashMap<Hash, TreeIndex>,
),
Error,
> {
let index_count = blob.len() / BLOCK_SIZE;

let mut seen_indexes: Vec<bool> = vec![false; index_count];
let mut key_to_index: HashMap<KeyId, TreeIndex> = HashMap::default();
let mut leaf_hash_to_index: HashMap<Hash, TreeIndex> = HashMap::default();

for item in MerkleBlobLeftChildFirstIterator::new(blob) {
for item in MerkleBlobLeftChildFirstIterator::new(blob, None) {
let (index, block) = item?;
seen_indexes[index.0 as usize] = true;

if let Node::Leaf(leaf) = block.node {
key_to_index.insert(leaf.key, index);
leaf_hash_to_index.insert(leaf.hash, index);
}
}

Expand All @@ -665,7 +682,7 @@ fn get_free_indexes_and_keys_values_indexes(
}
}

Ok((free_indexes, key_to_index))
Ok((free_indexes, key_to_index, leaf_hash_to_index))
}

/// Stores a DataLayer merkle tree in bytes and provides serialization on each access so that only
Expand All @@ -682,6 +699,7 @@ pub struct MerkleBlob {
// TODO: would be nice for this to be deterministic ala a fifo set
free_indexes: HashSet<TreeIndex>,
key_to_index: HashMap<KeyId, TreeIndex>,
leaf_hash_to_index: HashMap<Hash, TreeIndex>,
// TODO: used by fuzzing, some cleaner way? making it cfg-dependent is annoying with
// the type stubs
pub check_integrity_on_drop: bool,
Expand All @@ -696,12 +714,14 @@ impl MerkleBlob {
}

// TODO: maybe integrate integrity check here if quick enough
let (free_indexes, key_to_index) = get_free_indexes_and_keys_values_indexes(&blob)?;
let (free_indexes, key_to_index, leaf_hash_to_index) =
get_free_indexes_and_keys_values_indexes(&blob)?;

let self_ = Self {
blob,
free_indexes,
key_to_index,
leaf_hash_to_index,
check_integrity_on_drop: true,
};

Expand All @@ -712,6 +732,7 @@ impl MerkleBlob {
self.blob.clear();
self.key_to_index.clear();
self.free_indexes.clear();
self.leaf_hash_to_index.clear();
}

pub fn insert(
Expand Down Expand Up @@ -1080,7 +1101,7 @@ impl MerkleBlob {
}

fn get_min_height_leaf(&self) -> Result<LeafNode, Error> {
let (_index, block) = MerkleBlobBreadthFirstIterator::new(&self.blob)
let (_index, block) = MerkleBlobBreadthFirstIterator::new(&self.blob, None)
.next()
.ok_or(Error::UnableToFindALeaf())??;

Expand Down Expand Up @@ -1169,7 +1190,7 @@ impl MerkleBlob {
let mut internal_count: usize = 0;
let mut child_to_parent: HashMap<TreeIndex, TreeIndex> = HashMap::new();

for item in MerkleBlobParentFirstIterator::new(&self.blob) {
for item in MerkleBlobParentFirstIterator::new(&self.blob, None) {
let (index, block) = item?;
if let Some(parent) = block.node.parent().0 {
if child_to_parent.remove(&index) != Some(parent) {
Expand Down Expand Up @@ -1433,7 +1454,7 @@ impl MerkleBlob {

pub fn calculate_lazy_hashes(&mut self) -> Result<(), Error> {
// OPT: yeah, storing the whole set of blocks via collect is not great
for item in MerkleBlobLeftChildFirstIterator::new(&self.blob).collect::<Vec<_>>() {
for item in MerkleBlobLeftChildFirstIterator::new(&self.blob, None).collect::<Vec<_>>() {
let (index, mut block) = item?;
// OPT: really want a pruned traversal, not filter
if !block.metadata.dirty {
Expand Down Expand Up @@ -1505,14 +1526,46 @@ impl MerkleBlob {
layers,
})
}

pub fn get_node_by_hash(&self, node_hash: Hash) -> Result<(KeyId, ValueId), Error> {
let Some(index) = self.leaf_hash_to_index.get(&node_hash) else {
return Err(Error::LeafHashNotFound(node_hash));
};

let node = self
.get_node(*index)?
.expect_leaf("should only have leaves in the leaf hash to index cache");

Ok((node.key, node.value))
}

pub fn get_hashes_indexes(&self, leafs_only: bool) -> Result<HashMap<Hash, TreeIndex>, Error> {
let mut hash_to_index = HashMap::new();

if self.blob.is_empty() {
return Ok(hash_to_index);
}

for item in MerkleBlobParentFirstIterator::new(&self.blob, None) {
let (index, block) = item?;

if leafs_only && block.metadata.node_type != NodeType::Leaf {
continue;
}

hash_to_index.insert(block.node.hash(), index);
}

Ok(hash_to_index)
}
}

impl PartialEq for MerkleBlob {
fn eq(&self, other: &Self) -> bool {
// NOTE: this is checking tree structure equality, not serialized bytes equality
for item in zip(
MerkleBlobLeftChildFirstIterator::new(&self.blob),
MerkleBlobLeftChildFirstIterator::new(&other.blob),
MerkleBlobLeftChildFirstIterator::new(&self.blob, None),
MerkleBlobLeftChildFirstIterator::new(&other.blob, None),
) {
let (Ok((_, self_block)), Ok((_, other_block))) = item else {
// TODO: it's an error though, hmm
Expand Down Expand Up @@ -1631,11 +1684,15 @@ impl MerkleBlob {
Ok(list.into())
}

#[pyo3(name = "get_nodes_with_indexes")]
pub fn py_get_nodes_with_indexes(&self, py: Python<'_>) -> PyResult<pyo3::PyObject> {
#[pyo3(name = "get_nodes_with_indexes", signature = (index=None))]
pub fn py_get_nodes_with_indexes(
&self,
index: Option<TreeIndex>,
py: Python<'_>,
) -> PyResult<pyo3::PyObject> {
let list = pyo3::types::PyList::empty(py);

for item in MerkleBlobParentFirstIterator::new(&self.blob) {
for item in MerkleBlobParentFirstIterator::new(&self.blob, index) {
let (index, block) = item?;
list.append((index.into_pyobject(py)?, block.node.into_pyobject(py)?))?;
}
Expand Down Expand Up @@ -1711,6 +1768,27 @@ impl MerkleBlob {
pub fn py_get_proof_of_inclusion(&self, key: KeyId) -> PyResult<ProofOfInclusion> {
Ok(self.get_proof_of_inclusion(key)?)
}

#[pyo3(name = "get_node_by_hash")]
pub fn py_get_node_by_hash(&self, node_hash: Hash) -> PyResult<(KeyId, ValueId)> {
Ok(self.get_node_by_hash(node_hash)?)
}

#[pyo3(name = "get_hashes_indexes", signature = (leafs_only=false))]
pub fn py_get_hashes_indexes(&self, leafs_only: bool) -> PyResult<HashMap<Hash, TreeIndex>> {
Ok(self.get_hashes_indexes(leafs_only)?)
}

#[pyo3(name = "get_random_leaf_node")]
pub fn py_get_random_leaf_node(&self, seed: &[u8]) -> PyResult<LeafNode> {
let insert_location = self.get_random_insert_location_by_seed(seed)?;
let InsertLocation::Leaf { index, side: _ } = insert_location else {
// TODO: real error
return Err(PyValueError::new_err(""));
};

Ok(self.get_node(index)?.expect_leaf("matched leaf above"))
}
}

fn try_get_block(blob: &[u8], index: TreeIndex) -> Result<Block, Error> {
Expand All @@ -1737,12 +1815,13 @@ pub struct MerkleBlobLeftChildFirstIterator<'a> {
}

impl<'a> MerkleBlobLeftChildFirstIterator<'a> {
fn new(blob: &'a Vec<u8>) -> Self {
fn new(blob: &'a Vec<u8>, from_index: Option<TreeIndex>) -> Self {
let mut deque = VecDeque::new();
let from_index = from_index.unwrap_or(TreeIndex(0));
if blob.len() / BLOCK_SIZE > 0 {
deque.push_back(MerkleBlobLeftChildFirstIteratorItem {
visited: false,
index: TreeIndex(0),
index: from_index,
});
}

Expand Down Expand Up @@ -1804,10 +1883,11 @@ pub struct MerkleBlobParentFirstIterator<'a> {
}

impl<'a> MerkleBlobParentFirstIterator<'a> {
fn new(blob: &'a Vec<u8>) -> Self {
fn new(blob: &'a Vec<u8>, from_index: Option<TreeIndex>) -> Self {
let mut deque = VecDeque::new();
let from_index = from_index.unwrap_or(TreeIndex(0));
if blob.len() / BLOCK_SIZE > 0 {
deque.push_back(TreeIndex(0));
deque.push_back(from_index);
}

Self {
Expand Down Expand Up @@ -1852,10 +1932,11 @@ pub struct MerkleBlobBreadthFirstIterator<'a> {

impl<'a> MerkleBlobBreadthFirstIterator<'a> {
#[allow(unused)]
fn new(blob: &'a Vec<u8>) -> Self {
fn new(blob: &'a Vec<u8>, from_index: Option<TreeIndex>) -> Self {
let mut deque = VecDeque::new();
let from_index = from_index.unwrap_or(TreeIndex(0));
if blob.len() / BLOCK_SIZE > 0 {
deque.push_back(TreeIndex(0));
deque.push_back(from_index);
}

Self {
Expand Down Expand Up @@ -2307,7 +2388,7 @@ mod tests {
let mut blob = small_blob.blob.clone();
let expected_free_index = TreeIndex((blob.len() / BLOCK_SIZE) as u32);
blob.extend_from_slice(&[0; BLOCK_SIZE]);
let (free_indexes, _) = get_free_indexes_and_keys_values_indexes(&blob).unwrap();
let (free_indexes, _, _) = get_free_indexes_and_keys_values_indexes(&blob).unwrap();
assert_eq!(free_indexes, HashSet::from([expected_free_index]));
}

Expand Down Expand Up @@ -2338,15 +2419,15 @@ mod tests {
#[rstest]
fn test_upsert_upserts(mut small_blob: MerkleBlob) {
let before_blocks =
MerkleBlobLeftChildFirstIterator::new(&small_blob.blob).collect::<Vec<_>>();
MerkleBlobLeftChildFirstIterator::new(&small_blob.blob, None).collect::<Vec<_>>();
let (key, index) = small_blob.key_to_index.iter().next().unwrap();
let original = small_blob.get_node(*index).unwrap().expect_leaf("<<self>>");
let new_value = ValueId(original.value.0 + 1);

small_blob.upsert(*key, new_value, &original.hash).unwrap();

let after_blocks =
MerkleBlobLeftChildFirstIterator::new(&small_blob.blob).collect::<Vec<_>>();
MerkleBlobLeftChildFirstIterator::new(&small_blob.blob, None).collect::<Vec<_>>();

assert_eq!(before_blocks.len(), after_blocks.len());
for item in zip(before_blocks, after_blocks) {
Expand Down Expand Up @@ -2656,7 +2737,7 @@ mod tests {
#[case] expected: Expect,
#[by_ref] traversal_blob: &'a MerkleBlob,
) where
F: Fn(&'a Vec<u8>) -> T,
F: Fn(&'a Vec<u8>, Option<TreeIndex>) -> T,
T: Iterator<Item = Result<(TreeIndex, Block), Error>>,
{
let mut dot_actual = traversal_blob.to_dot().unwrap();
Expand All @@ -2665,7 +2746,7 @@ mod tests {
let mut actual = vec![];
{
let blob: &Vec<u8> = &traversal_blob.blob;
for item in iterator_new(blob) {
for item in iterator_new(blob, None) {
let (index, block) = item.unwrap();
actual.push(iterator_test_reference(index, &block));
dot_actual.push_traversal(index);
Expand Down
2 changes: 1 addition & 1 deletion crates/chia-datalayer/src/merkle/dot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ impl Node {
impl MerkleBlob {
pub fn to_dot(&self) -> Result<DotLines, Error> {
let mut result = DotLines::new();
for item in MerkleBlobLeftChildFirstIterator::new(&self.blob) {
for item in MerkleBlobLeftChildFirstIterator::new(&self.blob, None) {
let (index, block) = item?;
result.push(block.node.to_dot(index));
}
Expand Down
8 changes: 7 additions & 1 deletion wheel/python/chia_rs/datalayer.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class StreamingError(Exception): ...
class IndexIsNotAChildError(Exception): ...
class CycleFoundError(Exception): ...
class BlockIndexOutOfBoundsError(Exception): ...
class LeafHashNotFoundError(Exception): ...

@final
class KeyId:
Expand Down Expand Up @@ -202,6 +203,8 @@ class MerkleBlob:
@property
def key_to_index(self) -> Mapping[KeyId, TreeIndex]: ...
@property
def leaf_hash_to_index(self) -> Mapping[bytes32, TreeIndex]: ...
@property
def check_integrity_on_drop(self) -> bool: ...

def __init__(
Expand All @@ -215,14 +218,17 @@ class MerkleBlob:
def get_raw_node(self, index: TreeIndex) -> Union[InternalNode, LeafNode]: ...
def calculate_lazy_hashes(self) -> None: ...
def get_lineage_with_indexes(self, index: TreeIndex) -> list[tuple[TreeIndex, Union[InternalNode, LeafNode]]]:...
def get_nodes_with_indexes(self) -> list[tuple[TreeIndex, Union[InternalNode, LeafNode]]]: ...
def get_nodes_with_indexes(self, index: Optional[TreeIndex] = ...) -> list[tuple[TreeIndex, Union[InternalNode, LeafNode]]]: ...
def empty(self) -> bool: ...
def get_root_hash(self) -> bytes32: ...
def batch_insert(self, keys_values: list[tuple[KeyId, ValueId]], hashes: list[bytes32]): ...
def get_hash_at_index(self, index: TreeIndex): ...
def get_keys_values(self) -> dict[KeyId, ValueId]: ...
def get_key_index(self, key: KeyId) -> TreeIndex: ...
def get_proof_of_inclusion(self, key: KeyId) -> ProofOfInclusion: ...
def get_node_by_hash(self, node_hash: bytes32) -> tuple[KeyId, ValueId]: ...
def get_hashes_indexes(self, leafs_only: bool = ...) -> dict[bytes32, TreeIndex]: ...
def get_random_leaf_node(self, seed: bytes) -> LeafNode: ...

def __len__(self) -> int: ...

Expand Down
Loading