Skip to content

Commit

Permalink
refactor: improve inference clarity
Browse files Browse the repository at this point in the history
  • Loading branch information
anna-grim committed Jan 16, 2025
1 parent f7b924d commit c3b20d8
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 43 deletions.
10 changes: 5 additions & 5 deletions src/deep_neurographs/fragment_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@
"""

from collections import defaultdict
from tqdm import tqdm

import networkx as nx
import numpy as np
from tqdm import tqdm

from deep_neurographs import geometry
from deep_neurographs.utils import util

QUERY_DIST = 15

Expand Down Expand Up @@ -46,15 +47,14 @@ def remove_curvy(fragments_graph, max_length, ratio=0.5):
"""
deleted_ids = set()
components = get_line_components(fragments_graph)
for nodes in tqdm(components, desc="Filter Curvy Fragments"):
for nodes in get_line_components(fragments_graph):
i, j = tuple(nodes)
length = fragments_graph.edges[i, j]["length"]
endpoint_dist = fragments_graph.dist(i, j)
if endpoint_dist / length < ratio and length < max_length:
deleted_ids.add(fragments_graph.edges[i, j]["swc_id"])
delete_fragment(fragments_graph, i, j)
return len(deleted_ids)
return util.reformat_number(len(deleted_ids))


# --- Doubles Removal ---
Expand Down Expand Up @@ -96,7 +96,7 @@ def remove_doubles(fragments_graph, max_length, node_spacing):
if check_doubles_criteria(hits, n_points):
delete_fragment(fragments_graph, i, j)
deleted_ids.add(swc_id)
return len(deleted_ids)
return util.reformat_number(len(deleted_ids))


def compute_projections(fragments_graph, kdtree, edge):
Expand Down
85 changes: 47 additions & 38 deletions src/deep_neurographs/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,6 @@
from deep_neurographs.utils.gnn_util import toCPU
from deep_neurographs.utils.graph_util import GraphLoader

BATCH_SIZE = 2000
CONFIDENCE_THRESHOLD = 0.7


class InferencePipeline:
"""
Expand Down Expand Up @@ -132,9 +129,9 @@ def __init__(
self.model_path,
self.ml_config.model_type,
self.graph_config.search_radius,
accept_threshold=self.ml_config.threshold,
anisotropy=self.ml_config.anisotropy,
batch_size=self.ml_config.batch_size,
confidence_threshold=self.ml_config.threshold,
device=device,
multiscale=self.ml_config.multiscale,
labels_path=labels_path,
Expand Down Expand Up @@ -178,21 +175,27 @@ def run(self, fragments_pointer):
# Finish
self.report("Final Graph...")
self.report_graph()

t, unit = util.time_writer(time() - t0)
self.report(f"Total Runtime: {round(t, 4)} {unit}\n")

def run_schedule(self, fragments_pointer, radius_schedule):
t0 = time()
# Initializations
self.log_experiment()
self.write_metadata()
t0 = time()

# Main
self.build_graph(fragments_pointer)
for round_id, radius in enumerate(radius_schedule):
self.report(f"--- Round {round_id + 1}: Radius = {radius} ---")
round_id += 1
self.report(f"--- Round {round_id}: Radius = {radius} ---")
self.generate_proposals(radius)
self.run_inference()
self.save_results(round_id=round_id)

# Finish
self.report("Final Graph...")
self.report_graph()
t, unit = util.time_writer(time() - t0)
self.report(f"Total Runtime: {round(t, 4)} {unit}\n")

Expand All @@ -212,7 +215,7 @@ def build_graph(self, fragments_pointer):
None
"""
self.report("(1) Building FragmentGraph")
self.report("Step 1: Building FragmentGraph")
t0 = time()

# Initialize Graph
Expand All @@ -233,31 +236,27 @@ def build_graph(self, fragments_pointer):
self.graph.save_labels(labels_path)
self.report(f"# SWCs Saved: {n_saved}")

# Report runtime
# Report results
t, unit = util.time_writer(time() - t0)
self.report(f"Module Runtime: {round(t, 4)} {unit}")

# Report graph overview
self.report("\nInitial Graph...")
self.report_graph()

def filter_fragments(self):
# Filter curvy fragments
# Curvy fragments
n_curvy = fragment_filtering.remove_curvy(self.graph, 200)
n_curvy = util.reformat_number(n_curvy)

# Filter doubles
# Double fragments
if self.graph_config.remove_doubles_bool:
n_doubles = fragment_filtering.remove_doubles(
self.graph, 200, self.graph_config.node_spacing
)
n_doubles = util.reformat_number(n_doubles)
self.report(f"# Double Fragments Deleted: {n_doubles}")
self.report(f"# Curvy Fragments Deleted: {n_curvy}")

