Skip to content

Commit

Permalink
Add tokenizer merging tests (#79)
Browse files Browse the repository at this point in the history
Also fixes a bug uncovered by said tests.
  • Loading branch information
cg123 authored Jan 6, 2024
1 parent 1011ef3 commit 190cf2f
Show file tree
Hide file tree
Showing 4 changed files with 255 additions and 48 deletions.
6 changes: 5 additions & 1 deletion mergekit/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,11 @@ def build_union_tokenizer(
if tok not in out_vocab:
out_vocab[tok] = len(out_vocab)

for tok, info in tokenizer.added_tokens_decoder.items():
for tok_idx, info in tokenizer.added_tokens_decoder.items():
tok = info.content
if tok_idx >= vocab_size:
continue

if tok in out_added_tokens:
if (out_added_tokens[tok] != info) and tok not in warned_added_tokens:
logging.warning(
Expand Down
52 changes: 52 additions & 0 deletions tests/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import os
import tempfile
from typing import Callable, Optional

from transformers import LlamaConfig, LlamaForCausalLM

from mergekit.config import MergeConfiguration
from mergekit.io.lazy_tensor_loader import LazyTensorLoader, ShardedTensorIndex
from mergekit.merge import MergeOptions, run_merge


def run_and_check_merge(
config: MergeConfiguration,
check_nan: bool = True,
validate: Optional[Callable[[str], None]] = None,
):
with tempfile.TemporaryDirectory() as tmpdir:
run_merge(config, out_path=tmpdir, options=MergeOptions())
assert os.path.exists(
os.path.join(tmpdir, "model.safetensors.index.json")
), "No index file for merge"
assert os.path.exists(
os.path.join(tmpdir, "config.json")
), "No config json produced by merge"

if check_nan:
# check for NaN in output
loader = LazyTensorLoader(
ShardedTensorIndex.from_disk(tmpdir), lazy_unpickle=False
)
tp = loader.index.tensor_paths
sorted_tensors = sorted(tp.keys(), key=lambda k: tp[k])
for tensor_name in sorted_tensors:
tensor = loader.get_tensor(tensor_name)
has_nan = tensor.view(-1).isnan().any()
assert not has_nan, "Output contains NaN"

if validate:
validate(tmpdir)


def make_picollama(path: str, vocab_size: int = 64):
cfg = LlamaConfig(
vocab_size=vocab_size,
hidden_size=32,
intermediate_size=48,
num_attention_heads=16,
num_hidden_layers=2,
)
model = LlamaForCausalLM(cfg)
model.save_pretrained(path, safe_serialization=True)
return str(path)
56 changes: 9 additions & 47 deletions tests/test_merges.py → tests/test_basic_merges.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import os
import tempfile
from typing import Dict, Optional

import pytest
from transformers import LlamaConfig, LlamaForCausalLM
from common import make_picollama, run_and_check_merge

from mergekit.config import (
InputModelDefinition,
Expand All @@ -12,21 +10,6 @@
OutputSliceDefinition,
ParameterSetting,
)
from mergekit.io.lazy_tensor_loader import LazyTensorLoader, ShardedTensorIndex
from mergekit.merge import MergeOptions, run_merge


def make_picollama(path: str):
cfg = LlamaConfig(
vocab_size=64,
hidden_size=32,
intermediate_size=48,
num_attention_heads=16,
num_hidden_layers=2,
)
model = LlamaForCausalLM(cfg)
model.save_pretrained(path, safe_serialization=True)
return str(path)


@pytest.fixture(scope="session")
Expand All @@ -44,14 +27,14 @@ def model_c(tmp_path_factory):
return make_picollama(tmp_path_factory.mktemp("model_c"))


class TestMerges:
class TestBasicMerges:
def test_gpt2_copy(self):
config = MergeConfiguration(
merge_method="passthrough",
models=[InputModelDefinition(model="gpt2")],
dtype="bfloat16",
)
self.run_and_check_merge(config)
run_and_check_merge(config)

def test_gpt2_stack(self):
config = MergeConfiguration(
Expand All @@ -64,24 +47,24 @@ def test_gpt2_stack(self):
],
dtype="bfloat16",
)
self.run_and_check_merge(config)
run_and_check_merge(config)

def test_linear_merge(self, model_a, model_b):
config = self.two_model_config(model_a, model_b, merge_method="linear")
self.run_and_check_merge(config)
run_and_check_merge(config)

def test_slerp_merge(self, model_a, model_b):
config = self.two_model_config(
model_a, model_b, merge_method="slerp", base_model=model_a
)
config.parameters = {"t": 0.35}
self.run_and_check_merge(config)
run_and_check_merge(config)

def test_task_arithmetic_merge(self, model_a, model_b, model_c):
config = self.two_model_config(
model_a, model_b, merge_method="task_arithmetic", base_model=model_c
)
self.run_and_check_merge(config)
run_and_check_merge(config)

def test_ties_merge(self, model_a, model_b, model_c):
config = self.two_model_config(
Expand All @@ -91,7 +74,7 @@ def test_ties_merge(self, model_a, model_b, model_c):
base_model=model_c,
params={"density": 0.3},
)
self.run_and_check_merge(config)
run_and_check_merge(config)

def test_dare_ties_merge(self, model_a, model_b, model_c):
config = self.two_model_config(
Expand All @@ -101,28 +84,7 @@ def test_dare_ties_merge(self, model_a, model_b, model_c):
base_model=model_c,
params={"density": 0.66},
)
self.run_and_check_merge(config)

def run_and_check_merge(self, config: MergeConfiguration):
with tempfile.TemporaryDirectory() as tmpdir:
run_merge(config, out_path=tmpdir, options=MergeOptions())
assert os.path.exists(
os.path.join(tmpdir, "model.safetensors.index.json")
), "No index file for merge"
assert os.path.exists(
os.path.join(tmpdir, "config.json")
), "No config json produced by merge"

# check for NaN in output
loader = LazyTensorLoader(
ShardedTensorIndex.from_disk(tmpdir), lazy_unpickle=False
)
tp = loader.index.tensor_paths
sorted_tensors = sorted(tp.keys(), key=lambda k: tp[k])
for tensor_name in sorted_tensors:
tensor = loader.get_tensor(tensor_name)
has_nan = tensor.view(-1).isnan().any()
assert not has_nan, "Output contains NaN"
run_and_check_merge(config)

def two_model_config(
self,
Expand Down
189 changes: 189 additions & 0 deletions tests/test_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
import json
import os
import tempfile
from typing import List, Optional, Union

import pytest
import tokenizers
from common import make_picollama, run_and_check_merge
from transformers import LlamaTokenizerFast, PreTrainedTokenizerBase

from mergekit.config import InputModelDefinition, MergeConfiguration, ParameterSetting


@pytest.fixture(scope="session")
def model_base(tmp_path_factory):
model_path = make_picollama(tmp_path_factory.mktemp("model_base"), vocab_size=64)
make_tokenizer(vocab_size=64, added_tokens=[]).save_pretrained(model_path)
return model_path


@pytest.fixture(scope="session")
def model_chatml(tmp_path_factory):
model_path = make_picollama(tmp_path_factory.mktemp("model_chatml"), vocab_size=66)
make_tokenizer(
vocab_size=64, added_tokens=["<|im_start|>", "<|im_end|>"]
).save_pretrained(model_path)
return model_path


@pytest.fixture(scope="session")
def model_padded(tmp_path_factory):
model_path = make_picollama(tmp_path_factory.mktemp("model_padded"), vocab_size=64)
make_tokenizer(
vocab_size=64,
added_tokens=["<UNUSED_0>", "<UNUSED_1>", "<UNUSED_2>", "<UNUSED_3>"],
).save_pretrained(model_path)
return model_path


def make_tokenizer(
vocab_size: int, added_tokens: List[Union[str, tokenizers.AddedToken]]
) -> PreTrainedTokenizerBase:
tokens = ["<unk>", "<s>", "</s>"] + [f"_tok_{idx}" for idx in range(3, vocab_size)]
tokens = tokens[:vocab_size]
tok_data = {
"version": "1.0",
"model": {
"type": "BPE",
"vocab": dict(zip(tokens, range(vocab_size))),
"merges": [],
},
"added_tokens": [],
}
tok = tokenizers.Tokenizer.from_str(json.dumps(tok_data))
with tempfile.TemporaryDirectory() as p:
tok_path = os.path.join(p, "tokenizer.json")
tok.save(tok_path)
res = LlamaTokenizerFast(tokenizer_file=tok_path)

res.add_tokens(added_tokens)
return res


def check_tokenizer(
expected_size: int,
expected_added_ct: Optional[int] = None,
must_contain: Optional[List[str]] = None,
must_not_contain: Optional[List[str]] = None,
):
def _cb(model_path: str):
tok: LlamaTokenizerFast = LlamaTokenizerFast.from_pretrained(model_path)
vocab = tok.get_vocab()
print(vocab)
assert len(vocab) == expected_size

if expected_added_ct is not None:
assert len(tok.added_tokens_decoder) == expected_added_ct

if must_contain:
for tok in must_contain:
assert tok in vocab

if must_not_contain:
for tok in must_not_contain:
assert tok not in vocab

return _cb


class TestTokenizerMerges:
def test_legacy_mode(self, model_base: str, model_padded: str, model_chatml: str):
config = self.make_config(
[model_base, model_padded, model_chatml], base_model=model_base
)
# when no tokenizer_source is set, expect output tokenizer to be from base_model
run_and_check_merge(
config, validate=check_tokenizer(expected_size=64, expected_added_ct=3)
)

def test_source_base(self, model_base: str, model_padded: str, model_chatml: str):
config = self.make_config(
[model_base, model_padded, model_chatml],
base_model=model_base,
tokenizer_source="base",
)
# expect the same output but it's a different code path
run_and_check_merge(
config, validate=check_tokenizer(expected_size=64, expected_added_ct=3)
)

def test_source_union(self, model_base: str, model_padded: str, model_chatml: str):
config = self.make_config(
[model_base, model_padded, model_chatml],
base_model=model_base,
tokenizer_source="union",
)

# output should have all tokens used by any model
# but not include any unused tokens
run_and_check_merge(
config,
validate=check_tokenizer(
expected_size=66,
expected_added_ct=5,
must_contain=["<|im_start|>", "<|im_end|>"],
must_not_contain=[f"<UNUSED_{idx}>" for idx in range(4)],
),
)

def test_source_model(self, model_base: str, model_padded: str, model_chatml: str):
config = self.make_config(
[model_base, model_padded, model_chatml],
base_model=model_base,
tokenizer_source="model:" + model_chatml,
)
# tokenizer should match model_chatml
run_and_check_merge(
config,
validate=check_tokenizer(
expected_size=66, must_contain=["<|im_start|>", "<|im_end|>"]
),
)

def test_slerp_union(self, model_base: str, model_chatml: str):
config = self.make_config(
[model_base, model_chatml],
base_model=model_base,
tokenizer_source="union",
merge_method="slerp",
embed_slerp=True,
t="0.5",
)

run_and_check_merge(
config,
validate=check_tokenizer(
expected_size=66,
must_contain=["<|im_start|>", "<|im_end|>"],
),
)

def make_config(
self,
models: List[str],
base_model: Optional[str] = None,
merge_method: str = "linear",
tokenizer_source: Optional[str] = None,
embed_slerp: bool = False,
t: Optional[ParameterSetting] = None,
):
parameters = {"embed_slerp": embed_slerp}
if t is not None:
parameters["t"] = t

config = MergeConfiguration(
merge_method=merge_method,
base_model=base_model,
models=[
InputModelDefinition(
model=m,
parameters={"weight": 1.0},
)
for m in models
],
dtype="bfloat16",
tokenizer_source=tokenizer_source,
parameters=parameters,
)
return config

0 comments on commit 190cf2f

Please sign in to comment.