Skip to content

Commit

Permalink
Make work
Browse files Browse the repository at this point in the history
  • Loading branch information
cg123 committed Dec 12, 2023
1 parent f3e328e commit 49ad236
Show file tree
Hide file tree
Showing 16 changed files with 176 additions and 109 deletions.
54 changes: 52 additions & 2 deletions mergekit/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,18 @@
import logging
import os
import os.path
from typing import List, Optional, Union
from typing import Generic, Iterator, List, Optional, Tuple, Union

import huggingface_hub
import immutables
import numpy as np
import peft
import pydantic
import torch
import transformers
from pydantic import BaseModel
from pydantic import BaseModel, SerializerFunctionWrapHandler
from transformers import AutoConfig, PretrainedConfig
from typing_extensions import TypeVar

from mergekit.io import ShardedTensorIndex

Expand Down Expand Up @@ -172,3 +175,50 @@ def parse_kmb(value: Union[str, int]) -> int:
return int(value[:-1]) * 1000 * 1000 * 1000
else:
raise ValueError(value)


class MergeOptions(BaseModel):
allow_crimes: bool = False
transformers_cache: Optional[str] = None
lora_merge_cache: Optional[str] = None
cuda: bool = False
low_cpu_memory: bool = False
out_shard_size: int = parse_kmb("5B")
copy_tokenizer: bool = True
allow_crimes: bool = False
clone_tensors: bool = False
trust_remote_code: bool = False
random_seed: Optional[int] = None
lazy_unpickle: bool = False


T_K = TypeVar("KT")
T_V = TypeVar("VT")


class ImmutableMap(
Generic[T_K, T_V], BaseModel, frozen=True, arbitrary_types_allowed=True
):
data: immutables.Map[T_K, T_V]

@pydantic.validator("data", pre=True)
def validate_data(cls, data):
return immutables.Map(data)

@pydantic.field_serializer("data", mode="wrap")
def serialize_data(
self,
data: immutables.Map[T_K, T_V],
nxt: SerializerFunctionWrapHandler,
# info: SerializationInfo,
):
return nxt(dict(data.items()))

def __iter__(self):
return self.data.__iter__()

def __getitem__(self, key: T_K) -> T_V:
return self.data[key]

def items(self) -> Iterator[Tuple[T_K, T_V]]:
return self.data.items()
30 changes: 26 additions & 4 deletions mergekit/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,18 +130,40 @@ def base_model(self) -> Optional[ModelReference]:
return None

def for_out_slice(self, slice: OutputSliceDefinition) -> "ConfigReader":
return ConfigReader(self.model_dump(exclude=["slice_out"]), slice_out=slice)
return ConfigReader(
config=self.config,
t=self.t,
tensor_name=self.tensor_name,
slice_out=slice,
slices_in=self.slices_in,
)

def for_in_slices(self, slices: List[InputSliceDefinition]) -> "ConfigReader":
return ConfigReader(self.model_dump(exclude=["slices_in"]), slices_in=slices)
return ConfigReader(
config=self.config,
t=self.t,
tensor_name=self.tensor_name,
slice_out=self.slice_out,
slices_in=slices,
)

def for_tensor(self, tensor_name: str) -> "ConfigReader":
return ConfigReader(
self.model_dump(exclude=["tensor_name"]), tensor_name=tensor_name
config=self.config,
t=self.t,
tensor_name=tensor_name,
slice_out=self.slice_out,
slices_in=self.slices_in,
)

def with_t(self, t: float) -> "ConfigReader":
return ConfigReader(self.model_dump(exclude=["t"]), t=t)
return ConfigReader(
config=self.config,
t=t,
tensor_name=self.tensor_name,
slice_out=self.slice_out,
slices_in=self.slices_in,
)