def generate_proposals(self, radius=None):
"""
Generates proposals for the fragment graph based on the specified
Generates proposals for the fragments graph based on the specified
configuration.
Parameters
Expand All @@ -270,7 +269,7 @@ def generate_proposals(self, radius=None):
"""
# Initializations
self.report("(2) Generate Proposals")
self.report("Step 2: Generate Proposals")
if radius is None:
radius = self.graph_config.search_radius

Expand Down Expand Up @@ -307,17 +306,21 @@ def run_inference(self):
None
"""
self.report("(3) Run Inference")
# Initializations
self.report("Step 3: Run Inference")
proposals = self.graph.list_proposals()
n_proposals = max(len(proposals), 1)

# Main
t0 = time()
n_proposals = max(self.graph.n_proposals(), 1)
self.graph, accepts = self.inference_engine.run(
self.graph, self.graph.list_proposals()
)
self.graph, accepts = self.inference_engine.run(self.graph, proposals)
self.accepted_proposals.extend(accepts)
self.report(f"# Accepted: {util.reformat_number(len(accepts))}")
self.report(f"% Accepted: {round(len(accepts) / n_proposals, 4)}")

# Report results
t, unit = util.time_writer(time() - t0)
n_accepts = len(self.accepted_proposals)
self.report(f"# Accepted: {util.reformat_number(n_accepts)}")
self.report(f"% Accepted: {round(n_accepts / n_proposals, 4)}")
self.report(f"Module Runtime: {round(t, 4)} {unit}\n")

def save_results(self, round_id=None):
Expand All @@ -334,15 +337,15 @@ def save_results(self, round_id=None):
None
"""
# Save result locally
# Save result on local machine
suffix = f"-{round_id}" if round_id else ""
filename = f"corrected-processed-swcs{suffix}.zip"
path = os.path.join(self.output_dir, filename)
self.graph.to_zipped_swcs(path)
self.save_connections(round_id=round_id)
self.write_metadata()

# Save result on s3
# Save result on s3 (if applicable)
filename = f"corrected-processed-swcs-s3.zip"
path = os.path.join(self.output_dir, filename)
self.graph.to_zipped_swcs(path, min_size=50)
Expand Down Expand Up @@ -373,7 +376,8 @@ def save_to_s3(self):
# --- io ---
def save_connections(self, round_id=None):
"""
Saves predicted connections between connected components in a txt file.
Writes the accepted proposals from the graph to a text file. Each line
contains the two swc ids as comma separated values.
Parameters
----------
Expand Down Expand Up @@ -414,7 +418,7 @@ def write_metadata(self):
"long_range_bool": self.graph_config.long_range_bool,
"proposals_per_leaf": self.graph_config.proposals_per_leaf,
"search_radius": f"{self.graph_config.search_radius}um",
"confidence_threshold": self.ml_config.threshold,
"accept_threshold": self.ml_config.threshold,
"node_spacing": self.graph_config.node_spacing,
"remove_doubles": self.graph_config.remove_doubles_bool,
}
Expand Down Expand Up @@ -475,9 +479,9 @@ def __init__(
model_path,
model_type,
radius,
accept_threshold=0.7,
anisotropy=[1.0, 1.0, 1.0],
batch_size=BATCH_SIZE,
confidence_threshold=CONFIDENCE_THRESHOLD,
batch_size=2000,
device=None,
multiscale=1,
labels_path=None,
Expand All @@ -490,22 +494,27 @@ def __init__(
Parameters
----------
img_path : str
Path to image stored in a GCS bucket.
Path to image.
model_path : str
Path to machine learning model parameters.
Path to machine learning model weights.
model_type : str
Type of machine learning model used to perform inference.
radius : float
Search radius used to generate proposals.
accept_threshold : float, optional
Threshold for accepting proposals, where proposals with predicted
likelihood above this threshold are accepted. The default is 0.7.
anisotropy : List[float], optional
...
batch_size : int, optional
Number of proposals to generate features and classify per batch.
The default is the global varaible "BATCH_SIZE".
confidence_threshold : float, optional
Threshold on acceptance probability for proposals. The default is
the global variable "CONFIDENCE_THRESHOLD".
Number of proposals to classify in each batch.The default is 2000.
multiscale : int, optional
Level in the image pyramid that voxel coordinates must index into.
The default is 1.
labels_path : str or None, optional
...
is_multimodal : bool, optional
...
Returns
-------
Expand All @@ -517,7 +526,7 @@ def __init__(
self.device = "cpu" if device is None else device
self.is_gnn = True if "Graph" in model_type else False
self.radius = radius
self.threshold = confidence_threshold
self.threshold = accept_threshold

# Features
self.feature_generator = FeatureGenerator(
Expand Down

0 comments on commit c3b20d8

Please sign in to comment.