Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
anna-grim authored Oct 12, 2024
1 parent ab934fa commit 4b6e674
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 62 deletions.
3 changes: 1 addition & 2 deletions src/deep_neurographs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,10 @@ class MLConfig:
batch_size: int = 2000
downsample_factor: int = 1
high_threshold: float = 0.9
lr: float = 1e-4
lr: float = 1e-3
threshold: float = 0.6
model_type: str = "GraphNeuralNet"
n_epochs: int = 1000
use_img_embedding: bool = False
validation_split: float = 0.15
weight_decay: float = 1e-3

Expand Down
75 changes: 34 additions & 41 deletions src/deep_neurographs/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,10 @@
from tqdm import tqdm

from deep_neurographs.graph_artifact_removal import remove_doubles
from deep_neurographs.machine_learning.feature_generation import (
FeatureGenerator,
)
from deep_neurographs.machine_learning import feature_generation
from deep_neurographs.utils import gnn_util
from deep_neurographs.utils import graph_util as gutil
from deep_neurographs.utils import ml_util, util
from deep_neurographs.utils import img_util, ml_util, util
from deep_neurographs.utils.gnn_util import toCPU
from deep_neurographs.utils.graph_util import GraphLoader

Expand Down Expand Up @@ -67,8 +65,6 @@ def __init__(
output_dir,
config,
device=None,
label_path=None,
use_img_embedding=False,
):
"""
Initializes an object that executes the full GraphTrace inference
Expand All @@ -83,7 +79,7 @@ def __init__(
Identifier for the predicted segmentation to be processed by the
inference pipeline.
img_path : str
Path to the raw image assumed to be stored in a GCS bucket.
Path to the raw image of whole brain stored on a GCS bucket.
model_path : str
Path to machine learning model parameters.
output_dir : str
Expand All @@ -93,10 +89,6 @@ def __init__(
for the inference pipeline.
device : str, optional
...
label_path : str, optional
Path to the segmentation assumed to be stored on a GCS bucket.
use_img_embedding : bool, optional
...
Returns
-------
Expand All @@ -107,6 +99,7 @@ def __init__(
self.accepted_proposals = list()
self.sample_id = sample_id
self.segmentation_id = segmentation_id
self.img_path = img_path
self.model_path = model_path

# Extract config settings
Expand All @@ -115,15 +108,13 @@ def __init__(

# Inference engine
self.inference_engine = InferenceEngine(
img_path,
self.img_path,
self.model_path,
self.ml_config.model_type,
self.graph_config.search_radius,
confidence_threshold=self.ml_config.threshold,
device=device,
downsample_factor=self.ml_config.downsample_factor,
label_path=label_path,
use_img_embedding=use_img_embedding,
)

# Set output directory
Expand Down Expand Up @@ -167,10 +158,10 @@ def run_schedule(
t0 = time()
self.report_experiment()
self.build_graph(fragments_pointer)
for round_id, radius in enumerate(radius_schedule):
print(f"--- Round {round_id + 1}: Radius = {radius} ---")
for round_id, search_radius in enumerate(search_radius_schedule):
print(f"--- Round {round_id + 1}: Radius = {search_radius} ---")
round_id += 1
self.generate_proposals(radius)
self.generate_proposals(search_radius)
self.run_inference()
if save_all_rounds:
self.save_results(round_id=round_id)
Expand Down Expand Up @@ -222,7 +213,7 @@ def build_graph(self, fragments_pointer):
print(f"Module Runtime: {round(t, 4)} {unit}\n")
self.print_graph_overview()

def generate_proposals(self, radius=None):
def generate_proposals(self, search_radius=None):
"""
Generates proposals for the fragment graph based on the specified
configuration.
Expand All @@ -238,13 +229,13 @@ def generate_proposals(self, radius=None):
"""
# Initializations
print("(2) Generate Proposals")
if radius is None:
radius = self.graph_config.radius
if search_radius is None:
search_radius = self.graph_config.search_radius

# Main
t0 = time()
self.graph.generate_proposals(
radius,
search_radius,
complex_bool=self.graph_config.complex_bool,
long_range_bool=self.graph_config.long_range_bool,
proposals_per_leaf=self.graph_config.proposals_per_leaf,
Expand Down Expand Up @@ -401,13 +392,11 @@ def __init__(
img_path,
model_path,
model_type,
radius,
search_radius,
batch_size=BATCH_SIZE,
confidence_threshold=CONFIDENCE_THRESHOLD,
device=None,
downsample_factor=1,
label_path=None,
use_img_embedding=False
):
"""
Initializes an inference engine by loading images and setting class
Expand All @@ -421,7 +410,7 @@ def __init__(
Path to machine learning model parameters.
model_type : str
Type of machine learning model used to perform inference.
radius : float
search_radius : float
Search radius used to generate proposals.
batch_size : int, optional
Number of proposals to generate features and classify per batch.
Expand All @@ -440,20 +429,16 @@ def __init__(
"""
# Set class attributes
self.batch_size = batch_size
self.downsample_factor = downsample_factor
self.device = "cpu" if device is None else device
self.is_gnn = True if "Graph" in model_type else False
self.radius = radius
self.model_type = model_type
self.search_radius = search_radius
self.threshold = confidence_threshold

# Features
self.feature_generator = FeatureGenerator(
img_path,
downsample_factor,
label_path=label_path,
use_img_embedding=use_img_embedding
)

# Model
# Load image and model
driver = "n5" if ".n5" in img_path else "zarr"
self.img = img_util.open_tensorstore(img_path, driver=driver)
self.model = ml_util.load_model(model_path)
if self.is_gnn:
self.model = self.model.to(self.device)
Expand Down Expand Up @@ -547,14 +532,22 @@ def get_batch_dataset(self, neurograph, batch):
...
"""
t0 = time()
features = self.feature_generator.run(neurograph, batch, self.radius)
print("Feature Generation:", time() - t0)
# Generate features
features = feature_generation.run(
neurograph,
self.img,
self.model_type,
batch,
self.search_radius,
downsample_factor=self.downsample_factor,
)

# Initialize dataset
computation_graph = batch["graph"] if type(batch) is dict else None
dataset = ml_util.init_dataset(
neurograph,
features,
self.is_gnn,
self.model_type,
computation_graph=computation_graph,
)
return dataset
Expand All @@ -577,7 +570,7 @@ def predict(self, dataset):
"""
# Get predictions
if self.is_gnn:
if self.model_type == "GraphNeuralNet":
with torch.no_grad():
# Get inputs
n = len(dataset.data["proposal"]["y"])
Expand All @@ -592,7 +585,7 @@ def predict(self, dataset):
preds = np.array(self.model.predict_proba(dataset.data.x)[:, 1])

# Reformat prediction
idxs = dataset.idxs_proposals["idx_to_id"]
idxs = dataset.idxs_proposals["idx_to_edge"]
return {idxs[i]: p for i, p in enumerate(preds)}


Expand Down
31 changes: 12 additions & 19 deletions src/deep_neurographs/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@
from torch.optim.lr_scheduler import StepLR
from torch.utils.tensorboard import SummaryWriter

from deep_neurographs.machine_learning.feature_generation import FeatureGenerator
from deep_neurographs.machine_learning import feature_generation
from deep_neurographs.utils import gnn_util, img_util, ml_util, util
from deep_neurographs.utils.gnn_util import toCPU
from deep_neurographs.utils.graph_util import GraphLoader

LR = 1e-3
N_EPOCHS = 500
SCHEDULER_GAMMA = 0.7
SCHEDULER_STEP_SIZE = 100
N_EPOCHS = 200
SCHEDULER_GAMMA = 0.5
SCHEDULER_STEP_SIZE = 1000
WEIGHT_DECAY = 1e-3


Expand All @@ -50,7 +50,6 @@ def __init__(
model_type,
criterion=None,
output_dir=None,
use_img_embedding=False,
validation_ids=None,
save_model_bool=True,
):
Expand All @@ -59,18 +58,17 @@ def __init__(
raise ValueError("Must provide output_dir to save model.")

# Set class attributes
self.feature_generators = dict()
self.idx_to_ids = list()
self.model = model
self.model_type = model_type
self.output_dir = output_dir
self.save_model_bool = save_model_bool
self.use_img_embedding = use_img_embedding
self.validation_ids = validation_ids

# Set data structures for training examples
self.gt_graphs = list()
self.pred_graphs = list()
self.imgs = dict()
self.train_dataset_list = list()
self.validation_dataset_list = list()

Expand Down Expand Up @@ -144,16 +142,9 @@ def load_example(
}
)

def load_img(
self, sample_id, img_path, downsample_factor, label_path=None
):
if sample_id not in self.feature_generators:
self.feature_generators[sample_id] = FeatureGenerator(
img_path,
downsample_factor,
label_path=label_path,
use_img_embedding=self.use_img_embedding,
)
def load_img(self, path, sample_id):
if sample_id not in self.imgs:
self.imgs[sample_id] = img_util.open_tensorstore(path, "zarr")

# --- main pipeline ---
def run(self):
Expand Down Expand Up @@ -209,8 +200,10 @@ def generate_features(self):

# Generate features
sample_id = self.idx_to_ids[i]["sample_id"]
features = self.feature_generators[sample_id].run(
features = feature_generation.run(
self.pred_graphs[i],
self.imgs[sample_id],
self.model_type,
proposals_dict,
self.graph_config.search_radius,
)
Expand Down Expand Up @@ -470,4 +463,4 @@ def get_predictions(hat_y, threshold=0.5):
Binary predictions based on the given threshold.
"""
return (ml_util.sigmoid(np.array(hat_y)) > threshold).tolist()
return (ml_util.sigmoid(np.array(hat_y)) > threshold).tolist()

0 comments on commit 4b6e674

Please sign in to comment.