Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enforce specific tile size in UNet to reduce artifacts #121

Merged
merged 2 commits into from
May 28, 2024

Conversation

jdeschamps
Copy link
Member

Description

Pooling layers in the UNet break the translation invariance, which can lead to artifacts during tile prediction if the tiles have the wrong size. This PR adds a check for tile size different than k*2**depth, to insure artifact free prediction.

  • What: Enforce tile size that prevents artifacts.
  • Why: Users are likely to choose a tile size leading to artifacts.
  • How: The inference configuration convenience function now checks the UNet depth and throws an error if the tile size is not k*2**depth.

Changes Made

  • Added: Tile size checks in Inference configuration convenience function.

Additional Notes and Examples

Two important consequences arise:

  • If the image has total dimension in Z that is smaller than 2**depth (e.g. < 8 for UNet depth == 3), then Users will not be able to predict. Then two solutions: pad the image in Z, or bypass the CAREamist and use its trainer object directly with a PredictDataWrapper
  • Using the "Lightning API", there is no way to compare UNet depth and tile size. This is then up to the users to correctly choose the tile size. A note has been added to the docstring.

  • Code builds and passes tests locally, including doctests
  • New tests have been added (for bug fixes/features)
  • Pre-commit passes
  • PR to the documentation exists (for bug fixes / features): no need

Translational invariance is broken in UNets, due to the pooling layers. To avoid artefacts in during prediction, `tile_size` are now forced to be a multiple of 2**depth.

This does not impact the "Lightning API", and can be an issue for small Z depth during prediction. In such a case, padding can be added or using prediction while by passing the CAREamist.
@jdeschamps jdeschamps requested review from CatEek and melisande-c May 27, 2024 10:54
@jdeschamps jdeschamps changed the title Jd/fix/fix tile size unet Enforce specific tile size in UNet to reduce artifacts May 27, 2024
@melisande-c
Copy link
Member

Is it possible for Pydantic to validate a field based on another field? If so, the tile size could be validated directly in the InferenceModel class, since it has access to the model_config. However, this would only be beneficial if there is any chance of the InferenceModel class being initialised without the convenience function.

@jdeschamps
Copy link
Member Author

jdeschamps commented May 28, 2024

Is it possible for Pydantic to validate a field based on another field? If so, the tile size could be validated directly in the InferenceModel class, since it has access to the model_config. However, this would only be beneficial if there is any chance of the InferenceModel class being initialised without the convenience function.

Yes Pydantic supports the validation between fields, but unfortunately the model_config is a Pydantic ConfigDict and not the CAREamics (deep learning) model configuration. We've been struggling a bit with the naming conventions, since Pydantic classes are called models (and actually anything model_* is reserved namespace in Pydantic)...

The problem I faced here is that the DL-model configuration lives in a different Pydantic class than the tiling... So two choices:

  • Validation outside of the Pydantic model (which I chose here)
  • Validation inside InferenceModel by adding additional fields (e.g. adding the whole AlgorithmModel from the overall configuration)

I agree the second one is more elegant, but I think it is not compatible with the so-called "Lightning API", where people might be using the PredictDataWrapper outside of the CAREamist. There, they would then need to also pass the same fields to the PredictDataWrapper as it creates an internal InferenceModel. In the end, the InferenceModel is only created internally inside PredictDataWrapper and the convenience function.

A middle ground would be to minimize the number of fields, maybe something like:

class InferenceModel(BaseModel):
    ... 
    unet_depth: Optional[int] = None
    ...

That is a bit odd, but at least a single parameter that can be left None.

In the end, I still prefer the first solution because it does not have a strange parameter added to the PredictDataWrapper constructor (but not a very strong opinion...). But maybe that's not a big deal? What do you think?

@melisande-c
Copy link
Member

melisande-c commented May 28, 2024

Yes Pydantic supports the validation between fields, but unfortunately the model_config is a Pydantic ConfigDict and not the CAREamics (deep learning) model configuration. We've been struggling a bit with the naming conventions, since Pydantic classes are called models (and actually anything model_* is reserved namespace in Pydantic)...

My bad, I did know that and I think I just read through the code too quickly. Still getting used to Pydantic!

This all makes sense then! Thank you for the detailed explanation.

I agree that having the unet_depth passed to the InferenceModel is a bit odd and your solution makes the most sense since InferenceModel has limited reuse anyway.

@jdeschamps jdeschamps merged commit eea3ecd into main May 28, 2024
14 of 15 checks passed
@jdeschamps jdeschamps deleted the jd/fix/fix_tile_size_unet branch May 28, 2024 12:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants