Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Feature: Sampling from noise model, noise model refactoring (#340)
### Description - **What**: Added new method `sample_observation_from_signal` to the GaussianMixtureNoiseModel class, refactored the GaussianMixtureNoiseModel initialization and tensor device allocation, added noise model plotting function - **Why**: The `sample_observation_from_signal` function is necessary for the AI4Life project that I'm working on. The new tensor device allocation logic for GaussianMixtureNoiseModel was necessary because of the weights copying bug in the training. - **How**: 1. The `sample_observation_from_signal` function of GaussianMixtureNoiseModel class allows the creation of a noisy image based on an input clean "signal" image using the learned noise model statistics. It predicts means, standard deviations, and the probability of gaussian components for every pixel in "signal", then selects the gaussian component with the predicted probability and samples from the selected gaussian with the predicted mean and standard deviations. 2. The GaussianMixtureNoiseModel's parameters are now moved to the GPU before training and back to the CPU after training is finished. The `weights` parameter's `requires_grad` is set to `True` before training begins and it is detached after training is completed. 3. Added a slightly refactored version of `plot_probability_distribution` function into the utils folder, added matplotlib dependency ### Changes Made - **Added**: 1. `sample_observation_from_signal` function of GaussianMixtureNoiseModel 2. Added `_set_model_mode` functionality to move parameters between GPU and CPU for training 3. Added a new file `utils/plotting.py` with `plot_noise_model_probability_distribution` function for noise model visualization 4. Addded `matplotlib` dependency - **Modified**: Describe existing features or files modified. 1. Refactored GaussianMixtureNoiseModel class initialization 2. Added type annotations and missing docstrings, slightly changed namings to python standards 3. Refactored `create_histogram` function ### Breaking changes 1. Added the `matplotlib` dependency to the package. 2. `create_histogram` function's output is now slightly numerically different than before due to a logical error in the original code --- **Please ensure your PR meets the following requirements:** - [x] Code builds and passes tests locally, including doctests - [x] New tests have been added (for bug fixes/features) - [x] Pre-commit passes - [ ] PR to the documentation exists (for bug fixes / features)
- Loading branch information