Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add ALCF/examples/finetune_llama3/* #74

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open

Conversation

saforem2
Copy link
Member

Copilot Generated Summary:

This pull request includes several updates to the ALCF/examples/finetune_llama3 directory, focusing on adding a new README file, configuration files, and a shell script for fine-tuning the Llama3 model. Additionally, there is a minor update in the megatron/checkpointing.py file to ensure directory creation when saving the learning rate state dictionary.

Documentation and setup:

  • ALCF/examples/finetune_llama3/README.md: Added comprehensive instructions for setting up the environment, installing dependencies, downloading data, and converting Hugging Face checkpoints for fine-tuning Llama3.

Configuration files:

Shell script:

  • ALCF/examples/finetune_llama3/finetune_llama.sh: Added a script for setting up the environment, configuring DeepSpeed, and running the fine-tuning process for Llama3. This script includes logic for converting Hugging Face models to Megatron-Deepspeed format and vice versa.

Minor update:

  • megatron/checkpointing.py: Ensured the parent directory is created if it does not exist when saving the learning rate state dictionary.

@saforem2 saforem2 changed the title feat: Add ALCF/examples/finetune_llama3/* feat: Add ALCF/examples/finetune_llama3/* Jan 15, 2025
Comment on lines 130 to +135

for name, param in hf_auto_model.named_parameters():
hf_model[name] = param.clone()
log.info(name)
hf_model = {}
for name, submodule in hf_auto_model.named_children():
for pname, param in submodule.named_parameters():
logger.info(f'[{name}.{pname}] shape={param.shape}')
hf_model[f'{name}.{pname}'] = param.clone()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was an issue being caused by the following block:

for name, param in hf_auto_model.named_parameters():

which was failing to capture the hf_auto_model.lm_head.weight, which was preventing the checkpoint from being converted successfully.

Replacing this block with

>>> for name, submodule in lmodel.named_children():
...     for pname, param in submodule.named_parameters():
...         named_submods[f'{name}.{pname}'] = param.clone()

fixes this issue, as shown explicitly below:

>>> from transformers import AutoModelForCausalLM, LlamaConfig, AutoTokenizer, LlamaForCausalLM, AutoConfig
>>> lmodel = AutoModelForCausalLM.from_pretrained('meta-llama/Llama-3.2-1B')
>>> named_params = []
>>> named_params = {}
>>> for name, param in lmodel.named_parameters():
...     named_params[name] = param.clone()

>>> named_submods = {}
>>> for name, submodule in lmodel.named_children():
...     for pname, param in submodule.named_parameters():
...         named_submods[f'{name}.{pname}'] = param.clone()

>>> len(named_submods.keys())
147

>>> len(named_params.keys())
146

>>> list(named_submods.keys())[-3:]
['model.layers.15.post_attention_layernorm.weight',
 'model.norm.weight',
 'lm_head.weight']

>>> list(named_params.keys())[-3:]
['model.layers.15.input_layernorm.weight',
 'model.layers.15.post_attention_layernorm.weight',
 'model.norm.weight']

Comment on lines +167 to +174
self.tokenizer = get_tokenizer()
if args.tokenizer_type == 'HFTokenizer':
self.hf_tokenizer = get_hf_tokenizer(args.tokenizer_model)
self.token_vocab = len(self.hf_tokenizer)
else:
self.hf_tokenizer = None
assert self.tokenizer is not None
self.token_vocab = self.tokenizer.vocab_size
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mismatch between self.tokenizer.vocab_size hf_w.shape[0] when using Llama3 tokenizers.

This discrepancy causes the following assertion:

assert hf_w.shape[0] == self.padded_vocab_size

to fail since hf_w.shape[0] = 128256 but self.tokenizer.vocab_size = 128000.

Explicitly:

>>> type(self.hf_tokenizer)
<class 'transformers.tokenization_utils_fast.PreTrainedTokenizerFast'>
>>> type(self.tokenizer)
<class 'megatron.tokenizer.tokenizer._HFTokenizer'>
>>> self.tokenizer.vocab_size
128000
>>> len(self.hf_tokenizer)
128256

so, replacing self.token_vocab = len(self.hf_tokenizer) seems to resolve this issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant