Skip to content

Commit

Permalink
Add masks to avoid division through 0 (or taking the log of 0)
Browse files Browse the repository at this point in the history
  • Loading branch information
fschlueter committed Jan 10, 2024
1 parent 3a473f9 commit f76d01b
Showing 1 changed file with 18 additions and 6 deletions.
24 changes: 18 additions & 6 deletions NuRadioReco/detector/RNO_G/rnog_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -1168,7 +1168,14 @@ def __call__(self, freq):
response = np.ones_like(freq, dtype=np.complex128)

for gain, phase, weight in zip(self.__gains, self.__phases, self.__weights):
response *= (gain(freq / units.GHz) * np.exp(1j * phase(freq / units.GHz))) ** weight

_gain = gain(freq / units.GHz)

# to avoid RunTime warning and NANs in total reponse
if weight == -1:
_gain = np.where(_gain > 0, _gain, 1e-6)

response *= (_gain * np.exp(1j * phase(freq / units.GHz))) ** weight

return response

Expand Down Expand Up @@ -1212,7 +1219,7 @@ def __rmul__(self, other):
def __str__(self):
ampl = 20 * np.log10(np.abs(self(0.5 * units.GHz)))
return "Response of " + ", ".join([f"{name} ({weight})" for name, weight in zip(self.get_names(), self.__weights)]) \
+ f": |R(0.5 GHz)| = {ampl} dB (amplitude)"
+ f": |R(0.5 GHz)| = {ampl:.2f} dB (amplitude)"

def plot(self, show=False, in_dB=True):
import matplotlib.pyplot as plt
Expand All @@ -1221,15 +1228,20 @@ def plot(self, show=False, in_dB=True):

fig, ax = plt.subplots()
for gain, weight, name in zip(self.__gains, self.__weights, self.__names):
_gain = gain(freqs)

if in_dB:
ax.plot(freqs / units.MHz, weight * 20 * np.log10(gain(freqs)), label=name)
mask = _gain > 0 # to avoid RunTime warning
ax.plot(freqs[mask] / units.MHz, weight * 20 * np.log10(_gain[mask]), label=name)
else:
ax.plot(freqs / units.MHz, gain(freqs), label=name)
ax.plot(freqs / units.MHz, _gain, label=name)

_gain = np.abs(self(freqs))
if in_dB:
ax.plot(freqs / units.MHz, 20 * np.log10(np.abs(self(freqs))), color="k", label="total")
mask = _gain > 0 # to avoid RunTime warning
ax.plot(freqs[mask] / units.MHz, 20 * np.log10(_gain[mask]), color="k", label="total")
else:
ax.plot(freqs / units.MHz, np.abs(self(freqs)), color="k", label="total")
ax.plot(freqs / units.MHz, _gain, color="k", label="total")

ax.set_xlabel("frequency / MHz")
if in_dB:
Expand Down

0 comments on commit f76d01b

Please sign in to comment.