diff --git a/src/lib.rs b/src/lib.rs index 3d930655..54210bd5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,608 +1,3 @@ // This check is new and seems buggy (possibly with PyO3 interaction) -#![allow(clippy::borrow_deref_ref)] - -use std::collections::HashSet; -use std::thread; - -use fancy_regex::Regex; -use pyo3::exceptions; -use pyo3::prelude::*; -use pyo3::types::{PyBytes, PyList, PyTuple}; -use pyo3::PyResult; -use rustc_hash::FxHashMap as HashMap; - -fn _byte_pair_merge( - piece: &[u8], - ranks: &HashMap, usize>, - f: impl Fn(std::ops::Range) -> T, -) -> Vec { - // This is a vector of (start, rank). - // The rank is of the byte pair starting at position start. - // The rank of the last item in the vector is not a valid value. - let mut parts: Vec<(usize, usize)> = (0..piece.len() + 1).map(|i| (i, usize::MAX)).collect(); - - let get_rank = { - #[inline(always)] - |parts: &Vec<(usize, usize)>, start_idx: usize, skip: usize| { - if (start_idx + skip + 2) < parts.len() { - ranks - .get(&piece[parts[start_idx].0..parts[start_idx + skip + 2].0]) - .copied() - } else { - None - } - } - }; - - // We look up the ranks once in the beginning and iteratively update - // them during each merge, which reduces the number of rank lookups. - for i in 0..parts.len() - 2 { - match get_rank(&parts, i, 0) { - Some(rank) => { - // usize::MAX is a sentinel value and cannot be a valid rank - debug_assert!(rank != usize::MAX); - parts[i].1 = rank; - } - None => { - continue; - } - }; - } - - // If you have n parts and m merges, this does O(mn) work. - // We could do something with a heap and do O(m log n) work. - // It is important to consider that n is often small (<100), and as such - // the cache-locality benefits outweigh the algorithmic complexity downsides - // of the `parts` vector data structure above. - - // Note that we hash bytes, not token pairs. As long as we train BPE the way we - // currently do, this is equivalent. An easy way to break this would be to decouple - // merge priority from token index or to prevent specific token merges. - loop { - if parts.len() == 1 { - break; - } - - // usize::MAX is a sentinel rank value allowing us to - // take the min more quickly - let mut min_rank: (usize, usize) = (usize::MAX, 0); - for (i, &(_, rank)) in parts[..parts.len() - 1].iter().enumerate() { - if rank < min_rank.0 { - min_rank = (rank, i); - } - } - - if min_rank.0 != usize::MAX { - let i = min_rank.1; - - // NOTE: We are about to remove parts[i + 1]. We do not do it - // yet because there are cache-locality benefits to updating - // parts[i] and parts[i-1] before removing, which could thrash - // the cache. Thus, we update the rank calculation by skipping over - // parts[i + 1], by invoking `get_rank!` with `skip = 1`. - parts[i].1 = get_rank(&parts, i, 1).unwrap_or(usize::MAX); - if i > 0 { - parts[i - 1].1 = get_rank(&parts, i - 1, 1).unwrap_or(usize::MAX); - } - - parts.remove(i + 1); - } else { - break; - } - } - let mut out: Vec = Vec::with_capacity(parts.len() - 1); - for i in 0..parts.len() - 1 { - out.push(f(parts[i].0..parts[i + 1].0)); - } - out -} - -pub fn byte_pair_encode(piece: &[u8], ranks: &HashMap, usize>) -> Vec { - if piece.len() == 1 { - return vec![ranks[piece]]; - } - _byte_pair_merge(piece, ranks, |p| ranks[&piece[p.start..p.end]]) -} - -pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap, usize>) -> Vec<&'a [u8]> { - if piece.len() == 1 { - return vec![piece]; - } - _byte_pair_merge(piece, ranks, |p| &piece[p.start..p.end]) -} - -// Various performance notes: -// -// Regex -// ===== -// Most of the time is spent in regex. The easiest way to speed this up is by using less fancy -// regex features. For instance, using a regex parse-able by `regex` crate is 3x faster than -// the usual regex we use. -// -// However, given that we're using a regex parse-able by `regex`, there isn't much difference -// between using the `regex` crate and using the `fancy_regex` crate. -// -// There is an important interaction between threading, `regex` and `fancy_regex`. -// When using `fancy_regex`, we hit `regex.find_at`. It turns out that this causes contention on -// some mutable scratch space inside of `regex`. This absolutely kills performance. When using plain -// old `regex`, we don't hit this, because `find_iter` has a different code path. -// Related: https://github.com/rust-lang/regex/blob/master/PERFORMANCE.md -// Anyway, the way we get around this is with having a (mostly) thread local clone of the regex for -// each thread. -// -// Threading -// ========= -// I tried using `rayon`. It wasn't really faster than using Python threads and releasing the GIL. -// So goodbye `rayon`! Let thread count etc be in control of our Python users. -// -// Caching -// ======= -// The reference tokeniser has an lru cache over the equivalent of `byte_pair_encode`. -// Originally, we had one too! Without it, we were only vaguely faster than Python. -// I used an RWLock to protect the cache. This didn't seem to hurt single threaded performance -// noticeably, but it did affect multi-threaded performance. Weirdly, it seemed to affect -// multi-threaded performance even when I only had readers (maybed I messed something up?). -// Anyway, I realised that we could get rid of the cache, if we treat the set of tokens as a cache! -// These are exactly the set or merges that are likely to be hot. And now we don't have to think -// about interior mutability, memory use, or cloning. -// -// Hashing -// ======= -// We use FxHashMap instead of the standard HashMap. This is maybe like a 5-10% win? -// The current implementation ends up doing a lot of hashing of bytes. In theory, this could be made -// to be hashing of two-tuples of ints, which looks like it may also be a couple percent faster. - -use std::num::NonZeroU64; -pub struct FakeThreadId(NonZeroU64); - -fn hash_current_thread() -> usize { - // It's easier to use unsafe than to use nightly. Rust has this nice u64 thread id counter - // that works great for our use case of avoiding collisions in our array. Unfortunately, - // it's private. However, there are only so many ways you can layout a u64, so just transmute - // https://github.com/rust-lang/rust/issues/67939 - const _: [u8; 8] = [0; std::mem::size_of::()]; - const _: [u8; 8] = [0; std::mem::size_of::()]; - let x = unsafe { - std::mem::transmute::(thread::current().id()).0 - }; - u64::from(x) as usize -} - -const MAX_NUM_THREADS: usize = 128; -#[pyclass] -struct CoreBPE { - encoder: HashMap, usize>, - special_tokens_encoder: HashMap, - decoder: HashMap>, - special_tokens_decoder: HashMap>, - regex_tls: Vec, - special_regex_tls: Vec, - sorted_token_bytes: Vec>, -} - -impl CoreBPE { - fn _get_tl_regex(&self) -> &Regex { - // See performance notes above for what this is about - // It's also a little janky, please make a better version of it! - // However, it's nice that this doesn't leak memory to short-lived threads - &self.regex_tls[hash_current_thread() % MAX_NUM_THREADS] - } - - fn _get_tl_special_regex(&self) -> &Regex { - &self.special_regex_tls[hash_current_thread() % MAX_NUM_THREADS] - } - - fn _decode_native(&self, tokens: &[usize]) -> Vec { - let mut ret = Vec::with_capacity(tokens.len() * 2); - for token in tokens { - let token_bytes = self - .decoder - .get(token) - .unwrap_or_else(|| &self.special_tokens_decoder[token]); - ret.extend(token_bytes); - } - ret - } - - fn _encode_ordinary_native(&self, text: &str) -> Vec { - // This is the core of the encoding logic; the other functions in here - // just make things complicated :-) - let regex = self._get_tl_regex(); - let mut ret = vec![]; - for mat in regex.find_iter(text) { - let piece = mat.unwrap().as_str().as_bytes(); - if let Some(token) = self.encoder.get(piece) { - ret.push(*token); - continue; - } - ret.extend(&byte_pair_encode(piece, &self.encoder)); - } - ret - } - - fn _encode_native(&self, text: &str, allowed_special: &HashSet<&str>) -> (Vec, usize) { - let special_regex = self._get_tl_special_regex(); - let regex = self._get_tl_regex(); - let mut ret = vec![]; - - let mut start = 0; - let mut last_piece_token_len = 0; - loop { - let mut next_special; - let mut start_find = start; - loop { - // Find the next allowed special token, if any - next_special = special_regex.find_from_pos(text, start_find).unwrap(); - match next_special { - Some(m) => { - if allowed_special.contains(&text[m.start()..m.end()]) { - break; - } - start_find = m.start() + 1; - } - None => break, - } - } - let end = next_special.map_or(text.len(), |m| m.start()); - - // Okay, here we go, compare this logic to _encode_ordinary_native - for mat in regex.find_iter(&text[start..end]) { - let piece = mat.unwrap().as_str().as_bytes(); - if let Some(token) = self.encoder.get(piece) { - last_piece_token_len = 1; - ret.push(*token); - continue; - } - let tokens = byte_pair_encode(piece, &self.encoder); - last_piece_token_len = tokens.len(); - ret.extend(&tokens); - } - - match next_special { - // And here we push the special token - Some(m) => { - let piece = m.as_str(); - let token = self.special_tokens_encoder[piece]; - ret.push(token); - start = m.end(); - last_piece_token_len = 0; - } - None => break, - } - } - - // last_piece_token_len is how many tokens came from the last regex split. This is used - // for determining unstable tokens, since you can't merge across (stable) regex splits - (ret, last_piece_token_len) - } - - fn _increase_last_piece_token_len( - &self, - tokens: Vec, - mut last_piece_token_len: usize, - ) -> (Vec, usize) { - // Unfortunately, the locations where our regex splits can be unstable. - // For the purposes of determining unstable tokens, unstable regex splitting - // is only a problem if a split that was present disappears, since this can - // lead to merging of tokens otherwise thought to be stable. - // cl100k_base makes our life hard by including the \s*[\r\n]+ - // pattern. This can e.g. cause "\n" + " " to become "\n \n". - // Here is a quick and dirty fix: - { - let token_is_all_space = |token| { - self.decoder - .get(token) - .map(|token_bytes| { - token_bytes - .iter() - .rev() - .all(|&b| [b' ', b'\n', b'\t'].contains(&b)) - }) - .unwrap_or(false) - }; - if last_piece_token_len > 0 - && token_is_all_space(&tokens[tokens.len() - last_piece_token_len]) - { - while (last_piece_token_len < tokens.len()) - && token_is_all_space(&tokens[tokens.len() - last_piece_token_len - 1]) - { - last_piece_token_len += 1; - } - } - } - debug_assert!(last_piece_token_len <= tokens.len()); - - (tokens, last_piece_token_len) - } - - fn _encode_unstable_native( - &self, - text: &str, - allowed_special: &HashSet<&str>, - ) -> (Vec, HashSet>) { - let (tokens, last_piece_token_len) = self._encode_native(text, allowed_special); - if last_piece_token_len == 0 { - // If last_piece_token_len is zero, the last token was a special token and we have - // no unstable bytes - return (tokens, HashSet::new()); - } - let (mut tokens, last_piece_token_len) = - self._increase_last_piece_token_len(tokens, last_piece_token_len); - - let unstable_bytes = self._decode_native(&tokens[tokens.len() - last_piece_token_len..]); - tokens.truncate(tokens.len() - last_piece_token_len); - - // TODO: we should try harder to find additional stable tokens - // This would reduce the amount of retokenising when determining completions - // Refer to the logic in an older version of this file - - let mut completions = HashSet::new(); - if unstable_bytes.is_empty() { - return (tokens, completions); - } - - // This is the easy bit. Just find all single tokens that start with unstable_bytes - // (including tokens that exactly match unstable_bytes) - // Separating this from the loop below helps with performance in a common case. - let mut point = self - .sorted_token_bytes - .partition_point(|x| x.as_slice() < unstable_bytes.as_slice()); - while point < self.sorted_token_bytes.len() - && self.sorted_token_bytes[point].starts_with(&unstable_bytes) - { - completions.insert(vec![ - self.encoder[self.sorted_token_bytes[point].as_slice()], - ]); - point += 1; - } - - // Now apply even more brute force. At every (other) possible position for the straddling - // token, concatenate additional bytes from that token (if any) to unstable_bytes, - // and retokenise the whole thing and see what we get. - for i in 1..unstable_bytes.len() { - let prefix = &unstable_bytes[..i]; - let suffix = &unstable_bytes[i..]; - let mut point = self - .sorted_token_bytes - .partition_point(|x| x.as_slice() < suffix); - // TODO: Perf optimisation if suffix starts with " "? - while point < self.sorted_token_bytes.len() - && self.sorted_token_bytes[point].starts_with(suffix) - { - let possibility = [prefix, self.sorted_token_bytes[point].as_slice()].concat(); - let encoded = match std::str::from_utf8(&possibility) { - // Morally, this is byte_pair_encode(&possibility, &self.encoder) - // But we might have introduced a regex split which would prevent merges. - // (particularly possible in the presence of unstable regex splits) - // So convert to UTF-8 and do regex splitting. - // E.g. with cl100k_base " !" gets split to " " + " !", - // but byte_pair_encode(" !") != byte_pair_encode(" ") - Ok(s) => self._encode_ordinary_native(s), - - // Technically, whether or not this arm is correct depends on whether there - // would be a regex split before the UTF-8 truncation point. - // Probably niche enough that no one will ever notice (after all, people didn't - // notice all the big holes in the previous unstable token implementation) - Err(_) => byte_pair_encode(&possibility, &self.encoder), - // Something like the following is intriguing but incorrect: - // Err(e) => self._encode_ordinary_native(unsafe { - // std::str::from_utf8_unchecked(&possibility[..e.valid_up_to()]) - // }), - }; - let mut seq = Vec::new(); - let mut seq_len = 0; - for token in encoded { - seq.push(token); - seq_len += self.decoder[&token].len(); - if seq_len >= unstable_bytes.len() { - break; - } - } - completions.insert(seq); - point += 1; - } - } - - // This is also not straightforward. While we generally assume that regex splits are stable, - // unfortunately, they are not. That is, if adding bytes were to make a split appear in - // unstable_bytes, this could make tokens possible which our logic would otherwise think - // would be merged. - // For example, with gpt2, the use of \s+(?!\S) means that "\n\n" could - // develop a split, e.g. "\n\n0" splits into "\n"+"\n"+"0", making "\n" a possible token. - // Here is a quick and dirty fix: - // This isn't right if we ever remove \s+(?!\S) - if unstable_bytes.len() > 1 { - let last_decoded = bstr::decode_last_utf8(unstable_bytes.as_slice()); - if unstable_bytes.len() - last_decoded.1 > 0 - && last_decoded.0.map_or(false, |c| c.is_whitespace()) - { - let mut reencoded = byte_pair_encode( - &unstable_bytes[..unstable_bytes.len() - last_decoded.1], - &self.encoder, - ); - reencoded.extend(byte_pair_encode( - &unstable_bytes[unstable_bytes.len() - last_decoded.1..], - &self.encoder, - )); - completions.insert(reencoded); - } - } - - (tokens, completions) - } -} - -#[pymethods] -impl CoreBPE { - #[new] - fn new( - encoder: HashMap, usize>, - special_tokens_encoder: HashMap, - pattern: &str, - ) -> PyResult { - let regex = Regex::new(pattern) - .map_err(|e| PyErr::new::(e.to_string()))?; - - let special_regex = { - let _parts = special_tokens_encoder - .keys() - .map(|s| fancy_regex::escape(s)) - .collect::>(); - Regex::new(&_parts.join("|")) - .map_err(|e| PyErr::new::(e.to_string()))? - }; - - let decoder: HashMap> = - encoder.iter().map(|(k, v)| (*v, k.clone())).collect(); - - assert!( - encoder.len() == decoder.len(), - "Encoder and decoder must be of equal length; maybe you had duplicate token indices in your encoder?" - ); - - let special_tokens_decoder: HashMap> = special_tokens_encoder - .iter() - .map(|(k, v)| (*v, k.as_bytes().to_vec())) - .collect(); - - // Clone because I don't know how to tell Rust I'm not going to change the map - let mut sorted_token_bytes: Vec> = encoder.keys().cloned().collect(); - sorted_token_bytes.sort(); - - Ok(CoreBPE { - encoder, - special_tokens_encoder, - decoder, - special_tokens_decoder, - regex_tls: (0..MAX_NUM_THREADS).map(|_| regex.clone()).collect(), - special_regex_tls: (0..MAX_NUM_THREADS) - .map(|_| special_regex.clone()) - .collect(), - sorted_token_bytes, - }) - } - - // ==================== - // Encoding - // ==================== - - fn encode_ordinary(&self, py: Python, text: &str) -> Vec { - py.allow_threads(|| self._encode_ordinary_native(text)) - } - - fn encode(&self, py: Python, text: &str, allowed_special: HashSet<&str>) -> Vec { - py.allow_threads(|| self._encode_native(text, &allowed_special).0) - } - - fn _encode_bytes(&self, py: Python, bytes: &[u8]) -> Vec { - py.allow_threads(|| { - match std::str::from_utf8(bytes) { - Ok(text) => self._encode_ordinary_native(text), - Err(e) => { - let text = unsafe { std::str::from_utf8_unchecked(&bytes[..e.valid_up_to()]) }; - let (tokens, last_piece_token_len) = self._encode_native(text, &HashSet::new()); - let (mut tokens, last_piece_token_len) = - self._increase_last_piece_token_len(tokens, last_piece_token_len); - if !tokens.is_empty() && last_piece_token_len > 0 { - // Lop off the tokens from the last piece and run BPE on the remaining bytes - // Somewhat niche, but this may not be correct if we'd have had a regex - // split between the valid UTF-8 and the invalid bytes, which is why this - // method is private - let mut unstable_bytes = - self._decode_native(&tokens[tokens.len() - last_piece_token_len..]); - unstable_bytes.extend_from_slice(&bytes[e.valid_up_to()..]); - - tokens.truncate(tokens.len() - last_piece_token_len); - tokens.extend(byte_pair_encode(&unstable_bytes, &self.encoder)); - } - tokens - } - } - }) - } - - fn encode_with_unstable( - &self, - py: Python, - text: &str, - allowed_special: HashSet<&str>, - ) -> Py { - let (tokens, completions) = - py.allow_threads(|| self._encode_unstable_native(text, &allowed_special)); - let py_completions = - PyList::new(py, completions.iter().map(|seq| PyList::new(py, &seq[..]))); - (tokens, py_completions).into_py(py) - } - - fn encode_single_token(&self, piece: &[u8]) -> PyResult { - if let Some(token) = self.encoder.get(piece).copied() { - return Ok(token); - } - if let Ok(piece_str) = std::str::from_utf8(piece) { - if let Some(token) = self.special_tokens_encoder.get(piece_str).copied() { - return Ok(token); - } - } - Err(PyErr::new::(piece.to_owned())) - } - - fn encode_single_piece(&self, piece: &[u8]) -> Vec { - if let Some(token) = self.encoder.get(piece) { - return vec![*token]; - } - byte_pair_encode(piece, &self.encoder) - } - - // ==================== - // Decoding - // ==================== - - fn decode_bytes(&self, py: Python, tokens: Vec) -> Py { - let bytes = py.allow_threads(|| self._decode_native(&tokens)); - PyBytes::new(py, &bytes).into() - } - - fn decode_single_token_bytes(&self, py: Python, token: usize) -> PyResult> { - if let Some(bytes) = self.decoder.get(&token) { - return Ok(PyBytes::new(py, bytes).into()); - } - if let Some(bytes) = self.special_tokens_decoder.get(&token) { - return Ok(PyBytes::new(py, bytes).into()); - } - Err(PyErr::new::(token.to_string())) - } - - // ==================== - // Miscellaneous - // ==================== - - fn token_byte_values(&self, py: Python) -> Vec> { - self.sorted_token_bytes - .iter() - .map(|x| PyBytes::new(py, x).into()) - .collect() - } -} - -#[pymodule] -fn _tiktoken(_py: Python, m: &PyModule) -> PyResult<()> { - m.add_class::()?; - Ok(()) -} - -#[cfg(test)] -mod tests { - use rustc_hash::FxHashMap as HashMap; - - use crate::byte_pair_split; - - #[test] - fn very_simple_test() { - let mut ranks = HashMap::default(); - ranks.insert(b"ab".to_vec(), 1); - ranks.insert(b"cd".to_vec(), 2); - - let res = byte_pair_split(b"abcd", &ranks); - assert_eq!(res, vec![b"ab", b"cd"]); - } -} +pub mod tiktoken_py; +pub mod tiktoken; \ No newline at end of file diff --git a/src/tiktoken.rs b/src/tiktoken.rs new file mode 100644 index 00000000..53c1a075 --- /dev/null +++ b/src/tiktoken.rs @@ -0,0 +1,429 @@ +// This check is new and seems buggy (possibly with PyO3 interaction) +#![allow(clippy::borrow_deref_ref)] + +use std::collections::HashSet; +use fancy_regex::Regex; +use std::num::NonZeroU64; +use std::thread; + +use rustc_hash::FxHashMap as HashMap; + + +fn _byte_pair_merge( + piece: &[u8], + ranks: &HashMap, usize>, + f: impl Fn(std::ops::Range) -> T, +) -> Vec { + // This is a vector of (start, rank). + // The rank is of the byte pair starting at position start. + // The rank of the last item in the vector is not a valid value. + let mut parts: Vec<(usize, usize)> = (0..piece.len() + 1).map(|i| (i, usize::MAX)).collect(); + + let get_rank = { + #[inline(always)] + |parts: &Vec<(usize, usize)>, start_idx: usize, skip: usize| { + if (start_idx + skip + 2) < parts.len() { + ranks + .get(&piece[parts[start_idx].0..parts[start_idx + skip + 2].0]) + .copied() + } else { + None + } + } + }; + + // We look up the ranks once in the beginning and iteratively update + // them during each merge, which reduces the number of rank lookups. + for i in 0..parts.len() - 2 { + match get_rank(&parts, i, 0) { + Some(rank) => { + // usize::MAX is a sentinel value and cannot be a valid rank + debug_assert!(rank != usize::MAX); + parts[i].1 = rank; + } + None => { + continue; + } + }; + } + + // If you have n parts and m merges, this does O(mn) work. + // We could do something with a heap and do O(m log n) work. + // It is important to consider that n is often small (<100), and as such + // the cache-locality benefits outweigh the algorithmic complexity downsides + // of the `parts` vector data structure above. + + // Note that we hash bytes, not token pairs. As long as we train BPE the way we + // currently do, this is equivalent. An easy way to break this would be to decouple + // merge priority from token index or to prevent specific token merges. + loop { + if parts.len() == 1 { + break; + } + + // usize::MAX is a sentinel rank value allowing us to + // take the min more quickly + let mut min_rank: (usize, usize) = (usize::MAX, 0); + for (i, &(_, rank)) in parts[..parts.len() - 1].iter().enumerate() { + if rank < min_rank.0 { + min_rank = (rank, i); + } + } + + if min_rank.0 != usize::MAX { + let i = min_rank.1; + + // NOTE: We are about to remove parts[i + 1]. We do not do it + // yet because there are cache-locality benefits to updating + // parts[i] and parts[i-1] before removing, which could thrash + // the cache. Thus, we update the rank calculation by skipping over + // parts[i + 1], by invoking `get_rank!` with `skip = 1`. + parts[i].1 = get_rank(&parts, i, 1).unwrap_or(usize::MAX); + if i > 0 { + parts[i - 1].1 = get_rank(&parts, i - 1, 1).unwrap_or(usize::MAX); + } + + parts.remove(i + 1); + } else { + break; + } + } + let mut out: Vec = Vec::with_capacity(parts.len() - 1); + for i in 0..parts.len() - 1 { + out.push(f(parts[i].0..parts[i + 1].0)); + } + out +} + +pub fn byte_pair_encode(piece: &[u8], ranks: &HashMap, usize>) -> Vec { + if piece.len() == 1 { + return vec![ranks[piece]]; + } + _byte_pair_merge(piece, ranks, |p| ranks[&piece[p.start..p.end]]) +} + +pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap, usize>) -> Vec<&'a [u8]> { + if piece.len() == 1 { + return vec![piece]; + } + _byte_pair_merge(piece, ranks, |p| &piece[p.start..p.end]) +} + +// Various performance notes: +// +// Regex +// ===== +// Most of the time is spent in regex. The easiest way to speed this up is by using less fancy +// regex features. For instance, using a regex parse-able by `regex` crate is 3x faster than +// the usual regex we use. +// +// However, given that we're using a regex parse-able by `regex`, there isn't much difference +// between using the `regex` crate and using the `fancy_regex` crate. +// +// There is an important interaction between threading, `regex` and `fancy_regex`. +// When using `fancy_regex`, we hit `regex.find_at`. It turns out that this causes contention on +// some mutable scratch space inside of `regex`. This absolutely kills performance. When using plain +// old `regex`, we don't hit this, because `find_iter` has a different code path. +// Related: https://github.com/rust-lang/regex/blob/master/PERFORMANCE.md +// Anyway, the way we get around this is with having a (mostly) thread local clone of the regex for +// each thread. +// +// Threading +// ========= +// I tried using `rayon`. It wasn't really faster than using Python threads and releasing the GIL. +// So goodbye `rayon`! Let thread count etc be in control of our Python users. +// +// Caching +// ======= +// The reference tokeniser has an lru cache over the equivalent of `byte_pair_encode`. +// Originally, we had one too! Without it, we were only vaguely faster than Python. +// I used an RWLock to protect the cache. This didn't seem to hurt single threaded performance +// noticeably, but it did affect multi-threaded performance. Weirdly, it seemed to affect +// multi-threaded performance even when I only had readers (maybed I messed something up?). +// Anyway, I realised that we could get rid of the cache, if we treat the set of tokens as a cache! +// These are exactly the set or merges that are likely to be hot. And now we don't have to think +// about interior mutability, memory use, or cloning. +// +// Hashing +// ======= +// We use FxHashMap instead of the standard HashMap. This is maybe like a 5-10% win? +// The current implementation ends up doing a lot of hashing of bytes. In theory, this could be made +// to be hashing of two-tuples of ints, which looks like it may also be a couple percent faster. + +pub struct FakeThreadId(NonZeroU64); + +pub fn hash_current_thread() -> usize { + // It's easier to use unsafe than to use nightly. Rust has this nice u64 thread id counter + // that works great for our use case of avoiding collisions in our array. Unfortunately, + // it's private. However, there are only so many ways you can layout a u64, so just transmute + // https://github.com/rust-lang/rust/issues/67939 + const _: [u8; 8] = [0; std::mem::size_of::()]; + const _: [u8; 8] = [0; std::mem::size_of::()]; + let x = unsafe { + std::mem::transmute::(thread::current().id()).0 + }; + u64::from(x) as usize +} + +pub const MAX_NUM_THREADS: usize = 128; +pub struct CoreBPE { + pub encoder: HashMap, usize>, + pub special_tokens_encoder: HashMap, + pub decoder: HashMap>, + pub special_tokens_decoder: HashMap>, + pub regex_tls: Vec, + pub special_regex_tls: Vec, + pub sorted_token_bytes: Vec>, +} + +impl CoreBPE { + pub fn _get_tl_regex(&self) -> &Regex { + // See performance notes above for what this is about + // It's also a little janky, please make a better version of it! + // However, it's nice that this doesn't leak memory to short-lived threads + &self.regex_tls[hash_current_thread() % MAX_NUM_THREADS] + } + + pub fn _get_tl_special_regex(&self) -> &Regex { + &self.special_regex_tls[hash_current_thread() % MAX_NUM_THREADS] + } + + pub fn _decode_native(&self, tokens: &[usize]) -> Vec { + let mut ret = Vec::with_capacity(tokens.len() * 2); + for token in tokens { + let token_bytes = self + .decoder + .get(token) + .unwrap_or_else(|| &self.special_tokens_decoder[token]); + ret.extend(token_bytes); + } + ret + } + + pub fn _encode_ordinary_native(&self, text: &str) -> Vec { + // This is the core of the encoding logic; the other functions in here + // just make things complicated :-) + let regex = self._get_tl_regex(); + let mut ret = vec![]; + for mat in regex.find_iter(text) { + let piece = mat.unwrap().as_str().as_bytes(); + if let Some(token) = self.encoder.get(piece) { + ret.push(*token); + continue; + } + ret.extend(&byte_pair_encode(piece, &self.encoder)); + } + ret + } + + pub fn _encode_native(&self, text: &str, allowed_special: &HashSet<&str>) -> (Vec, usize) { + let special_regex = self._get_tl_special_regex(); + let regex = self._get_tl_regex(); + let mut ret = vec![]; + + let mut start = 0; + let mut last_piece_token_len = 0; + loop { + let mut next_special; + let mut start_find = start; + loop { + // Find the next allowed special token, if any + next_special = special_regex.find_from_pos(text, start_find).unwrap(); + match next_special { + Some(m) => { + if allowed_special.contains(&text[m.start()..m.end()]) { + break; + } + start_find = m.start() + 1; + } + None => break, + } + } + let end = next_special.map_or(text.len(), |m| m.start()); + + // Okay, here we go, compare this logic to _encode_ordinary_native + for mat in regex.find_iter(&text[start..end]) { + let piece = mat.unwrap().as_str().as_bytes(); + if let Some(token) = self.encoder.get(piece) { + last_piece_token_len = 1; + ret.push(*token); + continue; + } + let tokens = byte_pair_encode(piece, &self.encoder); + last_piece_token_len = tokens.len(); + ret.extend(&tokens); + } + + match next_special { + // And here we push the special token + Some(m) => { + let piece = m.as_str(); + let token = self.special_tokens_encoder[piece]; + ret.push(token); + start = m.end(); + last_piece_token_len = 0; + } + None => break, + } + } + + // last_piece_token_len is how many tokens came from the last regex split. This is used + // for determining unstable tokens, since you can't merge across (stable) regex splits + (ret, last_piece_token_len) + } + + pub fn _increase_last_piece_token_len( + &self, + tokens: Vec, + mut last_piece_token_len: usize, + ) -> (Vec, usize) { + // Unfortunately, the locations where our regex splits can be unstable. + // For the purposes of determining unstable tokens, unstable regex splitting + // is only a problem if a split that was present disappears, since this can + // lead to merging of tokens otherwise thought to be stable. + // cl100k_base makes our life hard by including the \s*[\r\n]+ + // pattern. This can e.g. cause "\n" + " " to become "\n \n". + // Here is a quick and dirty fix: + { + let token_is_all_space = |token| { + self.decoder + .get(token) + .map(|token_bytes| { + token_bytes + .iter() + .rev() + .all(|&b| [b' ', b'\n', b'\t'].contains(&b)) + }) + .unwrap_or(false) + }; + if last_piece_token_len > 0 + && token_is_all_space(&tokens[tokens.len() - last_piece_token_len]) + { + while (last_piece_token_len < tokens.len()) + && token_is_all_space(&tokens[tokens.len() - last_piece_token_len - 1]) + { + last_piece_token_len += 1; + } + } + } + debug_assert!(last_piece_token_len <= tokens.len()); + + (tokens, last_piece_token_len) + } + + pub fn _encode_unstable_native( + &self, + text: &str, + allowed_special: &HashSet<&str>, + ) -> (Vec, HashSet>) { + let (tokens, last_piece_token_len) = self._encode_native(text, allowed_special); + if last_piece_token_len == 0 { + // If last_piece_token_len is zero, the last token was a special token and we have + // no unstable bytes + return (tokens, HashSet::new()); + } + let (mut tokens, last_piece_token_len) = + self._increase_last_piece_token_len(tokens, last_piece_token_len); + + let unstable_bytes = self._decode_native(&tokens[tokens.len() - last_piece_token_len..]); + tokens.truncate(tokens.len() - last_piece_token_len); + + // TODO: we should try harder to find additional stable tokens + // This would reduce the amount of retokenising when determining completions + // Refer to the logic in an older version of this file + + let mut completions = HashSet::new(); + if unstable_bytes.is_empty() { + return (tokens, completions); + } + + // This is the easy bit. Just find all single tokens that start with unstable_bytes + // (including tokens that exactly match unstable_bytes) + // Separating this from the loop below helps with performance in a common case. + let mut point = self + .sorted_token_bytes + .partition_point(|x| x.as_slice() < unstable_bytes.as_slice()); + while point < self.sorted_token_bytes.len() + && self.sorted_token_bytes[point].starts_with(&unstable_bytes) + { + completions.insert(vec![ + self.encoder[self.sorted_token_bytes[point].as_slice()], + ]); + point += 1; + } + + // Now apply even more brute force. At every (other) possible position for the straddling + // token, concatenate additional bytes from that token (if any) to unstable_bytes, + // and retokenise the whole thing and see what we get. + for i in 1..unstable_bytes.len() { + let prefix = &unstable_bytes[..i]; + let suffix = &unstable_bytes[i..]; + let mut point = self + .sorted_token_bytes + .partition_point(|x| x.as_slice() < suffix); + // TODO: Perf optimisation if suffix starts with " "? + while point < self.sorted_token_bytes.len() + && self.sorted_token_bytes[point].starts_with(suffix) + { + let possibility = [prefix, self.sorted_token_bytes[point].as_slice()].concat(); + let encoded = match std::str::from_utf8(&possibility) { + // Morally, this is byte_pair_encode(&possibility, &self.encoder) + // But we might have introduced a regex split which would prevent merges. + // (particularly possible in the presence of unstable regex splits) + // So convert to UTF-8 and do regex splitting. + // E.g. with cl100k_base " !" gets split to " " + " !", + // but byte_pair_encode(" !") != byte_pair_encode(" ") + Ok(s) => self._encode_ordinary_native(s), + + // Technically, whether or not this arm is correct depends on whether there + // would be a regex split before the UTF-8 truncation point. + // Probably niche enough that no one will ever notice (after all, people didn't + // notice all the big holes in the previous unstable token implementation) + Err(_) => byte_pair_encode(&possibility, &self.encoder), + // Something like the following is intriguing but incorrect: + // Err(e) => self._encode_ordinary_native(unsafe { + // std::str::from_utf8_unchecked(&possibility[..e.valid_up_to()]) + // }), + }; + let mut seq = Vec::new(); + let mut seq_len = 0; + for token in encoded { + seq.push(token); + seq_len += self.decoder[&token].len(); + if seq_len >= unstable_bytes.len() { + break; + } + } + completions.insert(seq); + point += 1; + } + } + + // This is also not straightforward. While we generally assume that regex splits are stable, + // unfortunately, they are not. That is, if adding bytes were to make a split appear in + // unstable_bytes, this could make tokens possible which our logic would otherwise think + // would be merged. + // For example, with gpt2, the use of \s+(?!\S) means that "\n\n" could + // develop a split, e.g. "\n\n0" splits into "\n"+"\n"+"0", making "\n" a possible token. + // Here is a quick and dirty fix: + // This isn't right if we ever remove \s+(?!\S) + if unstable_bytes.len() > 1 { + let last_decoded = bstr::decode_last_utf8(unstable_bytes.as_slice()); + if unstable_bytes.len() - last_decoded.1 > 0 + && last_decoded.0.map_or(false, |c| c.is_whitespace()) + { + let mut reencoded = byte_pair_encode( + &unstable_bytes[..unstable_bytes.len() - last_decoded.1], + &self.encoder, + ); + reencoded.extend(byte_pair_encode( + &unstable_bytes[unstable_bytes.len() - last_decoded.1..], + &self.encoder, + )); + completions.insert(reencoded); + } + } + + (tokens, completions) + } +} diff --git a/src/tiktoken_py.rs b/src/tiktoken_py.rs new file mode 100644 index 00000000..cb7cc6df --- /dev/null +++ b/src/tiktoken_py.rs @@ -0,0 +1,195 @@ +// This check is new and seems buggy (possibly with PyO3 interaction) +#![allow(clippy::borrow_deref_ref)] + +use std::collections::HashSet; + +use fancy_regex::Regex; +use pyo3::exceptions; +use pyo3::prelude::*; +use pyo3::PyResult; +use pyo3::types::{PyBytes, PyList, PyTuple}; +use rustc_hash::FxHashMap as HashMap; + +use crate::tiktoken::{byte_pair_encode, CoreBPE, MAX_NUM_THREADS}; + +#[pyclass] +pub struct PyCoreBPE { + pub core_bpe: CoreBPE, +} + + +#[pymethods] +impl PyCoreBPE { + #[new] + fn new( + encoder: HashMap, usize>, + special_tokens_encoder: HashMap, + pattern: &str, + ) -> PyResult { + let regex = Regex::new(pattern) + .map_err(|e| PyErr::new::(e.to_string()))?; + + let special_regex = { + let _parts = special_tokens_encoder + .keys() + .map(|s| fancy_regex::escape(s)) + .collect::>(); + Regex::new(&_parts.join("|")) + .map_err(|e| PyErr::new::(e.to_string()))? + }; + + let decoder: HashMap> = + encoder.iter().map(|(k, v)| (*v, k.clone())).collect(); + + assert!( + encoder.len() == decoder.len(), + "Encoder and decoder must be of equal length; maybe you had duplicate token indices in your encoder?" + ); + + let special_tokens_decoder: HashMap> = special_tokens_encoder + .iter() + .map(|(k, v)| (*v, k.as_bytes().to_vec())) + .collect(); + + // Clone because I don't know how to tell Rust I'm not going to change the map + let mut sorted_token_bytes: Vec> = encoder.keys().cloned().collect(); + sorted_token_bytes.sort(); + + let core_bpe = CoreBPE { + encoder, + special_tokens_encoder, + decoder, + special_tokens_decoder, + regex_tls: (0..MAX_NUM_THREADS).map(|_| regex.clone()).collect(), + special_regex_tls: (0..MAX_NUM_THREADS) + .map(|_| special_regex.clone()) + .collect(), + sorted_token_bytes, + }; + Ok(PyCoreBPE { core_bpe }) + } + + // ==================== + // Encoding + // ==================== + + fn encode_ordinary(&self, py: Python, text: &str) -> Vec { + py.allow_threads(|| self.core_bpe._encode_ordinary_native(text)) + } + + fn encode(&self, py: Python, text: &str, allowed_special: HashSet<&str>) -> Vec { + py.allow_threads(|| self.core_bpe._encode_native(text, &allowed_special).0) + } + + fn _encode_bytes(&self, py: Python, bytes: &[u8]) -> Vec { + py.allow_threads(|| { + match std::str::from_utf8(bytes) { + Ok(text) => self.core_bpe._encode_ordinary_native(text), + Err(e) => { + let text = unsafe { std::str::from_utf8_unchecked(&bytes[..e.valid_up_to()]) }; + let (tokens, last_piece_token_len) = self.core_bpe._encode_native(text, &HashSet::new()); + let (mut tokens, last_piece_token_len) = + self.core_bpe._increase_last_piece_token_len(tokens, last_piece_token_len); + if !tokens.is_empty() && last_piece_token_len > 0 { + // Lop off the tokens from the last piece and run BPE on the remaining bytes + // Somewhat niche, but this may not be correct if we'd have had a regex + // split between the valid UTF-8 and the invalid bytes, which is why this + // method is private + let mut unstable_bytes = + self.core_bpe._decode_native(&tokens[tokens.len() - last_piece_token_len..]); + unstable_bytes.extend_from_slice(&bytes[e.valid_up_to()..]); + + tokens.truncate(tokens.len() - last_piece_token_len); + tokens.extend(byte_pair_encode(&unstable_bytes, &self.core_bpe.encoder)); + } + tokens + } + } + }) + } + + fn encode_with_unstable( + &self, + py: Python, + text: &str, + allowed_special: HashSet<&str>, + ) -> Py { + let (tokens, completions) = + py.allow_threads(|| self.core_bpe._encode_unstable_native(text, &allowed_special)); + let py_completions = + PyList::new(py, completions.iter().map(|seq| PyList::new(py, &seq[..]))); + (tokens, py_completions).into_py(py) + } + + fn encode_single_token(&self, piece: &[u8]) -> PyResult { + if let Some(token) = self.core_bpe.encoder.get(piece).copied() { + return Ok(token); + } + if let Ok(piece_str) = std::str::from_utf8(piece) { + if let Some(token) = self.core_bpe.special_tokens_encoder.get(piece_str).copied() { + return Ok(token); + } + } + Err(PyErr::new::(piece.to_owned())) + } + + fn encode_single_piece(&self, piece: &[u8]) -> Vec { + if let Some(token) = self.core_bpe.encoder.get(piece) { + return vec![*token]; + } + byte_pair_encode(piece, &self.core_bpe.encoder) + } + + // ==================== + // Decoding + // ==================== + + fn decode_bytes(&self, py: Python, tokens: Vec) -> Py { + let bytes = py.allow_threads(|| self.core_bpe._decode_native(&tokens)); + PyBytes::new(py, &bytes).into() + } + + fn decode_single_token_bytes(&self, py: Python, token: usize) -> PyResult> { + if let Some(bytes) = self.core_bpe.decoder.get(&token) { + return Ok(PyBytes::new(py, bytes).into()); + } + if let Some(bytes) = self.core_bpe.special_tokens_decoder.get(&token) { + return Ok(PyBytes::new(py, bytes).into()); + } + Err(PyErr::new::(token.to_string())) + } + + // ==================== + // Miscellaneous + // ==================== + + fn token_byte_values(&self, py: Python) -> Vec> { + self.core_bpe.sorted_token_bytes + .iter() + .map(|x| PyBytes::new(py, x).into()) + .collect() + } +} + +#[pymodule] +pub fn _tiktoken(_py: Python, m: &PyModule) -> PyResult<()> { + m.add_class::()?; + Ok(()) +} + +#[cfg(test)] +mod tests { + use rustc_hash::FxHashMap as HashMap; + + use crate::tiktoken::byte_pair_split; + + #[test] + fn very_simple_test() { + let mut ranks = HashMap::default(); + ranks.insert(b"ab".to_vec(), 1); + ranks.insert(b"cd".to_vec(), 2); + + let res = byte_pair_split(b"abcd", &ranks); + assert_eq!(res, vec![b"ab", b"cd"]); + } +} diff --git a/tiktoken/core.py b/tiktoken/core.py index aa72c9d7..bbb1d4fa 100644 --- a/tiktoken/core.py +++ b/tiktoken/core.py @@ -47,7 +47,7 @@ def __init__( assert len(mergeable_ranks) + len(special_tokens) == explicit_n_vocab assert self.max_token_value == explicit_n_vocab - 1 - self._core_bpe = _tiktoken.CoreBPE(mergeable_ranks, special_tokens, pat_str) + self._core_bpe = _tiktoken.PyCoreBPE(mergeable_ranks, special_tokens, pat_str) def __repr__(self) -> str: return f""