diff --git a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
index fc6a8166442a4..9372255980a87 100644
--- a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
+++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
@@ -71,6 +71,7 @@ org.apache.spark.ml.recommendation.ALSModel
# fpm
org.apache.spark.ml.fpm.FPGrowthModel
+org.apache.spark.ml.fpm.PrefixSpanWrapper
# feature
org.apache.spark.ml.feature.RFormulaModel
diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala
index 3ea76658d1a92..099e42ee27496 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala
@@ -18,6 +18,7 @@
package org.apache.spark.ml.fpm
import org.apache.spark.annotation.Since
+import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param._
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.util.Instrumentation.instrumented
@@ -26,23 +27,7 @@ import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.{ArrayType, LongType, StructField, StructType}
-/**
- * A parallel PrefixSpan algorithm to mine frequent sequential patterns.
- * The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining Sequential Patterns
- * Efficiently by Prefix-Projected Pattern Growth
- * (see here).
- * This class is not yet an Estimator/Transformer, use `findFrequentSequentialPatterns` method to
- * run the PrefixSpan algorithm.
- *
- * @see Sequential Pattern Mining
- * (Wikipedia)
- */
-@Since("2.4.0")
-final class PrefixSpan(@Since("2.4.0") override val uid: String) extends Params {
-
- @Since("2.4.0")
- def this() = this(Identifiable.randomUID("prefixSpan"))
-
+private[fpm] trait PrefixSpanParams extends Params {
/**
* Param for the minimal support level (default: `0.1`).
* Sequential patterns that appear more than (minSupport * size-of-the-dataset) times are
@@ -59,10 +44,6 @@ final class PrefixSpan(@Since("2.4.0") override val uid: String) extends Params
@Since("2.4.0")
def getMinSupport: Double = $(minSupport)
- /** @group setParam */
- @Since("2.4.0")
- def setMinSupport(value: Double): this.type = set(minSupport, value)
-
/**
* Param for the maximal pattern length (default: `10`).
* @group param
@@ -76,10 +57,6 @@ final class PrefixSpan(@Since("2.4.0") override val uid: String) extends Params
@Since("2.4.0")
def getMaxPatternLength: Int = $(maxPatternLength)
- /** @group setParam */
- @Since("2.4.0")
- def setMaxPatternLength(value: Int): this.type = set(maxPatternLength, value)
-
/**
* Param for the maximum number of items (including delimiters used in the internal storage
* format) allowed in a projected database before local processing (default: `32000000`).
@@ -90,18 +67,14 @@ final class PrefixSpan(@Since("2.4.0") override val uid: String) extends Params
@Since("2.4.0")
val maxLocalProjDBSize = new LongParam(this, "maxLocalProjDBSize",
"The maximum number of items (including delimiters used in the internal storage format) " +
- "allowed in a projected database before local processing. If a projected database exceeds " +
- "this size, another iteration of distributed prefix growth is run.",
+ "allowed in a projected database before local processing. If a projected database exceeds " +
+ "this size, another iteration of distributed prefix growth is run.",
ParamValidators.gt(0))
/** @group getParam */
@Since("2.4.0")
def getMaxLocalProjDBSize: Long = $(maxLocalProjDBSize)
- /** @group setParam */
- @Since("2.4.0")
- def setMaxLocalProjDBSize(value: Long): this.type = set(maxLocalProjDBSize, value)
-
/**
* Param for the name of the sequence column in dataset (default "sequence"), rows with
* nulls in this column are ignored.
@@ -115,12 +88,42 @@ final class PrefixSpan(@Since("2.4.0") override val uid: String) extends Params
@Since("2.4.0")
def getSequenceCol: String = $(sequenceCol)
+ setDefault(minSupport -> 0.1, maxPatternLength -> 10, maxLocalProjDBSize -> 32000000,
+ sequenceCol -> "sequence")
+}
+
+/**
+ * A parallel PrefixSpan algorithm to mine frequent sequential patterns.
+ * The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining Sequential Patterns
+ * Efficiently by Prefix-Projected Pattern Growth
+ * (see here).
+ * This class is not yet an Estimator/Transformer, use `findFrequentSequentialPatterns` method to
+ * run the PrefixSpan algorithm.
+ *
+ * @see Sequential Pattern Mining
+ * (Wikipedia)
+ */
+@Since("2.4.0")
+final class PrefixSpan(@Since("2.4.0") override val uid: String) extends PrefixSpanParams {
+
+ @Since("2.4.0")
+ def this() = this(Identifiable.randomUID("prefixSpan"))
+
/** @group setParam */
@Since("2.4.0")
- def setSequenceCol(value: String): this.type = set(sequenceCol, value)
+ def setMinSupport(value: Double): this.type = set(minSupport, value)
- setDefault(minSupport -> 0.1, maxPatternLength -> 10, maxLocalProjDBSize -> 32000000,
- sequenceCol -> "sequence")
+ /** @group setParam */
+ @Since("2.4.0")
+ def setMaxPatternLength(value: Int): this.type = set(maxPatternLength, value)
+
+ /** @group setParam */
+ @Since("2.4.0")
+ def setMaxLocalProjDBSize(value: Long): this.type = set(maxLocalProjDBSize, value)
+
+ /** @group setParam */
+ @Since("2.4.0")
+ def setSequenceCol(value: String): this.type = set(sequenceCol, value)
/**
* Finds the complete set of frequent sequential patterns in the input sequences of itemsets.
@@ -167,3 +170,27 @@ final class PrefixSpan(@Since("2.4.0") override val uid: String) extends Params
override def copy(extra: ParamMap): PrefixSpan = defaultCopy(extra)
}
+
+private[spark] class PrefixSpanWrapper(override val uid: String)
+ extends Transformer with PrefixSpanParams {
+
+ def this() = this(Identifiable.randomUID("prefixSpanWrapper"))
+
+ override def transformSchema(schema: StructType): StructType = {
+ new StructType()
+ .add("sequence", ArrayType(schema($(sequenceCol)).dataType), nullable = false)
+ .add("freq", LongType, nullable = false)
+ }
+
+ override def transform(dataset: Dataset[_]): DataFrame = {
+ val prefixSpan = new PrefixSpan(uid)
+ prefixSpan
+ .setMinSupport($(minSupport))
+ .setMaxPatternLength($(maxPatternLength))
+ .setMaxLocalProjDBSize($(maxLocalProjDBSize))
+ .setSequenceCol($(sequenceCol))
+ .findFrequentSequentialPatterns(dataset)
+ }
+
+ override def copy(extra: ParamMap): PrefixSpanWrapper = defaultCopy(extra)
+}
diff --git a/python/pyspark/ml/fpm.py b/python/pyspark/ml/fpm.py
index c068b5f26ba84..0e46ecc45e934 100644
--- a/python/pyspark/ml/fpm.py
+++ b/python/pyspark/ml/fpm.py
@@ -506,9 +506,25 @@ def findFrequentSequentialPatterns(self, dataset: DataFrame) -> DataFrame:
- `sequence: ArrayType(ArrayType(T))` (T is the item type)
- `freq: Long`
"""
+ from pyspark.sql.utils import is_remote
- self._transfer_params_to_java()
assert self._java_obj is not None
+
+ if is_remote():
+ from pyspark.ml.wrapper import JavaTransformer
+ from pyspark.ml.connect.serialize import serialize_ml_params
+
+ instance = JavaTransformer()
+ instance._java_obj = "org.apache.spark.ml.fpm.PrefixSpanWrapper"
+ # The helper object is just a JavaTransformer without any Param Mixin,
+ # copying the params by .copy() or directly assigning the _paramMap won't work
+ instance._serialized_ml_params = serialize_ml_params( # type: ignore[attr-defined]
+ self,
+ dataset.sparkSession.client, # type: ignore[arg-type,operator]
+ )
+ return instance.transform(dataset)
+
+ self._transfer_params_to_java()
jdf = self._java_obj.findFrequentSequentialPatterns(dataset._jdf)
return DataFrame(jdf, dataset.sparkSession)
diff --git a/python/pyspark/ml/tests/test_fpm.py b/python/pyspark/ml/tests/test_fpm.py
index cc8ead7127d69..1c2b717b1c854 100644
--- a/python/pyspark/ml/tests/test_fpm.py
+++ b/python/pyspark/ml/tests/test_fpm.py
@@ -18,11 +18,12 @@
import tempfile
import unittest
-from pyspark.sql import SparkSession
+from pyspark.sql import SparkSession, Row
import pyspark.sql.functions as sf
from pyspark.ml.fpm import (
FPGrowth,
FPGrowthModel,
+ PrefixSpan,
)
@@ -71,6 +72,32 @@ def test_fp_growth(self):
model2 = FPGrowthModel.load(d)
self.assertEqual(str(model), str(model2))
+ def test_prefix_span(self):
+ spark = self.spark
+ df = spark.createDataFrame(
+ [
+ Row(sequence=[[1, 2], [3]]),
+ Row(sequence=[[1], [3, 2], [1, 2]]),
+ Row(sequence=[[1, 2], [5]]),
+ Row(sequence=[[6]]),
+ ]
+ )
+
+ ps = PrefixSpan()
+ ps.setMinSupport(0.5)
+ ps.setMaxPatternLength(5)
+
+ self.assertEqual(ps.getMinSupport(), 0.5)
+ self.assertEqual(ps.getMaxPatternLength(), 5)
+
+ output = ps.findFrequentSequentialPatterns(df)
+ self.assertEqual(output.columns, ["sequence", "freq"])
+ self.assertEqual(output.count(), 5)
+
+ head = output.sort("sequence").head()
+ self.assertEqual(head.sequence, [[1]])
+ self.assertEqual(head.freq, 3)
+
class FPMTests(FPMTestsMixin, unittest.TestCase):
def setUp(self) -> None:
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index 04a91cf2f8b28..b895a86ae5350 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -156,9 +156,14 @@ def wrapped(self: "JavaWrapper", dataset: "ConnectDataFrame") -> Any:
session = dataset.sparkSession
assert session is not None
+
+ if hasattr(self, "_serialized_ml_params"):
+ params = self._serialized_ml_params
+ else:
+ params = serialize_ml_params(self, session.client) # type: ignore[arg-type]
+
# Model is also a Transformer, so we much match Model first
if isinstance(self, Model):
- params = serialize_ml_params(self, session.client)
from pyspark.ml.connect.proto import TransformerRelation
assert isinstance(self._java_obj, str)
@@ -169,7 +174,6 @@ def wrapped(self: "JavaWrapper", dataset: "ConnectDataFrame") -> Any:
session,
)
elif isinstance(self, Transformer):
- params = serialize_ml_params(self, session.client)
from pyspark.ml.connect.proto import TransformerRelation
assert isinstance(self._java_obj, str)