def parameter(
self,
Expand Down
6 changes: 3 additions & 3 deletions mergekit/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,13 +155,13 @@ def run(self) -> Iterator[Tuple[Task, Any]]:
last_use_index[task] = j

values: Dict[Task, Any] = {}
for idx, task in tqdm.tqdm(enumerate(self.schedule)):
for idx, task in tqdm.tqdm(enumerate(self.schedule), total=len(self.schedule)):
arguments = {}
for dep in self.dependencies[task]:
for name, dep in task.arguments().items():
value = values[dep]
if isinstance(value, torch.Tensor) and value.device != self.math_device:
value = value.to(self.math_device)
arguments[dep] = value
arguments[name] = value

res = task.execute(**arguments)
del arguments
Expand Down
4 changes: 3 additions & 1 deletion mergekit/io/tensor_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def flush_current_shard(self):
if not self.current_shard:
return

logging.info(f"writing shard #{self.shards_written+1} to disk")
logging.info(f"Writing shard #{self.shards_written+1} to disk")

shard_name = f"model-{self.shards_written+1}.safetensors"
for key in self.current_shard:
Expand All @@ -77,6 +77,8 @@ def flush_current_shard(self):
def finalize(self):
self.flush_current_shard()

logging.info("Finalizing shard names")

# standardize shard names to hf format
total_shards = self.shards_written
name_remap = {}
Expand Down
27 changes: 6 additions & 21 deletions mergekit/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,33 +14,17 @@
# along with this program. If not, see http://www.gnu.org/licenses/.

import logging
from typing import Optional

import tqdm
import transformers
from pydantic import BaseModel

from mergekit.architecture import get_architecture_info
from mergekit.common import ModelReference, parse_kmb
from mergekit.common import MergeOptions, ModelReference
from mergekit.config import MergeConfiguration
from mergekit.graph import Executor
from mergekit.plan import MergePlanner
from mergekit.tasks import LoaderCache, TokenizerInfo


class MergeOptions(BaseModel):
allow_crimes: bool = False
transformers_cache: Optional[str] = None
lora_merge_cache: Optional[str] = None
cuda: bool = False
low_cpu_memory: bool = False
out_shard_size: int = parse_kmb("5B")
copy_tokenizer: bool = True
allow_crimes: bool = False
clone_tensors: bool = False
trust_remote_code: bool = False
random_seed: Optional[int] = None
lazy_unpickle: bool = False
from mergekit.tasks import LoaderCache
from mergekit.tokenizer import TokenizerInfo


def run_merge(merge_config: MergeConfiguration, out_path: str, options: MergeOptions):
Expand Down Expand Up @@ -70,8 +54,7 @@ def run_merge(merge_config: MergeConfiguration, out_path: str, options: MergeOpt
merge_config,
arch_info,
out_path=out_path,
max_shard_size=options.out_shard_size,
clone_tensors=options.clone_tensors,
options=options,
).plan()

# warm up loader cache
Expand Down Expand Up @@ -112,12 +95,14 @@ def run_merge(merge_config: MergeConfiguration, out_path: str, options: MergeOpt
"Unable to set number of layers in output config - you may need to manually correct it.",
exc_info=e,
)
logging.info("Saving config")
cfg_out.save_pretrained(out_path)

if tokenizer is None and options.copy_tokenizer:
tokenizer = _get_donor_tokenizer(merge_config)

if tokenizer:
logging.info("Saving tokenizer")
tokenizer.save_pretrained(out_path, safe_serialization=True)


Expand Down
8 changes: 4 additions & 4 deletions mergekit/merge_methods/generalized_task_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from pydantic import BaseModel
from typing_extensions import Literal

from mergekit.common import ModelReference
from mergekit.common import ImmutableMap, ModelReference
from mergekit.graph import Task
from mergekit.merge_methods.base import ConfigParameterDef, MergeMethod
from mergekit.sparsify import SparsificationMethod, sparsify
Expand Down Expand Up @@ -64,7 +64,7 @@ def make_task(
method=self,
tensors=tensors,
base_model=base_model,
tensor_parameters=tensor_parameters,
tensor_parameters=ImmutableMap(data=tensor_parameters),
int8_mask=parameters["int8_mask"],
normalize=parameters["normalize"],
out_tensor_name=output_tensor_name,
Expand All @@ -76,7 +76,7 @@ class GTATask(Task[torch.Tensor]):
tensors: GatherTensors
base_model: ModelReference
out_tensor_name: str
tensor_parameters: Dict[ModelReference, Any]
tensor_parameters: ImmutableMap[ModelReference, Any]
int8_mask: bool
normalize: bool

Expand All @@ -93,7 +93,7 @@ def execute(
self.out_tensor_name,
self.base_model,
tensors,
tensor_parameters=self.tensor_parameters,
tensor_parameters=self.tensor_parameters.data,
)
if not tvs:
return base
Expand Down
6 changes: 3 additions & 3 deletions mergekit/merge_methods/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@
import torch
from torch._tensor import Tensor

from mergekit.common import ModelReference, rectify_embed_sizes
from mergekit.common import ImmutableMap, ModelReference, rectify_embed_sizes
from mergekit.graph import Task
from mergekit.merge_methods.base import ConfigParameterDef, MergeMethod
from mergekit.tasks import GatherTensors


class LinearMergeTask(Task[torch.Tensor]):
gather_tensors: GatherTensors
tensor_parameters: Dict[ModelReference, Dict[str, Any]]
tensor_parameters: ImmutableMap[ModelReference, Dict[str, Any]]
normalize: bool
parameter_name: str

Expand Down Expand Up @@ -79,7 +79,7 @@ def make_task(
) -> Task:
return LinearMergeTask(
gather_tensors=tensors,
tensor_parameters=tensor_parameters,
tensor_parameters=ImmutableMap(data=tensor_parameters),
normalize=parameters["normalize"],
parameter_name=output_tensor_name,
)
2 changes: 2 additions & 0 deletions mergekit/merge_methods/passthrough.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def execute(self, tensors: Dict[ModelReference, torch.Tensor]) -> Tensor:
if self.scale is not None:
res *= self.scale

return res


class PassthroughMerge(MergeMethod):
def parameters(self) -> List[ConfigParameterDef]:
Expand Down
9 changes: 5 additions & 4 deletions mergekit/merge_methods/tokenizer_permute.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@
from pydantic import BaseModel
from torch._tensor import Tensor

from mergekit.common import ModelReference
from mergekit.common import ImmutableMap, ModelReference
from mergekit.graph import Task
from mergekit.merge_methods.base import ConfigParameterDef, MergeMethod
from mergekit.merge_methods.slerp import slerp
from mergekit.tasks import BuildTokenizer, GatherTensors, TokenizerInfo
from mergekit.tasks import GatherTensors
from mergekit.tokenizer import BuildTokenizer, TokenizerInfo


class TokenizerPermutationMergeTask(Task[torch.Tensor]):
Expand All @@ -32,7 +33,7 @@ class TokenizerPermutationMergeTask(Task[torch.Tensor]):
base_model: Optional[ModelReference]
use_slerp: bool
slerp_t: float
tensor_parameters: Dict[ModelReference, Any]
tensor_parameters: ImmutableMap[ModelReference, Any]

def arguments(self) -> Dict[str, Task]:
return {"tokenizer_info": self.tokenizer_task, "tensors": self.gather_tensors}
Expand Down Expand Up @@ -132,5 +133,5 @@ def make_task(
gather_tensors=tensors,
use_slerp=parameters["embed_slerp"],
slerp_t=parameters["t"],
tensor_parameters=tensor_parameters,
tensor_parameters=ImmutableMap(data=tensor_parameters),
)
Loading

0 comments on commit 49ad236

Please sign in to comment.