Skip to content

Commit

Permalink
[SPARK-50899][ML][PYTHON][CONNECT] Support PrefixSpan on connect
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Support PrefixSpan on connect

### Why are the changes needed?
feature parity

### Does this PR introduce _any_ user-facing change?
yes

### How was this patch tested?
added tests

### Was this patch authored or co-authored using generative AI tooling?
no

Closes #49763 from zhengruifeng/ml_connect_prefix_span.

Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
(cherry picked from commit f840abb)
Signed-off-by: Ruifeng Zheng <[email protected]>
  • Loading branch information
zhengruifeng committed Feb 3, 2025
1 parent 836a632 commit 9dfd1ff
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ org.apache.spark.ml.recommendation.ALSModel

# fpm
org.apache.spark.ml.fpm.FPGrowthModel
org.apache.spark.ml.fpm.PrefixSpanWrapper

# feature
org.apache.spark.ml.feature.RFormulaModel
Expand Down
95 changes: 61 additions & 34 deletions mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.ml.fpm

import org.apache.spark.annotation.Since
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param._
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.util.Instrumentation.instrumented
Expand All @@ -26,23 +27,7 @@ import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.{ArrayType, LongType, StructField, StructType}

/**
* A parallel PrefixSpan algorithm to mine frequent sequential patterns.
* The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining Sequential Patterns
* Efficiently by Prefix-Projected Pattern Growth
* (see <a href="https://doi.org/10.1109/ICDE.2001.914830">here</a>).
* This class is not yet an Estimator/Transformer, use `findFrequentSequentialPatterns` method to
* run the PrefixSpan algorithm.
*
* @see <a href="https://en.wikipedia.org/wiki/Sequential_Pattern_Mining">Sequential Pattern Mining
* (Wikipedia)</a>
*/
@Since("2.4.0")
final class PrefixSpan(@Since("2.4.0") override val uid: String) extends Params {

@Since("2.4.0")
def this() = this(Identifiable.randomUID("prefixSpan"))

private[fpm] trait PrefixSpanParams extends Params {
/**
* Param for the minimal support level (default: `0.1`).
* Sequential patterns that appear more than (minSupport * size-of-the-dataset) times are
Expand All @@ -59,10 +44,6 @@ final class PrefixSpan(@Since("2.4.0") override val uid: String) extends Params
@Since("2.4.0")
def getMinSupport: Double = $(minSupport)

/** @group setParam */
@Since("2.4.0")
def setMinSupport(value: Double): this.type = set(minSupport, value)

/**
* Param for the maximal pattern length (default: `10`).
* @group param
Expand All @@ -76,10 +57,6 @@ final class PrefixSpan(@Since("2.4.0") override val uid: String) extends Params
@Since("2.4.0")
def getMaxPatternLength: Int = $(maxPatternLength)

/** @group setParam */
@Since("2.4.0")
def setMaxPatternLength(value: Int): this.type = set(maxPatternLength, value)

/**
* Param for the maximum number of items (including delimiters used in the internal storage
* format) allowed in a projected database before local processing (default: `32000000`).
Expand All @@ -90,18 +67,14 @@ final class PrefixSpan(@Since("2.4.0") override val uid: String) extends Params
@Since("2.4.0")
val maxLocalProjDBSize = new LongParam(this, "maxLocalProjDBSize",
"The maximum number of items (including delimiters used in the internal storage format) " +
"allowed in a projected database before local processing. If a projected database exceeds " +
"this size, another iteration of distributed prefix growth is run.",
"allowed in a projected database before local processing. If a projected database exceeds " +
"this size, another iteration of distributed prefix growth is run.",
ParamValidators.gt(0))

/** @group getParam */
@Since("2.4.0")
def getMaxLocalProjDBSize: Long = $(maxLocalProjDBSize)

/** @group setParam */
@Since("2.4.0")
def setMaxLocalProjDBSize(value: Long): this.type = set(maxLocalProjDBSize, value)

/**
* Param for the name of the sequence column in dataset (default "sequence"), rows with
* nulls in this column are ignored.
Expand All @@ -115,12 +88,42 @@ final class PrefixSpan(@Since("2.4.0") override val uid: String) extends Params
@Since("2.4.0")
def getSequenceCol: String = $(sequenceCol)

setDefault(minSupport -> 0.1, maxPatternLength -> 10, maxLocalProjDBSize -> 32000000,
sequenceCol -> "sequence")
}

/**
* A parallel PrefixSpan algorithm to mine frequent sequential patterns.
* The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining Sequential Patterns
* Efficiently by Prefix-Projected Pattern Growth
* (see <a href="https://doi.org/10.1109/ICDE.2001.914830">here</a>).
* This class is not yet an Estimator/Transformer, use `findFrequentSequentialPatterns` method to
* run the PrefixSpan algorithm.
*
* @see <a href="https://en.wikipedia.org/wiki/Sequential_Pattern_Mining">Sequential Pattern Mining
* (Wikipedia)</a>
*/
@Since("2.4.0")
final class PrefixSpan(@Since("2.4.0") override val uid: String) extends PrefixSpanParams {

@Since("2.4.0")
def this() = this(Identifiable.randomUID("prefixSpan"))

/** @group setParam */
@Since("2.4.0")
def setSequenceCol(value: String): this.type = set(sequenceCol, value)
def setMinSupport(value: Double): this.type = set(minSupport, value)

setDefault(minSupport -> 0.1, maxPatternLength -> 10, maxLocalProjDBSize -> 32000000,
sequenceCol -> "sequence")
/** @group setParam */
@Since("2.4.0")
def setMaxPatternLength(value: Int): this.type = set(maxPatternLength, value)

/** @group setParam */
@Since("2.4.0")
def setMaxLocalProjDBSize(value: Long): this.type = set(maxLocalProjDBSize, value)

/** @group setParam */
@Since("2.4.0")
def setSequenceCol(value: String): this.type = set(sequenceCol, value)

/**
* Finds the complete set of frequent sequential patterns in the input sequences of itemsets.
Expand Down Expand Up @@ -167,3 +170,27 @@ final class PrefixSpan(@Since("2.4.0") override val uid: String) extends Params
override def copy(extra: ParamMap): PrefixSpan = defaultCopy(extra)

}

private[spark] class PrefixSpanWrapper(override val uid: String)
extends Transformer with PrefixSpanParams {

def this() = this(Identifiable.randomUID("prefixSpanWrapper"))

override def transformSchema(schema: StructType): StructType = {
new StructType()
.add("sequence", ArrayType(schema($(sequenceCol)).dataType), nullable = false)
.add("freq", LongType, nullable = false)
}

override def transform(dataset: Dataset[_]): DataFrame = {
val prefixSpan = new PrefixSpan(uid)
prefixSpan
.setMinSupport($(minSupport))
.setMaxPatternLength($(maxPatternLength))
.setMaxLocalProjDBSize($(maxLocalProjDBSize))
.setSequenceCol($(sequenceCol))
.findFrequentSequentialPatterns(dataset)
}

override def copy(extra: ParamMap): PrefixSpanWrapper = defaultCopy(extra)
}
18 changes: 17 additions & 1 deletion python/pyspark/ml/fpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,9 +506,25 @@ def findFrequentSequentialPatterns(self, dataset: DataFrame) -> DataFrame:
- `sequence: ArrayType(ArrayType(T))` (T is the item type)
- `freq: Long`
"""
from pyspark.sql.utils import is_remote

self._transfer_params_to_java()
assert self._java_obj is not None

if is_remote():
from pyspark.ml.wrapper import JavaTransformer
from pyspark.ml.connect.serialize import serialize_ml_params

instance = JavaTransformer()
instance._java_obj = "org.apache.spark.ml.fpm.PrefixSpanWrapper"
# The helper object is just a JavaTransformer without any Param Mixin,
# copying the params by .copy() or directly assigning the _paramMap won't work
instance._serialized_ml_params = serialize_ml_params( # type: ignore[attr-defined]
self,
dataset.sparkSession.client, # type: ignore[arg-type,operator]
)
return instance.transform(dataset)

self._transfer_params_to_java()
jdf = self._java_obj.findFrequentSequentialPatterns(dataset._jdf)
return DataFrame(jdf, dataset.sparkSession)

Expand Down
29 changes: 28 additions & 1 deletion python/pyspark/ml/tests/test_fpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@
import tempfile
import unittest

from pyspark.sql import SparkSession
from pyspark.sql import SparkSession, Row
import pyspark.sql.functions as sf
from pyspark.ml.fpm import (
FPGrowth,
FPGrowthModel,
PrefixSpan,
)


Expand Down Expand Up @@ -71,6 +72,32 @@ def test_fp_growth(self):
model2 = FPGrowthModel.load(d)
self.assertEqual(str(model), str(model2))

def test_prefix_span(self):
spark = self.spark
df = spark.createDataFrame(
[
Row(sequence=[[1, 2], [3]]),
Row(sequence=[[1], [3, 2], [1, 2]]),
Row(sequence=[[1, 2], [5]]),
Row(sequence=[[6]]),
]
)

ps = PrefixSpan()
ps.setMinSupport(0.5)
ps.setMaxPatternLength(5)

self.assertEqual(ps.getMinSupport(), 0.5)
self.assertEqual(ps.getMaxPatternLength(), 5)

output = ps.findFrequentSequentialPatterns(df)
self.assertEqual(output.columns, ["sequence", "freq"])
self.assertEqual(output.count(), 5)

head = output.sort("sequence").head()
self.assertEqual(head.sequence, [[1]])
self.assertEqual(head.freq, 3)


class FPMTests(FPMTestsMixin, unittest.TestCase):
def setUp(self) -> None:
Expand Down
8 changes: 6 additions & 2 deletions python/pyspark/ml/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,14 @@ def wrapped(self: "JavaWrapper", dataset: "ConnectDataFrame") -> Any:

session = dataset.sparkSession
assert session is not None

if hasattr(self, "_serialized_ml_params"):
params = self._serialized_ml_params
else:
params = serialize_ml_params(self, session.client) # type: ignore[arg-type]

# Model is also a Transformer, so we much match Model first
if isinstance(self, Model):
params = serialize_ml_params(self, session.client)
from pyspark.ml.connect.proto import TransformerRelation

assert isinstance(self._java_obj, str)
Expand All @@ -169,7 +174,6 @@ def wrapped(self: "JavaWrapper", dataset: "ConnectDataFrame") -> Any:
session,
)
elif isinstance(self, Transformer):
params = serialize_ml_params(self, session.client)
from pyspark.ml.connect.proto import TransformerRelation

assert isinstance(self._java_obj, str)
Expand Down

0 comments on commit 9dfd1ff

Please sign in to comment.