Skip to content

Commit

Permalink
minor upds and documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
anna-grim committed Sep 24, 2024
1 parent cfa490b commit a2f3acb
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 46 deletions.
2 changes: 1 addition & 1 deletion src/deep_neurographs/neurograph.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def generate_proposals(
long_range_bool=False,
proposals_per_leaf=3,
return_trimmed_proposals=False,
trim_endpoints_bool=False,
trim_endpoints_bool=False,
):
"""
Generates proposals from leaf nodes.
Expand Down
7 changes: 0 additions & 7 deletions src/deep_neurographs/utils/ml_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,3 @@ def get_kfolds(filenames, k):
if n_samples > len(samples):
break
return folds


def get_batches(my_list, batch_size):
batches = list()
for start in range(0, len(my_list), batch_size):
batches.append(my_list[start: min(start + batch_size, len(my_list))])
return batches
94 changes: 87 additions & 7 deletions src/deep_neurographs/utils/swc_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from deep_neurographs.utils import util


# --- Read ---
class Reader:
"""
Class that reads swc files that are stored as (1) local directory of swcs,
Expand Down Expand Up @@ -336,20 +337,39 @@ def read_xyz(self, xyz_str, offset=[0.0, 0.0, 0.0]):
return xyz


# --- Write ---
def write(path, content, color=None):
"""
Write content to a specified file in a format based on the type o
f content.
Parameters
----------
path : str
File path where the content will be written.
content : list, dict, nx.Graph
The content to be written.
color : str, optional
Color of swc to be written. The default is None.
Returns
-------
None
"""
if type(content) is list:
write_list(path, content, color=color)
elif type(content) is dict:
write_dict(path, content, color=color)
elif type(content) is nx.Graph:
write_graph(path, content, color=color)
else:
assert True, "Unable to write {} to swc".format(type(content))
raise ExceptionType("Unable to write {} to swc".format(type(content)))


def write_list(path, entry_list, color=None):
"""
Writes an swc file.
Writes a list of swc entries to a file at path.
Parameters
----------
Expand All @@ -358,7 +378,7 @@ def write_list(path, entry_list, color=None):
entry_list : list[str]
List of entries that will be written to an swc file.
color : str, optional
Color of nodes. The default is None.
Color of swc to be written. The default is None.
Returns
-------
Expand All @@ -378,8 +398,26 @@ def write_list(path, entry_list, color=None):


def write_dict(path, swc_dict, color=None):
"""
Writes the dictionary to an swc file.
Parameters
----------
path : str
Path that swc will be written to.
swc_dict : dict
Dictionaries whose keys and values are the attribute name and values
from an swc file.
color : str, optional
Color of swc to be written. The default is None.
Returns
-------
None
"""
graph, _ = to_graph(swc_dict, set_attrs=True)
return write_graph(path, graph, color=color)
write_graph(path, graph, color=color)


def write_graph(path, graph, color=None):
Expand Down Expand Up @@ -546,9 +584,28 @@ def make_simple_entry(node, parent, xyz, radius=8):
return f"{node} 2 {x} {y} {z} {radius} {parent}"


# -- Conversions --
def to_graph(swc_dict, graph_id=None, set_attrs=False):
graph = nx.Graph(graph_id=graph_id)
# --- Miscellaneous ---
def to_graph(swc_dict, swc_id=None, set_attrs=False):
"""
Converts an dictionary containing swc attributes to a graph.
Parameters
----------
swc_dict : dict
Dictionaries whose keys and values are the attribute name and values
from an swc file.
swc_id : str, optional
Identifier that dictionary was generated from. The default is None.
set_attrs : bool, optional
Indication of whether to set attributes. The default is False.
Returns
-------
networkx.Graph
Graph generated from "swc_dict".
"""
graph = nx.Graph(graph_id=swc_id)
graph.add_edges_from(zip(swc_dict["id"][1:], swc_dict["pid"][1:]))
if set_attrs:
xyz = swc_dict["xyz"]
Expand All @@ -561,6 +618,29 @@ def to_graph(swc_dict, graph_id=None, set_attrs=False):


def __add_attributes(swc_dict, graph):
"""
Adds node attributes to a NetworkX graph based on information from
"swc_dict".
Parameters:
----------
swc_dict : dict
A dictionary containing SWC data. It must have the following keys:
- "id": A list of node identifiers (unique for each node).
- "xyz": A list of 3D coordinates (x, y, z) for each node.
- "radius": A list of radii for each node.
graph : networkx.Graph
A NetworkX graph object to which the attributes will be added.
The graph must contain nodes that correspond to the IDs in
"swc_dict["id"]".
Returns:
-------
networkx.Graph
The modified graph with added node attributes for each node.
"""
attrs = dict()
for idx, node_id in enumerate(swc_dict["id"]):
attrs[node_id] = {
Expand Down
54 changes: 23 additions & 31 deletions src/deep_neurographs/utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"""

import boto3
import json
import math
import os
Expand Down Expand Up @@ -487,6 +488,28 @@ def write_txt(path, contents):
f.close()


def write_to_s3(local_path, bucket_name, s3_key):
"""
Writes a single file on local machine to an s3 bucket.
Parameters
----------
local_path : str
Path to file to be written to s3.
bucket_name : str
Name of s3 bucket.
s3_key : str
Path within s3 bucket.
Returns
-------
None
"""
s3 = boto3.client('s3')
s3.upload_file(local_path, bucket_name, s3_key)


# --- math utils ---
def get_avg_std(data, weights=None):
"""
Expand Down Expand Up @@ -623,37 +646,6 @@ def time_writer(t, unit="seconds"):
return t, unit


def progress_bar(current, total, bar_length=50, eta=None, runtime=None):
progress = int(current / total * bar_length)
n_completed = f"Completed: {current}/{total}"
bar = f"[{'=' * progress}{' ' * (bar_length - progress)}]"
eta = f"Time Remaining: {eta}" if eta else ""
runtime = f"Estimated Total Runtime: {runtime}" if runtime else ""
print(f"\r{bar} {n_completed} | {eta} | {runtime} ", end="", flush=True)


def report_progress(current, total, chunk_size, cnt, t0, t1):
eta = get_eta(current, total, chunk_size, t1)
runtime = get_runtime(current, total, chunk_size, t0, t1)
progress_bar(current, total, eta=eta, runtime=runtime)
return cnt + 1, time()


def get_eta(current, total, chunk_size, t0, return_str=True):
chunk_runtime = time() - t0
remaining = total - current
eta = remaining * (chunk_runtime / max(chunk_size, 1))
t, unit = time_writer(eta)
return f"{round(t, 4)} {unit}" if return_str else eta


def get_runtime(current, total, chunk_size, t0, t1):
eta = get_eta(current, total, chunk_size, t1, return_str=False)
total_runtime = time() - t0 + eta
t, unit = time_writer(total_runtime)
return f"{round(t, 4)} {unit}"


# --- miscellaneous ---
def get_swc_id(path):
"""
Expand Down

0 comments on commit a2f3acb

Please sign in to comment.