-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathplot_heatmap.py
40 lines (34 loc) · 1.02 KB
/
plot_heatmap.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# Compute Plots for different purposes
import seaborn as sns
import matplotlib as mpl
from matplotlib import font_manager
from scipy.optimize import linear_sum_assignment
font_manager._rebuild()
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import pickle
EPS = 1e-8
# Set Plotting Variables
sns.color_palette("dark", as_cmap=True)
sns.set(style="darkgrid")#, font_scale=1.1)
font = {'family' : 'Open Sans',
'size' : 32}
mpl.rc('font', **font)
size=16
activations = [
[0.95, 0., 0., 0.05],
[0.0, 0., 0., 1.0],
[0.35, 0.1, 0., 0.55],
[0., 1., 0., 0.]
]
activations = np.array(activations)
ax = sns.heatmap(activations, cmap="YlGnBu", linewidths=2.,
xticklabels=range(1, 5), yticklabels=range(1, 5),
vmin=0., vmax=1.
)
ax.set_xlabel('Module', fontsize=size)
ax.set_ylabel('Rule', fontsize=size)
ax.set_title('Ground Truth Rules: 4 | Modules: 4', fontsize=size)
plt.savefig('Main_Plots/collapse_eg.pdf', bbox_inches='tight')
plt.close()