Skip to content

Commit

Permalink
Feature: Sampling from noise model, noise model refactoring (#340)
Browse files Browse the repository at this point in the history
### Description

- **What**: 
Added new method `sample_observation_from_signal` to the
GaussianMixtureNoiseModel class, refactored the
GaussianMixtureNoiseModel initialization and tensor device allocation,
added noise model plotting function
- **Why**: 
The `sample_observation_from_signal` function is necessary for the
AI4Life project that I'm working on. The new tensor device allocation
logic for GaussianMixtureNoiseModel was necessary because of the weights
copying bug in the training.
- **How**: 
1. The `sample_observation_from_signal` function of
GaussianMixtureNoiseModel class allows the creation of a noisy image
based on an input clean "signal" image using the learned noise model
statistics. It predicts means, standard deviations, and the probability
of gaussian components for every pixel in "signal", then selects the
gaussian component with the predicted probability and samples from the
selected gaussian with the predicted mean and standard deviations.
2. The GaussianMixtureNoiseModel's parameters are now moved to the GPU
before training and back to the CPU after training is finished. The
`weights` parameter's `requires_grad` is set to `True` before training
begins and it is detached after training is completed.
3. Added a slightly refactored version of
`plot_probability_distribution` function into the utils folder, added
matplotlib dependency

### Changes Made

- **Added**: 
1. `sample_observation_from_signal` function of
GaussianMixtureNoiseModel
2. Added `_set_model_mode` functionality to move parameters between GPU
and CPU for training
3. Added a new file `utils/plotting.py` with
`plot_noise_model_probability_distribution` function for noise model
visualization
  4. Addded `matplotlib` dependency
- **Modified**: Describe existing features or files modified.
1. Refactored GaussianMixtureNoiseModel class initialization 
2. Added type annotations and missing docstrings, slightly changed
namings to python standards
3. Refactored `create_histogram` function

### Breaking changes

1. Added the `matplotlib` dependency to the package.
2. `create_histogram` function's output is now slightly numerically
different than before due to a logical error in the original code

---

**Please ensure your PR meets the following requirements:**

- [x] Code builds and passes tests locally, including doctests
- [x] New tests have been added (for bug fixes/features)
- [x] Pre-commit passes
- [ ] PR to the documentation exists (for bug fixes / features)
  • Loading branch information
veegalinova authored Jan 16, 2025
1 parent 2bdd021 commit 8f03c88
Show file tree
Hide file tree
Showing 5 changed files with 446 additions and 226 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ dependencies = [
'scikit-image<=0.25.0',
'zarr<3.0.0',
'pillow<=11.1.0',
'matplotlib<=3.9.0'
]

[project.optional-dependencies]
Expand Down
2 changes: 2 additions & 0 deletions src/careamics/models/lvae/likelihoods.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,8 @@ def _set_params_to_same_device_as(
if self.data_mean.device != correct_device_tensor.device:
self.data_mean = self.data_mean.to(correct_device_tensor.device)
self.data_std = self.data_std.to(correct_device_tensor.device)
if correct_device_tensor.device != self.noiseModel.device:
self.noiseModel.to_device(correct_device_tensor.device)

def get_mean_lv(self, x: torch.Tensor) -> tuple[torch.Tensor, None]:
return x, None
Expand Down
Loading

0 comments on commit 8f03c88

Please sign in to comment.