Skip to content

Commit

Permalink
bug: merge proposals and check cycles (#258)
Browse files Browse the repository at this point in the history
Co-authored-by: anna-grim <[email protected]>
  • Loading branch information
anna-grim and anna-grim authored Oct 2, 2024
1 parent 84a8a73 commit b09c086
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 56 deletions.
28 changes: 8 additions & 20 deletions src/deep_neurographs/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,14 +479,9 @@ def run(self, neurograph, proposals):
preds = self.predict(dataset)

# Update graph
batch_accepts = get_accepted_proposals(
neurograph, preds, self.threshold
)
for proposal in batch_accepts:
neurograph.merge_proposal(proposal)

# Finish
accepts.extend(batch_accepts)
for p in get_accepts(neurograph, preds, self.threshold):
neurograph.merge_proposal(p)
accepts.append(p)
pbar.update(len(batch["proposals"]))
neurograph.absorb_reducibles()
return neurograph, accepts
Expand Down Expand Up @@ -591,7 +586,7 @@ def predict(self, dataset):


# --- Accepting Proposals ---
def get_accepted_proposals(neurograph, preds, threshold, high_threshold=0.9):
def get_accepts(neurograph, preds, threshold, high_threshold=0.9):
"""
Determines which proposals to accept based on prediction scores and the
specified threshold.
Expand Down Expand Up @@ -623,6 +618,7 @@ def get_accepted_proposals(neurograph, preds, threshold, high_threshold=0.9):
accepts = list()
accepts.extend(filter_proposals(neurograph, best_proposals))
accepts.extend(filter_proposals(neurograph, proposals))
neurograph.remove_edges_from(map(tuple, accepts))
return accepts


Expand Down Expand Up @@ -685,17 +681,9 @@ def filter_proposals(graph, proposals):
"""
accepts = list()
for i, j in proposals:
nodes_i = set(gutil.get_component(graph, i))
nodes_j = set(gutil.get_component(graph, j))
if nodes_i.isdisjoint(nodes_j):
subgraph_i = graph.subgraph(nodes_i)
subgraph_j = graph.subgraph(nodes_j)
subgraph = nx.union(subgraph_i, subgraph_j)
created_cycle, _ = gutil.creates_cycle(subgraph, (i, j))
if not created_cycle:
graph.add_edge(i, j)
accepts.append(frozenset({i, j}))
graph.remove_edges_from(map(tuple, accepts))
if not nx.has_path(graph, i, j):
graph.add_edge(i, j)
accepts.append(frozenset({i, j}))
return accepts


Expand Down
11 changes: 7 additions & 4 deletions src/deep_neurographs/neurograph.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,10 +224,13 @@ def absorb_reducibles(self):
for i in nodes:
nbs = list(self.neighbors(i))
if len(nbs) == 2 and len(self.nodes[i]["proposals"]) == 0:
# Get attributes
# Concatenate attributes
len_1 = self.edges[i, nbs[0]]["length"]
len_2 = self.edges[i, nbs[1]]["length"]
xyz = self.get_branches(i, key="xyz")
radius = self.get_branches(i, key="radius")
attrs = {
"length": len_1 + len_2,
"radius": concatenate([np.flip(radius[0]), radius[1]]),
"xyz": concatenate([np.flip(xyz[0], axis=0), xyz[1]]),
}
Expand Down Expand Up @@ -649,7 +652,7 @@ def proposal_directionals(self, proposal, window):
def merge_proposal(self, proposal):
i, j = tuple(proposal)
somas_check = not (self.is_soma(i) and self.is_soma(j))
degrees_check = self.degree[i] == 2 and self.degree[j] == 2
degrees_check = not (self.degree[i] == 2 and self.degree[j] == 2)
if somas_check and degrees_check:
# Dense attributes
attrs = dict()
Expand All @@ -665,10 +668,10 @@ def merge_proposal(self, proposal):
attrs["length"] = len_ij
elif self.degree[i] == 2:
e_j = (j, self.leaf_neighbor(j))
attrs["length"] = self.edges[e_i]["length"]
attrs["length"] = self.edges[e_j]["length"]
else:
e_i = (i, self.leaf_neighbor(i))
attrs["length"] = self.edges[e_j]["length"]
attrs["length"] = self.edges[e_i]["length"]

swc_id_i = self.nodes[i]["swc_id"]
swc_id_j = self.nodes[j]["swc_id"]
Expand Down
37 changes: 5 additions & 32 deletions src/deep_neurographs/utils/graph_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,32 +798,6 @@ def upd_node_attrs(swc_dict, leafs, junctions, i):


# -- miscellaneous --
def creates_cycle(graph, edge):
"""
Checks whether adding "edge" to "graph" creates a cycle.
Paramaters
----------
graph : networkx.Graph
Graph to be checked for cycles.
edge : tuple
Edge to be added to "graph"
Returns
-------
bool
Indication of whether adding "edge" to graph creates a cycle.
"""
graph.add_edges_from([edge])
exists = cycle_exists(graph)
graph.remove_edges_from([edge])
if exists:
return True, edge
else:
return False, edge


def cycle_exists(graph):
"""
Checks whether a cycle exists in "graph".
Expand Down Expand Up @@ -900,14 +874,13 @@ def get_component(graph, root):
"""
queue = [root]
component = set()
visited = set()
while len(queue):
i = queue.pop()
component.add(i)
for j in [j for j in graph.neighbors(i) if j not in component]:
if (i, j) in graph.edges:
queue.append(j)
return component
visited.add(i)
for j in [j for j in graph.neighbors(i) if j not in visited]:
queue.append(j)
return visited


def count_components(graph):
Expand Down

0 comments on commit b09c086

Please sign in to comment.