Skip to content

Commit

Permalink
llama3 8B support, tiktoken tokenizer (#158)
Browse files Browse the repository at this point in the history
* WIP: llama3 support, tiktoken tokenizer

* Finalizing
  • Loading branch information
Artyom17 authored Apr 29, 2024
1 parent c21a889 commit 30d69b3
Show file tree
Hide file tree
Showing 8 changed files with 210 additions and 59 deletions.
6 changes: 3 additions & 3 deletions eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
torch._inductor.config.triton.cudagraphs = True
torch._dynamo.config.cache_size_limit = 100000

from sentencepiece import SentencePieceProcessor
from tokenizer import get_tokenizer

from model import Transformer

Expand Down Expand Up @@ -217,7 +217,7 @@ def main(
assert checkpoint_path.is_file(), checkpoint_path

tokenizer_path = checkpoint_path.parent / "tokenizer.model"
assert tokenizer_path.is_file(), tokenizer_path
assert tokenizer_path.is_file(), str(tokenizer_path)

device = 'cuda'
precision = torch.bfloat16
Expand All @@ -231,7 +231,7 @@ def main(

model.eval()

tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path))
tokenizer = get_tokenizer(tokenizer_path, checkpoint_path)

torch.manual_seed(1234)

Expand Down
9 changes: 4 additions & 5 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,8 @@ def device_sync(device):
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))

from sentencepiece import SentencePieceProcessor

from model import Transformer

from tokenizer import get_tokenizer

def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization
q = torch.empty_like(probs_sort).exponential_(1)
Expand Down Expand Up @@ -269,7 +267,7 @@ def main(
assert checkpoint_path.is_file(), checkpoint_path

tokenizer_path = checkpoint_path.parent / "tokenizer.model"
assert tokenizer_path.is_file(), tokenizer_path
assert tokenizer_path.is_file(), str(tokenizer_path)

global print
from tp import maybe_init_dist
Expand Down Expand Up @@ -297,7 +295,8 @@ def main(
device_sync(device=device) # MKG
print(f"Time to load model: {time.time() - t0:.02f} seconds")

tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path))
tokenizer = get_tokenizer(tokenizer_path, checkpoint_path)

encoded = encode_tokens(tokenizer, prompt, bos=True, device=device)
prompt_length = encoded.size(0)

Expand Down
2 changes: 1 addition & 1 deletion mixtral-moe/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def main(
assert checkpoint_path.is_file(), checkpoint_path

tokenizer_path = checkpoint_path.parent / "tokenizer.model"
assert tokenizer_path.is_file(), tokenizer_path
assert tokenizer_path.is_file(), str(tokenizer_path)

global print
rank = maybe_init_dist()
Expand Down
1 change: 1 addition & 0 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def from_name(cls, name: str):
"Mistral-7B": dict(n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=32000),
"stories15M": dict(n_layer=6, n_head=6, dim=288),
"stories110M": dict(n_layer=12, n_head=12, dim=768),
"Llama-3-8B": dict(block_size=8192, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256),
}

class KVCache(nn.Module):
Expand Down
6 changes: 3 additions & 3 deletions quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from sentencepiece import SentencePieceProcessor
from tokenizer import get_tokenizer

try:
from GPTQ import GenericGPTQRunner, InputRecorder
Expand Down Expand Up @@ -578,8 +578,8 @@ def quantize(
quant_handler = WeightOnlyInt4GPTQQuantHandler(model, groupsize)

tokenizer_path = checkpoint_path.parent / "tokenizer.model"
assert tokenizer_path.is_file(), tokenizer_path
tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path))
assert tokenizer_path.is_file(), str(tokenizer_path)
tokenizer = get_tokenizer(tokenizer_path, checkpoint_path)

