From 54a79a8c36a69f14f08bec6e95b4c35cddd3e48a Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Sun, 2 Feb 2025 16:29:16 +0800 Subject: [PATCH 1/3] init --- .../services/org.apache.spark.ml.Transformer | 1 + .../org/apache/spark/ml/fpm/PrefixSpan.scala | 95 ++++++++++++------- python/pyspark/ml/fpm.py | 15 ++- python/pyspark/ml/tests/test_fpm.py | 29 +++++- python/pyspark/ml/util.py | 8 +- 5 files changed, 110 insertions(+), 38 deletions(-) diff --git a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer index fc6a8166442a4..9372255980a87 100644 --- a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer +++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer @@ -71,6 +71,7 @@ org.apache.spark.ml.recommendation.ALSModel # fpm org.apache.spark.ml.fpm.FPGrowthModel +org.apache.spark.ml.fpm.PrefixSpanWrapper # feature org.apache.spark.ml.feature.RFormulaModel diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala index 3ea76658d1a92..099e42ee27496 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.fpm import org.apache.spark.annotation.Since +import org.apache.spark.ml.Transformer import org.apache.spark.ml.param._ import org.apache.spark.ml.util.Identifiable import org.apache.spark.ml.util.Instrumentation.instrumented @@ -26,23 +27,7 @@ import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions.col import org.apache.spark.sql.types.{ArrayType, LongType, StructField, StructType} -/** - * A parallel PrefixSpan algorithm to mine frequent sequential patterns. - * The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining Sequential Patterns - * Efficiently by Prefix-Projected Pattern Growth - * (see here). - * This class is not yet an Estimator/Transformer, use `findFrequentSequentialPatterns` method to - * run the PrefixSpan algorithm. - * - * @see Sequential Pattern Mining - * (Wikipedia) - */ -@Since("2.4.0") -final class PrefixSpan(@Since("2.4.0") override val uid: String) extends Params { - - @Since("2.4.0") - def this() = this(Identifiable.randomUID("prefixSpan")) - +private[fpm] trait PrefixSpanParams extends Params { /** * Param for the minimal support level (default: `0.1`). * Sequential patterns that appear more than (minSupport * size-of-the-dataset) times are @@ -59,10 +44,6 @@ final class PrefixSpan(@Since("2.4.0") override val uid: String) extends Params @Since("2.4.0") def getMinSupport: Double = $(minSupport) - /** @group setParam */ - @Since("2.4.0") - def setMinSupport(value: Double): this.type = set(minSupport, value) - /** * Param for the maximal pattern length (default: `10`). * @group param @@ -76,10 +57,6 @@ final class PrefixSpan(@Since("2.4.0") override val uid: String) extends Params @Since("2.4.0") def getMaxPatternLength: Int = $(maxPatternLength) - /** @group setParam */ - @Since("2.4.0") - def setMaxPatternLength(value: Int): this.type = set(maxPatternLength, value) - /** * Param for the maximum number of items (including delimiters used in the internal storage * format) allowed in a projected database before local processing (default: `32000000`). @@ -90,18 +67,14 @@ final class PrefixSpan(@Since("2.4.0") override val uid: String) extends Params @Since("2.4.0") val maxLocalProjDBSize = new LongParam(this, "maxLocalProjDBSize", "The maximum number of items (including delimiters used in the internal storage format) " + - "allowed in a projected database before local processing. If a projected database exceeds " + - "this size, another iteration of distributed prefix growth is run.", + "allowed in a projected database before local processing. If a projected database exceeds " + + "this size, another iteration of distributed prefix growth is run.", ParamValidators.gt(0)) /** @group getParam */ @Since("2.4.0") def getMaxLocalProjDBSize: Long = $(maxLocalProjDBSize) - /** @group setParam */ - @Since("2.4.0") - def setMaxLocalProjDBSize(value: Long): this.type = set(maxLocalProjDBSize, value) - /** * Param for the name of the sequence column in dataset (default "sequence"), rows with * nulls in this column are ignored. @@ -115,12 +88,42 @@ final class PrefixSpan(@Since("2.4.0") override val uid: String) extends Params @Since("2.4.0") def getSequenceCol: String = $(sequenceCol) + setDefault(minSupport -> 0.1, maxPatternLength -> 10, maxLocalProjDBSize -> 32000000, + sequenceCol -> "sequence") +} + +/** + * A parallel PrefixSpan algorithm to mine frequent sequential patterns. + * The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining Sequential Patterns + * Efficiently by Prefix-Projected Pattern Growth + * (see here). + * This class is not yet an Estimator/Transformer, use `findFrequentSequentialPatterns` method to + * run the PrefixSpan algorithm. + * + * @see Sequential Pattern Mining + * (Wikipedia) + */ +@Since("2.4.0") +final class PrefixSpan(@Since("2.4.0") override val uid: String) extends PrefixSpanParams { + + @Since("2.4.0") + def this() = this(Identifiable.randomUID("prefixSpan")) + /** @group setParam */ @Since("2.4.0") - def setSequenceCol(value: String): this.type = set(sequenceCol, value) + def setMinSupport(value: Double): this.type = set(minSupport, value) - setDefault(minSupport -> 0.1, maxPatternLength -> 10, maxLocalProjDBSize -> 32000000, - sequenceCol -> "sequence") + /** @group setParam */ + @Since("2.4.0") + def setMaxPatternLength(value: Int): this.type = set(maxPatternLength, value) + + /** @group setParam */ + @Since("2.4.0") + def setMaxLocalProjDBSize(value: Long): this.type = set(maxLocalProjDBSize, value) + + /** @group setParam */ + @Since("2.4.0") + def setSequenceCol(value: String): this.type = set(sequenceCol, value) /** * Finds the complete set of frequent sequential patterns in the input sequences of itemsets. @@ -167,3 +170,27 @@ final class PrefixSpan(@Since("2.4.0") override val uid: String) extends Params override def copy(extra: ParamMap): PrefixSpan = defaultCopy(extra) } + +private[spark] class PrefixSpanWrapper(override val uid: String) + extends Transformer with PrefixSpanParams { + + def this() = this(Identifiable.randomUID("prefixSpanWrapper")) + + override def transformSchema(schema: StructType): StructType = { + new StructType() + .add("sequence", ArrayType(schema($(sequenceCol)).dataType), nullable = false) + .add("freq", LongType, nullable = false) + } + + override def transform(dataset: Dataset[_]): DataFrame = { + val prefixSpan = new PrefixSpan(uid) + prefixSpan + .setMinSupport($(minSupport)) + .setMaxPatternLength($(maxPatternLength)) + .setMaxLocalProjDBSize($(maxLocalProjDBSize)) + .setSequenceCol($(sequenceCol)) + .findFrequentSequentialPatterns(dataset) + } + + override def copy(extra: ParamMap): PrefixSpanWrapper = defaultCopy(extra) +} diff --git a/python/pyspark/ml/fpm.py b/python/pyspark/ml/fpm.py index c068b5f26ba84..98aa65be3a8bf 100644 --- a/python/pyspark/ml/fpm.py +++ b/python/pyspark/ml/fpm.py @@ -16,6 +16,7 @@ # import sys +import copy from typing import Any, Dict, Optional, TYPE_CHECKING from pyspark import keyword_only, since @@ -23,6 +24,7 @@ from pyspark.ml.util import JavaMLWritable, JavaMLReadable, try_remote_attribute_relation from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams from pyspark.ml.param.shared import HasPredictionCol, Param, TypeConverters, Params +from pyspark.sql.utils import is_remote if TYPE_CHECKING: from py4j.java_gateway import JavaObject @@ -506,9 +508,20 @@ def findFrequentSequentialPatterns(self, dataset: DataFrame) -> DataFrame: - `sequence: ArrayType(ArrayType(T))` (T is the item type) - `freq: Long` """ + assert self._java_obj is not None + + if is_remote(): + from python.pyspark.ml.wrapper import JavaTransformer + from pyspark.ml.connect.serialize import serialize_ml_params + + instance = JavaTransformer(self._java_obj) + instance._java_obj = "org.apache.spark.ml.fpm.PrefixSpanWrapper" + # The helper object is just a JavaTransformer without any Param Mixin, + # copying the params by .copy() or directly assigning the _paramMap won't work + instance._serialized_ml_params = serialize_ml_params(self, dataset.sparkSession.client) + return instance.transform(dataset) self._transfer_params_to_java() - assert self._java_obj is not None jdf = self._java_obj.findFrequentSequentialPatterns(dataset._jdf) return DataFrame(jdf, dataset.sparkSession) diff --git a/python/pyspark/ml/tests/test_fpm.py b/python/pyspark/ml/tests/test_fpm.py index cc8ead7127d69..1c2b717b1c854 100644 --- a/python/pyspark/ml/tests/test_fpm.py +++ b/python/pyspark/ml/tests/test_fpm.py @@ -18,11 +18,12 @@ import tempfile import unittest -from pyspark.sql import SparkSession +from pyspark.sql import SparkSession, Row import pyspark.sql.functions as sf from pyspark.ml.fpm import ( FPGrowth, FPGrowthModel, + PrefixSpan, ) @@ -71,6 +72,32 @@ def test_fp_growth(self): model2 = FPGrowthModel.load(d) self.assertEqual(str(model), str(model2)) + def test_prefix_span(self): + spark = self.spark + df = spark.createDataFrame( + [ + Row(sequence=[[1, 2], [3]]), + Row(sequence=[[1], [3, 2], [1, 2]]), + Row(sequence=[[1, 2], [5]]), + Row(sequence=[[6]]), + ] + ) + + ps = PrefixSpan() + ps.setMinSupport(0.5) + ps.setMaxPatternLength(5) + + self.assertEqual(ps.getMinSupport(), 0.5) + self.assertEqual(ps.getMaxPatternLength(), 5) + + output = ps.findFrequentSequentialPatterns(df) + self.assertEqual(output.columns, ["sequence", "freq"]) + self.assertEqual(output.count(), 5) + + head = output.sort("sequence").head() + self.assertEqual(head.sequence, [[1]]) + self.assertEqual(head.freq, 3) + class FPMTests(FPMTestsMixin, unittest.TestCase): def setUp(self) -> None: diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 04a91cf2f8b28..7895d796fd97f 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -156,9 +156,14 @@ def wrapped(self: "JavaWrapper", dataset: "ConnectDataFrame") -> Any: session = dataset.sparkSession assert session is not None + + if hasattr(self, "_serialized_ml_params"): + params = self._serialized_ml_params + else: + params = serialize_ml_params(self, session.client) + # Model is also a Transformer, so we much match Model first if isinstance(self, Model): - params = serialize_ml_params(self, session.client) from pyspark.ml.connect.proto import TransformerRelation assert isinstance(self._java_obj, str) @@ -169,7 +174,6 @@ def wrapped(self: "JavaWrapper", dataset: "ConnectDataFrame") -> Any: session, ) elif isinstance(self, Transformer): - params = serialize_ml_params(self, session.client) from pyspark.ml.connect.proto import TransformerRelation assert isinstance(self._java_obj, str) From 28ee05eadfdac6f7b49967b196abcd6b8b506722 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Sun, 2 Feb 2025 16:32:07 +0800 Subject: [PATCH 2/3] nit --- python/pyspark/ml/fpm.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/pyspark/ml/fpm.py b/python/pyspark/ml/fpm.py index 98aa65be3a8bf..3a497e2dba469 100644 --- a/python/pyspark/ml/fpm.py +++ b/python/pyspark/ml/fpm.py @@ -16,7 +16,6 @@ # import sys -import copy from typing import Any, Dict, Optional, TYPE_CHECKING from pyspark import keyword_only, since From 3e3a3c64a25e5ffbf8311de8daeb91632962daf4 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Sun, 2 Feb 2025 21:09:48 +0800 Subject: [PATCH 3/3] fix lint --- python/pyspark/ml/fpm.py | 12 ++++++++---- python/pyspark/ml/util.py | 2 +- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/python/pyspark/ml/fpm.py b/python/pyspark/ml/fpm.py index 3a497e2dba469..0e46ecc45e934 100644 --- a/python/pyspark/ml/fpm.py +++ b/python/pyspark/ml/fpm.py @@ -23,7 +23,6 @@ from pyspark.ml.util import JavaMLWritable, JavaMLReadable, try_remote_attribute_relation from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams from pyspark.ml.param.shared import HasPredictionCol, Param, TypeConverters, Params -from pyspark.sql.utils import is_remote if TYPE_CHECKING: from py4j.java_gateway import JavaObject @@ -507,17 +506,22 @@ def findFrequentSequentialPatterns(self, dataset: DataFrame) -> DataFrame: - `sequence: ArrayType(ArrayType(T))` (T is the item type) - `freq: Long` """ + from pyspark.sql.utils import is_remote + assert self._java_obj is not None if is_remote(): - from python.pyspark.ml.wrapper import JavaTransformer + from pyspark.ml.wrapper import JavaTransformer from pyspark.ml.connect.serialize import serialize_ml_params - instance = JavaTransformer(self._java_obj) + instance = JavaTransformer() instance._java_obj = "org.apache.spark.ml.fpm.PrefixSpanWrapper" # The helper object is just a JavaTransformer without any Param Mixin, # copying the params by .copy() or directly assigning the _paramMap won't work - instance._serialized_ml_params = serialize_ml_params(self, dataset.sparkSession.client) + instance._serialized_ml_params = serialize_ml_params( # type: ignore[attr-defined] + self, + dataset.sparkSession.client, # type: ignore[arg-type,operator] + ) return instance.transform(dataset) self._transfer_params_to_java() diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 7895d796fd97f..b895a86ae5350 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -160,7 +160,7 @@ def wrapped(self: "JavaWrapper", dataset: "ConnectDataFrame") -> Any: if hasattr(self, "_serialized_ml_params"): params = self._serialized_ml_params else: - params = serialize_ml_params(self, session.client) + params = serialize_ml_params(self, session.client) # type: ignore[arg-type] # Model is also a Transformer, so we much match Model first if isinstance(self, Model):