Skip to content

Commit

Permalink
Merge branch 'main' into iz/feat/separate_norm
Browse files Browse the repository at this point in the history
  • Loading branch information
jdeschamps authored Jun 6, 2024
2 parents 15ccfb8 + 763d965 commit 7903ef0
Show file tree
Hide file tree
Showing 96 changed files with 1,935 additions and 1,616 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11"]
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
# https://docs.github.com/en/actions/using-github-hosted-runners/about-github-hosted-runners/about-github-hosted-runners#standard-github-hosted-runners-for-public-repositories
platform: [ubuntu-latest, macos-13, windows-latest]

Expand Down
20 changes: 10 additions & 10 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ ci:

repos:
- repo: https://github.com/abravalheri/validate-pyproject
rev: v0.16
rev: v0.18
hooks:
- id: validate-pyproject

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.3
rev: v0.4.7
hooks:
- id: ruff
args: [--fix, --target-version, py38]
Expand All @@ -40,14 +40,14 @@ repos:
hooks:
- id: numpydoc-validation

# jupyter linting and formatting
- repo: https://github.com/nbQA-dev/nbQA
rev: 1.8.5
hooks:
- id: nbqa-ruff
args: [--fix]
- id: nbqa-black
#- id: nbqa-mypy
# # jupyter linting and formatting
# - repo: https://github.com/nbQA-dev/nbQA
# rev: 1.8.5
# hooks:
# - id: nbqa-ruff
# args: [--fix]
# - id: nbqa-black
# #- id: nbqa-mypy

# strip out jupyter notebooks
- repo: https://github.com/kynan/nbstripout
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
</a>
</p>

# CAREamics Restoration
# CAREamics

