Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for checking hash of downloaded files before use. #230

Merged
merged 2 commits into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "tiktoken"
version = "0.5.2"
version = "0.6.0"
description = "tiktoken is a fast BPE tokeniser for use with OpenAI's models"
readme = "README.md"
license = {file = "LICENSE"}
Expand Down
31 changes: 24 additions & 7 deletions tiktoken/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import os
import tempfile
import uuid
from typing import Optional

import requests

Expand All @@ -26,7 +27,12 @@ def read_file(blobpath: str) -> bytes:
return resp.content


def read_file_cached(blobpath: str) -> bytes:
def check_hash(data: bytes, hash: str) -> bool:
data_hash = hashlib.sha256(data).hexdigest()
return data_hash == hash


def read_file_cached(blobpath: str, expected_hash: Optional[str]=None) -> bytes:
user_specified_cache = True
if "TIKTOKEN_CACHE_DIR" in os.environ:
cache_dir = os.environ["TIKTOKEN_CACHE_DIR"]
Expand All @@ -45,9 +51,20 @@ def read_file_cached(blobpath: str) -> bytes:
cache_path = os.path.join(cache_dir, cache_key)
if os.path.exists(cache_path):
with open(cache_path, "rb") as f:
return f.read()
data = f.read()
if expected_hash and not check_hash(data, expected_hash):
raise ValueError(
f"Hash mismatch for cached data from {blobpath} (expected {expected_hash}). "
f"Please delete the cache file at {cache_path} and try again."
)
return data

contents = read_file(blobpath)
if expected_hash and not check_hash(contents, expected_hash):
raise ValueError(
f"Hash mismatch for data downloaded from {blobpath} (expected {expected_hash}). "
f"This may indicate a corrupted download. Please try again."
)

try:
os.makedirs(cache_dir, exist_ok=True)
Expand All @@ -64,7 +81,7 @@ def read_file_cached(blobpath: str) -> bytes:


def data_gym_to_mergeable_bpe_ranks(
vocab_bpe_file: str, encoder_json_file: str
vocab_bpe_file: str, encoder_json_file: str, vocab_bpe_hash: Optional[str]=None, encoder_json_hash: Optional[str]=None
) -> dict[bytes, int]:
# NB: do not add caching to this function
rank_to_intbyte = [b for b in range(2**8) if chr(b).isprintable() and chr(b) != " "]
Expand All @@ -79,7 +96,7 @@ def data_gym_to_mergeable_bpe_ranks(
assert len(rank_to_intbyte) == 2**8

# vocab_bpe contains the merges along with associated ranks
vocab_bpe_contents = read_file_cached(vocab_bpe_file).decode()
vocab_bpe_contents = read_file_cached(vocab_bpe_file, vocab_bpe_hash).decode()
bpe_merges = [tuple(merge_str.split()) for merge_str in vocab_bpe_contents.split("\n")[1:-1]]

def decode_data_gym(value: str) -> bytes:
Expand All @@ -96,7 +113,7 @@ def decode_data_gym(value: str) -> bytes:
# check that the encoder file matches the merges file
# this sanity check is important since tiktoken assumes that ranks are ordered the same
# as merge priority
encoder_json = json.loads(read_file_cached(encoder_json_file))
encoder_json = json.loads(read_file_cached(encoder_json_file, encoder_json_hash))
encoder_json_loaded = {decode_data_gym(k): v for k, v in encoder_json.items()}
# drop these two special tokens if present, since they're not mergeable bpe tokens
encoder_json_loaded.pop(b"<|endoftext|>", None)
Expand All @@ -118,9 +135,9 @@ def dump_tiktoken_bpe(bpe_ranks: dict[bytes, int], tiktoken_bpe_file: str) -> No
f.write(base64.b64encode(token) + b" " + str(rank).encode() + b"\n")


def load_tiktoken_bpe(tiktoken_bpe_file: str) -> dict[bytes, int]:
def load_tiktoken_bpe(tiktoken_bpe_file: str, expected_hash: Optional[str]=None) -> dict[bytes, int]:
# NB: do not add caching to this function
contents = read_file_cached(tiktoken_bpe_file)
contents = read_file_cached(tiktoken_bpe_file, expected_hash)
return {
base64.b64decode(token): int(rank)
for token, rank in (line.split() for line in contents.splitlines() if line)
Expand Down
14 changes: 10 additions & 4 deletions tiktoken_ext/openai_public.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ def gpt2():
mergeable_ranks = data_gym_to_mergeable_bpe_ranks(
vocab_bpe_file="https://openaipublic.blob.core.windows.net/gpt-2/encodings/main/vocab.bpe",
encoder_json_file="https://openaipublic.blob.core.windows.net/gpt-2/encodings/main/encoder.json",
vocab_bpe_hash="1ce1664773c50f3e0cc8842619a93edc4624525b728b188a9e0be33b7726adc5",
encoder_json_hash="196139668be63f3b5d6574427317ae82f612a97c5d1cdaf36ed2256dbf636783",
)
return {
"name": "gpt2",
Expand All @@ -23,7 +25,8 @@ def gpt2():

def r50k_base():
mergeable_ranks = load_tiktoken_bpe(
"https://openaipublic.blob.core.windows.net/encodings/r50k_base.tiktoken"
"https://openaipublic.blob.core.windows.net/encodings/r50k_base.tiktoken",
expected_hash="306cd27f03c1a714eca7108e03d66b7dc042abe8c258b44c199a7ed9838dd930",
)
return {
"name": "r50k_base",
Expand All @@ -36,7 +39,8 @@ def r50k_base():

def p50k_base():
mergeable_ranks = load_tiktoken_bpe(
"https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken"
"https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken",
expected_hash="94b5ca7dff4d00767bc256fdd1b27e5b17361d7b8a5f968547f9f23eb70d2069",
)
return {
"name": "p50k_base",
Expand All @@ -49,7 +53,8 @@ def p50k_base():

def p50k_edit():
mergeable_ranks = load_tiktoken_bpe(
"https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken"
"https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken",
expected_hash="94b5ca7dff4d00767bc256fdd1b27e5b17361d7b8a5f968547f9f23eb70d2069",
)
special_tokens = {ENDOFTEXT: 50256, FIM_PREFIX: 50281, FIM_MIDDLE: 50282, FIM_SUFFIX: 50283}
return {
Expand All @@ -62,7 +67,8 @@ def p50k_edit():

def cl100k_base():
mergeable_ranks = load_tiktoken_bpe(
"https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken"
"https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken",
expected_hash="223921b76ee99bde995b7ff738513eef100fb51d18c93597a113bcffe865b2a7",
)
special_tokens = {
ENDOFTEXT: 100257,
Expand Down
Loading