Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support transfer llama hf weight to megatron weight #246

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
10 changes: 7 additions & 3 deletions megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,14 +222,17 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
if not args.deepspeed:
model = unwrap_model(model)

print_rank_0('saving checkpoint at iteration {:7d} to {}'.format(
uygnef marked this conversation as resolved.
Show resolved Hide resolved
print_rank_0('saving checkpoint at iteration {} to {}'.format(
iteration, args.save))

# Collect rng state across data parallel ranks.
rng_state = get_rng_state()

# Checkpoint name.
checkpoint_name = get_checkpoint_name(args.save, iteration)
if iteration == 'release':
checkpoint_name = get_checkpoint_name(args.save, iteration, release=True)
else:
checkpoint_name = get_checkpoint_name(args.save, iteration)

# Save distributed optimizer's custom parameter state.
if args.use_distributed_optimizer:
Expand Down Expand Up @@ -300,7 +303,7 @@ def state_dict_for_save_checkpoint_deepspeed(destination=None, prefix='', keep_v
if torch.distributed.is_initialized():
torch.distributed.barrier()

print_rank_0(' successfully saved checkpoint at iteration {:7d} to {}' \
print_rank_0(' successfully saved checkpoint at iteration {} to {}' \
.format(iteration, args.save))

# And update the latest iteration
Expand Down Expand Up @@ -509,6 +512,7 @@ def _set_arg(arg_name, old_arg_name=None, force=False):
_set_arg('apply_layernorm_1p', force=True)
_set_arg('tokenizer_type')
_set_arg('padded_vocab_size')
_set_arg('normalization', force=True)
if checkpoint_version < 3.0:
_set_arg('tensor_model_parallel_size',
'model_parallel_size')
Expand Down
8 changes: 6 additions & 2 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,9 +589,13 @@ def __init__(self, config, layer_number,
local_attn = FlashSelfAttention(causal=True, attention_dropout=config.attention_dropout)
else:
local_attn = CoreAttention(self.layer_number, config, self.attn_mask_type)

# if hasattr(args, 'ckpt_transfer') and args.ckpt_transfer:
# self.enable_ds_sequence_parallel = False
# else:
# self.enable_ds_sequence_parallel = parallel_state.get_sequence_parallel_world_size() > 1 \
# or args.force_ds_sequence_parallel
self.enable_ds_sequence_parallel = parallel_state.get_sequence_parallel_world_size() > 1 \
or args.force_ds_sequence_parallel
or args.force_ds_sequence_parallel
if self.enable_ds_sequence_parallel:
assert dist_attn_supported, 'Distributed attention is not supported in this DeepSpeed version'
assert args.num_attention_heads % parallel_state.get_sequence_parallel_world_size() == 0
Expand Down
10 changes: 9 additions & 1 deletion pretrain_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,22 @@
import torch.nn.functional as F


def model_provider(pre_process=True, post_process=True):
def model_provider(pre_process=True, post_process=True, ckpt_transfer_model=False):
"""Build the model."""

print_rank_0('building GPT model ...')
see_memory_usage(f"Before Building Model", force=True)

args = get_args()
config = core_transformer_config_from_args(args)

if ckpt_transfer_model:
return GPTModel(config=config,
num_tokentypes=0,
parallel_output=True,
pre_process=pre_process,
post_process=post_process)

with deepspeed.zero.Init(sequence_data_parallel_group=mpu.get_sequence_data_parallel_group(),
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there must be some better solution to init model without init distibute group. please help me ..

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The distributed initialization only occurs for args.zero_stage==3. Have you tried with different stage value on command line?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The distributed initialization only occurs for args.zero_stage==3. Have you tried with different stage value on command line?

The problem is mpu.get_sequence_data_parallel_group(). How can I solve this problem?

  File "/mnt/megatron-deepspeed/pretrain_gpt.py", line 48, in model_provider
    with deepspeed.zero.Init(sequence_data_parallel_group=mpu.get_sequence_data_parallel_group(),
  File "/mnt/megatron-deepspeed/megatron/core/parallel_state.py", line 369, in get_sequence_data_parallel_group
    assert _SEQUENCE_DATA_PARALLEL_GROUP is not None, \
AssertionError: sequence data parallel group is not initialized

remote_device=None if args.remote_device == 'none' else args.remote_device,
config_dict_or_path=args.deepspeed_config,
Expand Down
18 changes: 13 additions & 5 deletions tools/checkpoint_loader_megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ def _load_checkpoint(queue, args):

margs = parse_args()
margs, checkpoint_args = load_args_from_checkpoint(margs)
if args.tokenizer_model:
margs.tokenizer_model = args.tokenizer_model
margs.ckpt_transfer = True

# Arguments do sanity checks on the world size, but we don't care,
# so trick it into thinking we are plenty of processes
Expand Down Expand Up @@ -124,14 +127,15 @@ def get_models(count, dtype):
post_process = mpu.is_pipeline_last_stage()
this_model = model_provider(
pre_process=pre_process,
post_process=post_process
post_process=post_process,
ckpt_transfer_model=True
).to(dtype)
model_.append(this_model)
else:
pre_process = mpu.is_pipeline_first_stage()
post_process = mpu.is_pipeline_last_stage()
model_rank = 0
model_ = [model_provider(pre_process, post_process).to(dtype)]
model_ = [model_provider(pre_process, post_process, ckpt_transfer_model=True).to(dtype)]
margs.consumed_train_samples = 0
margs.consumed_valid_samples = 0
load_checkpoint(model_, None, None)
Expand Down Expand Up @@ -236,9 +240,11 @@ def queue_put(name, msg):
# Get non-parallel tensors from tp_rank 0
layer = models[0].language_model.encoder.layers[layer_num]
message["input layernorm weight"] = layer.input_layernorm.weight.data
message["input layernorm bias"] = layer.input_layernorm.bias.data
message["post layernorm weight"] = layer.post_attention_layernorm.weight.data
message["post layernorm bias"] = layer.post_attention_layernorm.bias.data
if margs.normalization != 'rmsnorm':
message["input layernorm bias"] = layer.input_layernorm.bias.data
message["post layernorm bias"] = layer.post_attention_layernorm.bias.data

if md.linear_bias:
message["dense bias"] = layer.self_attention.dense.bias.data
message["mlp l1 bias"] = layer.mlp.dense_4h_to_h.bias.data
Expand Down Expand Up @@ -291,8 +297,9 @@ def queue_put(name, msg):
# Send final layernorm from tp_rank 0
message = {
"weight": models[0].language_model.encoder.final_layernorm.weight.data,
"bias": models[0].language_model.encoder.final_layernorm.bias.data
}
if margs.normalization != 'rmsnorm':
message["bias"] = models[0].language_model.encoder.final_layernorm.bias.data
queue_put("final layernorm", message)

if md.output_layer:
Expand Down Expand Up @@ -334,3 +341,4 @@ def load_checkpoint(queue, args):
except:
queue.put("exit")
raise

33 changes: 22 additions & 11 deletions tools/checkpoint_saver_megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,12 +162,15 @@ def check_message(msg):
setattr(margs, arg, value)

validate_args(margs)

margs.ckpt_transfer = True
if args.tokenizer_model:
margs.tokenizer_model = args.tokenizer_model
set_global_variables(margs)

# margs = megatron args
margs = get_args()

print("args.tokenizer_model", args.tokenizer_model)
if hasattr(md, 'consumed_train_samples'):
margs.consumed_train_samples = md.consumed_train_samples
margs.consumed_valid_samples = md.consumed_valid_samples
Expand All @@ -187,7 +190,7 @@ def check_message(msg):
raise Exception(f'unrecognized model type: {args.model_type}')

def get_models(count, dtype, pre_process, post_process):
models = [model_provider(pre_process, post_process).to(dtype) for _ in range(count)]
models = [model_provider(pre_process, post_process, ckpt_transfer_model=True).to(dtype) for _ in range(count)]
return models

# fake initializing distributed
Expand Down Expand Up @@ -262,9 +265,11 @@ def get_models(count, dtype, pre_process, post_process):

# duplicated tensors
input_layernorm_weight = msg.pop("input layernorm weight")
input_layernorm_bias = msg.pop("input layernorm bias")
post_layernorm_weight = msg.pop("post layernorm weight")
post_layernorm_bias = msg.pop("post layernorm bias")
if margs.normalization != 'rmsnorm':
post_layernorm_bias = msg.pop("post layernorm bias")
input_layernorm_bias = msg.pop("input layernorm bias")

if md.linear_bias:
dense_bias = msg.pop("dense bias")
mlp_l1_bias = msg.pop("mlp l1 bias")
Expand Down Expand Up @@ -295,11 +300,12 @@ def get_models(count, dtype, pre_process, post_process):
for tp_rank in range(args.target_tensor_parallel_size):
l = models[tp_rank].language_model.encoder.layers[layer]
l.input_layernorm.weight.data.copy_(input_layernorm_weight)
l.input_layernorm.bias.data.copy_(input_layernorm_bias)
if margs.normalization != 'rmsnorm':
l.input_layernorm.bias.data.copy_(input_layernorm_bias)
l.post_attention_layernorm.bias.data.copy_(post_layernorm_bias)
l.self_attention.query_key_value.weight.data.copy_(qkv_weight[tp_rank])
l.self_attention.dense.weight.data.copy_(dense_weight[tp_rank])
l.post_attention_layernorm.weight.data.copy_(post_layernorm_weight)
l.post_attention_layernorm.bias.data.copy_(post_layernorm_bias)
l.mlp.dense_h_to_4h.weight.data.copy_(mlp_l0_weight[tp_rank])
l.mlp.dense_4h_to_h.weight.data.copy_(mlp_l1_weight[tp_rank])
if md.linear_bias:
Expand All @@ -315,15 +321,18 @@ def get_models(count, dtype, pre_process, post_process):
if post_process:
msg = queue_get("final layernorm")
final_layernorm_weight = msg.pop("weight")
final_layernorm_bias = msg.pop("bias")
if margs.normalization != 'rmsnorm':
final_layernorm_bias = msg.pop("bias")
for tp_rank in range(args.target_tensor_parallel_size):
models[tp_rank].language_model.encoder.final_layernorm.weight.data.copy_(final_layernorm_weight)
models[tp_rank].language_model.encoder.final_layernorm.bias.data.copy_(final_layernorm_bias)
if margs.normalization != 'rmsnorm':
models[tp_rank].language_model.encoder.final_layernorm.bias.data.copy_(final_layernorm_bias)
if pp_rank != 0 and not md.output_layer:
# Copy word embeddings to final pipeline rank
models[tp_rank].word_embeddings.weight.data.copy_(out_word_embed[tp_rank])
del final_layernorm_weight
del final_layernorm_bias
if margs.normalization != 'rmsnorm':
del final_layernorm_bias
check_message(msg)

if md.output_layer:
Expand Down Expand Up @@ -361,12 +370,14 @@ def get_models(count, dtype, pre_process, post_process):
lm_head_dense_weight = msg.pop("dense weight")
lm_head_dense_bias = msg.pop("dense bias")
lm_head_layernorm_weight = msg.pop("layernorm weight")
lm_head_layernorm_bias = msg.pop("layernorm bias")
if margs.normalization != 'rmsnorm':
lm_head_layernorm_bias = msg.pop("layernorm bias")
for tp_rank in range(args.target_tensor_parallel_size):
models[tp_rank].lm_head.dense.weight.data.copy_(lm_head_dense_weight)
models[tp_rank].lm_head.dense.bias.data.copy_(lm_head_dense_bias)
models[tp_rank].lm_head.layernorm.weight.data.copy_(lm_head_layernorm_weight)
models[tp_rank].lm_head.layernorm.bias.data.copy_(lm_head_layernorm_bias)
if margs.normalization != 'rmsnorm':
models[tp_rank].lm_head.layernorm.bias.data.copy_(lm_head_layernorm_bias)
check_message(msg)
msg = queue_get()

Expand Down
7 changes: 6 additions & 1 deletion tools/checkpoint_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,16 +124,21 @@ def main():
parser.add_argument('--no-checking', action='store_false',
help='Do not perform checking on the name and ordering of weights',
dest='checking')
parser.add_argument('--tokenizer-model', type=str, default=None,
help='tokenizer-model, should be on python path')


known_args, _ = parser.parse_known_args()

loader = load_plugin('loader', known_args.loader)
saver = load_plugin('saver', known_args.saver)

loader.add_arguments(parser)
saver.add_arguments(parser)

args = parser.parse_args()

if args.tokenizer_model is None:
args.tokenizer_model = args.load_dir+"/tokenizer.model"
queue = mp.Queue(maxsize=args.max_queue_size)

print("Starting saver...")
Expand Down
24 changes: 24 additions & 0 deletions tools/convert_checkpoint/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,27 @@ cd /hf/transformers
python src/transformers/models/megatron_gpt2/convert_megatron_gpt2_checkpoint.py \
/path/to/Megatron/checkpoint/iter_0097500/mp_rank_00/model_optim_rng.pt
```

## HF Transformers to Megatron-DeepSpeed (currently only support LLama)

In order to convert llama model from HF Transformers to Megatron-DeepSpeed, you can do this by two steps:

```bash
# 1. Convert llama weight from hf to megatron
python tools/convert_checkpoint/transformers_to_megatron_llama.py \
--out=/path/to/Megatron-Deepspeed/checkpoint/ \
--cache-dir=/path/to/hf/transformers/llama_checkpoint

# 2. Convert Megatron-DeepSpeed checkpoint to distributed version
python3 tools/checkpoint_util.py \
--target-tensor-parallel-size 4 \
--target-pipeline-parallel-size 2 \
--load-dir /path/to/Megatron-Deepspeed/checkpoint/ \
--save-dir /path/to/Megatron-Deepspeed/distribute_checkpoint/ \
--model-type GPT

# 3. Convert DeepSpeed to Huggingface transformers version
python3 tools/convert_checkpoint/deepspeed_to_megatron_llama.py \
--input_dir /path/to/Deepspeed/checkpoint/ \
--output_dir /path/to//hf/transformers/checkpoint/
```
Loading