-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtrain_tokenizer.py
53 lines (41 loc) · 1.64 KB
/
train_tokenizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import argparse
from pathlib import Path
from tokenizers.implementations import BertWordPieceTokenizer
from src.utils.common import load_json
def train_tokenizer(
data_path: str = "resources/data/noisy_dataset.jsonl", save_to: str = "tokenizer", vocab_size: int = 30_000
):
"""
Trains the WordPiece tokenizer. Used to operate a classic transformer encoder.
:param data_path: the data on which the tokenizer is trained (default: resources/data/noisy_dataset.jsonl)
:param save_to: where to save the tokenizer (default: tokenizer)
:return:
"""
bert_tokenizer = BertWordPieceTokenizer(
unk_token="[UNK]",
sep_token="[SEP]",
cls_token="[CLS]",
clean_text=False,
handle_chinese_chars=False,
lowercase=False,
wordpieces_prefix="##",
)
data = load_json(data_path)
texts = list(map(lambda x: x["text"], data))
bert_tokenizer.train_from_iterator(
texts,
vocab_size=vocab_size,
wordpieces_prefix="##",
special_tokens=["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"],
)
output_dir = Path(save_to)
if not output_dir.is_dir():
output_dir.mkdir(parents=True, exist_ok=True)
bert_tokenizer.save_model(save_to)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data", type=str, default=f"resources/data/noisy_dataset.jsonl")
parser.add_argument("--save-to", type=str, default=f"tokenizer")
parser.add_argument("--vocab-size", type=int, default=30_000)
args = parser.parse_args()
train_tokenizer(data_path=args.data, save_to=args.save_to, vocab_size=args.vocab_size)