From ec88b0cd8ebe9a6b44f8f08680b42066f03594eb Mon Sep 17 00:00:00 2001 From: pattonw Date: Fri, 6 Oct 2023 13:37:04 -0700 Subject: [PATCH 01/74] remove duplicated for loop --- gunpowder/torch/nodes/train.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/gunpowder/torch/nodes/train.py b/gunpowder/torch/nodes/train.py index ed2df002..ae3f184e 100644 --- a/gunpowder/torch/nodes/train.py +++ b/gunpowder/torch/nodes/train.py @@ -278,13 +278,6 @@ def train_step(self, batch, request): spec.roi = request[array_key].roi batch.arrays[array_key] = Array(tensor.grad.cpu().detach().numpy(), spec) - for array_key, array_name in requested_outputs.items(): - spec = self.spec[array_key].copy() - spec.roi = request[array_key].roi - batch.arrays[array_key] = Array( - outputs[array_name].cpu().detach().numpy(), spec - ) - batch.loss = loss.cpu().detach().numpy() self.iteration += 1 batch.iteration = self.iteration From abcecf2f8eb6a23c28366c853f2af612ed1890e0 Mon Sep 17 00:00:00 2001 From: pattonw Date: Fri, 6 Oct 2023 13:37:43 -0700 Subject: [PATCH 02/74] increment patch number --- gunpowder/version_info.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gunpowder/version_info.py b/gunpowder/version_info.py index 01724d07..e45efbfd 100644 --- a/gunpowder/version_info.py +++ b/gunpowder/version_info.py @@ -1,6 +1,6 @@ __major__ = 1 __minor__ = 3 -__patch__ = 1 +__patch__ = 2 __tag__ = "" __version__ = "{}.{}.{}{}".format(__major__, __minor__, __patch__, __tag__).strip(".") From b557548bcfd73b09340ff61eca4d727e87f52446 Mon Sep 17 00:00:00 2001 From: pattonw Date: Fri, 6 Oct 2023 14:00:33 -0700 Subject: [PATCH 03/74] ArraySpec docs fix documentation to be more accurate around nonspatial arrays --- gunpowder/array_spec.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/gunpowder/array_spec.py b/gunpowder/array_spec.py index ec271488..1400b7d4 100644 --- a/gunpowder/array_spec.py +++ b/gunpowder/array_spec.py @@ -14,13 +14,12 @@ class ArraySpec(Freezable): roi (:class:`Roi`): The region of interested represented by this array spec. Can be - ``None`` for :class:`BatchProviders` that allow - requests for arrays everywhere, but will always be set for array - specs that are part of a :class:`Array`. + ``None`` for nonspatial arrays but must otherwise always be set. voxel_size (:class:`Coordinate`): - The size of the spatial axises in world units. + The size of the spatial axises in world units. Can be ``None`` for + nonspatial arrays but must otherwise always be set. interpolatable (``bool``): @@ -55,7 +54,10 @@ def __init__( if nonspatial: assert roi is None, "Non-spatial arrays can not have a ROI" - assert voxel_size is None, "Non-spatial arrays can not " "have a voxel size" + assert voxel_size is None, "Non-spatial arrays can not have a voxel size" + else: + assert roi is not None, "Spatial arrays must have a ROI" + assert voxel_size is not None, "Spatial arrays must have a voxel size" self.freeze() From 80033bd0986582b5dddb326f478588ab19ecfdef Mon Sep 17 00:00:00 2001 From: pattonw Date: Fri, 6 Oct 2023 14:05:58 -0700 Subject: [PATCH 04/74] ArraySpec bug fix: allow None roi/voxel size for spatial arrays --- gunpowder/array_spec.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/gunpowder/array_spec.py b/gunpowder/array_spec.py index 1400b7d4..9002ae4f 100644 --- a/gunpowder/array_spec.py +++ b/gunpowder/array_spec.py @@ -14,12 +14,12 @@ class ArraySpec(Freezable): roi (:class:`Roi`): The region of interested represented by this array spec. Can be - ``None`` for nonspatial arrays but must otherwise always be set. + ``None`` for nonspatial arrays or to indicate the true value is unknown. voxel_size (:class:`Coordinate`): The size of the spatial axises in world units. Can be ``None`` for - nonspatial arrays but must otherwise always be set. + nonspatial arrays or to indicate the true value is unknown. interpolatable (``bool``): @@ -55,9 +55,6 @@ def __init__( if nonspatial: assert roi is None, "Non-spatial arrays can not have a ROI" assert voxel_size is None, "Non-spatial arrays can not have a voxel size" - else: - assert roi is not None, "Spatial arrays must have a ROI" - assert voxel_size is not None, "Spatial arrays must have a voxel size" self.freeze() From 29497bbd65c56f3c2367698966874db4891e0447 Mon Sep 17 00:00:00 2001 From: sheridana Date: Tue, 17 Oct 2023 13:15:32 -0700 Subject: [PATCH 05/74] Add probability option to aug nodes --- gunpowder/nodes/batch_filter.py | 12 +++++-- gunpowder/nodes/defect_augment.py | 12 +++++++ gunpowder/nodes/deform_augment.py | 14 +++++++- gunpowder/nodes/elastic_augment.py | 13 ++++++- gunpowder/nodes/intensity_augment.py | 13 +++++++ gunpowder/nodes/noise_augment.py | 16 +++++++-- gunpowder/nodes/shift_augment.py | 7 ++-- gunpowder/nodes/simple_augment.py | 13 ++++++- tests/cases/batch_filter.py | 53 ++++++++++++++++++++++++++++ 9 files changed, 143 insertions(+), 10 deletions(-) create mode 100644 tests/cases/batch_filter.py diff --git a/gunpowder/nodes/batch_filter.py b/gunpowder/nodes/batch_filter.py index 2ba954c1..e2938e8a 100644 --- a/gunpowder/nodes/batch_filter.py +++ b/gunpowder/nodes/batch_filter.py @@ -137,7 +137,7 @@ def autoskip_enabled(self): return self._autoskip_enabled def provide(self, request): - skip = self.__can_skip(request) + skip = self._can_skip(request) or self.skip_node(request) timing_prepare = Timing(self, "prepare") timing_prepare.start() @@ -190,7 +190,7 @@ def provide(self, request): return batch - def __can_skip(self, request): + def _can_skip(self, request): """Check if this filter needs to be run for the given request.""" if not self.autoskip_enabled: @@ -206,6 +206,14 @@ def __can_skip(self, request): return True + def skip_node(self, request): + """To be implemented in subclasses. + + Skip a node if a condition is met. Can be useful if using a probability + to determine whether to use an augmentation, for example. + """ + pass + def setup(self): """To be implemented in subclasses. diff --git a/gunpowder/nodes/defect_augment.py b/gunpowder/nodes/defect_augment.py index 8f8fefd3..13f0fee6 100644 --- a/gunpowder/nodes/defect_augment.py +++ b/gunpowder/nodes/defect_augment.py @@ -67,6 +67,13 @@ class DefectAugment(BatchFilter): axis (``int``, optional): Along which axis sections are cut. + + p (``float``, optional): + + Probability applying the augmentation. Default is 1.0 (always + apply). Should be a float value between 0 and 1. Lowering this value + could be useful for computational efficiency and increasing + augmentation space. """ def __init__( @@ -82,6 +89,7 @@ def __init__( artifacts_mask=None, deformation_strength=20, axis=0, + p=1.0, ): self.intensities = intensities self.prob_missing = prob_missing @@ -94,6 +102,7 @@ def __init__( self.artifacts_mask = artifacts_mask self.deformation_strength = deformation_strength self.axis = axis + self.p = p def setup(self): if self.artifact_source is not None: @@ -103,6 +112,9 @@ def teardown(self): if self.artifact_source is not None: self.artifact_source.teardown() + def skip_node(self, request): + return random.random() > self.p + # send roi request to data-source upstream def prepare(self, request): deps = BatchRequest() diff --git a/gunpowder/nodes/deform_augment.py b/gunpowder/nodes/deform_augment.py index cdf5eeff..13826909 100644 --- a/gunpowder/nodes/deform_augment.py +++ b/gunpowder/nodes/deform_augment.py @@ -81,6 +81,14 @@ class DeformAugment(BatchFilter): Whether or not to compute the elastic transform node wise for nodes that were lossed during the fast elastic transform process. + + + p (``float``, optional): + + Probability applying the augmentation. Default is 1.0 (always + apply). Should be a float value between 0 and 1. Lowering this value + could be useful for computational efficiency and increasing + augmentation space. """ def __init__( @@ -95,6 +103,7 @@ def __init__( recompute_missing_points=True, transform_key: ArrayKey = None, graph_raster_voxel_size: Coordinate = None, + p: float = 1.0, ): self.control_point_spacing = Coordinate(control_point_spacing) self.jitter_sigma = Coordinate(jitter_sigma) @@ -107,6 +116,7 @@ def __init__( self.recompute_missing_points = recompute_missing_points self.transform_key = transform_key self.graph_raster_voxel_size = Coordinate(graph_raster_voxel_size) + self.p = p assert ( self.control_point_spacing.dims == self.jitter_sigma.dims @@ -128,8 +138,10 @@ def setup(self): self.provides(self.transform_key, spec) - def prepare(self, request): + def skip_node(self, request): + return random.random() > self.p + def prepare(self, request): # get the total ROI of all requests total_roi = request.get_total_roi() logger.debug("total ROI is %s" % total_roi) diff --git a/gunpowder/nodes/elastic_augment.py b/gunpowder/nodes/elastic_augment.py index a70f7866..c40cb8d7 100644 --- a/gunpowder/nodes/elastic_augment.py +++ b/gunpowder/nodes/elastic_augment.py @@ -88,6 +88,13 @@ class ElasticAugment(BatchFilter): Whether or not to compute the elastic transform node wise for nodes that were lossed during the fast elastic transform process. + + p (``float``, optional): + + Probability applying the augmentation. Default is 1.0 (always + apply). Should be a float value between 0 and 1. Lowering this value + could be useful for computational efficiency and increasing + augmentation space. """ def __init__( @@ -103,6 +110,7 @@ def __init__( spatial_dims=3, use_fast_points_transform=False, recompute_missing_points=True, + p=1.0, ): warnings.warn( "ElasticAugment is deprecated, please use the DeformAugment", @@ -122,9 +130,12 @@ def __init__( self.spatial_dims = spatial_dims self.use_fast_points_transform = use_fast_points_transform self.recompute_missing_points = recompute_missing_points + self.p = p - def prepare(self, request): + def skip_node(self, request): + return random.random() > self.p + def prepare(self, request): # get the voxel size self.voxel_size = self.__get_common_voxel_size(request) diff --git a/gunpowder/nodes/intensity_augment.py b/gunpowder/nodes/intensity_augment.py index 771f57bb..1055549f 100644 --- a/gunpowder/nodes/intensity_augment.py +++ b/gunpowder/nodes/intensity_augment.py @@ -1,4 +1,5 @@ import numpy as np +import random from gunpowder.batch_request import BatchRequest @@ -34,6 +35,13 @@ class IntensityAugment(BatchFilter): Set to False if modified values should not be clipped to [0, 1] Disables range check! + + p (``float``, optional): + + Probability applying the augmentation. Default is 1.0 (always + apply). Should be a float value between 0 and 1. Lowering this value + could be useful for computational efficiency and increasing + augmentation space. """ def __init__( @@ -45,6 +53,7 @@ def __init__( shift_max, z_section_wise=False, clip=True, + p=1.0, ): self.array = array self.scale_min = scale_min @@ -53,11 +62,15 @@ def __init__( self.shift_max = shift_max self.z_section_wise = z_section_wise self.clip = clip + self.p = p def setup(self): self.enable_autoskip() self.updates(self.array, self.spec[self.array]) + def skip_node(self, request): + return random.random() > self.p + def prepare(self, request): deps = BatchRequest() deps[self.array] = request[self.array].copy() diff --git a/gunpowder/nodes/noise_augment.py b/gunpowder/nodes/noise_augment.py index f4bfb5ba..c2ff223f 100644 --- a/gunpowder/nodes/noise_augment.py +++ b/gunpowder/nodes/noise_augment.py @@ -1,4 +1,5 @@ import numpy as np +import random import skimage from gunpowder.batch_request import BatchRequest @@ -24,18 +25,29 @@ class NoiseAugment(BatchFilter): Whether to preserve the image range (either [-1, 1] or [0, 1]) by clipping values in the end, see scikit-image documentation + + p (``float``, optional): + + Probability applying the augmentation. Default is 1.0 (always + apply). Should be a float value between 0 and 1. Lowering this value + could be useful for computational efficiency and increasing + augmentation space. """ - def __init__(self, array, mode="gaussian", clip=True, **kwargs): + def __init__(self, array, mode="gaussian", clip=True, p=1.0, **kwargs): self.array = array self.mode = mode self.clip = clip + self.p = p self.kwargs = kwargs def setup(self): self.enable_autoskip() self.updates(self.array, self.spec[self.array]) + def skip_node(self, request): + return random.random() > self.p + def prepare(self, request): deps = BatchRequest() deps[self.array] = request[self.array].copy() @@ -57,13 +69,11 @@ def process(self, batch, request): seed = request.random_seed try: - raw.data = skimage.util.random_noise( raw.data, mode=self.mode, rng=seed, clip=self.clip, **self.kwargs ).astype(raw.data.dtype) except ValueError: - # legacy version of skimage random_noise raw.data = skimage.util.random_noise( raw.data, mode=self.mode, seed=seed, clip=self.clip, **self.kwargs diff --git a/gunpowder/nodes/shift_augment.py b/gunpowder/nodes/shift_augment.py index 8fe6524b..4fd16d54 100644 --- a/gunpowder/nodes/shift_augment.py +++ b/gunpowder/nodes/shift_augment.py @@ -12,19 +12,22 @@ class ShiftAugment(BatchFilter): - def __init__(self, prob_slip=0, prob_shift=0, sigma=0, shift_axis=0): + def __init__(self, prob_slip=0, prob_shift=0, sigma=0, shift_axis=0, p=1.0): self.prob_slip = prob_slip self.prob_shift = prob_shift self.sigma = sigma self.shift_axis = shift_axis + self.p = p self.ndim = None self.shift_sigmas = None self.shift_array = None self.lcm_voxel_size = None - def prepare(self, request): + def skip_node(self, request): + return random.random() > self.p + def prepare(self, request): self.ndim = request.get_total_roi().dims assert self.shift_axis in range(self.ndim) diff --git a/gunpowder/nodes/simple_augment.py b/gunpowder/nodes/simple_augment.py index dc756e3a..c78aea1c 100644 --- a/gunpowder/nodes/simple_augment.py +++ b/gunpowder/nodes/simple_augment.py @@ -47,6 +47,13 @@ class SimpleAugment(BatchFilter): and attempt to weight them appropriately. A weight of 0 means this axis will never be transposed, a weight of 1 means this axis will always be transposed. + + p (``float``, optional): + + Probability applying the augmentation. Default is 1.0 (always + apply). Should be a float value between 0 and 1. Lowering this value + could be useful for computational efficiency and increasing + augmentation space. """ def __init__( @@ -55,6 +62,7 @@ def __init__( transpose_only=None, mirror_probs=None, transpose_probs=None, + p=1.0, ): self.mirror_only = mirror_only self.mirror_probs = mirror_probs @@ -63,6 +71,7 @@ def __init__( self.mirror_mask = None self.dims = None self.transpose_dims = None + self.p = p def setup(self): self.dims = self.spec.get_total_roi().dims @@ -105,8 +114,10 @@ def setup(self): if valid: self.permutation_dict[k] = v - def prepare(self, request): + def skip_node(self, request): + return random.random() > self.p + def prepare(self, request): self.mirror = [ random.random() < self.mirror_probs[d] if self.mirror_mask[d] else 0 for d in range(self.dims) diff --git a/tests/cases/batch_filter.py b/tests/cases/batch_filter.py new file mode 100644 index 00000000..63288bd5 --- /dev/null +++ b/tests/cases/batch_filter.py @@ -0,0 +1,53 @@ +from .helper_sources import ArraySource +from gunpowder import ( + ArrayKey, + build, + Array, + ArraySpec, + Roi, + Coordinate, + BatchRequest, + BatchFilter, +) + +import numpy as np +import random + + +class DummyNode(BatchFilter): + def __init__(self, array, p=1.0): + self.array = array + self.p = p + + def skip_node(self, request): + return random.random() > self.p + + def process(self, batch, request): + batch[self.array].data = batch[self.array].data + 1 + + +def test_skip(): + raw_key = ArrayKey("RAW") + array = Array( + np.ones((10, 10)), + ArraySpec(Roi((0, 0), (10, 10)), Coordinate(1, 1)), + ) + source = ArraySource(raw_key, array) + + request_1 = BatchRequest(random_seed=1) + request_2 = BatchRequest(random_seed=2) + + request_1.add(raw_key, Coordinate(10, 10)) + request_2.add(raw_key, Coordinate(10, 10)) + + pipeline = source + DummyNode(raw_key, p=0.5) + + with build(pipeline): + batch_1 = pipeline.request_batch(request_1) + batch_2 = pipeline.request_batch(request_2) + + x_1 = batch_1.arrays[raw_key].data + x_2 = batch_2.arrays[raw_key].data + + assert x_1.max() == 2 + assert x_2.max() == 1 From f0c62afbf656f995d33638db0031226c88c66130 Mon Sep 17 00:00:00 2001 From: sheridana Date: Wed, 18 Oct 2023 09:43:15 -0700 Subject: [PATCH 06/74] Revert can_skip to private method --- gunpowder/nodes/batch_filter.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gunpowder/nodes/batch_filter.py b/gunpowder/nodes/batch_filter.py index e2938e8a..4f1e4da2 100644 --- a/gunpowder/nodes/batch_filter.py +++ b/gunpowder/nodes/batch_filter.py @@ -137,7 +137,7 @@ def autoskip_enabled(self): return self._autoskip_enabled def provide(self, request): - skip = self._can_skip(request) or self.skip_node(request) + skip = self.__can_skip(request) or self.skip_node(request) timing_prepare = Timing(self, "prepare") timing_prepare.start() @@ -190,7 +190,7 @@ def provide(self, request): return batch - def _can_skip(self, request): + def __can_skip(self, request): """Check if this filter needs to be run for the given request.""" if not self.autoskip_enabled: From 7f8b8776d7c5d2a0adca578d300f327e1399a9fd Mon Sep 17 00:00:00 2001 From: William Patton Date: Wed, 1 Nov 2023 11:19:14 -0700 Subject: [PATCH 07/74] fix the deform augment test no longer assumes a deformed label will still exist in an array --- tests/cases/deform_augment.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/cases/deform_augment.py b/tests/cases/deform_augment.py index 2134a708..f722b0bb 100644 --- a/tests/cases/deform_augment.py +++ b/tests/cases/deform_augment.py @@ -160,6 +160,9 @@ def test_3d_basics(rotate, spatial_dims, fast_points): loc = (loc - labels.spec.roi.begin) / labels.spec.voxel_size loc = np.array(loc) com = center_of_mass(labels.data == node.id) + if any(np.isnan(com)): + # cannot assume that the rasterized data will exist after defomation + continue assert ( np.linalg.norm(com - loc) < np.linalg.norm(labels.spec.voxel_size) * 2 From f1fd63af3498360f938346f67063d75b9024b6cc Mon Sep 17 00:00:00 2001 From: William Patton Date: Wed, 1 Nov 2023 11:20:31 -0700 Subject: [PATCH 08/74] better bounds on required packages --- pyproject.toml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 11ab82bc..e187ff4f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,16 +32,16 @@ keywords = [] requires-python = ">=3.7" dependencies = [ - "numpy", - "scipy", - "h5py", + "numpy>=1.24", + "scipy>=1.6", + "h5py>=3.10", "scikit-image", "requests", "augment-nd>=0.1.3", "tqdm", - "funlib.geometry", + "funlib.geometry>=0.2", "zarr", - "networkx", + "networkx>=3.1", ] [project.optional-dependencies] From f96aa78bb8c65c1aaffd96783f0cdbff9bb1fd18 Mon Sep 17 00:00:00 2001 From: William Patton Date: Wed, 1 Nov 2023 20:51:43 -0700 Subject: [PATCH 09/74] ignore missing imports from packages that don't provide type hints --- mypy.ini | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 mypy.ini diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 00000000..6daa39e0 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,37 @@ +[mypy] + +# ext +[mypy-dvision.*] +ignore_missing_imports = True +[mypy-pyklb.*] +ignore_missing_imports = True +[mypy-malis.*] +ignore_missing_imports = True +[mypy-haiku.*] +ignore_missing_imports = True +[mypy-optax.*] +ignore_missing_imports = True + +# dependencies +[mypy-tensorflow.*] +ignore_missing_imports = True +[mypy-tensorboardX.*] +ignore_missing_imports = True +[mypy-torch.*] +ignore_missing_imports = True +[mypy-jax.*] +ignore_missing_imports = True +[mypy-daisy.*] +ignore_missing_imports = True +[mypy-scipy.*] +ignore_missing_imports = True +[mypy-h5py.*] +ignore_missing_imports = True +[mypy-augment.*] +ignore_missing_imports = True +[mypy-zarr.*] +ignore_missing_imports = True +[mypy-networkx.*] +ignore_missing_imports = True +[mypy-Queue.*] +ignore_missing_imports = True \ No newline at end of file From 98bbe7165c8ceca7c4b86ccc08a20404fc6e7d33 Mon Sep 17 00:00:00 2001 From: William Patton Date: Wed, 1 Nov 2023 20:57:03 -0700 Subject: [PATCH 10/74] fix typehint mistakes --- gunpowder/ext/__init__.py | 2 ++ gunpowder/jax/nodes/predict.py | 6 +++--- gunpowder/jax/nodes/train.py | 6 +++--- gunpowder/nodes/deform_augment.py | 5 +++-- gunpowder/nodes/zarr_write.py | 4 ++-- gunpowder/torch/nodes/predict.py | 8 ++++---- gunpowder/torch/nodes/train.py | 19 ++++++++++++------- 7 files changed, 29 insertions(+), 21 deletions(-) diff --git a/gunpowder/ext/__init__.py b/gunpowder/ext/__init__.py index 7aec50c9..fdfcfa02 100644 --- a/gunpowder/ext/__init__.py +++ b/gunpowder/ext/__init__.py @@ -3,6 +3,7 @@ import traceback import sys +from typing import Optional, Any logger = logging.getLogger(__name__) @@ -58,6 +59,7 @@ def __getattr__(self, item): except ImportError as e: augment = NoSuchModule("augment") +ZarrFile: Optional[Any] = None try: import zarr from .zarr_file import ZarrFile diff --git a/gunpowder/jax/nodes/predict.py b/gunpowder/jax/nodes/predict.py index 4c46f233..496d0fd0 100644 --- a/gunpowder/jax/nodes/predict.py +++ b/gunpowder/jax/nodes/predict.py @@ -6,7 +6,7 @@ import pickle import logging -from typing import Dict, Union +from typing import Dict, Union, Optional logger = logging.getLogger(__name__) @@ -52,8 +52,8 @@ def __init__( model: GenericJaxModel, inputs: Dict[str, ArrayKey], outputs: Dict[Union[str, int], ArrayKey], - array_specs: Dict[ArrayKey, ArraySpec] = None, - checkpoint: str = None, + array_specs: Optional[Dict[ArrayKey, ArraySpec]] = None, + checkpoint: Optional[str] = None, spawn_subprocess=False, ): self.array_specs = array_specs if array_specs is not None else {} diff --git a/gunpowder/jax/nodes/train.py b/gunpowder/jax/nodes/train.py index 4d1f17a3..9621b129 100644 --- a/gunpowder/jax/nodes/train.py +++ b/gunpowder/jax/nodes/train.py @@ -11,7 +11,7 @@ from gunpowder.nodes.generic_train import GenericTrain from gunpowder.jax import GenericJaxModel -from typing import Dict, Union, Optional +from typing import Dict, Union, Optional, Any logger = logging.getLogger(__name__) @@ -108,7 +108,7 @@ def __init__( checkpoint_basename: str = "model", save_every: int = 2000, keep_n_checkpoints: Optional[int] = None, - log_dir: str = None, + log_dir: Optional[str] = None, log_every: int = 1, spawn_subprocess: bool = False, n_devices: Optional[int] = None, @@ -141,7 +141,7 @@ def __init__( if log_dir is not None: logger.warning("log_dir given, but tensorboardX is not installed") - self.intermediate_layers = {} + self.intermediate_layers: dict[ArrayKey, Any] = {} self.validate_fn = validate_fn self.validate_every = validate_every diff --git a/gunpowder/nodes/deform_augment.py b/gunpowder/nodes/deform_augment.py index cdf5eeff..0291a367 100644 --- a/gunpowder/nodes/deform_augment.py +++ b/gunpowder/nodes/deform_augment.py @@ -21,6 +21,7 @@ import logging import math import random +from typing import Optional logger = logging.getLogger(__name__) @@ -93,8 +94,8 @@ def __init__( spatial_dims=3, use_fast_points_transform=False, recompute_missing_points=True, - transform_key: ArrayKey = None, - graph_raster_voxel_size: Coordinate = None, + transform_key: Optional[ArrayKey] = None, + graph_raster_voxel_size: Optional[Coordinate] = None, ): self.control_point_spacing = Coordinate(control_point_spacing) self.jitter_sigma = Coordinate(jitter_sigma) diff --git a/gunpowder/nodes/zarr_write.py b/gunpowder/nodes/zarr_write.py index 35965b6d..3beba3ae 100644 --- a/gunpowder/nodes/zarr_write.py +++ b/gunpowder/nodes/zarr_write.py @@ -5,10 +5,10 @@ from zarr import N5FSStore, N5Store from .batch_filter import BatchFilter +from gunpowder.array import ArrayKey from gunpowder.batch_request import BatchRequest from gunpowder.coordinate import Coordinate from gunpowder.roi import Roi -from gunpowder.coordinate import Coordinate from gunpowder.ext import ZarrFile import logging @@ -71,7 +71,7 @@ def __init__( else: self.dataset_dtypes = dataset_dtypes - self.dataset_offsets = {} + self.dataset_offsets: dict[ArrayKey, Coordinate] = {} def _get_voxel_size(self, dataset): if "resolution" not in dataset.attrs: diff --git a/gunpowder/torch/nodes/predict.py b/gunpowder/torch/nodes/predict.py index 3e5ba8f1..9db42a4c 100644 --- a/gunpowder/torch/nodes/predict.py +++ b/gunpowder/torch/nodes/predict.py @@ -4,7 +4,7 @@ from gunpowder.nodes.generic_predict import GenericPredict import logging -from typing import Dict, Union +from typing import Dict, Union, Optional, Any logger = logging.getLogger(__name__) @@ -60,8 +60,8 @@ def __init__( model, inputs: Dict[str, ArrayKey], outputs: Dict[Union[str, int], ArrayKey], - array_specs: Dict[ArrayKey, ArraySpec] = None, - checkpoint: str = None, + array_specs: Optional[Dict[ArrayKey, ArraySpec]] = None, + checkpoint: Optional[str] = None, device="cuda", spawn_subprocess=False, ): @@ -82,7 +82,7 @@ def __init__( self.model = model self.checkpoint = checkpoint - self.intermediate_layers = {} + self.intermediate_layers: dict[ArrayKey, Any] = {} self.register_hooks() def start(self): diff --git a/gunpowder/torch/nodes/train.py b/gunpowder/torch/nodes/train.py index ae3f184e..1103ce86 100644 --- a/gunpowder/torch/nodes/train.py +++ b/gunpowder/torch/nodes/train.py @@ -6,7 +6,8 @@ from gunpowder.ext import torch, tensorboardX, NoSuchModule from gunpowder.nodes.generic_train import GenericTrain -from typing import Dict, Union, Optional +from typing import Dict, Union, Optional, Any +import itertools logger = logging.getLogger(__name__) @@ -92,7 +93,7 @@ def __init__( array_specs: Optional[Dict[ArrayKey, ArraySpec]] = None, checkpoint_basename: str = "model", save_every: int = 2000, - log_dir: str = None, + log_dir: Optional[str] = None, log_every: int = 1, spawn_subprocess: bool = False, ): @@ -104,12 +105,16 @@ def __init__( # not yet implemented gradients = gradients - inputs.update( - {k: v for k, v in loss_inputs.items() if v not in outputs.values()} - ) + all_inputs = { + { + k: v + for k, v in itertools.chain(inputs.items(), loss_inputs.items()) + if v not in outputs.values() + } + } super(Train, self).__init__( - inputs, outputs, gradients, array_specs, spawn_subprocess=spawn_subprocess + all_inputs, outputs, gradients, array_specs, spawn_subprocess=spawn_subprocess ) self.model = model @@ -129,7 +134,7 @@ def __init__( if log_dir is not None: logger.warning("log_dir given, but tensorboardX is not installed") - self.intermediate_layers = {} + self.intermediate_layers: dict[ArrayKey, Any] = {} self.register_hooks() def register_hooks(self): From 3ba99da7c5b759b690f4b1726d79c74656a79b10 Mon Sep 17 00:00:00 2001 From: William Patton Date: Wed, 1 Nov 2023 20:57:16 -0700 Subject: [PATCH 11/74] format pyproject.toml --- pyproject.toml | 90 +++++++++++++++++++++++--------------------------- 1 file changed, 42 insertions(+), 48 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e187ff4f..7168e7b2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,74 +6,68 @@ requires = ["setuptools", "wheel"] name = "gunpowder" description = "A library to facilitate machine learning on large, multi-dimensional images." authors = [ - {name = "Jan Funke", email = "funkej@hhmi.org"}, - {name = "William Patton", email = "pattonw@hhmi.org"}, - {name = "Renate Krause"}, - {name = "Julia Buhmann"}, - {name = "Rodrigo Ceballos Lentini"}, - {name = "William Grisaitis"}, - {name = "Chris Barnes"}, - {name = "Caroline Malin-Mayor"}, - {name = "Larissa Heinrich"}, - {name = "Philipp Hanslovsky"}, - {name = "Sherry Ding"}, - {name = "Andrew Champion"}, - {name = "Arlo Sheridan"}, - {name = "Constantin Pape"}, + { name = "Jan Funke", email = "funkej@hhmi.org" }, + { name = "William Patton", email = "pattonw@hhmi.org" }, + { name = "Renate Krause" }, + { name = "Julia Buhmann" }, + { name = "Rodrigo Ceballos Lentini" }, + { name = "William Grisaitis" }, + { name = "Chris Barnes" }, + { name = "Caroline Malin-Mayor" }, + { name = "Larissa Heinrich" }, + { name = "Philipp Hanslovsky" }, + { name = "Sherry Ding" }, + { name = "Andrew Champion" }, + { name = "Arlo Sheridan" }, + { name = "Constantin Pape" }, ] -license = {text = "MIT"} +license = { text = "MIT" } readme = "README.md" dynamic = ["version"] -classifiers = [ - "Programming Language :: Python :: 3", -] +classifiers = ["Programming Language :: Python :: 3"] keywords = [] requires-python = ">=3.7" dependencies = [ - "numpy>=1.24", - "scipy>=1.6", - "h5py>=3.10", - "scikit-image", - "requests", - "augment-nd>=0.1.3", - "tqdm", - "funlib.geometry>=0.2", - "zarr", - "networkx>=3.1", + "numpy>=1.24", + "scipy>=1.6", + "h5py>=3.10", + "scikit-image", + "requests", + "augment-nd>=0.1.3", + "tqdm", + "funlib.geometry>=0.2", + "zarr", + "networkx>=3.1", ] [project.optional-dependencies] -dev = [ - "pytest", - "pytest-cov", - "flake8", -] +dev = ["pytest", "pytest-cov", "flake8", "mypy"] docs = [ - "sphinx", - "sphinx_rtd_theme", - "sphinx_togglebutton", - "tomli", - "jupyter_sphinx", - "ipykernel", - "matplotlib", - "torch", + "sphinx", + "sphinx_rtd_theme", + "sphinx_togglebutton", + "tomli", + "jupyter_sphinx", + "ipykernel", + "matplotlib", + "torch", ] pytorch = ['torch'] tensorflow = [ - # TF doesn't provide <2.0 wheels for py>=3.8 on pypi - 'tensorflow<2.0; python_version<"3.8"', # https://stackoverflow.com/a/72493690 - 'protobuf==3.20.*; python_version=="3.7"', + # TF doesn't provide <2.0 wheels for py>=3.8 on pypi + 'tensorflow<2.0; python_version<"3.8"', # https://stackoverflow.com/a/72493690 + 'protobuf==3.20.*; python_version=="3.7"', ] full = [ - 'torch', - 'tensorflow<2.0; python_version<"3.8"', - 'protobuf==3.20.*; python_version=="3.7"', + 'torch', + 'tensorflow<2.0; python_version<"3.8"', + 'protobuf==3.20.*; python_version=="3.7"', ] [tool.setuptools.dynamic] -version = {attr = "gunpowder.version_info.__version__"} +version = { attr = "gunpowder.version_info.__version__" } [tool.black] target_version = ['py36', 'py37', 'py38', 'py39', 'py310'] From 7a10397300acd4ebd33445bf4b1bc8a6e1949e4e Mon Sep 17 00:00:00 2001 From: William Patton Date: Wed, 1 Nov 2023 20:57:44 -0700 Subject: [PATCH 12/74] black format --- gunpowder/nodes/random_location.py | 4 +--- gunpowder/nodes/zarr_source.py | 5 ++--- gunpowder/torch/nodes/train.py | 6 +++++- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/gunpowder/nodes/random_location.py b/gunpowder/nodes/random_location.py index fccbd6cb..a86a2c69 100644 --- a/gunpowder/nodes/random_location.py +++ b/gunpowder/nodes/random_location.py @@ -383,9 +383,7 @@ def __select_random_location_with_points( logger.debug("belongs to lcm voxel %s", lcm_location) # align the point request ROI with lcm voxel grid - lcm_roi = request_points_roi.snap_to_grid( - lcm_voxel_size, - mode="shrink") + lcm_roi = request_points_roi.snap_to_grid(lcm_voxel_size, mode="shrink") lcm_roi = lcm_roi / lcm_voxel_size logger.debug("Point request ROI: %s", request_points_roi) logger.debug("Point request lcm ROI shape: %s", lcm_roi.shape) diff --git a/gunpowder/nodes/zarr_source.py b/gunpowder/nodes/zarr_source.py index 812769f3..b7133580 100644 --- a/gunpowder/nodes/zarr_source.py +++ b/gunpowder/nodes/zarr_source.py @@ -107,9 +107,8 @@ def _get_offset(self, dataset): def _rev_metadata(self): with ZarrFile(self.store, mode="a") as store: - return ( - isinstance(store.chunk_store, N5Store) or - isinstance(store.chunk_store, N5FSStore) + return isinstance(store.chunk_store, N5Store) or isinstance( + store.chunk_store, N5FSStore ) def _open_file(self, store): diff --git a/gunpowder/torch/nodes/train.py b/gunpowder/torch/nodes/train.py index 1103ce86..2eebcff4 100644 --- a/gunpowder/torch/nodes/train.py +++ b/gunpowder/torch/nodes/train.py @@ -114,7 +114,11 @@ def __init__( } super(Train, self).__init__( - all_inputs, outputs, gradients, array_specs, spawn_subprocess=spawn_subprocess + all_inputs, + outputs, + gradients, + array_specs, + spawn_subprocess=spawn_subprocess, ) self.model = model From 1615a7d7adddc88cb916f2e0e59d846c7beaac3c Mon Sep 17 00:00:00 2001 From: William Patton Date: Wed, 1 Nov 2023 22:04:51 -0700 Subject: [PATCH 13/74] move register hooks to the start method This is to get around local functions (i.e. the hooks) not being pickle-able which we need for the "spawn" start function (spawn is the default on windows and recent macs) --- gunpowder/torch/nodes/predict.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/gunpowder/torch/nodes/predict.py b/gunpowder/torch/nodes/predict.py index 9db42a4c..585ebc34 100644 --- a/gunpowder/torch/nodes/predict.py +++ b/gunpowder/torch/nodes/predict.py @@ -83,7 +83,6 @@ def __init__( self.checkpoint = checkpoint self.intermediate_layers: dict[ArrayKey, Any] = {} - self.register_hooks() def start(self): self.use_cuda = torch.cuda.is_available() and self.device_string == "cuda" @@ -106,6 +105,8 @@ def start(self): else: self.model.load_state_dict(checkpoint) + self.register_hooks() + def predict(self, batch, request): inputs = self.get_inputs(batch) with torch.no_grad(): From 4a8ccf43ae7c5c7a19fef887c2b4c4f3fcfeef6d Mon Sep 17 00:00:00 2001 From: William Patton Date: Wed, 1 Nov 2023 22:05:00 -0700 Subject: [PATCH 14/74] fix typo --- gunpowder/torch/nodes/train.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/gunpowder/torch/nodes/train.py b/gunpowder/torch/nodes/train.py index 2eebcff4..3f688929 100644 --- a/gunpowder/torch/nodes/train.py +++ b/gunpowder/torch/nodes/train.py @@ -106,11 +106,9 @@ def __init__( # not yet implemented gradients = gradients all_inputs = { - { - k: v - for k, v in itertools.chain(inputs.items(), loss_inputs.items()) - if v not in outputs.values() - } + k: v + for k, v in itertools.chain(inputs.items(), loss_inputs.items()) + if v not in outputs.values() } super(Train, self).__init__( From a001d7d3f1d58b195d9835ab6a0ea05351b0459d Mon Sep 17 00:00:00 2001 From: William Patton Date: Wed, 1 Nov 2023 22:05:11 -0700 Subject: [PATCH 15/74] support non-spatial arrays in ArraySource --- tests/cases/helper_sources.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/cases/helper_sources.py b/tests/cases/helper_sources.py index 219044b1..630333d6 100644 --- a/tests/cases/helper_sources.py +++ b/tests/cases/helper_sources.py @@ -13,7 +13,10 @@ def setup(self): def provide(self, request): outputs = Batch() - outputs[self.key] = copy.deepcopy(self.array.crop(request[self.key].roi)) + if self.array.spec.nonspatial: + outputs[self.key] = copy.deepcopy(self.array) + else: + outputs[self.key] = copy.deepcopy(self.array.crop(request[self.key].roi)) return outputs From 98a2e34062b6e6e50ed251f66499e5652b9ed474 Mon Sep 17 00:00:00 2001 From: William Patton Date: Wed, 1 Nov 2023 22:05:25 -0700 Subject: [PATCH 16/74] overhaul torch tests --- tests/cases/torch_train.py | 465 +++++++++++++++++-------------------- 1 file changed, 212 insertions(+), 253 deletions(-) diff --git a/tests/cases/torch_train.py b/tests/cases/torch_train.py index c213a1c9..c9eed28c 100644 --- a/tests/cases/torch_train.py +++ b/tests/cases/torch_train.py @@ -1,4 +1,4 @@ -from .provider_test import ProviderTest +from .helper_sources import ArraySource from gunpowder import ( BatchProvider, BatchRequest, @@ -11,6 +11,7 @@ Batch, Scan, PreCache, + MergeProvider, build, ) from gunpowder.ext import torch, NoSuchModule @@ -21,210 +22,168 @@ import logging -class ExampleTorchTrain2DSource(BatchProvider): - def __init__(self): - pass +# Example 2D source +def example_2d_source(array_key: ArrayKey): + array_spec = ArraySpec( + roi=Roi((0, 0), (17, 17)), + dtype=np.float32, + interpolatable=True, + voxel_size=(1, 1), + ) + data = np.array(list(range(17)), dtype=np.float32).reshape([17, 1]) + data = data + data.T + array = Array(data, array_spec) + return ArraySource(array_key, array) - def setup(self): - spec = ArraySpec( - roi=Roi((0, 0), (17, 17)), - dtype=np.float32, - interpolatable=True, - voxel_size=(1, 1), - ) - self.provides(ArrayKeys.A, spec) - def provide(self, request): - batch = Batch() +def example_train_source(a_key, b_key, c_key): + spec1 = ArraySpec( + roi=Roi((0, 0), (2, 2)), + dtype=np.float32, + interpolatable=True, + voxel_size=(1, 1), + ) + spec2 = ArraySpec(nonspatial=True) - spec = self.spec[ArrayKeys.A] + data1 = np.array([[0, 1], [2, 3]], dtype=np.float32) + data2 = np.array([1], dtype=np.float32) - x = np.array(list(range(17)), dtype=np.float32).reshape([17, 1]) - x = x + x.T + source_a = ArraySource(a_key, Array(data1, spec1)) + source_b = ArraySource(b_key, Array(data1, spec1)) + source_c = ArraySource(c_key, Array(data2, spec2)) - batch.arrays[ArrayKeys.A] = Array(x, spec).crop(request[ArrayKeys.A].roi) + return (source_a, source_b, source_c) + MergeProvider() - return batch +if torch is not NoSuchModule: -class ExampleTorchTrainSource(BatchProvider): - def setup(self): - spec = ArraySpec( - roi=Roi((0, 0), (2, 2)), - dtype=np.float32, - interpolatable=True, - voxel_size=(1, 1), - ) - self.provides(ArrayKeys.A, spec) - self.provides(ArrayKeys.B, spec) - - spec = ArraySpec(nonspatial=True) - self.provides(ArrayKeys.C, spec) - - def provide(self, request): - batch = Batch() - - spec = self.spec[ArrayKeys.A] - spec.roi = request[ArrayKeys.A].roi - - batch.arrays[ArrayKeys.A] = Array( - np.array([[0, 1], [2, 3]], dtype=np.float32), spec - ) - - spec = self.spec[ArrayKeys.B] - spec.roi = request[ArrayKeys.B].roi - - batch.arrays[ArrayKeys.B] = Array( - np.array([[0, 1], [2, 3]], dtype=np.float32), spec - ) - - spec = self.spec[ArrayKeys.C] - - batch.arrays[ArrayKeys.C] = Array(np.array([1], dtype=np.float32), spec) + class ExampleLinearModel(torch.nn.Module): + def __init__(self): + super(ExampleLinearModel, self).__init__() + self.linear = torch.nn.Linear(4, 1, False) + self.linear.weight.data = torch.Tensor([0, 1, 2, 3]) - return batch + def forward(self, a, b): + a = a.reshape(-1) + b = b.reshape(-1) + c_pred = self.linear(a * b) + d_pred = c_pred * 2 + return d_pred @skipIf(isinstance(torch, NoSuchModule), "torch is not installed") -class TestTorchTrain(ProviderTest): - def test_output(self): - logging.getLogger("gunpowder.torch.nodes.train").setLevel(logging.INFO) - - checkpoint_basename = self.path_to("model") - - ArrayKey("A") - ArrayKey("B") - ArrayKey("C") - ArrayKey("C_PREDICTED") - ArrayKey("C_GRADIENT") - - class ExampleModel(torch.nn.Module): - def __init__(self): - super(ExampleModel, self).__init__() - self.linear = torch.nn.Linear(4, 1, False) - - def forward(self, a, b): - a = a.reshape(-1) - b = b.reshape(-1) - return self.linear(a * b) - - model = ExampleModel() - loss = torch.nn.MSELoss() - optimizer = torch.optim.SGD(model.parameters(), lr=1e-7, momentum=0.999) - - source = ExampleTorchTrainSource() - train = Train( - model=model, - optimizer=optimizer, - loss=loss, - inputs={"a": ArrayKeys.A, "b": ArrayKeys.B}, - loss_inputs={0: ArrayKeys.C_PREDICTED, 1: ArrayKeys.C}, - outputs={0: ArrayKeys.C_PREDICTED}, - gradients={0: ArrayKeys.C_GRADIENT}, - array_specs={ - ArrayKeys.C_PREDICTED: ArraySpec(nonspatial=True), - ArrayKeys.C_GRADIENT: ArraySpec(nonspatial=True), - }, - checkpoint_basename=checkpoint_basename, - save_every=100, - spawn_subprocess=True, - ) - pipeline = source + train - - request = BatchRequest( - { - ArrayKeys.A: ArraySpec(roi=Roi((0, 0), (2, 2))), - ArrayKeys.B: ArraySpec(roi=Roi((0, 0), (2, 2))), - ArrayKeys.C: ArraySpec(nonspatial=True), - ArrayKeys.C_PREDICTED: ArraySpec(nonspatial=True), - ArrayKeys.C_GRADIENT: ArraySpec(nonspatial=True), - } - ) - - # train for a couple of iterations - with build(pipeline): +def test_loss_drops(tmpdir): + checkpoint_basename = str(tmpdir / "model") + + a_key = ArrayKey("A") + b_key = ArrayKey("B") + c_key = ArrayKey("C") + c_predicted_key = ArrayKey("C_PREDICTED") + c_gradient_key = ArrayKey("C_GRADIENT") + + model = ExampleLinearModel() + loss = torch.nn.MSELoss() + optimizer = torch.optim.SGD(model.parameters(), lr=1e-7, momentum=0.999) + + source = example_train_source(a_key, b_key, c_key) + train = Train( + model=model, + optimizer=optimizer, + loss=loss, + inputs={"a": a_key, "b": b_key}, + loss_inputs={0: c_predicted_key, 1: c_key}, + outputs={0: c_predicted_key}, + gradients={0: c_gradient_key}, + array_specs={ + c_predicted_key: ArraySpec(nonspatial=True), + c_gradient_key: ArraySpec(nonspatial=True), + }, + checkpoint_basename=checkpoint_basename, + save_every=100, + spawn_subprocess=False, + ) + pipeline = source + train + + request = BatchRequest( + { + a_key: ArraySpec(roi=Roi((0, 0), (2, 2))), + b_key: ArraySpec(roi=Roi((0, 0), (2, 2))), + c_key: ArraySpec(nonspatial=True), + c_predicted_key: ArraySpec(nonspatial=True), + c_gradient_key: ArraySpec(nonspatial=True), + } + ) + + # train for a couple of iterations + with build(pipeline): + batch = pipeline.request_batch(request) + + for i in range(200 - 1): + loss1 = batch.loss batch = pipeline.request_batch(request) + loss2 = batch.loss + assert loss2 < loss1 - for i in range(200 - 1): - loss1 = batch.loss - batch = pipeline.request_batch(request) - loss2 = batch.loss - self.assertLess(loss2, loss1) - - # resume training - with build(pipeline): - for i in range(100): - loss1 = batch.loss - batch = pipeline.request_batch(request) - loss2 = batch.loss - self.assertLess(loss2, loss1) + # resume training + with build(pipeline): + for i in range(100): + loss1 = batch.loss + batch = pipeline.request_batch(request) + loss2 = batch.loss + assert loss2 < loss1 @skipIf(isinstance(torch, NoSuchModule), "torch is not installed") -class TestTorchPredict(ProviderTest): - def test_output(self): - logging.getLogger("gunpowder.torch.nodes.predict").setLevel(logging.INFO) - - a = ArrayKey("A") - b = ArrayKey("B") - c = ArrayKey("C") - c_pred = ArrayKey("C_PREDICTED") - d_pred = ArrayKey("D_PREDICTED") - - class ExampleModel(torch.nn.Module): - def __init__(self): - super(ExampleModel, self).__init__() - self.linear = torch.nn.Linear(4, 1, False) - self.linear.weight.data = torch.Tensor([1, 1, 1, 1]) - - def forward(self, a, b): - a = a.reshape(-1) - b = b.reshape(-1) - c_pred = self.linear(a * b) - d_pred = c_pred * 2 - return d_pred - - model = ExampleModel() - - source = ExampleTorchTrainSource() - predict = Predict( - model=model, - inputs={"a": a, "b": b}, - outputs={"linear": c_pred, 0: d_pred}, - array_specs={ - c: ArraySpec(nonspatial=True), - c_pred: ArraySpec(nonspatial=True), - d_pred: ArraySpec(nonspatial=True), - }, - spawn_subprocess=True, - ) - pipeline = source + predict - - request = BatchRequest( - { - a: ArraySpec(roi=Roi((0, 0), (2, 2))), - b: ArraySpec(roi=Roi((0, 0), (2, 2))), - c: ArraySpec(nonspatial=True), - c_pred: ArraySpec(nonspatial=True), - d_pred: ArraySpec(nonspatial=True), - } - ) - - # train for a couple of iterations - with build(pipeline): - batch1 = pipeline.request_batch(request) - batch2 = pipeline.request_batch(request) - - assert np.isclose(batch1[c_pred].data, batch2[c_pred].data) - assert np.isclose(batch1[c_pred].data, 1 + 4 + 9) - assert np.isclose(batch2[d_pred].data, 2 * (1 + 4 + 9)) - - -if not isinstance(torch, NoSuchModule): - - class ExampleModel(torch.nn.Module): +def test_output(): + logging.getLogger("gunpowder.torch.nodes.predict").setLevel(logging.INFO) + + a_key = ArrayKey("A") + b_key = ArrayKey("B") + c_key = ArrayKey("C") + c_pred = ArrayKey("C_PREDICTED") + d_pred = ArrayKey("D_PREDICTED") + + model = ExampleLinearModel() + + source = example_train_source(a_key, b_key, c_key) + predict = Predict( + model=model, + inputs={"a": a_key, "b": b_key}, + outputs={"linear": c_pred, 0: d_pred}, + array_specs={ + c_key: ArraySpec(nonspatial=True), + c_pred: ArraySpec(nonspatial=True), + d_pred: ArraySpec(nonspatial=True), + }, + spawn_subprocess=True, + ) + pipeline = source + predict + + request = BatchRequest( + { + a_key: ArraySpec(roi=Roi((0, 0), (2, 2))), + b_key: ArraySpec(roi=Roi((0, 0), (2, 2))), + c_key: ArraySpec(nonspatial=True), + c_pred: ArraySpec(nonspatial=True), + d_pred: ArraySpec(nonspatial=True), + } + ) + + # train for a couple of iterations + with build(pipeline): + batch1 = pipeline.request_batch(request) + batch2 = pipeline.request_batch(request) + + assert np.isclose(batch1[c_pred].data, batch2[c_pred].data) + assert np.isclose(batch1[c_pred].data, 1 + 4 * 2 + 9 * 3) + assert np.isclose(batch2[d_pred].data, 2 * (1 + 4 * 2 + 9 * 3)) + + +if torch is not NoSuchModule: + + class Example2DModel(torch.nn.Module): def __init__(self): - super(ExampleModel, self).__init__() + super(Example2DModel, self).__init__() self.linear = torch.nn.Conv2d(1, 1, 3) def forward(self, a): @@ -236,69 +195,69 @@ def forward(self, a): @skipIf(isinstance(torch, NoSuchModule), "torch is not installed") -class TestTorchPredictMultiprocessing(ProviderTest): - def test_scan(self): - logging.getLogger("gunpowder.torch.nodes.predict").setLevel(logging.INFO) - - a = ArrayKey("A") - pred = ArrayKey("PRED") - - model = ExampleModel() - - reference_request = BatchRequest() - reference_request[a] = ArraySpec(roi=Roi((0, 0), (7, 7))) - reference_request[pred] = ArraySpec(roi=Roi((1, 1), (5, 5))) - - source = ExampleTorchTrain2DSource() - predict = Predict( - model=model, - inputs={"a": a}, - outputs={0: pred}, - array_specs={pred: ArraySpec()}, - ) - pipeline = source + predict + Scan(reference_request, num_workers=2) - - request = BatchRequest( - { - a: ArraySpec(roi=Roi((0, 0), (17, 17))), - pred: ArraySpec(roi=Roi((0, 0), (15, 15))), - } - ) - - # train for a couple of iterations - with build(pipeline): - batch = pipeline.request_batch(request) - assert pred in batch - - def test_precache(self): - logging.getLogger("gunpowder.torch.nodes.predict").setLevel(logging.INFO) - - a = ArrayKey("A") - pred = ArrayKey("PRED") - - model = ExampleModel() - - reference_request = BatchRequest() - reference_request[a] = ArraySpec(roi=Roi((0, 0), (7, 7))) - reference_request[pred] = ArraySpec(roi=Roi((1, 1), (5, 5))) - - source = ExampleTorchTrain2DSource() - predict = Predict( - model=model, - inputs={"a": a}, - outputs={0: pred}, - array_specs={pred: ArraySpec()}, - ) - pipeline = source + predict + PreCache(cache_size=3, num_workers=2) - - request = BatchRequest( - { - a: ArraySpec(roi=Roi((0, 0), (17, 17))), - pred: ArraySpec(roi=Roi((0, 0), (15, 15))), - } - ) - - # train for a couple of iterations - with build(pipeline): - batch = pipeline.request_batch(request) - assert pred in batch +def test_scan(): + logging.getLogger("gunpowder.torch.nodes.predict").setLevel(logging.INFO) + + a_key = ArrayKey("A") + pred = ArrayKey("PRED") + + model = Example2DModel() + + reference_request = BatchRequest() + reference_request[a_key] = ArraySpec(roi=Roi((0, 0), (7, 7))) + reference_request[pred] = ArraySpec(roi=Roi((1, 1), (5, 5))) + + source = example_2d_source(a_key) + predict = Predict( + model=model, + inputs={"a": a_key}, + outputs={0: pred}, + array_specs={pred: ArraySpec()}, + ) + pipeline = source + predict + Scan(reference_request, num_workers=2) + + request = BatchRequest( + { + a_key: ArraySpec(roi=Roi((0, 0), (17, 17))), + pred: ArraySpec(roi=Roi((0, 0), (15, 15))), + } + ) + + # train for a couple of iterations + with build(pipeline): + batch = pipeline.request_batch(request) + assert pred in batch + + +def test_precache(): + logging.getLogger("gunpowder.torch.nodes.predict").setLevel(logging.INFO) + + a_key = ArrayKey("A") + pred = ArrayKey("PRED") + + model = Example2DModel() + + reference_request = BatchRequest() + reference_request[a_key] = ArraySpec(roi=Roi((0, 0), (7, 7))) + reference_request[pred] = ArraySpec(roi=Roi((1, 1), (5, 5))) + + source = example_2d_source(a_key) + predict = Predict( + model=model, + inputs={"a": a_key}, + outputs={0: pred}, + array_specs={pred: ArraySpec()}, + ) + pipeline = source + predict + PreCache(cache_size=3, num_workers=2) + + request = BatchRequest( + { + a_key: ArraySpec(roi=Roi((0, 0), (17, 17))), + pred: ArraySpec(roi=Roi((0, 0), (15, 15))), + } + ) + + # train for a couple of iterations + with build(pipeline): + batch = pipeline.request_batch(request) + assert pred in batch From 96ff2f13c34b0aa6fba9d96d5ccea93d625882e4 Mon Sep 17 00:00:00 2001 From: William Patton Date: Wed, 1 Nov 2023 22:06:40 -0700 Subject: [PATCH 17/74] remove multiprocess set start method monkey patch We want to test with both fork and spawn start methods, but this seems to interfere with the torch tests --- tests/conftest.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index a8f65ea1..1386c6b8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,10 +6,10 @@ # cannot parametrize unittest.TestCase. We should test both # fork and spawn but I'm not sure how to. # @pytest.fixture(params=["fork", "spawn"], autouse=True) -@pytest.fixture(autouse=True) -def context(monkeypatch): - ctx = mp.get_context("spawn") - monkeypatch.setattr(mp, "Queue", ctx.Queue) - monkeypatch.setattr(mp, "Process", ctx.Process) - monkeypatch.setattr(mp, "Event", ctx.Event) - monkeypatch.setattr(mp, "Value", ctx.Value) +# @pytest.fixture(autouse=True) +# def context(monkeypatch): +# ctx = mp.get_context("spawn") +# monkeypatch.setattr(mp, "Queue", ctx.Queue) +# monkeypatch.setattr(mp, "Process", ctx.Process) +# monkeypatch.setattr(mp, "Event", ctx.Event) +# monkeypatch.setattr(mp, "Value", ctx.Value) From 75ff6ffd372fcfc2c8dc4222f68f157f13f9aea7 Mon Sep 17 00:00:00 2001 From: William Patton Date: Thu, 2 Nov 2023 10:09:33 -0700 Subject: [PATCH 18/74] only deploy docs on tagged commits to main --- .github/workflows/publish-docs.yaml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/publish-docs.yaml b/.github/workflows/publish-docs.yaml index 0f9b78bb..efeb8e85 100644 --- a/.github/workflows/publish-docs.yaml +++ b/.github/workflows/publish-docs.yaml @@ -3,8 +3,7 @@ name: Deploy Docs to GitHub Pages on: push: branches: [main] - pull_request: - branches: [main] + tags: "*" workflow_dispatch: # Allow this job to clone the repo and create a page deployment From 865fe539d38a859dce090a0f84c471362f3afa73 Mon Sep 17 00:00:00 2001 From: William Patton Date: Thu, 2 Nov 2023 10:18:37 -0700 Subject: [PATCH 19/74] minor black formatting and configuration changes --- gunpowder/nodes/batch_provider.py | 1 - gunpowder/nodes/deform_augment.py | 1 - gunpowder/nodes/elastic_augment.py | 1 - gunpowder/nodes/noise_augment.py | 2 -- gunpowder/nodes/random_location.py | 1 - gunpowder/nodes/random_provider.py | 1 - gunpowder/nodes/reject.py | 1 - gunpowder/nodes/shift_augment.py | 1 - gunpowder/nodes/simple_augment.py | 1 - gunpowder/nodes/stack.py | 1 - pyproject.toml | 2 +- 11 files changed, 1 insertion(+), 12 deletions(-) diff --git a/gunpowder/nodes/batch_provider.py b/gunpowder/nodes/batch_provider.py index dc641c8e..1f6b9dc8 100644 --- a/gunpowder/nodes/batch_provider.py +++ b/gunpowder/nodes/batch_provider.py @@ -174,7 +174,6 @@ def request_batch(self, request): batch = None try: - self.set_seeds(request) logger.debug("%s got request %s", self.name(), request) diff --git a/gunpowder/nodes/deform_augment.py b/gunpowder/nodes/deform_augment.py index 0291a367..6d7e23af 100644 --- a/gunpowder/nodes/deform_augment.py +++ b/gunpowder/nodes/deform_augment.py @@ -130,7 +130,6 @@ def setup(self): self.provides(self.transform_key, spec) def prepare(self, request): - # get the total ROI of all requests total_roi = request.get_total_roi() logger.debug("total ROI is %s" % total_roi) diff --git a/gunpowder/nodes/elastic_augment.py b/gunpowder/nodes/elastic_augment.py index a70f7866..d999d6fe 100644 --- a/gunpowder/nodes/elastic_augment.py +++ b/gunpowder/nodes/elastic_augment.py @@ -124,7 +124,6 @@ def __init__( self.recompute_missing_points = recompute_missing_points def prepare(self, request): - # get the voxel size self.voxel_size = self.__get_common_voxel_size(request) diff --git a/gunpowder/nodes/noise_augment.py b/gunpowder/nodes/noise_augment.py index f4bfb5ba..5275a2c0 100644 --- a/gunpowder/nodes/noise_augment.py +++ b/gunpowder/nodes/noise_augment.py @@ -57,13 +57,11 @@ def process(self, batch, request): seed = request.random_seed try: - raw.data = skimage.util.random_noise( raw.data, mode=self.mode, rng=seed, clip=self.clip, **self.kwargs ).astype(raw.data.dtype) except ValueError: - # legacy version of skimage random_noise raw.data = skimage.util.random_noise( raw.data, mode=self.mode, seed=seed, clip=self.clip, **self.kwargs diff --git a/gunpowder/nodes/random_location.py b/gunpowder/nodes/random_location.py index a86a2c69..d5b6c1e2 100644 --- a/gunpowder/nodes/random_location.py +++ b/gunpowder/nodes/random_location.py @@ -172,7 +172,6 @@ def setup(self): self.provides(self.random_shift_key, ArraySpec(nonspatial=True)) def prepare(self, request): - logger.debug("request: %s", request.array_specs) logger.debug("my spec: %s", self.spec) diff --git a/gunpowder/nodes/random_provider.py b/gunpowder/nodes/random_provider.py index dfb086f8..a9ae1081 100644 --- a/gunpowder/nodes/random_provider.py +++ b/gunpowder/nodes/random_provider.py @@ -69,7 +69,6 @@ def setup(self): self.provides(self.random_provider_key, ArraySpec(nonspatial=True)) def provide(self, request): - if self.random_provider_key is not None: del request[self.random_provider_key] diff --git a/gunpowder/nodes/reject.py b/gunpowder/nodes/reject.py index b6a47436..87bb83aa 100644 --- a/gunpowder/nodes/reject.py +++ b/gunpowder/nodes/reject.py @@ -55,7 +55,6 @@ def setup(self): self.upstream_provider = self.get_upstream_provider() def provide(self, request): - report_next_timeout = 10 num_rejected = 0 diff --git a/gunpowder/nodes/shift_augment.py b/gunpowder/nodes/shift_augment.py index 8fe6524b..8761a563 100644 --- a/gunpowder/nodes/shift_augment.py +++ b/gunpowder/nodes/shift_augment.py @@ -24,7 +24,6 @@ def __init__(self, prob_slip=0, prob_shift=0, sigma=0, shift_axis=0): self.lcm_voxel_size = None def prepare(self, request): - self.ndim = request.get_total_roi().dims assert self.shift_axis in range(self.ndim) diff --git a/gunpowder/nodes/simple_augment.py b/gunpowder/nodes/simple_augment.py index dc756e3a..f5a97333 100644 --- a/gunpowder/nodes/simple_augment.py +++ b/gunpowder/nodes/simple_augment.py @@ -106,7 +106,6 @@ def setup(self): self.permutation_dict[k] = v def prepare(self, request): - self.mirror = [ random.random() < self.mirror_probs[d] if self.mirror_mask[d] else 0 for d in range(self.dims) diff --git a/gunpowder/nodes/stack.py b/gunpowder/nodes/stack.py index 21f53acc..5d7feabd 100644 --- a/gunpowder/nodes/stack.py +++ b/gunpowder/nodes/stack.py @@ -25,7 +25,6 @@ def __init__(self, num_repetitions): self.num_repetitions = num_repetitions def provide(self, request): - batches = [] for _ in range(self.num_repetitions): upstream_request = request.copy() diff --git a/pyproject.toml b/pyproject.toml index 7168e7b2..349b8623 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,7 +70,7 @@ full = [ version = { attr = "gunpowder.version_info.__version__" } [tool.black] -target_version = ['py36', 'py37', 'py38', 'py39', 'py310'] +target_version = ['py38', 'py39', 'py310'] [tool.setuptools.packages.find] include = ["gunpowder*"] From 20780360ff730b6c540f7517cbef8624da3e2832 Mon Sep 17 00:00:00 2001 From: William Patton Date: Thu, 2 Nov 2023 10:27:48 -0700 Subject: [PATCH 20/74] properly skip torch tests if torch not installed --- tests/cases/torch_train.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/tests/cases/torch_train.py b/tests/cases/torch_train.py index c9eed28c..2098f384 100644 --- a/tests/cases/torch_train.py +++ b/tests/cases/torch_train.py @@ -1,14 +1,10 @@ from .helper_sources import ArraySource from gunpowder import ( - BatchProvider, BatchRequest, ArraySpec, Roi, - Coordinate, - ArrayKeys, ArrayKey, Array, - Batch, Scan, PreCache, MergeProvider, @@ -16,8 +12,9 @@ ) from gunpowder.ext import torch, NoSuchModule from gunpowder.torch import Train, Predict -from unittest import skipIf, expectedFailure +from unittest import skipIf import numpy as np +import pytest import logging @@ -55,8 +52,7 @@ def example_train_source(a_key, b_key, c_key): return (source_a, source_b, source_c) + MergeProvider() -if torch is not NoSuchModule: - +if not isinstance(torch, NoSuchModule): class ExampleLinearModel(torch.nn.Module): def __init__(self): super(ExampleLinearModel, self).__init__() @@ -179,7 +175,7 @@ def test_output(): assert np.isclose(batch2[d_pred].data, 2 * (1 + 4 * 2 + 9 * 3)) -if torch is not NoSuchModule: +if not isinstance(torch, NoSuchModule): class Example2DModel(torch.nn.Module): def __init__(self): @@ -229,6 +225,7 @@ def test_scan(): assert pred in batch +@skipIf(isinstance(torch, NoSuchModule), "torch is not installed") def test_precache(): logging.getLogger("gunpowder.torch.nodes.predict").setLevel(logging.INFO) From 46676dd82c8405efb2c26a66865048dd238e99b0 Mon Sep 17 00:00:00 2001 From: William Patton Date: Thu, 2 Nov 2023 10:28:04 -0700 Subject: [PATCH 21/74] black formatting --- tests/cases/torch_train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/cases/torch_train.py b/tests/cases/torch_train.py index 2098f384..2c67fda0 100644 --- a/tests/cases/torch_train.py +++ b/tests/cases/torch_train.py @@ -53,6 +53,7 @@ def example_train_source(a_key, b_key, c_key): if not isinstance(torch, NoSuchModule): + class ExampleLinearModel(torch.nn.Module): def __init__(self): super(ExampleLinearModel, self).__init__() From 84593e823efa7d59e8cfe97915286a0992247133 Mon Sep 17 00:00:00 2001 From: William Patton Date: Thu, 2 Nov 2023 10:35:16 -0700 Subject: [PATCH 22/74] avoid testing on python 3.7, instead use 3.11 numpy is no longer releasing updates for python 3.7, they are on 1.24 but the last release for 3.7 was 1.21. I don't think we need to support it either, but we should test on 3.11 --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 706882c7..2a345506 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -15,7 +15,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.7", "3.8", "3.9", "3.10"] + python-version: ["3.8", "3.9", "3.10", "3.11"] platform: [ubuntu-latest] steps: From ed3c7cec6f1c10e0261a0fdcb4df687bb1a5f886 Mon Sep 17 00:00:00 2001 From: William Patton Date: Thu, 2 Nov 2023 13:49:15 -0700 Subject: [PATCH 23/74] add typed libraries to dev dependencies --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 349b8623..a389a33c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,7 +43,7 @@ dependencies = [ ] [project.optional-dependencies] -dev = ["pytest", "pytest-cov", "flake8", "mypy"] +dev = ["pytest", "pytest-cov", "flake8", "mypy", "types-requests", "types-tqdm"] docs = [ "sphinx", "sphinx_rtd_theme", From ade6bfed80a46a1930fbaa9895bd3894a04da28d Mon Sep 17 00:00:00 2001 From: pattonw Date: Tue, 14 Nov 2023 12:47:17 -0800 Subject: [PATCH 24/74] test subsampling in deform augment test fails --- tests/cases/deform_augment.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/cases/deform_augment.py b/tests/cases/deform_augment.py index 2134a708..6923827d 100644 --- a/tests/cases/deform_augment.py +++ b/tests/cases/deform_augment.py @@ -135,6 +135,7 @@ def test_3d_basics(rotate, spatial_dims, fast_points): rotate=rotate, spatial_dims=spatial_dims, use_fast_points_transform=fast_points, + subsample=2, ) for _ in range(5): From 1ce6f15e16bfa4b58ee7e32d2919d213efaa0da7 Mon Sep 17 00:00:00 2001 From: pattonw Date: Tue, 14 Nov 2023 12:47:36 -0800 Subject: [PATCH 25/74] fix bugs associated with subsampling --- gunpowder/nodes/deform_augment.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/gunpowder/nodes/deform_augment.py b/gunpowder/nodes/deform_augment.py index 13826909..75c72254 100644 --- a/gunpowder/nodes/deform_augment.py +++ b/gunpowder/nodes/deform_augment.py @@ -486,11 +486,12 @@ def __create_transformation(self, target_spec: ArraySpec): rot_transformation = create_rotation_transformation( target_shape, random.random() * math.pi, + subsample=self.subsample, ) else: angle = Rotation.random() rot_transformation = create_3D_rotation_transformation( - target_shape, angle + target_shape, angle, subsample=self.subsample ) local_transformation += rot_transformation @@ -499,6 +500,9 @@ def __create_transformation(self, target_spec: ArraySpec): local_transformation = upscale_transformation( local_transformation, target_shape ) + global_transformation = upscale_transformation( + global_transformation, target_shape + ) # transform into world units global_transformation *= np.array(target_spec.voxel_size).reshape( From 6216f2f32fa13dab2255bf596db9f1cd03935ac2 Mon Sep 17 00:00:00 2001 From: pattonw Date: Tue, 14 Nov 2023 12:48:16 -0800 Subject: [PATCH 26/74] deform augment fix bug with checking dims of graph_raster_voxel_size --- gunpowder/nodes/deform_augment.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/gunpowder/nodes/deform_augment.py b/gunpowder/nodes/deform_augment.py index 75c72254..11292981 100644 --- a/gunpowder/nodes/deform_augment.py +++ b/gunpowder/nodes/deform_augment.py @@ -115,13 +115,21 @@ def __init__( self.use_fast_points_transform = use_fast_points_transform self.recompute_missing_points = recompute_missing_points self.transform_key = transform_key - self.graph_raster_voxel_size = Coordinate(graph_raster_voxel_size) + self.graph_raster_voxel_size = ( + Coordinate(graph_raster_voxel_size) + if graph_raster_voxel_size is not None + else None + ) self.p = p - assert ( - self.control_point_spacing.dims - == self.jitter_sigma.dims - == self.graph_raster_voxel_size.dims + assert self.control_point_spacing.dims == self.jitter_sigma.dims, ( + self.control_point_spacing, + self.jitter_sigma, ) + if self.graph_raster_voxel_size is not None: + assert self.graph_raster_voxel_size.dims == self.jitter_sigma.dims, ( + self.graph_raster_voxel_size, + self.jitter_sigma, + ) def setup(self): if self.transform_key is not None: From 592b35464d1b89b7b31fc06768f7f0121f3cda69 Mon Sep 17 00:00:00 2001 From: Jan Funke Date: Wed, 29 Nov 2023 14:38:12 -0500 Subject: [PATCH 27/74] Add progress callback to Scan node --- docs/source/api.rst | 1 + gunpowder/nodes/__init__.py | 2 +- gunpowder/nodes/scan.py | 78 +++++++++++++++++++++++++++++++++++-- 3 files changed, 76 insertions(+), 5 deletions(-) diff --git a/docs/source/api.rst b/docs/source/api.rst index 5f120a1c..21b7753b 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -334,6 +334,7 @@ Iterative Processing Nodes Scan ^^^^ .. autoclass:: Scan + .. autoclass:: ScanCallback DaisyRequestBlocks ^^^^^^^^^^^^^^^^^^ diff --git a/gunpowder/nodes/__init__.py b/gunpowder/nodes/__init__.py index 4131b824..1a152f8e 100644 --- a/gunpowder/nodes/__init__.py +++ b/gunpowder/nodes/__init__.py @@ -34,7 +34,7 @@ from .reject import Reject from .renumber_connected_components import RenumberConnectedComponents from .resample import Resample -from .scan import Scan +from .scan import Scan, ScanCallback from .shift_augment import ShiftAugment from .simple_augment import SimpleAugment from .snapshot import Snapshot diff --git a/gunpowder/nodes/scan.py b/gunpowder/nodes/scan.py index ef6b378e..3473764e 100644 --- a/gunpowder/nodes/scan.py +++ b/gunpowder/nodes/scan.py @@ -2,6 +2,7 @@ import multiprocessing import numpy as np import tqdm +from abc import ABC from gunpowder.array import Array from gunpowder.batch import Batch from gunpowder.coordinate import Coordinate @@ -13,6 +14,55 @@ logger = logging.getLogger(__name__) +class ScanCallback(ABC): + """Base class for :class:`Scan` callbacks. Implement any of ``start``, + ``update``, and ``stop`` in a subclass to create your own callback. + """ + + def start(self, num_total): + """Called once before :class:`Scan` starts scanning over chunks. + + Args: + + num_total (int): + + The total number of chunks to process. + """ + pass + + def update(self, num_processed): + """Called periodically by :class:`Scan` while processing chunks. + + Args: + + num_processed (int): + + The number of chunks already processed. + """ + pass + + def stop(self): + """Called once after :class:`Scan` scanned over all chunks.""" + pass + + +class TqdmCallback(ScanCallback): + """A default callback that uses ``tqdm`` to show a progress bar.""" + + def start(self, num_total): + logger.info("scanning over %d chunks", num_total) + + self.progress_bar = tqdm.tqdm(desc="Scan, chunks processed", total=num_total) + self.num_processed = 0 + + def update(self, num_processed): + self.progress_bar.update(num_processed - self.num_processed) + self.num_processed = num_processed + + def stop(self): + self.progress_bar.close() + + class Scan(BatchFilter): """Iteratively requests batches of size ``reference`` from upstream providers in a scanning fashion, until all requested ROIs are covered. If @@ -40,14 +90,24 @@ class Scan(BatchFilter): cache_size (``int``, optional): If multiple workers are used, how many batches to hold at most. + + progress_callback (class:`ScanCallback`, optional): + + A callback instance to get updated from this node while processing + chunks. See :class:`ScanCallback` for details. The default is a + callback that shows a ``tqdm`` progress bar. """ - def __init__(self, reference, num_workers=1, cache_size=50): + def __init__(self, reference, num_workers=1, cache_size=50, progress_callback=None): self.reference = reference.copy() self.num_workers = num_workers self.cache_size = cache_size self.workers = None self.batch = None + if progress_callback is None: + self.progress_callback = TqdmCallback() + else: + self.progress_callback = progress_callback def setup(self): if self.num_workers > 1: @@ -75,7 +135,8 @@ def provide(self, request): shifts = self._enumerate_shifts(shift_roi, stride) num_chunks = len(shifts) - logger.info("scanning over %d chunks", num_chunks) + if self.progress_callback is not None: + self.progress_callback.start(num_chunks) # the batch to return self.batch = Batch() @@ -85,24 +146,33 @@ def provide(self, request): shifted_reference = self._shift_request(self.reference, shift) self.request_queue.put(shifted_reference) - for i in tqdm.tqdm(range(num_chunks)): + for i in range(num_chunks): chunk = self.workers.get() if not empty_request: self._add_to_batch(request, chunk) + if self.progress_callback is not None: + self.progress_callback.update(i + 1) + logger.debug("processed chunk %d/%d", i + 1, num_chunks) else: - for i, shift in enumerate(tqdm.tqdm(shifts)): + for i, shift in enumerate(shifts): shifted_reference = self._shift_request(self.reference, shift) chunk = self._get_chunk(shifted_reference) if not empty_request: self._add_to_batch(request, chunk) + if self.progress_callback is not None: + self.progress_callback.update(i + 1) + logger.debug("processed chunk %d/%d", i + 1, num_chunks) + if self.progress_callback is not None: + self.progress_callback.stop() + batch = self.batch self.batch = None From d49db1f23ce5711fa850cf89eb228bd8cade0e7a Mon Sep 17 00:00:00 2001 From: William Patton Date: Thu, 30 Nov 2023 11:15:26 -0800 Subject: [PATCH 28/74] pass torch train test if using start method = "spawn" and the "start_subprocess" flag for the predict node, we now pass our test. --- gunpowder/nodes/generic_predict.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gunpowder/nodes/generic_predict.py b/gunpowder/nodes/generic_predict.py index 524967b8..e3f4ec5b 100644 --- a/gunpowder/nodes/generic_predict.py +++ b/gunpowder/nodes/generic_predict.py @@ -89,7 +89,7 @@ def setup(self): if self.spawn_subprocess: # start prediction as a producer pool, so that we can gracefully # exit if anything goes wrong - self.worker = ProducerPool([self.__produce_predict_batch], queue_size=1) + self.worker = ProducerPool([self._produce_predict_batch], queue_size=1) self.batch_in = multiprocessing.Queue(maxsize=1) self.batch_in_lock = multiprocessing.Lock() self.batch_out_lock = multiprocessing.Lock() @@ -177,7 +177,7 @@ def stop(self): """ pass - def __produce_predict_batch(self): + def _produce_predict_batch(self): """Process one batch.""" if not self.initialized: From 797994c4ce91946f8ccc217fa8d82306d6d503fc Mon Sep 17 00:00:00 2001 From: William Patton Date: Thu, 30 Nov 2023 11:16:04 -0800 Subject: [PATCH 29/74] pass torch train test if using the start method "spawn", and the "spawn_subprocess" flag for the train node, we now pass our test --- gunpowder/nodes/generic_train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gunpowder/nodes/generic_train.py b/gunpowder/nodes/generic_train.py index ae93b7de..a26a285f 100644 --- a/gunpowder/nodes/generic_train.py +++ b/gunpowder/nodes/generic_train.py @@ -104,7 +104,7 @@ def setup(self): if self.spawn_subprocess: # start training as a producer pool, so that we can gracefully exit if # anything goes wrong - self.worker = ProducerPool([self.__produce_train_batch], queue_size=1) + self.worker = ProducerPool([self._produce_train_batch], queue_size=1) self.batch_in = multiprocessing.Queue(maxsize=1) self.worker.start() else: @@ -208,7 +208,7 @@ def natural_keys(text): return None, 0 - def __produce_train_batch(self): + def _produce_train_batch(self): """Process one train batch.""" if not self.initialized: From b2f8c2dbedf5bf7b3917c62704f082c3927a092c Mon Sep 17 00:00:00 2001 From: William Patton Date: Tue, 19 Dec 2023 08:05:51 -0700 Subject: [PATCH 30/74] remove extra error printing --- gunpowder/producer_pool.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/gunpowder/producer_pool.py b/gunpowder/producer_pool.py index 035f0d74..0f6c2888 100644 --- a/gunpowder/producer_pool.py +++ b/gunpowder/producer_pool.py @@ -143,9 +143,7 @@ def _run_worker(self, target): try: result = target() except Exception as e: - logger.error(e, exc_info=True) result = e - traceback.print_exc() # don't stop on normal exceptions -- place them in result queue # and let them be handled by caller except: From 625cb03b82894c4dfdc40162a700662030d76eee Mon Sep 17 00:00:00 2001 From: William Patton Date: Tue, 19 Dec 2023 08:07:54 -0700 Subject: [PATCH 31/74] switch error printing order Now prints the errors in reverse order of execution so the initial pipeline error is printed first --- gunpowder/nodes/batch_provider.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/gunpowder/nodes/batch_provider.py b/gunpowder/nodes/batch_provider.py index 1f6b9dc8..304e1e3a 100644 --- a/gunpowder/nodes/batch_provider.py +++ b/gunpowder/nodes/batch_provider.py @@ -3,6 +3,8 @@ import copy import logging import random +import traceback +from typing import Optional from gunpowder.coordinate import Coordinate from gunpowder.provider_spec import ProviderSpec @@ -15,17 +17,22 @@ class BatchRequestError(Exception): - def __init__(self, provider, request, batch): + def __init__( + self, provider, request, batch, original_traceback: Optional[list[str]] = None + ): self.provider = provider self.request = request self.batch = batch + self.original_traceback = original_traceback def __str__(self): return ( f"Exception in {self.provider.name()} while processing request" - f"{self.request} \n" + f"{self.request}" "Batch returned so far:\n" - f"{self.batch}" + f"{self.batch}" + ("\n\n" + "".join(self.original_traceback)) + if self.original_traceback is not None + else "" ) @@ -194,7 +201,12 @@ def request_batch(self, request): logger.debug("%s provides %s", self.name(), batch) except Exception as e: - raise BatchRequestError(self, request, batch) from e + tb = traceback.format_exception(type(e), e, e.__traceback__) + if isinstance(e, BatchRequestError): + tb = tb[-1:] + raise BatchRequestError( + self, request, batch, original_traceback=tb + ) from None return batch From 35fcd4343db258f14db59890e32fc1e6409c1453 Mon Sep 17 00:00:00 2001 From: William Patton Date: Tue, 19 Dec 2023 08:41:20 -0700 Subject: [PATCH 32/74] black format docs and examples --- docs/source/conf.py | 2 +- examples/cremi/mknet.py | 54 ++++++--------- examples/cremi/predict.py | 63 ++++++++---------- examples/cremi/train.py | 137 +++++++++++++++++--------------------- 4 files changed, 111 insertions(+), 145 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 529dbe8b..b0da21cf 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -54,7 +54,7 @@ ] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # -- Options for HTML output ------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output diff --git a/examples/cremi/mknet.py b/examples/cremi/mknet.py index aac8c0df..fe13a7a5 100644 --- a/examples/cremi/mknet.py +++ b/examples/cremi/mknet.py @@ -2,8 +2,8 @@ import tensorflow as tf import json -def create_network(input_shape, name): +def create_network(input_shape, name): tf.reset_default_graph() # create a placeholder for the 3D raw input tensor @@ -11,20 +11,17 @@ def create_network(input_shape, name): # create a U-Net raw_batched = tf.reshape(raw, (1, 1) + input_shape) - unet_output = unet(raw_batched, 6, 4, [[1,3,3],[1,3,3],[1,3,3]]) + unet_output = unet(raw_batched, 6, 4, [[1, 3, 3], [1, 3, 3], [1, 3, 3]]) # add a convolution layer to create 3 output maps representing affinities # in z, y, and x pred_affs_batched = conv_pass( - unet_output, - kernel_size=1, - num_fmaps=3, - num_repetitions=1, - activation='sigmoid') + unet_output, kernel_size=1, num_fmaps=3, num_repetitions=1, activation="sigmoid" + ) # get the shape of the output output_shape_batched = pred_affs_batched.get_shape().as_list() - output_shape = output_shape_batched[1:] # strip the batch dimension + output_shape = output_shape_batched[1:] # strip the batch dimension # the 4D output tensor (3, depth, height, width) pred_affs = tf.reshape(pred_affs_batched, output_shape) @@ -33,46 +30,39 @@ def create_network(input_shape, name): gt_affs = tf.placeholder(tf.float32, shape=output_shape) # create a placeholder for per-voxel loss weights - loss_weights = tf.placeholder( - tf.float32, - shape=output_shape) + loss_weights = tf.placeholder(tf.float32, shape=output_shape) # compute the loss as the weighted mean squared error between the # predicted and the ground-truth affinities - loss = tf.losses.mean_squared_error( - gt_affs, - pred_affs, - loss_weights) + loss = tf.losses.mean_squared_error(gt_affs, pred_affs, loss_weights) # use the Adam optimizer to minimize the loss opt = tf.train.AdamOptimizer( - learning_rate=0.5e-4, - beta1=0.95, - beta2=0.999, - epsilon=1e-8) + learning_rate=0.5e-4, beta1=0.95, beta2=0.999, epsilon=1e-8 + ) optimizer = opt.minimize(loss) # store the network in a meta-graph file - tf.train.export_meta_graph(filename=name + '.meta') + tf.train.export_meta_graph(filename=name + ".meta") # store network configuration for use in train and predict scripts config = { - 'raw': raw.name, - 'pred_affs': pred_affs.name, - 'gt_affs': gt_affs.name, - 'loss_weights': loss_weights.name, - 'loss': loss.name, - 'optimizer': optimizer.name, - 'input_shape': input_shape, - 'output_shape': output_shape[1:] + "raw": raw.name, + "pred_affs": pred_affs.name, + "gt_affs": gt_affs.name, + "loss_weights": loss_weights.name, + "loss": loss.name, + "optimizer": optimizer.name, + "input_shape": input_shape, + "output_shape": output_shape[1:], } - with open(name + '_config.json', 'w') as f: + with open(name + "_config.json", "w") as f: json.dump(config, f) -if __name__ == "__main__": +if __name__ == "__main__": # create a network for training - create_network((84, 268, 268), 'train_net') + create_network((84, 268, 268), "train_net") # create a larger network for faster prediction - create_network((120, 322, 322), 'test_net') + create_network((120, 322, 322), "test_net") diff --git a/examples/cremi/predict.py b/examples/cremi/predict.py index 8693786f..4f229b14 100644 --- a/examples/cremi/predict.py +++ b/examples/cremi/predict.py @@ -2,29 +2,29 @@ import gunpowder as gp import json -def predict(iteration): +def predict(iteration): ################## # DECLARE ARRAYS # ################## # raw intensities - raw = gp.ArrayKey('RAW') + raw = gp.ArrayKey("RAW") # the predicted affinities - pred_affs = gp.ArrayKey('PRED_AFFS') + pred_affs = gp.ArrayKey("PRED_AFFS") #################### # DECLARE REQUESTS # #################### - with open('test_net_config.json', 'r') as f: + with open("test_net_config.json", "r") as f: net_config = json.load(f) # get the input and output size in world units (nm, in this case) voxel_size = gp.Coordinate((40, 4, 4)) - input_size = gp.Coordinate(net_config['input_shape'])*voxel_size - output_size = gp.Coordinate(net_config['output_shape'])*voxel_size + input_size = gp.Coordinate(net_config["input_shape"]) * voxel_size + output_size = gp.Coordinate(net_config["output_shape"]) * voxel_size context = input_size - output_size # formulate the request for what a batch should contain @@ -37,10 +37,8 @@ def predict(iteration): ############################# source = gp.Hdf5Source( - 'sample_A_padded_20160501.hdf', - datasets = { - raw: 'volumes/raw' - }) + "sample_A_padded_20160501.hdf", datasets={raw: "volumes/raw"} + ) # get the ROI provided for raw (we need it later to calculate the ROI in # which we can make predictions) @@ -48,41 +46,35 @@ def predict(iteration): raw_roi = source.spec[raw].roi pipeline = ( - # read from HDF5 file - source + - + source + + # convert raw to float in [0, 1] - gp.Normalize(raw) + - + gp.Normalize(raw) + + # perform one training iteration for each passing batch (here we use # the tensor names earlier stored in train_net.config) gp.tensorflow.Predict( - graph='test_net.meta', - checkpoint='train_net_checkpoint_%d'%iteration, - inputs={ - net_config['raw']: raw - }, - outputs={ - net_config['pred_affs']: pred_affs - }, - array_specs={ - pred_affs: gp.ArraySpec(roi=raw_roi.grow(-context, -context)) - }) + - + graph="test_net.meta", + checkpoint="train_net_checkpoint_%d" % iteration, + inputs={net_config["raw"]: raw}, + outputs={net_config["pred_affs"]: pred_affs}, + array_specs={pred_affs: gp.ArraySpec(roi=raw_roi.grow(-context, -context))}, + ) + + # store all passing batches in the same HDF5 file gp.Hdf5Write( { - raw: '/volumes/raw', - pred_affs: '/volumes/pred_affs', + raw: "/volumes/raw", + pred_affs: "/volumes/pred_affs", }, - output_filename='predictions_sample_A.hdf', - compression_type='gzip' - ) + - + output_filename="predictions_sample_A.hdf", + compression_type="gzip", + ) + + # show a summary of time spend in each node every 10 iterations - gp.PrintProfilingStats(every=10) + - + gp.PrintProfilingStats(every=10) + + # iterate over the whole dataset in a scanning fashion, emitting # requests that match the size of the network gp.Scan(reference=request) @@ -93,5 +85,6 @@ def predict(iteration): # without keeping the complete dataset in memory pipeline.request_batch(gp.BatchRequest()) + if __name__ == "__main__": predict(200000) diff --git a/examples/cremi/train.py b/examples/cremi/train.py index 8edd12f7..6faf7e50 100644 --- a/examples/cremi/train.py +++ b/examples/cremi/train.py @@ -6,41 +6,41 @@ logging.basicConfig(level=logging.INFO) -def train(iterations): +def train(iterations): ################## # DECLARE ARRAYS # ################## # raw intensities - raw = gp.ArrayKey('RAW') + raw = gp.ArrayKey("RAW") # objects labelled with unique IDs - gt_labels = gp.ArrayKey('LABELS') + gt_labels = gp.ArrayKey("LABELS") # array of per-voxel affinities to direct neighbors - gt_affs= gp.ArrayKey('AFFINITIES') + gt_affs = gp.ArrayKey("AFFINITIES") # weights to use to balance the loss - loss_weights = gp.ArrayKey('LOSS_WEIGHTS') + loss_weights = gp.ArrayKey("LOSS_WEIGHTS") # the predicted affinities - pred_affs = gp.ArrayKey('PRED_AFFS') + pred_affs = gp.ArrayKey("PRED_AFFS") # the gredient of the loss wrt to the predicted affinities - pred_affs_gradients = gp.ArrayKey('PRED_AFFS_GRADIENTS') + pred_affs_gradients = gp.ArrayKey("PRED_AFFS_GRADIENTS") #################### # DECLARE REQUESTS # #################### - with open('train_net_config.json', 'r') as f: + with open("train_net_config.json", "r") as f: net_config = json.load(f) # get the input and output size in world units (nm, in this case) voxel_size = gp.Coordinate((40, 4, 4)) - input_size = gp.Coordinate(net_config['input_shape'])*voxel_size - output_size = gp.Coordinate(net_config['output_shape'])*voxel_size + input_size = gp.Coordinate(net_config["input_shape"]) * voxel_size + output_size = gp.Coordinate(net_config["output_shape"]) * voxel_size # formulate the request for what a batch should (at least) contain request = gp.BatchRequest() @@ -60,44 +60,38 @@ def train(iterations): ############################## pipeline = ( - # a tuple of sources, one for each sample (A, B, and C) provided by the # CREMI challenge tuple( - # read batches from the HDF5 file gp.Hdf5Source( - 'sample_'+s+'_padded_20160501.hdf', - datasets = { - raw: 'volumes/raw', - gt_labels: 'volumes/labels/neuron_ids' - } - ) + - + "sample_" + s + "_padded_20160501.hdf", + datasets={raw: "volumes/raw", gt_labels: "volumes/labels/neuron_ids"}, + ) + + # convert raw to float in [0, 1] gp.Normalize(raw) + - # chose a random location for each requested batch gp.RandomLocation() - - for s in ['A', 'B', 'C'] - ) + - + for s in ["A", "B", "C"] + ) + + # chose a random source (i.e., sample) from the above - gp.RandomProvider() + - + gp.RandomProvider() + + # elastically deform the batch gp.ElasticAugment( - [4,40,40], - [0,2,2], - [0,math.pi/2.0], + [4, 40, 40], + [0, 2, 2], + [0, math.pi / 2.0], prob_slip=0.05, prob_shift=0.05, - max_misalign=25) + - + max_misalign=25, + ) + + # apply transpose and mirror augmentations - gp.SimpleAugment(transpose_only=[1, 2]) + - + gp.SimpleAugment(transpose_only=[1, 2]) + + # scale and shift the intensity of the raw array gp.IntensityAugment( raw, @@ -105,65 +99,54 @@ def train(iterations): scale_max=1.1, shift_min=-0.1, shift_max=0.1, - z_section_wise=True) + - + z_section_wise=True, + ) + + # grow a boundary between labels - gp.GrowBoundary( - gt_labels, - steps=3, - only_xy=True) + - + gp.GrowBoundary(gt_labels, steps=3, only_xy=True) + + # convert labels into affinities between voxels - gp.AddAffinities( - [[-1, 0, 0], [0, -1, 0], [0, 0, -1]], - gt_labels, - gt_affs) + - + gp.AddAffinities([[-1, 0, 0], [0, -1, 0], [0, 0, -1]], gt_labels, gt_affs) + + # create a weight array that balances positive and negative samples in # the affinity array - gp.BalanceLabels( - gt_affs, - loss_weights) + - + gp.BalanceLabels(gt_affs, loss_weights) + + # pre-cache batches from the point upstream - gp.PreCache( - cache_size=10, - num_workers=5) + - + gp.PreCache(cache_size=10, num_workers=5) + + # perform one training iteration for each passing batch (here we use # the tensor names earlier stored in train_net.config) gp.tensorflow.Train( - 'train_net', - net_config['optimizer'], - net_config['loss'], + "train_net", + net_config["optimizer"], + net_config["loss"], inputs={ - net_config['raw']: raw, - net_config['gt_affs']: gt_affs, - net_config['loss_weights']: loss_weights + net_config["raw"]: raw, + net_config["gt_affs"]: gt_affs, + net_config["loss_weights"]: loss_weights, }, - outputs={ - net_config['pred_affs']: pred_affs - }, - gradients={ - net_config['pred_affs']: pred_affs_gradients - }, - save_every=1) + - + outputs={net_config["pred_affs"]: pred_affs}, + gradients={net_config["pred_affs"]: pred_affs_gradients}, + save_every=1, + ) + + # save the passing batch as an HDF5 file for inspection gp.Snapshot( { - raw: '/volumes/raw', - gt_labels: '/volumes/labels/neuron_ids', - gt_affs: '/volumes/labels/affs', - pred_affs: '/volumes/pred_affs', - pred_affs_gradients: '/volumes/pred_affs_gradients' + raw: "/volumes/raw", + gt_labels: "/volumes/labels/neuron_ids", + gt_affs: "/volumes/labels/affs", + pred_affs: "/volumes/pred_affs", + pred_affs_gradients: "/volumes/pred_affs_gradients", }, - output_dir='snapshots', - output_filename='batch_{iteration}.hdf', + output_dir="snapshots", + output_filename="batch_{iteration}.hdf", every=100, additional_request=snapshot_request, - compression_type='gzip') + - + compression_type="gzip", + ) + + # show a summary of time spend in each node every 10 iterations gp.PrintProfilingStats(every=10) ) @@ -180,6 +163,6 @@ def train(iterations): print("Finished") + if __name__ == "__main__": train(200000) - \ No newline at end of file From 076661f81ae09aadb15b877c238ced040afc7cc4 Mon Sep 17 00:00:00 2001 From: William Patton Date: Tue, 19 Dec 2023 08:45:59 -0700 Subject: [PATCH 33/74] Squashed commit of the following: commit 1686b949766b76960534ede1105751591fd91c9f Author: William Patton Date: Tue Dec 19 08:43:11 2023 -0700 black reformatting commit 26d2c7cfff3f2702f56a5bb4249a0811f54b45ef Author: Mohinta2892 Date: Thu Nov 2 19:09:15 2023 +0000 Revert "black reformatted" This reverts commit 66dd69bb6404eb305c69e7b64cc7d902774a9263. Only format changed files, since black does not consider formatting history commit a273fd3813fc16b516c2438ad5af0c4ee3f0686b Author: Samia Mohinta <44754434+Mohinta2892@users.noreply.github.com> Date: Thu Nov 2 17:12:26 2023 +0000 black reformatted commit bb37769eec33af5921386f283e2579055bb34e6d Author: Samia Mohinta <44754434+Mohinta2892@users.noreply.github.com> Date: Thu Nov 2 16:40:32 2023 +0000 add device arg Allow passing cuda device to Predict. Issue #188 commit a3b3588a1406d609ae95370cf2c5339872616011 Author: Samia Mohinta <44754434+Mohinta2892@users.noreply.github.com> Date: Thu Nov 2 16:39:09 2023 +0000 add device arg allow passing cuda device to Train --- gunpowder/torch/nodes/predict.py | 9 ++++++--- gunpowder/torch/nodes/train.py | 13 +++++++++++-- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/gunpowder/torch/nodes/predict.py b/gunpowder/torch/nodes/predict.py index 585ebc34..89c9ac0c 100644 --- a/gunpowder/torch/nodes/predict.py +++ b/gunpowder/torch/nodes/predict.py @@ -85,10 +85,13 @@ def __init__( self.intermediate_layers: dict[ArrayKey, Any] = {} def start(self): - self.use_cuda = torch.cuda.is_available() and self.device_string == "cuda" - logger.info(f"Predicting on {'gpu' if self.use_cuda else 'cpu'}") - self.device = torch.device("cuda" if self.use_cuda else "cpu") + # Issue #188 + self.use_cuda = torch.cuda.is_available() and self.device_string.__contains__( + "cuda" + ) + logger.info(f"Predicting on {'gpu' if self.use_cuda else 'cpu'}") + self.device = torch.device(self.device_string if self.use_cuda else "cpu") try: self.model = self.model.to(self.device) except RuntimeError as e: diff --git a/gunpowder/torch/nodes/train.py b/gunpowder/torch/nodes/train.py index 3f688929..e913d8ad 100644 --- a/gunpowder/torch/nodes/train.py +++ b/gunpowder/torch/nodes/train.py @@ -79,6 +79,12 @@ class Train(GenericTrain): spawn_subprocess (``bool``, optional): Whether to run the ``train_step`` in a separate process. Default is false. + + device (``str``, optional): + + Accepts a cuda gpu specifically to train on (e.g. `cuda:1`, `cuda:2`), helps in multi-card systems. + defaults to ``cuda`` + """ def __init__( @@ -93,9 +99,10 @@ def __init__( array_specs: Optional[Dict[ArrayKey, ArraySpec]] = None, checkpoint_basename: str = "model", save_every: int = 2000, - log_dir: Optional[str] = None, + log_dir: str = None, log_every: int = 1, spawn_subprocess: bool = False, + device: str = "cuda", ): if not model.training: logger.warning( @@ -125,6 +132,7 @@ def __init__( self.loss_inputs = loss_inputs self.checkpoint_basename = checkpoint_basename self.save_every = save_every + self.dev = device self.iteration = 0 @@ -167,7 +175,8 @@ def retain_gradients(self, request, outputs): def start(self): self.use_cuda = torch.cuda.is_available() - self.device = torch.device("cuda" if self.use_cuda else "cpu") + # Issue: #188 + self.device = torch.device(self.dev if self.use_cuda else "cpu") try: self.model = self.model.to(self.device) From b6c425f033252efbe3df46b073ca1658cfffe482 Mon Sep 17 00:00:00 2001 From: William Patton Date: Tue, 19 Dec 2023 09:32:47 -0700 Subject: [PATCH 34/74] parameterize tests for cuda devices currently failing a few of them, some are expected failures. --- tests/cases/torch_train.py | 77 +++++++++++++++++++++++++++++++++++--- 1 file changed, 72 insertions(+), 5 deletions(-) diff --git a/tests/cases/torch_train.py b/tests/cases/torch_train.py index 2c67fda0..b368f9d6 100644 --- a/tests/cases/torch_train.py +++ b/tests/cases/torch_train.py @@ -68,8 +68,20 @@ def forward(self, a, b): return d_pred +@pytest.mark.parametrize( + "device", + [ + "cpu", + pytest.param( + "cuda:0", + marks=pytest.mark.skipif( + not torch.cuda.is_available(), reason="CUDA not available" + ), + ), + ], +) @skipIf(isinstance(torch, NoSuchModule), "torch is not installed") -def test_loss_drops(tmpdir): +def test_loss_drops(tmpdir, device): checkpoint_basename = str(tmpdir / "model") a_key = ArrayKey("A") @@ -80,7 +92,7 @@ def test_loss_drops(tmpdir): model = ExampleLinearModel() loss = torch.nn.MSELoss() - optimizer = torch.optim.SGD(model.parameters(), lr=1e-7, momentum=0.999) + optimizer = torch.optim.SGD(model.parameters(), lr=1e-8, momentum=0.999) source = example_train_source(a_key, b_key, c_key) train = Train( @@ -98,6 +110,7 @@ def test_loss_drops(tmpdir): checkpoint_basename=checkpoint_basename, save_every=100, spawn_subprocess=False, + device=device, ) pipeline = source + train @@ -130,8 +143,25 @@ def test_loss_drops(tmpdir): assert loss2 < loss1 +@pytest.mark.parametrize( + "device", + [ + "cpu", + pytest.param( + "cuda:0", + marks=[ + pytest.mark.skipif( + not torch.cuda.is_available(), reason="CUDA not available" + ), + pytest.mark.xfail( + reason="failing to move model to device when using a subprocess" + ), + ], + ), + ], +) @skipIf(isinstance(torch, NoSuchModule), "torch is not installed") -def test_output(): +def test_output(device): logging.getLogger("gunpowder.torch.nodes.predict").setLevel(logging.INFO) a_key = ArrayKey("A") @@ -153,6 +183,7 @@ def test_output(): d_pred: ArraySpec(nonspatial=True), }, spawn_subprocess=True, + device=device, ) pipeline = source + predict @@ -191,8 +222,25 @@ def forward(self, a): return pred +@pytest.mark.parametrize( + "device", + [ + "cpu", + pytest.param( + "cuda:0", + marks=[ + pytest.mark.skipif( + not torch.cuda.is_available(), reason="CUDA not available" + ), + pytest.mark.xfail( + reason="failing to move model to device in multiprocessing context" + ), + ], + ), + ], +) @skipIf(isinstance(torch, NoSuchModule), "torch is not installed") -def test_scan(): +def test_scan(device): logging.getLogger("gunpowder.torch.nodes.predict").setLevel(logging.INFO) a_key = ArrayKey("A") @@ -210,6 +258,7 @@ def test_scan(): inputs={"a": a_key}, outputs={0: pred}, array_specs={pred: ArraySpec()}, + device=device, ) pipeline = source + predict + Scan(reference_request, num_workers=2) @@ -226,8 +275,25 @@ def test_scan(): assert pred in batch +@pytest.mark.parametrize( + "device", + [ + "cpu", + pytest.param( + "cuda:0", + marks=[ + pytest.mark.skipif( + not torch.cuda.is_available(), reason="CUDA not available" + ), + pytest.mark.xfail( + reason="failing to move model to device in multiprocessing context" + ), + ], + ), + ], +) @skipIf(isinstance(torch, NoSuchModule), "torch is not installed") -def test_precache(): +def test_precache(device): logging.getLogger("gunpowder.torch.nodes.predict").setLevel(logging.INFO) a_key = ArrayKey("A") @@ -245,6 +311,7 @@ def test_precache(): inputs={"a": a_key}, outputs={0: pred}, array_specs={pred: ArraySpec()}, + device=device, ) pipeline = source + predict + PreCache(cache_size=3, num_workers=2) From 588824f9bb753ca3b698df0d7ba87aec22d86d1e Mon Sep 17 00:00:00 2001 From: William Patton Date: Tue, 2 Jan 2024 08:56:35 -0800 Subject: [PATCH 35/74] Added support for reflect padding Squashed commit of the following: commit 0fb29c852180227a36c4055c528ad39d956fd175 Author: William Patton Date: Tue Jan 2 08:54:17 2024 -0800 replace custom padding code with np.pad commit c6928bd70565e8bc8911bb414583a176a1112810 Author: William Patton Date: Tue Jan 2 08:54:06 2024 -0800 simplify/expand padding test test padding on both sides commit 37825254ccef75ca34371115809745f5e6f92e9d Author: William Patton Date: Tue Dec 19 11:30:31 2023 -0700 pass the fixed tests commit a7027c6825fd0661074a67cdf9b0406b8982600d Author: William Patton Date: Tue Dec 19 10:37:48 2023 -0700 fix the test case commit 531d81dc7dfad30b0639578ef350368655e0eb42 Author: William Patton Date: Tue Dec 19 10:06:44 2023 -0700 update the pad tests parametrized the use of constant or reflect padding. Now avoids using the unittest framework commit 443c666a35e8ff5784360c2e427931ffca06a741 Author: Manan Lalit <34229641+lmanan@users.noreply.github.com> Date: Fri Nov 3 00:09:33 2023 -0400 Replace .ndim by len() commit a7503d7bf93f3609d7e573b927f7ea9f8c0a6b72 Author: lmanan Date: Thu Nov 2 11:52:27 2023 -0400 Update pad.py to include reflective padding --- gunpowder/nodes/pad.py | 35 ++++++------- tests/cases/pad.py | 111 +++++++++++++++++++++++++---------------- 2 files changed, 86 insertions(+), 60 deletions(-) diff --git a/gunpowder/nodes/pad.py b/gunpowder/nodes/pad.py index 6bbfdc58..f025d9d7 100644 --- a/gunpowder/nodes/pad.py +++ b/gunpowder/nodes/pad.py @@ -7,6 +7,8 @@ from gunpowder.coordinate import Coordinate from gunpowder.batch_request import BatchRequest +from itertools import product + logger = logging.getLogger(__name__) @@ -27,15 +29,22 @@ class Pad(BatchFilter): a coordinate, this amount will be added to the ROI in the positive and negative direction. + mode (string): + + One of 'constant' or 'reflect'. + Default is 'constant' + value (scalar or ``None``): The value to report inside the padding. If not given, 0 is used. + Only used in case of 'constant' mode. Only used for :class:`Array`. """ - def __init__(self, key, size, value=None): + def __init__(self, key, size, mode="constant", value=None): self.key = key self.size = size + self.mode = mode self.value = value def setup(self): @@ -118,19 +127,11 @@ def __expand(self, a, from_roi, to_roi, value): ) num_channels = len(a.shape) - from_roi.dims - channel_shapes = a.shape[:num_channels] - - b = np.zeros(channel_shapes + to_roi.shape, dtype=a.dtype) - if value != 0: - b[:] = value - - shift = -to_roi.offset - logger.debug("shifting 'from' by " + str(shift)) - a_in_b = from_roi.shift(shift).to_slices() - - logger.debug("target shape is " + str(b.shape)) - logger.debug("target slice is " + str(a_in_b)) - - b[(slice(None),) * num_channels + a_in_b] = a - - return b + lower_pad = from_roi.begin - to_roi.begin + upper_pad = to_roi.end - from_roi.end + pad_width = [(0, 0)] * num_channels + list(zip(lower_pad, upper_pad)) + if self.mode == "constant": + padded = np.pad(a, pad_width, "constant", constant_values=value) + elif self.mode == "reflect": + padded = np.pad(a, pad_width, "reflect") + return padded diff --git a/tests/cases/pad.py b/tests/cases/pad.py index 8b7ab179..5efda685 100644 --- a/tests/cases/pad.py +++ b/tests/cases/pad.py @@ -1,71 +1,96 @@ -from .provider_test import ProviderTest +from .helper_sources import ArraySource, GraphSource from gunpowder import ( - BatchProvider, BatchRequest, - Batch, - ArrayKeys, ArraySpec, Roi, Coordinate, + Graph, GraphKey, - GraphKeys, GraphSpec, Array, ArrayKey, Pad, build, + MergeProvider, ) -import numpy as np - -class ExampleSourcePad(BatchProvider): - def setup(self): - self.provides( - ArrayKeys.TEST_LABELS, - ArraySpec(roi=Roi((200, 20, 20), (1800, 180, 180)), voxel_size=(20, 2, 2)), - ) +import pytest +import numpy as np - self.provides( - GraphKeys.TEST_GRAPH, GraphSpec(roi=Roi((200, 20, 20), (1800, 180, 180))) - ) +from itertools import product - def provide(self, request): - batch = Batch() - roi_array = request[ArrayKeys.TEST_LABELS].roi - roi_voxel = roi_array // self.spec[ArrayKeys.TEST_LABELS].voxel_size +@pytest.mark.parametrize("mode", ["constant", "reflect"]) +def test_padding(mode): + array_key = ArrayKey("TEST_ARRAY") + graph_key = GraphKey("TEST_GRAPH") - data = np.zeros(roi_voxel.shape, dtype=np.uint32) - data[:, ::2] = 100 + array_spec = ArraySpec(roi=Roi((200, 20, 20), (600, 60, 60)), voxel_size=(20, 2, 2)) + roi_voxel = array_spec.roi / array_spec.voxel_size + data = np.zeros(roi_voxel.shape, dtype=np.uint32) + data[:, ::2] = 100 + array = Array(data, spec=array_spec) - spec = self.spec[ArrayKeys.TEST_LABELS].copy() - spec.roi = roi_array - batch.arrays[ArrayKeys.TEST_LABELS] = Array(data, spec=spec) + graph_spec = GraphSpec(roi=Roi((200, 20, 20), (600, 60, 60))) + graph = Graph([], [], graph_spec) - return batch + source = ( + ArraySource(array_key, array), + GraphSource(graph_key, graph), + ) + MergeProvider() + pipeline = ( + source + + Pad(array_key, Coordinate((200, 20, 20)), value=1, mode=mode) + + Pad(graph_key, Coordinate((100, 10, 10)), mode=mode) + ) -class TestPad(ProviderTest): - def test_output(self): - graph = GraphKey("TEST_GRAPH") - labels = ArrayKey("TEST_LABELS") + with build(pipeline): + assert pipeline.spec[array_key].roi == Roi((0, 0, 0), (1000, 100, 100)) + assert pipeline.spec[graph_key].roi == Roi((100, 10, 10), (800, 80, 80)) - pipeline = ( - ExampleSourcePad() - + Pad(labels, Coordinate((20, 20, 20)), value=1) - + Pad(graph, Coordinate((10, 10, 10))) + batch = pipeline.request_batch( + BatchRequest({array_key: ArraySpec(Roi((180, 0, 0), (40, 40, 40)))}) ) - with build(pipeline): - self.assertTrue( - pipeline.spec[labels].roi == Roi((180, 0, 0), (1840, 220, 220)) + data = batch.arrays[array_key].data + if mode == "constant": + octants = [ + (1 * 10 * 10) if zi + yi + xi < 3 else 100 * 1 * 5 * 10 + for zi, yi, xi in product(range(2), range(2), range(2)) + ] + assert np.sum(data) == np.sum(octants), ( + np.sum(data), + np.sum(octants), + np.unique(data), ) - self.assertTrue( - pipeline.spec[graph].roi == Roi((190, 10, 10), (1820, 200, 200)) + elif mode == "reflect": + octants = [100 * 1 * 5 * 10 for _ in range(8)] + assert np.sum(data) == np.sum(octants), ( + np.sum(data), + np.sum(octants), + data, ) - batch = pipeline.request_batch( - BatchRequest({labels: ArraySpec(Roi((180, 0, 0), (20, 20, 20)))}) - ) + # 1 x 10 x (10,30,10) + batch = pipeline.request_batch( + BatchRequest({array_key: ArraySpec(Roi((200, 20, 0), (20, 20, 100)))}) + ) + data = batch.arrays[array_key].data - self.assertEqual(np.sum(batch.arrays[labels].data), 1 * 10 * 10) + if mode == "constant": + lower_pad = 1 * 10 * 10 + upper_pad = 1 * 10 * 10 + center = 100 * 1 * 5 * 30 + assert np.sum(data) == np.sum((lower_pad, upper_pad, center)), ( + np.sum(data), + np.sum((lower_pad, upper_pad, center)), + ) + elif mode == "reflect": + lower_pad = 100 * 1 * 5 * 10 + upper_pad = 100 * 1 * 5 * 10 + center = 100 * 1 * 5 * 30 + assert np.sum(data) == np.sum((lower_pad, upper_pad, center)), ( + np.sum(data), + np.sum((lower_pad, upper_pad, center)), + ) From d44903679997cff612ee6bd5ce71c97706b822d6 Mon Sep 17 00:00:00 2001 From: William Patton Date: Tue, 2 Jan 2024 09:46:44 -0800 Subject: [PATCH 36/74] Fix bug in rasterize graph we were using `graph.data.items()` to iterate over nodes instead of `graph.nodes` Squashed commit of the following: commit d027f5a260a1e2a9cf851efca85b7318434675d6 Author: William Patton Date: Tue Jan 2 09:44:21 2024 -0800 refactor rasterize_points test to use pytest commit eadb0476d8475b55120486df6cf30f95b6df86f4 Author: William Patton Date: Tue Jan 2 09:25:11 2024 -0800 remove extra roi handling The node only needs to request the data it needs for its own operations. If you request a mask for a set of points that extend outside the bounds of your mask you will get an error commit 29507f1f21d69cf76e34e7b0f05cd780100fd68b Author: William Patton Date: Tue Jan 2 09:22:21 2024 -0800 remove type cast we do a bitwise during the `__rasterize` call which results fails if you change the dtype commit 96e93e53ce0bc8240357259dab92f1ca64a08199 Author: William Patton Date: Tue Jan 2 09:21:16 2024 -0800 remove matplotlib commit eb2977a187a1cad95da54a515c84ce44d73b8315 Author: Samia Mohinta <44754434+Mohinta2892@users.noreply.github.com> Date: Thu Dec 14 15:27:41 2023 +0000 fix mask intersection with request outputs must match request rois when a mask is provided commit 682189dac2ef6b94876bd30df813717da6530060 Author: Samia Mohinta <44754434+Mohinta2892@users.noreply.github.com> Date: Thu Dec 14 15:25:12 2023 +0000 Update rasterize_graph.py commit e36dcf179ccd1aec6a5cafd31e7a9a858352faa1 Author: Mohinta2892 Date: Thu Nov 2 19:18:00 2023 +0000 reformat rasterize_graph and rasterize_points commit 42da2702e746f702d3d07144ab9fc1d4352b0c0d Author: Samia Mohinta <44754434+Mohinta2892@users.noreply.github.com> Date: Thu Nov 2 14:23:01 2023 +0000 Test for issue #193 Test added to pass mask to `RasterizeGraph()` via `RasterizationSettings`. commit b17cfad413f5ad7f48045a2167ec20d89674d939 Author: Samia Mohinta <44754434+Mohinta2892@users.noreply.github.com> Date: Thu Nov 2 14:19:42 2023 +0000 fix for issue #193 lines 224-226: replace graph.data.items() with graph.nodes lines 255-257: explicitly cast the boolean mask data to the original dtype of mask_array --- gunpowder/nodes/rasterize_graph.py | 9 +- tests/cases/rasterize_points.py | 412 +++++++++++++++-------------- 2 files changed, 217 insertions(+), 204 deletions(-) diff --git a/gunpowder/nodes/rasterize_graph.py b/gunpowder/nodes/rasterize_graph.py index 1a12335f..b660dd57 100644 --- a/gunpowder/nodes/rasterize_graph.py +++ b/gunpowder/nodes/rasterize_graph.py @@ -221,7 +221,8 @@ def process(self, batch, request): mask_array = batch.arrays[mask].crop(enlarged_vol_roi) # get those component labels in the mask, that contain graph labels = [] - for i, point in graph.data.items(): + # for i, point in graph.data.items(): + for i, point in enumerate(graph.nodes): v = Coordinate(point.location / voxel_size) v -= data_roi.begin labels.append(mask_array.data[v]) @@ -250,11 +251,15 @@ def process(self, batch, request): voxel_size, self.spec[self.array].dtype, self.settings, - Array(data=mask_array.data == label, spec=mask_array.spec), + Array( + data=(mask_array.data == label), + spec=mask_array.spec, + ), ) for label in labels ], axis=0, + dtype=self.spec[self.array].dtype, ) else: diff --git a/tests/cases/rasterize_points.py b/tests/cases/rasterize_points.py index 16c11ae1..a57906f8 100644 --- a/tests/cases/rasterize_points.py +++ b/tests/cases/rasterize_points.py @@ -1,31 +1,30 @@ -from .provider_test import ProviderTest +from .helper_sources import ArraySource, GraphSource from gunpowder import ( - BatchProvider, BatchRequest, - Batch, Roi, Coordinate, GraphSpec, Array, ArrayKey, - ArrayKeys, ArraySpec, RasterizeGraph, + MergeProvider, RasterizationSettings, build, ) -from gunpowder.graph import GraphKeys, GraphKey, Graph, Node, Edge +from gunpowder.graph import GraphKey, Graph, Node, Edge import numpy as np -import math -from random import randint -class GraphTestSource3D(BatchProvider): - def __init__(self): - self.voxel_size = Coordinate((40, 4, 4)) +def test_3d(): + graph_key = GraphKey("TEST_GRAPH") + array_key = ArrayKey("TEST_ARRAY") + rasterized_key = ArrayKey("RASTERIZED_ARRAY") + voxel_size = Coordinate((40, 4, 4)) - self.nodes = [ + graph = Graph( + [ # corners Node(id=1, location=np.array((-200, -200, -200))), Node(id=2, location=np.array((-200, -200, 199))), @@ -38,249 +37,258 @@ def __init__(self): # center Node(id=9, location=np.array((0, 0, 0))), Node(id=10, location=np.array((-1, -1, -1))), - ] - - self.graph_spec = GraphSpec(roi=Roi((-100, -100, -100), (300, 300, 300))) - self.array_spec = ArraySpec( - roi=Roi((-200, -200, -200), (400, 400, 400)), voxel_size=self.voxel_size - ) - - self.graph = Graph(self.nodes, [], self.graph_spec) - - def setup(self): - self.provides( - GraphKeys.TEST_GRAPH, - self.graph_spec, - ) - - self.provides( - ArrayKeys.GT_LABELS, - self.array_spec, - ) - - def provide(self, request): - batch = Batch() - - graph_roi = request[GraphKeys.TEST_GRAPH].roi - - batch.graphs[GraphKeys.TEST_GRAPH] = self.graph.crop(graph_roi).trim(graph_roi) - - roi_array = request[ArrayKeys.GT_LABELS].roi - - image = np.ones(roi_array.shape / self.voxel_size, dtype=np.uint64) - # label half of GT_LABELS differently - depth = image.shape[0] - image[0 : depth // 2] = 2 - - spec = self.spec[ArrayKeys.GT_LABELS].copy() - spec.roi = roi_array - batch.arrays[ArrayKeys.GT_LABELS] = Array(image, spec=spec) - - return batch - - -class GraphTestSourceWithEdge(BatchProvider): - def __init__(self): - self.voxel_size = Coordinate((1, 1, 1)) - - self.nodes = [ - # corners - Node(id=1, location=np.array((0, 4, 4))), - Node(id=2, location=np.array((9, 4, 4))), - ] - self.edges = [Edge(1, 2)] - - self.graph_spec = GraphSpec(roi=Roi((0, 0, 0), (10, 10, 10))) - self.graph = Graph(self.nodes, self.edges, self.graph_spec) - - def setup(self): - self.provides( - GraphKeys.TEST_GRAPH_WITH_EDGE, - self.graph_spec, - ) - - def provide(self, request): - batch = Batch() - - graph_roi = request[GraphKeys.TEST_GRAPH_WITH_EDGE].roi - - batch.graphs[GraphKeys.TEST_GRAPH_WITH_EDGE] = self.graph.crop(graph_roi).trim( - graph_roi - ) - - return batch - - -class TestRasterizePoints(ProviderTest): - def test_3d(self): - GraphKey("TEST_GRAPH") - ArrayKey("RASTERIZED") - - pipeline = GraphTestSource3D() + RasterizeGraph( - GraphKeys.TEST_GRAPH, - ArrayKeys.RASTERIZED, + ], + [], + GraphSpec(roi=Roi((-100, -100, -100), (300, 300, 300))), + ) + + array = Array( + np.ones((10, 100, 100)), + ArraySpec( + roi=Roi((-200, -200, -200), (400, 400, 400)), + voxel_size=voxel_size, + ), + ) + + pipeline = ( + (GraphSource(graph_key, graph), ArraySource(array_key, array)) + + MergeProvider() + + RasterizeGraph( + graph_key, + rasterized_key, ArraySpec(voxel_size=(40, 4, 4)), ) + ) - with build(pipeline): - request = BatchRequest() - roi = Roi((0, 0, 0), (200, 200, 200)) + with build(pipeline): + request = BatchRequest() + roi = Roi((0, 0, 0), (200, 200, 200)) - request[GraphKeys.TEST_GRAPH] = GraphSpec(roi=roi) - request[ArrayKeys.GT_LABELS] = ArraySpec(roi=roi) - request[ArrayKeys.RASTERIZED] = ArraySpec(roi=roi) + request[graph_key] = GraphSpec(roi=roi) + request[array_key] = ArraySpec(roi=roi) + request[rasterized_key] = ArraySpec(roi=roi) - batch = pipeline.request_batch(request) + batch = pipeline.request_batch(request) - rasterized = batch.arrays[ArrayKeys.RASTERIZED].data - self.assertEqual(rasterized[0, 0, 0], 1) - self.assertEqual(rasterized[2, 20, 20], 0) - self.assertEqual(rasterized[4, 49, 49], 1) + rasterized = batch.arrays[rasterized_key].data + assert rasterized[0, 0, 0] == 1 + assert rasterized[2, 20, 20] == 0 + assert rasterized[4, 49, 49] == 1 - # same with different foreground/background labels + # same with different foreground/background labels - pipeline = GraphTestSource3D() + RasterizeGraph( - GraphKeys.TEST_GRAPH, - ArrayKeys.RASTERIZED, + pipeline = ( + (GraphSource(graph_key, graph), ArraySource(array_key, array)) + + MergeProvider() + + RasterizeGraph( + graph_key, + rasterized_key, ArraySpec(voxel_size=(40, 4, 4)), RasterizationSettings(radius=1, fg_value=0, bg_value=1), ) + ) - with build(pipeline): - request = BatchRequest() - roi = Roi((0, 0, 0), (200, 200, 200)) + with build(pipeline): + request = BatchRequest() + roi = Roi((0, 0, 0), (200, 200, 200)) - request[GraphKeys.TEST_GRAPH] = GraphSpec(roi=roi) - request[ArrayKeys.GT_LABELS] = ArraySpec(roi=roi) - request[ArrayKeys.RASTERIZED] = ArraySpec(roi=roi) + request[graph_key] = GraphSpec(roi=roi) + request[array_key] = ArraySpec(roi=roi) + request[rasterized_key] = ArraySpec(roi=roi) - batch = pipeline.request_batch(request) + batch = pipeline.request_batch(request) - rasterized = batch.arrays[ArrayKeys.RASTERIZED].data - self.assertEqual(rasterized[0, 0, 0], 0) - self.assertEqual(rasterized[2, 20, 20], 1) - self.assertEqual(rasterized[4, 49, 49], 0) + rasterized = batch.arrays[rasterized_key].data + assert rasterized[0, 0, 0] == 0 + assert rasterized[2, 20, 20] == 1 + assert rasterized[4, 49, 49] == 0 - # same with different radius and inner radius + # same with different radius and inner radius - pipeline = GraphTestSource3D() + RasterizeGraph( - GraphKeys.TEST_GRAPH, - ArrayKeys.RASTERIZED, + pipeline = ( + (GraphSource(graph_key, graph), ArraySource(array_key, array)) + + MergeProvider() + + RasterizeGraph( + graph_key, + rasterized_key, ArraySpec(voxel_size=(40, 4, 4)), RasterizationSettings( radius=40, inner_radius_fraction=0.25, fg_value=1, bg_value=0 ), ) + ) - with build(pipeline): - request = BatchRequest() - roi = Roi((0, 0, 0), (200, 200, 200)) + with build(pipeline): + request = BatchRequest() + roi = Roi((0, 0, 0), (200, 200, 200)) - request[GraphKeys.TEST_GRAPH] = GraphSpec(roi=roi) - request[ArrayKeys.GT_LABELS] = ArraySpec(roi=roi) - request[ArrayKeys.RASTERIZED] = ArraySpec(roi=roi) + request[graph_key] = GraphSpec(roi=roi) + request[array_key] = ArraySpec(roi=roi) + request[rasterized_key] = ArraySpec(roi=roi) - batch = pipeline.request_batch(request) + batch = pipeline.request_batch(request) - rasterized = batch.arrays[ArrayKeys.RASTERIZED].data + rasterized = batch.arrays[rasterized_key].data - # in the middle of the ball, there should be 0 (since inner radius is set) - self.assertEqual(rasterized[0, 0, 0], 0) - # check larger radius: rasterized point (0, 0, 0) should extend in - # x,y by 10; z, by 1 - self.assertEqual(rasterized[0, 10, 0], 1) - self.assertEqual(rasterized[0, 0, 10], 1) - self.assertEqual(rasterized[1, 0, 0], 1) + # in the middle of the ball, there should be 0 (since inner radius is set) + assert rasterized[0, 0, 0] == 0 + # check larger radius: rasterized point (0, 0, 0) should extend in + # x,y by 10; z, by 1 + assert rasterized[0, 10, 0] == 1 + assert rasterized[0, 0, 10] == 1 + assert rasterized[1, 0, 0] == 1 - self.assertEqual(rasterized[2, 20, 20], 0) - self.assertEqual(rasterized[4, 49, 49], 0) + assert rasterized[2, 20, 20] == 0 + assert rasterized[4, 49, 49] == 0 - # same with anisotropic radius + # same with different foreground/background labels + # and GT_LABELS as mask of type np.uint64. Issue #193 - pipeline = GraphTestSource3D() + RasterizeGraph( - GraphKeys.TEST_GRAPH, - ArrayKeys.RASTERIZED, + pipeline = ( + (GraphSource(graph_key, graph), ArraySource(array_key, array)) + + MergeProvider() + + RasterizeGraph( + graph_key, + rasterized_key, ArraySpec(voxel_size=(40, 4, 4)), - RasterizationSettings(radius=(40, 40, 20), fg_value=1, bg_value=0), + RasterizationSettings(radius=1, fg_value=0, bg_value=1, mask=array_key), ) + ) - with build(pipeline): - request = BatchRequest() - roi = Roi((0, 0, 0), (120, 80, 80)) + with build(pipeline): + request = BatchRequest() + roi = Roi((0, 0, 0), (200, 200, 200)) - request[GraphKeys.TEST_GRAPH] = GraphSpec(roi=roi) - request[ArrayKeys.GT_LABELS] = ArraySpec(roi=roi) - request[ArrayKeys.RASTERIZED] = ArraySpec(roi=roi) + request[graph_key] = GraphSpec(roi=roi) + request[array_key] = ArraySpec(roi=roi) + request[rasterized_key] = ArraySpec(roi=roi) - batch = pipeline.request_batch(request) + batch = pipeline.request_batch(request) - rasterized = batch.arrays[ArrayKeys.RASTERIZED].data + rasterized = batch.arrays[rasterized_key].data + assert rasterized[0, 0, 0] == 0 + assert rasterized[2, 20, 20] == 1 + assert rasterized[4, 49, 49] == 0 - # check larger radius: rasterized point (0, 0, 0) should extend in - # x,y by 10; z, by 1 - self.assertEqual(rasterized[0, 10, 0], 1) - self.assertEqual(rasterized[0, 11, 0], 0) - self.assertEqual(rasterized[0, 0, 5], 1) - self.assertEqual(rasterized[0, 0, 6], 0) - self.assertEqual(rasterized[1, 0, 0], 1) - self.assertEqual(rasterized[2, 0, 0], 0) + # same with anisotropic radius - # same with anisotropic radius and inner radius - - pipeline = GraphTestSource3D() + RasterizeGraph( - GraphKeys.TEST_GRAPH, - ArrayKeys.RASTERIZED, + pipeline = ( + (GraphSource(graph_key, graph), ArraySource(array_key, array)) + + MergeProvider() + + RasterizeGraph( + graph_key, + rasterized_key, ArraySpec(voxel_size=(40, 4, 4)), - RasterizationSettings( - radius=(40, 40, 20), inner_radius_fraction=0.75, fg_value=1, bg_value=0 - ), + RasterizationSettings(radius=(40, 40, 20), fg_value=1, bg_value=0), ) + ) - with build(pipeline): - request = BatchRequest() - roi = Roi((0, 0, 0), (120, 80, 80)) + with build(pipeline): + request = BatchRequest() + roi = Roi((0, 0, 0), (120, 80, 80)) - request[GraphKeys.TEST_GRAPH] = GraphSpec(roi=roi) - request[ArrayKeys.GT_LABELS] = ArraySpec(roi=roi) - request[ArrayKeys.RASTERIZED] = ArraySpec(roi=roi) + request[graph_key] = GraphSpec(roi=roi) + request[array_key] = ArraySpec(roi=roi) + request[rasterized_key] = ArraySpec(roi=roi) - batch = pipeline.request_batch(request) + batch = pipeline.request_batch(request) - rasterized = batch.arrays[ArrayKeys.RASTERIZED].data + rasterized = batch.arrays[rasterized_key].data - # in the middle of the ball, there should be 0 (since inner radius is set) - self.assertEqual(rasterized[0, 0, 0], 0) - # check larger radius: rasterized point (0, 0, 0) should extend in - # x,y by 10; z, by 1 - self.assertEqual(rasterized[0, 10, 0], 1) - self.assertEqual(rasterized[0, 11, 0], 0) - self.assertEqual(rasterized[0, 0, 5], 1) - self.assertEqual(rasterized[0, 0, 6], 0) - self.assertEqual(rasterized[1, 0, 0], 1) - self.assertEqual(rasterized[2, 0, 0], 0) + # check larger radius: rasterized point (0, 0, 0) should extend in + # x,y by 10; z, by 1 + assert rasterized[0, 10, 0] == 1 + assert rasterized[0, 11, 0] == 0 + assert rasterized[0, 0, 5] == 1 + assert rasterized[0, 0, 6] == 0 + assert rasterized[1, 0, 0] == 1 + assert rasterized[2, 0, 0] == 0 - def test_with_edge(self): - graph_with_edge = GraphKey("TEST_GRAPH_WITH_EDGE") - array_with_edge = ArrayKey("RASTERIZED_EDGE") + # same with anisotropic radius and inner radius - pipeline = GraphTestSourceWithEdge() + RasterizeGraph( - GraphKeys.TEST_GRAPH_WITH_EDGE, - ArrayKeys.RASTERIZED_EDGE, + pipeline = ( + (GraphSource(graph_key, graph), ArraySource(array_key, array)) + + MergeProvider() + + RasterizeGraph( + graph_key, + rasterized_key, + ArraySpec(voxel_size=(40, 4, 4)), + RasterizationSettings( + radius=(40, 40, 20), inner_radius_fraction=0.75, fg_value=1, bg_value=0 + ), + ) + ) + + with build(pipeline): + request = BatchRequest() + roi = Roi((0, 0, 0), (120, 80, 80)) + + request[graph_key] = GraphSpec(roi=roi) + request[array_key] = ArraySpec(roi=roi) + request[rasterized_key] = ArraySpec(roi=roi) + + batch = pipeline.request_batch(request) + + rasterized = batch.arrays[rasterized_key].data + + # in the middle of the ball, there should be 0 (since inner radius is set) + assert rasterized[0, 0, 0] == 0 + # check larger radius: rasterized point (0, 0, 0) should extend in + # x,y by 10; z, by 1 + assert rasterized[0, 10, 0] == 1 + assert rasterized[0, 11, 0] == 0 + assert rasterized[0, 0, 5] == 1 + assert rasterized[0, 0, 6] == 0 + assert rasterized[1, 0, 0] == 1 + assert rasterized[2, 0, 0] == 0 + + +def test_with_edge(): + graph_key = GraphKey("TEST_GRAPH") + array_key = ArrayKey("TEST_ARRAY") + rasterized_key = ArrayKey("RASTERIZED_ARRAY") + voxel_size = Coordinate((40, 4, 4)) + + array = Array( + np.ones((10, 100, 100)), + ArraySpec( + roi=Roi((-200, -200, -200), (400, 400, 400)), + voxel_size=voxel_size, + ), + ) + + graph = Graph( + [ + # corners + Node(id=1, location=np.array((0, 4, 4))), + Node(id=2, location=np.array((9, 4, 4))), + ], + [Edge(1, 2)], + GraphSpec(roi=Roi((0, 0, 0), (10, 10, 10))), + ) + + pipeline = ( + (GraphSource(graph_key, graph), ArraySource(array_key, array)) + + MergeProvider() + + RasterizeGraph( + graph_key, + rasterized_key, ArraySpec(voxel_size=(1, 1, 1)), settings=RasterizationSettings(0.5), ) + ) - with build(pipeline): - request = BatchRequest() - roi = Roi((0, 0, 0), (10, 10, 10)) + with build(pipeline): + request = BatchRequest() + roi = Roi((0, 0, 0), (10, 10, 10)) - request[GraphKeys.TEST_GRAPH_WITH_EDGE] = GraphSpec(roi=roi) - request[ArrayKeys.RASTERIZED_EDGE] = ArraySpec(roi=roi) + request[graph_key] = GraphSpec(roi=roi) + request[rasterized_key] = ArraySpec(roi=roi) - batch = pipeline.request_batch(request) + batch = pipeline.request_batch(request) - rasterized = batch.arrays[ArrayKeys.RASTERIZED_EDGE].data + rasterized = batch.arrays[rasterized_key].data - assert ( - rasterized.sum() == 10 - ), f"rasterized has ones at: {np.where(rasterized==1)}" + assert ( + rasterized.sum() == 10 + ), f"rasterized has ones at: {np.where(rasterized==1)}" From d55c97677d0353e4c3e7f0411acda9d9bb85f825 Mon Sep 17 00:00:00 2001 From: William Patton Date: Tue, 2 Jan 2024 09:54:25 -0800 Subject: [PATCH 37/74] ruff: remove unused imports and fix small typos. --- gunpowder/array.py | 2 -- gunpowder/batch.py | 3 +- .../contrib/nodes/add_blobs_from_points.py | 2 +- .../nodes/add_boundary_distance_gradients.py | 3 +- .../nodes/add_gt_mask_exclusive_zone.py | 3 +- .../nodes/add_nonsymmetric_affinities.py | 2 -- gunpowder/contrib/nodes/hdf5_points_source.py | 3 -- gunpowder/ext/__init__.py | 28 +++++++++---------- gunpowder/graph.py | 3 +- gunpowder/nodes/crop.py | 1 - gunpowder/nodes/elastic_augment.py | 1 - gunpowder/nodes/exclude_labels.py | 2 +- gunpowder/nodes/grow_boundary.py | 1 - gunpowder/nodes/klb_source.py | 1 - gunpowder/nodes/merge_provider.py | 1 - gunpowder/nodes/pad.py | 1 - gunpowder/nodes/precache.py | 3 -- gunpowder/nodes/rasterize_graph.py | 2 -- gunpowder/nodes/shift_augment.py | 1 - gunpowder/nodes/specified_location.py | 2 -- gunpowder/nodes/squeeze.py | 1 - gunpowder/nodes/unsqueeze.py | 1 - gunpowder/producer_pool.py | 2 -- gunpowder/provider_spec.py | 2 -- tests/cases/add_affinities.py | 4 --- tests/cases/dvid_source.py | 1 - tests/cases/elastic_augment_points.py | 2 -- tests/cases/expected_failures.py | 2 +- tests/cases/intensity_scale_shift.py | 1 - tests/cases/jax_train.py | 5 +--- tests/cases/noise_augment.py | 1 - tests/cases/random_location.py | 1 - tests/cases/resample.py | 1 - tests/cases/simple_augment.py | 2 -- tests/cases/snapshot.py | 2 -- tests/cases/tensorflow_train.py | 2 +- tests/cases/zarr_read_write.py | 2 +- 37 files changed, 24 insertions(+), 73 deletions(-) diff --git a/gunpowder/array.py b/gunpowder/array.py index 0177cf5a..a8da1322 100644 --- a/gunpowder/array.py +++ b/gunpowder/array.py @@ -1,7 +1,5 @@ from .freezable import Freezable from copy import deepcopy -from gunpowder.coordinate import Coordinate -from gunpowder.roi import Roi import logging import numpy as np import copy diff --git a/gunpowder/batch.py b/gunpowder/batch.py index 412c891f..ffc97e77 100644 --- a/gunpowder/batch.py +++ b/gunpowder/batch.py @@ -1,7 +1,6 @@ from copy import copy as shallow_copy import logging import multiprocessing -import warnings from .freezable import Freezable from .profiling import ProfilingStats @@ -75,7 +74,7 @@ def __setitem__(self, key, value): elif isinstance(value, Graph): assert isinstance( key, GraphKey - ), f"Only a GraphKey is allowed as key for Graph value." + ), "Only a GraphKey is allowed as key for Graph value." self.graphs[key] = value else: diff --git a/gunpowder/contrib/nodes/add_blobs_from_points.py b/gunpowder/contrib/nodes/add_blobs_from_points.py index 03b063ba..a78eb814 100644 --- a/gunpowder/contrib/nodes/add_blobs_from_points.py +++ b/gunpowder/contrib/nodes/add_blobs_from_points.py @@ -143,7 +143,7 @@ def process(self, batch, request): synapse_ids = [] for point_id, point in points.data.items(): # pdb.set_trace() - if not point.partner_ids[0] in partner_points.data.keys(): + if point.partner_ids[0] not in partner_points.data.keys(): logger.warning( "Point %s has no partner. Deleting..." % point_id ) diff --git a/gunpowder/contrib/nodes/add_boundary_distance_gradients.py b/gunpowder/contrib/nodes/add_boundary_distance_gradients.py index 2ef93870..b2897272 100644 --- a/gunpowder/contrib/nodes/add_boundary_distance_gradients.py +++ b/gunpowder/contrib/nodes/add_boundary_distance_gradients.py @@ -4,7 +4,6 @@ from gunpowder.array import Array from gunpowder.batch_request import BatchRequest from gunpowder.nodes.batch_filter import BatchFilter -from numpy.lib.stride_tricks import as_strided from scipy.ndimage.morphology import distance_transform_edt logger = logging.getLogger(__name__) @@ -83,7 +82,7 @@ def prepare(self, request): return deps def process(self, batch, request): - if not self.gradient_array_key in request: + if self.gradient_array_key not in request: return labels = batch.arrays[self.label_array_key].data diff --git a/gunpowder/contrib/nodes/add_gt_mask_exclusive_zone.py b/gunpowder/contrib/nodes/add_gt_mask_exclusive_zone.py index cff056f7..f50e6a70 100644 --- a/gunpowder/contrib/nodes/add_gt_mask_exclusive_zone.py +++ b/gunpowder/contrib/nodes/add_gt_mask_exclusive_zone.py @@ -1,10 +1,9 @@ import copy import logging import numpy as np -from scipy import ndimage from gunpowder.nodes.batch_filter import BatchFilter -from gunpowder.array import Array, ArrayKeys +from gunpowder.array import Array from gunpowder.nodes.rasterize_graph import RasterizationSettings from gunpowder.morphology import enlarge_binary_map diff --git a/gunpowder/contrib/nodes/add_nonsymmetric_affinities.py b/gunpowder/contrib/nodes/add_nonsymmetric_affinities.py index 6d47201b..ef16398b 100644 --- a/gunpowder/contrib/nodes/add_nonsymmetric_affinities.py +++ b/gunpowder/contrib/nodes/add_nonsymmetric_affinities.py @@ -1,7 +1,5 @@ -import copy import logging import numpy as np -import pdb from gunpowder.array import Array from gunpowder.nodes.batch_filter import BatchFilter diff --git a/gunpowder/contrib/nodes/hdf5_points_source.py b/gunpowder/contrib/nodes/hdf5_points_source.py index a3cd2b44..f78630a1 100644 --- a/gunpowder/contrib/nodes/hdf5_points_source.py +++ b/gunpowder/contrib/nodes/hdf5_points_source.py @@ -5,10 +5,7 @@ from gunpowder.batch import Batch from gunpowder.coordinate import Coordinate from gunpowder.ext import h5py -from gunpowder.graph import GraphKey, Graph -from gunpowder.graph_spec import GraphSpec from gunpowder.profiling import Timing -from gunpowder.roi import Roi from gunpowder.nodes.batch_provider import BatchProvider logger = logging.getLogger(__name__) diff --git a/gunpowder/ext/__init__.py b/gunpowder/ext/__init__.py index fdfcfa02..5b51124d 100644 --- a/gunpowder/ext/__init__.py +++ b/gunpowder/ext/__init__.py @@ -21,73 +21,73 @@ def __getattr__(self, item): try: import dvision -except ImportError as e: +except ImportError: dvision = NoSuchModule("dvision") try: import h5py -except ImportError as e: +except ImportError: h5py = NoSuchModule("h5py") try: import pyklb -except ImportError as e: +except ImportError: pyklb = NoSuchModule("pyklb") try: import tensorflow -except ImportError as e: +except ImportError: tensorflow = NoSuchModule("tensorflow") try: import torch -except ImportError as e: +except ImportError: torch = NoSuchModule("torch") try: import tensorboardX -except ImportError as e: +except ImportError: tensorboardX = NoSuchModule("tensorboardX") try: import malis -except ImportError as e: +except ImportError: malis = NoSuchModule("malis") try: import augment -except ImportError as e: +except ImportError: augment = NoSuchModule("augment") ZarrFile: Optional[Any] = None try: import zarr from .zarr_file import ZarrFile -except ImportError as e: +except ImportError: zarr = NoSuchModule("zarr") ZarrFile = None try: import daisy -except ImportError as e: +except ImportError: daisy = NoSuchModule("daisy") try: import jax -except ImportError as e: +except ImportError: jax = NoSuchModule("jax") try: import jax.numpy as jnp -except ImportError as e: +except ImportError: jnp = NoSuchModule("jnp") try: import haiku -except ImportError as e: +except ImportError: haiku = NoSuchModule("haiku") try: import optax -except ImportError as e: +except ImportError: optax = NoSuchModule("optax") diff --git a/gunpowder/graph.py b/gunpowder/graph.py index 91fdb883..3321c5ac 100644 --- a/gunpowder/graph.py +++ b/gunpowder/graph.py @@ -9,7 +9,6 @@ from typing import Dict, Optional, Set, Iterator, Any import logging import itertools -import warnings logger = logging.getLogger(__name__) @@ -485,7 +484,7 @@ def _roi_intercept( offset = outside - inside distance = np.linalg.norm(offset) - assert not np.isclose(distance, 0), f"Inside and Outside are the same location" + assert not np.isclose(distance, 0), "Inside and Outside are the same location" direction = offset / distance # `offset` can be 0 on some but not all axes leaving a 0 in the denominator. diff --git a/gunpowder/nodes/crop.py b/gunpowder/nodes/crop.py index 3e4cdeb5..0584335e 100644 --- a/gunpowder/nodes/crop.py +++ b/gunpowder/nodes/crop.py @@ -1,4 +1,3 @@ -import copy import logging from .batch_filter import BatchFilter diff --git a/gunpowder/nodes/elastic_augment.py b/gunpowder/nodes/elastic_augment.py index d999d6fe..a4413a44 100644 --- a/gunpowder/nodes/elastic_augment.py +++ b/gunpowder/nodes/elastic_augment.py @@ -9,7 +9,6 @@ from gunpowder.coordinate import Coordinate from gunpowder.ext import augment from gunpowder.roi import Roi -from gunpowder.array import ArrayKey import warnings diff --git a/gunpowder/nodes/exclude_labels.py b/gunpowder/nodes/exclude_labels.py index ae38d43a..2592dc25 100644 --- a/gunpowder/nodes/exclude_labels.py +++ b/gunpowder/nodes/exclude_labels.py @@ -71,7 +71,7 @@ def process(self, batch, request): include_mask[gt.data == label] = 0 # if no ignore mask is provided or requested, we are done - if not self.ignore_mask or not self.ignore_mask in request: + if not self.ignore_mask or self.ignore_mask not in request: return voxel_size = self.spec[self.labels].voxel_size diff --git a/gunpowder/nodes/grow_boundary.py b/gunpowder/nodes/grow_boundary.py index 08d20abf..d793345f 100644 --- a/gunpowder/nodes/grow_boundary.py +++ b/gunpowder/nodes/grow_boundary.py @@ -2,7 +2,6 @@ from scipy import ndimage from .batch_filter import BatchFilter -from gunpowder.array import Array class GrowBoundary(BatchFilter): diff --git a/gunpowder/nodes/klb_source.py b/gunpowder/nodes/klb_source.py index e2a3f758..d4776bba 100644 --- a/gunpowder/nodes/klb_source.py +++ b/gunpowder/nodes/klb_source.py @@ -1,4 +1,3 @@ -import copy import logging import numpy as np import glob diff --git a/gunpowder/nodes/merge_provider.py b/gunpowder/nodes/merge_provider.py index 0d32300e..6df979b8 100644 --- a/gunpowder/nodes/merge_provider.py +++ b/gunpowder/nodes/merge_provider.py @@ -1,4 +1,3 @@ -from gunpowder.provider_spec import ProviderSpec from gunpowder.batch import Batch from gunpowder.batch_request import BatchRequest diff --git a/gunpowder/nodes/pad.py b/gunpowder/nodes/pad.py index f025d9d7..758fd04a 100644 --- a/gunpowder/nodes/pad.py +++ b/gunpowder/nodes/pad.py @@ -7,7 +7,6 @@ from gunpowder.coordinate import Coordinate from gunpowder.batch_request import BatchRequest -from itertools import product logger = logging.getLogger(__name__) diff --git a/gunpowder/nodes/precache.py b/gunpowder/nodes/precache.py index ac35d32f..9c58ae53 100644 --- a/gunpowder/nodes/precache.py +++ b/gunpowder/nodes/precache.py @@ -1,7 +1,4 @@ import logging -import multiprocessing -import time -import random from .batch_filter import BatchFilter from gunpowder.profiling import Timing diff --git a/gunpowder/nodes/rasterize_graph.py b/gunpowder/nodes/rasterize_graph.py index b660dd57..bb2473f6 100644 --- a/gunpowder/nodes/rasterize_graph.py +++ b/gunpowder/nodes/rasterize_graph.py @@ -1,4 +1,3 @@ -import copy import logging import numpy as np from scipy.ndimage.filters import gaussian_filter @@ -12,7 +11,6 @@ from gunpowder.freezable import Freezable from gunpowder.morphology import enlarge_binary_map, create_ball_kernel from gunpowder.ndarray import replace -from gunpowder.graph import GraphKey from gunpowder.graph_spec import GraphSpec from gunpowder.roi import Roi diff --git a/gunpowder/nodes/shift_augment.py b/gunpowder/nodes/shift_augment.py index 8761a563..d42b1434 100644 --- a/gunpowder/nodes/shift_augment.py +++ b/gunpowder/nodes/shift_augment.py @@ -4,7 +4,6 @@ import random from gunpowder.roi import Roi from gunpowder.coordinate import Coordinate -from gunpowder.batch_request import BatchRequest from .batch_filter import BatchFilter diff --git a/gunpowder/nodes/specified_location.py b/gunpowder/nodes/specified_location.py index b209e078..cc5e5844 100644 --- a/gunpowder/nodes/specified_location.py +++ b/gunpowder/nodes/specified_location.py @@ -1,10 +1,8 @@ from random import randrange -from random import choice, seed import logging import numpy as np from gunpowder.coordinate import Coordinate -from gunpowder.batch_request import BatchRequest from .batch_filter import BatchFilter diff --git a/gunpowder/nodes/squeeze.py b/gunpowder/nodes/squeeze.py index 2e999714..d0a1469b 100644 --- a/gunpowder/nodes/squeeze.py +++ b/gunpowder/nodes/squeeze.py @@ -1,4 +1,3 @@ -import copy from typing import List import logging diff --git a/gunpowder/nodes/unsqueeze.py b/gunpowder/nodes/unsqueeze.py index 3df03ec4..9f019978 100644 --- a/gunpowder/nodes/unsqueeze.py +++ b/gunpowder/nodes/unsqueeze.py @@ -1,4 +1,3 @@ -import copy from typing import List import logging diff --git a/gunpowder/producer_pool.py b/gunpowder/producer_pool.py index 0f6c2888..73df6b6d 100644 --- a/gunpowder/producer_pool.py +++ b/gunpowder/producer_pool.py @@ -6,8 +6,6 @@ import multiprocessing import os import sys -import time -import traceback import numpy as np diff --git a/gunpowder/provider_spec.py b/gunpowder/provider_spec.py index 7c34324d..6a1ab818 100644 --- a/gunpowder/provider_spec.py +++ b/gunpowder/provider_spec.py @@ -6,7 +6,6 @@ from gunpowder.graph_spec import GraphSpec from gunpowder.roi import Roi from .freezable import Freezable -import time import logging import copy @@ -14,7 +13,6 @@ import logging -import warnings logger = logging.getLogger(__file__) diff --git a/tests/cases/add_affinities.py b/tests/cases/add_affinities.py index 792be290..bd6ebc93 100644 --- a/tests/cases/add_affinities.py +++ b/tests/cases/add_affinities.py @@ -1,10 +1,6 @@ -from .provider_test import ProviderTest from gunpowder import * from itertools import product -from unittest import skipIf -import itertools import numpy as np -import logging class ExampleSource(BatchProvider): diff --git a/tests/cases/dvid_source.py b/tests/cases/dvid_source.py index 8f2c31e6..ac206909 100644 --- a/tests/cases/dvid_source.py +++ b/tests/cases/dvid_source.py @@ -2,7 +2,6 @@ from unittest import skipIf from gunpowder import * from gunpowder.ext import dvision, NoSuchModule -import numpy as np import socket import logging diff --git a/tests/cases/elastic_augment_points.py b/tests/cases/elastic_augment_points.py index 76e9ae2b..ddb99741 100644 --- a/tests/cases/elastic_augment_points.py +++ b/tests/cases/elastic_augment_points.py @@ -1,4 +1,3 @@ -import unittest from gunpowder import ( BatchProvider, Batch, @@ -25,7 +24,6 @@ import numpy as np import math import time -import unittest class PointTestSource3D(BatchProvider): diff --git a/tests/cases/expected_failures.py b/tests/cases/expected_failures.py index 39a2d21e..8adf48af 100644 --- a/tests/cases/expected_failures.py +++ b/tests/cases/expected_failures.py @@ -2,7 +2,7 @@ from gunpowder.nodes.batch_provider import BatchRequestError from .helper_sources import ArraySource -from funlib.geometry import Roi, Coordinate +from funlib.geometry import Coordinate import numpy as np import pytest diff --git a/tests/cases/intensity_scale_shift.py b/tests/cases/intensity_scale_shift.py index d65df3dd..c64b4ec3 100644 --- a/tests/cases/intensity_scale_shift.py +++ b/tests/cases/intensity_scale_shift.py @@ -3,7 +3,6 @@ IntensityScaleShift, ArrayKey, build, - Normalize, Array, ArraySpec, Roi, diff --git a/tests/cases/jax_train.py b/tests/cases/jax_train.py index 14fbad46..2ff55be6 100644 --- a/tests/cases/jax_train.py +++ b/tests/cases/jax_train.py @@ -4,18 +4,15 @@ BatchRequest, ArraySpec, Roi, - Coordinate, ArrayKeys, ArrayKey, Array, Batch, - Scan, - PreCache, build, ) from gunpowder.ext import jax, haiku, optax, NoSuchModule from gunpowder.jax import Train, Predict, GenericJaxModel -from unittest import skipIf, expectedFailure +from unittest import skipIf import numpy as np import logging diff --git a/tests/cases/noise_augment.py b/tests/cases/noise_augment.py index 6e9f635c..1a2bcf63 100644 --- a/tests/cases/noise_augment.py +++ b/tests/cases/noise_augment.py @@ -1,7 +1,6 @@ from .provider_test import ProviderTest from gunpowder import IntensityAugment, ArrayKeys, build, Normalize, NoiseAugment -import numpy as np class TestIntensityAugment(ProviderTest): diff --git a/tests/cases/random_location.py b/tests/cases/random_location.py index df3f1cca..611289a8 100644 --- a/tests/cases/random_location.py +++ b/tests/cases/random_location.py @@ -3,7 +3,6 @@ BatchProvider, Roi, Coordinate, - ArrayKeys, ArrayKey, ArraySpec, Array, diff --git a/tests/cases/resample.py b/tests/cases/resample.py index d7a057b0..9784b152 100644 --- a/tests/cases/resample.py +++ b/tests/cases/resample.py @@ -5,7 +5,6 @@ ArraySpec, Roi, Coordinate, - Batch, BatchRequest, Array, MergeProvider, diff --git a/tests/cases/simple_augment.py b/tests/cases/simple_augment.py index c77709c0..0696213c 100644 --- a/tests/cases/simple_augment.py +++ b/tests/cases/simple_augment.py @@ -1,6 +1,4 @@ from gunpowder import ( - Batch, - BatchProvider, BatchRequest, Array, ArrayKey, diff --git a/tests/cases/snapshot.py b/tests/cases/snapshot.py index 8dcc4443..928076ea 100644 --- a/tests/cases/snapshot.py +++ b/tests/cases/snapshot.py @@ -4,10 +4,8 @@ GraphSpec, Graph, ArrayKey, - ArrayKeys, ArraySpec, Array, - RasterizeGraph, Snapshot, BatchProvider, BatchRequest, diff --git a/tests/cases/tensorflow_train.py b/tests/cases/tensorflow_train.py index f4eae06e..079be0d3 100644 --- a/tests/cases/tensorflow_train.py +++ b/tests/cases/tensorflow_train.py @@ -11,7 +11,7 @@ build, ) from gunpowder.ext import tensorflow, NoSuchModule -from gunpowder.tensorflow import Train, Predict, LocalServer +from gunpowder.tensorflow import Train import multiprocessing import numpy as np from unittest import skipIf diff --git a/tests/cases/zarr_read_write.py b/tests/cases/zarr_read_write.py index 64303174..c6cdb39b 100644 --- a/tests/cases/zarr_read_write.py +++ b/tests/cases/zarr_read_write.py @@ -1,7 +1,7 @@ from .helper_sources import ArraySource from gunpowder import * -from gunpowder.ext import zarr, ZarrFile, NoSuchModule +from gunpowder.ext import zarr, NoSuchModule import pytest import numpy as np From 00543bfff422e88c0f8d5722e3625d2bb148d436 Mon Sep 17 00:00:00 2001 From: William Patton Date: Tue, 2 Jan 2024 15:14:31 -0800 Subject: [PATCH 38/74] Custom BatchRequestError handling in pipeline.request_batch We can filter out some more of the excess error traceback that isn't helpful to the readers. --- gunpowder/pipeline.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/gunpowder/pipeline.py b/gunpowder/pipeline.py index f9976c26..3e4bba8f 100644 --- a/gunpowder/pipeline.py +++ b/gunpowder/pipeline.py @@ -1,5 +1,8 @@ -import logging from gunpowder.nodes import BatchProvider +from gunpowder.nodes.batch_provider import BatchRequestError + +import logging +import traceback logger = logging.getLogger(__name__) @@ -21,13 +24,19 @@ def __str__(self): class PipelineRequestError(Exception): - def __init__(self, pipeline, request): + def __init__(self, pipeline, request, original_traceback=None): self.pipeline = pipeline self.request = request + self.original_traceback = original_traceback def __str__(self): return ( - "Exception in pipeline:\n" + ( + ("".join(self.original_traceback) ) + if self.original_traceback is not None + else "" + ) + + "Exception in pipeline:\n" f"{self.pipeline}\n" "while trying to process request\n" f"{self.request}" @@ -123,6 +132,11 @@ def request_batch(self, request): try: return self.output.request_batch(request) + except BatchRequestError as e: + tb = traceback.format_exception(type(e), e, e.__traceback__) + if isinstance(e, BatchRequestError): + tb = tb[-1:] + raise PipelineRequestError(self, request, original_traceback=tb) from None except Exception as e: raise PipelineRequestError(self, request) from e From 431f1068e81079ea288ce6cf887145aa62a95677 Mon Sep 17 00:00:00 2001 From: William Patton Date: Tue, 2 Jan 2024 20:15:29 -0800 Subject: [PATCH 39/74] black formatting --- gunpowder/pipeline.py | 2 +- tests/cases/noise_augment.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/gunpowder/pipeline.py b/gunpowder/pipeline.py index 3e4bba8f..cad87f1a 100644 --- a/gunpowder/pipeline.py +++ b/gunpowder/pipeline.py @@ -32,7 +32,7 @@ def __init__(self, pipeline, request, original_traceback=None): def __str__(self): return ( ( - ("".join(self.original_traceback) ) + ("".join(self.original_traceback)) if self.original_traceback is not None else "" ) diff --git a/tests/cases/noise_augment.py b/tests/cases/noise_augment.py index 1a2bcf63..57768091 100644 --- a/tests/cases/noise_augment.py +++ b/tests/cases/noise_augment.py @@ -2,7 +2,6 @@ from gunpowder import IntensityAugment, ArrayKeys, build, Normalize, NoiseAugment - class TestIntensityAugment(ProviderTest): def test_shift(self): pipeline = ( From c069592e73021553fd65049c2b1c383229003fdc Mon Sep 17 00:00:00 2001 From: William Patton Date: Tue, 2 Jan 2024 20:18:25 -0800 Subject: [PATCH 40/74] mypy workflow use dev dependencies --- .github/workflows/mypy.yaml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/mypy.yaml b/.github/workflows/mypy.yaml index b5439db2..074fa9e9 100644 --- a/.github/workflows/mypy.yaml +++ b/.github/workflows/mypy.yaml @@ -15,6 +15,5 @@ jobs: uses: actions/checkout@v2 - name: mypy run: | - pip install . - pip install --upgrade mypy + pip install ".[dev]" mypy gunpowder From ecb55e7452a24cb45ebb3632fe7679f5a4148eb9 Mon Sep 17 00:00:00 2001 From: William Patton Date: Tue, 2 Jan 2024 20:19:03 -0800 Subject: [PATCH 41/74] avoid testing on python 3.8, it doesn't support typing very well --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 2a345506..4a2139c5 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -15,7 +15,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.8", "3.9", "3.10", "3.11"] + python-version: ["3.9", "3.10", "3.11"] platform: [ubuntu-latest] steps: From d7b5673a1d0a55d41529d761af177269ce6ae17f Mon Sep 17 00:00:00 2001 From: William Patton Date: Tue, 2 Jan 2024 20:19:26 -0800 Subject: [PATCH 42/74] fix type hint for logdir in torch train --- gunpowder/torch/nodes/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gunpowder/torch/nodes/train.py b/gunpowder/torch/nodes/train.py index e913d8ad..676b2c71 100644 --- a/gunpowder/torch/nodes/train.py +++ b/gunpowder/torch/nodes/train.py @@ -99,7 +99,7 @@ def __init__( array_specs: Optional[Dict[ArrayKey, ArraySpec]] = None, checkpoint_basename: str = "model", save_every: int = 2000, - log_dir: str = None, + log_dir: Optional[str] = None, log_every: int = 1, spawn_subprocess: bool = False, device: str = "cuda", From bcfb7d70b2f66ec815bbe266504ac4589a3be744 Mon Sep 17 00:00:00 2001 From: William Patton Date: Tue, 2 Jan 2024 20:24:02 -0800 Subject: [PATCH 43/74] switch order of decorators to avoid trying to determine if cuda is available if torch isn't installed --- tests/cases/torch_train.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/cases/torch_train.py b/tests/cases/torch_train.py index b368f9d6..707e1b91 100644 --- a/tests/cases/torch_train.py +++ b/tests/cases/torch_train.py @@ -68,6 +68,7 @@ def forward(self, a, b): return d_pred +@skipIf(isinstance(torch, NoSuchModule), "torch is not installed") @pytest.mark.parametrize( "device", [ @@ -80,7 +81,6 @@ def forward(self, a, b): ), ], ) -@skipIf(isinstance(torch, NoSuchModule), "torch is not installed") def test_loss_drops(tmpdir, device): checkpoint_basename = str(tmpdir / "model") @@ -143,6 +143,7 @@ def test_loss_drops(tmpdir, device): assert loss2 < loss1 +@skipIf(isinstance(torch, NoSuchModule), "torch is not installed") @pytest.mark.parametrize( "device", [ @@ -160,7 +161,6 @@ def test_loss_drops(tmpdir, device): ), ], ) -@skipIf(isinstance(torch, NoSuchModule), "torch is not installed") def test_output(device): logging.getLogger("gunpowder.torch.nodes.predict").setLevel(logging.INFO) @@ -222,6 +222,7 @@ def forward(self, a): return pred +@skipIf(isinstance(torch, NoSuchModule), "torch is not installed") @pytest.mark.parametrize( "device", [ @@ -239,7 +240,6 @@ def forward(self, a): ), ], ) -@skipIf(isinstance(torch, NoSuchModule), "torch is not installed") def test_scan(device): logging.getLogger("gunpowder.torch.nodes.predict").setLevel(logging.INFO) @@ -275,6 +275,7 @@ def test_scan(device): assert pred in batch +@skipIf(isinstance(torch, NoSuchModule), "torch is not installed") @pytest.mark.parametrize( "device", [ @@ -292,7 +293,6 @@ def test_scan(device): ), ], ) -@skipIf(isinstance(torch, NoSuchModule), "torch is not installed") def test_precache(device): logging.getLogger("gunpowder.torch.nodes.predict").setLevel(logging.INFO) From 4be523e9dafdf0de3cb4a19588ed863ff4c67e06 Mon Sep 17 00:00:00 2001 From: William Patton Date: Tue, 2 Jan 2024 20:42:05 -0800 Subject: [PATCH 44/74] check if torch is installed before checking if cuda is available --- tests/cases/torch_train.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/tests/cases/torch_train.py b/tests/cases/torch_train.py index 707e1b91..0196c67d 100644 --- a/tests/cases/torch_train.py +++ b/tests/cases/torch_train.py @@ -18,6 +18,8 @@ import logging +TORCH_AVAILABLE = isinstance(torch, NoSuchModule) + # Example 2D source def example_2d_source(array_key: ArrayKey): @@ -52,7 +54,7 @@ def example_train_source(a_key, b_key, c_key): return (source_a, source_b, source_c) + MergeProvider() -if not isinstance(torch, NoSuchModule): +if not TORCH_AVAILABLE: class ExampleLinearModel(torch.nn.Module): def __init__(self): @@ -68,7 +70,7 @@ def forward(self, a, b): return d_pred -@skipIf(isinstance(torch, NoSuchModule), "torch is not installed") +@skipIf(TORCH_AVAILABLE, "torch is not installed") @pytest.mark.parametrize( "device", [ @@ -76,7 +78,8 @@ def forward(self, a, b): pytest.param( "cuda:0", marks=pytest.mark.skipif( - not torch.cuda.is_available(), reason="CUDA not available" + TORCH_AVAILABLE or not torch.cuda.is_available(), + reason="CUDA not available", ), ), ], @@ -143,7 +146,7 @@ def test_loss_drops(tmpdir, device): assert loss2 < loss1 -@skipIf(isinstance(torch, NoSuchModule), "torch is not installed") +@skipIf(TORCH_AVAILABLE, "torch is not installed") @pytest.mark.parametrize( "device", [ @@ -152,7 +155,8 @@ def test_loss_drops(tmpdir, device): "cuda:0", marks=[ pytest.mark.skipif( - not torch.cuda.is_available(), reason="CUDA not available" + TORCH_AVAILABLE or not torch.cuda.is_available(), + reason="CUDA not available", ), pytest.mark.xfail( reason="failing to move model to device when using a subprocess" @@ -207,7 +211,7 @@ def test_output(device): assert np.isclose(batch2[d_pred].data, 2 * (1 + 4 * 2 + 9 * 3)) -if not isinstance(torch, NoSuchModule): +if not TORCH_AVAILABLE: class Example2DModel(torch.nn.Module): def __init__(self): @@ -222,7 +226,7 @@ def forward(self, a): return pred -@skipIf(isinstance(torch, NoSuchModule), "torch is not installed") +@skipIf(TORCH_AVAILABLE, "torch is not installed") @pytest.mark.parametrize( "device", [ @@ -231,7 +235,8 @@ def forward(self, a): "cuda:0", marks=[ pytest.mark.skipif( - not torch.cuda.is_available(), reason="CUDA not available" + TORCH_AVAILABLE or not torch.cuda.is_available(), + reason="CUDA not available", ), pytest.mark.xfail( reason="failing to move model to device in multiprocessing context" @@ -275,7 +280,7 @@ def test_scan(device): assert pred in batch -@skipIf(isinstance(torch, NoSuchModule), "torch is not installed") +@skipIf(TORCH_AVAILABLE, "torch is not installed") @pytest.mark.parametrize( "device", [ @@ -284,7 +289,8 @@ def test_scan(device): "cuda:0", marks=[ pytest.mark.skipif( - not torch.cuda.is_available(), reason="CUDA not available" + TORCH_AVAILABLE or not torch.cuda.is_available(), + reason="CUDA not available", ), pytest.mark.xfail( reason="failing to move model to device in multiprocessing context" From cceeb41422b4c5d31ac35e0d303dcb635f133ce6 Mon Sep 17 00:00:00 2001 From: William Patton Date: Tue, 2 Jan 2024 21:20:49 -0800 Subject: [PATCH 45/74] update funlib.geometry version for mypy typing --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index a389a33c..01441435 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ dependencies = [ "requests", "augment-nd>=0.1.3", "tqdm", - "funlib.geometry>=0.2", + "funlib.geometry>=0.3", "zarr", "networkx>=3.1", ] From ed445f1ecbe0a3cd93d2ee331e55371291a23ac2 Mon Sep 17 00:00:00 2001 From: William Patton Date: Tue, 2 Jan 2024 22:12:05 -0800 Subject: [PATCH 46/74] remove batch.id replaced in tensorflow predict node debug statments with the request. This better indicates the roi being predicted on. replaced in snapshot node with an internal counter --- gunpowder/batch.py | 10 ---------- gunpowder/nodes/snapshot.py | 6 ++++-- gunpowder/tensorflow/nodes/predict.py | 6 +++--- 3 files changed, 7 insertions(+), 15 deletions(-) diff --git a/gunpowder/batch.py b/gunpowder/batch.py index ffc97e77..1ddf200c 100644 --- a/gunpowder/batch.py +++ b/gunpowder/batch.py @@ -44,17 +44,7 @@ class Batch(Freezable): Contains all graphs that have been requested for this batch. """ - __next_id = multiprocessing.Value("L") - - @staticmethod - def get_next_id(): - with Batch.__next_id.get_lock(): - next_id = Batch.__next_id.value - Batch.__next_id.value += 1 - return next_id - def __init__(self): - self.id = Batch.get_next_id() self.profiling_stats = ProfilingStats() self.arrays = {} self.graphs = {} diff --git a/gunpowder/nodes/snapshot.py b/gunpowder/nodes/snapshot.py index 8a9a453e..78ba8f0c 100644 --- a/gunpowder/nodes/snapshot.py +++ b/gunpowder/nodes/snapshot.py @@ -74,7 +74,7 @@ def __init__( self, dataset_names, output_dir="snapshots", - output_filename="{id}.zarr", + output_filename="{iteration}.zarr", every=1, additional_request=None, compression_type=None, @@ -97,6 +97,7 @@ def __init__( self.dataset_dtypes = dataset_dtypes self.mode = "w" + self.id = 0 def write_if(self, batch): """To be implemented in subclasses. @@ -157,6 +158,7 @@ def prepare(self, request): return deps def process(self, batch, request): + self.id += 1 if self.record_snapshot and self.write_if(batch): try: os.makedirs(self.output_dir) @@ -166,7 +168,7 @@ def process(self, batch, request): snapshot_name = os.path.join( self.output_dir, self.output_filename.format( - id=str(batch.id).zfill(8), iteration=int(batch.iteration or 0) + id=str(self.id).zfill(8), iteration=int(batch.iteration or self.id) ), ) logger.info("saving to %s" % snapshot_name) diff --git a/gunpowder/tensorflow/nodes/predict.py b/gunpowder/tensorflow/nodes/predict.py index 0a92a0f6..d2c03498 100644 --- a/gunpowder/tensorflow/nodes/predict.py +++ b/gunpowder/tensorflow/nodes/predict.py @@ -112,7 +112,7 @@ def predict(self, batch, request): break if can_skip: - logger.info("Skipping batch %i (all inputs are 0)" % batch.id) + logger.info(f"Skipping batch for request: {request} (all inputs are 0)") for name, array_key in self.outputs.items(): shape = self.shared_output_arrays[name].shape @@ -124,7 +124,7 @@ def predict(self, batch, request): return - logger.debug("predicting in batch %i", batch.id) + logger.debug(f"predicting for request: {request}") output_tensors = self.__collect_outputs(request) input_data = self.__collect_provided_inputs(batch) @@ -160,7 +160,7 @@ def predict(self, batch, request): spec.roi = request[array_key].roi batch.arrays[array_key] = Array(output_data[array_key], spec) - logger.debug("predicted in batch %i", batch.id) + logger.debug("predicted") def __predict(self): """The background predict process.""" From 23bde3effaa55bc062cdaead97df5bd911391532 Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Wed, 15 May 2024 11:50:15 +0200 Subject: [PATCH 47/74] Provide the separator to the csv points source --- gunpowder/nodes/csv_points_source.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/gunpowder/nodes/csv_points_source.py b/gunpowder/nodes/csv_points_source.py index 59e6c193..1410b9ba 100644 --- a/gunpowder/nodes/csv_points_source.py +++ b/gunpowder/nodes/csv_points_source.py @@ -47,10 +47,14 @@ class CsvPointsSource(BatchProvider): Each line may optionally contain an id for each point. This parameter specifies its location, has to come after the position values. + + sep (``str``): + + Separator in the csv file. Defaults to None """ def __init__( - self, filename, points, points_spec=None, scale=None, ndims=None, id_dim=None + self, filename, points, points_spec=None, scale=None, ndims=None, id_dim=None, sep=None ): self.filename = filename self.points = points @@ -59,6 +63,7 @@ def __init__( self.ndims = ndims self.id_dim = id_dim self.data = None + self.sep = sep def setup(self): self._parse_csv() @@ -117,13 +122,13 @@ def _parse_csv(self): """ with open(self.filename, "r") as f: - self.data = np.array( - [[float(t.strip(",")) for t in line.split()] for line in f], - dtype=np.float32, - ) - - if self.ndims is None: - self.ndims = self.data.shape[1] - + data = [] + for line in f: + split = line.split(self.sep) + if self.ndims is not None: + split = split[0:self.ndims] + data.append(list(map(float, split))) + self.data = np.array(data, dtype=np.float32) + if self.scale is not None: self.data[:, : self.ndims] *= self.scale From 177abf07c37d4c5a9840d1aac5f42ed039968686 Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Thu, 16 May 2024 16:43:38 +0200 Subject: [PATCH 48/74] Use csv reader in csv points source --- gunpowder/nodes/csv_points_source.py | 94 +++++++++++++++------------- tests/cases/shift_augment.py | 4 +- 2 files changed, 55 insertions(+), 43 deletions(-) diff --git a/gunpowder/nodes/csv_points_source.py b/gunpowder/nodes/csv_points_source.py index 1410b9ba..3e397baa 100644 --- a/gunpowder/nodes/csv_points_source.py +++ b/gunpowder/nodes/csv_points_source.py @@ -1,19 +1,23 @@ +from typing import Union, Optional import numpy as np import logging from gunpowder.batch import Batch from gunpowder.coordinate import Coordinate from gunpowder.nodes.batch_provider import BatchProvider -from gunpowder.graph import Node, Graph +from gunpowder.graph import Node, Graph, GraphKey from gunpowder.graph_spec import GraphSpec from gunpowder.profiling import Timing from gunpowder.roi import Roi +import csv logger = logging.getLogger(__name__) class CsvPointsSource(BatchProvider): """Read a set of points from a comma-separated-values text file. Each line - in the file represents one point, e.g. z y x (id) + in the file represents one point, e.g. z y x (id). Note: this reads all + points into memory and finds the ones in the given roi by iterating + over all the points. For large datasets, this may be too slow. Args: @@ -25,6 +29,11 @@ class CsvPointsSource(BatchProvider): The key of the points set to create. + spatial_cols (list[``int``]): + + The columns of the csv that hold the coordinates of the points + (in the order that you want them to be used in training) + points_spec (:class:`GraphSpec`, optional): An optional :class:`GraphSpec` to overwrite the points specs @@ -37,33 +46,35 @@ class CsvPointsSource(BatchProvider): from the CSV file. This is useful if the points refer to voxel positions to convert them to world units. - ndims (``int``): + id_col (``int``, optional): - If ``ndims`` is None, all values in one line are considered as the - location of the point. If positive, only the first ``ndims`` are used. - If negative, all but the last ``-ndims`` are used. + The column of the csv that holds an id for each point. If not + provided, the index of the rows are used as the ids. - id_dim (``int``): + delimiter (``str``, optional): - Each line may optionally contain an id for each point. This parameter - specifies its location, has to come after the position values. - - sep (``str``): - - Separator in the csv file. Defaults to None + Delimiter to pass to the csv reader. Defaults to ",". """ def __init__( - self, filename, points, points_spec=None, scale=None, ndims=None, id_dim=None, sep=None + self, + filename: str, + points: GraphKey, + spatial_cols: list[int], + points_spec: Optional[GraphSpec] = None, + scale: Optional[Union[int, float, tuple, list, np.ndarray]] = None, + id_col: Optional[int] = None, + delimiter: str = ",", ): self.filename = filename self.points = points self.points_spec = points_spec self.scale = scale - self.ndims = ndims - self.id_dim = id_dim - self.data = None - self.sep = sep + self.spatial_cols = spatial_cols + self.id_dim = id_col + self.delimiter = delimiter + self.data: Optional[np.ndarray] = None + self.ids: Optional[list] = None def setup(self): self._parse_csv() @@ -72,8 +83,8 @@ def setup(self): self.provides(self.points, self.points_spec) return - min_bb = Coordinate(np.floor(np.amin(self.data[:, : self.ndims], 0))) - max_bb = Coordinate(np.ceil(np.amax(self.data[:, : self.ndims], 0)) + 1) + min_bb = Coordinate(np.floor(np.amin(self.data, 0))) + max_bb = Coordinate(np.ceil(np.amax(self.data, 0)) + 1) roi = Roi(min_bb, max_bb - min_bb) @@ -89,7 +100,7 @@ def provide(self, request): logger.debug("CSV points source got request for %s", request[self.points].roi) point_filter = np.ones((self.data.shape[0],), dtype=bool) - for d in range(self.ndims): + for d in range(len(self.spatial_cols)): point_filter = np.logical_and(point_filter, self.data[:, d] >= min_bb[d]) point_filter = np.logical_and(point_filter, self.data[:, d] < max_bb[d]) @@ -105,30 +116,29 @@ def provide(self, request): return batch def _get_points(self, point_filter): - filtered = self.data[point_filter][:, : self.ndims] - - if self.id_dim is not None: - ids = self.data[point_filter][:, self.id_dim] - else: - ids = np.arange(len(self.data))[point_filter] - + filtered = self.data[point_filter] + ids = self.ids[point_filter] return [Node(id=i, location=p) for i, p in zip(ids, filtered)] def _parse_csv(self): - """Read one point per line. If ``ndims`` is None, all values in one line - are considered as the location of the point. If positive, only the - first ``ndims`` are used. If negative, all but the last ``-ndims`` are - used. + """Read one point per line, with spatial and id columns determined by + self.spatial_cols and self.id_col. """ + data = [] + ids = [] + with open(self.filename, "r", newline="") as f: + reader = csv.reader(f, delimiter=self.delimiter) + for line in reader: + space = [line[c] for c in self.spatial_cols] + if self.id_dim is not None: + ids.append(line[self.id_dim]) + data.append(list(map(float, space))) + + self.data = np.array(data, dtype=np.float32) + if self.id_dim: + self.ids = np.array(data) + else: + self.ids = np.arange(len(self.data)) - with open(self.filename, "r") as f: - data = [] - for line in f: - split = line.split(self.sep) - if self.ndims is not None: - split = split[0:self.ndims] - data.append(list(map(float, split))) - self.data = np.array(data, dtype=np.float32) - if self.scale is not None: - self.data[:, : self.ndims] *= self.scale + self.data *= self.scale diff --git a/tests/cases/shift_augment.py b/tests/cases/shift_augment.py index 75ab40e3..f35862b6 100644 --- a/tests/cases/shift_augment.py +++ b/tests/cases/shift_augment.py @@ -143,7 +143,9 @@ def test_pipeline3(test_points): csv_source = CsvPointsSource( fake_points_file, points_key, - GraphSpec(roi=Roi(shape=Coordinate((100, 100)), offset=(0, 0))), + spatial_cols=[0,1,], + delimiter="\t", + points_spec=GraphSpec(roi=Roi(shape=Coordinate((100, 100)), offset=(0, 0))), ) request = BatchRequest() From 0bb86c10a9244363baa4b56fa48ccf88f86cc1f2 Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Thu, 16 May 2024 16:45:23 +0200 Subject: [PATCH 49/74] Test csv points source with new dev dependencies --- pyproject.toml | 12 ++++- tests/cases/csv_points_source.py | 75 ++++++++++++++++++++++++++++++++ 2 files changed, 86 insertions(+), 1 deletion(-) create mode 100644 tests/cases/csv_points_source.py diff --git a/pyproject.toml b/pyproject.toml index 01441435..0a05c01c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,7 +43,17 @@ dependencies = [ ] [project.optional-dependencies] -dev = ["pytest", "pytest-cov", "flake8", "mypy", "types-requests", "types-tqdm"] +dev = [ + "pytest", + "pytest-cov", + "pyest_unordered", + "flake8", + "mypy", + "types-requests", + "types-tqdm", + "black", + "ruff", +] docs = [ "sphinx", "sphinx_rtd_theme", diff --git a/tests/cases/csv_points_source.py b/tests/cases/csv_points_source.py new file mode 100644 index 00000000..f53950b1 --- /dev/null +++ b/tests/cases/csv_points_source.py @@ -0,0 +1,75 @@ +import random + +import numpy as np +import pytest +from pytest_unordered import unordered +import unittest + +from gunpowder import ( + ArraySpec, + BatchRequest, + CsvPointsSource, + GraphKey, + GraphSpec, + build, + Coordinate, + Roi +) + +# automatically set the seed for all tests +@pytest.fixture(autouse=True) +def seeds(): + random.seed(12345) + np.random.seed(12345) + + +@pytest.fixture +def test_points(tmpdir): + random.seed(1234) + np.random.seed(1234) + + fake_points_file = tmpdir / "shift_test.csv" + fake_points = np.random.randint(0, 100, size=(2, 2)) + with open(fake_points_file, "w") as f: + for point in fake_points: + f.write(str(point[0]) + "\t" + str(point[1]) + "\n") + + # This fixture will run after seeds since it is set + # with autouse=True. So make sure to reset the seeds properly at the end + # of this fixture + random.seed(12345) + np.random.seed(12345) + + yield fake_points_file, fake_points + + +def test_pipeline3(test_points): + fake_points_file, fake_points = test_points + + points_key = GraphKey("TEST_POINTS") + voxel_size = Coordinate((1, 1)) + spec = ArraySpec(voxel_size=voxel_size, interpolatable=True) + + csv_source = CsvPointsSource( + fake_points_file, + points_key, + spatial_cols=[0,1,], + delimiter="\t", + points_spec=GraphSpec(roi=Roi(shape=Coordinate((100, 100)), offset=(0, 0))), + ) + + request = BatchRequest() + shape = Coordinate((100, 100)) + request.add(points_key, shape) + + pipeline = ( + csv_source + ) + with build(pipeline) as b: + request = b.request_batch(request) + + target_locs = [list(fake_point) for fake_point in fake_points] + result_points = list(request[points_key].nodes) + result_locs = [list(point.location) for point in result_points] + + assert result_locs == unordered(target_locs) \ No newline at end of file From 7d57f92a28f64b442a840356bdc812b7e6ef0b0c Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Thu, 16 May 2024 16:45:45 +0200 Subject: [PATCH 50/74] Update required python to 3.9 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 0a05c01c..945003c0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ dynamic = ["version"] classifiers = ["Programming Language :: Python :: 3"] keywords = [] -requires-python = ">=3.7" +requires-python = ">=3.9" dependencies = [ "numpy>=1.24", From a2cdf719f710ce0e7c713872333ea2ec80b1d7e8 Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Thu, 16 May 2024 16:48:41 +0200 Subject: [PATCH 51/74] Black test cases --- tests/cases/csv_points_source.py | 16 +++++++++------- tests/cases/shift_augment.py | 5 ++++- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/tests/cases/csv_points_source.py b/tests/cases/csv_points_source.py index f53950b1..90c67cf9 100644 --- a/tests/cases/csv_points_source.py +++ b/tests/cases/csv_points_source.py @@ -13,9 +13,10 @@ GraphSpec, build, Coordinate, - Roi + Roi, ) + # automatically set the seed for all tests @pytest.fixture(autouse=True) def seeds(): @@ -53,7 +54,10 @@ def test_pipeline3(test_points): csv_source = CsvPointsSource( fake_points_file, points_key, - spatial_cols=[0,1,], + spatial_cols=[ + 0, + 1, + ], delimiter="\t", points_spec=GraphSpec(roi=Roi(shape=Coordinate((100, 100)), offset=(0, 0))), ) @@ -62,14 +66,12 @@ def test_pipeline3(test_points): shape = Coordinate((100, 100)) request.add(points_key, shape) - pipeline = ( - csv_source - ) + pipeline = csv_source with build(pipeline) as b: request = b.request_batch(request) target_locs = [list(fake_point) for fake_point in fake_points] result_points = list(request[points_key].nodes) result_locs = [list(point.location) for point in result_points] - - assert result_locs == unordered(target_locs) \ No newline at end of file + + assert result_locs == unordered(target_locs) diff --git a/tests/cases/shift_augment.py b/tests/cases/shift_augment.py index f35862b6..b92c71f9 100644 --- a/tests/cases/shift_augment.py +++ b/tests/cases/shift_augment.py @@ -143,7 +143,10 @@ def test_pipeline3(test_points): csv_source = CsvPointsSource( fake_points_file, points_key, - spatial_cols=[0,1,], + spatial_cols=[ + 0, + 1, + ], delimiter="\t", points_spec=GraphSpec(roi=Roi(shape=Coordinate((100, 100)), offset=(0, 0))), ) From 807e68244eb35d358788f7faa2c8419707b913d6 Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Thu, 16 May 2024 16:50:46 +0200 Subject: [PATCH 52/74] Fix typos in pytest unordered dependency --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 945003c0..a3861d0a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,7 @@ dependencies = [ dev = [ "pytest", "pytest-cov", - "pyest_unordered", + "pytest-unordered", "flake8", "mypy", "types-requests", From 1b1a6983d7a2918e0f14c097bbf8a433cf7ce5eb Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Thu, 16 May 2024 17:02:12 +0200 Subject: [PATCH 53/74] Remove pytest unordered dependency --- pyproject.toml | 1 - tests/cases/csv_points_source.py | 3 +-- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a3861d0a..81789f6a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,6 @@ dependencies = [ dev = [ "pytest", "pytest-cov", - "pytest-unordered", "flake8", "mypy", "types-requests", diff --git a/tests/cases/csv_points_source.py b/tests/cases/csv_points_source.py index 90c67cf9..95969213 100644 --- a/tests/cases/csv_points_source.py +++ b/tests/cases/csv_points_source.py @@ -2,7 +2,6 @@ import numpy as np import pytest -from pytest_unordered import unordered import unittest from gunpowder import ( @@ -74,4 +73,4 @@ def test_pipeline3(test_points): result_points = list(request[points_key].nodes) result_locs = [list(point.location) for point in result_points] - assert result_locs == unordered(target_locs) + assert sorted(result_locs) == sorted(target_locs) From 011e5ef7a49583aa4cf972926cc54d74ba3c543c Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Thu, 16 May 2024 17:04:33 +0200 Subject: [PATCH 54/74] Black and ruff CSVPointsSource tests --- tests/cases/csv_points_source.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/tests/cases/csv_points_source.py b/tests/cases/csv_points_source.py index 95969213..1fec0833 100644 --- a/tests/cases/csv_points_source.py +++ b/tests/cases/csv_points_source.py @@ -2,10 +2,8 @@ import numpy as np import pytest -import unittest from gunpowder import ( - ArraySpec, BatchRequest, CsvPointsSource, GraphKey, @@ -47,16 +45,11 @@ def test_pipeline3(test_points): fake_points_file, fake_points = test_points points_key = GraphKey("TEST_POINTS") - voxel_size = Coordinate((1, 1)) - spec = ArraySpec(voxel_size=voxel_size, interpolatable=True) csv_source = CsvPointsSource( fake_points_file, points_key, - spatial_cols=[ - 0, - 1, - ], + spatial_cols=[0, 1], delimiter="\t", points_spec=GraphSpec(roi=Roi(shape=Coordinate((100, 100)), offset=(0, 0))), ) From 6b8bfeaae5a32759f2eaec811ff7b59e355d7048 Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Thu, 16 May 2024 17:39:05 +0200 Subject: [PATCH 55/74] Correctly read and document ids in CsvPointsSource --- gunpowder/nodes/csv_points_source.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/gunpowder/nodes/csv_points_source.py b/gunpowder/nodes/csv_points_source.py index 3e397baa..4f5d2516 100644 --- a/gunpowder/nodes/csv_points_source.py +++ b/gunpowder/nodes/csv_points_source.py @@ -49,7 +49,8 @@ class CsvPointsSource(BatchProvider): id_col (``int``, optional): The column of the csv that holds an id for each point. If not - provided, the index of the rows are used as the ids. + provided, the index of the rows are used as the ids. When read + from file, ids are left as strings and not cast to anything. delimiter (``str``, optional): @@ -136,7 +137,7 @@ def _parse_csv(self): self.data = np.array(data, dtype=np.float32) if self.id_dim: - self.ids = np.array(data) + self.ids = np.array(ids) else: self.ids = np.arange(len(self.data)) From 4c54298e84e83c0ba53e30d17479f698b6128f59 Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Thu, 16 May 2024 17:39:43 +0200 Subject: [PATCH 56/74] Automatically detect header in CSVPointsSource --- gunpowder/nodes/csv_points_source.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/gunpowder/nodes/csv_points_source.py b/gunpowder/nodes/csv_points_source.py index 4f5d2516..7f40df32 100644 --- a/gunpowder/nodes/csv_points_source.py +++ b/gunpowder/nodes/csv_points_source.py @@ -128,12 +128,18 @@ def _parse_csv(self): data = [] ids = [] with open(self.filename, "r", newline="") as f: + has_header = csv.Sniffer().has_header(f.read(1024)) + f.seek(0) + first_line = True reader = csv.reader(f, delimiter=self.delimiter) for line in reader: - space = [line[c] for c in self.spatial_cols] + if first_line and has_header: + first_line = False + continue + space = [float(line[c]) for c in self.spatial_cols] + data.append(space) if self.id_dim is not None: ids.append(line[self.id_dim]) - data.append(list(map(float, space))) self.data = np.array(data, dtype=np.float32) if self.id_dim: From 18409b755d72340da145e0f81f00c11431261b69 Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Thu, 16 May 2024 17:40:04 +0200 Subject: [PATCH 57/74] Test all CSVPointsSource functionality --- tests/cases/csv_points_source.py | 62 ++++++++++++++++++++++++++++++-- 1 file changed, 59 insertions(+), 3 deletions(-) diff --git a/tests/cases/csv_points_source.py b/tests/cases/csv_points_source.py index 1fec0833..66ba613a 100644 --- a/tests/cases/csv_points_source.py +++ b/tests/cases/csv_points_source.py @@ -2,6 +2,7 @@ import numpy as np import pytest +import csv from gunpowder import ( BatchRequest, @@ -22,7 +23,7 @@ def seeds(): @pytest.fixture -def test_points(tmpdir): +def test_points_2d(tmpdir): random.seed(1234) np.random.seed(1234) @@ -41,8 +42,31 @@ def test_points(tmpdir): yield fake_points_file, fake_points -def test_pipeline3(test_points): - fake_points_file, fake_points = test_points +@pytest.fixture +def test_points_3d(tmpdir): + random.seed(1234) + np.random.seed(1234) + + fake_points_file = tmpdir / "shift_test.csv" + fake_points = np.random.randint(0, 100, size=(3, 3)).astype(float) + with open(fake_points_file, "w") as f: + writer = csv.DictWriter(f, fieldnames=["x", "y", "z", "id"]) + writer.writeheader() + for i, point in enumerate(fake_points): + pointdict = {"x": point[0], "y": point[1], "z": point[2], "id": i} + writer.writerow(pointdict) + + # This fixture will run after seeds since it is set + # with autouse=True. So make sure to reset the seeds properly at the end + # of this fixture + random.seed(12345) + np.random.seed(12345) + + yield fake_points_file, fake_points + + +def test_pipeline_2d(test_points_2d): + fake_points_file, fake_points = test_points_2d points_key = GraphKey("TEST_POINTS") @@ -67,3 +91,35 @@ def test_pipeline3(test_points): result_locs = [list(point.location) for point in result_points] assert sorted(result_locs) == sorted(target_locs) + + +def test_pipeline_3d(test_points_3d): + fake_points_file, fake_points = test_points_3d + + points_key = GraphKey("TEST_POINTS") + scale = 2 + csv_source = CsvPointsSource( + fake_points_file, + points_key, + spatial_cols=[0, 2, 1], + delimiter=",", + id_col=3, + points_spec=GraphSpec(roi=Roi(shape=Coordinate((100, 100)), offset=(0, 0))), + scale=scale, + ) + + request = BatchRequest() + shape = Coordinate((100, 100, 100)) + request.add(points_key, shape) + + pipeline = csv_source + with build(pipeline) as b: + request = b.request_batch(request) + + result_points = list(request[points_key].nodes) + for node in result_points: + orig_loc = fake_points[int(node.id)] + reordered_loc = orig_loc.copy() + reordered_loc[1] = orig_loc[2] + reordered_loc[2] = orig_loc[1] + assert list(node.location) == list(reordered_loc * scale) From 832e475cc6da6f3e3901a0120e038ef8da5f9835 Mon Sep 17 00:00:00 2001 From: pattonw Date: Thu, 30 May 2024 09:57:53 -0700 Subject: [PATCH 58/74] add support for args as inputs to predict.py Its often not so straightforward to know the key word argument name for the forward function of your model. Especially if you use something like `torch.nn.Sequential` --- gunpowder/torch/nodes/predict.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/gunpowder/torch/nodes/predict.py b/gunpowder/torch/nodes/predict.py index 89c9ac0c..3bc58344 100644 --- a/gunpowder/torch/nodes/predict.py +++ b/gunpowder/torch/nodes/predict.py @@ -18,10 +18,10 @@ class Predict(GenericPredict): The model to use for prediction. - inputs (``dict``, ``string`` -> :class:`ArrayKey`): + inputs (``dict``, ``string`` or ``int`` -> :class:`ArrayKey`): - Dictionary from the names of input tensors (argument names of the - ``forward`` method) in the model to array keys. + Dictionary from the position (for args) and names (for kwargs) of input + tensors (argument names of the ``forward`` method) in the model to array keys. outputs (``dict``, ``string`` or ``int`` -> :class:`ArrayKey`): @@ -58,7 +58,7 @@ class Predict(GenericPredict): def __init__( self, model, - inputs: Dict[str, ArrayKey], + inputs: Dict[Union[str, int], ArrayKey], outputs: Dict[Union[str, int], ArrayKey], array_specs: Optional[Dict[ArrayKey, ArraySpec]] = None, checkpoint: Optional[str] = None, @@ -111,18 +111,24 @@ def start(self): self.register_hooks() def predict(self, batch, request): - inputs = self.get_inputs(batch) + input_args, input_kwargs = self.get_inputs(batch) with torch.no_grad(): - out = self.model.forward(**inputs) + out = self.model.forward(*input_args, **input_kwargs) outputs = self.get_outputs(out, request) self.update_batch(batch, request, outputs) def get_inputs(self, batch): - model_inputs = { + model_args = [ + torch.as_tensor(batch[self.inputs[ii]].data, device=self.device) + for ii in range(len(self.inputs)) + if ii in self.inputs + ] + model_kwargs = { key: torch.as_tensor(batch[value].data, device=self.device) for key, value in self.inputs.items() + if isinstance(key, str) } - return model_inputs + return model_args, model_kwargs def register_hooks(self): for key in self.outputs: From 9bd61281da055090bc853fb85e2ad10647bcb05b Mon Sep 17 00:00:00 2001 From: William Patton Date: Thu, 13 Jun 2024 17:26:30 -0700 Subject: [PATCH 59/74] black reformat pad.py test --- tests/cases/pad.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/cases/pad.py b/tests/cases/pad.py index e2c1b7cd..a20861ec 100644 --- a/tests/cases/pad.py +++ b/tests/cases/pad.py @@ -23,6 +23,7 @@ from itertools import product + @pytest.mark.parametrize("mode", ["constant", "reflect"]) def test_padding(mode): array_key = ArrayKey("TEST_ARRAY") From fcdee74e486a41cf0594496fefe7540d056a9150 Mon Sep 17 00:00:00 2001 From: William Patton Date: Thu, 13 Jun 2024 17:41:35 -0700 Subject: [PATCH 60/74] remove excessive seed setting. I don't think this is necessary since as soon as the seeds are set, the rest of the tests are determanistic --- tests/cases/csv_points_source.py | 16 ---------------- tests/cases/shift_augment.py | 8 -------- 2 files changed, 24 deletions(-) diff --git a/tests/cases/csv_points_source.py b/tests/cases/csv_points_source.py index 66ba613a..ec8d88ec 100644 --- a/tests/cases/csv_points_source.py +++ b/tests/cases/csv_points_source.py @@ -24,8 +24,6 @@ def seeds(): @pytest.fixture def test_points_2d(tmpdir): - random.seed(1234) - np.random.seed(1234) fake_points_file = tmpdir / "shift_test.csv" fake_points = np.random.randint(0, 100, size=(2, 2)) @@ -33,19 +31,11 @@ def test_points_2d(tmpdir): for point in fake_points: f.write(str(point[0]) + "\t" + str(point[1]) + "\n") - # This fixture will run after seeds since it is set - # with autouse=True. So make sure to reset the seeds properly at the end - # of this fixture - random.seed(12345) - np.random.seed(12345) - yield fake_points_file, fake_points @pytest.fixture def test_points_3d(tmpdir): - random.seed(1234) - np.random.seed(1234) fake_points_file = tmpdir / "shift_test.csv" fake_points = np.random.randint(0, 100, size=(3, 3)).astype(float) @@ -56,12 +46,6 @@ def test_points_3d(tmpdir): pointdict = {"x": point[0], "y": point[1], "z": point[2], "id": i} writer.writerow(pointdict) - # This fixture will run after seeds since it is set - # with autouse=True. So make sure to reset the seeds properly at the end - # of this fixture - random.seed(12345) - np.random.seed(12345) - yield fake_points_file, fake_points diff --git a/tests/cases/shift_augment.py b/tests/cases/shift_augment.py index b92c71f9..b53b9b4a 100644 --- a/tests/cases/shift_augment.py +++ b/tests/cases/shift_augment.py @@ -33,8 +33,6 @@ def seeds(): @pytest.fixture def test_points(tmpdir): - random.seed(1234) - np.random.seed(1234) fake_points_file = tmpdir / "shift_test.csv" fake_data_file = tmpdir / "shift_test.hdf5" @@ -46,12 +44,6 @@ def test_points(tmpdir): for point in fake_points: f.write(str(point[0]) + "\t" + str(point[1]) + "\n") - # This fixture will run after seeds since it is set - # with autouse=True. So make sure to reset the seeds properly at the end - # of this fixture - random.seed(12345) - np.random.seed(12345) - yield fake_points_file, fake_data_file, fake_points, fake_data From 89a0354861ccd72d8cdbca5106d8b7bb3dbd1d49 Mon Sep 17 00:00:00 2001 From: pattonw Date: Thu, 13 Jun 2024 17:55:05 -0700 Subject: [PATCH 61/74] Pytorch Train: let users specify model inputs as args instead of kwargs --- gunpowder/torch/nodes/train.py | 45 ++++++++++++++++++++++------------ 1 file changed, 29 insertions(+), 16 deletions(-) diff --git a/gunpowder/torch/nodes/train.py b/gunpowder/torch/nodes/train.py index 676b2c71..d5be0f9a 100644 --- a/gunpowder/torch/nodes/train.py +++ b/gunpowder/torch/nodes/train.py @@ -29,7 +29,7 @@ class Train(GenericTrain): The torch optimizer to use. - inputs (``dict``, ``string`` -> :class:`ArrayKey`): + inputs (``dict``, ``string`` or ``int`` -> :class:`ArrayKey`): Dictionary from the names of input tensors (argument names of the ``forward`` method) in the model to array keys. @@ -92,7 +92,7 @@ def __init__( model, loss, optimizer, - inputs: Dict[str, ArrayKey], + inputs: Dict[Union[str, int], ArrayKey], outputs: Dict[Union[int, str], ArrayKey], loss_inputs: Dict[Union[int, str], ArrayKey], gradients: Dict[Union[int, str], ArrayKey] = {}, @@ -112,11 +112,11 @@ def __init__( # not yet implemented gradients = gradients - all_inputs = { - k: v - for k, v in itertools.chain(inputs.items(), loss_inputs.items()) - if v not in outputs.values() - } + loss_inputs = {f"loss_{k}": v for k, v in loss_inputs.items()} + all_inputs = {f"{k}": v for k, v in inputs.items() if v not in outputs.values()} + all_inputs.update( + {k: v for k, v in loss_inputs.items() if v not in outputs.values()} + ) super(Train, self).__init__( all_inputs, @@ -208,16 +208,22 @@ def start(self): def train_step(self, batch, request): inputs = self.__collect_provided_inputs(batch) + inputs = {k: torch.as_tensor(v, device=self.device) for k, v in inputs.items()} requested_outputs = self.__collect_requested_outputs(request) # keys are argument names of model forward pass - device_inputs = { - k: torch.as_tensor(v, device=self.device) for k, v in inputs.items() - } + device_input_args = [] + for i in range(len(inputs)): + key = f"{i}" + if key in inputs: + device_input_args.append(inputs.pop(key)) + else: + break + device_input_kwargs = {k: v for k, v in inputs.items() if isinstance(k, str)} # get outputs. Keys are tuple indices or model attr names as in self.outputs self.optimizer.zero_grad() - model_outputs = self.model(**device_inputs) + model_outputs = self.model(*device_input_args, **device_input_kwargs) if isinstance(model_outputs, tuple): outputs = {i: model_outputs[i] for i in range(len(model_outputs))} elif isinstance(model_outputs, torch.Tensor): @@ -247,8 +253,9 @@ def train_step(self, batch, request): device_loss_args = [] for i in range(len(device_loss_inputs)): - if i in device_loss_inputs: - device_loss_args.append(device_loss_inputs.pop(i)) + key = f"loss_{i}" + if key in device_loss_inputs: + device_loss_args.append(device_loss_inputs.pop(key)) else: break device_loss_kwargs = {} @@ -327,7 +334,12 @@ def __collect_requested_outputs(self, request): def __collect_provided_inputs(self, batch): return self.__collect_provided_arrays( - {k: v for k, v in self.inputs.items() if k not in self.loss_inputs}, batch + { + k: v + for k, v in self.inputs.items() + if (isinstance(k, int) or k not in self.loss_inputs) + }, + batch, ) def __collect_provided_loss_inputs(self, batch): @@ -353,8 +365,9 @@ def __collect_provided_arrays(self, reference, batch, expect_missing_arrays=Fals arrays[array_name] = getattr(batch, array_key) else: raise Exception( - "Unknown network array key {}, can't be given to " - "network".format(array_key) + "Unknown network array key {}, can't be given to " "network".format( + array_key + ) ) return arrays From 1d4bf9ab4db2398ad7a9febf7ea44cc9933c9dfc Mon Sep 17 00:00:00 2001 From: pattonw Date: Thu, 13 Jun 2024 18:00:48 -0700 Subject: [PATCH 62/74] PyTorch Train: add tests for using arg indexes for model inputs --- tests/cases/torch_train.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/cases/torch_train.py b/tests/cases/torch_train.py index 8fb9e8ec..c66b56f2 100644 --- a/tests/cases/torch_train.py +++ b/tests/cases/torch_train.py @@ -86,7 +86,8 @@ def forward(self, a, b): ), ], ) -def test_loss_drops(tmpdir, device): +@pytest.mark.parametrize("input_args", [True, False]) +def test_loss_drops(tmpdir, device, input_args): checkpoint_basename = str(tmpdir / "model") a_key = ArrayKey("A") @@ -104,7 +105,7 @@ def test_loss_drops(tmpdir, device): model=model, optimizer=optimizer, loss=loss, - inputs={"a": a_key, "b": b_key}, + inputs={"a": a_key, "b": b_key} if not input_args else {0: a_key, 1: b_key}, loss_inputs={0: c_predicted_key, 1: c_key}, outputs={0: c_predicted_key}, gradients={0: c_gradient_key}, @@ -167,7 +168,8 @@ def test_loss_drops(tmpdir, device): ), ], ) -def test_output(device): +@pytest.mark.parametrize("input_args", [True, False]) +def test_spawn_subprocess(device, input_args): logging.getLogger("gunpowder.torch.nodes.predict").setLevel(logging.INFO) a_key = ArrayKey("A") @@ -181,7 +183,7 @@ def test_output(device): source = example_train_source(a_key, b_key, c_key) predict = Predict( model=model, - inputs={"a": a_key, "b": b_key}, + inputs={"a": a_key, "b": b_key} if not input_args else {0: a_key, 1: b_key}, outputs={"linear": c_pred, 0: d_pred}, array_specs={ c_key: ArraySpec(nonspatial=True), From 9d3058cdb58c402dd4b919e25a8229adad1bc449 Mon Sep 17 00:00:00 2001 From: William Patton Date: Thu, 29 Aug 2024 14:06:58 -0400 Subject: [PATCH 63/74] depend on overhauled funlib.persistence --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 061935d7..ecf3ddcd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ dependencies = [ "funlib.geometry>=0.3", "zarr", "networkx>=3.1", + "funlib.persistence>=0.5", ] [project.optional-dependencies] From 7236845bc7969d50ecc020695a4439fa96b84dee Mon Sep 17 00:00:00 2001 From: William Patton Date: Thu, 29 Aug 2024 14:07:07 -0400 Subject: [PATCH 64/74] add funlib.persistence array source --- gunpowder/nodes/array_source.py | 55 +++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 gunpowder/nodes/array_source.py diff --git a/gunpowder/nodes/array_source.py b/gunpowder/nodes/array_source.py new file mode 100644 index 00000000..4ebbbb26 --- /dev/null +++ b/gunpowder/nodes/array_source.py @@ -0,0 +1,55 @@ +from funlib.persistence.arrays import Array as PersistenceArray +from gunpowder import Array, ArrayKey, Batch, BatchProvider, ArraySpec + + +class ArraySource(BatchProvider): + """A `array `_ source. + + Provides a source for any array that can fit into the funkelab + funlib.persistence.Array format. This class comes with assumptions about + the available metadata and convenient methods for indexing the data + with a :class:`Roi` in world units. + + Args: + + key (:class:`ArrayKey`): + + The ArrayKey for accessing this array. + + array (``Array``): + + A `funlib.persistence.Array` object. + """ + + def __init__( + self, + key: ArrayKey, + array: PersistenceArray, + interpolatable: bool | None = None, + nonspatial: bool = False, + ): + self.key = key + self.array = array + self.array_spec = ArraySpec( + self.array.roi, + self.array.voxel_size, + self.interpolatable, + self.nonspatial, + self.array.dtype, + ) + + self.interpolatable = interpolatable + self.nonspatial = nonspatial + + def setup(self): + self.provides(self.key, self.array_spec) + + def provide(self, request): + outputs = Batch() + if self.nonspatial: + outputs[self.key] = Array(self.array[:], self.array_spec.copy()) + else: + out_spec = self.array_spec.copy() + out_spec.roi = request[self.key].roi + outputs[self.key] = Array(self.array[out_spec.roi], out_spec) + return outputs From 470a2386039c28f41085e9e89200820855ffac17 Mon Sep 17 00:00:00 2001 From: William Patton Date: Thu, 29 Aug 2024 14:07:26 -0400 Subject: [PATCH 65/74] black formatting fix --- gunpowder/torch/nodes/train.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/gunpowder/torch/nodes/train.py b/gunpowder/torch/nodes/train.py index d5be0f9a..445eaf43 100644 --- a/gunpowder/torch/nodes/train.py +++ b/gunpowder/torch/nodes/train.py @@ -365,9 +365,8 @@ def __collect_provided_arrays(self, reference, batch, expect_missing_arrays=Fals arrays[array_name] = getattr(batch, array_key) else: raise Exception( - "Unknown network array key {}, can't be given to " "network".format( - array_key - ) + "Unknown network array key {}, can't be given to " + "network".format(array_key) ) return arrays From 59fea58b8179cd2f43a57bef45671b44b2935902 Mon Sep 17 00:00:00 2001 From: William Patton Date: Thu, 29 Aug 2024 14:29:38 -0400 Subject: [PATCH 66/74] Add basic `ArraySource` node that accepts any `funlib.persistence.Array` --- gunpowder/nodes/__init__.py | 1 + gunpowder/nodes/array_source.py | 14 +++++++------- tests/cases/array_source.py | 29 +++++++++++++++++++++++++++++ 3 files changed, 37 insertions(+), 7 deletions(-) create mode 100644 tests/cases/array_source.py diff --git a/gunpowder/nodes/__init__.py b/gunpowder/nodes/__init__.py index 3c2a410f..30dbf848 100644 --- a/gunpowder/nodes/__init__.py +++ b/gunpowder/nodes/__init__.py @@ -1,5 +1,6 @@ from __future__ import absolute_import +from .array_source import ArraySource from .add_affinities import AddAffinities from .astype import AsType from .balance_labels import BalanceLabels diff --git a/gunpowder/nodes/array_source.py b/gunpowder/nodes/array_source.py index 4ebbbb26..e6f05507 100644 --- a/gunpowder/nodes/array_source.py +++ b/gunpowder/nodes/array_source.py @@ -1,5 +1,8 @@ from funlib.persistence.arrays import Array as PersistenceArray -from gunpowder import Array, ArrayKey, Batch, BatchProvider, ArraySpec +from gunpowder.array import Array, ArrayKey +from gunpowder.array_spec import ArraySpec +from gunpowder.batch import Batch +from .batch_provider import BatchProvider class ArraySource(BatchProvider): @@ -33,20 +36,17 @@ def __init__( self.array_spec = ArraySpec( self.array.roi, self.array.voxel_size, - self.interpolatable, - self.nonspatial, + interpolatable, + nonspatial, self.array.dtype, ) - self.interpolatable = interpolatable - self.nonspatial = nonspatial - def setup(self): self.provides(self.key, self.array_spec) def provide(self, request): outputs = Batch() - if self.nonspatial: + if self.array_spec.nonspatial: outputs[self.key] = Array(self.array[:], self.array_spec.copy()) else: out_spec = self.array_spec.copy() diff --git a/tests/cases/array_source.py b/tests/cases/array_source.py new file mode 100644 index 00000000..f7cb666b --- /dev/null +++ b/tests/cases/array_source.py @@ -0,0 +1,29 @@ +from funlib.persistence import prepare_ds +from funlib.geometry import Roi +from gunpowder.nodes import ArraySource +from gunpowder import ArrayKey, build, BatchRequest, ArraySpec + +import numpy as np + + +def test_array_source(tmpdir): + array = prepare_ds( + tmpdir / "data.zarr", + shape=(100, 102, 108), + offset=(100, 50, 0), + voxel_size=(1, 2, 3), + dtype="uint8", + ) + array[:] = np.arange(100 * 102 * 108).reshape((100, 102, 108)) % 255 + + key = ArrayKey("TEST") + + source = ArraySource(key=key, array=array) + + with build(source): + request = BatchRequest() + + roi = Roi((100, 100, 102), (30, 30, 30)) + request[key] = ArraySpec(roi) + + assert np.array_equal(source.request_batch(request)[key].data, array[roi]) From aac278b045c22cae5f4410931fb916ba024a477e Mon Sep 17 00:00:00 2001 From: William Patton Date: Thu, 29 Aug 2024 14:30:26 -0400 Subject: [PATCH 67/74] add ArraySource to docs --- docs/source/api.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/source/api.rst b/docs/source/api.rst index 21b7753b..ea0cfa39 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -91,6 +91,11 @@ BatchFilter Source Nodes ------------ +ArraySource +^^^^^^^^^^^ + + .. autoclass:: ArraySource + ZarrSource ^^^^^^^^^^ .. autoclass:: ZarrSource From 3a8be9ba22c79dc5bb6414097249df64ff669983 Mon Sep 17 00:00:00 2001 From: William Patton Date: Thu, 29 Aug 2024 16:31:36 -0400 Subject: [PATCH 68/74] fix dtype checking for float types for numpy >= 2.0 --- gunpowder/nodes/array_source.py | 18 ++++++++++-------- gunpowder/nodes/dvid_source.py | 7 ++----- gunpowder/nodes/hdf5like_source_base.py | 7 ++----- gunpowder/nodes/klb_source.py | 7 ++----- gunpowder/nodes/zarr_source.py | 7 ++----- 5 files changed, 18 insertions(+), 28 deletions(-) diff --git a/gunpowder/nodes/array_source.py b/gunpowder/nodes/array_source.py index e6f05507..ad4573e1 100644 --- a/gunpowder/nodes/array_source.py +++ b/gunpowder/nodes/array_source.py @@ -22,6 +22,12 @@ class ArraySource(BatchProvider): array (``Array``): A `funlib.persistence.Array` object. + + interpolatable (``bool``, optional): + + Whether the array is interpolatable. If not given it is + guessed based on dtype. + """ def __init__( @@ -29,7 +35,6 @@ def __init__( key: ArrayKey, array: PersistenceArray, interpolatable: bool | None = None, - nonspatial: bool = False, ): self.key = key self.array = array @@ -37,7 +42,7 @@ def __init__( self.array.roi, self.array.voxel_size, interpolatable, - nonspatial, + False, self.array.dtype, ) @@ -46,10 +51,7 @@ def setup(self): def provide(self, request): outputs = Batch() - if self.array_spec.nonspatial: - outputs[self.key] = Array(self.array[:], self.array_spec.copy()) - else: - out_spec = self.array_spec.copy() - out_spec.roi = request[self.key].roi - outputs[self.key] = Array(self.array[out_spec.roi], out_spec) + out_spec = self.array_spec.copy() + out_spec.roi = request[self.key].roi + outputs[self.key] = Array(self.array[out_spec.roi], out_spec) return outputs diff --git a/gunpowder/nodes/dvid_source.py b/gunpowder/nodes/dvid_source.py index d285a502..312dd59e 100644 --- a/gunpowder/nodes/dvid_source.py +++ b/gunpowder/nodes/dvid_source.py @@ -182,11 +182,8 @@ def __get_spec(self, array_key): spec.dtype = data_dtype if spec.interpolatable is None: - spec.interpolatable = spec.dtype in ( - np.sctypes["float"] - + [ - np.uint8, # assuming this is not used for labels - ] + spec.interpolatable = np.issubdtype(spec.dtype, np.floating) or ( + spec.dtype == np.uint8 ) logger.warning( "WARNING: You didn't set 'interpolatable' for %s. " diff --git a/gunpowder/nodes/hdf5like_source_base.py b/gunpowder/nodes/hdf5like_source_base.py index d7c63149..f5a8e58b 100644 --- a/gunpowder/nodes/hdf5like_source_base.py +++ b/gunpowder/nodes/hdf5like_source_base.py @@ -174,11 +174,8 @@ def __read_spec(self, array_key, data_file, ds_name): spec.dtype = dataset.dtype if spec.interpolatable is None: - spec.interpolatable = spec.dtype in ( - np.sctypes["float"] - + [ - np.uint8, # assuming this is not used for labels - ] + spec.interpolatable = np.issubdtype(spec.dtype, np.floating) or ( + spec.dtype == np.uint8 ) logger.warning( "WARNING: You didn't set 'interpolatable' for %s " diff --git a/gunpowder/nodes/klb_source.py b/gunpowder/nodes/klb_source.py index 53eca5c4..d4a55049 100644 --- a/gunpowder/nodes/klb_source.py +++ b/gunpowder/nodes/klb_source.py @@ -155,11 +155,8 @@ def __read_spec(self, headers): spec.dtype = dtype if spec.interpolatable is None: - spec.interpolatable = spec.dtype in ( - np.sctypes["float"] - + [ - np.uint8, # assuming this is not used for labels - ] + spec.interpolatable = np.issubdtype(spec.dtype, np.floating) or ( + spec.dtype == np.uint8 ) logger.warning( "WARNING: You didn't set 'interpolatable' for %s. " diff --git a/gunpowder/nodes/zarr_source.py b/gunpowder/nodes/zarr_source.py index 2f1c15fc..82831fa3 100644 --- a/gunpowder/nodes/zarr_source.py +++ b/gunpowder/nodes/zarr_source.py @@ -206,11 +206,8 @@ def __read_spec(self, array_key, data_file, ds_name): spec.dtype = dataset.dtype if spec.interpolatable is None: - spec.interpolatable = spec.dtype in ( - np.sctypes["float"] - + [ - np.uint8, # assuming this is not used for labels - ] + spec.interpolatable = np.issubdtype(spec.dtype, np.floating) or ( + spec.dtype == np.uint8 ) logger.warning( "WARNING: You didn't set 'interpolatable' for %s " From bc5e6b0be091c266a70141cf2db3a4d188b33469 Mon Sep 17 00:00:00 2001 From: William Patton Date: Thu, 29 Aug 2024 16:47:47 -0400 Subject: [PATCH 69/74] add documentation for gradients argument of torch `Train` node --- gunpowder/torch/nodes/train.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/gunpowder/torch/nodes/train.py b/gunpowder/torch/nodes/train.py index 445eaf43..71fc9c48 100644 --- a/gunpowder/torch/nodes/train.py +++ b/gunpowder/torch/nodes/train.py @@ -52,6 +52,16 @@ class Train(GenericTrain): New arrays will be generated by this node for each entry (if requested downstream). + gradients (``dict``, ``string`` or ``int`` -> :class:`ArrayKey`, optional): + + Dictionary from the names of tensors in the network to array + keys. If the key is a string, the tensor will be retrieved + by checking the model for an attribute with they key as its name. + If the key is an integer, it is interpreted as a tuple index of + the outputs of the network. + Instead of the actual array, the gradient of the array with respect + to the loss will be generated and saved. + array_specs (``dict``, :class:`ArrayKey` -> :class:`ArraySpec`, optional): Used to set the specs of generated arrays (at the moment only From b438fc0ca88c371a490770a10a9804597a0eec6f Mon Sep 17 00:00:00 2001 From: William Patton Date: Thu, 29 Aug 2024 19:41:34 -0400 Subject: [PATCH 70/74] add typehint for dict --- gunpowder/torch/nodes/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gunpowder/torch/nodes/train.py b/gunpowder/torch/nodes/train.py index 71fc9c48..9bf17c70 100644 --- a/gunpowder/torch/nodes/train.py +++ b/gunpowder/torch/nodes/train.py @@ -123,7 +123,7 @@ def __init__( # not yet implemented gradients = gradients loss_inputs = {f"loss_{k}": v for k, v in loss_inputs.items()} - all_inputs = {f"{k}": v for k, v in inputs.items() if v not in outputs.values()} + all_inputs: dict[str | int, Any] = {f"{k}": v for k, v in inputs.items() if v not in outputs.values()} all_inputs.update( {k: v for k, v in loss_inputs.items() if v not in outputs.values()} ) From c0e9cc551428a2a5751f79bb06d1d3b38a78d789 Mon Sep 17 00:00:00 2001 From: William Patton Date: Thu, 29 Aug 2024 19:42:22 -0400 Subject: [PATCH 71/74] black reformatting --- gunpowder/torch/nodes/train.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/gunpowder/torch/nodes/train.py b/gunpowder/torch/nodes/train.py index 9bf17c70..241332db 100644 --- a/gunpowder/torch/nodes/train.py +++ b/gunpowder/torch/nodes/train.py @@ -123,7 +123,9 @@ def __init__( # not yet implemented gradients = gradients loss_inputs = {f"loss_{k}": v for k, v in loss_inputs.items()} - all_inputs: dict[str | int, Any] = {f"{k}": v for k, v in inputs.items() if v not in outputs.values()} + all_inputs: dict[str | int, Any] = { + f"{k}": v for k, v in inputs.items() if v not in outputs.values() + } all_inputs.update( {k: v for k, v in loss_inputs.items() if v not in outputs.values()} ) From 1cb77328e217e1c9ff2d2f05f49311f48b402de2 Mon Sep 17 00:00:00 2001 From: William Patton Date: Thu, 29 Aug 2024 20:46:18 -0400 Subject: [PATCH 72/74] add support for python 3.12 --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 4a2139c5..5966aa92 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -15,7 +15,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.9", "3.10", "3.11"] + python-version: ["3.10", "3.11", "3.12"] platform: [ubuntu-latest] steps: From 65ed7f4dfee698cd9ac91a3404e154a1c9569f86 Mon Sep 17 00:00:00 2001 From: William Patton Date: Thu, 29 Aug 2024 20:56:38 -0400 Subject: [PATCH 73/74] remove distutils --- .../contrib/nodes/dvid_partner_annotation_source.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/gunpowder/contrib/nodes/dvid_partner_annotation_source.py b/gunpowder/contrib/nodes/dvid_partner_annotation_source.py index f7f08599..400f6c68 100644 --- a/gunpowder/contrib/nodes/dvid_partner_annotation_source.py +++ b/gunpowder/contrib/nodes/dvid_partner_annotation_source.py @@ -1,4 +1,3 @@ -import distutils.util import numpy as np import logging import requests @@ -14,6 +13,14 @@ logger = logging.getLogger(__name__) +def strtobool(val): + val = val.lower() + if val in ('y', 'yes', 't', 'true', 'on', '1'): + return 1 + elif val in ('n', 'no', 'f', 'false', 'off', '0'): + return 0 + else: + raise ValueError(f"Invalid truth value: {val}") class DvidPartnerAnnoationSourceReadException(Exception): pass @@ -198,10 +205,10 @@ def __read_syn_points(self, roi): props["agent"] = str(node["Prop"]["agent"]) if "flagged" in node["Prop"]: str_value_flagged = str(node["Prop"]["flagged"]) - props["flagged"] = bool(distutils.util.strtobool(str_value_flagged)) + props["flagged"] = bool(strtobool(str_value_flagged)) if "multi" in node["Prop"]: str_value_multi = str(node["Prop"]["multi"]) - props["multi"] = bool(distutils.util.strtobool(str_value_multi)) + props["multi"] = bool(strtobool(str_value_multi)) # create synPoint with information collected so far (partner_ids not completed yet) if kind == "PreSyn": From 9693615510e7de002f4400ed31c96930ad6b35ad Mon Sep 17 00:00:00 2001 From: William Patton Date: Thu, 29 Aug 2024 20:57:01 -0400 Subject: [PATCH 74/74] black formatting --- gunpowder/contrib/nodes/dvid_partner_annotation_source.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/gunpowder/contrib/nodes/dvid_partner_annotation_source.py b/gunpowder/contrib/nodes/dvid_partner_annotation_source.py index 400f6c68..36e182e1 100644 --- a/gunpowder/contrib/nodes/dvid_partner_annotation_source.py +++ b/gunpowder/contrib/nodes/dvid_partner_annotation_source.py @@ -13,15 +13,17 @@ logger = logging.getLogger(__name__) + def strtobool(val): val = val.lower() - if val in ('y', 'yes', 't', 'true', 'on', '1'): + if val in ("y", "yes", "t", "true", "on", "1"): return 1 - elif val in ('n', 'no', 'f', 'false', 'off', '0'): + elif val in ("n", "no", "f", "false", "off", "0"): return 0 else: raise ValueError(f"Invalid truth value: {val}") + class DvidPartnerAnnoationSourceReadException(Exception): pass