Skip to content

Commit

Permalink
Refactor img windows (#267)
Browse files Browse the repository at this point in the history
* refactor: extract fixed image patch in features

* fixed performance gap

---------

Co-authored-by: anna-grim <[email protected]>
  • Loading branch information
anna-grim and anna-grim authored Oct 11, 2024
1 parent 8389bab commit ab934fa
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 15 deletions.
2 changes: 1 addition & 1 deletion src/deep_neurographs/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def run(self, fragments_pointer):
print(f"Total Runtime: {round(t, 4)} {unit}\n")

def run_schedule(
self, fragments_pointer, radius_schedule, save_all_rounds=False
self, fragments_pointer, search_radius_schedule, save_all_rounds=False
):
t0 = time()
self.report_experiment()
Expand Down
21 changes: 7 additions & 14 deletions src/deep_neurographs/machine_learning/heterograph_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,10 @@ def __init__(
self.dropout = dropout

# Feature vector sizes
hidden_dim = scale_hidden_dim* np.max(list(node_dict.values()))
output_dim = heads_1 * heads_2 * hidden_dim
node_dict = ml.feature_generation.get_node_dict()
edge_dict = ml.feature_generation.get_edge_dict()
hidden = scale_hidden * np.max(list(node_dict.values()))
self.dropout = dropout

# Linear layers
self.input_nodes = nn.ModuleDict()
Expand All @@ -69,12 +71,9 @@ def __init__(
self.input_edges[key] = nn.Linear(d, hidden_dim, device=device)
self.output = Linear(output_dim, 1).to(device)

# Message passing layers
self.conv1 = self.init_gat_layer(hidden_dim, hidden_dim, heads_1) # change name
edge_dim = hidden_dim
hidden_dim = heads_1 * hidden_dim

self.conv2 = self.init_gat_layer(hidden_dim, edge_dim, heads_2) # change name
# Convolutional layers
self.conv1 = self.init_gat_layer(hidden, hidden, heads_1)
self.conv2 = self.init_gat_layer(heads_1 * hidden, hidden, heads_2)

# Nonlinear activation
self.dropout = Dropout(dropout) # change name
Expand All @@ -89,12 +88,6 @@ def get_relation_types(cls):
return cls.relation_types

# --- Architecture ---
def init_linear_layer(self, hidden_dim, my_dict):
linear_layer = dict()
for key, dim in my_dict.items():
linear_layer[key] = nn.Linear(dim, hidden_dim, device=self.device)
return linear_layer

def init_gat_layer(self, hidden_dim, edge_dim, heads):
gat_dict = dict()
for r in self.get_relation_types():
Expand Down

0 comments on commit ab934fa

Please sign in to comment.