Skip to content

Commit

Permalink
r.learn.train.py: changed sparse keyword in OneHotEncoder (#1180)
Browse files Browse the repository at this point in the history
Ensure scikit-learn version is at least 1.2.2

Use sparse_output instead of sparse
  • Loading branch information
rohannallamadge authored Sep 7, 2024
1 parent 5762dd9 commit b6c90cc
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions src/raster/r.learn.ml2/r.learn.train/r.learn.train.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,11 +445,11 @@ def main():
try:
import sklearn

if sklearn.__version__ < "0.20":
gs.fatal("Package python3-scikit-learn 0.20 or newer is not installed")
if sklearn.__version__ < "1.2.2":
gs.fatal("Package python3-scikit-learn 1.2.2 or newer is not installed")

except ImportError:
gs.fatal("Package python3-scikit-learn 0.20 or newer is not installed")
gs.fatal("Package python3-scikit-learn 1.2.2 or newer is not installed")

try:
import pandas as pd
Expand Down Expand Up @@ -683,15 +683,15 @@ def main():

# one-hot encoding
elif norm_data is False and category_maps is not None:
enc = OneHotEncoder(handle_unknown="ignore", sparse=False)
enc = OneHotEncoder(handle_unknown="ignore", sparse_output=False)
trans = ColumnTransformer(
remainder="passthrough", transformers=[("onehot", enc, stack.categorical)]
)

# standardization and one-hot encoding
elif norm_data is True and category_maps is not None:
scaler = StandardScaler()
enc = OneHotEncoder(handle_unknown="ignore", sparse=False)
enc = OneHotEncoder(handle_unknown="ignore", sparse_output=False)
trans = ColumnTransformer(
remainder="passthrough",
transformers=[
Expand Down

0 comments on commit b6c90cc

Please sign in to comment.