diff --git a/src/deep_neurographs/fragment_filtering.py b/src/deep_neurographs/fragment_filtering.py index ae6b931..362360b 100644 --- a/src/deep_neurographs/fragment_filtering.py +++ b/src/deep_neurographs/fragment_filtering.py @@ -10,12 +10,13 @@ """ from collections import defaultdict +from tqdm import tqdm import networkx as nx import numpy as np -from tqdm import tqdm from deep_neurographs import geometry +from deep_neurographs.utils import util QUERY_DIST = 15 @@ -46,15 +47,14 @@ def remove_curvy(fragments_graph, max_length, ratio=0.5): """ deleted_ids = set() - components = get_line_components(fragments_graph) - for nodes in tqdm(components, desc="Filter Curvy Fragments"): + for nodes in get_line_components(fragments_graph): i, j = tuple(nodes) length = fragments_graph.edges[i, j]["length"] endpoint_dist = fragments_graph.dist(i, j) if endpoint_dist / length < ratio and length < max_length: deleted_ids.add(fragments_graph.edges[i, j]["swc_id"]) delete_fragment(fragments_graph, i, j) - return len(deleted_ids) + return util.reformat_number(len(deleted_ids)) # --- Doubles Removal --- @@ -96,7 +96,7 @@ def remove_doubles(fragments_graph, max_length, node_spacing): if check_doubles_criteria(hits, n_points): delete_fragment(fragments_graph, i, j) deleted_ids.add(swc_id) - return len(deleted_ids) + return util.reformat_number(len(deleted_ids)) def compute_projections(fragments_graph, kdtree, edge): diff --git a/src/deep_neurographs/inference.py b/src/deep_neurographs/inference.py index 1e559a8..893afbb 100644 --- a/src/deep_neurographs/inference.py +++ b/src/deep_neurographs/inference.py @@ -29,9 +29,6 @@ from deep_neurographs.utils.gnn_util import toCPU from deep_neurographs.utils.graph_util import GraphLoader -BATCH_SIZE = 2000 -CONFIDENCE_THRESHOLD = 0.7 - class InferencePipeline: """ @@ -132,9 +129,9 @@ def __init__( self.model_path, self.ml_config.model_type, self.graph_config.search_radius, + accept_threshold=self.ml_config.threshold, anisotropy=self.ml_config.anisotropy, batch_size=self.ml_config.batch_size, - confidence_threshold=self.ml_config.threshold, device=device, multiscale=self.ml_config.multiscale, labels_path=labels_path, @@ -178,21 +175,27 @@ def run(self, fragments_pointer): # Finish self.report("Final Graph...") self.report_graph() - t, unit = util.time_writer(time() - t0) self.report(f"Total Runtime: {round(t, 4)} {unit}\n") def run_schedule(self, fragments_pointer, radius_schedule): - t0 = time() + # Initializations self.log_experiment() + self.write_metadata() + t0 = time() + + # Main self.build_graph(fragments_pointer) for round_id, radius in enumerate(radius_schedule): - self.report(f"--- Round {round_id + 1}: Radius = {radius} ---") round_id += 1 + self.report(f"--- Round {round_id}: Radius = {radius} ---") self.generate_proposals(radius) self.run_inference() self.save_results(round_id=round_id) + # Finish + self.report("Final Graph...") + self.report_graph() t, unit = util.time_writer(time() - t0) self.report(f"Total Runtime: {round(t, 4)} {unit}\n") @@ -212,7 +215,7 @@ def build_graph(self, fragments_pointer): None """ - self.report("(1) Building FragmentGraph") + self.report("Step 1: Building FragmentGraph") t0 = time() # Initialize Graph @@ -233,31 +236,27 @@ def build_graph(self, fragments_pointer): self.graph.save_labels(labels_path) self.report(f"# SWCs Saved: {n_saved}") - # Report runtime + # Report results t, unit = util.time_writer(time() - t0) self.report(f"Module Runtime: {round(t, 4)} {unit}") - - # Report graph overview self.report("\nInitial Graph...") self.report_graph() def filter_fragments(self): - # Filter curvy fragments + # Curvy fragments n_curvy = fragment_filtering.remove_curvy(self.graph, 200) - n_curvy = util.reformat_number(n_curvy) - # Filter doubles + # Double fragments if self.graph_config.remove_doubles_bool: n_doubles = fragment_filtering.remove_doubles( self.graph, 200, self.graph_config.node_spacing ) - n_doubles = util.reformat_number(n_doubles) self.report(f"# Double Fragments Deleted: {n_doubles}") self.report(f"# Curvy Fragments Deleted: {n_curvy}") def generate_proposals(self, radius=None): """ - Generates proposals for the fragment graph based on the specified + Generates proposals for the fragments graph based on the specified configuration. Parameters @@ -270,7 +269,7 @@ def generate_proposals(self, radius=None): """ # Initializations - self.report("(2) Generate Proposals") + self.report("Step 2: Generate Proposals") if radius is None: radius = self.graph_config.search_radius @@ -307,17 +306,21 @@ def run_inference(self): None """ - self.report("(3) Run Inference") + # Initializations + self.report("Step 3: Run Inference") + proposals = self.graph.list_proposals() + n_proposals = max(len(proposals), 1) + + # Main t0 = time() - n_proposals = max(self.graph.n_proposals(), 1) - self.graph, accepts = self.inference_engine.run( - self.graph, self.graph.list_proposals() - ) + self.graph, accepts = self.inference_engine.run(self.graph, proposals) self.accepted_proposals.extend(accepts) - self.report(f"# Accepted: {util.reformat_number(len(accepts))}") - self.report(f"% Accepted: {round(len(accepts) / n_proposals, 4)}") + # Report results t, unit = util.time_writer(time() - t0) + n_accepts = len(self.accepted_proposals) + self.report(f"# Accepted: {util.reformat_number(n_accepts)}") + self.report(f"% Accepted: {round(n_accepts / n_proposals, 4)}") self.report(f"Module Runtime: {round(t, 4)} {unit}\n") def save_results(self, round_id=None): @@ -334,7 +337,7 @@ def save_results(self, round_id=None): None """ - # Save result locally + # Save result on local machine suffix = f"-{round_id}" if round_id else "" filename = f"corrected-processed-swcs{suffix}.zip" path = os.path.join(self.output_dir, filename) @@ -342,7 +345,7 @@ def save_results(self, round_id=None): self.save_connections(round_id=round_id) self.write_metadata() - # Save result on s3 + # Save result on s3 (if applicable) filename = f"corrected-processed-swcs-s3.zip" path = os.path.join(self.output_dir, filename) self.graph.to_zipped_swcs(path, min_size=50) @@ -373,7 +376,8 @@ def save_to_s3(self): # --- io --- def save_connections(self, round_id=None): """ - Saves predicted connections between connected components in a txt file. + Writes the accepted proposals from the graph to a text file. Each line + contains the two swc ids as comma separated values. Parameters ---------- @@ -414,7 +418,7 @@ def write_metadata(self): "long_range_bool": self.graph_config.long_range_bool, "proposals_per_leaf": self.graph_config.proposals_per_leaf, "search_radius": f"{self.graph_config.search_radius}um", - "confidence_threshold": self.ml_config.threshold, + "accept_threshold": self.ml_config.threshold, "node_spacing": self.graph_config.node_spacing, "remove_doubles": self.graph_config.remove_doubles_bool, } @@ -475,9 +479,9 @@ def __init__( model_path, model_type, radius, + accept_threshold=0.7, anisotropy=[1.0, 1.0, 1.0], - batch_size=BATCH_SIZE, - confidence_threshold=CONFIDENCE_THRESHOLD, + batch_size=2000, device=None, multiscale=1, labels_path=None, @@ -490,22 +494,27 @@ def __init__( Parameters ---------- img_path : str - Path to image stored in a GCS bucket. + Path to image. model_path : str - Path to machine learning model parameters. + Path to machine learning model weights. model_type : str Type of machine learning model used to perform inference. radius : float Search radius used to generate proposals. + accept_threshold : float, optional + Threshold for accepting proposals, where proposals with predicted + likelihood above this threshold are accepted. The default is 0.7. + anisotropy : List[float], optional + ... batch_size : int, optional - Number of proposals to generate features and classify per batch. - The default is the global varaible "BATCH_SIZE". - confidence_threshold : float, optional - Threshold on acceptance probability for proposals. The default is - the global variable "CONFIDENCE_THRESHOLD". + Number of proposals to classify in each batch.The default is 2000. multiscale : int, optional Level in the image pyramid that voxel coordinates must index into. The default is 1. + labels_path : str or None, optional + ... + is_multimodal : bool, optional + ... Returns ------- @@ -517,7 +526,7 @@ def __init__( self.device = "cpu" if device is None else device self.is_gnn = True if "Graph" in model_type else False self.radius = radius - self.threshold = confidence_threshold + self.threshold = accept_threshold # Features self.feature_generator = FeatureGenerator(