-
Notifications
You must be signed in to change notification settings - Fork 493
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Also fixes a bug uncovered by said tests.
- Loading branch information
Showing
4 changed files
with
255 additions
and
48 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import os | ||
import tempfile | ||
from typing import Callable, Optional | ||
|
||
from transformers import LlamaConfig, LlamaForCausalLM | ||
|
||
from mergekit.config import MergeConfiguration | ||
from mergekit.io.lazy_tensor_loader import LazyTensorLoader, ShardedTensorIndex | ||
from mergekit.merge import MergeOptions, run_merge | ||
|
||
|
||
def run_and_check_merge( | ||
config: MergeConfiguration, | ||
check_nan: bool = True, | ||
validate: Optional[Callable[[str], None]] = None, | ||
): | ||
with tempfile.TemporaryDirectory() as tmpdir: | ||
run_merge(config, out_path=tmpdir, options=MergeOptions()) | ||
assert os.path.exists( | ||
os.path.join(tmpdir, "model.safetensors.index.json") | ||
), "No index file for merge" | ||
assert os.path.exists( | ||
os.path.join(tmpdir, "config.json") | ||
), "No config json produced by merge" | ||
|
||
if check_nan: | ||
# check for NaN in output | ||
loader = LazyTensorLoader( | ||
ShardedTensorIndex.from_disk(tmpdir), lazy_unpickle=False | ||
) | ||
tp = loader.index.tensor_paths | ||
sorted_tensors = sorted(tp.keys(), key=lambda k: tp[k]) | ||
for tensor_name in sorted_tensors: | ||
tensor = loader.get_tensor(tensor_name) | ||
has_nan = tensor.view(-1).isnan().any() | ||
assert not has_nan, "Output contains NaN" | ||
|
||
if validate: | ||
validate(tmpdir) | ||
|
||
|
||
def make_picollama(path: str, vocab_size: int = 64): | ||
cfg = LlamaConfig( | ||
vocab_size=vocab_size, | ||
hidden_size=32, | ||
intermediate_size=48, | ||
num_attention_heads=16, | ||
num_hidden_layers=2, | ||
) | ||
model = LlamaForCausalLM(cfg) | ||
model.save_pretrained(path, safe_serialization=True) | ||
return str(path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,189 @@ | ||
import json | ||
import os | ||
import tempfile | ||
from typing import List, Optional, Union | ||
|
||
import pytest | ||
import tokenizers | ||
from common import make_picollama, run_and_check_merge | ||
from transformers import LlamaTokenizerFast, PreTrainedTokenizerBase | ||
|
||
from mergekit.config import InputModelDefinition, MergeConfiguration, ParameterSetting | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def model_base(tmp_path_factory): | ||
model_path = make_picollama(tmp_path_factory.mktemp("model_base"), vocab_size=64) | ||
make_tokenizer(vocab_size=64, added_tokens=[]).save_pretrained(model_path) | ||
return model_path | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def model_chatml(tmp_path_factory): | ||
model_path = make_picollama(tmp_path_factory.mktemp("model_chatml"), vocab_size=66) | ||
make_tokenizer( | ||
vocab_size=64, added_tokens=["<|im_start|>", "<|im_end|>"] | ||
).save_pretrained(model_path) | ||
return model_path | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def model_padded(tmp_path_factory): | ||
model_path = make_picollama(tmp_path_factory.mktemp("model_padded"), vocab_size=64) | ||
make_tokenizer( | ||
vocab_size=64, | ||
added_tokens=["<UNUSED_0>", "<UNUSED_1>", "<UNUSED_2>", "<UNUSED_3>"], | ||
).save_pretrained(model_path) | ||
return model_path | ||
|
||
|
||
def make_tokenizer( | ||
vocab_size: int, added_tokens: List[Union[str, tokenizers.AddedToken]] | ||
) -> PreTrainedTokenizerBase: | ||
tokens = ["<unk>", "<s>", "</s>"] + [f"_tok_{idx}" for idx in range(3, vocab_size)] | ||
tokens = tokens[:vocab_size] | ||
tok_data = { | ||
"version": "1.0", | ||
"model": { | ||
"type": "BPE", | ||
"vocab": dict(zip(tokens, range(vocab_size))), | ||
"merges": [], | ||
}, | ||
"added_tokens": [], | ||
} | ||
tok = tokenizers.Tokenizer.from_str(json.dumps(tok_data)) | ||
with tempfile.TemporaryDirectory() as p: | ||
tok_path = os.path.join(p, "tokenizer.json") | ||
tok.save(tok_path) | ||
res = LlamaTokenizerFast(tokenizer_file=tok_path) | ||
|
||
res.add_tokens(added_tokens) | ||
return res | ||
|
||
|
||
def check_tokenizer( | ||
expected_size: int, | ||
expected_added_ct: Optional[int] = None, | ||
must_contain: Optional[List[str]] = None, | ||
must_not_contain: Optional[List[str]] = None, | ||
): | ||
def _cb(model_path: str): | ||
tok: LlamaTokenizerFast = LlamaTokenizerFast.from_pretrained(model_path) | ||
vocab = tok.get_vocab() | ||
print(vocab) | ||
assert len(vocab) == expected_size | ||
|
||
if expected_added_ct is not None: | ||
assert len(tok.added_tokens_decoder) == expected_added_ct | ||
|
||
if must_contain: | ||
for tok in must_contain: | ||
assert tok in vocab | ||
|
||
if must_not_contain: | ||
for tok in must_not_contain: | ||
assert tok not in vocab | ||
|
||
return _cb | ||
|
||
|
||
class TestTokenizerMerges: | ||
def test_legacy_mode(self, model_base: str, model_padded: str, model_chatml: str): | ||
config = self.make_config( | ||
[model_base, model_padded, model_chatml], base_model=model_base | ||
) | ||
# when no tokenizer_source is set, expect output tokenizer to be from base_model | ||
run_and_check_merge( | ||
config, validate=check_tokenizer(expected_size=64, expected_added_ct=3) | ||
) | ||
|
||
def test_source_base(self, model_base: str, model_padded: str, model_chatml: str): | ||
config = self.make_config( | ||
[model_base, model_padded, model_chatml], | ||
base_model=model_base, | ||
tokenizer_source="base", | ||
) | ||
# expect the same output but it's a different code path | ||
run_and_check_merge( | ||
config, validate=check_tokenizer(expected_size=64, expected_added_ct=3) | ||
) | ||
|
||
def test_source_union(self, model_base: str, model_padded: str, model_chatml: str): | ||
config = self.make_config( | ||
[model_base, model_padded, model_chatml], | ||
base_model=model_base, | ||
tokenizer_source="union", | ||
) | ||
|
||
# output should have all tokens used by any model | ||
# but not include any unused tokens | ||
run_and_check_merge( | ||
config, | ||
validate=check_tokenizer( | ||
expected_size=66, | ||
expected_added_ct=5, | ||
must_contain=["<|im_start|>", "<|im_end|>"], | ||
must_not_contain=[f"<UNUSED_{idx}>" for idx in range(4)], | ||
), | ||
) | ||
|
||
def test_source_model(self, model_base: str, model_padded: str, model_chatml: str): | ||
config = self.make_config( | ||
[model_base, model_padded, model_chatml], | ||
base_model=model_base, | ||
tokenizer_source="model:" + model_chatml, | ||
) | ||
# tokenizer should match model_chatml | ||
run_and_check_merge( | ||
config, | ||
validate=check_tokenizer( | ||
expected_size=66, must_contain=["<|im_start|>", "<|im_end|>"] | ||
), | ||
) | ||
|
||
def test_slerp_union(self, model_base: str, model_chatml: str): | ||
config = self.make_config( | ||
[model_base, model_chatml], | ||
base_model=model_base, | ||
tokenizer_source="union", | ||
merge_method="slerp", | ||
embed_slerp=True, | ||
t="0.5", | ||
) | ||
|
||
run_and_check_merge( | ||
config, | ||
validate=check_tokenizer( | ||
expected_size=66, | ||
must_contain=["<|im_start|>", "<|im_end|>"], | ||
), | ||
) | ||
|
||
def make_config( | ||
self, | ||
models: List[str], | ||
base_model: Optional[str] = None, | ||
merge_method: str = "linear", | ||
tokenizer_source: Optional[str] = None, | ||
embed_slerp: bool = False, | ||
t: Optional[ParameterSetting] = None, | ||
): | ||
parameters = {"embed_slerp": embed_slerp} | ||
if t is not None: | ||
parameters["t"] = t | ||
|
||
config = MergeConfiguration( | ||
merge_method=merge_method, | ||
base_model=base_model, | ||
models=[ | ||
InputModelDefinition( | ||
model=m, | ||
parameters={"weight": 1.0}, | ||
) | ||
for m in models | ||
], | ||
dtype="bfloat16", | ||
tokenizer_source=tokenizer_source, | ||
parameters=parameters, | ||
) | ||
return config |