From 3d83bf118d5049c7343561ced1157a30d1f9cb31 Mon Sep 17 00:00:00 2001 From: Francois Drielsma Date: Fri, 28 Feb 2025 12:58:19 -0800 Subject: [PATCH] Fixed full chain pre-calibration (was completely wrong) --- spine/model/full_chain.py | 71 ++++++++++++++++++------------------ spine/utils/calib/manager.py | 15 +++++++- 2 files changed, 49 insertions(+), 37 deletions(-) diff --git a/spine/model/full_chain.py b/spine/model/full_chain.py index e8c76f45..3b5f3732 100644 --- a/spine/model/full_chain.py +++ b/spine/model/full_chain.py @@ -232,7 +232,7 @@ def __init__(self, chain, uresnet_deghost=None, uresnet=None, assert 'stage' in calibration, ( "If the calibration is to be applied, must provide the " "`stage` to specify where to apply it.") - self.calibration_stage = calibration.pop(stage) + self.calibration_stage = calibration.pop('stage') self.calibrator = CalibrationManager(**calibration) calibration['stage'] = self.calibration_stage @@ -249,7 +249,7 @@ def modes(self): return dict(self._modes) def forward(self, data, sources=None, seg_label=None, clust_label=None, - coord_label=None, energy_label=None, run_info=None): + coord_label=None, energy_label=None, meta=None, run_info=None): """Run a batch of data through the full chain. Parameters @@ -274,6 +274,8 @@ def forward(self, data, sources=None, seg_label=None, clust_label=None, energy_label : TensorBatch, optional (N, 1 + D + 1) Tensor of true energy deposition values - 1 is the energy deposition value in each voxel + meta : Meta, optional + Image metadata information run_info : List[RunInfo], optional Object containing information about the run, subrun and event @@ -289,27 +291,27 @@ def forward(self, data, sources=None, seg_label=None, clust_label=None, # Run the semantic segmentation (and point proposal) stage if self.calibration_stage == 'segmentation': - data = self.run_calibration(data, sources, energy_label, run_info) + data = self.run_calibration(data, sources, energy_label, meta, run_info) clust_label = self.run_segmentation_ppn(data, seg_label, clust_label) # Run the fragmentation stage if self.calibration_stage == 'fragmentation': - data = self.run_calibration(data, sources, energy_label, run_info) + data = self.run_calibration(data, sources, energy_label, meta, run_info) self.run_fragmentation(data, clust_label) # Run the particle aggregation if self.calibration_stage == 'particle_aggregation': - data = self.run_calibration(data, sources, energy_label, run_info) + data = self.run_calibration(data, sources, energy_label, meta, run_info) self.run_part_aggregation(data, clust_label, coord_label) # Run the interaction aggregation if self.calibration_stage == 'inter_aggregation': - data = self.run_calibration(data, sources, energy_label, run_info) + data = self.run_calibration(data, sources, energy_label, meta, run_info) self.run_inter_aggregation(data, clust_label, coord_label) # Run an independant particle classification stage if self.calibration_stage == 'particle_classification': - data = self.run_calibration(data, sources, energy_label, run_info) + data = self.run_calibration(data, sources, energy_label, meta, run_info) # TODO # Run the interaction classification @@ -720,7 +722,7 @@ def run_inter_aggregation(self, data, clust_label=None, coord_label=None): if self.inter_aggregation is not None: self.result['interaction_clusts'] = interactions - def run_calibration(self, data, sources=None, energy_label=None, + def run_calibration(self, data, sources=None, energy_label=None, meta=None, run_info=None): """Run the calibration algorithm. @@ -737,6 +739,8 @@ def run_calibration(self, data, sources=None, energy_label=None, energy_label : TensorBatch, optional (N, 1 + D + 1) Tensor of true energy deposition values - 1 is the energy deposition value in each voxel + meta : Meta, optional + Image metadata information run_info : List[RunInfo], optional Object containing information about the run, subrun and event @@ -746,39 +750,34 @@ def run_calibration(self, data, sources=None, energy_label=None, (N, 1 + D + N_f) tensor of calibrated voxel/value pairs """ if self.calibration == 'apply': - # Apply calibration routines + # Check that the metadata is provided + assert meta is not None, ( + "Must provide the metadata to convert pixel coordinates " + "to detector coordinates and apply calibrations.") + + # Loop over entries in the batch (might have different meta/run IDs) data_np = data.to_numpy().tensor sources = sources.to_numpy().tensor if sources is not None else None - if run_info is None: - # Fetch points for the whole batch - voxels = data_np[:, COORD_COLS] - values = data_np[:, VALUE_COL] + rep = data.batch_size//len(meta) + for b in range(data.batch_size): + # Fetch necessary information for this batch entry + lower, upper = data.edges[b], data.edges[b+1] + data_b = data_np[lower:upper] + voxels_b = data_b[:, COORD_COLS] + values_b = data_b[:, VALUE_COL] + sources_b = sources[lower:upper] if sources is not None else None + + # Fetch meta/run information for this batch entry + meta_b = meta[b//rep] + run_id = run_info[b//rep].run if run_info is not None else None # Calibrate voxel values - values = self.calibrator(voxels, values, sources) - data.tensor[:, VALUE_COL] = torch.tensor( - values, dtype=data.dtype, device=data.device) + values_b = self.calibrator( + voxels_b, values_b, sources_b, run_id, + meta=meta_b, module_id=b%rep) - else: - # Loop over entries in the batch (might have different run IDs) - rep = data.batch_size//len(run_info) - for b in range(data.batch_size): - # Fetch points for this batch entry - lower, upper = data.edges[b], data.edges[b+1] - data_b = data_np[lower:upper] - voxels_b = data_b[:, COORD_COLS] - values_b = data_b[:, VALUE_COL] - - # Fetch run ID for this batch entry - run_id = run_info[b//rep].run - - # Calibrate voxel values - sources_b = sources[lower:upper] if sources is not None else None - values_b = self.calibrator( - voxels_b, values_b, sources_b, run_id) - - data.tensor[lower:upper, VALUE_COL] = torch.tensor( - values_b, dtype=data.dtype, device=data.device) + data.tensor[lower:upper, VALUE_COL] = torch.tensor( + values_b, dtype=data.dtype, device=data.device) self.result['data_adapt'] = data diff --git a/spine/utils/calib/manager.py b/spine/utils/calib/manager.py index 6bb3f015..c6cf6a01 100644 --- a/spine/utils/calib/manager.py +++ b/spine/utils/calib/manager.py @@ -49,7 +49,8 @@ def __init__(self, geometry, gain_applied=False, **cfg): # Append self.modules[key] = calibrator_factory(key, value) - def __call__(self, points, values, sources=None, run_id=None, track=None): + def __call__(self, points, values, sources=None, run_id=None, track=None, + meta=None, module_id=None): """Main calibration driver. Parameters @@ -67,12 +68,24 @@ def __call__(self, points, values, sources=None, run_id=None, track=None): track : bool, defaut `False` Whether the object is a track or not. If it is, the track gets segmented to evaluate local dE/dx and track angle. + meta : Meta, optional + If provided, use to convert the coordinates from image pixel + coordinates to detector coordinates + module_id : int, optional + If provided, shift points to the requested module assuming that the + points currently live in module ID 0 Returns ------- np.ndarray (N) array of calibrated depositions in ADC, e- or MeV """ + # If necessary, convert all points to detector coordinates + if meta is not None: + points = meta.to_cm(points, center=True) + if module_id is not None: + points = self.geo.translate(points, 0, module_id) + # Create a mask for each of the TPC volume in the detector if sources is not None: tpc_indexes = []