[![License](https://img.shields.io/pypi/l/careamics.svg?color=green)](https://github.com/CAREamics/careamics/blob/main/LICENSE)
[![PyPI](https://img.shields.io/pypi/v/careamics.svg?color=green)](https://pypi.org/project/careamics)
Expand All @@ -31,4 +31,4 @@ tutorials on how to best apply these methods in a scientific context.

## Installation and use

Check out the [documentation](https://careamics.github.io/) for installation instructions and guides!
Check out the [documentation](https://careamics.github.io/) for installation instructions and guides!
8 changes: 0 additions & 8 deletions examples/2D/n2v/example_BSD68_lightning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,7 @@
" CAREamicsPredictDataModule,\n",
" CAREamicsTrainDataModule,\n",
")\n",
<<<<<<< Updated upstream
"from careamics.lightning_prediction import CAREamicsPredictionLoop\n",
=======
"from careamics.lightning_prediction import CAREamicsFiring\n",
>>>>>>> Stashed changes
"from careamics.utils.metrics import psnr"
]
},
Expand Down Expand Up @@ -172,11 +168,7 @@
"train_data_module = CAREamicsTrainDataModule(\n",
" train_data=train_path,\n",
" val_data=val_path,\n",
<<<<<<< Updated upstream
" data_type=\"tiff\", # to use np.ndarray, set data_type to \"array\"\n",
=======
" data_type=\"tiff\",\n",
>>>>>>> Stashed changes
" patch_size=(64, 64),\n",
" axes=\"SYX\",\n",
" batch_size=128,\n",
Expand Down
10 changes: 5 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ classifiers = [
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"License :: OSI Approved :: BSD License",
"Typing :: Typed",
]
Expand Down Expand Up @@ -73,7 +74,7 @@ repository = "https://github.com/CAREamics/careamics"
line-length = 88
target-version = "py38"
src = ["src"]
select = [
lint.select = [
"E", # style errors
"W", # style warnings
"F", # flakes
Expand All @@ -86,7 +87,7 @@ select = [
"A001", # flake8-builtins
"RUF", # ruff-specific rules
]
ignore = [
lint.ignore = [
"D100", # Missing docstring in public module
"D107", # Missing docstring in __init__
"D203", # 1 blank line required before class docstring
Expand All @@ -103,13 +104,12 @@ ignore = [
"UP006", # Replace typing.List by list, mandatory for py3.8
"UP007", # Replace Union by |, mandatory for py3.9
]
ignore-init-module-imports = true
show-fixes = true

[tool.ruff.pydocstyle]
[tool.ruff.lint.pydocstyle]
convention = "numpy"

[tool.ruff.per-file-ignores]
[tool.ruff.lint.per-file-ignores]
"tests/*.py" = ["D", "S"]
"setup.py" = ["D"]

Expand Down
13 changes: 10 additions & 3 deletions src/careamics/callbacks/hyperparameters_callback.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Callback saving CAREamics configuration as hyperparameters in the model."""

from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import Callback

Expand All @@ -11,13 +13,18 @@ class HyperParametersCallback(Callback):
This allows saving the configuration as dictionnary in the checkpoints, and
loading it subsequently in a CAREamist instance.
Parameters
----------
config : Configuration
CAREamics configuration to be saved as hyperparameter in the model.
Attributes
----------
config : Configuration
CAREamics configuration to be saved as hyperparameter in the model.
"""

def __init__(self, config: Configuration):
def __init__(self, config: Configuration) -> None:
"""
Constructor.
Expand All @@ -28,14 +35,14 @@ def __init__(self, config: Configuration):
"""
self.config = config

def on_train_start(self, trainer: Trainer, pl_module: LightningModule):
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""
Update the hyperparameters of the model with the configuration on train start.
Parameters
----------
trainer : Trainer
PyTorch Lightning trainer.
PyTorch Lightning trainer, unused.
pl_module : LightningModule
PyTorch Lightning module.
"""
Expand Down
41 changes: 37 additions & 4 deletions src/careamics/callbacks/progress_bar_callback.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Progressbar callback."""

import sys
from typing import Dict, Union

Expand All @@ -10,7 +12,13 @@ class ProgressBarCallback(TQDMProgressBar):
"""Progress bar for training and validation steps."""

def init_train_tqdm(self) -> tqdm:
"""Override this to customize the tqdm bar for training."""
"""Override this to customize the tqdm bar for training.
Returns
-------
tqdm
A tqdm bar.
"""
bar = tqdm(
desc="Training",
position=(2 * self.process_position),
Expand All @@ -23,7 +31,13 @@ def init_train_tqdm(self) -> tqdm:
return bar

def init_validation_tqdm(self) -> tqdm:
"""Override this to customize the tqdm bar for validation."""
"""Override this to customize the tqdm bar for validation.
Returns
-------
tqdm
A tqdm bar.
"""
# The main progress bar doesn't exist in `trainer.validate()`
has_main_bar = self.train_progress_bar is not None
bar = tqdm(
Expand All @@ -37,7 +51,13 @@ def init_validation_tqdm(self) -> tqdm:
return bar

def init_test_tqdm(self) -> tqdm:
"""Override this to customize the tqdm bar for testing."""
"""Override this to customize the tqdm bar for testing.
Returns
-------
tqdm
A tqdm bar.
"""
bar = tqdm(
desc="Testing",
position=(2 * self.process_position),
Expand All @@ -52,6 +72,19 @@ def init_test_tqdm(self) -> tqdm:
def get_metrics(
self, trainer: Trainer, pl_module: LightningModule
) -> Dict[str, Union[int, str, float, Dict[str, float]]]:
"""Override this to customize the metrics displayed in the progress bar."""
"""Override this to customize the metrics displayed in the progress bar.
Parameters
----------
trainer : Trainer
The trainer object.
pl_module : LightningModule
The LightningModule object, unused.
Returns
-------
dict
A dictionary with the metrics to display in the progress bar.
"""
pbar_metrics = trainer.progress_bar_metrics
return {**pbar_metrics}
Loading

0 comments on commit 7903ef0

Please sign in to comment.