Skip to content

Commit

Permalink
Collect SOM dtype at one place (#1511)
Browse files Browse the repository at this point in the history
  • Loading branch information
dachengx authored Dec 28, 2024
1 parent 9815352 commit 69265ff
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 32 deletions.
12 changes: 2 additions & 10 deletions straxen/plugins/events/event_basics_som.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import strax
import numpy as np

from straxen.plugins.events.event_basics_vanilla import EventBasicsVanilla
from straxen.plugins.peaklets.peaklet_classification_som import som_additional_fields

export, __all__ = strax.exporter()

Expand All @@ -17,11 +16,4 @@ def _set_dtype_requirements(self):
# Properties to store for each peak (main and alternate S1 and S2)
# Add here SOM types:
super()._set_dtype_requirements()
self.peak_properties = list(self.peak_properties)
self.peak_properties += [
("som_sub_type", np.int32, "SOM subtype of the peak(let)"),
("old_type", np.int8, "Old type of the peak(let)"),
("loc_x_som", np.int16, "x location of the peak(let) in the SOM"),
("loc_y_som", np.int16, "y location of the peak(let) in the SOM"),
]
self.peak_properties = tuple(self.peak_properties)
self.peak_properties += tuple(som_additional_fields)
30 changes: 19 additions & 11 deletions straxen/plugins/peaklets/peaklet_classification_som.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,16 @@

export, __all__ = strax.exporter()

__all__.extend(["som_additional_fields"])


som_additional_fields = [
("som_sub_type", np.int32, "SOM subtype of the peak(let)"),
("vanilla_type", np.int8, "Vanilla type of the peak(let)"),
("loc_x_som", np.int16, "x location of the peak(let) in the SOM"),
("loc_y_som", np.int16, "y location of the peak(let) in the SOM"),
]


@export
class PeakletClassificationSOM(PeakletClassificationVanilla):
Expand All @@ -30,14 +40,13 @@ class PeakletClassificationSOM(PeakletClassificationVanilla):
__version__ = "0.2.0"
child_plugin = True

dtype = strax.peak_interval_dtype + [
("type", np.int8, "Classification of the peak(let)"),
("som_sub_type", np.int32, "SOM subtype of the peak(let)"),
("old_type", np.int8, "Old type of the peak(let)"),
("som_type", np.int8, "SOM type of the peak(let)"),
("loc_x_som", np.int16, "x location of the peak(let) in the SOM"),
("loc_y_som", np.int16, "y location of the peak(let) in the SOM"),
]
dtype = (
strax.peak_interval_dtype
+ [
("type", np.int8, "Classification of the peak(let)"),
]
+ som_additional_fields
)

som_files = straxen.URLConfig(
default="resource://xedocs://som_classifiers?attr=value&version=v1&run_id=045000&fmt=npy"
Expand Down Expand Up @@ -67,7 +76,7 @@ def compute(self, peaklets):

peaklet_with_som = np.zeros(len(peaklets_classifcation), dtype=self.dtype)
strax.copy_to_buffer(peaklets_classifcation, peaklet_with_som, "_copy_peaklets_information")
peaklet_with_som["old_type"] = peaklets_classifcation["type"]
peaklet_with_som["vanilla_type"] = peaklets_classifcation["type"]
del peaklets_classifcation

# SOM classification
Expand All @@ -86,11 +95,10 @@ def compute(self, peaklets):
peaklet_with_som["som_sub_type"][_is_s1_or_s2] = som_sub_type
peaklet_with_som["loc_x_som"][_is_s1_or_s2] = x_som
peaklet_with_som["loc_y_som"][_is_s1_or_s2] = y_som
peaklet_with_som["som_type"][_is_s1_or_s2] = strax_type
if self.use_som_as_default:
peaklet_with_som["type"][_is_s1_or_s2] = strax_type
else:
peaklet_with_som["type"] = peaklet_with_som["old_type"]
peaklet_with_som["type"] = peaklet_with_som["vanilla_type"]

return peaklet_with_som

Expand Down
13 changes: 3 additions & 10 deletions straxen/plugins/peaks/peak_basics_som.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
import strax
from straxen.plugins.peaklets.peaklet_classification_som import som_additional_fields
from straxen.plugins.peaks.peak_basics_vanilla import PeakBasicsVanilla

export, __all__ = strax.exporter()
Expand All @@ -14,17 +14,10 @@ class PeakBasicsSOM(PeakBasicsVanilla):

def infer_dtype(self):
dtype = super().infer_dtype()
additional_fields = [
("som_sub_type", np.int32, "SOM subtype of the peak(let)"),
("old_type", np.int8, "Old type of the peak(let)"),
("loc_x_som", np.int16, "x location of the peak(let) in the SOM"),
("loc_y_som", np.int16, "y location of the peak(let) in the SOM"),
]

return dtype + additional_fields
return dtype + som_additional_fields

def compute(self, peaks):
peak_basics = super().compute(peaks)
fields_to_copy = ("som_sub_type", "old_type", "loc_x_som", "loc_y_som")
fields_to_copy = strax.to_numpy_dtype(som_additional_fields).names
strax.copy_to_buffer(peaks, peak_basics, "_copy_som_information", fields_to_copy)
return peak_basics
2 changes: 1 addition & 1 deletion straxen/plugins/peaks/peaks_som.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def compute(self, peaklets, merged_s2s):
_is_merged_s2 = np.isin(result["time"], merged_s2s["time"]) & np.isin(
strax.endtime(result), strax.endtime(merged_s2s)
)
result["old_type"][_is_merged_s2] = -1
result["vanilla_type"][_is_merged_s2] = -1
result["som_sub_type"][_is_merged_s2] = -1

return result

0 comments on commit 69265ff

Please sign in to comment.