Skip to content

Commit

Permalink
[SPARK-51060][ML][PYTHON][CONNECT] Support QuantileDiscretizer on C…
Browse files Browse the repository at this point in the history
…onenct

### What changes were proposed in this pull request?
Support `QuantileDiscretizer` on Conenct

### Why are the changes needed?
for feature parity

### Does this PR introduce _any_ user-facing change?
yes

### How was this patch tested?
added tests

### Was this patch authored or co-authored using generative AI tooling?
no

Closes #49761 from zhengruifeng/ml_connect_qd.

Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
(cherry picked from commit 9ad8d22)
Signed-off-by: Ruifeng Zheng <[email protected]>
  • Loading branch information
zhengruifeng committed Feb 3, 2025
1 parent 53844cd commit 836a632
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ org.apache.spark.ml.recommendation.ALS
org.apache.spark.ml.fpm.FPGrowth

# feature
org.apache.spark.ml.feature.QuantileDiscretizer
org.apache.spark.ml.feature.RFormula
org.apache.spark.ml.feature.Imputer
org.apache.spark.ml.feature.StandardScaler
Expand Down
18 changes: 18 additions & 0 deletions python/pyspark/ml/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from pyspark.ml.util import JavaMLReadable, JavaMLWritable, try_remote_attribute_relation
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, JavaTransformer, _jvm
from pyspark.ml.common import inherit_doc
from pyspark.sql.utils import is_remote

if TYPE_CHECKING:
from py4j.java_gateway import JavaObject
Expand Down Expand Up @@ -3754,6 +3755,23 @@ def _create_model(self, java_model: "JavaObject") -> Bucketizer:
"""
Private method to convert the java_model to a Python model.
"""
if is_remote():
remote_model = JavaModel(java_model)
if self.isSet(self.inputCol):
return Bucketizer(
splits=remote_model._call_java("getSplits"),
inputCol=self.getInputCol(),
outputCol=self.getOutputCol(),
handleInvalid=self.getHandleInvalid(),
)
else:
return Bucketizer(
splitsArray=remote_model._call_java("getSplitsArray"),
inputCols=self.getInputCols(),
outputCols=self.getOutputCols(),
handleInvalid=self.getHandleInvalid(),
)

if self.isSet(self.inputCol):
return Bucketizer(
splits=list(java_model.getSplits()),
Expand Down
93 changes: 93 additions & 0 deletions python/pyspark/ml/tests/test_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
DCT,
Binarizer,
Bucketizer,
QuantileDiscretizer,
CountVectorizer,
CountVectorizerModel,
OneHotEncoder,
Expand Down Expand Up @@ -978,6 +979,98 @@ def test_binarizer(self):
binarizer2 = Binarizer.load(d)
self.assertEqual(str(binarizer), str(binarizer2))

def test_quantile_discretizer_single_column(self):
spark = self.spark
values = [(0.1,), (0.4,), (1.2,), (1.5,), (float("nan"),), (float("nan"),)]
df = spark.createDataFrame(values, ["values"])

qds = QuantileDiscretizer(inputCol="values", outputCol="buckets")
qds.setNumBuckets(2)
qds.setRelativeError(0.01)
qds.setHandleInvalid("keep")

self.assertEqual(qds.getInputCol(), "values")
self.assertEqual(qds.getOutputCol(), "buckets")
self.assertEqual(qds.getNumBuckets(), 2)
self.assertEqual(qds.getRelativeError(), 0.01)
self.assertEqual(qds.getHandleInvalid(), "keep")

bucketizer = qds.fit(df)
self.assertIsInstance(bucketizer, Bucketizer)
# Bucketizer doesn't inherit uid from QuantileDiscretizer
# self.assertEqual(qds.uid, bucketizer.uid)

# check model coefficients
self.assertEqual(bucketizer.getSplits(), [float("-inf"), 0.4, float("inf")])

output = bucketizer.transform(df)
self.assertEqual(output.columns, ["values", "buckets"])
self.assertEqual(output.count(), 6)

# save & load
with tempfile.TemporaryDirectory(prefix="quantile_discretizer_single_column") as d:
qds.write().overwrite().save(d)
qds2 = QuantileDiscretizer.load(d)
self.assertEqual(str(qds), str(qds2))

bucketizer.write().overwrite().save(d)
bucketizer2 = Bucketizer.load(d)
self.assertEqual(str(bucketizer), str(bucketizer2))

def test_quantile_discretizer_multiple_columns(self):
spark = self.spark
inputs = [
(0.1, 0.0),
(0.4, 1.0),
(1.2, 1.3),
(1.5, 1.5),
(float("nan"), float("nan")),
(float("nan"), float("nan")),
]
df = spark.createDataFrame(inputs, ["input1", "input2"])

qds = QuantileDiscretizer(
relativeError=0.01,
handleInvalid="keep",
numBuckets=2,
inputCols=["input1", "input2"],
outputCols=["output1", "output2"],
)

self.assertEqual(qds.getInputCols(), ["input1", "input2"])
self.assertEqual(qds.getOutputCols(), ["output1", "output2"])
self.assertEqual(qds.getNumBuckets(), 2)
self.assertEqual(qds.getRelativeError(), 0.01)
self.assertEqual(qds.getHandleInvalid(), "keep")

bucketizer = qds.fit(df)
self.assertIsInstance(bucketizer, Bucketizer)
# Bucketizer doesn't inherit uid from QuantileDiscretizer
# self.assertEqual(qds.uid, bucketizer.uid)

# check model coefficients
self.assertEqual(
bucketizer.getSplitsArray(),
[
[float("-inf"), 0.4, float("inf")],
[float("-inf"), 1.0, float("inf")],
],
)

output = bucketizer.transform(df)
self.assertEqual(output.columns, ["input1", "input2", "output1", "output2"])
self.assertEqual(output.count(), 6)

# save & load
with tempfile.TemporaryDirectory(prefix="quantile_discretizer_multiple_columns") as d:
qds.write().overwrite().save(d)
qds2 = QuantileDiscretizer.load(d)
self.assertEqual(str(qds), str(qds2))

bucketizer.write().overwrite().save(d)
bucketizer2 = Bucketizer.load(d)
self.assertEqual(str(bucketizer), str(bucketizer2))

def test_bucketizer(self):
df = self.spark.createDataFrame(
[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,7 @@ private[ml] object MLUtils {
(classOf[FPGrowthModel], Set("associationRules", "freqItemsets")),

// Feature Models
(classOf[Bucketizer], Set("getSplits", "getSplitsArray")),
(classOf[ImputerModel], Set("surrogateDF")),
(classOf[StandardScalerModel], Set("mean", "std")),
(classOf[MaxAbsScalerModel], Set("maxAbs")),
Expand Down

0 comments on commit 836a632

Please sign in to comment.