Skip to content

Commit

Permalink
style(pre-commit.ci): auto fixes [...]
Browse files Browse the repository at this point in the history
  • Loading branch information
pre-commit-ci[bot] committed Nov 27, 2024
1 parent 0b1960e commit 2a70d98
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 25 deletions.
54 changes: 30 additions & 24 deletions src/careamics/lvae_training/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,35 +68,41 @@ def compute_stats(
stats_dict = {}
for ch_idx in range(pred.shape[-1]):
stats_dict[ch_idx] = {
'bin_count': [],
'rmv': [],
'rmse': [],
'bin_boundaries': None,
'bin_matrix': [],
'rmse_err': []
"bin_count": [],
"rmv": [],
"rmse": [],
"bin_boundaries": None,
"bin_matrix": [],
"rmse_err": [],
}
pred_ch = pred[..., ch_idx]
std_ch = pred_std[..., ch_idx]
target_ch = target[..., ch_idx]
boundaries = self.compute_bin_boundaries(std_ch)
stats_dict[ch_idx]['bin_boundaries'] = boundaries
stats_dict[ch_idx]["bin_boundaries"] = boundaries
bin_matrix = np.digitize(std_ch.reshape(-1), boundaries)
bin_matrix = bin_matrix.reshape(std_ch.shape)
stats_dict[ch_idx]['bin_matrix'] = bin_matrix
error = (pred_ch - target_ch)**2
for bin_idx in range(1, 1+self._bins):
stats_dict[ch_idx]["bin_matrix"] = bin_matrix
error = (pred_ch - target_ch) ** 2
for bin_idx in range(1, 1 + self._bins):
bin_mask = bin_matrix == bin_idx
bin_error = error[bin_mask]
bin_size = np.sum(bin_mask)
bin_error = np.sqrt(np.sum(bin_error) / bin_size) if bin_size > 0 else None
stderr = np.std(error[bin_mask]) / np.sqrt(bin_size) if bin_size > 0 else None
bin_error = (
np.sqrt(np.sum(bin_error) / bin_size) if bin_size > 0 else None
)
stderr = (
np.std(error[bin_mask]) / np.sqrt(bin_size)
if bin_size > 0
else None
)
rmse_stderr = np.sqrt(stderr) if stderr is not None else None

bin_var = np.mean((std_ch[bin_mask]**2))
stats_dict[ch_idx]['rmse'].append(bin_error)
stats_dict[ch_idx]['rmse_err'].append(rmse_stderr)
stats_dict[ch_idx]['rmv'].append(np.sqrt(bin_var))
stats_dict[ch_idx]['bin_count'].append(bin_size)
bin_var = np.mean((std_ch[bin_mask] ** 2))
stats_dict[ch_idx]["rmse"].append(bin_error)
stats_dict[ch_idx]["rmse_err"].append(rmse_stderr)
stats_dict[ch_idx]["rmv"].append(np.sqrt(bin_var))
stats_dict[ch_idx]["bin_count"].append(bin_size)
return stats_dict


Expand All @@ -109,7 +115,7 @@ def get_calibrated_factor_for_stdev(
num_bins: int = 30,
) -> dict[str, float]:
"""Calibrate the uncertainty by multiplying the predicted std with a scalar.
Parameters
----------
pred : Union[np.ndarray, torch.Tensor]
Expand All @@ -124,7 +130,7 @@ def get_calibrated_factor_for_stdev(
End quantile, by default 0.99999.
num_bins : int, optional
Number of bins to use for calibration, by default 30.
Returns
-------
dict[str, float]
Expand All @@ -134,16 +140,16 @@ def get_calibrated_factor_for_stdev(
stats_dict = calib.compute_stats(pred, pred_std, target)
outputs = {}
for ch_idx in stats_dict.keys():
y = stats_dict[ch_idx]['rmse']
x = stats_dict[ch_idx]['rmv']
count = stats_dict[ch_idx]['bin_count']
y = stats_dict[ch_idx]["rmse"]
x = stats_dict[ch_idx]["rmv"]
count = stats_dict[ch_idx]["bin_count"]

first_idx = get_first_index(count, q_s)
last_idx = get_last_index(count, q_e)
x = x[first_idx:-last_idx]
y = y[first_idx:-last_idx]
slope, intercept, *_ = stats.linregress(x,y)
output = {'scalar': slope, 'offset': intercept}
slope, intercept, *_ = stats.linregress(x, y)
output = {"scalar": slope, "offset": intercept}
outputs[ch_idx] = output
return outputs

Expand Down
3 changes: 2 additions & 1 deletion src/careamics/lvae_training/eval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
- quantify the performance of the model
- create plots to visualize the results.
"""

import os
from typing import List, Literal, Union

Expand Down Expand Up @@ -858,4 +859,4 @@ def stitch_predictions_new(predictions, dset):
else:
raise ValueError(f"Unsupported shape {output.shape}")

return output
return output

0 comments on commit 2a70d98

Please sign in to comment.