Skip to content

Commit

Permalink
Merge branch 'main' into pre-commit-ci-update-config
Browse files Browse the repository at this point in the history
  • Loading branch information
jdeschamps authored Feb 3, 2025
2 parents 35630fc + adacfe6 commit 9339825
Show file tree
Hide file tree
Showing 8 changed files with 231 additions and 190 deletions.
121 changes: 64 additions & 57 deletions src/careamics/lvae_training/calibration.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union
from typing import Union, Optional

import numpy as np
import torch
Expand Down Expand Up @@ -34,9 +34,6 @@ def __init__(self, num_bins: int = 15):
self._bins = num_bins
self._bin_boundaries = None

def logvar_to_std(self, logvar: np.ndarray) -> np.ndarray:
return np.exp(logvar / 2)

def compute_bin_boundaries(self, predict_std: np.ndarray) -> np.ndarray:
"""Compute the bin boundaries for `num_bins` bins and predicted std values."""
min_std = np.min(predict_std)
Expand Down Expand Up @@ -104,81 +101,91 @@ def compute_stats(
)
rmse_stderr = np.sqrt(stderr) if stderr is not None else None

bin_var = np.mean((std_ch[bin_mask] ** 2))
bin_var = np.mean(std_ch[bin_mask] ** 2)
stats_dict[ch_idx]["rmse"].append(bin_error)
stats_dict[ch_idx]["rmse_err"].append(rmse_stderr)
stats_dict[ch_idx]["rmv"].append(np.sqrt(bin_var))
stats_dict[ch_idx]["bin_count"].append(bin_size)
self.stats_dict = stats_dict
return stats_dict

def get_calibrated_factor_for_stdev(
self,
pred: Optional[np.ndarray] = None,
pred_std: Optional[np.ndarray] = None,
target: Optional[np.ndarray] = None,
q_s: float = 0.00001,
q_e: float = 0.99999,
) -> dict[str, float]:
"""Calibrate the uncertainty by multiplying the predicted std with a scalar.
def get_calibrated_factor_for_stdev(
pred: Union[np.ndarray, torch.Tensor],
pred_std: Union[np.ndarray, torch.Tensor],
target: Union[np.ndarray, torch.Tensor],
q_s: float = 0.00001,
q_e: float = 0.99999,
num_bins: int = 30,
) -> dict[str, float]:
"""Calibrate the uncertainty by multiplying the predicted std with a scalar.
Parameters
----------
pred : Union[np.ndarray, torch.Tensor]
Predicted image, shape (n, h, w, c).
pred_std : Union[np.ndarray, torch.Tensor]
Predicted std, shape (n, h, w, c).
target : Union[np.ndarray, torch.Tensor]
Target image, shape (n, h, w, c).
q_s : float, optional
Start quantile, by default 0.00001.
q_e : float, optional
End quantile, by default 0.99999.
num_bins : int, optional
Number of bins to use for calibration, by default 30.
Returns
-------
dict[str, float]
Calibrated factor for each channel (slope + intercept).
"""
calib = Calibration(num_bins=num_bins)
stats_dict = calib.compute_stats(pred, pred_std, target)
outputs = {}
for ch_idx in stats_dict.keys():
y = stats_dict[ch_idx]["rmse"]
x = stats_dict[ch_idx]["rmv"]
count = stats_dict[ch_idx]["bin_count"]

first_idx = get_first_index(count, q_s)
last_idx = get_last_index(count, q_e)
x = x[first_idx:-last_idx]
y = y[first_idx:-last_idx]
slope, intercept, *_ = stats.linregress(x, y)
output = {"scalar": slope, "offset": intercept}
outputs[ch_idx] = output
return outputs
Parameters
----------
stats_dict : dict[int, dict[str, Union[np.ndarray, list]]]
Dictionary containing the stats for each channel.
q_s : float, optional
Start quantile, by default 0.00001.
q_e : float, optional
End quantile, by default 0.99999.
Returns
-------
dict[str, float]
Calibrated factor for each channel (slope + intercept).
"""
if not hasattr(self, "stats_dict"):
print("No stats found. Computing stats...")
if any(v is None for v in [pred, pred_std, target]):
raise ValueError("pred, pred_std, and target must be provided.")
self.stats_dict = self.compute_stats(
pred=pred, pred_std=pred_std, target=target
)
outputs = {}
for ch_idx in self.stats_dict.keys():
y = self.stats_dict[ch_idx]["rmse"]
x = self.stats_dict[ch_idx]["rmv"]
count = self.stats_dict[ch_idx]["bin_count"]