quantized_state_dict = quant_handler.create_quantized_state_dict(
tokenizer,
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
torch
sentencepiece
tiktoken
133 changes: 86 additions & 47 deletions scripts/convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.
import json
import re
import shutil
import sys
from pathlib import Path
from typing import Optional
Expand All @@ -27,33 +28,62 @@ def convert_hf_checkpoint(
if model_name is None:
model_name = checkpoint_dir.name

# Llama 3 8B doesn't need conversion; instead, the original/consolidated.NN.pth files
# need to be copied into model.pth.
# Llama 3 70B can't be easily merged into one model.pth file, though, since names of the
# weights is state dict are the same in each consolidated.NN.pth file. Thus, it is not
# currently supported.
# Along this, we need to copy the original/tokenizer.model file to tokenizer.model.tiktoken
is_llama3 = "Llama-3" in model_name
if is_llama3:
# Check if we have multiple original/consolidated.NN.pth files and report error
# if we do for Llama 3.
original_dir = checkpoint_dir / "original"
pattern = re.compile(r"^consolidated\.\d{2}\.pth$")
bin_files = [bin for bin in original_dir.iterdir() if pattern.match(bin.name)]
if len(bin_files) > 1:
raise ValueError(
f"Multiple consolidated.NN.pth files found in {original_dir}. "
"Merging them into one model.pth file is not supported for Llama 3.")


config = ModelArgs.from_name(model_name)
print(f"Model config {config.__dict__}")

# Load the json file containing weight mapping
model_map_json = checkpoint_dir / "pytorch_model.bin.index.json"

assert model_map_json.is_file()

with open(model_map_json) as json_map:
bin_index = json.load(json_map)

weight_map = {
"model.embed_tokens.weight": "tok_embeddings.weight",
"model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
"model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
"model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
"model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
'model.layers.{}.self_attn.rotary_emb.inv_freq': None,
'model.layers.{}.mlp.gate_proj.weight': 'layers.{}.feed_forward.w1.weight',
"model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
"model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
"model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
"model.norm.weight": "norm.weight",
"lm_head.weight": "output.weight",
}
bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()}
if not is_llama3:
model_map_json = checkpoint_dir / "pytorch_model.bin.index.json"

assert model_map_json.is_file()

with open(model_map_json) as json_map:
bin_index = json.load(json_map)

weight_map = {
"model.embed_tokens.weight": "tok_embeddings.weight",
"model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
"model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
"model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
"model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
'model.layers.{}.self_attn.rotary_emb.inv_freq': None,
'model.layers.{}.mlp.gate_proj.weight': 'layers.{}.feed_forward.w1.weight',
"model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
"model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
"model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
"model.norm.weight": "norm.weight",
"lm_head.weight": "output.weight",
}
bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()}
else:
# There is no separate pytorch_model.bin.index.json file for llama3.
# Instead, we will just use all original/consolidated.NN.pth files.
# so, we use model.safetensors.index.json
weight_map = None
original_dir = checkpoint_dir / "original"
pattern = re.compile(r"^consolidated\.\d{2}\.pth$")
bin_files = {bin for bin in original_dir.iterdir() if pattern.match(bin.name)}


def permute(w, n_head):
dim = config.dim
Expand All @@ -68,32 +98,41 @@ def permute(w, n_head):
state_dict = torch.load(str(file), map_location="cpu", mmap=True, weights_only=True)
merged_result.update(state_dict)
final_result = {}
for key, value in merged_result.items():
if "layers" in key:
abstract_key = re.sub(r'(\d+)', '{}', key)
layer_num = re.search(r'\d+', key).group(0)
new_key = weight_map[abstract_key]
if new_key is None:
continue
new_key = new_key.format(layer_num)
else:
new_key = weight_map[key]

final_result[new_key] = value

for key in tuple(final_result.keys()):
if "wq" in key:
q = final_result[key]
k = final_result[key.replace("wq", "wk")]
v = final_result[key.replace("wq", "wv")]
q = permute(q, config.n_head)
k = permute(k, config.n_local_heads)
final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v])
del final_result[key]
del final_result[key.replace("wq", "wk")]
del final_result[key.replace("wq", "wv")]
if weight_map is not None:
for key, value in merged_result.items():
if "layers" in key:
abstract_key = re.sub(r'(\d+)', '{}', key)
layer_num = re.search(r'\d+', key).group(0)
new_key = weight_map[abstract_key]
if new_key is None:
continue
new_key = new_key.format(layer_num)
else:
new_key = weight_map[key]

