Skip to content

Commit

Permalink
add support to run custom Hf tokenizer for training and dataset pre-p…
Browse files Browse the repository at this point in the history
…rocessing (#421)

* Update arguments.py

* Update tokenizer.py

* Update preprocess_data.py
  • Loading branch information
polisettyvarma authored Jul 18, 2024
1 parent fc989b8 commit 7d23e33
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 4 deletions.
2 changes: 2 additions & 0 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1299,6 +1299,8 @@ def _add_data_args(parser):
help='What type of tokenizer to use.')
group.add_argument('--tokenizer-model', type=str, default=None,
help='Sentencepiece tokenizer model.')
group.add_argument('--trust-remote-code', action='store_true', default=False,
help='To run HFTokenizer model from local path.')
group.add_argument('--data-impl', type=str, default='infer',
choices=['mmap', 'infer'],
help='Implementation of indexed datasets.')
Expand Down
12 changes: 9 additions & 3 deletions megatron/tokenizer/tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.

"""Megatron tokenizers."""
Expand Down Expand Up @@ -40,7 +41,9 @@ def build_tokenizer(args):
tokenizer = _NullTokenizer(args.vocab_size)
elif args.tokenizer_type == 'HFTokenizer':
assert args.tokenizer_model is not None
tokenizer = _HFTokenizer(args.tokenizer_model,args.seq_length)
tokenizer = _HFTokenizer(args.tokenizer_model,
args.seq_length,
args.trust_remote_code)
else:
raise NotImplementedError('{} tokenizer is not '
'implemented.'.format(args.tokenizer_type))
Expand Down Expand Up @@ -540,10 +543,13 @@ def additional_special_tokens_ids(self):

class _HFTokenizer(AbstractTokenizer):
"""HF Tokenizer"""
def __init__(self, tokenizer_name_or_path,max_seq_len):
def __init__(self, tokenizer_name_or_path, max_seq_len, trust_remote_code):
name = tokenizer_name_or_path
super().__init__(name)
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path,padding_side="right",use_fast=False)
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path,
padding_side="right",
trust_remote_code=trust_remote_code,
use_fast=False)

DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
Expand Down
8 changes: 7 additions & 1 deletion tools/preprocess_data.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.

"""Processing large data for pretraining."""
Expand Down Expand Up @@ -193,10 +194,15 @@ def get_args():
group.add_argument('--tokenizer-type', type=str, required=True,
choices=['BertWordPieceLowerCase','BertWordPieceCase',
'GPT2BPETokenizer', 'SentencePieceTokenizer',
'GPTSentencePieceTokenizer', 'NullTokenizer'],
'GPTSentencePieceTokenizer', 'HFTokenizer',
'NullTokenizer'],
help='What type of tokenizer to use.')
group.add_argument('--tokenizer-model', type=str, default=None,
help='YTTM tokenizer model.')
group.add_argument('--seq-length', type=int, default=None,
help='Maximum sequence length to process.')
group.add_argument('--trust-remote-code', action='store_true',
help='To run HFTokenizer model from local path.')
group.add_argument('--vocab-file', type=str, default=None,
help='Path to the vocab file')
group.add_argument('--vocab-size', default=786,
Expand Down

0 comments on commit 7d23e33

Please sign in to comment.