Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add missing run decorators #813

Merged
merged 8 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .github/workflows/run_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -202,3 +202,9 @@ jobs:
export GSLDIR=$(gsl-config --prefix)
export PYTHONPATH=$(pwd):$PYTHONPATH
NuRadioReco/test/test_examples.sh
- name: "Test module structure of all modules in NuRadioReco.modules"
if: always()
run: |
export GSLDIR=$(gsl-config --prefix)
export PYTHONPATH=$(pwd):$PYTHONPATH
python NuRadioReco/test/check_modules.py -r
2 changes: 2 additions & 0 deletions NuRadioReco/modules/RNO_G/crRNOGTemplateCreator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import pickle
import os
import NuRadioReco.modules.channelBandPassFilter
from NuRadioReco.modules.base.module import register_run


class crRNOGTemplateCreator:
Expand Down Expand Up @@ -135,6 +136,7 @@ def set_template_parameter(self, template_run_id:list[int]=[0, 0, 0], template_e
self.__cr_azimuth = cr_azimuth


@register_run()
def run(self, template_filename:str='templates_cr_station_101.pickle', include_hardware_response:bool=True, hardware_response_source:str='json',
return_templates:bool=False, bandpass_filter:None|dict[str,Any]=None) -> None|list[Event]:
"""
Expand Down
2 changes: 2 additions & 0 deletions NuRadioReco/modules/beamFormingDirectionFitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from NuRadioReco.framework.parameters import stationParameters as stnp
import NuRadioReco.modules.voltageToEfieldConverterPerChannel
import NuRadioReco.modules.electricFieldBandPassFilter
from NuRadioReco.modules.base.module import register_run


electricFieldBandPassFilter = NuRadioReco.modules.electricFieldBandPassFilter.electricFieldBandPassFilter()
Expand Down Expand Up @@ -98,6 +99,7 @@ def begin(self, debug=False, log_level=logging.NOTSET):
self.logger.setLevel(log_level)
self.__debug = debug

@register_run()
def run(self, evt, station, det, polarization, n_index=None, channels=None, ZenLim=None,
AziLim=None):
"""
Expand Down
12 changes: 7 additions & 5 deletions NuRadioReco/modules/channelCWNotchFilter.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
"""
Contains module to filter continuous wave out of the signal using notch filters
on peaks in frequency spectrum
"""

import logging
logger = logging.getLogger("NuRadioReco.channelCWNotchFilter")
import time
import numpy as np
from scipy import signal
from NuRadioReco.utilities import units
from NuRadioReco.utilities import fft

"""
Contains module to filter continuous wave out of the signal using notch filters
on peaks in frequency spectrum
"""
from NuRadioReco.modules.base.module import register_run


def find_frequency_peaks_from_trace(trace : np.ndarray, fs : float, threshold : float = 4):
Expand Down Expand Up @@ -222,6 +223,7 @@ def begin(self, quality_factor=1e3, threshold=4, save_filters=False):
# dictionary to cache known notch filters at specific frequencies
self.filter_cache = {}

@register_run()
def run(self, event, station, det):
for channel in station.iter_channels():
fs = channel.get_sampling_rate()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import NuRadioReco.modules.iftElectricFieldReconstructor.operators
import NuRadioReco.framework.base_trace
import NuRadioReco.framework.electric_field
from NuRadioReco.modules.base.module import register_run

import scipy
import nifty5 as ift
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -188,6 +190,7 @@ def make_priors_plot(self, event, station, detector, channel_ids):
)
self.__draw_priors(event, station, frequency_domain)

@register_run()
def run(self, event, station, detector, channel_ids, efield_scaling, use_sim=False):
"""
Run the electric field reconstruction
Expand Down
2 changes: 2 additions & 0 deletions NuRadioReco/modules/io/coreas/readCoREASShower.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import NuRadioReco.framework.event
import NuRadioReco.framework.station
from NuRadioReco.framework.parameters import showerParameters as shp
from NuRadioReco.modules.base.module import register_run
from NuRadioReco.modules.io.coreas import coreas
from NuRadioReco.utilities import units
from radiotools import coordinatesystems
Expand Down Expand Up @@ -61,6 +62,7 @@ def begin(self, input_files, det=None, logger_level=logging.NOTSET, set_ascendin

self.__ascending_run_and_event_number = 1 if set_ascending_run_and_event_number else 0

@register_run()
def run(self):
"""
Reads in CoREAS file(s) and returns one event containing all simulated observer positions as stations.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from scipy import optimize as opt
import matplotlib.pyplot as plt

from NuRadioReco.modules.base.module import register_run
from NuRadioReco.detector import antennapattern
from NuRadioReco.utilities import units, fft, geometryUtilities as geo_utl
from NuRadioMC.SignalGen import askaryan as ask
Expand Down Expand Up @@ -46,6 +47,7 @@ def begin(self):
self.antenna_provider = antennapattern.AntennaPatternProvider()
pass

