Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-50899][ML][PYTHON][CONNECT] Support PrefixSpan on connect #49763

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ org.apache.spark.ml.recommendation.ALSModel

# fpm
org.apache.spark.ml.fpm.FPGrowthModel
org.apache.spark.ml.fpm.PrefixSpanWrapper

# feature
org.apache.spark.ml.feature.RFormulaModel
Expand Down
95 changes: 61 additions & 34 deletions mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.ml.fpm

import org.apache.spark.annotation.Since
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param._
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.util.Instrumentation.instrumented
Expand All @@ -26,23 +27,7 @@ import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.{ArrayType, LongType, StructField, StructType}

/**
* A parallel PrefixSpan algorithm to mine frequent sequential patterns.
* The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining Sequential Patterns
* Efficiently by Prefix-Projected Pattern Growth
* (see <a href="https://doi.org/10.1109/ICDE.2001.914830">here</a>).
* This class is not yet an Estimator/Transformer, use `findFrequentSequentialPatterns` method to
* run the PrefixSpan algorithm.
*
* @see <a href="https://en.wikipedia.org/wiki/Sequential_Pattern_Mining">Sequential Pattern Mining
* (Wikipedia)</a>
*/
@Since("2.4.0")
final class PrefixSpan(@Since("2.4.0") override val uid: String) extends Params {

@Since("2.4.0")
def this() = this(Identifiable.randomUID("prefixSpan"))

private[fpm] trait PrefixSpanParams extends Params {
/**
* Param for the minimal support level (default: `0.1`).
* Sequential patterns that appear more than (minSupport * size-of-the-dataset) times are
Expand All @@ -59,10 +44,6 @@ final class PrefixSpan(@Since("2.4.0") override val uid: String) extends Params
@Since("2.4.0")
def getMinSupport: Double = $(minSupport)

/** @group setParam */
@Since("2.4.0")
def setMinSupport(value: Double): this.type = set(minSupport, value)

/**
* Param for the maximal pattern length (default: `10`).
* @group param
Expand All @@ -76,10 +57,6 @@ final class PrefixSpan(@Since("2.4.0") override val uid: String) extends Params
@Since("2.4.0")
def getMaxPatternLength: Int = $(maxPatternLength)

/** @group setParam */
@Since("2.4.0")
def setMaxPatternLength(value: Int): this.type = set(maxPatternLength, value)

/**
* Param for the maximum number of items (including delimiters used in the internal storage
* format) allowed in a projected database before local processing (default: `32000000`).
Expand All @@ -90,18 +67,14 @@ final class PrefixSpan(@Since("2.4.0") override val uid: String) extends Params
@Since("2.4.0")
val maxLocalProjDBSize = new LongParam(this, "maxLocalProjDBSize",
"The maximum number of items (including delimiters used in the internal storage format) " +
"allowed in a projected database before local processing. If a projected database exceeds " +
"this size, another iteration of distributed prefix growth is run.",
"allowed in a projected database before local processing. If a projected database exceeds " +
"this size, another iteration of distributed prefix growth is run.",
ParamValidators.gt(0))

/** @group getParam */
@Since("2.4.0")
def getMaxLocalProjDBSize: Long = $(maxLocalProjDBSize)

/** @group setParam */
@Since("2.4.0")
def setMaxLocalProjDBSize(value: Long): this.type = set(maxLocalProjDBSize, value)

/**
* Param for the name of the sequence column in dataset (default "sequence"), rows with
* nulls in this column are ignored.
Expand All @@ -115,12 +88,42 @@ final class PrefixSpan(@Since("2.4.0") override val uid: String) extends Params
@Since("2.4.0")
def getSequenceCol: String = $(sequenceCol)

setDefault(minSupport -> 0.1, maxPatternLength -> 10, maxLocalProjDBSize -> 32000000,
sequenceCol -> "sequence")
}

/**
* A parallel PrefixSpan algorithm to mine frequent sequential patterns.
* The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining Sequential Patterns
* Efficiently by Prefix-Projected Pattern Growth
* (see <a href="https://doi.org/10.1109/ICDE.2001.914830">here</a>).
* This class is not yet an Estimator/Transformer, use `findFrequentSequentialPatterns` method to
* run the PrefixSpan algorithm.
*
* @see <a href="https://en.wikipedia.org/wiki/Sequential_Pattern_Mining">Sequential Pattern Mining
* (Wikipedia)</a>
*/
@Since("2.4.0")
final class PrefixSpan(@Since("2.4.0") override val uid: String) extends PrefixSpanParams {

@Since("2.4.0")
def this() = this(Identifiable.randomUID("prefixSpan"))

/** @group setParam */
@Since("2.4.0")
def setSequenceCol(value: String): this.type = set(sequenceCol, value)
def setMinSupport(value: Double): this.type = set(minSupport, value)

setDefault(minSupport -> 0.1, maxPatternLength -> 10, maxLocalProjDBSize -> 32000000,
sequenceCol -> "sequence")
/** @group setParam */
@Since("2.4.0")
def setMaxPatternLength(value: Int): this.type = set(maxPatternLength, value)

/** @group setParam */
@Since("2.4.0")
def setMaxLocalProjDBSize(value: Long): this.type = set(maxLocalProjDBSize, value)

/** @group setParam */
@Since("2.4.0")
def setSequenceCol(value: String): this.type = set(sequenceCol, value)

/**
* Finds the complete set of frequent sequential patterns in the input sequences of itemsets.
Expand Down Expand Up @@ -167,3 +170,27 @@ final class PrefixSpan(@Since("2.4.0") override val uid: String) extends Params
override def copy(extra: ParamMap): PrefixSpan = defaultCopy(extra)

}

private[spark] class PrefixSpanWrapper(override val uid: String)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PrefixSpan is not a transformer or estimator, add such wrapper so that it can be treated as a transformer on spark connect

extends Transformer with PrefixSpanParams {

def this() = this(Identifiable.randomUID("prefixSpanWrapper"))

override def transformSchema(schema: StructType): StructType = {
new StructType()
.add("sequence", ArrayType(schema($(sequenceCol)).dataType), nullable = false)
.add("freq", LongType, nullable = false)
}

override def transform(dataset: Dataset[_]): DataFrame = {
val prefixSpan = new PrefixSpan(uid)
prefixSpan
.setMinSupport($(minSupport))
.setMaxPatternLength($(maxPatternLength))
.setMaxLocalProjDBSize($(maxLocalProjDBSize))
.setSequenceCol($(sequenceCol))
.findFrequentSequentialPatterns(dataset)
}

override def copy(extra: ParamMap): PrefixSpanWrapper = defaultCopy(extra)
}
18 changes: 17 additions & 1 deletion python/pyspark/ml/fpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,9 +506,25 @@ def findFrequentSequentialPatterns(self, dataset: DataFrame) -> DataFrame:
- `sequence: ArrayType(ArrayType(T))` (T is the item type)
- `freq: Long`
"""
from pyspark.sql.utils import is_remote

self._transfer_params_to_java()
assert self._java_obj is not None

if is_remote():
from pyspark.ml.wrapper import JavaTransformer
from pyspark.ml.connect.serialize import serialize_ml_params

instance = JavaTransformer()
instance._java_obj = "org.apache.spark.ml.fpm.PrefixSpanWrapper"
# The helper object is just a JavaTransformer without any Param Mixin,
# copying the params by .copy() or directly assigning the _paramMap won't work
instance._serialized_ml_params = serialize_ml_params( # type: ignore[attr-defined]
self,
dataset.sparkSession.client, # type: ignore[arg-type,operator]
)
return instance.transform(dataset)

self._transfer_params_to_java()
jdf = self._java_obj.findFrequentSequentialPatterns(dataset._jdf)
return DataFrame(jdf, dataset.sparkSession)

Expand Down
29 changes: 28 additions & 1 deletion python/pyspark/ml/tests/test_fpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@
import tempfile
import unittest

from pyspark.sql import SparkSession
from pyspark.sql import SparkSession, Row
import pyspark.sql.functions as sf
from pyspark.ml.fpm import (
FPGrowth,
FPGrowthModel,
PrefixSpan,
)


Expand Down Expand Up @@ -71,6 +72,32 @@ def test_fp_growth(self):
model2 = FPGrowthModel.load(d)
self.assertEqual(str(model), str(model2))

def test_prefix_span(self):
spark = self.spark
df = spark.createDataFrame(
[
Row(sequence=[[1, 2], [3]]),
Row(sequence=[[1], [3, 2], [1, 2]]),
Row(sequence=[[1, 2], [5]]),
Row(sequence=[[6]]),
]
)

ps = PrefixSpan()
ps.setMinSupport(0.5)
ps.setMaxPatternLength(5)

self.assertEqual(ps.getMinSupport(), 0.5)
self.assertEqual(ps.getMaxPatternLength(), 5)

output = ps.findFrequentSequentialPatterns(df)
self.assertEqual(output.columns, ["sequence", "freq"])
self.assertEqual(output.count(), 5)

head = output.sort("sequence").head()
self.assertEqual(head.sequence, [[1]])
self.assertEqual(head.freq, 3)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PrefixSpan doesn't support save/load



class FPMTests(FPMTestsMixin, unittest.TestCase):
def setUp(self) -> None:
Expand Down
8 changes: 6 additions & 2 deletions python/pyspark/ml/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,14 @@ def wrapped(self: "JavaWrapper", dataset: "ConnectDataFrame") -> Any:

session = dataset.sparkSession
assert session is not None

if hasattr(self, "_serialized_ml_params"):
params = self._serialized_ml_params
else:
params = serialize_ml_params(self, session.client) # type: ignore[arg-type]

# Model is also a Transformer, so we much match Model first
if isinstance(self, Model):
params = serialize_ml_params(self, session.client)
from pyspark.ml.connect.proto import TransformerRelation

assert isinstance(self._java_obj, str)
Expand All @@ -169,7 +174,6 @@ def wrapped(self: "JavaWrapper", dataset: "ConnectDataFrame") -> Any:
session,
)
elif isinstance(self, Transformer):
params = serialize_ml_params(self, session.client)
from pyspark.ml.connect.proto import TransformerRelation

assert isinstance(self._java_obj, str)
Expand Down