Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring: Decouple tiling from prediction loop #140

Closed
melisande-c opened this issue Jun 11, 2024 · 5 comments · Fixed by #141
Closed

Refactoring: Decouple tiling from prediction loop #140

melisande-c opened this issue Jun 11, 2024 · 5 comments · Fixed by #141
Assignees
Labels
feature New feature or request

Comments

@melisande-c
Copy link
Member

melisande-c commented Jun 11, 2024

Description

I think the tiling logic should probably be removed from the prediction loop.

Why

  • Currently the prediction loop's run function is copied from the original Lightning _PredictionLoop class with the tiling logic added in-between the Lightning code. This is hard to maintain if Lightning ever change their code.
  • Currently I'm pretty sure CAREamicsPredictionLoop is incompatible with trainer.predict(*args, **kwargs, return_predictions=False).
  • Once the save to disk function is added, when a full (stitched) prediction is saved we want to free up memory so that the prediction can be run on a large number of files. This is hard to add into the current implementation.
  • When zarr datasets are added, tiles can be written into the correct place in the file without waiting for the complete set of tiles. This is hard to add into the current implementation.

Two solutions

Tiling as a Callback

Lightning already has a BasePredictionWriter Callback. There is a write_on_batch_end hook that could handle writing to zarr files or caching tiles until the last tile to save to tiff. The outputs of trainer.predict can also be changed so that it is the full prediction and not the tiles. I have implemented a version of this as a demo in this branch, where I move the current tiling logic to a callback. (This would have to change a lot to accommodate writing predictions).

Tiling in CAREamicsModule

I don't like this option as much because it doesn't feel like what the LightingModules are for. LightningModules have all the same hooks as Callback, on_predict_batch_end etc. So all the tiling logic could move there.

Discusion points

Please comment likes/dislikes with either solution or a new solution and any other thoughts!

  • Does trainer.predict have to output the stitched predictions?

    • To change the outputs of trainer.predict, in on_predict_epoch_end we have to do:
      trainer.predict_loop._predictions = stitched_predictions
      (see implementation in mentioned branch above) which I don't like because it feels a bit hacky.
    • predictions keep the tiling information alongside them, why not just do:
      predictions = trainer.predict(model=model, datamodule=datamodule, ckpt_path=checkpoint)
      stitched_preds = stitched_predictions(predictions)
    • Users using the CAREamist class never need to see this (happens in CAREamist.predict). Users not using the CAREamist class have more control, which is why they might not be using it.
  • Something to consider:

    • Users not using CAREamist class, if they want tiled predictions, currently have to do:
      trainer = Trainer(*args, **kwargs)
      trainer.prediction_loop = CAREamicsPredictionLoop(trainer)
      predictions = trainer.predict(model=model, datamodule: CAREamicsDataModule=datamodule)
    • With the Callback option:
      trainer = Trainer(*args, **kwargs, callbacks=[TiledPredictionCallback, ...])
      predictions = trainer.predict(model=model, datamodule: CAREamicsDataModule=datamodule)
    • With the LightningModule option
      trainer = Trainer(*args, **kwargs)
      predictions = trainer.predict(model=model, datamodule: CAREamicsDataModule=datamodule)
    • Although I have just realised this is redundant if stitching is applied afterwards (as described in point above). However they would still have to add the PredictionWriterCallback for saving tiles.
@melisande-c melisande-c added the feature New feature or request label Jun 11, 2024
@melisande-c melisande-c changed the title Refactoring: Decoupling tiling from prediction loop Refactoring: Decouple tiling from prediction loop Jun 11, 2024
@melisande-c
Copy link
Member Author

I am now leaning towards having stitching applied after prediction, which means we don't have to mess around with the lightning implementation too much and don't have to do the hacky trainer.predict_loop._predictions = stitched_predictions.

Then we can write a PredictionWriterCallback that will handle caching tiles until the last tile for saving to tiff, and writing tiles directly into zarr files.

I feel this will remove some complication in the code.

@jdeschamps
Copy link
Member

jdeschamps commented Jun 13, 2024

First off, totally agree!! We need something easier to maintain and with a more straightforward data flow.

Two solutions

Tiling as a Callback

Lightning already has a BasePredictionWriter Callback. There is a write_on_batch_end hook that could handle writing to zarr files or caching tiles until the last tile to save to tiff. The outputs of trainer.predict can also be changed so that it is the full prediction and not the tiles. I have implemented a version of this as a demo in this branch, where I move the current tiling logic to a callback. (This would have to change a lot to accommodate writing predictions).

What bothers me with the CallBack is that they are instantiated with the trainer. Which means that the prediction writer needs to handle all these cases:

  • prediction with tiling returned to the user
  • prediction without tiling returned to the user
  • prediction with tiling written to the disk
  • prediction without tiling written to the disk
  • prediction with tiling written to a zarr (where we don't have to keep the tiled in memory, we can already just write the tiles using the crop and stitch coords)
  • prediction without tiling written to a zarr

Either the callback can easily handle all that, or we need separate CallBacks and a possibility to switch between them depending on the user call (careamist.predict or careamist.predict_to_disk, regardless of whether these two exists or are just different parameters passed to a predict function).

Tiling in CAREamicsModule

I don't like this option as much because it doesn't feel like what the LightingModules are for. LightningModules have all the same hooks as Callback, on_predict_batch_end etc. So all the tiling logic could move there.

Is this the solution you are exploring in #141 ?

* Does `trainer.predict` have to output the stitched predictions?
  
  * To change the outputs of `trainer.predict`, in `on_predict_epoch_end` we have to do:
    ```python
    trainer.predict_loop._predictions = stitched_predictions
    ```
    (see implementation in mentioned branch above) which I don't like because it feels a bit hacky.
  * predictions keep the tiling information alongside them, why not just do:
    ```python
    predictions = trainer.predict(model=model, datamodule=datamodule, ckpt_path=checkpoint)
    stitched_preds = stitched_predictions(predictions)
    ```
  * Users using the `CAREamist` class never need to see this (happens in `CAREamist.predict`). Users not using the `CAREamist` class have more control, which is why they might not be using it.

I like the predictions stitching after trainer.predict, but that's not compatible with writing to disk, isn't it?

As for the people using the LigthningAPI, they are already replacing the prediction loop, so they would replace that one line by a call to the stitching somewhere else!

I might be arriving after the fight here (sorry for the delay), if I understand correctly:

  • Refactor: Tiling applied post prediction (in contrast to during loop) #141: shows how to just return the tiles and stitch them in the CAREamist
    • This is not incompatible with have a callback for writing to disk, but how would we switch between writing to disk and returning data to user? Somehow the callback needs to be disabled in the case where we return data to users.
    • Tiling for the LightningAPI would require that the logic for the stitching resides in an easy to call function when there are multiple images (which I think is what stitch_prediction is in the PR, although it might be better called stitch_predictions plural now).
  • For saving to the disk, we would need to add all the callback logic. Is that still the plan in Feature: Add option to save predictions to disk #136? Just for clarity!

@melisande-c
Copy link
Member Author

melisande-c commented Jun 13, 2024

I like the predictions stitching after trainer.predict, but that's not compatible with writing to disk, isn't it?

They actually are compatible!
The original description in this issue is actually a bit out of date now. When I was thinking about it, I was conflating the problems of writing predictions and removing tiling from the prediction loop but they are not as tightly coupled as I originally thought. As I wrote this issue the problem became clearer to me 😅 and now my solution is as follows:

  • For in memory predictions (CAREamist.predict) tiling will be applied after prediction, as implemented in Refactoring: Decouple tiling from prediction loop #140.
    • This can also be fully extracted as a seperate function for users not using the CAREamist class to use.
  • For writing predictions to disk (CAREamist.predict_to_disk) we can subclass the BasePredictionWriter.
    • For tiffs, the prediction writer will cache tiles until the last tile then call the stitch_prediction function, write the result to disk and then clear the tile cache. (This will also be the logic for custom file types).
    • For zarrs, tiles will be written directly to disk and not cached.
    • This can potentially be two different writers, but then setting which to call will be annoying since they would both have to be added to the trainer in CAREamist.__init__

@jdeschamps
Copy link
Member

jdeschamps commented Jun 13, 2024

Great, thanks for the summary!

Just to clarify after the discussion here: since the BasePredictionWriter is added to the trainer in CAREamist.__init__, do we switch from "writing to disk" to "return predictions" using settings given to the BasePredictionWriter?

Does lightning have mechanism to disable callbacks after instantiation? Or to update them?

@melisande-c
Copy link
Member Author

melisande-c commented Jun 13, 2024

Does lightning have mechanism to disable callbacks after instantiation? Or to update them?

No, annoyingly 😔

Just to clarify after the discussion here: since the BasePredictionWriter is added to the trainer in CAREamist.__init__, do we switch from "writing to disk" to "return predictions" using settings given to the BasePredictionWriter?

Following that discussion, if CAREamist keeps a reference to the callbacks, as you suggested, we can update attributes in the callbacks. So, when not writing to disk (in CAREamist.predict) we would do something like self.prediction_writer.save_predictions = False. And in the write to disk function we would set to True.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants