Skip to content

Commit

Permalink
feat: rounds of inference (#235)
Browse files Browse the repository at this point in the history
Co-authored-by: anna-grim <[email protected]>
  • Loading branch information
anna-grim and anna-grim authored Sep 18, 2024
1 parent a7b0299 commit 8d52581
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 106 deletions.
Binary file added imgs/result.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
53 changes: 23 additions & 30 deletions src/deep_neurographs/neurograph.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,9 @@ def add_component(self, irreducibles):
# Edges
for (i, j), attrs in irreducibles["edges"].items():
edge = (ids[i], ids[j])
idxs = np.arange(0, attrs["xyz"].shape[0], self.node_spacing)
if idxs[-1] != attrs["xyz"].shape[0] - 1:
idxs = np.append(idxs, attrs["xyz"].shape[0] - 1)
self.__add_edge(edge, attrs, idxs, swc_id)
idxs = util.spaced_idxs(attrs["radius"], self.node_spacing)
attrs = {key: value[idxs] for key, value in attrs.items()}
self.__add_edge(edge, attrs, swc_id)

def __add_nodes(self, irreducibles, node_type, node_ids):
"""
Expand Down Expand Up @@ -178,7 +177,7 @@ def __add_nodes(self, irreducibles, node_type, node_ids):
node_ids[i] = cur_id
return node_ids

def __add_edge(self, edge, attrs, idxs, swc_id):
def __add_edge(self, edge, attrs, swc_id):
"""
Adds an edge to "self".
Expand All @@ -187,11 +186,7 @@ def __add_edge(self, edge, attrs, idxs, swc_id):
edge : tuple
Edge to be added.
attrs : dict
Dictionary of attributes of "edge" that were obtained from an swc
file.
idxs : dict
Indices of attributes to store in order to reduce the amount of
memory required to store "self".
Dictionary of attributes of "edge" obtained from an swc file.
swc_id : str
swc id corresponding to edge.
Expand All @@ -202,14 +197,9 @@ def __add_edge(self, edge, attrs, idxs, swc_id):
"""
i, j = tuple(edge)
self.add_edge(
i,
j,
radius=attrs["radius"][idxs],
xyz=attrs["xyz"][idxs],
swc_id=swc_id,
i, j, radius=attrs["radius"], xyz=attrs["xyz"], swc_id=swc_id,
)
for xyz in attrs["xyz"][idxs]:
self.xyz_to_edge[tuple(xyz)] = edge
self.xyz_to_edge.update({tuple(xyz): edge for xyz in attrs["xyz"]})

"""
def absorb_node(self, i, nb_1, nb_2):
Expand Down Expand Up @@ -265,10 +255,11 @@ def split_edge(self, edge, attrs, idx):
self.node_cnt += 1

# Create edges
idxs_1 = np.arange(0, idx + 1)
idxs_2 = np.arange(idx, len(attrs["xyz"]))
self.__add_edge((i, node_id), attrs, idxs_1, swc_id)
self.__add_edge((node_id, j), attrs, idxs_2, swc_id)
n = len(attrs["xyz"])
attrs_1 = {k: v[np.arange(0, idx + 1)] for k, v in attrs.items()}
attrs_2 = {k: v[np.arange(idx, n)] for k, v in attrs.items()}
self.__add_edge((i, node_id), attrs_1, swc_id)
self.__add_edge((node_id, j), attrs_2, swc_id)
return node_id

# --- Proposal Generation ---
Expand Down Expand Up @@ -656,16 +647,16 @@ def merge_proposal(self, edge):
swc_id_j = self.nodes[j]["swc_id"]
if not (self.is_soma(i) and self.is_soma(j)):
# Attributes
xyz = np.vstack([self.nodes[i]["xyz"], self.nodes[j]["xyz"]])
radius = np.array(
[self.nodes[i]["radius"], self.nodes[j]["radius"]]
)
attrs = dict()
for k in ["xyz", "radius"]:
combine = np.vstack if k == "xyz" else np.array
attrs[k] = combine([self.nodes[i][k], self.nodes[j][k]])
swc_id = swc_id_i if self.is_soma(i) else swc_id_j

# Update graph
self.merged_ids.add((swc_id_i, swc_id_j))
self.upd_ids(swc_id, j if swc_id == swc_id_i else i)
self.add_edge(i, j, xyz=xyz, radius=radius, swc_id=swc_id)
self.__add_edge((i, j), attrs, swc_id)
if i in self.leafs:
self.leafs.remove(i)
if j in self.leafs:
Expand Down Expand Up @@ -744,7 +735,7 @@ def dist(self, i, j):
return get_dist(self.nodes[i]["xyz"], self.nodes[j]["xyz"])

def get_branches(self, i, ignore_reducibles=False, key="xyz"):
branches = []
branches = list()
for j in self.neighbors(i):
branch = self.oriented_edge((i, j), i, key=key)
if ignore_reducibles:
Expand Down Expand Up @@ -884,12 +875,14 @@ def leaf_neighbor(self, i):
assert self.is_leaf(i)
return list(self.neighbors(i))[0]

"""
def get_edge_attr(self, edge, key):
xyz_arr = gutil.get_edge_attr(self, edge, key)
return xyz_arr[0], xyz_arr[-1]
"""

def to_patch_coords(self, edge, midpoint, chunk_size):
patch_coords = []
patch_coords = list()
for xyz in self.edges[edge]["xyz"]:
coord = self.to_voxels(xyz)
local_coord = util.voxels_to_patch(coord, midpoint, chunk_size)
Expand Down Expand Up @@ -976,7 +969,7 @@ def to_swcs(self, swc_dir):
"""
with ThreadPoolExecutor() as executor:
threads = []
threads = list()
for i, nodes in enumerate(nx.connected_components(self)):
threads.append(executor.submit(self.to_swc, swc_dir, nodes))

Expand All @@ -998,7 +991,7 @@ def to_swc(self, swc_dir, nodes, color=None):
None.
"""
entry_list = []
entry_list = list()
node_to_idx = dict()
for i, j in nx.dfs_edges(self.subgraph(nodes)):
# Initialize
Expand Down
42 changes: 36 additions & 6 deletions src/deep_neurographs/run_graphtrace_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def __init__(
"""
# Class attributes
self.accepted_proposals = list()
self.dataset = dataset
self.pred_id = pred_id
self.img_path = img_path
Expand Down Expand Up @@ -113,6 +114,7 @@ def run(self, fragments_pointer):
self.write_metadata()
t0 = time()

# Main
self.build_graph(fragments_pointer)
self.generate_proposals()
self.run_inference()
Expand All @@ -121,6 +123,24 @@ def run(self, fragments_pointer):
t, unit = util.time_writer(time() - t0)
print(f"Total Runtime: {round(t, 4)} {unit}\n")

def run_schedule(self, fragments_pointer, search_radius_schedule):
# Initializations
print("\nExperiment Details")
print("-----------------------------------------------")
print("Dataset:", self.dataset)
print("Pred_Id:", self.pred_id)
print("")
t0 = time()

# Main
self.build_graph(fragments_pointer)
for search_radius in search_radius_schedule:
self.generate_proposals(search_radius=search_radius)
self.run_inference()
self.save_results()
t, unit = util.time_writer(time() - t0)
print(f"Total Runtime: {round(t, 4)} {unit}\n")

def build_graph(self, fragments_pointer):
"""
Initializes and constructs the fragments graph based on the provided
Expand Down Expand Up @@ -163,7 +183,7 @@ def build_graph(self, fragments_pointer):
print(f"Module Runtime: {round(t, 4)} {unit}\n")
self.print_graph_overview()

def generate_proposals(self):
def generate_proposals(self, search_radius=None):
"""
Generates proposals for the fragment graph based on the specified
configuration.
Expand All @@ -177,18 +197,23 @@ def generate_proposals(self):
None
"""
# Initializations
print("(2) Generate Proposals")
if not search_radius:
search_radius = self.graph_config.search_radius,

# Main
t0 = time()
self.graph.generate_proposals(
self.graph_config.search_radius,
search_radius,
complex_bool=self.graph_config.complex_bool,
long_range_bool=self.graph_config.long_range_bool,
proposals_per_leaf=self.graph_config.proposals_per_leaf,
trim_endpoints_bool=self.graph_config.trim_endpoints_bool,
)
self.graph.xyz_to_edge = dict()
n_proposals = util.reformat_number(self.graph.n_proposals())

# Report results
t, unit = util.time_writer(time() - t0)
print("# Proposals:", n_proposals)
print(f"Module Runtime: {round(t, 4)} {unit}\n")
Expand All @@ -209,6 +234,7 @@ def run_inference(self):
"""
print("(3) Run Inference")
t0 = time()
n_proposals = self.graph.n_proposals()
inference_engine = InferenceEngine(
self.img_path,
self.model_path,
Expand All @@ -217,14 +243,16 @@ def run_inference(self):
confidence_threshold=self.ml_config.threshold,
downsample_factor=self.ml_config.downsample_factor,
)
self.graph, self.accepted_proposals = inference_engine.run(
self.graph, accepts = inference_engine.run(
self.graph, self.graph.list_proposals()
)
self.accepted_proposals.extend(accepts)
print("% Accepted:", len(accepts) / n_proposals)

t, unit = util.time_writer(time() - t0)
print(f"Module Runtime: {round(t, 4)} {unit}\n")

def save_results(self):
def save_results(self, round_id=None):
"""
Saves the processed results from running the inference pipeline,
namely the corrected swc files and a list of the merged swc ids.
Expand All @@ -239,9 +267,11 @@ def save_results(self):
"""
print("(4) Saving Results")
path = os.path.join(self.output_dir, "corrected-processed-swcs.zip")
name = "corrected-processed-swcs.zip"
path = os.path.join(self.output_dir, name + ".zip")
self.graph.to_zipped_swcs(path)
self.save_connections()
self.write_metadata()

# --- io ---
def save_connections(self):
Expand Down
70 changes: 0 additions & 70 deletions src/deep_neurographs/utils/io_util.py

This file was deleted.

27 changes: 27 additions & 0 deletions src/deep_neurographs/utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,3 +725,30 @@ def get_memory_usage():
"""
return psutil.virtual_memory().used / 1e9


def spaced_idxs(container, k):
"""
Generates an array of indices based on a specified step size and ensures
the last index is included.
Parameters:
----------
container : iterable
An iterable (e.g., list, array) from which the length is determined.
The length of this container dictates the range of generated indices.
k : int
Step size for generating indices.
Returns:
-------
numpy.ndarray
Array of indices starting from 0 up to (but not including) the length
of "container" spaced by "k". The last index before the length of
"container" is guaranteed to be included in the output.
"""
idxs = np.arange(0, len(container) + k, k)[:-1]
if len(container) % 2 == 0:
idxs = np.append(idxs, len(container) - 1)
return idxs

0 comments on commit 8d52581

Please sign in to comment.