From 53e852b7a97566c6d13e6a9c7cec469f1637721a Mon Sep 17 00:00:00 2001 From: Charles Goddard Date: Sat, 13 Jan 2024 18:57:05 -0800 Subject: [PATCH] Add back safe_serialization flag --- mergekit/io/tasks.py | 7 ++++++- mergekit/options.py | 1 + mergekit/plan.py | 4 +++- 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/mergekit/io/tasks.py b/mergekit/io/tasks.py index b4a71866..cae703aa 100644 --- a/mergekit/io/tasks.py +++ b/mergekit/io/tasks.py @@ -94,12 +94,17 @@ def execute(self, **kwargs) -> Dict[ModelReference, Tensor]: class TensorWriterTask(Task[TensorWriter]): out_path: str max_shard_size: int + safe_serialization: bool = True def arguments(self) -> Dict[str, Task]: return {} def execute(self, **_kwargs) -> TensorWriter: - return TensorWriter(self.out_path, self.max_shard_size) + return TensorWriter( + self.out_path, + max_shard_size=self.max_shard_size, + safe_serialization=self.safe_serialization, + ) class SaveTensor(Task[None]): diff --git a/mergekit/options.py b/mergekit/options.py index aae8d9de..afb7229c 100644 --- a/mergekit/options.py +++ b/mergekit/options.py @@ -36,6 +36,7 @@ class MergeOptions(BaseModel): random_seed: Optional[int] = None lazy_unpickle: bool = False write_model_card: bool = True + safe_serialization: bool = True OPTION_HELP = { diff --git a/mergekit/plan.py b/mergekit/plan.py index 5fa869c8..53e4ca21 100644 --- a/mergekit/plan.py +++ b/mergekit/plan.py @@ -57,7 +57,9 @@ def __init__( self.trust_remote_code = options.trust_remote_code self._method = merge_methods.get(config.merge_method) self._writer_task = TensorWriterTask( - out_path=out_path, max_shard_size=options.out_shard_size + out_path=out_path, + max_shard_size=options.out_shard_size, + safe_serialization=options.safe_serialization, ) if config.tokenizer_source: