diff --git a/Cargo.toml b/Cargo.toml index 948b9f13..908492d3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,21 +1,5 @@ -[package] -name = "tiktoken" -version = "0.4.0" -edition = "2021" -rust-version = "1.57.0" - -[lib] -name = "_tiktoken" -crate-type = ["cdylib"] - -[dependencies] -pyo3 = { version = "0.19.0", features = ["extension-module"] } - -# tiktoken dependencies -fancy-regex = "0.11.0" -regex = "1.8.3" -rustc-hash = "1.1.0" -bstr = "1.5.0" - -[profile.release] -incremental = true +[workspace] +members = [ + "rs-tiktoken", + "py-tiktoken", +] \ No newline at end of file diff --git a/MANIFEST.in b/MANIFEST.in index 7f25b271..63adb28c 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -5,4 +5,5 @@ include Makefile global-include py.typed recursive-include scripts *.py recursive-include tests *.py -recursive-include src *.rs +recursive-include py-tiktoken *.rs +recursive-include rs-tiktoken *.rs diff --git a/py-tiktoken/Cargo.toml b/py-tiktoken/Cargo.toml new file mode 100644 index 00000000..e02a8121 --- /dev/null +++ b/py-tiktoken/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "py-tiktoken" +version = "0.4.0" +edition = "2021" +rust-version = "1.57.0" + +[lib] +name = "_tiktoken" +crate-type = ["cdylib"] + +[dependencies] +tiktoken = { path = "../rs-tiktoken" } +pyo3 = { version = "0.19.0", features = ["extension-module"] } + +# tiktoken dependencies +fancy-regex = "0.11.0" +regex = "1.8.3" +rustc-hash = "1.1.0" +bstr = "1.5.0" + +[profile.release] +incremental = true diff --git a/py-tiktoken/src/lib.rs b/py-tiktoken/src/lib.rs new file mode 100644 index 00000000..e13657a2 --- /dev/null +++ b/py-tiktoken/src/lib.rs @@ -0,0 +1 @@ +pub mod tiktoken_py; diff --git a/src/tiktoken_py.rs b/py-tiktoken/src/tiktoken_py.rs similarity index 98% rename from src/tiktoken_py.rs rename to py-tiktoken/src/tiktoken_py.rs index cb7cc6df..90157116 100644 --- a/src/tiktoken_py.rs +++ b/py-tiktoken/src/tiktoken_py.rs @@ -10,7 +10,7 @@ use pyo3::PyResult; use pyo3::types::{PyBytes, PyList, PyTuple}; use rustc_hash::FxHashMap as HashMap; -use crate::tiktoken::{byte_pair_encode, CoreBPE, MAX_NUM_THREADS}; +use tiktoken::core::{byte_pair_encode, CoreBPE, MAX_NUM_THREADS}; #[pyclass] pub struct PyCoreBPE { @@ -181,7 +181,7 @@ pub fn _tiktoken(_py: Python, m: &PyModule) -> PyResult<()> { mod tests { use rustc_hash::FxHashMap as HashMap; - use crate::tiktoken::byte_pair_split; + use crate::core::byte_pair_split; #[test] fn very_simple_test() { diff --git a/rs-tiktoken/Cargo.toml b/rs-tiktoken/Cargo.toml new file mode 100644 index 00000000..520a6ebf --- /dev/null +++ b/rs-tiktoken/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "tiktoken" +version = "0.4.0" +edition = "2021" +rust-version = "1.57.0" + +[dependencies] +fancy-regex = "0.11.0" +regex = "1.8.3" +rustc-hash = "1.1.0" +bstr = "1.5.0" +once_cell = "1.18.0" + +[profile.release] +incremental = true diff --git a/src/tiktoken.rs b/rs-tiktoken/src/core.rs similarity index 99% rename from src/tiktoken.rs rename to rs-tiktoken/src/core.rs index 53c1a075..e32a37c3 100644 --- a/src/tiktoken.rs +++ b/rs-tiktoken/src/core.rs @@ -152,7 +152,7 @@ pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap, usize>) -> pub struct FakeThreadId(NonZeroU64); -pub fn hash_current_thread() -> usize { +fn hash_current_thread() -> usize { // It's easier to use unsafe than to use nightly. Rust has this nice u64 thread id counter // that works great for our use case of avoiding collisions in our array. Unfortunately, // it's private. However, there are only so many ways you can layout a u64, so just transmute diff --git a/rs-tiktoken/src/encoding.rs b/rs-tiktoken/src/encoding.rs new file mode 100644 index 00000000..14908fd2 --- /dev/null +++ b/rs-tiktoken/src/encoding.rs @@ -0,0 +1,66 @@ +//! WARNING: This code is under active development. Functionality, +//! behavior, and the interface may change in future updates. + +use std::collections::HashMap; +use once_cell::sync::Lazy; +use regex::Regex; + + +pub struct Encoding { + /// The name of the encoding. It should be clear from the name of the encoding + /// what behaviour to expect, in particular, encodings with different special tokens + /// should have different names. + pub name: &'static str, + /// A regex pattern string that is used to split the input text. + pub pat_str: Regex, + /// A dictionary mapping mergeable token bytes to their ranks. The ranks + /// must correspond to merge priority. + pub mergeable_ranks: HashMap<&'static str, u32>, + /// A dictionary mapping special token strings to their token values. + pub special_tokens: HashMap<&'static str, u32>, + /// The number of tokens in the vocabulary. If provided, it is checked + /// that the number of mergeable tokens and special tokens is equal to this number. + pub explicit_n_vocab: Option, +} + +pub static GPT2: Lazy = Lazy::new(|| { + let mergeable_ranks = Default::default(); + let special_tokens = [ + ("<|endoftext|>", 50256) + ].iter().cloned().collect(); + + Encoding{ + name: "gpt2", + pat_str: Regex::new(r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+").unwrap(), + mergeable_ranks, + special_tokens, + explicit_n_vocab: Some(50257), + } +}); + +pub fn get_encoding() { + +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_simple() { + // enc = tiktoken.get_encoding("gpt2") + // assert enc.encode("hello world") == [31373, 995] + // assert enc.decode([31373, 995]) == "hello world" + // assert enc.encode("hello <|endoftext|>", allowed_special="all") == [31373, 220, 50256] + // + // enc = tiktoken.get_encoding("cl100k_base") + // assert enc.encode("hello world") == [15339, 1917] + // assert enc.decode([15339, 1917]) == "hello world" + // assert enc.encode("hello <|endoftext|>", allowed_special="all") == [15339, 220, 100257] + // + // for enc_name in tiktoken.list_encoding_names(): + // enc = tiktoken.get_encoding(enc_name) + // for token in range(10_000): + // assert enc.encode_single_token(enc.decode_single_token_bytes(token)) == token + } +} \ No newline at end of file diff --git a/src/lib.rs b/rs-tiktoken/src/lib.rs similarity index 62% rename from src/lib.rs rename to rs-tiktoken/src/lib.rs index 54210bd5..e42cd75a 100644 --- a/src/lib.rs +++ b/rs-tiktoken/src/lib.rs @@ -1,3 +1,4 @@ // This check is new and seems buggy (possibly with PyO3 interaction) -pub mod tiktoken_py; -pub mod tiktoken; \ No newline at end of file +pub mod core; +pub mod encoding; +mod model; \ No newline at end of file diff --git a/rs-tiktoken/src/model.rs b/rs-tiktoken/src/model.rs new file mode 100644 index 00000000..e90d9725 --- /dev/null +++ b/rs-tiktoken/src/model.rs @@ -0,0 +1,3 @@ +//! WARNING: This code is under active development. Functionality, +//! behavior, and the interface may change in future updates. + diff --git a/setup.py b/setup.py index a22e8e5d..b0b42967 100644 --- a/setup.py +++ b/setup.py @@ -5,11 +5,12 @@ name="tiktoken", rust_extensions=[ RustExtension( - "tiktoken._tiktoken", + target="tiktoken._tiktoken", binding=Binding.PyO3, # Between our use of editable installs and wanting to use Rust for performance sensitive # code, it makes sense to just always use --release debug=False, + path="py-tiktoken/Cargo.toml", ) ], package_data={"tiktoken": ["py.typed"]},