-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Error while converting checkpoints to Flax format #17
Comments
Same issue here, @yaraksen were you able to get past the error? |
@PootieT, unfortunately not, waiting for the author's response |
@yaraksen @staticmethod
def load_checkpoint(path, target=None, shard_fns=None, remove_dict_prefix=None):
if shard_fns is not None:
shard_fns = flatten_dict(
to_state_dict(shard_fns)
)
if remove_dict_prefix is not None:
remove_dict_prefix = tuple(remove_dict_prefix)
flattend_train_state = {}
with utils.open_file(path) as fin:
# 83886080 bytes = 80 MB, which is 16 blocks on GCS
# unpacker = msgpack.Unpacker(fin, read_size=83886080, max_buffer_size=0,
# use_list=False)
# for key, value in unpacker:
# key = tuple(key) # not sure why this is there
# TODO: bug here where unpacker is returning stream of integers, but
# code is expecting tuple of key/value pair of parameter name/values
# this is not save-all solution but instead we load with torch.load()
# all into memory, then iterate through
weight_dict = torch.load(fin)
for key, value in weight_dict.items():
if remove_dict_prefix is not None:
if key[:len(remove_dict_prefix)] == remove_dict_prefix:
key = key[len(remove_dict_prefix):]
else:
continue
key = tuple(key.split("."))
# tensor = from_bytes(None, buff)
tensor = value.tolist() # tensor -> List[float]
if shard_fns is not None:
tensor = shard_fns[key](tensor)
flattend_train_state[key] = tensor
if target is not None:
flattened_target = flatten_dict(
to_state_dict(target), keep_empty_nodes=True
)
for key, value in flattened_target.items():
if key not in flattend_train_state and value == empty_node:
flattend_train_state[key] = value
train_state = unflatten_dict(flattend_train_state)
if target is None:
return train_state
return from_state_dict(target, train_state) From here I can at least have the weights saved, but I am not entirely sure if the weights are saved in the correct format, since I am running into other jax related issues(#18, but even more issues down the line). If you can confirm these weights format are correct that might help both of us down the line. |
Directory with official LLaMa2 weights consists of
checklist.chk, consolidated.00.pth, params.json
. I want to use it to train CoH model and at first try to convert .pth model to Jax weights using your script:But it leads to the following error:
I created conda environment using your .yml file
The text was updated successfully, but these errors were encountered: