-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathMakeMIEvolutionPlot.py
72 lines (55 loc) · 2.56 KB
/
MakeMIEvolutionPlot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os, pickle
from argparse import ArgumentParser
from plotting.PerformancePlotter import PerformancePlotter
def _load_dict(path):
retdict = {}
try:
with open(path, 'rb') as infile:
retdict = pickle.load(infile)
except FileNotFoundError:
print("file {} not found, ignoring".format(path))
return retdict
def _load_metadata(path, section):
from configparser import ConfigParser
gconfig = ConfigParser()
gconfig.read(path)
pars_dict = {key: val for key, val in gconfig[section].items()}
return pars_dict
def MakeMIEvolutionPlot(plotdir, workdirs):
plotdata_x = []
plotdata_y = []
labels = []
traces_to_plot = ["binned_MI_tukey", "binned_MI_cellucci", "binned_MI_bendat_piersol", "neural_MI"]
trace_labels = ["Tukey", "C-A-R", "Bendat-Piersol", "MINE"]
style_library = ["--", "-.", "dotted", "-"]
style_labels = {}
styles = []
color_labels = {}
color_library = ['orange', 'royalblue', 'green']
colors = []
# load the traces as well as the metadata (lambda of the run)
for workdir, cur_color in zip(workdirs, color_library):
tracedict = _load_dict(os.path.join(workdir, "training_evolution.pkl"))
anadict = _load_dict(os.path.join(workdir, "anadict.pkl"))
if not anadict:
# did not find it in this format, look into the metadata directly
anadict = _load_metadata(os.path.join(workdir, "meta.conf"), section = "AdversarialEnvironment")
x_data = tracedict["batch"]
for trace, trace_label, style in zip(traces_to_plot, trace_labels, style_library):
y_data = tracedict[trace]
plotdata_x.append(x_data)
plotdata_y.append(y_data)
labels.append(trace)
styles.append(style)
colors.append(cur_color)
style_labels[style] = trace_label
cur_lambda = anadict["lambda"]
color_labels[cur_color] = '$\lambda_{{\mathrm{{MIND}}}} = {}$'.format(cur_lambda)
outfile_path = os.path.join(plotdir, "MI_evolution.pdf")
PerformancePlotter._simple_plot(plotdata_x, plotdata_y, colors, styles, style_labels, color_labels, outfile_path, xlabel = "minibatch", ylabel = r'$\widehat{\mathrm{MI}}_{\mathrm{bkg}}(\hat{y}, m_{bb})$')
if __name__ == "__main__":
parser = ArgumentParser(description = "show evolution of MI as the training progresses")
parser.add_argument("--plotdir", action = "store")
parser.add_argument("--workdirs", nargs = '+', action = "store")
args = vars(parser.parse_args())
MakeMIEvolutionPlot(**args)