Skip to content

Commit

Permalink
Feat whole brain pipeline (#234)
Browse files Browse the repository at this point in the history
* feat: pipeline routine

* feat: improved config and full pipeline works

* refactor: simplified and improved

* refactor: simplified saving results

---------

Co-authored-by: anna-grim <[email protected]>
  • Loading branch information
anna-grim and anna-grim authored Sep 17, 2024
1 parent 0feba76 commit 503094a
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 33 deletions.
2 changes: 0 additions & 2 deletions src/deep_neurographs/intake.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
"""

from time import time

from deep_neurographs.neurograph import NeuroGraph
from deep_neurographs.utils import graph_util as gutil
from deep_neurographs.utils import img_util, swc_util
Expand Down
117 changes: 86 additions & 31 deletions src/deep_neurographs/run_graphtrace_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from deep_neurographs.graph_artifact_removal import remove_doubles
from deep_neurographs.intake import GraphBuilder
from deep_neurographs.machine_learning.inference import InferenceEngine
from deep_neurographs.utils import io_util, util
from deep_neurographs.utils import util


class GraphTracePipeline:
Expand Down Expand Up @@ -88,14 +88,15 @@ def __init__(
self.output_dir = f"{output_dir}/{pred_id}-{date}"
util.mkdir(self.output_dir, delete=True)

# --- Core ---
def run(self, fragments_pointer):
"""
Executes the full inference pipeline.
Parameters
----------
fragments_pointer : dict, list, str
Pointer to swc files used to build an instance of FragmentsGraph,
Pointer to swc files used to build an instance of FragmentGraph,
see "swc_util.Reader" for further documentation.
Returns
Expand All @@ -109,6 +110,7 @@ def run(self, fragments_pointer):
print("Dataset:", self.dataset)
print("Pred_Id:", self.pred_id)
print("")
self.write_metadata()
t0 = time()

self.build_graph(fragments_pointer)
Expand All @@ -127,15 +129,15 @@ def build_graph(self, fragments_pointer):
Parameters
----------
fragment_pointer : dict, list, str
Pointer to swc files used to build an instance of FragmentsGraph,
Pointer to swc files used to build an instance of FragmentGraph,
see "swc_util.Reader" for further documentation.
Returns
-------
None
"""
print("1. Building FragmentsGraph...")
print("(1) Building FragmentGraph")
t0 = time()

# Initialize Graph
Expand All @@ -161,21 +163,9 @@ def build_graph(self, fragments_pointer):
print(f"Module Runtime: {round(t, 4)} {unit}\n")
self.print_graph_overview()

def print_graph_overview(self):
# Compute values
n_components = nx.number_connected_components(self.graph)
usage = round(util.get_memory_usage(), 2)

# Print overview
print("Graph Overview...")
print("# connected components:", util.reformat_number(n_components))
print("# nodes:", util.reformat_number(self.graph.number_of_nodes()))
print("# edges:", util.reformat_number(self.graph.number_of_edges()))
print(f"Memory Consumption: {usage} GBs\n")

def generate_proposals(self):
"""
Generates proposals for the fragments graph based on the specified
Generates proposals for the fragment graph based on the specified
configuration.
Parameters
Expand All @@ -187,7 +177,7 @@ def generate_proposals(self):
None
"""
print("2. Generate Proposals")
print("(2) Generate Proposals")
t0 = time()
self.graph.generate_proposals(
self.graph_config.search_radius,
Expand Down Expand Up @@ -217,7 +207,7 @@ def run_inference(self):
None
"""
print("3. Run Inference")
print("(3) Run Inference")
t0 = time()
inference_engine = InferenceEngine(
self.img_path,
Expand All @@ -227,31 +217,71 @@ def run_inference(self):
confidence_threshold=self.ml_config.threshold,
downsample_factor=self.ml_config.downsample_factor,
)
self.graph, self.proposal_preds = inference_engine.run(
self.graph, self.accepted_proposals = inference_engine.run(
self.graph, self.graph.list_proposals()
)

t, unit = util.time_writer(time() - t0)
print(f"Module Runtime: {round(t, 4)} {unit}\n")

def save_results(self):
print("4. Save Predictions")
t0 = time()
io_util.save_prediction(
self.graph, self.proposal_preds, self.output_dir
)
t, unit = util.time_writer(time() - t0)
print(f"Module Runtime: {round(t, 4)} {unit}\n")
"""
Saves the processed results from running the inference pipeline,
namely the corrected swc files and a list of the merged swc ids.
Parameters
----------
None
Returns
-------
None
"""
print("(4) Saving Results")
path = os.path.join(self.output_dir, "corrected-processed-swcs.zip")
self.graph.to_zipped_swcs(path)
self.save_connections()

# --- io ---
def save_connections(self):
"""
Saves predicted connections between connected components in a txt file.
Parameters
----------
None
Returns
-------
None
"""
path = os.path.join(self.output_dir, "connections.txt")
with open(path, "w") as f:
for id_1, id_2 in self.graph.merged_ids:
f.write(f"{id_1}, {id_2}" + "\n")

def write_metadata(self):
"""
Writes metadata about the current pipeline run to a JSON file.
Parameters
----------
None
Returns
-------
None
"""
metadata = {
"date": self.date,
"date": datetime.today().strftime("%Y-%m-%d"),
"dataset": self.dataset,
"pred_id": self.pred_id,
"min_fragment_input_size": f"{self.graph_config.min_size}um",
"min_fragment_output_size": f"{self.graph_config.min_size}um",
"min_fragment_size": f"{self.graph_config.min_size}um",
"model_type": self.ml_config.model_type,
"model_name": self.model_name,
"model_name": os.path.basename(self.model_path),
"complex_proposals": self.graph_config.complex_bool,
"long_range_bool": self.graph_config.long_range_bool,
"proposals_per_leaf": self.graph_config.proposals_per_leaf,
Expand All @@ -263,3 +293,28 @@ def write_metadata(self):
}
path = os.path.join(self.output_dir, "metadata.json")
util.write_json(path, metadata)

# --- Summaries ---
def print_graph_overview(self):
"""
Prints an overview of the graph's structure and memory usage.
Parameters
----------
None
Returns
-------
None
"""
# Compute values
n_components = nx.number_connected_components(self.graph)
usage = round(util.get_memory_usage(), 2)

# Print overview
print("Graph Overview...")
print("# connected components:", util.reformat_number(n_components))
print("# nodes:", util.reformat_number(self.graph.number_of_nodes()))
print("# edges:", util.reformat_number(self.graph.number_of_edges()))
print(f"Memory Consumption: {usage} GBs\n")

0 comments on commit 503094a

Please sign in to comment.