Skip to content

Commit

Permalink
[SPARK-50918][ML][PYTHON][CONNECT] Refactor read/write for Pipeline
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

We can use the built-in Pipeline/PipelineModel reader and writer to support read/write on connect

### Why are the changes needed?

Reusing code

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
CI passes

### Was this patch authored or co-authored using generative AI tooling?
No

Closes #49706 from wbo4958/pipeline-read-write.

Authored-by: Bobby Wang <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
  • Loading branch information
wbo4958 authored and zhengruifeng committed Jan 28, 2025
1 parent b49ef2a commit dd51f0e
Showing 1 changed file with 27 additions and 26 deletions.
53 changes: 27 additions & 26 deletions python/pyspark/ml/connect/readwrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.
#
import warnings
from typing import cast, Type, TYPE_CHECKING, Union, List, Dict, Any, Optional
from typing import cast, Type, TYPE_CHECKING, Union, Dict, Any, Optional

import pyspark.sql.connect.proto as pb2
from pyspark.ml.connect.serialize import serialize_ml_params, deserialize, deserialize_param
Expand Down Expand Up @@ -139,26 +139,26 @@ def saveInstance(
command.ml_command.write.CopyFrom(writer)
session.client.execute_command(command)

elif isinstance(instance, (Pipeline, PipelineModel)):
from pyspark.ml.pipeline import PipelineSharedReadWrite
elif isinstance(instance, Pipeline):
from pyspark.ml.pipeline import PipelineWriter

if shouldOverwrite:
# TODO(SPARK-50954): Support client side model path overwrite
warnings.warn("Overwrite doesn't take effect for Pipeline and PipelineModel")
warnings.warn("Overwrite doesn't take effect for Pipeline")

if isinstance(instance, Pipeline):
stages = instance.getStages() # type: ignore[attr-defined]
else:
stages = instance.stages

PipelineSharedReadWrite.validateStages(stages)
PipelineSharedReadWrite.saveImpl(
instance, # type: ignore[arg-type]
stages,
session, # type: ignore[arg-type]
path,
)
pl_writer = PipelineWriter(instance)
pl_writer.session(session) # type: ignore[arg-type]
pl_writer.save(path)
elif isinstance(instance, PipelineModel):
from pyspark.ml.pipeline import PipelineModelWriter

if shouldOverwrite:
# TODO(SPARK-50954): Support client side model path overwrite
warnings.warn("Overwrite doesn't take effect for PipelineModel")

plm_writer = PipelineModelWriter(instance)
plm_writer.session(session) # type: ignore[arg-type]
plm_writer.save(path)
elif isinstance(instance, CrossValidator):
from pyspark.ml.tuning import CrossValidatorWriter

Expand Down Expand Up @@ -231,7 +231,6 @@ def loadInstance(
path: str,
session: "SparkSession",
) -> RL:
from pyspark.ml.base import Transformer
from pyspark.ml.wrapper import JavaModel, JavaEstimator, JavaTransformer
from pyspark.ml.evaluation import JavaEvaluator
from pyspark.ml.pipeline import Pipeline, PipelineModel
Expand Down Expand Up @@ -289,17 +288,19 @@ def _get_class() -> Type[RL]:
else:
raise RuntimeError(f"Unsupported python type {py_type}")

elif issubclass(clazz, Pipeline) or issubclass(clazz, PipelineModel):
from pyspark.ml.pipeline import PipelineSharedReadWrite
from pyspark.ml.util import DefaultParamsReader
elif issubclass(clazz, Pipeline):
from pyspark.ml.pipeline import PipelineReader

metadata = DefaultParamsReader.loadMetadata(path, session)
uid, stages = PipelineSharedReadWrite.load(metadata, session, path)
pl_reader = PipelineReader(Pipeline)
pl_reader.session(session)
return pl_reader.load(path)

if issubclass(clazz, Pipeline):
return Pipeline(stages=stages)._resetUid(uid)
else:
return PipelineModel(stages=cast(List[Transformer], stages))._resetUid(uid)
elif issubclass(clazz, PipelineModel):
from pyspark.ml.pipeline import PipelineModelReader

plm_reader = PipelineModelReader(PipelineModel)
plm_reader.session(session)
return plm_reader.load(path)

elif issubclass(clazz, CrossValidator):
from pyspark.ml.tuning import CrossValidatorReader
Expand Down

0 comments on commit dd51f0e

Please sign in to comment.