diff --git a/ignite/metrics/clustering/calinski_harabasz_score.py b/ignite/metrics/clustering/calinski_harabasz_score.py index c32b079ac6b..ddea539d9bd 100644 --- a/ignite/metrics/clustering/calinski_harabasz_score.py +++ b/ignite/metrics/clustering/calinski_harabasz_score.py @@ -6,16 +6,13 @@ from ignite.metrics.clustering._base import _ClusteringMetricBase -def _get_calinski_harabasz_score() -> Callable[[Tensor, Tensor], float]: +def _calinski_harabasz_score(features: Tensor, labels: Tensor) -> float: from sklearn.metrics import calinski_harabasz_score - def _calinski_harabasz_score(features: Tensor, labels: Tensor) -> float: - np_features = features.numpy() - np_labels = labels.numpy() - score = calinski_harabasz_score(np_features, np_labels) - return score - - return _calinski_harabasz_score + np_features = features.numpy() + np_labels = labels.numpy() + score = calinski_harabasz_score(np_features, np_labels) + return score class CalinskiHarabaszScore(_ClusteringMetricBase): @@ -104,4 +101,4 @@ def __init__( except ImportError: raise ModuleNotFoundError("This module requires scikit-learn to be installed.") - super().__init__(_get_calinski_harabasz_score(), output_transform, check_compute_fn, device, skip_unrolling) + super().__init__(_calinski_harabasz_score, output_transform, check_compute_fn, device, skip_unrolling) diff --git a/ignite/metrics/clustering/davies_bouldin_score.py b/ignite/metrics/clustering/davies_bouldin_score.py index f117b69b092..c6433e7264f 100644 --- a/ignite/metrics/clustering/davies_bouldin_score.py +++ b/ignite/metrics/clustering/davies_bouldin_score.py @@ -6,16 +6,13 @@ from ignite.metrics.clustering._base import _ClusteringMetricBase -def _get_davies_bouldin_score() -> Callable[[Tensor, Tensor], float]: +def _davies_bouldin_score(features: Tensor, labels: Tensor) -> float: from sklearn.metrics import davies_bouldin_score - def _davies_bouldin_score(features: Tensor, labels: Tensor) -> float: - np_features = features.numpy() - np_labels = labels.numpy() - score = davies_bouldin_score(np_features, np_labels) - return score - - return _davies_bouldin_score + np_features = features.numpy() + np_labels = labels.numpy() + score = davies_bouldin_score(np_features, np_labels) + return score class DaviesBouldinScore(_ClusteringMetricBase): @@ -104,4 +101,4 @@ def __init__( except ImportError: raise ModuleNotFoundError("This module requires scikit-learn to be installed.") - super().__init__(_get_davies_bouldin_score(), output_transform, check_compute_fn, device, skip_unrolling) + super().__init__(_davies_bouldin_score, output_transform, check_compute_fn, device, skip_unrolling)