Skip to content

Commit

Permalink
[SPARK-50812][ML][PYTHON][CONNECT] Add support PolynomialExpansion
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Support PolynomialExpansion on connect

### Why are the changes needed?

feature parity

### Does this PR introduce _any_ user-facing change?
Yes

### How was this patch tested?
CI passes

### Was this patch authored or co-authored using generative AI tooling?
No

Closes #49702 from wbo4958/px.

Authored-by: Bobby Wang <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
  • Loading branch information
wbo4958 authored and zhengruifeng committed Jan 28, 2025
1 parent e891627 commit aa24a9a
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ org.apache.spark.ml.feature.FeatureHasher
org.apache.spark.ml.feature.ElementwiseProduct
org.apache.spark.ml.feature.HashingTF
org.apache.spark.ml.feature.IndexToString
org.apache.spark.ml.feature.PolynomialExpansion

########### Model for loading
# classification
Expand Down
26 changes: 26 additions & 0 deletions python/pyspark/ml/tests/test_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
MinHashLSH,
MinHashLSHModel,
IndexToString,
PolynomialExpansion,
)
from pyspark.ml.linalg import DenseVector, SparseVector, Vectors
from pyspark.sql import Row
Expand All @@ -85,6 +86,31 @@


class FeatureTestsMixin:
def test_polynomial_expansion(self):
df = self.spark.createDataFrame([(Vectors.dense([0.5, 2.0]),)], ["dense"])
px = PolynomialExpansion(degree=2)
px.setInputCol("dense")
px.setOutputCol("expanded")
self.assertTrue(
np.allclose(
px.transform(df).head().expanded.toArray(), [0.5, 0.25, 2.0, 1.0, 4.0], atol=1e-4
)
)

def check(p: PolynomialExpansion) -> None:
self.assertEqual(p.getInputCol(), "dense")
self.assertEqual(p.getOutputCol(), "expanded")
self.assertEqual(p.getDegree(), 2)

check(px)

# save & load
with tempfile.TemporaryDirectory(prefix="px") as d:
px.write().overwrite().save(d)
px2 = PolynomialExpansion.load(d)
self.assertEqual(str(px), str(px2))
check(px2)

def test_index_string(self):
dataset = self.spark.createDataFrame(
[
Expand Down

0 comments on commit aa24a9a

Please sign in to comment.