Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: Tiling applied post prediction (in contrast to during loop) #141

Merged
merged 37 commits into from
Jun 20, 2024

Conversation

melisande-c
Copy link
Member

@melisande-c melisande-c commented Jun 12, 2024

Description

Using the regular Lightning prediction loop, and tiled datasets, the tile information is also returned when calling trainer.predict. This means there is no reason tiling cannot be applied after prediction instead of during the prediction loop. This means we don't have to write a custom loop, simplifying the code and reducing coupling with Lightning. This will also make life easier when we add saving predictions (tiff or zarr).

What: Removes CAREamicsPredictionLoop (where tiling was previously implemented) and applies tiling in CAREamist.predict after calling trainer.predict with the regular Lightning prediction loop.

Why: Altering the Lightning prediction loop was overcomplicated, hard to maintain and made adding an option to save predictions more difficult.

How: predictions are returned with tiling information so predicted tiles are stitched together at the end of prediction.

Changes Made

  • Added:
    • multi-image stitch prediction function (takes name of old stitching function, stitched_prediction)
    • New prediction_utils package.
    • module prediction_outputs in new prediction_utils package.
  • Modified:
    • Trainer in CAREamist no longer has prediction_loop replaced
    • At the end of CAREamist.predict predictions are stitched and/or converted to match old CAREamistPredictionLoop outputs.
    • Creation of CAREamicsPredictData has been moved to new prediction_utils package
    • stitch_prediction function has been moved to prediction_utils package
  • Removed: CAREamistPredictionLoop
  • Tests
    • Added: test for new stitched_prediction function.
    • Added: test for prediction output conversion.
    • Modified: moved stitched_prediction tests to match new file structure of src.
    • Modified: test_predict_on_array_tiled and test_predict_arrays_no_tiling – parametrised with samples and batch_size; squeeze train_array when asserting size equality.

Related Issues

Breaking changes

Any code that instantiated a Lightning Trainer and added the CAREamistPredictionLoop .
There might be some unforeseen changes to dimensions of prediction outputs that the tests do not catch, i.e. adding S & C dims.

TODO: (for future)

  • Change prediction so that list is always output (currently, if there is only 1 prediction it will not output a list).
  • Dimensions of outputs always have all SC(Z)YX, or match input dimensions. Currently, there is some inconsistency between tiled and not tiled prediction output dimensions.

Please ensure your PR meets the following requirements:

  • Code builds and passes tests locally, including doctests
  • New tests have been added (for bug fixes/features)
  • Pre-commit passes
  • PR to the documentation exists (for bug fixes / features)

@melisande-c melisande-c requested review from jdeschamps and CatEek June 12, 2024 15:17
Copy link

codecov bot commented Jun 12, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 91.23%. Comparing base (acb456f) to head (519c1df).
Report is 2 commits behind head on main.

Current head 519c1df differs from pull request most recent head bc916d3

Please upload reports for the commit bc916d3 to get more accurate results.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #141      +/-   ##
==========================================
- Coverage   91.30%   91.23%   -0.07%     
==========================================
  Files         104      103       -1     
  Lines        2633     2510     -123     
==========================================
- Hits         2404     2290     -114     
+ Misses        229      220       -9     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

src/careamics/careamist.py Outdated Show resolved Hide resolved
src/careamics/careamist.py Outdated Show resolved Hide resolved
src/careamics/careamist.py Outdated Show resolved Hide resolved
@conradkun conradkun mentioned this pull request Jun 18, 2024
@melisande-c melisande-c marked this pull request as ready for review June 18, 2024 16:05
@melisande-c melisande-c requested review from jdeschamps and CatEek June 18, 2024 16:06
Copy link
Member

@jdeschamps jdeschamps left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good!

I am wondering a bit what the consequences are for the Lightning API. I guess they just need to call the stitch method. Since I anyway need to create examples and tests, I will make a PR in that direction as soon as this is merged.

else:
predictions_output = combine_batches(predictions, tiled)

# TODO: add this in? Returns output with same axes as input
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't the result of stitch prediction have a "S" dimensions if the input had any? If singleton then I guess it doesn't?

The output should already be shaped correctly. Now what we could do is to reshape back to the real input (which may have the axes in a weird order)... But for now it is better to stick to that.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see now, this is not the reshape that we have elsewhere in the code. This is to add S and C back.

If the images are just returned to the user or saved to the disk, we are probably better off not returning singleton dims. It is annoying in terms of coherence, but much nicer for downstream.

Copy link
Member Author

@melisande-c melisande-c Jun 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jdeschamps

I think there are some inconsistencies with output dims depending on whether the prediction was tiled or not. If the prediction was tiled, then because the function stitch_prediction doesn't expect the S dim, if there was only one image, then the output will also not have an S dim. However, if the prediction is not tiled, then the S dimension is not removed, even if there is only one image.

I have left it like this because I am pretty sure it is recreating how the code was previously behaving, but I can merge this PR and then create a new PR to fix this issue.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make an issue out of it so that we don't loose track of it! Also with respect to the list or not list return value.

I think for now it is ok to merge!

Copy link
Contributor

@CatEek CatEek left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@CatEek CatEek merged commit 0a29ea2 into main Jun 20, 2024
15 checks passed
@CatEek CatEek deleted the mc/refac/post_prediction_tiling branch June 20, 2024 12:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Batched prediction [BUG] Refactoring: Decouple tiling from prediction loop
3 participants