Skip to content

Commit

Permalink
simplify imports for sklearn.metrics functions
Browse files Browse the repository at this point in the history
  • Loading branch information
kzkadc committed Oct 13, 2024
1 parent 6aaa3c8 commit 9eaf764
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 18 deletions.
15 changes: 6 additions & 9 deletions ignite/metrics/clustering/calinski_harabasz_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,13 @@
from ignite.metrics.clustering._base import _ClusteringMetricBase


def _get_calinski_harabasz_score() -> Callable[[Tensor, Tensor], float]:
def _calinski_harabasz_score(features: Tensor, labels: Tensor) -> float:
from sklearn.metrics import calinski_harabasz_score

def _calinski_harabasz_score(features: Tensor, labels: Tensor) -> float:
np_features = features.numpy()
np_labels = labels.numpy()
score = calinski_harabasz_score(np_features, np_labels)
return score

return _calinski_harabasz_score
np_features = features.numpy()
np_labels = labels.numpy()
score = calinski_harabasz_score(np_features, np_labels)
return score


class CalinskiHarabaszScore(_ClusteringMetricBase):
Expand Down Expand Up @@ -104,4 +101,4 @@ def __init__(
except ImportError:
raise ModuleNotFoundError("This module requires scikit-learn to be installed.")

super().__init__(_get_calinski_harabasz_score(), output_transform, check_compute_fn, device, skip_unrolling)
super().__init__(_calinski_harabasz_score, output_transform, check_compute_fn, device, skip_unrolling)
15 changes: 6 additions & 9 deletions ignite/metrics/clustering/davies_bouldin_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,13 @@
from ignite.metrics.clustering._base import _ClusteringMetricBase


def _get_davies_bouldin_score() -> Callable[[Tensor, Tensor], float]:
def _davies_bouldin_score(features: Tensor, labels: Tensor) -> float:
from sklearn.metrics import davies_bouldin_score

def _davies_bouldin_score(features: Tensor, labels: Tensor) -> float:
np_features = features.numpy()
np_labels = labels.numpy()
score = davies_bouldin_score(np_features, np_labels)
return score

return _davies_bouldin_score
np_features = features.numpy()
np_labels = labels.numpy()
score = davies_bouldin_score(np_features, np_labels)
return score


class DaviesBouldinScore(_ClusteringMetricBase):
Expand Down Expand Up @@ -104,4 +101,4 @@ def __init__(
except ImportError:
raise ModuleNotFoundError("This module requires scikit-learn to be installed.")

super().__init__(_get_davies_bouldin_score(), output_transform, check_compute_fn, device, skip_unrolling)
super().__init__(_davies_bouldin_score, output_transform, check_compute_fn, device, skip_unrolling)

0 comments on commit 9eaf764

Please sign in to comment.