Skip to content

Commit

Permalink
Get encoded columns from original via Dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
alessiamarcolini committed Nov 22, 2020
1 parent d995679 commit db5bcb6
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 1 deletion.
29 changes: 29 additions & 0 deletions src/trousse/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import pandas as pd
from joblib import Parallel, delayed

from . import feature_operations as fop
from .exceptions import MultipleObjectsInFileError, NotShelveFileError
from .operations_list import OperationsList
from .settings import CATEG_COL_THRESHOLD
Expand All @@ -20,6 +21,7 @@
if typing.TYPE_CHECKING: # pragma: no cover
from .feature_operations import FeatureOperation


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -487,6 +489,33 @@ def operations_history(self) -> OperationsList:
"""
return self._operations_history

def encoded_columns_from_original(self, column: str) -> List[str]:
"""Return the list of encoded columns name from ``column``.
Parameters
----------
column : str
Column name
Returns
-------
List[str]
List of encoded columns name from ``column``
"""
encoders_on_column = self._operations_history.operations_from_original_column(
column, [fop.OrdinalEncoder, fop.OneHotEncoder]
)

encoded_columns = []

for encoder in encoders_on_column:
if encoder.derived_columns is None:
encoded_columns.extend(encoder.columns)
else:
encoded_columns.extend(encoder.derived_columns)

return encoded_columns

def _get_categorical_cols(self, col_list: Tuple[str]) -> Set[str]:
"""
Identify every categorical column in dataset.
Expand Down
53 changes: 52 additions & 1 deletion tests/unit/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,6 @@ def it_knows_how_to_track_history(
self, request, metadata_cols, derived_columns, expected_metadata_cols
):
operations_list_iadd_ = method_mock(request, OperationsList, "__iadd__")

expected_df = DataFrameMock.df_generic(10)
get_df_from_csv_ = function_mock(request, "trousse.dataset.get_df_from_csv")
get_df_from_csv_.return_value = expected_df
Expand All @@ -256,6 +255,58 @@ def it_knows_how_to_track_history(
assert dataset.metadata_cols == expected_metadata_cols
operations_list_iadd_.assert_called_once_with(ANY, feat_op)

@pytest.mark.parametrize(
"op_from_original_column_ret_value, expected_columns",
[
(
[
fop.OrdinalEncoder(
columns=["col"], derived_columns=["encoded_col"]
),
fop.OrdinalEncoder(columns=["col"], derived_columns=None),
],
["encoded_col", "col"],
),
(
[
fop.OrdinalEncoder(
columns=["col"], derived_columns=["encoded_col"]
),
fop.OrdinalEncoder(columns=["col"], derived_columns=None),
fop.OrdinalEncoder(
columns=["col"], derived_columns=["encoded_col2"]
),
],
["encoded_col", "col", "encoded_col2"],
),
(
[],
[],
),
],
)
def it_knows_how_to_get_encoded_columns_from_original(
self, request, op_from_original_column_ret_value, expected_columns
):
op_list__op_from_original_column_ = method_mock(
request, OperationsList, "operations_from_original_column"
)
op_list__op_from_original_column_.return_value = (
op_from_original_column_ret_value
)
expected_df = DataFrameMock.df_generic(10)
get_df_from_csv_ = function_mock(request, "trousse.dataset.get_df_from_csv")
get_df_from_csv_.return_value = expected_df
dataset = Dataset(data_file="fake/path")

columns = dataset.encoded_columns_from_original("col")

assert type(columns) == list
assert columns == expected_columns
op_list__op_from_original_column_.assert_called_once_with(
dataset.operations_history, "col", [fop.OrdinalEncoder, fop.OneHotEncoder]
)


class DescribeColumnListByType:
def it_knows_its_str(self, request):
Expand Down

0 comments on commit db5bcb6

Please sign in to comment.