diff --git a/python/pyspark/ml/connect/readwrite.py b/python/pyspark/ml/connect/readwrite.py index 9e5587ebc5d90..de70a410dbc7c 100644 --- a/python/pyspark/ml/connect/readwrite.py +++ b/python/pyspark/ml/connect/readwrite.py @@ -15,7 +15,7 @@ # limitations under the License. # import warnings -from typing import cast, Type, TYPE_CHECKING, Union, List, Dict, Any, Optional +from typing import cast, Type, TYPE_CHECKING, Union, Dict, Any, Optional import pyspark.sql.connect.proto as pb2 from pyspark.ml.connect.serialize import serialize_ml_params, deserialize, deserialize_param @@ -139,26 +139,26 @@ def saveInstance( command.ml_command.write.CopyFrom(writer) session.client.execute_command(command) - elif isinstance(instance, (Pipeline, PipelineModel)): - from pyspark.ml.pipeline import PipelineSharedReadWrite + elif isinstance(instance, Pipeline): + from pyspark.ml.pipeline import PipelineWriter if shouldOverwrite: # TODO(SPARK-50954): Support client side model path overwrite - warnings.warn("Overwrite doesn't take effect for Pipeline and PipelineModel") + warnings.warn("Overwrite doesn't take effect for Pipeline") - if isinstance(instance, Pipeline): - stages = instance.getStages() # type: ignore[attr-defined] - else: - stages = instance.stages - - PipelineSharedReadWrite.validateStages(stages) - PipelineSharedReadWrite.saveImpl( - instance, # type: ignore[arg-type] - stages, - session, # type: ignore[arg-type] - path, - ) + pl_writer = PipelineWriter(instance) + pl_writer.session(session) # type: ignore[arg-type] + pl_writer.save(path) + elif isinstance(instance, PipelineModel): + from pyspark.ml.pipeline import PipelineModelWriter + if shouldOverwrite: + # TODO(SPARK-50954): Support client side model path overwrite + warnings.warn("Overwrite doesn't take effect for PipelineModel") + + plm_writer = PipelineModelWriter(instance) + plm_writer.session(session) # type: ignore[arg-type] + plm_writer.save(path) elif isinstance(instance, CrossValidator): from pyspark.ml.tuning import CrossValidatorWriter @@ -231,7 +231,6 @@ def loadInstance( path: str, session: "SparkSession", ) -> RL: - from pyspark.ml.base import Transformer from pyspark.ml.wrapper import JavaModel, JavaEstimator, JavaTransformer from pyspark.ml.evaluation import JavaEvaluator from pyspark.ml.pipeline import Pipeline, PipelineModel @@ -289,17 +288,19 @@ def _get_class() -> Type[RL]: else: raise RuntimeError(f"Unsupported python type {py_type}") - elif issubclass(clazz, Pipeline) or issubclass(clazz, PipelineModel): - from pyspark.ml.pipeline import PipelineSharedReadWrite - from pyspark.ml.util import DefaultParamsReader + elif issubclass(clazz, Pipeline): + from pyspark.ml.pipeline import PipelineReader - metadata = DefaultParamsReader.loadMetadata(path, session) - uid, stages = PipelineSharedReadWrite.load(metadata, session, path) + pl_reader = PipelineReader(Pipeline) + pl_reader.session(session) + return pl_reader.load(path) - if issubclass(clazz, Pipeline): - return Pipeline(stages=stages)._resetUid(uid) - else: - return PipelineModel(stages=cast(List[Transformer], stages))._resetUid(uid) + elif issubclass(clazz, PipelineModel): + from pyspark.ml.pipeline import PipelineModelReader + + plm_reader = PipelineModelReader(PipelineModel) + plm_reader.session(session) + return plm_reader.load(path) elif issubclass(clazz, CrossValidator): from pyspark.ml.tuning import CrossValidatorReader