Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New template-based PID + big speed ups in truth-matching #53

Merged
merged 8 commits into from
Feb 16, 2025
10 changes: 9 additions & 1 deletion spine/ana/metric/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
"""Reconstruction quality evaluation module."""
"""Reconstruction quality evaluation module.

This submodule is used to evaluate reconstruction quality metrics, such as:
- Semantic segmentation accuracy
- Clustering accuracy
- Flash matching efficiency
- ...
"""

from .segment import *
from .point import *
from .cluster import *
from .optical import *
151 changes: 151 additions & 0 deletions spine/ana/metric/optical.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
"""Analysis script used to evaluate the semantic segmentation accuracy."""

import numpy as np

from spine.data.out import RecoInteraction, TruthInteraction

from spine.ana.base import AnaBase

__all__ = ['FlashMatchingAna']


class FlashMatchingAna(AnaBase):
"""Class which computes and stores the necessary data to build a
semantic segmentation confusion matrix.
"""

# Name of the analysis script (as specified in the configuration)
name = 'flash_match_eval'

# Valid match modes
_match_modes = ('reco_to_truth', 'truth_to_reco', 'both', 'all')

# Default object types when a match is not found
_default_objs = (
('reco_interactions', RecoInteraction()),
('truth_interactions', TruthInteraction())
)

def __init__(self, time_window=None, neutrino_only=True, max_num_flashes=1,
match_mode='both', **kwargs):
"""Initialize the analysis script.

Parameters
----------
time_window : List[float], optional
Time window (in ns) for which interactions must have matched flash
neutrino_only : bool, default False
If `True`, only check if neutrino in-time activity is matched for
the efficiency measurement (as opposed to any in-time activity)
max_num_flashes : int
Maximum number of flash matches to store
match_mode : str, default 'both'
If reconstructed and truth are available, specified which matching
direction(s) should be saved to the log file.
**kwargs : dict, optional
Additional arguments to pass to :class:`AnaBase`
"""
# Initialize the parent class
super().__init__('interaction', 'both', **kwargs)

# Store basic parameters
self.time_window = time_window
self.neutrino_only = neutrino_only

# Store default objects as a dictionary
self.default_objs = dict(self._default_objs)

# Store the matching mode
self.match_mode = match_mode
assert match_mode in self._match_modes, (
f"Invalid matching mode: {self.match_mode}. Must be one "
f"of {self._match_modes}.")

# Make sure the matches are loaded, initialize the output files
keys = {}
for prefix in self.prefixes:
if prefix == 'reco' and match_mode != 'truth_to_reco':
keys['interaction_matches_r2t'] = True
keys['interaction_matches_r2t_overlap'] = True
self.initialize_writer('reco')
if prefix == 'truth' and match_mode != 'reco_to_truth':
keys['interaction_matches_t2r'] = True
keys['interaction_matches_t2r_overlap'] = True
self.initialize_writer('truth')

self.update_keys(keys)

# List the interaction attributes to be stored
nu_attrs = ('energy_init',) if neutrino_only else ()
flash_attrs = (
'is_flash_matched', 'flash_ids', 'flash_times', 'flash_scores',
'flash_total_pe', 'flash_hypo_pe'
)
flash_lengths = {
k: max_num_flashes for k in ['flash_ids', 'flash_times', 'flash_scores']
}

self.reco_attrs = ('id', 'size', 'is_contained', 'topology', *flash_attrs)
self.truth_attrs = ('id', 'size', 'is_contained', 'nu_id', 't', 'topology', *nu_attrs)

self.reco_lengths = flash_lengths
self.truth_lengths = None

def process(self, data):
"""Store the flash matching metrics for one entry.

Parameters
----------
data : dict
Dictionary of data products
"""
# Loop over the matching directions
prefixes = {'reco': 'truth', 'truth': 'reco'}
for source, target in prefixes.items():
# Loop over the match pairs
src_attrs = getattr(self, f'{source}_attrs')
src_lengths = getattr(self, f'{source}_lengths')
tgt_attrs = getattr(self, f'{target}_attrs')
tgt_lengths = getattr(self, f'{target}_lengths')
match_suffix = f'{source[0]}2{target[0]}'
match_key = f'interaction_matches_{match_suffix}'
for idx, (obj_i, obj_j) in enumerate(data[match_key]):
# Check that the source interaction is of interest
if obj_i.is_truth:
# If the source object is a true interaction, check if it
# should be matched or not (in time or not)
if (self.time_window is not None and
(obj_i.t < self.time_window[0] or
obj_i.t > self.time_window[1])):
continue

# If requested, check that the in-time activity is a neutrino
if self.neutrino_only and obj_i.nu_id < 0:
continue

else:
# If the source object is a reco interaction, check if it
# is matched to a flash or not
if not obj_i.is_flash_matched:
continue

# Store information about the corresponding reco interaction
# and the flash associated with it (if any)
src_dict = obj_i.scalar_dict(src_attrs, src_lengths)
if obj_j is not None:
tgt_dict = obj_j.scalar_dict(tgt_attrs, tgt_lengths)
else:
default_obj = self.default_objs[f'{target}_interactions']
tgt_dict = default_obj.scalar_dict(tgt_attrs, tgt_lengths)

