Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed full chain pre-calibration (was completely wrong) #65

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 35 additions & 36 deletions spine/model/full_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def __init__(self, chain, uresnet_deghost=None, uresnet=None,
assert 'stage' in calibration, (
"If the calibration is to be applied, must provide the "
"`stage` to specify where to apply it.")
self.calibration_stage = calibration.pop(stage)
self.calibration_stage = calibration.pop('stage')
self.calibrator = CalibrationManager(**calibration)
calibration['stage'] = self.calibration_stage

Expand All @@ -249,7 +249,7 @@ def modes(self):
return dict(self._modes)

def forward(self, data, sources=None, seg_label=None, clust_label=None,
coord_label=None, energy_label=None, run_info=None):
coord_label=None, energy_label=None, meta=None, run_info=None):
"""Run a batch of data through the full chain.

Parameters
Expand All @@ -274,6 +274,8 @@ def forward(self, data, sources=None, seg_label=None, clust_label=None,
energy_label : TensorBatch, optional
(N, 1 + D + 1) Tensor of true energy deposition values
- 1 is the energy deposition value in each voxel
meta : Meta, optional
Image metadata information
run_info : List[RunInfo], optional
Object containing information about the run, subrun and event

Expand All @@ -289,27 +291,27 @@ def forward(self, data, sources=None, seg_label=None, clust_label=None,

# Run the semantic segmentation (and point proposal) stage
if self.calibration_stage == 'segmentation':
data = self.run_calibration(data, sources, energy_label, run_info)
data = self.run_calibration(data, sources, energy_label, meta, run_info)
clust_label = self.run_segmentation_ppn(data, seg_label, clust_label)

# Run the fragmentation stage
if self.calibration_stage == 'fragmentation':
data = self.run_calibration(data, sources, energy_label, run_info)
data = self.run_calibration(data, sources, energy_label, meta, run_info)
self.run_fragmentation(data, clust_label)

# Run the particle aggregation
if self.calibration_stage == 'particle_aggregation':
data = self.run_calibration(data, sources, energy_label, run_info)
data = self.run_calibration(data, sources, energy_label, meta, run_info)
self.run_part_aggregation(data, clust_label, coord_label)

# Run the interaction aggregation
if self.calibration_stage == 'inter_aggregation':
data = self.run_calibration(data, sources, energy_label, run_info)
data = self.run_calibration(data, sources, energy_label, meta, run_info)
self.run_inter_aggregation(data, clust_label, coord_label)

# Run an independant particle classification stage
if self.calibration_stage == 'particle_classification':
data = self.run_calibration(data, sources, energy_label, run_info)
data = self.run_calibration(data, sources, energy_label, meta, run_info)
# TODO

# Run the interaction classification
Expand Down Expand Up @@ -720,7 +722,7 @@ def run_inter_aggregation(self, data, clust_label=None, coord_label=None):
if self.inter_aggregation is not None:
self.result['interaction_clusts'] = interactions

def run_calibration(self, data, sources=None, energy_label=None,
def run_calibration(self, data, sources=None, energy_label=None, meta=None,
run_info=None):
"""Run the calibration algorithm.

Expand All @@ -737,6 +739,8 @@ def run_calibration(self, data, sources=None, energy_label=None,
energy_label : TensorBatch, optional
(N, 1 + D + 1) Tensor of true energy deposition values
- 1 is the energy deposition value in each voxel
meta : Meta, optional
Image metadata information
run_info : List[RunInfo], optional
Object containing information about the run, subrun and event

Expand All @@ -746,39 +750,34 @@ def run_calibration(self, data, sources=None, energy_label=None,
(N, 1 + D + N_f) tensor of calibrated voxel/value pairs
"""
if self.calibration == 'apply':
# Apply calibration routines
# Check that the metadata is provided
assert meta is not None, (
"Must provide the metadata to convert pixel coordinates "
"to detector coordinates and apply calibrations.")

# Loop over entries in the batch (might have different meta/run IDs)
data_np = data.to_numpy().tensor
sources = sources.to_numpy().tensor if sources is not None else None
if run_info is None:
# Fetch points for the whole batch
voxels = data_np[:, COORD_COLS]
values = data_np[:, VALUE_COL]
rep = data.batch_size//len(meta)
for b in range(data.batch_size):
# Fetch necessary information for this batch entry
lower, upper = data.edges[b], data.edges[b+1]
data_b = data_np[lower:upper]
voxels_b = data_b[:, COORD_COLS]
values_b = data_b[:, VALUE_COL]
sources_b = sources[lower:upper] if sources is not None else None

# Fetch meta/run information for this batch entry
meta_b = meta[b//rep]
run_id = run_info[b//rep].run if run_info is not None else None

# Calibrate voxel values
values = self.calibrator(voxels, values, sources)
data.tensor[:, VALUE_COL] = torch.tensor(
values, dtype=data.dtype, device=data.device)
values_b = self.calibrator(
voxels_b, values_b, sources_b, run_id,
meta=meta_b, module_id=b%rep)

else:
# Loop over entries in the batch (might have different run IDs)
rep = data.batch_size//len(run_info)
for b in range(data.batch_size):
# Fetch points for this batch entry
lower, upper = data.edges[b], data.edges[b+1]
data_b = data_np[lower:upper]
voxels_b = data_b[:, COORD_COLS]
values_b = data_b[:, VALUE_COL]

# Fetch run ID for this batch entry
run_id = run_info[b//rep].run

# Calibrate voxel values
sources_b = sources[lower:upper] if sources is not None else None
values_b = self.calibrator(
voxels_b, values_b, sources_b, run_id)

data.tensor[lower:upper, VALUE_COL] = torch.tensor(
values_b, dtype=data.dtype, device=data.device)
data.tensor[lower:upper, VALUE_COL] = torch.tensor(
values_b, dtype=data.dtype, device=data.device)

self.result['data_adapt'] = data

Expand Down
15 changes: 14 additions & 1 deletion spine/utils/calib/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ def __init__(self, geometry, gain_applied=False, **cfg):
# Append
self.modules[key] = calibrator_factory(key, value)

def __call__(self, points, values, sources=None, run_id=None, track=None):
def __call__(self, points, values, sources=None, run_id=None, track=None,
meta=None, module_id=None):
"""Main calibration driver.

Parameters
Expand All @@ -67,12 +68,24 @@ def __call__(self, points, values, sources=None, run_id=None, track=None):
track : bool, defaut `False`
Whether the object is a track or not. If it is, the track gets
segmented to evaluate local dE/dx and track angle.
meta : Meta, optional
If provided, use to convert the coordinates from image pixel
coordinates to detector coordinates
module_id : int, optional
If provided, shift points to the requested module assuming that the
points currently live in module ID 0

Returns
-------
np.ndarray
(N) array of calibrated depositions in ADC, e- or MeV
"""
# If necessary, convert all points to detector coordinates
if meta is not None:
points = meta.to_cm(points, center=True)
if module_id is not None:
points = self.geo.translate(points, 0, module_id)

# Create a mask for each of the TPC volume in the detector
if sources is not None:
tpc_indexes = []
Expand Down
Loading