-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathplot_multitask_training_log.py
83 lines (79 loc) · 2.6 KB
/
plot_multitask_training_log.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import pandas as pd
from matplotlib import pyplot as plt
import os
import numpy as np
import typer
def main(refresh : bool = True, logdir : str = None, paper : bool = False):
if logdir is None:
logdir = os.path.expanduser("~/navdreams_data/results/logs/multitask")
plt.close('all')
while True:
plt.figure("training log")
plt.clf()
fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True, num="training log")
logs = sorted(os.listdir(logdir))
legend = []
lines = []
for log in logs:
if not log.endswith(".csv"):
continue
# name
name = log
if paper:
if "baseline" in log:
name = "Task-specific features"
elif "E2E" in log:
name = "End-to-end features"
elif "N3D" in log:
name = "World-Model features"
else:
raise NotImplementedError
# title and axes
if "segmenter" in log:
ax = ax1
ax.set_title("Segmentation Error")
ax.set_ylabel("Segmentation BCE Error")
elif "depth" in log:
ax = ax2
ax.set_title("Depth Error")
ax.set_ylabel("Depth Mean Square Proportional Error")
else:
raise NotImplementedError
# read data
path = os.path.join(logdir, log)
data = pd.read_csv(path)
x = data["step"].values
y = data["test_error"].values
if paper:
if x[-1] < 80000:
continue
if "sequence" in log:
continue
style = "solid"
color = None
if paper:
if "baseline" in log:
style = "dashed"
color = 'k'
line, = ax.plot(x, y, linestyle=style, color=color, label=name)
ax.axhline(np.min(y), alpha=0.3, linewidth=1, color=line.get_color())
ax.set_xlabel("training steps")
lines.append(line)
legend.append(name)
ax1.axhline(0, linewidth=1, color='k')
ax2.axhline(0, linewidth=1, color='k')
if paper:
ax1.set_ylim((0, 0.05))
ax2.set_ylim((0, 0.1))
ax1.legend()
else:
ax1.legend()
ax2.legend()
if refresh:
plt.ion()
plt.pause(10.)
else:
plt.show()
break
if __name__ == "__main__":
typer.run(main)