Skip to content

Commit

Permalink
MoE support (#48)
Browse files Browse the repository at this point in the history
Co-authored-by: Ammar Ahmad Awan <[email protected]>
Co-authored-by: Conglong Li <[email protected]>
Co-authored-by: Jeff Rasley <[email protected]>
Co-authored-by: Minjia Zhang <[email protected]>
Co-authored-by: Reza Yazdani <[email protected]>
Co-authored-by: yaozhewei <[email protected]>
Co-authored-by: Adam Moody <[email protected]>
  • Loading branch information
7 people authored Jun 7, 2022
1 parent 87476ce commit 50b3251
Show file tree
Hide file tree
Showing 40 changed files with 5,840 additions and 215 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
## Megatron-DeepSpeed
DeepSpeed version of NVIDIA's Megatron-LM that adds additional support for several features such as MoE, Curriculum Learning, 3D Parallelism, etc.

------

Megatron ([1](https://arxiv.org/pdf/1909.08053.pdf) and [2](https://arxiv.org/pdf/2104.04473.pdf)) is a large, powerful transformer developed by the Applied Deep Learning Research team at NVIDIA. This repository is for ongoing research on training large transformer language models at scale. We developed efficient, model-parallel (tensor and pipeline), and multi-node pre-training of transformer based models such as [GPT](https://arxiv.org/abs/2005.14165), [BERT](https://arxiv.org/pdf/1810.04805.pdf), and [T5](https://arxiv.org/abs/1910.10683) using mixed precision.

Below are some of the projects where we have directly used Megatron:
Expand Down Expand Up @@ -68,6 +73,12 @@ GPT-345M: wget --content-disposition https://api.ngc.nvidia.com/v2/models/nvidia

The models require vocabulary files to run. The BERT WordPiece vocab file can be extracted from Google's pretrained BERT models: [uncased](https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt), [cased](https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt). The GPT [vocab file](https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json) and [merge table](https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt) can be downloaded directly.

Additional notes for DeepSpeed. We have added a helper script to download the checkpoints and make the example runnable.

Steps to follow:
- bash ds_download_ckpt.sh -- this will download and extract the checkpoint and GPT merges and vocab files.
- bash examples/generate_text.sh -- this will generate examples using the 345m GPT model.

# Usage

After installation, there are several possible workflows. The most comprehensive is:
Expand Down
12 changes: 12 additions & 0 deletions ds_download_ckpt.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@

wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json
wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt

mkdir -p checkpoints/gpt2_345m

cd checkpoints/gpt2_345m
wget --content-disposition https://api.ngc.nvidia.com/v2/models/nvidia/megatron_lm_345m/versions/v0.0/zip -O megatron_lm_345m_v0.0.zip
unzip megatron_lm_345m_v0.0.zip
rm megatron_lm_345m_v0.0.zip
cd ../..

39 changes: 39 additions & 0 deletions examples/MoE/ds_config_gpt_TEMPLATE.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
{
"train_batch_size" : CONFIG_BATCH_SIZE,
"train_micro_batch_size_per_gpu": CONFIG_MBSIZE,
"steps_per_print": LOG_INTERVAL,

"zero_optimization": {
"stage": ZERO_STAGE,
"elastic_checkpoint": true
},

"gradient_clipping": 1.0,
"prescale_gradients": PRESCALE_GRAD,

"fp16": {
"enabled": CONFIG_FP16_ENABLED,
"loss_scale": 0,
"loss_scale_window": 500,
"hysteresis": 2,
"min_loss_scale": 1,
"initial_scale_power": 11
},

"bf16": {
"enabled": CONFIG_BF16_ENABLED
},
"curriculum_learning": {
"enabled": CONFIG_CL_ENABLED,
"curriculum_type": "seqlen",
"min_difficulty": CONFIG_CL_MIN,
"max_difficulty": CONFIG_CL_MAX,
"schedule_type": "fixed_linear",
"schedule_config": {
"total_curriculum_step": CONFIG_CL_DURATION,
"difficulty_step": 8
}
},

"wall_clock_breakdown" : false
}
38 changes: 38 additions & 0 deletions examples/MoE/ds_config_gpt_Zero2_TEMPLATE.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
{
"train_batch_size" : CONFIG_BATCH_SIZE,
"train_micro_batch_size_per_gpu": CONFIG_MBSIZE,
"steps_per_print": LOG_INTERVAL,

"zero_optimization": {
"stage": 2
},

"gradient_clipping": 1.0,
"prescale_gradients": false,

"fp16": {
"enabled": CONFIG_FP16_ENABLED,
"loss_scale": 0,
"loss_scale_window": 500,
"hysteresis": 2,
"min_loss_scale": 1,
"initial_scale_power": 11
},

"bf16": {
"enabled": CONFIG_BF16_ENABLED
},
"curriculum_learning": {
"enabled": CONFIG_CL_ENABLED,
"curriculum_type": "seqlen",
"min_difficulty": CONFIG_CL_MIN,
"max_difficulty": CONFIG_CL_MAX,
"schedule_type": "fixed_linear",
"schedule_config": {
"total_curriculum_step": CONFIG_CL_DURATION,
"difficulty_step": 8
}
},

"wall_clock_breakdown" : false
}
71 changes: 71 additions & 0 deletions examples/MoE/ds_evalharness.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# This is an example zero-shot eval script. Please first read the readme_evalharness.md under the same directory.

CHECKPOINT_PATH=/blob/users/conglli/project/gpt3_with_pile/checkpoint/gpt3-with-pile-0.125B-lr-2.4e-3-minlr-6.0e-5-bs-2048-gpus-128-zero-0-mp-1-pp-1-no_pp-cl-startseqlen-72-step-20728-token-45B/global_step81566/
CONFIG_PATH=ds_config_gpt3-with-pile-0.125B-lr-2.4e-3-minlr-6.0e-5-bs-2048-gpus-128-zero-0-mp-1-pp-1-no_pp-cl-startseqlen-72-step-20728-token-45B.json
RESULT_PATH=gpt3-with-pile-0.125B-lr-2.4e-3-minlr-6.0e-5-bs-2048-gpus-128-zero-0-mp-1-pp-1-no_pp-cl-startseqlen-72-step-20728-token-45B_global_step81566.log

PP_SIZE=1
TP_SIZE=1
NO_PP="true"
EP_PARALLEL_SIZE=1
# Currently eval harness does not support data parallel
# However, for MoE models it's possible to enable a "fake data parallel"
# in order to load experts on multiple gpus. At the same time, it's not
# real data parallel because we load the same data on all gpus.
# On the other hand, it's better to use less number of gpus than training,
# to reduce communication overhead.
NUM_NODE=1
NUM_GPU_PER_NODE=1

TASKS="lambada"
# WikiText-2, not used in GPT-3 paper but used in GPT-2 paper
# TASKS="wikitext"
# Tasks that appeared in GPT-3 paper (sorted based on the order in paper), plus WikiText-2.
# TASKS="hellaswag,lambada,triviaqa,webqs,winogrande,piqa,arc_challenge,arc_easy,openbookqa,race,boolq,cb,copa,rte,wic,wsc,multirc,record,anli_r1,anli_r2,anli_r3,wikitext"
# All tasks that confirmed to work, there are more tasks on https://github.com/EleutherAI/lm-evaluation-harness that we didn't test.
# TASKS="hellaswag,lambada,triviaqa,webqs,winogrande,piqa,arc_challenge,arc_easy,openbookqa,race,boolq,cb,copa,rte,wic,wsc,multirc,record,anli_r1,anli_r2,anli_r3,wikitext,logiqa,mathqa,mc_taco,mrpc,prost,pubmedqa,qnli,qqp,sciq,sst,wnli"

VOCAB_FILE=/data/Megatron-LM/data/gpt2-vocab.json
MERGE_FILE=/data/Megatron-LM/data/gpt2-merges.txt

export HF_DATASETS_OFFLINE=1

# Dummy arguments to make megatron happy. No need to configure them.
# The reason we don't need to configure them and many other arguments is
# because the eval framework will read the arguments from checkpoint file.
MEGATRON_REQUIRED_ARGS="\
--num-layers -1\
--hidden-size -1\
--num-attention-heads -1\
--seq-length -1 \
--max-position-embeddings -1
"

CMD="../../tasks/eval_harness/evaluate.py \
--load $CHECKPOINT_PATH\
--tensor-model-parallel-size $TP_SIZE \
--pipeline-model-parallel-size $PP_SIZE\
--moe-expert-parallel-size ${EP_PARALLEL_SIZE} \
--vocab-file $VOCAB_FILE\
--merge-file $MERGE_FILE\
--micro-batch-size 12\
--no-load-optim \
--no-load-rng \
--inference \
--disable-moe-token-dropping \
--adaptive_seq_len\
--eval_fp32\
--task_list $TASKS\
--results_path $RESULT_PATH \
--deepspeed \
--deepspeed_config $CONFIG_PATH \
$MEGATRON_REQUIRED_ARGS\
"

if [[ "${NO_PP}" = "true" ]]; then
CMD="${CMD} \
--no-pipeline-parallel"
fi

LAUNCHER="deepspeed --num_nodes $NUM_NODE --num_gpus $NUM_GPU_PER_NODE"
$LAUNCHER $CMD
Loading

0 comments on commit 50b3251

Please sign in to comment.