Skip to content

Commit

Permalink
Axes (#103)
Browse files Browse the repository at this point in the history
  • Loading branch information
jdeschamps authored Mar 5, 2024
2 parents 14ff842 + 3280564 commit 922a829
Show file tree
Hide file tree
Showing 126 changed files with 6,484 additions and 4,751 deletions.
70 changes: 30 additions & 40 deletions examples/2D/n2v/example_BSD68_lightning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,18 @@
"source": [
"from pathlib import Path\n",
"\n",
"import tifffile\n",
"import matplotlib.pyplot as plt\n",
"import tifffile\n",
"from careamics_portfolio import PortfolioManager\n",
"from pytorch_lightning import Trainer\n",
"import albumentations as Aug\n",
"\n",
"from careamics_portfolio import PortfolioManager\n",
"from careamics.lightning_module import (\n",
" CAREamicsModule,\n",
" CAREamicsTrainDataModule,\n",
"from careamics import CAREamicsModule\n",
"from careamics.lightning_prediction import CAREamicsFiring\n",
"from careamics.ligthning_datamodule import (\n",
" CAREamicsPredictDataModule,\n",
" CAREamicsFiring,\n",
" CAREamicsTrainDataModule,\n",
")\n",
"from careamics.utils.metrics import psnr\n",
"from careamics.transforms import ManipulateN2V"
"from careamics.utils.metrics import psnr"
]
},
{
Expand All @@ -38,8 +36,18 @@
"metadata": {},
"outputs": [],
"source": [
"# Download and unzip the files\n",
"# Explore portfolio\n",
"portfolio = PortfolioManager()\n",
"print(portfolio.denoising)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Download and unzip the files\n",
"root_path = Path(\"data\")\n",
"files = portfolio.denoising.N2V_BSD68.download(root_path)\n",
"print(f\"List of downloaded files: {files}\")"
Expand All @@ -55,7 +63,12 @@
"train_path = data_path / \"train\"\n",
"val_path = data_path / \"val\"\n",
"test_path = data_path / \"test\" / \"images\"\n",
"gt_path = data_path / \"test\" / \"gt\""
"gt_path = data_path / \"test\" / \"gt\"\n",
"\n",
"train_path.mkdir(parents=True, exist_ok=True)\n",
"val_path.mkdir(parents=True, exist_ok=True)\n",
"test_path.mkdir(parents=True, exist_ok=True)\n",
"gt_path.mkdir(parents=True, exist_ok=True)"
]
},
{
Expand Down Expand Up @@ -118,7 +131,9 @@
" algorithm=\"n2v\",\n",
" loss=\"n2v\",\n",
" architecture=\"UNet\",\n",
")\n"
" optimizer_parameters={\"lr\": 1e-4},\n",
" lr_scheduler_parameters={\"factor\": 0.5, \"patience\": 10},\n",
")"
]
},
{
Expand All @@ -129,17 +144,6 @@
"### Define the Transforms"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"transforms = Aug.Compose(\n",
" [Aug.Flip(), Aug.RandomRotate90(), Aug.Normalize(), ManipulateN2V()],\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
Expand All @@ -161,7 +165,6 @@
" patch_size=(64, 64),\n",
" axes=\"SYX\",\n",
" batch_size=128,\n",
" transforms=transforms,\n",
" num_workers=4,\n",
")"
]
Expand All @@ -182,7 +185,7 @@
"metadata": {},
"outputs": [],
"source": [
"trainer = Trainer(max_epochs=1)"
"trainer = Trainer(max_epochs=50)"
]
},
{
Expand All @@ -202,17 +205,6 @@
"### Define a prediction datamodule"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"transforms_predict = Aug.Compose(\n",
" [Aug.Normalize()],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -225,8 +217,6 @@
" tile_size=(256, 256),\n",
" axes=\"YX\",\n",
" batch_size=1,\n",
" num_workers=0,\n",
" transforms=transforms_predict,\n",
")"
]
},
Expand Down Expand Up @@ -324,7 +314,7 @@
"psnr_total = 0\n",
"\n",
"for pred, gt in zip(preds, gts):\n",
" psnr_total += psnr(gt, pred)\n",
" psnr_total += psnr(gt, pred.squeeze())\n",
"\n",
"print(f\"PSNR total: {psnr_total / len(preds)}\")"
]
Expand Down Expand Up @@ -353,7 +343,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
"version": "3.9.18"
},
"vscode": {
"interpreter": {
Expand Down
9 changes: 4 additions & 5 deletions examples/careamics_api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from careamics import CAREamist, Configuration


def main():
config_dict ={
config_dict = {
"experiment_name": "ConfigTest",
"working_directory": ".",
"algorithm": {
Expand All @@ -14,9 +15,7 @@ def main():
"optimizer": {
"name": "Adam",
},
"lr_scheduler": {
"name": "ReduceLROnPlateau"
},
"lr_scheduler": {"name": "ReduceLROnPlateau"},
},
"training": {
"num_epochs": 1,
Expand All @@ -42,5 +41,5 @@ def main():
# print(pred.shape)


if __name__ == '__main__':
if __name__ == "__main__":
main()
18 changes: 10 additions & 8 deletions examples/careamics_lightning_api.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,11 @@
"import albumentations as Aug\n",
"from pytorch_lightning import Trainer\n",
"\n",
"\n",
"from careamics import (\n",
" CAREamicsModule,\n",
" CAREamicsTrainDataModule,\n",
")\n",
"from careamics.transforms import ManipulateN2V\n"
"from careamics.transforms import ManipulateN2V"
]
},
{
Expand All @@ -26,20 +25,20 @@
"# Instantiate ligthning module\n",
"model = CAREamicsModule(\n",
" algorithm=\"n2v\",\n",
" loss=\"n2v\", \n",
" loss=\"n2v\",\n",
" architecture=\"UNet\",\n",
" model_parameters={\n",
" # parameters such as depth, n2v2, etc. See UNet definition.\n",
" },\n",
" optimizer=\"Adam\", # see SupportedOptimizer\n",
" optimizer=\"Adam\", # see SupportedOptimizer\n",
" optimizer_parameters={\n",
" \"lr\": 1e-4,\n",
" # parameters from torch.optim\n",
" },\n",
" lr_scheduler=\"ReduceLROnPlateau\", # see SupportedLRScheduler\n",
" lr_scheduler=\"ReduceLROnPlateau\", # see SupportedLRScheduler\n",
" lr_scheduler_parameters={\n",
" # parameters from torch.optim.lr_scheduler\n",
" }\n",
" },\n",
")"
]
},
Expand Down Expand Up @@ -68,9 +67,12 @@
"outputs": [],
"source": [
"# define function to read data\n",
"\n",
"\n",
"def read_my_data_type(file):\n",
" pass\n",
"\n",
"\n",
"# Create your transforms using albumentations\n",
"transforms = Aug.Compose(\n",
" [Aug.Flip(), Aug.RandomRotate90(), Aug.Normalize(), ManipulateN2V()],\n",
Expand All @@ -83,13 +85,13 @@
"train_data_module = CAREamicsTrainDataModule(\n",
" train_path=train_path,\n",
" val_path=val_path,\n",
" data_type=\"custom\", # this forces read_source_func to be specified\n",
" data_type=\"custom\", # this forces read_source_func to be specified\n",
" patch_size=(64, 64),\n",
" axes=\"SYX\",\n",
" batch_size=128,\n",
" transforms=transforms,\n",
" num_workers=4,\n",
" read_source_func = read_my_data_type # function to read data\n",
" read_source_func=read_my_data_type, # function to read data\n",
")"
]
},
Expand Down
Loading

0 comments on commit 922a829

Please sign in to comment.