From 0b9641a8d99c309ef3171464b1d97de6a6cc7053 Mon Sep 17 00:00:00 2001 From: anna-grim Date: Fri, 13 Sep 2024 01:34:40 +0000 Subject: [PATCH 1/4] refactor: deleted remove_connector option --- src/deep_neurographs/geometry.py | 2 +- src/deep_neurographs/intake.py | 60 +---- .../machine_learning/datasets.py | 201 ++++++++------ .../machine_learning/graph_datasets.py | 37 +-- .../machine_learning/heterograph_datasets.py | 39 +-- .../machine_learning/inference.py | 45 +--- .../machine_learning/trainer.py | 8 +- src/deep_neurographs/utils/graph_util.py | 248 +++++------------- src/deep_neurographs/utils/img_util.py | 2 +- src/deep_neurographs/utils/ml_util.py | 38 +-- 10 files changed, 226 insertions(+), 454 deletions(-) diff --git a/src/deep_neurographs/geometry.py b/src/deep_neurographs/geometry.py index 754330e..dd701f7 100644 --- a/src/deep_neurographs/geometry.py +++ b/src/deep_neurographs/geometry.py @@ -316,7 +316,7 @@ def fill_path(img, path, val=-1): """ for xyz in path: x, y, z = tuple(np.floor(xyz).astype(int)) - img[x - 1 : x + 2, y - 1 : y + 2, z - 1 : z + 2] = val + img[x - 1: x + 2, y - 1: y + 2, z - 1: z + 2] = val return img diff --git a/src/deep_neurographs/intake.py b/src/deep_neurographs/intake.py index 6899230..fcf08ba 100644 --- a/src/deep_neurographs/intake.py +++ b/src/deep_neurographs/intake.py @@ -25,10 +25,8 @@ MIN_SIZE = 30 NODE_SPACING = 2 SMOOTH = True -PRUNE_CONNECTORS = False PRUNE_DEPTH = 25 TRIM_DEPTH = 0 -CONNECTOR_LENGTH = 8 # --- Build graph wrappers --- @@ -40,8 +38,6 @@ def build_neurograph_from_local( min_size=MIN_SIZE, node_spacing=NODE_SPACING, progress_bar=False, - prune_connectors=PRUNE_CONNECTORS, - connector_length=CONNECTOR_LENGTH, prune_depth=PRUNE_DEPTH, trim_depth=TRIM_DEPTH, smooth=SMOOTH, @@ -74,13 +70,6 @@ def build_neurograph_from_local( progress_bar : bool, optional Indication of whether to print out a progress bar during build. The default is False. - prune_connectors : bool, optional - Indication of whether to prune connectors (see graph_util.py), sites - that are likely to be false merges. The default is the global variable - "PRUNE_CONNECTORS". - connector_length : int, optional - Maximum length of connecting paths pruned (see graph_util.py). The - default is the global variable "CONNECTOR_LENGTH". prune_depth : int, optional Branches less than "prune_depth" microns are pruned if "prune" is True. The default is the global variable "PRUNE_DEPTH". @@ -122,8 +111,6 @@ def build_neurograph_from_local( min_size=min_size, node_spacing=node_spacing, progress_bar=progress_bar, - prune_connectors=prune_connectors, - connector_length=connector_length, prune_depth=prune_depth, trim_depth=trim_depth, smooth=smooth, @@ -139,8 +126,6 @@ def build_neurograph_from_gcs_zips( img_path=None, min_size=MIN_SIZE, node_spacing=NODE_SPACING, - prune_connectors=PRUNE_CONNECTORS, - connector_length=CONNECTOR_LENGTH, prune_depth=PRUNE_DEPTH, trim_depth=TRIM_DEPTH, smooth=SMOOTH, @@ -166,13 +151,6 @@ def build_neurograph_from_gcs_zips( node_spacing : int, optional Spacing (in microns) between nodes. The default is the global variable "NODE_SPACING". - prune_connectors : bool, optional - Indication of whether to prune connectors (see graph_util.py), sites - that are likely to be false merges. The default is the global variable - "PRUNE_CONNECTORS". - connector_length : int, optional - Maximum length of connecting paths pruned (see graph_util.py). The - default is the global variable "CONNECTOR_LENGTH". prune_depth : int, optional Branches less than "prune_depth" microns are pruned if "prune" is True. The default is the global variable "PRUNE_DEPTH". @@ -186,23 +164,14 @@ def build_neurograph_from_gcs_zips( Neurograph generated from zips of swc files stored in a GCS bucket. """ - # Process swc files - print("Process swc files...") - total_runtime, t0 = util.init_timers() - swc_dicts = download_gcs_zips(bucket_name, gcs_path, min_size, anisotropy) - t, unit = util.time_writer(time() - t0) - print(f"\nModule Runtime: {round(t, 4)} {unit} \n") - - # Build neurograph - print("Build NeuroGraph...") + print("\nBuild NeuroGraph...") t0 = time() + swc_dicts = download_gcs_zips(bucket_name, gcs_path, min_size, anisotropy) neurograph = build_neurograph( swc_dicts, img_path=img_path, min_size=min_size, node_spacing=node_spacing, - prune_connectors=prune_connectors, - connector_length=connector_length, prune_depth=prune_depth, trim_depth=trim_depth, smooth=smooth, @@ -272,23 +241,16 @@ def build_neurograph( node_spacing=NODE_SPACING, swc_paths=None, progress_bar=True, - prune_connectors=PRUNE_CONNECTORS, - connector_length=CONNECTOR_LENGTH, prune_depth=PRUNE_DEPTH, trim_depth=TRIM_DEPTH, smooth=SMOOTH, ): # Extract irreducibles - n_components = len(swc_dicts) - if progress_bar: - print("# swcs downloaded:", util.reformat_number(n_components)) irreducibles, n_nodes, n_edges = get_irreducibles( swc_dicts, bbox=img_bbox, min_size=min_size, progress_bar=progress_bar, - prune_connectors=prune_connectors, - connector_length=connector_length, prune_depth=prune_depth, trim_depth=trim_depth, smooth=smooth, @@ -317,18 +279,14 @@ def get_irreducibles( bbox=None, min_size=MIN_SIZE, progress_bar=True, - prune_connectors=PRUNE_CONNECTORS, - connector_length=CONNECTOR_LENGTH, prune_depth=PRUNE_DEPTH, trim_depth=TRIM_DEPTH, smooth=SMOOTH, ): - n_components = len(swc_dicts) - chunk_size = int(n_components * 0.02) with ProcessPoolExecutor() as executor: # Assign Processes i = 0 - processes = [None] * n_components + processes = [None] * len(swc_dicts) while swc_dicts: swc_dict = swc_dicts.pop() processes[i] = executor.submit( @@ -336,8 +294,6 @@ def get_irreducibles( swc_dict, min_size, bbox, - prune_connectors, - connector_length, prune_depth, trim_depth, smooth, @@ -345,19 +301,13 @@ def get_irreducibles( i += 1 # Store results - t0, t1 = util.init_timers() - n_nodes, n_edges = 0, 0 - cnt = 1 irreducibles = [] - for i, process in enumerate(as_completed(processes)): + n_nodes, n_edges = 0, 0 + for process in tqdm(as_completed(processes), desc="Extract Graphs"): irreducibles_i = process.result() irreducibles.extend(irreducibles_i) n_nodes += count_nodes(irreducibles_i) n_edges += count_edges(irreducibles_i) - if i >= cnt * chunk_size and progress_bar: - cnt, t1 = util.report_progress( - i + 1, n_components, chunk_size, cnt, t0, t1 - ) return irreducibles, n_nodes, n_edges diff --git a/src/deep_neurographs/machine_learning/datasets.py b/src/deep_neurographs/machine_learning/datasets.py index fe4bd6c..3601a4d 100644 --- a/src/deep_neurographs/machine_learning/datasets.py +++ b/src/deep_neurographs/machine_learning/datasets.py @@ -10,113 +10,121 @@ import numpy as np import torchio as tio -from torch.utils.data import Dataset +from torch.utils.data import Dataset as TorchDataset +from deep_neurographs.machine_learning import feature_generation -# Custom datasets -class ProposalDataset(Dataset): + +# Wrapper +def init(neurograph, features, model_type, sample_ids=None): """ - Custom dataset that contains feature vectors that correspond to edge - proposals. + Initializes a dataset that can be used to train a machine learning model. + + Parameters + ---------- + neurograph : NeuroGraph + Graph that dataset is built from. + features : dict + Feature vectors corresponding to branches such that the keys are + "proposals" and "branches". The values are a dictionary containing + different types of features for edges and branches. + is_multimodal : bool, optional + Indication of whether model is multimodal. The default is False. + sample_ids : list[str] + ... + + Returns + ------- + GraphDataset + Custom dataset. """ + # Extract features + x_proposals, y_proposals, idxs_proposals = feature_generation.get_matrix( + neurograph, features["proposals"], model_type, sample_ids=sample_ids + ) + + # Initialize dataset + proposals = list(features["proposals"]["skel"].keys()) + dataset = Dataset( + proposals, + x_proposals, + y_proposals, + idxs_proposals, + ) + return dataset + + +class Dataset: + """ + Dataset class that contains feature vectors of edge proposals. The feature + vectors may be either unimodal or multimodal. + """ def __init__( - self, inputs, targets, search_radius=10, transform=False, lengths=[] + self, + proposals, + x_proposals, + y_proposals, + idxs_proposals, ): """ - Constructs ProposalDataset object. + Constructs a Dataset object. Parameters ---------- - inputs : np.array - Feature matrix where each row corresponds to the feature vector of - an edge proposal. - targets : np.array - Binary vector where each entry indicates whether an edge proposal - should be added or omitted from a reconstruction. + computation_graph : networkx.Graph + Graph used by gnn to classify proposals. + proposals : list + List of proposals to be classified. + x_proposals : numpy.ndarray + Feature matrix generated from "proposals" in "computation_graph". + y_proposals : numpy.ndarray + Ground truth of proposals (i.e. accept or reject). + idxs_proposals : dict + Dictionary that maps "proposals" to an index that represents the + proposal's position in "x_proposals". Returns ------- None """ - self.inputs = inputs.astype(np.float32) - self.targets = reformat(targets) - self.lengths = lengths - self.transform = transform - - def __len__(self): - """ - Computes number of examples in dataset. - - Parameters - ---------- - None - - Returns - ------- - int - Number of examples in dataset. - - """ - return len(self.targets) - - def __getitem__(self, idx): - """ - Gets example (i.e. input and label) corresponding to "idx". + # Conversion idxs + self.block_to_idxs = idxs_proposals["block_to_idxs"] + self.idxs_proposals = init_idxs(idxs_proposals) + self.proposals = proposals - Parameters - ---------- - idx : int - Index of example to be returned. + # Features + self.data = ProposalDataset(x=x_proposals, y=y_proposals) - Returns - ------- - dict - Example corresponding to "idx". - """ - inputs_i = self.inputs[idx] - if self.transform: - if np.random.random() > 0.6: - p = 100 * np.random.random() - inputs_i[0] = np.percentile(self.lengths, p) - return {"inputs": inputs_i, "targets": self.targets[idx]} - - -class ImgProposalDataset(Dataset): +class ProposalDataset(TorchDataset): """ - Custom dataset that contains image chunks that correspond to edge + Custom dataset that contains feature vectors that correspond to edge proposals. """ - def __init__(self, inputs, targets, transform=True): + def __init__(self, x, y): """ - Constructs ImgProposalDataset object. + Constructs ProposalDataset object. Parameters ---------- - inputs : numpy.array - Feature tensor where each submatrix corresponds to an image chunk - that contains an edge proposal. Note that the midpoint of the edge - proposal is the center point of the chunk. - targets : np.array - Binary vector where each entry indicates whether an edge proposal - should be added or omitted from a reconstruction. - transform : bool, optional - Indication of whether to apply data augmentation to the inputs. - The default is True. + x : np.array + Feature matrix where each row corresponds to the feature vector of + a proposal. + y : np.array + Ground truth of proposals (i.e. accept or reject). Returns ------- None """ - self.inputs = inputs.astype(np.float32) - self.targets = reformat(targets) - self.transform = AugmentImages() if transform else None + self.x = x.astype(np.float32) + self.y = reformat(y) def __len__(self): """ @@ -132,7 +140,7 @@ def __len__(self): Number of examples in dataset. """ - return len(self.targets) + return len(self.y) def __getitem__(self, idx): """ @@ -149,11 +157,7 @@ def __getitem__(self, idx): Example corresponding to "idx". """ - if self.transform: - inputs = self.transform.run(self.inputs[idx]) - else: - inputs = self.inputs[idx] - return {"inputs": inputs, "targets": self.targets[idx]} + return {"inputs": self.x[idx], "targets": self.y[idx]} class MultiModalDataset(Dataset): @@ -163,18 +167,18 @@ class MultiModalDataset(Dataset): """ - def __init__(self, inputs, targets, transform=True): + def __init__(self, x, y, transform=True): """ Constructs MultiModalDataset object. Parameters ---------- - inputs : dict + x : dict Feature dictionary where each key-value is the type of feature and corresponding value. The keys of this dictionary are (1) "imgs" and (2) "features" which correspond to a (1) feature tensor containing image chunks and (2) feature vector. - targets : np.array + y : np.array Binary vector where each entry indicates whether an edge proposal should be added or omitted from a reconstruction. transform : bool, optional @@ -186,9 +190,9 @@ def __init__(self, inputs, targets, transform=True): None """ - self.img_inputs = inputs["imgs"].astype(np.float32) - self.feature_inputs = inputs["features"].astype(np.float32) - self.targets = reformat(targets) + self.x_imgs = x["imgs"].astype(np.float32) + self.x_features = x["features"].astype(np.float32) + self.y = reformat(y) self.transform = AugmentImages() if transform else None def __len__(self): @@ -205,7 +209,7 @@ def __len__(self): Number of examples in dataset. """ - return len(self.targets) + return len(self.y) def __getitem__(self, idx): """ @@ -223,10 +227,10 @@ def __getitem__(self, idx): """ if self.transform: - img_inputs = self.transform.run(self.img_inputs[idx]) + x_img = self.transform.run(self.x_imgs[idx]) else: - img_inputs = self.img_inputs[idx] - inputs = [self.feature_inputs[idx], img_inputs] + x_img = self.x_imgs[idx] + inputs = [self.feature_inputs[idx], x_img] return {"inputs": inputs, "targets": self.targets[idx]} @@ -297,3 +301,26 @@ def reformat(arr): """ return np.expand_dims(arr, axis=1).astype(np.float32) + + +def init_idxs(idxs): + """ + Adds dictionary item called "edge_to_index" which maps an edge in a + neurograph to an that represents the edge's position in the feature + matrix. + + Parameters + ---------- + idxs : dict + Dictionary that maps indices to edges in some neurograph. + + Returns + ------- + dict + Updated dictionary. + + """ + idxs["edge_to_idx"] = dict() + for idx, edge in idxs["idx_to_edge"].items(): + idxs["edge_to_idx"][edge] = idx + return idxs diff --git a/src/deep_neurographs/machine_learning/graph_datasets.py b/src/deep_neurographs/machine_learning/graph_datasets.py index d6dbc8d..b7737a9 100644 --- a/src/deep_neurographs/machine_learning/graph_datasets.py +++ b/src/deep_neurographs/machine_learning/graph_datasets.py @@ -15,7 +15,7 @@ import torch from torch_geometric.data import Data as GraphData -from deep_neurographs.machine_learning import feature_generation +from deep_neurographs.machine_learning import feature_generation, datasets from deep_neurographs.utils import gnn_util @@ -41,14 +41,14 @@ def init(neurograph, features): """ # Extract features x_branches, _, idxs_branches = feature_generation.get_matrix( - neurograph, features["branch"], "GraphNeuralNet" + neurograph, features["branches"], "GraphNeuralNet" ) x_proposals, y_proposals, idxs_proposals = feature_generation.get_matrix( - neurograph, features["proposal"], "GraphNeuralNet" + neurograph, features["proposals"], "GraphNeuralNet" ) - # Initialize data - proposals = features["proposal"]["skel"].keys() + # Initialize dataset + proposals = list(features["proposals"]["skel"].keys()) graph_dataset = GraphDataset( neurograph, proposals, @@ -112,8 +112,8 @@ def __init__( # Set edges idxs_branches = shift_idxs(idxs_branches, x_proposals.shape[0]) - self.idxs_branches = init_idxs(idxs_branches) - self.idxs_proposals = init_idxs(idxs_proposals) + self.idxs_branches = datasets.init_idxs(idxs_branches) + self.idxs_proposals = datasets.init_idxs(idxs_proposals) self.proposals = proposals # Initialize data @@ -246,26 +246,3 @@ def shift_idxs(idxs, shift): shifted_idxs[key + shift] = value idxs["idx_to_edge"] = shifted_idxs return idxs - - -def init_idxs(idxs): - """ - Adds dictionary item called "edge_to_index" which maps an edge in a - neurograph to an that represents the edge's position in the feature - matrix. - - Parameters - ---------- - idxs : dict - Dictionary that maps indices to edges in some neurograph. - - Returns - ------- - dict - Updated dictionary. - - """ - idxs["edge_to_idx"] = dict() - for idx, edge in idxs["idx_to_edge"].items(): - idxs["edge_to_idx"][edge] = idx - return idxs diff --git a/src/deep_neurographs/machine_learning/heterograph_datasets.py b/src/deep_neurographs/machine_learning/heterograph_datasets.py index 6eedc71..67499af 100644 --- a/src/deep_neurographs/machine_learning/heterograph_datasets.py +++ b/src/deep_neurographs/machine_learning/heterograph_datasets.py @@ -17,14 +17,14 @@ import torch from torch_geometric.data import HeteroData as HeteroGraphData -from deep_neurographs.machine_learning import feature_generation +from deep_neurographs.machine_learning import feature_generation, datasets from deep_neurographs.utils import gnn_util DTYPE = torch.float32 # Wrapper -def init(neurograph, computation_graph, features): +def init(neurograph, features, computation_graph): """ Initializes a dataset that can be used to train a graph neural network. @@ -32,11 +32,11 @@ def init(neurograph, computation_graph, features): ---------- neurograph : NeuroGraph Graph that dataset is built from. + features : dict + Dictionary that contains different types of feature vectors for nodes, + edges, and proposals. computation_graph : networkx.Graph Graph used by gnn to classify proposals. - features : dict - Dictionary that contains different types of feature vectors for - nodes, edges, and proposals. Returns ------- @@ -53,7 +53,7 @@ def init(neurograph, computation_graph, features): ) x_nodes = feature_generation.combine_features(features["nodes"]) - # Initialize data + # Initialize dataset proposals = list(features["proposals"]["skel"].keys()) heterograph_dataset = HeteroGraphDataset( computation_graph, @@ -116,8 +116,8 @@ def __init__( """ # Conversion idxs - self.idxs_branches = init_idxs(idxs_branches) - self.idxs_proposals = init_idxs(idxs_proposals) + self.idxs_branches = datasets.init_idxs(idxs_branches) + self.idxs_proposals = datasets.init_idxs(idxs_proposals) self.computation_graph = computation_graph self.proposals = proposals @@ -369,29 +369,6 @@ def set_hetero_edge_attrs(self, x_nodes, edge_type, idx_map_1, idx_map_2): # -- util -- -def init_idxs(idxs): - """ - Adds dictionary item called "edge_to_index" which maps an edge in a - neurograph to an that represents the edge's position in the feature - matrix. - - Parameters - ---------- - idxs : dict - Dictionary that maps indices to edges in some neurograph. - - Returns - ------- - dict - Updated dictionary. - - """ - idxs["edge_to_idx"] = dict() - for idx, edge in idxs["idx_to_edge"].items(): - idxs["edge_to_idx"][edge] = idx - return idxs - - def node_intersection(idx_map, e1, e2): """ Computes the common node between "e1" and "e2" in the case where these diff --git a/src/deep_neurographs/machine_learning/inference.py b/src/deep_neurographs/machine_learning/inference.py index b4e6132..f130f1a 100644 --- a/src/deep_neurographs/machine_learning/inference.py +++ b/src/deep_neurographs/machine_learning/inference.py @@ -81,8 +81,8 @@ def __init__( # Load image and model driver = "n5" if ".n5" in img_path else "zarr" - self.img = img_util.open(img_path, driver) - self.model = ml_util.load_model(model_type, model_path) + self.img = img_util.open_tensorstore(img_path, driver=driver) + self.model = ml_util.load_model(model_path) def run(self, neurograph, proposals): """ @@ -195,7 +195,7 @@ def run_model(self, dataset): Parameters ---------- - dataset : ... + data : ... Dataset on which the model inference is to be run. Returns @@ -210,25 +210,23 @@ def run_model(self, dataset): if self.is_gnn: preds = run_gnn_model(dataset.data, self.model, self.model_type) elif "Net" in self.model_type: - preds = run_nn_model(dataset, self.model) + preds = run_nn_model(dataset.data, self.model) else: - data = dataset["dataset"]["inputs"] - preds = np.array(self.model.predict_proba(data)[:, 1]) + preds = np.array(self.model.predict_proba(dataset.data.x)[:, 1]) # Filter preds - idxs = get_idxs(dataset, self.model_type) + idxs = dataset.idxs_proposals["idx_to_edge"] return {idxs[i]: p for i, p in enumerate(preds) if p > self.threshold} # --- run machine learning model --- -def run_nn_model(dataset, model): +def run_nn_model(data, model): hat_y = [] model.eval() with torch.no_grad(): - for batch in DataLoader(dataset["dataset"], batch_size=32): + for batch in DataLoader(data, batch_size=32): # Run model - x_i = batch["inputs"] - hat_y_i = sigmoid(model(x_i)) + hat_y_i = sigmoid(model(batch["inputs"])) # Postprocess hat_y_i = np.array(hat_y_i) @@ -358,28 +356,3 @@ def filter_proposals(graph, proposals): accepts.append((i, j)) graph.remove_edges_from(accepts) return accepts - - -# --- util --- -def get_idxs(dataset, model_type): - """ - Gets dictionary from "dataset" that maps indices (from feature matrix) to - proposal ids. - - Parameters - ---------- - dataset : ProposalDataset - Dataset that contains features generated from proposals. - model_type : str - Type of model used to perform inference. - - Returns - ------- - dict - Dictionary that maps indices (from feature matrix) to proposal ids. - - """ - if "Graph" in model_type: - return dataset.idxs_proposals["idx_to_edge"] - else: - return dataset["idx_to_edge"] diff --git a/src/deep_neurographs/machine_learning/trainer.py b/src/deep_neurographs/machine_learning/trainer.py index d1c179a..954f125 100644 --- a/src/deep_neurographs/machine_learning/trainer.py +++ b/src/deep_neurographs/machine_learning/trainer.py @@ -37,9 +37,7 @@ def fit_model(model, dataset): - inputs = dataset["dataset"]["inputs"] - targets = dataset["dataset"]["targets"] - model.fit(inputs, targets) + model.fit(dataset.data.x, dataset.data.y) return model @@ -76,8 +74,7 @@ def fit_deep_model( ... """ # Load data - dataset = dataset["dataset"] - train_set, valid_set = random_split(dataset) + train_set, valid_set = random_split(dataset.data) train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True) valid_loader = DataLoader(valid_set, batch_size=batch_size) @@ -99,7 +96,6 @@ def fit_deep_model( pylightning_trainer.fit(lit_model, train_loader, valid_loader) # Return best model - print(ckpt_callback.best_model_path) ckpt = torch.load(ckpt_callback.best_model_path) lit_model.model.load_state_dict(ckpt["state_dict"]) return lit_model.model diff --git a/src/deep_neurographs/utils/graph_util.py b/src/deep_neurographs/utils/graph_util.py index a5f9ab3..e79e3c5 100644 --- a/src/deep_neurographs/utils/graph_util.py +++ b/src/deep_neurographs/utils/graph_util.py @@ -35,11 +35,9 @@ def get_irreducibles( swc_dict, min_size, bbox=None, - prune_connectors=False, - connector_length=8, - prune_depth=16, - trim_depth=0, + prune_depth=16.0, smooth=True, + trim_depth=0.0, ): """ Gets irreducible components of the graph stored in "swc_dict" by building @@ -51,18 +49,19 @@ def get_irreducibles( ---------- swc_dict : dict Contents of an swc file. + min_size : float + Minimum cardinality of swc files that are stored in NeuroGraph. bbox : dict, optional - ... - prune_connectors : bool, optional - Indication of whether to prune short paths connecting branches. - The default is False. + Dictionary with the keys "min" and "max" which specify a bounding box + in the image. The default is None. prune_depth : float, optional Path length microns that determines whether a branch is short and - should be pruned. The default is 16. - trim_depth : float, optional - Depth in microns to trim branch. The default is 0. + should be pruned. The default is 16.0. smooth : bool, optional Indication of whether to smooth each branch. The default is True. + trim_depth : float, optional + Maximum path length (in microns) to trim from "branch". The default is + 0.0. Returns ------- @@ -76,18 +75,12 @@ def get_irreducibles( swc_dict["idx"] = dict(zip(swc_dict["id"], range(len(swc_dict["id"])))) graph, _ = swc_util.to_graph(swc_dict, set_attrs=True) graph = clip_branches(graph, bbox) - graph, n_nodes_trimmed = prune_branches( - graph, - connector_length=connector_length, - prune_connectors=prune_connectors, - prune_depth=prune_depth, - trim_depth=trim_depth, - ) + graph, n_trimmed = prune_trim_branches(graph, prune_depth, trim_depth) # Extract irreducibles irreducibles = [] for node_subset in nx.connected_components(graph): - if len(node_subset) + n_nodes_trimmed > min_size: + if len(node_subset) + n_trimmed > min_size: subgraph = graph.subgraph(node_subset) irreducibles_i = __get_irreducibles(subgraph, swc_dict, smooth) if irreducibles_i: @@ -104,7 +97,8 @@ def clip_branches(graph, bbox): graph : networkx.Graph Graph to be searched bbox : dict - Bounding box. + Dictionary with the keys "min" and "max" which specify a bounding box + in the image. The default is None. Returns ------- @@ -122,50 +116,35 @@ def clip_branches(graph, bbox): return graph -def prune_branches( - graph, - connector_length=8, - prune_connectors=False, - prune_depth=16, - trim_depth=0, -): +def prune_trim_branches(graph, prune_depth, trim_depth): """ - Prunes spurious branches and short paths connecting branches - (i.e. possible merge mistakes). + Prunes all short branches from "graph" and trims branchs if applicable. A + short branch is a path between a leaf and junction node with a path length + smaller than "prune_depth". Parameters ---------- graph : networkx.Graph - Graph to be pruned. - connector_length : float, optional - ... - prune_connectors : bool, optional - Indication of whether to prune short paths connecting branches. - The default is False. - prune_depth : float, optional - Path length microns that determines whether a branch is short and - should be pruned. The default is 16. - trim_depth : float, optional - Depth in microns to trim branch. The default is 0. + Graph to be pruned and trimmed. + prune_depth : float + Path length microns that determines whether a branch is short. The + default is 16.0. + trim_depth : float + Maximum path length (in microns) to trim from "branch". Returns ------- networkx.Graph - Pruned graph. - int - Number of nodes that were trimmed. + Graph with branches trimmed and short branches pruned. """ - # Prune/Trim branches - assert prune_depth > 0 if prune_connectors else True, "prune_depth == 0" - if prune_depth > 0: - graph, n_nodes_trimmed = prune_trim_branches( - graph, prune_depth, trim_depth - ) - - # Prune connectors - if prune_connectors: - graph = prune_short_connectors(graph, connector_length) + remove_nodes = [] + n_nodes_trimmed = 0 + for leaf in get_leafs(graph): + nodes, n_trimmed = prune_trim(graph, leaf, prune_depth, trim_depth) + remove_nodes.extend(nodes) + n_nodes_trimmed += n_trimmed + graph.remove_nodes_from(remove_nodes) return graph, n_nodes_trimmed @@ -178,9 +157,9 @@ def __get_irreducibles(graph, swc_dict, smooth): graph : networkx.Graph Graph to be searched. swc_dict : dict - Dictionary that was used to build "graph". + Dictionary used to build "graph". smooth : bool - Indication of whether to smooth irreducible edges. + Indication of whether to smooth edges. Returns ------- @@ -259,40 +238,11 @@ def get_irreducible_nodes(graph): # --- Refine graph --- -def prune_trim_branches(graph, depth, trim_depth): - """ - Prunes all short branches from "graph" and trims branchs if applicable. A - short branch is a path between a leaf and junction node with a path length - smaller than depth. - - Parameters - ---------- - graph : networkx.Graph - Graph to be searched. - depth : int - Path length that determines whether a branch is short. - - Returns - ------- - networkx.Graph - Graph with short branches pruned. - - """ - remove_nodes = [] - n_nodes_trimmed = 0 - for leaf in get_leafs(graph): - nodes, n_trimmed = inspect_branch(graph, leaf, depth, trim_depth) - remove_nodes.extend(nodes) - n_nodes_trimmed += n_trimmed - graph.remove_nodes_from(remove_nodes) - return graph, n_nodes_trimmed - - -def inspect_branch(graph, leaf, depth, trim_depth): +def prune_trim(graph, leaf, prune_depth, trim_depth): """ Determines whether the branch emanating from "leaf" should be pruned and returns nodes that should be pruned. If applicable (i.e. trim_depth > 0), - trims the branch by "trim_depth" microns. + every branch is trimmed by"trim_depth" microns. Parameters ---------- @@ -301,10 +251,10 @@ def inspect_branch(graph, leaf, depth, trim_depth): leaf : int Leaf node being inspected to determine whether it is the endpoint of a short branch that should be pruned. - depth : int + prune_depth : int Path length microns that determines whether a branch is short. trim_depth : float - Depth in microns to trim branch. + Maximum path length (in microns) to trim from "branch". Returns ------- @@ -314,35 +264,53 @@ def inspect_branch(graph, leaf, depth, trim_depth): """ # Check whether to prune - path = [leaf] + branch = [leaf] node_spacing = [] - for (i, j) in nx.dfs_edges(graph, source=leaf, depth_limit=depth): + for (i, j) in nx.dfs_edges(graph, source=leaf, depth_limit=prune_depth): node_spacing.append(compute_dist(graph, i, j)) if graph.degree(j) > 2: - return path, 0 + return branch, 0 elif graph.degree(j) == 2: - path.append(j) - elif np.sum(node_spacing) > depth: + branch.append(j) + elif np.sum(node_spacing) > prune_depth: break # Check whether to trim spacing = np.mean(node_spacing) if trim_depth > 0 and graph.number_of_nodes() > 3 * trim_depth / spacing: - trim_nodes = trim_branch(graph, path, trim_depth) + trim_nodes = trim_branch(graph, branch, trim_depth) return trim_nodes, len(trim_nodes) else: return [], 0 -def trim_branch(graph, path, trim_depth): - branch_length = 0 - for i in range(1, len(path)): - xyz_1 = graph.nodes[path[i - 1]]["xyz"] - xyz_2 = graph.nodes[path[i]]["xyz"] - branch_length += geometry.dist(xyz_1, xyz_2) - if branch_length > trim_depth: +def trim_branch(graph, branch, trim_depth): + """ + Trims a branch of a graph based on a specified depth. + + Parameters + ---------- + graph : networkx.Graph + Graph to be searched. + branch : list[int] + List of nodes representing a path through the graph to be trimmed. + trim_depth : float + Maximum path length (in microns) to trim from "branch". + + Returns + ------- + list[int] + Trimmed branch. + + """ + path_length = 0 + for i in range(1, len(branch)): + xyz_1 = graph.nodes[branch[i - 1]]["xyz"] + xyz_2 = graph.nodes[branch[i]]["xyz"] + path_length += geometry.dist(xyz_1, xyz_2) + if path_length > trim_depth: break - return path[0:i] + return branch[0:i] def compute_dist(graph, i, j): @@ -366,82 +334,6 @@ def compute_dist(graph, i, j): return geometry.dist(graph.nodes[i]["xyz"], graph.nodes[j]["xyz"]) -def prune_short_connectors(graph, length=8): - """ " - Prunes shorts paths (i.e. connectors) between junctions nodes and the nbhd - about the junctions. - - Parameters - ---------- - graph : netowrkx.Graph - Graph to be inspected. - length : int, optional - Upper bound on the distance that defines a connector path to be - pruned. The default is 8. - - Returns - ------- - list[tuple] - Graph with connectors pruned. - list[np.ndarray] - List of xyz coordinates of centroids of connectors. - - """ - junctions = [j for j in graph.nodes if graph.degree[j] > 2] - pruned_centroids = [] - pruned_nodes = set() - cnt = 0 - while len(junctions): - # Search nbhd - j = junctions.pop() - junction_nbs = [] - for _, i in nx.dfs_edges(graph, source=j, depth_limit=length): - if i in junctions: - junction_nbs.append(i) - - # Store nodes to be pruned - for nb in junction_nbs: - connector = list(nx.shortest_path(graph, source=j, target=nb)) - nbhd = set(nx.dfs_tree(graph, source=nb, depth_limit=5)) - centroid = connector[len(connector) // 2] - if not ignore_connector(graph, centroid, 16 + length // 2): - pruned_nodes.update(nbhd.union(set(connector))) - pruned_centroids.append(graph.nodes[centroid]["xyz"]) - - if len(junction_nbs) > 0: - nbhd = set(nx.dfs_tree(graph, source=j, depth_limit=5)) - pruned_nodes.update(nbhd) - cnt += 1 - - # Finish - graph.remove_nodes_from(list(pruned_nodes)) - return graph - - -def ignore_connector(graph, root, depth): - """ - Determines whether the connector is in a region with lots of branching. - - Parameters - ---------- - graph : networkx.Graph - Graph to be searched. - root : int - Midpoint of connector. - - Returns - ------- - bool - Indication of whether connector is in a region with lots of branching. - - """ - n_branching_points = 0 - for i in nx.dfs_tree(graph, source=root, depth_limit=depth): - if graph.degree[i] > 2: - n_branching_points += 1 - return True if n_branching_points > 2 else False - - def __smooth_branch(swc_dict, attrs, edges, nbs, root, j): """ Smoothes a branch then updates "swc_dict" and "edges" with the new xyz diff --git a/src/deep_neurographs/utils/img_util.py b/src/deep_neurographs/utils/img_util.py index d336e20..39fa48d 100644 --- a/src/deep_neurographs/utils/img_util.py +++ b/src/deep_neurographs/utils/img_util.py @@ -89,7 +89,7 @@ def read(img, voxel, shape, from_center=True): """ start, end = get_start_end(voxel, shape, from_center=from_center) return deepcopy( - img[start[0] : end[0], start[1] : end[1], start[2] : end[2]] + img[start[0]: end[0], start[1]: end[1], start[2]: end[2]] ) diff --git a/src/deep_neurographs/utils/ml_util.py b/src/deep_neurographs/utils/ml_util.py index 6001499..b78bb65 100644 --- a/src/deep_neurographs/utils/ml_util.py +++ b/src/deep_neurographs/utils/ml_util.py @@ -17,13 +17,10 @@ from deep_neurographs.machine_learning import ( feature_generation, + datasets, graph_datasets, heterograph_datasets, ) -from deep_neurographs.machine_learning.datasets import ( - MultiModalDataset, - ProposalDataset, -) from deep_neurographs.machine_learning.models import ( FeedForwardNet, MultiModalNet, @@ -67,25 +64,21 @@ def init_model(model_type): return MultiModalNet(n_features) -def load_model(model_type, path): +def load_model(path): """ Loads the parameters of a machine learning model. Parameters ---------- - model_type : str - Type of machine learning model. path : str Path to the model parameters. Returns ------- ... + """ - if model_type in ["AdaBoost", "RandomForest"]: - return joblib.load(path) - else: - return torch.load(path) + return joblib.load(path) if ".joblib" in path else torch.load(path) # --- dataset utils --- @@ -120,32 +113,19 @@ def init_dataset( if "Hetero" in model_type: assert computation_graph, "Must provide computation graph!" dataset = heterograph_datasets.init( - neurograph, computation_graph, features + neurograph, features, computation_graph ) elif "Graph" in model_type: dataset = graph_datasets.init(neurograph, features) else: - dataset = init_proposal_dataset( + dataset = datasets.init( neurograph, features, model_type, sample_ids=sample_ids ) return dataset -def init_proposal_dataset(neurographs, features, model_type, sample_ids=None): - # Extract features - inputs, targets, idx_transforms = feature_generation.get_matrix( - neurographs, features["proposals"], model_type, sample_ids=sample_ids - ) - dataset = { - "dataset": get_dataset(inputs, targets, model_type), - "block_to_idxs": idx_transforms["block_to_idxs"], - "idx_to_edge": idx_transforms["idx_to_edge"], - } - return dataset - - +""" def get_dataset(inputs, targets, model_type): - """ Gets classification model to be fit. Parameters @@ -164,7 +144,6 @@ def get_dataset(inputs, targets, model_type): ------- ... - """ if model_type == "FeedForwardNet": dataset = ProposalDataset(inputs, targets) elif model_type == "MultiModalNet": @@ -172,6 +151,7 @@ def get_dataset(inputs, targets, model_type): else: dataset = {"inputs": inputs, "targets": targets} return dataset +""" # --- miscellaneous --- @@ -226,5 +206,5 @@ def get_kfolds(filenames, k): def get_batches(my_list, batch_size): batches = list() for start in range(0, len(my_list), batch_size): - batches.append(my_list[start : min(start + batch_size, len(my_list))]) + batches.append(my_list[start: min(start + batch_size, len(my_list))]) return batches From 7b844e06a2023aacbda9ae7b3d40113a907b6e98 Mon Sep 17 00:00:00 2001 From: anna-grim Date: Sat, 14 Sep 2024 20:54:53 +0000 Subject: [PATCH 2/4] simplifications --- src/deep_neurographs/generate_proposals.py | 1 - src/deep_neurographs/intake.py | 87 +++++++++++++++------- src/deep_neurographs/neurograph.py | 61 +++++++-------- src/deep_neurographs/utils/graph_util.py | 5 +- src/deep_neurographs/utils/img_util.py | 11 +-- src/deep_neurographs/utils/swc_util.py | 32 ++------ src/deep_neurographs/utils/util.py | 28 ++++--- src/deep_neurographs/visualization.py | 8 +- 8 files changed, 124 insertions(+), 109 deletions(-) diff --git a/src/deep_neurographs/generate_proposals.py b/src/deep_neurographs/generate_proposals.py index fbacec9..fd7e532 100644 --- a/src/deep_neurographs/generate_proposals.py +++ b/src/deep_neurographs/generate_proposals.py @@ -297,7 +297,6 @@ def run_trimming(neurograph, proposals, radius): elif neurograph.dist(i, j) > radius: neurograph.remove_proposal(proposal) n_endpoints_trimmed += 1 if trim_bool else 0 - print("# Endpoints Trimmed:", n_endpoints_trimmed) return neurograph diff --git a/src/deep_neurographs/intake.py b/src/deep_neurographs/intake.py index fcf08ba..f576dd0 100644 --- a/src/deep_neurographs/intake.py +++ b/src/deep_neurographs/intake.py @@ -29,12 +29,71 @@ TRIM_DEPTH = 0 +class GraphBuilder: + """ + Class that is used to build an instance of FragmentsGraph. + + """ + def __init__( + self, + anisotropy=[1.0, 1.0, 1.0], + img_patch_origin=None, + img_patch_shape=None, + min_size=MIN_SIZE, + node_spacing=NODE_SPACING, + progress_bar=False, + prune_depth=PRUNE_DEPTH, + smooth=SMOOTH, + trim_depth=TRIM_DEPTH, + ): + """ + Builds a FragmentsGraph by reading swc files stored either on the + cloud or local machine. + + Parameters + ---------- + swc_pointer : dict, list, str + Pointer to swc files used to build an instance of FragmentsGraph, + see "swc_util.Reader" for further documentation. + anisotropy : list[float], optional + Scaling factors applied to xyz coordinates to account for + anisotropy of microscope. The default is [1.0, 1.0, 1.0]. + image_patch_origin : list[float], optional + An xyz coordinate which is the upper, left, front corner of the + image patch that contains the swc files. The default is None. + image_patch_shape : list[float], optional + Shape of the image patch which contains the swc files. The default + is None. + min_size : int, optional + Minimum cardinality of swc files that are stored in NeuroGraph. The + default is the global variable "MIN_SIZE". + node_spacing : int, optional + Spacing (in microns) between nodes. The default is the global + variable "NODE_SPACING". + progress_bar : bool, optional + Indication of whether to print out a progress bar during build. + The default is False. + prune_depth : int, optional + Branches less than "prune_depth" microns are pruned if "prune" is + True. The default is the global variable "PRUNE_DEPTH". + smooth : bool, optional + Indication of whether to smooth branches from swc files. The + default is the global variable "SMOOTH". + + Returns + ------- + NeuroGraph + Neurograph generated from swc files. + + """ + pass + + # --- Build graph wrappers --- def build_neurograph_from_local( anisotropy=[1.0, 1.0, 1.0], img_patch_origin=None, img_patch_shape=None, - img_path=None, min_size=MIN_SIZE, node_spacing=NODE_SPACING, progress_bar=False, @@ -58,9 +117,6 @@ def build_neurograph_from_local( image_patch_shape : list[float], optional The xyz dimensions of the bounding box which contains the swc files. The default is None. - img_path : str, optional - Path to image which is assumed to be stored in a Google Bucket. The - default is None. min_size : int, optional Minimum cardinality of swc files that are stored in NeuroGraph. The default is the global variable "MIN_SIZE". @@ -91,23 +147,12 @@ def build_neurograph_from_local( assert swc_dir or swc_paths, "Provide swc_dir or swc_paths!" img_bbox = img_util.get_bbox(img_patch_origin, img_patch_shape) paths = util.list_paths(swc_dir, ext=".swc") if swc_dir else swc_paths - swc_dicts, paths = process_local_paths( - paths, anisotropy=anisotropy, min_size=min_size, img_bbox=img_bbox - ) - - # Filter swc_dicts - if img_bbox: - filtered_swc_dicts = [] - for swc_dict in swc_dicts: - if util.is_list_contained(img_bbox, swc_dict["xyz"]): - filtered_swc_dicts.append(swc_dict) - swc_dicts = filtered_swc_dicts + swc_dicts, paths = process_local_paths(paths, anisotropy, min_size) # Build neurograph neurograph = build_neurograph( swc_dicts, img_bbox=img_bbox, - img_path=img_path, min_size=min_size, node_spacing=node_spacing, progress_bar=progress_bar, @@ -123,7 +168,6 @@ def build_neurograph_from_gcs_zips( bucket_name, gcs_path, anisotropy=[1.0, 1.0, 1.0], - img_path=None, min_size=MIN_SIZE, node_spacing=NODE_SPACING, prune_depth=PRUNE_DEPTH, @@ -142,9 +186,6 @@ def build_neurograph_from_gcs_zips( anisotropy : list[float], optional Scaling factors applied to xyz coordinates to account for anisotropy of microscope. The default is [1.0, 1.0, 1.0]. - img_path : str, optional - Path to image stored GCS Bucket that swc files were generated from. - The default is None. min_size : int, optional Minimum cardinality of swc files that are stored in NeuroGraph. The default is the global variable "MIN_SIZE". @@ -169,7 +210,6 @@ def build_neurograph_from_gcs_zips( swc_dicts = download_gcs_zips(bucket_name, gcs_path, min_size, anisotropy) neurograph = build_neurograph( swc_dicts, - img_path=img_path, min_size=min_size, node_spacing=node_spacing, prune_depth=prune_depth, @@ -236,7 +276,6 @@ def download_gcs_zips(bucket_name, gcs_path, min_size, anisotropy): def build_neurograph( swc_dicts, img_bbox=None, - img_path=None, min_size=MIN_SIZE, node_spacing=NODE_SPACING, swc_paths=None, @@ -265,9 +304,7 @@ def build_neurograph( print("# nodes:", util.reformat_number(n_nodes)) print("# edges:", util.reformat_number(n_edges)) - neurograph = NeuroGraph( - img_path=img_path, node_spacing=node_spacing, swc_paths=swc_paths - ) + neurograph = NeuroGraph(node_spacing=node_spacing) while len(irreducibles): irreducible_set = irreducibles.pop() neurograph.add_component(irreducible_set) diff --git a/src/deep_neurographs/neurograph.py b/src/deep_neurographs/neurograph.py index 796e2cb..b0afedd 100644 --- a/src/deep_neurographs/neurograph.py +++ b/src/deep_neurographs/neurograph.py @@ -15,20 +15,15 @@ import networkx as nx import numpy as np -import tensorstore as ts from scipy.spatial import KDTree from deep_neurographs import generate_proposals, geometry from deep_neurographs.geometry import dist as get_dist from deep_neurographs.geometry import get_midpoint -from deep_neurographs.machine_learning.groundtruth_generation import ( - init_targets, -) +from deep_neurographs.machine_learning.groundtruth_generation import init_targets from deep_neurographs.utils import graph_util as gutil from deep_neurographs.utils import img_util, swc_util, util -SUPPORTED_LABEL_MASK_TYPES = [dict, np.array, ts.TensorStore] - class NeuroGraph(nx.Graph): """ @@ -37,38 +32,39 @@ class NeuroGraph(nx.Graph): """ - def __init__( - self, - img_bbox=None, - swc_paths=None, - img_path=None, - label_mask=None, - node_spacing=1, - train_model=False, - ): + def __init__(self, img_bbox=None, node_spacing=1): + """ + Initializes an instance of NeuroGraph. + + Parameters + ---------- + img_bbox : dict or None, optional + Dictionary with the keys "min" and "max" which specify a bounding + box in an image. The default is None. + node_spacing : int, optional + Spacing (in microns) between nodes. The default is 1. + + Returns + ------- + None + + """ super(NeuroGraph, self).__init__() - # Initialize paths - self.img_path = img_path - self.label_mask = label_mask - self.swc_paths = swc_paths + # General class attributes + self.node_spacing = node_spacing + self.merged_ids = set() + self.soma_ids = dict() self.swc_ids = set() + self.xyz_to_edge = dict() - # Initialize node and edge sets + # Nodes and Edges self.leafs = set() self.junctions = set() self.proposals = set() self.target_edges = set() self.node_cnt = 0 - self.node_spacing = node_spacing - self.soma_ids = dict() - # Initialize data structures for proposals - self.xyz_to_edge = dict() - self.kdtree = None - self.leaf_kdtree = None - self.merged_ids = set() - - # Initialize bounding box (if exists) + # Bounding box (if applicable) self.bbox = img_bbox if self.bbox: self.origin = img_bbox["min"].astype(int) @@ -279,8 +275,6 @@ def generate_proposals( complex_bool=False, long_range_bool=False, proposals_per_leaf=3, - optimize=False, - optimization_depth=10, return_trimmed_proposals=False, trim_endpoints_bool=False, ): @@ -300,11 +294,6 @@ def generate_proposals( proposals_per_leaf : int, optional Maximum number of proposals generated for each leaf. The default is False. - optimize : bool, optional - Indication of whether to optimize proposal alignment to image. The - default is False. - optimization_depth : int, optional - Depth to check during optimization. The default is False. return_trimmed_proposals, optional Indication of whether to return trimmed proposal ids. The default is False. diff --git a/src/deep_neurographs/utils/graph_util.py b/src/deep_neurographs/utils/graph_util.py index e79e3c5..d5d2961 100644 --- a/src/deep_neurographs/utils/graph_util.py +++ b/src/deep_neurographs/utils/graph_util.py @@ -28,7 +28,7 @@ import numpy as np from deep_neurographs import geometry -from deep_neurographs.utils import swc_util, util +from deep_neurographs.utils import img_util, swc_util, util def get_irreducibles( @@ -109,7 +109,7 @@ def clip_branches(graph, bbox): if bbox: delete_nodes = set() for i in graph.nodes: - xyz = util.to_voxels(graph.nodes[i]["xyz"]) + xyz = img_util.to_voxels(graph.nodes[i]["xyz"]) if not util.is_contained(bbox, xyz): delete_nodes.add(i) graph.remove_nodes_from(delete_nodes) @@ -303,6 +303,7 @@ def trim_branch(graph, branch, trim_depth): Trimmed branch. """ + i = 1 path_length = 0 for i in range(1, len(branch)): xyz_1 = graph.nodes[branch[i - 1]]["xyz"] diff --git a/src/deep_neurographs/utils/img_util.py b/src/deep_neurographs/utils/img_util.py index 39fa48d..94f96e0 100644 --- a/src/deep_neurographs/utils/img_util.py +++ b/src/deep_neurographs/utils/img_util.py @@ -15,7 +15,7 @@ import tensorstore as ts from skimage.color import label2rgb -ANISOTROPY = np.array([0.748, 0.748, 1.0]) +ANISOTROPY = [0.748, 0.748, 1.0] SUPPORTED_DRIVERS = ["neuroglancer_precomputed", "n5", "zarr"] @@ -330,7 +330,7 @@ def patch_to_img(voxel, patch_centroid, patch_dims): return np.round(voxel + patch_centroid - half_patch_dims).astype(int) -def to_world(voxel, shift=[0, 0, 0]): +def to_world(voxel, anisotropy=ANISOTROPY, shift=[0, 0, 0]): """ Converts coordinates from voxels to world. @@ -347,7 +347,7 @@ def to_world(voxel, shift=[0, 0, 0]): Converted coordinates. """ - return tuple([voxel[i] * ANISOTROPY[i] - shift[i] for i in range(3)]) + return tuple([voxel[i] * anisotropy[i] - shift[i] for i in range(3)]) def to_voxels(xyz, anisotropy=ANISOTROPY, downsample_factor=0): @@ -371,8 +371,9 @@ def to_voxels(xyz, anisotropy=ANISOTROPY, downsample_factor=0): Coordinates converted to voxels. """ - downsample_factor = 1 / 2 ** downsample_factor - return (downsample_factor * (xyz / np.array(anisotropy))).astype(int) + downsample_factor = 1.0 / 2 ** downsample_factor + voxel = downsample_factor * (xyz / np.array(anisotropy)) + return np.round(voxel).astype(int) # -- utils -- diff --git a/src/deep_neurographs/utils/swc_util.py b/src/deep_neurographs/utils/swc_util.py index f19e49e..be283c3 100644 --- a/src/deep_neurographs/utils/swc_util.py +++ b/src/deep_neurographs/utils/swc_util.py @@ -16,14 +16,11 @@ import networkx as nx import numpy as np -from deep_neurographs import geometry from deep_neurographs.utils import util # -- io util -- -def process_local_paths( - paths, anisotropy=[1.0, 1.0, 1.0], min_size=5, img_bbox=None -): +def process_local_paths(paths, anisotropy=[1.0, 1.0, 1.0], min_size=5): """ Iterates over a list of swc paths to swc file, then builds a dictionary where the keys are swc attributes (i.e. id, xyz, radius, pid) and values @@ -33,14 +30,13 @@ def process_local_paths( ---------- paths : list[str] List of paths to swc files to be parsed. + anisotropy : list[float], optional + Scaling factors applied to xyz coordinates to account for anisotropy + of microscope. The default is [1.0, 1.0, 1.0]. min_size : int, optional Threshold on the number of nodes contained in an swc file. Only swc files with more than "min_size" nodes are stored in "swc_dicts". The default is 3. - img_bbox : dict, optional - Dictionary with the keys "min" and "max" which specify a bounding box - in an image. Only swc files with at least one node contained in - "img_bbox" are stored in "swc_dicts". The default is None. Returns ------- @@ -87,13 +83,12 @@ def parse_gcs_zip(zip_file, path, anisotropy=[1.0, 1.0, 1.0], min_size=0): # Parse contents contents = read_from_gcs_zip(zip_file, path) if len(contents) > min_size: - swc_dict = parse(contents, anisotropy=anisotropy) + swc_dict = parse(contents, anisotropy) else: swc_dict = {"id": []} # Store id - swc_id = util.get_swc_id(path) - swc_dict["swc_id"] = swc_id + swc_dict["swc_id"] = util.get_swc_id(path) return swc_dict @@ -127,9 +122,7 @@ def parse(contents, anisotropy=[1.0, 1.0, 1.0]): swc_dict["id"][i] = parts[0] swc_dict["radius"][i] = float(parts[-2]) swc_dict["pid"][i] = parts[-1] - swc_dict["xyz"][i] = read_xyz( - parts[2:5], anisotropy=anisotropy, offset=offset - ) + swc_dict["xyz"][i] = read_xyz(parts[2:5], anisotropy, offset) # Check whether radius is in nanometers if swc_dict["radius"][0] > 100: @@ -137,10 +130,6 @@ def parse(contents, anisotropy=[1.0, 1.0, 1.0]): return swc_dict -def reindex(arr, idxs): - return arr[idxs] - - def get_contents(swc_contents): offset = [0, 0, 0] for i, line in enumerate(swc_contents): @@ -437,10 +426,3 @@ def __add_attributes(swc_dict, graph): } nx.set_node_attributes(graph, attrs) return graph - - -# -- miscellaneous -- -def upd_edge(xyz, idxs): - idxs = np.array(idxs) - xyz[idxs] = geometry.smooth_branch(xyz[idxs], s=10) - return xyz diff --git a/src/deep_neurographs/utils/util.py b/src/deep_neurographs/utils/util.py index e7a3bc8..9d5061e 100644 --- a/src/deep_neurographs/utils/util.py +++ b/src/deep_neurographs/utils/util.py @@ -20,8 +20,6 @@ import numpy as np import psutil -from deep_neurographs.utils import img_util - # --- dictionary utils --- def remove_item(my_set, item): @@ -409,7 +407,7 @@ def read_txt(path): return f.read() -def read_metadata(path, anisotropy=[1.0, 1.0, 1.0]): +def read_metadata(path): """ Parses metadata file to extract the "chunk_origin" and "chunk_shape". @@ -417,9 +415,6 @@ def read_metadata(path, anisotropy=[1.0, 1.0, 1.0]): ---------- path : str Path to metadata file to be read. - anisotropy : list[float], optional - Anisotropy to be applied to values of interest that converts - coordinates from voxels to world. The default is [1.0, 1.0, 1.0]. Returns ------- @@ -428,9 +423,7 @@ def read_metadata(path, anisotropy=[1.0, 1.0, 1.0]): """ metadata = read_json(path) - origin = metadata["chunk_origin"] - chunk_origin = img_util.to_voxels(origin, anisotropy=anisotropy) - return chunk_origin.tolist(), metadata["chunk_shape"] + return metadata["chunk_origin"], metadata["chunk_shape"] def write_json(path, contents): @@ -500,7 +493,20 @@ def get_avg_std(data, weights=None): def is_contained(bbox, voxel): """ - Checks whether "xyz" is contained within "bbox". + Checks whether "voxel" is contained within "bbox". + + Parameters + ---------- + bbox : dict + Dictionary with the keys "min" and "max" which specify a bounding box + in an image. + voxel : ArrayLike + Voxel coordinate to be checked. + + Returns + ------- + bool + Inidcation of whether "voxel" is contained in "bbox". """ above = any(voxel >= bbox["max"]) @@ -516,7 +522,7 @@ def is_list_contained(bbox, voxels): ---------- bbox : dict Dictionary with the keys "min" and "max" which specify a bounding box - in the image. + in an image. voxels List of xyz coordinates to be checked. diff --git a/src/deep_neurographs/visualization.py b/src/deep_neurographs/visualization.py index 40ee03e..c560dd4 100644 --- a/src/deep_neurographs/visualization.py +++ b/src/deep_neurographs/visualization.py @@ -198,7 +198,7 @@ def plot_nodes(graph): z=xyz[:, 2], mode="markers", name="Nodes", - marker=dict(size=3, color="red"), + marker=dict(size=2, color="red"), ) @@ -228,9 +228,9 @@ def plot_edges(graph, edges, color=None, line_width=3.5): ) for i, j in edges: trace = go.Scatter3d( - x=graph.edges[i, j]["xyz"][::2, 0], - y=graph.edges[i, j]["xyz"][::2, 1], - z=graph.edges[i, j]["xyz"][::2, 2], + x=graph.edges[i, j]["xyz"][:, 0], + y=graph.edges[i, j]["xyz"][:, 1], + z=graph.edges[i, j]["xyz"][:, 2], mode="lines", line=line, name=f"({i},{j})", From 9e31236eda13e2bf384433dabc9e9ec092a1a8fe Mon Sep 17 00:00:00 2001 From: anna-grim Date: Sun, 15 Sep 2024 03:09:56 +0000 Subject: [PATCH 3/4] refactor: swc_reader class and graph_builder class --- src/deep_neurographs/generate_proposals.py | 5 +- src/deep_neurographs/geometry.py | 9 +- src/deep_neurographs/intake.py | 356 ++++----------- src/deep_neurographs/neurograph.py | 8 +- src/deep_neurographs/utils/graph_util.py | 4 +- src/deep_neurographs/utils/swc_util.py | 488 +++++++++++++-------- src/deep_neurographs/utils/util.py | 24 +- 7 files changed, 433 insertions(+), 461 deletions(-) diff --git a/src/deep_neurographs/generate_proposals.py b/src/deep_neurographs/generate_proposals.py index fd7e532..e2036ea 100644 --- a/src/deep_neurographs/generate_proposals.py +++ b/src/deep_neurographs/generate_proposals.py @@ -436,9 +436,8 @@ def compute_dot(branch_1, branch_2, idx_1, idx_2): """ # Initializations - origin = geometry.get_midpoint(branch_1[idx_1], branch_2[idx_2]) - b1 = branch_1 - origin - b2 = branch_2 - origin + b1 = branch_1 - geometry.midpoint(branch_1[idx_1], branch_2[idx_2]) + b2 = branch_2 - geometry.midpoint(branch_1[idx_1], branch_2[idx_2]) # Main dot_10 = np.dot(tangent(b1, idx_1, 10), tangent(b2, idx_2, 10)) diff --git a/src/deep_neurographs/geometry.py b/src/deep_neurographs/geometry.py index dd701f7..989806a 100644 --- a/src/deep_neurographs/geometry.py +++ b/src/deep_neurographs/geometry.py @@ -44,13 +44,10 @@ def get_directional(branches, i, origin, depth): branches = shift_branches(branches, origin) if len(branches) == 1: return tangent(get_subarray(branches[0], depth)) - elif len(branches) == 2: + else: branch_1 = get_subarray(branches[0], depth) branch_2 = get_subarray(branches[1], depth) - branch = np.concatenate((branch_1, branch_2)) - return tangent(branch) - else: - return np.array([0, 0, 0]) + return tangent(np.concatenate((branch_1, branch_2))) def get_subarray(arr, depth): @@ -153,7 +150,7 @@ def normal(xyz): return VT[-1] / np.linalg.norm(VT[-1]) -def get_midpoint(xyz_1, xyz_2): +def midpoint(xyz_1, xyz_2): """ Computes the midpoint between "xyz_1" and "xyz_2". diff --git a/src/deep_neurographs/intake.py b/src/deep_neurographs/intake.py index f576dd0..f48bbb8 100644 --- a/src/deep_neurographs/intake.py +++ b/src/deep_neurographs/intake.py @@ -10,21 +10,16 @@ from concurrent.futures import ProcessPoolExecutor, as_completed from time import time - -from google.cloud import storage from tqdm import tqdm from deep_neurographs.neurograph import NeuroGraph from deep_neurographs.utils import graph_util as gutil -from deep_neurographs.utils import img_util, util -from deep_neurographs.utils.swc_util import ( - process_gcs_zip, - process_local_paths, -) +from deep_neurographs.utils import img_util, swc_util, util + MIN_SIZE = 30 NODE_SPACING = 2 -SMOOTH = True +SMOOTH_BOOL = True PRUNE_DEPTH = 25 TRIM_DEPTH = 0 @@ -43,18 +38,15 @@ def __init__( node_spacing=NODE_SPACING, progress_bar=False, prune_depth=PRUNE_DEPTH, - smooth=SMOOTH, + smooth_bool=SMOOTH_BOOL, trim_depth=TRIM_DEPTH, ): """ Builds a FragmentsGraph by reading swc files stored either on the - cloud or local machine. + cloud or local machine, then extracting the irreducible components. Parameters ---------- - swc_pointer : dict, list, str - Pointer to swc files used to build an instance of FragmentsGraph, - see "swc_util.Reader" for further documentation. anisotropy : list[float], optional Scaling factors applied to xyz coordinates to account for anisotropy of microscope. The default is [1.0, 1.0, 1.0]. @@ -79,6 +71,9 @@ def __init__( smooth : bool, optional Indication of whether to smooth branches from swc files. The default is the global variable "SMOOTH". + trim_depth : float, optional + Maximum path length (in microns) to trim from "branch". The default + is the global variable "TRIM_DEPTH". Returns ------- @@ -86,268 +81,91 @@ def __init__( Neurograph generated from swc files. """ - pass - - -# --- Build graph wrappers --- -def build_neurograph_from_local( - anisotropy=[1.0, 1.0, 1.0], - img_patch_origin=None, - img_patch_shape=None, - min_size=MIN_SIZE, - node_spacing=NODE_SPACING, - progress_bar=False, - prune_depth=PRUNE_DEPTH, - trim_depth=TRIM_DEPTH, - smooth=SMOOTH, - swc_dir=None, - swc_paths=None, -): - """ - Builds a neurograph from swc files on the local machine. - - Parameters - ---------- - anisotropy : list[float], optional - Scaling factors applied to xyz coordinates to account for anisotropy - of microscope. The default is [1.0, 1.0, 1.0]. - image_patch_origin : list[float], optional - An xyz coordinate in the image which is the upper, left, front corner - of am image patch that contains the swc files. The default is None. - image_patch_shape : list[float], optional - The xyz dimensions of the bounding box which contains the swc files. - The default is None. - min_size : int, optional - Minimum cardinality of swc files that are stored in NeuroGraph. The - default is the global variable "MIN_SIZE". - node_spacing : int, optional - Spacing (in microns) between nodes. The default is the global variable - "NODE_SPACING". - progress_bar : bool, optional - Indication of whether to print out a progress bar during build. The - default is False. - prune_depth : int, optional - Branches less than "prune_depth" microns are pruned if "prune" is - True. The default is the global variable "PRUNE_DEPTH". - smooth : bool, optional - Indication of whether to smooth branches from swc files. The default - is the global variable "SMOOTH". - swc_dir : str, optional - Path to a directory containing swc files. The default is None. - swc_paths : list[str], optional - List of paths to swc files. The default is None. - - Returns - ------- - NeuroGraph - Neurograph generated from swc files stored on local machine. - - """ - # Process swc files - assert swc_dir or swc_paths, "Provide swc_dir or swc_paths!" - img_bbox = img_util.get_bbox(img_patch_origin, img_patch_shape) - paths = util.list_paths(swc_dir, ext=".swc") if swc_dir else swc_paths - swc_dicts, paths = process_local_paths(paths, anisotropy, min_size) - - # Build neurograph - neurograph = build_neurograph( - swc_dicts, - img_bbox=img_bbox, - min_size=min_size, - node_spacing=node_spacing, - progress_bar=progress_bar, - prune_depth=prune_depth, - trim_depth=trim_depth, - smooth=smooth, - swc_paths=paths, - ) - return neurograph - - -def build_neurograph_from_gcs_zips( - bucket_name, - gcs_path, - anisotropy=[1.0, 1.0, 1.0], - min_size=MIN_SIZE, - node_spacing=NODE_SPACING, - prune_depth=PRUNE_DEPTH, - trim_depth=TRIM_DEPTH, - smooth=SMOOTH, -): - """ - Builds a neurograph from a GCS bucket that contain of zips of swc files. - - Parameters - ---------- - bucket_name : str - Name of GCS bucket where zips of swc files are stored. - gcs_path : str - Path within GCS bucket to directory containing zips. - anisotropy : list[float], optional - Scaling factors applied to xyz coordinates to account for anisotropy - of microscope. The default is [1.0, 1.0, 1.0]. - min_size : int, optional - Minimum cardinality of swc files that are stored in NeuroGraph. The - default is the global variable "MIN_SIZE". - node_spacing : int, optional - Spacing (in microns) between nodes. The default is the global variable - "NODE_SPACING". - prune_depth : int, optional - Branches less than "prune_depth" microns are pruned if "prune" is - True. The default is the global variable "PRUNE_DEPTH". - smooth : bool, optional - Indication of whether to smooth branches from swc files. The default - is the global variable "SMOOTH". - - Returns - ------- - NeuroGraph - Neurograph generated from zips of swc files stored in a GCS bucket. - - """ - print("\nBuild NeuroGraph...") - t0 = time() - swc_dicts = download_gcs_zips(bucket_name, gcs_path, min_size, anisotropy) - neurograph = build_neurograph( - swc_dicts, - min_size=min_size, - node_spacing=node_spacing, - prune_depth=prune_depth, - trim_depth=trim_depth, - smooth=smooth, - ) - t, unit = util.time_writer(time() - t0) - print(f"Memory Consumption: {round(util.get_memory_usage(), 4)} GBs") - print(f"Module Runtime: {round(t, 4)} {unit} \n") - - return neurograph - - -# -- Read swc files -- -def download_gcs_zips(bucket_name, gcs_path, min_size, anisotropy): - """ - Downloads swc files from zips stored in a GCS bucket. + self.anisotropy = anisotropy + self.min_size = min_size + self.node_spacing = node_spacing + self.progress_bar = progress_bar + self.prune_depth = prune_depth + self.smooth_bool = smooth_bool + self.trim_depth = trim_depth + + self.img_bbox = img_util.get_bbox(img_patch_origin, img_patch_shape) + self.reader = swc_util.Reader(anisotropy, min_size) + + def run(self, swc_pointer): + """ + Builds a FragmentsGraph by reading swc files stored either on the + cloud or local machine, then extracting the irreducible components. - Parameters - ---------- - bucket_name : str - Name of GCS bucket where zips are stored. - gcs_path : str - Path within GCS bucket to directory containing zips. - min_size : int - Minimum cardinality of swc files that are stored in NeuroGraph. - anisotropy : list[float] - Scaling factors applied to xyz coordinates to account for anisotropy - of microscope. + Parameters + ---------- + swc_pointer : dict, list, str + Pointer to swc files used to build an instance of FragmentsGraph, + see "swc_util.Reader" for further documentation. - Returns - ------- - swc_dicts : list + Returns + ------- + NeuroGraph + Neurograph generated from swc files. - """ - # Initializations - bucket = storage.Client().bucket(bucket_name) - zip_paths = util.list_gcs_filenames(bucket, gcs_path, ".zip") - - # Main - with ProcessPoolExecutor() as executor: - # Assign processes - processes = [] - for path in tqdm(zip_paths, desc="Download SWCs"): - zip_content = bucket.blob(path).download_as_bytes() - processes.append( - executor.submit( - process_gcs_zip, zip_content, anisotropy, min_size + """ + # Initializations + t0 = time() + swc_dicts = self.reader.load(swc_pointer) + irreducibles, n_nodes, n_edges = self.get_irreducibles(swc_dicts) + + # Build FragmentsGraph + neurograph = NeuroGraph(node_spacing=self.node_spacing) + while len(irreducibles): + irreducible_set = irreducibles.pop() + neurograph.add_component(irreducible_set) + + # Report results + if self.progress_bar: + # Graph size + n_components = util.reformat_number(len(irreducibles)) + print("\nGraph Overview...") + print("# connected components:", n_connected_components) + print("# nodes:", util.reformat_number(n_nodes)) + print("# edges:", util.reformat_number(n_edges)) + + # Memory and runtime + usage = round(util.get_memory_usage(), 2) + t, unit = util.time_writer(time() - t0) + print(f"Memory Consumption: {usage} GBs") + print(f"Module Runtime: {round(t, 4)} {unit} \n") + return neurograph + + def get_irreducibles(self, swc_dicts): + with ProcessPoolExecutor() as executor: + # Assign Processes + i = 0 + processes = [None] * len(swc_dicts) + while swc_dicts: + swc_dict = swc_dicts.pop() + processes[i] = executor.submit( + gutil.get_irreducibles, + swc_dict, + self.min_size, + self.img_bbox, + self.prune_depth, + self.smooth_bool, + self.trim_depth, ) - ) - - # Store result - swc_dicts = [] - for process in as_completed(processes): - try: - result = process.result() - swc_dicts.extend(result) - except Exception as e: - print(type(e), e) - return swc_dicts - - -# -- Build neurograph --- -def build_neurograph( - swc_dicts, - img_bbox=None, - min_size=MIN_SIZE, - node_spacing=NODE_SPACING, - swc_paths=None, - progress_bar=True, - prune_depth=PRUNE_DEPTH, - trim_depth=TRIM_DEPTH, - smooth=SMOOTH, -): - # Extract irreducibles - irreducibles, n_nodes, n_edges = get_irreducibles( - swc_dicts, - bbox=img_bbox, - min_size=min_size, - progress_bar=progress_bar, - prune_depth=prune_depth, - trim_depth=trim_depth, - smooth=smooth, - ) - - # Build neurograph - if progress_bar: - print("\nGraph Overview...") - print( - "# connected components:", util.reformat_number(len(irreducibles)) - ) - print("# nodes:", util.reformat_number(n_nodes)) - print("# edges:", util.reformat_number(n_edges)) - - neurograph = NeuroGraph(node_spacing=node_spacing) - while len(irreducibles): - irreducible_set = irreducibles.pop() - neurograph.add_component(irreducible_set) - return neurograph - - -def get_irreducibles( - swc_dicts, - bbox=None, - min_size=MIN_SIZE, - progress_bar=True, - prune_depth=PRUNE_DEPTH, - trim_depth=TRIM_DEPTH, - smooth=SMOOTH, -): - with ProcessPoolExecutor() as executor: - # Assign Processes - i = 0 - processes = [None] * len(swc_dicts) - while swc_dicts: - swc_dict = swc_dicts.pop() - processes[i] = executor.submit( - gutil.get_irreducibles, - swc_dict, - min_size, - bbox, - prune_depth, - trim_depth, - smooth, - ) - i += 1 - - # Store results - irreducibles = [] - n_nodes, n_edges = 0, 0 - for process in tqdm(as_completed(processes), desc="Extract Graphs"): - irreducibles_i = process.result() - irreducibles.extend(irreducibles_i) - n_nodes += count_nodes(irreducibles_i) - n_edges += count_edges(irreducibles_i) - return irreducibles, n_nodes, n_edges + i += 1 + + # Store results + irreducibles = [] + n_nodes, n_edges = 0, 0 + for process in tqdm(as_completed(processes), desc="Extract Graphs"): + irreducibles_i = process.result() + irreducibles.extend(irreducibles_i) + n_nodes += count_nodes(irreducibles_i) + n_edges += count_edges(irreducibles_i) + return irreducibles, n_nodes, n_edges +# --- utils --- def count_nodes(irreducibles): """ Counts the number of nodes in "irreducibles". diff --git a/src/deep_neurographs/neurograph.py b/src/deep_neurographs/neurograph.py index b0afedd..11628ee 100644 --- a/src/deep_neurographs/neurograph.py +++ b/src/deep_neurographs/neurograph.py @@ -19,7 +19,6 @@ from deep_neurographs import generate_proposals, geometry from deep_neurographs.geometry import dist as get_dist -from deep_neurographs.geometry import get_midpoint from deep_neurographs.machine_learning.groundtruth_generation import init_targets from deep_neurographs.utils import graph_util as gutil from deep_neurographs.utils import img_util, swc_util, util @@ -51,6 +50,7 @@ def __init__(self, img_bbox=None, node_spacing=1): """ super(NeuroGraph, self).__init__() # General class attributes + self.leaf_kdtree = None self.node_spacing = node_spacing self.merged_ids = set() self.soma_ids = dict() @@ -534,7 +534,7 @@ def query_kdtree(self, xyz, d, node_type): # --- Proposal util --- def n_proposals(self): """ - Computes number of edges proposals in the graph. + Counts the number of proposals. Parameters ---------- @@ -543,7 +543,7 @@ def n_proposals(self): Returns ------- int - Number of edge proposals in the graph. + Number of proposals in the graph. """ return len(self.proposals) @@ -582,7 +582,7 @@ def proposal_length(self, proposal): def proposal_midpoint(self, proposal): i, j = tuple(proposal) - return get_midpoint(self.nodes[i]["xyz"], self.nodes[j]["xyz"]) + return geometry.midpoint(self.nodes[i]["xyz"], self.nodes[j]["xyz"]) def proposal_radii(self, proposal): """ diff --git a/src/deep_neurographs/utils/graph_util.py b/src/deep_neurographs/utils/graph_util.py index d5d2961..363fd24 100644 --- a/src/deep_neurographs/utils/graph_util.py +++ b/src/deep_neurographs/utils/graph_util.py @@ -148,7 +148,7 @@ def prune_trim_branches(graph, prune_depth, trim_depth): return graph, n_nodes_trimmed -def __get_irreducibles(graph, swc_dict, smooth): +def __get_irreducibles(graph, swc_dict, smooth_bool): """ Gets the irreducible components of "graph". @@ -189,7 +189,7 @@ def __get_irreducibles(graph, swc_dict, smooth): attrs = upd_edge_attrs(swc_dict, attrs, j) if j in leafs or j in junctions: attrs = to_numpy(attrs) - if smooth: + if smooth_bool: swc_dict, edges = __smooth_branch( swc_dict, attrs, edges, nbs, root, j ) diff --git a/src/deep_neurographs/utils/swc_util.py b/src/deep_neurographs/utils/swc_util.py index be283c3..bdddc50 100644 --- a/src/deep_neurographs/utils/swc_util.py +++ b/src/deep_neurographs/utils/swc_util.py @@ -9,199 +9,337 @@ """ -from concurrent.futures import ThreadPoolExecutor, as_completed +from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed +from google.cloud import storage from io import BytesIO +from tqdm import tqdm from zipfile import ZipFile import networkx as nx import numpy as np +import os from deep_neurographs.utils import util -# -- io util -- -def process_local_paths(paths, anisotropy=[1.0, 1.0, 1.0], min_size=5): +class Reader: """ - Iterates over a list of swc paths to swc file, then builds a dictionary - where the keys are swc attributes (i.e. id, xyz, radius, pid) and values - are the corresponding contents within the swc file. - - Parameters - ---------- - paths : list[str] - List of paths to swc files to be parsed. - anisotropy : list[float], optional - Scaling factors applied to xyz coordinates to account for anisotropy - of microscope. The default is [1.0, 1.0, 1.0]. - min_size : int, optional - Threshold on the number of nodes contained in an swc file. Only swc - files with more than "min_size" nodes are stored in "swc_dicts". The - default is 3. - - Returns - ------- - swc_dicts : list - List of dictionaries where the keys are swc attributes (i.e. id, xyz, - radius, pid) and values are the corresponding contents within the swc - file. + Class that reads swc files that are stored as (1) local directory of swcs, + (2) gcs directory of zips containing swcs, (3) local zip containing swcs, + (4) list of local paths to swcs, or (5) single path to a local swc. """ - valid_paths = [] - swc_dicts = [] - for path in paths: - # Read contents - contents = read_from_local(path) - if len(contents) > min_size: - swc_dict = parse(contents, anisotropy=anisotropy) - swc_dict["swc_id"] = util.get_swc_id(path) - swc_dicts.append(swc_dict) - valid_paths.append(path) - return swc_dicts, valid_paths - - -def process_gcs_zip(zip_content, anisotropy=[1.0, 1.0, 1.0], min_size=0): - swc_dicts = [] - with ZipFile(BytesIO(zip_content)) as zip_file: - with ThreadPoolExecutor() as executor: - # Assign threads - threads = [ - executor.submit( - parse_gcs_zip, zip_file, path, anisotropy, min_size + + def __init__(self, anisotropy=[1.0, 1.0, 1.0], min_size=0): + """ + Initializes a Reader object that loads swc files. + + Parameters + ---------- + anisotropy : list[float], optional + Image to world scaling factors applied to xyz coordinates to + account for anisotropy of the microscope. The default is + [1.0, 1.0, 1.0]. + min_size : int, optional + Threshold on the number of nodes in swc file. Only swc files with + more than "min_size" nodes are stored in "xyz_coords". The default + is 0. + + Returns + ------- + None + + """ + self.anisotropy = anisotropy + self.min_size = min_size + + def load(self, swc_pointer): + """ + Load data based on the type and format of the provided "swc_pointer". + + Parameters + ---------- + swc_pointer : dict, list, str + Object that points to swcs to be read, see class documentation for + details. + + Returns + ------- + list[dict] + List of dictionaries whose keys and values are the attribute name + and values from an swc file. + + """ + if type(swc_pointer) is dict: + return self.load_from_gcs(swc_pointer) + if type(swc_pointer) is list: + return self.load_from_local_paths(swc_pointer) + if type(swc_pointer) is str: + if ".zip" in swc_pointer: + return self.load_from_local_zip(swc_pointer) + if ".swc" in swc_pointer: + return self.load_from_local_path(swc_pointer) + if os.path.isdir(swc_pointer): + paths = util.list_paths(swc_pointer, ext=".swc") + return self.load_from_local_paths(paths) + raise Exception("SWC Pointer is not Valid!") + + # --- Load subroutines --- + def load_from_local_paths(self, swc_paths): + """ + Reads swc files from local machine, then returns either the xyz + coordinates or graphs. + + Paramters + --------- + swc_paths : list + List of paths to swc files stored on the local machine. + + Returns + ------- + list[dict] + List of dictionaries whose keys and values are the attribute name + and values from an swc file. + + """ + with ProcessPoolExecutor(max_workers=1) as executor: + # Assign processes + processes = list() + for path in swc_paths: + processes.append( + executor.submit(self.load_from_local_path, path) ) - for path in util.list_files_in_gcs_zip(zip_content) - ] - # Process results - for thread in as_completed(threads): - result = thread.result() - if len(result["id"]) > 0: + # Store results + swc_dicts = list() + for process in as_completed(processes): + result = process.result() + if result: swc_dicts.append(result) - return swc_dicts - - -def parse_gcs_zip(zip_file, path, anisotropy=[1.0, 1.0, 1.0], min_size=0): - # Parse contents - contents = read_from_gcs_zip(zip_file, path) - if len(contents) > min_size: - swc_dict = parse(contents, anisotropy) - else: - swc_dict = {"id": []} - - # Store id - swc_dict["swc_id"] = util.get_swc_id(path) - return swc_dict - - -def parse(contents, anisotropy=[1.0, 1.0, 1.0]): - """ - Parses an swc file to extract the contents which is stored in a dict. Note - that node_ids from swc are refactored to index from 0 to n-1 where n is - the number of entries in the swc file. - - Parameters - ---------- - contents : list[str] - List of entries from an swc file. - anisotropy : list[float] - - Returns - ------- - ... + return swc_dicts + + def load_from_local_path(self, path): + """ + Reads a single swc file from local machine, then returns either the + xyz coordinates or graphs. + + Paramters + --------- + path : str + Path to swc file stored on the local machine. + + Returns + ------- + list[dict] + List of dictionaries whose keys and values are the attribute name + and values from an swc file. + + """ + content = util.read_txt(path) + if len(content) > self.min_size: + result = self.parse(content) + result["swc_id"] = util.get_swc_id(path) + return result + else: + return False + + def load_from_gcs(self, gcs_dict): + """ + Reads swc files from zips on a GCS bucket. + + Parameters + ---------- + gcs_dict : dict + Dictionary where keys are "bucket_name" and "path". + + Returns + ------- + dict + Dictionary that maps an swc_id to the the xyz coordinates read from + that swc file. + + """ + # Initializations + bucket = storage.Client().bucket(gcs_dict["bucket_name"]) + zip_paths = util.list_gcs_filenames(bucket, gcs_dict["path"], ".zip") + + # Main + with ProcessPoolExecutor() as executor: + # Assign processes + processes = [] + for path in zip_paths: + zip_content = bucket.blob(path).download_as_bytes() + processes.append( + executor.submit(self.load_from_cloud_zip, zip_content) + ) - """ - # Compile swc content - contents, offset = get_contents(contents) - swc_dict = { - "id": np.zeros((len(contents)), dtype=np.int32), - "radius": np.zeros((len(contents)), dtype=np.float32), - "pid": np.zeros((len(contents)), dtype=np.int32), - "xyz": np.zeros((len(contents), 3), dtype=np.float32), - } - for i, line in enumerate(contents): - parts = line.split() - swc_dict["id"][i] = parts[0] - swc_dict["radius"][i] = float(parts[-2]) - swc_dict["pid"][i] = parts[-1] - swc_dict["xyz"][i] = read_xyz(parts[2:5], anisotropy, offset) - - # Check whether radius is in nanometers - if swc_dict["radius"][0] > 100: - swc_dict["radius"] /= 1000 - return swc_dict - - -def get_contents(swc_contents): - offset = [0, 0, 0] - for i, line in enumerate(swc_contents): - if line.startswith("# OFFSET"): + # Store results + swc_dicts = list() + desc = "Downloading SWCs" + for process in tqdm(as_completed(processes), desc=desc): + swc_dicts.extend(process.result()) + return swc_dicts + + def load_from_cloud_zip(self, zip_content): + """ + Reads swc files from a zip that has been downloaded from a cloud + bucket. + + Parameters + ---------- + zip_content : ... + content of a zip file. + + Returns + ------- + dict + Dictionary that maps an swc_id to the the xyz coordinates read from + that swc file. + + """ + with ZipFile(BytesIO(zip_content)) as zip_file: + with ThreadPoolExecutor() as executor: + # Assign threads + threads = [] + for f in util.list_files_in_zip(zip_content): + threads.append( + executor.submit( + self.load_from_cloud_zipped_file, zip_file, f + ) + ) + + # Process results + swc_dicts = list() + for thread in as_completed(threads): + result = thread.result() + if result: + swc_dicts.append(result) + return swc_dicts + + def load_from_cloud_zipped_file(self, zip_file, path): + """ + Reads swc file stored at "path" which points to a file in a zip. + + Parameters + ---------- + zip_file : ZipFile + Zip containing swc file to be read. + path : str + Path to swc file to be read. + + Returns + ------- + dict + Dictionary that maps an swc_id to the the xyz coordinates or graph + read from that swc file. + + """ + content = util.read_zip(zip_file, path).splitlines() + if len(content) > self.min_size: + result = self.parse(content) + result["swc_id"] = util.get_swc_id(path) + return result + else: + return False + + # --- Process swc content --- + def parse(self, content): + """ + Parses an swc file to extract the content which is stored in a dict. Note + that node_ids from swc are refactored to index from 0 to n-1 where n is + the number of entries in the swc file. + + Parameters + ---------- + content : list[str] + List of entries from an swc file. + + Returns + ------- + dict + Dictionaries whose keys and values are the attribute name + and values from an swc file. + + """ + # Parse swc content + content, offset = self.process_content(content) + swc_dict = { + "id": np.zeros((len(content)), dtype=np.int32), + "radius": np.zeros((len(content)), dtype=np.float32), + "pid": np.zeros((len(content)), dtype=np.int32), + "xyz": np.zeros((len(content), 3), dtype=np.float32), + } + for i, line in enumerate(content): parts = line.split() - offset = read_xyz(parts[2:5]) - if not line.startswith("#"): - break - return swc_contents[i:], offset - - -def read_from_local(path): - """ - Reads swc file stored at "path" on local machine. - - Parameters - ---------- - Path : str - Path to swc file to be read. - - Returns - ------- - list - List such that each entry is a line from the swc file. - - """ - with open(path, "r") as file: - return file.readlines() - - -def read_from_gcs_zip(zip_file, path): - """ - Reads the content of an swc file from a zip file in a GCS bucket. - - """ - try: - with zip_file.open(path) as txt_file: - return txt_file.read().decode("utf-8").splitlines() - except: - print(f"Failed to read {path}") - return [] - - -def read_xyz(xyz, anisotropy=[1.0, 1.0, 1.0], offset=[0, 0, 0]): - """ - Reads the (x,y,z)) coordinates from an swc file, then shift and scales - them if application. - - Parameters - ---------- - xyz : str - (z,y,x) coordinates. - - Returns - ------- - tuple - The (x,y,z) coordinates from an swc file. - - """ - xyz = [float(xyz[i]) + offset[i] for i in range(3)] - return tuple([xyz[i] * anisotropy[i] for i in range(3)]) - - -def write(path, contents, color=None): - if type(contents) is list: - write_list(path, contents, color=color) - elif type(contents) is dict: - write_dict(path, contents, color=color) - elif type(contents) is nx.Graph: - write_graph(path, contents, color=color) + swc_dict["id"][i] = parts[0] + swc_dict["radius"][i] = float(parts[-2]) + swc_dict["pid"][i] = parts[-1] + swc_dict["xyz"][i] = self.read_xyz(parts[2:5], offset) + + # Check whether radius is in nanometers + if swc_dict["radius"][0] > 100: + swc_dict["radius"] /= 1000 + return swc_dict + + def process_content(self, content): + """ + Processes lines of text from a content source, extracting an offset + value and returning the remaining content starting from the line + immediately after the last commented line. + + Parameters + ---------- + content : List[str] + List of strings where each string represents a line of text. + + Returns + ------- + List[str] + A list of strings representing the lines of text starting from the + line immediately after the last commented line. + List[float] + Offset of swc file. + + """ + offset = [1.0, 1.0, 1.0] + for i, line in enumerate(content): + if line.startswith("# OFFSET"): + offset = self.read_xyz(line.split()[2:5]) + if not line.startswith("#"): + return content[i:], offset + + def read_xyz(self, xyz_str, offset=[0.0, 0.0, 0.0]): + """ + Reads the coordinates from a string and transforms it (if applicable). + + Parameters + ---------- + xyz_str : str + Coordinate stored in a str. + offset : list[int], optional + Offset of coordinates in swc file. The default is [0.0, 0.0, 0.0]. + + Returns + ------- + numpy.ndarray + xyz coordinates of an entry from an swc file. + + """ + xyz = np.zeros((3)) + for i in range(3): + xyz[i] = self.anisotropy[i] * (float(xyz_str[i]) + offset[i]) + return xyz + + +def write(path, content, color=None): + if type(content) is list: + write_list(path, content, color=color) + elif type(content) is dict: + write_dict(path, content, color=color) + elif type(content) is nx.Graph: + write_graph(path, content, color=color) else: - assert True, "Unable to write {} to swc".format(type(contents)) + assert True, "Unable to write {} to swc".format(type(content)) def write_list(path, entry_list, color=None): diff --git a/src/deep_neurographs/utils/util.py b/src/deep_neurographs/utils/util.py index 9d5061e..4d3a358 100644 --- a/src/deep_neurographs/utils/util.py +++ b/src/deep_neurographs/utils/util.py @@ -268,6 +268,7 @@ def list_subdirs(path, keyword=None): subdirs.append(d) elif keyword in d: subdirs.append(d) + subdirs.sort() return subdirs @@ -326,7 +327,7 @@ def set_path(dir_name, filename, ext): # -- gcs utils -- -def list_files_in_gcs_zip(zip_content): +def list_files_in_zip(zip_content): """ Lists all files in a zip file stored in a GCS bucket. @@ -404,7 +405,7 @@ def read_txt(path): """ with open(path, "r") as f: - return f.read() + return f.read().splitlines() def read_metadata(path): @@ -426,6 +427,25 @@ def read_metadata(path): return metadata["chunk_origin"], metadata["chunk_shape"] +def read_zip(zip_file, path): + """ + Reads the content of an swc file from a zip file. + + Parameters + ---------- + zip_file : ZipFile + Zip containing text file to be read. + + Returns + ------- + str + Contents of a txt file. + + """ + with zip_file.open(path) as f: + return f.read().decode("utf-8") + + def write_json(path, contents): """ Writes "contents" to a json file at "path". From 17dbd5b843949aae151d733d8f71a21a5b096b5c Mon Sep 17 00:00:00 2001 From: anna-grim Date: Sun, 15 Sep 2024 03:15:48 +0000 Subject: [PATCH 4/4] minor upds --- src/deep_neurographs/intake.py | 8 +++++--- .../machine_learning/graph_datasets.py | 2 +- .../machine_learning/heterograph_datasets.py | 2 +- src/deep_neurographs/neurograph.py | 4 +++- src/deep_neurographs/utils/ml_util.py | 2 +- src/deep_neurographs/utils/swc_util.py | 12 ++++++++---- 6 files changed, 19 insertions(+), 11 deletions(-) diff --git a/src/deep_neurographs/intake.py b/src/deep_neurographs/intake.py index a02f635..2f88ee6 100644 --- a/src/deep_neurographs/intake.py +++ b/src/deep_neurographs/intake.py @@ -10,13 +10,13 @@ from concurrent.futures import ProcessPoolExecutor, as_completed from time import time + from tqdm import tqdm from deep_neurographs.neurograph import NeuroGraph from deep_neurographs.utils import graph_util as gutil from deep_neurographs.utils import img_util, swc_util, util - MIN_SIZE = 30 NODE_SPACING = 2 SMOOTH_BOOL = True @@ -125,7 +125,7 @@ def run(self, swc_pointer): # Graph size n_components = util.reformat_number(len(irreducibles)) print("\nGraph Overview...") - print("# connected components:", n_connected_components) + print("# connected components:", n_components) print("# nodes:", util.reformat_number(n_nodes)) print("# edges:", util.reformat_number(n_edges)) @@ -155,15 +155,17 @@ def get_irreducibles(self, swc_dicts): i += 1 # Store results + desc = "Extract Graphs" irreducibles = [] n_nodes, n_edges = 0, 0 - for process in tqdm(as_completed(processes), desc="Extract Graphs"): + for process in tqdm(as_completed(processes), desc=desc): irreducibles_i = process.result() irreducibles.extend(irreducibles_i) n_nodes += count_nodes(irreducibles_i) n_edges += count_edges(irreducibles_i) return irreducibles, n_nodes, n_edges + # --- utils --- def count_nodes(irreducibles): """ diff --git a/src/deep_neurographs/machine_learning/graph_datasets.py b/src/deep_neurographs/machine_learning/graph_datasets.py index b7737a9..2777eb9 100644 --- a/src/deep_neurographs/machine_learning/graph_datasets.py +++ b/src/deep_neurographs/machine_learning/graph_datasets.py @@ -15,7 +15,7 @@ import torch from torch_geometric.data import Data as GraphData -from deep_neurographs.machine_learning import feature_generation, datasets +from deep_neurographs.machine_learning import datasets, feature_generation from deep_neurographs.utils import gnn_util diff --git a/src/deep_neurographs/machine_learning/heterograph_datasets.py b/src/deep_neurographs/machine_learning/heterograph_datasets.py index 67499af..57acbea 100644 --- a/src/deep_neurographs/machine_learning/heterograph_datasets.py +++ b/src/deep_neurographs/machine_learning/heterograph_datasets.py @@ -17,7 +17,7 @@ import torch from torch_geometric.data import HeteroData as HeteroGraphData -from deep_neurographs.machine_learning import feature_generation, datasets +from deep_neurographs.machine_learning import datasets, feature_generation from deep_neurographs.utils import gnn_util DTYPE = torch.float32 diff --git a/src/deep_neurographs/neurograph.py b/src/deep_neurographs/neurograph.py index bbb66b1..8bc449c 100644 --- a/src/deep_neurographs/neurograph.py +++ b/src/deep_neurographs/neurograph.py @@ -19,7 +19,9 @@ from deep_neurographs import generate_proposals, geometry from deep_neurographs.geometry import dist as get_dist -from deep_neurographs.machine_learning.groundtruth_generation import init_targets +from deep_neurographs.machine_learning.groundtruth_generation import ( + init_targets, +) from deep_neurographs.utils import graph_util as gutil from deep_neurographs.utils import img_util, swc_util, util diff --git a/src/deep_neurographs/utils/ml_util.py b/src/deep_neurographs/utils/ml_util.py index b78bb65..8fc2270 100644 --- a/src/deep_neurographs/utils/ml_util.py +++ b/src/deep_neurographs/utils/ml_util.py @@ -16,8 +16,8 @@ from sklearn.ensemble import AdaBoostClassifier, RandomForestClassifier from deep_neurographs.machine_learning import ( - feature_generation, datasets, + feature_generation, graph_datasets, heterograph_datasets, ) diff --git a/src/deep_neurographs/utils/swc_util.py b/src/deep_neurographs/utils/swc_util.py index bdddc50..82a6c82 100644 --- a/src/deep_neurographs/utils/swc_util.py +++ b/src/deep_neurographs/utils/swc_util.py @@ -9,15 +9,19 @@ """ -from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed -from google.cloud import storage +import os +from concurrent.futures import ( + ProcessPoolExecutor, + ThreadPoolExecutor, + as_completed, +) from io import BytesIO -from tqdm import tqdm from zipfile import ZipFile import networkx as nx import numpy as np -import os +from google.cloud import storage +from tqdm import tqdm from deep_neurographs.utils import util