From db5bcb62377c2a3b8210e75418845180875684b3 Mon Sep 17 00:00:00 2001 From: alessiamarcolini <98marcolini@gmail.com> Date: Thu, 19 Nov 2020 17:25:13 +0100 Subject: [PATCH] Get encoded columns from original via Dataset --- src/trousse/dataset.py | 29 +++++++++++++++++++++ tests/unit/test_dataset.py | 53 +++++++++++++++++++++++++++++++++++++- 2 files changed, 81 insertions(+), 1 deletion(-) diff --git a/src/trousse/dataset.py b/src/trousse/dataset.py index 9568e2d..16039d3 100644 --- a/src/trousse/dataset.py +++ b/src/trousse/dataset.py @@ -12,6 +12,7 @@ import pandas as pd from joblib import Parallel, delayed +from . import feature_operations as fop from .exceptions import MultipleObjectsInFileError, NotShelveFileError from .operations_list import OperationsList from .settings import CATEG_COL_THRESHOLD @@ -20,6 +21,7 @@ if typing.TYPE_CHECKING: # pragma: no cover from .feature_operations import FeatureOperation + logger = logging.getLogger(__name__) @@ -487,6 +489,33 @@ def operations_history(self) -> OperationsList: """ return self._operations_history + def encoded_columns_from_original(self, column: str) -> List[str]: + """Return the list of encoded columns name from ``column``. + + Parameters + ---------- + column : str + Column name + + Returns + ------- + List[str] + List of encoded columns name from ``column`` + """ + encoders_on_column = self._operations_history.operations_from_original_column( + column, [fop.OrdinalEncoder, fop.OneHotEncoder] + ) + + encoded_columns = [] + + for encoder in encoders_on_column: + if encoder.derived_columns is None: + encoded_columns.extend(encoder.columns) + else: + encoded_columns.extend(encoder.derived_columns) + + return encoded_columns + def _get_categorical_cols(self, col_list: Tuple[str]) -> Set[str]: """ Identify every categorical column in dataset. diff --git a/tests/unit/test_dataset.py b/tests/unit/test_dataset.py index ff17e7b..317bb9c 100644 --- a/tests/unit/test_dataset.py +++ b/tests/unit/test_dataset.py @@ -242,7 +242,6 @@ def it_knows_how_to_track_history( self, request, metadata_cols, derived_columns, expected_metadata_cols ): operations_list_iadd_ = method_mock(request, OperationsList, "__iadd__") - expected_df = DataFrameMock.df_generic(10) get_df_from_csv_ = function_mock(request, "trousse.dataset.get_df_from_csv") get_df_from_csv_.return_value = expected_df @@ -256,6 +255,58 @@ def it_knows_how_to_track_history( assert dataset.metadata_cols == expected_metadata_cols operations_list_iadd_.assert_called_once_with(ANY, feat_op) + @pytest.mark.parametrize( + "op_from_original_column_ret_value, expected_columns", + [ + ( + [ + fop.OrdinalEncoder( + columns=["col"], derived_columns=["encoded_col"] + ), + fop.OrdinalEncoder(columns=["col"], derived_columns=None), + ], + ["encoded_col", "col"], + ), + ( + [ + fop.OrdinalEncoder( + columns=["col"], derived_columns=["encoded_col"] + ), + fop.OrdinalEncoder(columns=["col"], derived_columns=None), + fop.OrdinalEncoder( + columns=["col"], derived_columns=["encoded_col2"] + ), + ], + ["encoded_col", "col", "encoded_col2"], + ), + ( + [], + [], + ), + ], + ) + def it_knows_how_to_get_encoded_columns_from_original( + self, request, op_from_original_column_ret_value, expected_columns + ): + op_list__op_from_original_column_ = method_mock( + request, OperationsList, "operations_from_original_column" + ) + op_list__op_from_original_column_.return_value = ( + op_from_original_column_ret_value + ) + expected_df = DataFrameMock.df_generic(10) + get_df_from_csv_ = function_mock(request, "trousse.dataset.get_df_from_csv") + get_df_from_csv_.return_value = expected_df + dataset = Dataset(data_file="fake/path") + + columns = dataset.encoded_columns_from_original("col") + + assert type(columns) == list + assert columns == expected_columns + op_list__op_from_original_column_.assert_called_once_with( + dataset.operations_history, "col", [fop.OrdinalEncoder, fop.OneHotEncoder] + ) + class DescribeColumnListByType: def it_knows_its_str(self, request):