-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_distilbert.py
136 lines (114 loc) · 5.96 KB
/
train_distilbert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import os
import argparse
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification, Trainer, TrainingArguments
from datasets import load_dataset, Features, Value, ClassLabel, load_from_disk
import torch
import time
def tokenize_function(data, tokenizer, textfield):
return tokenizer(data[textfield], padding='max_length', truncation=True, max_length=512)
def compute_metrics(eval_pred):
logits, labels = eval_pred
predictions = torch.argmax(torch.tensor(logits), dim=-1)
return {"accuracy": (predictions == torch.tensor(labels)).float().mean().item()}
def timing_tracing_wrapper(func):
def wrapper(*args, **kwargs):
# simulate filesystem activity tracing
separator = kwargs.get("separator")
print(f'================ {separator} ==============')
os.path.isfile(f'/mnt/fs1/lroc/{separator}.txt')
# run module and get execution time
start_time = time.time()
result = func(*args)
end_time = time.time()
running_time = end_time - start_time
print(f"Running time for {func.__name__}: {running_time:.4f} seconds")
return result
return wrapper
@timing_tracing_wrapper
def load_model(model_dir):
# model_dir = os.path.join(save_model_dir, 'distilbert')
# dataset_dir = os.path.join(save_data_dir, 'imdb')
# Load tokenizer and model from local disk
tokenizer = DistilBertTokenizer.from_pretrained(model_dir)
model = DistilBertForSequenceClassification.from_pretrained(model_dir, num_labels=2) # 2 labels for sentiment analysis
return [tokenizer, model]
@timing_tracing_wrapper
def load_dataset_for_distilbert(dataset_paths):
# https://huggingface.co/docs/datasets/v1.12.0/package_reference/loading_methods.html
dataset = load_from_disk(dataset_paths)
# dataset = load_dataset('json', data_files={'train': 'path/to/train.json', 'test': 'path/to/test.json'})
# dataset = load_dataset('json', data_files=dataset_paths, split='train')
return [dataset]
@timing_tracing_wrapper
def tokenize_and_prepare_dataset(dataset, tokenizer):
# Tokenize the dataset
tokenized_datasets = dataset.map(lambda data: tokenize_function(data, tokenizer, args.textfield), batched=True)
# Prepare the dataset for training
train_dataset = tokenized_datasets['train'].shuffle(seed=42).select(range(args.train_samples))
eval_dataset = tokenized_datasets['test'].shuffle(seed=42).select(range(args.eval_samples))
return[train_dataset, eval_dataset]
@timing_tracing_wrapper
def train_arguments(args):
training_args = TrainingArguments(
output_dir=args.output_dir,
evaluation_strategy="epoch",
save_strategy="epoch",
learning_rate=args.learning_rate,
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size,
num_train_epochs=args.epochs,
weight_decay=0.01,
load_best_model_at_end=True,
metric_for_best_model="accuracy",
ddp_find_unused_parameters=False,
)
return [training_args]
@timing_tracing_wrapper
def train(model, training_args, train_dataset, eval_dataset):
# Initialize the Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
compute_metrics=compute_metrics,
)
trainer.train()
return [trainer]
@timing_tracing_wrapper
def eval_model(trainer):
# Evaluate the model
return [trainer.evaluate()]
# FIXME: check eval_trainer needed?
@timing_tracing_wrapper
def save_model(args, model, tokenizer):
model.save_pretrained(os.path.join(args.output_dir, "final_model"))
tokenizer.save_pretrained(os.path.join(args.output_dir, "final_model"))
def main(args):
# [tokenizer, model] = load_model(args.model_dir, separator="load_model")
# [dataset] = load_dataset_for_distilbert(args.data_dir, separator="load_dataset")
# [train_dataset, eval_dataset] = tokenize_and_prepare_dataset(dataset,tokenizer, separator='tokenize_dataset')
# [training_args] = train_arguments(args, separator='training_arguments')
# [trainer] = train(model, training_args, train_dataset, eval_dataset, 'start_training')
# [eval_trainer] = eval_model(trainer, separator="eval_model")
# save_model(separator='save_model_tokenizer')
[tokenizer, model] = load_model(args.model_dir, separator="load_model")
[dataset] = load_dataset_for_distilbert(args.data_dir, separator="load_dataset")
# [train_dataset, eval_dataset] = tokenize_and_prepare_dataset(dataset,tokenizer, separator='tokenize_dataset')
[training_args] = train_arguments(args, separator='training_arguments')
[trainer] = train(model, training_args, dataset, dataset, 'start_training')
# [eval_trainer] = eval_model(trainer, separator="eval_model")
# save_model(separator='save_model_tokenizer')
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Train DistilBERT model with IMDb dataset from local disk.')
parser.add_argument('--model_dir', type=str, required=True, help='Directory where the model is saved')
parser.add_argument('--data_dir', type=str, required=True, help='Directory where the dataset are saved')
parser.add_argument('--output_dir', type=str, default='./results', help='Directory to save the training results')
parser.add_argument('--train_samples', type=int, default=25000, help='Number of training samples to use')
parser.add_argument('--eval_samples', type=int, default=5000, help='Number of evaluation samples to use')
parser.add_argument('--batch_size', type=int, default=16, help='Batch size for training and evaluation')
parser.add_argument('--epochs', type=int, default=3, help='Number of training epochs')
parser.add_argument('--learning_rate', type=float, default=5e-5, help='Learning rate')
parser.add_argument('--textfield', type=str, default="reviewText", help="field name in the dataset that contains training data")
args = parser.parse_args()
main(args)