diff --git a/src/run_program.rs b/src/run_program.rs index c9f51ddf..d8cfb7fe 100644 --- a/src/run_program.rs +++ b/src/run_program.rs @@ -1,5 +1,5 @@ -use super::traverse_path::traverse_path; -use crate::allocator::{Allocator, Checkpoint, NodePtr, SExp}; +use super::traverse_path::{traverse_path, traverse_path_fast}; +use crate::allocator::{Allocator, Checkpoint, NodePtr, NodeVisitor, SExp}; use crate::cost::Cost; use crate::dialect::{Dialect, OperatorSet}; use crate::err_utils::err; @@ -279,7 +279,15 @@ impl<'a, D: Dialect> RunProgramContext<'a, D> { // put a bunch of ops on op_stack let SExp::Pair(op_node, op_list) = self.allocator.sexp(program) else { // the program is just a bitfield path through the env tree - let r: Reduction = traverse_path(self.allocator, self.allocator.atom(program), env)?; + let r: Reduction = self.allocator.visit_node(program, |node| -> Response { + match node { + NodeVisitor::Buffer(buf) => traverse_path(self.allocator, buf, env), + NodeVisitor::U32(val) => traverse_path_fast(self.allocator, *val, env), + NodeVisitor::Pair(_, _) => { + panic!("expected atom, got pair"); + } + } + })?; self.push(r.1)?; return Ok(r.0); }; diff --git a/src/traverse_path.rs b/src/traverse_path.rs index 27127397..fa47b0dc 100644 --- a/src/traverse_path.rs +++ b/src/traverse_path.rs @@ -72,6 +72,42 @@ pub fn traverse_path(allocator: &Allocator, node_index: &[u8], args: NodePtr) -> Ok(Reduction(cost, arg_list)) } +// The cost calculation for this version of traverse_path assumes the node_index has the canonical +// integer representation (which is true for SmallAtom in the allocator). If there are any +// redundant leading zeros, the slow path must be used +pub fn traverse_path_fast(allocator: &Allocator, mut node_index: u32, args: NodePtr) -> Response { + if node_index == 0 { + return Ok(Reduction( + TRAVERSE_BASE_COST + TRAVERSE_COST_PER_BIT, + allocator.nil(), + )); + } + + let mut arg_list: NodePtr = args; + + let mut cost: Cost = TRAVERSE_BASE_COST + TRAVERSE_COST_PER_BIT; + let mut num_bits = 0; + while node_index != 1 { + let SExp::Pair(left, right) = allocator.sexp(arg_list) else { + return Err(EvalErr(arg_list, "path into atom".into())); + }; + + let is_bit_set: bool = (node_index & 0x01) != 0; + arg_list = if is_bit_set { right } else { left }; + node_index >>= 1; + num_bits += 1 + } + + cost += num_bits * TRAVERSE_COST_PER_BIT; + // since positive numbers sometimes need a leading zero, e.g. 0x80, 0x8000 etc. We also + // need to add the cost of that leading zero byte + if num_bits == 7 || num_bits == 15 || num_bits == 23 || num_bits == 31 { + cost += TRAVERSE_COST_PER_ZERO_BYTE; + } + + Ok(Reduction(cost, arg_list)) +} + #[test] fn test_msb_mask() { assert_eq!(msb_mask(0x0), 0x0); @@ -166,3 +202,61 @@ fn test_traverse_path() { EvalErr(n2, "path into atom".to_string()) ); } + +#[test] +fn test_traverse_path_fast_fast() { + use crate::allocator::Allocator; + + let mut a = Allocator::new(); + let nul = a.nil(); + let n1 = a.new_atom(&[0, 1, 2]).unwrap(); + let n2 = a.new_atom(&[4, 5, 6]).unwrap(); + + assert_eq!(traverse_path_fast(&a, 0, n1).unwrap(), Reduction(44, nul)); + assert_eq!(traverse_path_fast(&a, 0b1, n1).unwrap(), Reduction(44, n1)); + assert_eq!(traverse_path_fast(&a, 0b1, n2).unwrap(), Reduction(44, n2)); + + let n3 = a.new_pair(n1, n2).unwrap(); + assert_eq!(traverse_path_fast(&a, 0b1, n3).unwrap(), Reduction(44, n3)); + assert_eq!(traverse_path_fast(&a, 0b10, n3).unwrap(), Reduction(48, n1)); + assert_eq!(traverse_path_fast(&a, 0b11, n3).unwrap(), Reduction(48, n2)); + assert_eq!(traverse_path_fast(&a, 0b11, n3).unwrap(), Reduction(48, n2)); + + let list = a.new_pair(n1, nul).unwrap(); + let list = a.new_pair(n2, list).unwrap(); + + assert_eq!( + traverse_path_fast(&a, 0b10, list).unwrap(), + Reduction(48, n2) + ); + assert_eq!( + traverse_path_fast(&a, 0b101, list).unwrap(), + Reduction(52, n1) + ); + assert_eq!( + traverse_path_fast(&a, 0b111, list).unwrap(), + Reduction(52, nul) + ); + + // errors + assert_eq!( + traverse_path_fast(&a, 0b1011, list).unwrap_err(), + EvalErr(nul, "path into atom".to_string()) + ); + assert_eq!( + traverse_path_fast(&a, 0b1101, list).unwrap_err(), + EvalErr(n1, "path into atom".to_string()) + ); + assert_eq!( + traverse_path_fast(&a, 0b1001, list).unwrap_err(), + EvalErr(n1, "path into atom".to_string()) + ); + assert_eq!( + traverse_path_fast(&a, 0b1010, list).unwrap_err(), + EvalErr(n2, "path into atom".to_string()) + ); + assert_eq!( + traverse_path_fast(&a, 0b1110, list).unwrap_err(), + EvalErr(n2, "path into atom".to_string()) + ); +}