Skip to content

Commit

Permalink
Merge pull request #572 from nu-radio/run-decorator-and-serialization
Browse files Browse the repository at this point in the history
Remove non built-in objects for the __modules_event list which gets s…
  • Loading branch information
sjoerd-bouma authored Nov 16, 2023
2 parents 4bffb11 + 5553f69 commit e642c8b
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 41 deletions.
32 changes: 19 additions & 13 deletions NuRadioReco/framework/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def register_module_event(self, instance, name, kwargs):
kwargs:
the key word arguments of the run method
"""

self.__modules_event.append([name, instance, kwargs])

def register_module_station(self, station_id, instance, name, kwargs):
Expand All @@ -59,8 +59,9 @@ def register_module_station(self, station_id, instance, name, kwargs):
kwargs:
the key word arguments of the run method
"""
if(station_id not in self.__modules_station):
if station_id not in self.__modules_station:
self.__modules_station[station_id] = []

iE = len(self.__modules_event)
self.__modules_station[station_id].append([iE, name, instance, kwargs])

Expand Down Expand Up @@ -427,30 +428,33 @@ def serialize(self, mode):
commit_hash = NuRadioReco.utilities.version.get_NuRadioMC_commit_hash()
self.set_parameter(parameters.eventParameters.hash_NuRadioMC, commit_hash)
except:
logger.warning("Event is serialized without commit hash!")
self.set_parameter(parameters.eventParameters.hash_NuRadioMC, None)

for station in self.get_stations():
stations_pkl.append(station.serialize(mode))

showers_pkl = []
for shower in self.get_showers():
showers_pkl.append(shower.serialize())
sim_showers_pkl = []
for shower in self.get_sim_showers():
sim_showers_pkl.append(shower.serialize())
particles_pkl = []
for particle in self.get_particles():
particles_pkl.append(particle.serialize())
showers_pkl = [shower.serialize() for shower in self.get_showers()]
sim_showers_pkl = [shower.serialize() for shower in self.get_sim_showers()]
particles_pkl = [particle.serialize() for particle in self.get_particles()]

hybrid_info = self.__hybrid_information.serialize()

modules_out_event = []
for value in self.__modules_event: # remove module instances (this will just blow up the file size)
modules_out_event.append([value[0], None, value[2]])
invalid_keys = [key for key,val in value[2].items() if isinstance(val, BaseException)]
if len(invalid_keys):
logger.warning(f"The following arguments to module {value[0]} could not be serialized and will not be stored: {invalid_keys}")

modules_out_station = {}
for key in self.__modules_station: # remove module instances (this will just blow up the file size)
modules_out_station[key] = []
for value in self.__modules_station[key]:
modules_out_station[key].append([value[0], value[1], None, value[3]])
invalid_keys = [key for key,val in value[3].items() if isinstance(val, BaseException)]
if len(invalid_keys):
logger.warning(f"The following arguments to module {value[0]} could not be serialized and will not be stored: {invalid_keys}")

data = {'_parameters': self._parameters,
'__run_number': self.__run_number,
Expand Down Expand Up @@ -489,9 +493,11 @@ def deserialize(self, data_pkl):
particle = NuRadioReco.framework.particle.Particle(None)
particle.deserialize(particle_pkl)
self.add_particle(particle)

self.__hybrid_information = NuRadioReco.framework.hybrid_information.HybridInformation()
if 'hybrid_info' in data.keys():
self.__hybrid_information.deserialize(data['hybrid_info'])

self._parameters = data['_parameters']
self.__run_number = data['__run_number']
self._id = data['_id']
Expand All @@ -500,7 +506,7 @@ def deserialize(self, data_pkl):
if 'generator_info' in data.keys():
self._generator_info = data['generator_info']

if("__modules_event" in data):
if "__modules_event" in data:
self.__modules_event = data['__modules_event']
if("__modules_station" in data):
if "__modules_station" in data:
self.__modules_station = data['__modules_station']
74 changes: 46 additions & 28 deletions NuRadioReco/modules/base/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
from timeit import default_timer as timer
import NuRadioReco.framework.event
import NuRadioReco.framework.base_station
import NuRadioReco.detector.detector_base
import logging

import inspect
import pickle

def setup_logger(name="NuRadioReco", level=logging.WARNING):

Expand Down Expand Up @@ -39,43 +41,59 @@ def register_run_method(self, *args, **kwargs):
# generator, so not sure how to access the event.
evt = None
station = None
# find out type of module automatically
if(len(args) == 1):
if(isinstance(args[0], NuRadioReco.framework.event.Event)):
module_level = "event"
evt = args[0]
else:
# this is a module that creats events
module_level = "reader"
elif(len(args) >= 2):
if(isinstance(args[0], NuRadioReco.framework.event.Event) and isinstance(args[1], NuRadioReco.framework.base_station.BaseStation)):
module_level = "station"
evt = args[0]
station = args[1]
elif(isinstance(args[0], NuRadioReco.framework.event.Event)):
module_level = "event"
evt = args[0]
else:
# this is a module that creates events
module_level = "reader"
raise AttributeError("first argument of run method is not of type NuRadioReco.framework.event.Event")

signature = inspect.signature(run)
parameters = signature.parameters
# convert args to kwargs to facilitate easier bookkeeping
keys = [key for key in parameters.keys() if key != 'self']
all_kwargs = {key:value for key,value in zip(keys, args)}
all_kwargs.update(kwargs) # this silently overwrites positional args with kwargs, but this is probably okay as we still raise an error later

# include parameters with default values
for key,value in parameters.items():
if key not in all_kwargs.keys():
if value.default is not inspect.Parameter.empty:
all_kwargs[key] = value.default

store_kwargs = {}
for idx, (key,value) in enumerate(all_kwargs.items()):
if isinstance(value, NuRadioReco.framework.event.Event) and idx == 0: # event should be the first argument
evt = value
elif isinstance(value, NuRadioReco.framework.base_station.BaseStation) and idx == 1: # station should be second argument
station = value
elif isinstance(value, NuRadioReco.detector.detector_base.DetectorBase):
pass # we don't try to store detectors
else: # we try to store other arguments IF they are pickleable
try:
pickle.dumps(value, protocol=4)
store_kwargs[key] = value
except (TypeError, AttributeError): # object couldn't be pickled - we store the error instead
store_kwargs[key] = TypeError(f"Argument of type {type(value)} could not be serialized")
if station is not None:
module_level = "station"
elif evt is not None:
module_level = "event"
else:
# this is a module that creats events
module_level = "reader"

start = timer()
res = run(self, *args, **kwargs)
if(module_level == "event"):
evt.register_module_event(self, self.__class__.__name__, kwargs)
elif(module_level == "station"):
evt.register_module_station(station.get_id(), self, self.__class__.__name__, kwargs)
elif(module_level == "reader"):

if module_level == "event":
evt.register_module_event(self, self.__class__.__name__, store_kwargs)
elif module_level == "station":
evt.register_module_station(station.get_id(), self, self.__class__.__name__, store_kwargs)
elif module_level == "reader":
# not sure what to do... function returns generator, not sure how to access the event...
pass

res = run(self, *args, **kwargs)

end = timer()

if self not in register_run_method.time: # keep track of timing of modules. We use the module instance as key to time different module instances separately.
register_run_method.time[self] = 0
register_run_method.time[self] += (end - start)

return res

register_run_method.time = {}
Expand Down

0 comments on commit e642c8b

Please sign in to comment.