Skip to content

Commit

Permalink
Add back safe_serialization flag
Browse files Browse the repository at this point in the history
  • Loading branch information
cg123 committed Jan 14, 2024
1 parent 1316bd9 commit 53e852b
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 2 deletions.
7 changes: 6 additions & 1 deletion mergekit/io/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,17 @@ def execute(self, **kwargs) -> Dict[ModelReference, Tensor]:
class TensorWriterTask(Task[TensorWriter]):
out_path: str
max_shard_size: int
safe_serialization: bool = True

def arguments(self) -> Dict[str, Task]:
return {}

def execute(self, **_kwargs) -> TensorWriter:
return TensorWriter(self.out_path, self.max_shard_size)
return TensorWriter(
self.out_path,
max_shard_size=self.max_shard_size,
safe_serialization=self.safe_serialization,
)


class SaveTensor(Task[None]):
Expand Down
1 change: 1 addition & 0 deletions mergekit/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class MergeOptions(BaseModel):
random_seed: Optional[int] = None
lazy_unpickle: bool = False
write_model_card: bool = True
safe_serialization: bool = True


OPTION_HELP = {
Expand Down
4 changes: 3 additions & 1 deletion mergekit/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ def __init__(
self.trust_remote_code = options.trust_remote_code
self._method = merge_methods.get(config.merge_method)
self._writer_task = TensorWriterTask(
out_path=out_path, max_shard_size=options.out_shard_size
out_path=out_path,
max_shard_size=options.out_shard_size,
safe_serialization=options.safe_serialization,
)

if config.tokenizer_source:
Expand Down

0 comments on commit 53e852b

Please sign in to comment.