Skip to content

Commit

Permalink
Add Megatron-LM pretrain function for the core.
Browse files Browse the repository at this point in the history
  • Loading branch information
jaredcasper committed May 25, 2023
1 parent 9d83398 commit 8a85d59
Showing 1 changed file with 127 additions and 0 deletions.
127 changes: 127 additions & 0 deletions pretrain_gpt_core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.

"""Pretrain GPT"""

import torch
from functools import partial
from megatron import get_args
from megatron.arguments import core_transformer_config_from_args
from megatron import print_rank_0
from megatron import get_timers
from megatron import get_tokenizer
from megatron.core import tensor_parallel
from megatron.core.enums import ModelType
from megatron.data.gpt_dataset import build_train_valid_test_datasets
from megatron.core.models.gpt import GPTModel
from megatron.training import pretrain
from megatron.utils import get_ltor_masks_and_position_ids
from megatron.utils import average_losses_across_data_parallel_group

def model_provider(pre_process=True, post_process=True):
"""Build the model."""

args = get_args()
config = core_transformer_config_from_args(args)

print_rank_0('building GPT model ...')
model = GPTModel(
config=config,
vocab_size=args.padded_vocab_size,
max_sequence_length=args.max_position_embeddings,
pre_process=pre_process,
post_process=post_process,
fp16_lm_cross_entropy=args.fp16_lm_cross_entropy,
parallel_output=True,
share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights
)
return model


def get_batch(data_iterator):
"""Generate a batch"""
args = get_args()
tokenizer = get_tokenizer()

# Items and their type.
keys = ['text']
datatype = torch.int64

# Broadcast data.
if data_iterator is not None:
data = next(data_iterator)
else:
data = None
data_b = tensor_parallel.broadcast_data(keys, data, datatype)

# Unpack.
tokens_ = data_b['text'].long()
labels = tokens_[:, 1:].contiguous()
tokens = tokens_[:, :-1].contiguous()

# Get the masks and postition ids.
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
tokens,
tokenizer.eod,
args.reset_position_ids,
args.reset_attention_mask,
args.eod_mask_loss)

return tokens, labels, loss_mask, attention_mask, position_ids

def loss_func(loss_mask, output_tensor):
losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()

# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])

return loss, {'lm loss': averaged_loss[0]}


def forward_step(data_iterator, model):
"""Forward step."""
args = get_args()
timers = get_timers()

# Get the batch.
timers('batch-generator', log_level=2).start()
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
data_iterator)
timers('batch-generator').stop()

output_tensor = model(tokens, position_ids, attention_mask,
labels=labels)

return output_tensor, partial(loss_func, loss_mask)


def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build train, valid, and test datasets."""
args = get_args()

print_rank_0('> building train, validation, and test datasets '
'for GPT ...')
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
data_prefix=args.data_path,
data_impl=args.data_impl,
splits_string=args.split,
train_valid_test_num_samples=train_val_test_num_samples,
seq_length=args.seq_length,
seed=args.seed,
skip_warmup=(not args.mmap_warmup),
train_data_prefix=args.train_data_path,
valid_data_prefix=args.valid_data_path,
test_data_prefix=args.test_data_path)
print_rank_0("> finished creating GPT datasets ...")

return train_ds, valid_ds, test_ds


if __name__ == "__main__":

pretrain(train_valid_test_datasets_provider, model_provider,
ModelType.encoder_or_decoder,
forward_step,
args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}
)

0 comments on commit 8a85d59

Please sign in to comment.