diff --git a/tiktoken/load.py b/tiktoken/load.py index cc0a6a6d..11032dfd 100644 --- a/tiktoken/load.py +++ b/tiktoken/load.py @@ -12,6 +12,19 @@ def read_file(blobpath: str) -> bytes: + """ + Reads the contents of a file specified by the given blobpath. + + Parameters + ---------- + blobpath : str + The path or URL to the file to be read. + + Returns + ------- + bytes + The binary content of the file. + """ if not blobpath.startswith("http://") and not blobpath.startswith("https://"): try: import blobfile @@ -28,11 +41,44 @@ def read_file(blobpath: str) -> bytes: def check_hash(data: bytes, expected_hash: str) -> bool: + """ + Checks if the hash of the given data matches the expected hash. + + Parameters + ---------- + data : bytes + The binary data to be hashed. + + expected_hash : str + The expected hash value. + + Returns + ------- + bool + True if the actual hash matches the expected hash, False otherwise. + """ actual_hash = hashlib.sha256(data).hexdigest() return actual_hash == expected_hash def read_file_cached(blobpath: str, expected_hash: Optional[str] = None) -> bytes: + """ + Reads the contents of a file specified by the given blobpath from cache if available, + otherwise fetches it from the source, caches it, and returns the content. + + Parameters + ---------- + blobpath : str + The path or URL to the file to be read. + + expected_hash : str, optional + The expected hash value of the file content. Default is None. + + Returns + ------- + bytes + The binary content of the file. + """ user_specified_cache = True if "TIKTOKEN_CACHE_DIR" in os.environ: cache_dir = os.environ["TIKTOKEN_CACHE_DIR"] @@ -88,6 +134,28 @@ def data_gym_to_mergeable_bpe_ranks( vocab_bpe_hash: Optional[str] = None, encoder_json_hash: Optional[str] = None, ) -> dict[bytes, int]: + """ + Converts a vocab BPE file and an encoder JSON file into mergeable BPE ranks. + + Parameters + ---------- + vocab_bpe_file : str + The path to the vocabulary BPE file. + + encoder_json_file : str + The path to the encoder JSON file. + + vocab_bpe_hash : str, optional + The expected hash value of the vocabulary BPE file. Default is None. + + encoder_json_hash : str, optional + The expected hash value of the encoder JSON file. Default is None. + + Returns + ------- + dict[bytes, int] + A dictionary mapping mergeable BPE tokens to their ranks. + """ # NB: do not add caching to this function rank_to_intbyte = [b for b in range(2**8) if chr(b).isprintable() and chr(b) != " "] @@ -129,6 +197,21 @@ def decode_data_gym(value: str) -> bytes: def dump_tiktoken_bpe(bpe_ranks: dict[bytes, int], tiktoken_bpe_file: str) -> None: + """ + Dumps the mergeable BPE ranks to a TikToken BPE file. + + Parameters + ---------- + bpe_ranks : dict[bytes, int] + A dictionary mapping mergeable BPE tokens to their ranks. + + tiktoken_bpe_file : str + The path to the TikToken BPE file. + + Returns + ------- + None + """ try: import blobfile except ImportError as e: @@ -143,6 +226,22 @@ def dump_tiktoken_bpe(bpe_ranks: dict[bytes, int], tiktoken_bpe_file: str) -> No def load_tiktoken_bpe( tiktoken_bpe_file: str, expected_hash: Optional[str] = None ) -> dict[bytes, int]: + """ + Loads mergeable BPE ranks from a TikToken BPE file. + + Parameters + ---------- + tiktoken_bpe_file : str + The path to the TikToken BPE file. + + expected_hash : str, optional + The expected hash value of the file content. Default is None. + + Returns + ------- + dict[bytes, int] + A dictionary mapping mergeable BPE tokens to their ranks. + """ # NB: do not add caching to this function contents = read_file_cached(tiktoken_bpe_file, expected_hash) return { diff --git a/tiktoken/model.py b/tiktoken/model.py index 17532aee..6711e00e 100644 --- a/tiktoken/model.py +++ b/tiktoken/model.py @@ -69,9 +69,44 @@ def encoding_name_for_model(model_name: str) -> str: - """Returns the name of the encoding used by a model. + """ + Returns the name of the encoding used by a model. + + Parameters + ---------- + model_name : str + The name of the model. + + Returns + ------- + encoding_name : str + The name of the encoding used by the model. + + Raises + ------ + KeyError + If the model name is not recognized or cannot be mapped to an encoding. + + Notes + ----- + This function checks if the provided model name is directly mapped to an encoding in MODEL_TO_ENCODING. + If not, it attempts to match the model name with known prefixes in MODEL_PREFIX_TO_ENCODING. + If a match is found, it returns the corresponding encoding name. + + If the model name cannot be mapped to any encoding, it raises a KeyError. - Raises a KeyError if the model name is not recognised. + Examples + -------- + >>> encoding_name_for_model("gpt2") + 'gpt2' + + >>> encoding_name_for_model("roberta-large") + 'roberta' + + >>> encoding_name_for_model("nonexistent-model") + Traceback (most recent call last): + ... + KeyError: "Could not automatically map nonexistent-model to a tokeniser. Please use `tiktoken.get_encoding` to explicitly get the tokeniser you expect." """ encoding_name = None if model_name in MODEL_TO_ENCODING: @@ -94,8 +129,22 @@ def encoding_name_for_model(model_name: str) -> str: def encoding_for_model(model_name: str) -> Encoding: - """Returns the encoding used by a model. + """ + Returns the encoding used by a model. + + Parameters + ---------- + model_name : str + The name of the model. + + Returns + ------- + encoding : Encoding + The encoding used by the model. - Raises a KeyError if the model name is not recognised. + Raises + ------ + KeyError + If the model name is not recognized or cannot be mapped to an encoding. """ return get_encoding(encoding_name_for_model(model_name)) diff --git a/tiktoken/registry.py b/tiktoken/registry.py index a753ce67..6d06f449 100644 --- a/tiktoken/registry.py +++ b/tiktoken/registry.py @@ -17,6 +17,14 @@ @functools.lru_cache() def _available_plugin_modules() -> Sequence[str]: + """ + Returns a sequence of available plugin modules. + + Returns + ------- + Sequence[str] + A sequence of available plugin modules. + """ # tiktoken_ext is a namespace package # submodules inside tiktoken_ext will be inspected for ENCODING_CONSTRUCTORS attributes # - we use namespace package pattern so `pkgutil.iter_modules` is fast @@ -30,6 +38,22 @@ def _available_plugin_modules() -> Sequence[str]: def _find_constructors() -> None: + """ + Finds encoding constructors from available plugin modules and populates the ENCODING_CONSTRUCTORS dictionary. + + Parameters + ---------- + None + + Returns + ------- + None + + Raises + ------ + ValueError + If a plugin module does not define ENCODING_CONSTRUCTORS or if there are duplicate encoding names. + """ global ENCODING_CONSTRUCTORS with _lock: if ENCODING_CONSTRUCTORS is not None: @@ -53,6 +77,24 @@ def _find_constructors() -> None: def get_encoding(encoding_name: str) -> Encoding: + """ + Retrieves an Encoding object for the specified encoding name. + + Parameters + ---------- + encoding_name : str + The name of the encoding. + + Returns + ------- + Encoding + The Encoding object for the specified encoding name. + + Raises + ------ + ValueError + If the specified encoding name is unknown. + """ if encoding_name in ENCODINGS: return ENCODINGS[encoding_name] @@ -76,6 +118,18 @@ def get_encoding(encoding_name: str) -> Encoding: def list_encoding_names() -> list[str]: + """ + Lists available encoding names. + + Parameters + ---------- + None + + Returns + ------- + list[str] + A list of available encoding names. + """ with _lock: if ENCODING_CONSTRUCTORS is None: _find_constructors()