Skip to content

Commit

Permalink
[SPARK-50975][ML][PYTHON][CONNECT] Support `CountVectorizerModel.from…
Browse files Browse the repository at this point in the history
…_vocabulary` on connect

### What changes were proposed in this pull request?
Support `CountVectorizerModel.from_vocabulary` on connect

### Why are the changes needed?
For feature parity

### Does this PR introduce _any_ user-facing change?
yes, new API supported

### How was this patch tested?
enabled parity test

### Was this patch authored or co-authored using generative AI tooling?
no

Closes #49806 from zhengruifeng/ml_connect_from_voc.

Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
  • Loading branch information
zhengruifeng committed Feb 5, 2025
1 parent 6ed7733 commit 8b092be
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package org.apache.spark.ml.util

import org.apache.spark.ml.Model
import org.apache.spark.ml.feature.StringIndexerModel
import org.apache.spark.ml.feature.{CountVectorizerModel, StringIndexerModel}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.types.StructType
Expand All @@ -35,6 +35,11 @@ private[spark] class ConnectHelper(override val uid: String) extends Model[Conne
new StringIndexerModel(uid, labelsArray)
}

def countVectorizerModelFromVocabulary(
uid: String, vocabulary: Array[String]): CountVectorizerModel = {
new CountVectorizerModel(uid, vocabulary)
}

override def copy(extra: ParamMap): ConnectHelper = defaultCopy(extra)

override def transform(dataset: Dataset[_]): DataFrame = dataset.toDF()
Expand Down
32 changes: 23 additions & 9 deletions python/pyspark/ml/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -1220,15 +1220,29 @@ def from_vocabulary(
Construct the model directly from a vocabulary list of strings,
requires an active SparkContext.
"""
from pyspark.core.context import SparkContext

sc = SparkContext._active_spark_context
assert sc is not None and sc._gateway is not None
java_class = getattr(sc._gateway.jvm, "java.lang.String")
jvocab = CountVectorizerModel._new_java_array(vocabulary, java_class)
model = CountVectorizerModel._create_from_java_class(
"org.apache.spark.ml.feature.CountVectorizerModel", jvocab
)
if len(vocabulary) == 0:
raise ValueError("Vocabulary list cannot be empty")

if is_remote():
model = CountVectorizerModel()
helper = JavaWrapper(java_obj=ML_CONNECT_HELPER_ID)
model._java_obj = helper._call_java(
"countVectorizerModelFromVocabulary",
model.uid,
list(vocabulary),
)

else:
from pyspark.core.context import SparkContext

sc = SparkContext._active_spark_context
assert sc is not None and sc._gateway is not None
java_class = getattr(sc._gateway.jvm, "java.lang.String")
jvocab = CountVectorizerModel._new_java_array(vocabulary, java_class)
model = CountVectorizerModel._create_from_java_class(
"org.apache.spark.ml.feature.CountVectorizerModel", jvocab
)

model.setInputCol(inputCol)
if outputCol is not None:
model.setOutputCol(outputCol)
Expand Down
4 changes: 0 additions & 4 deletions python/pyspark/ml/tests/connect/test_parity_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,6 @@


class FeatureParityTests(FeatureTestsMixin, ReusedConnectTestCase):
@unittest.skip("Need to support.")
def test_count_vectorizer_from_vocab(self):
super().test_count_vectorizer_from_vocab()

@unittest.skip("Need to support.")
def test_stop_words_lengague_selection(self):
super().test_stop_words_lengague_selection()
Expand Down
6 changes: 2 additions & 4 deletions python/pyspark/ml/tests/test_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@
)
from pyspark.ml.linalg import DenseVector, SparseVector, Vectors
from pyspark.sql import Row
from pyspark.testing.utils import QuietTest
from pyspark.testing.mlutils import SparkSessionTestCase


Expand Down Expand Up @@ -1356,9 +1355,8 @@ def test_count_vectorizer_from_vocab(self):
self.assertEqual(feature, expected)

# Test an empty vocabulary
with QuietTest(self.sc):
with self.assertRaisesRegex(Exception, "vocabSize.*invalid.*0"):
CountVectorizerModel.from_vocabulary([], inputCol="words")
with self.assertRaisesRegex(Exception, "Vocabulary list cannot be empty"):
CountVectorizerModel.from_vocabulary([], inputCol="words")

# Test model with default settings can transform
model_default = CountVectorizerModel.from_vocabulary(["a", "b", "c"], inputCol="words")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -655,7 +655,10 @@ private[ml] object MLUtils {
// Utils
(
classOf[ConnectHelper],
Set("stringIndexerModelFromLabels", "stringIndexerModelFromLabelsArray")))
Set(
"stringIndexerModelFromLabels",
"stringIndexerModelFromLabelsArray",
"countVectorizerModelFromVocabulary")))

private def validate(obj: Any, method: String): Unit = {
assert(obj != null)
Expand Down

0 comments on commit 8b092be

Please sign in to comment.