-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathWimpPlotClass.py
153 lines (115 loc) · 5.33 KB
/
WimpPlotClass.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import numpy as np
import matplotlib.pyplot as plt
import DataClasses as dc
from buildDataBase import buildDataBase
PATH_FIGURE_FOLDER = './plots/'
PREVIOUS_CLICK = None
class WimpPlot:
def __init__(self,
x_limits=(0.1,20),
y_limits=(1e-46, 1e-34),
database = None,
show_excludedregion = True,
show_plot = True,
save_plotname = None
):
self.x_limits=x_limits #tuple (x_min, x_max)
self.y_limits=y_limits #tuple (y_min, y_max)
self.DB = database #dictionary containing DataClass objects
## ===== Define the plotting style options =====
# plt.rcParams['axes.grid'] = True ## Turn the grid on for all plots
plt.rcParams.update({'font.size': 18}) ## Set the global font size for the plots
plt.rc('text', usetex=True) ## Use the LaTeX engine to draw text
plt.rc('font', family='serif') ## Select the typeface
self.fig, self.ax = plt.subplots(1,1, figsize = (9,7))
## Set the plot scales
self.ax.set_xscale('log')
self.ax.set_yscale('log')
## Set the plot limits
self.ax.set_xlim(x_limits)
self.ax.set_ylim(y_limits)
## Set the axis labels
#ax0.set_xlabel('DM mass [GeV/c$^{2}$]')
self.ax.set_xlabel('WIMP mass [GeV/c$^{2}$]')
self.ax.set_ylabel(r'SI WIMP-nucleon cross section $\sigma_{\chi n}^\mathrm{SI}$ [cm$^{2}$]')
## Turn on some extra tick marks
self.ax.xaxis.set_tick_params(top = 'True', which='minor')
self.ax.xaxis.set_tick_params(top = 'True', which='major')
self.ax.yaxis.set_tick_params(top = 'True', which='major')
if self.DB == None:
self.DB = buildDataBase()
self.addCurves(show_excludedregion)
if show_plot:
self.showPlot()
if type(save_plotname)==str:
if len(save_plotname) > 0:
print('saving...')
self.savePlot(save_plotname)
print('done')
#Show elements of the plot
#print(self.ax.get_children())
def getExcludedRegion(self):
## -------- CALCULATE THE EXCLUDED PARAMETER SPACE --------
x_val_arr = np.logspace( start = np.log10(self.x_limits[0]),
stop = np.log10(self.x_limits[1]),
num = 1000)
interp_array=[]
for item in self.DB.values():
if type(item) != dc.Curve: #use only curves
continue
if item.style not in ['-', 'solid']: #exclude projections
continue
interp_array.append( item.interpolator(np.power(x_val_arr,1)) )
if len(interp_array)<=0:
print('Warning: no available curves (not projection) for computing excluded region.')
return (x_val_arr, [])
exp_upper_lim = np.min(interp_array, axis=0) #minimun value of cross section across all above included curves for each mass
return (x_val_arr, exp_upper_lim)
def addCurves(self, excludedRegion = True):
## Add all items of dataBase to the plot
for item in self.DB.values():
item.plot(self.fig, self.ax)
# ## Add some lines (tree-level scattering through Z0)
# ax0.plot(plot_x_limits, 1e-39*np.ones(2), 'r--', linewidth=3.0)
## Fill in the exclusion curve
if excludedRegion:
(x_excluded, y_excluded) = self.getExcludedRegion()
if len(y_excluded)>0:
self.ax.fill_between(x_excluded, y_excluded, self.y_limits[1],
color = '#aaffc3',
zorder = 0,
alpha = 0.5,
lw = 0)
# ==============================================================================#
# switch to interactive mode and shows the plot on screen
#
def showPlot(self):
cid = self.fig.canvas.mpl_connect('button_press_event', self.onclick)
plt.ioff()
plt.show()
self.fig.canvas.mpl_disconnect(cid)
def onclick(self, event):
print('%s click: button=%d, x=%d, y=%d, xdata=%g, ydata=%gf' %
('double' if event.dblclick else 'single', event.button,
event.x, event.y, event.xdata, event.ydata))
#Rotation angle to help finding a suitable rotation of labels
global PREVIOUS_CLICK
if PREVIOUS_CLICK is not None:
print('Rotation angle: %.1f deg'
%np.rad2deg(np.arctan(
(event.y-PREVIOUS_CLICK[1])/(event.x-PREVIOUS_CLICK[0]) ))
if event.x-PREVIOUS_CLICK[0]!=0 else '')
PREVIOUS_CLICK = (event.x, event.y)
# ==============================================================================#
# saves the plot on a file
#
def savePlot(self, plotname):
filename = PATH_FIGURE_FOLDER + plotname
if not (plotname.endswith('.pdf') or plotname.endswith('.png') or
plotname.endswith('.svg')):
filename = filename + '.pdf'
try:
self.fig.savefig(filename , bbox_inches='tight')
except (FileNotFoundError):
filename=filename.replace(PATH_FIGURE_FOLDER,'')
print(filename + " saved.")