final_result[new_key] = value

for key in tuple(final_result.keys()):
if "wq" in key:
q = final_result[key]
k = final_result[key.replace("wq", "wk")]
v = final_result[key.replace("wq", "wv")]
q = permute(q, config.n_head)
k = permute(k, config.n_local_heads)
final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v])
del final_result[key]
del final_result[key.replace("wq", "wk")]
del final_result[key.replace("wq", "wv")]
else:
final_result = merged_result
print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}")
torch.save(final_result, checkpoint_dir / "model.pth")
if is_llama3:
original_dir = checkpoint_dir / "original"
tokenizer_model = original_dir / "tokenizer.model"
tokenizer_model_tiktoken = checkpoint_dir / "tokenizer.model"
print(f"Copying {tokenizer_model} to {tokenizer_model_tiktoken}")
shutil.copy(tokenizer_model, tokenizer_model_tiktoken)

if __name__ == '__main__':
import argparse
Expand Down
111 changes: 111 additions & 0 deletions tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import os
import sentencepiece as spm
import tiktoken
from tiktoken.load import load_tiktoken_bpe
from pathlib import Path
from typing import Dict

class TokenizerInterface:
def __init__(self, model_path):
self.model_path = model_path

def encode(self, text):
raise NotImplementedError("This method should be overridden by subclasses.")

def decode(self, tokens):
raise NotImplementedError("This method should be overridden by subclasses.")

def bos_id(self):
raise NotImplementedError("This method should be overridden by subclasses.")

def eos_id(self):
raise NotImplementedError("This method should be overridden by subclasses.")

class SentencePieceWrapper(TokenizerInterface):
def __init__(self, model_path):
super().__init__(model_path)
self.processor = spm.SentencePieceProcessor(str(model_path))

def encode(self, text):
return self.processor.EncodeAsIds(text)

def decode(self, tokens):
return self.processor.DecodeIds(tokens)

def bos_id(self):
return self.processor.bos_id()

def eos_id(self):
return self.processor.eos_id()

class TiktokenWrapper(TokenizerInterface):
"""
Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
"""

special_tokens: Dict[str, int]

num_reserved_special_tokens = 256

pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501

def __init__(self, model_path):
super().__init__(model_path)
assert os.path.isfile(model_path), str(model_path)
mergeable_ranks = load_tiktoken_bpe(str(model_path))
num_base_tokens = len(mergeable_ranks)
special_tokens = [
"<|begin_of_text|>",
"<|end_of_text|>",
"<|reserved_special_token_0|>",
"<|reserved_special_token_1|>",
"<|reserved_special_token_2|>",
"<|reserved_special_token_3|>",
"<|start_header_id|>",
"<|end_header_id|>",
"<|reserved_special_token_4|>",
"<|eot_id|>", # end of turn
] + [
f"<|reserved_special_token_{i}|>"
for i in range(5, self.num_reserved_special_tokens - 5)
]
self.special_tokens = {
token: num_base_tokens + i for i, token in enumerate(special_tokens)
}
self.model = tiktoken.Encoding(
name=Path(model_path).name,
pat_str=self.pat_str,
mergeable_ranks=mergeable_ranks,
special_tokens=self.special_tokens,
)
# BOS / EOS token IDs
self._bos_id: int = self.special_tokens["<|begin_of_text|>"]
self._eos_id: int = self.special_tokens["<|end_of_text|>"]

def encode(self, text):
return self.model.encode(text)

def decode(self, tokens):
return self.model.decode(tokens)

def bos_id(self):
return self._bos_id

def eos_id(self):
return self._eos_id

def get_tokenizer(tokenizer_model_path, model_name):
"""
Factory function to get the appropriate tokenizer based on the model name.
Args:
- tokenizer_model_path (str): The file path to the tokenizer model.
- model_name (str): The name of the model, used to determine the tokenizer type.
Returns:
- TokenizerInterface: An instance of a tokenizer.
"""
if "Llama-3" in str(model_name):
return TiktokenWrapper(tokenizer_model_path)
else:
return SentencePieceWrapper(tokenizer_model_path)

0 comments on commit 30d69b3

Please sign in to comment.