forked from IvanMikhailovIMCRAS/spectrum
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathoutput.py
155 lines (131 loc) · 4.8 KB
/
output.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import numpy as np
import matplotlib.pyplot as plt
from miscellaneous import voigt
SMALL_SIZE = 16
MEDIUM_SIZE = 18
BIGGER_SIZE = 22
def show_spectra(spectra, save_path='', wavenumbers=None):
"""
:param spectra: iterable of Spectrum - the spectra to be plotted
:param save_path: str - filename. If specified the view window of pyplot isn't shown
:param wavenumbers: (float, float) - limits of the spectra
Plots the collection of spectra
"""
if not spectra:
return
classes = list(sorted(set(map(lambda x: x.clss, spectra))))
colors = plt.cm.rainbow(np.linspace(0, 1, len(classes)))
colors = dict(zip(classes, colors))
plt.figure(figsize=(20, 8))
class_color = dict.fromkeys(classes, None)
for spc in spectra:
if wavenumbers:
spc = spc.range(*wavenumbers)
line = plt.plot(spc.wavenums, spc.data, c=colors[spc.clss], linewidth=0.5)
class_color[spc.clss] = line[0]
plt.rc('font', size=MEDIUM_SIZE) # controls default text sizes
plt.rc('axes', titlesize=BIGGER_SIZE) # fontsize of the axes title
plt.rc('axes', labelsize=BIGGER_SIZE) # fontsize of the x and y labels
plt.rc('xtick', labelsize=MEDIUM_SIZE) # fontsize of the tick labels
plt.rc('ytick', labelsize=MEDIUM_SIZE) # fontsize of the tick labels
plt.rc('legend', fontsize=BIGGER_SIZE) # legend fontsize
plt.rc('figure', titlesize=MEDIUM_SIZE) # fontsize of the figure title
labels, handlers = zip(*class_color.items())
if len(labels) > 1:
plt.legend(handlers, labels, loc=0)
spectrum = spectra[-1]
# if len(spectrum) > 1:
# plt.xlim(spectrum.wavenums[0], spectrum.wavenums[-1])
# print(spectrum.wavenums[0], spectrum.wavenums[-1])
plt.xlabel('Wavenumber, cm-1')
plt.ylabel('ATR units')
if save_path:
plt.savefig(fname=save_path, dpi=600)
else:
plt.show()
def show_curve_approx(spc, peaks, *, path=None):
"""
:param spc: Spectrum
:param peaks: (list of (amplitudes, positions, widths, gauss proportions)
:param path: str - filename. If specified the view window of pyplot isn't shown
Plot the spectrum and its bandwise decomposition on the same canvas
"""
x = spc.wavenums
plt.plot(x, spc.data)
for amp, mu, w, g in peaks:
plt.plot(x, voigt(x, amp, mu, w, g))
plt.xlabel('Some units')
plt.ylabel('Intensity')
if path:
plt.savefig(path)
else:
plt.show()
def spectra_log(spectra_dict, path='log.txt'):
"""
:param spectra_dict: dict (hashable : Spectrum)
:param path: str
Types the spectra collection into the file by path.
"""
with open(path, 'w') as f:
for spc in spectra_dict:
print(spectra_dict[spc], file=f)
def heatmap(data, ax=None,
cbar_kw=None, cbarlabel="", path='', **kwargs):
if ax is None:
ax = plt.gca()
if cbar_kw is None:
cbar_kw = {}
# Plot the heatmap
im = ax.imshow(data, **kwargs)
# Create colorbar
cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw)
cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom")
ax.grid(which="minor", color="w", linestyle='-', linewidth=3)
ax.tick_params(which="minor", bottom=False, left=False)
if path:
plt.savefig(path, dpi=600)
else:
plt.show()
return im, cbar, ax
def auto_heatmap(spc, step=100):
"""
Plot correlations within the spectrum
"""
def each_to_each(spc):
mtr = []
for i in range(len(spc)):
mtr.append(np.roll(spc.data, i))
return np.vstack(mtr)
corrcoefs = np.corrcoef(each_to_each(spc))
*_, ax = heatmap(corrcoefs)
ax.set_xticks(np.arange(0, len(spc), step), labels=spc.wavenums[::step])
ax.set_yticks(np.arange(0, len(spc), step), labels=spc.wavenums[::step])
# Rotate the tick labels and set their alignment.
plt.setp(ax.get_xticklabels(), rotation=30, ha="right",
rotation_mode="anchor")
plt.show()
def plot_margins(X, y, margins, path='', cm=None):
"""
:param X: pandas.DataFrame - population
:param y: pandas.Series - labels
:param margins: iterable - objects margins
:param path: str
:param cm: colormap
"""
if not cm:
cm = plt.cm.get_cmap('tab20')
plt.figure(figsize=(20, 16))
plt.axhline(0)
lab = ''
counter = 0
for i, label in enumerate(y):
if lab != label:
lab = label
counter += 1
plt.plot(X[i, 0], margins[i], 'o', label=lab, color=cm.colors[counter])
else:
plt.plot(X[i, 0], margins[i], 'o', color=cm.colors[counter])
if path:
plt.savefig(path)
else:
plt.show()