From 69265ff85a3d7fc8d9e57b32ad02f71605f5d501 Mon Sep 17 00:00:00 2001 From: Dacheng Xu Date: Fri, 27 Dec 2024 19:51:14 -0500 Subject: [PATCH] Collect SOM dtype at one place (#1511) --- straxen/plugins/events/event_basics_som.py | 12 ++------ .../peaklets/peaklet_classification_som.py | 30 ++++++++++++------- straxen/plugins/peaks/peak_basics_som.py | 13 ++------ straxen/plugins/peaks/peaks_som.py | 2 +- 4 files changed, 25 insertions(+), 32 deletions(-) diff --git a/straxen/plugins/events/event_basics_som.py b/straxen/plugins/events/event_basics_som.py index ed653fd96..f236e5fb3 100644 --- a/straxen/plugins/events/event_basics_som.py +++ b/straxen/plugins/events/event_basics_som.py @@ -1,7 +1,6 @@ import strax -import numpy as np - from straxen.plugins.events.event_basics_vanilla import EventBasicsVanilla +from straxen.plugins.peaklets.peaklet_classification_som import som_additional_fields export, __all__ = strax.exporter() @@ -17,11 +16,4 @@ def _set_dtype_requirements(self): # Properties to store for each peak (main and alternate S1 and S2) # Add here SOM types: super()._set_dtype_requirements() - self.peak_properties = list(self.peak_properties) - self.peak_properties += [ - ("som_sub_type", np.int32, "SOM subtype of the peak(let)"), - ("old_type", np.int8, "Old type of the peak(let)"), - ("loc_x_som", np.int16, "x location of the peak(let) in the SOM"), - ("loc_y_som", np.int16, "y location of the peak(let) in the SOM"), - ] - self.peak_properties = tuple(self.peak_properties) + self.peak_properties += tuple(som_additional_fields) diff --git a/straxen/plugins/peaklets/peaklet_classification_som.py b/straxen/plugins/peaklets/peaklet_classification_som.py index fad9fa8fd..81b693839 100644 --- a/straxen/plugins/peaklets/peaklet_classification_som.py +++ b/straxen/plugins/peaklets/peaklet_classification_som.py @@ -8,6 +8,16 @@ export, __all__ = strax.exporter() +__all__.extend(["som_additional_fields"]) + + +som_additional_fields = [ + ("som_sub_type", np.int32, "SOM subtype of the peak(let)"), + ("vanilla_type", np.int8, "Vanilla type of the peak(let)"), + ("loc_x_som", np.int16, "x location of the peak(let) in the SOM"), + ("loc_y_som", np.int16, "y location of the peak(let) in the SOM"), +] + @export class PeakletClassificationSOM(PeakletClassificationVanilla): @@ -30,14 +40,13 @@ class PeakletClassificationSOM(PeakletClassificationVanilla): __version__ = "0.2.0" child_plugin = True - dtype = strax.peak_interval_dtype + [ - ("type", np.int8, "Classification of the peak(let)"), - ("som_sub_type", np.int32, "SOM subtype of the peak(let)"), - ("old_type", np.int8, "Old type of the peak(let)"), - ("som_type", np.int8, "SOM type of the peak(let)"), - ("loc_x_som", np.int16, "x location of the peak(let) in the SOM"), - ("loc_y_som", np.int16, "y location of the peak(let) in the SOM"), - ] + dtype = ( + strax.peak_interval_dtype + + [ + ("type", np.int8, "Classification of the peak(let)"), + ] + + som_additional_fields + ) som_files = straxen.URLConfig( default="resource://xedocs://som_classifiers?attr=value&version=v1&run_id=045000&fmt=npy" @@ -67,7 +76,7 @@ def compute(self, peaklets): peaklet_with_som = np.zeros(len(peaklets_classifcation), dtype=self.dtype) strax.copy_to_buffer(peaklets_classifcation, peaklet_with_som, "_copy_peaklets_information") - peaklet_with_som["old_type"] = peaklets_classifcation["type"] + peaklet_with_som["vanilla_type"] = peaklets_classifcation["type"] del peaklets_classifcation # SOM classification @@ -86,11 +95,10 @@ def compute(self, peaklets): peaklet_with_som["som_sub_type"][_is_s1_or_s2] = som_sub_type peaklet_with_som["loc_x_som"][_is_s1_or_s2] = x_som peaklet_with_som["loc_y_som"][_is_s1_or_s2] = y_som - peaklet_with_som["som_type"][_is_s1_or_s2] = strax_type if self.use_som_as_default: peaklet_with_som["type"][_is_s1_or_s2] = strax_type else: - peaklet_with_som["type"] = peaklet_with_som["old_type"] + peaklet_with_som["type"] = peaklet_with_som["vanilla_type"] return peaklet_with_som diff --git a/straxen/plugins/peaks/peak_basics_som.py b/straxen/plugins/peaks/peak_basics_som.py index b4df54e3a..105f8ac99 100644 --- a/straxen/plugins/peaks/peak_basics_som.py +++ b/straxen/plugins/peaks/peak_basics_som.py @@ -1,5 +1,5 @@ -import numpy as np import strax +from straxen.plugins.peaklets.peaklet_classification_som import som_additional_fields from straxen.plugins.peaks.peak_basics_vanilla import PeakBasicsVanilla export, __all__ = strax.exporter() @@ -14,17 +14,10 @@ class PeakBasicsSOM(PeakBasicsVanilla): def infer_dtype(self): dtype = super().infer_dtype() - additional_fields = [ - ("som_sub_type", np.int32, "SOM subtype of the peak(let)"), - ("old_type", np.int8, "Old type of the peak(let)"), - ("loc_x_som", np.int16, "x location of the peak(let) in the SOM"), - ("loc_y_som", np.int16, "y location of the peak(let) in the SOM"), - ] - - return dtype + additional_fields + return dtype + som_additional_fields def compute(self, peaks): peak_basics = super().compute(peaks) - fields_to_copy = ("som_sub_type", "old_type", "loc_x_som", "loc_y_som") + fields_to_copy = strax.to_numpy_dtype(som_additional_fields).names strax.copy_to_buffer(peaks, peak_basics, "_copy_som_information", fields_to_copy) return peak_basics diff --git a/straxen/plugins/peaks/peaks_som.py b/straxen/plugins/peaks/peaks_som.py index fb428a97e..8a879cd7c 100644 --- a/straxen/plugins/peaks/peaks_som.py +++ b/straxen/plugins/peaks/peaks_som.py @@ -34,7 +34,7 @@ def compute(self, peaklets, merged_s2s): _is_merged_s2 = np.isin(result["time"], merged_s2s["time"]) & np.isin( strax.endtime(result), strax.endtime(merged_s2s) ) - result["old_type"][_is_merged_s2] = -1 + result["vanilla_type"][_is_merged_s2] = -1 result["som_sub_type"][_is_merged_s2] = -1 return result