Skip to content

Commit

Permalink
Updated sklearn column_builders to handle import failures
Browse files Browse the repository at this point in the history
  • Loading branch information
aschonfeld committed Apr 28, 2023
1 parent 69a1f1c commit 4eb1fbd
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 48 deletions.
100 changes: 70 additions & 30 deletions dtale/column_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,6 @@
import numpy as np
import pandas as pd
from scipy.stats import mstats
from sklearn.preprocessing import (
LabelEncoder,
OrdinalEncoder,
PowerTransformer,
QuantileTransformer,
RobustScaler,
)
from sklearn.feature_extraction import FeatureHasher
from strsimpy.jaro_winkler import JaroWinkler

import dtale.global_state as global_state
Expand Down Expand Up @@ -897,11 +889,32 @@ def __init__(self, name, cfg):
def build_column(self, data):
col, algo = (self.cfg.get(p) for p in ["col", "algo"])
if algo == "robust":
transformer = RobustScaler()
try:
from sklearn.preprocessing import RobustScaler

transformer = RobustScaler()
except ImportError:
raise Exception(
"You must have at least scikit-learn 0.17.0 installed in order to use the RobustScaler!"
)
elif algo == "quantile":
transformer = QuantileTransformer()
try:
from sklearn.preprocessing import QuantileTransformer

transformer = QuantileTransformer()
except ImportError:
raise Exception(
"You must have at least scikit-learn 0.19.0 installed in order to use the QuantileTransformer!"
)
elif algo == "power":
transformer = PowerTransformer(method="yeo-johnson", standardize=True)
try:
from sklearn.preprocessing import PowerTransformer

transformer = PowerTransformer(method="yeo-johnson", standardize=True)
except ImportError:
raise Exception(
"You must have at least scikit-learn 0.20.0 installed in order to use the PowerTransformer!"
)
standardized = transformer.fit_transform(data[[col]]).reshape(-1)
return pd.Series(standardized, index=data.index, name=self.name)

Expand Down Expand Up @@ -941,27 +954,54 @@ def build_column(self, data):
if algo == "one_hot":
return pd.get_dummies(data, columns=[col], drop_first=True)
elif algo == "ordinal":
is_nan = data[col].isnull()
ordinals = (
OrdinalEncoder().fit_transform(data[[col]].astype("str")).reshape(-1)
)
return pd.Series(ordinals, index=data.index, name=self.name).where(
~is_nan, 0
)
try:
from sklearn.preprocessing import OrdinalEncoder

is_nan = data[col].isnull()
ordinals = (
OrdinalEncoder()
.fit_transform(data[[col]].astype("str"))
.reshape(-1)
)
return pd.Series(ordinals, index=data.index, name=self.name).where(
~is_nan, 0
)
except ImportError:
raise Exception(
"You must have at least scikit-learn 0.20.0 installed in order to use the OrdinalEncoder!"
)
elif algo == "label":
is_nan = data[col].isnull()
labels = LabelEncoder().fit_transform(data[col].astype("str"))
return pd.Series(labels, index=data.index, name=self.name).where(~is_nan, 0)
try:
from sklearn.preprocessing import LabelEncoder

is_nan = data[col].isnull()
labels = LabelEncoder().fit_transform(data[col].astype("str"))
return pd.Series(labels, index=data.index, name=self.name).where(
~is_nan, 0
)
except ImportError:
raise Exception(
"You must have at least scikit-learn 0.12.0 installed in order to use the LabelEncoder!"
)
elif algo == "feature_hasher":
n = int(self.cfg.get("n"))
features = (
FeatureHasher(n_features=n, input_type="string")
.transform(data[[col]].astype("str").values)
.toarray()
)
features = pd.DataFrame(features, index=data.index)
features.columns = ["{}_{}".format(col, col2) for col2 in features.columns]
return features
try:
from sklearn.feature_extraction import FeatureHasher