@register_run()
def run(self, evt, station, det, icemodel, shower_type='HAD', use_channels=[0,1,2,3], attenuation_model='SP1',
parametrization='Alvarez2000', hilbert=False, use_bandpass_filter=False, passband_low={}, passband_high={},
include_focusing=False, use_MC=True, n_samples_multiplication_factor=1, plot_traces_with_true_input=False, debug=False):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import scipy.signal
import matplotlib.pyplot as plt
from NuRadioReco.utilities import units, fft
from NuRadioReco.modules.base.module import register_run
import NuRadioReco.utilities.io_utilities
import NuRadioReco.framework.electric_field
from NuRadioReco.framework.parameters import stationParameters as stnp
Expand Down Expand Up @@ -111,6 +112,7 @@ def begin(self, station_id, channel_ids, detector, passband=None, template=None,
self.__template = template
self.__output_path = output_path

@register_run()
def run(self, event, station, max_distance, z_width, grid_spacing, direction_guess=None, debug=False, use_dnr=False):
"""
Execute the 2D vertex reconstruction
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from matplotlib import cm
import matplotlib.pyplot as plt
from NuRadioReco.utilities import units
from NuRadioReco.modules.base.module import register_run
import NuRadioReco.utilities.io_utilities
import NuRadioReco.framework.electric_field
import NuRadioReco.detector.antennapattern
Expand Down Expand Up @@ -165,6 +166,7 @@ def begin(
self.__header[int(channel_z)] = f['header']
self.__lookup_table[int(abs(channel_z))] = f['antenna_{}'.format(channel_z)]

@register_run()
def run(
self,
event,
Expand Down
3 changes: 2 additions & 1 deletion NuRadioReco/modules/sphericalWaveFitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import scipy.signal
from scipy import constants
from mpl_toolkits.axes_grid1 import make_axes_locatable
from NuRadioReco.modules.base.module import register_run


class sphericalWaveFitter:
Expand All @@ -22,7 +23,7 @@ def begin(self, channel_ids = [0, 3, 9, 10]):
self.__channel_ids = channel_ids
pass


@register_run()
def run(self, evt, station, det, start_pulser_position, n_index = None, debug = True):

print("channels used for this reconstruction:", self.__channel_ids)
Expand Down
106 changes: 106 additions & 0 deletions NuRadioReco/test/check_modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
#! /usr/bin/env python3


# This script will walk the module directory, load everything that looks like a
# module and try to import (possibly failing if for example, you don't have all
# the optional things for each module)

# Then it will find all classes defined in the module with a run method, and
# complain if the time attribute is not present (this is an attribute added by
# the register_run decorator) We also check for classes that have run defined but no begining/end

# It prints out a bunch of chatty output, then at the end, prints a list of
# modules that couldn't be imported, classes that have run but no decorator,
# and classes that have run but no begin/end.
#




import os
import pathlib
import inspect
import importlib
import argparse


if __name__ == "__main__":
argparser = argparse.ArgumentParser(
description='Simple script to check NuRadioReco classes in /modules conform to the convention.'
)
argparser.add_argument('-r', '--run', action='store_true', help="Raise error on missing/broken `run` method.")
argparser.add_argument('-b', '--begin', action='store_true', help="Raise error on missing/broken `begin` method.")
argparser.add_argument('-e', '--end', action='store_true', help="Raise error on missing/broken `end` method.")
argparser.add_argument('-i', '--import',action='store_true', help="Raise error if module failed to import.", dest='broken') # args.import is not a valid name

args = argparser.parse_args()

broken = []
unregistered_runs = []
missing_begin = []
missing_end = []

# switch to the NuRadioMC parent directory
os.chdir(pathlib.Path(__file__).parents[2])

for dirpath,folders,files in os.walk(os.path.join('NuRadioReco','modules')):
for f in files:
if not f.endswith(".py") or f == "__init__.py":
continue
try:
mname = dirpath.replace('/','.')+'.'+f[:-3]
print("Trying ", mname)

m = importlib.import_module(mname)

for name,obj in inspect.getmembers(m, lambda member: inspect.isclass(member) and member.__module__ == mname):
print("Found class ",name, obj)
if hasattr(obj,'run') and not hasattr(obj.run,'time'):
print('Has run method but not registered properly! Public flogging will be scheduled.')
unregistered_runs.append(mname + '.' +name)

if hasattr(obj,'run') and not hasattr(obj,'begin'):
print ('Has run but no begin...')
missing_begin.append(mname + '.' + name)
if hasattr(obj,'run') and not hasattr(obj,'end'):
print ('Has run but no end...')
missing_end.append(mname + '.' + name)


except Exception as e:
print("Couldn't load module... maybe it's broken, oh well. Exception below:")
print('\t', e)
broken.append(mname)



print("\n\n\n.........................................\n")
print ("Broken modules:\n\t" + '\n\t'.join(broken))
print()
print ("Unregistered runs:\n\t" + '\n\t'.join(unregistered_runs))
print()
print ("Missing end:\n\t" + '\n\t'.join(missing_end))
print()
print ("Missing begin:\n\t" + '\n\t'.join(missing_begin))

exit_code = (
bool(args.run and len(unregistered_runs))
+ 2 * bool(args.begin and len(missing_begin))
+ 4 * bool(args.end and len(missing_end))
+ 8 * bool(args.broken and len(broken))
)

if exit_code: # it seems sys.exit(0) is sometimes still treated as an error, so we avoid calling it if there was no error.
print('\n\n' + 80*'!' + '\n' + f"One or more problems found, exiting with code {exit_code}")
exit(exit_code)