From e3ce926e7ccf431d32ede47cb7fd8b84f20fc523 Mon Sep 17 00:00:00 2001 From: jpaillard Date: Thu, 19 Dec 2024 09:47:23 +0100 Subject: [PATCH] add parameters random_state, kernel_width, dropout as parameters --- objective.py | 1 - solvers/green.py | 16 +++++++++++----- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/objective.py b/objective.py index 503f34b..5bfcd08 100644 --- a/objective.py +++ b/objective.py @@ -27,7 +27,6 @@ class Objective(BaseObjective): "pip:git+https://github.com/braindecode/braindecode#egg=braindecode", # noqa "pip:optuna", "pip:optuna-integration", - "pip:git+https://github.com/Roche/neuro-green", # noqa ] parameters = { diff --git a/solvers/green.py b/solvers/green.py index 030c885..ab735f6 100644 --- a/solvers/green.py +++ b/solvers/green.py @@ -27,9 +27,16 @@ class Solver(BaseSolver): "bi_out": [[16]], "hidden_dim": [[8]], "n_freqs": [10], + "random_state": [0], + "kernel_width_s": [0.5], + "dropout": [0.5], + } sampling_strategy = "run_once" + requirements = [ + "pip:git+https://github.com/Roche/neuro-green", + ] def set_objective(self, X, y, sfreq, extra_info): """Set the objective information from Objective.get_objective. @@ -50,11 +57,11 @@ def set_objective(self, X, y, sfreq, extra_info): model = get_green( n_freqs=self.n_freqs, - kernel_width_s=0.5, + kernel_width_s=self.kernel_width_s, n_ch=n_channels, sfreq=sfreq, orth_weights=True, - dropout=0.5, + dropout=self.dropout, hidden_dim=self.hidden_dim, logref="logeuclid", pool_layer=RealCovariance(), @@ -66,7 +73,6 @@ def set_objective(self, X, y, sfreq, extra_info): if torch.cuda.is_available(): torch.backends.cudnn.benchmark = True - seed = 0 callbacks = get_braindecode_callbacks( dataset_name=extra_info["dataset_name"], patience=self.patience, @@ -84,7 +90,7 @@ def set_objective(self, X, y, sfreq, extra_info): train_split=ValidSplit( cv=self.valid_set, stratified=True, - random_state=seed, + random_state=self.random_state, ), batch_size=self.batch_size, device=device, @@ -100,7 +106,7 @@ def set_objective(self, X, y, sfreq, extra_info): self.X = X self.y = y - def run(self, n_iter): + def run(self, _): # This is the function that is called to evaluate the solver self.clf.fit(self.X, y=self.y)