n = int(self.cfg.get("n"))
features = (
FeatureHasher(n_features=n, input_type="string")
.transform(data[[col]].astype("str").values)
.toarray()
)
features = pd.DataFrame(features, index=data.index)
features.columns = [
"{}_{}".format(col, col2) for col2 in features.columns
]
return features
except ImportError:
raise Exception(
"You must have at least scikit-learn 0.13.0 installed in order to use the FeatureHasher!"
)
raise NotImplementedError("{} not implemented yet!".format(algo))

def build_code(self):
Expand Down
44 changes: 26 additions & 18 deletions tests/dtale/column_builders/test_column_builders.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import numpy as np
import pandas as pd
import pytest
import sklearn as skl
from numpy.random import randn
from pkg_resources import parse_version
from six import PY3

import dtale.pandas_util as pandas_util
Expand Down Expand Up @@ -263,17 +265,20 @@ def test_standardize():
data_id, column_type = "1", "standardize"
build_data_inst({data_id: df})

cfg = {"col": "a", "algo": "power"}
builder = ColumnBuilder(data_id, column_type, "Col1", cfg)
verify_builder(builder, lambda col: col.isnull().sum() == 0)
if parse_version(skl.__version__) >= parse_version("0.20.0"):
cfg = {"col": "a", "algo": "power"}
builder = ColumnBuilder(data_id, column_type, "Col1", cfg)
verify_builder(builder, lambda col: col.isnull().sum() == 0)

cfg = {"col": "a", "algo": "quantile"}
builder = ColumnBuilder(data_id, column_type, "Col1", cfg)
verify_builder(builder, lambda col: col.isnull().sum() == 0)
if parse_version(skl.__version__) >= parse_version("0.19.0"):
cfg = {"col": "a", "algo": "quantile"}
builder = ColumnBuilder(data_id, column_type, "Col1", cfg)
verify_builder(builder, lambda col: col.isnull().sum() == 0)

cfg = {"col": "a", "algo": "robust"}
builder = ColumnBuilder(data_id, column_type, "Col1", cfg)
verify_builder(builder, lambda col: col.isnull().sum() == 0)
if parse_version(skl.__version__) >= parse_version("0.17.0"):
cfg = {"col": "a", "algo": "robust"}
builder = ColumnBuilder(data_id, column_type, "Col1", cfg)
verify_builder(builder, lambda col: col.isnull().sum() == 0)


@pytest.mark.unit
Expand All @@ -293,17 +298,20 @@ def test_encoder():
),
)

cfg = {"col": "car", "algo": "ordinal"}
builder = ColumnBuilder(data_id, column_type, "Col1", cfg)
verify_builder(builder, lambda col: col.isnull().sum() == 0)
if parse_version(skl.__version__) >= parse_version("0.20.0"):
cfg = {"col": "car", "algo": "ordinal"}
builder = ColumnBuilder(data_id, column_type, "Col1", cfg)
verify_builder(builder, lambda col: col.isnull().sum() == 0)

cfg = {"col": "car", "algo": "label"}
builder = ColumnBuilder(data_id, column_type, "Col1", cfg)
verify_builder(builder, lambda col: col.isnull().sum() == 0)
if parse_version(skl.__version__) >= parse_version("0.12.0"):
cfg = {"col": "car", "algo": "label"}
builder = ColumnBuilder(data_id, column_type, "Col1", cfg)
verify_builder(builder, lambda col: col.isnull().sum() == 0)

cfg = {"col": "car", "algo": "feature_hasher", "n": 1}
builder = ColumnBuilder(data_id, column_type, "Col1", cfg)
verify_builder(builder, lambda col: col["car_0"].isnull().sum() == 0)
if parse_version(skl.__version__) >= parse_version("0.13.0"):
cfg = {"col": "car", "algo": "feature_hasher", "n": 1}
builder = ColumnBuilder(data_id, column_type, "Col1", cfg)
verify_builder(builder, lambda col: col["car_0"].isnull().sum() == 0)


@pytest.mark.unit
Expand Down

0 comments on commit 4eb1fbd

Please sign in to comment.