src_dict = {f'{source}_{k}':v for k, v in src_dict.items()}
tgt_dict = {f'{target}_{k}':v for k, v in tgt_dict.items()}

# Get the match quality
overlap = data[f'{match_key}_overlap'][idx]

# Build row dictionary and store
row_dict = {**src_dict, **tgt_dict}
row_dict.update({'match_overlap': overlap})

self.append(source, **row_dict)
5 changes: 5 additions & 0 deletions spine/data/neutrino.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ class Neutrino(PosDataBase):
True amount of distance traveled by the neutrino before interacting
theta : float
Angle between incoming and outgoing leptons in radians
t : float
Interaction time (ns)
creation_process : str
Creation process of the neutrino
position : np.ndarray
Expand Down Expand Up @@ -100,6 +102,7 @@ class Neutrino(PosDataBase):
lepton_p: float = -1.
distance_travel: float = -1.
theta: float = -1.
t: float = -np.inf
creation_process: str = ''
position: np.ndarray = None
momentum: np.ndarray = None
Expand Down Expand Up @@ -159,6 +162,8 @@ def from_larcv(cls, neutrino):
else:
obj_dict['track_id'] = getattr(neutrino, key)()

obj_dict['t'] = neutrino.position().t()

# Load the positional attribute
pos_attrs = ['x', 'y', 'z']
for key in cls._pos_attrs:
Expand Down
16 changes: 11 additions & 5 deletions spine/post/optical/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,9 @@ def get_matches(self, interactions, flashes):
# Build result, return
result = []
for m in self.matches:
result.append((interactions[m.tpc_id], flashes[m.flash_id], m))
tpc_id = self.qcluster_v[m.tpc_id].idx
flash_id = self.flash_v[m.flash_id].idx
result.append((interactions[tpc_id], flashes[flash_id], m))

return result

Expand All @@ -173,13 +175,17 @@ def make_qcluster_list(self, interactions):
# Loop over the interacions
from flashmatch import flashmatch
qcluster_v = []
for inter in interactions:
for idx, inter in enumerate(interactions):
# Produce a mask to remove negative value points (can happen)
valid_mask = np.where(inter.depositions > 0.)[0]

# Skip interactions with less than 2 points
if len(valid_mask) < 2:
continue

# Initialize qcluster
qcluster = flashmatch.QCluster_t()
qcluster.idx = int(inter.id)
qcluster.idx = idx
qcluster.time = 0

# Get the point coordinates
Expand Down Expand Up @@ -238,8 +244,8 @@ def make_flash_list(self, flashes):
for idx, f in enumerate(flashes):
# Initialize the Flash_t object
flash = flashmatch.Flash_t()
flash.idx = int(f.id) # Assign a unique index
flash.time = f.time # Flash timing, a candidate T0
flash.idx = idx
flash.time = f.time

# Assign the flash position and error on this position
flash.x, flash.y, flash.z = 0, 0, 0
Expand Down
78 changes: 78 additions & 0 deletions spine/post/reco/pid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""Particle identification modules."""

import numpy as np

from spine.utils.globals import TRACK_SHP
from spine.utils.pid import TemplateParticleIdentifier

from spine.post.base import PostBase

__all__ = ['PIDTemplateProcessor']


class PIDTemplateProcessor(PostBase):
"""Produces particle species classification estimates based on dE/dx vs
residual range templates of tracks.
"""

# Name of the post-processor (as specified in the configuration)
name = 'pid_template'

def __init__(self, fill_per_pid=False, obj_type='particle', run_mode='reco',
truth_point_mode='points', truth_dep_mode='depositions',
**identifier):
"""Store the necessary attributes to do template-based PID prediction.

Parameters
----------
fill_per_pid : bool, default False
If `True`, stores the scores associated with each PID candidate
**identifier : dict, optional
Particle template identifier configuration parameters
"""
# Initialize the parent class
super().__init__(obj_type, run_mode, truth_point_mode, truth_dep_mode)

# Store additional parameter
self.fill_per_pid = fill_per_pid

# Initialize the underlying template-fitter class
self.identifier = TemplateParticleIdentifier(**identifier)

def process(self, data):
"""Reconstruct the CSDA KE estimates for each particle in one entry.

Parameters
----------
data : dict
Dictionary of data products
"""
# Loop over particle objects
for k in self.fragment_keys + self.particle_keys:
for obj in data[k]:
# Only run this algorithm on tracks that have a CSDA table
if not obj.shape == TRACK_SHP:
continue

# Make sure the object coordinates are expressed in cm
self.check_units(obj)

# Get point coordinates and depositions
points = self.get_points(obj)
values = self.get_depositions(obj)
if not len(points):
continue

# Run the particle identifier
pid, chi2_scores = self.identifier(
points, values, obj.end_point, obj.start_point)

# Store for this PID
obj.pid = pid
if self.fill_per_pid:
chi2_per_pid = np.full(
len(obj.pid_scores), -1., dtype=obj.pid_scores.dtype)
for i, pid in enumerate(self.identifier.include_pids):
chi2_per_pid[pid] = chi2_scores[i]

obj.pid_scores = chi2_per_pid
Loading
Loading