Skip to content

Commit

Permalink
MicroSplit submission (#386)
Browse files Browse the repository at this point in the history
## Description

<!-- This section provides the necessary background and information for
reviewers to
understand the code and have the correct mindset when examining changes.
-->

> [!NOTE]  
> **tldr**: <!-- Write a one sentence summary. --> Contains all the
stuff we need for MicroSplit submission notebooks

### Background - why do we need this PR?

<!-- What problem are you solving? Describe in a few sentences the state
before
this PR. Use code examples if useful. --> 

A bunch of functionality required for notebooks in the MicroSplit
reproducibility repo. Unfortunately, also contain a lot of legacy code
or code in desperate need of refactoring.

### Overview - what changed?

<!-- What aspects and mechanisms of the code base changed? Describe only
the general
idea and overarching features. -->

### Implementation - how did you implement the changes?

<!-- How did you solve the issue technically? Explain why you chose this
approach and
provide code examples if applicable (e.g. change in the API for users).
-->


## Changes Made

<!-- This section highlights the important features and files that
reviewers should
pay attention to when reviewing. Only list important features or files,
this is useful for
reviewers to correctly assess how deeply the modifications impact the
code base.

For instance:

### New features or files
- `NewClass` added to `new_file.py`
- `new_function` added to `existing_file.py`

...
-->

### New features or files

<!-- List new features or files added. -->
-

### Modified features or files

<!-- List important modified features or files. -->
-

### Removed features or files

<!-- List removed features or files. -->
-

## How has this been tested?

<!-- Describe the tests that you ran to verify your changes. This can be
a short
description of the tests added to the PR or code snippet to reproduce
the change
in behaviour. -->

So far it haven't. We need to test the notebooks to check if anything
breaks

## Related Issues

<!-- Link to any related issues or discussions. Use keywords like
"Fixes", "Resolves",
or "Closes" to link to issues automatically. -->

- Resolves #

## Breaking changes

<!-- Describe any breaking changes introduced by this PR. -->

## Additional Notes and Examples

<!-- Provide any additional information that will help reviewers
understand the
changes. This can be links to documentations, forum posts, past
discussions etc. -->

---

**Please ensure your PR meets the following requirements:**

- [ ] Code builds and passes tests locally, including doctests
- [ ] New tests have been added (for bug fixes/features)
- [ ] Pre-commit passes
- [ ] PR to the documentation exists (for bug fixes / features)
  • Loading branch information
CatEek authored Feb 3, 2025
1 parent 426ba1d commit adacfe6
Show file tree
Hide file tree
Showing 8 changed files with 231 additions and 190 deletions.
121 changes: 64 additions & 57 deletions src/careamics/lvae_training/calibration.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union
from typing import Union, Optional

import numpy as np
import torch
Expand Down Expand Up @@ -34,9 +34,6 @@ def __init__(self, num_bins: int = 15):
self._bins = num_bins
self._bin_boundaries = None

def logvar_to_std(self, logvar: np.ndarray) -> np.ndarray:
return np.exp(logvar / 2)

def compute_bin_boundaries(self, predict_std: np.ndarray) -> np.ndarray:
"""Compute the bin boundaries for `num_bins` bins and predicted std values."""
min_std = np.min(predict_std)
Expand Down Expand Up @@ -104,81 +101,91 @@ def compute_stats(
)
rmse_stderr = np.sqrt(stderr) if stderr is not None else None

bin_var = np.mean((std_ch[bin_mask] ** 2))
bin_var = np.mean(std_ch[bin_mask] ** 2)
stats_dict[ch_idx]["rmse"].append(bin_error)
stats_dict[ch_idx]["rmse_err"].append(rmse_stderr)
stats_dict[ch_idx]["rmv"].append(np.sqrt(bin_var))
stats_dict[ch_idx]["bin_count"].append(bin_size)
self.stats_dict = stats_dict
return stats_dict

def get_calibrated_factor_for_stdev(
self,
pred: Optional[np.ndarray] = None,
pred_std: Optional[np.ndarray] = None,
target: Optional[np.ndarray] = None,
q_s: float = 0.00001,
q_e: float = 0.99999,
) -> dict[str, float]:
"""Calibrate the uncertainty by multiplying the predicted std with a scalar.
def get_calibrated_factor_for_stdev(
pred: Union[np.ndarray, torch.Tensor],
pred_std: Union[np.ndarray, torch.Tensor],
target: Union[np.ndarray, torch.Tensor],
q_s: float = 0.00001,
q_e: float = 0.99999,
num_bins: int = 30,
) -> dict[str, float]:
"""Calibrate the uncertainty by multiplying the predicted std with a scalar.
Parameters
----------
pred : Union[np.ndarray, torch.Tensor]
Predicted image, shape (n, h, w, c).
pred_std : Union[np.ndarray, torch.Tensor]
Predicted std, shape (n, h, w, c).
target : Union[np.ndarray, torch.Tensor]
Target image, shape (n, h, w, c).
q_s : float, optional
Start quantile, by default 0.00001.
q_e : float, optional
End quantile, by default 0.99999.
num_bins : int, optional
Number of bins to use for calibration, by default 30.
Returns
-------
dict[str, float]
Calibrated factor for each channel (slope + intercept).
"""
calib = Calibration(num_bins=num_bins)
stats_dict = calib.compute_stats(pred, pred_std, target)
outputs = {}
for ch_idx in stats_dict.keys():
y = stats_dict[ch_idx]["rmse"]
x = stats_dict[ch_idx]["rmv"]
count = stats_dict[ch_idx]["bin_count"]

first_idx = get_first_index(count, q_s)
last_idx = get_last_index(count, q_e)
x = x[first_idx:-last_idx]
y = y[first_idx:-last_idx]
slope, intercept, *_ = stats.linregress(x, y)
output = {"scalar": slope, "offset": intercept}
outputs[ch_idx] = output
return outputs
Parameters
----------
stats_dict : dict[int, dict[str, Union[np.ndarray, list]]]
Dictionary containing the stats for each channel.
q_s : float, optional
Start quantile, by default 0.00001.
q_e : float, optional
End quantile, by default 0.99999.
Returns
-------
dict[str, float]
Calibrated factor for each channel (slope + intercept).
"""
if not hasattr(self, "stats_dict"):
print("No stats found. Computing stats...")
if any(v is None for v in [pred, pred_std, target]):
raise ValueError("pred, pred_std, and target must be provided.")
self.stats_dict = self.compute_stats(
pred=pred, pred_std=pred_std, target=target
)
outputs = {}
for ch_idx in self.stats_dict.keys():
y = self.stats_dict[ch_idx]["rmse"]
x = self.stats_dict[ch_idx]["rmv"]
count = self.stats_dict[ch_idx]["bin_count"]

first_idx = get_first_index(count, q_s)
last_idx = get_last_index(count, q_e)
x = x[first_idx:-last_idx]
y = y[first_idx:-last_idx]
slope, intercept, *_ = stats.linregress(x, y)
output = {"scalar": slope, "offset": intercept}
outputs[ch_idx] = output
factors = self.get_factors_array(factors_dict=outputs)
return outputs, factors

def get_factors_array(self, factors_dict: list[dict]):
"""Get the calibration factors as a numpy array."""
calib_scalar = [factors_dict[i]["scalar"] for i in range(len(factors_dict))]
calib_scalar = np.array(calib_scalar).reshape(1, 1, 1, -1)
calib_offset = [
factors_dict[i].get("offset", 0.0) for i in range(len(factors_dict))
]
calib_offset = np.array(calib_offset).reshape(1, 1, 1, -1)
return {"scalar": calib_scalar, "offset": calib_offset}


def plot_calibration(ax, calibration_stats):
first_idx = get_first_index(calibration_stats[0]["bin_count"], 0.001)
last_idx = get_last_index(calibration_stats[0]["bin_count"], 0.999)
first_idx = get_first_index(calibration_stats[0]["bin_count"], 0.0001)
last_idx = get_last_index(calibration_stats[0]["bin_count"], 0.9999)
ax.plot(
calibration_stats[0]["rmv"][first_idx:-last_idx],
calibration_stats[0]["rmse"][first_idx:-last_idx],
"o",
label=r"$\hat{C}_0$: Ch1",
)

first_idx = get_first_index(calibration_stats[1]["bin_count"], 0.001)
last_idx = get_last_index(calibration_stats[1]["bin_count"], 0.999)
first_idx = get_first_index(calibration_stats[1]["bin_count"], 0.0001)
last_idx = get_last_index(calibration_stats[1]["bin_count"], 0.9999)
ax.plot(
calibration_stats[1]["rmv"][first_idx:-last_idx],
calibration_stats[1]["rmse"][first_idx:-last_idx],
"o",
label=r"$\hat{C}_1: : Ch2$",
label=r"$\hat{C}_1$: Ch2",
)

# TODO add multichannel
ax.set_xlabel("RMV")
ax.set_ylabel("RMSE")
ax.legend()
3 changes: 2 additions & 1 deletion src/careamics/lvae_training/dataset/lc_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ def reduce_data(
]

self.N = len(t_list)
self.set_img_sz(self._img_sz, self._grid_sz)
# TODO where tf is self._img_sz defined?
self.set_img_sz([self._img_sz, self._img_sz], self._grid_sz)
print(
f"[{self.__class__.__name__}] Data reduced. New data shape: {self._data.shape}"
)
Expand Down
4 changes: 2 additions & 2 deletions src/careamics/lvae_training/dataset/multich_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,8 +359,8 @@ def reduce_data(
self._noise_data = self._noise_data[
t_list, h_start:h_end, w_start:w_end, :
].copy()

self.set_img_sz(self._img_sz, self._grid_sz)
# TODO where tf is self._img_sz defined?
self.set_img_sz([self._img_sz, self._img_sz], self._grid_sz)
print(
f"[{self.__class__.__name__}] Data reduced. New data shape: {self._data.shape}"
)
Expand Down
2 changes: 1 addition & 1 deletion src/careamics/lvae_training/dataset/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

class DataType(Enum):
Elisa3DData = 0
NicolaData = 1
HTLIF24Data = 1
Pavia3SeqData = 2
TavernaSox2GolgiV2 = 3
Dao3ChannelWithInput = 4
Expand Down
Loading

0 comments on commit adacfe6

Please sign in to comment.