forked from yjang43/pushingonreadability_transformers
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinference.py
150 lines (114 loc) · 4.54 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
# preliminary imports
import pandas as pd
import torch
import os
import argparse
import torch
import ast
from copy import deepcopy
from torch.utils.data import DataLoader
from easydict import EasyDict as edict
from transformers import (
BertForSequenceClassification,
BertTokenizer,
XLNetForSequenceClassification,
XLNetTokenizer,
RobertaForSequenceClassification,
RobertaTokenizer,
BartForSequenceClassification,
BartTokenizer
)
from tqdm import tqdm
from dataloader import LingFeatBatchGenerator, LingFeatDataset
from utils import get_logger, set_seed
parser = argparse.ArgumentParser()
# required
parser.add_argument('--checkpoint_path',
type=str,
help="path to model checkpoint")
parser.add_argument('--data_path',
type=str,
help="path to add neural network feature to")
# optional
parser.add_argument('--seed',
default=0,
type=int,
help="seed value")
parser.add_argument('--batch_size',
default=8,
type=int,
help="number of batch to infer at once")
parser.add_argument('--device',
default='cuda',
type=str,
help="set to 'cuda' to use GPU. 'cpu' otherwise")
args = parser.parse_args()
logger = get_logger()
set_seed(args.seed)
logger.info(f'args: {args}')
# define model/tokenizer class according to model_name
# checkpoint strictly needs to be in the following foramt:
# {checkpoint_dir}/{task}.{model_name}.{k-fold}.{n-eval}
corpus_name = args.checkpoint_path.split('.')[0]
model_name = args.checkpoint_path.split('.')[1]
k = args.checkpoint_path.split('.')[2]
l = args.checkpoint_path.split('.')[3]
# load model and tokenizer
if model_name.lower() == 'bert':
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained(args.checkpoint_path)
elif model_name.lower() == 'xlnet':
tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased')
model = XLNetForSequenceClassification.from_pretrained(args.checkpoint_path)
elif model_name.lower() == 'roberta':
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
model = RobertaForSequenceClassification.from_pretrained(args.checkpoint_path)
elif model_name.lower() == 'bart':
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
model = BartForSequenceClassification.from_pretrained(args.checkpoint_path)
else:
raise ValueError("Model must be either BERT or XLNet or RoBERTa")
# # DEPRECATED: an experiment to observe a trend depending on the size of data is pushed back
# # to check trend on 20p, 40p, 60p, 80p
# if args.checkpoint_path.split('.')[0][-1] == 'p': # weebit20p
# corpus_name = args.checkpoint_path.split('.')[0][:-3] # weebit
# model_name = args.checkpoint_path.split('.')[1] + args.checkpoint_path.split('.')[0][-3:] # bert + 20p
model.to(args.device)
model.eval()
# load data
df = pd.read_csv(args.data_path)
dataset = LingFeatDataset(df)
batch_generator = LingFeatBatchGenerator(tokenizer)
dataloader = DataLoader(dataset, collate_fn=batch_generator, batch_size=args.batch_size)
pred_label = f'{model_name}.{l}.pred'
prob_label = f'{model_name}.{l}.prob'
df[pred_label] = -1 # expected values: int value in between 0 and num_class-1
df[prob_label] = 'nan' # expected values: string of list of softmax values
# make inference from here
softmax = torch.nn.Softmax(dim=1)
progress = tqdm(range(len(dataloader)))
for batch_idx, batch_item in enumerate(dataloader):
inputs, labels, indices = batch_item
inputs.to(args.device)
with torch.no_grad():
logits = model(**inputs)[0].detach().cpu()
probs = softmax(logits)
preds = torch.argmax(probs, dim=1).tolist()
probs = probs.tolist()
probs = [str(x) for x in probs]
df.loc[indices, pred_label] = preds
df.loc[indices, prob_label] = probs
progress.update()
# get probability for each column
prob = ast.literal_eval(probs[0])
# initialize each column
for i in range(len(prob)):
df[f"{prob_label}.{i + 1}"] = -1.0
# set probability for each column
for idx, row in df.iterrows():
prob = ast.literal_eval(df.loc[idx, f'{model_name}.{l}.prob'])
for i, p in enumerate(prob):
df.loc[idx, f"{prob_label}.{i + 1}"] = prob[i]
df.drop(prob_label, inplace=True, axis=1)
df.to_csv(args.data_path, index=False)
logger.info(f'new features are created and can be found at: "{args.data_path}"')