From dd51f0e7592de569ce0e4db9c0eca3f05c160cba Mon Sep 17 00:00:00 2001 From: Bobby Wang Date: Tue, 28 Jan 2025 18:55:56 +0800 Subject: [PATCH] [SPARK-50918][ML][PYTHON][CONNECT] Refactor read/write for Pipeline ### What changes were proposed in this pull request? We can use the built-in Pipeline/PipelineModel reader and writer to support read/write on connect ### Why are the changes needed? Reusing code ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? CI passes ### Was this patch authored or co-authored using generative AI tooling? No Closes #49706 from wbo4958/pipeline-read-write. Authored-by: Bobby Wang Signed-off-by: Ruifeng Zheng --- python/pyspark/ml/connect/readwrite.py | 53 +++++++++++++------------- 1 file changed, 27 insertions(+), 26 deletions(-) diff --git a/python/pyspark/ml/connect/readwrite.py b/python/pyspark/ml/connect/readwrite.py index 9e5587ebc5d90..de70a410dbc7c 100644 --- a/python/pyspark/ml/connect/readwrite.py +++ b/python/pyspark/ml/connect/readwrite.py @@ -15,7 +15,7 @@ # limitations under the License. # import warnings -from typing import cast, Type, TYPE_CHECKING, Union, List, Dict, Any, Optional +from typing import cast, Type, TYPE_CHECKING, Union, Dict, Any, Optional import pyspark.sql.connect.proto as pb2 from pyspark.ml.connect.serialize import serialize_ml_params, deserialize, deserialize_param @@ -139,26 +139,26 @@ def saveInstance( command.ml_command.write.CopyFrom(writer) session.client.execute_command(command) - elif isinstance(instance, (Pipeline, PipelineModel)): - from pyspark.ml.pipeline import PipelineSharedReadWrite + elif isinstance(instance, Pipeline): + from pyspark.ml.pipeline import PipelineWriter if shouldOverwrite: # TODO(SPARK-50954): Support client side model path overwrite - warnings.warn("Overwrite doesn't take effect for Pipeline and PipelineModel") + warnings.warn("Overwrite doesn't take effect for Pipeline") - if isinstance(instance, Pipeline): - stages = instance.getStages() # type: ignore[attr-defined] - else: - stages = instance.stages - - PipelineSharedReadWrite.validateStages(stages) - PipelineSharedReadWrite.saveImpl( - instance, # type: ignore[arg-type] - stages, - session, # type: ignore[arg-type] - path, - ) + pl_writer = PipelineWriter(instance) + pl_writer.session(session) # type: ignore[arg-type] + pl_writer.save(path) + elif isinstance(instance, PipelineModel): + from pyspark.ml.pipeline import PipelineModelWriter + if shouldOverwrite: + # TODO(SPARK-50954): Support client side model path overwrite + warnings.warn("Overwrite doesn't take effect for PipelineModel") + + plm_writer = PipelineModelWriter(instance) + plm_writer.session(session) # type: ignore[arg-type] + plm_writer.save(path) elif isinstance(instance, CrossValidator): from pyspark.ml.tuning import CrossValidatorWriter @@ -231,7 +231,6 @@ def loadInstance( path: str, session: "SparkSession", ) -> RL: - from pyspark.ml.base import Transformer from pyspark.ml.wrapper import JavaModel, JavaEstimator, JavaTransformer from pyspark.ml.evaluation import JavaEvaluator from pyspark.ml.pipeline import Pipeline, PipelineModel @@ -289,17 +288,19 @@ def _get_class() -> Type[RL]: else: raise RuntimeError(f"Unsupported python type {py_type}") - elif issubclass(clazz, Pipeline) or issubclass(clazz, PipelineModel): - from pyspark.ml.pipeline import PipelineSharedReadWrite - from pyspark.ml.util import DefaultParamsReader + elif issubclass(clazz, Pipeline): + from pyspark.ml.pipeline import PipelineReader - metadata = DefaultParamsReader.loadMetadata(path, session) - uid, stages = PipelineSharedReadWrite.load(metadata, session, path) + pl_reader = PipelineReader(Pipeline) + pl_reader.session(session) + return pl_reader.load(path) - if issubclass(clazz, Pipeline): - return Pipeline(stages=stages)._resetUid(uid) - else: - return PipelineModel(stages=cast(List[Transformer], stages))._resetUid(uid) + elif issubclass(clazz, PipelineModel): + from pyspark.ml.pipeline import PipelineModelReader + + plm_reader = PipelineModelReader(PipelineModel) + plm_reader.session(session) + return plm_reader.load(path) elif issubclass(clazz, CrossValidator): from pyspark.ml.tuning import CrossValidatorReader