diff --git a/tests/cases/add_affinities.py b/tests/cases/add_affinities.py index bd6ebc93..95bd491f 100644 --- a/tests/cases/add_affinities.py +++ b/tests/cases/add_affinities.py @@ -1,7 +1,21 @@ -from gunpowder import * from itertools import product + import numpy as np +from gunpowder import ( + AddAffinities, + Array, + ArrayKey, + ArrayKeys, + ArraySpec, + Batch, + BatchProvider, + BatchRequest, + Coordinate, + Roi, + build, +) + class ExampleSource(BatchProvider): def setup(self): diff --git a/tests/cases/add_boundary_distance_gradients.py b/tests/cases/add_boundary_distance_gradients.py index ca3efc6e..ea7b2e3d 100644 --- a/tests/cases/add_boundary_distance_gradients.py +++ b/tests/cases/add_boundary_distance_gradients.py @@ -1,84 +1,54 @@ -from .provider_test import ProviderTest -from gunpowder import * -from gunpowder.contrib import AddBoundaryDistanceGradients import numpy as np +from gunpowder import Array, ArrayKey, ArraySpec, BatchRequest, Roi, build +from gunpowder.contrib import AddBoundaryDistanceGradients -class ExampleSource(BatchProvider): - def setup(self): - self.provides( - ArrayKeys.GT_LABELS, - ArraySpec( - roi=Roi((-40, -40, -40), (160, 160, 160)), - voxel_size=(20, 4, 8), - interpolatable=False, - ), - ) - - def provide(self, request): - batch = Batch() - - roi = request[ArrayKeys.GT_LABELS].roi - shape = (roi / self.spec[ArrayKeys.GT_LABELS].voxel_size).shape - - spec = self.spec[ArrayKeys.GT_LABELS].copy() - spec.roi = roi - data = np.ones(shape) - data[shape[0] // 2 :, :, :] += 2 - data[:, shape[1] // 2 :, :] += 4 - data[:, :, shape[2] // 2 :] += 8 - batch.arrays[ArrayKeys.GT_LABELS] = Array(data, spec) +from .helper_sources import ArraySource - return batch +def test_output(): + labels_key = ArrayKey("LABELS") + dist_key = ArrayKey("BOUNDARY_DISTANCES") + grad_key = ArrayKey("BOUNDARY_GRADIENTS") -class TestAddBoundaryDistanceGradients(ProviderTest): - def test_output(self): - ArrayKey("GT_BOUNDARY_DISTANCES") - ArrayKey("GT_BOUNDARY_GRADIENTS") + labels_spec = ArraySpec( + roi=Roi((0, 0, 0), (120, 16, 64)), + voxel_size=(20, 4, 8), + interpolatable=False, + ) + shape = (labels_spec.roi / labels_spec.voxel_size).shape + labels_data = np.ones(shape) + labels_data[shape[0] // 2 :, :, :] += 2 + labels_data[:, shape[1] // 2 :, :] += 4 + labels_data[:, :, shape[2] // 2 :] += 8 + labels_array = Array(labels_data, labels_spec) - pipeline = ExampleSource() + AddBoundaryDistanceGradients( - label_array_key=ArrayKeys.GT_LABELS, - distance_array_key=ArrayKeys.GT_BOUNDARY_DISTANCES, - gradient_array_key=ArrayKeys.GT_BOUNDARY_GRADIENTS, - ) + labels_source = ArraySource(labels_key, labels_array) - with build(pipeline): - request = BatchRequest() - request.add(ArrayKeys.GT_LABELS, (120, 16, 64)) - request.add(ArrayKeys.GT_BOUNDARY_DISTANCES, (120, 16, 64)) - request.add(ArrayKeys.GT_BOUNDARY_GRADIENTS, (120, 16, 64)) + pipeline = labels_source + AddBoundaryDistanceGradients( + label_array_key=labels_key, + distance_array_key=dist_key, + gradient_array_key=grad_key, + ) - batch = pipeline.request_batch(request) + with build(pipeline): + request = BatchRequest() + request.add(labels_key, (120, 16, 64)) + request.add(dist_key, (120, 16, 64)) + request.add(grad_key, (120, 16, 64)) - labels = batch.arrays[ArrayKeys.GT_LABELS].data - distances = batch.arrays[ArrayKeys.GT_BOUNDARY_DISTANCES].data - gradients = batch.arrays[ArrayKeys.GT_BOUNDARY_GRADIENTS].data - shape = distances.shape + batch = pipeline.request_batch(request) - l_001 = labels[: shape[0] // 2, : shape[1] // 2, shape[2] // 2 :] - l_101 = labels[shape[0] // 2 :, : shape[1] // 2, shape[2] // 2 :] - d_001 = distances[: shape[0] // 2, : shape[1] // 2, shape[2] // 2 :] - d_101 = distances[shape[0] // 2 :, : shape[1] // 2, shape[2] // 2 :] - g_001 = gradients[:, : shape[0] // 2, : shape[1] // 2, shape[2] // 2 :] - g_101 = gradients[:, shape[0] // 2 :, : shape[1] // 2, shape[2] // 2 :] + distances = batch.arrays[dist_key].data + gradients = batch.arrays[grad_key].data + shape = distances.shape - # print labels - # print - # print distances - # print - # print l_001 - # print l_101 - # print - # print d_001 - # print d_101 - # print - # print g_001 - # print g_101 + g_001 = gradients[:, : shape[0] // 2, : shape[1] // 2, shape[2] // 2 :] + g_101 = gradients[:, shape[0] // 2 :, : shape[1] // 2, shape[2] // 2 :] - self.assertTrue((g_001 == g_101).all()) + assert (g_001 == g_101).all() - top = gradients[:, 0 : shape[0] // 2, :] - bot = gradients[:, shape[0] : shape[0] // 2 - 1 : -1, :] + top = gradients[:, 0 : shape[0] // 2, :] + bot = gradients[:, shape[0] : shape[0] // 2 - 1 : -1, :] - self.assertTrue((top == bot).all()) + assert (top == bot).all() diff --git a/tests/cases/add_vector_map.py b/tests/cases/add_vector_map.py index 70fdcf78..53d927f6 100644 --- a/tests/cases/add_vector_map.py +++ b/tests/cases/add_vector_map.py @@ -1,32 +1,38 @@ -import unittest -from .provider_test import ProviderTest +import itertools +from copy import deepcopy + +import numpy as np + from gunpowder import ( - ArrayKeys, - ArraySpec, - GraphSpec, - Roi, Array, - GraphKeys, - GraphKey, + ArrayKey, + ArraySpec, Batch, BatchProvider, + BatchRequest, + Coordinate, Graph, + GraphKey, + GraphSpec, Node, - Coordinate, - ArrayKey, - BatchRequest, + Roi, build, ) from gunpowder.contrib import AddVectorMap -from copy import deepcopy -import itertools -import numpy as np - +# TODO: Simplify the source node. The data being generated should not be defined +# in the provide method. Instead the source should be simple arrays and graphs. class AddVectorMapTestSource(BatchProvider): + def __init__(self, raw_key, labels_key, presyn_key, postsyn_key, vector_map_key): + self.raw_key = raw_key + self.labels_key = labels_key + self.presyn_key = presyn_key + self.postsyn_key = postsyn_key + self.vector_map_key = vector_map_key + def setup(self): - for identifier in [ArrayKeys.RAW, ArrayKeys.GT_LABELS]: + for identifier in [self.raw_key, self.labels_key]: self.provides( identifier, ArraySpec( @@ -34,7 +40,7 @@ def setup(self): ), ) - for identifier in [GraphKeys.PRESYN, GraphKeys.POSTSYN]: + for identifier in [self.presyn_key, self.postsyn_key]: self.provides( identifier, GraphSpec(roi=Roi((1000, 1000, 1000), (400, 400, 400))) ) @@ -43,10 +49,10 @@ def provide(self, request): batch = Batch() # have the pixels encode their position - if ArrayKeys.RAW in request: + if self.raw_key in request: # the z,y,x coordinates of the ROI - roi = request[ArrayKeys.RAW].roi - roi_voxel = roi // self.spec[ArrayKeys.RAW].voxel_size + roi = request[self.raw_key].roi + roi_voxel = roi // self.spec[self.raw_key].voxel_size meshgrids = np.meshgrid( range(roi_voxel.begin[0], roi_voxel.end[0]), range(roi_voxel.begin[1], roi_voxel.end[1]), @@ -55,34 +61,33 @@ def provide(self, request): ) data = meshgrids[0] + meshgrids[1] + meshgrids[2] - spec = self.spec[ArrayKeys.RAW].copy() + spec = self.spec[self.raw_key].copy() spec.roi = roi - batch.arrays[ArrayKeys.RAW] = Array(data, spec) + batch.arrays[self.raw_key] = Array(data, spec) - if ArrayKeys.GT_LABELS in request: - roi = request[ArrayKeys.GT_LABELS].roi - roi_voxel_shape = (roi // self.spec[ArrayKeys.GT_LABELS].voxel_size).shape + if self.labels_key in request: + roi = request[self.labels_key].roi + roi_voxel_shape = (roi // self.spec[self.labels_key].voxel_size).shape data = np.ones(roi_voxel_shape) data[roi_voxel_shape[0] // 2 :, roi_voxel_shape[1] // 2 :, :] = 2 data[roi_voxel_shape[0] // 2 :, -(roi_voxel_shape[1] // 2) :, :] = 3 - spec = self.spec[ArrayKeys.GT_LABELS].copy() + spec = self.spec[self.labels_key].copy() spec.roi = roi - batch.arrays[ArrayKeys.GT_LABELS] = Array(data, spec) + batch.arrays[self.labels_key] = Array(data, spec) - if GraphKeys.PRESYN in request: + if self.presyn_key in request: data_presyn, data_postsyn = self.__get_pre_and_postsyn_locations( - roi=request[GraphKeys.PRESYN].roi + roi=request[self.presyn_key].roi ) - elif GraphKeys.POSTSYN in request: + elif self.postsyn_key in request: data_presyn, data_postsyn = self.__get_pre_and_postsyn_locations( - roi=request[GraphKeys.POSTSYN].roi + roi=request[self.postsyn_key].roi ) - voxel_size_points = self.spec[ArrayKeys.RAW].voxel_size for graph_key, spec in request.graph_specs.items(): - if graph_key == GraphKeys.PRESYN: + if graph_key == self.presyn_key: data = data_presyn - if graph_key == GraphKeys.POSTSYN: + if graph_key == self.postsyn_key: data = data_postsyn batch.graphs[graph_key] = Graph( list(data.values()), [], GraphSpec(spec.roi) @@ -93,7 +98,7 @@ def provide(self, request): def __get_pre_and_postsyn_locations(self, roi): presyn_locs, postsyn_locs = {}, {} min_dist_between_presyn_locs = 250 - voxel_size_points = self.spec[ArrayKeys.RAW].voxel_size + voxel_size_points = self.spec[self.raw_key].voxel_size min_dist_pre_to_postsyn_loc, max_dist_pre_to_postsyn_loc = 60, 120 num_presyn_locations = roi.size // ( np.prod(50 * np.asarray(voxel_size_points)) @@ -170,209 +175,189 @@ def __get_pre_and_postsyn_locations(self, roi): return presyn_locs, postsyn_locs -class TestAddVectorMap(ProviderTest): - def test_output_min_distance(self): - voxel_size = Coordinate((20, 2, 2)) - - ArrayKey("GT_VECTORS_MAP_PRESYN") - GraphKey("PRESYN") - GraphKey("POSTSYN") - - arraytypes_to_source_target_pointstypes = { - ArrayKeys.GT_VECTORS_MAP_PRESYN: (GraphKeys.PRESYN, GraphKeys.POSTSYN) - } - arraytypes_to_stayinside_arraytypes = { - ArrayKeys.GT_VECTORS_MAP_PRESYN: ArrayKeys.GT_LABELS - } - - # test for partner criterion 'min_distance' - radius_phys = 30 - pipeline_min_distance = AddVectorMapTestSource() + AddVectorMap( - src_and_trg_points=arraytypes_to_source_target_pointstypes, - voxel_sizes={ArrayKeys.GT_VECTORS_MAP_PRESYN: voxel_size}, - radius_phys=radius_phys, - partner_criterion="min_distance", - stayinside_array_keys=arraytypes_to_stayinside_arraytypes, - pad_for_partners=(0, 0, 0), - ) - - with build(pipeline_min_distance): - request = BatchRequest() - raw_roi = pipeline_min_distance.spec[ArrayKeys.RAW].roi - gt_labels_roi = pipeline_min_distance.spec[ArrayKeys.GT_LABELS].roi - presyn_roi = pipeline_min_distance.spec[GraphKeys.PRESYN].roi - - request.add(ArrayKeys.RAW, raw_roi.shape) - request.add(ArrayKeys.GT_LABELS, gt_labels_roi.shape) - request.add(GraphKeys.PRESYN, presyn_roi.shape) - request.add(GraphKeys.POSTSYN, presyn_roi.shape) - request.add(ArrayKeys.GT_VECTORS_MAP_PRESYN, presyn_roi.shape) - for identifier, spec in request.items(): - spec.roi = spec.roi.shift(Coordinate(1000, 1000, 1000)) - - batch = pipeline_min_distance.request_batch(request) - - presyn_locs = {n.id: n for n in batch.graphs[GraphKeys.PRESYN].nodes} - postsyn_locs = {n.id: n for n in batch.graphs[GraphKeys.POSTSYN].nodes} - vector_map_presyn = batch.arrays[ArrayKeys.GT_VECTORS_MAP_PRESYN].data - offset_vector_map_presyn = request[ArrayKeys.GT_VECTORS_MAP_PRESYN].roi.offset - - self.assertTrue(len(presyn_locs) > 0) - self.assertTrue(len(postsyn_locs) > 0) - - for loc_id, point in presyn_locs.items(): - if request[ArrayKeys.GT_VECTORS_MAP_PRESYN].roi.contains( +def test_output_min_distance(): + voxel_size = Coordinate((20, 2, 2)) + + raw_key = ArrayKey("RAW") + labels_key = ArrayKey("LABELS") + vectors_map_key = ArrayKey("VECTORS_MAP_PRESYN") + pre_key = GraphKey("PRESYN") + post_key = GraphKey("POSTSYN") + + arraytypes_to_source_target_pointstypes = {vectors_map_key: (pre_key, post_key)} + arraytypes_to_stayinside_arraytypes = {vectors_map_key: labels_key} + + # test for partner criterion 'min_distance' + radius_phys = 30 + pipeline_min_distance = AddVectorMapTestSource( + raw_key, labels_key, pre_key, post_key, vectors_map_key + ) + AddVectorMap( + src_and_trg_points=arraytypes_to_source_target_pointstypes, + voxel_sizes={vectors_map_key: voxel_size}, + radius_phys=radius_phys, + partner_criterion="min_distance", + stayinside_array_keys=arraytypes_to_stayinside_arraytypes, + pad_for_partners=(0, 0, 0), + ) + + with build(pipeline_min_distance): + request = BatchRequest() + raw_roi = pipeline_min_distance.spec[raw_key].roi + gt_labels_roi = pipeline_min_distance.spec[labels_key].roi + presyn_roi = pipeline_min_distance.spec[pre_key].roi + + request.add(raw_key, raw_roi.shape) + request.add(labels_key, gt_labels_roi.shape) + request.add(pre_key, presyn_roi.shape) + request.add(post_key, presyn_roi.shape) + request.add(vectors_map_key, presyn_roi.shape) + for identifier, spec in request.items(): + spec.roi = spec.roi.shift(Coordinate(1000, 1000, 1000)) + + batch = pipeline_min_distance.request_batch(request) + + presyn_locs = {n.id: n for n in batch.graphs[pre_key].nodes} + postsyn_locs = {n.id: n for n in batch.graphs[post_key].nodes} + vector_map_presyn = batch.arrays[vectors_map_key].data + offset_vector_map_presyn = request[vectors_map_key].roi.offset + + assert len(presyn_locs) > 0 + assert len(postsyn_locs) > 0 + + for loc_id, point in presyn_locs.items(): + if request[vectors_map_key].roi.contains(Coordinate(point.location)): + assert batch.arrays[vectors_map_key].spec.roi.contains( Coordinate(point.location) - ): - self.assertTrue( - batch.arrays[ArrayKeys.GT_VECTORS_MAP_PRESYN].spec.roi.contains( - Coordinate(point.location) - ) - ) + ) - dist_to_loc = {} - for partner_id in point.attrs["partner_ids"]: - if partner_id in postsyn_locs.keys(): - partner_location = postsyn_locs[partner_id].location - dist_to_loc[ - np.linalg.norm(partner_location - point.location) - ] = partner_location - min_dist = np.min(list(dist_to_loc.keys())) - relevant_partner_loc = dist_to_loc[min_dist] - - presyn_loc_shifted_vx = ( - point.location - offset_vector_map_presyn - ) // voxel_size - radius_vx = [(radius_phys // vx_dim) for vx_dim in voxel_size] - region_to_check = np.clip( - [ - (presyn_loc_shifted_vx - radius_vx), - (presyn_loc_shifted_vx + radius_vx), - ], - a_min=(0, 0, 0), - a_max=vector_map_presyn.shape[-3:], - ) - for x, y, z in itertools.product( - range(int(region_to_check[0][0]), int(region_to_check[1][0])), - range(int(region_to_check[0][1]), int(region_to_check[1][1])), - range(int(region_to_check[0][2]), int(region_to_check[1][2])), + dist_to_loc = {} + for partner_id in point.attrs["partner_ids"]: + if partner_id in postsyn_locs.keys(): + partner_location = postsyn_locs[partner_id].location + dist_to_loc[np.linalg.norm(partner_location - point.location)] = ( + partner_location + ) + min_dist = np.min(list(dist_to_loc.keys())) + relevant_partner_loc = dist_to_loc[min_dist] + + presyn_loc_shifted_vx = ( + point.location - offset_vector_map_presyn + ) // voxel_size + radius_vx = [(radius_phys // vx_dim) for vx_dim in voxel_size] + region_to_check = np.clip( + [ + (presyn_loc_shifted_vx - radius_vx), + (presyn_loc_shifted_vx + radius_vx), + ], + a_min=(0, 0, 0), + a_max=vector_map_presyn.shape[-3:], + ) + for x, y, z in itertools.product( + range(int(region_to_check[0][0]), int(region_to_check[1][0])), + range(int(region_to_check[0][1]), int(region_to_check[1][1])), + range(int(region_to_check[0][2]), int(region_to_check[1][2])), + ): + if ( + np.linalg.norm((np.array((x, y, z)) - np.asarray(point.location))) + < radius_phys ): - if ( - np.linalg.norm( - (np.array((x, y, z)) - np.asarray(point.location)) + vector = [ + vector_map_presyn[dim][x, y, z] + for dim in range(vector_map_presyn.shape[0]) + ] + if not np.sum(vector) == 0: + trg_loc_of_vector_phys = ( + np.asarray(offset_vector_map_presyn) + + (voxel_size * np.array([x, y, z])) + + np.asarray(vector) ) - < radius_phys - ): - vector = [ - vector_map_presyn[dim][x, y, z] - for dim in range(vector_map_presyn.shape[0]) - ] - if not np.sum(vector) == 0: - trg_loc_of_vector_phys = ( - np.asarray(offset_vector_map_presyn) - + (voxel_size * np.array([x, y, z])) - + np.asarray(vector) - ) - self.assertTrue( - np.array_equal( - trg_loc_of_vector_phys, relevant_partner_loc - ) - ) - - # test for partner criterion 'all' - pipeline_all = AddVectorMapTestSource() + AddVectorMap( - src_and_trg_points=arraytypes_to_source_target_pointstypes, - voxel_sizes={ArrayKeys.GT_VECTORS_MAP_PRESYN: voxel_size}, - radius_phys=radius_phys, - partner_criterion="all", - stayinside_array_keys=arraytypes_to_stayinside_arraytypes, - pad_for_partners=(0, 0, 0), - ) - - with build(pipeline_all): - batch = pipeline_all.request_batch(request) - - presyn_locs = {n.id: n for n in batch.graphs[GraphKeys.PRESYN].nodes} - postsyn_locs = {n.id: n for n in batch.graphs[GraphKeys.POSTSYN].nodes} - vector_map_presyn = batch.arrays[ArrayKeys.GT_VECTORS_MAP_PRESYN].data - offset_vector_map_presyn = request[ArrayKeys.GT_VECTORS_MAP_PRESYN].roi.offset - - self.assertTrue(len(presyn_locs) > 0) - self.assertTrue(len(postsyn_locs) > 0) - - for loc_id, point in presyn_locs.items(): - if request[ArrayKeys.GT_VECTORS_MAP_PRESYN].roi.contains( + assert np.array_equal( + trg_loc_of_vector_phys, relevant_partner_loc + ) + + # test for partner criterion 'all' + pipeline_all = AddVectorMapTestSource( + raw_key, labels_key, pre_key, post_key, vectors_map_key + ) + AddVectorMap( + src_and_trg_points=arraytypes_to_source_target_pointstypes, + voxel_sizes={vectors_map_key: voxel_size}, + radius_phys=radius_phys, + partner_criterion="all", + stayinside_array_keys=arraytypes_to_stayinside_arraytypes, + pad_for_partners=(0, 0, 0), + ) + + with build(pipeline_all): + batch = pipeline_all.request_batch(request) + + presyn_locs = {n.id: n for n in batch.graphs[pre_key].nodes} + postsyn_locs = {n.id: n for n in batch.graphs[post_key].nodes} + vector_map_presyn = batch.arrays[vectors_map_key].data + offset_vector_map_presyn = request[vectors_map_key].roi.offset + + assert len(presyn_locs) > 0 + assert len(postsyn_locs) > 0 + + for loc_id, point in presyn_locs.items(): + if request[vectors_map_key].roi.contains(Coordinate(point.location)): + assert batch.arrays[vectors_map_key].spec.roi.contains( Coordinate(point.location) - ): - self.assertTrue( - batch.arrays[ArrayKeys.GT_VECTORS_MAP_PRESYN].spec.roi.contains( - Coordinate(point.location) - ) - ) + ) - partner_ids_to_locs_per_src, count_vectors_per_partner = {}, {} - for partner_id in point.attrs["partner_ids"]: - if partner_id in postsyn_locs.keys(): - partner_ids_to_locs_per_src[partner_id] = postsyn_locs[ - partner_id - ].location.tolist() - count_vectors_per_partner[partner_id] = 0 - - presyn_loc_shifted_vx = ( - point.location - offset_vector_map_presyn - ) // voxel_size - radius_vx = [(radius_phys // vx_dim) for vx_dim in voxel_size] - region_to_check = np.clip( - [ - (presyn_loc_shifted_vx - radius_vx), - (presyn_loc_shifted_vx + radius_vx), - ], - a_min=(0, 0, 0), - a_max=vector_map_presyn.shape[-3:], - ) - for x, y, z in itertools.product( - range(int(region_to_check[0][0]), int(region_to_check[1][0])), - range(int(region_to_check[0][1]), int(region_to_check[1][1])), - range(int(region_to_check[0][2]), int(region_to_check[1][2])), + partner_ids_to_locs_per_src, count_vectors_per_partner = {}, {} + for partner_id in point.attrs["partner_ids"]: + if partner_id in postsyn_locs.keys(): + partner_ids_to_locs_per_src[partner_id] = postsyn_locs[ + partner_id + ].location.tolist() + count_vectors_per_partner[partner_id] = 0 + + presyn_loc_shifted_vx = ( + point.location - offset_vector_map_presyn + ) // voxel_size + radius_vx = [(radius_phys // vx_dim) for vx_dim in voxel_size] + region_to_check = np.clip( + [ + (presyn_loc_shifted_vx - radius_vx), + (presyn_loc_shifted_vx + radius_vx), + ], + a_min=(0, 0, 0), + a_max=vector_map_presyn.shape[-3:], + ) + for x, y, z in itertools.product( + range(int(region_to_check[0][0]), int(region_to_check[1][0])), + range(int(region_to_check[0][1]), int(region_to_check[1][1])), + range(int(region_to_check[0][2]), int(region_to_check[1][2])), + ): + if ( + np.linalg.norm((np.array((x, y, z)) - np.asarray(point.location))) + < radius_phys ): - if ( - np.linalg.norm( - (np.array((x, y, z)) - np.asarray(point.location)) + vector = [ + vector_map_presyn[dim][x, y, z] + for dim in range(vector_map_presyn.shape[0]) + ] + if not np.sum(vector) == 0: + trg_loc_of_vector_phys = ( + np.asarray(offset_vector_map_presyn) + + (voxel_size * np.array([x, y, z])) + + np.asarray(vector) + ) + assert ( + trg_loc_of_vector_phys.tolist() + in partner_ids_to_locs_per_src.values() ) - < radius_phys - ): - vector = [ - vector_map_presyn[dim][x, y, z] - for dim in range(vector_map_presyn.shape[0]) - ] - if not np.sum(vector) == 0: - trg_loc_of_vector_phys = ( - np.asarray(offset_vector_map_presyn) - + (voxel_size * np.array([x, y, z])) - + np.asarray(vector) - ) - self.assertTrue( - trg_loc_of_vector_phys.tolist() - in partner_ids_to_locs_per_src.values() - ) - - for ( - partner_id, - partner_loc, - ) in partner_ids_to_locs_per_src.items(): - if np.array_equal( - np.asarray(trg_loc_of_vector_phys), partner_loc - ): - count_vectors_per_partner[partner_id] += 1 - self.assertTrue( - ( - list(count_vectors_per_partner.values()) - - np.min(list(count_vectors_per_partner.values())) - <= len(count_vectors_per_partner.keys()) - ).all() - ) - -if __name__ == "__main__": - suite = unittest.TestLoader().loadTestsFromTestCase(TestAddVectorMap) - unittest.TextTestRunner(verbosity=2).run(suite) + for ( + partner_id, + partner_loc, + ) in partner_ids_to_locs_per_src.items(): + if np.array_equal( + np.asarray(trg_loc_of_vector_phys), partner_loc + ): + count_vectors_per_partner[partner_id] += 1 + assert ( + list(count_vectors_per_partner.values()) + - np.min(list(count_vectors_per_partner.values())) + <= len(count_vectors_per_partner.keys()) + ).all() diff --git a/tests/cases/astype.py b/tests/cases/astype.py index 104b11f5..fd7a0c26 100644 --- a/tests/cases/astype.py +++ b/tests/cases/astype.py @@ -1,99 +1,77 @@ -from .provider_test import ProviderTest -from gunpowder import * import numpy as np - -class AsTypeTestSource(BatchProvider): - def setup(self): - self.provides( - ArrayKeys.RAW, - ArraySpec(roi=Roi((0, 0, 0), (1000, 1000, 1000)), voxel_size=(4, 4, 4)), - ) - - self.provides( - ArrayKeys.GT_LABELS, - ArraySpec(roi=Roi((0, 0, 0), (1000, 1000, 1000)), voxel_size=(4, 4, 4)), - ) - - def provide(self, request): - batch = Batch() - - # have the pixels encode their position - for array_key, spec in request.array_specs.items(): - roi = spec.roi - - data_roi = roi / 4 - +from gunpowder import ( + Array, + ArrayKey, + ArraySpec, + AsType, + BatchRequest, + MergeProvider, + Roi, + build, +) + +from .helper_sources import ArraySource + + +def test_output(): + raw_key = ArrayKey("RAW") + labels_key = ArrayKey("LABELS") + raw_typed_key = ArrayKey("RAW_TYPECAST") + labels_typed_key = ArrayKey("LABELS_TYPECAST") + + raw_spec = ArraySpec(roi=Roi((0, 0, 0), (1000, 1000, 1000)), voxel_size=(4, 4, 4)) + labels_spec = ArraySpec( + roi=Roi((0, 0, 0), (1000, 1000, 1000)), voxel_size=(4, 4, 4) + ) + + roi = raw_spec.roi / raw_spec.voxel_size + meshgrids = np.meshgrid( + range(roi.get_begin()[0], roi.get_end()[0]), + range(roi.get_begin()[1], roi.get_end()[1]), + range(roi.get_begin()[2], roi.get_end()[2]), + indexing="ij", + ) + data = meshgrids[0] + meshgrids[1] + meshgrids[2] + raw_array = Array(data, raw_spec) + labels_array = Array(data, labels_spec) + + request = BatchRequest() + request.add(raw_key, (200, 200, 200)) + request.add(raw_typed_key, (120, 120, 120)) + request.add(labels_key, (200, 200, 200)) + request.add(labels_typed_key, (200, 200, 200)) + + pipeline = ( + (ArraySource(raw_key, raw_array), ArraySource(labels_key, labels_array)) + + MergeProvider() + + AsType(raw_key, np.float16, raw_typed_key) + + AsType(labels_key, np.int16, labels_typed_key) + ) + + with build(pipeline): + batch = pipeline.request_batch(request) + + for array_key, array in batch.arrays.items(): + # assert that pixels encode their position for supposedly unaltered + # arrays + if array_key in [raw_key, labels_key]: # the z,y,x coordinates of the ROI + roi = array.spec.roi / 4 meshgrids = np.meshgrid( - range(data_roi.get_begin()[0], data_roi.get_end()[0]), - range(data_roi.get_begin()[1], data_roi.get_end()[1]), - range(data_roi.get_begin()[2], data_roi.get_end()[2]), + range(roi.get_begin()[0], roi.get_end()[0]), + range(roi.get_begin()[1], roi.get_end()[1]), + range(roi.get_begin()[2], roi.get_end()[2]), indexing="ij", ) data = meshgrids[0] + meshgrids[1] + meshgrids[2] - spec = self.spec[array_key].copy() - spec.roi = roi - batch.arrays[array_key] = Array(data, spec) - return batch - - -class TestAsType(ProviderTest): - def test_output(self): - ArrayKey("RAW_TYPECAST") - ArrayKey("GT_LABELS_TYPECAST") - - request = BatchRequest() - request.add(ArrayKeys.RAW, (200, 200, 200)) - request.add(ArrayKeys.RAW_TYPECAST, (120, 120, 120)) - request.add(ArrayKeys.GT_LABELS, (200, 200, 200)) - request.add(ArrayKeys.GT_LABELS_TYPECAST, (200, 200, 200)) - - pipeline = ( - AsTypeTestSource() - + AsType(ArrayKeys.RAW, np.float16, ArrayKeys.RAW_TYPECAST) - + AsType(ArrayKeys.GT_LABELS, np.int16, ArrayKeys.GT_LABELS_TYPECAST) - ) - - with build(pipeline): - batch = pipeline.request_batch(request) - - for array_key, array in batch.arrays.items(): - # assert that pixels encode their position for supposedly unaltered - # arrays - if array_key in [ArrayKeys.RAW, ArrayKeys.GT_LABELS]: - # the z,y,x coordinates of the ROI - roi = array.spec.roi / 4 - meshgrids = np.meshgrid( - range(roi.get_begin()[0], roi.get_end()[0]), - range(roi.get_begin()[1], roi.get_end()[1]), - range(roi.get_begin()[2], roi.get_end()[2]), - indexing="ij", - ) - data = meshgrids[0] + meshgrids[1] + meshgrids[2] - - self.assertTrue(np.array_equal(array.data, data), str(array_key)) - - elif array_key == ArrayKeys.RAW_TYPECAST: - self.assertTrue( - array.data.dtype == np.float16, - f"RAW_TYPECAST dtype: {array.data.dtype} does not equal expected: np.float16", - ) - self.assertTrue( - int(array.data[1, 11, 1]) == 43, - f"RAW_TYPECAST[1,11,1]: int({array.data[1,11,1]}) does not equal expected: 43", - ) + assert np.array_equal(array.data, data) - elif array_key == ArrayKeys.GT_LABELS_TYPECAST: - self.assertTrue( - array.data.dtype == np.int16, - f"GT_LABELS_TYPECAST dtype: {array.data.dtype} does not equal expected: np.int16", - ) - self.assertTrue( - int(array.data[1, 11, 1]) == 13, - f"GT_LABELS_TYPECAST[1,11,1]: int({array.data[1,11,1]}) does not equal expected: 13", - ) + elif array_key == raw_typed_key: + assert array.data.dtype == np.float16 + assert int(array.data[1, 11, 1]) == 43 - else: - self.assertTrue(False, "unexpected array type") + elif array_key == labels_typed_key: + assert array.data.dtype == np.int16 + assert int(array.data[1, 11, 1]) == 13 diff --git a/tests/cases/balance_labels.py b/tests/cases/balance_labels.py index ec0d50e2..cae7f7b4 100644 --- a/tests/cases/balance_labels.py +++ b/tests/cases/balance_labels.py @@ -1,73 +1,136 @@ -from .provider_test import ProviderTest -from gunpowder import * import numpy as np +from gunpowder import ( + Array, + ArrayKey, + ArraySpec, + BalanceLabels, + BatchRequest, + MergeProvider, + Roi, + build, +) + +from .helper_sources import ArraySource + + +def test_output(): + affs_key = ArrayKey("AFFS") + affs_mask_key = ArrayKey("AFFS_MASK") + ignore_key = ArrayKey("IGNORE") + loss_scale_key = ArrayKey("LOSS_SCALE") + + array_spec = ArraySpec(roi=Roi((0, 0, 0), (2000, 200, 200)), voxel_size=(20, 2, 2)) + + data_shape = array_spec.roi.shape // array_spec.voxel_size + affs_data = np.random.randint(0, 2, (3,) + data_shape) + affs_mask_data = np.random.randint(0, 2, (3,) + data_shape) + ignore_data = np.random.randint(0, 2, (3,) + data_shape) + + affs_array = Array(affs_data, array_spec.copy()) + affs_mask_array = Array(affs_mask_data, array_spec.copy()) + ignore_array = Array(ignore_data, array_spec.copy()) + + pipeline = ( + ( + ArraySource(affs_key, affs_array), + ArraySource(affs_mask_key, affs_mask_array), + ArraySource(ignore_key, ignore_array), + ) + + MergeProvider() + + BalanceLabels( + labels=affs_key, + scales=loss_scale_key, + mask=[affs_mask_key, ignore_key], + ) + ) -class ExampleSource(BatchProvider): - def setup(self): - for identifier in [ - ArrayKeys.GT_AFFINITIES, - ArrayKeys.GT_AFFINITIES_MASK, - ArrayKeys.GT_IGNORE, - ]: - self.provides( - identifier, - ArraySpec(roi=Roi((0, 0, 0), (2000, 200, 200)), voxel_size=(20, 2, 2)), - ) + with build(pipeline): + # check correct scaling on 10 random samples + for i in range(10): + request = BatchRequest() + request.add(affs_key, (400, 30, 34)) + request.add(affs_mask_key, (400, 30, 34)) + request.add(ignore_key, (400, 30, 34)) + request.add(loss_scale_key, (400, 30, 34)) - def provide(self, request): - batch = Batch() + batch = pipeline.request_batch(request) - roi = request[ArrayKeys.GT_AFFINITIES].roi - shape_vx = roi.shape // self.spec[ArrayKeys.GT_AFFINITIES].voxel_size + assert loss_scale_key in batch.arrays - spec = self.spec[ArrayKeys.GT_AFFINITIES].copy() - spec.roi = roi + affs = batch.arrays[affs_key].data + scale = batch.arrays[loss_scale_key].data + mask = batch.arrays[affs_mask_key].data + ignore = batch.arrays[ignore_key].data - batch.arrays[ArrayKeys.GT_AFFINITIES] = Array( - np.random.randint(0, 2, (3,) + shape_vx), spec - ) - batch.arrays[ArrayKeys.GT_AFFINITIES_MASK] = Array( - np.random.randint(0, 2, (3,) + shape_vx), spec - ) - batch.arrays[ArrayKeys.GT_IGNORE] = Array( - np.random.randint(0, 2, (3,) + shape_vx), spec - ) + # combine mask and ignore + mask *= ignore + + assert (scale[mask == 1] > 0).all() + assert (scale[mask == 0] == 0).all() - return batch + num_masked_out = affs.size - mask.sum() + num_masked_in = affs.size - num_masked_out + num_pos = (affs * mask).sum() + num_neg = affs.size - num_masked_out - num_pos + frac_pos = float(num_pos) / num_masked_in if num_masked_in > 0 else 0 + frac_pos = min(0.95, max(0.05, frac_pos)) + frac_neg = 1.0 - frac_pos -class TestBalanceLabels(ProviderTest): - def test_output(self): - pipeline = ExampleSource() + BalanceLabels( - labels=ArrayKeys.GT_AFFINITIES, - scales=ArrayKeys.LOSS_SCALE, - mask=[ArrayKeys.GT_AFFINITIES_MASK, ArrayKeys.GT_IGNORE], + w_pos = 1.0 / (2.0 * frac_pos) + w_neg = 1.0 / (2.0 * frac_neg) + + assert abs((scale * mask * affs).sum() - w_pos * num_pos) < 1e-3 + assert abs((scale * mask * (1 - affs)).sum() - w_neg * num_neg < 1e-3) + + # check if LOSS_SCALE is omitted if not requested + del request[loss_scale_key] + + batch = pipeline.request_batch(request) + assert loss_scale_key not in batch.arrays + + # same using a slab for balancing + + pipeline = ( + ( + ArraySource(affs_key, affs_array), + ArraySource(affs_mask_key, affs_mask_array), + ArraySource(ignore_key, ignore_array), + ) + + MergeProvider() + + BalanceLabels( + labels=affs_key, + scales=loss_scale_key, + mask=[affs_mask_key, ignore_key], + slab=(1, -1, -1, -1), # every channel individually ) + ) - with build(pipeline): - # check correct scaling on 10 random samples - for i in range(10): - request = BatchRequest() - request.add(ArrayKeys.GT_AFFINITIES, (400, 30, 34)) - request.add(ArrayKeys.GT_AFFINITIES_MASK, (400, 30, 34)) - request.add(ArrayKeys.GT_IGNORE, (400, 30, 34)) - request.add(ArrayKeys.LOSS_SCALE, (400, 30, 34)) + with build(pipeline): + # check correct scaling on 10 random samples + for i in range(10): + request = BatchRequest() + request.add(affs_key, (400, 30, 34)) + request.add(affs_mask_key, (400, 30, 34)) + request.add(ignore_key, (400, 30, 34)) + request.add(loss_scale_key, (400, 30, 34)) - batch = pipeline.request_batch(request) + batch = pipeline.request_batch(request) - self.assertTrue(ArrayKeys.LOSS_SCALE in batch.arrays) + assert loss_scale_key in batch.arrays - affs = batch.arrays[ArrayKeys.GT_AFFINITIES].data - scale = batch.arrays[ArrayKeys.LOSS_SCALE].data - mask = batch.arrays[ArrayKeys.GT_AFFINITIES_MASK].data - ignore = batch.arrays[ArrayKeys.GT_IGNORE].data + for c in range(3): + affs = batch.arrays[affs_key].data[c] + scale = batch.arrays[loss_scale_key].data[c] + mask = batch.arrays[affs_mask_key].data[c] + ignore = batch.arrays[ignore_key].data[c] # combine mask and ignore mask *= ignore - self.assertTrue((scale[mask == 1] > 0).all()) - self.assertTrue((scale[mask == 0] == 0).all()) + assert (scale[mask == 1] > 0).all() + assert (scale[mask == 0] == 0).all() num_masked_out = affs.size - mask.sum() num_masked_in = affs.size - num_masked_out @@ -81,68 +144,5 @@ def test_output(self): w_pos = 1.0 / (2.0 * frac_pos) w_neg = 1.0 / (2.0 * frac_neg) - self.assertAlmostEqual((scale * mask * affs).sum(), w_pos * num_pos, 3) - self.assertAlmostEqual( - (scale * mask * (1 - affs)).sum(), w_neg * num_neg, 3 - ) - - # check if LOSS_SCALE is omitted if not requested - del request[ArrayKeys.LOSS_SCALE] - - batch = pipeline.request_batch(request) - self.assertTrue(ArrayKeys.LOSS_SCALE not in batch.arrays) - - # same using a slab for balancing - - pipeline = ExampleSource() + BalanceLabels( - labels=ArrayKeys.GT_AFFINITIES, - scales=ArrayKeys.LOSS_SCALE, - mask=[ArrayKeys.GT_AFFINITIES_MASK, ArrayKeys.GT_IGNORE], - slab=(1, -1, -1, -1), - ) # every channel individually - - with build(pipeline): - # check correct scaling on 10 random samples - for i in range(10): - request = BatchRequest() - request.add(ArrayKeys.GT_AFFINITIES, (400, 30, 34)) - request.add(ArrayKeys.GT_AFFINITIES_MASK, (400, 30, 34)) - request.add(ArrayKeys.GT_IGNORE, (400, 30, 34)) - request.add(ArrayKeys.LOSS_SCALE, (400, 30, 34)) - - batch = pipeline.request_batch(request) - - self.assertTrue(ArrayKeys.LOSS_SCALE in batch.arrays) - - for c in range(3): - affs = batch.arrays[ArrayKeys.GT_AFFINITIES].data[c] - scale = batch.arrays[ArrayKeys.LOSS_SCALE].data[c] - mask = batch.arrays[ArrayKeys.GT_AFFINITIES_MASK].data[c] - ignore = batch.arrays[ArrayKeys.GT_IGNORE].data[c] - - # combine mask and ignore - mask *= ignore - - self.assertTrue((scale[mask == 1] > 0).all()) - self.assertTrue((scale[mask == 0] == 0).all()) - - num_masked_out = affs.size - mask.sum() - num_masked_in = affs.size - num_masked_out - num_pos = (affs * mask).sum() - num_neg = affs.size - num_masked_out - num_pos - - frac_pos = ( - float(num_pos) / num_masked_in if num_masked_in > 0 else 0 - ) - frac_pos = min(0.95, max(0.05, frac_pos)) - frac_neg = 1.0 - frac_pos - - w_pos = 1.0 / (2.0 * frac_pos) - w_neg = 1.0 / (2.0 * frac_neg) - - self.assertAlmostEqual( - (scale * mask * affs).sum(), w_pos * num_pos, 3 - ) - self.assertAlmostEqual( - (scale * mask * (1 - affs)).sum(), w_neg * num_neg, 3 - ) + assert abs((scale * mask * affs).sum() - w_pos * num_pos) < 1e-3 + assert abs((scale * mask * (1 - affs)).sum() - w_neg * num_neg) < 1e-3 diff --git a/tests/cases/batch.py b/tests/cases/batch.py index 5b3c31f9..b8f49419 100644 --- a/tests/cases/batch.py +++ b/tests/cases/batch.py @@ -1,4 +1,5 @@ import logging + import numpy as np from gunpowder import ( diff --git a/tests/cases/crop.py b/tests/cases/crop.py index 47526d13..fd9e6ed5 100644 --- a/tests/cases/crop.py +++ b/tests/cases/crop.py @@ -1,60 +1,63 @@ -from .provider_test import ProviderTest +import logging + +import numpy as np + from gunpowder import ( - BatchProvider, - ArrayKeys, + Array, + ArrayKey, ArraySpec, - Roi, + Crop, + Graph, GraphKey, - GraphKeys, GraphSpec, - Crop, + MergeProvider, + Roi, build, ) -import logging -logger = logging.getLogger(__name__) +from .helper_sources import ArraySource, GraphSource +logger = logging.getLogger(__name__) -class ExampleSourceCrop(BatchProvider): - def setup(self): - self.provides( - ArrayKeys.RAW, - ArraySpec(roi=Roi((200, 20, 20), (1800, 180, 180)), voxel_size=(20, 2, 2)), - ) - self.provides( - GraphKeys.PRESYN, GraphSpec(roi=Roi((200, 20, 20), (1800, 180, 180))) - ) +def test_output(): + raw_key = ArrayKey("RAW") + pre_key = GraphKey("PRESYN") - def provide(self, request): - pass + raw_spec = ArraySpec( + roi=Roi((200, 20, 20), (1800, 180, 180)), voxel_size=(20, 2, 2) + ) + pre_spec = GraphSpec(roi=Roi((200, 20, 20), (1800, 180, 180))) + raw_data = np.zeros(raw_spec.roi.shape / raw_spec.voxel_size) -class TestCrop(ProviderTest): - def test_output(self): - cropped_roi_raw = Roi((400, 40, 40), (1000, 100, 100)) - cropped_roi_presyn = Roi((800, 80, 80), (800, 80, 80)) + raw_array = Array(raw_data, raw_spec) + pre_graph = Graph([], [], pre_spec) - GraphKey("PRESYN") + cropped_roi_raw = Roi((400, 40, 40), (1000, 100, 100)) + cropped_roi_presyn = Roi((800, 80, 80), (800, 80, 80)) - pipeline = ( - ExampleSourceCrop() - + Crop(ArrayKeys.RAW, cropped_roi_raw) - + Crop(GraphKeys.PRESYN, cropped_roi_presyn) - ) + pipeline = ( + (ArraySource(raw_key, raw_array), GraphSource(pre_key, pre_graph)) + + MergeProvider() + + Crop(raw_key, cropped_roi_raw) + + Crop(pre_key, cropped_roi_presyn) + ) - with build(pipeline): - self.assertTrue(pipeline.spec[ArrayKeys.RAW].roi == cropped_roi_raw) - self.assertTrue(pipeline.spec[GraphKeys.PRESYN].roi == cropped_roi_presyn) + with build(pipeline): + assert pipeline.spec[raw_key].roi == cropped_roi_raw + assert pipeline.spec[pre_key].roi == cropped_roi_presyn - pipeline = ExampleSourceCrop() + Crop( - ArrayKeys.RAW, + pipeline = ( + (ArraySource(raw_key, raw_array), GraphSource(pre_key, pre_graph)) + + MergeProvider() + + Crop( + raw_key, fraction_negative=(0.25, 0, 0), fraction_positive=(0.25, 0, 0), ) - expected_roi_raw = Roi((650, 20, 20), (900, 180, 180)) + ) + expected_roi_raw = Roi((650, 20, 20), (900, 180, 180)) - with build(pipeline): - logger.info(pipeline.spec[ArrayKeys.RAW].roi) - logger.info(expected_roi_raw) - self.assertTrue(pipeline.spec[ArrayKeys.RAW].roi == expected_roi_raw) + with build(pipeline): + assert pipeline.spec[raw_key].roi == expected_roi_raw diff --git a/tests/cases/deform_augment.py b/tests/cases/deform_augment.py index f722b0bb..8a527849 100644 --- a/tests/cases/deform_augment.py +++ b/tests/cases/deform_augment.py @@ -1,23 +1,23 @@ +import numpy as np +import pytest +from scipy.ndimage import center_of_mass + from gunpowder import ( - BatchProvider, - GraphSpec, - Roi, - Coordinate, - ArraySpec, - Batch, Array, ArrayKey, - GraphKey, + ArraySpec, + Batch, + BatchProvider, BatchRequest, + Coordinate, DeformAugment, + GraphKey, + GraphSpec, + Roi, build, ) from gunpowder.graph import Graph, Node -from scipy.ndimage import center_of_mass -import pytest -import numpy as np - class GraphTestSource3D(BatchProvider): def __init__(self, graph_key: GraphKey, array_key: ArrayKey, array_key2: ArrayKey): diff --git a/tests/cases/downsample.py b/tests/cases/downsample.py index eec33564..ec8be60c 100644 --- a/tests/cases/downsample.py +++ b/tests/cases/downsample.py @@ -1,8 +1,18 @@ -from .helper_sources import ArraySource - -from gunpowder import * import numpy as np +from gunpowder import ( + Array, + ArrayKey, + ArraySpec, + BatchRequest, + DownSample, + MergeProvider, + Roi, + build, +) + +from .helper_sources import ArraySource + def test_output(): raw = ArrayKey("RAW") diff --git a/tests/cases/dvid_source.py b/tests/cases/dvid_source.py index ac206909..8f5f9d25 100644 --- a/tests/cases/dvid_source.py +++ b/tests/cases/dvid_source.py @@ -1,9 +1,18 @@ -from .provider_test import ProviderTest -from unittest import skipIf -from gunpowder import * -from gunpowder.ext import dvision, NoSuchModule -import socket import logging +import socket + +import pytest + +from gunpowder import ( + ArrayKey, + ArraySpec, + BatchRequest, + DvidSource, + Roi, + Snapshot, + build, +) +from gunpowder.ext import NoSuchModule, dvision logger = logging.getLogger(__name__) @@ -21,45 +30,46 @@ def is_dvid_unavailable(server): return True -class TestDvidSource(ProviderTest): - @skipIf(is_dvid_unavailable(DVID_SERVER), "DVID server not available") - def test_output_3d(self): - # create array keys - raw = ArrayKey("RAW") - seg = ArrayKey("SEG") - mask = ArrayKey("MASK") +@pytest.mark.skipif( + is_dvid_unavailable(DVID_SERVER), reason="DVID server not available" +) +def test_output_3d(tmpdir): + # create array keys + raw = ArrayKey("RAW") + seg = ArrayKey("SEG") + mask = ArrayKey("MASK") - pipeline = DvidSource( - DVID_SERVER, - 32768, - "2ad1d8f0f172425c9f87b60fd97331e6", - datasets={raw: "grayscale", seg: "groundtruth"}, - masks={mask: "seven_column"}, - ) + Snapshot( - { - raw: "/volumes/raw", - seg: "/volumes/labels/neuron_ids", - mask: "/volumes/labels/mask", - }, - output_dir=self.path_to(), - output_filename="dvid_source_test{id}-{iteration}.hdf", - ) + pipeline = DvidSource( + DVID_SERVER, + 32768, + "2ad1d8f0f172425c9f87b60fd97331e6", + datasets={raw: "grayscale", seg: "groundtruth"}, + masks={mask: "seven_column"}, + ) + Snapshot( + { + raw: "/volumes/raw", + seg: "/volumes/labels/neuron_ids", + mask: "/volumes/labels/mask", + }, + output_dir=tmpdir, + output_filename="dvid_source_test{id}-{iteration}.hdf", + ) - with build(pipeline): - batch = pipeline.request_batch( - BatchRequest( - { - raw: ArraySpec(roi=Roi((33000, 15000, 20000), (32000, 8, 80))), - seg: ArraySpec(roi=Roi((33000, 15000, 20000), (32000, 8, 80))), - mask: ArraySpec(roi=Roi((33000, 15000, 20000), (32000, 8, 80))), - } - ) + with build(pipeline): + batch = pipeline.request_batch( + BatchRequest( + { + raw: ArraySpec(roi=Roi((33000, 15000, 20000), (32000, 8, 80))), + seg: ArraySpec(roi=Roi((33000, 15000, 20000), (32000, 8, 80))), + mask: ArraySpec(roi=Roi((33000, 15000, 20000), (32000, 8, 80))), + } ) + ) - self.assertTrue(batch.arrays[raw].spec.interpolatable) - self.assertFalse(batch.arrays[seg].spec.interpolatable) - self.assertFalse(batch.arrays[mask].spec.interpolatable) + assert batch.arrays[raw].spec.interpolatable + assert not batch.arrays[seg].spec.interpolatable + assert not batch.arrays[mask].spec.interpolatable - self.assertEqual(batch.arrays[raw].spec.voxel_size, (8, 8, 8)) - self.assertEqual(batch.arrays[seg].spec.voxel_size, (8, 8, 8)) - self.assertEqual(batch.arrays[mask].spec.voxel_size, (8, 8, 8)) + assert batch.arrays[raw].spec.voxel_size == (8, 8, 8) + assert batch.arrays[seg].spec.voxel_size == (8, 8, 8) + assert batch.arrays[mask].spec.voxel_size == (8, 8, 8) diff --git a/tests/cases/elastic_augment.py b/tests/cases/elastic_augment.py index 694d351c..502b5b14 100644 --- a/tests/cases/elastic_augment.py +++ b/tests/cases/elastic_augment.py @@ -1,29 +1,32 @@ +import math + +import numpy as np + from gunpowder import ( - BatchProvider, - GraphSpec, - Roi, - Coordinate, - ArrayKeys, - ArraySpec, - Batch, Array, ArrayKey, - GraphKey, + ArraySpec, + Batch, + BatchProvider, BatchRequest, + Coordinate, + ElasticAugment, + GraphKey, + GraphSpec, RasterizationSettings, RasterizeGraph, + Roi, Snapshot, - ElasticAugment, build, ) -from gunpowder.graph import GraphKeys, Graph, Node -from .provider_test import ProviderTest - -import numpy as np -import math +from gunpowder.graph import Graph, Node class GraphTestSource3D(BatchProvider): + def __init__(self, points_key, labels_key): + self.points_key = points_key + self.labels_key = labels_key + def setup(self): self.nodes = [ Node(id=0, location=np.array([0, 0, 0])), @@ -35,12 +38,12 @@ def setup(self): ] self.provides( - GraphKeys.TEST_GRAPH, + self.points_key, GraphSpec(roi=Roi((-100, -100, -100), (200, 200, 200))), ) self.provides( - ArrayKeys.TEST_LABELS, + self.labels_key, ArraySpec( roi=Roi((-100, -100, -100), (200, 200, 200)), voxel_size=Coordinate((4, 1, 1)), @@ -50,19 +53,19 @@ def setup(self): def node_to_voxel(self, array_roi, location): # location is in world units, get it into voxels - location = location / self.spec[ArrayKeys.TEST_LABELS].voxel_size + location = location / self.spec[self.labels_key].voxel_size # shift location relative to beginning of array roi - location -= array_roi.begin / self.spec[ArrayKeys.TEST_LABELS].voxel_size + location -= array_roi.begin / self.spec[self.labels_key].voxel_size return tuple(slice(int(l - 2), int(l + 3)) for l in location) def provide(self, request): batch = Batch() - roi_graph = request[GraphKeys.TEST_GRAPH].roi - roi_array = request[ArrayKeys.TEST_LABELS].roi - roi_voxel = roi_array // self.spec[ArrayKeys.TEST_LABELS].voxel_size + roi_graph = request[self.points_key].roi + roi_array = request[self.labels_key].roi + roi_voxel = roi_array // self.spec[self.labels_key].voxel_size data = np.zeros(roi_voxel.shape, dtype=np.uint32) data[:, ::2] = 100 @@ -71,79 +74,78 @@ def provide(self, request): loc = self.node_to_voxel(roi_array, node.location) data[loc] = node.id - spec = self.spec[ArrayKeys.TEST_LABELS].copy() + spec = self.spec[self.labels_key].copy() spec.roi = roi_array - batch.arrays[ArrayKeys.TEST_LABELS] = Array(data, spec=spec) + batch.arrays[self.labels_key] = Array(data, spec=spec) nodes = [] for node in self.nodes: if roi_graph.contains(node.location): nodes.append(node) - batch.graphs[GraphKeys.TEST_GRAPH] = Graph( + batch.graphs[self.points_key] = Graph( nodes=nodes, edges=[], spec=GraphSpec(roi=roi_graph) ) return batch -class TestElasticAugment(ProviderTest): - def test_3d_basics(self): - test_labels = ArrayKey("TEST_LABELS") - test_graph = GraphKey("TEST_GRAPH") - test_raster = ArrayKey("TEST_RASTER") - - pipeline = ( - GraphTestSource3D() - + ElasticAugment( - [10, 10, 10], - [0.1, 0.1, 0.1], - # [0, 0, 0], # no jitter - [0, 2.0 * math.pi], - ) # rotate randomly - + - # [math.pi/4, math.pi/4]) + # rotate by 45 deg - # [0, 0]) + # no rotation - RasterizeGraph( - test_graph, - test_raster, - settings=RasterizationSettings(radius=2, mode="peak"), - ) - + Snapshot( - {test_labels: "volumes/labels", test_raster: "volumes/raster"}, - dataset_dtypes={test_raster: np.float32}, - output_dir=self.path_to(), - output_filename="elastic_augment_test{id}-{iteration}.hdf", - ) +def test_3d_basics(tmpdir): + test_labels = ArrayKey("TEST_LABELS") + test_graph = GraphKey("TEST_GRAPH") + test_raster = ArrayKey("TEST_RASTER") + + pipeline = ( + GraphTestSource3D(test_graph, test_labels) + + ElasticAugment( + [10, 10, 10], + [0.1, 0.1, 0.1], + # [0, 0, 0], # no jitter + [0, 2.0 * math.pi], + ) # rotate randomly + + + # [math.pi/4, math.pi/4]) + # rotate by 45 deg + # [0, 0]) + # no rotation + RasterizeGraph( + test_graph, + test_raster, + settings=RasterizationSettings(radius=2, mode="peak"), ) - - for _ in range(5): - with build(pipeline): - request_roi = Roi((-20, -20, -20), (40, 40, 40)) - - request = BatchRequest() - request[test_labels] = ArraySpec(roi=request_roi) - request[test_graph] = GraphSpec(roi=request_roi) - request[test_raster] = ArraySpec(roi=request_roi) - - batch = pipeline.request_batch(request) - labels = batch[test_labels] - graph = batch[test_graph] - - # the node at (0, 0, 0) should not have moved - # The node at (0,0,0) seems to have moved - # self.assertIn( - # Node(id=0, location=np.array([0, 0, 0])), list(graph.nodes) - # ) - self.assertIn(0, [v.id for v in graph.nodes]) - - labels_data_roi = ( - labels.spec.roi - labels.spec.roi.begin - ) / labels.spec.voxel_size - - # graph should have moved together with the voxels - for node in graph.nodes: - loc = node.location - labels.spec.roi.begin - loc = loc / labels.spec.voxel_size - loc = Coordinate(int(round(x)) for x in loc) - if labels_data_roi.contains(loc): - self.assertEqual(labels.data[loc], node.id) + + Snapshot( + {test_labels: "volumes/labels", test_raster: "volumes/raster"}, + dataset_dtypes={test_raster: np.float32}, + output_dir=tmpdir, + output_filename="elastic_augment_test{id}-{iteration}.hdf", + ) + ) + + for _ in range(5): + with build(pipeline): + request_roi = Roi((-20, -20, -20), (40, 40, 40)) + + request = BatchRequest() + request[test_labels] = ArraySpec(roi=request_roi) + request[test_graph] = GraphSpec(roi=request_roi) + request[test_raster] = ArraySpec(roi=request_roi) + + batch = pipeline.request_batch(request) + labels = batch[test_labels] + graph = batch[test_graph] + + # the node at (0, 0, 0) should not have moved + # The node at (0,0,0) seems to have moved + # self.assertIn( + # Node(id=0, location=np.array([0, 0, 0])), list(graph.nodes) + # ) + assert 0 in [v.id for v in graph.nodes] + + labels_data_roi = ( + labels.spec.roi - labels.spec.roi.begin + ) / labels.spec.voxel_size + + # graph should have moved together with the voxels + for node in graph.nodes: + loc = node.location - labels.spec.roi.begin + loc = loc / labels.spec.voxel_size + loc = Coordinate(int(round(x)) for x in loc) + if labels_data_roi.contains(loc): + assert labels.data[loc] == node.id diff --git a/tests/cases/elastic_augment_points.py b/tests/cases/elastic_augment_points.py index 0ec8a13b..b632117a 100644 --- a/tests/cases/elastic_augment_points.py +++ b/tests/cases/elastic_augment_points.py @@ -1,32 +1,34 @@ +import math +import time + +import numpy as np + from gunpowder import ( - BatchProvider, + Array, + ArrayKey, + ArraySpec, Batch, + BatchProvider, BatchRequest, - GraphSpec, - GraphKeys, - GraphKey, - Graph, - Node, - ArraySpec, - ArrayKeys, - ArrayKey, - Array, - Roi, Coordinate, ElasticAugment, - RasterizeGraph, + Graph, + GraphKey, + GraphSpec, + Node, RasterizationSettings, + RasterizeGraph, + Roi, Snapshot, build, ) -from .provider_test import ProviderTest - -import numpy as np -import math -import time class PointTestSource3D(BatchProvider): + def __init__(self, points_key, labels_key): + self.points_key = points_key + self.labels_key = labels_key + def setup(self): self.points = [ Node(0, np.array([0, 0, 0])), @@ -38,12 +40,12 @@ def setup(self): ] self.provides( - GraphKeys.TEST_POINTS, + self.points_key, GraphSpec(roi=Roi((-100, -100, -100), (200, 200, 200))), ) self.provides( - ArrayKeys.TEST_LABELS, + self.labels_key, ArraySpec( roi=Roi((-100, -100, -100), (200, 200, 200)), voxel_size=Coordinate((4, 1, 1)), @@ -53,19 +55,19 @@ def setup(self): def point_to_voxel(self, array_roi, location): # location is in world units, get it into voxels - location = location / self.spec[ArrayKeys.TEST_LABELS].voxel_size + location = location / self.spec[self.labels_key].voxel_size # shift location relative to beginning of array roi - location -= array_roi.begin / self.spec[ArrayKeys.TEST_LABELS].voxel_size + location -= array_roi.begin / self.spec[self.labels_key].voxel_size return tuple(slice(int(l - 2), int(l + 3)) for l in location) def provide(self, request): batch = Batch() - roi_points = request[GraphKeys.TEST_POINTS].roi - roi_array = request[ArrayKeys.TEST_LABELS].roi - roi_voxel = roi_array // self.spec[ArrayKeys.TEST_LABELS].voxel_size + roi_points = request[self.points_key].roi + roi_array = request[self.labels_key].roi + roi_voxel = roi_array // self.spec[self.labels_key].voxel_size data = np.zeros(roi_voxel.shape, dtype=np.uint32) data[:, ::2] = 100 @@ -74,22 +76,24 @@ def provide(self, request): loc = self.point_to_voxel(roi_array, node.location) data[loc] = node.id - spec = self.spec[ArrayKeys.TEST_LABELS].copy() + spec = self.spec[self.labels_key].copy() spec.roi = roi_array - batch.arrays[ArrayKeys.TEST_LABELS] = Array(data, spec=spec) + batch.arrays[self.labels_key] = Array(data, spec=spec) points = [] for node in self.points: if roi_points.contains(node.location): points.append(node) - batch.graphs[GraphKeys.TEST_POINTS] = Graph( - points, [], GraphSpec(roi=roi_points) - ) + batch.graphs[self.points_key] = Graph(points, [], GraphSpec(roi=roi_points)) return batch class DensePointTestSource3D(BatchProvider): + def __init__(self, points_key, labels_key): + self.points_key = points_key + self.labels_key = labels_key + def setup(self): self.points = [ Node(i, np.array([(i // 100) % 10 * 4, (i // 10) % 10 * 4, i % 10 * 4])) @@ -97,12 +101,12 @@ def setup(self): ] self.provides( - GraphKeys.TEST_POINTS, + self.points_key, GraphSpec(roi=Roi((-40, -40, -40), (120, 120, 120))), ) self.provides( - ArrayKeys.TEST_LABELS, + self.labels_key, ArraySpec( roi=Roi((-40, -40, -40), (120, 120, 120)), voxel_size=Coordinate((4, 1, 1)), @@ -112,19 +116,19 @@ def setup(self): def point_to_voxel(self, array_roi, location): # location is in world units, get it into voxels - location = location / self.spec[ArrayKeys.TEST_LABELS].voxel_size + location = location / self.spec[self.labels_key].voxel_size # shift location relative to beginning of array roi - location -= array_roi.begin / self.spec[ArrayKeys.TEST_LABELS].voxel_size + location -= array_roi.begin / self.spec[self.labels_key].voxel_size return tuple(slice(int(l - 2), int(l + 3)) for l in location) def provide(self, request): batch = Batch() - roi_points = request[GraphKeys.TEST_POINTS].roi - roi_array = request[ArrayKeys.TEST_LABELS].roi - roi_voxel = roi_array // self.spec[ArrayKeys.TEST_LABELS].voxel_size + roi_points = request[self.points_key].roi + roi_array = request[self.labels_key].roi + roi_voxel = roi_array // self.spec[self.labels_key].voxel_size data = np.zeros(roi_voxel.shape, dtype=np.uint32) data[:, ::2] = 100 @@ -133,301 +137,289 @@ def provide(self, request): loc = self.point_to_voxel(roi_array, node.location) data[loc] = node.id - spec = self.spec[ArrayKeys.TEST_LABELS].copy() + spec = self.spec[self.labels_key].copy() spec.roi = roi_array - batch.arrays[ArrayKeys.TEST_LABELS] = Array(data, spec=spec) + batch.arrays[self.labels_key] = Array(data, spec=spec) points = [] for point in self.points: if roi_points.contains(point.location): points.append(point) - batch[GraphKeys.TEST_POINTS] = Graph(points, [], GraphSpec(roi=roi_points)) + batch[self.points_key] = Graph(points, [], GraphSpec(roi=roi_points)) return batch -class TestElasticAugment(ProviderTest): - def test_3d_basics(self): - test_labels = ArrayKey("TEST_LABELS") - test_points = GraphKey("TEST_POINTS") - test_raster = ArrayKey("TEST_RASTER") +def test_3d_basics(tmpdir): + test_labels = ArrayKey("TEST_LABELS") + test_points = GraphKey("TEST_POINTS") + test_raster = ArrayKey("TEST_RASTER") - pipeline = ( - PointTestSource3D() - + ElasticAugment( - [10, 10, 10], - [0.1, 0.1, 0.1], - # [0, 0, 0], # no jitter - [0, 2.0 * math.pi], - ) - + RasterizeGraph( - test_points, - test_raster, - settings=RasterizationSettings(radius=2, mode="peak"), - ) - + Snapshot( - {test_labels: "volumes/labels", test_raster: "volumes/raster"}, - dataset_dtypes={test_raster: np.float32}, - output_dir=self.path_to(), - output_filename="elastic_augment_test{id}-{iteration}.hdf", - ) + pipeline = ( + PointTestSource3D(test_points, test_labels) + + ElasticAugment( + [10, 10, 10], + [0.1, 0.1, 0.1], + # [0, 0, 0], # no jitter + [0, 2.0 * math.pi], ) - - for _ in range(5): - with build(pipeline): - request_roi = Roi((-20, -20, -20), (40, 40, 40)) - - request = BatchRequest() - request[test_labels] = ArraySpec(roi=request_roi) - request[test_points] = GraphSpec(roi=request_roi) - request[test_raster] = ArraySpec(roi=request_roi) - - batch = pipeline.request_batch(request) - labels = batch[test_labels] - points = batch[test_points] - - # the point at (0, 0, 0) should not have moved - self.assertTrue(points.contains(0)) - - labels_data_roi = ( - labels.spec.roi - labels.spec.roi.begin - ) / labels.spec.voxel_size - - # points should have moved together with the voxels - for point in points.nodes: - loc = point.location - labels.spec.roi.begin - loc = loc / labels.spec.voxel_size - loc = Coordinate(int(round(x)) for x in loc) - if labels_data_roi.contains(loc): - self.assertEqual(labels.data[loc], point.id) - - def test_random_seed(self): - test_labels = ArrayKey("TEST_LABELS") - test_points = GraphKey("TEST_POINTS") - test_raster = ArrayKey("TEST_RASTER") - - pipeline = ( - PointTestSource3D() - + ElasticAugment( - [10, 10, 10], - [0.1, 0.1, 0.1], - # [0, 0, 0], # no jitter - [0, 2.0 * math.pi], - ) # rotate randomly - + - # [math.pi/4, math.pi/4]) + # rotate by 45 deg - # [0, 0]) + # no rotation - RasterizeGraph( - test_points, - test_raster, - settings=RasterizationSettings(radius=2, mode="peak"), - ) - + Snapshot( - {test_labels: "volumes/labels", test_raster: "volumes/raster"}, - dataset_dtypes={test_raster: np.float32}, - output_dir=self.path_to(), - output_filename="elastic_augment_test{id}-{iteration}.hdf", - ) + + RasterizeGraph( + test_points, + test_raster, + settings=RasterizationSettings(radius=2, mode="peak"), ) - - batch_points = [] - for _ in range(5): - with build(pipeline): - request_roi = Roi((-20, -20, -20), (40, 40, 40)) - - request = BatchRequest(random_seed=10) - request[test_labels] = ArraySpec(roi=request_roi) - request[test_points] = GraphSpec(roi=request_roi) - request[test_raster] = ArraySpec(roi=request_roi) - batch = pipeline.request_batch(request) - labels = batch[test_labels] - points = batch[test_points] - batch_points.append( - tuple((node.id, tuple(node.location)) for node in points.nodes) - ) - - # the point at (0, 0, 0) should not have moved - data = {node.id: node for node in points.nodes} - self.assertTrue(0 in data) - - labels_data_roi = ( - labels.spec.roi - labels.spec.roi.begin - ) / labels.spec.voxel_size - - # points should have moved together with the voxels - for node in points.nodes: - loc = node.location - labels.spec.roi.begin - loc = loc / labels.spec.voxel_size - loc = Coordinate(int(round(x)) for x in loc) - if labels_data_roi.contains(loc): - self.assertEqual(labels.data[loc], node.id) - - for point_data in zip(*batch_points): - self.assertEqual(len(set(point_data)), 1) - - def test_fast_transform(self): - test_labels = ArrayKey("TEST_LABELS") - test_points = GraphKey("TEST_POINTS") - test_raster = ArrayKey("TEST_RASTER") - fast_pipeline = ( - DensePointTestSource3D() - + ElasticAugment( - [10, 10, 10], - [0.1, 0.1, 0.1], - [0, 2.0 * math.pi], - use_fast_points_transform=True, - ) - + RasterizeGraph( - test_points, - test_raster, - settings=RasterizationSettings(radius=2, mode="peak"), - ) + + Snapshot( + {test_labels: "volumes/labels", test_raster: "volumes/raster"}, + dataset_dtypes={test_raster: np.float32}, + output_dir=tmpdir, + output_filename="elastic_augment_test{id}-{iteration}.hdf", ) - - reference_pipeline = ( - DensePointTestSource3D() - + ElasticAugment([10, 10, 10], [0.1, 0.1, 0.1], [0, 2.0 * math.pi]) - + RasterizeGraph( - test_points, - test_raster, - settings=RasterizationSettings(radius=2, mode="peak"), - ) + ) + + for _ in range(5): + with build(pipeline): + request_roi = Roi((-20, -20, -20), (40, 40, 40)) + + request = BatchRequest() + request[test_labels] = ArraySpec(roi=request_roi) + request[test_points] = GraphSpec(roi=request_roi) + request[test_raster] = ArraySpec(roi=request_roi) + + batch = pipeline.request_batch(request) + labels = batch[test_labels] + points = batch[test_points] + + # the point at (0, 0, 0) should not have moved + assert points.contains(0) + + labels_data_roi = ( + labels.spec.roi - labels.spec.roi.begin + ) / labels.spec.voxel_size + + # points should have moved together with the voxels + for point in points.nodes: + loc = point.location - labels.spec.roi.begin + loc = loc / labels.spec.voxel_size + loc = Coordinate(int(round(x)) for x in loc) + if labels_data_roi.contains(loc): + assert labels.data[loc] == point.id + + +def test_random_seed(tmpdir): + test_labels = ArrayKey("TEST_LABELS") + test_points = GraphKey("TEST_POINTS") + test_raster = ArrayKey("TEST_RASTER") + + pipeline = ( + PointTestSource3D(test_points, test_labels) + + ElasticAugment( + [10, 10, 10], + [0.1, 0.1, 0.1], + # [0, 0, 0], # no jitter + [0, 2.0 * math.pi], + ) # rotate randomly + + + # [math.pi/4, math.pi/4]) + # rotate by 45 deg + # [0, 0]) + # no rotation + RasterizeGraph( + test_points, + test_raster, + settings=RasterizationSettings(radius=2, mode="peak"), ) - - timings = [] - for i in range(5): - points_fast = {} - points_reference = {} - # seed chosen specifically to make this test fail - seed = i + 15 - with build(fast_pipeline): - request_roi = Roi((0, 0, 0), (40, 40, 40)) - - request = BatchRequest(random_seed=seed) - request[test_labels] = ArraySpec(roi=request_roi) - request[test_points] = GraphSpec(roi=request_roi) - request[test_raster] = ArraySpec(roi=request_roi) - - t1_fast = time.time() - batch = fast_pipeline.request_batch(request) - t2_fast = time.time() - points_fast = {node.id: node for node in batch[test_points].nodes} - - with build(reference_pipeline): - request_roi = Roi((0, 0, 0), (40, 40, 40)) - - request = BatchRequest(random_seed=seed) - request[test_labels] = ArraySpec(roi=request_roi) - request[test_points] = GraphSpec(roi=request_roi) - request[test_raster] = ArraySpec(roi=request_roi) - - t1_ref = time.time() - batch = reference_pipeline.request_batch(request) - t2_ref = time.time() - points_reference = {node.id: node for node in batch[test_points].nodes} - - timings.append((t2_fast - t1_fast, t2_ref - t1_ref)) - diffs = [] - missing = 0 - for point_id, point in points_reference.items(): - if point_id not in points_fast: - missing += 1 - continue - diff = point.location - points_fast[point_id].location - diffs.append(tuple(diff)) - self.assertAlmostEqual( - np.linalg.norm(diff), - 0, - delta=1, - msg="fast transform returned location {} but expected {} for point {}".format( - point.location, points_fast[point_id].location, point_id - ), - ) - - t_fast, t_ref = [np.mean(x) for x in zip(*timings)] - self.assertLess(t_fast, t_ref) - self.assertEqual(missing, 0) - - def test_fast_transform_no_recompute(self): - test_labels = ArrayKey("TEST_LABELS") - test_points = GraphKey("TEST_POINTS") - test_raster = ArrayKey("TEST_RASTER") - fast_pipeline = ( - DensePointTestSource3D() - + ElasticAugment( - [10, 10, 10], - [0.1, 0.1, 0.1], - [0, 2.0 * math.pi], - use_fast_points_transform=True, - recompute_missing_points=False, - ) - + RasterizeGraph( - test_points, - test_raster, - settings=RasterizationSettings(radius=2, mode="peak"), - ) + + Snapshot( + {test_labels: "volumes/labels", test_raster: "volumes/raster"}, + dataset_dtypes={test_raster: np.float32}, + output_dir=tmpdir, + output_filename="elastic_augment_test{id}-{iteration}.hdf", ) - - reference_pipeline = ( - DensePointTestSource3D() - + ElasticAugment([10, 10, 10], [0.1, 0.1, 0.1], [0, 2.0 * math.pi]) - + RasterizeGraph( - test_points, - test_raster, - settings=RasterizationSettings(radius=2, mode="peak"), + ) + + batch_points = [] + for _ in range(5): + with build(pipeline): + request_roi = Roi((-20, -20, -20), (40, 40, 40)) + + request = BatchRequest(random_seed=10) + request[test_labels] = ArraySpec(roi=request_roi) + request[test_points] = GraphSpec(roi=request_roi) + request[test_raster] = ArraySpec(roi=request_roi) + batch = pipeline.request_batch(request) + labels = batch[test_labels] + points = batch[test_points] + batch_points.append( + tuple((node.id, tuple(node.location)) for node in points.nodes) ) - ) - timings = [] - for i in range(5): - points_fast = {} - points_reference = {} - # seed chosen specifically to make this test fail - seed = i + 15 - with build(fast_pipeline): - request_roi = Roi((0, 0, 0), (40, 40, 40)) - - request = BatchRequest(random_seed=seed) - request[test_labels] = ArraySpec(roi=request_roi) - request[test_points] = GraphSpec(roi=request_roi) - request[test_raster] = ArraySpec(roi=request_roi) - - t1_fast = time.time() - batch = fast_pipeline.request_batch(request) - t2_fast = time.time() - points_fast = {node.id: node for node in batch[test_points].nodes} - - with build(reference_pipeline): - request_roi = Roi((0, 0, 0), (40, 40, 40)) - - request = BatchRequest(random_seed=seed) - request[test_labels] = ArraySpec(roi=request_roi) - request[test_points] = GraphSpec(roi=request_roi) - request[test_raster] = ArraySpec(roi=request_roi) - - t1_ref = time.time() - batch = reference_pipeline.request_batch(request) - t2_ref = time.time() - points_reference = {node.id: node for node in batch[test_points].nodes} - - timings.append((t2_fast - t1_fast, t2_ref - t1_ref)) - diffs = [] - missing = 0 - for point_id, point in points_reference.items(): - if point_id not in points_fast: - missing += 1 - continue - diff = point.location - points_fast[point_id].location - diffs.append(tuple(diff)) - self.assertAlmostEqual( - np.linalg.norm(diff), - 0, - delta=1, - msg="fast transform returned location {} but expected {} for point {}".format( - point.location, points_fast[point_id].location, point_id - ), - ) - - t_fast, t_ref = [np.mean(x) for x in zip(*timings)] - self.assertLess(t_fast, t_ref) - self.assertGreater(missing, 0) + # the point at (0, 0, 0) should not have moved + data = {node.id: node for node in points.nodes} + assert 0 in data + + labels_data_roi = ( + labels.spec.roi - labels.spec.roi.begin + ) / labels.spec.voxel_size + + # points should have moved together with the voxels + for node in points.nodes: + loc = node.location - labels.spec.roi.begin + loc = loc / labels.spec.voxel_size + loc = Coordinate(int(round(x)) for x in loc) + if labels_data_roi.contains(loc): + assert labels.data[loc] == node.id + + for point_data in zip(*batch_points): + assert len(set(point_data)) == 1 + + +def test_fast_transform(tmpdir): + test_labels = ArrayKey("TEST_LABELS") + test_points = GraphKey("TEST_POINTS") + test_raster = ArrayKey("TEST_RASTER") + fast_pipeline = ( + DensePointTestSource3D(test_points, test_labels) + + ElasticAugment( + [10, 10, 10], + [0.1, 0.1, 0.1], + [0, 2.0 * math.pi], + use_fast_points_transform=True, + ) + + RasterizeGraph( + test_points, + test_raster, + settings=RasterizationSettings(radius=2, mode="peak"), + ) + ) + + reference_pipeline = ( + DensePointTestSource3D(test_points, test_labels) + + ElasticAugment([10, 10, 10], [0.1, 0.1, 0.1], [0, 2.0 * math.pi]) + + RasterizeGraph( + test_points, + test_raster, + settings=RasterizationSettings(radius=2, mode="peak"), + ) + ) + + timings = [] + for i in range(5): + points_fast = {} + points_reference = {} + # seed chosen specifically to make this test fail + seed = i + 15 + with build(fast_pipeline): + request_roi = Roi((0, 0, 0), (40, 40, 40)) + + request = BatchRequest(random_seed=seed) + request[test_labels] = ArraySpec(roi=request_roi) + request[test_points] = GraphSpec(roi=request_roi) + request[test_raster] = ArraySpec(roi=request_roi) + + t1_fast = time.time() + batch = fast_pipeline.request_batch(request) + t2_fast = time.time() + points_fast = {node.id: node for node in batch[test_points].nodes} + + with build(reference_pipeline): + request_roi = Roi((0, 0, 0), (40, 40, 40)) + + request = BatchRequest(random_seed=seed) + request[test_labels] = ArraySpec(roi=request_roi) + request[test_points] = GraphSpec(roi=request_roi) + request[test_raster] = ArraySpec(roi=request_roi) + + t1_ref = time.time() + batch = reference_pipeline.request_batch(request) + t2_ref = time.time() + points_reference = {node.id: node for node in batch[test_points].nodes} + + timings.append((t2_fast - t1_fast, t2_ref - t1_ref)) + diffs = [] + missing = 0 + for point_id, point in points_reference.items(): + if point_id not in points_fast: + missing += 1 + continue + diff = point.location - points_fast[point_id].location + diffs.append(tuple(diff)) + assert np.linalg.norm(diff) < 1.5 + + t_fast, t_ref = [np.mean(x) for x in zip(*timings)] + assert t_fast < t_ref + assert missing == 0 + + +def test_fast_transform_no_recompute(tmpdir): + test_labels = ArrayKey("TEST_LABELS") + test_points = GraphKey("TEST_POINTS") + test_raster = ArrayKey("TEST_RASTER") + fast_pipeline = ( + DensePointTestSource3D(test_points, test_labels) + + ElasticAugment( + [10, 10, 10], + [0.1, 0.1, 0.1], + [0, 2.0 * math.pi], + use_fast_points_transform=True, + recompute_missing_points=False, + ) + + RasterizeGraph( + test_points, + test_raster, + settings=RasterizationSettings(radius=2, mode="peak"), + ) + ) + + reference_pipeline = ( + DensePointTestSource3D(test_points, test_labels) + + ElasticAugment([10, 10, 10], [0.1, 0.1, 0.1], [0, 2.0 * math.pi]) + + RasterizeGraph( + test_points, + test_raster, + settings=RasterizationSettings(radius=2, mode="peak"), + ) + ) + + timings = [] + for i in range(5): + points_fast = {} + points_reference = {} + # seed chosen specifically to make this test fail + seed = i + 15 + with build(fast_pipeline): + request_roi = Roi((0, 0, 0), (40, 40, 40)) + + request = BatchRequest(random_seed=seed) + request[test_labels] = ArraySpec(roi=request_roi) + request[test_points] = GraphSpec(roi=request_roi) + request[test_raster] = ArraySpec(roi=request_roi) + + t1_fast = time.time() + batch = fast_pipeline.request_batch(request) + t2_fast = time.time() + points_fast = {node.id: node for node in batch[test_points].nodes} + + with build(reference_pipeline): + request_roi = Roi((0, 0, 0), (40, 40, 40)) + + request = BatchRequest(random_seed=seed) + request[test_labels] = ArraySpec(roi=request_roi) + request[test_points] = GraphSpec(roi=request_roi) + request[test_raster] = ArraySpec(roi=request_roi) + + t1_ref = time.time() + batch = reference_pipeline.request_batch(request) + t2_ref = time.time() + points_reference = {node.id: node for node in batch[test_points].nodes} + + timings.append((t2_fast - t1_fast, t2_ref - t1_ref)) + diffs = [] + missing = 0 + for point_id, point in points_reference.items(): + if point_id not in points_fast: + missing += 1 + continue + diff = point.location - points_fast[point_id].location + diffs.append(tuple(diff)) + assert np.linalg.norm(diff) < 1.5 + + t_fast, t_ref = [np.mean(x) for x in zip(*timings)] + assert t_fast < t_ref + assert missing > 0 diff --git a/tests/cases/expected_failures.py b/tests/cases/expected_failures.py index 8adf48af..9a72bf74 100644 --- a/tests/cases/expected_failures.py +++ b/tests/cases/expected_failures.py @@ -1,11 +1,11 @@ +import numpy as np +import pytest +from funlib.geometry import Coordinate + import gunpowder as gp from gunpowder.nodes.batch_provider import BatchRequestError -from .helper_sources import ArraySource -from funlib.geometry import Coordinate - -import numpy as np -import pytest +from .helper_sources import ArraySource @pytest.mark.xfail() diff --git a/tests/cases/graph.py b/tests/cases/graph.py index 20b09b5f..57b5f2c7 100644 --- a/tests/cases/graph.py +++ b/tests/cases/graph.py @@ -1,25 +1,24 @@ -from .provider_test import ProviderTest +import numpy as np + from gunpowder import ( + Batch, + BatchFilter, BatchProvider, BatchRequest, - BatchFilter, - Batch, - Node, + Coordinate, Edge, Graph, - GraphSpec, GraphKey, - GraphKeys, - build, + GraphSpec, + Node, Roi, - Coordinate, + build, ) -import numpy as np - class ExampleGraphSource(BatchProvider): - def __init__(self): + def __init__(self, graph_key): + self.graph_key = graph_key self.dtype = float self.__vertices = [ Node(id=1, location=np.array([1, 1, 1], dtype=self.dtype)), @@ -33,21 +32,24 @@ def __init__(self): self.graph = Graph(self.__vertices, self.__edges, self.__spec) def setup(self): - self.provides(GraphKeys.TEST_GRAPH, self.__spec) + self.provides(self.graph_key, self.__spec) def provide(self, request): batch = Batch() - roi = request[GraphKeys.TEST_GRAPH].roi + roi = request[self.graph_key].roi sub_graph = self.graph.crop(roi) - batch[GraphKeys.TEST_GRAPH] = sub_graph + batch[self.graph_key] = sub_graph return batch class GrowFilter(BatchFilter): + def __init__(self, graph_key): + self.graph_key = graph_key + def prepare(self, request): grow = Coordinate([50, 50, 50]) for key, spec in request.items(): @@ -61,120 +63,119 @@ def process(self, batch, request): return batch -class TestGraphs(ProviderTest): - @property - def edges(self): - return [Edge(0, 1), Edge(1, 2), Edge(2, 3), Edge(3, 4), Edge(4, 0)] +def edges(): + return [Edge(0, 1), Edge(1, 2), Edge(2, 3), Edge(3, 4), Edge(4, 0)] - @property - def nodes(self): - return [ - Node(0, location=np.array([0, 0, 0], dtype=self.spec.dtype)), - Node(1, location=np.array([1, 1, 1], dtype=self.spec.dtype)), - Node(2, location=np.array([2, 2, 2], dtype=self.spec.dtype)), - Node(3, location=np.array([3, 3, 3], dtype=self.spec.dtype)), - Node(4, location=np.array([4, 4, 4], dtype=self.spec.dtype)), - ] - @property - def spec(self): - return GraphSpec( - roi=Roi(Coordinate([0, 0, 0]), Coordinate([5, 5, 5])), directed=True - ) +def nodes(): + return [ + Node(0, location=np.array([0, 0, 0])), + Node(1, location=np.array([1, 1, 1])), + Node(2, location=np.array([2, 2, 2])), + Node(3, location=np.array([3, 3, 3])), + Node(4, location=np.array([4, 4, 4])), + ] + + +def spec(): + return GraphSpec( + roi=Roi(Coordinate([0, 0, 0]), Coordinate([5, 5, 5])), directed=True + ) + - def test_output(self): - GraphKey("TEST_GRAPH") - - pipeline = ExampleGraphSource() + GrowFilter() - - with build(pipeline): - batch = pipeline.request_batch( - BatchRequest( - {GraphKeys.TEST_GRAPH: GraphSpec(roi=Roi((0, 0, 0), (50, 50, 50)))} - ) - ) - - graph = batch[GraphKeys.TEST_GRAPH] - expected_vertices = ( - Node(id=1, location=np.array([1.0, 1.0, 1.0], dtype=float)), - Node( - id=2, - location=np.array([50.0, 50.0, 50.0], dtype=float), - temporary=True, - ), - ) - seen_vertices = tuple(graph.nodes) - self.assertCountEqual( - [v.original_id for v in expected_vertices], - [v.original_id for v in seen_vertices], - ) - for expected, actual in zip( - sorted(expected_vertices, key=lambda v: tuple(v.location)), - sorted(seen_vertices, key=lambda v: tuple(v.location)), - ): - assert all(np.isclose(expected.location, actual.location)) - - batch = pipeline.request_batch( - BatchRequest( - { - GraphKeys.TEST_GRAPH: GraphSpec( - roi=Roi((25, 25, 25), (500, 500, 500)) - ) - } - ) - ) - - graph = batch[GraphKeys.TEST_GRAPH] - expected_vertices = ( - Node( - id=1, - location=np.array([25.0, 25.0, 25.0], dtype=float), - temporary=True, - ), - Node(id=2, location=np.array([500.0, 500.0, 500.0], dtype=float)), - Node( - id=3, - location=np.array([525.0, 525.0, 525.0], dtype=float), - temporary=True, - ), - ) - seen_vertices = tuple(graph.nodes) - self.assertCountEqual( - [v.original_id for v in expected_vertices], - [v.original_id for v in seen_vertices], - ) - for expected, actual in zip( - sorted(expected_vertices, key=lambda v: tuple(v.location)), - sorted(seen_vertices, key=lambda v: tuple(v.location)), - ): - assert all(np.isclose(expected.location, actual.location)) - - def test_neighbors(self): - # directed - d_spec = self.spec - # undirected - ud_spec = self.spec - ud_spec.directed = False - - directed = Graph(self.nodes, self.edges, d_spec) - undirected = Graph(self.nodes, self.edges, ud_spec) - - self.assertCountEqual( - directed.neighbors(self.nodes[0]), undirected.neighbors(self.nodes[0]) +def test_output(): + graph_key = GraphKey("TEST_GRAPH") + + pipeline = ExampleGraphSource(graph_key) + GrowFilter(graph_key) + + with build(pipeline): + batch = pipeline.request_batch( + BatchRequest({graph_key: GraphSpec(roi=Roi((0, 0, 0), (50, 50, 50)))}) ) - def test_crop(self): - g = Graph(self.nodes, self.edges, self.spec) + graph = batch[graph_key] + expected_vertices = ( + Node(id=1, location=np.array([1.0, 1.0, 1.0], dtype=float)), + Node( + id=2, + location=np.array([50.0, 50.0, 50.0], dtype=float), + temporary=True, + ), + ) + seen_vertices = tuple(graph.nodes) + assert sorted( + [ + v.original_id if v.original_id is not None else -1 + for v in expected_vertices + ] + ) == sorted( + [v.original_id if v.original_id is not None else -1 for v in seen_vertices] + ) + for expected, actual in zip( + sorted(expected_vertices, key=lambda v: tuple(v.location)), + sorted(seen_vertices, key=lambda v: tuple(v.location)), + ): + assert all(np.isclose(expected.location, actual.location)) + + batch = pipeline.request_batch( + BatchRequest({graph_key: GraphSpec(roi=Roi((25, 25, 25), (500, 500, 500)))}) + ) - sub_g = g.crop(Roi(Coordinate([1, 1, 1]), Coordinate([3, 3, 3]))) - self.assertEqual(g.spec.roi, self.spec.roi) - self.assertEqual( - sub_g.spec.roi, Roi(Coordinate([1, 1, 1]), Coordinate([3, 3, 3])) + graph = batch[graph_key] + expected_vertices = ( + Node( + id=1, + location=np.array([25.0, 25.0, 25.0], dtype=float), + temporary=True, + ), + Node(id=2, location=np.array([500.0, 500.0, 500.0], dtype=float)), + Node( + id=3, + location=np.array([525.0, 525.0, 525.0], dtype=float), + temporary=True, + ), ) + seen_vertices = tuple(graph.nodes) + assert sorted( + [ + v.original_id if v.original_id is not None else -1 + for v in expected_vertices + ] + ) == sorted( + [v.original_id if v.original_id is not None else -1 for v in seen_vertices] + ) + for expected, actual in zip( + sorted(expected_vertices, key=lambda v: tuple(v.location)), + sorted(seen_vertices, key=lambda v: tuple(v.location)), + ): + assert all(np.isclose(expected.location, actual.location)) + + +def test_neighbors(): + # directed + d_spec = spec() + # undirected + ud_spec = spec() + ud_spec.directed = False + + directed = Graph(nodes(), edges(), d_spec) + undirected = Graph(nodes(), edges(), ud_spec) + + assert [x for x in directed.neighbors(nodes()[0])] == [ + x for x in undirected.neighbors(nodes()[0]) + ] + + +def test_crop(): + g = Graph(nodes(), edges(), spec()) + + sub_g = g.crop(Roi(Coordinate([1, 1, 1]), Coordinate([3, 3, 3]))) + assert g.spec.roi == spec().roi + assert sub_g.spec.roi == Roi(Coordinate([1, 1, 1]), Coordinate([3, 3, 3])) - sub_g.spec.directed = False - self.assertTrue(g.spec.directed) - self.assertFalse(sub_g.spec.directed) + sub_g.spec.directed = False + assert g.spec.directed + assert not sub_g.spec.directed def test_nodes(): diff --git a/tests/cases/graph_keys.py b/tests/cases/graph_keys.py index 56e6e207..421d735f 100644 --- a/tests/cases/graph_keys.py +++ b/tests/cases/graph_keys.py @@ -1,11 +1,13 @@ from __future__ import print_function + +import pytest + from gunpowder import GraphKey, GraphKeys -import unittest -class TestGraphKeys(unittest.TestCase): - def test_register(self): - GraphKey("TEST_GRAPH") +def test_register(): + GraphKey("TEST_GRAPH") - self.assertTrue(GraphKeys.TEST_GRAPH) - self.assertRaises(AttributeError, getattr, GraphKeys, "TEST_GRAPH_2") + assert GraphKeys.TEST_GRAPH + with pytest.raises(AttributeError): + getattr(GraphKeys, "TEST_GRAPH_2") diff --git a/tests/cases/graph_source.py b/tests/cases/graph_source.py index 1a61c074..3cd65a94 100644 --- a/tests/cases/graph_source.py +++ b/tests/cases/graph_source.py @@ -1,20 +1,18 @@ -from .provider_test import ProviderTest +import networkx as nx +import numpy as np + from gunpowder import ( BatchRequest, - Node, + Coordinate, Edge, - GraphSpec, GraphKey, - GraphKeys, GraphSource, - build, + GraphSpec, + Node, Roi, - Coordinate, + build, ) -import numpy as np -import networkx as nx - class DummyDaisyGraphProvider: """Dummy graph provider mimicing daisy.SharedGraphProvider. @@ -33,84 +31,73 @@ def __getitem__(self, roi): graph = nx.DiGraph() else: graph = nx.Graph() - for node in self.nodes: + for node in nodes(): if roi.contains(node.location): graph.add_node(node.id, location=node.location) - for edge in self.edges: + for edge in edges(): if edge.u in graph.nodes: graph.add_edge(edge.u, edge.v) return graph -class TestGraphSource(ProviderTest): - @property - def edges(self): - return [Edge(0, 1), Edge(1, 2), Edge(2, 3), Edge(3, 4), Edge(4, 0)] - - @property - def nodes(self): - return [ - Node(0, location=np.array([0, 0, 0], dtype=self.spec.dtype)), - Node(1, location=np.array([1, 1, 1], dtype=self.spec.dtype)), - Node(2, location=np.array([2, 2, 2], dtype=self.spec.dtype)), - Node(3, location=np.array([3, 3, 3], dtype=self.spec.dtype)), - Node(4, location=np.array([4, 4, 4], dtype=self.spec.dtype)), - ] - - @property - def spec(self): - return GraphSpec( - roi=Roi(Coordinate([0, 0, 0]), Coordinate([5, 5, 5])), directed=True +def edges(): + return [Edge(0, 1), Edge(1, 2), Edge(2, 3), Edge(3, 4), Edge(4, 0)] + + +def nodes(): + return [ + Node(0, location=np.array([0, 0, 0], dtype=spec().dtype)), + Node(1, location=np.array([1, 1, 1], dtype=spec().dtype)), + Node(2, location=np.array([2, 2, 2], dtype=spec().dtype)), + Node(3, location=np.array([3, 3, 3], dtype=spec().dtype)), + Node(4, location=np.array([4, 4, 4], dtype=spec().dtype)), + ] + + +def spec(): + return GraphSpec( + roi=Roi(Coordinate([0, 0, 0]), Coordinate([5, 5, 5])), directed=True + ) + + +def test_output(): + graph_key = GraphKey("TEST_GRAPH") + + dummy_provider = DummyDaisyGraphProvider(nodes(), edges(), directed=True) + graph_source = GraphSource(dummy_provider, graph_key, spec()) + + pipeline = graph_source + + with build(pipeline): + batch = pipeline.request_batch( + BatchRequest({graph_key: GraphSpec(roi=Roi((0, 0, 0), (5, 5, 5)))}) ) - def test_output(self): - GraphKey("TEST_GRAPH") - - dummy_provider = DummyDaisyGraphProvider(self.nodes, self.edges, directed=True) - graph_source = GraphSource(dummy_provider, GraphKeys.TEST_GRAPH, self.spec) - - pipeline = graph_source - - with build(pipeline): - batch = pipeline.request_batch( - BatchRequest( - {GraphKeys.TEST_GRAPH: GraphSpec(roi=Roi((0, 0, 0), (5, 5, 5)))} - ) - ) - - graph = batch[GraphKeys.TEST_GRAPH] - expected_vertices = self.nodes - seen_vertices = tuple(graph.nodes) - self.assertCountEqual( - [v.id for v in expected_vertices], - [v.id for v in seen_vertices], - ) - for expected, actual in zip( - sorted(expected_vertices, key=lambda v: tuple(v.location)), - sorted(seen_vertices, key=lambda v: tuple(v.location)), - ): - assert all(np.isclose(expected.location, actual.location)) - - batch = pipeline.request_batch( - BatchRequest( - {GraphKeys.TEST_GRAPH: GraphSpec(roi=Roi((2, 2, 2), (3, 3, 3)))} - ) - ) - - graph = batch[GraphKeys.TEST_GRAPH] - expected_vertices = ( - Node(2, location=np.array([2, 2, 2], dtype=self.spec.dtype)), - Node(3, location=np.array([3, 3, 3], dtype=self.spec.dtype)), - Node(4, location=np.array([4, 4, 4], dtype=self.spec.dtype)), - ) - seen_vertices = tuple(graph.nodes) - print(seen_vertices) - self.assertCountEqual( - [v.id for v in expected_vertices], - [v.id for v in seen_vertices], - ) - for expected, actual in zip( - sorted(expected_vertices, key=lambda v: tuple(v.location)), - sorted(seen_vertices, key=lambda v: tuple(v.location)), - ): - assert all(np.isclose(expected.location, actual.location)) + graph = batch[graph_key] + expected_vertices = nodes() + seen_vertices = tuple(graph.nodes) + assert [v.id for v in expected_vertices] == [v.id for v in seen_vertices] + for expected, actual in zip( + sorted(expected_vertices, key=lambda v: tuple(v.location)), + sorted(seen_vertices, key=lambda v: tuple(v.location)), + ): + assert all(np.isclose(expected.location, actual.location)) + + batch = pipeline.request_batch( + BatchRequest({graph_key: GraphSpec(roi=Roi((2, 2, 2), (3, 3, 3)))}) + ) + + graph = batch[graph_key] + expected_vertices = ( + Node(2, location=np.array([2, 2, 2], dtype=spec().dtype)), + Node(3, location=np.array([3, 3, 3], dtype=spec().dtype)), + Node(4, location=np.array([4, 4, 4], dtype=spec().dtype)), + ) + seen_vertices = tuple(graph.nodes) + print(seen_vertices) + assert [v.id for v in expected_vertices] == [v.id for v in seen_vertices] + for expected, actual in zip( + sorted(expected_vertices, key=lambda v: tuple(v.location)), + sorted(seen_vertices, key=lambda v: tuple(v.location)), + ): + assert all(np.isclose(expected.location, actual.location)) diff --git a/tests/cases/hdf5_source.py b/tests/cases/hdf5_source.py index e9661b2d..945bfb3b 100644 --- a/tests/cases/hdf5_source.py +++ b/tests/cases/hdf5_source.py @@ -1,117 +1,134 @@ -from unittest import skipIf - -from .provider_test import ProviderTest -from gunpowder import * import numpy as np -from gunpowder.ext import h5py, zarr, ZarrFile, NoSuchModule - - -class Hdf5LikeSourceTestMixin(object): - """This class is to be used as a mixin for ProviderTest classes testing HDF5, N5 and Zarr - batch providers. - - Subclasses must define ``extension`` and ``SourceUnderTest`` class variables, and an - ``_open_writable_file(self, path)`` method. See TestHdf5Source for examples. - """ - - extension = None - SourceUnderTest = None - - def _open_writable_file(self, path): - raise NotImplementedError("_open_writable_file should be overridden") - - def _create_dataset(self, data_file, key, data, chunks=None, **kwargs): - chunks = chunks or data.shape - d = data_file.create_dataset( - key, shape=data.shape, dtype=data.dtype, chunks=chunks +import pytest + +from gunpowder import ( + ArrayKey, + ArraySpec, + BatchRequest, + Hdf5Source, + Roi, + ZarrSource, + build, +) +from gunpowder.ext import NoSuchModule, ZarrFile, h5py, zarr + +extension = None +SourceUnderTest = None + + +def open_zarr(path): + return ZarrFile(path, mode="w") + + +def open_hdf(path): + return h5py.File(path, "w") + + +open_writable_file_func = { + "hdf": open_hdf, + "zarr": open_zarr, +} +source_node = { + "hdf": Hdf5Source, + "zarr": ZarrSource, +} + + +def create_dataset(data_file, key, data, chunks=None, **kwargs): + chunks = chunks or data.shape + d = data_file.create_dataset(key, shape=data.shape, dtype=data.dtype, chunks=chunks) + d[:] = data + for key, value in kwargs.items(): + d.attrs[key] = value + + +@pytest.mark.parametrize( + "extension", + [ + "hdf", + pytest.param( + "zarr", + marks=pytest.mark.skipif( + isinstance(zarr, NoSuchModule), reason="zarr is not installed" + ), + ), + ], +) +def test_output_2d(extension, tmpdir): + path = tmpdir / f"test_{extension}_source.{extension}" + + with open_writable_file_func[extension](path) as f: + create_dataset(f, "raw", np.zeros((100, 100), dtype=np.float32)) + create_dataset( + f, "raw_low", np.zeros((10, 10), dtype=np.float32), resolution=(10, 10) ) - d[:] = data - for key, value in kwargs.items(): - d.attrs[key] = value - - def test_output_2d(self): - path = self.path_to("test_{0}_source.{0}".format(self.extension)) - - with self._open_writable_file(path) as f: - self._create_dataset(f, "raw", np.zeros((100, 100), dtype=np.float32)) - self._create_dataset( - f, "raw_low", np.zeros((10, 10), dtype=np.float32), resolution=(10, 10) + create_dataset(f, "seg", np.ones((100, 100), dtype=np.uint64)) + + # read arrays + raw = ArrayKey("RAW") + raw_low = ArrayKey("RAW_LOW") + seg = ArrayKey("SEG") + source = source_node[extension](path, {raw: "raw", raw_low: "raw_low", seg: "seg"}) + + with build(source): + batch = source.request_batch( + BatchRequest( + { + raw: ArraySpec(roi=Roi((0, 0), (100, 100))), + raw_low: ArraySpec(roi=Roi((0, 0), (100, 100))), + seg: ArraySpec(roi=Roi((0, 0), (100, 100))), + } ) - self._create_dataset(f, "seg", np.ones((100, 100), dtype=np.uint64)) - - # read arrays - raw = ArrayKey("RAW") - raw_low = ArrayKey("RAW_LOW") - seg = ArrayKey("SEG") - source = self.SourceUnderTest( - path, {raw: "raw", raw_low: "raw_low", seg: "seg"} ) - with build(source): - batch = source.request_batch( - BatchRequest( - { - raw: ArraySpec(roi=Roi((0, 0), (100, 100))), - raw_low: ArraySpec(roi=Roi((0, 0), (100, 100))), - seg: ArraySpec(roi=Roi((0, 0), (100, 100))), - } - ) - ) - - self.assertTrue(batch.arrays[raw].spec.interpolatable) - self.assertTrue(batch.arrays[raw_low].spec.interpolatable) - self.assertFalse(batch.arrays[seg].spec.interpolatable) - - def test_output_3d(self): - path = self.path_to("test_{0}_source.{0}".format(self.extension)) - - # create a test file - with self._open_writable_file(path) as f: - self._create_dataset(f, "raw", np.zeros((100, 100, 100), dtype=np.float32)) - self._create_dataset( - f, - "raw_low", - np.zeros((10, 10, 10), dtype=np.float32), - resolution=(10, 10, 10), - ) - self._create_dataset(f, "seg", np.ones((100, 100, 100), dtype=np.uint64)) - - # read arrays - raw = ArrayKey("RAW") - raw_low = ArrayKey("RAW_LOW") - seg = ArrayKey("SEG") - source = self.SourceUnderTest( - path, {raw: "raw", raw_low: "raw_low", seg: "seg"} + assert batch.arrays[raw].spec.interpolatable + assert batch.arrays[raw_low].spec.interpolatable + assert not (batch.arrays[seg].spec.interpolatable) + + +@pytest.mark.parametrize( + "extension", + [ + "hdf", + pytest.param( + "zarr", + marks=pytest.mark.skipif( + isinstance(zarr, NoSuchModule), reason="zarr is not installed" + ), + ), + ], +) +def test_output_3d(extension, tmpdir): + path = tmpdir / f"test_{extension}_source.{extension}" + + # create a test file + with open_writable_file_func[extension](path) as f: + create_dataset(f, "raw", np.zeros((100, 100, 100), dtype=np.float32)) + create_dataset( + f, + "raw_low", + np.zeros((10, 10, 10), dtype=np.float32), + resolution=(10, 10, 10), ) - - with build(source): - batch = source.request_batch( - BatchRequest( - { - raw: ArraySpec(roi=Roi((0, 0, 0), (100, 100, 100))), - raw_low: ArraySpec(roi=Roi((0, 0, 0), (100, 100, 100))), - seg: ArraySpec(roi=Roi((0, 0, 0), (100, 100, 100))), - } - ) + create_dataset(f, "seg", np.ones((100, 100, 100), dtype=np.uint64)) + + # read arrays + raw = ArrayKey("RAW") + raw_low = ArrayKey("RAW_LOW") + seg = ArrayKey("SEG") + source = source_node[extension](path, {raw: "raw", raw_low: "raw_low", seg: "seg"}) + + with build(source): + batch = source.request_batch( + BatchRequest( + { + raw: ArraySpec(roi=Roi((0, 0, 0), (100, 100, 100))), + raw_low: ArraySpec(roi=Roi((0, 0, 0), (100, 100, 100))), + seg: ArraySpec(roi=Roi((0, 0, 0), (100, 100, 100))), + } ) + ) - self.assertTrue(batch.arrays[raw].spec.interpolatable) - self.assertTrue(batch.arrays[raw_low].spec.interpolatable) - self.assertFalse(batch.arrays[seg].spec.interpolatable) - - -class TestHdf5Source(ProviderTest, Hdf5LikeSourceTestMixin): - extension = "hdf" - SourceUnderTest = Hdf5Source - - def _open_writable_file(self, path): - return h5py.File(path, "w") - - -@skipIf(isinstance(zarr, NoSuchModule), "zarr is not installed") -class TestZarrSource(ProviderTest, Hdf5LikeSourceTestMixin): - extension = "zarr" - SourceUnderTest = ZarrSource - - def _open_writable_file(self, path): - return ZarrFile(path, mode="w") + assert batch.arrays[raw].spec.interpolatable + assert batch.arrays[raw_low].spec.interpolatable + assert not (batch.arrays[seg].spec.interpolatable) diff --git a/tests/cases/hdf5_write.py b/tests/cases/hdf5_write.py index 3a564277..d5710cd5 100644 --- a/tests/cases/hdf5_write.py +++ b/tests/cases/hdf5_write.py @@ -1,19 +1,34 @@ -from .provider_test import ProviderTest -from gunpowder import * import numpy as np + +from gunpowder import ( + Array, + ArrayKey, + ArraySpec, + Batch, + BatchProvider, + BatchRequest, + Hdf5Write, + Roi, + Scan, + build, +) from gunpowder.ext import h5py class Hdf5WriteTestSource(BatchProvider): + def __init__(self, raw_key, labels_key): + self.raw_key = raw_key + self.labels_key = labels_key + def setup(self): self.provides( - ArrayKeys.RAW, + self.raw_key, ArraySpec( roi=Roi((20000, 2000, 2000), (2000, 200, 200)), voxel_size=(20, 2, 2) ), ) self.provides( - ArrayKeys.GT_LABELS, + self.labels_key, ArraySpec( roi=Roi((20100, 2010, 2010), (1800, 180, 180)), voxel_size=(20, 2, 2) ), @@ -47,44 +62,44 @@ def provide(self, request): return batch -class TestHdf5Write(ProviderTest): - def test_output(self): - path = self.path_to("hdf5_write_test.hdf") +def test_output(tmpdir): + path = tmpdir / "hdf5_write_test.hdf" - source = Hdf5WriteTestSource() + raw_key = ArrayKey("RAW") + labels_key = ArrayKey("LABELS") - chunk_request = BatchRequest() - chunk_request.add(ArrayKeys.RAW, (400, 30, 34)) - chunk_request.add(ArrayKeys.GT_LABELS, (200, 10, 14)) + source = Hdf5WriteTestSource(raw_key, labels_key) - pipeline = ( - source - + Hdf5Write({ArrayKeys.RAW: "arrays/raw"}, output_filename=path) - + Scan(chunk_request) - ) + chunk_request = BatchRequest() + chunk_request.add(raw_key, (400, 30, 34)) + chunk_request.add(labels_key, (200, 10, 14)) - with build(pipeline): - raw_spec = pipeline.spec[ArrayKeys.RAW] - labels_spec = pipeline.spec[ArrayKeys.GT_LABELS] + pipeline = ( + source + + Hdf5Write({raw_key: "arrays/raw"}, output_filename=path) + + Scan(chunk_request) + ) - full_request = BatchRequest( - {ArrayKeys.RAW: raw_spec, ArrayKeys.GT_LABELS: labels_spec} - ) + with build(pipeline): + raw_spec = pipeline.spec[raw_key] + labels_spec = pipeline.spec[labels_key] - batch = pipeline.request_batch(full_request) + full_request = BatchRequest({raw_key: raw_spec, labels_key: labels_spec}) - # assert that stored HDF dataset equals batch array + batch = pipeline.request_batch(full_request) - with h5py.File(path, "r") as f: - ds = f["arrays/raw"] + # assert that stored HDF dataset equals batch array - batch_raw = batch.arrays[ArrayKeys.RAW] - stored_raw = np.array(ds) + with h5py.File(path, "r") as f: + ds = f["arrays/raw"] - self.assertEqual( - stored_raw.shape[-3:], - batch_raw.spec.roi.shape // batch_raw.spec.voxel_size, - ) - self.assertEqual(tuple(ds.attrs["offset"]), batch_raw.spec.roi.offset) - self.assertEqual(tuple(ds.attrs["resolution"]), batch_raw.spec.voxel_size) - self.assertTrue((stored_raw == batch.arrays[ArrayKeys.RAW].data).all()) + batch_raw = batch.arrays[raw_key] + stored_raw = np.array(ds) + + assert ( + stored_raw.shape[-3:] + == batch_raw.spec.roi.shape // batch_raw.spec.voxel_size + ) + assert tuple(ds.attrs["offset"]) == batch_raw.spec.roi.offset + assert tuple(ds.attrs["resolution"]) == batch_raw.spec.voxel_size + assert (stored_raw == batch.arrays[raw_key].data).all() diff --git a/tests/cases/helper_sources.py b/tests/cases/helper_sources.py index 630333d6..135516db 100644 --- a/tests/cases/helper_sources.py +++ b/tests/cases/helper_sources.py @@ -1,7 +1,7 @@ -from gunpowder import BatchProvider, GraphKey, Graph, ArrayKey, Array, Batch - import copy +from gunpowder import Array, ArrayKey, Batch, BatchProvider, Graph, GraphKey + class ArraySource(BatchProvider): def __init__(self, key: ArrayKey, array: Array): diff --git a/tests/cases/intensity_augment.py b/tests/cases/intensity_augment.py index bc7ce785..fd8d21cb 100644 --- a/tests/cases/intensity_augment.py +++ b/tests/cases/intensity_augment.py @@ -1,23 +1,36 @@ -from .provider_test import ProviderTest -from gunpowder import IntensityAugment, ArrayKeys, build, Normalize - import numpy as np +from gunpowder import ( + Array, + ArrayKey, + ArraySpec, + BatchRequest, + IntensityAugment, + Roi, + build, +) + +from .helper_sources import ArraySource + + +def test_shift(): + raw_key = ArrayKey("RAW") + raw_spec = ArraySpec( + roi=Roi((0, 0, 0), (10, 10, 10)), voxel_size=(1, 1, 1), dtype=np.float32 + ) + raw_data = np.zeros(raw_spec.roi.shape / raw_spec.voxel_size, dtype=np.float32) + raw_array = Array(raw_data, raw_spec) + + pipeline = ArraySource(raw_key, raw_array) + IntensityAugment( + raw_key, scale_min=0, scale_max=0, shift_min=0.5, shift_max=0.5 + ) -class TestIntensityAugment(ProviderTest): - def test_shift(self): - pipeline = ( - self.test_source - + Normalize(ArrayKeys.RAW) - + IntensityAugment( - ArrayKeys.RAW, scale_min=0, scale_max=0, shift_min=0.5, shift_max=0.5 - ) - ) + request = BatchRequest() + request.add(raw_key, (10, 10, 10)) - with build(pipeline): - for i in range(100): - batch = pipeline.request_batch(self.test_request) + with build(pipeline): + batch = pipeline.request_batch(request) - x = batch.arrays[ArrayKeys.RAW].data - assert np.isclose(x.min(), 0.5) - assert np.isclose(x.max(), 0.5) + x = batch.arrays[raw_key].data + assert np.isclose(x.min(), 0.5) + assert np.isclose(x.max(), 0.5) diff --git a/tests/cases/intensity_scale_shift.py b/tests/cases/intensity_scale_shift.py index c64b4ec3..12a8caf2 100644 --- a/tests/cases/intensity_scale_shift.py +++ b/tests/cases/intensity_scale_shift.py @@ -1,16 +1,17 @@ -from .helper_sources import ArraySource +import numpy as np + from gunpowder import ( - IntensityScaleShift, - ArrayKey, - build, Array, + ArrayKey, ArraySpec, - Roi, - Coordinate, BatchRequest, + Coordinate, + IntensityScaleShift, + Roi, + build, ) -import numpy as np +from .helper_sources import ArraySource def test_shift(): diff --git a/tests/cases/iterate_locations.py b/tests/cases/iterate_locations.py index f1f7be24..6e13b785 100644 --- a/tests/cases/iterate_locations.py +++ b/tests/cases/iterate_locations.py @@ -1,24 +1,21 @@ -from .provider_test import ProviderTest +import networkx as nx +import numpy as np + from gunpowder import ( ArrayKey, - ArrayKeys, ArraySpec, BatchRequest, - Node, + Coordinate, Edge, - GraphSpec, GraphKey, - GraphKeys, GraphSource, + GraphSpec, IterateLocations, - build, + Node, Roi, - Coordinate, + build, ) -import numpy as np -import networkx as nx - class DummyDaisyGraphProvider: """Dummy graph provider mimicing daisy.SharedGraphProvider. @@ -46,58 +43,53 @@ def __getitem__(self, roi): return graph -class TestIterateLocation(ProviderTest): - @property - def edges(self): - return [Edge(0, 1), Edge(1, 2), Edge(2, 3), Edge(3, 4), Edge(4, 0)] +def edges(): + return [Edge(0, 1), Edge(1, 2), Edge(2, 3), Edge(3, 4), Edge(4, 0)] + + +def nodes(): + return [ + Node(0, location=np.array([0, 0, 0])), + Node(1, location=np.array([1, 1, 1])), + Node(2, location=np.array([2, 2, 2])), + Node(3, location=np.array([3, 3, 3])), + Node(4, location=np.array([4, 4, 4])), + ] + + +def spec(): + return GraphSpec( + roi=Roi(Coordinate([0, 0, 0]), Coordinate([5, 5, 5])), directed=True + ) - @property - def nodes(self): - return [ - Node(0, location=np.array([0, 0, 0], dtype=self.spec.dtype)), - Node(1, location=np.array([1, 1, 1], dtype=self.spec.dtype)), - Node(2, location=np.array([2, 2, 2], dtype=self.spec.dtype)), - Node(3, location=np.array([3, 3, 3], dtype=self.spec.dtype)), - Node(4, location=np.array([4, 4, 4], dtype=self.spec.dtype)), - ] - @property - def spec(self): - return GraphSpec( - roi=Roi(Coordinate([0, 0, 0]), Coordinate([5, 5, 5])), directed=True - ) +def test_output(): + graph_key = GraphKey("TEST_GRAPH") + node_key = ArrayKey("NODE_ID") - def test_output(self): - GraphKey("TEST_GRAPH") - ArrayKey("NODE_ID") + dummy_provider = DummyDaisyGraphProvider(nodes(), edges(), directed=True) + graph_source = GraphSource(dummy_provider, graph_key, spec()) + iterate_locations = IterateLocations(graph_key, node_id=node_key) + pipeline = graph_source + iterate_locations + request = BatchRequest( + { + graph_key: GraphSpec(roi=Roi((0, 0, 0), (1, 1, 1))), + node_key: ArraySpec(nonspatial=True), + } + ) + node_ids = [] + seen_vertices = [] + expected_vertices = nodes() + with build(pipeline): + for _ in range(len(nodes())): + batch = pipeline.request_batch(request) + node_ids.extend(batch[node_key].data) + graph = batch[graph_key] + assert graph.num_vertices() == 1 + node = next(graph.nodes) + seen_vertices.append(node) - dummy_provider = DummyDaisyGraphProvider(self.nodes, self.edges, directed=True) - graph_source = GraphSource(dummy_provider, GraphKeys.TEST_GRAPH, self.spec) - iterate_locations = IterateLocations( - GraphKeys.TEST_GRAPH, node_id=ArrayKeys.NODE_ID - ) - pipeline = graph_source + iterate_locations - request = BatchRequest( - { - GraphKeys.TEST_GRAPH: GraphSpec(roi=Roi((0, 0, 0), (1, 1, 1))), - ArrayKeys.NODE_ID: ArraySpec(nonspatial=True), - } - ) - node_ids = [] - seen_vertices = [] - expected_vertices = self.nodes - with build(pipeline): - for _ in range(len(self.nodes)): - batch = pipeline.request_batch(request) - node_ids.extend(batch[ArrayKeys.NODE_ID].data) - graph = batch[GraphKeys.TEST_GRAPH] - self.assertEqual(graph.num_vertices(), 1) - node = next(graph.nodes) - seen_vertices.append(node) - self.assertCountEqual( - [v.id for v in expected_vertices], - node_ids, - ) - for vertex in seen_vertices: - # locations are shifted to lie in roi (so, (0, 0, 0)) - assert all(np.isclose(np.array([0.0, 0.0, 0.0]), vertex.location)) + assert [v.id for v in expected_vertices] == node_ids + for vertex in seen_vertices: + # locations are shifted to lie in roi (so, (0, 0, 0)) + assert all(np.isclose(np.array([0.0, 0.0, 0.0]), vertex.location)) diff --git a/tests/cases/jax_train.py b/tests/cases/jax_train.py index 2ff55be6..20309fd9 100644 --- a/tests/cases/jax_train.py +++ b/tests/cases/jax_train.py @@ -1,269 +1,245 @@ -from .provider_test import ProviderTest +import logging + +import numpy as np +import pytest + from gunpowder import ( + Array, + ArrayKey, + ArraySpec, + Batch, BatchProvider, BatchRequest, - ArraySpec, Roi, - ArrayKeys, - ArrayKey, - Array, - Batch, build, ) -from gunpowder.ext import jax, haiku, optax, NoSuchModule -from gunpowder.jax import Train, Predict, GenericJaxModel -from unittest import skipIf -import numpy as np - -import logging +from gunpowder.ext import NoSuchModule, haiku, jax, optax +from gunpowder.jax import GenericJaxModel, Predict, Train # use CPU for JAX tests and avoid GPU compatibility if not isinstance(jax, NoSuchModule): jax.config.update("jax_platform_name", "cpu") -class ExampleJaxTrain2DSource(BatchProvider): - def __init__(self): - pass +class ExampleJaxTrainSource(BatchProvider): + def __init__(self, a_key, b_key, c_key): + self.a_key = a_key + self.b_key = b_key + self.c_key = c_key def setup(self): spec = ArraySpec( - roi=Roi((0, 0), (17, 17)), + roi=Roi((0, 0), (2, 2)), dtype=np.float32, interpolatable=True, voxel_size=(1, 1), ) - self.provides(ArrayKeys.A, spec) + self.provides(self.a_key, spec) + self.provides(self.b_key, spec) + + spec = ArraySpec(nonspatial=True) + self.provides(self.c_key, spec) def provide(self, request): batch = Batch() - spec = self.spec[ArrayKeys.A] + spec = self.spec[self.a_key] + spec.roi = request[self.a_key].roi + + batch.arrays[self.a_key] = Array( + np.array([[0, 1], [2, 3]], dtype=np.float32), spec + ) + + spec = self.spec[self.b_key] + spec.roi = request[self.b_key].roi + + batch.arrays[self.b_key] = Array( + np.array([[0, 1], [2, 3]], dtype=np.float32), spec + ) - x = np.array(list(range(17)), dtype=np.float32).reshape([17, 1]) - x = x + x.T + spec = self.spec[self.c_key] - batch.arrays[ArrayKeys.A] = Array(x, spec).crop(request[ArrayKeys.A].roi) + batch.arrays[self.c_key] = Array(np.array([1], dtype=np.float32), spec) return batch -class ExampleJaxTrainSource(BatchProvider): - def setup(self): - spec = ArraySpec( - roi=Roi((0, 0), (2, 2)), - dtype=np.float32, - interpolatable=True, - voxel_size=(1, 1), - ) - self.provides(ArrayKeys.A, spec) - self.provides(ArrayKeys.B, spec) +@pytest.mark.skipif(isinstance(jax, NoSuchModule), reason="Jax is not installed") +def test_train(tmpdir): + logging.getLogger("gunpowder.jax.nodes.train").setLevel(logging.INFO) - spec = ArraySpec(nonspatial=True) - self.provides(ArrayKeys.C, spec) + checkpoint_basename = tmpdir / "model" - def provide(self, request): - batch = Batch() + a_key = ArrayKey("A") + b_key = ArrayKey("B") + c_key = ArrayKey("C") + c_pred_key = ArrayKey("C_PREDICTED") + c_grad_key = ArrayKey("C_GRADIENT") - spec = self.spec[ArrayKeys.A] - spec.roi = request[ArrayKeys.A].roi + class ExampleModel(GenericJaxModel): + def __init__(self, is_training): + super().__init__(is_training) - batch.arrays[ArrayKeys.A] = Array( - np.array([[0, 1], [2, 3]], dtype=np.float32), spec - ) + def _linear(x): + return haiku.Linear(1, False)(x) - spec = self.spec[ArrayKeys.B] - spec.roi = request[ArrayKeys.B].roi + self.linear = haiku.without_apply_rng(haiku.transform(_linear)) + self.opt = optax.sgd(learning_rate=1e-7, momentum=0.999) - batch.arrays[ArrayKeys.B] = Array( - np.array([[0, 1], [2, 3]], dtype=np.float32), spec - ) + def initialize(self, rng_key, inputs): + a = inputs["a"].reshape(-1) + b = inputs["b"].reshape(-1) + weight = self.linear.init(rng_key, a * b) + opt_state = self.opt.init(weight) + return (weight, opt_state) - spec = self.spec[ArrayKeys.C] + def forward(self, params, inputs): + a = inputs["a"].reshape(-1) + b = inputs["b"].reshape(-1) + return {"c": self.linear.apply(params[0], a * b)} - batch.arrays[ArrayKeys.C] = Array(np.array([1], dtype=np.float32), spec) + def _loss_fn(self, weight, a, b, c): + c_pred = self.linear.apply(weight, a * b) + loss = optax.l2_loss(predictions=c_pred, targets=c) * 2 + loss_mean = loss.mean() + return loss_mean, (c_pred, loss, loss_mean) - return batch + def _apply_optimizer(self, params, grads): + updates, new_opt_state = self.opt.update(grads, params[1]) + new_weight = optax.apply_updates(params[0], updates) + return new_weight, new_opt_state + def train_step(self, params, inputs, pmapped=False): + a = inputs["a"].reshape(-1) + b = inputs["b"].reshape(-1) + c = inputs["c"].reshape(-1) -@skipIf(isinstance(jax, NoSuchModule), "Jax is not installed") -class TestJaxTrain(ProviderTest): - def test_output(self): - logging.getLogger("gunpowder.jax.nodes.train").setLevel(logging.INFO) - - checkpoint_basename = self.path_to("model") - - ArrayKey("A") - ArrayKey("B") - ArrayKey("C") - ArrayKey("C_PREDICTED") - ArrayKey("C_GRADIENT") - - class ExampleModel(GenericJaxModel): - def __init__(self, is_training): - super().__init__(is_training) - - def _linear(x): - return haiku.Linear(1, False)(x) - - self.linear = haiku.without_apply_rng(haiku.transform(_linear)) - self.opt = optax.sgd(learning_rate=1e-7, momentum=0.999) - - def initialize(self, rng_key, inputs): - a = inputs["a"].reshape(-1) - b = inputs["b"].reshape(-1) - weight = self.linear.init(rng_key, a * b) - opt_state = self.opt.init(weight) - return (weight, opt_state) - - def forward(self, params, inputs): - a = inputs["a"].reshape(-1) - b = inputs["b"].reshape(-1) - return {"c": self.linear.apply(params[0], a * b)} - - def _loss_fn(self, weight, a, b, c): - c_pred = self.linear.apply(weight, a * b) - loss = optax.l2_loss(predictions=c_pred, targets=c) * 2 - loss_mean = loss.mean() - return loss_mean, (c_pred, loss, loss_mean) - - def _apply_optimizer(self, params, grads): - updates, new_opt_state = self.opt.update(grads, params[1]) - new_weight = optax.apply_updates(params[0], updates) - return new_weight, new_opt_state - - def train_step(self, params, inputs, pmapped=False): - a = inputs["a"].reshape(-1) - b = inputs["b"].reshape(-1) - c = inputs["c"].reshape(-1) - - grads, (c_pred, loss, loss_mean) = jax.grad( - self._loss_fn, has_aux=True - )(params[0], a, b, c) - - new_weight, new_opt_state = self._apply_optimizer(params, grads) - new_params = (new_weight, new_opt_state) - - outputs = { - "c_pred": c_pred, - "grad": loss, - } - return new_params, outputs, loss_mean - - model = ExampleModel(is_training=False) - - source = ExampleJaxTrainSource() - train = Train( - model=model, - inputs={"a": ArrayKeys.A, "b": ArrayKeys.B, "c": ArrayKeys.C}, - outputs={"c_pred": ArrayKeys.C_PREDICTED, "grad": ArrayKeys.C_GRADIENT}, - array_specs={ - ArrayKeys.C_PREDICTED: ArraySpec(nonspatial=True), - ArrayKeys.C_GRADIENT: ArraySpec(nonspatial=True), - }, - checkpoint_basename=checkpoint_basename, - save_every=100, - spawn_subprocess=True, - n_devices=1, - ) - pipeline = source + train - - request = BatchRequest( - { - ArrayKeys.A: ArraySpec(roi=Roi((0, 0), (2, 2))), - ArrayKeys.B: ArraySpec(roi=Roi((0, 0), (2, 2))), - ArrayKeys.C: ArraySpec(nonspatial=True), - ArrayKeys.C_PREDICTED: ArraySpec(nonspatial=True), - ArrayKeys.C_GRADIENT: ArraySpec(nonspatial=True), - } - ) + grads, (c_pred, loss, loss_mean) = jax.grad(self._loss_fn, has_aux=True)( + params[0], a, b, c + ) - # train for a couple of iterations - with build(pipeline): - batch = pipeline.request_batch(request) + new_weight, new_opt_state = self._apply_optimizer(params, grads) + new_params = (new_weight, new_opt_state) - for i in range(200 - 1): - loss1 = batch.loss - batch = pipeline.request_batch(request) - loss2 = batch.loss - self.assertLess(loss2, loss1) - - # resume training - with build(pipeline): - for i in range(100): - loss1 = batch.loss - batch = pipeline.request_batch(request) - loss2 = batch.loss - self.assertLess(loss2, loss1) - - -@skipIf(isinstance(jax, NoSuchModule), "Jax is not installed") -class TestJaxPredict(ProviderTest): - def test_output(self): - logging.getLogger("gunpowder.jax.nodes.predict").setLevel(logging.INFO) - - a = ArrayKey("A") - b = ArrayKey("B") - c = ArrayKey("C") - c_pred = ArrayKey("C_PREDICTED") - d_pred = ArrayKey("D_PREDICTED") - - class ExampleModel(GenericJaxModel): - def __init__(self, is_training): - super().__init__(is_training) - - def _linear(x): - return haiku.Linear(1, False)(x) - - self.linear = haiku.without_apply_rng(haiku.transform(_linear)) - - def initialize(self, rng_key, inputs): - a = inputs["a"].reshape(-1) - b = inputs["b"].reshape(-1) - weight = self.linear.init(rng_key, a * b) - weight["linear"]["w"] = ( - weight["linear"]["w"].at[:].set(np.array([[1], [1], [1], [1]])) - ) - return weight - - def forward(self, params, inputs): - a = inputs["a"].reshape(-1) - b = inputs["b"].reshape(-1) - c_pred = self.linear.apply(params, a * b) - d_pred = c_pred * 2 - return {"c": c_pred, "d": d_pred} - - model = ExampleModel(is_training=False) - - source = ExampleJaxTrainSource() - predict = Predict( - model=model, - inputs={"a": a, "b": b}, - outputs={"c": c_pred, "d": d_pred}, - array_specs={ - c: ArraySpec(nonspatial=True), - c_pred: ArraySpec(nonspatial=True), - d_pred: ArraySpec(nonspatial=True), - }, - spawn_subprocess=True, - ) - pipeline = source + predict - - request = BatchRequest( - { - a: ArraySpec(roi=Roi((0, 0), (2, 2))), - b: ArraySpec(roi=Roi((0, 0), (2, 2))), - c: ArraySpec(nonspatial=True), - c_pred: ArraySpec(nonspatial=True), - d_pred: ArraySpec(nonspatial=True), + outputs = { + "c_pred": c_pred, + "grad": loss, } - ) - - # train for a couple of iterations - with build(pipeline): - batch1 = pipeline.request_batch(request) - batch2 = pipeline.request_batch(request) + return new_params, outputs, loss_mean + + model = ExampleModel(is_training=False) + + source = ExampleJaxTrainSource(a_key, b_key, c_key) + train = Train( + model=model, + inputs={"a": a_key, "b": b_key, "c": c_key}, + outputs={"c_pred": c_pred_key, "grad": c_grad_key}, + array_specs={ + c_pred_key: ArraySpec(nonspatial=True), + c_grad_key: ArraySpec(nonspatial=True), + }, + checkpoint_basename=checkpoint_basename, + save_every=100, + spawn_subprocess=True, + n_devices=1, + ) + pipeline = source + train + + request = BatchRequest( + { + a_key: ArraySpec(roi=Roi((0, 0), (2, 2))), + b_key: ArraySpec(roi=Roi((0, 0), (2, 2))), + c_key: ArraySpec(nonspatial=True), + c_pred_key: ArraySpec(nonspatial=True), + c_grad_key: ArraySpec(nonspatial=True), + } + ) + + # train for a couple of iterations + with build(pipeline): + batch = pipeline.request_batch(request) + + for i in range(200 - 1): + loss1 = batch.loss + batch = pipeline.request_batch(request) + loss2 = batch.loss + assert loss2 < loss1 - assert np.isclose(batch1[c_pred].data, batch2[c_pred].data) - assert np.isclose(batch1[c_pred].data, 1 + 4 + 9) - assert np.isclose(batch2[d_pred].data, 2 * (1 + 4 + 9)) + # resume training + with build(pipeline): + for i in range(100): + loss1 = batch.loss + batch = pipeline.request_batch(request) + loss2 = batch.loss + assert loss2 < loss1 + + +@pytest.mark.skipif(isinstance(jax, NoSuchModule), reason="Jax is not installed") +def test_pred(): + logging.getLogger("gunpowder.jax.nodes.predict").setLevel(logging.INFO) + + a = ArrayKey("A") + b = ArrayKey("B") + c = ArrayKey("C") + c_pred = ArrayKey("C_PREDICTED") + d_pred = ArrayKey("D_PREDICTED") + + class ExampleModel(GenericJaxModel): + def __init__(self, is_training): + super().__init__(is_training) + + def _linear(x): + return haiku.Linear(1, False)(x) + + self.linear = haiku.without_apply_rng(haiku.transform(_linear)) + + def initialize(self, rng_key, inputs): + a = inputs["a"].reshape(-1) + b = inputs["b"].reshape(-1) + weight = self.linear.init(rng_key, a * b) + weight["linear"]["w"] = ( + weight["linear"]["w"].at[:].set(np.array([[1], [1], [1], [1]])) + ) + return weight + + def forward(self, params, inputs): + a = inputs["a"].reshape(-1) + b = inputs["b"].reshape(-1) + c_pred = self.linear.apply(params, a * b) + d_pred = c_pred * 2 + return {"c": c_pred, "d": d_pred} + + model = ExampleModel(is_training=False) + + source = ExampleJaxTrainSource(a, b, c) + predict = Predict( + model=model, + inputs={"a": a, "b": b}, + outputs={"c": c_pred, "d": d_pred}, + array_specs={ + c: ArraySpec(nonspatial=True), + c_pred: ArraySpec(nonspatial=True), + d_pred: ArraySpec(nonspatial=True), + }, + spawn_subprocess=True, + ) + pipeline = source + predict + + request = BatchRequest( + { + a: ArraySpec(roi=Roi((0, 0), (2, 2))), + b: ArraySpec(roi=Roi((0, 0), (2, 2))), + c: ArraySpec(nonspatial=True), + c_pred: ArraySpec(nonspatial=True), + d_pred: ArraySpec(nonspatial=True), + } + ) + + # train for a couple of iterations + with build(pipeline): + batch1 = pipeline.request_batch(request) + batch2 = pipeline.request_batch(request) + + assert np.isclose(batch1[c_pred].data, batch2[c_pred].data) + assert np.isclose(batch1[c_pred].data, 1 + 4 + 9) + assert np.isclose(batch2[d_pred].data, 2 * (1 + 4 + 9)) diff --git a/tests/cases/merge_provider.py b/tests/cases/merge_provider.py index 1bc48e1e..0b47d2c1 100644 --- a/tests/cases/merge_provider.py +++ b/tests/cases/merge_provider.py @@ -1,25 +1,25 @@ -import unittest +import numpy as np +import pytest + from gunpowder import ( - GraphSpec, - GraphKey, - Roi, - Coordinate, - Batch, - BatchProvider, - ArrayKeys, + Array, ArrayKey, + ArrayKeys, ArraySpec, - Array, + Batch, + BatchProvider, BatchRequest, + Coordinate, + GraphKey, + GraphSpec, MergeProvider, RandomLocation, + Roi, build, ) -from gunpowder.graph import GraphKeys, Graph +from gunpowder.graph import Graph, GraphKeys from gunpowder.pipeline import PipelineSetupError -import numpy as np - class GraphTestSource(BatchProvider): def __init__(self, voxel_size): @@ -56,34 +56,33 @@ def provide(self, request): return batch -class TestMergeProvider(unittest.TestCase): - def test_merge_basics(self): - voxel_size = (1, 1, 1) - GraphKey("PRESYN") - ArrayKey("GT_LABELS") - graphsource = GraphTestSource(voxel_size) - arraysource = ArrayTestSoure(voxel_size) - pipeline = (graphsource, arraysource) + MergeProvider() + RandomLocation() - window_request = Coordinate((50, 50, 50)) - with build(pipeline): - # Check basic merging. - request = BatchRequest() - request.add((GraphKeys.PRESYN), window_request) - request.add((ArrayKeys.GT_LABELS), window_request) - batch_res = pipeline.request_batch(request) - self.assertTrue(ArrayKeys.GT_LABELS in batch_res.arrays) - self.assertTrue(GraphKeys.PRESYN in batch_res.graphs) +def test_merge_basics(): + voxel_size = (1, 1, 1) + GraphKey("PRESYN") + ArrayKey("GT_LABELS") + graphsource = GraphTestSource(voxel_size) + arraysource = ArrayTestSoure(voxel_size) + pipeline = (graphsource, arraysource) + MergeProvider() + RandomLocation() + window_request = Coordinate((50, 50, 50)) + with build(pipeline): + # Check basic merging. + request = BatchRequest() + request.add((GraphKeys.PRESYN), window_request) + request.add((ArrayKeys.GT_LABELS), window_request) + batch_res = pipeline.request_batch(request) + assert ArrayKeys.GT_LABELS in batch_res.arrays + assert GraphKeys.PRESYN in batch_res.graphs - # Check that request of only one source also works. - request = BatchRequest() - request.add((GraphKeys.PRESYN), window_request) - batch_res = pipeline.request_batch(request) - self.assertFalse(ArrayKeys.GT_LABELS in batch_res.arrays) - self.assertTrue(GraphKeys.PRESYN in batch_res.graphs) + # Check that request of only one source also works. + request = BatchRequest() + request.add((GraphKeys.PRESYN), window_request) + batch_res = pipeline.request_batch(request) + assert ArrayKeys.GT_LABELS not in batch_res.arrays + assert GraphKeys.PRESYN in batch_res.graphs - # Check that it fails, when having two sources that provide the same type. - arraysource2 = ArrayTestSoure(voxel_size) - pipeline_fail = (arraysource, arraysource2) + MergeProvider() + RandomLocation() - with self.assertRaises(PipelineSetupError): - with build(pipeline_fail): - pass + # Check that it fails, when having two sources that provide the same type. + arraysource2 = ArrayTestSoure(voxel_size) + pipeline_fail = (arraysource, arraysource2) + MergeProvider() + RandomLocation() + with pytest.raises(PipelineSetupError): + with build(pipeline_fail): + pass diff --git a/tests/cases/node_dependencies.py b/tests/cases/node_dependencies.py index 76a09b46..dcfe71d5 100644 --- a/tests/cases/node_dependencies.py +++ b/tests/cases/node_dependencies.py @@ -1,28 +1,31 @@ -from .provider_test import ProviderTest +import numpy as np + from gunpowder import ( - BatchProvider, + Array, + ArrayKey, + ArraySpec, + Batch, BatchFilter, + BatchProvider, BatchRequest, - Batch, - ArrayKeys, - ArraySpec, - ArrayKey, - Array, Roi, build, ) -import numpy as np class NodeDependenciesTestSource(BatchProvider): + def __init__(self, a_key, b_key): + self.a_key = a_key + self.b_key = b_key + def setup(self): self.provides( - ArrayKeys.A, + self.a_key, ArraySpec(roi=Roi((0, 0, 0), (1000, 1000, 1000)), voxel_size=(4, 4, 4)), ) self.provides( - ArrayKeys.B, + self.b_key, ArraySpec(roi=Roi((0, 0, 0), (1000, 1000, 1000)), voxel_size=(4, 4, 4)), ) @@ -56,18 +59,20 @@ def provide(self, request): class NodeDependenciesTestNode(BatchFilter): """Creates C from B.""" - def __init__(self): + def __init__(self, b_key, c_key): + self.b_key = b_key + self.c_key = c_key self.context = (20, 20, 20) def setup(self): - self.provides(ArrayKeys.C, self.spec[ArrayKeys.B]) + self.provides(self.c_key, self.spec[self.b_key]) def prepare(self, request): - assert ArrayKeys.C in request + assert self.c_key in request dependencies = BatchRequest() - dependencies[ArrayKeys.B] = ArraySpec( - request[ArrayKeys.C].roi.grow(self.context, self.context) + dependencies[self.b_key] = ArraySpec( + request[self.c_key].roi.grow(self.context, self.context) ) return dependencies @@ -76,87 +81,86 @@ def process(self, batch, request): outputs = Batch() # make sure a ROI is what we requested - b_roi = request[ArrayKeys.C].roi.grow(self.context, self.context) - assert batch[ArrayKeys.B].spec.roi == b_roi + b_roi = request[self.c_key].roi.grow(self.context, self.context) + assert batch[self.b_key].spec.roi == b_roi # add C to batch - c = batch[ArrayKeys.B].crop(request[ArrayKeys.C].roi) - outputs[ArrayKeys.C] = c + c = batch[self.b_key].crop(request[self.c_key].roi) + outputs[self.c_key] = c return outputs -class TestNodeDependencies(ProviderTest): - def test_dependecies(self): - ArrayKey("A") - ArrayKey("B") - ArrayKey("C") +def test_dependecies(): + a_key = ArrayKey("A") + b_key = ArrayKey("B") + c_key = ArrayKey("C") - pipeline = NodeDependenciesTestSource() - pipeline += NodeDependenciesTestNode() + pipeline = NodeDependenciesTestSource(a_key, b_key) + pipeline += NodeDependenciesTestNode(b_key, c_key) - c_roi = Roi((100, 100, 100), (100, 100, 100)) + c_roi = Roi((100, 100, 100), (100, 100, 100)) - # simple test, ask only for C + # simple test, ask only for C - request = BatchRequest() - request[ArrayKeys.C] = ArraySpec(roi=c_roi) + request = BatchRequest() + request[c_key] = ArraySpec(roi=c_roi) - with build(pipeline): - batch = pipeline.request_batch(request) + with build(pipeline): + batch = pipeline.request_batch(request) - assert ArrayKeys.A not in batch - assert ArrayKeys.B not in batch - assert batch[ArrayKeys.C].spec.roi == c_roi + assert a_key not in batch + assert b_key not in batch + assert batch[c_key].spec.roi == c_roi - # ask for C and B of same size as needed by node + # ask for C and B of same size as needed by node - b_roi = c_roi.grow((20, 20, 20), (20, 20, 20)) + b_roi = c_roi.grow((20, 20, 20), (20, 20, 20)) - request = BatchRequest() - request[ArrayKeys.C] = ArraySpec(roi=c_roi) - request[ArrayKeys.B] = ArraySpec(roi=b_roi) + request = BatchRequest() + request[c_key] = ArraySpec(roi=c_roi) + request[b_key] = ArraySpec(roi=b_roi) - with build(pipeline): - batch = pipeline.request_batch(request) + with build(pipeline): + batch = pipeline.request_batch(request) - c = batch[ArrayKeys.C] - b = batch[ArrayKeys.B] - assert b.spec.roi == b_roi - assert c.spec.roi == c_roi - assert np.equal(b.crop(c.spec.roi).data, c.data).all() + c = batch[c_key] + b = batch[b_key] + assert b.spec.roi == b_roi + assert c.spec.roi == c_roi + assert np.equal(b.crop(c.spec.roi).data, c.data).all() - # ask for C and B of larger size + # ask for C and B of larger size - b_roi = c_roi.grow((40, 40, 40), (40, 40, 40)) + b_roi = c_roi.grow((40, 40, 40), (40, 40, 40)) - request = BatchRequest() - request[ArrayKeys.B] = ArraySpec(roi=b_roi) - request[ArrayKeys.C] = ArraySpec(roi=c_roi) + request = BatchRequest() + request[b_key] = ArraySpec(roi=b_roi) + request[c_key] = ArraySpec(roi=c_roi) - with build(pipeline): - batch = pipeline.request_batch(request) + with build(pipeline): + batch = pipeline.request_batch(request) - b = batch[ArrayKeys.B] - c = batch[ArrayKeys.C] - assert ArrayKeys.A not in batch - assert b.spec.roi == b_roi - assert c.spec.roi == c_roi - assert np.equal(b.crop(c.spec.roi).data, c.data).all() + b = batch[b_key] + c = batch[c_key] + assert a_key not in batch + assert b.spec.roi == b_roi + assert c.spec.roi == c_roi + assert np.equal(b.crop(c.spec.roi).data, c.data).all() - # ask for C and B of smaller size + # ask for C and B of smaller size - b_roi = c_roi.grow((-40, -40, -40), (-40, -40, -40)) + b_roi = c_roi.grow((-40, -40, -40), (-40, -40, -40)) - request = BatchRequest() - request[ArrayKeys.B] = ArraySpec(roi=b_roi) - request[ArrayKeys.C] = ArraySpec(roi=c_roi) + request = BatchRequest() + request[b_key] = ArraySpec(roi=b_roi) + request[c_key] = ArraySpec(roi=c_roi) - with build(pipeline): - batch = pipeline.request_batch(request) + with build(pipeline): + batch = pipeline.request_batch(request) - b = batch[ArrayKeys.B] - c = batch[ArrayKeys.C] - assert ArrayKeys.A not in batch - assert b.spec.roi == b_roi - assert c.spec.roi == c_roi - assert np.equal(c.crop(b.spec.roi).data, b.data).all() + b = batch[b_key] + c = batch[c_key] + assert a_key not in batch + assert b.spec.roi == b_roi + assert c.spec.roi == c_roi + assert np.equal(c.crop(b.spec.roi).data, b.data).all() diff --git a/tests/cases/noise_augment.py b/tests/cases/noise_augment.py index 57768091..e145b166 100644 --- a/tests/cases/noise_augment.py +++ b/tests/cases/noise_augment.py @@ -1,24 +1,44 @@ -from .provider_test import ProviderTest -from gunpowder import IntensityAugment, ArrayKeys, build, Normalize, NoiseAugment +import numpy as np +from gunpowder import ( + Array, + ArrayKey, + ArraySpec, + BatchRequest, + IntensityAugment, + NoiseAugment, + Normalize, + Roi, + build, +) -class TestIntensityAugment(ProviderTest): - def test_shift(self): - pipeline = ( - self.test_source - + Normalize(ArrayKeys.RAW) - + IntensityAugment( - ArrayKeys.RAW, scale_min=0, scale_max=0, shift_min=0.5, shift_max=0.5 - ) - + NoiseAugment(ArrayKeys.RAW, clip=True) +from .helper_sources import ArraySource + + +def test_noise(): + raw_key = ArrayKey("RAW") + raw_spec = ArraySpec( + roi=Roi((0, 0, 0), (10, 10, 10)), voxel_size=(1, 1, 1), dtype=np.float32 + ) + raw_data = np.zeros(raw_spec.roi.shape / raw_spec.voxel_size, dtype=np.float32) + raw_array = Array(raw_data, raw_spec) + pipeline = ( + ArraySource(raw_key, raw_array) + + Normalize(raw_key) + + IntensityAugment( + raw_key, scale_min=0, scale_max=0, shift_min=0.5, shift_max=0.5 ) + + NoiseAugment(raw_key, clip=True) + ) + + request = BatchRequest() + request.add(raw_key, (10, 10, 10)) - with build(pipeline): - for i in range(100): - batch = pipeline.request_batch(self.test_request) + with build(pipeline): + batch = pipeline.request_batch(request) - x = batch.arrays[ArrayKeys.RAW].data - assert x.min() < 0.5 - assert x.min() >= 0 - assert x.max() > 0.5 - assert x.max() <= 1 + x = batch.arrays[raw_key].data + assert x.min() < 0.5 + assert x.min() >= 0 + assert x.max() > 0.5 + assert x.max() <= 1 diff --git a/tests/cases/normalize.py b/tests/cases/normalize.py index e37cdfc7..4ad04569 100644 --- a/tests/cases/normalize.py +++ b/tests/cases/normalize.py @@ -1,14 +1,25 @@ -from .provider_test import ProviderTest -from gunpowder import * +import numpy as np +from gunpowder import Array, ArrayKey, ArraySpec, BatchRequest, Normalize, Roi, build -class TestNormalize(ProviderTest): - def test_output(self): - pipeline = self.test_source + Normalize(ArrayKeys.RAW) +from .helper_sources import ArraySource - with build(pipeline): - batch = pipeline.request_batch(self.test_request) - raw = batch.arrays[ArrayKeys.RAW] - self.assertTrue(raw.data.min() >= 0) - self.assertTrue(raw.data.max() <= 1) +def test_output(): + raw_key = ArrayKey("RAW") + raw_spec = ArraySpec( + roi=Roi((0, 0, 0), (10, 10, 10)), voxel_size=(1, 1, 1), dtype=np.uint8 + ) + raw_data = np.zeros(raw_spec.roi.shape / raw_spec.voxel_size, dtype=np.uint8) + 128 + raw_array = Array(raw_data, raw_spec) + pipeline = ArraySource(raw_key, raw_array) + Normalize(raw_key) + + request = BatchRequest() + request.add(raw_key, (10, 10, 10)) + + with build(pipeline): + batch = pipeline.request_batch(request) + + raw = batch.arrays[raw_key] + assert raw.data.min() >= 0 + assert raw.data.max() <= 1 diff --git a/tests/cases/pad.py b/tests/cases/pad.py index 5efda685..31c0a848 100644 --- a/tests/cases/pad.py +++ b/tests/cases/pad.py @@ -1,23 +1,24 @@ -from .helper_sources import ArraySource, GraphSource +from itertools import product + +import numpy as np +import pytest + from gunpowder import ( - BatchRequest, + Array, + ArrayKey, ArraySpec, - Roi, + BatchRequest, Coordinate, Graph, GraphKey, GraphSpec, - Array, - ArrayKey, + MergeProvider, Pad, + Roi, build, - MergeProvider, ) -import pytest -import numpy as np - -from itertools import product +from .helper_sources import ArraySource, GraphSource @pytest.mark.parametrize("mode", ["constant", "reflect"]) diff --git a/tests/cases/placeholder_requests.py b/tests/cases/placeholder_requests.py index 8503a0c4..f8319f31 100644 --- a/tests/cases/placeholder_requests.py +++ b/tests/cases/placeholder_requests.py @@ -1,32 +1,34 @@ +import copy +import math + +import numpy as np +import pytest + from gunpowder import ( - PipelineRequestError, + Array, + ArrayKey, + ArraySpec, + Batch, BatchProvider, BatchRequest, - Batch, - Roi, Coordinate, - GraphSpec, - GraphKey, - ArrayKeys, - ArrayKey, - ArraySpec, - Array, ElasticAugment, + GraphKey, + GraphSpec, + PipelineRequestError, RandomLocation, + Roi, Snapshot, build, ) -from gunpowder.graph import Graph, GraphKeys, Node -from .provider_test import ProviderTest - -import pytest -import numpy as np - -import math -import copy +from gunpowder.graph import Graph, Node class PointTestSource3D(BatchProvider): + def __init__(self, points_key, labels_key): + self.points_key = points_key + self.labels_key = labels_key + def setup(self): self.points = [ Node(0, np.array([0, 10, 0])), @@ -37,12 +39,12 @@ def setup(self): ] self.provides( - GraphKeys.TEST_POINTS, + self.points_key, GraphSpec(roi=Roi((-100, -100, -100), (300, 300, 300))), ) self.provides( - ArrayKeys.TEST_LABELS, + self.labels_key, ArraySpec( roi=Roi((-100, -100, -100), (300, 300, 300)), voxel_size=Coordinate((4, 1, 1)), @@ -52,30 +54,30 @@ def setup(self): def point_to_voxel(self, array_roi, location): # location is in world units, get it into voxels - location = location / self.spec[ArrayKeys.TEST_LABELS].voxel_size + location = location / self.spec[self.labels_key].voxel_size # shift location relative to beginning of array roi - location -= array_roi.begin / self.spec[ArrayKeys.TEST_LABELS].voxel_size + location -= array_roi.begin / self.spec[self.labels_key].voxel_size return tuple(slice(int(l - 2), int(l + 3)) for l in location) def provide(self, request): batch = Batch() - if GraphKeys.TEST_POINTS in request: - roi_points = request[GraphKeys.TEST_POINTS].roi + if self.points_key in request: + roi_points = request[self.points_key].roi contained_points = [] for point in self.points: if roi_points.contains(point.location): contained_points.append(copy.deepcopy(point)) - batch[GraphKeys.TEST_POINTS] = Graph( + batch[self.points_key] = Graph( contained_points, [], GraphSpec(roi=roi_points) ) - if ArrayKeys.TEST_LABELS in request: - roi_array = request[ArrayKeys.TEST_LABELS].roi - roi_voxel = roi_array // self.spec[ArrayKeys.TEST_LABELS].voxel_size + if self.labels_key in request: + roi_array = request[self.labels_key].roi + roi_voxel = roi_array // self.spec[self.labels_key].voxel_size data = np.zeros(roi_voxel.shape, dtype=np.uint32) data[:, ::2] = 100 @@ -84,79 +86,79 @@ def provide(self, request): loc = self.point_to_voxel(roi_array, point.location) data[loc] = point.id - spec = self.spec[ArrayKeys.TEST_LABELS].copy() + spec = self.spec[self.labels_key].copy() spec.roi = roi_array - batch.arrays[ArrayKeys.TEST_LABELS] = Array(data, spec=spec) + batch.arrays[self.labels_key] = Array(data, spec=spec) return batch -class TestPlaceholderRequest(ProviderTest): - def test_without_placeholder(self): - test_labels = ArrayKey("TEST_LABELS") - test_points = GraphKey("TEST_POINTS") +def test_without_placeholder(tmpdir): + test_labels = ArrayKey("TEST_LABELS") + test_points = GraphKey("TEST_POINTS") - pipeline = ( - PointTestSource3D() - + RandomLocation(ensure_nonempty=test_points) - + ElasticAugment([10, 10, 10], [0.1, 0.1, 0.1], [0, 2.0 * math.pi]) - + Snapshot( - {test_labels: "volumes/labels"}, - output_dir=self.path_to(), - output_filename="elastic_augment_test{id}-{iteration}.hdf", - ) + pipeline = ( + PointTestSource3D(points_key=test_points, labels_key=test_labels) + + RandomLocation(ensure_nonempty=test_points) + + ElasticAugment([10, 10, 10], [0.1, 0.1, 0.1], [0, 2.0 * math.pi]) + + Snapshot( + {test_labels: "volumes/labels"}, + output_dir=tmpdir, + output_filename="elastic_augment_test{id}-{iteration}.hdf", ) + ) - with build(pipeline): - for i in range(2): - request_size = Coordinate((40, 40, 40)) + with build(pipeline): + for i in range(2): + request_size = Coordinate((40, 40, 40)) - request_a = BatchRequest(random_seed=i) - request_a.add(test_points, request_size) + request_a = BatchRequest(random_seed=i) + request_a.add(test_points, request_size) - request_b = BatchRequest(random_seed=i) - request_b.add(test_points, request_size) - request_b.add(test_labels, request_size) + request_b = BatchRequest(random_seed=i) + request_b.add(test_points, request_size) + request_b.add(test_labels, request_size) - # No array to provide a voxel size to ElasticAugment - with pytest.raises(PipelineRequestError): - pipeline.request_batch(request_a) - batch_b = pipeline.request_batch(request_b) + # No array to provide a voxel size to ElasticAugment + with pytest.raises(PipelineRequestError): + pipeline.request_batch(request_a) + batch_b = pipeline.request_batch(request_b) - self.assertIn(test_labels, batch_b) + assert test_labels in batch_b - def test_placeholder(self): - test_labels = ArrayKey("TEST_LABELS") - test_points = GraphKey("TEST_POINTS") - pipeline = ( - PointTestSource3D() - + RandomLocation(ensure_nonempty=test_points) - + ElasticAugment([10, 10, 10], [0.1, 0.1, 0.1], [0, 2.0 * math.pi]) - + Snapshot( - {test_labels: "volumes/labels"}, - output_dir=self.path_to(), - output_filename="elastic_augment_test{id}-{iteration}.hdf", - ) +def test_placeholder(tmpdir): + test_labels = ArrayKey("TEST_LABELS") + test_points = GraphKey("TEST_POINTS") + + pipeline = ( + PointTestSource3D(points_key=test_points, labels_key=test_labels) + + RandomLocation(ensure_nonempty=test_points) + + ElasticAugment([10, 10, 10], [0.1, 0.1, 0.1], [0, 2.0 * math.pi]) + + Snapshot( + {test_labels: "volumes/labels"}, + output_dir=tmpdir, + output_filename="elastic_augment_test{id}-{iteration}.hdf", ) + ) - with build(pipeline): - for i in range(2): - request_size = Coordinate((40, 40, 40)) + with build(pipeline): + for i in range(2): + request_size = Coordinate((40, 40, 40)) - request_a = BatchRequest(random_seed=i) - request_a.add(test_points, request_size) - request_a.add(test_labels, request_size, placeholder=True) + request_a = BatchRequest(random_seed=i) + request_a.add(test_points, request_size) + request_a.add(test_labels, request_size, placeholder=True) - request_b = BatchRequest(random_seed=i) - request_b.add(test_points, request_size) - request_b.add(test_labels, request_size) + request_b = BatchRequest(random_seed=i) + request_b.add(test_points, request_size) + request_b.add(test_labels, request_size) - batch_a = pipeline.request_batch(request_a) - batch_b = pipeline.request_batch(request_b) + batch_a = pipeline.request_batch(request_a) + batch_b = pipeline.request_batch(request_b) - points_a = batch_a[test_points].nodes - points_b = batch_b[test_points].nodes + points_a = batch_a[test_points].nodes + points_b = batch_b[test_points].nodes - for a, b in zip(points_a, points_b): - assert all(np.isclose(a.location, b.location)) + for a, b in zip(points_a, points_b): + assert all(np.isclose(a.location, b.location)) diff --git a/tests/cases/precache.py b/tests/cases/precache.py index ae2a53e8..d824170c 100644 --- a/tests/cases/precache.py +++ b/tests/cases/precache.py @@ -1,10 +1,21 @@ -from .helper_sources import ArraySource -from gunpowder import * +import time -import pytest import numpy as np +import pytest -import time +from gunpowder import ( + Array, + ArrayKey, + ArraySpec, + BatchFilter, + BatchRequest, + Coordinate, + PreCache, + Roi, + build, +) + +from .helper_sources import ArraySource class Delay(BatchFilter): diff --git a/tests/cases/prepare_malis.py b/tests/cases/prepare_malis.py index b4b9dc32..c1862e51 100755 --- a/tests/cases/prepare_malis.py +++ b/tests/cases/prepare_malis.py @@ -1,13 +1,26 @@ -from gunpowder import * -from gunpowder.contrib import PrepareMalis import numpy as np -from .provider_test import ProviderTest + +from gunpowder import ( + Array, + ArrayKey, + ArraySpec, + Batch, + BatchProvider, + BatchRequest, + Roi, + build, +) +from gunpowder.contrib import PrepareMalis class ExampleSourcePrepareMalis(BatchProvider): + def __init__(self, labels_key, ignore_key): + self.labels_key = labels_key + self.ignore_key = ignore_key + def setup(self): self.provides( - ArrayKeys.GT_LABELS, + self.labels_key, ArraySpec( roi=Roi((0, 0, 0), (90, 90, 90)), voxel_size=(1, 1, 1), @@ -15,7 +28,7 @@ def setup(self): ), ) self.provides( - ArrayKeys.GT_IGNORE, + self.ignore_key, ArraySpec( roi=Roi((0, 0, 0), (90, 90, 90)), voxel_size=(1, 1, 1), @@ -26,177 +39,138 @@ def setup(self): def provide(self, request): batch = Batch() - if ArrayKeys.GT_LABELS in request: - gt_labels_roi = request[ArrayKeys.GT_LABELS].roi + if self.labels_key in request: + gt_labels_roi = request[self.labels_key].roi gt_labels_shape = gt_labels_roi.shape data_labels = np.ones(gt_labels_shape) data_labels[gt_labels_shape[0] // 2 :, :, :] = 2 - spec = self.spec[ArrayKeys.GT_LABELS].copy() + spec = self.spec[self.labels_key].copy() spec.roi = gt_labels_roi - batch.arrays[ArrayKeys.GT_LABELS] = Array(data_labels, spec) + batch.arrays[self.labels_key] = Array(data_labels, spec) - if ArrayKeys.GT_IGNORE in request: - gt_ignore_roi = request[ArrayKeys.GT_IGNORE].roi + if self.ignore_key in request: + gt_ignore_roi = request[self.ignore_key].roi gt_ignore_shape = gt_ignore_roi.shape data_gt_ignore = np.ones(gt_ignore_shape) data_gt_ignore[:, gt_ignore_shape[1] // 6 :, :] = 0 - spec = self.spec[ArrayKeys.GT_IGNORE].copy() + spec = self.spec[self.ignore_key].copy() spec.roi = gt_ignore_roi - batch.arrays[ArrayKeys.GT_IGNORE] = Array(data_gt_ignore, spec) + batch.arrays[self.ignore_key] = Array(data_gt_ignore, spec) return batch -class TestPrepareMalis(ProviderTest): - def test_output(self): - ArrayKey("MALIS_COMP_LABEL") - - pipeline_with_ignore = ExampleSourcePrepareMalis() + PrepareMalis( - ArrayKeys.GT_LABELS, - ArrayKeys.MALIS_COMP_LABEL, - ignore_array_key=ArrayKeys.GT_IGNORE, +def test_malis(): + malis_key = ArrayKey("MALIS_COMP_LABEL") + labels_key = ArrayKey("LABELS") + ignore_key = ArrayKey("GT_IGNORE") + + pipeline_with_ignore = ExampleSourcePrepareMalis( + labels_key, ignore_key + ) + PrepareMalis( + labels_key, + malis_key, + ignore_array_key=ignore_key, + ) + pipeline_without_ignore = ExampleSourcePrepareMalis( + labels_key, ignore_key + ) + PrepareMalis( + labels_key, + malis_key, + ) + + # test that MALIS_COMP_LABEL not in batch if not in request + with build(pipeline_with_ignore): + request = BatchRequest() + request.add(labels_key, (90, 90, 90)) + request.add(ignore_key, (90, 90, 90)) + + batch = pipeline_with_ignore.request_batch(request) + + # test if array added to batch + assert malis_key not in batch.arrays + + # test usage with gt_ignore + with build(pipeline_with_ignore): + request = BatchRequest() + request.add(labels_key, (90, 90, 90)) + request.add(ignore_key, (90, 90, 90)) + request.add(malis_key, (90, 90, 90)) + + batch = pipeline_with_ignore.request_batch(request) + + # test if array added to batch + assert malis_key in batch.arrays + + # test if gt_ignore considered for gt_neg_pass ([0, ...]) and not for gt_pos_pass ([1, ...]) + ignored_locations = np.where(batch.arrays[ignore_key].data == 0) + # gt_neg_pass + assert (batch.arrays[malis_key].data[0, ...][ignored_locations] == 3).all() + assert not ( + np.array_equal( + batch.arrays[malis_key].data[0, ...], + batch.arrays[labels_key].data, + ) ) - pipeline_without_ignore = ExampleSourcePrepareMalis() + PrepareMalis( - ArrayKeys.GT_LABELS, - ArrayKeys.MALIS_COMP_LABEL, + # gt_pos_pass + assert not ( + (batch.arrays[malis_key].data[1, ...][ignored_locations] == 3).all() + ) + assert np.array_equal( + batch.arrays[malis_key].data[1, ...], + batch.arrays[labels_key].data, ) - # test that MALIS_COMP_LABEL not in batch if not in request - with build(pipeline_with_ignore): - request = BatchRequest() - request.add(ArrayKeys.GT_LABELS, (90, 90, 90)) - request.add(ArrayKeys.GT_IGNORE, (90, 90, 90)) - - batch = pipeline_with_ignore.request_batch(request) - - # test if array added to batch - self.assertTrue(ArrayKeys.MALIS_COMP_LABEL not in batch.arrays) - - # test usage with gt_ignore - with build(pipeline_with_ignore): - request = BatchRequest() - request.add(ArrayKeys.GT_LABELS, (90, 90, 90)) - request.add(ArrayKeys.GT_IGNORE, (90, 90, 90)) - request.add(ArrayKeys.MALIS_COMP_LABEL, (90, 90, 90)) - - batch = pipeline_with_ignore.request_batch(request) - - # test if array added to batch - self.assertTrue(ArrayKeys.MALIS_COMP_LABEL in batch.arrays) - - # test if gt_ignore considered for gt_neg_pass ([0, ...]) and not for gt_pos_pass ([1, ...]) - ignored_locations = np.where(batch.arrays[ArrayKeys.GT_IGNORE].data == 0) - # gt_neg_pass - self.assertTrue( - ( - batch.arrays[ArrayKeys.MALIS_COMP_LABEL].data[0, ...][ - ignored_locations - ] - == 3 - ).all() - ) - self.assertFalse( - ( - np.array_equal( - batch.arrays[ArrayKeys.MALIS_COMP_LABEL].data[0, ...], - batch.arrays[ArrayKeys.GT_LABELS].data, - ) - ) - ) - # gt_pos_pass - self.assertFalse( - ( - batch.arrays[ArrayKeys.MALIS_COMP_LABEL].data[1, ...][ - ignored_locations - ] - == 3 - ).all() - ) - self.assertTrue( - ( - np.array_equal( - batch.arrays[ArrayKeys.MALIS_COMP_LABEL].data[1, ...], - batch.arrays[ArrayKeys.GT_LABELS].data, - ) - ) - ) - - # Test ignore without requesting ignore array - request = BatchRequest() - request.add(ArrayKeys.GT_LABELS, (90, 90, 90)) - request.add(ArrayKeys.MALIS_COMP_LABEL, (90, 90, 90)) + # Test ignore without requesting ignore array + request = BatchRequest() + request.add(labels_key, (90, 90, 90)) + request.add(malis_key, (90, 90, 90)) - batch = pipeline_with_ignore.request_batch(request) + batch = pipeline_with_ignore.request_batch(request) - # test if array added to batch - self.assertTrue(ArrayKeys.MALIS_COMP_LABEL in batch.arrays) + # test if array added to batch + assert malis_key in batch.arrays - # gt_neg_pass - self.assertTrue( - ( - batch.arrays[ArrayKeys.MALIS_COMP_LABEL].data[0, ...][ - ignored_locations - ] - == 3 - ).all() - ) - self.assertFalse( - ( - np.array_equal( - batch.arrays[ArrayKeys.MALIS_COMP_LABEL].data[0, ...], - batch.arrays[ArrayKeys.GT_LABELS].data, - ) - ) - ) - # gt_pos_pass - self.assertFalse( - ( - batch.arrays[ArrayKeys.MALIS_COMP_LABEL].data[1, ...][ - ignored_locations - ] - == 3 - ).all() - ) - self.assertTrue( - ( - np.array_equal( - batch.arrays[ArrayKeys.MALIS_COMP_LABEL].data[1, ...], - batch.arrays[ArrayKeys.GT_LABELS].data, - ) - ) + # gt_neg_pass + assert (batch.arrays[malis_key].data[0, ...][ignored_locations] == 3).all() + assert not ( + np.array_equal( + batch.arrays[malis_key].data[0, ...], + batch.arrays[labels_key].data, ) + ) + # gt_pos_pass + assert not ( + (batch.arrays[malis_key].data[1, ...][ignored_locations] == 3).all() + ) + assert np.array_equal( + batch.arrays[malis_key].data[1, ...], + batch.arrays[labels_key].data, + ) - # test usage without gt_ignore - with build(pipeline_without_ignore): - request = BatchRequest() - request.add(ArrayKeys.GT_LABELS, (90, 90, 90)) - request.add(ArrayKeys.MALIS_COMP_LABEL, (90, 90, 90)) - - batch = pipeline_without_ignore.request_batch(request) - - # test if array added to batch - self.assertTrue(ArrayKeys.MALIS_COMP_LABEL in batch.arrays) - - # test if gt_ignore considered for gt_neg_pass ([0, ;;;]) and not for gt_pos_pass ([1, ...]) - # gt_neg_pass - self.assertTrue( - ( - np.array_equal( - batch.arrays[ArrayKeys.MALIS_COMP_LABEL].data[0, ...], - batch.arrays[ArrayKeys.GT_LABELS].data, - ) - ) - ) - # gt_pos_pass - self.assertTrue( - ( - np.array_equal( - batch.arrays[ArrayKeys.MALIS_COMP_LABEL].data[1, ...], - batch.arrays[ArrayKeys.GT_LABELS].data, - ) - ) - ) + # test usage without gt_ignore + with build(pipeline_without_ignore): + request = BatchRequest() + request.add(labels_key, (90, 90, 90)) + request.add(malis_key, (90, 90, 90)) + + batch = pipeline_without_ignore.request_batch(request) + + # test if array added to batch + assert malis_key in batch.arrays + + # test if gt_ignore considered for gt_neg_pass ([0, ;;;]) and not for gt_pos_pass ([1, ...]) + # gt_neg_pass + assert np.array_equal( + batch.arrays[malis_key].data[0, ...], + batch.arrays[labels_key].data, + ) + # gt_pos_pass + assert np.array_equal( + batch.arrays[malis_key].data[1, ...], + batch.arrays[labels_key].data, + ) diff --git a/tests/cases/profiling.py b/tests/cases/profiling.py index 967f240f..75af1d67 100644 --- a/tests/cases/profiling.py +++ b/tests/cases/profiling.py @@ -1,7 +1,20 @@ -from .provider_test import ProviderTest -from gunpowder import * import time +import numpy as np + +from gunpowder import ( + Array, + ArrayKey, + ArraySpec, + BatchFilter, + BatchRequest, + PrintProfilingStats, + Roi, + build, +) + +from .helper_sources import ArraySource + class DelayNode(BatchFilter): def __init__(self, time_prepare, time_process): @@ -18,36 +31,43 @@ def process(self, batch, request): time.sleep(self.time_process) -class TestProfiling(ProviderTest): - def test_profiling(self): - pipeline = ( - self.test_source - + DelayNode(0.1, 0.2) - + PrintProfilingStats(every=2) - + DelayNode(0.2, 0.3) - ) +def test_profiling(): + raw_key = ArrayKey("RAW") + raw_data = np.random.rand(100, 100, 100) + raw_spec = ArraySpec(Roi((0, 0, 0), (100, 100, 100)), voxel_size=(1, 1, 1)) + raw_array = Array(raw_data, raw_spec) + raw_source = ArraySource(raw_key, raw_array) + pipeline = ( + raw_source + + DelayNode(0.1, 0.2) + + PrintProfilingStats(every=2) + + DelayNode(0.2, 0.3) + ) + + request = BatchRequest() + request.add(raw_key, (100, 100, 100)) - with build(pipeline): - for i in range(5): - batch = pipeline.request_batch(self.test_request) + with build(pipeline): + for i in range(5): + batch = pipeline.request_batch(request) - profiling_stats = batch.profiling_stats + profiling_stats = batch.profiling_stats - summary = profiling_stats.get_timing_summary("DelayNode", "prepare") + summary = profiling_stats.get_timing_summary("DelayNode", "prepare") - # is the timing for each pass correct? - self.assertGreaterEqual(summary.min(), 0.1) - self.assertLessEqual(summary.min(), 0.2 + 0.1) # bit of tolerance + # is the timing for each pass correct? + assert summary.min() >= 0.1 + assert summary.min() <= 0.2 + 0.1 # bit of tolerance - summary = profiling_stats.get_timing_summary("DelayNode", "process") + summary = profiling_stats.get_timing_summary("DelayNode", "process") - self.assertGreaterEqual(summary.min(), 0.2) - self.assertLessEqual(summary.min(), 0.3 + 0.1) # bit of tolerance + assert summary.min() >= 0.2 + assert summary.min() <= 0.3 + 0.1 # bit of tolerance - # is the upstream time correct? - self.assertGreaterEqual( - profiling_stats.span_time(), 0.1 + 0.2 + 0.2 + 0.3 - ) # total time spend upstream - self.assertLessEqual( - profiling_stats.span_time(), 0.1 + 0.2 + 0.2 + 0.3 + 0.1 - ) # plus bit of tolerance + # is the upstream time correct? + assert ( + profiling_stats.span_time() >= 0.1 + 0.2 + 0.2 + 0.3 + ) # total time spend upstream + assert ( + profiling_stats.span_time() <= 0.1 + 0.2 + 0.2 + 0.3 + 0.1 + ) # plus bit of tolerance diff --git a/tests/cases/provider_test.py b/tests/cases/provider_test.py deleted file mode 100644 index 095334d7..00000000 --- a/tests/cases/provider_test.py +++ /dev/null @@ -1,131 +0,0 @@ -from gunpowder import * -import shutil -import os -import copy -from warnings import warn -import unittest -from datetime import datetime -from tempfile import mkdtemp -import numpy as np - - -class ExampleSource(BatchProvider): - def setup(self): - self.provides( - ArrayKeys.RAW, - ArraySpec( - roi=Roi((0, 0, 0), (100, 100, 100)), - voxel_size=Coordinate((1, 1, 1)), - dtype=np.uint8, - interpolatable=True, - ), - ) - - def provide(self, request): - data = np.zeros( - request[ArrayKeys.RAW].roi.shape / self.spec[ArrayKeys.RAW].voxel_size, - dtype=np.uint8, - ) - spec = copy.deepcopy(self.spec[ArrayKeys.RAW]) - spec.roi = request[ArrayKeys.RAW].roi - - batch = Batch() - batch.arrays[ArrayKeys.RAW] = Array(data, spec) - return batch - - -class TestWithTempFiles(unittest.TestCase): - """ - Usage: - - If your test case dumps out any files, use ``self.path_to("path", "to", "my.file")`` to get the path to a directory - in your temporary directory. This will be namespaced by the test class, timestamp and test method, e.g. - - >>> self.path_to("path", "to", "my.file") - /tmp/gunpowder_MyTestCase_2018-03-08T18:32:18.967927_r4nd0m/my_test_method/path/to/my.file - - Each test method's data will be deleted after the test case is run (regardless of pass, fail or error). - To disable test method data deletion, set ``self._cleanup = False`` anywhere in the test. - - The test case directory will be deleted after every test method is run, unless there is data left in it. - Any files written directly to the class output directory (rather than the test output subdirectory) should therefore - be explicitly removed before tearDownClass is called. - To disable data deletion for the whole class (the test case directory and all tests), set ``_cleanup = False`` in the - class definition. N.B. doing this in a method (``type(self)._cleanup = False``) will have unexpected results - depending on the order of test execution. - - Subclasses implementing their own setUp, setUpClass, tearDown and tearDownClass should explicitly call the - ``super`` method in the method definition. - """ - - _output_root = "" - _cleanup = True - - def path_to(self, *args): - return type(self).path_to_cls(self._testMethodName, *args) - - @classmethod - def path_to_cls(cls, *args): - return os.path.join(cls._output_root, *args) - - @classmethod - def setUpClass(cls): - timestamp = datetime.now().isoformat() - cls._output_root = mkdtemp( - prefix="gunpowder_{}_{}_".format(cls.__name__, timestamp) - ) - - def setUp(self): - os.mkdir(self.path_to()) - - def tearDown(self): - path = self.path_to() - try: - if self._cleanup: - shutil.rmtree(path) - else: - warn("Directory {} was not deleted".format(path)) - except OSError as e: - if "[Errno 2]" in str(e): - pass - else: - raise - - @classmethod - def tearDownClass(cls): - try: - if cls._cleanup: - os.rmdir(cls.path_to_cls()) - else: - warn("Directory {} was not deleted".format(cls.path_to_cls())) - except OSError as e: - if "[Errno 39]" in str(e): - warn( - "Directory {} could not be deleted as it still had data in it".format( - cls.path_to_cls() - ) - ) - elif "[Errno 2]" in str(e): - pass - else: - raise - - -class ProviderTest(TestWithTempFiles): - def setUp(self): - super(ProviderTest, self).setUp() - # create some common array keys to be used by concrete tests - ArrayKey("RAW") - ArrayKey("GT_LABELS") - ArrayKey("GT_AFFINITIES") - ArrayKey("GT_AFFINITIES_MASK") - ArrayKey("GT_MASK") - ArrayKey("GT_IGNORE") - ArrayKey("LOSS_SCALE") - GraphKey("GT_GRAPH") - - self.test_source = ExampleSource() - self.test_request = BatchRequest() - self.test_request[ArrayKeys.RAW] = ArraySpec( - roi=Roi((20, 20, 20), (10, 10, 10)) - ) diff --git a/tests/cases/random_location.py b/tests/cases/random_location.py index 611289a8..02679da2 100644 --- a/tests/cases/random_location.py +++ b/tests/cases/random_location.py @@ -1,25 +1,21 @@ +import numpy as np +import pytest + from gunpowder import ( - RandomLocation, - BatchProvider, - Roi, - Coordinate, + Array, ArrayKey, ArraySpec, - Array, - Roi, - Coordinate, Batch, - BatchRequest, BatchProvider, - RandomLocation, + BatchRequest, + Coordinate, MergeProvider, + RandomLocation, + Roi, build, ) -import numpy as np from gunpowder.pipeline import PipelineRequestError -import pytest - class ExampleSourceRandomLocation(BatchProvider): def __init__(self, array): diff --git a/tests/cases/random_location_graph.py b/tests/cases/random_location_graph.py index c04d759a..c7b6b759 100644 --- a/tests/cases/random_location_graph.py +++ b/tests/cases/random_location_graph.py @@ -1,20 +1,21 @@ -from .provider_test import ProviderTest +import logging + +import numpy as np + from gunpowder import ( + Batch, + BatchFilter, BatchProvider, + BatchRequest, Graph, - Node, - GraphSpec, GraphKey, GraphKeys, - Roi, - Batch, - BatchRequest, + GraphSpec, + Node, RandomLocation, + Roi, build, - BatchFilter, ) -import numpy as np -import logging logger = logging.getLogger(__name__) @@ -48,7 +49,8 @@ def process(self, batch, request): class SourceGraphLocation(BatchProvider): - def __init__(self): + def __init__(self, graph_key): + self.graph_key = graph_key self.graph = Graph( [Node(id=1, location=np.array([500, 500, 500]))], [], @@ -56,105 +58,89 @@ def __init__(self): ) def setup(self): - self.provides(GraphKeys.TEST_GRAPH, self.graph.spec) + self.provides(self.graph_key, self.graph.spec) def provide(self, request): batch = Batch() - roi = request[GraphKeys.TEST_GRAPH].roi - batch[GraphKeys.TEST_GRAPH] = self.graph.crop(roi).trim(roi) + roi = request[self.graph_key].roi + batch[self.graph_key] = self.graph.crop(roi).trim(roi) return batch -class TestRandomLocationGraph(ProviderTest): - def test_dim_size_1(self): - GraphKey("TEST_GRAPH") - upstream_roi = Roi((500, 401, 401), (1, 200, 200)) - pipeline = ( - SourceGraphLocation() - + BatchTester(upstream_roi, exact=False) - + RandomLocation(ensure_nonempty=GraphKeys.TEST_GRAPH) - ) +def test_dim_size_1(): + graph_key = GraphKey("TEST_GRAPH") + upstream_roi = Roi((500, 401, 401), (1, 200, 200)) + pipeline = ( + SourceGraphLocation(graph_key) + + BatchTester(upstream_roi, exact=False) + + RandomLocation(ensure_nonempty=graph_key) + ) - # count the number of times we get each node - with build(pipeline): - for i in range(500): - batch = pipeline.request_batch( - BatchRequest( - { - GraphKeys.TEST_GRAPH: GraphSpec( - roi=Roi((0, 0, 0), (1, 100, 100)) - ) - } - ) - ) + # count the number of times we get each node + with build(pipeline): + for _ in range(50): + batch = pipeline.request_batch( + BatchRequest({graph_key: GraphSpec(roi=Roi((0, 0, 0), (1, 100, 100)))}) + ) - assert len(list(batch[GraphKeys.TEST_GRAPH].nodes)) == 1 + assert len(list(batch[graph_key].nodes)) == 1 - def test_req_full_roi(self): - GraphKey("TEST_GRAPH") - possible_roi = Roi((0, 0, 0), (1000, 1000, 1000)) +def test_req_full_roi(): + graph_key = GraphKey("TEST_GRAPH") - pipeline = ( - SourceGraphLocation() - + BatchTester(possible_roi, exact=False) - + RandomLocation(ensure_nonempty=GraphKeys.TEST_GRAPH) + possible_roi = Roi((0, 0, 0), (1000, 1000, 1000)) + + pipeline = ( + SourceGraphLocation(graph_key) + + BatchTester(possible_roi, exact=False) + + RandomLocation(ensure_nonempty=graph_key) + ) + with build(pipeline): + batch = pipeline.request_batch( + BatchRequest({graph_key: GraphSpec(roi=Roi((0, 0, 0), (1000, 1000, 1000)))}) ) - with build(pipeline): - batch = pipeline.request_batch( - BatchRequest( - { - GraphKeys.TEST_GRAPH: GraphSpec( - roi=Roi((0, 0, 0), (1000, 1000, 1000)) - ) - } - ) - ) - assert len(list(batch[GraphKeys.TEST_GRAPH].nodes)) == 1 + assert len(list(batch[graph_key].nodes)) == 1 - def test_roi_one_point(self): - GraphKey("TEST_GRAPH") - upstream_roi = Roi((500, 500, 500), (1, 1, 1)) - pipeline = ( - SourceGraphLocation() - + BatchTester(upstream_roi, exact=True) - + RandomLocation(ensure_nonempty=GraphKeys.TEST_GRAPH) - ) +def test_roi_one_point(): + graph_key = GraphKey("TEST_GRAPH") + upstream_roi = Roi((500, 500, 500), (1, 1, 1)) - with build(pipeline): - for i in range(500): - batch = pipeline.request_batch( - BatchRequest( - {GraphKeys.TEST_GRAPH: GraphSpec(roi=Roi((0, 0, 0), (1, 1, 1)))} - ) - ) + pipeline = ( + SourceGraphLocation(graph_key) + + BatchTester(upstream_roi, exact=True) + + RandomLocation(ensure_nonempty=graph_key) + ) - assert len(list(batch[GraphKeys.TEST_GRAPH].nodes)) == 1 + with build(pipeline): + for _ in range(50): + batch = pipeline.request_batch( + BatchRequest({graph_key: GraphSpec(roi=Roi((0, 0, 0), (1, 1, 1)))}) + ) - def test_iso_roi(self): - GraphKey("TEST_GRAPH") - upstream_roi = Roi((401, 401, 401), (200, 200, 200)) + assert len(list(batch[graph_key].nodes)) == 1 - pipeline = ( - SourceGraphLocation() - + BatchTester(upstream_roi, exact=False) - + RandomLocation(ensure_nonempty=GraphKeys.TEST_GRAPH) - ) - with build(pipeline): - for i in range(500): - batch = pipeline.request_batch( - BatchRequest( - { - GraphKeys.TEST_GRAPH: GraphSpec( - roi=Roi((0, 0, 0), (100, 100, 100)) - ) - } - ) +def test_iso_roi(): + graph_key = GraphKey("TEST_GRAPH") + upstream_roi = Roi((401, 401, 401), (200, 200, 200)) + + pipeline = ( + SourceGraphLocation(graph_key) + + BatchTester(upstream_roi, exact=False) + + RandomLocation(ensure_nonempty=graph_key) + ) + + with build(pipeline): + for _ in range(50): + batch = pipeline.request_batch( + BatchRequest( + {graph_key: GraphSpec(roi=Roi((0, 0, 0), (100, 100, 100)))} ) + ) - assert len(list(batch[GraphKeys.TEST_GRAPH].nodes)) == 1 + assert len(list(batch[graph_key].nodes)) == 1 diff --git a/tests/cases/random_location_points.py b/tests/cases/random_location_points.py index fa322d80..d57e0992 100644 --- a/tests/cases/random_location_points.py +++ b/tests/cases/random_location_points.py @@ -1,27 +1,24 @@ -from .provider_test import ProviderTest +import numpy as np +import pytest + from gunpowder import ( + Batch, BatchProvider, BatchRequest, - Batch, - Node, + Coordinate, Graph, - GraphSpec, GraphKey, - GraphKeys, + GraphSpec, + Node, RandomLocation, - build, Roi, - Coordinate, + build, ) -import numpy as np -import pytest - -import unittest - class ExampleSourceRandomLocation(BatchProvider): - def __init__(self): + def __init__(self, points_key): + self.points_key = points_key self.graph = Graph( [ Node(1, np.array([1, 1, 1])), @@ -33,170 +30,145 @@ def __init__(self): ) def setup(self): - self.provides(GraphKeys.TEST_POINTS, self.graph.spec) + self.provides(self.points_key, self.graph.spec) def provide(self, request): batch = Batch() - roi = request[GraphKeys.TEST_POINTS].roi - batch[GraphKeys.TEST_POINTS] = self.graph.crop(roi).trim(roi) + roi = request[self.points_key].roi + batch[self.points_key] = self.graph.crop(roi).trim(roi) return batch -class TestRandomLocationPoints(ProviderTest): - @pytest.mark.xfail - def test_output(self): - """ - Fails due to probabilities being calculated in advance, rather than after creating - each roi. The new approach does not account for all possible roi's containing - each point, some of which may not contain its nearest neighbors. - """ +def test_output(): + points_key = GraphKey("TEST_POINTS") - GraphKey("TEST_POINTS") + pipeline = ExampleSourceRandomLocation(points_key) + RandomLocation( + ensure_nonempty=points_key, point_balance_radius=100 + ) - pipeline = ExampleSourceRandomLocation() + RandomLocation( - ensure_nonempty=GraphKeys.TEST_POINTS, point_balance_radius=100 - ) + # count the number of times we get each point + histogram = {} - # count the number of times we get each point - histogram = {} - - with build(pipeline): - for i in range(5000): - batch = pipeline.request_batch( - BatchRequest( - { - GraphKeys.TEST_POINTS: GraphSpec( - roi=Roi((0, 0, 0), (100, 100, 100)) - ) - } - ) + with build(pipeline): + for i in range(500): + batch = pipeline.request_batch( + BatchRequest( + {points_key: GraphSpec(roi=Roi((0, 0, 0), (100, 100, 100)))} ) + ) - points = {node.id: node for node in batch[GraphKeys.TEST_POINTS].nodes} + points = {node.id: node for node in batch[points_key].nodes} - self.assertTrue(len(points) > 0) - self.assertTrue((1 in points) != (2 in points or 3 in points), points) + assert len(points) > 0 + assert (1 in points) != (2 in points or 3 in points) - for node in batch[GraphKeys.TEST_POINTS].nodes: - if node.id not in histogram: - histogram[node.id] = 1 - else: - histogram[node.id] += 1 + for node in batch[points_key].nodes: + if node.id not in histogram: + histogram[node.id] = 1 + else: + histogram[node.id] += 1 - total = sum(histogram.values()) - for k, v in histogram.items(): - histogram[k] = float(v) / total + total = sum(histogram.values()) + for k, v in histogram.items(): + histogram[k] = float(v) / total - # we should get roughly the same count for each point - for i in histogram.keys(): - for j in histogram.keys(): - self.assertAlmostEqual(histogram[i], histogram[j], 1) + # we should get roughly the same count for each point + for i in histogram.keys(): + for j in histogram.keys(): + assert abs(histogram[i] - histogram[j]) < 1 - def test_equal_probability(self): - GraphKey("TEST_POINTS") - pipeline = ExampleSourceRandomLocation() + RandomLocation( - ensure_nonempty=GraphKeys.TEST_POINTS - ) +def test_equal_probability(): + points_key = GraphKey("TEST_POINTS") - # count the number of times we get each point - histogram = {} - - with build(pipeline): - for i in range(5000): - batch = pipeline.request_batch( - BatchRequest( - { - GraphKeys.TEST_POINTS: GraphSpec( - roi=Roi((0, 0, 0), (10, 10, 10)) - ) - } - ) - ) + pipeline = ExampleSourceRandomLocation(points_key) + RandomLocation( + ensure_nonempty=points_key + ) - points = {node.id: node for node in batch[GraphKeys.TEST_POINTS].nodes} - - self.assertTrue(len(points) > 0) - self.assertTrue((1 in points) != (2 in points or 3 in points), points) - - for point in batch[GraphKeys.TEST_POINTS].nodes: - if point.id not in histogram: - histogram[point.id] = 1 - else: - histogram[point.id] += 1 - - total = sum(histogram.values()) - for k, v in histogram.items(): - histogram[k] = float(v) / total - - # we should get roughly the same count for each point - for i in histogram.keys(): - for j in histogram.keys(): - self.assertAlmostEqual(histogram[i], histogram[j], 1) - - @unittest.expectedFailure - def test_ensure_centered(self): - """ - Expected failure due to emergent behavior of two desired rules: - 1) Points on the upper bound of Roi are not considered contained - 2) When considering a point as a center of a random location, - scale by the number of points within some delta distance - - if two points are equally likely to be chosen, and centering - a roi on either of them means the other is on the bounding box - of the roi, then it can be the case that if the roi is centered - one of them, the roi contains only that one, but if the roi is - centered on the second, then both are considered contained, - breaking the equal likelihood of picking each point. - """ - - GraphKey("TEST_POINTS") - - pipeline = ExampleSourceRandomLocation() + RandomLocation( - ensure_nonempty=GraphKeys.TEST_POINTS, ensure_centered=True - ) + # count the number of times we get each point + histogram = {} - # count the number of times we get each point - histogram = {} - - with build(pipeline): - for i in range(5000): - batch = pipeline.request_batch( - BatchRequest( - { - GraphKeys.TEST_POINTS: GraphSpec( - roi=Roi((0, 0, 0), (100, 100, 100)) - ) - } - ) - ) + with build(pipeline): + for i in range(500): + batch = pipeline.request_batch( + BatchRequest({points_key: GraphSpec(roi=Roi((0, 0, 0), (10, 10, 10)))}) + ) - points = batch[GraphKeys.TEST_POINTS].data - roi = batch[GraphKeys.TEST_POINTS].spec.roi + points = {node.id: node for node in batch[points_key].nodes} - locations = tuple( - [Coordinate(point.location) for point in points.values()] - ) - self.assertTrue( - Coordinate([50, 50, 50]) in locations, - f"locations: {tuple([point.location for point in points.values()])}", + assert len(points) > 0 + assert (1 in points) != (2 in points or 3 in points) + + for point in batch[points_key].nodes: + if point.id not in histogram: + histogram[point.id] = 1 + else: + histogram[point.id] += 1 + + total = sum(histogram.values()) + for k, v in histogram.items(): + histogram[k] = float(v) / total + + # we should get roughly the same count for each point + for i in histogram.keys(): + for j in histogram.keys(): + assert abs(histogram[i] - histogram[j]) < 1 + + +@pytest.mark.xfail +def test_ensure_centered(): + """ + Expected failure due to emergent behavior of two desired rules: + 1) Points on the upper bound of Roi are not considered contained + 2) When considering a point as a center of a random location, + scale by the number of points within some delta distance + + if two points are equally likely to be chosen, and centering + a roi on either of them means the other is on the bounding box + of the roi, then it can be the case that if the roi is centered + one of them, the roi contains only that one, but if the roi is + centered on the second, then both are considered contained, + breaking the equal likelihood of picking each point. + """ + + points_key = GraphKey("TEST_POINTS") + + pipeline = ExampleSourceRandomLocation(points_key) + RandomLocation( + ensure_nonempty=points_key, ensure_centered=True + ) + + # count the number of times we get each point + histogram = {} + + with build(pipeline): + for i in range(500): + batch = pipeline.request_batch( + BatchRequest( + {points_key: GraphSpec(roi=Roi((0, 0, 0), (100, 100, 100)))} ) + ) + + points = {node.id: node for node in batch[points_key].nodes} + roi = batch[points_key].spec.roi + + locations = tuple([Coordinate(point.location) for point in points.values()]) + assert Coordinate([50, 50, 50]) in locations - self.assertTrue(len(points) > 0) - self.assertTrue((1 in points) != (2 in points or 3 in points), points) + assert len(points) > 0 + assert (1 in points) != (2 in points or 3 in points) - for point_id in batch[GraphKeys.TEST_POINTS].data.keys(): - if point_id not in histogram: - histogram[point_id] = 1 - else: - histogram[node.id] += 1 + for point_id in batch[points_key].data.keys(): + if point_id not in histogram: + histogram[point_id] = 1 + else: + histogram[point_id] += 1 - total = sum(histogram.values()) - for k, v in histogram.items(): - histogram[k] = float(v) / total + total = sum(histogram.values()) + for k, v in histogram.items(): + histogram[k] = float(v) / total - # we should get roughly the same count for each point - for i in histogram.keys(): - for j in histogram.keys(): - self.assertAlmostEqual(histogram[i], histogram[j], 1, histogram) + # we should get roughly the same count for each point + for i in histogram.keys(): + for j in histogram.keys(): + assert abs(histogram[i] - histogram[j]) < 1 diff --git a/tests/cases/random_provider.py b/tests/cases/random_provider.py index 6bcce871..d66d2f9b 100644 --- a/tests/cases/random_provider.py +++ b/tests/cases/random_provider.py @@ -1,14 +1,14 @@ +import numpy as np + from gunpowder import ( - RandomProvider, - Roi, + Array, ArrayKey, ArraySpec, - Array, - Roi, BatchRequest, + RandomProvider, + Roi, build, ) -import numpy as np from .helper_sources import ArraySource diff --git a/tests/cases/rasterize_points.py b/tests/cases/rasterize_points.py index fc21da39..25062f55 100644 --- a/tests/cases/rasterize_points.py +++ b/tests/cases/rasterize_points.py @@ -1,20 +1,21 @@ -from .helper_sources import ArraySource, GraphSource +import numpy as np + from gunpowder import ( - BatchRequest, - Roi, - Coordinate, - GraphSpec, Array, ArrayKey, ArraySpec, - RasterizeGraph, + BatchRequest, + Coordinate, + GraphSpec, MergeProvider, RasterizationSettings, + RasterizeGraph, + Roi, build, ) -from gunpowder.graph import GraphKey, Graph, Node, Edge +from gunpowder.graph import Edge, Graph, GraphKey, Node -import numpy as np +from .helper_sources import ArraySource, GraphSource def test_rasterize_graph_colors(): diff --git a/tests/cases/resample.py b/tests/cases/resample.py index 9784b152..05b6687f 100644 --- a/tests/cases/resample.py +++ b/tests/cases/resample.py @@ -1,17 +1,18 @@ -from .helper_sources import ArraySource +import numpy as np from gunpowder import ( + Array, ArrayKey, ArraySpec, - Roi, - Coordinate, BatchRequest, - Array, + Coordinate, MergeProvider, Resample, + Roi, build, ) -import numpy as np + +from .helper_sources import ArraySource def test_up_and_downsample(): @@ -73,7 +74,10 @@ def test_up_and_downsample(): assert np.array_equal(array.data, data), str(array_key) elif array_key == raw_resampled_key: - # Note: First assert averages over the voxels in the raw roi: (40:48, 40:48, 40:48), values of [30,31,31,32,31,32,32,33], the average of which is 31.5. Casting to an integer, in this case, rounds down, resulting in 31. + # Note: First assert averages over the voxels in the raw roi: + # (40:48, 40:48, 40:48), values of [30,31,31,32,31,32,32,33], the average of + # which is 31.5. Casting to an integer, in this case, rounds down, resulting + # in 31. assert ( array.data[0, 0, 0, 0] == 31 ), f"RAW_RESAMPLED[0,0,0]: {array.data[0,0,0]} does not equal expected: 31" diff --git a/tests/cases/scan.py b/tests/cases/scan.py index 95f5d33b..c16c44e5 100644 --- a/tests/cases/scan.py +++ b/tests/cases/scan.py @@ -1,23 +1,25 @@ +import itertools + +import numpy as np + from gunpowder import ( - BatchProvider, - BatchRequest, - Batch, - ArrayKeys, + Array, ArrayKey, + ArrayKeys, ArraySpec, - Array, + Batch, + BatchProvider, + BatchRequest, + Coordinate, + Graph, GraphKey, GraphKeys, GraphSpec, - Graph, Node, Roi, - Coordinate, Scan, build, ) -import numpy as np -import itertools def coordinate_to_id(i, j, k): diff --git a/tests/cases/shift_augment.py b/tests/cases/shift_augment.py index f04bb777..75ab40e3 100644 --- a/tests/cases/shift_augment.py +++ b/tests/cases/shift_augment.py @@ -1,564 +1,579 @@ -import unittest -import numpy as np import random + import h5py -import logging -import sys -import os +import numpy as np +import pytest + from gunpowder import ( ArrayKey, ArraySpec, + BatchRequest, + Coordinate, + CsvPointsSource, + Graph, GraphKey, GraphSpec, - Graph, + Hdf5Source, + MergeProvider, Node, RandomLocation, - Coordinate, Roi, - BatchRequest, - Hdf5Source, ShiftAugment, - CsvPointsSource, - MergeProvider, build, ) from gunpowder.pipeline import PipelineRequestError -logger = logging.getLogger(__name__) -logger.setLevel(logging.DEBUG) -stream_handler = logging.StreamHandler(sys.stdout) -logger.addHandler(stream_handler) - - -class TestShiftAugment2D(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.fake_points_file = "shift_test.csv" - cls.fake_data_file = "shift_test.hdf5" - random.seed(1234) - np.random.seed(1234) - cls.fake_data = np.array([[i + j for i in range(100)] for j in range(100)]) - with h5py.File(cls.fake_data_file, "w") as f: - f.create_dataset("testdata", shape=cls.fake_data.shape, data=cls.fake_data) - cls.fake_points = np.random.randint(0, 100, size=(2, 2)) - with open(cls.fake_points_file, "w") as f: - for point in cls.fake_points: - f.write(str(point[0]) + "\t" + str(point[1]) + "\n") - - def setUp(self): - random.seed(12345) - np.random.seed(12345) - - @classmethod - def tearDownClass(cls): - os.remove(cls.fake_data_file) - os.remove(cls.fake_points_file) - - ################## - # full pipeline # - ################## - - def test_prepare1(self): - key = ArrayKey("TEST_ARRAY") - spec = ArraySpec(voxel_size=Coordinate((1, 1)), interpolatable=True) - - hdf5_source = Hdf5Source( - self.fake_data_file, {key: "testdata"}, array_specs={key: spec} - ) - - request = BatchRequest() - shape = Coordinate((3, 3)) - request.add(key, shape, voxel_size=Coordinate((1, 1))) - - shift_node = ShiftAugment(sigma=1, shift_axis=0) - with build((hdf5_source + shift_node)): - shift_node.prepare(request) - self.assertTrue(shift_node.ndim == 2) - self.assertTrue(shift_node.shift_sigmas == tuple([0.0, 1.0])) - - def test_prepare2(self): - key = ArrayKey("TEST_ARRAY") - spec = ArraySpec(voxel_size=Coordinate((1, 1)), interpolatable=True) - - hdf5_source = Hdf5Source( - self.fake_data_file, {key: "testdata"}, array_specs={key: spec} - ) - - request = BatchRequest() - shape = Coordinate((3, 3)) - request.add(key, shape) - - shift_node = ShiftAugment(sigma=1, shift_axis=0) - - with build((hdf5_source + shift_node)): - shift_node.prepare(request) - self.assertTrue(shift_node.ndim == 2) - self.assertTrue(shift_node.shift_sigmas == tuple([0.0, 1.0])) - - def test_pipeline1(self): - key = ArrayKey("TEST_ARRAY") - spec = ArraySpec(voxel_size=Coordinate((2, 1)), interpolatable=True) - - hdf5_source = Hdf5Source( - self.fake_data_file, {key: "testdata"}, array_specs={key: spec} - ) - - request = BatchRequest() - shape = Coordinate((3, 3)) - request.add(key, shape, voxel_size=Coordinate((3, 1))) - - shift_node = ShiftAugment(prob_slip=0.2, prob_shift=0.2, sigma=1, shift_axis=0) - with build((hdf5_source + shift_node)) as b: - with self.assertRaises(PipelineRequestError): - b.request_batch(request) - - def test_pipeline2(self): - key = ArrayKey("TEST_ARRAY") - spec = ArraySpec(voxel_size=Coordinate((3, 1)), interpolatable=True) - - hdf5_source = Hdf5Source( - self.fake_data_file, {key: "testdata"}, array_specs={key: spec} - ) - - request = BatchRequest() - shape = Coordinate((3, 3)) - request[key] = ArraySpec(roi=Roi((9, 9), shape), voxel_size=Coordinate((3, 1))) - - shift_node = ShiftAugment(prob_slip=0.2, prob_shift=0.2, sigma=1, shift_axis=0) - with build((hdf5_source + shift_node)) as b: + +# automatically set the seed for all tests +@pytest.fixture(autouse=True) +def seeds(): + random.seed(12345) + np.random.seed(12345) + + +@pytest.fixture +def test_points(tmpdir): + random.seed(1234) + np.random.seed(1234) + + fake_points_file = tmpdir / "shift_test.csv" + fake_data_file = tmpdir / "shift_test.hdf5" + fake_data = np.array([[i + j for i in range(100)] for j in range(100)]) + with h5py.File(fake_data_file, "w") as f: + f.create_dataset("testdata", shape=fake_data.shape, data=fake_data) + fake_points = np.random.randint(0, 100, size=(2, 2)) + with open(fake_points_file, "w") as f: + for point in fake_points: + f.write(str(point[0]) + "\t" + str(point[1]) + "\n") + + # This fixture will run after seeds since it is set + # with autouse=True. So make sure to reset the seeds properly at the end + # of this fixture + random.seed(12345) + np.random.seed(12345) + + yield fake_points_file, fake_data_file, fake_points, fake_data + + +def test_prepare1(test_points): + _, fake_data_file, *_ = test_points + + key = ArrayKey("TEST_ARRAY") + spec = ArraySpec(voxel_size=Coordinate((1, 1)), interpolatable=True) + + hdf5_source = Hdf5Source(fake_data_file, {key: "testdata"}, array_specs={key: spec}) + + request = BatchRequest() + shape = Coordinate((3, 3)) + request.add(key, shape, voxel_size=Coordinate((1, 1))) + + shift_node = ShiftAugment(sigma=1, shift_axis=0) + with build((hdf5_source + shift_node)): + shift_node.prepare(request) + assert shift_node.ndim == 2 + assert shift_node.shift_sigmas == tuple([0.0, 1.0]) + + +def test_prepare2(test_points): + _, fake_data_file, *_ = test_points + + key = ArrayKey("TEST_ARRAY") + spec = ArraySpec(voxel_size=Coordinate((1, 1)), interpolatable=True) + + hdf5_source = Hdf5Source(fake_data_file, {key: "testdata"}, array_specs={key: spec}) + + request = BatchRequest() + shape = Coordinate((3, 3)) + request.add(key, shape) + + shift_node = ShiftAugment(sigma=1, shift_axis=0) + + with build((hdf5_source + shift_node)): + shift_node.prepare(request) + assert shift_node.ndim == 2 + assert shift_node.shift_sigmas == tuple([0.0, 1.0]) + + +def test_pipeline1(test_points): + _, fake_data_file, *_ = test_points + + key = ArrayKey("TEST_ARRAY") + spec = ArraySpec(voxel_size=Coordinate((2, 1)), interpolatable=True) + + hdf5_source = Hdf5Source(fake_data_file, {key: "testdata"}, array_specs={key: spec}) + + request = BatchRequest() + shape = Coordinate((3, 3)) + request.add(key, shape, voxel_size=Coordinate((3, 1))) + + shift_node = ShiftAugment(prob_slip=0.2, prob_shift=0.2, sigma=1, shift_axis=0) + with build((hdf5_source + shift_node)) as b: + with pytest.raises(PipelineRequestError): b.request_batch(request) - def test_pipeline3(self): - array_key = ArrayKey("TEST_ARRAY") - points_key = GraphKey("TEST_POINTS") - voxel_size = Coordinate((1, 1)) - spec = ArraySpec(voxel_size=voxel_size, interpolatable=True) - - hdf5_source = Hdf5Source( - self.fake_data_file, {array_key: "testdata"}, array_specs={array_key: spec} - ) - csv_source = CsvPointsSource( - self.fake_points_file, - points_key, - GraphSpec(roi=Roi(shape=Coordinate((100, 100)), offset=(0, 0))), - ) - - request = BatchRequest() - shape = Coordinate((60, 60)) - request.add(array_key, shape, voxel_size=Coordinate((1, 1))) - request.add(points_key, shape) - - shift_node = ShiftAugment(prob_slip=0.2, prob_shift=0.2, sigma=4, shift_axis=0) - pipeline = ( - (hdf5_source, csv_source) - + MergeProvider() - + RandomLocation(ensure_nonempty=points_key) - + shift_node - ) - with build(pipeline) as b: - request = b.request_batch(request) - # print(request[points_key]) - - target_vals = [self.fake_data[point[0]][point[1]] for point in self.fake_points] - result_data = request[array_key].data - result_points = list(request[points_key].nodes) - result_vals = [ - result_data[int(point.location[0])][int(point.location[1])] - for point in result_points - ] - - for result_val in result_vals: - self.assertTrue( - result_val in target_vals, - msg="result value {} at points {} not in target values {} at points {}".format( - result_val, - list(result_points), - target_vals, - self.fake_points, - ), - ) - - ################## - # shift_and_crop # - ################## - - def test_shift_and_crop_static(self): - shift_node = ShiftAugment(sigma=1, shift_axis=0) - shift_node.ndim = 2 - upstream_arr = np.arange(16).reshape(4, 4) - sub_shift_array = np.zeros(8, dtype=int).reshape(4, 2) - roi_shape = (4, 4) - voxel_size = Coordinate((1, 1)) - - downstream_arr = np.arange(16).reshape(4, 4) - - result = shift_node.shift_and_crop( - upstream_arr, roi_shape, sub_shift_array, voxel_size - ) - self.assertTrue(np.array_equal(result, downstream_arr)) - - def test_shift_and_crop1(self): - shift_node = ShiftAugment(sigma=1, shift_axis=0) - shift_node.ndim = 2 - upstream_arr = np.arange(16).reshape(4, 4) - sub_shift_array = np.zeros(8, dtype=int).reshape(4, 2) - sub_shift_array[:, 1] = np.array([0, -1, 1, 0], dtype=int) - roi_shape = (4, 2) - voxel_size = Coordinate((1, 1)) - - downstream_arr = np.array([[1, 2], [6, 7], [8, 9], [13, 14]], dtype=int) - - result = shift_node.shift_and_crop( - upstream_arr, roi_shape, sub_shift_array, voxel_size - ) - self.assertTrue(np.array_equal(result, downstream_arr)) - - def test_shift_and_crop2(self): - shift_node = ShiftAugment(sigma=1, shift_axis=0) - shift_node.ndim = 2 - upstream_arr = np.arange(16).reshape(4, 4) - sub_shift_array = np.zeros(8, dtype=int).reshape(4, 2) - sub_shift_array[:, 1] = np.array([0, -1, -2, 0], dtype=int) - roi_shape = (4, 2) - voxel_size = Coordinate((1, 1)) - - downstream_arr = np.array([[0, 1], [5, 6], [10, 11], [12, 13]], dtype=int) - - result = shift_node.shift_and_crop( - upstream_arr, roi_shape, sub_shift_array, voxel_size - ) - self.assertTrue(np.array_equal(result, downstream_arr)) - - def test_shift_and_crop3(self): - shift_node = ShiftAugment(sigma=1, shift_axis=1) - shift_node.ndim = 2 - upstream_arr = np.arange(16).reshape(4, 4) - sub_shift_array = np.zeros(8, dtype=int).reshape(4, 2) - sub_shift_array[:, 0] = np.array([0, 1, 0, 2], dtype=int) - roi_shape = (2, 4) - voxel_size = Coordinate((1, 1)) - - downstream_arr = np.array([[8, 5, 10, 3], [12, 9, 14, 7]], dtype=int) - - result = shift_node.shift_and_crop( - upstream_arr, roi_shape, sub_shift_array, voxel_size - ) - # print(result) - self.assertTrue(np.array_equal(result, downstream_arr)) - - def test_shift_and_crop4(self): - shift_node = ShiftAugment(sigma=1, shift_axis=1) - shift_node.ndim = 2 - upstream_arr = np.arange(16).reshape(4, 4) - sub_shift_array = np.zeros(8, dtype=int).reshape(4, 2) - sub_shift_array[:, 0] = np.array([0, 2, 0, 4], dtype=int) - roi_shape = (4, 4) - voxel_size = Coordinate((2, 1)) - - downstream_arr = np.array([[8, 5, 10, 3], [12, 9, 14, 7]], dtype=int) - - result = shift_node.shift_and_crop( - upstream_arr, roi_shape, sub_shift_array, voxel_size - ) - # print(result) - self.assertTrue(np.array_equal(result, downstream_arr)) - - result = shift_node.shift_and_crop( - upstream_arr, roi_shape, sub_shift_array, voxel_size - ) - # print(result) - self.assertTrue(np.array_equal(result, downstream_arr)) - - ################## - # shift_points # - ################## - - @staticmethod - def points_equal(vertices1, vertices2): - vs1 = sorted(list(vertices1), key=lambda v: tuple(v.location)) - vs2 = sorted(list(vertices2), key=lambda v: tuple(v.location)) - - for v1, v2 in zip(vs1, vs2): - if not v1.id == v2.id: - print(f"{vs1}, {vs2}") - return False - if not all(np.isclose(v1.location, v2.location)): - print(f"{vs1}, {vs2}") - return False - return True - - def test_points_equal(self): - points1 = [Node(id=1, location=np.array([0, 1]))] - points2 = [Node(id=1, location=np.array([0, 1]))] - self.assertTrue(self.points_equal(points1, points2)) - - points1 = [Node(id=2, location=np.array([1, 2]))] - points2 = [Node(id=2, location=np.array([2, 1]))] - self.assertFalse(self.points_equal(points1, points2)) - - def test_shift_points1(self): - data = [Node(id=1, location=np.array([0, 1]))] - spec = GraphSpec(Roi(offset=(0, 0), shape=(5, 5))) - points = Graph(data, [], spec) - request_roi = Roi(offset=(0, 1), shape=(5, 3)) - shift_array = np.array([[0, -1], [0, -1], [0, 0], [0, 0], [0, 1]], dtype=int) - lcm_voxel_size = Coordinate((1, 1)) - - shifted_points = Graph([], [], GraphSpec(request_roi)) - result = ShiftAugment.shift_points( - points, - request_roi, - shift_array, - shift_axis=0, - lcm_voxel_size=lcm_voxel_size, - ) - # print(result) - self.assertTrue(self.points_equal(result.nodes, shifted_points.nodes)) - self.assertTrue(result.spec == GraphSpec(request_roi)) - - def test_shift_points2(self): - data = [Node(id=1, location=np.array([0, 1]))] - spec = GraphSpec(Roi(offset=(0, 0), shape=(5, 5))) - points = Graph(data, [], spec) - request_roi = Roi(offset=(0, 1), shape=(5, 3)) - shift_array = np.array([[0, 0], [0, -1], [0, 0], [0, 0], [0, 1]], dtype=int) - lcm_voxel_size = Coordinate((1, 1)) - - result = ShiftAugment.shift_points( - points, - request_roi, - shift_array, - shift_axis=0, - lcm_voxel_size=lcm_voxel_size, - ) - # print("test 2", result.data, data) - self.assertTrue(self.points_equal(result.nodes, data)) - self.assertTrue(result.spec == GraphSpec(request_roi)) - - def test_shift_points3(self): - data = [Node(id=1, location=np.array([0, 1]))] - spec = GraphSpec(Roi(offset=(0, 0), shape=(5, 5))) - points = Graph(data, [], spec) - request_roi = Roi(offset=(0, 1), shape=(5, 3)) - shift_array = np.array([[0, 1], [0, -1], [0, 0], [0, 0], [0, 1]], dtype=int) - lcm_voxel_size = Coordinate((1, 1)) - - shifted_points = Graph( - [Node(id=1, location=np.array([0, 2]))], [], GraphSpec(request_roi) - ) - result = ShiftAugment.shift_points( - points, - request_roi, - shift_array, - shift_axis=0, - lcm_voxel_size=lcm_voxel_size, - ) - # print("test 3", result.data, shifted_points.data) - self.assertTrue(self.points_equal(result.nodes, shifted_points.nodes)) - self.assertTrue(result.spec == GraphSpec(request_roi)) - - def test_shift_points4(self): - data = [ - Node(id=0, location=np.array([1, 0])), - Node(id=1, location=np.array([1, 1])), - Node(id=2, location=np.array([1, 2])), - Node(id=3, location=np.array([1, 3])), - Node(id=4, location=np.array([1, 4])), - ] - spec = GraphSpec(Roi(offset=(0, 0), shape=(5, 5))) - points = Graph(data, [], spec) - request_roi = Roi(offset=(1, 0), shape=(3, 5)) - shift_array = np.array([[1, 0], [-1, 0], [0, 0], [-1, 0], [1, 0]], dtype=int) - - lcm_voxel_size = Coordinate((1, 1)) - shifted_data = [ - Node(id=0, location=np.array([2, 0])), - Node(id=2, location=np.array([1, 2])), - Node(id=4, location=np.array([2, 4])), - ] - result = ShiftAugment.shift_points( - points, - request_roi, - shift_array, - shift_axis=1, - lcm_voxel_size=lcm_voxel_size, - ) - # print("test 4", result.data, shifted_data) - self.assertTrue(self.points_equal(result.nodes, shifted_data)) - self.assertTrue(result.spec == GraphSpec(request_roi)) - - def test_shift_points5(self): - data = [ - Node(id=0, location=np.array([3, 0])), - Node(id=1, location=np.array([3, 2])), - Node(id=2, location=np.array([3, 4])), - Node(id=3, location=np.array([3, 6])), - Node(id=4, location=np.array([3, 8])), - ] - spec = GraphSpec(Roi(offset=(0, 0), shape=(15, 10))) - points = Graph(data, [], spec) - request_roi = Roi(offset=(3, 0), shape=(9, 10)) - shift_array = np.array([[3, 0], [-3, 0], [0, 0], [-3, 0], [3, 0]], dtype=int) - - lcm_voxel_size = Coordinate((3, 2)) - shifted_data = [ - Node(id=0, location=np.array([6, 0])), - Node(id=2, location=np.array([3, 4])), - Node(id=4, location=np.array([6, 8])), - ] - result = ShiftAugment.shift_points( - points, - request_roi, - shift_array, - shift_axis=1, - lcm_voxel_size=lcm_voxel_size, - ) - # print("test 4", result.data, shifted_data) - self.assertTrue(self.points_equal(result.nodes, shifted_data)) - self.assertTrue(result.spec == GraphSpec(request_roi)) - - ####################### - # get_sub_shift_array # - ####################### - - def test_get_sub_shift_array1(self): - total_roi = Roi(offset=(0, 0), shape=(6, 6)) - item_roi = Roi(offset=(1, 2), shape=(3, 3)) - shift_array = np.arange(12).reshape(6, 2).astype(int) - shift_axis = 1 - lcm_voxel_size = Coordinate((1, 1)) - - sub_shift_array = np.array([[4, 5], [6, 7], [8, 9]], dtype=int) - result = ShiftAugment.get_sub_shift_array( - total_roi, item_roi, shift_array, shift_axis, lcm_voxel_size - ) - # print(result) - self.assertTrue(np.array_equal(result, sub_shift_array)) - - def test_get_sub_shift_array2(self): - total_roi = Roi(offset=(0, 0), shape=(6, 6)) - item_roi = Roi(offset=(1, 2), shape=(3, 3)) - shift_array = np.arange(12).reshape(6, 2).astype(int) - shift_axis = 0 - lcm_voxel_size = Coordinate((1, 1)) - - sub_shift_array = np.array([[2, 3], [4, 5], [6, 7]], dtype=int) - result = ShiftAugment.get_sub_shift_array( - total_roi, item_roi, shift_array, shift_axis, lcm_voxel_size - ) - self.assertTrue(np.array_equal(result, sub_shift_array)) - - def test_get_sub_shift_array3(self): - total_roi = Roi(offset=(0, 0), shape=(18, 12)) - item_roi = Roi(offset=(3, 4), shape=(9, 6)) - shift_array = np.arange(12).reshape(6, 2).astype(int) - shift_axis = 0 - lcm_voxel_size = Coordinate((3, 2)) - - sub_shift_array = np.array([[2, 3], [4, 5], [6, 7]], dtype=int) - result = ShiftAugment.get_sub_shift_array( - total_roi, item_roi, shift_array, shift_axis, lcm_voxel_size - ) - # print(result) - self.assertTrue(np.array_equal(result, sub_shift_array)) - - ################################ - # construct_global_shift_array # - ################################ - - def test_construct_global_shift_array_static(self): - shift_axis_len = 5 - shift_sigmas = (0.0, 1.0) - prob_slip = 0 - prob_shift = 0 - lcm_voxel_size = Coordinate((1, 1)) - - shift_array = np.zeros(shape=(shift_axis_len, len(shift_sigmas)), dtype=int) - result = ShiftAugment.construct_global_shift_array( - shift_axis_len, shift_sigmas, prob_shift, prob_slip, lcm_voxel_size - ) - self.assertTrue(np.array_equal(result, shift_array)) - - def test_construct_global_shift_array1(self): - shift_axis_len = 5 - shift_sigmas = (0.0, 1.0) - prob_slip = 1 - prob_shift = 0 - lcm_voxel_size = Coordinate((1, 1)) - - shift_array = np.array([[0, 0], [0, -1], [0, 1], [0, 0], [0, 1]], dtype=int) - result = ShiftAugment.construct_global_shift_array( - shift_axis_len, shift_sigmas, prob_slip, prob_shift, lcm_voxel_size - ) - # print(result) - self.assertTrue(len(result) == shift_axis_len) - for position_shift in result: - self.assertTrue(position_shift[0] == 0) - self.assertTrue(np.array_equal(shift_array, result)) - - def test_construct_global_shift_array2(self): - shift_axis_len = 5 - shift_sigmas = (0.0, 1.0) - prob_slip = 0 - prob_shift = 1 - lcm_voxel_size = Coordinate((1, 1)) - - shift_array = np.array([[0, 0], [0, -1], [0, 0], [0, 0], [0, 1]], dtype=int) - result = ShiftAugment.construct_global_shift_array( - shift_axis_len, shift_sigmas, prob_slip, prob_shift, lcm_voxel_size - ) - self.assertTrue(len(result) == shift_axis_len) - for position_shift in result: - self.assertTrue(position_shift[0] == 0) - self.assertTrue(np.array_equal(shift_array, result)) - - def test_construct_global_shift_array3(self): - shift_axis_len = 5 - shift_sigmas = (0.0, 4.0) - prob_slip = 0 - prob_shift = 1 - lcm_voxel_size = Coordinate((1, 3)) - - shift_array = np.array([[0, 3], [0, 0], [0, 6], [0, 6], [0, 12]], dtype=int) - result = ShiftAugment.construct_global_shift_array( - shift_axis_len, shift_sigmas, prob_slip, prob_shift, lcm_voxel_size - ) - # print(result) - self.assertTrue(len(result) == shift_axis_len) - for position_shift in result: - self.assertTrue(position_shift[0] == 0) - self.assertTrue(np.array_equal(shift_array, result)) - - ######################## - # compute_upstream_roi # - ######################## - - def test_compute_upstream_roi_static(self): - request_roi = Roi(offset=(0, 0), shape=(5, 10)) - sub_shift_array = np.array([[0, 0], [0, 0], [0, 0], [0, 0], [0, 0]], dtype=int) - - upstream_roi = Roi(offset=(0, 0), shape=(5, 10)) - result = ShiftAugment.compute_upstream_roi(request_roi, sub_shift_array) - self.assertTrue(upstream_roi == result) - - def test_compute_upstream_roi1(self): - request_roi = Roi(offset=(0, 0), shape=(5, 10)) - sub_shift_array = np.array([[0, 0], [0, -1], [0, 0], [0, 0], [0, 1]], dtype=int) - - upstream_roi = Roi(offset=(0, -1), shape=(5, 12)) - result = ShiftAugment.compute_upstream_roi(request_roi, sub_shift_array) - self.assertTrue(upstream_roi == result) - - def test_compute_upstream_roi2(self): - request_roi = Roi(offset=(0, 0), shape=(5, 10)) - sub_shift_array = np.array( - [[2, 0], [-1, 0], [5, 0], [-2, 0], [0, 0]], dtype=int - ) - - upstream_roi = Roi(offset=(-5, 0), shape=(12, 10)) - result = ShiftAugment.compute_upstream_roi(request_roi, sub_shift_array) - self.assertTrue(upstream_roi == result) - - -if __name__ == "__main__": - unittest.main() + +def test_pipeline2(test_points): + _, fake_data_file, *_ = test_points + + key = ArrayKey("TEST_ARRAY") + spec = ArraySpec(voxel_size=Coordinate((3, 1)), interpolatable=True) + + hdf5_source = Hdf5Source(fake_data_file, {key: "testdata"}, array_specs={key: spec}) + + request = BatchRequest() + shape = Coordinate((3, 3)) + request[key] = ArraySpec(roi=Roi((9, 9), shape), voxel_size=Coordinate((3, 1))) + + shift_node = ShiftAugment(prob_slip=0.2, prob_shift=0.2, sigma=1, shift_axis=0) + with build((hdf5_source + shift_node)) as b: + b.request_batch(request) + + +def test_pipeline3(test_points): + fake_points_file, fake_data_file, fake_points, fake_data = test_points + + array_key = ArrayKey("TEST_ARRAY") + points_key = GraphKey("TEST_POINTS") + voxel_size = Coordinate((1, 1)) + spec = ArraySpec(voxel_size=voxel_size, interpolatable=True) + + hdf5_source = Hdf5Source( + fake_data_file, {array_key: "testdata"}, array_specs={array_key: spec} + ) + csv_source = CsvPointsSource( + fake_points_file, + points_key, + GraphSpec(roi=Roi(shape=Coordinate((100, 100)), offset=(0, 0))), + ) + + request = BatchRequest() + shape = Coordinate((60, 60)) + request.add(array_key, shape, voxel_size=Coordinate((1, 1))) + request.add(points_key, shape) + + shift_node = ShiftAugment(prob_slip=0.2, prob_shift=0.2, sigma=4, shift_axis=0) + pipeline = ( + (hdf5_source, csv_source) + + MergeProvider() + + RandomLocation(ensure_nonempty=points_key) + + shift_node + ) + with build(pipeline) as b: + request = b.request_batch(request) + # print(request[points_key]) + + target_vals = [fake_data[point[0]][point[1]] for point in fake_points] + result_data = request[array_key].data + result_points = list(request[points_key].nodes) + result_vals = [ + result_data[int(point.location[0])][int(point.location[1])] + for point in result_points + ] + + for result_val in result_vals: + assert result_val in target_vals + + +################## +# shift_and_crop # +################## + + +def test_shift_and_crop_static(): + shift_node = ShiftAugment(sigma=1, shift_axis=0) + shift_node.ndim = 2 + upstream_arr = np.arange(16).reshape(4, 4) + sub_shift_array = np.zeros(8, dtype=int).reshape(4, 2) + roi_shape = (4, 4) + voxel_size = Coordinate((1, 1)) + + downstream_arr = np.arange(16).reshape(4, 4) + + result = shift_node.shift_and_crop( + upstream_arr, roi_shape, sub_shift_array, voxel_size + ) + assert np.array_equal(result, downstream_arr) + + +def test_shift_and_crop1(): + shift_node = ShiftAugment(sigma=1, shift_axis=0) + shift_node.ndim = 2 + upstream_arr = np.arange(16).reshape(4, 4) + sub_shift_array = np.zeros(8, dtype=int).reshape(4, 2) + sub_shift_array[:, 1] = np.array([0, -1, 1, 0], dtype=int) + roi_shape = (4, 2) + voxel_size = Coordinate((1, 1)) + + downstream_arr = np.array([[1, 2], [6, 7], [8, 9], [13, 14]], dtype=int) + + result = shift_node.shift_and_crop( + upstream_arr, roi_shape, sub_shift_array, voxel_size + ) + assert np.array_equal(result, downstream_arr) + + +def test_shift_and_crop2(): + shift_node = ShiftAugment(sigma=1, shift_axis=0) + shift_node.ndim = 2 + upstream_arr = np.arange(16).reshape(4, 4) + sub_shift_array = np.zeros(8, dtype=int).reshape(4, 2) + sub_shift_array[:, 1] = np.array([0, -1, -2, 0], dtype=int) + roi_shape = (4, 2) + voxel_size = Coordinate((1, 1)) + + downstream_arr = np.array([[0, 1], [5, 6], [10, 11], [12, 13]], dtype=int) + + result = shift_node.shift_and_crop( + upstream_arr, roi_shape, sub_shift_array, voxel_size + ) + assert np.array_equal(result, downstream_arr) + + +def test_shift_and_crop3(): + shift_node = ShiftAugment(sigma=1, shift_axis=1) + shift_node.ndim = 2 + upstream_arr = np.arange(16).reshape(4, 4) + sub_shift_array = np.zeros(8, dtype=int).reshape(4, 2) + sub_shift_array[:, 0] = np.array([0, 1, 0, 2], dtype=int) + roi_shape = (2, 4) + voxel_size = Coordinate((1, 1)) + + downstream_arr = np.array([[8, 5, 10, 3], [12, 9, 14, 7]], dtype=int) + + result = shift_node.shift_and_crop( + upstream_arr, roi_shape, sub_shift_array, voxel_size + ) + # print(result) + assert np.array_equal(result, downstream_arr) + + +def test_shift_and_crop4(): + shift_node = ShiftAugment(sigma=1, shift_axis=1) + shift_node.ndim = 2 + upstream_arr = np.arange(16).reshape(4, 4) + sub_shift_array = np.zeros(8, dtype=int).reshape(4, 2) + sub_shift_array[:, 0] = np.array([0, 2, 0, 4], dtype=int) + roi_shape = (4, 4) + voxel_size = Coordinate((2, 1)) + + downstream_arr = np.array([[8, 5, 10, 3], [12, 9, 14, 7]], dtype=int) + + result = shift_node.shift_and_crop( + upstream_arr, roi_shape, sub_shift_array, voxel_size + ) + # print(result) + assert np.array_equal(result, downstream_arr) + + result = shift_node.shift_and_crop( + upstream_arr, roi_shape, sub_shift_array, voxel_size + ) + # print(result) + assert np.array_equal(result, downstream_arr) + + +################## +# shift_points # +################## + + +def points_equal(vertices1, vertices2): + vs1 = sorted(list(vertices1), key=lambda v: tuple(v.location)) + vs2 = sorted(list(vertices2), key=lambda v: tuple(v.location)) + + for v1, v2 in zip(vs1, vs2): + if not v1.id == v2.id: + print(f"{vs1}, {vs2}") + return False + if not all(np.isclose(v1.location, v2.location)): + print(f"{vs1}, {vs2}") + return False + return True + + +def test_points_equal(): + points1 = [Node(id=1, location=np.array([0, 1]))] + points2 = [Node(id=1, location=np.array([0, 1]))] + assert points_equal(points1, points2) + + points1 = [Node(id=2, location=np.array([1, 2]))] + points2 = [Node(id=2, location=np.array([2, 1]))] + assert not points_equal(points1, points2) + + +def test_shift_points1(): + data = [Node(id=1, location=np.array([0, 1]))] + spec = GraphSpec(Roi(offset=(0, 0), shape=(5, 5))) + points = Graph(data, [], spec) + request_roi = Roi(offset=(0, 1), shape=(5, 3)) + shift_array = np.array([[0, -1], [0, -1], [0, 0], [0, 0], [0, 1]], dtype=int) + lcm_voxel_size = Coordinate((1, 1)) + + shifted_points = Graph([], [], GraphSpec(request_roi)) + result = ShiftAugment.shift_points( + points, + request_roi, + shift_array, + shift_axis=0, + lcm_voxel_size=lcm_voxel_size, + ) + # print(result) + assert points_equal(result.nodes, shifted_points.nodes) + assert result.spec == GraphSpec(request_roi) + + +def test_shift_points2(): + data = [Node(id=1, location=np.array([0, 1]))] + spec = GraphSpec(Roi(offset=(0, 0), shape=(5, 5))) + points = Graph(data, [], spec) + request_roi = Roi(offset=(0, 1), shape=(5, 3)) + shift_array = np.array([[0, 0], [0, -1], [0, 0], [0, 0], [0, 1]], dtype=int) + lcm_voxel_size = Coordinate((1, 1)) + + result = ShiftAugment.shift_points( + points, + request_roi, + shift_array, + shift_axis=0, + lcm_voxel_size=lcm_voxel_size, + ) + # print("test 2", result.data, data) + assert points_equal(result.nodes, data) + assert result.spec == GraphSpec(request_roi) + + +def test_shift_points3(): + data = [Node(id=1, location=np.array([0, 1]))] + spec = GraphSpec(Roi(offset=(0, 0), shape=(5, 5))) + points = Graph(data, [], spec) + request_roi = Roi(offset=(0, 1), shape=(5, 3)) + shift_array = np.array([[0, 1], [0, -1], [0, 0], [0, 0], [0, 1]], dtype=int) + lcm_voxel_size = Coordinate((1, 1)) + + shifted_points = Graph( + [Node(id=1, location=np.array([0, 2]))], [], GraphSpec(request_roi) + ) + result = ShiftAugment.shift_points( + points, + request_roi, + shift_array, + shift_axis=0, + lcm_voxel_size=lcm_voxel_size, + ) + # print("test 3", result.data, shifted_points.data) + assert points_equal(result.nodes, shifted_points.nodes) + assert result.spec == GraphSpec(request_roi) + + +def test_shift_points4(): + data = [ + Node(id=0, location=np.array([1, 0])), + Node(id=1, location=np.array([1, 1])), + Node(id=2, location=np.array([1, 2])), + Node(id=3, location=np.array([1, 3])), + Node(id=4, location=np.array([1, 4])), + ] + spec = GraphSpec(Roi(offset=(0, 0), shape=(5, 5))) + points = Graph(data, [], spec) + request_roi = Roi(offset=(1, 0), shape=(3, 5)) + shift_array = np.array([[1, 0], [-1, 0], [0, 0], [-1, 0], [1, 0]], dtype=int) + + lcm_voxel_size = Coordinate((1, 1)) + shifted_data = [ + Node(id=0, location=np.array([2, 0])), + Node(id=2, location=np.array([1, 2])), + Node(id=4, location=np.array([2, 4])), + ] + result = ShiftAugment.shift_points( + points, + request_roi, + shift_array, + shift_axis=1, + lcm_voxel_size=lcm_voxel_size, + ) + # print("test 4", result.data, shifted_data) + assert points_equal(result.nodes, shifted_data) + assert result.spec == GraphSpec(request_roi) + + +def test_shift_points5(): + data = [ + Node(id=0, location=np.array([3, 0])), + Node(id=1, location=np.array([3, 2])), + Node(id=2, location=np.array([3, 4])), + Node(id=3, location=np.array([3, 6])), + Node(id=4, location=np.array([3, 8])), + ] + spec = GraphSpec(Roi(offset=(0, 0), shape=(15, 10))) + points = Graph(data, [], spec) + request_roi = Roi(offset=(3, 0), shape=(9, 10)) + shift_array = np.array([[3, 0], [-3, 0], [0, 0], [-3, 0], [3, 0]], dtype=int) + + lcm_voxel_size = Coordinate((3, 2)) + shifted_data = [ + Node(id=0, location=np.array([6, 0])), + Node(id=2, location=np.array([3, 4])), + Node(id=4, location=np.array([6, 8])), + ] + result = ShiftAugment.shift_points( + points, + request_roi, + shift_array, + shift_axis=1, + lcm_voxel_size=lcm_voxel_size, + ) + # print("test 4", result.data, shifted_data) + assert points_equal(result.nodes, shifted_data) + assert result.spec == GraphSpec(request_roi) + + +####################### +# get_sub_shift_array # +####################### + + +def test_get_sub_shift_array1(): + total_roi = Roi(offset=(0, 0), shape=(6, 6)) + item_roi = Roi(offset=(1, 2), shape=(3, 3)) + shift_array = np.arange(12).reshape(6, 2).astype(int) + shift_axis = 1 + lcm_voxel_size = Coordinate((1, 1)) + + sub_shift_array = np.array([[4, 5], [6, 7], [8, 9]], dtype=int) + result = ShiftAugment.get_sub_shift_array( + total_roi, item_roi, shift_array, shift_axis, lcm_voxel_size + ) + # print(result) + assert np.array_equal(result, sub_shift_array) + + +def test_get_sub_shift_array2(): + total_roi = Roi(offset=(0, 0), shape=(6, 6)) + item_roi = Roi(offset=(1, 2), shape=(3, 3)) + shift_array = np.arange(12).reshape(6, 2).astype(int) + shift_axis = 0 + lcm_voxel_size = Coordinate((1, 1)) + + sub_shift_array = np.array([[2, 3], [4, 5], [6, 7]], dtype=int) + result = ShiftAugment.get_sub_shift_array( + total_roi, item_roi, shift_array, shift_axis, lcm_voxel_size + ) + assert np.array_equal(result, sub_shift_array) + + +def test_get_sub_shift_array3(): + total_roi = Roi(offset=(0, 0), shape=(18, 12)) + item_roi = Roi(offset=(3, 4), shape=(9, 6)) + shift_array = np.arange(12).reshape(6, 2).astype(int) + shift_axis = 0 + lcm_voxel_size = Coordinate((3, 2)) + + sub_shift_array = np.array([[2, 3], [4, 5], [6, 7]], dtype=int) + result = ShiftAugment.get_sub_shift_array( + total_roi, item_roi, shift_array, shift_axis, lcm_voxel_size + ) + # print(result) + assert np.array_equal(result, sub_shift_array) + + +################################ +# construct_global_shift_array # +################################ + + +def test_construct_global_shift_array_static(): + shift_axis_len = 5 + shift_sigmas = (0.0, 1.0) + prob_slip = 0 + prob_shift = 0 + lcm_voxel_size = Coordinate((1, 1)) + + shift_array = np.zeros(shape=(shift_axis_len, len(shift_sigmas)), dtype=int) + result = ShiftAugment.construct_global_shift_array( + shift_axis_len, shift_sigmas, prob_shift, prob_slip, lcm_voxel_size + ) + assert np.array_equal(result, shift_array) + + +def test_construct_global_shift_array1(): + shift_axis_len = 5 + shift_sigmas = (0.0, 1.0) + prob_slip = 1 + prob_shift = 0 + lcm_voxel_size = Coordinate((1, 1)) + + shift_array = np.array([[0, 0], [0, -1], [0, 1], [0, 0], [0, 1]], dtype=int) + result = ShiftAugment.construct_global_shift_array( + shift_axis_len, shift_sigmas, prob_slip, prob_shift, lcm_voxel_size + ) + # print(result) + assert len(result) == shift_axis_len + for position_shift in result: + assert position_shift[0] == 0 + assert np.array_equal(shift_array, result) + + +def test_construct_global_shift_array2(): + shift_axis_len = 5 + shift_sigmas = (0.0, 1.0) + prob_slip = 0 + prob_shift = 1 + lcm_voxel_size = Coordinate((1, 1)) + + shift_array = np.array([[0, 0], [0, -1], [0, 0], [0, 0], [0, 1]], dtype=int) + result = ShiftAugment.construct_global_shift_array( + shift_axis_len, shift_sigmas, prob_slip, prob_shift, lcm_voxel_size + ) + assert len(result) == shift_axis_len + for position_shift in result: + assert position_shift[0] == 0 + assert np.array_equal(shift_array, result) + + +def test_construct_global_shift_array3(): + shift_axis_len = 5 + shift_sigmas = (0.0, 4.0) + prob_slip = 0 + prob_shift = 1 + lcm_voxel_size = Coordinate((1, 3)) + + shift_array = np.array([[0, 3], [0, 0], [0, 6], [0, 6], [0, 12]], dtype=int) + result = ShiftAugment.construct_global_shift_array( + shift_axis_len, shift_sigmas, prob_slip, prob_shift, lcm_voxel_size + ) + # print(result) + assert len(result) == shift_axis_len + for position_shift in result: + assert position_shift[0] == 0 + assert np.array_equal(shift_array, result) + + +######################## +# compute_upstream_roi # +######################## + + +def test_compute_upstream_roi_static(): + request_roi = Roi(offset=(0, 0), shape=(5, 10)) + sub_shift_array = np.array([[0, 0], [0, 0], [0, 0], [0, 0], [0, 0]], dtype=int) + + upstream_roi = Roi(offset=(0, 0), shape=(5, 10)) + result = ShiftAugment.compute_upstream_roi(request_roi, sub_shift_array) + assert upstream_roi == result + + +def test_compute_upstream_roi1(): + request_roi = Roi(offset=(0, 0), shape=(5, 10)) + sub_shift_array = np.array([[0, 0], [0, -1], [0, 0], [0, 0], [0, 1]], dtype=int) + + upstream_roi = Roi(offset=(0, -1), shape=(5, 12)) + result = ShiftAugment.compute_upstream_roi(request_roi, sub_shift_array) + assert upstream_roi == result + + +def test_compute_upstream_roi2(): + request_roi = Roi(offset=(0, 0), shape=(5, 10)) + sub_shift_array = np.array([[2, 0], [-1, 0], [5, 0], [-2, 0], [0, 0]], dtype=int) + + upstream_roi = Roi(offset=(-5, 0), shape=(12, 10)) + result = ShiftAugment.compute_upstream_roi(request_roi, sub_shift_array) + assert upstream_roi == result diff --git a/tests/cases/simple_augment.py b/tests/cases/simple_augment.py index 9243b5d2..5e976859 100644 --- a/tests/cases/simple_augment.py +++ b/tests/cases/simple_augment.py @@ -1,22 +1,22 @@ +import numpy as np + from gunpowder import ( - BatchRequest, Array, ArrayKey, ArraySpec, + BatchRequest, + Coordinate, Graph, GraphKey, GraphSpec, + MergeProvider, Node, - Coordinate, Roi, SimpleAugment, - MergeProvider, build, ) -import numpy as np - -from .helper_sources import GraphSource, ArraySource +from .helper_sources import ArraySource, GraphSource def test_mirror(): diff --git a/tests/cases/snapshot.py b/tests/cases/snapshot.py index 928076ea..13c0fcbc 100644 --- a/tests/cases/snapshot.py +++ b/tests/cases/snapshot.py @@ -1,26 +1,24 @@ +from pathlib import Path + +import h5py +import numpy as np + from gunpowder import ( - GraphKey, - GraphKeys, - GraphSpec, - Graph, + Array, ArrayKey, ArraySpec, - Array, - Snapshot, + Batch, BatchProvider, BatchRequest, - Batch, Coordinate, + Graph, + GraphKey, + GraphKeys, + GraphSpec, Roi, + Snapshot, build, ) -import numpy as np - -import unittest -import tempfile -import shutil -from pathlib import Path -import h5py class ExampleSource(BatchProvider): @@ -50,60 +48,53 @@ def provide(self, request): return outputs -class TestSnapshot(unittest.TestCase): - def setUp(self): - self.test_dir = tempfile.mkdtemp() - - def tearDown(self): - shutil.rmtree(self.test_dir) - - def test_3d(self): - test_graph = GraphKey("TEST_GRAPH") - graph_spec = GraphSpec(roi=Roi((0, 0, 0), (5, 5, 5))) - test_array = ArrayKey("TEST_ARRAY") - array_spec = ArraySpec( - roi=Roi((0, 0, 0), (5, 5, 5)), voxel_size=Coordinate((1, 1, 1)) - ) - test_array2 = ArrayKey("TEST_ARRAY2") - array2_spec = ArraySpec( - roi=Roi((0, 0, 0), (5, 5, 5)), voxel_size=Coordinate((1, 1, 1)) - ) - - snapshot_request = BatchRequest() - snapshot_request.add(test_graph, Coordinate((5, 5, 5))) - - pipeline = ExampleSource( - [test_graph, test_array, test_array2], [graph_spec, array_spec, array2_spec] - ) + Snapshot( - { - test_graph: "graphs/graph", - test_array: "volumes/array", - test_array2: "volumes/array2", - }, - output_dir=str(self.test_dir), - every=2, - additional_request=snapshot_request, - output_filename="snapshot.hdf", - ) - - snapshot_file_path = Path(self.test_dir, "snapshot.hdf") - - with build(pipeline): - request = BatchRequest() - roi = Roi((0, 0, 0), (5, 5, 5)) - - request[test_array] = ArraySpec(roi=roi) - request[test_array2] = ArraySpec(roi=roi) - - pipeline.request_batch(request) - - assert snapshot_file_path.exists() - f = h5py.File(snapshot_file_path, "r+") - assert f["volumes/array"] is not None - assert f["graphs/graph-ids"] is not None - - snapshot_file_path.unlink() - - pipeline.request_batch(request) - - assert not snapshot_file_path.exists() +def test_3d(tmpdir): + test_graph = GraphKey("TEST_GRAPH") + graph_spec = GraphSpec(roi=Roi((0, 0, 0), (5, 5, 5))) + test_array = ArrayKey("TEST_ARRAY") + array_spec = ArraySpec( + roi=Roi((0, 0, 0), (5, 5, 5)), voxel_size=Coordinate((1, 1, 1)) + ) + test_array2 = ArrayKey("TEST_ARRAY2") + array2_spec = ArraySpec( + roi=Roi((0, 0, 0), (5, 5, 5)), voxel_size=Coordinate((1, 1, 1)) + ) + + snapshot_request = BatchRequest() + snapshot_request.add(test_graph, Coordinate((5, 5, 5))) + + pipeline = ExampleSource( + [test_graph, test_array, test_array2], [graph_spec, array_spec, array2_spec] + ) + Snapshot( + { + test_graph: "graphs/graph", + test_array: "volumes/array", + test_array2: "volumes/array2", + }, + output_dir=tmpdir, + every=2, + additional_request=snapshot_request, + output_filename="snapshot.hdf", + ) + + snapshot_file_path = Path(tmpdir, "snapshot.hdf") + + with build(pipeline): + request = BatchRequest() + roi = Roi((0, 0, 0), (5, 5, 5)) + + request[test_array] = ArraySpec(roi=roi) + request[test_array2] = ArraySpec(roi=roi) + + pipeline.request_batch(request) + + assert snapshot_file_path.exists() + f = h5py.File(snapshot_file_path, "r+") + assert f["volumes/array"] is not None + assert f["graphs/graph-ids"] is not None + + snapshot_file_path.unlink() + + pipeline.request_batch(request) + + assert not snapshot_file_path.exists() diff --git a/tests/cases/specified_location.py b/tests/cases/specified_location.py index fb5ea04a..00d7f526 100644 --- a/tests/cases/specified_location.py +++ b/tests/cases/specified_location.py @@ -1,19 +1,19 @@ # from .provider_test import ProviderTest, ExampleSource +import numpy as np + from gunpowder import ( - BatchProvider, + Array, + ArrayKey, ArrayKeys, ArraySpec, - Roi, Batch, + BatchProvider, + BatchRequest, Coordinate, + Roi, SpecifiedLocation, build, - BatchRequest, - Array, - ArrayKey, ) -import numpy as np -import unittest class ExampleSourceSpecifiedLocation(BatchProvider): @@ -42,70 +42,63 @@ def provide(self, request): return batch -class TestSpecifiedLocation(unittest.TestCase): - def setUp(self): - ArrayKey("RAW") +def test_simple(): + ArrayKey("RAW") + locations = [[0, 0, 0], [100, 100, 100], [91, 20, 20], [42, 24, 57]] - def test_simple(self): - locations = [[0, 0, 0], [100, 100, 100], [91, 20, 20], [42, 24, 57]] + pipeline = ExampleSourceSpecifiedLocation( + roi=Roi((0, 0, 0), (100, 100, 100)), voxel_size=(1, 1, 1) + ) + SpecifiedLocation( + locations, choose_randomly=False, extra_data=None, jitter=None + ) - pipeline = ExampleSourceSpecifiedLocation( - roi=Roi((0, 0, 0), (100, 100, 100)), voxel_size=(1, 1, 1) - ) + SpecifiedLocation( - locations, choose_randomly=False, extra_data=None, jitter=None + with build(pipeline): + batch = pipeline.request_batch( + BatchRequest({ArrayKeys.RAW: ArraySpec(roi=Roi((0, 0, 0), (20, 20, 20)))}) ) + # first three locations are skipped + # fourth should start at [32, 14, 47] of self.data + assert batch.arrays[ArrayKeys.RAW].data[0, 0, 0] == 321447 + + +def test_voxel_size(): + ArrayKey("RAW") + locations = [[0, 0, 0], [91, 20, 20], [42, 24, 57]] - with build(pipeline): - batch = pipeline.request_batch( - BatchRequest( - {ArrayKeys.RAW: ArraySpec(roi=Roi((0, 0, 0), (20, 20, 20)))} - ) - ) - # first three locations are skipped - # fourth should start at [32, 14, 47] of self.data - self.assertEqual(batch.arrays[ArrayKeys.RAW].data[0, 0, 0], 321447) - - def test_voxel_size(self): - locations = [[0, 0, 0], [91, 20, 20], [42, 24, 57]] - - pipeline = ExampleSourceSpecifiedLocation( - roi=Roi((0, 0, 0), (100, 100, 100)), voxel_size=(5, 2, 2) - ) + SpecifiedLocation( - locations, choose_randomly=False, extra_data=None, jitter=None + pipeline = ExampleSourceSpecifiedLocation( + roi=Roi((0, 0, 0), (100, 100, 100)), voxel_size=(5, 2, 2) + ) + SpecifiedLocation( + locations, choose_randomly=False, extra_data=None, jitter=None + ) + + with build(pipeline): + batch = pipeline.request_batch( + BatchRequest({ArrayKeys.RAW: ArraySpec(roi=Roi((0, 0, 0), (20, 20, 20)))}) ) + # first locations is skipped + # second should start at [80/5, 10/2, 10/2] = [16, 5, 5] + assert batch.arrays[ArrayKeys.RAW].data[0, 0, 0] == 40255 - with build(pipeline): - batch = pipeline.request_batch( - BatchRequest( - {ArrayKeys.RAW: ArraySpec(roi=Roi((0, 0, 0), (20, 20, 20)))} - ) - ) - # first locations is skipped - # second should start at [80/5, 10/2, 10/2] = [16, 5, 5] - self.assertEqual(batch.arrays[ArrayKeys.RAW].data[0, 0, 0], 40255) - - batch = pipeline.request_batch( - BatchRequest( - {ArrayKeys.RAW: ArraySpec(roi=Roi((0, 0, 0), (20, 20, 20)))} - ) - ) - # third should start at [30/5, 14/2, 48/2] = [6, 7, 23] - self.assertEqual(batch.arrays[ArrayKeys.RAW].data[0, 0, 0], 15374) - - def test_jitter_and_random(self): - locations = [[0, 0, 0], [91, 20, 20], [42, 24, 57]] - - pipeline = ExampleSourceSpecifiedLocation( - roi=Roi((0, 0, 0), (100, 100, 100)), voxel_size=(5, 2, 2) - ) + SpecifiedLocation( - locations, choose_randomly=True, extra_data=None, jitter=(5, 5, 5) + batch = pipeline.request_batch( + BatchRequest({ArrayKeys.RAW: ArraySpec(roi=Roi((0, 0, 0), (20, 20, 20)))}) ) + # third should start at [30/5, 14/2, 48/2] = [6, 7, 23] + assert batch.arrays[ArrayKeys.RAW].data[0, 0, 0] == 15374 + + +def test_jitter_and_random(): + ArrayKey("RAW") + locations = [[0, 0, 0], [91, 20, 20], [42, 24, 57]] - with build(pipeline): - batch = pipeline.request_batch( - BatchRequest( - {ArrayKeys.RAW: ArraySpec(roi=Roi((0, 0, 0), (20, 20, 20)))} - ) - ) - # Unclear what result should be, so no errors means passing - self.assertTrue(batch.arrays[ArrayKeys.RAW].data[0, 0, 0] > 0) + pipeline = ExampleSourceSpecifiedLocation( + roi=Roi((0, 0, 0), (100, 100, 100)), voxel_size=(5, 2, 2) + ) + SpecifiedLocation( + locations, choose_randomly=True, extra_data=None, jitter=(5, 5, 5) + ) + + with build(pipeline): + batch = pipeline.request_batch( + BatchRequest({ArrayKeys.RAW: ArraySpec(roi=Roi((0, 0, 0), (20, 20, 20)))}) + ) + # Unclear what result should be, so no errors means passing + assert batch.arrays[ArrayKeys.RAW].data[0, 0, 0] > 0 diff --git a/tests/cases/squeeze.py b/tests/cases/squeeze.py index 1c5a07e1..2a8add72 100644 --- a/tests/cases/squeeze.py +++ b/tests/cases/squeeze.py @@ -1,17 +1,18 @@ import copy + import numpy as np -import gunpowder as gp +import pytest -from .provider_test import ProviderTest +import gunpowder as gp class ExampleSourceSqueeze(gp.BatchProvider): - def __init__(self, voxel_size): + def __init__(self, voxel_size, raw_key, labels_key): self.voxel_size = gp.Coordinate(voxel_size) self.roi = gp.Roi((0, 0, 0), (10, 10, 10)) * self.voxel_size - self.raw = gp.ArrayKey("RAW") - self.labels = gp.ArrayKey("LABELS") + self.raw = raw_key + self.labels = labels_key self.array_spec_raw = gp.ArraySpec( roi=self.roi, voxel_size=self.voxel_size, dtype="uint8", interpolatable=True @@ -60,44 +61,44 @@ def provide(self, request): return outputs -class TestSqueeze(ProviderTest): - def test_squeeze(self): - raw = gp.ArrayKey("RAW") - labels = gp.ArrayKey("LABELS") +def test_squeeze(): + raw = gp.ArrayKey("RAW") + labels = gp.ArrayKey("LABELS") - voxel_size = gp.Coordinate((50, 5, 5)) - input_voxels = gp.Coordinate((5, 5, 5)) - input_size = input_voxels * voxel_size + voxel_size = gp.Coordinate((50, 5, 5)) + input_voxels = gp.Coordinate((5, 5, 5)) + input_size = input_voxels * voxel_size - request = gp.BatchRequest() - request.add(raw, input_size) - request.add(labels, input_size) + request = gp.BatchRequest() + request.add(raw, input_size) + request.add(labels, input_size) - pipeline = ( - ExampleSourceSqueeze(voxel_size) - + gp.Squeeze([raw], axis=1) - + gp.Squeeze([raw, labels]) - ) + pipeline = ( + ExampleSourceSqueeze(voxel_size, raw, labels) + + gp.Squeeze([raw], axis=1) + + gp.Squeeze([raw, labels]) + ) + + with gp.build(pipeline) as p: + batch = p.request_batch(request) + assert batch[raw].data.shape == input_voxels + assert batch[labels].data.shape == input_voxels - with gp.build(pipeline) as p: - batch = p.request_batch(request) - assert batch[raw].data.shape == input_voxels - assert batch[labels].data.shape == input_voxels - def test_squeeze_not_possible(self): - raw = gp.ArrayKey("RAW") - labels = gp.ArrayKey("LABELS") +def test_squeeze_not_possible(): + raw = gp.ArrayKey("RAW") + labels = gp.ArrayKey("LABELS") - voxel_size = gp.Coordinate((50, 5, 5)) - input_voxels = gp.Coordinate((5, 5, 5)) - input_size = input_voxels * voxel_size + voxel_size = gp.Coordinate((50, 5, 5)) + input_voxels = gp.Coordinate((5, 5, 5)) + input_size = input_voxels * voxel_size - request = gp.BatchRequest() - request.add(raw, input_size) - request.add(labels, input_size) + request = gp.BatchRequest() + request.add(raw, input_size) + request.add(labels, input_size) - pipeline = ExampleSourceSqueeze(voxel_size) + gp.Squeeze([raw], axis=2) + pipeline = ExampleSourceSqueeze(voxel_size, raw, labels) + gp.Squeeze([raw], axis=2) - with self.assertRaises(gp.PipelineRequestError): - with gp.build(pipeline) as p: - batch = p.request_batch(request) + with pytest.raises(gp.PipelineRequestError): + with gp.build(pipeline) as p: + p.request_batch(request) diff --git a/tests/cases/tensorflow_train.py b/tests/cases/tensorflow_train.py index 079be0d3..9fc8f9cb 100644 --- a/tests/cases/tensorflow_train.py +++ b/tests/cases/tensorflow_train.py @@ -1,23 +1,27 @@ -from .provider_test import ProviderTest +import multiprocessing + +import numpy as np +import pytest + from gunpowder import ( - ArraySpec, - ArrayKeys, - ArrayKey, Array, - Roi, - BatchProvider, + ArrayKey, + ArraySpec, Batch, + BatchProvider, BatchRequest, + Roi, build, ) -from gunpowder.ext import tensorflow, NoSuchModule +from gunpowder.ext import NoSuchModule, tensorflow from gunpowder.tensorflow import Train -import multiprocessing -import numpy as np -from unittest import skipIf class ExampleTensorflowTrainSource(BatchProvider): + def __init__(self, a_key, b_key): + self.a_key = a_key + self.b_key = b_key + def setup(self): spec = ArraySpec( roi=Roi((0, 0), (2, 2)), @@ -25,149 +29,151 @@ def setup(self): interpolatable=True, voxel_size=(1, 1), ) - self.provides(ArrayKeys.A, spec) - self.provides(ArrayKeys.B, spec) + self.provides(self.a_key, spec) + self.provides(self.b_key, spec) def provide(self, request): batch = Batch() - spec = self.spec[ArrayKeys.A] - spec.roi = request[ArrayKeys.A].roi + spec = self.spec[self.a_key] + spec.roi = request[self.a_key].roi - batch.arrays[ArrayKeys.A] = Array( + batch.arrays[self.a_key] = Array( np.array([[0, 1], [2, 3]], dtype=np.float32), spec ) - spec = self.spec[ArrayKeys.B] - spec.roi = request[ArrayKeys.B].roi + spec = self.spec[self.b_key] + spec.roi = request[self.b_key].roi - batch.arrays[ArrayKeys.B] = Array( + batch.arrays[self.b_key] = Array( np.array([[0, 1], [2, 3]], dtype=np.float32), spec ) return batch -@skipIf(isinstance(tensorflow, NoSuchModule), "tensorflow is not installed") -class TestTensorflowTrain(ProviderTest): - def create_meta_graph(self, meta_base): - """ - - :param meta_base: Base name (no extension) for meta graph path - :return: - """ - - def mknet(): - import tensorflow as tf +def create_meta_graph(meta_base): + """ - # create a tf graph - a = tf.placeholder(tf.float32, shape=(2, 2)) - b = tf.placeholder(tf.float32, shape=(2, 2)) - v = tf.Variable(1, dtype=tf.float32) - c = a * b * v + :param meta_base: Base name (no extension) for meta graph path + :return: + """ - # dummy "loss" - loss = tf.norm(c) + def mknet(): + import tensorflow as tf - # dummy optimizer - opt = tf.train.AdamOptimizer() - optimizer = opt.minimize(loss) + # create a tf graph + a = tf.placeholder(tf.float32, shape=(2, 2)) + b = tf.placeholder(tf.float32, shape=(2, 2)) + v = tf.Variable(1, dtype=tf.float32) + c = a * b * v - tf.train.export_meta_graph(filename=meta_base + ".meta") + # dummy "loss" + loss = tf.norm(c) - with open(meta_base + ".names", "w") as f: - for x in [a, b, c, optimizer, loss]: - f.write(x.name + "\n") + # dummy optimizer + opt = tf.train.AdamOptimizer() + optimizer = opt.minimize(loss) - mknet_proc = multiprocessing.Process(target=mknet) - mknet_proc.start() - mknet_proc.join() + tf.train.export_meta_graph(filename=meta_base + ".meta") - with open(meta_base + ".names") as f: - names = [line.strip("\n") for line in f] + with open(meta_base + ".names", "w") as f: + for x in [a, b, c, optimizer, loss]: + f.write(x.name + "\n") - return names + mknet_proc = multiprocessing.Process(target=mknet) + mknet_proc.start() + mknet_proc.join() - def test_output(self): - meta_base = self.path_to("tf_graph") + with open(meta_base + ".names") as f: + names = [line.strip("\n") for line in f] - ArrayKey("A") - ArrayKey("B") - ArrayKey("C") - ArrayKey("GRADIENT_A") + return names - # create model meta graph file and get input/output names - (a, b, c, optimizer, loss) = self.create_meta_graph(meta_base) - source = ExampleTensorflowTrainSource() - train = Train( - meta_base, - optimizer=optimizer, - loss=loss, - inputs={a: ArrayKeys.A, b: ArrayKeys.B}, - outputs={c: ArrayKeys.C}, - gradients={a: ArrayKeys.GRADIENT_A}, - save_every=100, - ) - pipeline = source + train - - request = BatchRequest( - { - ArrayKeys.A: ArraySpec(roi=Roi((0, 0), (2, 2))), - ArrayKeys.B: ArraySpec(roi=Roi((0, 0), (2, 2))), - ArrayKeys.C: ArraySpec(roi=Roi((0, 0), (2, 2))), - ArrayKeys.GRADIENT_A: ArraySpec(roi=Roi((0, 0), (2, 2))), - } - ) - - # train for a couple of iterations - with build(pipeline): +@pytest.mark.skipif( + isinstance(tensorflow, NoSuchModule), reason="tensorflow is not installed" +) +def test_output(tmpdir): + meta_base = tmpdir / "tf_graph" + + a_key = ArrayKey("A") + b_key = ArrayKey("B") + c_key = ArrayKey("C") + a_grad_key = ArrayKey("GRADIENT_A") + + # create model meta graph file and get input/output names + (a, b, c, optimizer, loss) = create_meta_graph(meta_base) + + source = ExampleTensorflowTrainSource() + train = Train( + meta_base, + optimizer=optimizer, + loss=loss, + inputs={a: a_key, b: b_key}, + outputs={c: c_key}, + gradients={a: a_grad_key}, + save_every=100, + ) + pipeline = source + train + + request = BatchRequest( + { + a_key: ArraySpec(roi=Roi((0, 0), (2, 2))), + b_key: ArraySpec(roi=Roi((0, 0), (2, 2))), + c_key: ArraySpec(roi=Roi((0, 0), (2, 2))), + a_grad_key: ArraySpec(roi=Roi((0, 0), (2, 2))), + } + ) + + # train for a couple of iterations + with build(pipeline): + batch = pipeline.request_batch(request) + + assert abs(batch.loss - 9.8994951) < 1e-3 + + gradient_a = batch.arrays[a_grad_key].data + assert gradient_a[0, 0] < gradient_a[0, 1] + assert gradient_a[0, 1] < gradient_a[1, 0] + assert gradient_a[1, 0] < gradient_a[1, 1] + + for i in range(200 - 1): + loss1 = batch.loss batch = pipeline.request_batch(request) + loss2 = batch.loss + assert loss2 < loss1 - self.assertAlmostEqual(batch.loss, 9.8994951) - - gradient_a = batch.arrays[ArrayKeys.GRADIENT_A].data - self.assertTrue(gradient_a[0, 0] < gradient_a[0, 1]) - self.assertTrue(gradient_a[0, 1] < gradient_a[1, 0]) - self.assertTrue(gradient_a[1, 0] < gradient_a[1, 1]) - - for i in range(200 - 1): - loss1 = batch.loss - batch = pipeline.request_batch(request) - loss2 = batch.loss - self.assertLess(loss2, loss1) - - # resume training - with build(pipeline): - for i in range(100): - loss1 = batch.loss - batch = pipeline.request_batch(request) - loss2 = batch.loss - self.assertLess(loss2, loss1) - - # predict - # source = ExampleTensorflowTrainSource() - # predict = Predict( - # meta_base + '_checkpoint_300', - # inputs={a: ArrayKeys.A, b: ArrayKeys.B}, - # outputs={c: ArrayKeys.C}, - # max_shared_memory=1024*1024) - # pipeline = source + predict - - # request = BatchRequest({ - # ArrayKeys.A: ArraySpec(roi=Roi((0, 0), (2, 2))), - # ArrayKeys.B: ArraySpec(roi=Roi((0, 0), (2, 2))), - # ArrayKeys.C: ArraySpec(roi=Roi((0, 0), (2, 2))), - # }) - - # with build(pipeline): - - # prev_c = None - - # for i in range(100): - # batch = pipeline.request_batch(request) - # c = batch.arrays[ArrayKeys.C].data - - # if prev_c is not None: - # self.assertTrue(np.equal(c, prev_c)) - # prev_c = c + # resume training + with build(pipeline): + for i in range(100): + loss1 = batch.loss + batch = pipeline.request_batch(request) + loss2 = batch.loss + assert loss2 < loss1 + + # predict + # source = ExampleTensorflowTrainSource() + # predict = Predict( + # meta_base + '_checkpoint_300', + # inputs={a: a_key, b: b_key}, + # outputs={c: c_key}, + # max_shared_memory=1024*1024) + # pipeline = source + predict + + # request = BatchRequest({ + # a_key: ArraySpec(roi=Roi((0, 0), (2, 2))), + # b_key: ArraySpec(roi=Roi((0, 0), (2, 2))), + # c_key: ArraySpec(roi=Roi((0, 0), (2, 2))), + # }) + + # with build(pipeline): + + # prev_c = None + + # for i in range(100): + # batch = pipeline.request_batch(request) + # c = batch.arrays[c_key].data + + # if prev_c is not None: + # assert (np.equal(c, prev_c)) + # prev_c = c diff --git a/tests/cases/torch_train.py b/tests/cases/torch_train.py index 0196c67d..8fb9e8ec 100644 --- a/tests/cases/torch_train.py +++ b/tests/cases/torch_train.py @@ -1,22 +1,24 @@ -from .helper_sources import ArraySource +import logging +from unittest import skipIf + +import numpy as np +import pytest + from gunpowder import ( - BatchRequest, + Array, + ArrayKey, ArraySpec, + BatchRequest, + MergeProvider, + PreCache, Roi, - ArrayKey, - Array, Scan, - PreCache, - MergeProvider, build, ) -from gunpowder.ext import torch, NoSuchModule -from gunpowder.torch import Train, Predict -from unittest import skipIf -import numpy as np -import pytest +from gunpowder.ext import NoSuchModule, torch +from gunpowder.torch import Predict, Train -import logging +from .helper_sources import ArraySource TORCH_AVAILABLE = isinstance(torch, NoSuchModule) diff --git a/tests/cases/unsqueeze.py b/tests/cases/unsqueeze.py index 413d1a28..cc3f00ab 100644 --- a/tests/cases/unsqueeze.py +++ b/tests/cases/unsqueeze.py @@ -1,12 +1,16 @@ import copy + import numpy as np -import gunpowder as gp +import pytest -from .provider_test import ProviderTest +import gunpowder as gp class ExampleSourceUnsqueeze(gp.BatchProvider): - def __init__(self, voxel_size): + def __init__(self, voxel_size, raw_key, labels_key): + self.raw_key = raw_key + self.labels_key = labels_key + self.voxel_size = gp.Coordinate(voxel_size) self.roi = gp.Roi((0, 0, 0), (10, 10, 10)) * self.voxel_size @@ -53,44 +57,46 @@ def provide(self, request): return outputs -class TestUnsqueeze(ProviderTest): - def test_unsqueeze(self): - raw = gp.ArrayKey("RAW") - labels = gp.ArrayKey("LABELS") +def test_unsqueeze(): + raw = gp.ArrayKey("RAW") + labels = gp.ArrayKey("LABELS") - voxel_size = gp.Coordinate((50, 5, 5)) - input_voxels = gp.Coordinate((10, 10, 10)) - input_size = input_voxels * voxel_size + voxel_size = gp.Coordinate((50, 5, 5)) + input_voxels = gp.Coordinate((10, 10, 10)) + input_size = input_voxels * voxel_size - request = gp.BatchRequest() - request.add(raw, input_size) - request.add(labels, input_size) + request = gp.BatchRequest() + request.add(raw, input_size) + request.add(labels, input_size) - pipeline = ( - ExampleSourceUnsqueeze(voxel_size) - + gp.Unsqueeze([raw, labels]) - + gp.Unsqueeze([raw], axis=1) - ) + pipeline = ( + ExampleSourceUnsqueeze(voxel_size, raw, labels) + + gp.Unsqueeze([raw, labels]) + + gp.Unsqueeze([raw], axis=1) + ) - with gp.build(pipeline) as p: - batch = p.request_batch(request) - assert batch[raw].data.shape == (1,) + (1,) + input_voxels - assert batch[labels].data.shape == (1,) + input_voxels + with gp.build(pipeline) as p: + batch = p.request_batch(request) + assert batch[raw].data.shape == (1,) + (1,) + input_voxels + assert batch[labels].data.shape == (1,) + input_voxels - def test_unsqueeze_not_possible(self): - raw = gp.ArrayKey("RAW") - labels = gp.ArrayKey("LABELS") - voxel_size = gp.Coordinate((50, 5, 5)) - input_voxels = gp.Coordinate((5, 5, 5)) - input_size = input_voxels * voxel_size +def test_unsqueeze_not_possible(): + raw = gp.ArrayKey("RAW") + labels = gp.ArrayKey("LABELS") - request = gp.BatchRequest() - request.add(raw, input_size) - request.add(labels, input_size) + voxel_size = gp.Coordinate((50, 5, 5)) + input_voxels = gp.Coordinate((5, 5, 5)) + input_size = input_voxels * voxel_size - pipeline = ExampleSourceUnsqueeze(voxel_size) + gp.Unsqueeze([raw], axis=1) + request = gp.BatchRequest() + request.add(raw, input_size) + request.add(labels, input_size) - with self.assertRaises(gp.PipelineRequestError): - with gp.build(pipeline) as p: - batch = p.request_batch(request) + pipeline = ExampleSourceUnsqueeze(voxel_size, raw, labels) + gp.Unsqueeze( + [raw], axis=1 + ) + + with pytest.raises(gp.PipelineRequestError): + with gp.build(pipeline) as p: + p.request_batch(request) diff --git a/tests/cases/update_with.py b/tests/cases/update_with.py index 82cf4535..e8f7675e 100644 --- a/tests/cases/update_with.py +++ b/tests/cases/update_with.py @@ -1,21 +1,21 @@ import numpy as np +import pytest from gunpowder import ( - BatchProvider, - BatchFilter, Array, - ArraySpec, ArrayKey, - Graph, - GraphSpec, - GraphKey, + ArraySpec, Batch, + BatchFilter, + BatchProvider, BatchRequest, - Roi, + Graph, + GraphKey, + GraphSpec, PipelineRequestError, + Roi, build, ) -import pytest class ArrayTestSource(BatchProvider): diff --git a/tests/cases/upsample.py b/tests/cases/upsample.py index 47f08b44..c9fbc6ba 100644 --- a/tests/cases/upsample.py +++ b/tests/cases/upsample.py @@ -1,7 +1,18 @@ -from .helper_sources import ArraySource -from gunpowder import * import numpy as np +from gunpowder import ( + Array, + ArrayKey, + ArraySpec, + BatchRequest, + MergeProvider, + Roi, + UpSample, + build, +) + +from .helper_sources import ArraySource + def test_output(): raw = ArrayKey("RAW") diff --git a/tests/cases/zarr_read_write.py b/tests/cases/zarr_read_write.py index c6cdb39b..348d4a1d 100644 --- a/tests/cases/zarr_read_write.py +++ b/tests/cases/zarr_read_write.py @@ -1,10 +1,22 @@ -from .helper_sources import ArraySource +import numpy as np +import pytest -from gunpowder import * -from gunpowder.ext import zarr, NoSuchModule +from gunpowder import ( + Array, + ArrayKey, + ArraySpec, + BatchRequest, + Coordinate, + MergeProvider, + Roi, + Scan, + ZarrSource, + ZarrWrite, + build, +) +from gunpowder.ext import NoSuchModule, zarr -import pytest -import numpy as np +from .helper_sources import ArraySource @pytest.mark.skipif(isinstance(zarr, NoSuchModule), reason="zarr is not installed") diff --git a/tests/conftest.py b/tests/conftest.py index 1386c6b8..723f8535 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,15 +1,14 @@ -import pytest - import multiprocessing as mp +import pytest + # cannot parametrize unittest.TestCase. We should test both # fork and spawn but I'm not sure how to. -# @pytest.fixture(params=["fork", "spawn"], autouse=True) -# @pytest.fixture(autouse=True) -# def context(monkeypatch): -# ctx = mp.get_context("spawn") -# monkeypatch.setattr(mp, "Queue", ctx.Queue) -# monkeypatch.setattr(mp, "Process", ctx.Process) -# monkeypatch.setattr(mp, "Event", ctx.Event) -# monkeypatch.setattr(mp, "Value", ctx.Value) +@pytest.fixture(params=["fork", "spawn"], autouse=True) +def context(monkeypatch): + ctx = mp.get_context("spawn") + monkeypatch.setattr(mp, "Queue", ctx.Queue) + monkeypatch.setattr(mp, "Process", ctx.Process) + monkeypatch.setattr(mp, "Event", ctx.Event) + monkeypatch.setattr(mp, "Value", ctx.Value)