You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have been investigating if it is possible to store checkpoints in a format that does not depend on pickle, as pickles are inherently unsafe as they allow for arbitrary code execution. The .pt file format created by torch.save uses pickle internally.
Would this functionality be something that you would be interested in? If so I can post a PR with my implementation for consideration.
Implementation
My solution is a custom .safe format that is a .tar archive containing:
To save a state dictionary, I pre-process it so anything not json-serializable gets converted to a dictionary like {"__safe_obj_type": "tensor", "id": "123"} for example. The __safe_obj_type key is then used during loading to restore the object back to what it was before. The gathered tensors and numpy arrays are stored as safetensors and npz respectively and anything else is stored in the json file.
At least for the state-dict I was trying to serialize I found that only a few non-json-serializable data types needed to be supported: torch.tensor, np.ndarray, datetime.timedelta, tuple Though a few others may be needed if this was implemented beyond a proof of concept.
The text was updated successfully, but these errors were encountered:
@mbway Thanks for flagging this -- we're currently actively working on a major redesign of checkpointing that will include migrating to a safer format. We've been hesitant to change how checkpointing works as we deeply care about backwards compatibility, so we have been buffering up a long list of features we will now include in the revamp. You should start seeing things roll out this month.
CC: @eracah it might be good to publicly share a version of the roadmap as a RFC when its ready
If you already have a PR, please feel free to post it -- it's probably helpful for us to view a design and gain some inspiration. Given the refactor though, we likely wouldn't accept it as this seems like a big change.
🚀 Feature Request
I have been investigating if it is possible to store checkpoints in a format that does not depend on pickle, as pickles are inherently unsafe as they allow for arbitrary code execution. The
.pt
file format created bytorch.save
uses pickle internally.Would this functionality be something that you would be interested in? If so I can post a PR with my implementation for consideration.
Implementation
My solution is a custom
.safe
format that is a.tar
archive containing:To save a state dictionary, I pre-process it so anything not json-serializable gets converted to a dictionary like
{"__safe_obj_type": "tensor", "id": "123"}
for example. The__safe_obj_type
key is then used during loading to restore the object back to what it was before. The gathered tensors and numpy arrays are stored assafetensors
andnpz
respectively and anything else is stored in the json file.At least for the state-dict I was trying to serialize I found that only a few non-json-serializable data types needed to be supported:
torch.tensor
,np.ndarray
,datetime.timedelta
,tuple
Though a few others may be needed if this was implemented beyond a proof of concept.The text was updated successfully, but these errors were encountered: