From 1a492377321939b59313a78a367375dc016b466f Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Thu, 23 Jan 2025 16:58:06 +0800 Subject: [PATCH] [SPARK-50939][ML][PYTHON][CONNECT] Support Word2Vec on Connect ### What changes were proposed in this pull request? Support Word2Vec on Connect ### Why are the changes needed? for feature parity ### Does this PR introduce _any_ user-facing change? yes, new algorithm supported ### How was this patch tested? added test ### Was this patch authored or co-authored using generative AI tooling? no Closes #49614 from zhengruifeng/ml_connect_w2v. Authored-by: Ruifeng Zheng Signed-off-by: Ruifeng Zheng --- .../services/org.apache.spark.ml.Estimator | 1 + .../services/org.apache.spark.ml.Transformer | 1 + .../apache/spark/ml/feature/Word2Vec.scala | 2 + python/pyspark/ml/feature.py | 4 +- python/pyspark/ml/tests/test_feature.py | 44 +++++++++++++++++++ .../apache/spark/sql/connect/ml/MLUtils.scala | 3 +- 6 files changed, 53 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator index 2cc3c3b6382aa..6c5bbd858d9cc 100644 --- a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator +++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator @@ -52,3 +52,4 @@ org.apache.spark.ml.feature.MinMaxScaler org.apache.spark.ml.feature.RobustScaler org.apache.spark.ml.feature.StringIndexer org.apache.spark.ml.feature.PCA +org.apache.spark.ml.feature.Word2Vec diff --git a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer index 886c4cca209bf..0448117468198 100644 --- a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer +++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer @@ -50,3 +50,4 @@ org.apache.spark.ml.feature.MinMaxScalerModel org.apache.spark.ml.feature.RobustScalerModel org.apache.spark.ml.feature.StringIndexerModel org.apache.spark.ml.feature.PCAModel +org.apache.spark.ml.feature.Word2VecModel diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 0329190a239ec..c3eeb394c5d47 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -211,6 +211,8 @@ class Word2VecModel private[ml] ( import Word2VecModel._ + private[ml] def this() = this(Identifiable.randomUID("w2v"), null) + /** * Returns a dataframe with two fields, "word" and "vector", with "word" being a String and * and the vector the DenseVector that it is mapped to. diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index cf12a5390746f..ff8555fadbd12 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -50,7 +50,7 @@ Param, Params, ) -from pyspark.ml.util import JavaMLReadable, JavaMLWritable +from pyspark.ml.util import JavaMLReadable, JavaMLWritable, try_remote_attribute_relation from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, JavaTransformer, _jvm from pyspark.ml.common import inherit_doc @@ -6381,6 +6381,7 @@ class Word2VecModel(JavaModel, _Word2VecParams, JavaMLReadable["Word2VecModel"], """ @since("1.5.0") + @try_remote_attribute_relation def getVectors(self) -> DataFrame: """ Returns the vector representation of the words as a dataframe @@ -6401,6 +6402,7 @@ def setOutputCol(self, value: str) -> "Word2VecModel": return self._set(outputCol=value) @since("1.5.0") + @try_remote_attribute_relation def findSynonyms(self, word: Union[str, Vector], num: int) -> DataFrame: """ Find "num" number of words closest in similarity to "word". diff --git a/python/pyspark/ml/tests/test_feature.py b/python/pyspark/ml/tests/test_feature.py index faad7c941cb5a..9766ab1b02438 100644 --- a/python/pyspark/ml/tests/test_feature.py +++ b/python/pyspark/ml/tests/test_feature.py @@ -46,6 +46,8 @@ VectorAssembler, PCA, PCAModel, + Word2Vec, + Word2VecModel, ) from pyspark.ml.linalg import DenseVector, SparseVector, Vectors from pyspark.sql import Row @@ -357,6 +359,48 @@ def test_robust_scaler(self): self.assertEqual(str(model), str(model2)) self.assertEqual(model2.getOutputCol(), "scaled") + def test_word2vec(self): + sent = ("a b " * 100 + "a c " * 10).split(" ") + df = self.spark.createDataFrame([(sent,), (sent,)], ["sentence"]).coalesce(1) + + w2v = Word2Vec(vectorSize=3, seed=42, inputCol="sentence", outputCol="model") + w2v.setMaxIter(1) + self.assertEqual(w2v.getInputCol(), "sentence") + self.assertEqual(w2v.getOutputCol(), "model") + self.assertEqual(w2v.getVectorSize(), 3) + self.assertEqual(w2v.getSeed(), 42) + self.assertEqual(w2v.getMaxIter(), 1) + + model = w2v.fit(df) + self.assertEqual(model.getVectors().columns, ["word", "vector"]) + self.assertEqual(model.getVectors().count(), 3) + + synonyms = model.findSynonyms("a", 2) + self.assertEqual(synonyms.columns, ["word", "similarity"]) + self.assertEqual(synonyms.count(), 2) + + # TODO(SPARK-50958): Support Word2VecModel.findSynonymsArray + # synonyms = model.findSynonymsArray("a", 2) + # self.assertEqual(len(synonyms), 2) + # self.assertEqual(synonyms[0][0], "b") + # self.assertTrue(np.allclose(synonyms[0][1], -0.024012837558984756, atol=1e-4)) + # self.assertEqual(synonyms[1][0], "c") + # self.assertTrue(np.allclose(synonyms[1][1], -0.19355154037475586, atol=1e-4)) + + output = model.transform(df) + self.assertEqual(output.columns, ["sentence", "model"]) + self.assertEqual(output.count(), 2) + + # save & load + with tempfile.TemporaryDirectory(prefix="word2vec") as d: + w2v.write().overwrite().save(d) + w2v2 = Word2Vec.load(d) + self.assertEqual(str(w2v), str(w2v2)) + + model.write().overwrite().save(d) + model2 = Word2VecModel.load(d) + self.assertEqual(str(model), str(model2)) + def test_binarizer(self): b0 = Binarizer() self.assertListEqual( diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala index e94ad0db41aa6..dd961a3415cb5 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala @@ -584,7 +584,8 @@ private[ml] object MLUtils { (classOf[MaxAbsScalerModel], Set("maxAbs")), (classOf[MinMaxScalerModel], Set("originalMax", "originalMin")), (classOf[RobustScalerModel], Set("range", "median")), - (classOf[PCAModel], Set("pc", "explainedVariance"))) + (classOf[PCAModel], Set("pc", "explainedVariance")), + (classOf[Word2VecModel], Set("getVectors", "findSynonyms", "findSynonymsArray"))) private def validate(obj: Any, method: String): Unit = { assert(obj != null)