first_idx = get_first_index(count, q_s)
last_idx = get_last_index(count, q_e)
x = x[first_idx:-last_idx]
y = y[first_idx:-last_idx]
slope, intercept, *_ = stats.linregress(x, y)
output = {"scalar": slope, "offset": intercept}
outputs[ch_idx] = output
factors = self.get_factors_array(factors_dict=outputs)
return outputs, factors

def get_factors_array(self, factors_dict: list[dict]):
"""Get the calibration factors as a numpy array."""
calib_scalar = [factors_dict[i]["scalar"] for i in range(len(factors_dict))]
calib_scalar = np.array(calib_scalar).reshape(1, 1, 1, -1)
calib_offset = [
factors_dict[i].get("offset", 0.0) for i in range(len(factors_dict))
]
calib_offset = np.array(calib_offset).reshape(1, 1, 1, -1)
return {"scalar": calib_scalar, "offset": calib_offset}


def plot_calibration(ax, calibration_stats):
first_idx = get_first_index(calibration_stats[0]["bin_count"], 0.001)
last_idx = get_last_index(calibration_stats[0]["bin_count"], 0.999)
first_idx = get_first_index(calibration_stats[0]["bin_count"], 0.0001)
last_idx = get_last_index(calibration_stats[0]["bin_count"], 0.9999)
ax.plot(
calibration_stats[0]["rmv"][first_idx:-last_idx],
calibration_stats[0]["rmse"][first_idx:-last_idx],
"o",
label=r"$\hat{C}_0$: Ch1",
)

first_idx = get_first_index(calibration_stats[1]["bin_count"], 0.001)
last_idx = get_last_index(calibration_stats[1]["bin_count"], 0.999)
first_idx = get_first_index(calibration_stats[1]["bin_count"], 0.0001)
last_idx = get_last_index(calibration_stats[1]["bin_count"], 0.9999)
ax.plot(
calibration_stats[1]["rmv"][first_idx:-last_idx],
calibration_stats[1]["rmse"][first_idx:-last_idx],
"o",
label=r"$\hat{C}_1: : Ch2$",
label=r"$\hat{C}_1$: Ch2",
)

# TODO add multichannel
ax.set_xlabel("RMV")
ax.set_ylabel("RMSE")
ax.legend()
3 changes: 2 additions & 1 deletion src/careamics/lvae_training/dataset/lc_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ def reduce_data(
]

self.N = len(t_list)
self.set_img_sz(self._img_sz, self._grid_sz)
# TODO where tf is self._img_sz defined?
self.set_img_sz([self._img_sz, self._img_sz], self._grid_sz)
print(
f"[{self.__class__.__name__}] Data reduced. New data shape: {self._data.shape}"
)
Expand Down
4 changes: 2 additions & 2 deletions src/careamics/lvae_training/dataset/multich_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,8 +359,8 @@ def reduce_data(
self._noise_data = self._noise_data[
t_list, h_start:h_end, w_start:w_end, :
].copy()

self.set_img_sz(self._img_sz, self._grid_sz)
# TODO where tf is self._img_sz defined?
self.set_img_sz([self._img_sz, self._img_sz], self._grid_sz)
print(
f"[{self.__class__.__name__}] Data reduced. New data shape: {self._data.shape}"
)
Expand Down
2 changes: 1 addition & 1 deletion src/careamics/lvae_training/dataset/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

class DataType(Enum):
Elisa3DData = 0
NicolaData = 1
HTLIF24Data = 1
Pavia3SeqData = 2
TavernaSox2GolgiV2 = 3
Dao3ChannelWithInput = 4
Expand Down
Loading

0 comments on commit 9339825

Please sign in to comment.