Skip to content

Commit

Permalink
feat: development notebook for patch_extractor factory
Browse files Browse the repository at this point in the history
  • Loading branch information
melisande-c committed Jan 27, 2025
1 parent 40a4cc6 commit 48d98ba
Showing 1 changed file with 105 additions and 0 deletions.
105 changes: 105 additions & 0 deletions src/careamics/dataset_ng/patch_extractor_factory_dev.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from collections.abc import Sequence\n",
"from pathlib import Path\n",
"from typing import Optional, Union\n",
"\n",
"from numpy.typing import NDArray\n",
"\n",
"from careamics.config import GeneralDataConfig\n",
"from careamics.config.support import SupportedData\n",
"from careamics.dataset_ng.patch_extractor import (\n",
" PatchExtractor,\n",
" PatchExtractorConstructor,\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def build_patch_extractor_constructor_kwargs(\n",
" data_config: GeneralDataConfig, **custom_kwargs\n",
"):\n",
" if data_config.data_type == SupportedData.ARRAY:\n",
" return {\"axes\": data_config.axes}\n",
" elif data_config.data_type == SupportedData.TIFF:\n",
" return {\"axes\": data_config.axes}\n",
" elif data_config.data_type == SupportedData.CUSTOM:\n",
" return {\"axes\": data_config.axes, **custom_kwargs}\n",
" else:\n",
" raise ValueError(f\"Data type {data_config.data_type} is not supported.\")\n",
"\n",
"\n",
"def create_patch_extractors(\n",
" data_config: GeneralDataConfig,\n",
" train_data: Union[Sequence[NDArray], Sequence[Path]],\n",
" val_data: Optional[Union[Sequence[NDArray], Sequence[Path]]] = None,\n",
" train_data_target: Optional[Union[Sequence[NDArray], Sequence[Path]]] = None,\n",
" val_data_target: Optional[Union[Sequence[NDArray], Sequence[Path]]] = None,\n",
" **custom_kwargs,\n",
") -> PatchExtractor:\n",
"\n",
" CONSTRUCTORS: dict[SupportedData, PatchExtractorConstructor] = {\n",
" SupportedData.ARRAY: PatchExtractor.from_arrays,\n",
" SupportedData.TIFF: PatchExtractor.from_tiff_files,\n",
" SupportedData.CUSTOM: PatchExtractor.from_custom_file_type,\n",
" }\n",
"\n",
" # get correct constructor\n",
" constructor = CONSTRUCTORS[data_config.data_type]\n",
"\n",
" constructor_kwargs = build_patch_extractor_constructor_kwargs(\n",
" data_config, **custom_kwargs\n",
" )\n",
" # --- train images\n",
" train_patch_extractor: PatchExtractor = constructor(\n",
" source=train_data, **constructor_kwargs\n",
" )\n",
"\n",
" additional_patch_extractors: list[Union[PatchExtractor, None]] = []\n",
" additional_data_sources = [val_data, train_data_target, val_data_target]\n",
" for data_source in additional_data_sources:\n",
" if data_source is not None:\n",
" additional_patch_extractor: Optional[PatchExtractor] = constructor(\n",
" source=data_source, **constructor_kwargs\n",
" )\n",
" else:\n",
" additional_patch_extractor = None\n",
" additional_patch_extractors.append(additional_patch_extractor)\n",
"\n",
"\n",
" return train_patch_extractor, *additional_patch_extractors"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.16"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit 48d98ba

Please sign in to comment.