-
Notifications
You must be signed in to change notification settings - Fork 35
/
xer.py
75 lines (65 loc) · 2.3 KB
/
xer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import logging
logging.basicConfig(
format='%(levelname)s(%(filename)s:%(lineno)d): %(message)s')
def levenshtein(u, v):
prev = None
curr = [0] + list(range(1, len(v) + 1))
# Operations: (SUB, DEL, INS)
prev_ops = None
curr_ops = [(0, 0, i) for i in range(len(v) + 1)]
for x in range(1, len(u) + 1):
prev, curr = curr, [x] + ([None] * len(v))
prev_ops, curr_ops = curr_ops, [(0, x, 0)] + ([None] * len(v))
for y in range(1, len(v) + 1):
delcost = prev[y] + 1
addcost = curr[y - 1] + 1
subcost = prev[y - 1] + int(u[x - 1] != v[y - 1])
curr[y] = min(subcost, delcost, addcost)
if curr[y] == subcost:
(n_s, n_d, n_i) = prev_ops[y - 1]
curr_ops[y] = (n_s + int(u[x - 1] != v[y - 1]), n_d, n_i)
elif curr[y] == delcost:
(n_s, n_d, n_i) = prev_ops[y]
curr_ops[y] = (n_s, n_d + 1, n_i)
else:
(n_s, n_d, n_i) = curr_ops[y - 1]
curr_ops[y] = (n_s, n_d, n_i + 1)
return curr[len(v)], curr_ops[len(v)]
def load_file(fname, encoding):
try:
f = open(fname, 'r')
data = []
for line in f:
data.append(line.rstrip('\n').rstrip('\r').decode(encoding))
f.close()
except:
logging.error('Error reading file "%s"', fname)
exit(1)
return data
def cer_function(ref, hyp):
wer_s, wer_i, wer_d, wer_n = 0, 0, 0, 0
cer_s, cer_i, cer_d, cer_n = 0, 0, 0, 0
sen_err = 0
for n in range(len(ref)):
# update CER statistics
_, (s, i, d) = levenshtein(ref[n], hyp[n])
cer_s += s
cer_i += i
cer_d += d
cer_n += len(ref[n])
# update WER statistics
_, (s, i, d) = levenshtein(ref[n].split(), hyp[n].split())
wer_s += s
wer_i += i
wer_d += d
wer_n += len(ref[n].split())
# update SER statistics
if s + i + d > 0:
sen_err += 1
print(cer_s, cer_i, cer_d, cer_n)
return (cer_s + cer_i + cer_d) / cer_n
if __name__ == '__main__':
ref = ['天然气用户为优先允许限制类和禁止类']
hyp = ['天然气用户为优先允许限制类和禁止量内']
cer_function = cer_function(ref, hyp)
print(cer_function)