Skip to content

Commit

Permalink
Feat gnn whole brain inference (#225)
Browse files Browse the repository at this point in the history
* refactor: inference is performed with class object

* bug: minor updates to adapt code

---------

Co-authored-by: anna-grim <[email protected]>
  • Loading branch information
anna-grim and anna-grim authored Sep 12, 2024
1 parent 9dfd526 commit 8d00fbf
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 17 deletions.
12 changes: 6 additions & 6 deletions src/deep_neurographs/machine_learning/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,11 @@ def run_evaluation(neurograph, accepts, proposals):
overall_stats = get_stats(neurograph, proposals, accepts)

simple_stats = get_stats(
neurograph, neurograph.get_simple_proposals(), accepts
neurograph, neurograph.simple_proposals(), accepts
)

complex_stats = get_stats(
neurograph, neurograph.get_complex_proposals(), accepts
neurograph, neurograph.complex_proposals(), accepts
)

# Store results
Expand Down Expand Up @@ -124,23 +124,23 @@ def run_evaluation_blocks(neurographs, blocks, accepts):

simple_stats_i = get_stats(
neurographs[block_id],
neurographs[block_id].get_simple_proposals(),
neurographs[block_id].simple_proposals(),
accepts[block_id],
)

complex_stats_i = get_stats(
neurographs[block_id],
neurographs[block_id].get_complex_proposals(),
neurographs[block_id].complex_proposals(),
accepts[block_id],
)

# Store results
avg_wgts["Overall"].append(len(neurographs[block_id].proposals))
avg_wgts["Simple"].append(
len(neurographs[block_id].get_simple_proposals())
len(neurographs[block_id].simple_proposals())
)
avg_wgts["Complex"].append(
len(neurographs[block_id].get_complex_proposals())
len(neurographs[block_id].complex_proposals())
)
for metric in METRICS_LIST:
stats["Overall"][metric].append(overall_stats_i[metric])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,12 +236,12 @@ def proposal_skeletal(neurograph, proposals, radius):
proposal_skeletal_features[proposal] = np.concatenate(
(
neurograph.proposal_length(proposal),
feats.n_nearby_leafs(neurograph, proposal, radius),
feats.get_radii(neurograph, proposal),
feats.get_directionals(neurograph, proposal, 8),
feats.get_directionals(neurograph, proposal, 16),
feats.get_directionals(neurograph, proposal, 32),
feats.get_directionals(neurograph, proposal, 64),
neurograph.n_nearby_leafs(proposal, radius),
neurograph.proposal_radii(proposal),
neurograph.proposal_directionals(proposal, 8),
neurograph.proposal_directionals(proposal, 16),
neurograph.proposal_directionals(proposal, 32),
neurograph.proposal_directionals(proposal, 64),
),
axis=None,
)
Expand Down
10 changes: 5 additions & 5 deletions src/deep_neurographs/machine_learning/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ class InferenceEngine:
def __init__(
self,
img_path,
model_type,
model_path,
model_type,
search_radius,
batch_size=BATCH_SIZE,
confidence_threshold=CONFIDENCE_THRESHOLD,
Expand All @@ -56,10 +56,10 @@ def __init__(
----------
img_path : str
Path to image stored in a GCS bucket.
model_type : str
Type of machine learning model used to perform inference.
model_path : str
Path to model parameters.
model_type : str
Type of machine learning model used to perform inference.
search_radius : float
Search radius used to generate proposals.
batch_size : int, optional
Expand Down Expand Up @@ -211,7 +211,7 @@ def run_model(self, dataset):
"""
# Get predictions
if self.is_gnn:
preds = run_gnn_model(dataset.data, self.model)
preds = run_gnn_model(dataset.data, self.model, self.model_type)
elif "Net" in self.model_type:
preds = run_nn_model(dataset, self.model)
else:
Expand Down Expand Up @@ -242,7 +242,7 @@ def run_nn_model(dataset, model):
def run_gnn_model(data, model, model_type):
model.eval()
with torch.no_grad():
if "Hetero":
if "Hetero" in model_type:
x_dict, edge_index_dict, edge_attr_dict = gnn_utils.get_inputs(
data, model_type
)
Expand Down

0 comments on commit 8d00fbf

Please sign in to comment.