Skip to content

Commit

Permalink
120 use hinge loss for svm example (#124)
Browse files Browse the repository at this point in the history
* Switch to hinge loss in SVM example

* remove named arguments in loss calls to allow the use of more losses
  • Loading branch information
jpaillard authored Jan 13, 2025
1 parent e710ea9 commit 19e6cf7
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 10 deletions.
6 changes: 3 additions & 3 deletions examples/plot_variable_importance_classif.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from scipy.stats import ttest_1samp
from sklearn.base import clone
from sklearn.linear_model import RidgeCV
from sklearn.metrics import log_loss
from sklearn.metrics import hinge_loss
from sklearn.model_selection import RandomizedSearchCV, StratifiedKFold
from sklearn.svm import SVC

Expand Down Expand Up @@ -163,7 +163,7 @@
imputation_model=clone(imputation_model),
n_permutations=50,
n_jobs=5,
loss=log_loss,
loss=hinge_loss,
random_state=seed,
method="decision_function",
)
Expand All @@ -177,7 +177,7 @@
imputation_model=clone(imputation_model),
n_permutations=50,
n_jobs=5,
loss=log_loss,
loss=hinge_loss,
random_state=seed,
method="decision_function",
)
Expand Down
6 changes: 4 additions & 2 deletions src/hidimstat/cpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,9 @@ def score(self, X, y):

y_pred = getattr(self.estimator, self.method)(X)

loss_reference = self.loss(y_true=y, y_pred=y_pred)
# In sklearn API y_true is the first argument. Not specifying `y_true=...`
# will allows using other losses such as `hinge_loss`.
loss_reference = self.loss(y, y_pred)
out_dict["loss_reference"] = loss_reference

y_pred_perm = self.predict(X, y)
Expand All @@ -245,7 +247,7 @@ def score(self, X, y):
for j, y_pred_j in enumerate(y_pred_perm):
list_loss_perm = []
for y_pred_perm in y_pred_j:
list_loss_perm.append(self.loss(y_true=y, y_pred=y_pred_perm))
list_loss_perm.append(self.loss(y, y_pred_perm))
out_dict["loss_perm"][j] = np.array(list_loss_perm)

out_dict["importance"] = np.array(
Expand Down
6 changes: 3 additions & 3 deletions src/hidimstat/loco.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def predict(self, X, y):
output_dict = dict()

y_pred = getattr(self.estimator, self.method)(X)
loss_reference = self.loss(y_true=y, y_pred=y_pred)
loss_reference = self.loss(y, y_pred)
output_dict["loss_reference"] = loss_reference
output_dict["loss_loco"] = dict()

Expand Down Expand Up @@ -183,13 +183,13 @@ def score(self, X, y):
out_dict = dict()
y_pred = getattr(self.estimator, self.method)(X)

loss_reference = self.loss(y_true=y, y_pred=y_pred)
loss_reference = self.loss(y, y_pred)
out_dict["loss_reference"] = loss_reference

y_pred_loco = self.predict(X, y)

out_dict["loss_loco"] = np.array(
[self.loss(y_true=y, y_pred=y_pred_loco[j]) for j in range(self.n_groups)]
[self.loss(y, y_pred_loco[j]) for j in range(self.n_groups)]
)

out_dict["importance"] = out_dict["loss_loco"] - loss_reference
Expand Down
4 changes: 2 additions & 2 deletions src/hidimstat/permutation_importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def score(self, X, y):

output_dict = dict()
y_pred = getattr(self.estimator, self.method)(X)
loss_reference = self.loss(y_true=y, y_pred=y_pred)
loss_reference = self.loss(y, y_pred)
output_dict["loss_reference"] = loss_reference
output_dict["loss_perm"] = dict()

Expand All @@ -182,7 +182,7 @@ def score(self, X, y):
for j, y_pred_j in enumerate(y_pred_perm):
list_loss_perm = []
for y_pred_perm in y_pred_j:
list_loss_perm.append(self.loss(y_true=y, y_pred=y_pred_perm))
list_loss_perm.append(self.loss(y, y_pred_perm))
output_dict["loss_perm"][j] = np.array(list_loss_perm)

output_dict["importance"] = np.array(
Expand Down

0 comments on commit 19e6cf7

Please sign in to comment.