Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge arbitrary pytorch models #335

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mergekit/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class ConditionalParameter(BaseModel):

def evaluate_setting(
tensor_name: str, setting: ParameterSetting, t: float = 0
) -> float:
) -> Optional[float]:
if isinstance(setting, (float, int, bool, str)):
return setting
elif isinstance(setting, list):
Expand Down
51 changes: 31 additions & 20 deletions mergekit/io/lazy_tensor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,30 +76,41 @@ def from_disk(cls, base_path: str) -> "ShardedTensorIndex":
)
shards.append(info)

elif os.path.exists(model_path):
shard_name = os.path.basename(model_path)

# get list of tensors contained in single-file checkpoint
if model_path.lower().endswith(".safetensors"):
with safetensors.safe_open(model_path, framework="pt") as st:
tensor_paths = {key: shard_name for key in st.keys()}
else:
# this is ugly but not much else can be done
shard = torch.load(model_path, map_location="meta")
if "state_dict" in shard:
shard = shard["state_dict"]

tensor_paths = {key: shard_name for key in shard}

shards.append(
ShardInfo(os.path.basename(model_path), list(tensor_paths.keys()))
return ShardedTensorIndex(
base_path=base_path,
is_safetensors=is_safetensors,
tensor_paths=tensor_paths,
shards=shards,
)

elif os.path.exists(model_path):
return ShardedTensorIndex.from_file(model_path)

else:
raise RuntimeError(f"Unable to find model files at {base_path}")

@classmethod
def from_file(cls, file_path: str) -> "ShardedTensorIndex":
if not os.path.exists(file_path):
raise FileNotFoundError(file_path)

lower = file_path.lower()
shard_name = os.path.basename(file_path)
if lower.endswith(".safetensors"):
with safetensors.safe_open(file_path, framework="pt") as st:
tensor_paths = {key: shard_name for key in st.keys()}
else:
shard = torch.load(file_path, map_location="meta")
if "state_dict" in shard:
shard = shard["state_dict"]

tensor_paths = {key: shard_name for key in shard}

return ShardedTensorIndex(
base_path=base_path,
is_safetensors=is_safetensors,
base_path=os.path.dirname(file_path),
is_safetensors=lower.endswith(".safetensors"),
tensor_paths=tensor_paths,
shards=shards,
shards=[ShardInfo(shard_name, list(tensor_paths.keys()))],
)


Expand Down
6 changes: 6 additions & 0 deletions mergekit/io/tensor_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,18 @@ def finalize(self):
f"{prefix}-{idx+1}.{extension}"
] = f"{prefix}-{idx+1:05d}-of-{total_shards:05d}.{extension}"

if total_shards < 2:
name_remap[f"{prefix}-1.{extension}"] = f"{prefix}.{extension}"

for old_name, new_name in name_remap.items():
os.rename(
os.path.join(self.out_path, old_name),
os.path.join(self.out_path, new_name),
)

if total_shards < 2:
return

for key in self.weight_map:
self.weight_map[key] = name_remap[self.weight_map[key]]

Expand Down
23 changes: 21 additions & 2 deletions mergekit/merge_methods/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
# along with this program. If not, see http://www.gnu.org/licenses/.

from abc import ABC, abstractmethod
from typing import Any, List, Optional, Union
from typing import Any, Dict, List, Optional, Union

import torch
from pydantic import BaseModel
from typing_extensions import TypeAlias

Expand All @@ -25,7 +26,25 @@
from mergekit.io.tasks import GatherTensors
from mergekit.tokenizer import PermutedEmbeddings

MergeTensorInput: TypeAlias = Union[GatherTensors, PermutedEmbeddings]

class TensorDictWrapper(Task[Dict[ModelReference, torch.Tensor]]):
tensors: ImmutableMap[ModelReference, Task[torch.Tensor]]

def arguments(self) -> Dict[str, Task]:
return {
k.model_dump_json(
exclude_none=True, exclude_defaults=True, round_trip=True
): v
for k, v in self.tensors.items()
}

