diff --git a/src/lib.rs b/src/lib.rs index 6fb2e0e1..2b9e15ff 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,17 +8,17 @@ use std::thread; use fancy_regex::Regex; use pyo3::exceptions; use pyo3::prelude::*; +use pyo3::pyclass; use pyo3::PyResult; use pyo3::types::{PyBytes, PyList, PyTuple}; use rustc_hash::FxHashMap as HashMap; type Rank = u32; -fn _byte_pair_merge( - piece: &[u8], +fn _byte_pair_merge( ranks: &HashMap, Rank>, - f: impl Fn(std::ops::Range) -> T, -) -> Vec { + piece: &[u8], +) -> Vec<(usize, Rank)> { // This is a vector of (start, rank). // The rank is of the byte pair starting at position start. // The rank of the last item in the vector is not a valid value. @@ -93,25 +93,24 @@ fn _byte_pair_merge( break; } } - let mut out: Vec = Vec::with_capacity(parts.len() - 1); - for i in 0..parts.len() - 1 { - out.push(f(parts[i].0..parts[i + 1].0)); - } - out + + parts } pub fn byte_pair_encode(piece: &[u8], ranks: &HashMap, Rank>) -> Vec { - if piece.len() == 1 { - return vec![ranks[piece]]; - } - _byte_pair_merge(piece, ranks, |p| ranks[&piece[p.start..p.end]]) + assert!(piece.len() > 1); + _byte_pair_merge(&ranks, &piece) + .windows(2) + .map(|part| ranks[&piece[part[0].0..part[1].0]]) + .collect() } pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap, Rank>) -> Vec<&'a [u8]> { - if piece.len() == 1 { - return vec![piece]; - } - _byte_pair_merge(piece, ranks, |p| &piece[p.start..p.end]) + assert!(piece.len() > 1); + _byte_pair_merge(&ranks, &piece) + .windows(2) + .map(|part| &piece[part[0].0..part[1].0]) + .collect() } // Various performance notes: