Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Export to BMZ and other changes #108

Merged
merged 14 commits into from
Apr 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions examples/2D/n2n/example_SEM_careamist.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"import tifffile\n",
"from careamics_portfolio import PortfolioManager\n",
"\n",
"from careamics import CAREamist\n"
"from careamics import CAREamist"
]
},
{
Expand Down Expand Up @@ -156,8 +156,12 @@
"metadata": {},
"outputs": [],
"source": [
"engine.train(train_source=train_image[0], val_source=train_image[1],\n",
" train_target=train_image[2], val_target=train_image[3],)"
"engine.train(\n",
" train_source=train_image[0],\n",
" val_source=train_image[1],\n",
" train_target=train_image[2],\n",
" val_target=train_image[3],\n",
")"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion examples/2D/n2v/example_BSD68_careamist.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@
"source": [
"# Create a list of ground truth images\n",
"\n",
"gts = [tifffile.imread(f) for f in sorted(gt_path.glob(\"*.tiff\"))]\n"
"gts = [tifffile.imread(f) for f in sorted(gt_path.glob(\"*.tiff\"))]"
]
},
{
Expand Down
10 changes: 5 additions & 5 deletions examples/2D/n2v/example_SEM_lightning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,20 @@
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"import shutil\n",
"import albumentations as Aug\n",
"from pathlib import Path\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import tifffile\n",
"from careamics_portfolio import PortfolioManager\n",
"from pytorch_lightning import Trainer\n",
"\n",
"from careamics import CAREamicsModule\n",
"from careamics.lightning_prediction import CAREamicsPredictionLoop\n",
"from careamics.lightning_datamodule import (\n",
" CAREamicsPredictDataModule,\n",
" CAREamicsTrainDataModule,\n",
")"
")\n",
"from careamics.lightning_prediction import CAREamicsPredictionLoop"
]
},
{
Expand Down Expand Up @@ -142,7 +142,7 @@
" model_parameters={\"n2v2\": False},\n",
" optimizer_parameters={\"lr\": 1e-3},\n",
" lr_scheduler_parameters={\"factor\": 0.5, \"patience\": 10},\n",
")\n"
")"
]
},
{
Expand Down
1 change: 1 addition & 0 deletions examples/3D/example_flywing_3D.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"import numpy as np\n",
"import tifffile\n",
"from careamics_portfolio import PortfolioManager\n",
"\n",
"# from itkwidgets import compare, view # \"pip install itkwidgets \"if necessary\n",
"from pytorch_lightning import Trainer\n",
"\n",
Expand Down
18 changes: 11 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,13 @@ classifiers = [
"Typing :: Typed",
]
dependencies = [
# pytorch should be installed via conda/pip beforehand
'albumentations',
'bioimageio.core',
'bioimageio.core>=0.6.0',
'tifffile',
'psutil',
'pydantic>=2.5',
'pytorch_lightning',
'pytorch_lightning>=2.2.0',
'pyyaml',
'scikit-image',
'zarr',
Expand Down Expand Up @@ -166,11 +167,14 @@ ignore = [
[tool.numpydoc_validation]
checks = [
"all", # report on all checks, except the below
"EX01",
"SA01",
"ES01",
"GL02",
"GL03",
"EX01", # Example section not found
"SA01", # See Also section not found
"ES01", # Extended Summar not found
"GL01", # Docstring text (summary) should start in the line immediately
# after the opening quotes
"GL02", # Closing quotes should be placed in the line after the last text
# in the docstring
"GL03", # Double line break found
]
exclude = [ # don't report on objects that match any of these regex
"test_*",
Expand Down
6 changes: 2 additions & 4 deletions src/careamics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@

from .careamist import CAREamist
from .config import Configuration, load_configuration, save_configuration
from .lightning_datamodule import (
CAREamicsPredictDataModule,
CAREamicsTrainDataModule,
)
from .lightning_datamodule import CAREamicsTrainDataModule
from .lightning_module import CAREamicsModule
from .lightning_prediction_datamodule import CAREamicsPredictDataModule
7 changes: 4 additions & 3 deletions src/careamics/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Callback module."""
"""Callbacks module."""

__all__ = ["ProgressBarCallback"]
__all__ = ["HyperParametersCallback", "ProgressBarCallback"]

from .progress_bar import ProgressBarCallback
from .hyperparameters_callback import HyperParametersCallback
from .progress_bar_callback import ProgressBarCallback
42 changes: 42 additions & 0 deletions src/careamics/callbacks/hyperparameters_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import Callback

from careamics.config import Configuration


class HyperParametersCallback(Callback):
"""
Callback allowing saving CAREamics configuration as hyperparameters in the model.

This allows saving the configuration as dictionnary in the checkpoints, and
loading it subsequently in a CAREamist instance.

Attributes
----------
config : Configuration
CAREamics configuration to be saved as hyperparameter in the model.
"""

def __init__(self, config: Configuration):
"""
Constructor.

Parameters
----------
config : Configuration
CAREamics configuration to be saved as hyperparameter in the model.
"""
self.config = config

def on_train_start(self, trainer: Trainer, pl_module: LightningModule):
"""
Update the hyperparameters of the model with the configuration on train start.

Parameters
----------
trainer : Trainer
PyTorch Lightning trainer.
pl_module : LightningModule
PyTorch Lightning module.
"""
pl_module.hparams.update(self.config.model_dump())
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class ProgressBarCallback(TQDMProgressBar):
def init_train_tqdm(self) -> tqdm:
"""Override this to customize the tqdm bar for training."""
bar = tqdm(
desc='Training',
desc="Training",
position=(2 * self.process_position),
disable=self.is_disabled,
leave=True,
Expand All @@ -27,12 +27,12 @@ def init_validation_tqdm(self) -> tqdm:
# The main progress bar doesn't exist in `trainer.validate()`
has_main_bar = self.train_progress_bar is not None
bar = tqdm(
desc='Validating',
desc="Validating",
position=(2 * self.process_position + has_main_bar),
disable=self.is_disabled,
leave=False,
dynamic_ncols=True,
file=sys.stdout
file=sys.stdout,
)
return bar

Expand All @@ -45,7 +45,7 @@ def init_test_tqdm(self) -> tqdm:
leave=True,
dynamic_ncols=False,
ncols=100,
file=sys.stdout
file=sys.stdout,
)
return bar

Expand All @@ -55,4 +55,3 @@ def get_metrics(
"""Override this to customize the metrics displayed in the progress bar."""
pbar_metrics = trainer.progress_bar_metrics
return {**pbar_metrics}

Loading
Loading