def execute(self, **kwargs) -> Dict[ModelReference, torch.Tensor]:
return {ModelReference.model_validate_json(k): v for k, v in kwargs.items()}


MergeTensorInput: TypeAlias = Union[
GatherTensors, PermutedEmbeddings, TensorDictWrapper
]


class ConfigParameterDef(BaseModel):
Expand Down
266 changes: 266 additions & 0 deletions mergekit/scripts/merge_raw_pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
# Copyright (C) 2024 Charles O. Goddard
#
# This software is free software: you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This software is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see http://www.gnu.org/licenses/.

import logging
from typing import Dict, List, Optional

import click
import torch
import tqdm
import yaml
from pydantic import BaseModel

import mergekit.merge_methods as merge_methods
from mergekit.architecture import WeightInfo
from mergekit.common import ImmutableMap, ModelReference, dtype_from_name
from mergekit.config import ParameterSetting, evaluate_setting
from mergekit.graph import Executor, Task
from mergekit.io import LazyTensorLoader, ShardedTensorIndex
from mergekit.io.tasks import FinalizeModel, SaveTensor, TensorWriterTask
from mergekit.merge_methods.base import MergeMethod, TensorDictWrapper
from mergekit.options import MergeOptions, add_merge_options


class InputModelDefinition(BaseModel, frozen=True):
model: str
parameters: Optional[Dict[str, ParameterSetting]] = None


class RawPyTorchMergeConfig(BaseModel, frozen=True):
merge_method: str
parameters: Optional[Dict[str, ParameterSetting]]
models: List[InputModelDefinition]
dtype: Optional[str] = None
base_model: Optional[str] = None


class SimpleLoaderCache:
loaders: Dict[str, LazyTensorLoader]
lazy_unpickle: bool = False
_instance: Optional["SimpleLoaderCache"] = None

def __new__(cls) -> "SimpleLoaderCache":
if cls._instance is None:
cls._instance = super(SimpleLoaderCache, cls).__new__(cls)
cls._instance.loaders = {}
return cls._instance

def get(self, model: str) -> LazyTensorLoader:
if model not in self.loaders:
self.loaders[model] = LazyTensorLoader(
ShardedTensorIndex.from_file(model), lazy_unpickle=self.lazy_unpickle
)
return self.loaders[model]


class SimpleLoadTensor(Task[torch.Tensor]):
model: str
tensor_name: str
dtype: Optional[str] = None
device: Optional[str] = None

def arguments(self) -> Dict[str, Task]:
return {}

def execute(self) -> torch.Tensor:
loader = SimpleLoaderCache().get(self.model)
tensor = loader.get_tensor(self.tensor_name, device=self.device or "cpu")
if tensor is None:
return None
if dt := dtype_from_name(self.dtype):
tensor = tensor.to(dtype=dt)
return tensor


def plan_flat_merge(
config: RawPyTorchMergeConfig,
out_path: str,
tensor_union: bool,
tensor_intersection: bool,
options: MergeOptions,
) -> List[Task[torch.Tensor]]:
merge_method = merge_methods.get(config.merge_method)

loaders = SimpleLoaderCache()
loaders.lazy_unpickle = options.lazy_unpickle
all_tensor_names = set()
for model_def in tqdm.tqdm(config.models, desc="Preparing model loaders"):
loader = loaders.get(model_def.model)
all_tensor_names.update(loader.index.tensor_paths.keys())

writer_task = TensorWriterTask(
out_path=out_path,
max_shard_size=options.out_shard_size,
safe_serialization=options.safe_serialization,
)

save_tasks = []
for tensor_name in tqdm.tqdm(list(all_tensor_names), desc="Planning operations"):
inputs = {
model_def.model: SimpleLoadTensor(
model=model_def.model, tensor_name=tensor_name, dtype=config.dtype
)
for model_def in config.models
}
if config.base_model is not None and config.base_model not in inputs:
inputs[config.base_model] = SimpleLoadTensor(
model=config.base_model, tensor_name=tensor_name, dtype=config.dtype
)

has_tensor = [
lt.model
for lt in inputs.values()
if lt.tensor_name in loaders.get(lt.model).index.tensor_paths
]
if len(has_tensor) < len(inputs):
if tensor_intersection:
continue
elif tensor_union:
pass
else:
missing = set(inputs) - set(has_tensor)
logging.warning(f"Tensor {tensor_name} not found in models:")
for model in missing:
logging.warning(f" {model}")
logging.warning("Was found in:")
for model in has_tensor:
logging.warning(f" {model}")
raise RuntimeError("Missing tensors")

inputs = {
ModelReference.model_validate({"model": {"path": k}}): v
for k, v in inputs.items()
}

global_params, tensor_params = construct_param_dicts(
config, merge_method, tensor_name
)

tensor_task = merge_method.make_task(
output_weight=WeightInfo(name=tensor_name),
tensors=TensorDictWrapper(tensors=inputs),
parameters=ImmutableMap(global_params),
tensor_parameters=ImmutableMap(
data={
key: ImmutableMap(data=tensor_params[key]) for key in tensor_params
}
),
base_model=(
ModelReference.model_validate({"model": {"path": config.base_model}})
if config.base_model is not None
else None
),
)
save_task = SaveTensor(
tensor_name=tensor_name,
tensor_task=tensor_task,
writer_task=writer_task,
clone=options.clone_tensors,
dtype=config.dtype,
)
save_tasks.append(save_task)

finalize = FinalizeModel(tensor_save_tasks=save_tasks, writer_task=writer_task)
return save_tasks + [finalize]


def construct_param_dicts(
config: RawPyTorchMergeConfig, merge_method: MergeMethod, tensor_name: str
):
global_params = {}
for param_def in merge_method.parameters():
if param_def.name in config.parameters:
value = evaluate_setting(tensor_name, config.parameters[param_def.name])
if value is not None:
global_params[param_def.name] = value

if param_def.name not in global_params:
if param_def.required:
raise RuntimeError(
f"Missing required parameter {param_def.name} for merge method {merge_method}"
)
else:
global_params[param_def.name] = param_def.default_value

tensor_params = {}
for param_def in merge_method.tensor_parameters():
for model_def in config.models:
mr = ModelReference.model_validate({"model": {"path": model_def.model}})
tensor_params[mr] = tensor_params.get(mr, {})
if value := evaluate_setting(
tensor_name, model_def.parameters.get(param_def.name, [])
):
tensor_params[mr][param_def.name] = value
elif value := evaluate_setting(
tensor_name, config.parameters.get(param_def.name, [])
):
tensor_params[mr][param_def.name] = value
elif param_def.required:
raise RuntimeError(
f"Missing required parameter {param_def.name} for model {mr} tensor {tensor_name}"
)
else:
tensor_params[mr][param_def.name] = param_def.default_value
return global_params, tensor_params


@click.command("mergekit-pytorch")
@click.argument("config_path", type=click.Path(exists=True))
@click.argument("out_path", type=click.Path())
@click.option(
"--tensor-intersection",
"-i",
type=bool,
default=False,
is_flag=True,
help="Only merge tensors that are present in all input models",
)
@click.option(
"--tensor-union",
"-u",
type=bool,
default=False,
is_flag=True,
help="Merge all tensors present in any input model",
)
@add_merge_options
def main(
config_path: str,
out_path: str,
tensor_union: bool,
tensor_intersection: bool,
merge_options: MergeOptions,
):
"""Merge arbitrary PyTorch models.

Uses similar configuration syntax to `mergekit-yaml`, minus the
`slices` sections. Each input model should be the path on disk to a
pytorch pickle file or safetensors file."""
with open(config_path, "r", encoding="utf-8") as file:
config_source = file.read()

config = RawPyTorchMergeConfig.model_validate(yaml.safe_load(config_source))
tasks = plan_flat_merge(
config, out_path, tensor_union, tensor_intersection, merge_options
)

executor = Executor(
tasks,
math_device="cuda" if merge_options.cuda else "cpu",
storage_device=(
"cuda" if (merge_options.cuda and merge_options.low_cpu_memory) else "cpu"
),
)
executor.execute()
Loading
Loading