diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py index 5070bdd54b..9dee5d9674 100644 --- a/megatron/checkpointing.py +++ b/megatron/checkpointing.py @@ -222,14 +222,17 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler): if not args.deepspeed: model = unwrap_model(model) - print_rank_0('saving checkpoint at iteration {:7d} to {}'.format( + print_rank_0('saving checkpoint at iteration {} to {}'.format( iteration, args.save)) # Collect rng state across data parallel ranks. rng_state = get_rng_state() # Checkpoint name. - checkpoint_name = get_checkpoint_name(args.save, iteration) + if iteration == 'release': + checkpoint_name = get_checkpoint_name(args.save, iteration, release=True) + else: + checkpoint_name = get_checkpoint_name(args.save, iteration) # Save distributed optimizer's custom parameter state. if args.use_distributed_optimizer: @@ -300,7 +303,7 @@ def state_dict_for_save_checkpoint_deepspeed(destination=None, prefix='', keep_v if torch.distributed.is_initialized(): torch.distributed.barrier() - print_rank_0(' successfully saved checkpoint at iteration {:7d} to {}' \ + print_rank_0(' successfully saved checkpoint at iteration {} to {}' \ .format(iteration, args.save)) # And update the latest iteration @@ -509,6 +512,7 @@ def _set_arg(arg_name, old_arg_name=None, force=False): _set_arg('apply_layernorm_1p', force=True) _set_arg('tokenizer_type') _set_arg('padded_vocab_size') + _set_arg('normalization', force=True) if checkpoint_version < 3.0: _set_arg('tensor_model_parallel_size', 'model_parallel_size') diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index d488a892aa..ec9dd57bf0 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -589,9 +589,13 @@ def __init__(self, config, layer_number, local_attn = FlashSelfAttention(causal=True, attention_dropout=config.attention_dropout) else: local_attn = CoreAttention(self.layer_number, config, self.attn_mask_type) - + # if hasattr(args, 'ckpt_transfer') and args.ckpt_transfer: + # self.enable_ds_sequence_parallel = False + # else: + # self.enable_ds_sequence_parallel = parallel_state.get_sequence_parallel_world_size() > 1 \ + # or args.force_ds_sequence_parallel self.enable_ds_sequence_parallel = parallel_state.get_sequence_parallel_world_size() > 1 \ - or args.force_ds_sequence_parallel + or args.force_ds_sequence_parallel if self.enable_ds_sequence_parallel: assert dist_attn_supported, 'Distributed attention is not supported in this DeepSpeed version' assert args.num_attention_heads % parallel_state.get_sequence_parallel_world_size() == 0 diff --git a/pretrain_gpt.py b/pretrain_gpt.py index 149414848c..7122bfa1fa 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -29,7 +29,7 @@ import torch.nn.functional as F -def model_provider(pre_process=True, post_process=True): +def model_provider(pre_process=True, post_process=True, ckpt_transfer_model=False): """Build the model.""" print_rank_0('building GPT model ...') @@ -37,6 +37,14 @@ def model_provider(pre_process=True, post_process=True): args = get_args() config = core_transformer_config_from_args(args) + + if ckpt_transfer_model: + return GPTModel(config=config, + num_tokentypes=0, + parallel_output=True, + pre_process=pre_process, + post_process=post_process) + with deepspeed.zero.Init(sequence_data_parallel_group=mpu.get_sequence_data_parallel_group(), remote_device=None if args.remote_device == 'none' else args.remote_device, config_dict_or_path=args.deepspeed_config, diff --git a/tools/checkpoint_loader_megatron.py b/tools/checkpoint_loader_megatron.py index 1cd4937152..cde3e4f158 100644 --- a/tools/checkpoint_loader_megatron.py +++ b/tools/checkpoint_loader_megatron.py @@ -56,6 +56,9 @@ def _load_checkpoint(queue, args): margs = parse_args() margs, checkpoint_args = load_args_from_checkpoint(margs) + if args.tokenizer_model: + margs.tokenizer_model = args.tokenizer_model + margs.ckpt_transfer = True # Arguments do sanity checks on the world size, but we don't care, # so trick it into thinking we are plenty of processes @@ -124,14 +127,15 @@ def get_models(count, dtype): post_process = mpu.is_pipeline_last_stage() this_model = model_provider( pre_process=pre_process, - post_process=post_process + post_process=post_process, + ckpt_transfer_model=True ).to(dtype) model_.append(this_model) else: pre_process = mpu.is_pipeline_first_stage() post_process = mpu.is_pipeline_last_stage() model_rank = 0 - model_ = [model_provider(pre_process, post_process).to(dtype)] + model_ = [model_provider(pre_process, post_process, ckpt_transfer_model=True).to(dtype)] margs.consumed_train_samples = 0 margs.consumed_valid_samples = 0 load_checkpoint(model_, None, None) @@ -236,9 +240,11 @@ def queue_put(name, msg): # Get non-parallel tensors from tp_rank 0 layer = models[0].language_model.encoder.layers[layer_num] message["input layernorm weight"] = layer.input_layernorm.weight.data - message["input layernorm bias"] = layer.input_layernorm.bias.data message["post layernorm weight"] = layer.post_attention_layernorm.weight.data - message["post layernorm bias"] = layer.post_attention_layernorm.bias.data + if margs.normalization != 'rmsnorm': + message["input layernorm bias"] = layer.input_layernorm.bias.data + message["post layernorm bias"] = layer.post_attention_layernorm.bias.data + if md.linear_bias: message["dense bias"] = layer.self_attention.dense.bias.data message["mlp l1 bias"] = layer.mlp.dense_4h_to_h.bias.data @@ -291,8 +297,9 @@ def queue_put(name, msg): # Send final layernorm from tp_rank 0 message = { "weight": models[0].language_model.encoder.final_layernorm.weight.data, - "bias": models[0].language_model.encoder.final_layernorm.bias.data } + if margs.normalization != 'rmsnorm': + message["bias"] = models[0].language_model.encoder.final_layernorm.bias.data queue_put("final layernorm", message) if md.output_layer: @@ -334,3 +341,4 @@ def load_checkpoint(queue, args): except: queue.put("exit") raise + diff --git a/tools/checkpoint_saver_megatron.py b/tools/checkpoint_saver_megatron.py index 0ff8c55b1f..71cdf9efb7 100644 --- a/tools/checkpoint_saver_megatron.py +++ b/tools/checkpoint_saver_megatron.py @@ -162,12 +162,15 @@ def check_message(msg): setattr(margs, arg, value) validate_args(margs) - + margs.ckpt_transfer = True + if args.tokenizer_model: + margs.tokenizer_model = args.tokenizer_model set_global_variables(margs) # margs = megatron args margs = get_args() + print("args.tokenizer_model", args.tokenizer_model) if hasattr(md, 'consumed_train_samples'): margs.consumed_train_samples = md.consumed_train_samples margs.consumed_valid_samples = md.consumed_valid_samples @@ -187,7 +190,7 @@ def check_message(msg): raise Exception(f'unrecognized model type: {args.model_type}') def get_models(count, dtype, pre_process, post_process): - models = [model_provider(pre_process, post_process).to(dtype) for _ in range(count)] + models = [model_provider(pre_process, post_process, ckpt_transfer_model=True).to(dtype) for _ in range(count)] return models # fake initializing distributed @@ -262,9 +265,11 @@ def get_models(count, dtype, pre_process, post_process): # duplicated tensors input_layernorm_weight = msg.pop("input layernorm weight") - input_layernorm_bias = msg.pop("input layernorm bias") post_layernorm_weight = msg.pop("post layernorm weight") - post_layernorm_bias = msg.pop("post layernorm bias") + if margs.normalization != 'rmsnorm': + post_layernorm_bias = msg.pop("post layernorm bias") + input_layernorm_bias = msg.pop("input layernorm bias") + if md.linear_bias: dense_bias = msg.pop("dense bias") mlp_l1_bias = msg.pop("mlp l1 bias") @@ -295,11 +300,12 @@ def get_models(count, dtype, pre_process, post_process): for tp_rank in range(args.target_tensor_parallel_size): l = models[tp_rank].language_model.encoder.layers[layer] l.input_layernorm.weight.data.copy_(input_layernorm_weight) - l.input_layernorm.bias.data.copy_(input_layernorm_bias) + if margs.normalization != 'rmsnorm': + l.input_layernorm.bias.data.copy_(input_layernorm_bias) + l.post_attention_layernorm.bias.data.copy_(post_layernorm_bias) l.self_attention.query_key_value.weight.data.copy_(qkv_weight[tp_rank]) l.self_attention.dense.weight.data.copy_(dense_weight[tp_rank]) l.post_attention_layernorm.weight.data.copy_(post_layernorm_weight) - l.post_attention_layernorm.bias.data.copy_(post_layernorm_bias) l.mlp.dense_h_to_4h.weight.data.copy_(mlp_l0_weight[tp_rank]) l.mlp.dense_4h_to_h.weight.data.copy_(mlp_l1_weight[tp_rank]) if md.linear_bias: @@ -315,15 +321,18 @@ def get_models(count, dtype, pre_process, post_process): if post_process: msg = queue_get("final layernorm") final_layernorm_weight = msg.pop("weight") - final_layernorm_bias = msg.pop("bias") + if margs.normalization != 'rmsnorm': + final_layernorm_bias = msg.pop("bias") for tp_rank in range(args.target_tensor_parallel_size): models[tp_rank].language_model.encoder.final_layernorm.weight.data.copy_(final_layernorm_weight) - models[tp_rank].language_model.encoder.final_layernorm.bias.data.copy_(final_layernorm_bias) + if margs.normalization != 'rmsnorm': + models[tp_rank].language_model.encoder.final_layernorm.bias.data.copy_(final_layernorm_bias) if pp_rank != 0 and not md.output_layer: # Copy word embeddings to final pipeline rank models[tp_rank].word_embeddings.weight.data.copy_(out_word_embed[tp_rank]) del final_layernorm_weight - del final_layernorm_bias + if margs.normalization != 'rmsnorm': + del final_layernorm_bias check_message(msg) if md.output_layer: @@ -361,12 +370,14 @@ def get_models(count, dtype, pre_process, post_process): lm_head_dense_weight = msg.pop("dense weight") lm_head_dense_bias = msg.pop("dense bias") lm_head_layernorm_weight = msg.pop("layernorm weight") - lm_head_layernorm_bias = msg.pop("layernorm bias") + if margs.normalization != 'rmsnorm': + lm_head_layernorm_bias = msg.pop("layernorm bias") for tp_rank in range(args.target_tensor_parallel_size): models[tp_rank].lm_head.dense.weight.data.copy_(lm_head_dense_weight) models[tp_rank].lm_head.dense.bias.data.copy_(lm_head_dense_bias) models[tp_rank].lm_head.layernorm.weight.data.copy_(lm_head_layernorm_weight) - models[tp_rank].lm_head.layernorm.bias.data.copy_(lm_head_layernorm_bias) + if margs.normalization != 'rmsnorm': + models[tp_rank].lm_head.layernorm.bias.data.copy_(lm_head_layernorm_bias) check_message(msg) msg = queue_get() diff --git a/tools/checkpoint_util.py b/tools/checkpoint_util.py index 628ce47c62..e8e23993d2 100644 --- a/tools/checkpoint_util.py +++ b/tools/checkpoint_util.py @@ -124,8 +124,12 @@ def main(): parser.add_argument('--no-checking', action='store_false', help='Do not perform checking on the name and ordering of weights', dest='checking') + parser.add_argument('--tokenizer-model', type=str, default=None, + help='tokenizer-model, should be on python path') + known_args, _ = parser.parse_known_args() + loader = load_plugin('loader', known_args.loader) saver = load_plugin('saver', known_args.saver) @@ -133,7 +137,8 @@ def main(): saver.add_arguments(parser) args = parser.parse_args() - + if args.tokenizer_model is None: + args.tokenizer_model = args.load_dir+"/tokenizer.model" queue = mp.Queue(maxsize=args.max_queue_size) print("Starting saver...") diff --git a/tools/convert_checkpoint/README.md b/tools/convert_checkpoint/README.md index 3f74bb1aa4..eda4b705e1 100644 --- a/tools/convert_checkpoint/README.md +++ b/tools/convert_checkpoint/README.md @@ -76,3 +76,27 @@ cd /hf/transformers python src/transformers/models/megatron_gpt2/convert_megatron_gpt2_checkpoint.py \ /path/to/Megatron/checkpoint/iter_0097500/mp_rank_00/model_optim_rng.pt ``` + +## HF Transformers to Megatron-DeepSpeed (currently only support LLama) + +In order to convert llama model from HF Transformers to Megatron-DeepSpeed, you can do this by two steps: + +```bash +# 1. Convert llama weight from hf to megatron +python tools/convert_checkpoint/transformers_to_megatron_llama.py \ +--out=/path/to/Megatron-Deepspeed/checkpoint/ \ +--cache-dir=/path/to/hf/transformers/llama_checkpoint + +# 2. Convert Megatron-DeepSpeed checkpoint to distributed version +python3 tools/checkpoint_util.py \ +--target-tensor-parallel-size 4 \ +--target-pipeline-parallel-size 2 \ +--load-dir /path/to/Megatron-Deepspeed/checkpoint/ \ +--save-dir /path/to/Megatron-Deepspeed/distribute_checkpoint/ \ +--model-type GPT + +# 3. Convert DeepSpeed to Huggingface transformers version +python3 tools/convert_checkpoint/deepspeed_to_megatron_llama.py \ +--input_dir /path/to/Deepspeed/checkpoint/ \ +--output_dir /path/to//hf/transformers/checkpoint/ +``` diff --git a/tools/convert_checkpoint/deepspeed_to_transformers_llama.py b/tools/convert_checkpoint/deepspeed_to_transformers_llama.py new file mode 100644 index 0000000000..c1edc8b374 --- /dev/null +++ b/tools/convert_checkpoint/deepspeed_to_transformers_llama.py @@ -0,0 +1,237 @@ +# Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import gc +import os +import sys +import json +from pathlib import Path +from tempfile import TemporaryDirectory +from argparse import ArgumentParser, Namespace +sys.path.append(str(Path(__file__).parent.parent.absolute())) # megatron is importable + +import torch +from tqdm.auto import trange +from transformers import LlamaConfig, LlamaForCausalLM + +from permute_qkv import permute_qkv + + +""" +modify from https://github.com/epfLLM/Megatron-LLM/tree/main/weights_conversion + +Sample usage: + +``` +python3 /pure-mlo-scratch/sfan/model-parallel-trainer/llama2megatron/convert_llama2hf.py \ + --input_dir /pure-mlo-scratch/llama/ --output_dir /pure-mlo-scratch/llama/converted_HF_7B +``` + +Thereafter, models can be loaded via: + +```py +from transformers import LlamaForCausalLM, LlamaTokenizer + +model = LlamaForCausalLM.from_pretrained("/output/path") +tokenizer = LlamaTokenizer.from_pretrained("/output/path") +``` + +Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions +come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM). +""" + + +def write_json(text, path): + with open(path, "w") as f: + json.dump(text, f) + + +def convert_wqkv(llama_mega, layer_idx=0, n_heads=32, n_heads_kv=8): + qkv_w = llama_mega["transformer"][f'layers.{layer_idx}.attention.query_key_value.weight'] + n_hidden = qkv_w.size(1) + hidden_dim = n_hidden//n_heads + qkv_w = permute_qkv(qkv_w, n_hidden, n_heads, n_heads_kv, revert=True) + + n_qs_per_kv = n_heads//n_heads_kv + n_groups = qkv_w.size(0)//hidden_dim//(n_qs_per_kv + 2) + qkv_w = list(torch.split(qkv_w, hidden_dim, dim=0)) + + wq, wk, wv = [], [], [] + for group in range(n_groups): + for qs in range(n_qs_per_kv): + wq.append(qkv_w[0]) + del qkv_w[0] + wk.append(qkv_w[0]) + del qkv_w[0] + wv.append(qkv_w[0]) + del qkv_w[0] + assert len(qkv_w) == 0 + + wq = torch.concat(wq, dim=0) + wk = torch.concat(wk, dim=0) + wv = torch.concat(wv, dim=0) + return wq, wk, wv + + +def convert_ffn(llama_mega, layer_idx=0, n_dense=11008): + mega_ffn = llama_mega["transformer"][f'layers.{layer_idx}.mlp.dense_h_to_4h.weight'] + ffn_w3, ffn_w1 = mega_ffn.split(n_dense, dim=0) + return ffn_w1, ffn_w3 + + +def write_llama_model(model_path, + input_base_path, + num_output_shards=2, + norm_eps=1e-05): + + + # Preliminaries + print(f"Fetching all parameters from the checkpoint at {input_base_path}.") + os.makedirs(model_path, exist_ok=True) + base = 10000.0 + with open(os.path.join(input_base_path, 'latest_checkpointed_iteration.txt')) as f: + iteration = f.read() + if iteration != "release": + iteration = f"iter_{int(iteration):07d}" + print(f"Fetching iteration {iteration}") + + # Load weights + base_path = Path(input_base_path)/iteration + assert len(list(base_path.glob("mp_rank_*"))) == 1, "Unshard your model with checkpoint_util.py first!" + loaded = torch.load(base_path/"mp_rank_00"/"model_optim_rng.pt", map_location="cpu") + args = loaded['args'] + loaded = loaded['model']['language_model'] + if 'transformer' not in loaded: # normalize key names + loaded["transformer"] = loaded.pop("encoder") + for key in list(loaded["transformer"].keys()): + loaded["transformer"][key.replace("self_attention", "attention")] = loaded["transformer"].pop(key) + loaded["embedding"]["word_embeddings.weight"] = loaded["embedding"].pop("word_embeddings")["weight"] + + # Load arguments + n_layers = args.num_layers + n_heads = args.num_attention_heads + n_heads_kv = getattr(args, "num_attention_heads_kv", n_heads) + n_dense = args.ffn_hidden_size + n_hidden = args.hidden_size + hidden_per_head = n_hidden // n_heads + intermediate_size = args.ffn_hidden_size + inv_freq = 1.0 / (base ** (torch.arange(0, hidden_per_head, 2).float() / hidden_per_head)) + + print('Llama-Megatron Loaded!') + param_count = 0 + index_dict = {"weight_map": {}} + + # Start conversion + with TemporaryDirectory() as tmp_model_path: + print(f'Weighted Converting for {n_layers} layers...') + for layer_i in range(n_layers): + filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin" + wq_proj, wk_proj, wv_proj = convert_wqkv(llama_mega=loaded, + layer_idx=layer_i, n_heads=n_heads, + n_heads_kv=n_heads_kv) + ffn_w1, ffn_w3 = convert_ffn(llama_mega=loaded, + layer_idx=layer_i, + n_dense=n_dense) + state_dict = { + f"model.layers.{layer_i}.self_attn.q_proj.weight": wq_proj, + f"model.layers.{layer_i}.self_attn.k_proj.weight": wk_proj, + f"model.layers.{layer_i}.self_attn.v_proj.weight": wv_proj, + f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded["transformer"][f"layers.{layer_i}.attention.dense.weight"], + f"model.layers.{layer_i}.mlp.gate_proj.weight": ffn_w1, + f"model.layers.{layer_i}.mlp.down_proj.weight": loaded["transformer"][f"layers.{layer_i}.mlp.dense_4h_to_h.weight"], + f"model.layers.{layer_i}.mlp.up_proj.weight": ffn_w3, + f"model.layers.{layer_i}.input_layernorm.weight": loaded["transformer"][f"layers.{layer_i}.input_layernorm.weight"], + f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded["transformer"][f"layers.{layer_i}.post_attention_layernorm.weight"], + f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq": inv_freq + } + + for k, v in state_dict.items(): + index_dict["weight_map"][k] = filename + param_count += v.numel() + print(os.path.join(tmp_model_path, filename)) + torch.save(state_dict, os.path.join(tmp_model_path, filename)) + print(f'Sharded file saved to {filename}') + + filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin" + state_dict = { + "model.norm.weight": loaded["transformer"]['final_layernorm.weight'], + "lm_head.weight": loaded['lm_head'], + "model.embed_tokens.weight": loaded['embedding']["word_embeddings.weight"] + } + + for k, v in state_dict.items(): + index_dict["weight_map"][k] = filename + param_count += v.numel() + torch_dtype = state_dict["lm_head.weight"].dtype + torch.save(state_dict, os.path.join(tmp_model_path, filename)) + print(f'Sharded file saved to {filename}') + + # Write configs and save + index_dict["metadata"] = {"total_size": param_count * 2} + write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json")) + config = LlamaConfig( + vocab_size=args.padded_vocab_size, + hidden_size=n_hidden, + intermediate_size=intermediate_size, + num_attention_heads=n_heads, + num_hidden_layers=n_layers, + rms_norm_eps=norm_eps, + num_key_value_heads=n_heads_kv, + max_position_embeddings=args.seq_length, + ) + config.save_pretrained(tmp_model_path) + + # Make space so we can load the model properly now. + del state_dict + del loaded + gc.collect() + + print("Loading the checkpoint in a Llama model...") + model = LlamaForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch_dtype) + # Avoid saving this as part of the config. + del model.config._name_or_path + + print("Saving in the Transformers format.") + max_num_params_per_shard = param_count*2 // max(1,(num_output_shards-1)) + model.save_pretrained(model_path, max_shard_size=max_num_params_per_shard) + + +def main(): + # make sure megatron is importable + + parser = ArgumentParser() + parser.add_argument("--input_dir", help="Location of LLaMA_Megatron weights", + required=True) + parser.add_argument("--num_output_shards", type=int, default=1) + parser.add_argument("--model", choices={"falcon", "llama", "llama2"}, + default="llama2") + parser.add_argument("--output_dir", help="Location to write HF model and tokenizer", + required=True) + parser.add_argument("--cache_dir", help="Huggingface cache_dir (optional)") + parser.add_argument("--vocab_file", type=str, help="Path to the vocab file") + parser.add_argument("--vocab_extra_ids_list", + help="comma separated list of special vocab ids to add to the tokenizer") + parser.add_argument("--override_special_tokens", nargs="*", + help=("One or more arguments to override special tokens. " + "Syntax set as `key=value`, e.g. `eos=<|im_end|>`. " + "Overrides available only bos, cls, eos, mask, pad, sep, unk.")) + + args = parser.parse_args() + write_llama_model( + model_path=args.output_dir, + input_base_path=args.input_dir, + num_output_shards=args.num_output_shards + ) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/tools/convert_checkpoint/merge_llama.py b/tools/convert_checkpoint/merge_llama.py new file mode 100644 index 0000000000..6235211a21 --- /dev/null +++ b/tools/convert_checkpoint/merge_llama.py @@ -0,0 +1,111 @@ +import os +import re +from pathlib import Path +from typing import Optional +from collections import OrderedDict + +import torch +from tqdm.auto import tqdm +from transformers import LlamaForCausalLM, AutoTokenizer + + +scale2emb = { + '7B': 4096, + '13B': 5120, + '30B': 6656, + '65B': 8192, + '70B': 8192, +} + + +key_to_dim = { + "w1": 0, + "w2": -1, + "w3": 0, + "wo": -1, + "wq": 0, + "wk": 0, + "wv": 0, + "output": 0, + "tok_embeddings": -1, + "ffn_norm": None, + "attention_norm": None, + "norm": None, + "rope": None, +} + + +def init_merged_ckpt(pth_00, num_pth=8, emb_dim=8192): + merged_ckpt = OrderedDict() + for parameter_name, parameter in pth_00.items(): + short_name = parameter_name.split(".")[-2] + if key_to_dim[short_name] is None: + merged_ckpt[parameter_name] = parameter + del parameter + elif key_to_dim[short_name] == 0: + size = parameter.shape[0] + merged_param_shape = [ parameter.shape[0] * num_pth, parameter.shape[1] ] + merged_ckpt[parameter_name] = torch.zeros(merged_param_shape) + merged_ckpt[parameter_name][0 : size, :] = parameter + del parameter + elif key_to_dim[short_name] == -1: + size = parameter.shape[-1] + merged_param_shape = [ parameter.shape[0], parameter.shape[1] * num_pth] + merged_ckpt[parameter_name] = torch.zeros(merged_param_shape) + merged_ckpt[parameter_name][:, 0 : size] = parameter + del parameter + return merged_ckpt + + +def merge_meta_llama(size: int, root_dir: Path): + paths = sorted(path for path in root_dir.iterdir() + if re.match(r"^consolidated\.[0-9]+\.pth$", path.name)) + if len(paths) == 1: # no sharded checkpoints, return everything + return torch.load(paths[0], map_location=torch.device("cpu")) + + num_pth = len(paths) + for i, ckpt_path in enumerate(tqdm(paths, desc="Merging llama")): + llama_config = torch.load(ckpt_path, map_location=torch.device('cpu')) + if i == 0: + merged_ckpt = init_merged_ckpt(llama_config, num_pth=num_pth, + emb_dim=scale2emb[f"{size}B"]) + else: + for parameter_name, parameter in llama_config.items(): + short_name = parameter_name.split(".")[-2] + if key_to_dim[short_name] == 0: + size = parameter.shape[0] + merged_param_shape = [ parameter.shape[0] * num_pth, parameter.shape[1] ] + merged_ckpt[parameter_name][size * i : size * (i + 1), :] = parameter + del parameter + if key_to_dim[short_name] == -1: + size = parameter.shape[-1] + merged_param_shape = [ parameter.shape[0], parameter.shape[1] * num_pth] + merged_ckpt[parameter_name][:, size * i : size * (i + 1)] = parameter + del parameter + del llama_config + return merged_ckpt + + +def merge_hf_llama(cache_dir: Optional[Path] = None): + # assert version == 2, "Only llama v2 available using huggingface" + model = LlamaForCausalLM.from_pretrained(cache_dir, cache_dir=cache_dir, local_files_only=True, use_safetensors=False) + weights = model.state_dict() + weights["tok_embeddings.weight"] = weights.pop("model.embed_tokens.weight") + weights["norm.weight"] = weights.pop("model.norm.weight") + weights["output.weight"] = weights.pop("lm_head.weight") + for key in list(weights.keys()): + if rmatch := re.match(r"^model\.(layers\.[0-9]+\.)(.+)(\.weight)$", key): + new_key = { + "self_attn.q_proj": "attention.wq", + "self_attn.k_proj": "attention.wk", + "self_attn.v_proj": "attention.wv", + "self_attn.o_proj": "attention.wo", + "mlp.gate_proj": "feed_forward.w1", + "mlp.down_proj": "feed_forward.w2", + "mlp.up_proj": "feed_forward.w3", + "input_layernorm": "attention_norm", + "post_attention_layernorm": "ffn_norm" + }[rmatch.group(2)] + weights[rmatch.group(1) + new_key + rmatch.group(3)] = weights.pop(key) + return weights, model.config + diff --git a/tools/convert_checkpoint/permute_qkv.py b/tools/convert_checkpoint/permute_qkv.py new file mode 100644 index 0000000000..69159e94e2 --- /dev/null +++ b/tools/convert_checkpoint/permute_qkv.py @@ -0,0 +1,81 @@ +import re +import sys +import os +import shutil +from pathlib import Path +from argparse import ArgumentParser + +import torch +from tqdm.auto import tqdm + + +def permute_qkv(qkv_w: torch.Tensor, dim: int, n_heads: int, + n_heads_kv: int, revert: bool = False) -> torch.Tensor: + + def permute(x): + if revert: + return x.view(head_dim//2, 2, dim).transpose(0, 1).reshape(head_dim, dim) + return x.view(2, head_dim//2, dim).transpose(0, 1).reshape(head_dim, dim) + + head_dim = dim//n_heads + n_qs_per_kv = n_heads//n_heads_kv + n_groups = qkv_w.size(0)//head_dim//(n_qs_per_kv + 2) + groups = torch.chunk(qkv_w, n_groups, dim=0) + new = [] + for group in groups: + *qs, k, v = torch.split(group, head_dim, dim=0) + assert len(qs) == n_qs_per_kv, f"{len(qs)}, {n_qs_per_kv}" + new += list(map(permute, qs)) + [permute(k), v] + return torch.cat(new, dim=0) + + +def update_checkpoint(input_dir: Path, output_dir: Path, overwrite_ok: bool = False): + # make sure megatron is importable + sys.path.append(os.path.abspath( + os.path.join(os.path.dirname(__file__), + os.path.pardir))) + + + # prepare output dir + if output_dir.exists(): + if not overwrite_ok: + raise FileExistsError(f"Output directory {output_dir} already exists") + print(f"Removing {output_dir}") + shutil.rmtree(output_dir) + output_dir.mkdir(exist_ok=True) + + # determine realease + with open(input_dir/"latest_checkpointed_iteration.txt") as f: + it = f.read() + print("Updating weights of iteration", it) + with open(output_dir/"latest_checkpointed_iteration.txt", "w+") as f: + f.write(it) + (output_dir/it).mkdir() + + # convert weights + for fname in tqdm(list((input_dir/it).iterdir())): + checkpoint = torch.load(fname/"model_optim_rng.pt") + args = checkpoint["args"] + args = (args.hidden_size, args.num_attention_heads, + args.num_attention_heads_kv) + if "transformer" in checkpoint["model"]["language_model"]: + key = "transformer" + attn_key = "attention" + else: + key = "encoder" + attn_key = "self_attention" + states = checkpoint["model"]["language_model"][key] + for name, weight in states.items(): + if re.match(rf"^layers\.[0-9]+\.{attn_key}\.query_key_value\.weight$", name): + states[name] = permute_qkv(weight, *args) + (output_dir/it/fname.stem).mkdir() + torch.save(checkpoint, output_dir/it/fname.stem/"model_optim_rng.pt") + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("--input-dir", type=Path) + parser.add_argument("--output-dir", type=Path) + parser.add_argument("--overwrite-ok", action="store_true") + args = parser.parse_args() + update_checkpoint(args.input_dir, args.output_dir, args.overwrite_ok) diff --git a/tools/convert_checkpoint/transformers_to_megatron_llama.py b/tools/convert_checkpoint/transformers_to_megatron_llama.py new file mode 100644 index 0000000000..65a94f2406 --- /dev/null +++ b/tools/convert_checkpoint/transformers_to_megatron_llama.py @@ -0,0 +1,160 @@ +import os +import sys +import shutil +from pathlib import Path +from typing import Optional +from argparse import ArgumentParser, Namespace + +import torch +from tqdm.auto import trange +from transformers import AutoModelForCausalLM, LlamaTokenizer +from transformers import LlamaConfig + +from permute_qkv import permute_qkv +from merge_llama import merge_hf_llama + +def llama_to_megatron(weights: dict, llama_config: LlamaConfig = None) -> dict: + def permute(qkv_w): + return permute_qkv(qkv_w, hidden, n_heads, n_kv_heads) + + def rearrange_qkv(wq, wk, wv): + wq = torch.split(wq, n_hidden_per_head, dim=0) + wk = torch.split(wk, n_hidden_per_head, dim=0) + wv = torch.split(wv, n_hidden_per_head, dim=0) + assert len(wq) == n_heads + assert len(wk) == n_kv_heads + assert len(wv) == n_kv_heads + n_qs_per_kv = n_heads//n_kv_heads + w_qkv = [] + for i in range(n_kv_heads): + w_qkv += [wq[i*n_qs_per_kv + j] for j in range(n_qs_per_kv)] + w_qkv += [wk[i], wv[i]] + return permute(torch.concat(w_qkv)) + + # config + n_layer = llama_config.num_hidden_layers + hidden = llama_config.hidden_size + n_heads = llama_config.num_attention_heads + n_hidden_per_head = hidden//n_heads + n_kv_heads = llama_config.num_key_value_heads + # weights independent of layers + embedding = {"word_embeddings": {"weight": weights["tok_embeddings.weight"]}} + transformer = {"final_layernorm.weight": weights["norm.weight"]} + lm_head = weights["output.weight"] + # get all the other weights + for layer in trange(n_layer, desc="Converting weights"): + prefix = f"layers.{layer}" + # identical weights + transformer[f"{prefix}.attention.dense.weight"] = \ + weights[f"{prefix}.attention.wo.weight"] + transformer[f"{prefix}.post_attention_layernorm.weight"] = \ + weights[f"{prefix}.ffn_norm.weight"] + transformer[f"{prefix}.input_layernorm.weight"] = \ + weights[f"{prefix}.attention_norm.weight"] + transformer[f"{prefix}.mlp.dense_4h_to_h.weight"] = \ + weights[f"{prefix}.feed_forward.w2.weight"] + # concatenate up, gate mlp weights + transformer[f"{prefix}.mlp.dense_h_to_4h.weight"] = torch.concat([ + weights[f"{prefix}.feed_forward.w3.weight"], + weights[f"{prefix}.feed_forward.w1.weight"] + ]) + # finally, qkv requires serious manipulation to get right + transformer[f"{prefix}.attention.query_key_value.weight"] = rearrange_qkv( + weights[f"{prefix}.attention.wq.weight"], + weights[f"{prefix}.attention.wk.weight"], + weights[f"{prefix}.attention.wv.weight"] + ) + + # release references to original weights (free mem) + del weights[f"{prefix}.feed_forward.w3.weight"] + del weights[f"{prefix}.feed_forward.w1.weight"] + del weights[f"{prefix}.attention.wq.weight"] + del weights[f"{prefix}.attention.wk.weight"] + del weights[f"{prefix}.attention.wv.weight"] + + return {"embedding": embedding, "encoder": transformer, + "lm_head": lm_head} + +def main(out: Optional[Path] = None, + cache_dir: Optional[Path] = None, megatron_path: Optional[Path] = None): + + if megatron_path: + print("Add megatron to os path") + os.path.append(megatron_path) + # get weights from or specified directory + print("Getting llama...") + hf_weights, llama_config = merge_hf_llama(cache_dir) + + # convert state dict to be megatron-compatible + megatron_weights = llama_to_megatron(hf_weights, llama_config=llama_config) + + # set args + # llama1, llama2 + args = {"num_layers": llama_config.num_hidden_layers, + "hidden_size": llama_config.hidden_size, + "num_attention_heads": llama_config.num_attention_heads, + "ffn_hidden_size": llama_config.intermediate_size, + "num_key_value_heads": llama_config.num_key_value_heads, + "parallel_attn": False, + "make_vocab_size_divisible_by": 1, + "glu_activation": "swiglu", + "max_position_embeddings": llama_config.max_length, # should use max_length rather than max_position_embeddings, detail in https://github.com/lm-sys/FastChat/issues/2046#issuecomment-1645265800 + "seq_length": llama_config.max_length, + "layernorm_epsilon": llama_config.rms_norm_eps, + # llama args + "padded_vocab_size": llama_config.vocab_size, + "tokenizer_type": "GPTSentencePieceTokenizer", + "no-query-key-layer-scaling": True, + "attention-dropout": 0, + "hidden-dropout": 0, + "use-rotary-position-embeddings": True, + "untie-embeddings-and-output-weights": True, + "swiglu": True, + "normalization": "rmsnorm", + "disable-bias-linear": True, + "add_position_embedding": False, + "add_bias_linear": False, + } + if llama_config.num_key_value_heads: + args.update({"num_attention_heads_kv": llama_config.num_key_value_heads}) + + args.update({ + "tensor_model_parallel_size": 1, + "pipeline_model_parallel_size": 1, + "iteration": 0, + "bias_gelu_fusion": False, + "bias_droput_fusion": False, + }) + + # save converted weights in specified out + (out/"release"/"mp_rank_00").mkdir(parents=True) + with open(out/"latest_checkpointed_iteration.txt", "w+") as f: + f.write("release") + final_dict = {"iteration": 'release', "model": {"language_model": megatron_weights}, + "checkpoint_version": 3.0, "args": Namespace(**args)} + torch.save(final_dict, out/"release"/"mp_rank_00"/"model_optim_rng.pt") + print("Saved weights in", out) + + tokenizer = LlamaTokenizer.from_pretrained( + cache_dir, cache_dir=cache_dir, local_files_only=True, + ) + token_path = out/"tokenizer.model" + vocab_file = tokenizer.vocab_file + shutil.copy(vocab_file, token_path) + print("Saved tokenizer.model in", token_path) + print("Done") + +if __name__ == "__main__": + parser = ArgumentParser(description="Convert Huggingface llama weights to " + "megatron-compatible weights") + parser.add_argument("--out", type=Path, + help="Directory to store the megatron weights (as checkpoint)") + parser.add_argument("--cache-dir", type=Path, + help=("Directory to store the huggingface weights, or " + "in case of the llama model, where to look for " + "the consolidated.xx.pth")) + parser.add_argument("--megatron-path", type=Path, default=None, + help="Path where to find megatron code") + args = parser.parse_args() + + main(args.out, args.cache_dir, args.megatron_path)