Skip to content

Commit

Permalink
Merge pull request #40 from jpaillard/main
Browse files Browse the repository at this point in the history
Add Green model to benchmark - follow up
  • Loading branch information
bruAristimunha authored Jan 3, 2025
2 parents 4e4fefe + b0b33eb commit 9c20d6d
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
1 change: 0 additions & 1 deletion objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ class Objective(BaseObjective):
"pip:git+https://github.com/braindecode/braindecode#egg=braindecode", # noqa
"pip:optuna",
"pip:optuna-integration",
"pip:git+https://github.com/Roche/neuro-green", # noqa
]

parameters = {
Expand Down
16 changes: 11 additions & 5 deletions solvers/green.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,16 @@ class Solver(BaseSolver):
"bi_out": [[16]],
"hidden_dim": [[8]],
"n_freqs": [10],
"random_state": [0],
"kernel_width_s": [0.5],
"dropout": [0.5],

}

sampling_strategy = "run_once"
requirements = [
"pip:git+https://github.com/Roche/neuro-green",
]

def set_objective(self, X, y, sfreq, extra_info):
"""Set the objective information from Objective.get_objective.
Expand All @@ -50,11 +57,11 @@ def set_objective(self, X, y, sfreq, extra_info):

model = get_green(
n_freqs=self.n_freqs,
kernel_width_s=0.5,
kernel_width_s=self.kernel_width_s,
n_ch=n_channels,
sfreq=sfreq,
orth_weights=True,
dropout=0.5,
dropout=self.dropout,
hidden_dim=self.hidden_dim,
logref="logeuclid",
pool_layer=RealCovariance(),
Expand All @@ -66,7 +73,6 @@ def set_objective(self, X, y, sfreq, extra_info):
if torch.cuda.is_available():
torch.backends.cudnn.benchmark = True

seed = 0
callbacks = get_braindecode_callbacks(
dataset_name=extra_info["dataset_name"],
patience=self.patience,
Expand All @@ -84,7 +90,7 @@ def set_objective(self, X, y, sfreq, extra_info):
train_split=ValidSplit(
cv=self.valid_set,
stratified=True,
random_state=seed,
random_state=self.random_state,
),
batch_size=self.batch_size,
device=device,
Expand All @@ -100,7 +106,7 @@ def set_objective(self, X, y, sfreq, extra_info):
self.X = X
self.y = y

def run(self, n_iter):
def run(self, _):
# This is the function that is called to evaluate the solver
self.clf.fit(self.X, y=self.y)

Expand Down

0 comments on commit 9c20d6d

Please sign in to comment.