Skip to content

Commit

Permalink
various tests fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
CatEek committed Jan 21, 2025
1 parent 93a0eff commit 425d0bf
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 31 deletions.
1 change: 0 additions & 1 deletion src/careamics/careamist.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,6 @@ def __init__(
self.cfg = load_configuration(source)

# instantiate model
# TODO call model factory here
if isinstance(self.cfg.algorithm_config, UNetBasedAlgorithm):
self.model = FCNModule(
algorithm_config=self.cfg.algorithm_config,
Expand Down
67 changes: 64 additions & 3 deletions src/careamics/config/algorithms/n2v_algorithm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,15 @@

from typing import Annotated, Literal

from pydantic import ConfigDict, AfterValidator
from pydantic import AfterValidator, ConfigDict

from careamics.config.architectures import UNetModel
from careamics.config.transformations import N2VManipulateModel
from careamics.config.validators import (
model_matching_in_out_channels,
model_without_final_activation,
)

from careamics.config.transformations import N2VManipulateModel

from .unet_algorithm_model import UNetBasedAlgorithm


Expand All @@ -33,3 +32,65 @@ class N2VAlgorithm(UNetBasedAlgorithm):
AfterValidator(model_matching_in_out_channels),
AfterValidator(model_without_final_activation),
]

def get_masking_strategy(self) -> str:
"""Get the masking strategy for N2V."""
return self.n2v_masking.strategy

def set_masking_strategy(self, strategy: Literal["uniform", "median"]) -> None:
"""
Set masking strategy.
Parameters
----------
strategy : "uniform" or "median"
Strategy to use for N2V2.
Raises
------
ValueError
If the N2V pixel manipulate transform is not found in the transforms.
"""
self.model.n2v_masking.strategy = strategy


def set_n2v2(self, use_n2v2: bool) -> None:
"""
Set the configuration to use N2V2 or the vanilla Noise2Void.
Parameters
----------
use_n2v2 : bool
Whether to use N2V2.
"""
if use_n2v2:
self.set_masking_strategy("median")
else:
self.set_masking_strategy("uniform")

def is_using_struct_n2v(self) -> bool:
"""Check if the configuration is using structN2V."""
return self.n2v_masking.struct_mask_axis != "none" # TODO change!

def set_structN2V_mask(
self, mask_axis: Literal["horizontal", "vertical", "none"], mask_span: int
) -> None:
"""
Set structN2V mask parameters.
Setting `mask_axis` to `none` will disable structN2V.
Parameters
----------
mask_axis : Literal["horizontal", "vertical", "none"]
Axis along which to apply the mask. `none` will disable structN2V.
mask_span : int
Total span of the mask in pixels.
Raises
------
ValueError
If the N2V pixel manipulate transform is not found in the transforms.
"""
self.n2v_masking.struct_mask_axis = mask_axis
self.n2v_masking.struct_mask_span = mask_span
20 changes: 16 additions & 4 deletions src/careamics/lightning/lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class FCNModule(L.LightningModule):
Learning rate scheduler name.
"""

def __init__(self, algorithm_config: UNetBasedAlgorithm) -> None:
def __init__(self, algorithm_config: Union[UNetBasedAlgorithm, dict]) -> None:
"""Lightning module for CAREamics.
This class encapsulates the a PyTorch model along with the training, validation,
Expand All @@ -78,11 +78,15 @@ def __init__(self, algorithm_config: UNetBasedAlgorithm) -> None:
"""
super().__init__()

if isinstance(algorithm_config, dict):
algorithm_config = algorithm_factory(algorithm_config)

# create preprocessing, model and loss function
# TODO should we use compose here ?
# TODO should we use compose here ? should we even have this?
self.preprocess = preprocess_factory(
getattr(algorithm_config, "n2v_masking", [])
)
self.algorithm = algorithm_config.algorithm
self.model: nn.Module = model_factory(algorithm_config.model)
self.loss_func = loss_factory(algorithm_config.loss)

Expand Down Expand Up @@ -123,7 +127,11 @@ def training_step(self, batch: Tensor, batch_idx: Any) -> Any:
Loss value.
"""
x, *targets = batch
x_preprocessed, *aux = self.preprocess(x)
if self.algorithm == "n2v":
x_preprocessed, *aux = self.preprocess(x)
else:
x_preprocessed = x
aux = []
out = self.model(x_preprocessed)
loss = self.loss_func(out, *aux, *targets)
self.log(
Expand All @@ -142,7 +150,11 @@ def validation_step(self, batch: Tensor, batch_idx: Any) -> None:
Batch index.
"""
x, *targets = batch
x_preprocessed, *aux = self.preprocess(x)
if self.algorithm == "n2v":
x_preprocessed, *aux = self.preprocess(x)
else:
x_preprocessed = x
aux = []
out = self.model(x_preprocessed)
val_loss = self.loss_func(out, *aux, *targets)

Expand Down
19 changes: 5 additions & 14 deletions tests/config/test_configuration_factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,23 +457,14 @@ def test_n2v_configuration_n2v2_structn2v():
struct_n2v_axis=struct_mask_axis,
struct_n2v_span=struct_n2v_span,
)
assert len(config.data_config.transforms) == 3
assert (
config.data_config.transforms[-1].name
== SupportedTransform.N2V_MANIPULATE.value
)
assert (
config.data_config.transforms[-1].name
== SupportedTransform.N2V_MANIPULATE.value
)
assert (
config.data_config.transforms[-1].strategy
config.algorithm_config.n2v_masking.strategy
== SupportedPixelManipulation.MEDIAN.value
)
assert config.data_config.transforms[-1].roi_size == roi_size
assert config.algorithm_config.n2v_masking.roi_size == roi_size
assert (
config.data_config.transforms[-1].masked_pixel_percentage
config.algorithm_config.n2v_masking.masked_pixel_percentage
== masked_pixel_percentage
)
assert config.data_config.transforms[-1].struct_mask_axis == struct_mask_axis
assert config.data_config.transforms[-1].struct_mask_span == struct_n2v_span
assert config.algorithm_config.n2v_masking.struct_mask_axis == struct_mask_axis
assert config.algorithm_config.n2v_masking.struct_mask_span == struct_n2v_span
8 changes: 7 additions & 1 deletion tests/transforms/test_manipulate_n2v.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from careamics.config.support import SupportedPixelManipulation
from careamics.transforms import N2VManipulate, N2VManipulateTorch
from careamics.config.transformations import N2VManipulateModel



@pytest.mark.parametrize(
Expand Down Expand Up @@ -39,8 +41,12 @@ def test_manipulate_n2v_torch(strategy):
# Create tensor, adding a channel to simulate a 2D image with channel first
array = torch.arange(16 * 16).reshape(1, 16, 16).float()

# create configuration
config = N2VManipulateModel(
roi_size=5, masked_pixel_percentage=5, strategy=strategy.value
)
# Create augmentation
aug = N2VManipulateTorch(roi_size=5, masked_pixel_percentage=5, strategy=strategy)
aug = N2VManipulateTorch(config)

# Apply augmentation
augmented = aug(array)
Expand Down
8 changes: 0 additions & 8 deletions tests/transforms/test_supported_transforms.py

This file was deleted.

0 comments on commit 425d0bf

Please sign in to comment.