-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutils.py
99 lines (83 loc) · 2.72 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import numpy as np
import matplotlib.pyplot as plt
def accuracy(y_pred, y_true):
"""
Compute the accuracy of the provided predictions
y_pred (n) : prediciton
y_true (n) : true value to predict
"""
if len(y_pred) != len(y_true):
raise KeyError("prediction and truth must be the same size")
return np.sum(y_pred == y_true)/len(y_true)
def error(y_pred, y_true):
"""
Compute the error of the provided predictions
y_pred (n) : prediciton
y_true (n) : true value to predict
"""
if len(y_pred) != len(y_true):
raise KeyError("prediction and truth must be the same size")
return np.sum(y_pred != y_true)/len(y_true)
def plot_loss(loss, graph_title=None):
"""
a useful function to plot curves
:param loss: list (n) or dict (key, list(n)) values or batch of values to be plotted
:param graph_title: (str) grah
"""
if type(loss) == list:
idx = [i for i in range(len(loss))]
return plt.plot(idx, loss, title=graph_title)
else:
ax = None
for key, vals in loss.items():
idx = [i for i in range(len(vals))]
ax = plt.plot(idx, vals, title=graph_title, legend=key)
return ax
def compute_accuracies(wts, X, y_true, average=True):
"""
This function computes the accuracy using averaged weights
wts (txm) : weigths at each time step of the algo
X (nxm) : data to be predicted
y_true (n) : true value to predict
"""
accs = []
it, d = wts.shape
if average:
# here we compute the online mean weights
wts_mean = np.cumsum(wts, 0)/(np.arange(1, it + 1)[:, np.newaxis]) # here we compute the online mean weights
else :
wts_mean = wts
for weigts in wts_mean:
y_pred = np.sign(X.dot(weigts))
acc = accuracy(y_pred, y_true)
accs.append(acc)
return accs
def rate(wts, X, y):
"""
This function computes the accuracy using the actual weights (not averaged through time)
wts : weights provided during the online fitting
X : test data
y : test labels
"""
acc = []
for w in wts:
acc.append(np.mean(y*X.dot(w) > 0))
return acc
def compute_errors(wts, X, y_true, average=True):
"""
Compute the accuracy wrt time of the provided predictions and data
wts (txm) : weigths at each time step of the algo
X (nxm) : data to be predicted
y_true (n) : true value to predict
"""
errs = []
it, d = wts.shape
if average:
wts_mean = np.cumsum(wts, 0)/(np.arange(1, it + 1)[:, np.newaxis])
else:
wts_mean = wts
for weigts in wts_mean:
y_pred = np.sign(X.dot(weigts))
err = 1 - accuracy(y_pred, y_true)
errs.append(err)
return errs