-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathner_model.py
135 lines (112 loc) · 5.26 KB
/
ner_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
#!/usr/bin/env python2.7
# -*- coding: utf-8 -*-
"""
A model for named entity recognition.
"""
import pdb
import logging
import tensorflow as tf
from util import ConfusionMatrix, Progbar, minibatches
from data_util import get_chunks
from model import Model
from defs import LBLS
logger = logging.getLogger("hw3")
logger.setLevel(logging.DEBUG)
logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.DEBUG)
class NERModel(Model):
"""
Implements special functionality for NER models.
"""
def __init__(self, helper, config, report=None):
self.helper = helper
self.config = config
self.report = report
def preprocess_sequence_data(self, examples):
"""Preprocess sequence data for the model.
Args:
examples: A list of vectorized input/output sequences.
Returns:
A new list of vectorized input/output pairs appropriate for the model.
"""
raise NotImplementedError("Each Model must re-implement this method.")
def consolidate_predictions(self, data_raw, data, preds):
"""
Convert a sequence of predictions according to the batching
process back into the original sequence.
"""
raise NotImplementedError("Each Model must re-implement this method.")
def evaluate(self, sess, examples, examples_raw):
"""Evaluates model performance on @examples.
This function uses the model to predict labels for @examples and constructs a confusion matrix.
Args:
sess: the current TensorFlow session.
examples: A list of vectorized input/output pairs.
examples: A list of the original input/output sequence pairs.
Returns:
The F1 score for predicting tokens as named entities.
"""
token_cm = ConfusionMatrix(labels=LBLS)
correct_preds, total_correct, total_preds = 0., 0., 0.
for _, labels, labels_ in self.output(sess, examples_raw, examples):
for l, l_ in zip(labels, labels_):
token_cm.update(l, l_)
gold = set(get_chunks(labels))
pred = set(get_chunks(labels_))
correct_preds += len(gold.intersection(pred))
total_preds += len(pred)
total_correct += len(gold)
p = correct_preds / total_preds if correct_preds > 0 else 0
r = correct_preds / total_correct if correct_preds > 0 else 0
f1 = 2 * p * r / (p + r) if correct_preds > 0 else 0
return token_cm, (p, r, f1)
def run_epoch(self, sess, train_examples, dev_set, train_examples_raw, dev_set_raw):
prog = Progbar(target=1 + int(len(train_examples) / self.config.batch_size))
for i, batch in enumerate(minibatches(train_examples, self.config.batch_size)):
loss = self.train_on_batch(sess, *batch)
prog.update(i + 1, [("train loss", loss)])
if self.report: self.report.log_train_loss(loss)
print("")
#logger.info("Evaluating on training data")
#token_cm, entity_scores = self.evaluate(sess, train_examples, train_examples_raw)
#logger.debug("Token-level confusion matrix:\n" + token_cm.as_table())
#logger.debug("Token-level scores:\n" + token_cm.summary())
#logger.info("Entity level P/R/F1: %.2f/%.2f/%.2f", *entity_scores)
logger.info("Evaluating on development data")
token_cm, entity_scores = self.evaluate(sess, dev_set, dev_set_raw)
logger.debug("Token-level confusion matrix:\n" + token_cm.as_table())
logger.debug("Token-level scores:\n" + token_cm.summary())
logger.info("Entity level P/R/F1: %.2f/%.2f/%.2f", *entity_scores)
f1 = entity_scores[-1]
return f1
def output(self, sess, inputs_raw, inputs=None):
"""
Reports the output of the model on examples (uses helper to featurize each example).
"""
if inputs is None:
inputs = self.preprocess_sequence_data(self.helper.vectorize(inputs_raw))
preds = []
prog = Progbar(target=1 + int(len(inputs) / self.config.batch_size))
for i, batch in enumerate(minibatches(inputs, self.config.batch_size, shuffle=False)):
# Ignore predict
batch = batch[:1] + batch[2:]
preds_ = self.predict_on_batch(sess, *batch)
preds += list(preds_)
prog.update(i + 1, [])
return self.consolidate_predictions(inputs_raw, inputs, preds)
def fit(self, sess, saver, train_examples_raw, dev_set_raw):
best_score = 0.
train_examples = self.preprocess_sequence_data(train_examples_raw)
dev_set = self.preprocess_sequence_data(dev_set_raw)
for epoch in range(self.config.n_epochs):
logger.info("Epoch %d out of %d", epoch + 1, self.config.n_epochs)
score = self.run_epoch(sess, train_examples, dev_set, train_examples_raw, dev_set_raw)
if score > best_score:
best_score = score
if saver:
logger.info("New best score! Saving model in %s", self.config.model_output)
saver.save(sess, self.config.model_output)
print("")
if self.report:
self.report.log_epoch()
self.report.save()
return best_score