Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring: lightning API package and smoke tests #161

Merged
merged 15 commits into from
Jul 4, 2024

Conversation

jdeschamps
Copy link
Member

@jdeschamps jdeschamps commented Jun 21, 2024

Description

I moved the lightning modules into their own package to clarify the imports for the different API, and added smoke tests for the Lightning API. As a result I needed to implement a new function to retrieve the statistics from the CAREamicsTrainData, which led me to refactor a bit the datasets and standardize how the statistics are recorded across the two datasets.

The smoke tests check that the Lightning API works for tiled and un-tiled data.

Finally, the lightning modules wrappers have disappeared, and now convenience functions explicit the parameters. This led to renaming of the lightning modules.

  • What:
    • Moved Lightning modules and callbacks to careamics.lightning
    • Added tests for the Lightning API (tiled and un-tiled prediction included)
    • Normalize the way the datasets keep the image statistics.
    • Fixed some errors in CAREamicsTrainData
  • Why: Since we are differentiating between CAREamist and Lightning APIs, they should be imported from different packages.
  • How:
    • Moved Lightning modules and callbacks to careamics.lightning
    • Added tests for Ligthning API.
    • Removed StatsOutput, both datasets now have a self.image_stats: Stats and self.target_stats: Stats
    • Added get_data_statistics to the datasets and CAREamicsTrainData and corresponding test

Changes Made

  • Added: test_lightning_api.py.
  • Modified:
    • ligthning_data_module.py
    • Iterable_dataset.py
    • in_memory_dataset.py
    • patching.py
    • imports throughout the careamist and other files

Additional Notes and Examples

To use the Lightning API, users now need to have the following imports:

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint

from careamics.lightning import (
    CAREamicsModuleWrapper,
    TrainingDataWrapper,
    PredictDataWrapper,
)
from careamics.prediction_utils import convert_outputs # if tiling required

(see full example here)

Since the PredictDataWrapper requires passing the statistics, users can now simply call:

means, stds = train_data_wrapper.get_data_statistics()

While building the notebook, I uncovered the following error: #162


Please ensure your PR meets the following requirements:

@jdeschamps
Copy link
Member Author

This PR might interact with #153

@jdeschamps jdeschamps requested review from CatEek and melisande-c and removed request for CatEek June 21, 2024 17:38
melisande-c added a commit that referenced this pull request Jun 24, 2024
### Description

Fixes a bug with stitching multiple prediction outputs

- **What**: There was a mistake when splitting tiles into groups for
their respective images. This only became a problem with > 1 images.
- **How**: Very simple fix; a last tile from the previous group was
being added to the next. Just needed to +1 to the index.

### Changes Made

- **Modified**: `stitch_prediction` function. +1 to index when creating
`image_slices`.

### Related Issues

- Fixes #162 

### Additional Notes and Examples

I did not an extra test because I think they are added in #161


---

**Please ensure your PR meets the following requirements:**

- [x] Code builds and passes tests locally, including doctests
- [ ] New tests have been added (for bug fixes/features)
- [x] Pre-commit passes
- [ ] PR to the documentation exists (for bug fixes / features)
Copy link
Member

@melisande-c melisande-c left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a couple notes on the lightning API that are maybe not relevant to this PR since they weren't altered here.

  1. This is just my opinion so feel free to ignore: but I find all the "Wrapper" type classes unnecessary and confusing. These include the TrainingDataWrapper, CAREamicsModuleWrapper and PredictDataWrapper. The purpose of these classes is to initialise the classes they claim to wrap with a different set of parameters and this can easily be achieved with a function. I feel subclassing should be primarily be reserved for extending or modifying the functionality of the parent class.

For example a CAREamicsTrainData factory function could reimplement what is in the __init__ of TrainingDataWrapper, but return a CAREamicsTrainData object, as such:

def create_train_datamodule(
  train_data: Union[str, Path, np.ndarray],
  data_type: Union[Literal["array", "tiff", "custom"], SupportedData],
  patch_size: List[int],
  axes: str,
  batch_size: int,
  val_data: Optional[Union[str, Path]] = None,
  transforms: Optional[List[TRANSFORMS_UNION]] = None,
  train_target_data: Optional[Union[str, Path]] = None,
  val_target_data: Optional[Union[str, Path]] = None,
  read_source_func: Optional[Callable] = None,
  extension_filter: str = "",
  val_percentage: float = 0.1,
  val_minimum_patches: int = 5,
  dataloader_params: Optional[dict] = None,
  use_in_memory: bool = True,
  use_n2v2: bool = False,
  struct_n2v_axis: Literal["horizontal", "vertical", "none"] = "none",
  struct_n2v_span: int = 5,
) -> CAREamicsTrainData:
  """Create `CAREamicsTrainData`."""
   # logic from `TrainingDataWrapper.__init__`
  ...
  return CAREamicsTrainData(
    data_config=data_config,
    train_data=train_data,
    val_data=val_data,
    train_data_target=train_target_data,
    val_data_target=val_target_data,
    read_source_func=read_source_func,
    extension_filter=extension_filter,
    val_percentage=val_percentage,
    val_minimum_split=val_minimum_patches,
    use_in_memory=use_in_memory
  )
  1. Not a priority but I will mention anyway: I think the CAREamicsTrainData.setup method is pretty difficult to read with all the nested "if else" statements 😅. Maybe a strategy pattern could help here in future, especially if there are more dataset types to come, i.e. zarr !

