Skip to content

Commit

Permalink
[SPARK-50934][ML][PYTHON][CONNECT] Support CountVectorizer and OneHot…
Browse files Browse the repository at this point in the history
…Encoder on Connect

### What changes were proposed in this pull request?
Support CountVectorizer and OneHotEncoder on Connect

### Why are the changes needed?
for feature parity

### Does this PR introduce _any_ user-facing change?
yes

### How was this patch tested?
added tests

### Was this patch authored or co-authored using generative AI tooling?
no

Closes #49647 from zhengruifeng/ml_connect_cv.

Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
  • Loading branch information
zhengruifeng committed Jan 24, 2025
1 parent 39e9b3b commit 5c1f7c2
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,6 @@ org.apache.spark.ml.feature.VarianceThresholdSelector
org.apache.spark.ml.feature.StringIndexer
org.apache.spark.ml.feature.PCA
org.apache.spark.ml.feature.Word2Vec
org.apache.spark.ml.feature.CountVectorizer
org.apache.spark.ml.feature.OneHotEncoder

Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,6 @@ org.apache.spark.ml.feature.VarianceThresholdSelectorModel
org.apache.spark.ml.feature.StringIndexerModel
org.apache.spark.ml.feature.PCAModel
org.apache.spark.ml.feature.Word2VecModel
org.apache.spark.ml.feature.CountVectorizerModel
org.apache.spark.ml.feature.OneHotEncoderModel

Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,8 @@ class CountVectorizerModel(

import CountVectorizerModel._

private[ml] def this() = this(Identifiable.randomUID("cntVecModel"), Array.empty)

@Since("1.5.0")
def this(vocabulary: Array[String]) = {
this(Identifiable.randomUID("cntVecModel"), vocabulary)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,8 @@ class OneHotEncoderModel private[ml] (

import OneHotEncoderModel._

private[ml] def this() = this(Identifiable.randomUID("oneHotEncoder)"), Array.emptyIntArray)

// Returns the category size for each index with `dropLast` and `handleInvalid`
// taken into account.
private def getConfigedCategorySizes: Array[Int] = {
Expand Down
57 changes: 57 additions & 0 deletions python/pyspark/ml/tests/test_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
Bucketizer,
CountVectorizer,
CountVectorizerModel,
OneHotEncoder,
OneHotEncoderModel,
HashingTF,
IDF,
NGram,
Expand Down Expand Up @@ -536,6 +538,61 @@ def test_word2vec(self):
model2 = Word2VecModel.load(d)
self.assertEqual(str(model), str(model2))

def test_count_vectorizer(self):
df = self.spark.createDataFrame(
[(0, ["a", "b", "c"]), (1, ["a", "b", "b", "c", "a"])],
["label", "raw"],
)

cv = CountVectorizer()
cv.setInputCol("raw")
cv.setOutputCol("vectors")
self.assertEqual(cv.getInputCol(), "raw")
self.assertEqual(cv.getOutputCol(), "vectors")

model = cv.fit(df)
self.assertEqual(sorted(model.vocabulary), ["a", "b", "c"])

output = model.transform(df)
self.assertEqual(output.columns, ["label", "raw", "vectors"])
self.assertEqual(output.count(), 2)

# save & load
with tempfile.TemporaryDirectory(prefix="count_vectorizer") as d:
cv.write().overwrite().save(d)
cv2 = CountVectorizer.load(d)
self.assertEqual(str(cv), str(cv2))

model.write().overwrite().save(d)
model2 = CountVectorizerModel.load(d)
self.assertEqual(str(model), str(model2))

def test_one_hot_encoder(self):
df = self.spark.createDataFrame([(0.0,), (1.0,), (2.0,)], ["input"])

encoder = OneHotEncoder()
encoder.setInputCols(["input"])
encoder.setOutputCols(["output"])
self.assertEqual(encoder.getInputCols(), ["input"])
self.assertEqual(encoder.getOutputCols(), ["output"])

model = encoder.fit(df)
self.assertEqual(model.categorySizes, [3])

output = model.transform(df)
self.assertEqual(output.columns, ["input", "output"])
self.assertEqual(output.count(), 3)

# save & load
with tempfile.TemporaryDirectory(prefix="count_vectorizer") as d:
encoder.write().overwrite().save(d)
encoder2 = OneHotEncoder.load(d)
self.assertEqual(str(encoder), str(encoder2))

model.write().overwrite().save(d)
model2 = OneHotEncoderModel.load(d)
self.assertEqual(str(model), str(model2))

def test_tokenizer(self):
df = self.spark.createDataFrame([("a b c",)], ["text"])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,9 @@ private[ml] object MLUtils {
(classOf[UnivariateFeatureSelectorModel], Set("selectedFeatures")),
(classOf[VarianceThresholdSelectorModel], Set("selectedFeatures")),
(classOf[PCAModel], Set("pc", "explainedVariance")),
(classOf[Word2VecModel], Set("getVectors", "findSynonyms", "findSynonymsArray")))
(classOf[Word2VecModel], Set("getVectors", "findSynonyms", "findSynonymsArray")),
(classOf[CountVectorizerModel], Set("vocabulary")),
(classOf[OneHotEncoderModel], Set("categorySizes")))

private def validate(obj: Any, method: String): Unit = {
assert(obj != null)
Expand Down

0 comments on commit 5c1f7c2

Please sign in to comment.