diff --git a/.github/workflows/deploy-pypi.yml b/.github/workflows/deploy-pypi.yml index 43a49706..65e11e95 100644 --- a/.github/workflows/deploy-pypi.yml +++ b/.github/workflows/deploy-pypi.yml @@ -30,7 +30,7 @@ jobs: python -m pip install . - name: Test with pytest run: | - pytest -v + pytest -v --flake8 --pydocstyle --cov=hiclass --cov-fail-under=90 --cov-report html coverage xml - name: Upload Coverage to Codecov if: matrix.os == 'ubuntu-latest' diff --git a/.github/workflows/test-pr.yml b/.github/workflows/test-pr.yml index 221cc142..8ffd1e1f 100644 --- a/.github/workflows/test-pr.yml +++ b/.github/workflows/test-pr.yml @@ -6,7 +6,12 @@ on: - main jobs: - build: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: psf/black@stable + test: runs-on: ${{ matrix.os }} strategy: fail-fast: false @@ -29,4 +34,4 @@ jobs: python -m pip install . - name: Test with pytest run: | - pytest -v + pytest -v --flake8 --pydocstyle --cov=hiclass --cov-fail-under=90 --cov-report html diff --git a/README.md b/README.md index 35fe8de4..cec2b2bf 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ HiClass is an open-source Python library for hierarchical classification compatible with scikit-learn. -[![Deploy PyPI](https://github.com/mirand863/hiclass/actions/workflows/deploy-pypi.yml/badge.svg?event=push)](https://github.com/mirand863/hiclass/actions/workflows/deploy-pypi.yml) [![Documentation Status](https://readthedocs.org/projects/hiclass/badge/?version=latest)](https://hiclass.readthedocs.io/en/latest/?badge=latest) [![codecov](https://codecov.io/gh/mirand863/hiclass/branch/main/graph/badge.svg?token=PR8VLBMMNR)](https://codecov.io/gh/mirand863/hiclass) [![Downloads PyPI](https://static.pepy.tech/personalized-badge/hiclass?period=total&units=international_system&left_color=grey&right_color=brightgreen&left_text=pypi)](https://pypi.org/project/hiclass/) [![Downloads Conda](https://img.shields.io/conda/dn/conda-forge/hiclass?label=conda)](https://anaconda.org/conda-forge/hiclass) [![License](https://img.shields.io/badge/License-BSD_3--Clause-blue.svg)](https://opensource.org/licenses/BSD-3-Clause) +[![Deploy PyPI](https://github.com/mirand863/hiclass/actions/workflows/deploy-pypi.yml/badge.svg?event=push)](https://github.com/mirand863/hiclass/actions/workflows/deploy-pypi.yml) [![Documentation Status](https://readthedocs.org/projects/hiclass/badge/?version=latest)](https://hiclass.readthedocs.io/en/latest/?badge=latest) [![codecov](https://codecov.io/gh/mirand863/hiclass/branch/main/graph/badge.svg?token=PR8VLBMMNR)](https://codecov.io/gh/mirand863/hiclass) [![Downloads PyPI](https://static.pepy.tech/personalized-badge/hiclass?period=total&units=international_system&left_color=grey&right_color=brightgreen&left_text=pypi)](https://pypi.org/project/hiclass/) [![Downloads Conda](https://img.shields.io/conda/dn/conda-forge/hiclass?label=conda)](https://anaconda.org/conda-forge/hiclass) [![License](https://img.shields.io/badge/License-BSD_3--Clause-blue.svg)](https://opensource.org/licenses/BSD-3-Clause) [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) ✨ Here is a **demo** that shows HiClass in action on hierarchical data: @@ -16,7 +16,7 @@ HiClass is an open-source Python library for hierarchical classification compati - [Who is using HiClass?](#who-is-using-hiclass) - [Install](#install) - [Quick start](#quick-start) -- [Step-by-step- walk-through](#step-by-step-walk-through) +- [Step-by-step walk-through](#step-by-step-walk-through) - [API documentation](#api-documentation) - [FAQ](#faq) - [Support](#support) @@ -34,7 +34,7 @@ HiClass is an open-source Python library for hierarchical classification compati - **Hierarchical metrics:** HiClass supports the computation of hierarchical precision, recall and f-score, which are more appropriate for hierarchical data than traditional metrics. - **Compatible with pickle:** Easily store trained models on disk for future use. -**Don't see a feature on this list?** Search our [issue tracker](https://github.com/mirand863/hiclass/issues) if someone has already requested it and add a comment to it explaining your use-case, or open a new issue if not. We prioritize our roadmap based on user feedback, so we'd love to hear from you. +**Any feature missing on this list?** Search our [issue tracker](https://github.com/mirand863/hiclass/issues) to see if someone has already requested it and add a comment to it explaining your use-case. Otherwise, please open a new issue describing the requested feature and possible use-case scenario. We prioritize our roadmap based on user feedback, so we would love to hear from you. ## Benchmarks @@ -85,7 +85,7 @@ We would love to benchmark with larger datasets, if we can find them in the publ Here is our public roadmap: https://github.com/mirand863/hiclass/projects/1. -We do Just-In-Time planning, and we tend to reprioritize based on your feedback. Hence, items you see on this roadmap are subject to change. We prioritize features based on the number of people asking for it, features/fixes that are small enough and can be addressed while we work on other related features, features/fixes that help improve stability & relevance and features that address interesting use cases that excite us! If you'd like to have a request prioritized, we ask that you add a detailed use-case for it, either as a comment on an existing issue (besides a thumbs-up) or in a new issue. The detailed context helps. +We do Just-In-Time planning, and we tend to reprioritize based on your feedback. Hence, items you see on this roadmap are subject to change. We prioritize features based on the number of people asking for it, features/fixes that are small enough and can be addressed while we work on other related features, features/fixes that help improve stability & relevance and features that address interesting use cases that excite us! If you would like to have a request prioritized, we ask that you add a detailed use-case for it, either as a comment on an existing issue (besides a thumbs-up) or in a new issue. The detailed context helps. ## Who is using HiClass? @@ -123,7 +123,7 @@ Here's a quick example showcasing how you can train and predict using a local cl from hiclass import LocalClassifierPerNode from sklearn.ensemble import RandomForestClassifier -# define data +# Define data X_train = [[1], [2], [3], [4]] X_test = [[4], [3], [2], [1]] Y_train = [ @@ -152,7 +152,7 @@ from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer from sklearn.linear_model import LogisticRegression from sklearn.pipeline import Pipeline -# define data +# Define data X_train = [ 'Struggling to repay loan', 'Unable to get annual report', @@ -220,7 +220,9 @@ Please reach out to fabio.malchermiranda@hpi.de. ## Contributing -We are a small team on a mission to democratize hierarchical classification, and we'll take all the help we can get! If you'd like to get involved, here's information on [contribution guidelines and how to test the code locally](https://github.com/mirand863/hiclass/blob/main/CONTRIBUTING.md). +We are a small team on a mission to democratize hierarchical classification, and we will take all the help we can get! If you would like to get involved, here is information on [contribution guidelines and how to test the code locally](https://github.com/mirand863/hiclass/blob/main/CONTRIBUTING.md). + +You can contribute in multiple ways, e.g., reporting bugs, writing or translating documentation, reviewing or refactoring code, requesting or implementing new features, etc. ## Getting the latest updates diff --git a/docs/examples/README.rst b/docs/examples/README.rst index 90ce2230..7cc24de4 100644 --- a/docs/examples/README.rst +++ b/docs/examples/README.rst @@ -1,4 +1,7 @@ Gallery of Examples =================== -These examples illustrate the main features of HiClass. \ No newline at end of file +These examples illustrate the main features of HiClass. + +.. toctree:: + :hidden: diff --git a/docs/examples/plot_empty_levels.py b/docs/examples/plot_empty_levels.py new file mode 100644 index 00000000..f3af3469 --- /dev/null +++ b/docs/examples/plot_empty_levels.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- +""" +========================== +Different Number of Levels +========================== + +HiClass supports different number of levels in the hierarchy. +For this example, we will train a local classifier per node +with a hierarchy similar to the following image: + +.. figure:: ../algorithms/local_classifier_per_node.svg + :align: center +""" +from sklearn.linear_model import LogisticRegression + +from hiclass import LocalClassifierPerNode + +# Define data +X_train = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]] +X_test = [[9, 10], [7, 8], [5, 6], [3, 4], [1, 2]] +Y_train = [ + ["Bird"], + ["Reptile", "Snake"], + ["Reptile", "Lizard"], + ["Mammal", "Cat"], + ["Mammal", "Wolf", "Dog"], +] + +# Use random forest classifiers for every node +rf = LogisticRegression() +classifier = LocalClassifierPerNode(local_classifier=rf) + +# Train local classifier per node +classifier.fit(X_train, Y_train) + +# Predict +predictions = classifier.predict(X_test) +print(predictions) diff --git a/docs/examples/plot_parallel_training.py b/docs/examples/plot_parallel_training.py index e90712af..6ecbab09 100644 --- a/docs/examples/plot_parallel_training.py +++ b/docs/examples/plot_parallel_training.py @@ -7,8 +7,9 @@ Larger datasets require more time for training. While by default the models in HiClass are trained using a single core, it is possible to train each local classifier in parallel by leveraging the library Ray [1]_. -In this example, we demonstrate how to train a hierarchical classifier in parallel, -using all the cores available, on a mock dataset from Kaggle [2]_. +In this example, we demonstrate how to train a hierarchical classifier in parallel by +setting the parameter :literal:`n_jobs` to use all the cores available. Training +is performed on a mock dataset from Kaggle [2]_. .. [1] https://www.ray.io/ .. [2] https://www.kaggle.com/datasets/kashnitsky/hierarchical-text-classification @@ -25,29 +26,15 @@ from hiclass import LocalClassifierPerParentNode -def download(url: str, path: str) -> None: - """ - Download a file from the internet. - - Parameters - ---------- - url : str - The address of the file to be downloaded. - path : str - The path to store the downloaded file. - """ - response = requests.get(url) - with open(path, "wb") as file: - file.write(response.content) - - # Download training data -training_data_url = "https://zenodo.org/record/6657410/files/train_40k.csv?download=1" -training_data_path = "train_40k.csv" -download(training_data_url, training_data_path) +url = "https://zenodo.org/record/6657410/files/train_40k.csv?download=1" +path = "train_40k.csv" +response = requests.get(url) +with open(path, "wb") as file: + file.write(response.content) # Load training data into pandas dataframe -training_data = pd.read_csv(training_data_path).fillna(" ") +training_data = pd.read_csv(path).fillna(" ") # We will use logistic regression classifiers for every parent node lr = LogisticRegression(max_iter=1000) diff --git a/docs/source/conf.py b/docs/source/conf.py index b39f35cd..aecd85fd 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -13,17 +13,18 @@ # import os import sys -sys.path.insert(0, os.path.abspath('./../..')) -sys.path.insert(0, os.path.abspath('./../../hiclass')) + +sys.path.insert(0, os.path.abspath("./../..")) +sys.path.insert(0, os.path.abspath("./../../hiclass")) print(sys.path) import sphinx_code_tabs # -- Project information ----------------------------------------------------- -project = 'hiclass' -copyright = '2022, Fabio Malcher Miranda, Niklas Köhnecke' -author = 'Fabio Malcher Miranda, Niklas Köhnecke' +project = "hiclass" +copyright = "2022, Fabio Malcher Miranda, Niklas Köhnecke" +author = "Fabio Malcher Miranda, Niklas Köhnecke" # -- General configuration --------------------------------------------------- @@ -32,15 +33,15 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.napoleon', - 'sphinx.ext.autosectionlabel', - 'sphinx_code_tabs', - 'sphinx_gallery.gen_gallery', + "sphinx.ext.autodoc", + "sphinx.ext.napoleon", + "sphinx.ext.autosectionlabel", + "sphinx_code_tabs", + "sphinx_gallery.gen_gallery", ] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. @@ -55,12 +56,13 @@ use_rtd_scheme = False try: import sphinx_rtd_theme + extensions.extend(["sphinx_rtd_theme"]) use_rtd_scheme = True except ImportError: print("sphinx_rtd_theme was not installed, using alabaster as fallback!") -html_theme = 'sphinx_rtd_theme' if use_rtd_scheme else 'alabaster' +html_theme = "sphinx_rtd_theme" if use_rtd_scheme else "alabaster" # Add any paths that contain custom static files (such as style sheets) here, @@ -76,6 +78,6 @@ html_theme_options["sidebar_width"] = "230px" sphinx_gallery_conf = { - 'examples_dirs': '../examples', - 'gallery_dirs': 'auto_examples', -} \ No newline at end of file + "examples_dirs": "../examples", + "gallery_dirs": "auto_examples", +} diff --git a/docs/source/index.rst b/docs/source/index.rst index 443dd6a2..ccfe48cb 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -30,15 +30,15 @@ Welcome to hiclass' documentation! :target: https://opensource.org/licenses/BSD-3-Clause :alt: License +.. image:: https://img.shields.io/badge/code%20style-black-000000.svg + :target: https://github.com/psf/black + .. toctree:: - :titlesonly: + :includehidden: + :maxdepth: 3 introduction/index get_started/index auto_examples/index algorithms/index - -.. toctree:: - :maxdepth: 3 - - api/index + api/index diff --git a/hiclass/HierarchicalClassifier.py b/hiclass/HierarchicalClassifier.py index bf5633bc..1075877c 100644 --- a/hiclass/HierarchicalClassifier.py +++ b/hiclass/HierarchicalClassifier.py @@ -4,9 +4,44 @@ import networkx as nx import numpy as np +import ray from sklearn.base import BaseEstimator from sklearn.linear_model import LogisticRegression -from sklearn.utils.validation import check_X_y + + +def make_leveled(y): + """ + Add empty cells if columns' length differs. + + Parameters + ---------- + y : array-like of shape (n_samples, n_levels) + The target values, i.e., hierarchical class labels for classification. + + Returns + ------- + leveled_y : array-like of shape (n_samples, n_levels) + The leveled target values, i.e., hierarchical class labels for classification. + + Notes + ----- + If rows are not iterable, returns the current y without modifications. + + Examples + -------- + >>> from hiclass.HierarchicalClassifier import make_leveled + >>> y = [['a'], ['b', 'c']] + >>> make_leveled(y) + array([['a', ''], + ['b', 'c']]) + """ + try: + depth = max([len(row) for row in y]) + except TypeError: + return y + y = np.array(y) + leveled_y = [[i for i in row] + [""] * (depth - len(row)) for row in y] + return np.array(leveled_y) class HierarchicalClassifier(abc.ABC): @@ -74,10 +109,7 @@ def fit(self, X, y): Fitted estimator. """ # Fit local classifiers in DAG - if self.n_jobs > 1: - self._fit_digraph_parallel() - else: - self._fit_digraph() + self._fit_digraph() # Delete unnecessary variables self._clean_up() @@ -85,10 +117,13 @@ def fit(self, X, y): def _pre_fit(self, X, y): # Check that X and y have correct shape # and convert them to np.ndarray if need be + self.X_, self.y_ = self._validate_data( X, y, multi_output=True, accept_sparse="csr" ) + self.y_ = make_leveled(self.y_) + # Create and configure logger self._create_logger() @@ -140,9 +175,11 @@ def _disambiguate(self): if self.y_.ndim == 2: new_y = [] for i in range(self.y_.shape[0]): - row = [self.y_[i, 0]] + row = [str(self.y_[i, 0])] for j in range(1, self.y_.shape[1]): - row.append(str(row[-1]) + self.separator_ + str(self.y_[i, j])) + parent = str(row[-1]) + child = str(self.y_[i, j]) + row.append(parent + self.separator_ + child) new_y.append(np.asarray(row, dtype=np.str_)) self.y_ = np.array(new_y) @@ -153,38 +190,46 @@ def _create_digraph(self): # Save dtype of y_ self.dtype_ = self.y_.dtype - # 1D disguised as 2D + self._create_digraph_1d() + + self._create_digraph_2d() + + if self.y_.ndim > 2: + # Unsuported dimension + self.logger_.error(f"y with {self.y_.ndim} dimensions detected") + raise ValueError( + f"Creating graph from y with {self.y_.ndim} dimensions is not supported" + ) + + def _create_digraph_1d(self): + # Flatten 1D disguised as 2D if self.y_.ndim == 2 and self.y_.shape[1] == 1: self.logger_.info("Converting y to 1D") self.y_ = self.y_.flatten() + if self.y_.ndim == 1: + # Create max_levels_ variable + self.max_levels_ = 1 + self.logger_.info(f"Creating digraph from {self.y_.size} 1D labels") + for label in self.y_: + self.hierarchy_.add_node(label) - # Check dimension of labels + def _create_digraph_2d(self): if self.y_.ndim == 2: - # 2D labels # Create max_levels variable self.max_levels_ = self.y_.shape[1] rows, columns = self.y_.shape self.logger_.info(f"Creating digraph from {rows} 2D labels") for row in range(rows): for column in range(columns - 1): - self.hierarchy_.add_edge( - self.y_[row, column], self.y_[row, column + 1] - ) - - elif self.y_.ndim == 1: - # 1D labels - # Create max_levels_ variable - self.max_levels_ = 1 - self.logger_.info(f"Creating digraph from {self.y_.size} 1D labels") - for label in self.y_: - self.hierarchy_.add_node(label) - - else: - # Unsuported dimension - self.logger_.error(f"y with {self.y_.ndim} dimensions detected") - raise ValueError( - f"Creating graph from y with {self.y_.ndim} dimensions is not supported" - ) + parent = self.y_[row, column].split(self.separator_)[-1] + child = self.y_[row, column + 1].split(self.separator_)[-1] + if parent != "" and child != "": + # Only add edge if both parent and child are not empty + self.hierarchy_.add_edge( + self.y_[row, column], self.y_[row, column + 1] + ) + elif parent != "" and column == 0: + self.hierarchy_.add_node(parent) def _export_digraph(self): # Check if edge_list is set @@ -230,6 +275,37 @@ def _initialize_local_classifiers(self): else: self.local_classifier_ = self.local_classifier + def _convert_to_1d(self, y): + # Convert predictions to 1D if there is only 1 column + if self.max_levels_ == 1: + y = y.flatten() + return y + + def _remove_separator(self, y): + # Remove separator from predictions + if y.ndim == 2: + for i in range(y.shape[0]): + for j in range(1, y.shape[1]): + y[i, j] = y[i, j].split(self.separator_)[-1] + + def _fit_node_classifier(self, nodes, local_mode): + if self.n_jobs > 1: + ray.init( + num_cpus=self.n_jobs, local_mode=local_mode, ignore_reinit_error=True + ) + lcppn = ray.put(self) + _parallel_fit = ray.remote(self._fit_classifier) + results = [_parallel_fit.remote(lcppn, node) for node in nodes] + classifiers = ray.get(results) + else: + classifiers = [self._fit_classifier(self, node) for node in nodes] + for classifier, node in zip(classifiers, nodes): + self.hierarchy_.nodes[node]["classifier"] = classifier + + @staticmethod + def _fit_classifier(self, node): + raise NotImplementedError("Method should be implemented in the LCPN and LCPPN") + def _clean_up(self): self.logger_.info("Cleaning up variables that can take a lot of disk space") del self.X_ diff --git a/hiclass/LocalClassifierPerLevel.py b/hiclass/LocalClassifierPerLevel.py index 7bc6a239..6db1ab94 100644 --- a/hiclass/LocalClassifierPerLevel.py +++ b/hiclass/LocalClassifierPerLevel.py @@ -8,31 +8,29 @@ import numpy as np import ray from sklearn.base import BaseEstimator -from sklearn.metrics import euclidean_distances from sklearn.utils.validation import check_array, check_is_fitted from hiclass.ConstantClassifier import ConstantClassifier from hiclass.HierarchicalClassifier import HierarchicalClassifier -@ray.remote -def _parallel_fit(lcpl, level): - classifier = lcpl.local_classifiers_[level] - X = lcpl.X_ - y = lcpl.y_[:, level] - unique_y = np.unique(y) - if len(unique_y) == 1 and lcpl.replace_classifiers: - classifier = ConstantClassifier() - classifier.fit(X, y) - return classifier - - class LocalClassifierPerLevel(BaseEstimator, HierarchicalClassifier): """ Assign local classifiers to each level of the hierarchy, except the root node. A local classifier per level is a local hierarchical classifier that fits one local multi-class classifier for each level of the class hierarchy, except for the root node. + + Examples + -------- + >>> from hiclass import LocalClassifierPerLevel + >>> y = [['1', '1.1'], ['2', '2.1']] + >>> X = [[1, 2], [3, 4]] + >>> lcpl = LocalClassifierPerLevel() + >>> lcpl.fit(X, y) + >>> lcpl.predict(X) + array([['1', '1.1'], + ['2', '2.1']]) """ def __init__( @@ -102,8 +100,6 @@ def fit(self, X, y): # TODO: Add parameter to receive hierarchy as parameter in constructor - # TODO: Add support to empty labels in some levels - # Return the classifier return self @@ -130,100 +126,107 @@ def predict(self, X): # Input validation X = check_array(X, accept_sparse="csr") + # Initialize array that holds predictions y = np.empty((X.shape[0], self.max_levels_), dtype=self.dtype_) # TODO: Add threshold to stop prediction halfway if need be self.logger_.info("Predicting") - for level, classifier in enumerate(self.local_classifiers_): - self.logger_.info(f"Predicting level {level}") - if level == 0: - y[:, level] = classifier.predict(X).flatten() - else: - all_probabilities = classifier.predict_proba(X) - successors = np.array( - [ - list(self.hierarchy_.successors(node)) - for node in y[:, level - 1] - ], - dtype=object, - ) - classes_masks = np.array( - [ - np.isin(classifier.classes_, successors[i]) - for i in range(len(successors)) - ] - ) - probabilities = np.array( - [ - all_probabilities[i, classes_masks[i]] - for i in range(len(classes_masks)) - ], - dtype=object, - ) - highest_probabilities = [ - np.argmax(probabilities[i], axis=0) - for i in range(len(probabilities)) - ] - classes = np.array( - [ - classifier.classes_[classes_masks[i]] - for i in range(len(classes_masks)) - ], - dtype=object, - ) - predictions = np.array( - [ - classes[i][highest_probabilities[i]] - for i in range(len(highest_probabilities)) - ] - ) - y[:, level] = predictions - - # Convert back to 1D if there is only 1 column to pass all sklearn's checks - if self.max_levels_ == 1: - y = y.flatten() - - # Remove separator from predictions - if y.ndim == 2: - for i in range(y.shape[0]): - for j in range(1, y.shape[1]): - y[i, j] = y[i, j].split(self.separator_)[-1] + # Predict first level + classifier = self.local_classifiers_[0] + y[:, 0] = classifier.predict(X).flatten() + + self._predict_remaining_levels(X, y) + + y = self._convert_to_1d(y) + + self._remove_separator(y) return y + def _predict_remaining_levels(self, X, y): + for level in range(1, y.shape[1]): + classifier = self.local_classifiers_[level] + probabilities = classifier.predict_proba(X) + classes = self.local_classifiers_[level].classes_ + probabilities_dict = [dict(zip(classes, prob)) for prob in probabilities] + successors = self._get_successors(y[:, level - 1]) + successors_prob = self._get_successors_probability( + probabilities_dict, successors + ) + index_max_probability = [ + np.argmax(prob) if len(prob) > 0 else None for prob in successors_prob + ] + y[:, level] = [ + successors_list[index_max_probability[i]] + if index_max_probability[i] is not None + else "" + for i, successors_list in enumerate(successors) + ] + + @staticmethod + def _get_successors_probability(probabilities_dict, successors): + successors_probability = [ + np.array( + [probabilities_dict[i][successor] for successor in successors_list] + ) + for i, successors_list in enumerate(successors) + ] + return successors_probability + + def _get_successors(self, level): + successors = [ + list(self.hierarchy_.successors(node)) + if self.hierarchy_.has_node(node) + else [] + for node in level + ] + return successors + def _initialize_local_classifiers(self): super()._initialize_local_classifiers() self.local_classifiers_ = [ deepcopy(self.local_classifier_) for _ in range(self.y_.shape[1]) ] + self.masks_ = [None for _ in range(self.y_.shape[1])] - def _fit_digraph(self): + def _fit_digraph(self, local_mode: bool = False): self.logger_.info("Fitting local classifiers") - for level, classifier in enumerate(self.local_classifiers_): - self.logger_.info( - f"Fitting local classifier for level '{level + 1}' ({level + 1}/{len(self.local_classifiers_)})" + if self.n_jobs > 1: + ray.init( + num_cpus=self.n_jobs, local_mode=local_mode, ignore_reinit_error=True ) - X = self.X_ - y = self.y_[:, level] - unique_y = np.unique(y) - if len(unique_y) == 1 and self.replace_classifiers: - self.logger_.warning( - f"Fitting ConstantClassifier for level '{level + 1}'" - ) - self.local_classifiers_[level] = ConstantClassifier() - classifier = self.local_classifiers_[level] - classifier.fit(X, y) - - def _fit_digraph_parallel(self, local_mode: bool = False): - self.logger_.info("Fitting local classifiers") - ray.init(num_cpus=self.n_jobs, local_mode=local_mode, ignore_reinit_error=True) - lcpl = ray.put(self) - results = [ - _parallel_fit.remote(lcpl, level) - for level in range(len(self.local_classifiers_)) - ] - classifiers = ray.get(results) + lcpl = ray.put(self) + _parallel_fit = ray.remote(self._fit_classifier) + results = [ + _parallel_fit.remote(lcpl, level, self.separator_) + for level in range(len(self.local_classifiers_)) + ] + classifiers = ray.get(results) + else: + classifiers = [ + self._fit_classifier(self, level, self.separator_) + for level in range(len(self.local_classifiers_)) + ] for level, classifier in enumerate(classifiers): self.local_classifiers_[level] = classifier + + @staticmethod + def _fit_classifier(self, level, separator): + classifier = self.local_classifiers_[level] + + X, y = self._remove_empty_leaves(separator, self.X_, self.y_[:, level]) + + unique_y = np.unique(y) + if len(unique_y) == 1 and self.replace_classifiers: + classifier = ConstantClassifier() + classifier.fit(X, y) + return classifier + + @staticmethod + def _remove_empty_leaves(separator, X, y): + # Detect rows where leaves are not empty + leaves = np.array([str(i).split(separator)[-1] for i in y]) + mask = leaves != "" + return X[mask], y[mask] diff --git a/hiclass/LocalClassifierPerNode.py b/hiclass/LocalClassifierPerNode.py index 78a5a3f9..1bdcf705 100644 --- a/hiclass/LocalClassifierPerNode.py +++ b/hiclass/LocalClassifierPerNode.py @@ -7,7 +7,6 @@ import networkx as nx import numpy as np -import ray from sklearn.base import BaseEstimator from sklearn.utils.validation import check_array, check_is_fitted @@ -16,23 +15,23 @@ from hiclass.HierarchicalClassifier import HierarchicalClassifier -@ray.remote -def _parallel_fit(lcpn, node): - classifier = lcpn.hierarchy_.nodes[node]["classifier"] - X, y = lcpn.binary_policy_.get_binary_examples(node) - unique_y = np.unique(y) - if len(unique_y) == 1 and lcpn.replace_classifiers: - classifier = ConstantClassifier() - classifier.fit(X, y) - return classifier - - class LocalClassifierPerNode(BaseEstimator, HierarchicalClassifier): """ Assign local classifiers to each node of the graph, except the root node. A local classifier per node is a local hierarchical classifier that fits one local binary classifier for each node of the class hierarchy, except for the root node. + + Examples + -------- + >>> from hiclass import LocalClassifierPerNode + >>> y = [['1', '1.1'], ['2', '2.1']] + >>> X = [[1, 2], [3, 4]] + >>> lcpn = LocalClassifierPerNode() + >>> lcpn.fit(X, y) + >>> lcpn.predict(X) + array([['1', '1.1'], + ['2', '2.1']]) """ def __init__( @@ -109,8 +108,6 @@ def fit(self, X, y): # TODO: Add parameter to receive hierarchy as parameter in constructor - # TODO: Add support to empty labels in some levels - # Return the classifier return self @@ -137,6 +134,7 @@ def predict(self, X): # Input validation X = check_array(X, accept_sparse="csr") + # Initialize array that holds predictions y = np.empty((X.shape[0], self.max_levels_), dtype=self.dtype_) # TODO: Add threshold to stop prediction halfway if need be @@ -172,15 +170,9 @@ def predict(self, X): prediction = np.array(prediction) y[mask, level] = prediction - # Convert back to 1D if there is only 1 column to pass all sklearn's checks - if self.max_levels_ == 1: - y = y.flatten() + y = self._convert_to_1d(y) - # Remove separator from predictions - if y.ndim == 2: - for i in range(y.shape[0]): - for j in range(1, y.shape[1]): - y[i, j] = y[i, j].split(self.separator_)[-1] + self._remove_separator(y) return y @@ -214,39 +206,22 @@ def _initialize_local_classifiers(self): } nx.set_node_attributes(self.hierarchy_, local_classifiers) - def _fit_digraph_parallel(self, local_mode: bool = False): + def _fit_digraph(self, local_mode: bool = False): self.logger_.info("Fitting local classifiers") - ray.init(num_cpus=self.n_jobs, local_mode=local_mode, ignore_reinit_error=True) nodes = list(self.hierarchy_.nodes) # Remove root because it does not need to be fitted nodes.remove(self.root_) - lcpn = ray.put(self) - results = [_parallel_fit.remote(lcpn, node) for node in nodes] - classifiers = ray.get(results) - for classifier, node in zip(classifiers, nodes): - self.hierarchy_.nodes[node]["classifier"] = classifier - - def _fit_digraph(self): - self.logger_.info("Fitting local classifiers") - nodes = list(self.hierarchy_.nodes) - # Remove root because it does not need to be fitted - nodes.remove(self.root_) - for index, node in enumerate(nodes): - node_name = str(node).split(self.separator_)[-1] - self.logger_.info( - f"Fitting local classifier for node '{node_name}' ({index + 1}/{len(nodes)})" - ) - classifier = self.hierarchy_.nodes[node]["classifier"] - X, y = self.binary_policy_.get_binary_examples(node) - unique_y = np.unique(y) - if len(unique_y) == 1 and self.replace_classifiers: - node_name = str(node).split(self.separator_)[-1] - self.logger_.warning( - f"Fitting ConstantClassifier for node '{node_name}'" - ) - self.hierarchy_.nodes[node]["classifier"] = ConstantClassifier() - classifier = self.hierarchy_.nodes[node]["classifier"] - classifier.fit(X, y) + self._fit_node_classifier(nodes, local_mode) + + @staticmethod + def _fit_classifier(self, node): + classifier = self.hierarchy_.nodes[node]["classifier"] + X, y = self.binary_policy_.get_binary_examples(node) + unique_y = np.unique(y) + if len(unique_y) == 1 and self.replace_classifiers: + classifier = ConstantClassifier() + classifier.fit(X, y) + return classifier def _clean_up(self): super()._clean_up() diff --git a/hiclass/LocalClassifierPerParentNode.py b/hiclass/LocalClassifierPerParentNode.py index 72ac398e..b7a8e5f8 100644 --- a/hiclass/LocalClassifierPerParentNode.py +++ b/hiclass/LocalClassifierPerParentNode.py @@ -7,7 +7,6 @@ import networkx as nx import numpy as np -import ray from sklearn.base import BaseEstimator from sklearn.utils.validation import check_array, check_is_fitted @@ -15,24 +14,23 @@ from hiclass.HierarchicalClassifier import HierarchicalClassifier -@ray.remote -def _parallel_fit(lcppn, node): - classifier = lcppn.hierarchy_.nodes[node]["classifier"] - # get children examples - X, y = lcppn._get_successors(node) - unique_y = np.unique(y) - if len(unique_y) == 1 and lcppn.replace_classifiers: - classifier = ConstantClassifier() - classifier.fit(X, y) - return classifier - - class LocalClassifierPerParentNode(BaseEstimator, HierarchicalClassifier): """ Assign local classifiers to each parent node of the graph. A local classifier per parent node is a local hierarchical classifier that fits one multi-class classifier for each parent node of the class hierarchy. + + Examples + -------- + >>> from hiclass import LocalClassifierPerParentNode + >>> y = [['1', '1.1'], ['2', '2.1']] + >>> X = [[1, 2], [3, 4]] + >>> lcppn = LocalClassifierPerParentNode() + >>> lcppn.fit(X, y) + >>> lcppn.predict(X) + array([['1', '1.1'], + ['2', '2.1']]) """ def __init__( @@ -102,8 +100,6 @@ def fit(self, X, y): # TODO: Add parameter to receive hierarchy as parameter in constructor - # TODO: Add support to empty labels in some levels - # Return the classifier return self @@ -130,43 +126,38 @@ def predict(self, X): # Input validation X = check_array(X, accept_sparse="csr") + # Initialize array that holds predictions y = np.empty((X.shape[0], self.max_levels_), dtype=self.dtype_) # TODO: Add threshold to stop prediction halfway if need be - bfs = nx.bfs_successors(self.hierarchy_, source=self.root_) - self.logger_.info("Predicting") - for predecessor, successors in bfs: - if predecessor == self.root_: - mask = [True] * X.shape[0] - subset_x = X[mask] - else: - mask = np.isin(y, predecessor).any(axis=1) - subset_x = X[mask] - if subset_x.shape[0] > 0: - classifier = self.hierarchy_.nodes[predecessor]["classifier"] - prediction = classifier.predict(subset_x) - level = nx.shortest_path_length( - self.hierarchy_, self.root_, predecessor - ) - if prediction.ndim == 2 and prediction.shape[1] == 1: - prediction = prediction.flatten() - y[mask, level] = prediction - - # Convert back to 1D if there is only 1 column to pass all sklearn's checks - if self.max_levels_ == 1: - y = y.flatten() - - # Remove separator from predictions - if y.ndim == 2: - for i in range(y.shape[0]): - for j in range(1, y.shape[1]): - y[i, j] = y[i, j].split(self.separator_)[-1] + # Predict first level + classifier = self.hierarchy_.nodes[self.root_]["classifier"] + y[:, 0] = classifier.predict(X).flatten() + + self._predict_remaining_levels(X, y) + + y = self._convert_to_1d(y) + + self._remove_separator(y) return y + def _predict_remaining_levels(self, X, y): + for level in range(1, y.shape[1]): + predecessors = set(y[:, level - 1]) + predecessors.discard("") + for predecessor in predecessors: + mask = np.isin(y[:, level - 1], predecessor) + predecessor_x = X[mask] + if predecessor_x.shape[0] > 0: + successors = list(self.hierarchy_.successors(predecessor)) + if len(successors) > 0: + classifier = self.hierarchy_.nodes[predecessor]["classifier"] + y[mask, level] = classifier.predict(predecessor_x).flatten() + def _initialize_local_classifiers(self): super()._initialize_local_classifiers() local_classifiers = {} @@ -197,33 +188,18 @@ def _get_successors(self, node): y = np.array(y) return X, y - def _fit_digraph_parallel(self, local_mode: bool = False): - self.logger_.info("Fitting local classifiers") - ray.init(num_cpus=self.n_jobs, local_mode=local_mode, ignore_reinit_error=True) - nodes = self._get_parents() - lcppn = ray.put(self) - results = [_parallel_fit.remote(lcppn, node) for node in nodes] - classifiers = ray.get(results) - for classifier, node in zip(classifiers, nodes): - self.hierarchy_.nodes[node]["classifier"] = classifier - - def _fit_digraph(self): + @staticmethod + def _fit_classifier(self, node): + classifier = self.hierarchy_.nodes[node]["classifier"] + # get children examples + X, y = self._get_successors(node) + unique_y = np.unique(y) + if len(unique_y) == 1 and self.replace_classifiers: + classifier = ConstantClassifier() + classifier.fit(X, y) + return classifier + + def _fit_digraph(self, local_mode: bool = False): self.logger_.info("Fitting local classifiers") nodes = self._get_parents() - for index, node in enumerate(nodes): - node_name = str(node).split(self.separator_)[-1] - self.logger_.info( - f"Fitting local classifier for node '{node_name}' ({index + 1}/{len(nodes)})" - ) - classifier = self.hierarchy_.nodes[node]["classifier"] - # get children examples - X, y = self._get_successors(node) - unique_y = np.unique(y) - if len(unique_y) == 1 and self.replace_classifiers: - node_name = str(node).split(self.separator_)[-1] - self.logger_.warning( - f"Fitting ConstantClassifier for node '{node_name}'" - ) - self.hierarchy_.nodes[node]["classifier"] = ConstantClassifier() - classifier = self.hierarchy_.nodes[node]["classifier"] - classifier.fit(X, y) + self._fit_node_classifier(nodes, local_mode) diff --git a/hiclass/metrics.py b/hiclass/metrics.py index 46efb9af..d8926700 100644 --- a/hiclass/metrics.py +++ b/hiclass/metrics.py @@ -2,6 +2,17 @@ import numpy as np from sklearn.utils import check_array +from hiclass.HierarchicalClassifier import make_leveled + + +def _validate_input(y_true, y_pred): + assert len(y_true) == len(y_pred) + y_pred = make_leveled(y_pred) + y_true = make_leveled(y_true) + y_true = check_array(y_true, dtype=None) + y_pred = check_array(y_pred, dtype=None) + return y_true, y_pred + def precision(y_true: np.ndarray, y_pred: np.ndarray): r""" @@ -24,17 +35,19 @@ def precision(y_true: np.ndarray, y_pred: np.ndarray): precision : float What proportion of positive identifications was actually correct? """ - assert len(y_true) == len(y_pred) - y_true = check_array(y_true, dtype=None) - y_pred = check_array(y_pred, dtype=None) + y_true, y_pred = _validate_input(y_true, y_pred) sum_intersection = 0 sum_prediction_and_ancestors = 0 for ground_truth, prediction in zip(y_true, y_pred): + ground_truth_set = set(ground_truth) + ground_truth_set.discard("") + prediction_set = set(prediction) + prediction_set.discard("") sum_intersection = sum_intersection + len( - set(ground_truth).intersection(set(prediction)) + ground_truth_set.intersection(prediction_set) ) sum_prediction_and_ancestors = sum_prediction_and_ancestors + len( - set(prediction) + prediction_set ) precision = sum_intersection / sum_prediction_and_ancestors return precision @@ -61,17 +74,19 @@ def recall(y_true: np.ndarray, y_pred: np.ndarray): recall : float What proportion of actual positives was identified correctly? """ - assert len(y_true) == len(y_pred) - y_true = check_array(y_true, dtype=None) - y_pred = check_array(y_pred, dtype=None) + y_true, y_pred = _validate_input(y_true, y_pred) sum_intersection = 0 sum_prediction_and_ancestors = 0 for ground_truth, prediction in zip(y_true, y_pred): + ground_truth_set = set(ground_truth) + ground_truth_set.discard("") + prediction_set = set(prediction) + prediction_set.discard("") sum_intersection = sum_intersection + len( - set(ground_truth).intersection(set(prediction)) + ground_truth_set.intersection(prediction_set) ) sum_prediction_and_ancestors = sum_prediction_and_ancestors + len( - set(ground_truth) + ground_truth_set ) recall = sum_intersection / sum_prediction_and_ancestors return recall @@ -95,9 +110,7 @@ def f1(y_true: np.ndarray, y_pred: np.ndarray): f1 : float Weighted average of the precision and recall """ - assert len(y_true) == len(y_pred) - y_true = check_array(y_true, dtype=None) - y_pred = check_array(y_pred, dtype=None) + y_true, y_pred = _validate_input(y_true, y_pred) prec = precision(y_true, y_pred) rec = recall(y_true, y_pred) f1 = 2 * prec * rec / (prec + rec) diff --git a/setup.cfg b/setup.cfg index 777b5950..df142dd7 100755 --- a/setup.cfg +++ b/setup.cfg @@ -1,11 +1,7 @@ [tool:pytest] testpaths=hiclass tests -addopts = --flake8 - --pydocstyle - --cov=hiclass - --cov-fail-under=90 - --cov-report html - --disable-warnings +addopts = --disable-warnings + --color=yes --ignore=hiclass/_version.py, [flake8] diff --git a/tests/test_HierarchicalClassifier.py b/tests/test_HierarchicalClassifier.py index 241e2443..80e5fbd3 100644 --- a/tests/test_HierarchicalClassifier.py +++ b/tests/test_HierarchicalClassifier.py @@ -1,12 +1,13 @@ import logging +import tempfile + import networkx as nx import numpy as np import pytest -import tempfile from numpy.testing import assert_array_equal from sklearn.linear_model import LogisticRegression -from hiclass.HierarchicalClassifier import HierarchicalClassifier +from hiclass.HierarchicalClassifier import HierarchicalClassifier, make_leveled @pytest.fixture @@ -175,3 +176,40 @@ def test_clean_up(digraph_multiple_roots): assert digraph_multiple_roots.X_ is None with pytest.raises(AttributeError): assert digraph_multiple_roots.y_ is None + + +@pytest.fixture +def empty_levels(): + y = [ + ["a"], + ["b", "c"], + ["d", "e", "f"], + ] + return y + + +def test_make_leveled(empty_levels): + ground_truth = np.array( + [ + ["a", "", ""], + ["b", "c", ""], + ["d", "e", "f"], + ] + ) + result = make_leveled(empty_levels) + assert_array_equal(ground_truth, result) + + +@pytest.fixture +def noniterable_y(): + y = [1, 2, 3] + return y + + +def test_make_leveled_non_iterable_y(noniterable_y): + assert noniterable_y == make_leveled(noniterable_y) + + +def test_fit_classifier(): + with pytest.raises(NotImplementedError): + HierarchicalClassifier._fit_classifier(None, None) diff --git a/tests/test_LocalClassifierPerLevel.py b/tests/test_LocalClassifierPerLevel.py index 5089d86a..0184e9a3 100644 --- a/tests/test_LocalClassifierPerLevel.py +++ b/tests/test_LocalClassifierPerLevel.py @@ -27,6 +27,10 @@ def digraph_logistic_regression(): digraph.logger_ = logging.getLogger("LCPL") digraph.root_ = "a" digraph.separator_ = "::HiClass::Separator::" + digraph.masks_ = [ + [True, True], + [True, True], + ] return digraph @@ -40,28 +44,13 @@ def test_initialize_local_classifiers(digraph_logistic_regression): def test_fit_digraph(digraph_logistic_regression): - classifiers = [ - LogisticRegression(), - LogisticRegression(), - ] - digraph_logistic_regression.local_classifiers_ = classifiers - digraph_logistic_regression._fit_digraph() - for classifier in digraph_logistic_regression.local_classifiers_: - try: - check_is_fitted(classifier) - except NotFittedError as e: - pytest.fail(repr(e)) - assert 1 - - -def test_fit_digraph_parallel(digraph_logistic_regression): classifiers = [ LogisticRegression(), LogisticRegression(), ] digraph_logistic_regression.n_jobs = 2 digraph_logistic_regression.local_classifiers_ = classifiers - digraph_logistic_regression._fit_digraph_parallel(local_mode=True) + digraph_logistic_regression._fit_digraph(local_mode=True) for classifier in digraph_logistic_regression.local_classifiers_: try: check_is_fitted(classifier) @@ -93,6 +82,10 @@ def fitted_logistic_regression(): digraph.dtype_ = "