src/careamics/dataset/patching/patching.py Show resolved Hide resolved
src/careamics/dataset/patching/patching.py Show resolved Hide resolved
@jdeschamps
Copy link
Member Author

I have a couple notes on the lightning API that are maybe not relevant to this PR since they weren't altered here.

1. This is just my opinion so feel free to ignore: but I find all the "`Wrapper`" type classes unnecessary and confusing. These include the `TrainingDataWrapper`, `CAREamicsModuleWrapper` and `PredictDataWrapper`. The purpose of these classes is to initialise the classes they claim to wrap with a different set of parameters and this can easily be achieved with a function. I feel subclassing should be primarily be reserved for extending or modifying the functionality of the parent class.

For example a CAREamicsTrainData factory function could reimplement what is in the __init__ of TrainingDataWrapper, but return a CAREamicsTrainData object, as such:

def create_train_datamodule(
  train_data: Union[str, Path, np.ndarray],
  data_type: Union[Literal["array", "tiff", "custom"], SupportedData],
  patch_size: List[int],
  axes: str,
  batch_size: int,
  val_data: Optional[Union[str, Path]] = None,
  transforms: Optional[List[TRANSFORMS_UNION]] = None,
  train_target_data: Optional[Union[str, Path]] = None,
  val_target_data: Optional[Union[str, Path]] = None,
  read_source_func: Optional[Callable] = None,
  extension_filter: str = "",
  val_percentage: float = 0.1,
  val_minimum_patches: int = 5,
  dataloader_params: Optional[dict] = None,
  use_in_memory: bool = True,
  use_n2v2: bool = False,
  struct_n2v_axis: Literal["horizontal", "vertical", "none"] = "none",
  struct_n2v_span: int = 5,
) -> CAREamicsTrainData:
  """Create `CAREamicsTrainData`."""
   # logic from `TrainingDataWrapper.__init__`
  ...
  return CAREamicsTrainData(
    data_config=data_config,
    train_data=train_data,
    val_data=val_data,
    train_data_target=train_target_data,
    val_data_target=val_target_data,
    read_source_func=read_source_func,
    extension_filter=extension_filter,
    val_percentage=val_percentage,
    val_minimum_split=val_minimum_patches,
    use_in_memory=use_in_memory
  )
2. Not a priority but I will mention anyway: I think the `CAREamicsTrainData.setup` method is pretty difficult to read with all the nested "if else" statements 😅. Maybe a strategy pattern could help here in future, especially if there are more dataset types to come, i.e. zarr !

I absolutely agree, the whole dataset handling has been a mess since they have been implemented and we are slowly refactoring things...

I like the idea of the function rather than the subclassing. It would allow us to come up with better names for the classes as well (since now there would be only a single one).

@jdeschamps
Copy link
Member Author

Update: the wrappers have made space for functions, this created a duplication with a method used in CAREamics. The duplication was solved via some refactoring.

@jdeschamps jdeschamps requested a review from melisande-c June 26, 2024 09:40
Copy link

codecov bot commented Jun 26, 2024

Codecov Report

Attention: Patch coverage is 96.93878% with 3 lines in your changes missing coverage. Please review.

Project coverage is 41.65%. Comparing base (acb456f) to head (954599d).
Report is 14 commits behind head on main.

Files Patch % Lines
src/careamics/lightning/train_data_module.py 92.59% 2 Missing ⚠️
src/careamics/dataset/patching/patching.py 94.11% 1 Missing ⚠️

❗ There is a different number of reports uploaded between BASE (acb456f) and HEAD (954599d). Click for more details.

HEAD has 1 upload more than BASE | Flag | BASE (acb456f) | HEAD (954599d) | |------|------|------| ||2|3|
Additional details and impacted files
@@             Coverage Diff             @@
##             main     #161       +/-   ##
===========================================
- Coverage   91.30%   41.65%   -49.66%     
===========================================
  Files         104      119       +15     
  Lines        2633     5786     +3153     
===========================================
+ Hits         2404     2410        +6     
- Misses        229     3376     +3147     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

src/careamics/config/configuration_factory.py Outdated Show resolved Hide resolved
src/careamics/careamist.py Outdated Show resolved Hide resolved
src/careamics/lightning/lightning_module.py Outdated Show resolved Hide resolved
@jdeschamps
Copy link
Member Author

Removed create_inference_config and moved the code into the careamics predict function, now InferenceConfig is only instantiated in create_predict_datamodule.

And I corrected the vocab. 😉

Comment on lines +621 to +623
if model.architecture == SupportedArchitecture.UNET.value:
# tile size must be equal to k*2^n, where n is the number of pooling
# layers (equal to the depth) and k is an integer
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I realise now the InferenceConfig doesn't have access to the type of architecture so it can't be validated in there 😅, I guess it might be nice to isolate this check as a separate function somewhere but otherwise looks good!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not only the type of architecture, but also the depth of the UNet...

But you had a good point with the double Inference instantiation, that was convoluted. Let's see later if we want to move that to some utils package or so.

@jdeschamps jdeschamps merged commit ff20596 into main Jul 4, 2024
15 checks passed
@jdeschamps jdeschamps deleted the jd/refac/lightning_api branch July 4, 2024 12:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants