Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error while converting checkpoints to Flax format #17

Open
yaraksen opened this issue Oct 20, 2023 · 3 comments
Open

Error while converting checkpoints to Flax format #17

yaraksen opened this issue Oct 20, 2023 · 3 comments

Comments

@yaraksen
Copy link

Directory with official LLaMa2 weights consists of checklist.chk, consolidated.00.pth, params.json. I want to use it to train CoH model and at first try to convert .pth model to Jax weights using your script:

python3 -m coh.scripts.convert_checkpoint \
    --load_checkpoint='params::llama-2-7b/consolidated.00.pth' \
    --output_file='llama-2-7b-jax/' \
    --streaming=True

But it leads to the following error:

Traceback (most recent call last):
  File "/usr/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/app/src/coh/scripts/convert_checkpoint.py", line 37, in <module>
    utils.run(main)
  File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/app/src/coh/scripts/convert_checkpoint.py", line 22, in main
    params = StreamingCheckpointer.load_trainstate_checkpoint(
  File "/app/src/coh/tools/checkpoint.py", line 191, in load_trainstate_checkpoint
    restored_params = cls.load_checkpoint(
  File "/app/src/coh/tools/checkpoint.py", line 107, in load_checkpoint
    for key, value in unpacker:
TypeError: cannot unpack non-iterable int object

I created conda environment using your .yml file

@PootieT
Copy link

PootieT commented Oct 23, 2023

Same issue here, @yaraksen were you able to get past the error?

@yaraksen
Copy link
Author

@PootieT, unfortunately not, waiting for the author's response

@PootieT
Copy link

PootieT commented Oct 27, 2023

@yaraksen
not sure if this is exactly the solution, but I was able to get some thing by changing the StreamingCheckpointer.load_checkpoint() method in coh.tools.checkpoint.py :

    @staticmethod
    def load_checkpoint(path, target=None, shard_fns=None, remove_dict_prefix=None):
        if shard_fns is not None:
            shard_fns = flatten_dict(
                to_state_dict(shard_fns)
            )
        if remove_dict_prefix is not None:
            remove_dict_prefix = tuple(remove_dict_prefix)
        flattend_train_state = {}
        with utils.open_file(path) as fin:
            # 83886080 bytes = 80 MB, which is 16 blocks on GCS
            # unpacker = msgpack.Unpacker(fin, read_size=83886080, max_buffer_size=0,
            #                             use_list=False)
            # for key, value in unpacker:
            #     key = tuple(key)  # not sure why this is there
            # TODO: bug here where unpacker is returning stream of integers, but
            # code is expecting tuple of key/value pair of parameter name/values
            # this is not save-all solution but instead we load with torch.load()
            # all into memory, then iterate through
            weight_dict = torch.load(fin)
            for key, value in weight_dict.items():
                if remove_dict_prefix is not None:
                    if key[:len(remove_dict_prefix)] == remove_dict_prefix:
                        key = key[len(remove_dict_prefix):]
                    else:
                        continue

                key = tuple(key.split("."))
                # tensor = from_bytes(None, buff)
                tensor = value.tolist()  # tensor -> List[float]
                if shard_fns is not None:
                    tensor = shard_fns[key](tensor)
                flattend_train_state[key] = tensor

        if target is not None:
            flattened_target = flatten_dict(
                to_state_dict(target), keep_empty_nodes=True
            )
            for key, value in flattened_target.items():
                if key not in flattend_train_state and value == empty_node:
                    flattend_train_state[key] = value
        train_state = unflatten_dict(flattend_train_state)
        if target is None:
            return train_state

        return from_state_dict(target, train_state)

From here I can at least have the weights saved, but I am not entirely sure if the weights are saved in the correct format, since I am running into other jax related issues(#18, but even more issues down the line). If you can confirm these weights format are correct that might help both of us down the line.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants