-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathplot_SHD_results.py
46 lines (42 loc) · 1.28 KB
/
plot_SHD_results.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import numpy as np
import matplotlib.pyplot as plt
import sys
res= np.loadtxt(sys.argv[1]+"_results.txt", dtype= np.float32)
print(res.shape)
fig= plt.figure(figsize=(20,15))
fig.canvas.manager.set_window_title(sys.argv[1])
plt.subplot(2,2,1)
plt.plot(1-res[:,1])
plt.plot(1-res[:,3])
#plt.yscale("log")
#plt.xscale("log")
#plt.ylim(0.0001,1)
#plt.title("% Error")
plt.subplot(2,2,3)
plt.plot(res[:,2])
plt.plot(res[:,4])
#plt.legend(["train 0.95/0.9995", "eval 0.95/0.9995", "train 0.9/0.999", "eval 0.9/0.999", "train 0.99/0.9999", "eval 0.99/0.9999", "train 0.995/0.99995", "eval 0.995/0.99995", "train DT=1ms, 0.99/0.9999,pDrop= 0.1", "eval DT=1ms, 0.99/0.9999,pDrop= 0.1","train n_hid= 1000","eval n_hid= 1000"])
plt.yscale("log")
#plt.xscale("log")
#plt.ylim(0.001,20)
#plt.title("Loss")
plt.xlabel("epochs")
plt.subplot(2,2,2)
plt.plot(res[:,9])
plt.errorbar(np.arange(res.shape[0]),res[:,9],res[:,10])
plt.plot(res[:,11])
plt.plot(res[:,12])
#plt.title("sNSum")
plt.subplot(2,2,4)
plt.plot(res[:,13])
#plt.title("n_rewire")
plt.xlabel("epochs")
avg= np.mean(res[-10:,:], axis=0)
print(res[-10:,1])
print(res[-10:,3])
print(res[-10:,2])
print(res[-10:,4])
print(f"Performance: train {avg[1]}, eval {avg[3]}")
print(f"Loss: train {avg[2]}, eval {avg[4]}")
plt.show()
plt.savefig("test2.png", dpi= 300)